package uk.ac.sussex.gdsc.smlm.fitting;

import java.util.logging.Logger;
import org.apache.commons.math3.distribution.BinomialDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.rng.RestorableUniformRandomProvider;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.distribution.InverseTransformDiscreteSampler;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.BeforeAll;
import uk.ac.sussex.gdsc.test.junit5.SeededTest;
import uk.ac.sussex.gdsc.test.rng.RngFactory;
import uk.ac.sussex.gdsc.test.utils.RandomSeed;
import uk.ac.sussex.gdsc.test.utils.TestComplexity;
import uk.ac.sussex.gdsc.test.utils.TestLogging;
import uk.ac.sussex.gdsc.test.utils.TestSettings;
import uk.ac.sussex.gdsc.test.utils.functions.FormatSupplier;

/* loaded from: input_file:uk/ac/sussex/gdsc/smlm/fitting/BinomialFitterTest.class */
class BinomialFitterTest {
    private static Logger logger;
    static final int[] N = {2, 3, 4, 6};
    static final double[] P = {0.3d, 0.5d, 0.7d};
    static final int TRIALS = 10;
    static final int FAILURES = 3;
    TestComplexity optionalTestComplexity = TestComplexity.HIGH;
    TestComplexity nonEssentialTestComplexity = TestComplexity.VERY_HIGH;

    BinomialFitterTest() {
    }

    @BeforeAll
    public static void beforeAll() {
        logger = Logger.getLogger(BinomialFitterTest.class.getName());
    }

    @AfterAll
    public static void afterAll() {
        logger = null;
    }

    @SeededTest
    void canFitBinomialWithKnownNUsingLeastSquaresEstimator(RandomSeed randomSeed) {
        Assumptions.assumeTrue(TestSettings.allow(this.nonEssentialTestComplexity));
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        for (int i : N) {
            for (double d : P) {
                fitBinomial(create, i, d, false, false, i, i);
            }
        }
    }

    @SeededTest
    void canFitBinomialWithKnownNUsingMaximumLikelihood(RandomSeed randomSeed) {
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        if (!TestSettings.allow(this.optionalTestComplexity)) {
            fitBinomial(create, 2, 0.5d, false, true, 2, 2);
            return;
        }
        for (int i : N) {
            for (double d : P) {
                fitBinomial(create, i, d, false, true, i, i);
            }
        }
    }

    @SeededTest
    void canFitBinomialWithUnknownNUsingLeastSquaresEstimator(RandomSeed randomSeed) {
        Assumptions.assumeTrue(TestSettings.allow(this.nonEssentialTestComplexity));
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        for (int i : N) {
            for (double d : P) {
                fitBinomial(create, i, d, false, false, 1, i);
            }
        }
    }

    @SeededTest
    void canFitBinomialWithUnknownNUsingMaximumLikelihood(RandomSeed randomSeed) {
        Assumptions.assumeTrue(TestSettings.allow(this.optionalTestComplexity));
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        for (int i : N) {
            for (double d : P) {
                fitBinomial(create, i, d, false, true, 1, i);
            }
        }
    }

    @SeededTest
    void canFitZeroTruncatedBinomialWithKnownNUsingLeastSquaresEstimator(RandomSeed randomSeed) {
        Assumptions.assumeTrue(TestSettings.allow(this.nonEssentialTestComplexity));
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        for (int i : N) {
            for (double d : P) {
                fitBinomial(create, i, d, true, false, i, i);
            }
        }
    }

    @SeededTest
    void canFitZeroTruncatedBinomialWithKnownNUsingMaximumLikelihood(RandomSeed randomSeed) {
        Assumptions.assumeTrue(TestSettings.allow(this.nonEssentialTestComplexity));
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        for (int i : N) {
            for (double d : P) {
                fitBinomial(create, i, d, true, true, i, i);
            }
        }
    }

    @SeededTest
    void canFitZeroTruncatedBinomialWithUnknownNUsingLeastSquaresEstimator(RandomSeed randomSeed) {
        Assumptions.assumeTrue(TestSettings.allow(this.nonEssentialTestComplexity));
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        for (int i : N) {
            for (double d : P) {
                fitBinomial(create, i, d, true, false, 1, i);
            }
        }
    }

    @SeededTest
    void canFitZeroTruncatedBinomialWithUnknownNUsingMaximumLikelihood(RandomSeed randomSeed) {
        Assumptions.assumeTrue(TestSettings.allow(this.nonEssentialTestComplexity));
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        for (int i : N) {
            for (double d : P) {
                fitBinomial(create, i, d, true, false, 1, i);
            }
        }
    }

    @SeededTest
    void sameFitBinomialWithKnownNUsingLseOrMle(RandomSeed randomSeed) {
        Assumptions.assumeTrue(TestSettings.allow(this.nonEssentialTestComplexity));
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        for (int i : N) {
            for (double d : P) {
                fitBinomialUsingLseOrMle(create, i, d, false, i, i);
            }
        }
    }

    @SeededTest
    void sameFitZeroTruncatedBinomialWithKnownNUsingLseOrMle(RandomSeed randomSeed) {
        Assumptions.assumeTrue(TestSettings.allow(this.nonEssentialTestComplexity));
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        for (int i : N) {
            for (double d : P) {
                fitBinomialUsingLseOrMle(create, i, d, true, i, i);
            }
        }
    }

    private static void fitBinomial(UniformRandomProvider uniformRandomProvider, int i, double d, boolean z, boolean z2, int i2, int i3) {
        BinomialFitter binomialFitter = new BinomialFitter((Logger) null);
        binomialFitter.setMaximumLikelihood(z2);
        logger.log(TestLogging.TestLevel.TEST_INFO, FormatSupplier.getSupplier("Fitting (n=%d, p=%f)", new Object[]{Integer.valueOf(i), Double.valueOf(d)}));
        int i4 = 0;
        for (int i5 = 0; i5 < TRIALS; i5++) {
            double[] fitBinomial = binomialFitter.fitBinomial(createData(uniformRandomProvider, i, d, false), i2, i3, z);
            int i6 = (int) fitBinomial[0];
            double d2 = fitBinomial[1];
            logger.log(TestLogging.TestLevel.TEST_INFO, FormatSupplier.getSupplier("  Fitted (n=%d, p=%f)", new Object[]{Integer.valueOf(i6), Double.valueOf(d2)}));
            try {
                Assertions.assertEquals(i, i6, "Failed to fit n");
                Assertions.assertEquals(d, d2, 0.05d, "Failed to fit p");
            } catch (AssertionError e) {
                i4++;
                logger.log(TestLogging.TestLevel.TEST_INFO, "    " + e.getMessage());
            }
        }
        Assertions.assertTrue(i4 <= 3, FormatSupplier.getSupplier("Too many failures (n=%d, p=%f): %d", new Object[]{Integer.valueOf(i), Double.valueOf(d), Integer.valueOf(i4)}));
    }

    private static void fitBinomialUsingLseOrMle(UniformRandomProvider uniformRandomProvider, int i, double d, boolean z, int i2, int i3) {
        BinomialFitter binomialFitter = new BinomialFitter((Logger) null);
        logger.log(TestLogging.TestLevel.TEST_INFO, FormatSupplier.getSupplier("Fitting (n=%d, p=%f)", new Object[]{Integer.valueOf(i), Double.valueOf(d)}));
        int i4 = 0;
        int i5 = 0;
        for (int i6 = 0; i6 < TRIALS; i6++) {
            int[] createData = createData(uniformRandomProvider, i, d, false);
            binomialFitter.setMaximumLikelihood(false);
            double[] fitBinomial = binomialFitter.fitBinomial(createData, i2, i3, z);
            binomialFitter.setMaximumLikelihood(true);
            double[] fitBinomial2 = binomialFitter.fitBinomial(createData, i2, i3, z);
            int i7 = (int) fitBinomial[0];
            double d2 = fitBinomial[1];
            int i8 = (int) fitBinomial2[0];
            double d3 = fitBinomial2[1];
            logger.log(TestLogging.TestLevel.TEST_INFO, FormatSupplier.getSupplier("  Fitted LSE (n=%d, p=%f) == MLE (n=%d, p=%f)", new Object[]{Integer.valueOf(i7), Double.valueOf(d2), Integer.valueOf(i8), Double.valueOf(d3)}));
            try {
                Assertions.assertEquals(i7, i8, "Failed to match n");
                Assertions.assertEquals(d2, d3, 0.05d, "Failed to match p");
            } catch (AssertionError e) {
                i4++;
                logger.log(TestLogging.TestLevel.TEST_INFO, "    " + e.getMessage());
            }
            if (Math.abs(d2 - d) < Math.abs(d3 - d)) {
                i5++;
            }
        }
        logger.log(TestLogging.TestLevel.TEST_INFO, FormatSupplier.getSupplier("  Closest LSE %d, MLE %d", new Object[]{Integer.valueOf(i5), Integer.valueOf(TRIALS - i5)}));
        if (i4 > 3) {
            Assertions.fail(String.format("Too many failures (n=%d, p=%f): %d", Integer.valueOf(i), Double.valueOf(d), Integer.valueOf(i4)));
        }
    }

    private static int[] createData(UniformRandomProvider uniformRandomProvider, int i, double d, boolean z) {
        int sample;
        BinomialDistribution binomialDistribution = new BinomialDistribution((RandomGenerator) null, i, d);
        InverseTransformDiscreteSampler inverseTransformDiscreteSampler = new InverseTransformDiscreteSampler(uniformRandomProvider, d2 -> {
            return binomialDistribution.inverseCumulativeProbability(d2);
        });
        int[] iArr = new int[2000];
        if (!z) {
            for (int i2 = 0; i2 < iArr.length; i2++) {
                iArr[i2] = inverseTransformDiscreteSampler.sample();
            }
        } else {
            if (d <= 0.0d) {
                throw new RuntimeException("p must be positive");
            }
            for (int i3 = 0; i3 < iArr.length; i3++) {
                do {
                    sample = inverseTransformDiscreteSampler.sample();
                } while (sample == 0);
                iArr[i3] = sample;
            }
        }
        return iArr;
    }
}
