package us.ihmc.robotics.functionApproximation.NeuralNetwork.importing;

import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import org.yaml.snakeyaml.Yaml;
import us.ihmc.robotics.functionApproximation.NeuralNetwork.NeuralNetwork;

/* loaded from: input_file:us/ihmc/robotics/functionApproximation/NeuralNetwork/importing/NeuralNetworkYamlHelper.class */
public class NeuralNetworkYamlHelper {
    private NeuralNetworkYamlHelper() {
    }

    public static NeuralNetwork createNeuralNetworkFromYamlFile(InputStream inputStream) {
        return new NeuralNetwork((NeuralNetworkConfiguration) new Yaml().load(inputStream));
    }

    public static void saveNeuralNetworkConfigurationToYamlFile(NeuralNetworkConfiguration neuralNetworkConfiguration, String str) {
        String dump = new Yaml().dump(neuralNetworkConfiguration);
        try {
            FileWriter fileWriter = new FileWriter(str);
            fileWriter.write(dump);
            fileWriter.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v11, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v24, types: [double[][], double[][][]] */
    public static void main(String[] strArr) throws IOException {
        NeuralNetworkConfiguration neuralNetworkConfiguration = new NeuralNetworkConfiguration();
        neuralNetworkConfiguration.setInputVariableNames(new String[]{"LEFT_ANKLE_PITCH_joint_q", "LEFT_ANKLE_PITCH_joint_q_previous", "LEFT_ANKLE_PITCH_joint_qd", "LEFT_ANKLE_PITCH_joint_qd_previous", "LEFT_ANKLE_PITCH_joint_qdd", "LEFT_ANKLE_PITCH_joint_qdd_previous", "LEFT_ANKLE_PITCH_tauMeasured_previous", "LEFT_KNEE_PITCH_joint_q", "LEFT_KNEE_PITCH_joint_q_previous", "LEFT_KNEE_PITCH_joint_qd", "LEFT_KNEE_PITCH_joint_qd_previous", "LEFT_KNEE_PITCH_joint_qdd", "LEFT_KNEE_PITCH_joint_qdd_previous", "LEFT_KNEE_PITCH_tauMeasured_previous"});
        neuralNetworkConfiguration.setActivationFunctionsPerLayer(new String[]{"", "RELU", ""});
        neuralNetworkConfiguration.setBias(new double[]{new double[16], new double[]{0.8374003d, 0.03701664d, -0.28729317d, 0.01319704d}, new double[]{-0.3590592d}});
        neuralNetworkConfiguration.setNumberOfNeuronsPerLayer(new int[]{16, 4, 1});
        neuralNetworkConfiguration.setWeights(new double[][]{new double[16][1], new double[]{new double[]{-0.46510178d, -0.05915822d, 0.3717158d, -0.45220408d, 0.1210492d, -0.12430278d, -0.29257497d, 0.3202149d, 0.10641811d, 0.45007122d, 0.24379735d, -0.13266943d, -0.1242077d, 0.12984337d, 0.62552196d, 0.3651869d}, new double[]{-0.03552369d, 0.7017831d, -0.16389109d, 0.2945423d, -0.10112676d, 0.10519658d, 0.7274673d, -0.2770513d, -0.5848242d, -0.08602782d, 0.1309489d, -0.2813644d, 0.09275991d, -0.099305d, 1.612328d, -0.42894414d}, new double[]{0.15122245d, 0.5346764d, 0.01194376d, 0.33071777d, -0.05709987d, 0.05674522d, -0.5670206d, 0.0768759d, -0.31645796d, -0.13648617d, 0.38817182d, -0.31028637d, 0.02570818d, -0.03258159d, 0.1956599d, -0.17547709d}, new double[]{-0.50018334d, -0.5318647d, -0.24139157d, 0.05813099d, 0.15277725d, -0.15938585d, -0.71538013d, 0.7469889d, 0.17892408d, 0.76543343d, 0.12041978d, 0.10104298d, -0.15151536d, 0.160927d, -2.4808238d, 0.554581d}}, new double[]{new double[]{0.46576336d, 0.41093794d, -0.01115055d, -0.27883196d}}});
        saveNeuralNetworkConfigurationToYamlFile(neuralNetworkConfiguration, "testNNParam.yaml");
        NeuralNetwork createNeuralNetworkFromYamlFile = createNeuralNetworkFromYamlFile(new FileInputStream("testNNParam.yaml"));
        createNeuralNetworkFromYamlFile.setInput(new double[]{0.7041736d, 0.7041736d, -0.82175085d, -0.7660606d, -55.92842308d, -55.55744086d, -0.19960167d, 0.20344828d, 0.48069055d, 0.48069055d, 0.07482435d, 0.0718136d, 3.15835607d, 2.86869009d, 1.14953754d, -1.47241379d});
        double[] dArr = new double[1];
        createNeuralNetworkFromYamlFile.compute(dArr);
        System.out.println(dArr[0]);
    }
}
