package com.komputation.cuda.layers.forward.activation;

import com.komputation.cuda.CudaFloatArrayKt;
import com.komputation.cuda.kernels.Kernel;
import com.komputation.cuda.kernels.launch.EntrywiseKt;
import com.komputation.layers.Resourceful;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.runtime.JCuda;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.jvm.functions.Function0;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: BaseCudaEntrywiseActivationLayer.kt */
@Metadata(mv = {1, 1, 9}, bv = {1, 0, 2}, k = 1, d1 = {"��F\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000e\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\b\n\u0002\b\u0006\n\u0002\u0010\u0015\n��\n\u0002\u0018\u0002\n\u0002\b\u001b\n\u0002\u0010\u0002\n\u0002\b\u0006\n\u0002\u0010\u000b\n\u0002\b\u0002\b&\u0018��2\u00020\u00012\u00020\u0002BO\b��\u0012\n\b\u0002\u0010\u0003\u001a\u0004\u0018\u00010\u0004\u0012\f\u0010\u0005\u001a\b\u0012\u0004\u0012\u00020\u00070\u0006\u0012\f\u0010\b\u001a\b\u0012\u0004\u0012\u00020\u00070\u0006\u0012\u0006\u0010\t\u001a\u00020\n\u0012\u0006\u0010\u000b\u001a\u00020\n\u0012\u0006\u0010\f\u001a\u00020\n\u0012\u0006\u0010\r\u001a\u00020\n¢\u0006\u0002\u0010\u000eJ\u0010\u0010.\u001a\u00020/2\u0006\u00100\u001a\u00020\nH\u0016J\u0018\u00101\u001a\u00020\u00132\u0006\u0010\u0010\u001a\u00020\n2\u0006\u00102\u001a\u00020\u0013H\u0016J \u00103\u001a\u00020\u00132\u0006\u0010\u0010\u001a\u00020\n2\u0006\u00104\u001a\u00020\u00132\u0006\u00105\u001a\u000206H\u0016J\b\u00107\u001a\u00020/H\u0016R\u0010\u0010\u000f\u001a\u0004\u0018\u00010\u0007X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0010\u001a\u00020\u0011X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\b\u001a\b\u0012\u0004\u0012\u00020\u00070\u0006X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0005\u001a\b\u0012\u0004\u0012\u00020\u00070\u0006X\u0082\u0004¢\u0006\u0002\n��R\u0011\u0010\u0012\u001a\u00020\u0013¢\u0006\b\n��\u001a\u0004\b\u0014\u0010\u0015R\u0011\u0010\u0016\u001a\u00020\u0013¢\u0006\b\n��\u001a\u0004\b\u0017\u0010\u0015R\u0010\u0010\u0018\u001a\u0004\u0018\u00010\u0007X\u0082\u000e¢\u0006\u0002\n��R\u0014\u0010\u0019\u001a\u00020\nX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u001a\u0010\u001bR\u000e\u0010\r\u001a\u00020\nX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u001c\u001a\u00020\nX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u001d\u0010\u001bR\u000e\u0010\u001e\u001a\u00020\nX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u001f\u001a\u00020\nX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u000b\u001a\u00020\nX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010 \u001a\u00020\nX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010!\u001a\u00020\nX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\"\u0010\u001bR\u000e\u0010#\u001a\u00020\u0011X\u0082\u000e¢\u0006\u0002\n��R\u0014\u0010$\u001a\u00020\nX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b%\u0010\u001bR\u000e\u0010\t\u001a\u00020\nX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010&\u001a\u00020\nX\u0082\u000e¢\u0006\u0002\n��R\u0016\u0010'\u001a\n (*\u0004\u0018\u00010\u00130\u0013X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010)\u001a\n (*\u0004\u0018\u00010\u00130\u0013X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010*\u001a\n (*\u0004\u0018\u00010\u00130\u0013X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010+\u001a\n (*\u0004\u0018\u00010\u00130\u0013X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010,\u001a\n (*\u0004\u0018\u00010\u00130\u0013X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010-\u001a\n (*\u0004\u0018\u00010\u00130\u0013X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\f\u001a\u00020\nX\u0082\u0004¢\u0006\u0002\n��¨\u00068"}, d2 = {"Lcom/komputation/cuda/layers/forward/activation/BaseCudaEntrywiseActivationLayer;", "Lcom/komputation/cuda/layers/forward/activation/BaseCudaActivationLayer;", "Lcom/komputation/layers/Resourceful;", "name", "", "createForwardKernel", "Lkotlin/Function0;", "Lcom/komputation/cuda/kernels/Kernel;", "createBackwardKernel", "numberRows", "", "numberColumns", "warpSize", "maximumNumberThreadsPerBlock", "(Ljava/lang/String;Lkotlin/jvm/functions/Function0;Lkotlin/jvm/functions/Function0;IIII)V", "backwardKernel", "batchSize", "", "deviceBackwardResult", "Ljcuda/Pointer;", "getDeviceBackwardResult", "()Ljcuda/Pointer;", "deviceForwardResult", "getDeviceForwardResult", "forwardKernel", "maximumInputColumns", "getMaximumInputColumns", "()I", "maximumOutputColumns", "getMaximumOutputColumns", "numberBlocksInXDimension", "numberBlocksInYDimension", "numberEntries", "numberInputRows", "getNumberInputRows", "numberIterations", "numberOutputRows", "getNumberOutputRows", "numberThreadsPerBlock", "pointerToBatchSize", "kotlin.jvm.PlatformType", "pointerToDeviceBackwardResult", "pointerToDeviceForwardResult", "pointerToNumberEntriesPerInstance", "pointerToNumberIterations", "pointerToNumberRows", "acquire", "", "maximumBatchSize", "backward", "chain", "forward", "deviceInput", "isTraining", "", "release", "komputation"})
/* loaded from: input_file:com/komputation/cuda/layers/forward/activation/BaseCudaEntrywiseActivationLayer.class */
public abstract class BaseCudaEntrywiseActivationLayer extends BaseCudaActivationLayer implements Resourceful {
    private int numberBlocksInXDimension;
    private int numberBlocksInYDimension;
    private int numberThreadsPerBlock;
    private int[] numberIterations;
    private final Pointer pointerToNumberIterations;
    private Kernel forwardKernel;
    private final int numberOutputRows;
    private final int maximumOutputColumns;

    @NotNull
    private final Pointer deviceForwardResult;
    private final Pointer pointerToDeviceForwardResult;
    private Kernel backwardKernel;

    @NotNull
    private final Pointer deviceBackwardResult;
    private final int numberInputRows;
    private final int maximumInputColumns;
    private final Pointer pointerToDeviceBackwardResult;
    private final int numberEntries;
    private final Pointer pointerToNumberEntriesPerInstance;
    private final Pointer pointerToNumberRows;
    private final int[] batchSize;
    private final Pointer pointerToBatchSize;
    private final Function0<Kernel> createForwardKernel;
    private final Function0<Kernel> createBackwardKernel;
    private final int numberRows;
    private final int numberColumns;
    private final int warpSize;
    private final int maximumNumberThreadsPerBlock;

    @Override // com.komputation.cuda.CudaForwardState
    public int getNumberOutputRows() {
        return this.numberOutputRows;
    }

    @Override // com.komputation.cuda.CudaForwardState
    public int getMaximumOutputColumns() {
        return this.maximumOutputColumns;
    }

    @Override // com.komputation.cuda.CudaForwardState
    @NotNull
    public final Pointer getDeviceForwardResult() {
        return this.deviceForwardResult;
    }

    @Override // com.komputation.cuda.CudaBackwardState
    @NotNull
    public final Pointer getDeviceBackwardResult() {
        return this.deviceBackwardResult;
    }

    @Override // com.komputation.cuda.CudaBackwardState
    public int getNumberInputRows() {
        return this.numberInputRows;
    }

    @Override // com.komputation.cuda.CudaBackwardState
    public int getMaximumInputColumns() {
        return this.maximumInputColumns;
    }

    @Override // com.komputation.layers.Resourceful
    public void acquire(int i) {
        CudaFloatArrayKt.allocateDeviceFloatMemory(this.deviceForwardResult, i * this.numberEntries);
        this.forwardKernel = (Kernel) this.createForwardKernel.invoke();
        CudaFloatArrayKt.allocateDeviceFloatMemory(this.deviceBackwardResult, i * this.numberEntries);
        this.backwardKernel = (Kernel) this.createBackwardKernel.invoke();
        this.numberBlocksInXDimension = i;
        this.numberBlocksInYDimension = this.numberColumns;
        Pair<Integer, Integer> computeNumberOfThreadsForRows = EntrywiseKt.computeNumberOfThreadsForRows(this.numberRows, this.warpSize, this.maximumNumberThreadsPerBlock);
        int intValue = ((Number) computeNumberOfThreadsForRows.component1()).intValue();
        this.numberThreadsPerBlock = ((Number) computeNumberOfThreadsForRows.component2()).intValue();
        this.numberIterations[0] = intValue;
    }

    @Override // com.komputation.cuda.layers.CudaForwardLayer
    @NotNull
    public Pointer forward(int i, @NotNull Pointer pointer, boolean z) {
        Intrinsics.checkParameterIsNotNull(pointer, "deviceInput");
        this.batchSize[0] = i;
        Pointer pointer2 = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.pointerToBatchSize, (NativePointerObject) this.pointerToNumberRows, (NativePointerObject) this.pointerToNumberEntriesPerInstance, (NativePointerObject) this.pointerToNumberIterations, (NativePointerObject) Pointer.to(new NativePointerObject[]{(NativePointerObject) pointer}), (NativePointerObject) this.pointerToDeviceForwardResult});
        Kernel kernel = this.forwardKernel;
        if (kernel == null) {
            Intrinsics.throwNpe();
        }
        Intrinsics.checkExpressionValueIsNotNull(pointer2, "forwardParameters");
        kernel.launch(pointer2, this.numberBlocksInXDimension, this.numberBlocksInYDimension, this.numberThreadsPerBlock, 0);
        return this.deviceForwardResult;
    }

    @Override // com.komputation.cuda.layers.CudaForwardLayer
    @NotNull
    public Pointer backward(int i, @NotNull Pointer pointer) {
        Intrinsics.checkParameterIsNotNull(pointer, "chain");
        Pointer pointer2 = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.pointerToBatchSize, (NativePointerObject) this.pointerToNumberRows, (NativePointerObject) this.pointerToNumberEntriesPerInstance, (NativePointerObject) this.pointerToNumberIterations, (NativePointerObject) this.pointerToDeviceForwardResult, (NativePointerObject) Pointer.to(new NativePointerObject[]{(NativePointerObject) pointer}), (NativePointerObject) this.pointerToDeviceBackwardResult});
        Kernel kernel = this.backwardKernel;
        if (kernel == null) {
            Intrinsics.throwNpe();
        }
        Intrinsics.checkExpressionValueIsNotNull(pointer2, "backwardParameters");
        kernel.launch(pointer2, this.numberBlocksInXDimension, this.numberBlocksInYDimension, this.numberThreadsPerBlock, 0);
        return this.deviceBackwardResult;
    }

    @Override // com.komputation.layers.Resourceful
    public void release() {
        Kernel kernel = this.forwardKernel;
        if (kernel == null) {
            Intrinsics.throwNpe();
        }
        kernel.destroy();
        Kernel kernel2 = this.backwardKernel;
        if (kernel2 == null) {
            Intrinsics.throwNpe();
        }
        kernel2.destroy();
        JCuda.cudaFree(this.deviceForwardResult);
        JCuda.cudaFree(this.deviceBackwardResult);
        this.numberBlocksInXDimension = -1;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public BaseCudaEntrywiseActivationLayer(@Nullable String str, @NotNull Function0<Kernel> function0, @NotNull Function0<Kernel> function02, int i, int i2, int i3, int i4) {
        super(str);
        Intrinsics.checkParameterIsNotNull(function0, "createForwardKernel");
        Intrinsics.checkParameterIsNotNull(function02, "createBackwardKernel");
        this.createForwardKernel = function0;
        this.createBackwardKernel = function02;
        this.numberRows = i;
        this.numberColumns = i2;
        this.warpSize = i3;
        this.maximumNumberThreadsPerBlock = i4;
        this.numberBlocksInXDimension = -1;
        this.numberBlocksInYDimension = -1;
        this.numberThreadsPerBlock = -1;
        this.numberIterations = new int[]{-1};
        this.pointerToNumberIterations = Pointer.to(this.numberIterations);
        this.numberOutputRows = this.numberRows;
        this.maximumOutputColumns = this.numberColumns;
        this.deviceForwardResult = new Pointer();
        this.pointerToDeviceForwardResult = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.deviceForwardResult});
        this.deviceBackwardResult = new Pointer();
        this.numberInputRows = this.numberRows;
        this.maximumInputColumns = this.numberColumns;
        this.pointerToDeviceBackwardResult = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.deviceBackwardResult});
        this.numberEntries = this.numberRows * this.numberColumns;
        this.pointerToNumberEntriesPerInstance = Pointer.to(new int[]{this.numberEntries});
        this.pointerToNumberRows = Pointer.to(new int[]{this.numberRows});
        this.batchSize = new int[]{-1};
        this.pointerToBatchSize = Pointer.to(this.batchSize);
    }

    public /* synthetic */ BaseCudaEntrywiseActivationLayer(String str, Function0 function0, Function0 function02, int i, int i2, int i3, int i4, int i5, DefaultConstructorMarker defaultConstructorMarker) {
        this((i5 & 1) != 0 ? (String) null : str, function0, function02, i, i2, i3, i4);
    }
}
