package com.komputation.cuda.optimization;

import com.komputation.cuda.CudaIntArrayKt;
import com.komputation.layers.Resourceful;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.runtime.JCuda;
import kotlin.Metadata;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: BaseCudaUpdateRule.kt */
@Metadata(mv = {1, 1, 9}, bv = {1, 0, 2}, k = 1, d1 = {"��,\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0010\b\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\u0002\n\u0002\b\f\b&\u0018��2\u00020\u00012\u00020\u0002B\u0005¢\u0006\u0002\u0010\u0003J\u0010\u0010\f\u001a\u00020\r2\u0006\u0010\u000e\u001a\u00020\u0006H\u0016J \u0010\u000f\u001a\u00020\r2\u0006\u0010\u0010\u001a\u00020\u00062\u0006\u0010\u0011\u001a\u00020\u00072\u0006\u0010\u0012\u001a\u00020\u0007H\u0016J0\u0010\u0013\u001a\u00020\u00062\u0006\u0010\u0014\u001a\u00020\u00062\u0006\u0010\u0015\u001a\u00020\u00072\u0006\u0010\u0016\u001a\u00020\u00072\u0006\u0010\u0011\u001a\u00020\u00072\u0006\u0010\u0012\u001a\u00020\u0007H&J\b\u0010\u0017\u001a\u00020\rH\u0016J0\u0010\u0018\u001a\u00020\r2\u0006\u0010\u0014\u001a\u00020\u00062\u0006\u0010\u0015\u001a\u00020\u00072\u0006\u0010\u0016\u001a\u00020\u00072\u0006\u0010\u0011\u001a\u00020\u00072\u0006\u0010\u0012\u001a\u00020\u0007H\u0016R*\u0010\u0004\u001a\u001e\u0012\u0004\u0012\u00020\u0006\u0012\u0004\u0012\u00020\u00070\u0005j\u000e\u0012\u0004\u0012\u00020\u0006\u0012\u0004\u0012\u00020\u0007`\bX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\t\u001a\u00020\u0007X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010\n\u001a\n \u000b*\u0004\u0018\u00010\u00070\u0007X\u0082\u0004¢\u0006\u0002\n��¨\u0006\u0019"}, d2 = {"Lcom/komputation/cuda/optimization/BaseCudaUpdateRule;", "Lcom/komputation/cuda/optimization/CudaUpdateRule;", "Lcom/komputation/layers/Resourceful;", "()V", "deviceCountMap", "Ljava/util/HashMap;", "", "Ljcuda/Pointer;", "Lkotlin/collections/HashMap;", "deviceZero", "pointerToZero", "kotlin.jvm.PlatformType", "acquire", "", "maximumBatchSize", "denseUpdate", "count", "pointerToParameters", "pointerToGradient", "launchKernel", "maximumParameters", "pointerToIndices", "pointerToCounts", "release", "sparseUpdate", "komputation"})
/* loaded from: input_file:com/komputation/cuda/optimization/BaseCudaUpdateRule.class */
public abstract class BaseCudaUpdateRule implements CudaUpdateRule, Resourceful {
    private final Pointer deviceZero = new Pointer();
    private final Pointer pointerToZero = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.deviceZero});
    private final HashMap<Integer, Pointer> deviceCountMap = new HashMap<>();

    @Override // com.komputation.layers.Resourceful
    public void acquire(int i) {
        CudaIntArrayKt.setIntArray(new int[]{0}, 1, this.deviceZero);
    }

    @Override // com.komputation.layers.Resourceful
    public void release() {
        JCuda.cudaFree(this.deviceZero);
        Collection<Pointer> values = this.deviceCountMap.values();
        Intrinsics.checkExpressionValueIsNotNull(values, "this.deviceCountMap.values");
        Iterator<T> it = values.iterator();
        while (it.hasNext()) {
            JCuda.cudaFree((Pointer) it.next());
        }
    }

    @Override // com.komputation.cuda.optimization.CudaUpdateRule
    public void denseUpdate(int i, @NotNull Pointer pointer, @NotNull Pointer pointer2) {
        Pointer pointer3;
        Intrinsics.checkParameterIsNotNull(pointer, "pointerToParameters");
        Intrinsics.checkParameterIsNotNull(pointer2, "pointerToGradient");
        Pointer pointer4 = this.deviceCountMap.get(Integer.valueOf(i));
        if (pointer4 == null) {
            Pointer pointer5 = new Pointer();
            CudaIntArrayKt.setIntArray(new int[]{i}, 1, pointer5);
            this.deviceCountMap.put(Integer.valueOf(i), pointer5);
            pointer3 = pointer5;
        } else {
            pointer3 = pointer4;
        }
        Pointer pointer6 = this.pointerToZero;
        Intrinsics.checkExpressionValueIsNotNull(pointer6, "this.pointerToZero");
        Pointer pointer7 = Pointer.to(new NativePointerObject[]{(NativePointerObject) pointer3});
        Intrinsics.checkExpressionValueIsNotNull(pointer7, "Pointer.to(deviceCount)");
        launchKernel(1, pointer6, pointer7, pointer, pointer2);
    }

    @Override // com.komputation.cuda.optimization.CudaUpdateRule
    public void sparseUpdate(int i, @NotNull Pointer pointer, @NotNull Pointer pointer2, @NotNull Pointer pointer3, @NotNull Pointer pointer4) {
        Intrinsics.checkParameterIsNotNull(pointer, "pointerToIndices");
        Intrinsics.checkParameterIsNotNull(pointer2, "pointerToCounts");
        Intrinsics.checkParameterIsNotNull(pointer3, "pointerToParameters");
        Intrinsics.checkParameterIsNotNull(pointer4, "pointerToGradient");
        launchKernel(i, pointer, pointer2, pointer3, pointer4);
    }

    public abstract int launchKernel(int i, @NotNull Pointer pointer, @NotNull Pointer pointer2, @NotNull Pointer pointer3, @NotNull Pointer pointer4);
}
