package org.nd4j.linalg.jcublas.ops.executioner;

import java.util.Arrays;
import org.apache.commons.math3.util.Pair;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.tad.DeviceTADManager;
import org.nd4j.jita.allocator.tad.TADManager;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.CopyOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.AddressRetriever;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/jcublas/ops/executioner/JCudaExecutioner.class */
public class JCudaExecutioner extends DefaultOpExecutioner {
    private static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    private static final Allocator allocator = AtomicAllocator.getInstance();
    private static Logger log = LoggerFactory.getLogger(JCudaExecutioner.class);
    private static TADManager tadManager = new DeviceTADManager();

    public NativeOps getNativeOps() {
        return nativeOps;
    }

    protected void doBroadcastOp(BroadcastOp broadcastOp) {
        exec(broadcastOp);
    }

    public INDArray exec(BroadcastOp broadcastOp, int... iArr) {
        Arrays.sort(iArr);
        CudaContext prepareAction = allocator.getFlowController().prepareAction(broadcastOp.z(), broadcastOp.x(), broadcastOp.y());
        long retrieveHostAddress = broadcastOp.y() == null ? 0L : AddressRetriever.retrieveHostAddress(broadcastOp.y().shapeInfoDataBuffer());
        long retrieveHostAddress2 = broadcastOp.z() == null ? 0L : AddressRetriever.retrieveHostAddress(broadcastOp.z().shapeInfoDataBuffer());
        long address = AtomicAllocator.getInstance().getPointer(broadcastOp.x(), prepareAction).address();
        long address2 = AtomicAllocator.getInstance().getPointer(broadcastOp.y(), prepareAction).address();
        long address3 = AtomicAllocator.getInstance().getPointer(broadcastOp.z(), prepareAction).address();
        long address4 = AtomicAllocator.getInstance().getPointer(broadcastOp.x().shapeInfoDataBuffer(), prepareAction).address();
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(broadcastOp.x(), iArr);
        long[] jArr = {AddressRetriever.retrieveHostAddress(broadcastOp.x().shapeInfoDataBuffer()), prepareAction.getOldStream().getNativePointer(), allocator.getDeviceId().intValue(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostAddress, retrieveHostAddress2, AddressRetriever.retrieveHostAddress((DataBuffer) tADOnlyShapeInfo.getFirst()), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction).address(), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), prepareAction).address()};
        long address5 = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr), prepareAction).address();
        if (broadcastOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execBroadcastDouble(jArr, broadcastOp.opNum(), address, address4, address2, AtomicAllocator.getInstance().getPointer(broadcastOp.y().shapeInfoDataBuffer(), prepareAction).address(), address3, AtomicAllocator.getInstance().getPointer(broadcastOp.z().shapeInfoDataBuffer(), prepareAction).address(), address5, iArr.length);
        } else {
            nativeOps.execBroadcastFloat(jArr, broadcastOp.opNum(), address, address4, address2, AtomicAllocator.getInstance().getPointer(broadcastOp.y().shapeInfoDataBuffer(), prepareAction).address(), address3, AtomicAllocator.getInstance().getPointer(broadcastOp.z().shapeInfoDataBuffer(), prepareAction).address(), address5, iArr.length);
        }
        allocator.registerAction(prepareAction, broadcastOp.z(), broadcastOp.x(), broadcastOp.y());
        return broadcastOp.z();
    }

    public INDArray exec(Accumulation accumulation, int... iArr) {
        Arrays.sort(iArr);
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] < 0) {
                int i2 = i;
                iArr[i2] = iArr[i2] + accumulation.x().rank();
            }
        }
        if (iArr.length == accumulation.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        int[] removeIndex = Shape.wholeArrayDimension(iArr) ? new int[]{1, 1} : ArrayUtil.removeIndex(accumulation.x().shape(), iArr);
        if (removeIndex.length == 1) {
            removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
        } else if (removeIndex.length == 0) {
            removeIndex = new int[]{1, 1};
        }
        if (accumulation.x().isVector() && accumulation.x().length() == ArrayUtil.prod(removeIndex)) {
            return accumulation.noOp();
        }
        INDArray valueArrayOf = (accumulation.zeroDouble() <= -0.009999999776482582d || accumulation.zeroDouble() >= 0.009999999776482582d) ? Nd4j.valueArrayOf(removeIndex, accumulation.zeroDouble()) : Nd4j.zeros(removeIndex);
        accumulation.setZ(valueArrayOf);
        CudaContext prepareAction = allocator.getFlowController().prepareAction(accumulation.z(), accumulation.x(), accumulation.y());
        long retrieveHostAddress = accumulation.y() == null ? 0L : AddressRetriever.retrieveHostAddress(accumulation.y().shapeInfoDataBuffer());
        long retrieveHostAddress2 = accumulation.z() == null ? 0L : AddressRetriever.retrieveHostAddress(accumulation.z().shapeInfoDataBuffer());
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(accumulation.x(), iArr);
        long retrieveHostAddress3 = AddressRetriever.retrieveHostAddress((DataBuffer) tADOnlyShapeInfo.getFirst());
        long address = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction).address();
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        long address2 = dataBuffer == null ? 0L : AtomicAllocator.getInstance().getPointer(dataBuffer, prepareAction).address();
        long address3 = AtomicAllocator.getInstance().getPointer(accumulation.x(), prepareAction).address();
        long address4 = AtomicAllocator.getInstance().getPointer(accumulation.x().shapeInfoDataBuffer(), prepareAction).address();
        long[] jArr = {AddressRetriever.retrieveHostAddress(accumulation.x().shapeInfoDataBuffer()), prepareAction.getOldStream().getNativePointer(), allocator.getDeviceId().intValue(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostAddress, retrieveHostAddress2, retrieveHostAddress3, address, address2};
        long address5 = (accumulation.extraArgs() == null || !(accumulation instanceof Variance)) ? 0L : AtomicAllocator.getInstance().getPointer(accumulation.extraArgsDataBuff(), prepareAction).address();
        long address6 = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr), prepareAction).address();
        if (accumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (accumulation instanceof Variance) {
                if (valueArrayOf.isScalar()) {
                    allocator.tickHostWrite(valueArrayOf);
                    valueArrayOf.putScalar(0, nativeOps.execSummaryStatsScalarDouble(jArr, accumulation.opNum(), address3, address4, address5, true));
                    accumulation.setFinalResult(Double.valueOf(valueArrayOf.getDouble(0)));
                } else {
                    nativeOps.execSummaryStatsDouble(jArr, accumulation.opNum(), address3, address4, address5, AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction).address(), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction).address(), address6, iArr.length, ((Variance) accumulation).isBiasCorrected());
                    allocator.registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
                }
            } else if (accumulation.y() != null) {
                if (valueArrayOf.isScalar()) {
                    allocator.tickHostWrite(valueArrayOf);
                    valueArrayOf.putScalar(0, nativeOps.execReduce3ScalarDouble(jArr, accumulation.opNum(), address3, address4, address5, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction).address(), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction).address()));
                    accumulation.setFinalResult(Double.valueOf(valueArrayOf.getDouble(0)));
                } else {
                    nativeOps.execReduce3Double(jArr, accumulation.opNum(), address3, address4, address5, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction).address(), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction).address(), AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction).address(), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction).address(), address6, iArr.length);
                    allocator.registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
                }
            } else if (valueArrayOf.isScalar()) {
                allocator.tickHostWrite(valueArrayOf);
                valueArrayOf.putScalar(0, nativeOps.execReduceScalarDouble(jArr, accumulation.opNum(), address3, address4, address5));
                accumulation.setFinalResult(Double.valueOf(valueArrayOf.getDouble(0)));
            } else {
                nativeOps.execReduceDouble(jArr, accumulation.opNum(), address3, address4, address5, AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction).address(), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction).address(), address6, iArr.length);
                allocator.registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
            }
        } else if (accumulation instanceof Variance) {
            if (valueArrayOf.isScalar()) {
                allocator.tickHostWrite(valueArrayOf);
                valueArrayOf.putScalar(0, nativeOps.execSummaryStatsScalarFloat(jArr, accumulation.opNum(), address3, address4, address5, true));
                accumulation.setFinalResult(Float.valueOf(valueArrayOf.getFloat(0)));
            } else {
                nativeOps.execSummaryStatsFloat(jArr, accumulation.opNum(), address3, address4, address5, AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction).address(), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction).address(), address6, iArr.length, ((Variance) accumulation).isBiasCorrected());
                allocator.registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
            }
        } else if (accumulation.y() != null) {
            if (valueArrayOf.isScalar()) {
                allocator.tickHostWrite(valueArrayOf);
                valueArrayOf.putScalar(0, nativeOps.execReduce3ScalarFloat(jArr, accumulation.opNum(), address3, address4, address5, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction).address(), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction).address()));
                accumulation.setFinalResult(Float.valueOf(valueArrayOf.getFloat(0)));
            } else {
                nativeOps.execReduce3Float(jArr, accumulation.opNum(), address3, address4, address5, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction).address(), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction).address(), AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction).address(), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction).address(), address6, iArr.length);
                allocator.registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
            }
        } else if (valueArrayOf.isScalar()) {
            allocator.tickHostWrite(valueArrayOf);
            valueArrayOf.putScalar(0, nativeOps.execReduceScalarFloat(jArr, accumulation.opNum(), address3, address4, address5));
            accumulation.setFinalResult(Float.valueOf(valueArrayOf.getFloat(0)));
        } else {
            nativeOps.execReduceFloat(jArr, accumulation.opNum(), address3, address4, address5, AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction).address(), AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction).address(), address6, iArr.length);
            allocator.registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
        }
        return valueArrayOf;
    }

    public INDArray exec(IndexAccumulation indexAccumulation, int... iArr) {
        Arrays.sort(iArr);
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] < 0) {
                int i2 = i;
                iArr[i2] = iArr[i2] + indexAccumulation.x().rank();
            }
        }
        if (iArr.length == indexAccumulation.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        int[] removeIndex = Shape.wholeArrayDimension(iArr) ? new int[]{1, 1} : ArrayUtil.removeIndex(indexAccumulation.x().shape(), iArr);
        if (indexAccumulation.x().isVector() && indexAccumulation.x().length() == ArrayUtil.prod(removeIndex)) {
            return indexAccumulation.x();
        }
        if (removeIndex.length == 1) {
            removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
        } else if (removeIndex.length == 0) {
            removeIndex = new int[]{1, 1};
        }
        indexAccumulation.setZ((indexAccumulation.zeroDouble() <= -0.009999999776482582d || indexAccumulation.zeroDouble() >= 0.009999999776482582d) ? Nd4j.valueArrayOf(removeIndex, indexAccumulation.zeroDouble()) : Nd4j.zeros(removeIndex));
        if (iArr.length == indexAccumulation.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        CudaContext prepareAction = allocator.getFlowController().prepareAction(indexAccumulation.z(), indexAccumulation.x(), indexAccumulation.y());
        long retrieveHostAddress = indexAccumulation.y() == null ? 0L : AddressRetriever.retrieveHostAddress(indexAccumulation.y().shapeInfoDataBuffer());
        long retrieveHostAddress2 = indexAccumulation.z() == null ? 0L : AddressRetriever.retrieveHostAddress(indexAccumulation.z().shapeInfoDataBuffer());
        long address = AtomicAllocator.getInstance().getPointer(indexAccumulation.x(), prepareAction).address();
        long address2 = AtomicAllocator.getInstance().getPointer(indexAccumulation.x().shapeInfoDataBuffer(), prepareAction).address();
        long address3 = AtomicAllocator.getInstance().getPointer(indexAccumulation.z(), prepareAction).address();
        long address4 = AtomicAllocator.getInstance().getPointer(indexAccumulation.z().shapeInfoDataBuffer(), prepareAction).address();
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(indexAccumulation.x(), iArr);
        long retrieveHostAddress3 = AddressRetriever.retrieveHostAddress((DataBuffer) tADOnlyShapeInfo.getFirst());
        long address5 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction).address();
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        long[] jArr = {AddressRetriever.retrieveHostAddress(indexAccumulation.x().shapeInfoDataBuffer()), prepareAction.getOldStream().getNativePointer(), allocator.getDeviceId().intValue(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostAddress, retrieveHostAddress2, retrieveHostAddress3, address5, dataBuffer == null ? 0L : AtomicAllocator.getInstance().getPointer(dataBuffer, prepareAction).address()};
        long address6 = indexAccumulation.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(indexAccumulation.extraArgsDataBuff(), prepareAction).address() : 0L;
        long address7 = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr), prepareAction).address();
        if (indexAccumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execIndexReduceDouble(jArr, indexAccumulation.opNum(), address, address2, address6, address3, address4, address7, iArr.length);
        } else {
            nativeOps.execIndexReduceFloat(jArr, indexAccumulation.opNum(), address, address2, address6, address3, address4, address7, iArr.length);
        }
        allocator.registerAction(prepareAction, indexAccumulation.z(), indexAccumulation.x(), indexAccumulation.y());
        return indexAccumulation.z();
    }

    public Op exec(Op op, int... iArr) {
        Arrays.sort(iArr);
        return super.exec(op, iArr);
    }

    public Op exec(Op op) {
        if ((op.x() instanceof IComplexNDArray) || executionMode() == OpExecutioner.ExecutionMode.JAVA || (op instanceof CopyOp)) {
            if (op.x() != null) {
                allocator.synchronizeHostData(op.x());
            }
            if (op.y() != null) {
                allocator.synchronizeHostData(op.y());
            }
            super.exec(op);
            if (op.z() == null) {
                return null;
            }
            allocator.tickHostWrite(op.z());
            return null;
        }
        if (op instanceof TransformOp) {
            invoke((TransformOp) op);
        } else if (op instanceof Accumulation) {
            invoke((Accumulation) op, (int[]) null);
        } else if (op instanceof ScalarOp) {
            invoke((ScalarOp) op);
        } else if (op instanceof BroadcastOp) {
            invoke((BroadcastOp) op);
        } else if (op instanceof IndexAccumulation) {
            invoke((IndexAccumulation) op, (int[]) null);
        }
        return op;
    }

    public INDArray execAndReturn(TransformOp transformOp) {
        invoke(transformOp);
        return transformOp.z();
    }

    private CudaContext invoke(BroadcastOp broadcastOp) {
        CudaContext prepareAction = allocator.getFlowController().prepareAction(broadcastOp.z(), broadcastOp.x(), broadcastOp.y());
        long address = AtomicAllocator.getInstance().getPointer(broadcastOp.x(), prepareAction).address();
        long address2 = AtomicAllocator.getInstance().getPointer(broadcastOp.x().shapeInfoDataBuffer(), prepareAction).address();
        long retrieveHostAddress = broadcastOp.y() == null ? 0L : AddressRetriever.retrieveHostAddress(broadcastOp.y().shapeInfoDataBuffer());
        long retrieveHostAddress2 = broadcastOp.z() == null ? 0L : AddressRetriever.retrieveHostAddress(broadcastOp.z().shapeInfoDataBuffer());
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(broadcastOp.x(), broadcastOp.getDimension());
        long[] jArr = {AddressRetriever.retrieveHostAddress(broadcastOp.x().shapeInfoDataBuffer()), prepareAction.getOldStream().getNativePointer(), allocator.getDeviceId().intValue(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostAddress, retrieveHostAddress2, AddressRetriever.retrieveHostAddress((DataBuffer) tADOnlyShapeInfo.getFirst()), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction).address(), AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), prepareAction).address()};
        long address3 = AtomicAllocator.getInstance().getPointer(broadcastOp.y(), prepareAction).address();
        long address4 = AtomicAllocator.getInstance().getPointer(broadcastOp.y().shapeInfoDataBuffer(), prepareAction).address();
        long address5 = AtomicAllocator.getInstance().getPointer(broadcastOp.z(), prepareAction).address();
        long address6 = AtomicAllocator.getInstance().getPointer(broadcastOp.z().shapeInfoDataBuffer(), prepareAction).address();
        long address7 = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(broadcastOp.getDimension()), prepareAction).address();
        if (broadcastOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execBroadcastDouble(jArr, broadcastOp.opNum(), address, address2, address3, address4, address5, address6, address7, broadcastOp.getDimension().length);
        } else {
            nativeOps.execBroadcastFloat(jArr, broadcastOp.opNum(), address, address2, address3, address4, address5, address6, address7, broadcastOp.getDimension().length);
        }
        allocator.registerAction(prepareAction, broadcastOp.z(), broadcastOp.x(), broadcastOp.y());
        return null;
    }

    private CudaContext invoke(IndexAccumulation indexAccumulation, int[] iArr) {
        CudaContext prepareAction = allocator.getFlowController().prepareAction(indexAccumulation.z(), indexAccumulation.x(), indexAccumulation.y());
        long address = AtomicAllocator.getInstance().getPointer(indexAccumulation.x(), prepareAction).address();
        long address2 = AtomicAllocator.getInstance().getPointer(indexAccumulation.x().shapeInfoDataBuffer(), prepareAction).address();
        long address3 = indexAccumulation.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(indexAccumulation.extraArgsDataBuff(), prepareAction).address() : 0L;
        long retrieveHostAddress = indexAccumulation.y() == null ? 0L : AddressRetriever.retrieveHostAddress(indexAccumulation.y().shapeInfoDataBuffer());
        long retrieveHostAddress2 = indexAccumulation.z() == null ? 0L : AddressRetriever.retrieveHostAddress(indexAccumulation.z().shapeInfoDataBuffer());
        int[] iArr2 = iArr;
        if (iArr2 == null) {
            iArr2 = new int[]{0};
        }
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(indexAccumulation.x(), iArr2);
        long retrieveHostAddress3 = AddressRetriever.retrieveHostAddress((DataBuffer) tADOnlyShapeInfo.getFirst());
        long address4 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction).address();
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        long[] jArr = {AddressRetriever.retrieveHostAddress(indexAccumulation.x().shapeInfoDataBuffer()), prepareAction.getOldStream().getNativePointer(), allocator.getDeviceId().intValue(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostAddress, retrieveHostAddress2, retrieveHostAddress3, address4, dataBuffer == null ? 0L : AtomicAllocator.getInstance().getPointer(dataBuffer, prepareAction).address()};
        if (!indexAccumulation.z().isScalar() && iArr != null && iArr[0] != Integer.MAX_VALUE) {
            if (iArr == null) {
                iArr = new int[]{0};
            }
            Arrays.sort(iArr);
            long address5 = AtomicAllocator.getInstance().getPointer(indexAccumulation.z(), prepareAction).address();
            long address6 = AtomicAllocator.getInstance().getPointer(indexAccumulation.z().shapeInfoDataBuffer(), prepareAction).address();
            long address7 = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr), prepareAction).address();
            if (indexAccumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                nativeOps.execIndexReduceDouble(jArr, indexAccumulation.opNum(), address, address2, address3, address5, address6, address7, iArr.length);
            } else {
                nativeOps.execIndexReduceFloat(jArr, indexAccumulation.opNum(), address, address2, address3, address5, address6, address7, iArr.length);
            }
        } else if (indexAccumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            indexAccumulation.setFinalResult((int) nativeOps.execIndexReduceScalarDouble(jArr, indexAccumulation.opNum(), address, address2, address3));
        } else {
            indexAccumulation.setFinalResult((int) nativeOps.execIndexReduceScalarFloat(jArr, indexAccumulation.opNum(), address, address2, address3));
        }
        allocator.registerAction(prepareAction, indexAccumulation.z(), indexAccumulation.x(), indexAccumulation.y());
        return null;
    }

    private CudaContext invoke(Accumulation accumulation, int[] iArr) {
        if (iArr == null) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        Arrays.sort(iArr);
        CudaContext prepareAction = allocator.getFlowController().prepareAction(accumulation.z(), accumulation.x(), accumulation.y());
        long retrieveHostAddress = accumulation.y() == null ? 0L : AddressRetriever.retrieveHostAddress(accumulation.y().shapeInfoDataBuffer());
        long retrieveHostAddress2 = accumulation.z() == null ? 0L : AddressRetriever.retrieveHostAddress(accumulation.z().shapeInfoDataBuffer());
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(accumulation.x(), iArr);
        long retrieveHostAddress3 = AddressRetriever.retrieveHostAddress((DataBuffer) tADOnlyShapeInfo.getFirst());
        long address = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction).address();
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        long[] jArr = {AddressRetriever.retrieveHostAddress(accumulation.x().shapeInfoDataBuffer()), prepareAction.getOldStream().getNativePointer(), allocator.getDeviceId().intValue(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostAddress, retrieveHostAddress2, retrieveHostAddress3, address, dataBuffer == null ? 0L : AtomicAllocator.getInstance().getPointer(dataBuffer, prepareAction).address()};
        long address2 = AtomicAllocator.getInstance().getPointer(accumulation.x(), prepareAction).address();
        long address3 = AtomicAllocator.getInstance().getPointer(accumulation.x().shapeInfoDataBuffer(), prepareAction).address();
        long address4 = accumulation.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(accumulation.extraArgsDataBuff(), prepareAction).address() : 0L;
        int[] removeIndex = Shape.wholeArrayDimension(iArr) ? new int[]{1, 1} : ArrayUtil.removeIndex(accumulation.x().shape(), iArr);
        if (removeIndex.length == 1) {
            removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
        } else if (removeIndex.length == 0) {
            removeIndex = new int[]{1, 1};
        }
        if (accumulation.x().isVector() && accumulation.x().length() == ArrayUtil.prod(removeIndex)) {
            return null;
        }
        accumulation.setZ((accumulation.zeroDouble() <= -0.009999999776482582d || accumulation.zeroDouble() >= 0.009999999776482582d) ? Nd4j.valueArrayOf(removeIndex, accumulation.zeroDouble()) : Nd4j.zeros(removeIndex));
        if (!accumulation.z().isScalar()) {
            long address5 = AtomicAllocator.getInstance().getPointer(accumulation.z(), prepareAction).address();
            long address6 = AtomicAllocator.getInstance().getPointer(accumulation.z().shapeInfoDataBuffer(), prepareAction).address();
            long address7 = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr), prepareAction).address();
            if (accumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                if (accumulation.y() != null) {
                    nativeOps.execReduce3Double(jArr, accumulation.opNum(), address2, address3, address4, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction).address(), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction).address(), address5, address6, address7, iArr.length);
                } else if (accumulation instanceof Variance) {
                    nativeOps.execSummaryStatsDouble(jArr, accumulation.opNum(), address2, address3, address4, address5, address6, address7, iArr.length, ((Variance) accumulation).isBiasCorrected());
                } else {
                    nativeOps.execReduceDouble(jArr, accumulation.opNum(), address2, address3, address4, address5, address6, address7, iArr.length);
                }
            } else if (accumulation.y() != null) {
                nativeOps.execReduce3Float(jArr, accumulation.opNum(), address2, address3, address4, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction).address(), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction).address(), address5, address6, address7, iArr.length);
            } else if (accumulation instanceof Variance) {
                nativeOps.execSummaryStatsFloat(jArr, accumulation.opNum(), address2, address3, address4, address5, address6, address7, iArr.length, ((Variance) accumulation).isBiasCorrected());
            } else {
                nativeOps.execReduceFloat(jArr, accumulation.opNum(), address2, address3, address4, address5, address6, address7, iArr.length);
            }
        } else if (accumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (accumulation instanceof Variance) {
                accumulation.setFinalResult(Double.valueOf(nativeOps.execSummaryStatsScalarDouble(jArr, accumulation.opNum(), address2, address3, address4, true)));
            } else if (accumulation.y() != null) {
                accumulation.setFinalResult(Double.valueOf(nativeOps.execReduce3ScalarDouble(jArr, accumulation.opNum(), address2, address3, address4, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction).address(), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction).address())));
            } else {
                accumulation.setFinalResult(Double.valueOf(nativeOps.execReduceScalarDouble(jArr, accumulation.opNum(), address2, address3, address4)));
            }
        } else if (accumulation instanceof Variance) {
            accumulation.setFinalResult(Float.valueOf(nativeOps.execSummaryStatsScalarFloat(jArr, accumulation.opNum(), address2, address3, address4, true)));
        } else if (accumulation.y() != null) {
            accumulation.setFinalResult(Float.valueOf(nativeOps.execReduce3ScalarFloat(jArr, accumulation.opNum(), address2, address3, address4, AtomicAllocator.getInstance().getPointer(accumulation.y(), prepareAction).address(), AtomicAllocator.getInstance().getPointer(accumulation.y().shapeInfoDataBuffer(), prepareAction).address())));
        } else {
            accumulation.setFinalResult(Float.valueOf(nativeOps.execReduceScalarFloat(jArr, accumulation.opNum(), address2, address3, address4)));
        }
        allocator.registerAction(prepareAction, accumulation.z(), accumulation.x(), accumulation.y());
        return prepareAction;
    }

    private CudaContext invoke(ScalarOp scalarOp) {
        CudaContext prepareAction = allocator.getFlowController().prepareAction(scalarOp.z(), scalarOp.x(), scalarOp.y());
        long retrieveHostAddress = scalarOp.y() == null ? 0L : AddressRetriever.retrieveHostAddress(scalarOp.y().shapeInfoDataBuffer());
        long retrieveHostAddress2 = scalarOp.z() == null ? 0L : AddressRetriever.retrieveHostAddress(scalarOp.z().shapeInfoDataBuffer());
        long address = AtomicAllocator.getInstance().getPointer(scalarOp.x(), prepareAction).address();
        long address2 = AtomicAllocator.getInstance().getPointer(scalarOp.x().shapeInfoDataBuffer(), prepareAction).address();
        long address3 = scalarOp.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(scalarOp.extraArgsDataBuff(), prepareAction).address() : 0L;
        long address4 = AtomicAllocator.getInstance().getPointer(scalarOp.z(), prepareAction).address();
        long address5 = AtomicAllocator.getInstance().getPointer(scalarOp.z().shapeInfoDataBuffer(), prepareAction).address();
        long[] jArr = {AddressRetriever.retrieveHostAddress(scalarOp.x().shapeInfoDataBuffer()), prepareAction.getOldStream().getNativePointer(), allocator.getDeviceId().intValue(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostAddress, retrieveHostAddress2, 0, 0};
        if (scalarOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            nativeOps.execScalarDouble(jArr, scalarOp.opNum(), address, address2, address4, address5, scalarOp.scalar().doubleValue(), address3);
        } else {
            nativeOps.execScalarFloat(jArr, scalarOp.opNum(), address, address2, address4, address5, scalarOp.scalar().floatValue(), address3);
        }
        allocator.registerAction(prepareAction, scalarOp.z(), scalarOp.x(), scalarOp.y());
        return null;
    }

    private CudaContext invoke(TransformOp transformOp) {
        CudaContext prepareAction = allocator.getFlowController().prepareAction(transformOp.z(), transformOp.x(), transformOp.y());
        long address = AtomicAllocator.getInstance().getPointer(transformOp.x(), prepareAction).address();
        long address2 = AtomicAllocator.getInstance().getPointer(transformOp.x().shapeInfoDataBuffer(), prepareAction).address();
        long address3 = transformOp.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(transformOp.extraArgsDataBuff(), prepareAction).address() : 0L;
        long retrieveHostAddress = transformOp.y() == null ? 0L : AddressRetriever.retrieveHostAddress(transformOp.y().shapeInfoDataBuffer());
        long retrieveHostAddress2 = transformOp.z() == null ? 0L : AddressRetriever.retrieveHostAddress(transformOp.z().shapeInfoDataBuffer());
        long j = 0;
        long j2 = 0;
        int[] iArr = null;
        if (transformOp.opNum() == 41 && transformOp.extraArgs() != null) {
            iArr = new int[]{((Integer) transformOp.extraArgs()[1]).intValue()};
            for (int i = 0; i < iArr.length; i++) {
                if (iArr[i] < 0) {
                    int i2 = i;
                    iArr[i2] = iArr[i2] + transformOp.x().rank();
                }
            }
            if (iArr.length == transformOp.x().rank()) {
                iArr = new int[]{Integer.MAX_VALUE};
            }
            int[] removeIndex = Shape.wholeArrayDimension(iArr) ? new int[]{1, 1} : ArrayUtil.removeIndex(transformOp.x().shape(), iArr);
            if (removeIndex.length == 1) {
                removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
            } else if (removeIndex.length == 0) {
                removeIndex = new int[]{1, 1};
            }
            retrieveHostAddress = AtomicAllocator.getInstance().getPointer(Nd4j.zeros(removeIndex).shapeInfoDataBuffer(), prepareAction).address();
            DataBuffer constantBuffer = AtomicAllocator.getInstance().getConstantBuffer(iArr);
            j = AtomicAllocator.getInstance().getPointer(constantBuffer, prepareAction).address();
            j2 = AtomicAllocator.getInstance().getHostPointer(constantBuffer).address();
        }
        long j3 = 0;
        long j4 = 0;
        long j5 = 0;
        long j6 = 0;
        long j7 = 0;
        long j8 = 0;
        if (transformOp.opNum() >= 38 && transformOp.opNum() <= 41) {
            if (transformOp.opNum() != 41) {
                Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(transformOp.x(), new int[]{0});
                Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo2 = tadManager.getTADOnlyShapeInfo(transformOp.x(), new int[]{1});
                j3 = AddressRetriever.retrieveHostAddress((DataBuffer) tADOnlyShapeInfo.getFirst());
                j4 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), prepareAction).address();
                j5 = AddressRetriever.retrieveHostAddress((DataBuffer) tADOnlyShapeInfo2.getFirst());
                j6 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo2.getFirst(), prepareAction).address();
                DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
                j7 = dataBuffer == null ? 0L : AtomicAllocator.getInstance().getPointer(dataBuffer, prepareAction).address();
                DataBuffer dataBuffer2 = (DataBuffer) tADOnlyShapeInfo2.getSecond();
                j8 = dataBuffer2 == null ? 0L : AtomicAllocator.getInstance().getPointer(dataBuffer2, prepareAction).address();
            } else {
                Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo3 = tadManager.getTADOnlyShapeInfo(transformOp.z(), iArr);
                j3 = AddressRetriever.retrieveHostAddress((DataBuffer) tADOnlyShapeInfo3.getFirst());
                j4 = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo3.getFirst(), prepareAction).address();
                DataBuffer dataBuffer3 = (DataBuffer) tADOnlyShapeInfo3.getSecond();
                j7 = dataBuffer3 == null ? 0L : AtomicAllocator.getInstance().getPointer(dataBuffer3, prepareAction).address();
            }
        }
        long address4 = AtomicAllocator.getInstance().getPointer(transformOp.z(), prepareAction).address();
        long address5 = AtomicAllocator.getInstance().getPointer(transformOp.z().shapeInfoDataBuffer(), prepareAction).address();
        long[] jArr = {AddressRetriever.retrieveHostAddress(transformOp.x().shapeInfoDataBuffer()), prepareAction.getOldStream().getNativePointer(), allocator.getDeviceId().intValue(), prepareAction.getBufferAllocation(), prepareAction.getBufferReduction(), prepareAction.getBufferScalar(), prepareAction.getBufferSpecial(), retrieveHostAddress, retrieveHostAddress2, j3, j4, j7, j5, j6, j8, j, j2};
        if (transformOp.y() != null) {
            long address6 = AtomicAllocator.getInstance().getPointer(transformOp.y(), prepareAction).address();
            long address7 = AtomicAllocator.getInstance().getPointer(transformOp.y().shapeInfoDataBuffer(), prepareAction).address();
            if (transformOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                if (transformOp.x().elementWiseStride() < 1 || transformOp.y().elementWiseStride() < 1 || transformOp.isExecSpecial() || transformOp.x().ordering() != transformOp.y().ordering() || transformOp.x().ordering() != transformOp.z().ordering()) {
                    nativeOps.execPairwiseTransformDouble(jArr, transformOp.opNum(), address, address2, address6, address7, address4, address5, address3);
                } else {
                    nativeOps.execPairwiseTransformDouble(jArr, transformOp.opNum(), address, transformOp.x().elementWiseStride(), address6, transformOp.y().elementWiseStride(), address4, transformOp.z().elementWiseStride(), address3, transformOp.n());
                }
            } else if (transformOp.x().elementWiseStride() < 1 || transformOp.y().elementWiseStride() < 1 || transformOp.x().elementWiseStride() != transformOp.y().elementWiseStride() || transformOp.isExecSpecial() || transformOp.x().ordering() != transformOp.y().ordering() || transformOp.x().ordering() != transformOp.z().ordering()) {
                nativeOps.execPairwiseTransformFloat(jArr, transformOp.opNum(), address, address2, address6, address7, address4, address5, address3);
            } else {
                nativeOps.execPairwiseTransformFloat(jArr, transformOp.opNum(), address, transformOp.x().elementWiseStride(), address6, transformOp.y().elementWiseStride(), address4, transformOp.z().elementWiseStride(), address3, transformOp.n());
            }
        } else if (transformOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (transformOp.x().elementWiseStride() < 1 || transformOp.isExecSpecial() || transformOp.z().ordering() != transformOp.x().ordering()) {
                nativeOps.execTransformDouble(jArr, transformOp.opNum(), address, address2, address4, address5, address3);
            } else {
                nativeOps.execTransformDouble(jArr, transformOp.opNum(), address, transformOp.x().elementWiseStride(), address4, transformOp.z().elementWiseStride(), address3, transformOp.n());
            }
        } else if (transformOp.x().elementWiseStride() < 1 || transformOp.isExecSpecial() || transformOp.z().ordering() != transformOp.x().ordering()) {
            nativeOps.execTransformFloat(jArr, transformOp.opNum(), address, address2, address4, address5, address3);
        } else {
            nativeOps.execTransformFloat(jArr, transformOp.opNum(), address, transformOp.x().elementWiseStride(), address4, transformOp.z().elementWiseStride(), address3, transformOp.n());
        }
        allocator.registerAction(prepareAction, transformOp.z(), transformOp.x(), transformOp.y());
        return null;
    }

    public static TADManager getTadManager() {
        return tadManager;
    }
}
