package org.nd4j.jita.allocator.tad;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import org.apache.commons.math3.util.Pair;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/jita/allocator/tad/DeviceTADManager.class */
public class DeviceTADManager extends BasicTADManager {
    protected Map<Integer, Map<TadDescriptor, Pair<DataBuffer, DataBuffer>>> tadCache = new ConcurrentHashMap();
    private Semaphore lock = new Semaphore(1);
    private static Logger logger = LoggerFactory.getLogger(DeviceTADManager.class);

    @Override // org.nd4j.jita.allocator.tad.BasicTADManager, org.nd4j.jita.allocator.tad.TADManager
    public Pair<DataBuffer, DataBuffer> getTADOnlyShapeInfo(INDArray iNDArray, int[] iArr) {
        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
        TadDescriptor tadDescriptor = new TadDescriptor(iNDArray, iArr);
        if (!this.tadCache.containsKey(deviceId)) {
            try {
                try {
                    this.lock.acquire();
                    if (!this.tadCache.containsKey(deviceId)) {
                        this.tadCache.put(deviceId, new ConcurrentHashMap());
                    }
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            } finally {
                this.lock.release();
            }
        }
        if (!this.tadCache.get(deviceId).containsKey(tadDescriptor)) {
            Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = super.getTADOnlyShapeInfo(iNDArray, iArr);
            if (tADOnlyShapeInfo.getFirst() != iNDArray.shapeInfoDataBuffer()) {
                AtomicAllocator.getInstance().moveToConstant((DataBuffer) tADOnlyShapeInfo.getFirst());
            }
            if (tADOnlyShapeInfo.getSecond() != null) {
                AtomicAllocator.getInstance().moveToConstant((DataBuffer) tADOnlyShapeInfo.getSecond());
            }
            this.tadCache.get(deviceId).put(tadDescriptor, tADOnlyShapeInfo);
        }
        return this.tadCache.get(deviceId).get(tadDescriptor);
    }
}
