package com.komputation.cuda.layers.forward.dropout;

import com.komputation.cpu.functions.DropoutKt;
import com.komputation.cuda.CudaFloatArrayKt;
import com.komputation.cuda.CudaIntArrayKt;
import com.komputation.cuda.kernels.Kernel;
import com.komputation.cuda.kernels.launch.EntrywiseKt;
import com.komputation.cuda.layers.forward.activation.BaseCudaActivationLayer;
import com.komputation.layers.Resourceful;
import java.util.Random;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.runtime.JCuda;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.jvm.functions.Function0;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: CudaDropoutLayer.kt */
@Metadata(mv = {1, 1, 9}, bv = {1, 0, 2}, k = 1, d1 = {"��R\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000e\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0007\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0010\u0015\n��\n\u0002\u0018\u0002\n\u0002\b\"\n\u0002\u0010\u0002\n\u0002\b\u0006\n\u0002\u0010\u000b\n\u0002\b\u0002\u0018��2\u00020\u00012\u00020\u0002Bm\b��\u0012\n\b\u0002\u0010\u0003\u001a\u0004\u0018\u00010\u0004\u0012\u0006\u0010\u0005\u001a\u00020\u0006\u0012\u0006\u0010\u0007\u001a\u00020\u0006\u0012\u0006\u0010\b\u001a\u00020\t\u0012\u0006\u0010\n\u001a\u00020\u000b\u0012\f\u0010\f\u001a\b\u0012\u0004\u0012\u00020\u000e0\r\u0012\f\u0010\u000f\u001a\b\u0012\u0004\u0012\u00020\u000e0\r\u0012\f\u0010\u0010\u001a\b\u0012\u0004\u0012\u00020\u000e0\r\u0012\u0006\u0010\u0011\u001a\u00020\u0006\u0012\u0006\u0010\u0012\u001a\u00020\u0006¢\u0006\u0002\u0010\u0013J\u0010\u0010:\u001a\u00020;2\u0006\u0010<\u001a\u00020\u0006H\u0016J\u0018\u0010=\u001a\u00020\u00182\u0006\u0010\u0015\u001a\u00020\u00062\u0006\u0010>\u001a\u00020\u0018H\u0016J \u0010?\u001a\u00020\u00182\u0006\u0010\u0015\u001a\u00020\u00062\u0006\u0010@\u001a\u00020\u00182\u0006\u0010A\u001a\u00020BH\u0016J\b\u0010C\u001a\u00020;H\u0016R\u0010\u0010\u0014\u001a\u0004\u0018\u00010\u000eX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0015\u001a\u00020\u0016X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0010\u001a\b\u0012\u0004\u0012\u00020\u000e0\rX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u000f\u001a\b\u0012\u0004\u0012\u00020\u000e0\rX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\f\u001a\b\u0012\u0004\u0012\u00020\u000e0\rX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0017\u001a\u00020\u0018X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0019\u0010\u001aR\u0014\u0010\u001b\u001a\u00020\u0018X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u001c\u0010\u001aR\u000e\u0010\u001d\u001a\u00020\u0018X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u001e\u001a\u00020\u0018X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u001f\u001a\u00020\u0006X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b \u0010!R\u000e\u0010\u0012\u001a\u00020\u0006X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\"\u001a\u00020\u0006X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b#\u0010!R\u000e\u0010$\u001a\u00020\u0006X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010%\u001a\u00020\u0006X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010&\u001a\u00020\u0006X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010'\u001a\u00020\u0006X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b(\u0010!R\u000e\u0010)\u001a\u00020\u0016X\u0082\u000e¢\u0006\u0002\n��R\u0014\u0010*\u001a\u00020\u0006X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b+\u0010!R\u000e\u0010,\u001a\u00020\u0006X\u0082\u000e¢\u0006\u0002\n��R\u0016\u0010-\u001a\n .*\u0004\u0018\u00010\u00180\u0018X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010/\u001a\n .*\u0004\u0018\u00010\u00180\u0018X\u0082\u0004¢\u0006\u0002\n��R\u0016\u00100\u001a\n .*\u0004\u0018\u00010\u00180\u0018X\u0082\u0004¢\u0006\u0002\n��R\u0016\u00101\u001a\n .*\u0004\u0018\u00010\u00180\u0018X\u0082\u0004¢\u0006\u0002\n��R\u0016\u00102\u001a\n .*\u0004\u0018\u00010\u00180\u0018X\u0082\u0004¢\u0006\u0002\n��R\u0016\u00103\u001a\n .*\u0004\u0018\u00010\u00180\u0018X\u0082\u0004¢\u0006\u0002\n��R\u0016\u00104\u001a\n .*\u0004\u0018\u00010\u00180\u0018X\u0082\u0004¢\u0006\u0002\n��R\u0016\u00105\u001a\n .*\u0004\u0018\u00010\u00180\u0018X\u0082\u0004¢\u0006\u0002\n��R\u0016\u00106\u001a\n .*\u0004\u0018\u00010\u00180\u0018X\u0082\u0004¢\u0006\u0002\n��R\u0016\u00107\u001a\n .*\u0004\u0018\u00010\u00180\u0018X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\b\u001a\u00020\tX\u0082\u0004¢\u0006\u0002\n��R\u0010\u00108\u001a\u0004\u0018\u00010\u000eX\u0082\u000e¢\u0006\u0002\n��R\u0010\u00109\u001a\u0004\u0018\u00010\u000eX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0011\u001a\u00020\u0006X\u0082\u0004¢\u0006\u0002\n��¨\u0006D"}, d2 = {"Lcom/komputation/cuda/layers/forward/dropout/CudaDropoutLayer;", "Lcom/komputation/cuda/layers/forward/activation/BaseCudaActivationLayer;", "Lcom/komputation/layers/Resourceful;", "name", "", "numberRows", "", "numberColumns", "random", "Ljava/util/Random;", "keepProbability", "", "createTrainingKernel", "Lkotlin/Function0;", "Lcom/komputation/cuda/kernels/Kernel;", "createRuntimeKernel", "createBackwardKernel", "warpSize", "maximumNumberThreadsPerBlock", "(Ljava/lang/String;IILjava/util/Random;FLkotlin/jvm/functions/Function0;Lkotlin/jvm/functions/Function0;Lkotlin/jvm/functions/Function0;II)V", "backwardKernel", "batchSize", "", "deviceBackwardResult", "Ljcuda/Pointer;", "getDeviceBackwardResult", "()Ljcuda/Pointer;", "deviceForwardResult", "getDeviceForwardResult", "deviceMasks", "deviceSeeds", "maximumInputColumns", "getMaximumInputColumns", "()I", "maximumOutputColumns", "getMaximumOutputColumns", "numberBlocksInXDimension", "numberBlocksInYDimension", "numberEntries", "numberInputRows", "getNumberInputRows", "numberIterations", "numberOutputRows", "getNumberOutputRows", "numberThreadsPerBlock", "pointerToBatchSize", "kotlin.jvm.PlatformType", "pointerToDeviceBackwardResults", "pointerToDeviceForwardResults", "pointerToDeviceMasks", "pointerToDeviceSeeds", "pointerToDropoutProbability", "pointerToKeepProbability", "pointerToNumberEntries", "pointerToNumberIterations", "pointerToNumberRows", "runtimeKernel", "trainingKernel", "acquire", "", "maximumBatchSize", "backward", "chain", "forward", "deviceInput", "isTraining", "", "release", "komputation"})
/* loaded from: input_file:com/komputation/cuda/layers/forward/dropout/CudaDropoutLayer.class */
public final class CudaDropoutLayer extends BaseCudaActivationLayer implements Resourceful {
    private final int numberEntries;
    private final Pointer pointerToNumberEntries;
    private int numberBlocksInXDimension;
    private int numberBlocksInYDimension;
    private int numberThreadsPerBlock;
    private int[] numberIterations;
    private final Pointer pointerToNumberIterations;
    private final Pointer pointerToKeepProbability;
    private final Pointer pointerToDropoutProbability;
    private final Pointer deviceSeeds;
    private final Pointer pointerToDeviceSeeds;
    private final Pointer deviceMasks;
    private final Pointer pointerToDeviceMasks;
    private Kernel trainingKernel;
    private Kernel runtimeKernel;

    @NotNull
    private final Pointer deviceForwardResult;
    private final int numberOutputRows;
    private final int maximumOutputColumns;
    private final Pointer pointerToDeviceForwardResults;
    private Kernel backwardKernel;

    @NotNull
    private final Pointer deviceBackwardResult;
    private final int numberInputRows;
    private final Pointer pointerToNumberRows;
    private final int maximumInputColumns;
    private final Pointer pointerToDeviceBackwardResults;
    private final int[] batchSize;
    private final Pointer pointerToBatchSize;
    private final Random random;
    private final Function0<Kernel> createTrainingKernel;
    private final Function0<Kernel> createRuntimeKernel;
    private final Function0<Kernel> createBackwardKernel;
    private final int warpSize;
    private final int maximumNumberThreadsPerBlock;

    @Override // com.komputation.cuda.CudaForwardState
    @NotNull
    public Pointer getDeviceForwardResult() {
        return this.deviceForwardResult;
    }

    @Override // com.komputation.cuda.CudaForwardState
    public int getNumberOutputRows() {
        return this.numberOutputRows;
    }

    @Override // com.komputation.cuda.CudaForwardState
    public int getMaximumOutputColumns() {
        return this.maximumOutputColumns;
    }

    @Override // com.komputation.cuda.CudaBackwardState
    @NotNull
    public Pointer getDeviceBackwardResult() {
        return this.deviceBackwardResult;
    }

    @Override // com.komputation.cuda.CudaBackwardState
    public int getNumberInputRows() {
        return this.numberInputRows;
    }

    @Override // com.komputation.cuda.CudaBackwardState
    public int getMaximumInputColumns() {
        return this.maximumInputColumns;
    }

    @Override // com.komputation.layers.Resourceful
    public void acquire(int i) {
        this.trainingKernel = (Kernel) this.createTrainingKernel.invoke();
        this.runtimeKernel = (Kernel) this.createRuntimeKernel.invoke();
        int i2 = i * this.numberEntries;
        int[] iArr = new int[i2];
        DropoutKt.seed(this.random, iArr, i2);
        CudaIntArrayKt.setIntArray(iArr, i2, this.deviceSeeds);
        CudaFloatArrayKt.allocateDeviceFloatMemory(this.deviceMasks, i2);
        CudaFloatArrayKt.allocateDeviceFloatMemory(getDeviceForwardResult(), i2);
        this.backwardKernel = (Kernel) this.createBackwardKernel.invoke();
        CudaFloatArrayKt.allocateDeviceFloatMemory(getDeviceBackwardResult(), i2);
        this.numberBlocksInXDimension = i;
        Pair<Integer, Integer> computeNumberOfThreadsForRows = EntrywiseKt.computeNumberOfThreadsForRows(getNumberInputRows(), this.warpSize, this.maximumNumberThreadsPerBlock);
        int intValue = ((Number) computeNumberOfThreadsForRows.component1()).intValue();
        int intValue2 = ((Number) computeNumberOfThreadsForRows.component2()).intValue();
        this.numberBlocksInYDimension = getMaximumInputColumns();
        this.numberIterations[0] = intValue;
        this.numberThreadsPerBlock = intValue2;
    }

    @Override // com.komputation.cuda.layers.CudaForwardLayer
    @NotNull
    public Pointer forward(int i, @NotNull Pointer pointer, boolean z) {
        Intrinsics.checkParameterIsNotNull(pointer, "deviceInput");
        this.batchSize[0] = i;
        NativePointerObject nativePointerObject = Pointer.to(new NativePointerObject[]{(NativePointerObject) pointer});
        if (z) {
            Kernel kernel = this.trainingKernel;
            if (kernel == null) {
                Intrinsics.throwNpe();
            }
            Pointer pointer2 = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.pointerToBatchSize, (NativePointerObject) this.pointerToNumberEntries, (NativePointerObject) this.pointerToNumberRows, (NativePointerObject) this.pointerToNumberIterations, (NativePointerObject) this.pointerToDropoutProbability, nativePointerObject, (NativePointerObject) this.pointerToDeviceSeeds, (NativePointerObject) this.pointerToDeviceMasks, (NativePointerObject) this.pointerToDeviceForwardResults});
            Intrinsics.checkExpressionValueIsNotNull(pointer2, "Pointer.to(\n            …Results\n                )");
            kernel.launch(pointer2, this.numberBlocksInXDimension, this.numberBlocksInYDimension, this.numberThreadsPerBlock, 0);
        } else {
            Kernel kernel2 = this.runtimeKernel;
            if (kernel2 == null) {
                Intrinsics.throwNpe();
            }
            Pointer pointer3 = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.pointerToBatchSize, (NativePointerObject) this.pointerToNumberEntries, (NativePointerObject) this.pointerToNumberRows, (NativePointerObject) this.pointerToNumberIterations, (NativePointerObject) this.pointerToKeepProbability, nativePointerObject, (NativePointerObject) this.pointerToDeviceForwardResults});
            Intrinsics.checkExpressionValueIsNotNull(pointer3, "Pointer.to(\n            …Results\n                )");
            kernel2.launch(pointer3, this.numberBlocksInXDimension, this.numberBlocksInYDimension, this.numberThreadsPerBlock, 0);
        }
        return getDeviceForwardResult();
    }

    @Override // com.komputation.cuda.layers.CudaForwardLayer
    @NotNull
    public Pointer backward(int i, @NotNull Pointer pointer) {
        Intrinsics.checkParameterIsNotNull(pointer, "chain");
        Kernel kernel = this.backwardKernel;
        if (kernel == null) {
            Intrinsics.throwNpe();
        }
        Pointer pointer2 = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.pointerToBatchSize, (NativePointerObject) this.pointerToNumberEntries, (NativePointerObject) this.pointerToNumberRows, (NativePointerObject) this.pointerToNumberIterations, (NativePointerObject) Pointer.to(new NativePointerObject[]{(NativePointerObject) pointer}), (NativePointerObject) this.pointerToDeviceMasks, (NativePointerObject) this.pointerToDeviceBackwardResults});
        Intrinsics.checkExpressionValueIsNotNull(pointer2, "Pointer.to(\n            …wardResults\n            )");
        kernel.launch(pointer2, this.numberBlocksInXDimension, this.numberBlocksInYDimension, this.numberThreadsPerBlock, 0);
        return getDeviceBackwardResult();
    }

    @Override // com.komputation.layers.Resourceful
    public void release() {
        Kernel kernel = this.trainingKernel;
        if (kernel == null) {
            Intrinsics.throwNpe();
        }
        kernel.destroy();
        JCuda.cudaFree(getDeviceBackwardResult());
        JCuda.cudaFree(getDeviceForwardResult());
        JCuda.cudaFree(this.deviceMasks);
        JCuda.cudaFree(this.deviceSeeds);
        Kernel kernel2 = this.backwardKernel;
        if (kernel2 == null) {
            Intrinsics.throwNpe();
        }
        kernel2.destroy();
        Kernel kernel3 = this.runtimeKernel;
        if (kernel3 == null) {
            Intrinsics.throwNpe();
        }
        kernel3.destroy();
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public CudaDropoutLayer(@Nullable String str, int i, int i2, @NotNull Random random, float f, @NotNull Function0<Kernel> function0, @NotNull Function0<Kernel> function02, @NotNull Function0<Kernel> function03, int i3, int i4) {
        super(str);
        Intrinsics.checkParameterIsNotNull(random, "random");
        Intrinsics.checkParameterIsNotNull(function0, "createTrainingKernel");
        Intrinsics.checkParameterIsNotNull(function02, "createRuntimeKernel");
        Intrinsics.checkParameterIsNotNull(function03, "createBackwardKernel");
        this.random = random;
        this.createTrainingKernel = function0;
        this.createRuntimeKernel = function02;
        this.createBackwardKernel = function03;
        this.warpSize = i3;
        this.maximumNumberThreadsPerBlock = i4;
        this.numberEntries = i * i2;
        this.pointerToNumberEntries = Pointer.to(new int[]{this.numberEntries});
        this.numberBlocksInXDimension = -1;
        this.numberBlocksInYDimension = -1;
        this.numberThreadsPerBlock = -1;
        this.numberIterations = new int[]{-1};
        this.pointerToNumberIterations = Pointer.to(this.numberIterations);
        this.pointerToKeepProbability = Pointer.to(new float[]{f});
        this.pointerToDropoutProbability = Pointer.to(new float[]{1.0f - f});
        this.deviceSeeds = new Pointer();
        this.pointerToDeviceSeeds = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.deviceSeeds});
        this.deviceMasks = new Pointer();
        this.pointerToDeviceMasks = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.deviceMasks});
        this.deviceForwardResult = new Pointer();
        this.numberOutputRows = i;
        this.maximumOutputColumns = i2;
        this.pointerToDeviceForwardResults = Pointer.to(new NativePointerObject[]{(NativePointerObject) getDeviceForwardResult()});
        this.deviceBackwardResult = new Pointer();
        this.numberInputRows = i;
        this.pointerToNumberRows = Pointer.to(new int[]{getNumberInputRows()});
        this.maximumInputColumns = i2;
        this.pointerToDeviceBackwardResults = Pointer.to(new NativePointerObject[]{(NativePointerObject) getDeviceBackwardResult()});
        this.batchSize = new int[]{-1};
        this.pointerToBatchSize = Pointer.to(this.batchSize);
    }

    public /* synthetic */ CudaDropoutLayer(String str, int i, int i2, Random random, float f, Function0 function0, Function0 function02, Function0 function03, int i3, int i4, int i5, DefaultConstructorMarker defaultConstructorMarker) {
        this((i5 & 1) != 0 ? (String) null : str, i, i2, random, f, function0, function02, function03, i3, i4);
    }
}
