package uk.ac.sussex.gdsc.smlm.math3.distribution.fitting;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.IntStream;
import org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution;
import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
import org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.stat.correlation.Covariance;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.apache.commons.math3.util.Pair;
import org.apache.commons.rng.RestorableUniformRandomProvider;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.distribution.ContinuousUniformSampler;
import org.apache.commons.rng.sampling.distribution.SharedStateContinuousSampler;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import uk.ac.sussex.gdsc.core.utils.LocalList;
import uk.ac.sussex.gdsc.core.utils.SimpleArrayUtils;
import uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter;
import uk.ac.sussex.gdsc.core.utils.rng.RandomUtils;
import uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization;
import uk.ac.sussex.gdsc.test.api.Predicates;
import uk.ac.sussex.gdsc.test.api.TestAssertions;
import uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate;
import uk.ac.sussex.gdsc.test.junit5.SeededTest;
import uk.ac.sussex.gdsc.test.rng.RngFactory;
import uk.ac.sussex.gdsc.test.utils.RandomSeed;

/* loaded from: input_file:uk/ac/sussex/gdsc/smlm/math3/distribution/fitting/MultivariateGaussianMixtureExpectationMaximizationTest.class */
class MultivariateGaussianMixtureExpectationMaximizationTest {
    private static final MultivariateGaussianMixtureExpectationMaximization.DoubleDoubleBiPredicate DEFAULT_CONVERGENCE_CHECKER;

    MultivariateGaussianMixtureExpectationMaximizationTest() {
    }

    @SeededTest
    void canComputeCovariance(RandomSeed randomSeed) {
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        double[][] dArr = new double[20][3];
        for (int i = 0; i < 20; i++) {
            for (int i2 = 0; i2 < 3; i2++) {
                dArr[i][i2] = create.nextDouble();
            }
        }
        Assertions.assertArrayEquals(new Covariance(dArr).getCovarianceMatrix().getData(), MultivariateGaussianMixtureExpectationMaximization.covariance(getColumnMeans(dArr), dArr));
    }

    @Test
    void testCreateMultivariateGaussianDistributionThrows() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution.create(new double[2], new double[3][3]);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution.create(new double[2], new double[2][3]);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution.create(new double[]{3.4d, 5.6d}, (double[][]) new double[]{new double[]{1.1d, 2.3d}, new double[]{2.3d, 2.0d}});
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution.create(new double[]{3.2394652282161466E-159d, 0.9874091290326361d, 0.3148579631286939d, 0.3718812084490761d}, (double[][]) new double[]{new double[]{4.7656408265957816E-164d, 1.1864890699789289E-160d, 3.486364577999923E-160d, 2.799059273090178E-160d}, new double[]{1.1864890699789289E-160d, 0.24556306971562533d, 0.03418519684554344d, 0.04106282423640166d}, new double[]{3.486364577999923E-160d, 0.03418519684554344d, 0.005991172243435974d, 0.007750670597351695d}, new double[]{2.799059273090178E-160d, 0.04106282423640166d, 0.007750670597351695d, 0.03725445634104687d}});
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    @Test
    void canCreateMultivariateGaussianDistribution() {
        ?? r0 = {new double[]{1.0d, 2.0d}, new double[]{2.5d, 1.5d}, new double[]{3.5d, 1.0d}};
        double[] columnMeans = getColumnMeans(r0);
        double[][] covariance = getCovariance(r0);
        MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution create = MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution.create(columnMeans, covariance);
        Assertions.assertSame(columnMeans, create.getMeans());
        Assertions.assertSame(covariance, create.getCovariances());
        double[] standardDeviations = create.getStandardDeviations();
        Assertions.assertEquals(covariance.length, standardDeviations.length);
        for (int i = 0; i < standardDeviations.length; i++) {
            Assertions.assertEquals(Math.sqrt(covariance[i][i]), standardDeviations[i]);
        }
        MultivariateNormalDistribution multivariateNormalDistribution = new MultivariateNormalDistribution(columnMeans, covariance);
        for (double[] dArr : r0) {
            Assertions.assertEquals(multivariateNormalDistribution.density(dArr), create.density(dArr));
        }
    }

    @Test
    void testCreateUnmixedMultivariateGaussianDistributionThrows() {
        Assertions.assertThrows(NullPointerException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.createUnmixed((double[][]) null);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.createUnmixed((double[][]) new double[0]);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.createUnmixed((double[][]) new double[1]);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.createUnmixed(new double[2][1]);
        });
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    @Test
    void canCreateUnmixedMultivariateGaussianDistribution() {
        ?? r0 = {new double[]{1.0d, 2.0d}, new double[]{2.5d, 1.5d}, new double[]{3.5d, 1.0d}};
        MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution create = MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution.create(getColumnMeans(r0), getCovariance(r0));
        MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution createUnmixed = MultivariateGaussianMixtureExpectationMaximization.createUnmixed((double[][]) r0);
        DoubleDoubleBiPredicate doublesAreRelativelyClose = Predicates.doublesAreRelativelyClose(1.0E-6d);
        TestAssertions.assertArrayTest(create.getMeans(), createUnmixed.getMeans(), doublesAreRelativelyClose);
        TestAssertions.assertArrayTest(create.getCovariances(), createUnmixed.getCovariances(), doublesAreRelativelyClose);
    }

    @Test
    void testCreateMixedMultivariateGaussianDistributionThrows() {
        double[][] dArr = new double[2][2];
        int[] iArr = {0, 1};
        Assertions.assertThrows(NullPointerException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.createMixed((double[][]) null, iArr);
        });
        Assertions.assertThrows(NullPointerException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.createMixed(dArr, (int[]) null);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.createMixed(new double[0][2], iArr);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.createMixed(new double[1][2], iArr);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.createMixed(new double[3][2], iArr);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.createMixed(new double[2][1], iArr);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.createMixed(dArr, new int[1]);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.createMixed(dArr, new int[2]);
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v7, types: [double[], double[][], java.lang.Object[]] */
    /* JADX WARN: Type inference failed for: r0v9, types: [double[], double[][], java.lang.Object[]] */
    /* JADX WARN: Type inference failed for: r1v18, types: [double[], java.lang.Object[]] */
    @SeededTest
    void canCreateMixedMultivariateGaussianDistribution(RandomSeed randomSeed) {
        double[] dArr = {1.0d, 1.0d};
        ?? r0 = {new double[]{1.0d, 2.0d}, new double[]{2.5d, 1.5d}, new double[]{3.5d, 1.0d}};
        ?? r02 = {new double[]{4.0d, 2.0d}, new double[]{3.5d, -1.5d}, new double[]{-3.5d, 1.0d}};
        double[] dArr2 = {getColumnMeans(r0), getColumnMeans(r02)};
        double[][] dArr3 = {getCovariance(r0), getCovariance(r02)};
        LocalList localList = new LocalList();
        localList.addAll(Arrays.asList(r0));
        localList.addAll(Arrays.asList(r02));
        double[][] dArr4 = (double[][]) localList.toArray((Object[]) new double[0]);
        int[] iArr = {-1, -1, -1, 3, 3, 3};
        for (int i = 0; i < 3; i++) {
            long asLong = i + randomSeed.getAsLong();
            RandomUtils.shuffle(dArr4, RngFactory.create(asLong));
            RandomUtils.shuffle(iArr, RngFactory.create(asLong));
            MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution createMixed = MultivariateGaussianMixtureExpectationMaximization.createMixed(dArr4, iArr);
            Assertions.assertArrayEquals(new double[]{0.5d, 0.5d}, createMixed.getWeights());
            MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution[] distributions = createMixed.getDistributions();
            Assertions.assertEquals(dArr.length, distributions.length);
            DoubleDoubleBiPredicate doublesAreRelativelyClose = Predicates.doublesAreRelativelyClose(1.0E-8d);
            for (int i2 = 0; i2 < dArr2.length; i2++) {
                TestAssertions.assertArrayTest(dArr2[i2], distributions[i2].getMeans(), doublesAreRelativelyClose);
                TestAssertions.assertArrayTest(dArr3[i2], distributions[i2].getCovariances(), doublesAreRelativelyClose);
            }
        }
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v5, types: [double[][], double[][][]] */
    /* JADX WARN: Type inference failed for: r0v7, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v9, types: [double[], double[][]] */
    @Test
    void testCreateMixtureMultivariateGaussianDistributionThrows() {
        double[] dArr = {1.0d, 3.0d};
        ?? r0 = {new double[]{1.0d, 2.0d}, new double[]{2.5d, 1.5d}, new double[]{3.5d, 1.0d}};
        ?? r02 = {new double[]{4.0d, 2.0d}, new double[]{3.5d, -1.5d}, new double[]{-3.5d, 1.0d}};
        ?? r03 = {getColumnMeans(r0), getColumnMeans(r02)};
        ?? r04 = {getCovariance(r0), getCovariance(r02)};
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.create(new double[]{-1.0d, 1.0d}, r03, r04);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.create(new double[]{Double.MAX_VALUE, Double.MAX_VALUE}, r03, r04);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            double[][] dArr2 = (double[][]) r03.clone();
            dArr2[0] = Arrays.copyOf(dArr2[0], dArr2[0].length + 1);
            MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.create(dArr, dArr2, r04);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            double[][][] dArr2 = (double[][][]) r04.clone();
            dArr2[0] = new double[3][2];
            MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.create(dArr, r03, dArr2);
        });
        MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution create = MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.create(dArr, (double[][]) r03, (double[][][]) r04);
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.create(new double[]{Double.MAX_VALUE, Double.MAX_VALUE}, create.getDistributions());
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.create(new double[]{1.0d}, create.getDistributions());
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v3, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v5, types: [double[][], double[][][]] */
    /* JADX WARN: Type inference failed for: r0v7, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v9, types: [double[], double[][]] */
    @Test
    void canCreateMixtureMultivariateGaussianDistribution() {
        double[] dArr = {1.0d, 3.0d};
        ?? r0 = {new double[]{1.0d, 2.0d}, new double[]{2.5d, 1.5d}, new double[]{3.5d, 1.0d}};
        ?? r02 = {new double[]{4.0d, 2.0d}, new double[]{3.5d, -1.5d}, new double[]{-3.5d, 1.0d}};
        ?? r03 = {getColumnMeans(r0), getColumnMeans(r02)};
        ?? r04 = {getCovariance(r0), getCovariance(r02)};
        MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution create = MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.create(dArr, (double[][]) r03, (double[][][]) r04);
        Assertions.assertArrayEquals(new double[]{0.25d, 0.75d}, create.getWeights());
        MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution[] distributions = create.getDistributions();
        Assertions.assertEquals(dArr.length, distributions.length);
        for (int i = 0; i < r03.length; i++) {
            Assertions.assertArrayEquals(r03[i], distributions[i].getMeans());
            Assertions.assertArrayEquals(r04[i], distributions[i].getCovariances());
        }
        MixtureMultivariateNormalDistribution mixtureMultivariateNormalDistribution = new MixtureMultivariateNormalDistribution(dArr, (double[][]) r03, (double[][][]) r04);
        for (double[] dArr2 : r0) {
            Assertions.assertEquals(mixtureMultivariateNormalDistribution.density(dArr2), create.density(dArr2), 1.0E-10d);
        }
        Assertions.assertArrayEquals(new double[]{1.0d, 3.0d}, dArr);
        Assertions.assertArrayEquals(dArr, MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.create(dArr, distributions).getWeights());
        Assertions.assertArrayEquals(new double[]{0.25d, 0.75d}, dArr);
    }

    @Test
    void testEstimateInitialMixtureThrows() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.estimate(new double[1][2], 2);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.estimate(new double[2][2], 1);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.estimate(new double[2][2], 3);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            MultivariateGaussianMixtureExpectationMaximization.estimate(new double[2][1], 2);
        });
    }

    @SeededTest
    void canEstimateInitialMixture(RandomSeed randomSeed) {
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        DoubleDoubleBiPredicate doublesAreClose = Predicates.doublesAreClose(1.0E-5d, 1.0E-16d);
        for (int i = 2; i <= 3; i++) {
            double[][] createData2d = createData2d(1000, create, createWeights(i, create), create(i, 2, create, -5.0d, 5.0d), create(i, 2, create, 1.0d, 10.0d), create(i, create, -0.9d, 0.9d));
            MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution estimate = MultivariateGaussianMixtureExpectationMaximization.estimate(createData2d, i);
            List components = MultivariateNormalMixtureExpectationMaximization.estimate(createData2d, i).getComponents();
            double[] weights = estimate.getWeights();
            MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution[] distributions = estimate.getDistributions();
            Assertions.assertEquals(i, components.size());
            Assertions.assertEquals(i, weights.length);
            Assertions.assertEquals(i, distributions.length);
            for (int i2 = 0; i2 < i; i2++) {
                Assertions.assertEquals((Double) ((Pair) components.get(i2)).getFirst(), weights[i2], "weight");
                MultivariateNormalDistribution multivariateNormalDistribution = (MultivariateNormalDistribution) ((Pair) components.get(i2)).getSecond();
                TestAssertions.assertArrayTest(multivariateNormalDistribution.getMeans(), distributions[i2].getMeans(), doublesAreClose, "means");
                TestAssertions.assertArrayTest(multivariateNormalDistribution.getCovariances().getData(), distributions[i2].getCovariances(), doublesAreClose, "covariances");
            }
        }
    }

    @Test
    void testCreateMultivariateGaussianMixtureExpectationMaximizationThrows() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            new MultivariateGaussianMixtureExpectationMaximization(new double[0][2]);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            new MultivariateGaussianMixtureExpectationMaximization(new double[2][1]);
        });
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            new MultivariateGaussianMixtureExpectationMaximization((double[][]) new double[]{new double[]{0.0d, 1.0d, 2.0d}, new double[]{3.0d, 4.0d}});
        });
    }

    /* JADX WARN: Type inference failed for: r9v1, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r9v3, types: [double[], double[][]] */
    @Test
    void testFitThrows() {
        MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution mixtureMultivariateGaussianDistribution = null;
        MultivariateGaussianMixtureExpectationMaximization multivariateGaussianMixtureExpectationMaximization = new MultivariateGaussianMixtureExpectationMaximization(new double[2][3]);
        Assertions.assertEquals(0.0d, multivariateGaussianMixtureExpectationMaximization.getLogLikelihood());
        Assertions.assertEquals(0, multivariateGaussianMixtureExpectationMaximization.getIterations());
        Assertions.assertNull(multivariateGaussianMixtureExpectationMaximization.getFittedModel());
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            multivariateGaussianMixtureExpectationMaximization.fit(mixtureMultivariateGaussianDistribution, 0, DEFAULT_CONVERGENCE_CHECKER);
        });
        Assertions.assertThrows(NullPointerException.class, () -> {
            multivariateGaussianMixtureExpectationMaximization.fit(mixtureMultivariateGaussianDistribution, 10, (MultivariateGaussianMixtureExpectationMaximization.DoubleDoubleBiPredicate) null);
        });
        Assertions.assertThrows(NullPointerException.class, () -> {
            multivariateGaussianMixtureExpectationMaximization.fit(mixtureMultivariateGaussianDistribution, 10, DEFAULT_CONVERGENCE_CHECKER);
        });
        MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution mixtureMultivariateGaussianDistribution2 = new MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution(new double[]{0.5d, 0.5d}, new MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution[]{new MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution(new double[2], (double[][]) new double[]{new double[]{1.0d, 0.0d}, new double[]{0.0d, 2.0d}}), new MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution(new double[2], (double[][]) new double[]{new double[]{1.0d, 0.0d}, new double[]{0.0d, 2.0d}})});
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            multivariateGaussianMixtureExpectationMaximization.fit(mixtureMultivariateGaussianDistribution2, 10, DEFAULT_CONVERGENCE_CHECKER);
        });
    }

    @SeededTest
    void canFit(RandomSeed randomSeed) {
        RestorableUniformRandomProvider create = RngFactory.create(randomSeed.get());
        DoubleDoubleBiPredicate doublesAreClose = Predicates.doublesAreClose(1.0E-5d, 1.0E-16d);
        for (int i = 2; i <= 3; i++) {
            double[][] createData2d = createData2d(1000, create, createWeights(i, create), create(i, 2, create, -5.0d, 5.0d), create(i, 2, create, 1.0d, 10.0d), create(i, create, -0.9d, 0.9d));
            MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution estimate = MultivariateGaussianMixtureExpectationMaximization.estimate(createData2d, i);
            MultivariateGaussianMixtureExpectationMaximization multivariateGaussianMixtureExpectationMaximization = new MultivariateGaussianMixtureExpectationMaximization(createData2d);
            Assertions.assertTrue(multivariateGaussianMixtureExpectationMaximization.fit(estimate));
            MultivariateNormalMixtureExpectationMaximization multivariateNormalMixtureExpectationMaximization = new MultivariateNormalMixtureExpectationMaximization(createData2d);
            multivariateNormalMixtureExpectationMaximization.fit(MultivariateNormalMixtureExpectationMaximization.estimate(createData2d, i));
            double logLikelihood = multivariateGaussianMixtureExpectationMaximization.getLogLikelihood() / 1000.0d;
            Assertions.assertNotEquals(0.0d, logLikelihood);
            TestAssertions.assertTest(multivariateNormalMixtureExpectationMaximization.getLogLikelihood(), logLikelihood, doublesAreClose);
            MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution fittedModel = multivariateGaussianMixtureExpectationMaximization.getFittedModel();
            Assertions.assertNotNull(fittedModel);
            List components = multivariateNormalMixtureExpectationMaximization.getFittedModel().getComponents();
            double[] weights = fittedModel.getWeights();
            MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution[] distributions = fittedModel.getDistributions();
            Assertions.assertEquals(i, components.size());
            Assertions.assertEquals(i, weights.length);
            Assertions.assertEquals(i, distributions.length);
            for (int i2 = 0; i2 < i; i2++) {
                TestAssertions.assertTest(((Double) ((Pair) components.get(i2)).getFirst()).doubleValue(), weights[i2], doublesAreClose, "weight");
                MultivariateNormalDistribution multivariateNormalDistribution = (MultivariateNormalDistribution) ((Pair) components.get(i2)).getSecond();
                TestAssertions.assertArrayTest(multivariateNormalDistribution.getMeans(), distributions[i2].getMeans(), doublesAreClose, "means");
                TestAssertions.assertArrayTest(multivariateNormalDistribution.getCovariances().getData(), distributions[i2].getCovariances(), doublesAreClose, "covariances");
            }
            int iterations = multivariateGaussianMixtureExpectationMaximization.getIterations();
            Assertions.assertNotEquals(0, iterations);
            if (iterations > 2) {
                Assertions.assertFalse(multivariateGaussianMixtureExpectationMaximization.fit(estimate, 2, DEFAULT_CONVERGENCE_CHECKER));
            }
        }
    }

    private static double[] getColumnMeans(double[][] dArr) {
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(dArr);
        Mean mean = new Mean();
        return IntStream.range(0, dArr[0].length).mapToDouble(i -> {
            return mean.evaluate(array2DRowRealMatrix.getColumn(i));
        }).toArray();
    }

    private static double[][] getCovariance(double[][] dArr) {
        return new Covariance(dArr).getCovarianceMatrix().getData();
    }

    private static double[] createWeights(int i, UniformRandomProvider uniformRandomProvider) {
        double[] dArr = new double[i];
        double d = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = uniformRandomProvider.nextDouble();
            d += dArr[i2];
        }
        if (d == 0.0d) {
            return createWeights(i, uniformRandomProvider);
        }
        SimpleArrayUtils.multiply(dArr, 1.0d / d);
        return dArr;
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    private static double[][] create(int i, int i2, UniformRandomProvider uniformRandomProvider, double d, double d2) {
        ?? r0 = new double[i];
        for (int i3 = 0; i3 < i; i3++) {
            r0[i3] = create(i2, uniformRandomProvider, d, d2);
        }
        return r0;
    }

    private static double[] create(int i, UniformRandomProvider uniformRandomProvider, double d, double d2) {
        SharedStateContinuousSampler of = ContinuousUniformSampler.of(uniformRandomProvider, d, d2);
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = of.sample();
        }
        return dArr;
    }

    /* JADX WARN: Type inference failed for: r0v17, types: [double[], double[][]] */
    private static double[][] createData2d(int i, UniformRandomProvider uniformRandomProvider, double[] dArr, double[][] dArr2, double[][] dArr3, double[] dArr4) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < dArr.length; i2++) {
            double d = dArr3[i2][0];
            double d2 = dArr3[i2][1];
            double d3 = dArr4[i2] * d * d2;
            arrayList.add(new Pair(Double.valueOf(dArr[i2]), new MultivariateNormalDistribution(dArr2[i2], (double[][]) new double[]{new double[]{d * d, d3}, new double[]{d3, d2 * d2}})));
        }
        return new MixtureMultivariateNormalDistribution(new RandomGeneratorAdapter(uniformRandomProvider), arrayList).sample(i);
    }

    static {
        DoubleDoubleBiPredicate doublesAreAbsolutelyClose = Predicates.doublesAreAbsolutelyClose(1.0E-5d);
        doublesAreAbsolutelyClose.getClass();
        DEFAULT_CONVERGENCE_CHECKER = doublesAreAbsolutelyClose::test;
    }
}
