package com.komputation.cuda.loss;

import com.komputation.cuda.CudaFloatArrayKt;
import com.komputation.cuda.kernels.Kernel;
import com.komputation.cuda.kernels.launch.EntrywiseKt;
import com.komputation.cuda.kernels.launch.KernelLaunchConfiguration;
import com.komputation.cuda.kernels.launch.RowwiseKt;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.runtime.JCuda;
import kotlin.Metadata;
import kotlin.collections.ArraysKt;
import kotlin.jvm.functions.Function0;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: CudaLogisticLoss.kt */
@Metadata(mv = {1, 1, 9}, bv = {1, 0, 2}, k = 1, d1 = {"��:\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0010\u0015\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\u0011\n\u0002\u0010\u0007\n��\n\u0002\u0010\u0002\n\u0002\b\u0007\u0018��2\u00020\u0001BK\b��\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005\u0012\f\u0010\u0007\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005\u0012\u0006\u0010\b\u001a\u00020\u0003\u0012\u0006\u0010\t\u001a\u00020\u0003\u0012\u0006\u0010\n\u001a\u00020\u0003\u0012\u0006\u0010\u000b\u001a\u00020\u0003¢\u0006\u0002\u0010\fJ\b\u0010%\u001a\u00020&H\u0016J \u0010'\u001a\u00020(2\u0006\u0010)\u001a\u00020\u00142\u0006\u0010*\u001a\u00020\u00142\u0006\u0010+\u001a\u00020\u0003H\u0016J\u0010\u0010,\u001a\u00020(2\u0006\u0010\u001c\u001a\u00020\u0003H\u0016J \u0010-\u001a\u00020\u00142\u0006\u0010)\u001a\u00020\u00142\u0006\u0010*\u001a\u00020\u00142\u0006\u0010+\u001a\u00020\u0003H\u0016J\b\u0010.\u001a\u00020(H\u0016R\u000e\u0010\r\u001a\u00020\u000eX\u0082\u0004¢\u0006\u0002\n��R\u0010\u0010\u000f\u001a\u0004\u0018\u00010\u0006X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0010\u001a\u00020\u0003X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0011\u001a\u00020\u000eX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0012\u001a\u00020\u0003X\u0082\u000e¢\u0006\u0002\n��R\u0014\u0010\u0007\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0013\u001a\u00020\u0014X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0015\u001a\u00020\u0014X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0016\u001a\u00020\u000eX\u0082\u0004¢\u0006\u0002\n��R\u0010\u0010\u0017\u001a\u0004\u0018\u00010\u0006X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0018\u001a\u00020\u0003X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0019\u001a\u00020\u000eX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u001a\u001a\u00020\u0003X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u001b\u001a\u00020\u0003X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u001c\u001a\u00020\u0003X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u000b\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\b\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\t\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010\u001d\u001a\n \u001e*\u0004\u0018\u00010\u00140\u0014X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010\u001f\u001a\n \u001e*\u0004\u0018\u00010\u00140\u0014X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010 \u001a\n \u001e*\u0004\u0018\u00010\u00140\u0014X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010!\u001a\n \u001e*\u0004\u0018\u00010\u00140\u0014X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010\"\u001a\n \u001e*\u0004\u0018\u00010\u00140\u0014X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010#\u001a\n \u001e*\u0004\u0018\u00010\u00140\u0014X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010$\u001a\n \u001e*\u0004\u0018\u00010\u00140\u0014X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\n\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��¨\u0006/"}, d2 = {"Lcom/komputation/cuda/loss/CudaLogisticLoss;", "Lcom/komputation/cuda/loss/CudaLossFunction;", "numberSteps", "", "createForwardKernel", "Lkotlin/Function0;", "Lcom/komputation/cuda/kernels/Kernel;", "createBackwardKernel", "numberMultiprocessors", "numberResidentWarps", "warpSize", "maximumNumberThreadsPerBlock", "(ILkotlin/jvm/functions/Function0;Lkotlin/jvm/functions/Function0;IIII)V", "backwardBatchSize", "", "backwardKernel", "backwardNumberBlocksInYDimension", "backwardNumberIterations", "backwardNumberThreadsPerBlock", "deviceBackwardResult", "Ljcuda/Pointer;", "deviceForwardResult", "forwardBatchSize", "forwardKernel", "forwardNumberBlocksInYDimension", "forwardNumberIterations", "forwardNumberThreadsPerBlock", "forwardSharedMemoryBytes", "maximumBatchSize", "pointerToBackwardBatchSize", "kotlin.jvm.PlatformType", "pointerToBackwardNumberIterations", "pointerToBackwardResult", "pointerToDeviceForwardResult", "pointerToForwardBatchSize", "pointerToForwardNumberIterations", "pointerToNumberSteps", "accessAccumulation", "", "accumulate", "", "pointerToPredictions", "pointerToTargets", "batchSize", "acquire", "backward", "release", "komputation"})
/* loaded from: input_file:com/komputation/cuda/loss/CudaLogisticLoss.class */
public final class CudaLogisticLoss implements CudaLossFunction {
    private final Pointer pointerToNumberSteps;
    private Kernel forwardKernel;
    private final Pointer deviceForwardResult;
    private final Pointer pointerToDeviceForwardResult;
    private int maximumBatchSize;
    private final int[] forwardBatchSize;
    private final Pointer pointerToForwardBatchSize;
    private int forwardNumberBlocksInYDimension;
    private int forwardNumberThreadsPerBlock;
    private final int[] forwardNumberIterations;
    private final Pointer pointerToForwardNumberIterations;
    private int forwardSharedMemoryBytes;
    private Kernel backwardKernel;
    private final Pointer deviceBackwardResult;
    private final Pointer pointerToBackwardResult;
    private final int[] backwardBatchSize;
    private final Pointer pointerToBackwardBatchSize;
    private int backwardNumberBlocksInYDimension;
    private int backwardNumberThreadsPerBlock;
    private final int[] backwardNumberIterations;
    private final Pointer pointerToBackwardNumberIterations;
    private final int numberSteps;
    private final Function0<Kernel> createForwardKernel;
    private final Function0<Kernel> createBackwardKernel;
    private final int numberMultiprocessors;
    private final int numberResidentWarps;
    private final int warpSize;
    private final int maximumNumberThreadsPerBlock;

    @Override // com.komputation.layers.Resourceful
    public void acquire(int i) {
        this.maximumBatchSize = i;
        CudaFloatArrayKt.allocateDeviceFloatMemory(this.deviceForwardResult, i * this.numberSteps);
        KernelLaunchConfiguration computeRowwiseLaunchConfiguration = RowwiseKt.computeRowwiseLaunchConfiguration(1, this.numberSteps, this.warpSize, this.maximumNumberThreadsPerBlock);
        this.forwardNumberBlocksInYDimension = computeRowwiseLaunchConfiguration.getNumberBlocks();
        this.forwardNumberThreadsPerBlock = computeRowwiseLaunchConfiguration.getNumberThreadsPerBlock();
        this.forwardNumberIterations[0] = computeRowwiseLaunchConfiguration.getNumberIterations();
        this.forwardKernel = (Kernel) this.createForwardKernel.invoke();
        this.forwardSharedMemoryBytes = (int) CudaFloatArrayKt.computeDeviceFloatArraySize((((this.numberSteps / computeRowwiseLaunchConfiguration.getNumberIterations()) + this.warpSize) - 1) / this.warpSize);
        CudaFloatArrayKt.allocateDeviceFloatMemory(this.deviceBackwardResult, i * this.numberSteps);
        KernelLaunchConfiguration computeEntrywiseLaunchConfiguration = EntrywiseKt.computeEntrywiseLaunchConfiguration(this.numberSteps, this.numberMultiprocessors, this.numberResidentWarps, this.warpSize, this.maximumNumberThreadsPerBlock);
        this.backwardNumberBlocksInYDimension = computeEntrywiseLaunchConfiguration.getNumberBlocks();
        this.backwardNumberThreadsPerBlock = computeEntrywiseLaunchConfiguration.getNumberThreadsPerBlock();
        this.backwardNumberIterations[0] = computeEntrywiseLaunchConfiguration.getNumberIterations();
        this.backwardKernel = (Kernel) this.createBackwardKernel.invoke();
    }

    @Override // com.komputation.cuda.loss.CudaLossFunction
    public void accumulate(@NotNull Pointer pointer, @NotNull Pointer pointer2, int i) {
        Intrinsics.checkParameterIsNotNull(pointer, "pointerToPredictions");
        Intrinsics.checkParameterIsNotNull(pointer2, "pointerToTargets");
        this.forwardBatchSize[0] = i;
        Pointer pointer3 = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.pointerToForwardBatchSize, (NativePointerObject) this.pointerToNumberSteps, (NativePointerObject) this.pointerToForwardNumberIterations, (NativePointerObject) pointer, (NativePointerObject) pointer2, (NativePointerObject) this.pointerToDeviceForwardResult});
        Kernel kernel = this.forwardKernel;
        if (kernel == null) {
            Intrinsics.throwNpe();
        }
        Intrinsics.checkExpressionValueIsNotNull(pointer3, "parameters");
        kernel.launch(pointer3, i, this.forwardNumberBlocksInYDimension, this.forwardNumberThreadsPerBlock, this.forwardSharedMemoryBytes);
    }

    @Override // com.komputation.cuda.loss.CudaLossFunction
    public float accessAccumulation() {
        return ArraysKt.sum(CudaFloatArrayKt.getFloatArray(this.deviceForwardResult, this.maximumBatchSize * this.numberSteps));
    }

    @Override // com.komputation.cuda.loss.CudaLossFunction
    @NotNull
    public Pointer backward(@NotNull Pointer pointer, @NotNull Pointer pointer2, int i) {
        Intrinsics.checkParameterIsNotNull(pointer, "pointerToPredictions");
        Intrinsics.checkParameterIsNotNull(pointer2, "pointerToTargets");
        this.backwardBatchSize[0] = i;
        Pointer pointer3 = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.pointerToBackwardBatchSize, (NativePointerObject) this.pointerToNumberSteps, (NativePointerObject) this.pointerToBackwardNumberIterations, (NativePointerObject) pointer, (NativePointerObject) pointer2, (NativePointerObject) this.pointerToBackwardResult});
        Kernel kernel = this.backwardKernel;
        if (kernel == null) {
            Intrinsics.throwNpe();
        }
        Intrinsics.checkExpressionValueIsNotNull(pointer3, "parameters");
        kernel.launch(pointer3, this.maximumBatchSize, this.backwardNumberBlocksInYDimension, this.backwardNumberThreadsPerBlock, 0);
        return this.deviceBackwardResult;
    }

    @Override // com.komputation.layers.Resourceful
    public void release() {
        Kernel kernel = this.forwardKernel;
        if (kernel == null) {
            Intrinsics.throwNpe();
        }
        kernel.destroy();
        JCuda.cudaFree(this.deviceForwardResult);
        Kernel kernel2 = this.backwardKernel;
        if (kernel2 == null) {
            Intrinsics.throwNpe();
        }
        kernel2.destroy();
        JCuda.cudaFree(this.deviceBackwardResult);
    }

    public CudaLogisticLoss(int i, @NotNull Function0<Kernel> function0, @NotNull Function0<Kernel> function02, int i2, int i3, int i4, int i5) {
        Intrinsics.checkParameterIsNotNull(function0, "createForwardKernel");
        Intrinsics.checkParameterIsNotNull(function02, "createBackwardKernel");
        this.numberSteps = i;
        this.createForwardKernel = function0;
        this.createBackwardKernel = function02;
        this.numberMultiprocessors = i2;
        this.numberResidentWarps = i3;
        this.warpSize = i4;
        this.maximumNumberThreadsPerBlock = i5;
        this.pointerToNumberSteps = Pointer.to(new int[]{this.numberSteps});
        this.deviceForwardResult = new Pointer();
        this.pointerToDeviceForwardResult = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.deviceForwardResult});
        this.maximumBatchSize = -1;
        this.forwardBatchSize = new int[]{-1};
        this.pointerToForwardBatchSize = Pointer.to(this.forwardBatchSize);
        this.forwardNumberBlocksInYDimension = -1;
        this.forwardNumberThreadsPerBlock = -1;
        this.forwardNumberIterations = new int[]{-1};
        this.pointerToForwardNumberIterations = Pointer.to(this.forwardNumberIterations);
        this.forwardSharedMemoryBytes = -1;
        this.deviceBackwardResult = new Pointer();
        this.pointerToBackwardResult = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.deviceBackwardResult});
        this.backwardBatchSize = new int[]{-1};
        this.pointerToBackwardBatchSize = Pointer.to(this.backwardBatchSize);
        this.backwardNumberBlocksInYDimension = -1;
        this.backwardNumberThreadsPerBlock = -1;
        this.backwardNumberIterations = new int[]{-1};
        this.pointerToBackwardNumberIterations = Pointer.to(this.backwardNumberIterations);
    }
}
