package us.ihmc.convexOptimization.quadraticProgram;

import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import us.ihmc.commons.MathTools;
import us.ihmc.convexOptimization.exceptions.NoConvergenceException;
import us.ihmc.log.LogTools;
import us.ihmc.matrixlib.MatrixTestTools;

/* loaded from: input_file:us/ihmc/convexOptimization/quadraticProgram/ConstrainedQPSolverTest.class */
public class ConstrainedQPSolverTest {
    @Test
    public void testSolveContrainedQP() throws NoConvergenceException {
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(2, 2, true, new double[]{1.0d, 0.0d, 0.0d, 1.0d});
        DMatrixRMaj dMatrixRMaj2 = new DMatrixRMaj(2, 1, true, new double[]{1.0d, 0.0d});
        DMatrixRMaj dMatrixRMaj3 = new DMatrixRMaj(1, 2, true, new double[]{1.0d, 1.0d});
        DMatrixRMaj dMatrixRMaj4 = new DMatrixRMaj(1, 1, true, new double[]{0.0d});
        DMatrixRMaj dMatrixRMaj5 = new DMatrixRMaj(1, 2, true, new double[]{2.0d, 1.0d});
        DMatrixRMaj dMatrixRMaj6 = new DMatrixRMaj(1, 1, true, new double[]{0.0d});
        ConstrainedQPSolver[] createSolvers = createSolvers();
        JavaQuadProgSolver javaQuadProgSolver = new JavaQuadProgSolver();
        for (int i = 0; i < 10000; i++) {
            for (ConstrainedQPSolver constrainedQPSolver : createSolvers) {
                DMatrixRMaj dMatrixRMaj7 = new DMatrixRMaj(2, 1, true, new double[]{-1.0d, 1.0d});
                constrainedQPSolver.solve(dMatrixRMaj, dMatrixRMaj2, dMatrixRMaj3, dMatrixRMaj4, dMatrixRMaj5, dMatrixRMaj6, dMatrixRMaj7, false);
                Assertions.assertArrayEquals(dMatrixRMaj7.getData(), new double[]{-0.5d, 0.5d}, 1.0E-10d);
            }
            DMatrixRMaj dMatrixRMaj8 = new DMatrixRMaj(2, 1, true, new double[]{-1.0d, 1.0d});
            javaQuadProgSolver.clear();
            javaQuadProgSolver.setQuadraticCostFunction(dMatrixRMaj, dMatrixRMaj2, 0.0d);
            javaQuadProgSolver.setLinearInequalityConstraints(dMatrixRMaj5, dMatrixRMaj6);
            javaQuadProgSolver.setLinearEqualityConstraints(dMatrixRMaj3, dMatrixRMaj4);
            javaQuadProgSolver.solve(dMatrixRMaj8);
            Assertions.assertArrayEquals(dMatrixRMaj8.getData(), new double[]{-0.5d, 0.5d}, 1.0E-10d);
        }
        ConstrainedQPSolver[] constrainedQPSolverArr = new ConstrainedQPSolver[0];
        DMatrixRMaj dMatrixRMaj9 = new DMatrixRMaj(3, 3, true, new double[]{1.0d, 0.0d, 1.0d, 0.0d, 1.0d, 2.0d, 1.0d, 3.0d, 7.0d});
        DMatrixRMaj dMatrixRMaj10 = new DMatrixRMaj(3, 1, true, new double[]{1.0d, 0.0d, 9.0d});
        DMatrixRMaj dMatrixRMaj11 = new DMatrixRMaj(2, 3, true, new double[]{1.0d, 1.0d, 1.0d, 2.0d, 3.0d, 4.0d});
        DMatrixRMaj dMatrixRMaj12 = new DMatrixRMaj(2, 1, true, new double[]{0.0d, 7.0d});
        DMatrixRMaj dMatrixRMaj13 = new DMatrixRMaj(1, 3, true, new double[]{2.0d, 1.0d, 3.0d});
        DMatrixRMaj dMatrixRMaj14 = new DMatrixRMaj(1, 1, true, new double[]{0.0d});
        for (int i2 = 0; i2 < 10000; i2++) {
            for (int i3 = 0; i3 < constrainedQPSolverArr.length; i3++) {
                DMatrixRMaj dMatrixRMaj15 = new DMatrixRMaj(3, 1, true, new double[]{-1.0d, 1.0d, 3.0d});
                constrainedQPSolverArr[i3].solve(dMatrixRMaj9, dMatrixRMaj10, dMatrixRMaj11, dMatrixRMaj12, dMatrixRMaj13, dMatrixRMaj14, dMatrixRMaj15, false);
                Assertions.assertArrayEquals(dMatrixRMaj15.getData(), new double[]{-7.75d, 8.5d, -0.75d}, 1.0E-10d, "repeat = " + i2 + ", optimizer = " + i3);
                DMatrixRMaj dMatrixRMaj16 = new DMatrixRMaj(2, 1);
                CommonOps_DDRM.mult(dMatrixRMaj11, dMatrixRMaj15, dMatrixRMaj16);
                MatrixTestTools.assertMatrixEquals(dMatrixRMaj16, dMatrixRMaj12, 1.0E-7d);
                DMatrixRMaj dMatrixRMaj17 = new DMatrixRMaj(1, 1);
                CommonOps_DDRM.mult(dMatrixRMaj13, dMatrixRMaj15, dMatrixRMaj17);
                for (int i4 = 0; i4 < dMatrixRMaj17.getNumRows(); i4++) {
                    Assertions.assertTrue(dMatrixRMaj17.get(i4, 0) < dMatrixRMaj12.get(i4, 0));
                }
            }
            if (constrainedQPSolverArr.length > 0) {
                DMatrixRMaj dMatrixRMaj18 = new DMatrixRMaj(3, 1, true, new double[]{-1.0d, 1.0d, 3.0d});
                javaQuadProgSolver.clear();
                javaQuadProgSolver.setQuadraticCostFunction(dMatrixRMaj9, dMatrixRMaj10, 0.0d);
                javaQuadProgSolver.setLinearInequalityConstraints(dMatrixRMaj13, dMatrixRMaj14);
                javaQuadProgSolver.setLinearEqualityConstraints(dMatrixRMaj11, dMatrixRMaj12);
                javaQuadProgSolver.solve(dMatrixRMaj18);
                Assertions.assertArrayEquals(dMatrixRMaj18.getData(), new double[]{-7.75d, 8.5d, -0.75d}, 1.0E-10d, "repeat = " + i2 + ", Java solver");
                DMatrixRMaj dMatrixRMaj19 = new DMatrixRMaj(2, 1);
                CommonOps_DDRM.mult(dMatrixRMaj11, dMatrixRMaj18, dMatrixRMaj19);
                MatrixTestTools.assertMatrixEquals(dMatrixRMaj19, dMatrixRMaj12, 1.0E-7d);
                DMatrixRMaj dMatrixRMaj20 = new DMatrixRMaj(1, 1);
                CommonOps_DDRM.mult(dMatrixRMaj13, dMatrixRMaj18, dMatrixRMaj20);
                for (int i5 = 0; i5 < dMatrixRMaj20.getNumRows(); i5++) {
                    Assertions.assertTrue(dMatrixRMaj20.get(i5, 0) < dMatrixRMaj12.get(i5, 0));
                }
            }
        }
    }

    @Test
    public void testSolveProblemWithParallelConstraints() throws NoConvergenceException {
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(1, 1);
        DMatrixRMaj dMatrixRMaj2 = new DMatrixRMaj(2, 1);
        DMatrixRMaj dMatrixRMaj3 = new DMatrixRMaj(2, 1);
        DMatrixRMaj dMatrixRMaj4 = new DMatrixRMaj(1, 1);
        dMatrixRMaj.set(0, 0, 1.0d);
        dMatrixRMaj2.set(0, 0, 1.0d);
        dMatrixRMaj2.set(1, 0, 1.0d);
        dMatrixRMaj3.set(0, -1.0d);
        dMatrixRMaj3.set(0, -2.0d);
        DMatrixRMaj dMatrixRMaj5 = new DMatrixRMaj(1, 1);
        DMatrixRMaj dMatrixRMaj6 = new DMatrixRMaj(0, 1);
        DMatrixRMaj dMatrixRMaj7 = new DMatrixRMaj(0, 1);
        for (ConstrainedQPSolver constrainedQPSolver : createSolvers()) {
            LogTools.info("Attempting to solve problem with: " + constrainedQPSolver.getClass().getSimpleName());
            constrainedQPSolver.solve(dMatrixRMaj, dMatrixRMaj5, dMatrixRMaj6, dMatrixRMaj7, dMatrixRMaj2, dMatrixRMaj3, dMatrixRMaj4, true);
            boolean epsilonEquals = MathTools.epsilonEquals(-2.0d, dMatrixRMaj4.get(0), 1.0E-9d);
            if (!epsilonEquals) {
                LogTools.info("Failed. Result was " + dMatrixRMaj4.get(0) + ", expected -2.0");
            }
            Assertions.assertTrue(epsilonEquals);
        }
        JavaQuadProgSolver javaQuadProgSolver = new JavaQuadProgSolver();
        LogTools.info("Attempting to solve problem with: " + javaQuadProgSolver.getClass().getSimpleName());
        javaQuadProgSolver.clear();
        javaQuadProgSolver.setQuadraticCostFunction(dMatrixRMaj, dMatrixRMaj5, 0.0d);
        javaQuadProgSolver.setLinearInequalityConstraints(dMatrixRMaj2, dMatrixRMaj3);
        javaQuadProgSolver.setLinearEqualityConstraints(dMatrixRMaj6, dMatrixRMaj7);
        javaQuadProgSolver.solve(dMatrixRMaj4);
        boolean epsilonEquals2 = MathTools.epsilonEquals(-2.0d, dMatrixRMaj4.get(0), 1.0E-9d);
        if (!epsilonEquals2) {
            LogTools.info("Failed. Result was " + dMatrixRMaj4.get(0) + ", expected -2.0");
        }
        Assertions.assertTrue(epsilonEquals2);
    }

    private ConstrainedQPSolver[] createSolvers() {
        return new ConstrainedQPSolver[]{new OASESConstrainedQPSolver(), new QuadProgSolver()};
    }
}
