package us.ihmc.robotics.optimization;

import gnu.trove.list.array.TDoubleArrayList;
import org.ejml.data.DMatrixD1;
import org.ejml.data.DMatrixRMaj;
import us.ihmc.robotics.numericalMethods.GradientDescentModule;
import us.ihmc.robotics.numericalMethods.SingleQueryFunction;

/* loaded from: input_file:us/ihmc/robotics/optimization/WrappedGradientDescent.class */
public class WrappedGradientDescent implements Optimizer {
    private GradientDescentModule gradientDescentModule;
    private CostFunction costFunction;
    private final DMatrixD1 vectorInputToCostFunction = new DMatrixRMaj();
    private final DMatrixD1 optimalInput = new DMatrixRMaj();
    private double stepSize = 10.0d;
    private double learningRate = 0.9d;

    @Override // us.ihmc.robotics.optimization.Optimizer
    public void setCostFunction(CostFunction costFunction) {
        this.costFunction = costFunction;
    }

    private SingleQueryFunction createUnwrappedCostFunction(final CostFunction costFunction) {
        return new SingleQueryFunction() { // from class: us.ihmc.robotics.optimization.WrappedGradientDescent.1
            @Override // us.ihmc.robotics.numericalMethods.SingleQueryFunction
            public double getQuery(TDoubleArrayList tDoubleArrayList) {
                WrappedGradientDescent.convertArrayToMatrix(WrappedGradientDescent.this.vectorInputToCostFunction, tDoubleArrayList);
                return costFunction.calculate(WrappedGradientDescent.this.vectorInputToCostFunction);
            }
        };
    }

    private static void convertArrayToMatrix(DMatrixD1 dMatrixD1, TDoubleArrayList tDoubleArrayList) {
        dMatrixD1.setData(tDoubleArrayList.toArray());
        dMatrixD1.reshape(tDoubleArrayList.size(), 1);
    }

    private static void convertMatrixToArray(DMatrixD1 dMatrixD1, TDoubleArrayList tDoubleArrayList) {
        tDoubleArrayList.reset();
        tDoubleArrayList.addAll(dMatrixD1.data);
    }

    public void setInitialStepSize(double d) {
        this.stepSize = d;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    @Override // us.ihmc.robotics.optimization.Optimizer
    public DMatrixD1 stepOneIteration() {
        return null;
    }

    @Override // us.ihmc.robotics.optimization.Optimizer
    public DMatrixD1 optimize(DMatrixD1 dMatrixD1) {
        TDoubleArrayList tDoubleArrayList = new TDoubleArrayList();
        convertMatrixToArray(dMatrixD1, tDoubleArrayList);
        this.gradientDescentModule = new GradientDescentModule(createUnwrappedCostFunction(this.costFunction), tDoubleArrayList);
        this.gradientDescentModule.setStepSize(this.stepSize);
        this.gradientDescentModule.setReducingStepSizeRatio(1.0d / this.learningRate);
        this.gradientDescentModule.run();
        return getOptimalParameters();
    }

    @Override // us.ihmc.robotics.optimization.Optimizer
    public DMatrixD1 getOptimalParameters() {
        convertArrayToMatrix(this.optimalInput, this.gradientDescentModule.getOptimalInput());
        return this.optimalInput;
    }

    @Override // us.ihmc.robotics.optimization.Optimizer
    public double getOptimumCost() {
        return this.gradientDescentModule.getOptimalQuery();
    }
}
