package org.nd4j.linalg.ops.transforms;

import java.util.ArrayList;
import java.util.Arrays;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.accum.distances.CosineSimilarity;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMax;
import org.nd4j.linalg.api.ops.impl.transforms.Abs;
import org.nd4j.linalg.api.ops.impl.transforms.Ceil;
import org.nd4j.linalg.api.ops.impl.transforms.Exp;
import org.nd4j.linalg.api.ops.impl.transforms.Floor;
import org.nd4j.linalg.api.ops.impl.transforms.HardTanh;
import org.nd4j.linalg.api.ops.impl.transforms.Identity;
import org.nd4j.linalg.api.ops.impl.transforms.Log;
import org.nd4j.linalg.api.ops.impl.transforms.Negative;
import org.nd4j.linalg.api.ops.impl.transforms.Pow;
import org.nd4j.linalg.api.ops.impl.transforms.Round;
import org.nd4j.linalg.api.ops.impl.transforms.Sigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.Sign;
import org.nd4j.linalg.api.ops.impl.transforms.Sqrt;
import org.nd4j.linalg.api.ops.impl.transforms.Stabilize;
import org.nd4j.linalg.api.ops.impl.transforms.Tanh;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.GreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.LessThanOrEqual;
import org.nd4j.linalg.convolution.Convolution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/nd4j/linalg/ops/transforms/Transforms.class */
public class Transforms {
    static final /* synthetic */ boolean $assertionsDisabled;

    public static INDArray maxPool(INDArray iNDArray, int[] iArr, boolean z) {
        if (!$assertionsDisabled && iNDArray.length() < 2) {
            throw new AssertionError("Max pooling requires an ndarray of >= length 2");
        }
        if (!$assertionsDisabled && iArr.length != 2) {
            throw new AssertionError("Down sampling must be of length 2 (the factors used for each image size");
        }
        if (!$assertionsDisabled && iNDArray.shape().length != 4) {
            throw new AssertionError("Only supports 4 dimensional tensors");
        }
        int prod = ArrayUtil.prod(new int[]{iNDArray.size(0) * iNDArray.size(1)});
        int size = iNDArray.size(2);
        int size2 = iNDArray.size(3);
        INDArray reshape = iNDArray.reshape(prod, 1, size, size2);
        INDArray create = Nd4j.create(reshape.shape());
        int pow = z ? (int) (size / Math.pow(iArr[0], 2.0d)) : size;
        int pow2 = z ? (int) (size2 / Math.pow(iArr[1], 2.0d)) : size2;
        int max = Math.max(1, pow);
        int max2 = Math.max(1, pow2);
        for (int i = 0; i < reshape.size(0); i++) {
            for (int i2 = 0; i2 < reshape.size(1); i2++) {
                for (int i3 = 0; i3 < max; i3++) {
                    int i4 = i3 / iArr[0];
                    for (int i5 = 0; i5 < max2; i5++) {
                        int i6 = i5 / iArr[1];
                        create.putScalar(new int[]{i, i2, i4, i6}, Math.max(iNDArray.getDouble(i, i2, i3, i5), create.getDouble(i, i2, i4, i6)));
                    }
                }
            }
        }
        return create.reshape(reshape.shape());
    }

    public static INDArray downSample(INDArray iNDArray, int[] iArr) {
        INDArray ones = Nd4j.ones(iArr);
        ones.divi(Integer.valueOf(ArrayUtil.prod(iArr)));
        if (iArr.length != iNDArray.shape().length) {
            if (iArr.length > iNDArray.shape().length) {
                int[] iArr2 = new int[iArr.length];
                Arrays.fill(iArr2, 1);
                int abs = Math.abs(ones.shape().length - iArr2.length);
                for (int length = iArr2.length - 1; length >= abs; length--) {
                    iArr2[length] = ones.shape()[length - abs];
                }
                iNDArray = iNDArray.reshape(iArr2);
            } else {
                int[] iArr3 = new int[iNDArray.shape().length];
                Arrays.fill(iArr3, 1);
                int abs2 = Math.abs(ones.shape().length - iArr3.length);
                for (int length2 = iArr3.length - 1; length2 >= abs2; length2--) {
                    iArr3[length2] = ones.shape()[length2 - abs2];
                }
                ones = ones.reshape(iArr3);
            }
        }
        INDArray convn = Convolution.convn(iNDArray, ones, Convolution.Type.VALID);
        INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[iNDArray.shape().length];
        for (int i = 0; i < iNDArrayIndexArr.length; i++) {
            if (i < iArr.length) {
                iNDArrayIndexArr[i] = NDArrayIndex.interval(0, iArr[i], iNDArray.size(i), true);
            } else {
                iNDArrayIndexArr[i] = NDArrayIndex.interval(0, iNDArray.size(i), true);
            }
        }
        return convn.get(iNDArrayIndexArr);
    }

    public static INDArray avgPooling(INDArray iNDArray, int[] iArr) {
        int length = iNDArray.shape().length;
        if (!$assertionsDisabled && length < 3) {
            throw new AssertionError("NDArray must have 3 dimensions");
        }
        int i = iNDArray.shape()[length - 2];
        int i2 = iNDArray.shape()[length - 1];
        int i3 = iArr[0];
        int i4 = iArr[1];
        INDArray create = Nd4j.create(iNDArray.shape());
        for (int i5 = 0; i5 < Math.ceil(i / i3); i5++) {
            INDArrayIndex interval = NDArrayIndex.interval(i5 * i3, i5 * i3, true);
            for (int i6 = 0; i6 < Math.ceil(i2 / i4); i6++) {
                INDArrayIndex interval2 = NDArrayIndex.interval(i6 * i4, (i6 * i4) + 1, true);
                create.put(new INDArrayIndex[]{interval, interval2}, iNDArray.get(interval, interval2).sum(iNDArray.shape().length - 1).mean(iNDArray.shape().length - 1).permute(1, 2, 0)).repmat(interval.length(), interval2.length());
            }
        }
        return create;
    }

    public static INDArray sumPooling(INDArray iNDArray, int[] iArr) {
        int length = iNDArray.shape().length;
        if (!$assertionsDisabled && length < 3) {
            throw new AssertionError("NDArray must have 3 dimensions");
        }
        int i = iNDArray.shape()[length - 2];
        int i2 = iNDArray.shape()[length - 1];
        int i3 = iArr[0];
        int i4 = iArr[1];
        INDArray create = Nd4j.create(iNDArray.shape());
        for (int i5 = 0; i5 < Math.ceil(i / i3); i5++) {
            INDArrayIndex interval = NDArrayIndex.interval(i5 * i3, i5 * i3, true);
            for (int i6 = 0; i6 < Math.ceil(i2 / i4); i6++) {
                INDArrayIndex interval2 = NDArrayIndex.interval(i6 * i4, (i6 * i4) + 1, true);
                create.put(new INDArrayIndex[]{interval, interval2}, iNDArray.get(interval, interval2).sum(iNDArray.shape().length - 1).sum(iNDArray.shape().length - 1).permute(1, 2, 0)).repmat(interval.length(), interval2.length());
            }
        }
        return create;
    }

    public static INDArray upSample(INDArray iNDArray, INDArray iNDArray2) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < iNDArray.shape().length; i++) {
            INDArray zeros = Nd4j.zeros(iNDArray.size(i) * ((int) iNDArray2.getDouble(i)), 1);
            zeros.put((INDArrayIndex[]) new NDArrayIndex[]{new NDArrayIndex(ArrayUtil.range(0, ((int) iNDArray2.getDouble(i)) * iNDArray.size(i), (int) iNDArray2.getDouble(i)))}, (Number) 1);
            arrayList.add(zeros.cumsum(0).sub((Number) 1));
        }
        INDArray create = Nd4j.create(ArrayUtil.toInts(ArrayUtil.toNDArray(iNDArray.shape()).muli(iNDArray2)));
        INDArray linearView = create.linearView();
        for (int i2 = 0; i2 < linearView.length(); i2++) {
            for (int i3 = 0; i3 < ((INDArray) arrayList.get(0)).length(); i3++) {
                ((INDArray) arrayList.get(0)).getInt(i3);
                for (int i4 = 1; i4 < arrayList.size(); i4++) {
                }
            }
        }
        return create;
    }

    public static double cosineSim(INDArray iNDArray, INDArray iNDArray2) {
        return Nd4j.getExecutioner().execAndReturn((Accumulation) new CosineSimilarity(iNDArray, iNDArray2, iNDArray.length())).currentResult().doubleValue();
    }

    public static INDArray normalizeZeroMeanAndUnitVariance(INDArray iNDArray) {
        INDArray mean = iNDArray.mean(0);
        INDArray std = iNDArray.std(0);
        iNDArray.subiRowVector(mean);
        std.addi(Double.valueOf(Nd4j.EPS_THRESHOLD));
        iNDArray.diviRowVector(std);
        return iNDArray;
    }

    public static INDArray unitVec(INDArray iNDArray) {
        double doubleValue = iNDArray.norm2Number().doubleValue();
        return doubleValue > 0.0d ? iNDArray.data().dataType() == DataBuffer.Type.FLOAT ? Nd4j.getBlasWrapper().scal(1.0f / ((float) doubleValue), iNDArray) : Nd4j.getBlasWrapper().scal(1.0d / doubleValue, iNDArray) : iNDArray;
    }

    public static INDArray neg(INDArray iNDArray) {
        return neg(iNDArray, Nd4j.copyOnOps);
    }

    public static INDArray floor(INDArray iNDArray) {
        return floor(iNDArray, Nd4j.copyOnOps);
    }

    public static INDArray ceiling(INDArray iNDArray) {
        return ceiling(iNDArray, Nd4j.copyOnOps);
    }

    public static INDArray ceiling(INDArray iNDArray, boolean z) {
        return exec(z ? new Ceil(iNDArray, iNDArray.dup()) : new Ceil(iNDArray, iNDArray));
    }

    public static INDArray sign(INDArray iNDArray) {
        return sign(iNDArray, Nd4j.copyOnOps);
    }

    public static INDArray stabilize(INDArray iNDArray, double d) {
        return stabilize(iNDArray, d, Nd4j.copyOnOps);
    }

    public static INDArray abs(INDArray iNDArray) {
        return abs(iNDArray, true);
    }

    public static INDArray exp(INDArray iNDArray) {
        return exp(iNDArray, Nd4j.copyOnOps);
    }

    public static INDArray hardTanh(INDArray iNDArray) {
        return hardTanh(iNDArray, Nd4j.copyOnOps);
    }

    public static INDArray identity(INDArray iNDArray) {
        return identity(iNDArray, Nd4j.copyOnOps);
    }

    public static INDArray pow(INDArray iNDArray, Number number) {
        return pow(iNDArray, number, Nd4j.copyOnOps);
    }

    public static INDArray round(INDArray iNDArray) {
        return round(iNDArray, Nd4j.copyOnOps);
    }

    public static INDArray sigmoid(INDArray iNDArray) {
        return sigmoid(iNDArray, Nd4j.copyOnOps);
    }

    public static INDArray sqrt(INDArray iNDArray) {
        return sqrt(iNDArray, Nd4j.copyOnOps);
    }

    public static INDArray tanh(INDArray iNDArray) {
        return tanh(iNDArray, Nd4j.copyOnOps);
    }

    public static INDArray log(INDArray iNDArray) {
        return log(iNDArray, Nd4j.copyOnOps);
    }

    public static INDArray eps(INDArray iNDArray) {
        return eps(iNDArray, Nd4j.copyOnOps);
    }

    public static INDArray greaterThanOrEqual(INDArray iNDArray, INDArray iNDArray2) {
        return greaterThanOrEqual(iNDArray, iNDArray2, Nd4j.copyOnOps);
    }

    public static INDArray lessThanOrEqual(INDArray iNDArray, INDArray iNDArray2) {
        return lessThanOrEqual(iNDArray, iNDArray2, Nd4j.copyOnOps);
    }

    public static INDArray lessThanOrEqual(INDArray iNDArray, INDArray iNDArray2, boolean z) {
        return exec(z ? new LessThanOrEqual(iNDArray.dup(), iNDArray2) : new LessThanOrEqual(iNDArray, iNDArray2));
    }

    public static INDArray greaterThanOrEqual(INDArray iNDArray, INDArray iNDArray2, boolean z) {
        return exec(z ? new GreaterThanOrEqual(iNDArray.dup(), iNDArray2) : new GreaterThanOrEqual(iNDArray, iNDArray2));
    }

    public static INDArray eps(INDArray iNDArray, boolean z) {
        return exec(z ? new Eps(iNDArray.dup()) : new Eps(iNDArray));
    }

    public static INDArray floor(INDArray iNDArray, boolean z) {
        return exec(z ? new Floor(iNDArray.dup()) : new Floor(iNDArray));
    }

    public static INDArray sign(INDArray iNDArray, boolean z) {
        return exec(z ? new Sign(iNDArray, iNDArray.dup()) : new Sign(iNDArray));
    }

    public static INDArray max(INDArray iNDArray, double d, boolean z) {
        return exec(z ? new ScalarMax(iNDArray.dup(), Double.valueOf(d)) : new ScalarMax(iNDArray, Double.valueOf(d)));
    }

    public static INDArray max(INDArray iNDArray, double d) {
        return max(iNDArray, d, Nd4j.copyOnOps);
    }

    public static INDArray stabilize(INDArray iNDArray, double d, boolean z) {
        return exec(z ? new Stabilize(iNDArray, iNDArray.dup(), d) : new Stabilize(iNDArray, d));
    }

    public static INDArray abs(INDArray iNDArray, boolean z) {
        return exec(z ? new Abs(iNDArray, iNDArray.dup()) : new Abs(iNDArray));
    }

    public static INDArray exp(INDArray iNDArray, boolean z) {
        return exec(z ? new Exp(iNDArray, iNDArray.dup()) : new Exp(iNDArray));
    }

    public static INDArray hardTanh(INDArray iNDArray, boolean z) {
        return exec(z ? new HardTanh(iNDArray, iNDArray.dup()) : new HardTanh(iNDArray));
    }

    public static INDArray identity(INDArray iNDArray, boolean z) {
        return exec(z ? new Identity(iNDArray, iNDArray.dup()) : new Identity(iNDArray));
    }

    public static INDArray pow(INDArray iNDArray, Number number, boolean z) {
        return exec(z ? new Pow(iNDArray, iNDArray.dup(), number.doubleValue()) : new Pow(iNDArray, number.doubleValue()));
    }

    public static INDArray round(INDArray iNDArray, boolean z) {
        return exec(z ? new Round(iNDArray, iNDArray.dup()) : new Round(iNDArray));
    }

    public static INDArray sigmoid(INDArray iNDArray, boolean z) {
        return exec(z ? new Sigmoid(iNDArray, iNDArray.dup()) : new Sigmoid(iNDArray));
    }

    public static INDArray sqrt(INDArray iNDArray, boolean z) {
        return exec(z ? new Sqrt(iNDArray, iNDArray.dup()) : new Sqrt(iNDArray));
    }

    public static INDArray tanh(INDArray iNDArray, boolean z) {
        return exec(z ? new Tanh(iNDArray, iNDArray.dup()) : new Tanh(iNDArray));
    }

    public static INDArray log(INDArray iNDArray, boolean z) {
        return exec(z ? new Log(iNDArray, iNDArray.dup()) : new Log(iNDArray));
    }

    public static INDArray neg(INDArray iNDArray, boolean z) {
        return exec(z ? new Negative(iNDArray, iNDArray.dup()) : new Negative(iNDArray));
    }

    private static INDArray exec(ScalarOp scalarOp) {
        if (scalarOp.x().isCleanedUp()) {
            throw new IllegalStateException("NDArray already freed");
        }
        return Nd4j.getExecutioner().exec(scalarOp).z();
    }

    private static INDArray exec(TransformOp transformOp) {
        if (transformOp.x().isCleanedUp()) {
            throw new IllegalStateException("NDArray already freed");
        }
        return Nd4j.getExecutioner().execAndReturn(transformOp);
    }

    static {
        $assertionsDisabled = !Transforms.class.desiredAssertionStatus();
    }
}
