package us.ihmc.robotics.optimization.constrainedOptimization;

import java.util.ArrayList;
import java.util.List;
import org.ejml.data.DMatrixD1;
import org.ejml.data.DMatrixRMaj;
import us.ihmc.log.LogTools;
import us.ihmc.robotics.optimization.CostFunction;

/* loaded from: input_file:us/ihmc/robotics/optimization/constrainedOptimization/AugmentedLagrangeOptimizationProblem.class */
public class AugmentedLagrangeOptimizationProblem {
    private final CostFunction costFunction;
    private final List<ConstraintFunction> inequalityConstraints = new ArrayList();
    private final List<ConstraintFunction> equalityConstraints = new ArrayList();
    private AugmentedLagrangeConstructor augmentedLagrangeConstructor;

    public AugmentedLagrangeOptimizationProblem(CostFunction costFunction) {
        this.costFunction = costFunction;
    }

    public void addInequalityConstraint(ConstraintFunction constraintFunction) {
        this.inequalityConstraints.add(constraintFunction);
    }

    public void addEqualityConstraint(ConstraintFunction constraintFunction) {
        this.equalityConstraints.add(constraintFunction);
    }

    public void clearConstraints() {
        this.equalityConstraints.clear();
        this.inequalityConstraints.clear();
    }

    public void initialize(double d, double d2) {
        this.augmentedLagrangeConstructor = new AugmentedLagrangeConstructor(d, d2, this.equalityConstraints.size(), this.inequalityConstraints.size());
    }

    public CostFunction getAugmentedCostFunction() {
        return this::calculateDualProblemCost;
    }

    public double calculateDualProblemCost(DMatrixD1 dMatrixD1) {
        double calculateCost = calculateCost(dMatrixD1);
        DMatrixD1 evaluateInequalityConstraints = evaluateInequalityConstraints(dMatrixD1);
        return this.augmentedLagrangeConstructor.getAugmentedLagrangeCost(calculateCost, evaluateEqualityConstraints(dMatrixD1), evaluateInequalityConstraints);
    }

    private double calculateCost(DMatrixD1 dMatrixD1) {
        return this.costFunction.calculate(dMatrixD1);
    }

    private DMatrixD1 evaluateInequalityConstraints(DMatrixD1 dMatrixD1) {
        int size = this.inequalityConstraints.size();
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(size, 1);
        for (int i = 0; i < size; i++) {
            dMatrixRMaj.set(i, this.inequalityConstraints.get(i).calculate(dMatrixD1));
        }
        return dMatrixRMaj;
    }

    private DMatrixD1 evaluateEqualityConstraints(DMatrixD1 dMatrixD1) {
        int size = this.equalityConstraints.size();
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(size, 1);
        for (int i = 0; i < size; i++) {
            dMatrixRMaj.set(i, this.equalityConstraints.get(i).calculate(dMatrixD1));
        }
        return new DMatrixRMaj(dMatrixRMaj);
    }

    public void updateLagrangeMultipliers(DMatrixD1 dMatrixD1) {
        this.augmentedLagrangeConstructor.updateLagrangeMultipliers(evaluateEqualityConstraints(dMatrixD1), evaluateInequalityConstraints(dMatrixD1));
    }

    public void printResults(DMatrixD1 dMatrixD1) {
        printResults(dMatrixD1, this.costFunction.calculate(dMatrixD1), evaluateEqualityConstraints(dMatrixD1), evaluateInequalityConstraints(dMatrixD1));
    }

    public void printResults(DMatrixD1 dMatrixD1, double d, DMatrixD1 dMatrixD12, DMatrixD1 dMatrixD13) {
        LogTools.info("");
        System.out.println("Solution x:");
        for (int i = 0; i < dMatrixD1.getNumElements(); i++) {
            LogTools.debug("\t" + dMatrixD1.get(i) + ",");
        }
        System.out.println("Cost f(x):");
        System.out.println("\t" + d);
        if (dMatrixD12.getNumElements() > 0) {
            System.out.println("Equality Constraints G(x):");
            for (int i2 = 0; i2 < dMatrixD12.getNumElements(); i2++) {
                System.out.println("\t" + dMatrixD12.get(i2) + " == 0");
            }
        }
        if (dMatrixD13.getNumElements() > 0) {
            System.out.println("Inquality Constraints H(x):");
            for (int i3 = 0; i3 < dMatrixD13.getNumElements(); i3++) {
                System.out.println("\t" + dMatrixD13.get(i3) + " >= 0");
            }
        }
    }
}
