package uk.ac.sussex.gdsc.smlm.fitting;

import java.util.Arrays;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.commons.rng.RestorableUniformRandomProvider;
import org.apache.commons.rng.UniformRandomProvider;
import org.ejml.data.DenseMatrix64F;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.opentest4j.AssertionFailedError;
import uk.ac.sussex.gdsc.core.utils.MathUtils;
import uk.ac.sussex.gdsc.core.utils.SimpleArrayUtils;
import uk.ac.sussex.gdsc.core.utils.rng.RandomUtils;
import uk.ac.sussex.gdsc.smlm.fitting.nonlinear.gradient.GradientCalculatorUtils;
import uk.ac.sussex.gdsc.smlm.function.gaussian.AstigmatismZModel;
import uk.ac.sussex.gdsc.smlm.function.gaussian.Gaussian2DFunction;
import uk.ac.sussex.gdsc.smlm.function.gaussian.GaussianFunctionFactory;
import uk.ac.sussex.gdsc.test.junit5.SeededTest;
import uk.ac.sussex.gdsc.test.rng.RngFactory;
import uk.ac.sussex.gdsc.test.utils.RandomSeed;
import uk.ac.sussex.gdsc.test.utils.TestLogging;
import uk.ac.sussex.gdsc.test.utils.functions.FormatSupplier;

/* loaded from: input_file:uk/ac/sussex/gdsc/smlm/fitting/FisherInformationMatrixTest.class */
class FisherInformationMatrixTest {
    private static Logger logger;
    private final Level level = TestLogging.TestLevel.TEST_DEBUG;

    FisherInformationMatrixTest() {
    }

    @BeforeAll
    public static void beforeAll() {
        logger = Logger.getLogger(FisherInformationMatrixTest.class.getName());
    }

    @AfterAll
    public static void afterAll() {
        logger = null;
    }

    @SeededTest
    void canComputeCrlb(RandomSeed randomSeed) {
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        for (int i = 1; i < 10; i++) {
            testComputeCrlb(create, i, 0, true);
        }
    }

    @SeededTest
    void canComputeCrlbWithZeros(RandomSeed randomSeed) {
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        for (int i = 2; i < 10; i++) {
            testComputeCrlb(create, i, 1, true);
            testComputeCrlb(create, i, i / 2, true);
        }
    }

    @SeededTest
    void canComputeCrlbWithReciprocal(RandomSeed randomSeed) {
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        for (int i = 1; i < 10; i++) {
            testComputeCrlb(create, i, 0, false);
        }
    }

    @SeededTest
    void canComputeCrlbWithReciprocalWithZeros(RandomSeed randomSeed) {
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        for (int i = 2; i < 10; i++) {
            testComputeCrlb(create, i, 1, false);
            testComputeCrlb(create, i, i / 2, false);
        }
    }

    @SeededTest
    void inversionDoesNotMatchReciprocal(RandomSeed randomSeed) {
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        for (int i = 1; i < 10; i++) {
            FisherInformationMatrix createFisherInformationMatrix = createFisherInformationMatrix(create, i, 0);
            double[] crlb = createFisherInformationMatrix.crlb();
            double[] crlbReciprocal = createFisherInformationMatrix.crlbReciprocal();
            if (logger.isLoggable(this.level)) {
                logger.log(this.level, FormatSupplier.getSupplier("%s =? %s", new Object[]{Arrays.toString(crlb), Arrays.toString(crlbReciprocal)}));
            }
            if (i > 1) {
                Assertions.assertThrows(AssertionFailedError.class, () -> {
                    Assertions.assertEquals(MathUtils.sum(crlb), MathUtils.sum(crlbReciprocal));
                });
            }
        }
    }

    private double[] testComputeCrlb(UniformRandomProvider uniformRandomProvider, int i, int i2, boolean z) {
        FisherInformationMatrix createFisherInformationMatrix = createFisherInformationMatrix(uniformRandomProvider, i, i2);
        double[] crlb = z ? createFisherInformationMatrix.crlb() : createFisherInformationMatrix.crlbReciprocal();
        if (logger.isLoggable(this.level)) {
            logger.log(this.level, FormatSupplier.getSupplier("columns=%d, zeroColumns=%d : %s", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Arrays.toString(crlb)}));
        }
        Assertions.assertNotNull(crlb, () -> {
            return String.format("Crlb failed: columns=%d, zeroColumns=%d", Integer.valueOf(i), Integer.valueOf(i2));
        });
        return crlb;
    }

    private static FisherInformationMatrix createFisherInformationMatrix(UniformRandomProvider uniformRandomProvider, int i, int i2) {
        Gaussian2DFunction gaussian2DFunction;
        int i3 = 1;
        Gaussian2DFunction createFunction = createFunction(10, 1);
        while (true) {
            gaussian2DFunction = createFunction;
            if (gaussian2DFunction.getNumberOfGradients() >= i) {
                break;
            }
            i3++;
            createFunction = createFunction(10, i3);
        }
        double[] dArr = new double[1 + (i3 * 7)];
        dArr[0] = nextUniform(uniformRandomProvider, 1.0d, 5.0d);
        int i4 = 0;
        int i5 = 0;
        while (i4 < i3) {
            dArr[i5 + 1] = nextUniform(uniformRandomProvider, 100.0d, 300.0d);
            dArr[i5 + 2] = nextUniform(uniformRandomProvider, 2 + (i4 * 2), 4 + (i4 * 2));
            dArr[i5 + 3] = nextUniform(uniformRandomProvider, 2 + (i4 * 2), 4 + (i4 * 2));
            dArr[i5 + 5] = nextUniform(uniformRandomProvider, 1.5d, 2.0d);
            dArr[i5 + 6] = nextUniform(uniformRandomProvider, 1.5d, 2.0d);
            i4++;
            i5 += 7;
        }
        gaussian2DFunction.initialise(dArr);
        double[][] dArr2 = (double[][]) Arrays.copyOf(GradientCalculatorUtils.newCalculator(gaussian2DFunction.getNumberOfGradients()).fisherInformationMatrix(100, dArr, gaussian2DFunction), i);
        for (int i6 = 0; i6 < i; i6++) {
            dArr2[i6] = Arrays.copyOf(dArr2[i6], i);
        }
        if (i2 > 0) {
            for (int i7 : RandomUtils.sample(i2, i, uniformRandomProvider)) {
                for (int i8 = 0; i8 < i; i8++) {
                    dArr2[i8][i7] = 0.0d;
                    dArr2[i7][i8] = 0.0d;
                }
            }
        }
        return new FisherInformationMatrix(dArr2, 0.001d);
    }

    private static double nextUniform(UniformRandomProvider uniformRandomProvider, double d, double d2) {
        return d + (uniformRandomProvider.nextDouble() * (d2 - d));
    }

    private static Gaussian2DFunction createFunction(int i, int i2) {
        return GaussianFunctionFactory.create2D(i2, i, i, 285, (AstigmatismZModel) null);
    }

    @SeededTest
    void canProduceSubset(RandomSeed randomSeed) {
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        FisherInformationMatrix createRandomMatrix = createRandomMatrix(create, 10);
        DenseMatrix64F matrix = createRandomMatrix.getMatrix();
        if (logger.isLoggable(this.level)) {
            logger.log(this.level, String.valueOf(matrix));
        }
        for (int i = 1; i < 10; i++) {
            int[] sample = RandomUtils.sample(5, 10, create);
            Arrays.sort(sample);
            DenseMatrix64F matrix2 = createRandomMatrix.subset(sample).getMatrix();
            if (logger.isLoggable(this.level)) {
                logger.log(this.level, FormatSupplier.getSupplier(Arrays.toString(sample), new Object[0]));
                logger.log(this.level, String.valueOf(matrix2));
            }
            for (int i2 = 0; i2 < sample.length; i2++) {
                for (int i3 = 0; i3 < sample.length; i3++) {
                    Assertions.assertEquals(matrix.get(sample[i2], sample[i3]), matrix2.get(i2, i3));
                }
            }
        }
    }

    private static FisherInformationMatrix createRandomMatrix(UniformRandomProvider uniformRandomProvider, int i) {
        double[] dArr = new double[i * i];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = uniformRandomProvider.nextDouble();
        }
        return new FisherInformationMatrix(dArr, i);
    }

    @SeededTest
    void computeWithSubsetReducesTheCrlb(RandomSeed randomSeed) {
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        int gradientParametersPerPeak = createFunction(10, 1).getGradientParametersPerPeak();
        FisherInformationMatrix createFisherInformationMatrix = createFisherInformationMatrix(create, 1 + (2 * gradientParametersPerPeak), 0);
        int[] natural = SimpleArrayUtils.natural(1 + gradientParametersPerPeak);
        FisherInformationMatrix subset = createFisherInformationMatrix.subset(natural);
        for (int i = 1; i < natural.length; i++) {
            int i2 = i;
            natural[i2] = natural[i2] + gradientParametersPerPeak;
        }
        FisherInformationMatrix subset2 = createFisherInformationMatrix.subset(natural);
        double[] crlb = createFisherInformationMatrix.crlb();
        double[] crlb2 = subset.crlb();
        double[] crlb3 = subset2.crlb();
        double[] copyOf = Arrays.copyOf(crlb2, crlb.length);
        System.arraycopy(crlb3, 1, copyOf, crlb2.length, gradientParametersPerPeak);
        for (int i3 = 0; i3 < crlb.length; i3++) {
            Assertions.assertTrue(copyOf[i3] < crlb[i3]);
        }
    }
}
