package uk.ac.sussex.gdsc.smlm.function;

import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import java.math.BigDecimal;
import java.math.MathContext;
import java.util.Arrays;
import java.util.concurrent.ConcurrentHashMap;
import java.util.logging.Logger;
import org.apache.commons.math3.analysis.integration.SimpsonIntegrator;
import org.apache.commons.math3.distribution.PoissonDistribution;
import org.apache.commons.math3.util.Precision;
import org.apache.commons.rng.RestorableUniformRandomProvider;
import org.apache.commons.rng.sampling.distribution.DiscreteSampler;
import org.apache.commons.rng.sampling.distribution.SharedStateContinuousSampler;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import uk.ac.sussex.gdsc.core.data.DataException;
import uk.ac.sussex.gdsc.core.math.QuadraticUtils;
import uk.ac.sussex.gdsc.core.utils.DoubleEquality;
import uk.ac.sussex.gdsc.core.utils.SimpleArrayUtils;
import uk.ac.sussex.gdsc.core.utils.rng.SamplerUtils;
import uk.ac.sussex.gdsc.smlm.GdscSmlmTestUtils;
import uk.ac.sussex.gdsc.smlm.function.gaussian.AstigmatismZModel;
import uk.ac.sussex.gdsc.smlm.function.gaussian.Gaussian2DFunction;
import uk.ac.sussex.gdsc.smlm.function.gaussian.GaussianFunctionFactory;
import uk.ac.sussex.gdsc.smlm.utils.StdMath;
import uk.ac.sussex.gdsc.test.api.Predicates;
import uk.ac.sussex.gdsc.test.api.TestAssertions;
import uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate;
import uk.ac.sussex.gdsc.test.junit5.SeededTest;
import uk.ac.sussex.gdsc.test.rng.RngFactory;
import uk.ac.sussex.gdsc.test.utils.RandomSeed;
import uk.ac.sussex.gdsc.test.utils.TestLogging;
import uk.ac.sussex.gdsc.test.utils.functions.FormatSupplier;
import uk.ac.sussex.gdsc.test.utils.functions.IntArrayFormatSupplier;

/* loaded from: input_file:uk/ac/sussex/gdsc/smlm/function/ScmosLikelihoodWrapperTest.class */
class ScmosLikelihoodWrapperTest {
    private static Logger logger;
    private static ConcurrentHashMap<RandomSeed, Object> dataCache;
    static final double P_LIMIT = 0.999999d;
    static String[] NAME = new String[8];
    private double[] testbackground;
    private double[] testsignal1;
    private double[] testangle1;
    private double[] testcx1;
    private double[] testcy1;
    private double[] testcz1;
    private double[][] testw1;
    private static int maxx;
    private static float VAR;
    private static float G;
    private static float G_SD;
    private static float O;
    private final double[] photons = {1.0d, 1.5d, 2.0d, 2.5d, 3.0d, 4.0d, 5.0d, 7.5d, 10.0d, 100.0d, 1000.0d};
    DoubleEquality eqPerDatum = new DoubleEquality(5.0E12d, 0.01d);
    DoubleEquality eq = new DoubleEquality(0.005d, 0.001d);
    private final double stepH = 0.01d;
    private final int[] testx = {4, 5, 6};
    private final int[] testy = {4, 5, 6};
    private final double[] testbackgroundOptions = {0.1d, 1.0d, 10.0d};
    private final double[] testsignal1Options = {15.0d, 55.0d, 105.0d};
    private final double[] testangle1Options = {0.6283185307179586d, 1.0471975511965976d};
    private final double[] testcx1Options = {4.9d, 5.3d};
    private final double[] testcy1Options = {4.8d, 5.2d};
    private final double[] testcz1Options = {-1.5d, 1.0d};
    private final double[][] testw1Options = {new double[]{1.1d, 1.4d}, new double[]{1.1d, 1.7d}, new double[]{1.5d, 1.2d}, new double[]{1.3d, 1.7d}};

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:uk/ac/sussex/gdsc/smlm/function/ScmosLikelihoodWrapperTest$BaseNonLinearFunction.class */
    public abstract class BaseNonLinearFunction implements NonLinearFunction {
        double[] params;
        String name;

        BaseNonLinearFunction(String str) {
            this.name = str;
        }

        public void initialise(double[] dArr) {
            this.params = dArr;
        }

        public int[] gradientIndices() {
            return new int[1];
        }

        public double evalw(int i, double[] dArr, double[] dArr2) {
            return 0.0d;
        }

        public double evalw(int i, double[] dArr) {
            return 0.0d;
        }

        public double eval(int i, double[] dArr) {
            return 0.0d;
        }

        public boolean canComputeWeights() {
            return false;
        }

        public int getNumberOfGradients() {
            return 1;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:uk/ac/sussex/gdsc/smlm/function/ScmosLikelihoodWrapperTest$SCcmosLikelihoodWrapperTestData.class */
    public static class SCcmosLikelihoodWrapperTestData {
        float[] var;
        float[] gain;
        float[] offset;
        float[] sd;

        private SCcmosLikelihoodWrapperTestData() {
        }
    }

    /* JADX WARN: Type inference failed for: r1v22, types: [double[], double[][]] */
    ScmosLikelihoodWrapperTest() {
    }

    @BeforeAll
    public static void beforeAll() {
        logger = Logger.getLogger(ScmosLikelihoodWrapperTest.class.getName());
        dataCache = new ConcurrentHashMap<>();
    }

    @AfterAll
    public static void afterAll() {
        dataCache.clear();
        dataCache = null;
        logger = null;
    }

    private static Object createData(RandomSeed randomSeed) {
        int i = maxx * maxx;
        SCcmosLikelihoodWrapperTestData sCcmosLikelihoodWrapperTestData = new SCcmosLikelihoodWrapperTestData();
        sCcmosLikelihoodWrapperTestData.var = new float[i];
        sCcmosLikelihoodWrapperTestData.gain = new float[i];
        sCcmosLikelihoodWrapperTestData.offset = new float[i];
        sCcmosLikelihoodWrapperTestData.sd = new float[i];
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        DiscreteSampler createPoissonSampler = GdscSmlmTestUtils.createPoissonSampler(create, O);
        SharedStateContinuousSampler createGaussianSampler = SamplerUtils.createGaussianSampler(create, G, G_SD);
        SharedStateContinuousSampler createExponentialSampler = SamplerUtils.createExponentialSampler(create, VAR);
        for (int i2 = 0; i2 < i; i2++) {
            sCcmosLikelihoodWrapperTestData.offset[i2] = createPoissonSampler.sample();
            sCcmosLikelihoodWrapperTestData.var[i2] = (float) createExponentialSampler.sample();
            sCcmosLikelihoodWrapperTestData.sd[i2] = (float) Math.sqrt(sCcmosLikelihoodWrapperTestData.var[i2]);
            sCcmosLikelihoodWrapperTestData.gain[i2] = (float) createGaussianSampler.sample();
        }
        return sCcmosLikelihoodWrapperTestData;
    }

    @SeededTest
    void fitFixedComputesGradientPerDatum(RandomSeed randomSeed) {
        functionComputesGradientPerDatum(randomSeed, 17);
    }

    @SeededTest
    void fitCircleComputesGradientPerDatum(RandomSeed randomSeed) {
        functionComputesGradientPerDatum(randomSeed, 21);
    }

    @SeededTest
    void fitFreeCircleComputesGradientPerDatum(RandomSeed randomSeed) {
        functionComputesGradientPerDatum(randomSeed, 29);
    }

    @SeededTest
    void fitEllipticalComputesGradientPerDatum(RandomSeed randomSeed) {
        functionComputesGradientPerDatum(randomSeed, 31);
    }

    @SeededTest
    void fitNbFixedComputesGradientPerDatum(RandomSeed randomSeed) {
        functionComputesGradientPerDatum(randomSeed, 528);
    }

    @SeededTest
    void fitNbCircleComputesGradientPerDatum(RandomSeed randomSeed) {
        functionComputesGradientPerDatum(randomSeed, 532);
    }

    @SeededTest
    void fitNbFreeCircleComputesGradientPerDatum(RandomSeed randomSeed) {
        functionComputesGradientPerDatum(randomSeed, 540);
    }

    @SeededTest
    void fitNbEllipticalComputesGradientPerDatum(RandomSeed randomSeed) {
        functionComputesGradientPerDatum(randomSeed, 542);
    }

    /* JADX WARN: Type inference failed for: r1v35, types: [double[], double[][]] */
    private void functionComputesGradientPerDatum(RandomSeed randomSeed, int i) {
        Gaussian2DFunction create2D = GaussianFunctionFactory.create2D(1, maxx, maxx, i, (AstigmatismZModel) null);
        this.testbackground = this.testbackgroundOptions;
        this.testsignal1 = this.testsignal1Options;
        this.testcx1 = this.testcx1Options;
        this.testcy1 = this.testcy1Options;
        this.testcz1 = this.testcz1Options;
        this.testw1 = this.testw1Options;
        this.testangle1 = this.testangle1Options;
        if (!create2D.evaluatesBackground()) {
            this.testbackground = new double[]{this.testbackground[0]};
        }
        if (!create2D.evaluatesSignal()) {
            this.testsignal1 = new double[]{this.testsignal1[0]};
        }
        if (!create2D.evaluatesZ()) {
            this.testcz1 = new double[]{0.0d};
        }
        boolean z = false;
        if (!create2D.evaluatesSD0()) {
            this.testw1 = new double[]{this.testw1[0]};
            z = true;
        } else if (!create2D.evaluatesSD1()) {
            this.testw1 = (double[][]) Arrays.copyOf(this.testw1, 2);
            z = true;
        }
        if (z) {
            for (int i2 = 0; i2 < this.testw1.length; i2++) {
                this.testw1[i2][1] = this.testw1[i2][0];
            }
        }
        if (!create2D.evaluatesAngle()) {
            this.testangle1 = new double[]{0.0d};
        }
        if (create2D.evaluatesBackground()) {
            functionComputesTargetGradientPerDatum(randomSeed, create2D, 0);
        }
        if (create2D.evaluatesSignal()) {
            functionComputesTargetGradientPerDatum(randomSeed, create2D, 1);
        }
        functionComputesTargetGradientPerDatum(randomSeed, create2D, 2);
        functionComputesTargetGradientPerDatum(randomSeed, create2D, 3);
        if (create2D.evaluatesZ()) {
            functionComputesTargetGradientPerDatum(randomSeed, create2D, 4);
        }
        if (create2D.evaluatesSD0()) {
            functionComputesTargetGradientPerDatum(randomSeed, create2D, 5);
        }
        if (create2D.evaluatesSD1()) {
            functionComputesTargetGradientPerDatum(randomSeed, create2D, 6);
        }
        if (create2D.evaluatesAngle()) {
            functionComputesTargetGradientPerDatum(randomSeed, create2D, 7);
        }
    }

    private void functionComputesTargetGradientPerDatum(RandomSeed randomSeed, Gaussian2DFunction gaussian2DFunction, int i) {
        int[] gradientIndices = gaussian2DFunction.gradientIndices();
        int findGradientIndex = findGradientIndex(gaussian2DFunction, i);
        double[] dArr = new double[gradientIndices.length];
        int i2 = maxx * maxx;
        int i3 = 0;
        int i4 = 0;
        SCcmosLikelihoodWrapperTestData sCcmosLikelihoodWrapperTestData = (SCcmosLikelihoodWrapperTestData) dataCache.computeIfAbsent(randomSeed, ScmosLikelihoodWrapperTest::createData);
        float[] fArr = sCcmosLikelihoodWrapperTestData.var;
        float[] fArr2 = sCcmosLikelihoodWrapperTestData.gain;
        float[] fArr3 = sCcmosLikelihoodWrapperTestData.offset;
        float[] fArr4 = sCcmosLikelihoodWrapperTestData.sd;
        SharedStateContinuousSampler createGaussianSampler = SamplerUtils.createGaussianSampler(RngFactory.create(randomSeed.get()), 0.0d, 1.0d);
        for (double d : this.testbackground) {
            for (double d2 : this.testsignal1) {
                for (double d3 : this.testcx1) {
                    for (double d4 : this.testcy1) {
                        for (double d5 : this.testcz1) {
                            for (double[] dArr2 : this.testw1) {
                                for (double d6 : this.testangle1) {
                                    double[] createParameters = createParameters(d, d2, d3, d4, d5, dArr2[0], dArr2[1], d6);
                                    double[] dArr3 = (double[]) createParameters.clone();
                                    dArr3[i] = dArr3[i] * 1.1d;
                                    gaussian2DFunction.initialise(dArr3);
                                    double[] dArr4 = new double[i2];
                                    for (int i5 = 0; i5 < i2; i5++) {
                                        dArr4[i5] = (GdscSmlmTestUtils.createPoissonSampler(r0, gaussian2DFunction.eval(i5)).sample() * fArr2[i5]) + fArr3[i5] + (createGaussianSampler.sample() * fArr4[i5]);
                                    }
                                    ScmosLikelihoodWrapper scmosLikelihoodWrapper = new ScmosLikelihoodWrapper(gaussian2DFunction, createParameters, dArr4, i2, fArr, fArr2, fArr3);
                                    double d7 = createParameters[i];
                                    double representableDelta = Precision.representableDelta(d7, 0.01d);
                                    for (int i6 : this.testx) {
                                        for (int i7 : this.testy) {
                                            int i8 = (i7 * maxx) + i6;
                                            createParameters[i] = d7;
                                            scmosLikelihoodWrapper.likelihood(getVariables(gradientIndices, createParameters), dArr, i8);
                                            createParameters[i] = d7 + representableDelta;
                                            double likelihood = scmosLikelihoodWrapper.likelihood(getVariables(gradientIndices, createParameters), i8);
                                            createParameters[i] = d7 - representableDelta;
                                            double likelihood2 = (likelihood - scmosLikelihoodWrapper.likelihood(getVariables(gradientIndices, createParameters), i8)) / (2.0d * representableDelta);
                                            if (!(Math.signum(likelihood2) == Math.signum(dArr[findGradientIndex]) || Math.abs(likelihood2 - dArr[findGradientIndex]) < 0.1d)) {
                                                Assertions.fail(NAME[i] + ": " + likelihood2 + " != " + dArr[findGradientIndex]);
                                            }
                                            if (this.eqPerDatum.almostEqualRelativeOrAbsolute(likelihood2, dArr[findGradientIndex])) {
                                                i3++;
                                            }
                                            i4++;
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
        double d8 = (100.0d * i3) / i4;
        logger.log(TestLogging.getRecord(TestLogging.TestLevel.TEST_INFO, "Per Datum %s : %s = %d / %d (%.2f)", new Object[]{gaussian2DFunction.getClass().getSimpleName(), NAME[i], Integer.valueOf(i3), Integer.valueOf(i4), Double.valueOf(d8)}));
        Assertions.assertTrue(d8 > 90.0d, () -> {
            return NAME[i] + " fraction too low per datum: " + d8;
        });
    }

    @SeededTest
    void fitFixedComputesGradient(RandomSeed randomSeed) {
        functionComputesGradient(randomSeed, 17);
    }

    @SeededTest
    void fitCircleComputesGradient(RandomSeed randomSeed) {
        functionComputesGradient(randomSeed, 21);
    }

    @SeededTest
    void fitFreeCircleComputesGradient(RandomSeed randomSeed) {
        functionComputesGradient(randomSeed, 29);
    }

    @SeededTest
    void fitEllipticalComputesGradient(RandomSeed randomSeed) {
        DoubleEquality doubleEquality = this.eq;
        this.eq = this.eqPerDatum;
        functionComputesGradient(randomSeed, 31);
        this.eq = doubleEquality;
    }

    @SeededTest
    void fitNbFixedComputesGradient(RandomSeed randomSeed) {
        functionComputesGradient(randomSeed, 528);
    }

    @SeededTest
    void fitNbCircleComputesGradient(RandomSeed randomSeed) {
        functionComputesGradient(randomSeed, 532);
    }

    @SeededTest
    void fitNbFreeCircleComputesGradient(RandomSeed randomSeed) {
        functionComputesGradient(randomSeed, 540);
    }

    @SeededTest
    void fitNbEllipticalComputesGradient(RandomSeed randomSeed) {
        DoubleEquality doubleEquality = this.eq;
        this.eq = this.eqPerDatum;
        functionComputesGradient(randomSeed, 542);
        this.eq = doubleEquality;
    }

    /* JADX WARN: Type inference failed for: r1v35, types: [double[], double[][]] */
    private void functionComputesGradient(RandomSeed randomSeed, int i) {
        Gaussian2DFunction create2D = GaussianFunctionFactory.create2D(1, maxx, maxx, i, (AstigmatismZModel) null);
        this.testbackground = this.testbackgroundOptions;
        this.testsignal1 = this.testsignal1Options;
        this.testcx1 = this.testcx1Options;
        this.testcy1 = this.testcy1Options;
        this.testcz1 = this.testcz1Options;
        this.testw1 = this.testw1Options;
        this.testangle1 = this.testangle1Options;
        if (!create2D.evaluatesBackground()) {
            this.testbackground = new double[]{this.testbackground[0]};
        }
        if (!create2D.evaluatesSignal()) {
            this.testsignal1 = new double[]{this.testsignal1[0]};
        }
        if (!create2D.evaluatesZ()) {
            this.testcz1 = new double[]{0.0d};
        }
        boolean z = false;
        if (!create2D.evaluatesSD0()) {
            this.testw1 = new double[]{this.testw1[0]};
            z = true;
        } else if (!create2D.evaluatesSD1()) {
            this.testw1 = (double[][]) Arrays.copyOf(this.testw1, 2);
            z = true;
        }
        if (z) {
            for (int i2 = 0; i2 < this.testw1.length; i2++) {
                this.testw1[i2][1] = this.testw1[i2][0];
            }
        }
        if (!create2D.evaluatesAngle()) {
            this.testangle1 = new double[]{0.0d};
        }
        if (create2D.evaluatesBackground()) {
            functionComputesTargetGradient(randomSeed, create2D, 0, 85.0d);
        }
        if (create2D.evaluatesSignal()) {
            functionComputesTargetGradient(randomSeed, create2D, 1, 85.0d);
        }
        functionComputesTargetGradient(randomSeed, create2D, 2, 85.0d);
        functionComputesTargetGradient(randomSeed, create2D, 3, 85.0d);
        if (create2D.evaluatesZ()) {
            functionComputesTargetGradient(randomSeed, create2D, 4, 85.0d);
        }
        if (create2D.evaluatesSD0()) {
            functionComputesTargetGradient(randomSeed, create2D, 5, 85.0d);
        }
        if (create2D.evaluatesSD1()) {
            functionComputesTargetGradient(randomSeed, create2D, 6, 85.0d);
        }
        if (create2D.evaluatesAngle()) {
            functionComputesTargetGradient(randomSeed, create2D, 7, 85.0d);
        }
    }

    private void functionComputesTargetGradient(RandomSeed randomSeed, Gaussian2DFunction gaussian2DFunction, int i, double d) {
        int[] gradientIndices = gaussian2DFunction.gradientIndices();
        int findGradientIndex = findGradientIndex(gaussian2DFunction, i);
        double[] dArr = new double[gradientIndices.length];
        int i2 = maxx * maxx;
        int i3 = 0;
        int i4 = 0;
        SCcmosLikelihoodWrapperTestData sCcmosLikelihoodWrapperTestData = (SCcmosLikelihoodWrapperTestData) dataCache.computeIfAbsent(randomSeed, ScmosLikelihoodWrapperTest::createData);
        float[] fArr = sCcmosLikelihoodWrapperTestData.var;
        float[] fArr2 = sCcmosLikelihoodWrapperTestData.gain;
        float[] fArr3 = sCcmosLikelihoodWrapperTestData.offset;
        float[] fArr4 = sCcmosLikelihoodWrapperTestData.sd;
        SharedStateContinuousSampler createGaussianSampler = SamplerUtils.createGaussianSampler(RngFactory.create(randomSeed.get()), 0.0d, 1.0d);
        for (double d2 : this.testbackground) {
            for (double d3 : this.testsignal1) {
                for (double d4 : this.testcx1) {
                    for (double d5 : this.testcy1) {
                        for (double d6 : this.testcz1) {
                            for (double[] dArr2 : this.testw1) {
                                for (double d7 : this.testangle1) {
                                    double[] createParameters = createParameters(d2, d3, d4, d5, d6, dArr2[0], dArr2[1], d7);
                                    double[] dArr3 = (double[]) createParameters.clone();
                                    dArr3[i] = dArr3[i] * 1.3d;
                                    gaussian2DFunction.initialise(dArr3);
                                    double[] dArr4 = new double[i2];
                                    for (int i5 = 0; i5 < i2; i5++) {
                                        dArr4[i5] = (GdscSmlmTestUtils.createPoissonSampler(r0, gaussian2DFunction.eval(i5)).sample() * fArr2[i5]) + fArr3[i5] + (createGaussianSampler.sample() * fArr4[i5]);
                                    }
                                    ScmosLikelihoodWrapper scmosLikelihoodWrapper = new ScmosLikelihoodWrapper(gaussian2DFunction, createParameters, dArr4, i2, fArr, fArr2, fArr3);
                                    double d8 = createParameters[i];
                                    double representableDelta = Precision.representableDelta(d8, 0.01d);
                                    scmosLikelihoodWrapper.likelihood(getVariables(gradientIndices, createParameters), dArr);
                                    createParameters[i] = d8 + representableDelta;
                                    double likelihood = scmosLikelihoodWrapper.likelihood(getVariables(gradientIndices, createParameters));
                                    createParameters[i] = d8 - representableDelta;
                                    double likelihood2 = (likelihood - scmosLikelihoodWrapper.likelihood(getVariables(gradientIndices, createParameters))) / (2.0d * representableDelta);
                                    if (!(Math.signum(likelihood2) == Math.signum(dArr[findGradientIndex]) || Math.abs(likelihood2 - dArr[findGradientIndex]) < 0.1d)) {
                                        Assertions.fail(NAME[i] + ": " + likelihood2 + " != " + dArr[findGradientIndex]);
                                    }
                                    if (this.eq.almostEqualRelativeOrAbsolute(likelihood2, dArr[findGradientIndex])) {
                                        i3++;
                                    }
                                    i4++;
                                }
                            }
                        }
                    }
                }
            }
        }
        double d9 = (100.0d * i3) / i4;
        logger.log(TestLogging.getRecord(TestLogging.TestLevel.TEST_INFO, "%s : %s = %d / %d (%.2f)", new Object[]{gaussian2DFunction.getClass().getSimpleName(), NAME[i], Integer.valueOf(i3), Integer.valueOf(i4), Double.valueOf(d9)}));
        Assertions.assertTrue(d9 > d, FormatSupplier.getSupplier("%s fraction too low: %s", new Object[]{NAME[i], Double.valueOf(d9)}));
    }

    private static double[] getVariables(int[] iArr, double[] dArr) {
        double[] dArr2 = new double[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            dArr2[i] = dArr[iArr[i]];
        }
        return dArr2;
    }

    private static int findGradientIndex(Gaussian2DFunction gaussian2DFunction, int i) {
        int findGradientIndex = gaussian2DFunction.findGradientIndex(i);
        Assertions.assertTrue(findGradientIndex >= 0, "Cannot find gradient index");
        return findGradientIndex;
    }

    double[] createParameters(double... dArr) {
        return dArr;
    }

    @Test
    void cumulativeProbabilityIsOneWithRealDataForCountAbove8() {
        double[] dArr = this.photons;
        int length = dArr.length;
        for (int i = 0; i < length; i++) {
            double d = dArr[i];
            double inverseCumulativeProbability = new PoissonDistribution(d).inverseCumulativeProbability(P_LIMIT);
            cumulativeProbabilityIsOneWithRealData(d, (((int) Math.max(0.0d, d - (4.0d * Math.sqrt(d)))) * G) + O, (inverseCumulativeProbability * G) + O, d > 8.0d);
        }
    }

    private static void cumulativeProbabilityIsOneWithRealData(double d, double d2, double d3, boolean z) {
        double integrate = new SimpsonIntegrator().integrate(20000, d4 -> {
            return ScmosLikelihoodWrapper.likelihood(d, VAR, G, O, d4);
        }, d2, d3);
        if (z) {
            Assertions.assertEquals(P_LIMIT, integrate, 0.02d, () -> {
                return "mu=" + d;
            });
        }
    }

    @Test
    void instanceLikelihoodMatches() {
        double[] dArr = this.photons;
        int length = dArr.length;
        for (int i = 0; i < length; i++) {
            double d = dArr[i];
            instanceLikelihoodMatches(d, d > 8.0d);
        }
    }

    private static void instanceLikelihoodMatches(final double d, boolean z) {
        int ceil = (int) Math.ceil((new PoissonDistribution(d).inverseCumulativeProbability(P_LIMIT) * G) / 0.1d);
        double[] newArray = SimpleArrayUtils.newArray(ceil, O, 0.1d);
        double[] dArr = new double[0];
        double[] dArr2 = new double[0];
        float[] newArray2 = newArray(ceil, VAR);
        float[] newArray3 = newArray(ceil, G);
        float[] newArray4 = newArray(ceil, O);
        NonLinearFunction nonLinearFunction = new NonLinearFunction() { // from class: uk.ac.sussex.gdsc.smlm.function.ScmosLikelihoodWrapperTest.1
            public void initialise(double[] dArr3) {
            }

            public int[] gradientIndices() {
                return new int[0];
            }

            public double evalw(int i, double[] dArr3, double[] dArr4) {
                return 0.0d;
            }

            public double evalw(int i, double[] dArr3) {
                return 0.0d;
            }

            public double eval(int i) {
                return d;
            }

            public double eval(int i, double[] dArr3) {
                return d;
            }

            public boolean canComputeWeights() {
                return false;
            }

            public int getNumberOfGradients() {
                return 0;
            }
        };
        ScmosLikelihoodWrapper scmosLikelihoodWrapper = new ScmosLikelihoodWrapper(nonLinearFunction, dArr, newArray, ceil, newArray2, newArray3, newArray4);
        IntArrayFormatSupplier intArrayFormatSupplier = new IntArrayFormatSupplier("computeLikelihood @ %d", 1);
        IntArrayFormatSupplier intArrayFormatSupplier2 = new IntArrayFormatSupplier("computeLikelihood+gradient @ %d", 1);
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        int i = 0;
        DoubleDoubleBiPredicate doublesAreClose = Predicates.doublesAreClose(1.0E-10d, 0.0d);
        for (int i2 = 0; i2 < ceil; i2++) {
            double computeLikelihood = scmosLikelihoodWrapper.computeLikelihood(i2);
            double computeLikelihood2 = scmosLikelihoodWrapper.computeLikelihood(dArr2, i2);
            double negativeLogLikelihood = ScmosLikelihoodWrapper.negativeLogLikelihood(d, newArray2[i2], newArray3[i2], newArray4[i2], newArray[i2]);
            d2 += computeLikelihood;
            TestAssertions.assertTest(negativeLogLikelihood, computeLikelihood, doublesAreClose, intArrayFormatSupplier.set(0, i2));
            TestAssertions.assertTest(negativeLogLikelihood, computeLikelihood2, doublesAreClose, intArrayFormatSupplier2.set(0, i2));
            double exp = StdMath.exp(-computeLikelihood);
            if (d4 < exp) {
                d4 = exp;
                i = i2;
            }
            d3 += exp * 0.1d;
        }
        TestAssertions.assertTest(((Math.floor(d) + (Math.ceil(d) - 1.0d)) * 0.5d * G) + O, newArray[i], Predicates.doublesAreClose(0.001d, 0.0d), "k-max");
        if (z) {
            Assertions.assertEquals(P_LIMIT, d3, 0.02d, () -> {
                return "mu=" + d;
            });
        }
        double computeLikelihood3 = scmosLikelihoodWrapper.computeLikelihood();
        double computeLikelihood4 = scmosLikelihoodWrapper.computeLikelihood(dArr2);
        TestAssertions.assertTest(d2, computeLikelihood3, doublesAreClose, "computeLikelihood");
        TestAssertions.assertTest(d2, computeLikelihood4, doublesAreClose, "computeLikelihood with gradient");
        ScmosLikelihoodWrapper build = scmosLikelihoodWrapper.build(nonLinearFunction, dArr);
        double computeLikelihood5 = build.computeLikelihood();
        double computeLikelihood6 = build.computeLikelihood(dArr2);
        TestAssertions.assertTest(d2, computeLikelihood5, doublesAreClose, "computeLikelihood");
        TestAssertions.assertTest(d2, computeLikelihood6, doublesAreClose, "computeLikelihood with gradient");
    }

    private static float[] newArray(int i, float f) {
        float[] fArr = new float[i];
        Arrays.fill(fArr, f);
        return fArr;
    }

    @SeededTest
    void canComputePValue(RandomSeed randomSeed) {
        final double d = maxx * maxx * 0.5d;
        canComputePValue(randomSeed, new BaseNonLinearFunction("Linear") { // from class: uk.ac.sussex.gdsc.smlm.function.ScmosLikelihoodWrapperTest.2
            public double eval(int i) {
                return this.params[0] * (i - d);
            }
        });
        canComputePValue(randomSeed, new BaseNonLinearFunction("Quadratic") { // from class: uk.ac.sussex.gdsc.smlm.function.ScmosLikelihoodWrapperTest.3
            public double eval(int i) {
                return this.params[0] * (i - d) * (i - d);
            }
        });
        canComputePValue(randomSeed, new BaseNonLinearFunction("Linear+C") { // from class: uk.ac.sussex.gdsc.smlm.function.ScmosLikelihoodWrapperTest.4
            public double eval(int i) {
                return (10.0d * this.params[0]) + (i - d);
            }
        });
        canComputePValue(randomSeed, new BaseNonLinearFunction("Gaussian") { // from class: uk.ac.sussex.gdsc.smlm.function.ScmosLikelihoodWrapperTest.5
            public double eval(int i) {
                return 100.0d * StdMath.exp(((-0.5d) * Math.pow(i - d, 2.0d)) / (this.params[0] * this.params[0]));
            }
        });
    }

    private static void canComputePValue(RandomSeed randomSeed, BaseNonLinearFunction baseNonLinearFunction) {
        logger.log(TestLogging.getRecord(TestLogging.TestLevel.TEST_INFO, baseNonLinearFunction.name));
        int i = maxx * maxx;
        double[] dArr = {1.0d};
        baseNonLinearFunction.initialise(dArr);
        SCcmosLikelihoodWrapperTestData sCcmosLikelihoodWrapperTestData = (SCcmosLikelihoodWrapperTestData) dataCache.computeIfAbsent(randomSeed, ScmosLikelihoodWrapperTest::createData);
        float[] fArr = sCcmosLikelihoodWrapperTestData.var;
        float[] fArr2 = sCcmosLikelihoodWrapperTestData.gain;
        float[] fArr3 = sCcmosLikelihoodWrapperTestData.offset;
        float[] fArr4 = sCcmosLikelihoodWrapperTestData.sd;
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        SharedStateContinuousSampler createGaussianSampler = SamplerUtils.createGaussianSampler(create, 0.0d, 1.0d);
        double[] newArray = SimpleArrayUtils.newArray(i, 0.0d, 1.0d);
        for (int i2 = 0; i2 < i; i2++) {
            double eval = baseNonLinearFunction.eval(i2);
            if (eval > 0.0d) {
                eval = GdscSmlmTestUtils.createPoissonSampler(create, eval).sample();
            }
            newArray[i2] = (eval * fArr2[i2]) + fArr3[i2] + (createGaussianSampler.sample() * fArr4[i2]);
        }
        ScmosLikelihoodWrapper scmosLikelihoodWrapper = new ScmosLikelihoodWrapper(baseNonLinearFunction, dArr, newArray, i, fArr, fArr2, fArr3);
        double computeObservedLikelihood = scmosLikelihoodWrapper.computeObservedLikelihood();
        double d = 0.0d;
        double[] dArr2 = new double[i];
        for (int i3 = 0; i3 < i; i3++) {
            dArr2[i3] = ScmosLikelihoodWrapper.likelihood((newArray[i3] - fArr3[i3]) / fArr2[i3], fArr[i3], fArr2[i3], fArr3[i3], newArray[i3]);
            d -= Math.log(dArr2[i3]);
        }
        logger.log(TestLogging.getRecord(TestLogging.TestLevel.TEST_INFO, "oll=%f, oll2=%f", new Object[]{Double.valueOf(computeObservedLikelihood), Double.valueOf(d)}));
        DoubleDoubleBiPredicate doublesAreClose = Predicates.doublesAreClose(1.0E-10d, 0.0d);
        TestAssertions.assertTest(d, computeObservedLikelihood, doublesAreClose, "Observed Log-likelihood");
        DoubleArrayList doubleArrayList = new DoubleArrayList();
        for (int i4 = 5; i4 <= 15; i4++) {
            dArr[0] = i4 / 10.0d;
            double likelihood = scmosLikelihoodWrapper.likelihood(dArr);
            doubleArrayList.add(likelihood);
            double computeLogLikelihoodRatio = scmosLikelihoodWrapper.computeLogLikelihoodRatio(likelihood);
            BigDecimal bigDecimal = new BigDecimal(1);
            double d2 = 0.0d;
            for (int i5 = 0; i5 < i; i5++) {
                double likelihood2 = ScmosLikelihoodWrapper.likelihood(baseNonLinearFunction.eval(i5), fArr[i5], fArr2[i5], fArr3[i5], newArray[i5]);
                d2 -= Math.log(likelihood2);
                bigDecimal = bigDecimal.multiply(new BigDecimal(likelihood2 / dArr2[i5]));
            }
            double log = (-2.0d) * Math.log(bigDecimal.doubleValue());
            logger.log(TestLogging.getRecord(TestLogging.TestLevel.TEST_INFO, "a=%f, ll=%f, ll2=%f, llr=%f, llr2=%f, product=%s, p=%f", new Object[]{Double.valueOf(dArr[0]), Double.valueOf(likelihood), Double.valueOf(d2), Double.valueOf(computeLogLikelihoodRatio), Double.valueOf(log), bigDecimal.round(new MathContext(4)).toString(), Double.valueOf(scmosLikelihoodWrapper.computeQValue(likelihood))}));
            if (bigDecimal.doubleValue() > 0.0d) {
                TestAssertions.assertTest(computeLogLikelihoodRatio, log, doublesAreClose, "Log-likelihood");
            }
        }
        double[] doubleArray = doubleArrayList.toDoubleArray();
        int findMinIndex = SimpleArrayUtils.findMinIndex(doubleArray);
        double d3 = (5 + findMinIndex) / 10.0d;
        double d4 = d3;
        if (findMinIndex == 0) {
            try {
                findMinIndex++;
            } catch (DataException e) {
            }
        }
        if (findMinIndex == doubleArray.length - 1) {
            findMinIndex--;
        }
        d4 = QuadraticUtils.findMinMax((5 + r0) / 10.0d, doubleArray[findMinIndex - 1], (5 + r0) / 10.0d, doubleArray[findMinIndex], (5 + r0) / 10.0d, doubleArray[findMinIndex + 1]);
        logger.log(TestLogging.getRecord(TestLogging.TestLevel.TEST_INFO, "min fit = %g => %g", new Object[]{Double.valueOf(d3), Double.valueOf(d4)}));
        Assertions.assertEquals(1.0d, d4, 0.199d, "min");
    }

    static {
        for (int i = 0; i < NAME.length; i++) {
            NAME[i] = Gaussian2DFunction.getName(i);
        }
        maxx = 10;
        VAR = 57.9f;
        G = 2.2f;
        G_SD = 0.2f;
        O = 100.0f;
    }
}
