package us.ihmc.perception.neural;

import ai.onnxruntime.NodeInfo;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.nio.FloatBuffer;
import java.util.HashMap;
import java.util.Map;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.opencv.global.opencv_core;
import org.bytedeco.opencv.global.opencv_highgui;
import org.bytedeco.opencv.global.opencv_imgproc;
import org.bytedeco.opencv.opencv_core.Mat;
import org.bytedeco.opencv.opencv_core.Size;
import us.ihmc.commons.Conversions;
import us.ihmc.log.LogTools;
import us.ihmc.perception.logging.PerceptionDataLoader;
import us.ihmc.perception.logging.PerceptionLoggerConstants;
import us.ihmc.perception.tools.PerceptionDebugTools;
import us.ihmc.tools.IHMCCommonPaths;
import us.ihmc.tools.io.WorkspaceFile;
import us.ihmc.tools.io.WorkspaceResourceDirectory;

/* loaded from: input_file:us/ihmc/perception/neural/HeightMapAutoencoder.class */
public class HeightMapAutoencoder {
    private static final int IMAGE_HEIGHT = 201;
    private static final int IMAGE_WIDTH = 201;
    private WorkspaceResourceDirectory modelDirectory = new WorkspaceResourceDirectory(getClass(), "/weights/");
    private WorkspaceFile onnxFile = new WorkspaceFile(this.modelDirectory, "height_map_autoencoder.onnx");
    private OrtEnvironment onnxRuntimeEnvironment = OrtEnvironment.getEnvironment();
    private OrtSession.SessionOptions onnxRuntimeSessionOptions = new OrtSession.SessionOptions();
    private OrtSession onnxRuntimeSession;

    public HeightMapAutoencoder() {
        try {
            this.onnxRuntimeSession = this.onnxRuntimeEnvironment.createSession(this.onnxFile.getFilesystemFile().toString(), this.onnxRuntimeSessionOptions);
            for (Map.Entry entry : this.onnxRuntimeSession.getInputInfo().entrySet()) {
                LogTools.info("{}: {}", entry.getKey(), entry.getValue());
            }
            for (Map.Entry entry2 : this.onnxRuntimeSession.getOutputInfo().entrySet()) {
                LogTools.info("{}: {}", entry2.getKey(), entry2.getValue());
            }
            for (Map.Entry entry3 : this.onnxRuntimeSession.getInputInfo().entrySet()) {
                String str = (String) entry3.getKey();
                NodeInfo nodeInfo = (NodeInfo) entry3.getValue();
                LogTools.info("Input Name: {}", str);
                LogTools.info("Input Info: {}", nodeInfo.toString());
            }
        } catch (OrtException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    public Mat denoiseHeightMap(Mat mat, float f) {
        try {
            long nanoTime = System.nanoTime();
            Mat predict = predict(mat, f);
            LogTools.debug("Inference time: {} ms", Long.valueOf(Conversions.nanosecondsToMilliseconds(System.nanoTime() - nanoTime)));
            return predict;
        } catch (OrtException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    public Mat predict(Mat mat, float f) throws OrtException {
        if (mat.rows() != 201 || mat.cols() != 201) {
            throw new RuntimeException("Image height and width must be 201 and 201");
        }
        LogTools.debug("Image Input Size: {}x{}", Integer.valueOf(mat.rows()), Integer.valueOf(mat.cols()));
        Mat clone = mat.clone();
        Mat mat2 = new Mat(201, 201, opencv_core.CV_32FC1);
        clone.convertTo(mat2, opencv_core.CV_32FC1, 1.0d, 0.0d);
        FloatBuffer allocate = FloatBuffer.allocate(40401);
        for (int i = 0; i < 201; i++) {
            for (int i2 = 0; i2 < 201; i2++) {
                allocate.put((mat2.ptr(i, i2).getFloat() / 10000.0f) - f);
            }
        }
        allocate.rewind();
        String str = (String) this.onnxRuntimeSession.getInputNames().toArray()[0];
        long[] jArr = {1, 1, mat2.rows(), mat2.cols()};
        HashMap hashMap = new HashMap();
        hashMap.put(str, OnnxTensor.createTensor(this.onnxRuntimeEnvironment, allocate, jArr));
        OrtSession.Result run = this.onnxRuntimeSession.run(hashMap);
        Mat mat3 = new Mat(201, 201, opencv_core.CV_32FC1);
        Map.Entry entry = (Map.Entry) run.iterator().next();
        LogTools.debug("{}: {}", entry.getKey(), entry.getValue());
        OnnxTensor onnxTensor = (OnnxTensor) entry.getValue();
        float[][][][] fArr = (float[][][][]) onnxTensor.getValue();
        LogTools.debug("Output: {}", onnxTensor.getInfo());
        if (fArr[0][0].length != 201 && fArr[0][0][0].length != 201) {
            throw new RuntimeException("Output size must be 40401");
        }
        for (int i3 = 0; i3 < 201; i3++) {
            for (int i4 = 0; i4 < 201; i4++) {
                mat3.ptr(i3, i4).putFloat((fArr[0][0][i3][i4] + f) * 10000.0f);
            }
        }
        mat3.convertTo(mat3, opencv_core.CV_16UC1, 1.0d, 0.0d);
        return mat3;
    }

    public static void main(String[] strArr) throws OrtException {
        String path = IHMCCommonPaths.PERCEPTION_LOGS_DIRECTORY.resolve("20231023_131517_PerceptionLog.hdf5").toString();
        HeightMapAutoencoder heightMapAutoencoder = new HeightMapAutoencoder();
        PerceptionDataLoader perceptionDataLoader = new PerceptionDataLoader();
        perceptionDataLoader.openLogFile(path);
        BytePointer bytePointer = new BytePointer(1000000L);
        Mat mat = new Mat(201, 201, opencv_core.CV_16UC1);
        for (int i = 1; i < 100; i++) {
            perceptionDataLoader.loadCompressedDepth(PerceptionLoggerConstants.CROPPED_HEIGHT_MAP_NAME, i, bytePointer, mat);
            long nanoTime = System.nanoTime();
            Mat predict = heightMapAutoencoder.predict(mat, 0.0f);
            LogTools.info("Inference time: {} ms", Long.valueOf(Conversions.nanosecondsToMilliseconds(System.nanoTime() - nanoTime)));
            Mat clone = predict.clone();
            Mat mat2 = new Mat(clone.rows(), clone.cols(), opencv_core.CV_8UC3);
            PerceptionDebugTools.convertDepthCopyToColor(clone, mat2);
            Mat clone2 = mat.clone();
            Mat mat3 = new Mat(clone2.rows(), clone2.cols(), opencv_core.CV_8UC3);
            PerceptionDebugTools.convertDepthCopyToColor(clone2, mat3);
            Mat mat4 = new Mat(mat3.rows(), mat3.cols() * 2, opencv_core.CV_8UC3);
            opencv_core.hconcat(mat3, mat2, mat4);
            opencv_core.convertScaleAbs(mat4, mat4, 4.0d, 0.0d);
            opencv_imgproc.resize(mat4, mat4, new Size(2000, 1000));
            opencv_highgui.imshow("Display", mat4);
            if (opencv_highgui.waitKeyEx(0) == 113) {
                return;
            }
        }
    }
}
