package org.nd4j.jita.allocator.tad;

import java.util.Arrays;
import org.apache.commons.math3.util.Pair;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.jcublas.buffer.AddressRetriever;
import org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/jita/allocator/tad/BasicTADManager.class */
public class BasicTADManager implements TADManager {
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    protected AtomicAllocator allocator = AtomicAllocator.getInstance();
    private static Logger logger = LoggerFactory.getLogger(BasicTADManager.class);

    @Override // org.nd4j.jita.allocator.tad.TADManager
    public Pair<DataBuffer, DataBuffer> getTADOnlyShapeInfo(INDArray iNDArray, int[] iArr) {
        if (iArr == null || iArr[0] == Integer.MAX_VALUE) {
            return new Pair<>(iNDArray.shapeInfoDataBuffer(), (Object) null);
        }
        Arrays.sort(iArr);
        int length = iArr.length;
        int rank = iNDArray.rank();
        int i = 1;
        for (int i2 : iArr) {
            i *= iNDArray.shape()[i2];
        }
        int length2 = iNDArray.length() / i;
        CudaIntDataBuffer cudaIntDataBuffer = new CudaIntDataBuffer((rank * 2) + 4);
        CudaIntDataBuffer cudaIntDataBuffer2 = new CudaIntDataBuffer(length2);
        this.nativeOps.tadOnlyShapeInfo(AddressRetriever.retrieveHostAddress(iNDArray.shapeInfoDataBuffer()), AtomicAllocator.getInstance().getHostPointer(AtomicAllocator.getInstance().getConstantBuffer(iArr)).address(), length, AddressRetriever.retrieveHostAddress(cudaIntDataBuffer), AddressRetriever.retrieveHostAddress(cudaIntDataBuffer2));
        AtomicAllocator.getInstance().getAllocationPoint(cudaIntDataBuffer).tickHostWrite();
        AtomicAllocator.getInstance().getAllocationPoint(cudaIntDataBuffer2).tickHostWrite();
        return new Pair<>(cudaIntDataBuffer, cudaIntDataBuffer2);
    }
}
