package org.nd4j.jita.allocator.concurrency;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.NonNull;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/jita/allocator/concurrency/DeviceAllocationsTracker.class */
public class DeviceAllocationsTracker {
    private Configuration configuration;
    private final ReentrantReadWriteLock globalLock = new ReentrantReadWriteLock();
    private final Map<Integer, ReentrantReadWriteLock> deviceLocks = new ConcurrentHashMap();
    private final Map<Integer, AtomicLong> memoryTackled = new ConcurrentHashMap();
    private final Map<Integer, AtomicLong> reservedSpace = new ConcurrentHashMap();
    private static Logger log = LoggerFactory.getLogger(DeviceAllocationsTracker.class);

    public DeviceAllocationsTracker(@NonNull Configuration configuration) {
        if (configuration == null) {
            throw new NullPointerException("configuration");
        }
        this.configuration = configuration;
        int numberOfDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        for (int i = 0; i < numberOfDevices; i++) {
            this.deviceLocks.put(Integer.valueOf(i), new ReentrantReadWriteLock());
        }
    }

    protected void ensureThreadRegistered(Long l, Integer num) {
        this.globalLock.readLock().lock();
        this.globalLock.readLock().unlock();
        if (this.memoryTackled.containsKey(num)) {
            return;
        }
        this.globalLock.writeLock().lock();
        if (!this.memoryTackled.containsKey(num)) {
            this.memoryTackled.put(num, new AtomicLong(0L));
        }
        if (!this.reservedSpace.containsKey(num)) {
            this.reservedSpace.put(num, new AtomicLong(0L));
        }
        this.globalLock.writeLock().unlock();
    }

    public long addToAllocation(@NonNull Long l, Integer num, long j) {
        if (l == null) {
            throw new NullPointerException("threadId");
        }
        ensureThreadRegistered(l, num);
        try {
            this.deviceLocks.get(num).readLock().lock();
            long addAndGet = this.memoryTackled.get(num).addAndGet(j);
            subFromReservedSpace(num, j);
            this.deviceLocks.get(num).readLock().unlock();
            return addAndGet;
        } catch (Throwable th) {
            this.deviceLocks.get(num).readLock().unlock();
            throw th;
        }
    }

    public long subFromAllocation(Long l, Integer num, long j) {
        ensureThreadRegistered(l, num);
        try {
            this.deviceLocks.get(num).writeLock().lock();
            AtomicLong atomicLong = this.memoryTackled.get(num);
            atomicLong.addAndGet(j * (-1));
            long j2 = atomicLong.get();
            this.deviceLocks.get(num).writeLock().unlock();
            return j2;
        } catch (Throwable th) {
            this.deviceLocks.get(num).writeLock().unlock();
            throw th;
        }
    }

    public boolean reserveAllocationIfPossible(Long l, Integer num, long j) {
        ensureThreadRegistered(l, num);
        try {
            this.deviceLocks.get(num).writeLock().lock();
            addToReservedSpace(num, j);
            this.deviceLocks.get(num).writeLock().unlock();
            return true;
        } catch (Throwable th) {
            this.deviceLocks.get(num).writeLock().unlock();
            throw th;
        }
    }

    public long getAllocatedSize(Long l, Integer num) {
        ensureThreadRegistered(l, num);
        try {
            this.deviceLocks.get(num).readLock().lock();
            long allocatedSize = getAllocatedSize(num);
            this.deviceLocks.get(num).readLock().unlock();
            return allocatedSize;
        } catch (Throwable th) {
            this.deviceLocks.get(num).readLock().unlock();
            throw th;
        }
    }

    public long getAllocatedSize(Integer num) {
        if (!this.memoryTackled.containsKey(num)) {
            return 0L;
        }
        try {
            this.deviceLocks.get(num).readLock().lock();
            long j = this.memoryTackled.get(num).get();
            this.deviceLocks.get(num).readLock().unlock();
            return j;
        } catch (Throwable th) {
            this.deviceLocks.get(num).readLock().unlock();
            throw th;
        }
    }

    public long getReservedSpace(Integer num) {
        return this.reservedSpace.get(num).get();
    }

    protected void addToReservedSpace(Integer num, long j) {
        ensureThreadRegistered(Long.valueOf(Thread.currentThread().getId()), num);
        this.reservedSpace.get(num).addAndGet(j);
    }

    protected void subFromReservedSpace(Integer num, long j) {
        ensureThreadRegistered(Long.valueOf(Thread.currentThread().getId()), num);
        this.reservedSpace.get(num).addAndGet(j * (-1));
    }
}
