package us.ihmc.robotics.optimization.constrainedOptimization;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.ejml.data.DMatrixD1;
import org.ejml.data.DMatrixRMaj;
import us.ihmc.log.LogTools;
import us.ihmc.robotics.optimization.CostFunction;

/* loaded from: input_file:us/ihmc/robotics/optimization/constrainedOptimization/MultiblockADMMProblem.class */
public class MultiblockADMMProblem {
    private final List<AugmentedLagrangeOptimizationProblem> isolatedOptimizationProblems = new ArrayList();
    private final List<BlockConstraintFunction> inequalityConstraints = new ArrayList();
    private final List<BlockConstraintFunction> equalityConstraints = new ArrayList();
    private AugmentedLagrangeConstructor multiblockAugmentedLagrangeConstructor;
    private DMatrixD1[] optimalBlocksFromLastIteration;

    public void addIsolatedProblem(AugmentedLagrangeOptimizationProblem augmentedLagrangeOptimizationProblem) {
        this.isolatedOptimizationProblems.add(augmentedLagrangeOptimizationProblem);
    }

    public void addInequalityConstraint(BlockConstraintFunction blockConstraintFunction) {
        this.inequalityConstraints.add(blockConstraintFunction);
    }

    public void addEqualityConstraint(BlockConstraintFunction blockConstraintFunction) {
        this.equalityConstraints.add(blockConstraintFunction);
    }

    public void clearConstraints() {
        this.equalityConstraints.clear();
        this.inequalityConstraints.clear();
    }

    public void clearIsolatedProblems() {
        this.isolatedOptimizationProblems.clear();
    }

    public void initialize(double d, double d2) {
        Iterator<AugmentedLagrangeOptimizationProblem> it = this.isolatedOptimizationProblems.iterator();
        while (it.hasNext()) {
            it.next().initialize(d, d2);
        }
        this.multiblockAugmentedLagrangeConstructor = new AugmentedLagrangeConstructor(d, d2, this.equalityConstraints.size(), this.inequalityConstraints.size());
    }

    public CostFunction getAugmentedCostFunctionForBlock(final int i) {
        return new CostFunction() { // from class: us.ihmc.robotics.optimization.constrainedOptimization.MultiblockADMMProblem.1
            @Override // us.ihmc.robotics.optimization.CostFunction
            public double calculate(DMatrixD1 dMatrixD1) {
                return MultiblockADMMProblem.this.calculateDualCostForBlock(i, dMatrixD1, MultiblockADMMProblem.this.optimalBlocksFromLastIteration);
            }
        };
    }

    public double calculateDualCostForBlock(int i, DMatrixD1 dMatrixD1, DMatrixD1[] dMatrixD1Arr) {
        if (dMatrixD1Arr.length != getNumBlocks()) {
            throw new RuntimeException("Not enough blocks " + dMatrixD1Arr.length + " were provided for all blocks of the problem " + getNumBlocks());
        }
        DMatrixD1[] dMatrixD1Arr2 = (DMatrixD1[]) dMatrixD1Arr.clone();
        dMatrixD1Arr2[i] = dMatrixD1;
        return calculateDualCostForBlock(i, dMatrixD1Arr2);
    }

    private double calculateDualCostForBlock(int i, DMatrixD1[] dMatrixD1Arr) {
        double calculateDualProblemCost = this.isolatedOptimizationProblems.get(i).calculateDualProblemCost(dMatrixD1Arr[i]);
        DMatrixD1 evaluateGlobalInequalityConstraints = evaluateGlobalInequalityConstraints(dMatrixD1Arr);
        return this.multiblockAugmentedLagrangeConstructor.getAugmentedLagrangeCost(calculateDualProblemCost, evaluateEqualityConstraints(dMatrixD1Arr), evaluateGlobalInequalityConstraints);
    }

    private DMatrixD1 evaluateGlobalInequalityConstraints(DMatrixD1[] dMatrixD1Arr) {
        int size = this.inequalityConstraints.size();
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(size, 1);
        for (int i = 0; i < size; i++) {
            dMatrixRMaj.set(i, this.inequalityConstraints.get(i).calculate(dMatrixD1Arr));
        }
        return dMatrixRMaj;
    }

    private DMatrixD1 evaluateEqualityConstraints(DMatrixD1[] dMatrixD1Arr) {
        int size = this.equalityConstraints.size();
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(size, 1);
        for (int i = 0; i < size; i++) {
            dMatrixRMaj.set(i, this.equalityConstraints.get(i).calculate(dMatrixD1Arr));
        }
        return dMatrixRMaj;
    }

    public void updateLagrangeMultipliers(DMatrixD1[] dMatrixD1Arr) {
        if (dMatrixD1Arr.length != getNumBlocks()) {
            throw new RuntimeException("Not enough values " + dMatrixD1Arr.length + " were provided for all blocks of the problem " + getNumBlocks());
        }
        this.multiblockAugmentedLagrangeConstructor.updateLagrangeMultipliers(evaluateEqualityConstraints(dMatrixD1Arr), evaluateGlobalInequalityConstraints(dMatrixD1Arr));
        for (int i = 0; i < this.isolatedOptimizationProblems.size(); i++) {
            this.isolatedOptimizationProblems.get(i).updateLagrangeMultipliers(dMatrixD1Arr[i]);
        }
    }

    public void saveOptimalBlocksForLastIteration(DMatrixD1[] dMatrixD1Arr) {
        if (dMatrixD1Arr.length != getNumBlocks()) {
            throw new RuntimeException("Not enough values " + dMatrixD1Arr.length + " were provided for all blocks of the problem " + getNumBlocks());
        }
        this.optimalBlocksFromLastIteration = dMatrixD1Arr;
    }

    public List<AugmentedLagrangeOptimizationProblem> getIsolatedOptimizationProblems() {
        return this.isolatedOptimizationProblems;
    }

    public int getNumBlocks() {
        return this.isolatedOptimizationProblems.size();
    }

    public void printResults(DMatrixD1[] dMatrixD1Arr) {
        printResults(evaluateEqualityConstraints(dMatrixD1Arr), evaluateGlobalInequalityConstraints(dMatrixD1Arr), dMatrixD1Arr);
    }

    public void printResults(DMatrixD1 dMatrixD1, DMatrixD1 dMatrixD12, DMatrixD1[] dMatrixD1Arr) {
        LogTools.info("");
        for (int i = 0; i < this.isolatedOptimizationProblems.size(); i++) {
            System.out.println("-- Isolated Problem " + i + ": --");
            this.isolatedOptimizationProblems.get(i).printResults(dMatrixD1Arr[i]);
        }
        if (dMatrixD1.getNumElements() > 0) {
            System.out.println("-- Global Equality Constraints J(x[]): --");
            for (int i2 = 0; i2 < dMatrixD1.getNumElements(); i2++) {
                System.out.println("\t" + dMatrixD1.get(i2) + " == 0");
            }
        }
        if (dMatrixD12.getNumElements() > 0) {
            System.out.println("-- Global Inquality Constraints K(x[]): --");
            for (int i3 = 0; i3 < dMatrixD12.getNumElements(); i3++) {
                System.out.println("\t" + dMatrixD12.get(i3) + " >= 0");
            }
        }
    }
}
