package us.ihmc.robotics.optimization;

import org.ejml.data.DMatrixD1;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.mult.VectorVectorMult_DDRM;
import org.junit.jupiter.api.Test;
import us.ihmc.commons.MathTools;
import us.ihmc.robotics.Assert;
import us.ihmc.robotics.optimization.constrainedOptimization.AugmentedLagrangeOptimizationProblem;
import us.ihmc.robotics.optimization.constrainedOptimization.AugmentedLagrangeOptimizer;

/* loaded from: input_file:us/ihmc/robotics/optimization/AugmentedLagrangeOptimizerTest.class */
public class AugmentedLagrangeOptimizerTest {
    private AugmentedLagrangeOptimizationProblem augmentedLagrangeProblem;

    private static double costFunctionQuadratic(DMatrixD1 dMatrixD1) {
        return VectorVectorMult_DDRM.innerProd(dMatrixD1, dMatrixD1);
    }

    private static double costFunctionNonconvex(DMatrixD1 dMatrixD1) {
        double sqrt = Math.sqrt(VectorVectorMult_DDRM.innerProd(dMatrixD1, dMatrixD1));
        if (sqrt == 0.0d) {
            return -5.0d;
        }
        return ((-5.0d) * Math.sin(sqrt)) / sqrt;
    }

    private static double constraint1(DMatrixD1 dMatrixD1) {
        return dMatrixD1.get(1) - 5.0d;
    }

    private static double constraint2(DMatrixD1 dMatrixD1) {
        return dMatrixD1.get(0) - 6.0d;
    }

    private static double constraint3(DMatrixD1 dMatrixD1) {
        return (-dMatrixD1.get(2)) + 3.0d;
    }

    private static double constraint4(DMatrixD1 dMatrixD1) {
        return ((dMatrixD1.get(0) + dMatrixD1.get(1)) + dMatrixD1.get(2)) - 6.0d;
    }

    private static double constraintNonconvex(DMatrixD1 dMatrixD1) {
        return (dMatrixD1.get(0) + dMatrixD1.get(1)) - 6.354061535d;
    }

    @Test
    public void testIsolatedConstraints() {
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(new double[]{10.0d, 14.5d, 16.0d});
        this.augmentedLagrangeProblem = new AugmentedLagrangeOptimizationProblem(AugmentedLagrangeOptimizerTest::costFunctionQuadratic);
        this.augmentedLagrangeProblem.addEqualityConstraint(AugmentedLagrangeOptimizerTest::constraint1);
        this.augmentedLagrangeProblem.addInequalityConstraint(AugmentedLagrangeOptimizerTest::constraint2);
        this.augmentedLagrangeProblem.addInequalityConstraint(AugmentedLagrangeOptimizerTest::constraint3);
        this.augmentedLagrangeProblem.initialize(1.0d, 1.5d);
        AugmentedLagrangeOptimizer augmentedLagrangeOptimizer = new AugmentedLagrangeOptimizer(new WrappedGradientDescent(), this.augmentedLagrangeProblem);
        augmentedLagrangeOptimizer.setVerbose(true);
        DMatrixD1 optimize = augmentedLagrangeOptimizer.optimize(10, dMatrixRMaj);
        Assert.assertTrue("x1 arrived on desired value", MathTools.epsilonCompare(optimize.get(0), 6.0d, 0.001d));
        Assert.assertTrue("x2 arrived on desired value", MathTools.epsilonCompare(optimize.get(1), 5.0d, 0.001d));
        Assert.assertTrue("x3 arrived on desired value", MathTools.epsilonCompare(optimize.get(2), 0.0d, 0.001d));
    }

    @Test
    public void testJointConstraints() {
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(new double[]{10.0d, 14.5d, 16.0d});
        this.augmentedLagrangeProblem = new AugmentedLagrangeOptimizationProblem(AugmentedLagrangeOptimizerTest::costFunctionQuadratic);
        this.augmentedLagrangeProblem.addEqualityConstraint(AugmentedLagrangeOptimizerTest::constraint4);
        this.augmentedLagrangeProblem.initialize(1.0d, 1.5d);
        AugmentedLagrangeOptimizer augmentedLagrangeOptimizer = new AugmentedLagrangeOptimizer(new WrappedGradientDescent(), this.augmentedLagrangeProblem);
        augmentedLagrangeOptimizer.setVerbose(true);
        DMatrixD1 optimize = augmentedLagrangeOptimizer.optimize(10, dMatrixRMaj);
        Assert.assertTrue("x1 arrived on desired value", MathTools.epsilonCompare(optimize.get(0), 2.0d, 0.001d));
        Assert.assertTrue("x2 arrived on desired value", MathTools.epsilonCompare(optimize.get(1), 2.0d, 0.001d));
        Assert.assertTrue("x3 arrived on desired value", MathTools.epsilonCompare(optimize.get(2), 2.0d, 0.001d));
    }

    @Test
    public void testNonconvex() {
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(new double[]{13.0d, 14.0d});
        this.augmentedLagrangeProblem = new AugmentedLagrangeOptimizationProblem(AugmentedLagrangeOptimizerTest::costFunctionNonconvex);
        this.augmentedLagrangeProblem.addEqualityConstraint(AugmentedLagrangeOptimizerTest::constraintNonconvex);
        this.augmentedLagrangeProblem.initialize(1.0d, 1.5d);
        WrappedGradientDescent wrappedGradientDescent = new WrappedGradientDescent();
        wrappedGradientDescent.setInitialStepSize(10.0d);
        wrappedGradientDescent.setLearningRate(0.9d);
        AugmentedLagrangeOptimizer augmentedLagrangeOptimizer = new AugmentedLagrangeOptimizer(wrappedGradientDescent, this.augmentedLagrangeProblem);
        augmentedLagrangeOptimizer.setVerbose(true);
        DMatrixD1 optimize = augmentedLagrangeOptimizer.optimize(10, dMatrixRMaj);
        Assert.assertTrue("x1 arrived on desired value", MathTools.epsilonCompare(optimize.get(0), 3.1770307678d, 0.001d));
        Assert.assertTrue("x2 arrived on desired value", MathTools.epsilonCompare(optimize.get(1), 3.1770307678d, 0.001d));
    }
}
