package us.ihmc.robotics.optimization;

import gnu.trove.iterator.TIntIterator;
import gnu.trove.list.array.TIntArrayList;
import java.util.Random;
import java.util.function.Function;
import org.ejml.MatrixDimensionException;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import us.ihmc.commons.Conversions;
import us.ihmc.commons.RandomNumbers;
import us.ihmc.euclid.transform.RigidBodyTransform;
import us.ihmc.euclid.transform.interfaces.RigidBodyTransformReadOnly;
import us.ihmc.matrixlib.MatrixTools;

/* loaded from: input_file:us/ihmc/robotics/optimization/LevenbergMarquardtParameterOptimizer.class */
public class LevenbergMarquardtParameterOptimizer {
    private static final boolean DEBUG = false;
    private int inputDimension;
    private final Function<DMatrixRMaj, RigidBodyTransform> inputFunction;
    private final OutputCalculator outputCalculator;
    private final DMatrixRMaj currentInput;
    private final OutputSpace currentOutputSpace;
    private final DMatrixRMaj perturbationVector;
    private final DMatrixRMaj perturbedInput;
    private final DMatrixRMaj jacobian;
    private final DMatrixRMaj squaredJacobian;
    private final DMatrixRMaj dampingMatrix;
    private final DMatrixRMaj invMultJacobianTranspose;
    private final DMatrixRMaj optimizeDirection;
    private static final double DEFAULT_PERTURBATION = 1.0E-4d;
    private static final double DEFAULT_DAMPING_COEFFICIENT = 0.001d;
    private int iteration;
    private int numberOfCorrespondences;
    private boolean optimized;
    private double correspondenceThreshold = 1.0d;
    private boolean useDamping = true;
    private int maximumNumberOfCorrespondences = Integer.MAX_VALUE;

    /* loaded from: input_file:us/ihmc/robotics/optimization/LevenbergMarquardtParameterOptimizer$OutputSpace.class */
    private class OutputSpace {
        private final DMatrixRMaj output;
        private DMatrixRMaj correspondingOutput;
        private final boolean[] correspondence;
        private final TIntArrayList correspondingIndices = new TIntArrayList();
        private double correspondingQuality;
        private double quality;

        private OutputSpace(int i) {
            this.output = new DMatrixRMaj(i, 1);
            this.correspondence = new boolean[i];
        }

        void updateOutputSpace(DMatrixRMaj dMatrixRMaj) {
            this.output.set(dMatrixRMaj);
        }

        boolean computeCorrespondence() {
            this.correspondingIndices.clear();
            for (int i = 0; i < this.output.getNumRows(); i++) {
                if (this.output.get(i, 0) < LevenbergMarquardtParameterOptimizer.this.correspondenceThreshold) {
                    this.correspondence[i] = true;
                    this.correspondingIndices.add(i);
                } else {
                    this.correspondence[i] = false;
                }
            }
            randomlySampleCorrespondences(this.correspondingIndices, LevenbergMarquardtParameterOptimizer.this.maximumNumberOfCorrespondences);
            this.correspondingOutput = new DMatrixRMaj(this.correspondingIndices.size(), 1);
            int i2 = 0;
            TIntIterator it = this.correspondingIndices.iterator();
            while (it.hasNext()) {
                int i3 = i2;
                i2++;
                this.correspondingOutput.set(i3, 0, this.output.get(it.next()));
            }
            return this.correspondingIndices.size() != 0;
        }

        private void randomlySampleCorrespondences(TIntArrayList tIntArrayList, int i) {
            Random random = new Random();
            while (tIntArrayList.size() > i) {
                tIntArrayList.remove(RandomNumbers.nextInt(random, 0, tIntArrayList.size() - 1));
            }
        }

        void computeQuality() {
            this.correspondingQuality = 0.0d;
            this.quality = 0.0d;
            for (int i = 0; i < this.output.getNumRows(); i++) {
                double d = this.output.get(i, 0) * this.output.get(i, 0);
                this.quality += d;
                if (this.correspondence[i]) {
                    this.correspondingQuality += d;
                }
            }
        }

        DMatrixRMaj getOutput() {
            return this.output;
        }

        DMatrixRMaj getCorrespondingOutput() {
            return this.correspondingOutput;
        }

        int getNumberOfCorrespondingPoints() {
            return this.correspondingIndices.size();
        }

        double getCorrespondingQuality() {
            return this.correspondingQuality;
        }

        double getQuality() {
            return this.quality;
        }
    }

    public LevenbergMarquardtParameterOptimizer(Function<DMatrixRMaj, RigidBodyTransform> function, OutputCalculator outputCalculator, int i, int i2) {
        this.inputFunction = function;
        this.inputDimension = i;
        this.outputCalculator = outputCalculator;
        this.currentInput = new DMatrixRMaj(i, 1);
        this.currentOutputSpace = new OutputSpace(i2);
        this.perturbationVector = new DMatrixRMaj(i, 1);
        CommonOps_DDRM.fill(this.perturbationVector, DEFAULT_PERTURBATION);
        this.perturbedInput = new DMatrixRMaj(i, 1);
        this.jacobian = new DMatrixRMaj(i2, i);
        this.squaredJacobian = new DMatrixRMaj(i, i);
        this.dampingMatrix = new DMatrixRMaj(i, i);
        this.invMultJacobianTranspose = new DMatrixRMaj(i, i2);
        this.optimizeDirection = new DMatrixRMaj(i, 1);
    }

    public void setPerturbationVector(DMatrixRMaj dMatrixRMaj) {
        if (this.perturbationVector.getNumCols() != dMatrixRMaj.getNumCols()) {
            throw new MatrixDimensionException("dimension is wrong. " + this.perturbationVector.getNumCols() + " " + dMatrixRMaj.getNumCols());
        }
        this.perturbationVector.set(dMatrixRMaj);
    }

    public void setCorrespondenceThreshold(double d) {
        this.correspondenceThreshold = d;
    }

    public void setMaximumNumberOfCorrespondences(int i) {
        this.maximumNumberOfCorrespondences = i;
    }

    public void setInitialOptimalGuess(DMatrixRMaj dMatrixRMaj) {
        this.currentInput.set(dMatrixRMaj);
    }

    public boolean initialize() {
        this.iteration = 0;
        this.optimized = false;
        MatrixTools.setDiagonal(this.dampingMatrix, DEFAULT_DAMPING_COEFFICIENT);
        this.outputCalculator.resetIndicesToCompute();
        this.currentOutputSpace.updateOutputSpace((DMatrixRMaj) this.outputCalculator.apply(this.currentInput));
        boolean computeCorrespondence = this.currentOutputSpace.computeCorrespondence();
        this.currentOutputSpace.computeQuality();
        return computeCorrespondence;
    }

    public double iterate() {
        this.iteration++;
        long nanoTime = System.nanoTime();
        if (this.currentOutputSpace.getNumberOfCorrespondingPoints() < 1) {
            return -1.0d;
        }
        this.outputCalculator.setIndicesToCompute(this.currentOutputSpace.correspondingIndices);
        this.numberOfCorrespondences = this.currentOutputSpace.getNumberOfCorrespondingPoints();
        this.jacobian.reshape(this.numberOfCorrespondences, this.inputDimension);
        this.invMultJacobianTranspose.reshape(this.inputDimension, this.numberOfCorrespondences);
        this.perturbedInput.set(this.currentInput);
        for (int i = 0; i < this.inputDimension; i++) {
            this.perturbedInput.add(i, 0, this.perturbationVector.get(i));
            DMatrixRMaj dMatrixRMaj = (DMatrixRMaj) this.outputCalculator.apply(this.perturbedInput);
            DMatrixRMaj correspondingOutput = this.currentOutputSpace.getCorrespondingOutput();
            for (int i2 = 0; i2 < this.numberOfCorrespondences; i2++) {
                this.jacobian.set(i2, i, (dMatrixRMaj.get(i2) - correspondingOutput.get(i2)) / this.perturbationVector.get(i));
            }
            this.perturbedInput.add(i, 0, -this.perturbationVector.get(i));
        }
        CommonOps_DDRM.multInner(this.jacobian, this.squaredJacobian);
        if (this.useDamping) {
            CommonOps_DDRM.addEquals(this.squaredJacobian, this.dampingMatrix);
        }
        CommonOps_DDRM.invert(this.squaredJacobian);
        CommonOps_DDRM.multTransB(this.squaredJacobian, this.jacobian, this.invMultJacobianTranspose);
        CommonOps_DDRM.mult(this.invMultJacobianTranspose, this.currentOutputSpace.getCorrespondingOutput(), this.optimizeDirection);
        CommonOps_DDRM.subtractEquals(this.currentInput, this.optimizeDirection);
        Conversions.nanosecondsToSeconds(System.nanoTime() - nanoTime);
        this.outputCalculator.resetIndicesToCompute();
        this.currentOutputSpace.updateOutputSpace((DMatrixRMaj) this.outputCalculator.apply(this.currentInput));
        this.currentOutputSpace.computeCorrespondence();
        this.currentOutputSpace.computeQuality();
        return this.currentOutputSpace.getCorrespondingQuality();
    }

    public void convertInputToTransform(DMatrixRMaj dMatrixRMaj, RigidBodyTransform rigidBodyTransform) {
        if (dMatrixRMaj.getData().length != this.inputDimension) {
            throw new MatrixDimensionException("dimension is wrong. " + dMatrixRMaj.getData().length + " " + this.inputDimension);
        }
        rigidBodyTransform.set(this.inputFunction.apply(dMatrixRMaj));
    }

    public int getNumberOfCorrespondingPoints() {
        return this.numberOfCorrespondences;
    }

    public DMatrixRMaj getOptimalParameter() {
        return this.currentInput;
    }

    public boolean isSolved() {
        return this.optimized;
    }

    public double getQuality() {
        return this.currentOutputSpace.getCorrespondingQuality();
    }

    public double getPureQuality() {
        return this.currentOutputSpace.getQuality();
    }

    public int getIteration() {
        return this.iteration;
    }

    public static Function<DMatrixRMaj, RigidBodyTransform> createSpatialInputFunction(final boolean z) {
        return new Function<DMatrixRMaj, RigidBodyTransform>() { // from class: us.ihmc.robotics.optimization.LevenbergMarquardtParameterOptimizer.1
            @Override // java.util.function.Function
            public RigidBodyTransform apply(DMatrixRMaj dMatrixRMaj) {
                RigidBodyTransform rigidBodyTransform = new RigidBodyTransform();
                if (z) {
                    rigidBodyTransform.setRotationYawPitchRollAndZeroTranslation(dMatrixRMaj.get(5), dMatrixRMaj.get(4), dMatrixRMaj.get(3));
                } else {
                    rigidBodyTransform.setRotationYawAndZeroTranslation(dMatrixRMaj.get(3));
                }
                rigidBodyTransform.getTranslation().set(dMatrixRMaj.get(0), dMatrixRMaj.get(1), dMatrixRMaj.get(2));
                return rigidBodyTransform;
            }
        };
    }

    public static Function<RigidBodyTransformReadOnly, DMatrixRMaj> createInverseSpatialInputFunction(boolean z) {
        return rigidBodyTransformReadOnly -> {
            DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(z ? 6 : 4, 1);
            if (z) {
                dMatrixRMaj.set(3, rigidBodyTransformReadOnly.getRotation().getRoll());
                dMatrixRMaj.set(4, rigidBodyTransformReadOnly.getRotation().getPitch());
                dMatrixRMaj.set(5, rigidBodyTransformReadOnly.getRotation().getYaw());
            } else {
                dMatrixRMaj.set(3, rigidBodyTransformReadOnly.getRotation().getYaw());
            }
            rigidBodyTransformReadOnly.getTranslation().get(dMatrixRMaj);
            return dMatrixRMaj;
        };
    }
}
