package uk.ac.sussex.gdsc.smlm.function;

import java.util.function.Supplier;
import java.util.logging.Logger;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.integration.SimpsonIntegrator;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import uk.ac.sussex.gdsc.core.utils.MathUtils;
import uk.ac.sussex.gdsc.test.api.Predicates;
import uk.ac.sussex.gdsc.test.api.TestAssertions;
import uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate;
import uk.ac.sussex.gdsc.test.junit5.SpeedTag;
import uk.ac.sussex.gdsc.test.utils.TestComplexity;
import uk.ac.sussex.gdsc.test.utils.TestLogging;
import uk.ac.sussex.gdsc.test.utils.TestSettings;

/* loaded from: input_file:uk/ac/sussex/gdsc/smlm/function/PoissonGaussianFunctionTest.class */
class PoissonGaussianFunctionTest {
    private static Logger logger;
    static double[] gain = {0.25d, 0.5d, 1.0d, 2.0d, 4.0d};
    static double[] photons = {-1.0d, 0.0d, 0.1d, 0.25d, 0.5d, 1.0d, 2.0d, 4.0d, 10.0d, 100.0d, 1000.0d};
    static double[] noise = {1.0d, 2.0d, 4.0d, 8.0d};

    PoissonGaussianFunctionTest() {
    }

    @BeforeAll
    public static void beforeAll() {
        logger = Logger.getLogger(PoissonGaussianFunctionTest.class.getName());
    }

    @AfterAll
    public static void afterAll() {
        logger = null;
    }

    @Test
    void cumulativeProbabilityIsOneWithPicard() {
        for (double d : gain) {
            for (double d2 : photons) {
                for (double d3 : noise) {
                    cumulativeProbabilityIsOne(d, d2, d3, true);
                }
            }
        }
    }

    @Test
    void cumulativeProbabilityIsOneWithPade() {
        for (double d : gain) {
            for (double d2 : photons) {
                for (double d3 : noise) {
                    cumulativeProbabilityIsOne(d, d2, d3, false);
                }
            }
        }
    }

    @Test
    void cumulativeProbabilityIsNotOneWhenMeanIsLowAndNoiseIsLow() {
        Assertions.assertTrue(1.02d < cumulativeProbability(1.7d, 0.25d, 0.01d, true));
        Assertions.assertTrue(1.02d < cumulativeProbability(1.7d, 0.25d, 0.1d, true));
        Assertions.assertEquals(1.0d, cumulativeProbability(1.7d, 0.25d, 0.3d, true), 0.02d);
        Assertions.assertTrue(0.98d > cumulativeProbability(1.7d, 0.25d, 0.5d, true));
        Assertions.assertEquals(1.0d, cumulativeProbability(1.7d, 0.25d, 0.75d, true), 0.02d);
        Assertions.assertEquals(1.0d, cumulativeProbability(1.7d, 10.0d, 0.01d, true), 0.02d);
        Assertions.assertEquals(1.0d, cumulativeProbability(1.7d, 10.0d, 0.1d, true), 0.02d);
        Assertions.assertEquals(1.0d, cumulativeProbability(1.7d, 10.0d, 0.3d, true), 0.02d);
        Assertions.assertEquals(1.0d, cumulativeProbability(1.7d, 10.0d, 0.5d, true), 0.02d);
        Assertions.assertEquals(1.0d, cumulativeProbability(1.7d, 10.0d, 0.75d, true), 0.02d);
    }

    private static void cumulativeProbabilityIsOne(double d, double d2, double d3, boolean z) {
        Assertions.assertEquals(1.0d, cumulativeProbability(d, d2, d3, z), 0.02d, () -> {
            return String.format("g=%f, mu=%f, s=%f", Double.valueOf(d), Double.valueOf(d2), Double.valueOf(d3));
        });
    }

    private static double cumulativeProbability(double d, double d2, double d3, boolean z) {
        int i;
        int i2;
        final PoissonGaussianFunction createWithStandardDeviation = PoissonGaussianFunction.createWithStandardDeviation(1.0d / d, d2, d3 * d);
        createWithStandardDeviation.setUsePicardApproximation(z);
        double d4 = 0.0d;
        int i3 = 1;
        int i4 = 0;
        if (d2 > 0.0d) {
            int[] range = getRange(d, d2, d3);
            i3 = range[0];
            i4 = range[1];
            for (int i5 = i3; i5 <= i4; i5++) {
                d4 += createWithStandardDeviation.probability(i5);
            }
        }
        int i6 = i3 - 1;
        while (true) {
            i = i6;
            double probability = createWithStandardDeviation.probability(i6);
            d4 += probability;
            if (probability == 0.0d || probability / d4 < 1.0E-6d) {
                break;
            }
            i6--;
        }
        int i7 = i4 + 1;
        while (true) {
            i2 = i7;
            double probability2 = createWithStandardDeviation.probability(i7);
            d4 += probability2;
            if (probability2 == 0.0d || probability2 / d4 < 1.0E-6d) {
                break;
            }
            i7++;
        }
        double integrate = new SimpsonIntegrator(1.0E-6d, 1.0E-6d, 4, 64).integrate(Integer.MAX_VALUE, new UnivariateFunction() { // from class: uk.ac.sussex.gdsc.smlm.function.PoissonGaussianFunctionTest.1
            public double value(double d5) {
                return createWithStandardDeviation.probability(d5);
            }
        }, i, i2);
        if (integrate < 0.98d || integrate > 1.02d) {
            logger.log(TestLogging.getRecord(TestLogging.TestLevel.TEST_INFO, "g=%f, mu=%f, s=%f p=%f  %f", new Object[]{Double.valueOf(d), Double.valueOf(d2), Double.valueOf(d3), Double.valueOf(d4), Double.valueOf(integrate)}));
        }
        return integrate;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static int[] getRange(double d, double d2, double d3) {
        double max = Math.max(d3, Math.sqrt(d2));
        return new int[]{(int) Math.floor(d * (d2 - (3.0d * max))), (int) Math.ceil(d * (d2 + (3.0d * max)))};
    }

    @Test
    void probabilityMatchesLogProbability() {
        for (double d : gain) {
            for (double d2 : photons) {
                for (double d3 : noise) {
                    probabilityMatchesLogProbability(d, d2, d3, true);
                    probabilityMatchesLogProbability(d, d2, d3, false);
                }
            }
        }
    }

    private static void probabilityMatchesLogProbability(double d, double d2, double d3, boolean z) {
        PoissonGaussianFunction createWithStandardDeviation = PoissonGaussianFunction.createWithStandardDeviation(1.0d / d, d2, d3 * d);
        createWithStandardDeviation.setUsePicardApproximation(z);
        int[] range = getRange(d, d2, d3);
        int i = range[0];
        int i2 = range[1];
        Supplier supplier = () -> {
            return String.format("g=%f, mu=%f, s=%f", Double.valueOf(d), Double.valueOf(d2), Double.valueOf(d3));
        };
        DoubleDoubleBiPredicate doublesAreClose = Predicates.doublesAreClose(0.001d, 0.0d);
        for (int i3 = i; i3 <= i2; i3++) {
            double probability = createWithStandardDeviation.probability(i3);
            if (probability != 0.0d) {
                TestAssertions.assertTest(Math.log(probability), createWithStandardDeviation.logProbability(i3), doublesAreClose, supplier);
            }
        }
    }

    @SpeedTag
    @Test
    void padeIsFaster() {
        Assumptions.assumeTrue(TestSettings.allow(TestComplexity.MEDIUM));
        double[] dArr = new double[noise.length];
        for (int i = 0; i < noise.length; i++) {
            dArr[i] = noise[i] * noise[i];
        }
        double[][] dArr2 = new double[photons.length][100];
        for (int i2 = 0; i2 < photons.length; i2++) {
            double d = (photons[i2] * 2.0d) / 100.0d;
            for (int i3 = 0; i3 < 100; i3++) {
                dArr2[i2][i3] = d * i3;
            }
        }
        long time = getTime(dArr, dArr2, true);
        long time2 = getTime(dArr, dArr2, false);
        logger.log(TestLogging.getRecord(TestLogging.TestLevel.TEST_INFO, "Picard %d : Pade %d (%fx)", new Object[]{Long.valueOf(time), Long.valueOf(time2), Double.valueOf(time / time2)}));
        Assertions.assertTrue(time2 < time, () -> {
            return String.format("Picard %d < Pade %d", Long.valueOf(time), Long.valueOf(time2));
        });
    }

    private static long getTime(double[] dArr, double[][] dArr2, boolean z) {
        int length = dArr2[0].length;
        for (double d : dArr) {
            for (int i = 0; i < photons.length; i++) {
                double d2 = photons[i];
                for (int i2 = 0; i2 < length; i2++) {
                    PoissonGaussianFunction.probability(dArr2[i][i2], d2, d, z);
                }
            }
        }
        long nanoTime = System.nanoTime();
        for (double d3 : dArr) {
            for (int i3 = 0; i3 < photons.length; i3++) {
                double d4 = photons[i3];
                for (int i4 = 0; i4 < length; i4++) {
                    PoissonGaussianFunction.probability(dArr2[i3][i4], d4, d3, z);
                }
            }
        }
        return System.nanoTime() - nanoTime;
    }

    @Test
    void staticMethodsMatchInstanceMethods() {
        for (double d : gain) {
            for (double d2 : photons) {
                for (double d3 : noise) {
                    staticMethodsMatchInstanceMethods(d, d2, d3, true);
                    staticMethodsMatchInstanceMethods(d, d2, d3, false);
                }
            }
        }
    }

    private static void staticMethodsMatchInstanceMethods(double d, double d2, double d3, boolean z) {
        PoissonGaussianFunction createWithStandardDeviation = PoissonGaussianFunction.createWithStandardDeviation(1.0d / d, d2, d3 * d);
        createWithStandardDeviation.setUsePicardApproximation(z);
        int[] range = getRange(d, d2, d3);
        int i = range[0];
        int i2 = range[1];
        double log = Math.log(d);
        double pow2 = MathUtils.pow2(d3);
        Supplier supplier = () -> {
            return String.format("probability g=%f, mu=%f, s=%f", Double.valueOf(d), Double.valueOf(d2), Double.valueOf(d3));
        };
        Supplier supplier2 = () -> {
            return String.format("logProbability g=%f, mu=%f, s=%f", Double.valueOf(d), Double.valueOf(d2), Double.valueOf(d3));
        };
        for (int i3 = i; i3 <= i2; i3++) {
            Assertions.assertEquals(createWithStandardDeviation.probability(i3), PoissonGaussianFunction.probability(i3 / d, d2, pow2, z) / d, 1.0E-10d, supplier);
            Assertions.assertEquals(createWithStandardDeviation.logProbability(i3), PoissonGaussianFunction.logProbability(i3 / d, d2, pow2, z) - log, 1.0E-10d, supplier2);
        }
    }
}
