package tokyo.peya.lib;

import java.util.Arrays;

/* loaded from: input_file:tokyo/peya/lib/LearnMath.class */
public class LearnMath {
    public static double sigmoid(double d) {
        return 1.0d / (1.0d + Math.exp(-d));
    }

    public static double sigmoidDef(double d) {
        return sigmoid(d) * (1.0d - sigmoid(d));
    }

    public static double swish(double d) {
        return d * sigmoid(d);
    }

    public static double swishDef(double d) {
        return swish(d) + (sigmoid(d) * (1.0d - swish(d)));
    }

    public static double step(double d) {
        return d >= 0.0d ? 1.0d : 0.0d;
    }

    public static double relu(double d) {
        return d * Math.max(d, 0.0d);
    }

    public static double reluDef(double d) {
        return 1.0d * d > 0.0d ? 1.0d : 0.0d;
    }

    public static double lrelu(double d) {
        return d >= 0.0d ? d : 0.01d * d;
    }

    public static double lrelu(double d, double d2) {
        return d >= 0.0d ? d : d2 * d;
    }

    public static double lreluDef(double d) {
        return d >= 0.0d ? 1.0d : 0.01d;
    }

    public static double lreluDef(double d, double d2) {
        if (d >= 0.0d) {
            return d2;
        }
        return 0.01d;
    }

    public static double elu(double d) {
        return d > 0.0d ? d : 1.0d * (Math.exp(d) - 1.0d);
    }

    public static double elu(double d, double d2) {
        return d > 0.0d ? d : d2 * (Math.exp(d) - 1.0d);
    }

    public static double eluDef(double d) {
        if (d > 0.0d) {
            return 1.0d;
        }
        return elu(d) + 1.0d;
    }

    public static double eluDef(double d, double d2) {
        if (d > 0.0d) {
            return 1.0d;
        }
        return elu(d, d2) + d2;
    }

    public static double selu(double d, double d2, double d3) {
        return d2 * d > 0.0d ? d : d3 * (Math.exp(d) - 1.0d);
    }

    public static double seluDef(double d, double d2, double d3) {
        if (d2 * d > 0.0d) {
            return 1.0d;
        }
        return d3 * Math.exp(d);
    }

    public static double tanH(double d) {
        return (Math.exp(d) - Math.exp(-d)) / (Math.exp(d) + Math.exp(-d));
    }

    public static double tanHDef(double d) {
        return 1.0d - Math.pow(tanH(d), 2.0d);
    }

    public static double softplus(double d) {
        return Math.log(1.0d + Math.exp(d));
    }

    public static double softplusDef(double d) {
        return 1.0d / (1.0d + Math.exp(-d));
    }

    public static double omega(double d) {
        return (4.0d * (d + 1.0d)) + (4.0d * Math.exp(2.0d * d)) + Math.exp(3.0d * d) + (Math.exp(d) * ((4.0d * d) + 6.0d));
    }

    public static double delta(double d) {
        return (2.0d * Math.exp(d)) + Math.exp(2.0d * d) + 2.0d;
    }

    public static double mish(double d) {
        return d * tanH(softplus(d));
    }

    public static double mishDef(double d) {
        return (Math.exp(d) * omega(d)) / Math.pow(delta(d), 2.0d);
    }

    public static double identity(double d) {
        return d;
    }

    public static double identityDef() {
        return 1.0d;
    }

    public static double[] softmax(double[] dArr) {
        double[] array = Arrays.stream(dArr).map(d -> {
            return Math.exp(d - Arrays.stream(dArr).max().getAsDouble());
        }).toArray();
        return Arrays.stream(array).map(d2 -> {
            return d2 / Arrays.stream(array).sum();
        }).toArray();
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [java.lang.Object[], double[], double[][]] */
    public static double[][] softmax(double[][] dArr) {
        ?? r0 = new double[dArr.length];
        Arrays.setAll((Object[]) r0, i -> {
            return softmax(dArr[i]);
        });
        return r0;
    }
}
