package org.nd4j.jita.constant;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicLong;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/jita/constant/CudaConstantHandler.class */
public class CudaConstantHandler implements ConstantHandler {
    private static Logger logger = LoggerFactory.getLogger(CudaConstantHandler.class);
    protected Map<Integer, AtomicLong> constantOffsets = new HashMap();
    protected Map<Integer, Semaphore> deviceLocks = new ConcurrentHashMap();
    protected Map<Integer, Map<ArrayDescriptor, DataBuffer>> buffersCache = new HashMap();
    protected Map<Integer, Long> deviceAddresses = new HashMap();
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    protected Semaphore lock = new Semaphore(1);

    @Override // org.nd4j.jita.constant.ConstantHandler
    public long moveToConstantSpace(DataBuffer dataBuffer) {
        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
        ensureMaps(deviceId);
        AllocationPoint allocationPoint = AtomicAllocator.getInstance().getAllocationPoint(dataBuffer);
        long requiredMemory = AllocationUtils.getRequiredMemory(allocationPoint.getShape());
        long j = this.constantOffsets.get(deviceId).get();
        CudaContext cudaContext = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
        if (j >= 49152 || requiredMemory > 272) {
            this.nativeOps.memcpyAsync(allocationPoint.getPointers().getDevicePointer().address(), allocationPoint.getPointers().getHostPointer().address(), requiredMemory, 1, cudaContext.getOldStream().getNativePointer());
            allocationPoint.setConstant(true);
            allocationPoint.tickDeviceWrite();
            allocationPoint.tickHostRead();
            return 0L;
        }
        long andAdd = this.constantOffsets.get(deviceId).getAndAdd(requiredMemory);
        if (andAdd >= 49152) {
            this.nativeOps.memcpyAsync(allocationPoint.getPointers().getDevicePointer().address(), allocationPoint.getPointers().getHostPointer().address(), requiredMemory, 1, cudaContext.getOldStream().getNativePointer());
            allocationPoint.setConstant(true);
            allocationPoint.tickDeviceWrite();
            allocationPoint.tickHostRead();
            return 0L;
        }
        this.nativeOps.memcpyConstantAsync(andAdd, allocationPoint.getPointers().getHostPointer().address(), requiredMemory, 1, cudaContext.getOldStream().getNativePointer());
        long longValue = this.deviceAddresses.get(deviceId).longValue() + andAdd;
        allocationPoint.getPointers().setDevicePointer(new CudaPointer(longValue));
        allocationPoint.setConstant(true);
        allocationPoint.tickDeviceWrite();
        allocationPoint.tickHostRead();
        return longValue;
    }

    private void ensureMaps(Integer num) {
        try {
            if (this.buffersCache.containsKey(num)) {
                return;
            }
            try {
                this.lock.acquire();
                if (!this.buffersCache.containsKey(num)) {
                    this.buffersCache.put(num, new ConcurrentHashMap());
                    this.constantOffsets.put(num, new AtomicLong(0L));
                    this.deviceLocks.put(num, new Semaphore(1));
                    this.deviceAddresses.put(num, Long.valueOf(this.nativeOps.getConstantSpace()));
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        } finally {
            this.lock.release();
        }
    }

    @Override // org.nd4j.jita.constant.ConstantHandler
    public DataBuffer getConstantBuffer(int[] iArr) {
        ArrayDescriptor arrayDescriptor = new ArrayDescriptor(iArr);
        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
        ensureMaps(deviceId);
        if (this.buffersCache.get(deviceId).containsKey(arrayDescriptor)) {
            return this.buffersCache.get(deviceId).get(arrayDescriptor);
        }
        DataBuffer createBuffer = Nd4j.createBuffer(iArr);
        moveToConstantSpace(createBuffer);
        this.buffersCache.get(deviceId).put(arrayDescriptor, createBuffer);
        return createBuffer;
    }

    @Override // org.nd4j.jita.constant.ConstantHandler
    public DataBuffer getConstantBuffer(float[] fArr) {
        ArrayDescriptor arrayDescriptor = new ArrayDescriptor(fArr);
        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
        ensureMaps(deviceId);
        if (this.buffersCache.get(deviceId).containsKey(arrayDescriptor)) {
            return this.buffersCache.get(deviceId).get(arrayDescriptor);
        }
        DataBuffer createBuffer = Nd4j.createBuffer(fArr);
        moveToConstantSpace(createBuffer);
        this.buffersCache.get(deviceId).put(arrayDescriptor, createBuffer);
        return createBuffer;
    }

    @Override // org.nd4j.jita.constant.ConstantHandler
    public DataBuffer getConstantBuffer(double[] dArr) {
        ArrayDescriptor arrayDescriptor = new ArrayDescriptor(dArr);
        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
        ensureMaps(deviceId);
        if (this.buffersCache.get(deviceId).containsKey(arrayDescriptor)) {
            return this.buffersCache.get(deviceId).get(arrayDescriptor);
        }
        DataBuffer createBuffer = Nd4j.createBuffer(dArr);
        moveToConstantSpace(createBuffer);
        this.buffersCache.get(deviceId).put(arrayDescriptor, createBuffer);
        return createBuffer;
    }
}
