package us.ihmc.robotics.optimization;

import org.ejml.data.DMatrixD1;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.mult.VectorVectorMult_DDRM;
import org.junit.jupiter.api.Test;
import us.ihmc.commons.MathTools;
import us.ihmc.log.LogTools;
import us.ihmc.robotics.Assert;
import us.ihmc.robotics.optimization.constrainedOptimization.AugmentedLagrangeOptimizationProblem;
import us.ihmc.robotics.optimization.constrainedOptimization.MultiblockADMMOptimizer;
import us.ihmc.robotics.optimization.constrainedOptimization.MultiblockADMMProblem;

/* loaded from: input_file:us/ihmc/robotics/optimization/MultiblockADMMOptimizerTest.class */
public class MultiblockADMMOptimizerTest {
    private int numLagrangeIterations = 15;
    private double initialPenalty = 0.5d;
    private double penaltyIncreaseFactor = 1.1d;
    private DMatrixD1 initial1 = new DMatrixRMaj(new double[]{4.0d});
    private DMatrixD1 initial2 = new DMatrixRMaj(new double[]{0.0d});
    private DMatrixD1[] initialValues = {this.initial1, this.initial2};

    private static double costFunction(DMatrixD1 dMatrixD1) {
        return VectorVectorMult_DDRM.innerProd(dMatrixD1, dMatrixD1);
    }

    private static double constraint1(DMatrixD1 dMatrixD1) {
        return dMatrixD1.get(1) - 5.0d;
    }

    private static double constraint2(DMatrixD1 dMatrixD1) {
        return dMatrixD1.get(0) - 6.0d;
    }

    private static double constraint3(DMatrixD1 dMatrixD1) {
        return dMatrixD1.get(0) - 1.0d;
    }

    private static double blockConstraint1(DMatrixD1... dMatrixD1Arr) {
        return (dMatrixD1Arr[0].get(0) + dMatrixD1Arr[1].get(0)) - 4.0d;
    }

    @Test
    public void testSimpleBlockConstraint() {
        AugmentedLagrangeOptimizationProblem augmentedLagrangeOptimizationProblem = new AugmentedLagrangeOptimizationProblem(MultiblockADMMOptimizerTest::costFunction);
        augmentedLagrangeOptimizationProblem.addInequalityConstraint(MultiblockADMMOptimizerTest::constraint3);
        AugmentedLagrangeOptimizationProblem augmentedLagrangeOptimizationProblem2 = new AugmentedLagrangeOptimizationProblem(MultiblockADMMOptimizerTest::costFunction);
        augmentedLagrangeOptimizationProblem2.addInequalityConstraint(MultiblockADMMOptimizerTest::constraint3);
        MultiblockADMMProblem multiblockADMMProblem = new MultiblockADMMProblem();
        multiblockADMMProblem.addIsolatedProblem(augmentedLagrangeOptimizationProblem);
        multiblockADMMProblem.addIsolatedProblem(augmentedLagrangeOptimizationProblem2);
        multiblockADMMProblem.addEqualityConstraint(MultiblockADMMOptimizerTest::blockConstraint1);
        multiblockADMMProblem.initialize(this.initialPenalty, this.penaltyIncreaseFactor);
        int numBlocks = multiblockADMMProblem.getNumBlocks();
        Optimizer[] optimizerArr = new Optimizer[numBlocks];
        for (int i = 0; i < numBlocks; i++) {
            optimizerArr[i] = new WrappedGradientDescent();
        }
        MultiblockADMMOptimizer multiblockADMMOptimizer = new MultiblockADMMOptimizer(multiblockADMMProblem, optimizerArr);
        multiblockADMMOptimizer.setVerbose(true);
        DMatrixD1[] solveOverNIterations = multiblockADMMOptimizer.solveOverNIterations(this.numLagrangeIterations, this.initialValues);
        Assert.assertTrue("x1 arrived on desired value", MathTools.epsilonCompare(solveOverNIterations[0].get(0), 2.0d, 0.001d));
        Assert.assertTrue("x2 arrived on desired value", MathTools.epsilonCompare(solveOverNIterations[1].get(0), 2.0d, 0.001d));
        LogTools.debug("Test completed successfully");
    }
}
