package greycatMLTest.neuralnet;

import greycat.ml.neuralnet.activation.Activation;
import greycat.ml.neuralnet.activation.Activations;
import greycat.ml.neuralnet.loss.Loss;
import greycat.ml.neuralnet.loss.Losses;
import greycat.ml.neuralnet.process.ExMatrix;
import greycat.ml.neuralnet.process.ProcessGraph;
import greycat.struct.DMatrix;
import greycat.struct.matrix.VolatileDMatrix;
import org.junit.Test;

/* loaded from: input_file:greycatMLTest/neuralnet/TestCalcGraph.class */
public class TestCalcGraph {
    private static double EPS = 1.0E-16d;

    @Test
    public void testcalc() {
        VolatileDMatrix empty = VolatileDMatrix.empty(2, 1);
        empty.set(0, 0, 0.05d);
        empty.set(1, 0, 0.1d);
        ExMatrix createFromW = ExMatrix.createFromW(empty);
        VolatileDMatrix empty2 = VolatileDMatrix.empty(2, 1);
        empty2.set(0, 0, 0.01d);
        empty2.set(1, 0, 0.99d);
        ExMatrix createFromW2 = ExMatrix.createFromW(empty2);
        VolatileDMatrix empty3 = VolatileDMatrix.empty(2, 2);
        empty3.set(0, 0, 0.15d);
        empty3.set(0, 1, 0.2d);
        empty3.set(1, 0, 0.25d);
        empty3.set(1, 1, 0.3d);
        ExMatrix createFromW3 = ExMatrix.createFromW(empty3);
        VolatileDMatrix empty4 = VolatileDMatrix.empty(2, 1);
        empty4.set(0, 0, 0.35d);
        empty4.set(1, 0, 0.35d);
        ExMatrix createFromW4 = ExMatrix.createFromW(empty4);
        VolatileDMatrix empty5 = VolatileDMatrix.empty(2, 2);
        empty5.set(0, 0, 0.4d);
        empty5.set(0, 1, 0.45d);
        empty5.set(1, 0, 0.5d);
        empty5.set(1, 1, 0.55d);
        ExMatrix createFromW5 = ExMatrix.createFromW(empty5);
        VolatileDMatrix empty6 = VolatileDMatrix.empty(2, 1);
        empty6.set(0, 0, 0.6d);
        empty6.set(1, 0, 0.6d);
        ExMatrix createFromW6 = ExMatrix.createFromW(empty6);
        Loss unit = Losses.getUnit(0);
        Activation unit2 = Activations.getUnit(1, (double[]) null);
        ProcessGraph processGraph = new ProcessGraph(true);
        testdouble(Losses.sumOfLosses(processGraph.applyLoss(unit, processGraph.activate(unit2, processGraph.add(processGraph.mul(createFromW5, processGraph.activate(unit2, processGraph.add(processGraph.mul(createFromW3, createFromW), createFromW4))), createFromW6)), createFromW2, true)), 0.2983711087600027d);
        processGraph.backpropagate();
        applyLearningRate(createFromW3, 0.5d);
        applyLearningRate(createFromW5, 0.5d);
        applyLearningRate(createFromW4, 0.5d);
        applyLearningRate(createFromW6, 0.5d);
        testdouble(createFromW3.get(0, 0), 0.1497807161327628d);
        testdouble(createFromW3.get(0, 1), 0.19956143226552567d);
        testdouble(createFromW3.get(1, 0), 0.24975114363236958d);
        testdouble(createFromW3.get(1, 1), 0.29950228726473915d);
        testdouble(createFromW4.get(0, 0), 0.3456143226552565d);
        testdouble(createFromW4.get(1, 0), 0.3450228726473914d);
        testdouble(createFromW5.get(0, 0), 0.35891647971788465d);
        testdouble(createFromW5.get(0, 1), 0.4086661860762334d);
        testdouble(createFromW5.get(1, 0), 0.5113012702387375d);
        testdouble(createFromW5.get(1, 1), 0.5613701211079891d);
        testdouble(createFromW6.get(0, 0), 0.5307507191857215d);
        testdouble(createFromW6.get(1, 0), 0.6190491182582781d);
    }

    private static void testdouble(double d, double d2) {
        if (Math.abs(d - d2) > EPS) {
            System.out.println("d1: " + d + " d2: " + d2);
            throw new RuntimeException("d1 != d2");
        }
    }

    private static void applyLearningRate(ExMatrix exMatrix, double d) {
        int length = exMatrix.length();
        DMatrix dw = exMatrix.getDw();
        for (int i = 0; i < length; i++) {
            exMatrix.unsafeSet(i, exMatrix.unsafeGet(i) - (d * dw.unsafeGet(i)));
        }
        dw.fill(0.0d);
    }
}
