package us.ihmc.ekf.filter;

import java.util.Random;
import org.apache.commons.math3.util.Precision;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import org.ejml.simple.SimpleMatrix;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import us.ihmc.commons.Conversions;
import us.ihmc.ekf.TestTools;

/* loaded from: input_file:us/ihmc/ekf/filter/NativeFilterMatrixOpsTest.class */
public class NativeFilterMatrixOpsTest {
    private static final double EPSILON = 1.0E-10d;
    private static final Random random = new Random(86526826);

    @Test
    public void testABAt() {
        for (int i = 0; i < 50; i++) {
            int nextInt = random.nextInt(100) + 1;
            int nextInt2 = random.nextInt(100) + 1;
            DMatrixRMaj nextMatrix = TestTools.nextMatrix(nextInt, nextInt2, random, -1.0d, 1.0d);
            DMatrixRMaj nextMatrix2 = TestTools.nextMatrix(nextInt2, random, -1.0d, 1.0d);
            DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(0, 0);
            NativeFilterMatrixOps.computeABAt(dMatrixRMaj, nextMatrix, nextMatrix2);
            SimpleMatrix simpleMatrix = new SimpleMatrix(nextMatrix);
            TestTools.assertEquals(simpleMatrix.mult(new SimpleMatrix(nextMatrix2).mult(simpleMatrix.transpose())).getMatrix(), dMatrixRMaj, EPSILON);
        }
    }

    @Test
    public void testPredictErrorCovariance() {
        for (int i = 0; i < 50; i++) {
            int nextInt = random.nextInt(100) + 1;
            DMatrixRMaj nextMatrix = TestTools.nextMatrix(nextInt, random, -1.0d, 1.0d);
            DMatrixRMaj nextSymmetricMatrix = TestTools.nextSymmetricMatrix(nextInt, random, 0.1d, 1.0d);
            DMatrixRMaj nextDiagonalMatrix = TestTools.nextDiagonalMatrix(nextInt, random, 0.1d, 1.0d);
            DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(0, 0);
            NativeFilterMatrixOps.predictErrorCovariance(dMatrixRMaj, nextMatrix, nextSymmetricMatrix, nextDiagonalMatrix);
            SimpleMatrix simpleMatrix = new SimpleMatrix(nextSymmetricMatrix);
            SimpleMatrix simpleMatrix2 = new SimpleMatrix(nextMatrix);
            TestTools.assertEquals(simpleMatrix2.mult(simpleMatrix.mult(simpleMatrix2.transpose())).plus(new SimpleMatrix(nextDiagonalMatrix)).getMatrix(), dMatrixRMaj, EPSILON);
        }
    }

    @Test
    public void testUpdateErrorCovariance() {
        for (int i = 0; i < 50; i++) {
            int nextInt = random.nextInt(100) + 1;
            int nextInt2 = random.nextInt(100) + 1;
            DMatrixRMaj nextMatrix = TestTools.nextMatrix(nextInt2, nextInt, random, -1.0d, 1.0d);
            DMatrixRMaj nextMatrix2 = TestTools.nextMatrix(nextInt, nextInt2, random, -1.0d, 1.0d);
            DMatrixRMaj nextSymmetricMatrix = TestTools.nextSymmetricMatrix(nextInt2, random, 0.1d, 1.0d);
            DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(0, 0);
            NativeFilterMatrixOps.updateErrorCovariance(dMatrixRMaj, nextMatrix, nextMatrix2, nextSymmetricMatrix);
            TestTools.assertEquals(SimpleMatrix.identity(nextInt2).minus(new SimpleMatrix(nextMatrix).mult(new SimpleMatrix(nextMatrix2))).mult(new SimpleMatrix(nextSymmetricMatrix)).getMatrix(), dMatrixRMaj, EPSILON);
        }
    }

    @Test
    public void testComputeKalmanGain() {
        for (int i = 0; i < 50; i++) {
            int nextInt = random.nextInt(100) + 1;
            int nextInt2 = random.nextInt(100) + 1;
            DMatrixRMaj nextSymmetricMatrix = TestTools.nextSymmetricMatrix(nextInt2, random, 0.1d, 1.0d);
            DMatrixRMaj nextMatrix = TestTools.nextMatrix(nextInt, nextInt2, random, -1.0d, 1.0d);
            DMatrixRMaj nextDiagonalMatrix = TestTools.nextDiagonalMatrix(nextInt, random, 1.0d, 100.0d);
            DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(0, 0);
            NativeFilterMatrixOps.computeKalmanGain(dMatrixRMaj, nextSymmetricMatrix, nextMatrix, nextDiagonalMatrix);
            SimpleMatrix simpleMatrix = new SimpleMatrix(nextSymmetricMatrix);
            SimpleMatrix simpleMatrix2 = new SimpleMatrix(nextMatrix);
            SimpleMatrix plus = simpleMatrix2.mult(simpleMatrix.mult(simpleMatrix2.transpose())).plus(new SimpleMatrix(nextDiagonalMatrix));
            if (Math.abs(plus.determinant()) < 1.0E-5d) {
                Assertions.fail("Poorly conditioned matrix. Change random seed or skip. Determinant is " + plus.determinant());
            }
            TestTools.assertEquals(simpleMatrix.mult(simpleMatrix2.transpose()).mult(plus.invert()).getMatrix(), dMatrixRMaj, EPSILON);
        }
    }

    @Test
    public void testUpdateState() {
        for (int i = 0; i < 50; i++) {
            int nextInt = random.nextInt(100) + 1;
            int nextInt2 = random.nextInt(100) + 1;
            DMatrixRMaj nextMatrix = TestTools.nextMatrix(nextInt, 1, random, -1.0d, 1.0d);
            DMatrixRMaj nextMatrix2 = TestTools.nextMatrix(nextInt, nextInt2, random, -1.0d, 1.0d);
            DMatrixRMaj nextMatrix3 = TestTools.nextMatrix(nextInt2, 1, random, -1.0d, 1.0d);
            DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(0, 0);
            NativeFilterMatrixOps.updateState(dMatrixRMaj, nextMatrix, nextMatrix2, nextMatrix3);
            TestTools.assertEquals(new SimpleMatrix(nextMatrix).plus(new SimpleMatrix(nextMatrix2).mult(new SimpleMatrix(nextMatrix3))).getMatrix(), dMatrixRMaj, EPSILON);
        }
    }

    public static void main(String[] strArr) {
        DMatrixRMaj nextMatrix = TestTools.nextMatrix(100, 100, random, -1.0d, 1.0d);
        DMatrixRMaj nextMatrix2 = TestTools.nextMatrix(100, random, -1.0d, 1.0d);
        for (int i = 0; i < 1000; i++) {
            NativeFilterMatrixOps.computeABAt(new DMatrixRMaj(0, 0), nextMatrix, nextMatrix2);
            DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(100, 100);
            DMatrixRMaj dMatrixRMaj2 = new DMatrixRMaj(100, 100);
            CommonOps_DDRM.multTransB(nextMatrix2, nextMatrix, dMatrixRMaj);
            CommonOps_DDRM.mult(nextMatrix, dMatrixRMaj, dMatrixRMaj2);
        }
        long nanoTime = System.nanoTime();
        for (int i2 = 0; i2 < 1000; i2++) {
            NativeFilterMatrixOps.computeABAt(new DMatrixRMaj(0, 0), nextMatrix, nextMatrix2);
        }
        System.out.println("Native computation took: " + Precision.round(Conversions.nanosecondsToMilliseconds((System.nanoTime() - nanoTime) / 1000), 2) + "ms");
        long nanoTime2 = System.nanoTime();
        for (int i3 = 0; i3 < 1000; i3++) {
            DMatrixRMaj dMatrixRMaj3 = new DMatrixRMaj(100, 100);
            DMatrixRMaj dMatrixRMaj4 = new DMatrixRMaj(100, 100);
            CommonOps_DDRM.multTransB(nextMatrix2, nextMatrix, dMatrixRMaj3);
            CommonOps_DDRM.mult(nextMatrix, dMatrixRMaj3, dMatrixRMaj4);
        }
        System.out.println("EJML computation took: " + Precision.round(Conversions.nanosecondsToMilliseconds((System.nanoTime() - nanoTime2) / 1000), 2) + "ms");
    }
}
