package org.nd4j.linalg.jcublas.ops.executioner;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Deque;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.GridOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.MetaOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.ops.grid.GridPointers;
import org.nd4j.linalg.api.ops.grid.OpDescriptor;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.ops.impl.meta.InvertedPredicateMetaOp;
import org.nd4j.linalg.api.ops.impl.meta.PostulateMetaOp;
import org.nd4j.linalg.api.ops.impl.meta.PredicateMetaOp;
import org.nd4j.linalg.api.ops.impl.meta.ReduceMetaOp;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMax;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.jcublas.ops.executioner.aggregates.AggregateDescriptor;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.nativeblas.LongPointerWrapper;
import org.nd4j.nativeblas.Nd4jCuda;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.class */
public class CudaGridExecutioner extends CudaExecutioner implements GridExecutioner {
    private static Logger logger = LoggerFactory.getLogger(CudaGridExecutioner.class);
    private ThreadLocal<OpDescriptor> lastOp = new ThreadLocal<>();
    private ThreadLocal<Deque<OpDescriptor>> deviceQueues = new ThreadLocal<>();
    private ThreadLocal<AtomicLong> opCounter = new ThreadLocal<>();
    private AtomicLong metaCounter = new AtomicLong(0);
    private AtomicLong execCounter = new AtomicLong(0);
    private List<WatchdogPair> watchdog = new CopyOnWriteArrayList();
    private List<Queue<AggregateDescriptor>> aggregates = new ArrayList();
    private AtomicBoolean experimental = new AtomicBoolean(false);

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner$MetaType.class */
    public enum MetaType {
        NOT_APPLICABLE,
        PREDICATE,
        INVERTED_PREDICATE,
        POSTULATE
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner$WatchdogPair.class */
    public static class WatchdogPair {
        private INDArray array;
        private String tag;

        public INDArray getArray() {
            return this.array;
        }

        public String getTag() {
            return this.tag;
        }

        public void setArray(INDArray iNDArray) {
            this.array = iNDArray;
        }

        public void setTag(String str) {
            this.tag = str;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof WatchdogPair)) {
                return false;
            }
            WatchdogPair watchdogPair = (WatchdogPair) obj;
            if (!watchdogPair.canEqual(this)) {
                return false;
            }
            INDArray array = getArray();
            INDArray array2 = watchdogPair.getArray();
            if (array == null) {
                if (array2 != null) {
                    return false;
                }
            } else if (!array.equals(array2)) {
                return false;
            }
            String tag = getTag();
            String tag2 = watchdogPair.getTag();
            return tag == null ? tag2 == null : tag.equals(tag2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof WatchdogPair;
        }

        public int hashCode() {
            INDArray array = getArray();
            int hashCode = (1 * 59) + (array == null ? 43 : array.hashCode());
            String tag = getTag();
            return (hashCode * 59) + (tag == null ? 43 : tag.hashCode());
        }

        public String toString() {
            return "CudaGridExecutioner.WatchdogPair(array=" + getArray() + ", tag=" + getTag() + ")";
        }

        public WatchdogPair() {
        }

        public WatchdogPair(INDArray iNDArray, String str) {
            this.array = iNDArray;
            this.tag = str;
        }
    }

    public CudaGridExecutioner() {
        this.deviceQueues.set(new ArrayDeque());
        int availableDevices = nativeOps.getAvailableDevices();
        for (int i = 0; i < availableDevices; i++) {
            this.aggregates.add(new ConcurrentLinkedQueue());
        }
        this.experimental.set(nativeOps.isExperimentalEnabled());
    }

    @Override // org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner
    public Op exec(Op op) {
        checkForCompression(op);
        invokeWatchdog(op);
        if (op instanceof Accumulation) {
            exec((Accumulation) op, Nd4jCuda.MAX_DIMENSION);
        } else if (op instanceof IndexAccumulation) {
            exec((IndexAccumulation) op, Nd4jCuda.MAX_DIMENSION);
        } else if ((op instanceof ScalarOp) || (op instanceof TransformOp)) {
            processAsGridOp(op, new int[0]);
        } else if (op instanceof BroadcastOp) {
            invoke((BroadcastOp) op);
        } else {
            pushToGrid(new OpDescriptor(op));
        }
        return op;
    }

    protected void pushToGrid(OpDescriptor opDescriptor) {
        pushToGrid(opDescriptor, true);
    }

    protected void invokeWatchdog(Op op) {
        if (this.watchdog.size() > 0) {
            for (WatchdogPair watchdogPair : this.watchdog) {
                if (!compareArrays(watchdogPair.getArray(), op)) {
                    if (compareDevicePointers(watchdogPair.getArray(), op)) {
                        throw new RuntimeException();
                    }
                    if (compareHostPointers(watchdogPair.getArray(), op)) {
                    }
                }
            }
        }
    }

    protected boolean compareDevicePointers(INDArray iNDArray, Op op) {
        CudaContext cudaContext = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
        Pointer pointer = AtomicAllocator.getInstance().getPointer(iNDArray, cudaContext);
        return AtomicAllocator.getInstance().getPointer(op.z(), cudaContext).address() == pointer.address() || (op.y() == null ? 0L : AtomicAllocator.getInstance().getPointer(op.y(), cudaContext).address()) == pointer.address() || AtomicAllocator.getInstance().getPointer(op.x(), cudaContext).address() == pointer.address();
    }

    protected boolean compareHostPointers(INDArray iNDArray, Op op) {
        Pointer pointer = AtomicAllocator.getInstance().getPointer(iNDArray, (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext());
        return AtomicAllocator.getInstance().getHostPointer(op.z()).address() == pointer.address() || (op.y() == null ? 0L : AtomicAllocator.getInstance().getHostPointer(op.y()).address()) == pointer.address() || AtomicAllocator.getInstance().getHostPointer(op.x()).address() == pointer.address();
    }

    protected boolean compareArrays(INDArray iNDArray, Op op) {
        return op.x() == iNDArray || op.y() == iNDArray || op.z() == iNDArray;
    }

    protected void pushToGrid(OpDescriptor opDescriptor, boolean z) {
        this.execCounter.incrementAndGet();
        Variance op = opDescriptor.getOp();
        int[] dimensions = opDescriptor.getDimensions();
        if (op instanceof TransformOp) {
            TransformOp transformOp = (TransformOp) op;
            if (z) {
                flushQueue();
            }
            super.invoke(transformOp);
            return;
        }
        if (op instanceof Variance) {
            Variance variance = op;
            if (z) {
                flushQueue();
            }
            super.naiveExec(variance, dimensions);
            return;
        }
        if (op instanceof Accumulation) {
            Accumulation accumulation = (Accumulation) op;
            if (z) {
                flushQueue();
            }
            super.naiveExec(accumulation, dimensions);
            return;
        }
        if (op instanceof ScalarOp) {
            ScalarOp scalarOp = (ScalarOp) op;
            if (z) {
                flushQueue();
            }
            super.invoke(scalarOp);
            return;
        }
        if (op instanceof BroadcastOp) {
            BroadcastOp broadcastOp = (BroadcastOp) op;
            if (z) {
                flushQueue();
            }
            if (dimensions != null) {
                super.exec(broadcastOp, dimensions);
                return;
            } else {
                super.invoke(broadcastOp);
                return;
            }
        }
        if (op instanceof IndexAccumulation) {
            IndexAccumulation indexAccumulation = (IndexAccumulation) op;
            if (z) {
                flushQueue();
            }
            super.exec(indexAccumulation, dimensions);
            return;
        }
        if (op instanceof MetaOp) {
            this.metaCounter.incrementAndGet();
            exec((MetaOp) op);
        } else if (op instanceof GridOp) {
            exec((GridOp) op);
        }
    }

    public long getMetaCounter() {
        return this.metaCounter.get();
    }

    public long getExecutionCounter() {
        return this.execCounter.get();
    }

    protected void processAsGridOp(Op op, int... iArr) {
        OpDescriptor opDescriptor = this.lastOp.get();
        if (opDescriptor == null) {
            if ((op instanceof TransformOp) && op.y() != null && onCurrentDeviceXYZ(op)) {
                enqueueOp(new OpDescriptor(op, iArr));
                return;
            } else {
                pushToGrid(new OpDescriptor(op, iArr), false);
                return;
            }
        }
        MetaType metaOpType = getMetaOpType(op, iArr);
        this.lastOp.remove();
        switch (metaOpType) {
            case NOT_APPLICABLE:
                dequeueOp(opDescriptor);
                pushToGrid(opDescriptor, false);
                if ((op instanceof TransformOp) && op.y() != null && onCurrentDeviceXYZ(op)) {
                    enqueueOp(new OpDescriptor(op, iArr));
                    return;
                } else {
                    pushToGrid(new OpDescriptor(op, iArr), false);
                    return;
                }
            case PREDICATE:
                pushToGrid(new OpDescriptor(new PredicateMetaOp(opDescriptor, new OpDescriptor(op, iArr))), false);
                return;
            case INVERTED_PREDICATE:
                OpDescriptor opDescriptor2 = new OpDescriptor(op, iArr);
                dequeueOp(opDescriptor);
                dequeueOp(opDescriptor2);
                pushToGrid(new OpDescriptor(new InvertedPredicateMetaOp(opDescriptor, opDescriptor2)), false);
                return;
            case POSTULATE:
                pushToGrid(new OpDescriptor(new PostulateMetaOp(opDescriptor, new OpDescriptor(op, iArr))), false);
                return;
            default:
                throw new UnsupportedOperationException("Not supported MetaType: [" + metaOpType + "]");
        }
    }

    protected boolean onCurrentDeviceXYZ(Op op) {
        int intValue = AtomicAllocator.getInstance().getDeviceId().intValue();
        int intValue2 = AtomicAllocator.getInstance().getDeviceId(op.x()).intValue();
        int intValue3 = AtomicAllocator.getInstance().getDeviceId(op.y()).intValue();
        int intValue4 = AtomicAllocator.getInstance().getDeviceId(op.y()).intValue();
        return intValue == intValue2 && intValue3 == intValue4 && intValue4 == intValue2;
    }

    protected void enqueueOp(OpDescriptor opDescriptor) {
        AtomicAllocator.getInstance().getAllocationPoint(opDescriptor.getOp().x()).markEnqueued(true);
        AtomicAllocator.getInstance().getAllocationPoint(opDescriptor.getOp().z()).markEnqueued(true);
        if (opDescriptor.getOp().y() != null) {
            AtomicAllocator.getInstance().getAllocationPoint(opDescriptor.getOp().y()).markEnqueued(true);
        }
        this.lastOp.set(opDescriptor);
    }

    protected void dequeueOp(OpDescriptor opDescriptor) {
        AtomicAllocator.getInstance().getAllocationPoint(opDescriptor.getOp().x()).markEnqueued(false);
        AtomicAllocator.getInstance().getAllocationPoint(opDescriptor.getOp().z()).markEnqueued(false);
        if (opDescriptor.getOp().y() != null) {
            AtomicAllocator.getInstance().getAllocationPoint(opDescriptor.getOp().y()).markEnqueued(false);
        }
    }

    protected MetaType getMetaOpType(Op op, int... iArr) {
        OpDescriptor opDescriptor = this.lastOp.get();
        if (opDescriptor == null) {
            return MetaType.NOT_APPLICABLE;
        }
        if (this.experimental.get()) {
            logger.info("Experimental hook");
            if ((opDescriptor.getOp() instanceof ScalarOp) || (opDescriptor.getOp() instanceof TransformOp)) {
                return isMatchingZX(opDescriptor.getOp(), op) ? MetaType.PREDICATE : MetaType.NOT_APPLICABLE;
            }
            if ((opDescriptor.getOp() instanceof Accumulation) && (((op instanceof ScalarOp) || (op instanceof TransformOp)) && op.y() == null)) {
                return isMatchingZX(opDescriptor.getOp(), op) ? MetaType.POSTULATE : MetaType.NOT_APPLICABLE;
            }
        } else if ((opDescriptor.getOp() instanceof TransformOp) && opDescriptor.getOp().y() != null && (op instanceof ScalarOp) && ((ScalarOp) op).getDimension() == null && !(op instanceof ScalarMax) && !(op instanceof ScalarMin) && ((op.opNum() < 7 || op.opNum() > 11) && op.opNum() != 16 && op.opNum() != 13 && (op.opNum() < 56 || op.opNum() > 59))) {
            return isMatchingZX(opDescriptor.getOp(), op) ? MetaType.INVERTED_PREDICATE : MetaType.NOT_APPLICABLE;
        }
        return MetaType.NOT_APPLICABLE;
    }

    protected boolean isMatchingZX(Op op, Op op2) {
        return op.x() == op2.x() && op.z() == op2.z() && op.x() == op2.z();
    }

    protected boolean isMatchingZXY(Op op, Op op2) {
        return op.z() == op2.x() || op.z() == op2.y();
    }

    protected GridPointers pointerizeOp(OpDescriptor opDescriptor) {
        return pointerizeOp(opDescriptor.getOp(), opDescriptor.getDimensions());
    }

    protected GridPointers pointerizeOp(Op op, int... iArr) {
        GridPointers gridPointers = new GridPointers(op, iArr);
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext cudaContext = (CudaContext) atomicAllocator.getDeviceContext().getContext();
        gridPointers.setX(atomicAllocator.getPointer(op.x(), cudaContext));
        gridPointers.setXShapeInfo(atomicAllocator.getPointer(op.x().shapeInfoDataBuffer(), cudaContext));
        gridPointers.setZ(atomicAllocator.getPointer(op.z(), cudaContext));
        gridPointers.setZShapeInfo(atomicAllocator.getPointer(op.z().shapeInfoDataBuffer(), cudaContext));
        gridPointers.setZLength(op.z().length());
        if (op.y() != null) {
            gridPointers.setY(atomicAllocator.getPointer(op.y(), cudaContext));
            gridPointers.setYShapeInfo(atomicAllocator.getPointer(op.y().shapeInfoDataBuffer(), cudaContext));
        }
        if (iArr != null && iArr.length > 0) {
            gridPointers.setDimensions(atomicAllocator.getPointer(Nd4j.getConstantHandler().getConstantBuffer(iArr), cudaContext));
            gridPointers.setDimensionsLength(iArr.length);
        }
        if (iArr != null && iArr.length > 0) {
            Pair tADOnlyShapeInfo = tadManager.getTADOnlyShapeInfo(op.x(), iArr);
            Pointer pointer = AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getFirst(), cudaContext);
            Pointer pointer2 = tADOnlyShapeInfo.getSecond() == null ? null : AtomicAllocator.getInstance().getPointer((DataBuffer) tADOnlyShapeInfo.getSecond(), cudaContext);
            gridPointers.setTadShape(pointer);
            gridPointers.setTadOffsets(pointer2);
        }
        return gridPointers;
    }

    public int getQueueLength() {
        return this.lastOp.get() == null ? 0 : 1;
    }

    @Deprecated
    protected int getQueueLength(int i) {
        return -1;
    }

    protected GridOp buildGrid() {
        return null;
    }

    protected void buildZ(IndexAccumulation indexAccumulation, int... iArr) {
        Arrays.sort(iArr);
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] < 0) {
                int i2 = i;
                iArr[i2] = iArr[i2] + indexAccumulation.x().rank();
            }
        }
        if (iArr.length == indexAccumulation.x().rank()) {
            iArr = new int[]{Nd4jCuda.MAX_DIMENSION};
        }
        int[] removeIndex = Shape.wholeArrayDimension(iArr) ? new int[]{1, 1} : ArrayUtil.removeIndex(indexAccumulation.x().shape(), iArr);
        if (removeIndex.length == 1) {
            removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
        } else if (removeIndex.length == 0) {
            removeIndex = new int[]{1, 1};
        }
        if (indexAccumulation.z() == null || indexAccumulation.z() == indexAccumulation.x()) {
            indexAccumulation.setZ(Math.abs(indexAccumulation.zeroDouble()) < Nd4j.EPS_THRESHOLD ? Nd4j.zeros(removeIndex) : Nd4j.valueArrayOf(removeIndex, indexAccumulation.zeroDouble()));
        } else if (!Arrays.equals(removeIndex, indexAccumulation.z().shape())) {
            throw new IllegalStateException("Z array shape does not match expected return type for op " + indexAccumulation + ": expected shape " + Arrays.toString(removeIndex) + ", z.shape()=" + Arrays.toString(indexAccumulation.z().shape()));
        }
    }

    protected void buildZ(Accumulation accumulation, int... iArr) {
        Arrays.sort(iArr);
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] < 0) {
                int i2 = i;
                iArr[i2] = iArr[i2] + accumulation.x().rank();
            }
        }
        if (iArr.length == accumulation.x().rank()) {
            iArr = new int[]{Nd4jCuda.MAX_DIMENSION};
        }
        int[] removeIndex = Shape.wholeArrayDimension(iArr) ? new int[]{1, 1} : ArrayUtil.removeIndex(accumulation.x().shape(), iArr);
        if (removeIndex.length == 1) {
            removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
        } else if (removeIndex.length == 0) {
            removeIndex = new int[]{1, 1};
        }
        if (accumulation.z() == null || accumulation.z() == accumulation.x()) {
            accumulation.setZ(accumulation.isComplexAccumulation() ? Nd4j.create(accumulation.x().tensorssAlongDimension(iArr), accumulation.y().tensorssAlongDimension(iArr)) : Math.abs(accumulation.zeroDouble()) < Nd4j.EPS_THRESHOLD ? Nd4j.zeros(removeIndex) : Nd4j.valueArrayOf(removeIndex, accumulation.zeroDouble()));
            return;
        }
        if (accumulation.z().lengthLong() != ArrayUtil.prodLong(removeIndex)) {
            throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(accumulation.z().shape()) + "] doesn't match expected [" + Arrays.toString(removeIndex) + "]");
        }
        if (accumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            accumulation.z().assign(Double.valueOf(accumulation.zeroDouble()));
        } else if (accumulation.x().data().dataType() == DataBuffer.Type.FLOAT) {
            accumulation.z().assign(Float.valueOf(accumulation.zeroFloat()));
        } else if (accumulation.x().data().dataType() == DataBuffer.Type.HALF) {
            accumulation.z().assign(Float.valueOf(accumulation.zeroHalf()));
        }
        accumulation.z();
    }

    @Override // org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner
    public Op exec(Op op, int... iArr) {
        flushQueue();
        return super.exec(op, iArr);
    }

    @Override // org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner
    public INDArray exec(Accumulation accumulation, int... iArr) {
        if (iArr == null || iArr.length == 0 || iArr[0] == Integer.MAX_VALUE) {
            flushQueue();
            super.exec(accumulation, Nd4jCuda.MAX_DIMENSION);
        } else {
            buildZ(accumulation, iArr);
            processAsGridOp(accumulation, iArr);
        }
        return accumulation.z();
    }

    @Override // org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner
    public INDArray exec(IndexAccumulation indexAccumulation, int... iArr) {
        if (iArr == null || iArr.length == 0 || iArr[0] == Integer.MAX_VALUE) {
            flushQueue();
            buildZ(indexAccumulation, Nd4jCuda.MAX_DIMENSION);
            super.invoke(indexAccumulation, new int[]{Nd4jCuda.MAX_DIMENSION});
        } else {
            buildZ(indexAccumulation, iArr);
            processAsGridOp(indexAccumulation, iArr);
        }
        return indexAccumulation.z();
    }

    @Override // org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner
    public INDArray exec(BroadcastOp broadcastOp, int... iArr) {
        processAsGridOp(broadcastOp, iArr);
        return broadcastOp.z();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner
    public CudaContext invoke(BroadcastOp broadcastOp) {
        processAsGridOp(broadcastOp, broadcastOp.getDimension());
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner
    public CudaContext invoke(ScalarOp scalarOp) {
        processAsGridOp(scalarOp, null);
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner
    public CudaContext invoke(TransformOp transformOp) {
        if (!transformOp.isExecSpecial()) {
            processAsGridOp(transformOp, null);
            return null;
        }
        flushQueue();
        super.invoke(transformOp);
        return null;
    }

    protected void prepareGrid(MetaOp metaOp) {
        GridPointers pointerizeOp = pointerizeOp(metaOp.getFirstOpDescriptor());
        GridPointers pointerizeOp2 = pointerizeOp(metaOp.getSecondOpDescriptor());
        metaOp.setFirstPointers(pointerizeOp);
        metaOp.setSecondPointers(pointerizeOp2);
    }

    public void exec(MetaOp metaOp) {
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        prepareGrid(metaOp);
        GridPointers gridPointers = (GridPointers) metaOp.getGridDescriptor().getGridPointers().get(0);
        GridPointers gridPointers2 = (GridPointers) metaOp.getGridDescriptor().getGridPointers().get(1);
        CudaContext prepareAction = AtomicAllocator.getInstance().getFlowController().prepareAction(gridPointers.getOpZ(), gridPointers.getOpY());
        PointerPointer put = this.extraz.get().put(new Pointer[]{null, prepareAction.getOldStream()});
        double d = 0.0d;
        double d2 = 0.0d;
        if (metaOp.getFirstOp() instanceof ScalarOp) {
            d = metaOp.getFirstOp().scalar().doubleValue();
        }
        if (metaOp.getSecondOp() instanceof ScalarOp) {
            d2 = metaOp.getSecondOp().scalar().doubleValue();
        }
        GridPointers gridPointers3 = gridPointers;
        if (metaOp.getSecondOp().y() != null) {
            gridPointers3 = gridPointers2;
        }
        if ((metaOp instanceof PredicateMetaOp) || (metaOp instanceof InvertedPredicateMetaOp)) {
            if (gridPointers.getDtype() == DataBuffer.Type.FLOAT) {
                if (gridPointers3.getYOrder() != gridPointers3.getXOrder() || gridPointers3.getXStride() < 1 || gridPointers3.getYStride() < 1) {
                    nativeOps.execMetaPredicateShapeFloat(put, gridPointers.getType().ordinal(), gridPointers.getOpNum(), gridPointers2.getType().ordinal(), gridPointers2.getOpNum(), gridPointers.getXLength(), gridPointers.getX(), gridPointers.getXShapeInfo(), gridPointers3.getY(), gridPointers3.getYShapeInfo(), gridPointers2.getZ(), gridPointers2.getZShapeInfo(), gridPointers.getExtraArgs(), gridPointers2.getExtraArgs(), (float) d, (float) d2);
                } else {
                    nativeOps.execMetaPredicateStridedFloat(put, gridPointers.getType().ordinal(), gridPointers.getOpNum(), gridPointers2.getType().ordinal(), gridPointers2.getOpNum(), gridPointers.getXLength(), gridPointers.getX(), gridPointers.getXStride(), gridPointers3.getY(), gridPointers3.getYStride(), gridPointers2.getZ(), gridPointers2.getZStride(), gridPointers.getExtraArgs(), gridPointers2.getExtraArgs(), (float) d, (float) d2);
                }
            } else if (gridPointers.getDtype() == DataBuffer.Type.DOUBLE) {
                if (gridPointers3.getYOrder() != gridPointers3.getXOrder() || gridPointers3.getXStride() < 1 || gridPointers3.getYStride() < 1) {
                    nativeOps.execMetaPredicateShapeDouble(put, gridPointers.getType().ordinal(), gridPointers.getOpNum(), gridPointers2.getType().ordinal(), gridPointers2.getOpNum(), gridPointers.getXLength(), gridPointers.getX(), gridPointers.getXShapeInfo(), gridPointers3.getY(), gridPointers3.getYShapeInfo(), gridPointers2.getZ(), gridPointers2.getZShapeInfo(), gridPointers.getExtraArgs(), gridPointers2.getExtraArgs(), d, d2);
                } else {
                    nativeOps.execMetaPredicateStridedDouble(put, gridPointers.getType().ordinal(), gridPointers.getOpNum(), gridPointers2.getType().ordinal(), gridPointers2.getOpNum(), gridPointers.getXLength(), gridPointers.getX(), gridPointers.getXStride(), gridPointers3.getY(), gridPointers3.getYStride(), gridPointers2.getZ(), gridPointers2.getZStride(), gridPointers.getExtraArgs(), gridPointers2.getExtraArgs(), d, d2);
                }
            } else if (gridPointers3.getYOrder() != gridPointers3.getXOrder() || gridPointers3.getXStride() < 1 || gridPointers3.getYStride() < 1) {
                nativeOps.execMetaPredicateShapeHalf(put, gridPointers.getType().ordinal(), gridPointers.getOpNum(), gridPointers2.getType().ordinal(), gridPointers2.getOpNum(), gridPointers.getXLength(), gridPointers.getX(), gridPointers.getXShapeInfo(), gridPointers3.getY(), gridPointers3.getYShapeInfo(), gridPointers2.getZ(), gridPointers2.getZShapeInfo(), gridPointers.getExtraArgs(), gridPointers2.getExtraArgs(), (float) d, (float) d2);
            } else {
                nativeOps.execMetaPredicateStridedHalf(put, gridPointers.getType().ordinal(), gridPointers.getOpNum(), gridPointers2.getType().ordinal(), gridPointers2.getOpNum(), gridPointers.getXLength(), gridPointers.getX(), gridPointers.getXStride(), gridPointers3.getY(), gridPointers3.getYStride(), gridPointers2.getZ(), gridPointers2.getZStride(), gridPointers.getExtraArgs(), gridPointers2.getExtraArgs(), (float) d, (float) d2);
            }
        } else if ((metaOp instanceof ReduceMetaOp) && gridPointers.getDtype() == DataBuffer.Type.FLOAT) {
            nativeOps.execMetaPredicateReduceFloat(put, gridPointers.getType().ordinal(), gridPointers.getOpNum(), gridPointers2.getType().ordinal(), gridPointers2.getOpNum(), gridPointers.getX(), gridPointers.getXShapeInfo(), gridPointers2.getY(), gridPointers2.getYShapeInfo(), gridPointers2.getZ(), gridPointers2.getZShapeInfo(), gridPointers2.getDimensions(), gridPointers2.getDimensionsLength(), gridPointers2.getTadShape(), new LongPointerWrapper(gridPointers2.getTadOffsets()), gridPointers.getExtraArgs(), gridPointers2.getExtraArgs(), (float) d, 0.0f, false);
        }
        AtomicAllocator.getInstance().getFlowController().registerAction(prepareAction, gridPointers.getOpZ(), gridPointers.getOpY());
    }

    public void exec(GridOp gridOp) {
    }

    protected void purgeQueue() {
        this.lastOp.remove();
    }

    public void flushQueue() {
        OpDescriptor opDescriptor = this.lastOp.get();
        if (opDescriptor != null) {
            if (this.experimental.get()) {
                throw new UnsupportedOperationException("Experimental flush isn't supported yet");
            }
            this.lastOp.remove();
            dequeueOp(opDescriptor);
            pushToGrid(opDescriptor, false);
        }
    }

    public void flushQueueBlocking() {
        flushQueue();
        ((CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()).syncOldStream();
        ((CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()).syncSpecialStream();
    }

    public void addToWatchdog(INDArray iNDArray, String str) {
        this.watchdog.add(new WatchdogPair(iNDArray, str));
    }

    @Override // org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner
    public INDArray exec(RandomOp randomOp) {
        return exec(randomOp, Nd4j.getRandom());
    }

    @Override // org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner
    public void exec(List<Aggregate> list) {
        flushQueue();
        super.exec(list);
    }

    @Override // org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner
    public void exec(Aggregate aggregate) {
        flushQueue();
        super.exec(aggregate);
    }

    public void aggregate(Aggregate aggregate) {
        aggregate(aggregate, Thread.currentThread().getId());
    }

    public void aggregate(Aggregate aggregate, long j) {
        int intValue = Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue();
        if (this.opCounter.get() == null) {
            this.opCounter.set(new AtomicLong(0L));
        }
        this.aggregates.get(intValue).add(new AggregateDescriptor(aggregate, j, this.opCounter.get().getAndIncrement()));
    }

    @Override // org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner
    public INDArray exec(RandomOp randomOp, Random random) {
        flushQueue();
        return super.exec(randomOp, random);
    }

    protected void buildAggregation() {
    }

    public void push() {
        flushQueue();
    }

    @Override // org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner
    public void commit() {
        flushQueueBlocking();
    }
}
