package uk.ac.sussex.gdsc.smlm.fitting.nonlinear.gradient;

import java.util.ArrayList;
import org.apache.commons.math3.util.Precision;
import org.ejml.data.DenseMatrix64F;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.Test;
import uk.ac.sussex.gdsc.smlm.function.FakeGradientFunction;
import uk.ac.sussex.gdsc.smlm.function.gaussian.HoltzerAstigmatismZModel;
import uk.ac.sussex.gdsc.smlm.function.gaussian.erf.ErfGaussian2DFunction;
import uk.ac.sussex.gdsc.smlm.function.gaussian.erf.MultiFreeCircularErfGaussian2DFunction;
import uk.ac.sussex.gdsc.smlm.function.gaussian.erf.SingleAstigmatismErfGaussian2DFunction;
import uk.ac.sussex.gdsc.smlm.function.gaussian.erf.SingleFreeCircularErfGaussian2DFunction;
import uk.ac.sussex.gdsc.test.api.Predicates;
import uk.ac.sussex.gdsc.test.api.TestAssertions;
import uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate;
import uk.ac.sussex.gdsc.test.junit5.SeededTest;
import uk.ac.sussex.gdsc.test.junit5.SpeedTag;
import uk.ac.sussex.gdsc.test.rng.RngFactory;
import uk.ac.sussex.gdsc.test.utils.AssertionErrorCounter;
import uk.ac.sussex.gdsc.test.utils.RandomSeed;
import uk.ac.sussex.gdsc.test.utils.functions.IndexSupplier;

/* loaded from: input_file:uk/ac/sussex/gdsc/smlm/fitting/nonlinear/gradient/FastMleJacobianGradient2ProcedureTest.class */
class FastMleJacobianGradient2ProcedureTest extends FastMleGradient2ProcedureTest {
    FastMleJacobianGradient2ProcedureTest() {
    }

    @Override // uk.ac.sussex.gdsc.smlm.fitting.nonlinear.gradient.FastMleGradient2ProcedureTest
    @Test
    void gradientProcedureFactoryCreatesOptimisedProcedures() {
        Assumptions.assumeTrue(false);
    }

    @Override // uk.ac.sussex.gdsc.smlm.fitting.nonlinear.gradient.FastMleGradient2ProcedureTest
    @SeededTest
    void gradientProcedureComputesSameLogLikelihoodAsMleGradientCalculator(RandomSeed randomSeed) {
        Assumptions.assumeTrue(false);
    }

    @Override // uk.ac.sussex.gdsc.smlm.fitting.nonlinear.gradient.FastMleGradient2ProcedureTest
    @SpeedTag
    @SeededTest
    void gradientProcedureIsNotSlowerThanGradientCalculator(RandomSeed randomSeed) {
        Assumptions.assumeTrue(false);
    }

    @Override // uk.ac.sussex.gdsc.smlm.fitting.nonlinear.gradient.FastMleGradient2ProcedureTest
    @SeededTest
    void gradientProcedureComputesSameWithPrecomputed(RandomSeed randomSeed) {
        Assumptions.assumeTrue(false);
    }

    @Override // uk.ac.sussex.gdsc.smlm.fitting.nonlinear.gradient.FastMleGradient2ProcedureTest
    @SeededTest
    void gradientProcedureUnrolledComputesSameAsGradientProcedure(RandomSeed randomSeed) {
        Assumptions.assumeTrue(false);
    }

    @Override // uk.ac.sussex.gdsc.smlm.fitting.nonlinear.gradient.FastMleGradient2ProcedureTest
    @SpeedTag
    @SeededTest
    void gradientProcedureIsFasterUnrolledThanGradientProcedure(RandomSeed randomSeed) {
        Assumptions.assumeTrue(false);
    }

    @SeededTest
    void gradientProcedureComputesSameAsBaseGradientProcedure(RandomSeed randomSeed) {
        DoubleDoubleBiPredicate doublesAreClose = Predicates.doublesAreClose(1.0E-5d, 0.0d);
        gradientProcedureComputesSameAsBaseGradientProcedure(randomSeed, 4, doublesAreClose);
        gradientProcedureComputesSameAsBaseGradientProcedure(randomSeed, 5, doublesAreClose);
        gradientProcedureComputesSameAsBaseGradientProcedure(randomSeed, 6, doublesAreClose);
        gradientProcedureComputesSameAsBaseGradientProcedure(randomSeed, 11, doublesAreClose);
        gradientProcedureComputesSameAsBaseGradientProcedure(randomSeed, 21, doublesAreClose);
    }

    private void gradientProcedureComputesSameAsBaseGradientProcedure(RandomSeed randomSeed, int i, DoubleDoubleBiPredicate doubleDoubleBiPredicate) {
        ArrayList<double[]> arrayList = new ArrayList<>(10);
        ArrayList<double[]> arrayList2 = new ArrayList<>(10);
        createFakeData(RngFactory.create(randomSeed.get()), i, 10, arrayList, arrayList2);
        FakeGradientFunction fakeGradientFunction = new FakeGradientFunction(this.blockWidth, i);
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            FastMleGradient2Procedure createUnrolled = FastMleGradient2ProcedureUtils.createUnrolled(arrayList2.get(i2), fakeGradientFunction);
            FastMleJacobianGradient2Procedure fastMleJacobianGradient2Procedure = new FastMleJacobianGradient2Procedure(arrayList2.get(i2), fakeGradientFunction);
            createUnrolled.computeSecondDerivative(arrayList.get(i2));
            fastMleJacobianGradient2Procedure.computeSecondDerivative(arrayList.get(i2));
            TestAssertions.assertArrayTest(createUnrolled.d1, fastMleJacobianGradient2Procedure.d1, doubleDoubleBiPredicate);
            TestAssertions.assertArrayTest(createUnrolled.d2, fastMleJacobianGradient2Procedure.d2, doubleDoubleBiPredicate);
        }
    }

    @Override // uk.ac.sussex.gdsc.smlm.fitting.nonlinear.gradient.FastMleGradient2ProcedureTest
    @SeededTest
    void gradientCalculatorComputesGradient(RandomSeed randomSeed) {
        gradientCalculatorComputesGradient(randomSeed, 1, new SingleFreeCircularErfGaussian2DFunction(this.blockWidth, this.blockWidth));
        gradientCalculatorComputesGradient(randomSeed, 2, new MultiFreeCircularErfGaussian2DFunction(2, this.blockWidth, this.blockWidth));
        gradientCalculatorComputesGradient(randomSeed, 1, new SingleAstigmatismErfGaussian2DFunction(this.blockWidth, this.blockWidth, HoltzerAstigmatismZModel.create(1.08d, 1.01d, 0.389d, 0.531d, -0.0708d, -0.073d, 0.164d, 0.0417d)));
    }

    private void gradientCalculatorComputesGradient(RandomSeed randomSeed, int i, ErfGaussian2DFunction erfGaussian2DFunction) {
        int numberOfGradients = erfGaussian2DFunction.getNumberOfGradients();
        int[] gradientIndices = erfGaussian2DFunction.gradientIndices();
        ArrayList<double[]> arrayList = new ArrayList<>(100);
        ArrayList<double[]> arrayList2 = new ArrayList<>(100);
        createData(RngFactory.create(randomSeed.get()), i, 100, arrayList, arrayList2, true);
        DoubleDoubleBiPredicate doublesAreClose = Predicates.doublesAreClose(0.05d, 1.0E-16d);
        IndexSupplier messagePrefix = new IndexSupplier(2).setMessagePrefix("Gradient1 ");
        IndexSupplier messagePrefix2 = new IndexSupplier(2).setMessagePrefix("Gradient2 ");
        IndexSupplier messagePrefix3 = new IndexSupplier(3).setMessagePrefix("GradientJ ");
        int computeFailureLimit = AssertionErrorCounter.computeFailureLimit(100, 0.1d);
        AssertionErrorCounter assertionErrorCounter = new AssertionErrorCounter(computeFailureLimit, 2 * numberOfGradients);
        AssertionErrorCounter assertionErrorCounter2 = new AssertionErrorCounter(computeFailureLimit, numberOfGradients * numberOfGradients);
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            messagePrefix.set(0, i2);
            messagePrefix2.set(0, i2);
            messagePrefix3.set(0, i2);
            double[] dArr = arrayList2.get(i2);
            double[] dArr2 = arrayList.get(i2);
            double[] dArr3 = (double[]) dArr2.clone();
            FastMleJacobianGradient2Procedure fastMleJacobianGradient2Procedure = new FastMleJacobianGradient2Procedure(dArr, erfGaussian2DFunction);
            fastMleJacobianGradient2Procedure.computeJacobian(dArr2);
            double[] dArr4 = (double[]) fastMleJacobianGradient2Procedure.d1.clone();
            double[] dArr5 = (double[]) fastMleJacobianGradient2Procedure.d2.clone();
            DenseMatrix64F wrap = DenseMatrix64F.wrap(numberOfGradients, numberOfGradients, fastMleJacobianGradient2Procedure.getJacobianLinear());
            for (int i3 = 0; i3 < numberOfGradients; i3++) {
                int i4 = i3;
                messagePrefix.set(1, i3);
                messagePrefix2.set(1, i3);
                messagePrefix3.set(1, i3);
                int i5 = gradientIndices[i3];
                double representableDelta = Precision.representableDelta(dArr2[i5], dArr2[i5] == 0.0d ? 1.0E-4d : dArr2[i5] * 1.0E-4d);
                dArr3[i5] = dArr2[i5] + representableDelta;
                double computeLogLikelihood = fastMleJacobianGradient2Procedure.computeLogLikelihood(dArr3);
                fastMleJacobianGradient2Procedure.computeFirstDerivative(dArr3);
                double[] dArr6 = (double[]) fastMleJacobianGradient2Procedure.d1.clone();
                dArr3[i5] = dArr2[i5] - representableDelta;
                double computeLogLikelihood2 = fastMleJacobianGradient2Procedure.computeLogLikelihood(dArr3);
                fastMleJacobianGradient2Procedure.computeFirstDerivative(dArr3);
                double[] dArr7 = (double[]) fastMleJacobianGradient2Procedure.d1.clone();
                dArr3[i5] = dArr2[i5];
                double d = (computeLogLikelihood - computeLogLikelihood2) / (2.0d * representableDelta);
                double d2 = (dArr6[i3] - dArr7[i3]) / (2.0d * representableDelta);
                assertionErrorCounter.run(i3, () -> {
                    TestAssertions.assertTest(d, dArr4[i4], doublesAreClose, messagePrefix);
                });
                assertionErrorCounter.run(numberOfGradients + i3, () -> {
                    TestAssertions.assertTest(d2, dArr5[i4], doublesAreClose, messagePrefix2);
                });
                for (int i6 = 0; i6 < numberOfGradients; i6++) {
                    if (i3 == i6) {
                    }
                    int i7 = i6;
                    messagePrefix3.set(2, i6);
                    int i8 = gradientIndices[i6];
                    double representableDelta2 = Precision.representableDelta(dArr2[i8], dArr2[i8] == 0.0d ? 1.0E-4d : dArr2[i8] * 1.0E-4d);
                    dArr3[i8] = dArr2[i8] + representableDelta2;
                    fastMleJacobianGradient2Procedure.computeFirstDerivative(dArr3);
                    System.arraycopy(fastMleJacobianGradient2Procedure.d1, 0, dArr6, 0, dArr6.length);
                    dArr3[i8] = dArr2[i8] - representableDelta2;
                    fastMleJacobianGradient2Procedure.computeFirstDerivative(dArr3);
                    System.arraycopy(fastMleJacobianGradient2Procedure.d1, 0, dArr7, 0, dArr7.length);
                    dArr3[i8] = dArr2[i8];
                    double d3 = (dArr6[i3] - dArr7[i3]) / (2.0d * representableDelta2);
                    assertionErrorCounter2.run((numberOfGradients * i3) + i6, () -> {
                        TestAssertions.assertTest(d3, wrap.get(i4, i7), doublesAreClose, messagePrefix3);
                    });
                }
            }
        }
    }
}
