package us.ihmc.robotics.functionApproximation;

import java.util.ArrayList;
import java.util.Random;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import us.ihmc.robotics.Assert;

/* loaded from: input_file:us/ihmc/robotics/functionApproximation/LinearRegressionTest.class */
public class LinearRegressionTest {
    private static final boolean VERBOSE = false;

    @Test
    public void testTypicalExampleOne() {
        Random random = new Random(1984L);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < 500; i++) {
            double nextDouble = (random.nextDouble() * 2.0d) - 1.0d;
            double[] dArr = {1.0d, nextDouble, nextDouble * nextDouble};
            double nextDouble2 = (random.nextDouble() * 0.1d) + (1.0d * nextDouble) + (5.0d * nextDouble * nextDouble);
            arrayList.add(dArr);
            arrayList2.add(Double.valueOf(nextDouble2));
        }
        LinearRegression linearRegression = new LinearRegression(arrayList, arrayList2);
        double solveAndReturnRuntimeInMilliseconds = solveAndReturnRuntimeInMilliseconds(linearRegression);
        double[] dArr2 = new double[3];
        linearRegression.getCoefficientVector(dArr2);
        printResults(linearRegression, solveAndReturnRuntimeInMilliseconds, dArr2);
        assertResultsAreAsExpected(linearRegression, solveAndReturnRuntimeInMilliseconds, dArr2, 0.1d, 0.001d, 7.2d, new double[]{0.05d, 1.0d, 5.0d});
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [double[], double[][]] */
    @Test
    public void testTypicalExampleTwo() {
        Random random = new Random(1776L);
        ?? r0 = new double[500];
        double[] dArr = new double[500];
        for (int i = 0; i < 500; i++) {
            double nextDouble = (random.nextDouble() * 2.0d) - 1.0d;
            double nextDouble2 = (random.nextDouble() * 2.0d) - 1.0d;
            double[] dArr2 = new double[6];
            dArr2[0] = 1.0d;
            dArr2[1] = nextDouble;
            dArr2[2] = nextDouble * nextDouble;
            dArr2[3] = nextDouble2;
            dArr2[4] = nextDouble2 * nextDouble2;
            dArr2[5] = nextDouble * nextDouble2;
            double nextDouble3 = 4.0d + (random.nextDouble() * 0.1d) + (1.0d * nextDouble) + (0.0d * nextDouble * nextDouble) + (0.2d * nextDouble2) + ((-3.0d) * nextDouble2 * nextDouble2) + (5.0d * nextDouble * nextDouble2);
            r0[i] = dArr2;
            dArr[i] = nextDouble3;
        }
        LinearRegression linearRegression = new LinearRegression((double[][]) r0, dArr);
        double solveAndReturnRuntimeInMilliseconds = solveAndReturnRuntimeInMilliseconds(linearRegression);
        double[] dArr3 = new double[6];
        linearRegression.getCoefficientVector(dArr3);
        printResults(linearRegression, solveAndReturnRuntimeInMilliseconds, dArr3);
        assertResultsAreAsExpected(linearRegression, solveAndReturnRuntimeInMilliseconds, dArr3, 0.1d, 0.001d, 5.0d, new double[]{4.05d, 1.0d, 0.0d, 0.2d, -3.0d, 5.0d});
    }

    @Test
    public void testPerfectMatch() {
        Random random = new Random(2000L);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < 500; i++) {
            double nextDouble = (random.nextDouble() * 4.0d) - 1.0d;
            double nextDouble2 = (random.nextDouble() * 4.0d) - 1.0d;
            arrayList.add(new double[]{1.0d, nextDouble, nextDouble * nextDouble, nextDouble2, nextDouble2 * nextDouble2, nextDouble * nextDouble2});
            arrayList2.add(Double.valueOf(90.0d + (5.0d * nextDouble) + (6.0d * nextDouble * nextDouble) + (7.0d * nextDouble2) + ((-8.0d) * nextDouble2 * nextDouble2) + (13.0d * nextDouble * nextDouble2)));
        }
        LinearRegression linearRegression = new LinearRegression(arrayList, arrayList2);
        double solveAndReturnRuntimeInMilliseconds = solveAndReturnRuntimeInMilliseconds(linearRegression);
        double[] dArr = new double[6];
        linearRegression.getCoefficientVector(dArr);
        printResults(linearRegression, solveAndReturnRuntimeInMilliseconds, dArr);
        assertResultsAreAsExpected(linearRegression, solveAndReturnRuntimeInMilliseconds, dArr, 1.0E-10d, 1.0E-14d, 5.0d, new double[]{90.0d, 5.0d, 6.0d, 7.0d, -8.0d, 13.0d});
    }

    @Test
    public void testRandomness() {
        Random random = new Random(1776L);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < 5; i++) {
            double[] dArr = {random.nextDouble() * 100.0d, random.nextDouble() * 10.0d};
            double nextDouble = random.nextDouble() * 500.0d;
            arrayList.add(dArr);
            arrayList2.add(Double.valueOf(nextDouble));
        }
        LinearRegression linearRegression = new LinearRegression(arrayList, arrayList2);
        Assert.assertTrue(linearRegression.solve());
        linearRegression.getSquaredError();
    }

    @Test
    public void testNotEnoughPoints() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            arrayList.add(new double[]{1.0d});
            arrayList2.add(Double.valueOf(2.0d));
            arrayList2.add(Double.valueOf(3.0d));
            new LinearRegression(arrayList, arrayList2).solve();
        });
    }

    @Test
    public void testAskingForAnswerBeforeDone() {
        Assertions.assertThrows(IllegalStateException.class, () -> {
            new LinearRegression((double[][]) new double[]{new double[]{1.0d}, new double[]{1.0d}}, new double[]{1.0d, 1.0d}).getCoefficientVector(new double[1]);
        });
    }

    @Test
    public void testAskingForSquaredErrorBeforeDone() {
        Assertions.assertThrows(IllegalStateException.class, () -> {
            new LinearRegression((double[][]) new double[]{new double[]{1.0d}, new double[]{1.0d}}, new double[]{1.0d, 1.0d}).getSquaredError();
        });
    }

    @Test
    public void testAskingForCoefficientVectorAsMatrixBeforeDone() {
        Assertions.assertThrows(IllegalStateException.class, () -> {
            new LinearRegression((double[][]) new double[]{new double[]{1.0d}, new double[]{1.0d}}, new double[]{1.0d, 1.0d}).getCoefficientVectorAsMatrix();
        });
    }

    private void assertResultsAreAsExpected(LinearRegression linearRegression, double d, double[] dArr, double d2, double d3, double d4, double[] dArr2) {
        Assert.assertEquals(dArr2.length, dArr.length);
        for (int i = 0; i < dArr2.length; i++) {
            Assert.assertEquals(dArr2[i], dArr[i], d2);
        }
        Assert.assertTrue("linearRegression.getSquaredError() was less than 0.0!", linearRegression.getSquaredError() > 0.0d);
        Assert.assertTrue("linearRegression.getSquaredError() = " + linearRegression.getSquaredError(), linearRegression.getSquaredError() < d3);
    }

    private void printResults(LinearRegression linearRegression, double d, double[] dArr) {
    }

    private double solveAndReturnRuntimeInMilliseconds(LinearRegression linearRegression) {
        long nanoTime = System.nanoTime();
        Assert.assertTrue(linearRegression.solve());
        return (System.nanoTime() - nanoTime) / 1000000.0d;
    }
}
