package us.ihmc.avatar.slamTools;

import cern.colt.list.BooleanArrayList;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.Graphics;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.function.Function;
import javax.swing.JFrame;
import javax.swing.JPanel;
import org.ejml.EjmlUnitTests;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import us.ihmc.commons.RandomNumbers;
import us.ihmc.commons.thread.ThreadTools;
import us.ihmc.euclid.tools.EuclidCoreRandomTools;
import us.ihmc.euclid.tools.EuclidCoreTestTools;
import us.ihmc.euclid.transform.RigidBodyTransform;
import us.ihmc.euclid.tuple2D.Point2D;
import us.ihmc.euclid.tuple4D.Quaternion;
import us.ihmc.robotics.geometry.AngleTools;
import us.ihmc.robotics.optimization.LevenbergMarquardtParameterOptimizer;
import us.ihmc.robotics.optimization.OutputCalculator;

@Tag("point-cloud-drift-correction-test")
/* loaded from: input_file:us/ihmc/avatar/slamTools/LevenbergMarquardtICPTest.class */
public class LevenbergMarquardtICPTest {
    private XYPlaneDrawer drawer;
    private JFrame frame;
    private Function<DMatrixRMaj, RigidBodyTransform> inputFunction;
    private boolean visualize = false;
    private List<Point2D> fullModel = new ArrayList();
    private List<Point2D> data1 = new ArrayList();
    private List<Point2D> data2 = new ArrayList();
    private double innerCircleLong = 2.0d;
    private double innerCircleShort = 1.0d;
    private double outterCircleLong = 4.0d;
    private double outterCircleShort = 3.0d;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:us/ihmc/avatar/slamTools/LevenbergMarquardtICPTest$XYPlaneDrawer.class */
    public static class XYPlaneDrawer extends JPanel {
        private static final long serialVersionUID = 1;
        private static final int scale = 40;
        private final List<Point2D> pointCloud = new ArrayList();
        private final List<Color> pointCloudColors = new ArrayList();
        private final BooleanArrayList fillers = new BooleanArrayList();
        private final double xUpper;
        private final double yUpper;
        private final int sizeU;
        private final int sizeV;

        XYPlaneDrawer(double d, double d2, double d3, double d4) {
            this.xUpper = d;
            this.yUpper = d3;
            this.sizeU = (int) Math.round(((-d4) + d3) * 40.0d);
            this.sizeV = (int) Math.round(((-d2) + d) * 40.0d);
        }

        public void addPoint(Point2D point2D, Color color, boolean z) {
            this.pointCloud.add(point2D);
            this.pointCloudColors.add(color);
            this.fillers.add(z);
        }

        public void addPointCloud(List<Point2D> list, Color color, boolean z) {
            this.pointCloud.addAll(list);
            for (int i = 0; i < list.size(); i++) {
                this.pointCloudColors.add(color);
                this.fillers.add(z);
            }
        }

        public void paint(Graphics graphics) {
            super.paint(graphics);
            graphics.setColor(Color.red);
            graphics.drawLine(x2u(1.5d), y2v(0.0d), x2u(0.0d), y2v(0.0d));
            graphics.setColor(Color.green);
            graphics.drawLine(x2u(0.0d), y2v(1.5d), x2u(0.0d), y2v(0.0d));
            graphics.setColor(Color.black);
            pointFill(graphics, 0.0d, 0.0d, 4);
            for (int i = 0; i < this.pointCloud.size(); i++) {
                graphics.setColor(this.pointCloudColors.get(i));
                Point2D point2D = this.pointCloud.get(i);
                if (this.fillers.get(i)) {
                    pointFill(graphics, point2D.getX(), point2D.getY(), 4);
                } else {
                    point(graphics, point2D.getX(), point2D.getY(), 4);
                }
            }
        }

        public void point(Graphics graphics, double d, double d2, int i) {
            graphics.drawOval(x2u(d) - (i / 2), y2v(d2) - (i / 2), i, i);
        }

        public void pointFill(Graphics graphics, double d, double d2, int i) {
            graphics.fillOval(x2u(d) - (i / 2), y2v(d2) - (i / 2), i, i);
        }

        public int x2u(double d) {
            return (int) Math.round((d + this.xUpper) * 40.0d);
        }

        public int y2v(double d) {
            return (int) Math.round(((-d) + this.yUpper) * 40.0d);
        }

        public Dimension getDimension() {
            return new Dimension(this.sizeU, this.sizeV);
        }
    }

    private void setupPointCloud() {
        this.drawer = new XYPlaneDrawer(10.0d, -10.0d, 10.0d, -10.0d);
        this.frame = new JFrame("2D_LM_ICP_TEST");
        this.frame.setPreferredSize(this.drawer.getDimension());
        this.frame.setLocation(200, 100);
        this.fullModel.addAll(generatePointsOnEllipsoid(50, Math.toRadians(90.0d), Math.toRadians(359.9d), this.innerCircleLong, this.innerCircleShort));
        this.fullModel.addAll(generatePointsOnEllipsoid(70, Math.toRadians(90.0d), Math.toRadians(359.9d), this.outterCircleLong, this.outterCircleShort));
        this.fullModel.addAll(generatePointsOnLine(10, this.innerCircleShort, this.outterCircleShort, 0.0d, true));
        this.fullModel.addAll(generatePointsOnLine(10, this.innerCircleLong, this.outterCircleLong, 0.0d, false));
        this.data1.addAll(generatePointsOnEllipsoid(35, Math.toRadians(90.0d), Math.toRadians(200.0d), this.innerCircleLong, this.innerCircleShort));
        this.data1.addAll(generatePointsOnEllipsoid(50, Math.toRadians(90.0d), Math.toRadians(320.0d), this.outterCircleLong, this.outterCircleShort));
        this.data1.addAll(generatePointsOnLine(10, this.innerCircleShort, this.outterCircleShort, 0.0d, true));
        this.data2.addAll(generatePointsOnEllipsoid(50, Math.toRadians(130.0d), Math.toRadians(359.9d), this.innerCircleLong, this.innerCircleShort));
        this.data2.addAll(generatePointsOnEllipsoid(60, Math.toRadians(120.0d), Math.toRadians(359.9d), this.outterCircleLong, this.outterCircleShort));
        this.data2.addAll(generatePointsOnLine(10, this.innerCircleLong, this.outterCircleLong, 0.0d, false));
        this.inputFunction = new Function<DMatrixRMaj, RigidBodyTransform>() { // from class: us.ihmc.avatar.slamTools.LevenbergMarquardtICPTest.1
            @Override // java.util.function.Function
            public RigidBodyTransform apply(DMatrixRMaj dMatrixRMaj) {
                RigidBodyTransform rigidBodyTransform = new RigidBodyTransform();
                rigidBodyTransform.setRotationYawAndZeroTranslation(dMatrixRMaj.get(2));
                rigidBodyTransform.getTranslation().set(dMatrixRMaj.get(0), dMatrixRMaj.get(1), 0.0d);
                return rigidBodyTransform;
            }
        };
        Assertions.assertTrue(true);
    }

    @Test
    public void testForwardBackwardsInputConverter() {
        Function createSpatialInputFunction = LevenbergMarquardtParameterOptimizer.createSpatialInputFunction(false);
        Function createInverseSpatialInputFunction = LevenbergMarquardtParameterOptimizer.createInverseSpatialInputFunction(false);
        Random random = new Random(1738L);
        for (int i = 0; i < 1000; i++) {
            RigidBodyTransform nextRigidBodyTransform = EuclidCoreRandomTools.nextRigidBodyTransform(random);
            RigidBodyTransform rigidBodyTransform = (RigidBodyTransform) createSpatialInputFunction.apply((DMatrixRMaj) createInverseSpatialInputFunction.apply(nextRigidBodyTransform));
            EuclidCoreTestTools.assertVector3DGeometricallyEquals(nextRigidBodyTransform.getTranslation(), rigidBodyTransform.getTranslation(), 1.0E-7d);
            Assertions.assertTrue(AngleTools.computeAngleDifferenceMinusPiToPi(nextRigidBodyTransform.getRotation().getYaw(), rigidBodyTransform.getRotation().getYaw()) < 1.0E-7d);
            DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(4, 1);
            dMatrixRMaj.set(0, RandomNumbers.nextDouble(random, 100.0d));
            dMatrixRMaj.set(1, RandomNumbers.nextDouble(random, 100.0d));
            dMatrixRMaj.set(2, RandomNumbers.nextDouble(random, 100.0d));
            dMatrixRMaj.set(3, RandomNumbers.nextDouble(random, 3.141592653589793d));
            EjmlUnitTests.assertEquals(dMatrixRMaj, (DMatrixRMaj) createInverseSpatialInputFunction.apply((RigidBodyTransform) createSpatialInputFunction.apply(dMatrixRMaj)), 1.0E-7d);
        }
        Function createSpatialInputFunction2 = LevenbergMarquardtParameterOptimizer.createSpatialInputFunction(true);
        Function createInverseSpatialInputFunction2 = LevenbergMarquardtParameterOptimizer.createInverseSpatialInputFunction(true);
        for (int i2 = 0; i2 < 1000; i2++) {
            RigidBodyTransform nextRigidBodyTransform2 = EuclidCoreRandomTools.nextRigidBodyTransform(random);
            EuclidCoreTestTools.assertEquals(nextRigidBodyTransform2, (RigidBodyTransform) createSpatialInputFunction2.apply((DMatrixRMaj) createInverseSpatialInputFunction2.apply(nextRigidBodyTransform2)), 1.0E-7d);
            DMatrixRMaj dMatrixRMaj2 = new DMatrixRMaj(6, 1);
            dMatrixRMaj2.set(0, RandomNumbers.nextDouble(random, 100.0d));
            dMatrixRMaj2.set(1, RandomNumbers.nextDouble(random, 100.0d));
            dMatrixRMaj2.set(2, RandomNumbers.nextDouble(random, 100.0d));
            dMatrixRMaj2.set(3, RandomNumbers.nextDouble(random, 3.141592653589793d));
            dMatrixRMaj2.set(4, RandomNumbers.nextDouble(random, 3.141592653589793d));
            dMatrixRMaj2.set(5, RandomNumbers.nextDouble(random, 3.141592653589793d));
            DMatrixRMaj dMatrixRMaj3 = (DMatrixRMaj) createInverseSpatialInputFunction2.apply((RigidBodyTransform) createSpatialInputFunction2.apply(dMatrixRMaj2));
            EuclidCoreTestTools.assertOrientation3DGeometricallyEquals(new Quaternion(dMatrixRMaj2.get(3), dMatrixRMaj2.get(4), dMatrixRMaj2.get(5)), new Quaternion(dMatrixRMaj3.get(3), dMatrixRMaj3.get(4), dMatrixRMaj3.get(5)), 1.0E-7d);
        }
    }

    @Test
    public void testVisualization() {
        setupPointCloud();
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(3, 1);
        dMatrixRMaj.set(0, 1.0d);
        dMatrixRMaj.set(1, 2.0d);
        dMatrixRMaj.set(2, Math.toRadians(30.0d));
        transformPointCloud(this.data1, this.inputFunction.apply(dMatrixRMaj));
        DMatrixRMaj dMatrixRMaj2 = new DMatrixRMaj(3, 1);
        dMatrixRMaj2.set(0, -1.0d);
        dMatrixRMaj2.set(1, -2.0d);
        dMatrixRMaj2.set(2, Math.toRadians(-90.0d));
        transformPointCloud(this.data2, this.inputFunction.apply(dMatrixRMaj2));
        this.drawer.addPointCloud(this.fullModel, Color.black, false);
        this.drawer.addPointCloud(this.data1, Color.red, false);
        this.drawer.addPointCloud(this.data2, Color.green, false);
        if (this.visualize) {
            this.frame.add(this.drawer);
            this.frame.pack();
            this.frame.setVisible(true);
            ThreadTools.sleepForever();
        }
    }

    @Test
    public void testFindingClosestPointWithFullModel() {
        setupPointCloud();
        double radians = Math.toRadians(10.0d);
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(3, 1);
        dMatrixRMaj.set(0, 0.3d);
        dMatrixRMaj.set(1, 0.5d);
        dMatrixRMaj.set(2, radians);
        transformPointCloud(this.data1, this.inputFunction.apply(dMatrixRMaj));
        this.drawer.addPointCloud(this.fullModel, Color.black, false);
        this.drawer.addPointCloud(this.data1, Color.red, false);
        BooleanArrayList booleanArrayList = new BooleanArrayList();
        for (int i = 0; i < this.data1.size(); i++) {
            if (computeClosestDistance(this.data1.get(i), this.fullModel) < 0.2d) {
                booleanArrayList.add(true);
                this.drawer.addPoint(this.data1.get(i), Color.red, true);
            } else {
                booleanArrayList.add(false);
                this.drawer.addPoint(this.data1.get(i), Color.red, false);
            }
        }
        if (this.visualize) {
            this.frame.add(this.drawer);
            this.frame.pack();
            this.frame.setVisible(true);
            ThreadTools.sleepForever();
        }
    }

    @Test
    public void testErrorFunctionDerivationAndJacobian() {
        setupPointCloud();
        double radians = Math.toRadians(10.0d);
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(3, 1);
        dMatrixRMaj.set(0, 0.3d);
        dMatrixRMaj.set(1, 0.5d);
        dMatrixRMaj.set(2, radians);
        transformPointCloud(this.data1, this.inputFunction.apply(dMatrixRMaj));
        this.drawer.addPointCloud(this.fullModel, Color.black, false);
        this.drawer.addPointCloud(this.data1, Color.red, false);
        DMatrixRMaj dMatrixRMaj2 = new DMatrixRMaj(this.data1.size(), 1);
        BooleanArrayList booleanArrayList = new BooleanArrayList();
        for (int i = 0; i < this.data1.size(); i++) {
            double computeClosestDistance = computeClosestDistance(this.data1.get(i), this.fullModel);
            dMatrixRMaj2.set(i, computeClosestDistance);
            if (computeClosestDistance < 0.2d) {
                booleanArrayList.add(true);
                this.drawer.addPoint(this.data1.get(i), Color.red, true);
            } else {
                booleanArrayList.add(false);
                this.drawer.addPoint(this.data1.get(i), Color.red, false);
            }
        }
        DMatrixRMaj dMatrixRMaj3 = new DMatrixRMaj(3, 1);
        DMatrixRMaj dMatrixRMaj4 = new DMatrixRMaj(3, 1);
        DMatrixRMaj dMatrixRMaj5 = new DMatrixRMaj(this.data1.size(), 3);
        ArrayList arrayList = new ArrayList(this.data1);
        for (int i2 = 0; i2 < 3; i2++) {
            for (int i3 = 0; i3 < 3; i3++) {
                if (i3 == i2) {
                    dMatrixRMaj4.set(i3, 0, dMatrixRMaj3.get(i3, 0) + 0.001d);
                } else {
                    dMatrixRMaj4.set(i3, 0, dMatrixRMaj3.get(i3, 0));
                }
            }
            arrayList.clear();
            for (int i4 = 0; i4 < this.data1.size(); i4++) {
                arrayList.add(new Point2D(this.data1.get(i4)));
            }
            transformPointCloud(arrayList, this.inputFunction.apply(dMatrixRMaj4));
            for (int i5 = 0; i5 < this.data1.size(); i5++) {
                double computeClosestDistance2 = computeClosestDistance(arrayList.get(i5), this.fullModel);
                if (computeClosestDistance2 < 0.2d) {
                    dMatrixRMaj5.set(i5, i2, (computeClosestDistance2 - dMatrixRMaj2.get(i5, 0)) / 0.001d);
                } else {
                    dMatrixRMaj5.set(i5, i2, 0.0d);
                }
            }
        }
        DMatrixRMaj dMatrixRMaj6 = new DMatrixRMaj(this.data1.size(), 3);
        dMatrixRMaj6.set(dMatrixRMaj5);
        CommonOps_DDRM.transpose(dMatrixRMaj6);
        DMatrixRMaj dMatrixRMaj7 = new DMatrixRMaj(3, 3);
        CommonOps_DDRM.mult(dMatrixRMaj6, dMatrixRMaj5, dMatrixRMaj7);
        CommonOps_DDRM.invert(dMatrixRMaj7);
        DMatrixRMaj dMatrixRMaj8 = new DMatrixRMaj(3, this.data1.size());
        CommonOps_DDRM.mult(dMatrixRMaj7, dMatrixRMaj6, dMatrixRMaj8);
        DMatrixRMaj dMatrixRMaj9 = new DMatrixRMaj(3, 1);
        CommonOps_DDRM.mult(dMatrixRMaj8, dMatrixRMaj2, dMatrixRMaj9);
        System.out.println("direction of the optimization is,");
        dMatrixRMaj9.print();
        Assertions.assertTrue(dMatrixRMaj9.get(0) > 0.0d, "direction of the translation x     is correct.");
        Assertions.assertTrue(dMatrixRMaj9.get(1) > 0.0d, "direction of the translation y     is correct.");
        Assertions.assertTrue(dMatrixRMaj9.get(2) > 0.0d, "direction of the translation theta is correct.");
        for (int i6 = 0; i6 < dMatrixRMaj9.data.length; i6++) {
            dMatrixRMaj9.set(i6, -dMatrixRMaj9.data[i6]);
        }
        arrayList.clear();
        for (int i7 = 0; i7 < this.data1.size(); i7++) {
            arrayList.add(new Point2D(this.data1.get(i7)));
        }
        transformPointCloud(arrayList, this.inputFunction.apply(dMatrixRMaj9));
        for (int i8 = 0; i8 < this.data1.size(); i8++) {
            this.drawer.addPoint(arrayList.get(i8), Color.green, true);
        }
        if (this.visualize) {
            this.frame.add(this.drawer);
            this.frame.pack();
            this.frame.setVisible(true);
            ThreadTools.sleepForever();
        }
    }

    @Test
    public void testIteration() {
        setupPointCloud();
        double radians = Math.toRadians(-30.0d);
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(3, 1);
        dMatrixRMaj.set(0, -0.3d);
        dMatrixRMaj.set(1, 0.5d);
        dMatrixRMaj.set(2, radians);
        transformPointCloud(this.fullModel, this.inputFunction.apply(dMatrixRMaj));
        this.drawer.addPointCloud(this.fullModel, Color.black, false);
        this.drawer.addPointCloud(this.data1, Color.red, false);
        LevenbergMarquardtParameterOptimizer levenbergMarquardtParameterOptimizer = new LevenbergMarquardtParameterOptimizer(this.inputFunction, new OutputCalculator() { // from class: us.ihmc.avatar.slamTools.LevenbergMarquardtICPTest.2
            public DMatrixRMaj apply(DMatrixRMaj dMatrixRMaj2) {
                ArrayList arrayList = new ArrayList();
                for (int i = 0; i < LevenbergMarquardtICPTest.this.data1.size(); i++) {
                    arrayList.add(new Point2D(LevenbergMarquardtICPTest.this.data1.get(i)));
                }
                LevenbergMarquardtICPTest.this.transformPointCloud(arrayList, LevenbergMarquardtICPTest.this.inputFunction.apply(dMatrixRMaj2));
                DMatrixRMaj dMatrixRMaj3 = new DMatrixRMaj(arrayList.size(), 1);
                for (int i2 = 0; i2 < arrayList.size(); i2++) {
                    dMatrixRMaj3.set(i2, LevenbergMarquardtICPTest.this.computeClosestDistance((Point2D) arrayList.get(i2), LevenbergMarquardtICPTest.this.fullModel));
                }
                return dMatrixRMaj3;
            }
        }, 3, this.data1.size());
        DMatrixRMaj dMatrixRMaj2 = new DMatrixRMaj(3, 1);
        dMatrixRMaj2.set(0, 1.0E-5d);
        dMatrixRMaj2.set(1, 1.0E-5d);
        dMatrixRMaj2.set(2, 1.0E-5d);
        levenbergMarquardtParameterOptimizer.setPerturbationVector(dMatrixRMaj2);
        boolean z = false;
        int i = 0;
        while (true) {
            if (i >= 30) {
                break;
            }
            levenbergMarquardtParameterOptimizer.iterate();
            if (levenbergMarquardtParameterOptimizer.getQuality() < 0.4d) {
                z = true;
                break;
            } else {
                System.out.println(i + " " + levenbergMarquardtParameterOptimizer.getQuality());
                i++;
            }
        }
        System.out.println("is solved? " + z + " " + levenbergMarquardtParameterOptimizer.getIteration() + " " + levenbergMarquardtParameterOptimizer.getQuality());
        levenbergMarquardtParameterOptimizer.getOptimalParameter().print();
        DMatrixRMaj optimalParameter = levenbergMarquardtParameterOptimizer.getOptimalParameter();
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < this.data1.size(); i2++) {
            arrayList.add(new Point2D(this.data1.get(i2)));
        }
        transformPointCloud(arrayList, this.inputFunction.apply(optimalParameter));
        Point2D point2D = new Point2D(0.0d, this.innerCircleLong);
        Point2D point2D2 = new Point2D(point2D);
        Point2D point2D3 = new Point2D(point2D);
        transformPoint(point2D2, this.inputFunction.apply(dMatrixRMaj));
        transformPoint(point2D3, this.inputFunction.apply(optimalParameter));
        Assertions.assertTrue(point2D3.distance(point2D2) < 0.06d, "a point on the drifted doughnut corrected with icp. " + point2D3.distance(point2D2));
        this.drawer.addPointCloud(arrayList, Color.green, true);
        if (this.visualize) {
            this.frame.add(this.drawer);
            this.frame.pack();
            this.frame.setVisible(true);
            ThreadTools.sleepForever();
        }
    }

    private double computeClosestDistance(Point2D point2D, List<Point2D> list) {
        double d = Double.MAX_VALUE;
        for (int i = 0; i < list.size(); i++) {
            double distance = point2D.distance(list.get(i));
            if (distance < d) {
                d = distance;
            }
        }
        return d;
    }

    private void transformPointCloud(List<Point2D> list, RigidBodyTransform rigidBodyTransform) {
        list.forEach(point2D -> {
            transformPoint(point2D, rigidBodyTransform);
        });
    }

    private void transformPoint(Point2D point2D, RigidBodyTransform rigidBodyTransform) {
        rigidBodyTransform.transform(point2D);
    }

    private List<Point2D> generatePointsOnLine(int i, double d, double d2, double d3, boolean z) {
        double d4;
        double d5;
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            if (z) {
                d4 = d3;
                d5 = d + (((d2 - d) * i2) / (i - 1));
            } else {
                d4 = d + (((d2 - d) * i2) / (i - 1));
                d5 = d3;
            }
            arrayList.add(new Point2D(d4, d5));
        }
        return arrayList;
    }

    private List<Point2D> generatePointsOnEllipsoid(int i, double d, double d2, double d3, double d4) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            double d5 = d + (((d2 - d) * i2) / (i - 1));
            double sin = Math.sin(d5);
            double cos = Math.cos(d5);
            double sqrt = Math.sqrt((((d3 * d3) * d4) * d4) / (((d4 * d4) - (((d4 * d4) * sin) * sin)) + (((d3 * d3) * sin) * sin)));
            arrayList.add(new Point2D(sqrt * cos, sqrt * sin));
        }
        return arrayList;
    }
}
