package com.komputation.cuda.memory;

import com.komputation.cuda.CudaFloatArrayKt;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.runtime.JCuda;
import kotlin.Metadata;
import kotlin.collections.ArraysKt;
import kotlin.collections.IndexedValue;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: TargetMemory.kt */
@Metadata(mv = {1, 1, 9}, bv = {1, 0, 2}, k = 1, d1 = {"��:\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n\u0002\b\u0005\n\u0002\u0010\u0015\n��\n\u0002\u0010\u0011\n\u0002\u0010\u0014\n\u0002\b\u0002\u0018��2\u00020\u0001B\r\u0012\u0006\u0010\u0002\u001a\u00020\u0003¢\u0006\u0002\u0010\u0004J\u0006\u0010\t\u001a\u00020\nJ9\u0010\u000b\u001a\n \f*\u0004\u0018\u00010\u00070\u00072\u0006\u0010\r\u001a\u00020\u00032\u0006\u0010\u000e\u001a\u00020\u00032\u0006\u0010\u000f\u001a\u00020\u00102\f\u0010\u0011\u001a\b\u0012\u0004\u0012\u00020\u00130\u0012¢\u0006\u0002\u0010\u0014R*\u0010\u0005\u001a\u001e\u0012\u0004\u0012\u00020\u0003\u0012\u0004\u0012\u00020\u00070\u0006j\u000e\u0012\u0004\u0012\u00020\u0003\u0012\u0004\u0012\u00020\u0007`\bX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��¨\u0006\u0015"}, d2 = {"Lcom/komputation/cuda/memory/TargetMemory;", "", "targetSize", "", "(I)V", "memory", "Ljava/util/HashMap;", "Ljcuda/Pointer;", "Lkotlin/collections/HashMap;", "free", "", "get", "kotlin.jvm.PlatformType", "batchId", "batchSize", "batch", "", "targets", "", "", "(II[I[[F)Ljcuda/Pointer;", "komputation"})
/* loaded from: input_file:com/komputation/cuda/memory/TargetMemory.class */
public final class TargetMemory {
    private final HashMap<Integer, Pointer> memory = new HashMap<>();
    private final int targetSize;

    public final Pointer get(int i, int i2, @NotNull int[] iArr, @NotNull float[][] fArr) {
        Intrinsics.checkParameterIsNotNull(iArr, "batch");
        Intrinsics.checkParameterIsNotNull(fArr, "targets");
        if (this.memory.containsKey(Integer.valueOf(i))) {
            Pointer pointer = this.memory.get(Integer.valueOf(i));
            if (pointer == null) {
                Intrinsics.throwNpe();
            }
            return pointer;
        }
        int i3 = i2 * this.targetSize;
        float[] fArr2 = new float[i3];
        for (IndexedValue indexedValue : ArraysKt.withIndex(iArr)) {
            System.arraycopy(fArr[((Number) indexedValue.component2()).intValue()], 0, fArr2, indexedValue.component1() * this.targetSize, this.targetSize);
        }
        NativePointerObject pointer2 = new Pointer();
        CudaFloatArrayKt.setFloatArray(fArr2, i3, pointer2);
        Pointer pointer3 = Pointer.to(new NativePointerObject[]{pointer2});
        HashMap<Integer, Pointer> hashMap = this.memory;
        Integer valueOf = Integer.valueOf(i);
        Intrinsics.checkExpressionValueIsNotNull(pointer3, "pointerToDeviceTargets");
        hashMap.put(valueOf, pointer3);
        return pointer3;
    }

    public final void free() {
        Collection<Pointer> values = this.memory.values();
        Intrinsics.checkExpressionValueIsNotNull(values, "this.memory.values");
        Iterator<T> it = values.iterator();
        while (it.hasNext()) {
            JCuda.cudaFree((Pointer) it.next());
        }
    }

    public TargetMemory(int i) {
        this.targetSize = i;
    }
}
