package us.ihmc.robotics.math.trajectories.generators;

import gnu.trove.list.array.TDoubleArrayList;
import java.util.ArrayList;
import java.util.List;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import us.ihmc.commons.lists.RecyclingArrayList;
import us.ihmc.robotics.time.ExecutionTimer;
import us.ihmc.yoVariables.registry.YoRegistry;
import us.ihmc.yoVariables.variable.YoDouble;
import us.ihmc.yoVariables.variable.YoInteger;

/* loaded from: input_file:us/ihmc/robotics/math/trajectories/generators/TrajectoryPointOptimizer.class */
public class TrajectoryPointOptimizer {
    public static final int maxWaypoints = 200;
    public static final int maxIterations = 200;
    private static final double epsilon = 1.0E-7d;
    private static final double initialTimeGain = 0.001d;
    private static final double costEpsilon = 0.01d;
    public static final int coefficients = 4;
    private final YoRegistry registry;
    private final YoInteger dimensions;
    private final YoInteger nWaypoints;
    private final YoInteger intervals;
    private final YoInteger problemSize;
    private final YoInteger iteration;
    private final TDoubleArrayList x0;
    private final TDoubleArrayList x1;
    private final TDoubleArrayList xd0;
    private final TDoubleArrayList xd1;
    private final ArrayList<DMatrixRMaj> waypoints;
    private final MultiCubicSpline1DSolver solver;
    private final TDoubleArrayList w0;
    private final TDoubleArrayList w1;
    private final TDoubleArrayList wd0;
    private final TDoubleArrayList wd1;
    private final DMatrixRMaj intervalTimes;
    private final DMatrixRMaj saveIntervalTimes;
    private final TDoubleArrayList costs;
    private final RecyclingArrayList<DMatrixRMaj> x;
    private final DMatrixRMaj timeGradient;
    private final DMatrixRMaj timeUpdate;
    private final YoDouble timeGain;
    private final ExecutionTimer computeTimer;
    private final ExecutionTimer timeUpdateTimer;
    private final DMatrixRMaj tempCoeffs;
    private final DMatrixRMaj tempLine;

    public TrajectoryPointOptimizer(int i, YoRegistry yoRegistry) {
        this("", i, yoRegistry);
    }

    public TrajectoryPointOptimizer(String str, int i, YoRegistry yoRegistry) {
        this(str, i);
        yoRegistry.addChild(this.registry);
    }

    public TrajectoryPointOptimizer(int i) {
        this("", i);
    }

    public TrajectoryPointOptimizer(String str, int i) {
        this.waypoints = new ArrayList<>();
        this.solver = new MultiCubicSpline1DSolver();
        this.intervalTimes = new DMatrixRMaj(1, 1);
        this.saveIntervalTimes = new DMatrixRMaj(1, 1);
        this.costs = new TDoubleArrayList(201);
        this.x = new RecyclingArrayList<>(0, () -> {
            return new DMatrixRMaj(1, 1);
        });
        this.timeGradient = new DMatrixRMaj(1, 1);
        this.timeUpdate = new DMatrixRMaj(1, 1);
        this.tempCoeffs = new DMatrixRMaj(1, 1);
        this.tempLine = new DMatrixRMaj(1, 1);
        this.registry = new YoRegistry(str + getClass().getSimpleName());
        this.dimensions = new YoInteger(str + "Dimensions", this.registry);
        this.nWaypoints = new YoInteger(str + "NumberOfWaypoints", this.registry);
        this.intervals = new YoInteger(str + "NumberOfIntervals", this.registry);
        this.problemSize = new YoInteger(str + "ProblemSize", this.registry);
        this.iteration = new YoInteger(str + "Iteration", this.registry);
        this.computeTimer = new ExecutionTimer(str + "ComputeTimer", 0.0d, this.registry);
        this.timeUpdateTimer = new ExecutionTimer(str + "TimeUpdateTimer", 0.0d, this.registry);
        this.timeGain = new YoDouble(str + "TimeGain", this.registry);
        int max = Math.max(i, 0);
        this.dimensions.set(max);
        this.timeGain.set(initialTimeGain);
        this.x0 = new TDoubleArrayList(max);
        this.x1 = new TDoubleArrayList(max);
        this.xd0 = new TDoubleArrayList(max);
        this.xd1 = new TDoubleArrayList(max);
        for (int i2 = 0; i2 < max; i2++) {
            this.x0.add(0.0d);
            this.xd0.add(0.0d);
            this.x1.add(0.0d);
            this.xd1.add(0.0d);
        }
        for (int i3 = 0; i3 < 200; i3++) {
            this.waypoints.add(new DMatrixRMaj(max, 1));
        }
        this.w0 = new TDoubleArrayList(max);
        this.w1 = new TDoubleArrayList(max);
        this.wd0 = new TDoubleArrayList(max);
        this.wd1 = new TDoubleArrayList(max);
        clearWeights();
        this.tempCoeffs.reshape(4, 1);
    }

    public void clearWeights() {
        this.w0.fill(0, this.dimensions.getValue(), Double.POSITIVE_INFINITY);
        this.w1.fill(0, this.dimensions.getValue(), Double.POSITIVE_INFINITY);
        this.wd0.fill(0, this.dimensions.getValue(), Double.POSITIVE_INFINITY);
        this.wd1.fill(0, this.dimensions.getValue(), Double.POSITIVE_INFINITY);
    }

    public void setEndPoints(int i, double d, double d2, double d3, double d4) {
        if (i < 0 || i >= this.dimensions.getValue()) {
            throw new IllegalArgumentException("Illegal dimension, expected to be in [0, " + this.dimensions.getValue() + "[, but was: " + i);
        }
        this.x0.set(i, d);
        this.xd0.set(i, d2);
        this.x1.set(i, d3);
        this.xd1.set(i, d4);
    }

    public void setEndPoints(TDoubleArrayList tDoubleArrayList, TDoubleArrayList tDoubleArrayList2, TDoubleArrayList tDoubleArrayList3, TDoubleArrayList tDoubleArrayList4) {
        if (tDoubleArrayList.size() != this.dimensions.getIntegerValue()) {
            throw new RuntimeException("Unexpected Size of Input");
        }
        if (tDoubleArrayList2.size() != this.dimensions.getIntegerValue()) {
            throw new RuntimeException("Unexpected Size of Input");
        }
        if (tDoubleArrayList3.size() != this.dimensions.getIntegerValue()) {
            throw new RuntimeException("Unexpected Size of Input");
        }
        if (tDoubleArrayList4.size() != this.dimensions.getIntegerValue()) {
            throw new RuntimeException("Unexpected Size of Input");
        }
        for (int i = 0; i < this.dimensions.getIntegerValue(); i++) {
            this.x0.set(i, tDoubleArrayList.get(i));
            this.xd0.set(i, tDoubleArrayList2.get(i));
            this.x1.set(i, tDoubleArrayList3.get(i));
            this.xd1.set(i, tDoubleArrayList4.get(i));
        }
    }

    public void setEndPointWeights(int i, double d, double d2, double d3, double d4) {
        if (i < 0 || i >= this.dimensions.getValue()) {
            throw new IllegalArgumentException("Illegal dimension, expected to be in [0, " + this.dimensions.getValue() + "[, but was: " + i);
        }
        this.w0.set(i, d);
        this.wd0.set(i, d2);
        this.w1.set(i, d3);
        this.wd1.set(i, d4);
    }

    public void setEndPointWeights(TDoubleArrayList tDoubleArrayList, TDoubleArrayList tDoubleArrayList2, TDoubleArrayList tDoubleArrayList3, TDoubleArrayList tDoubleArrayList4) {
        if (tDoubleArrayList != null && tDoubleArrayList.size() != this.dimensions.getIntegerValue()) {
            throw new RuntimeException("Unexpected Size of Input");
        }
        if (tDoubleArrayList2 != null && tDoubleArrayList2.size() != this.dimensions.getIntegerValue()) {
            throw new RuntimeException("Unexpected Size of Input");
        }
        if (tDoubleArrayList3 != null && tDoubleArrayList3.size() != this.dimensions.getIntegerValue()) {
            throw new RuntimeException("Unexpected Size of Input");
        }
        if (tDoubleArrayList4 != null && tDoubleArrayList4.size() != this.dimensions.getIntegerValue()) {
            throw new RuntimeException("Unexpected Size of Input");
        }
        if (tDoubleArrayList == null) {
            this.w0.fill(0, this.dimensions.getValue(), Double.POSITIVE_INFINITY);
        }
        if (tDoubleArrayList2 == null) {
            this.wd0.fill(0, this.dimensions.getValue(), Double.POSITIVE_INFINITY);
        }
        if (tDoubleArrayList3 == null) {
            this.w1.fill(0, this.dimensions.getValue(), Double.POSITIVE_INFINITY);
        }
        if (tDoubleArrayList4 == null) {
            this.wd1.fill(0, this.dimensions.getValue(), Double.POSITIVE_INFINITY);
        }
        for (int i = 0; i < this.dimensions.getValue(); i++) {
            if (tDoubleArrayList != null) {
                this.w0.set(i, tDoubleArrayList.get(i));
            }
            if (tDoubleArrayList2 != null) {
                this.w1.set(i, tDoubleArrayList3.get(i));
            }
            if (tDoubleArrayList3 != null) {
                this.wd0.set(i, tDoubleArrayList2.get(i));
            }
            if (tDoubleArrayList4 != null) {
                this.wd1.set(i, tDoubleArrayList4.get(i));
            }
        }
    }

    public void setWaypoints(List<TDoubleArrayList> list) {
        if (list.size() > 200) {
            throw new RuntimeException("Too Many Waypoints");
        }
        this.nWaypoints.set(list.size());
        for (int i = 0; i < this.nWaypoints.getIntegerValue(); i++) {
            if (list.get(i).size() != this.dimensions.getIntegerValue()) {
                throw new RuntimeException("Unexpected Size of Input");
            }
            list.get(i).toArray(this.waypoints.get(i).data);
        }
    }

    public void compute() {
        compute(200);
    }

    public void compute(int i) {
        this.intervals.set(this.nWaypoints.getIntegerValue() + 1);
        this.intervalTimes.reshape(this.intervals.getValue(), 1);
        CommonOps_DDRM.fill(this.intervalTimes, 1.0d / this.intervals.getValue());
        computeInternal(i);
    }

    public void computeForFixedTime(TDoubleArrayList tDoubleArrayList) {
        compute(0, tDoubleArrayList);
    }

    public void compute(int i, TDoubleArrayList tDoubleArrayList) {
        this.intervals.set(this.nWaypoints.getIntegerValue() + 1);
        setIntervalTimes(tDoubleArrayList);
        computeInternal(i);
    }

    private void computeInternal(int i) {
        this.computeTimer.startMeasurement();
        this.timeGain.set(initialTimeGain);
        this.problemSize.set(this.dimensions.getIntegerValue() * 4 * this.intervals.getValue());
        this.costs.reset();
        this.costs.add(solveMinAcceleration());
        this.iteration.set(0);
        for (int i2 = 0; i2 < i && !doFullTimeUpdate(); i2++) {
        }
        this.computeTimer.stopMeasurement();
    }

    private void setIntervalTimes(TDoubleArrayList tDoubleArrayList) {
        this.intervalTimes.reshape(this.intervals.getValue(), 1);
        if (tDoubleArrayList.size() != this.nWaypoints.getValue()) {
            throw new RuntimeException("Unexpected number of waypoint times. Need " + this.nWaypoints.getValue() + ", got " + tDoubleArrayList.size() + ".");
        }
        int i = 0;
        while (i < this.intervals.getValue()) {
            double d = (i == this.nWaypoints.getValue() ? 1.0d : tDoubleArrayList.get(i)) - (i == 0 ? 0.0d : tDoubleArrayList.get(i - 1));
            if (d < 0.0d || d > 1.0d) {
                throw new RuntimeException("Time in this trajectory is from 0.0 to 1.0. Got invalid waypoint times:\n" + tDoubleArrayList.toString());
            }
            this.intervalTimes.set(i, d);
            i++;
        }
    }

    public boolean doFullTimeUpdate() {
        double d = this.costs.get(this.iteration.getIntegerValue());
        double computeTimeUpdate = computeTimeUpdate(d);
        this.costs.add(computeTimeUpdate);
        this.iteration.increment();
        return Math.abs(d - computeTimeUpdate) < costEpsilon;
    }

    private double computeTimeUpdate(double d) {
        this.timeUpdateTimer.startMeasurement();
        int integerValue = this.intervals.getIntegerValue();
        this.timeGradient.reshape(integerValue, 1);
        this.saveIntervalTimes.set(this.intervalTimes);
        for (int i = 0; i < integerValue; i++) {
            for (int i2 = 0; i2 < integerValue; i2++) {
                if (i2 == i) {
                    this.intervalTimes.add(i2, 0, epsilon);
                } else {
                    this.intervalTimes.add(i2, 0, (-1.0E-7d) / (integerValue - 1));
                }
            }
            this.timeGradient.set(i, (solveMinAcceleration() - d) / epsilon);
            this.intervalTimes.set(this.saveIntervalTimes);
        }
        CommonOps_DDRM.add(this.timeGradient, (-CommonOps_DDRM.elementSum(this.timeGradient)) / integerValue);
        for (int i3 = 0; i3 < 10; i3++) {
            double applyTimeUpdate = applyTimeUpdate();
            if (applyTimeUpdate <= d) {
                return applyTimeUpdate;
            }
            this.timeGain.set(this.timeGain.getDoubleValue() * 0.5d);
            this.intervalTimes.set(this.saveIntervalTimes);
        }
        double applyTimeUpdate2 = applyTimeUpdate();
        this.timeUpdateTimer.stopMeasurement();
        return applyTimeUpdate2;
    }

    private double applyTimeUpdate() {
        this.timeUpdate.set(this.timeGradient);
        CommonOps_DDRM.scale(-this.timeGain.getDoubleValue(), this.timeUpdate);
        double elementMaxAbs = CommonOps_DDRM.elementMaxAbs(this.timeUpdate);
        double elementMin = CommonOps_DDRM.elementMin(this.intervalTimes);
        if (elementMaxAbs > 0.4d * elementMin) {
            CommonOps_DDRM.scale((0.4d * elementMin) / elementMaxAbs, this.timeUpdate);
        }
        for (int i = 0; i < this.intervals.getIntegerValue(); i++) {
            this.intervalTimes.add(i, 0, this.timeUpdate.get(i));
        }
        return solveMinAcceleration();
    }

    private double solveMinAcceleration() {
        double d = 0.0d;
        this.x.clear();
        for (int i = 0; i < this.dimensions.getValue(); i++) {
            d += solveDimension(i, (DMatrixRMaj) this.x.add());
        }
        return d;
    }

    private double solveDimension(int i, DMatrixRMaj dMatrixRMaj) {
        this.solver.setEndpoints(this.x0.get(i), this.xd0.get(i), this.x1.get(i), this.xd1.get(i));
        this.solver.setEndpointWeights(this.w0.get(i), this.wd0.get(i), this.w1.get(i), this.wd1.get(i));
        this.solver.clearWaypoints();
        double d = 0.0d;
        for (int i2 = 0; i2 < this.nWaypoints.getValue(); i2++) {
            d += this.intervalTimes.get(i2);
            this.solver.addWaypoint(this.waypoints.get(i2).get(i), d);
        }
        return this.solver.solve(dMatrixRMaj);
    }

    public void getWaypointTimes(TDoubleArrayList tDoubleArrayList) {
        tDoubleArrayList.reset();
        for (int i = 0; i < this.nWaypoints.getIntegerValue(); i++) {
            if (i == 0) {
                tDoubleArrayList.add(this.intervalTimes.get(0));
            } else {
                tDoubleArrayList.add(tDoubleArrayList.get(i - 1) + this.intervalTimes.get(i));
            }
        }
    }

    public double getWaypointTime(int i) {
        if (i < 0) {
            throw new RuntimeException("Unexpected Waypoint Index");
        }
        if (i > this.nWaypoints.getIntegerValue() - 1) {
            throw new RuntimeException("Unexpected Waypoint Index");
        }
        double d = this.intervalTimes.get(0);
        for (int i2 = 1; i2 < i + 1; i2++) {
            d += this.intervalTimes.get(i2);
        }
        return d;
    }

    public void getPolynomialCoefficients(List<TDoubleArrayList> list, int i) {
        if (list.size() != this.intervals.getIntegerValue()) {
            throw new RuntimeException("Unexpected Size of Output");
        }
        if (i > this.dimensions.getIntegerValue() - 1 || i < 0) {
            throw new RuntimeException("Unknown Dimension");
        }
        DMatrixRMaj dMatrixRMaj = (DMatrixRMaj) this.x.get(i);
        for (int i2 = 0; i2 < this.intervals.getIntegerValue(); i2++) {
            int i3 = i2 * 4;
            CommonOps_DDRM.extract(dMatrixRMaj, i3, i3 + 4, 0, 1, this.tempCoeffs, 0, 0);
            list.get(i2).reset();
            list.get(i2).add(this.tempCoeffs.getData());
        }
    }

    public void getWaypointVelocity(TDoubleArrayList tDoubleArrayList, int i) {
        double waypointTime = getWaypointTime(i);
        this.tempLine.reshape(1, 4);
        MultiCubicSpline1DSolver.getVelocityConstraintABlock(waypointTime, 0, 0, this.tempLine);
        tDoubleArrayList.reset();
        for (int i2 = 0; i2 < this.dimensions.getIntegerValue(); i2++) {
            int i3 = i * 4;
            CommonOps_DDRM.extract((DMatrixRMaj) this.x.get(i2), i3, i3 + 4, 0, 1, this.tempCoeffs, 0, 0);
            tDoubleArrayList.add(CommonOps_DDRM.dot(this.tempCoeffs, this.tempLine));
        }
    }

    public void getStartPosition(TDoubleArrayList tDoubleArrayList) {
        this.tempLine.reshape(1, 4);
        MultiCubicSpline1DSolver.getPositionConstraintABlock(0.0d, 0, 0, this.tempLine);
        tDoubleArrayList.reset();
        for (int i = 0; i < this.dimensions.getIntegerValue(); i++) {
            CommonOps_DDRM.extract((DMatrixRMaj) this.x.get(i), 0, 4, 0, 1, this.tempCoeffs, 0, 0);
            tDoubleArrayList.add(CommonOps_DDRM.dot(this.tempCoeffs, this.tempLine));
        }
    }

    public void getStartVelocity(TDoubleArrayList tDoubleArrayList) {
        this.tempLine.reshape(1, 4);
        MultiCubicSpline1DSolver.getVelocityConstraintABlock(0.0d, 0, 0, this.tempLine);
        tDoubleArrayList.reset();
        for (int i = 0; i < this.dimensions.getIntegerValue(); i++) {
            CommonOps_DDRM.extract((DMatrixRMaj) this.x.get(i), 0, 4, 0, 1, this.tempCoeffs, 0, 0);
            tDoubleArrayList.add(CommonOps_DDRM.dot(this.tempCoeffs, this.tempLine));
        }
    }

    public void getTargetPosition(TDoubleArrayList tDoubleArrayList) {
        this.tempLine.reshape(1, 4);
        MultiCubicSpline1DSolver.getPositionConstraintABlock(1.0d, 0, 0, this.tempLine);
        tDoubleArrayList.reset();
        int value = this.nWaypoints.getValue() * 4;
        for (int i = 0; i < this.dimensions.getIntegerValue(); i++) {
            CommonOps_DDRM.extract((DMatrixRMaj) this.x.get(i), value, value + 4, 0, 1, this.tempCoeffs, 0, 0);
            tDoubleArrayList.add(CommonOps_DDRM.dot(this.tempCoeffs, this.tempLine));
        }
    }

    public void getTargetVelocity(TDoubleArrayList tDoubleArrayList) {
        this.tempLine.reshape(1, 4);
        MultiCubicSpline1DSolver.getVelocityConstraintABlock(1.0d, 0, 0, this.tempLine);
        tDoubleArrayList.reset();
        int value = this.nWaypoints.getValue() * 4;
        for (int i = 0; i < this.dimensions.getIntegerValue(); i++) {
            CommonOps_DDRM.extract((DMatrixRMaj) this.x.get(i), value, value + 4, 0, 1, this.tempCoeffs, 0, 0);
            tDoubleArrayList.add(CommonOps_DDRM.dot(this.tempCoeffs, this.tempLine));
        }
    }
}
