package com.komputation.cuda.layers.forward.projection;

import com.komputation.cuda.CudaFloatArrayKt;
import com.komputation.cuda.CudaIntArrayKt;
import com.komputation.cuda.functions.CublasBackwardProjectionKt;
import com.komputation.cuda.kernels.Kernel;
import com.komputation.cuda.kernels.launch.EntrywiseKt;
import com.komputation.cuda.layers.BaseCudaForwardLayer;
import com.komputation.cuda.layers.CudaVariableLengthForwardLayer;
import com.komputation.cuda.optimization.BaseCudaUpdateRule;
import com.komputation.layers.Resourceful;
import com.komputation.optimization.Optimizable;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.jcublas.cublasHandle;
import jcuda.runtime.JCuda;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.jvm.functions.Function0;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: CublasBiasLayer.kt */
@Metadata(mv = {1, 1, 9}, bv = {1, 0, 2}, k = 1, d1 = {"��`\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000e\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0010\u0014\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\u0015\n��\n\u0002\u0018\u0002\n\u0002\b\"\n\u0002\u0010\u0002\n\u0002\b\u0006\n\u0002\u0010\u000b\n\u0002\b\u0004\u0018��2\u00020\u00012\u00020\u00022\u00020\u00032\u00020\u0004BY\b��\u0012\b\u0010\u0005\u001a\u0004\u0018\u00010\u0006\u0012\u0006\u0010\u0007\u001a\u00020\b\u0012\u0006\u0010\t\u001a\u00020\n\u0012\u0006\u0010\u000b\u001a\u00020\n\u0012\u0006\u0010\f\u001a\u00020\r\u0012\b\u0010\u000e\u001a\u0004\u0018\u00010\u000f\u0012\f\u0010\u0010\u001a\b\u0012\u0004\u0012\u00020\u00120\u0011\u0012\u0006\u0010\u0013\u001a\u00020\n\u0012\u0006\u0010\u0014\u001a\u00020\n¢\u0006\u0002\u0010\u0015J\u0010\u0010;\u001a\u00020<2\u0006\u0010\"\u001a\u00020\nH\u0016J\u0018\u0010=\u001a\u00020\u00192\u0006\u0010\u0016\u001a\u00020\n2\u0006\u0010>\u001a\u00020\u0019H\u0016J(\u0010?\u001a\u00020\u00192\u0006\u0010\u0016\u001a\u00020\n2\u0006\u0010@\u001a\u00020\u00192\u0006\u0010A\u001a\u00020\u00192\u0006\u0010B\u001a\u00020CH\u0016J \u0010?\u001a\u00020\u00192\u0006\u0010\u0016\u001a\u00020\n2\u0006\u0010A\u001a\u00020\u00192\u0006\u0010B\u001a\u00020CH\u0016J\u0006\u0010D\u001a\u00020\u0019J\u0010\u0010E\u001a\u00020<2\u0006\u0010\u0016\u001a\u00020\nH\u0016J\b\u0010F\u001a\u00020<H\u0016R\u000e\u0010\u0016\u001a\u00020\u0017X\u0082\u0004¢\u0006\u0002\n��R\u0010\u0010\u000e\u001a\u0004\u0018\u00010\u000fX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0010\u001a\b\u0012\u0004\u0012\u00020\u00120\u0011X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0007\u001a\u00020\bX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0018\u001a\u00020\u0019X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u001a\u0010\u001bR\u000e\u0010\u001c\u001a\u00020\u0019X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u001d\u001a\u00020\u0019X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u001e\u0010\u001bR\u000e\u0010\u001f\u001a\u00020\u0019X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010 \u001a\u00020\u0019X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\f\u001a\u00020\rX\u0082\u0004¢\u0006\u0002\n��R\u0010\u0010!\u001a\u0004\u0018\u00010\u0012X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\"\u001a\u00020\nX\u0082\u000e¢\u0006\u0002\n��R\u0014\u0010#\u001a\u00020\nX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b$\u0010%R\u000e\u0010\u0014\u001a\u00020\nX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010&\u001a\u00020\nX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b'\u0010%R\u000e\u0010(\u001a\u00020\nX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010)\u001a\u00020\nX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010*\u001a\u00020\nX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010+\u001a\u00020\nX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010,\u001a\u00020\nX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b-\u0010%R\u000e\u0010.\u001a\u00020\u0017X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010/\u001a\u00020\nX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b0\u0010%R\u000e\u00101\u001a\u00020\nX\u0082\u000e¢\u0006\u0002\n��R\u0016\u00102\u001a\n 3*\u0004\u0018\u00010\u00190\u0019X\u0082\u0004¢\u0006\u0002\n��R\u0016\u00104\u001a\n 3*\u0004\u0018\u00010\u00190\u0019X\u0082\u0004¢\u0006\u0002\n��R\u0016\u00105\u001a\n 3*\u0004\u0018\u00010\u00190\u0019X\u0082\u0004¢\u0006\u0002\n��R\u0016\u00106\u001a\n 3*\u0004\u0018\u00010\u00190\u0019X\u0082\u0004¢\u0006\u0002\n��R\u0016\u00107\u001a\n 3*\u0004\u0018\u00010\u00190\u0019X\u0082\u0004¢\u0006\u0002\n��R\u0016\u00108\u001a\n 3*\u0004\u0018\u00010\u00190\u0019X\u0082\u0004¢\u0006\u0002\n��R\u0016\u00109\u001a\n 3*\u0004\u0018\u00010\u00190\u0019X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010:\u001a\n 3*\u0004\u0018\u00010\u00190\u0019X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0013\u001a\u00020\nX\u0082\u0004¢\u0006\u0002\n��¨\u0006G"}, d2 = {"Lcom/komputation/cuda/layers/forward/projection/CublasBiasLayer;", "Lcom/komputation/cuda/layers/BaseCudaForwardLayer;", "Lcom/komputation/cuda/layers/CudaVariableLengthForwardLayer;", "Lcom/komputation/optimization/Optimizable;", "Lcom/komputation/layers/Resourceful;", "name", "", "cublasHandle", "Ljcuda/jcublas/cublasHandle;", "numberRows", "", "numberColumns", "initialBias", "", "biasUpdateRule", "Lcom/komputation/cuda/optimization/BaseCudaUpdateRule;", "createKernel", "Lkotlin/Function0;", "Lcom/komputation/cuda/kernels/Kernel;", "warpSize", "maximumNumberThreadsPerBlock", "(Ljava/lang/String;Ljcuda/jcublas/cublasHandle;II[FLcom/komputation/cuda/optimization/BaseCudaUpdateRule;Lkotlin/jvm/functions/Function0;II)V", "batchSize", "", "deviceBackwardResult", "Ljcuda/Pointer;", "getDeviceBackwardResult", "()Ljcuda/Pointer;", "deviceBias", "deviceForwardResult", "getDeviceForwardResult", "deviceMaximumInputColumns", "deviceOnes", "kernel", "maximumBatchSize", "maximumInputColumns", "getMaximumInputColumns", "()I", "maximumOutputColumns", "getMaximumOutputColumns", "numberBatchInputColumns", "numberBlocksInXDimension", "numberBlocksInYDimension", "numberEntries", "numberInputRows", "getNumberInputRows", "numberIterations", "numberOutputRows", "getNumberOutputRows", "numberThreadsPerBlock", "pointerToBatchSize", "kotlin.jvm.PlatformType", "pointerToDeviceBackwardWrtBias", "pointerToDeviceBias", "pointerToDeviceForwardResult", "pointerToMaximumInputColumns", "pointerToNumberEntries", "pointerToNumberInputRows", "pointerToNumberIterations", "acquire", "", "backward", "chain", "forward", "deviceLengths", "deviceInput", "isTraining", "", "getDeviceBias", "optimize", "release", "komputation"})
/* loaded from: input_file:com/komputation/cuda/layers/forward/projection/CublasBiasLayer.class */
public final class CublasBiasLayer extends BaseCudaForwardLayer implements CudaVariableLengthForwardLayer, Optimizable, Resourceful {
    private final int numberEntries;
    private Kernel kernel;
    private int numberBlocksInXDimension;
    private int numberBlocksInYDimension;
    private int numberThreadsPerBlock;

    @NotNull
    private final Pointer deviceForwardResult;
    private final int numberOutputRows;
    private final int maximumOutputColumns;
    private final Pointer pointerToDeviceForwardResult;
    private final Pointer deviceBias;
    private final Pointer pointerToDeviceBias;

    @NotNull
    private final Pointer deviceBackwardResult;
    private final int numberInputRows;
    private final int maximumInputColumns;
    private final Pointer pointerToDeviceBackwardWrtBias;
    private final Pointer deviceOnes;
    private final Pointer pointerToNumberEntries;
    private final Pointer pointerToNumberInputRows;
    private final int[] batchSize;
    private final Pointer pointerToBatchSize;
    private final int[] numberIterations;
    private final Pointer pointerToNumberIterations;
    private int numberBatchInputColumns;
    private int maximumBatchSize;
    private final Pointer deviceMaximumInputColumns;
    private final Pointer pointerToMaximumInputColumns;
    private final cublasHandle cublasHandle;
    private final float[] initialBias;
    private final BaseCudaUpdateRule biasUpdateRule;
    private final Function0<Kernel> createKernel;
    private final int warpSize;
    private final int maximumNumberThreadsPerBlock;

    @Override // com.komputation.cuda.CudaForwardState
    @NotNull
    public Pointer getDeviceForwardResult() {
        return this.deviceForwardResult;
    }

    @Override // com.komputation.cuda.CudaForwardState
    public int getNumberOutputRows() {
        return this.numberOutputRows;
    }

    @Override // com.komputation.cuda.CudaForwardState
    public int getMaximumOutputColumns() {
        return this.maximumOutputColumns;
    }

    @Override // com.komputation.cuda.CudaBackwardState
    @NotNull
    public Pointer getDeviceBackwardResult() {
        return this.deviceBackwardResult;
    }

    @Override // com.komputation.cuda.CudaBackwardState
    public int getNumberInputRows() {
        return this.numberInputRows;
    }

    @Override // com.komputation.cuda.CudaBackwardState
    public int getMaximumInputColumns() {
        return this.maximumInputColumns;
    }

    @NotNull
    public final Pointer getDeviceBias() {
        return this.deviceBias;
    }

    @Override // com.komputation.layers.Resourceful
    public void acquire(int i) {
        this.maximumBatchSize = i;
        this.numberBatchInputColumns = i * getMaximumInputColumns();
        this.kernel = (Kernel) this.createKernel.invoke();
        int[] iArr = new int[i];
        int length = iArr.length;
        for (int i2 = 0; i2 < length; i2++) {
            iArr[i2] = getMaximumInputColumns();
        }
        CudaIntArrayKt.setIntArray(iArr, this.maximumBatchSize, this.deviceMaximumInputColumns);
        CudaFloatArrayKt.setFloatArray(this.initialBias, this.numberEntries, this.deviceBias);
        CudaFloatArrayKt.allocateDeviceFloatMemory(getDeviceForwardResult(), i * this.numberEntries);
        CudaFloatArrayKt.allocateDeviceFloatMemory(getDeviceBackwardResult(), this.numberEntries);
        BaseCudaUpdateRule baseCudaUpdateRule = this.biasUpdateRule;
        if (baseCudaUpdateRule != null) {
            baseCudaUpdateRule.acquire(i);
        }
        float[] fArr = new float[this.numberBatchInputColumns];
        int length2 = fArr.length;
        for (int i3 = 0; i3 < length2; i3++) {
            fArr[i3] = 1.0f;
        }
        CudaFloatArrayKt.setFloatArray(fArr, this.numberBatchInputColumns, this.deviceOnes);
        this.numberBlocksInXDimension = i;
        this.numberBlocksInYDimension = getMaximumInputColumns();
        Pair<Integer, Integer> computeNumberOfThreadsForRows = EntrywiseKt.computeNumberOfThreadsForRows(getNumberInputRows(), this.warpSize, this.maximumNumberThreadsPerBlock);
        int intValue = ((Number) computeNumberOfThreadsForRows.component1()).intValue();
        int intValue2 = ((Number) computeNumberOfThreadsForRows.component2()).intValue();
        this.numberThreadsPerBlock = intValue;
        this.numberIterations[0] = intValue2;
    }

    @Override // com.komputation.cuda.layers.CudaForwardLayer
    @NotNull
    public Pointer forward(int i, @NotNull Pointer pointer, boolean z) {
        Intrinsics.checkParameterIsNotNull(pointer, "deviceInput");
        this.batchSize[0] = i;
        Kernel kernel = this.kernel;
        if (kernel == null) {
            Intrinsics.throwNpe();
        }
        Pointer pointer2 = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.pointerToBatchSize, (NativePointerObject) this.pointerToMaximumInputColumns, (NativePointerObject) this.pointerToNumberEntries, (NativePointerObject) this.pointerToNumberInputRows, (NativePointerObject) this.pointerToNumberIterations, (NativePointerObject) Pointer.to(new NativePointerObject[]{(NativePointerObject) pointer}), (NativePointerObject) this.pointerToDeviceBias, (NativePointerObject) this.pointerToDeviceForwardResult});
        Intrinsics.checkExpressionValueIsNotNull(pointer2, "Pointer.to(\n            …rwardResult\n            )");
        kernel.launch(pointer2, this.numberBlocksInXDimension, this.numberBlocksInYDimension, this.numberThreadsPerBlock, 0);
        return getDeviceForwardResult();
    }

    @Override // com.komputation.cuda.layers.CudaVariableLengthForwardLayer
    @NotNull
    public Pointer forward(int i, @NotNull Pointer pointer, @NotNull Pointer pointer2, boolean z) {
        Intrinsics.checkParameterIsNotNull(pointer, "deviceLengths");
        Intrinsics.checkParameterIsNotNull(pointer2, "deviceInput");
        this.batchSize[0] = i;
        Kernel kernel = this.kernel;
        if (kernel == null) {
            Intrinsics.throwNpe();
        }
        Pointer pointer3 = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.pointerToBatchSize, (NativePointerObject) Pointer.to(new NativePointerObject[]{(NativePointerObject) pointer}), (NativePointerObject) this.pointerToNumberEntries, (NativePointerObject) this.pointerToNumberInputRows, (NativePointerObject) this.pointerToNumberIterations, (NativePointerObject) Pointer.to(new NativePointerObject[]{(NativePointerObject) pointer2}), (NativePointerObject) this.pointerToDeviceBias, (NativePointerObject) this.pointerToDeviceForwardResult});
        Intrinsics.checkExpressionValueIsNotNull(pointer3, "Pointer.to(\n            …rwardResult\n            )");
        kernel.launch(pointer3, this.numberBlocksInXDimension, this.numberBlocksInYDimension, this.numberThreadsPerBlock, 0);
        return getDeviceForwardResult();
    }

    @Override // com.komputation.cuda.layers.CudaForwardLayer
    @NotNull
    public Pointer backward(int i, @NotNull Pointer pointer) {
        Intrinsics.checkParameterIsNotNull(pointer, "chain");
        if (i < this.maximumBatchSize) {
            CudaIntArrayKt.setArrayToZero(getDeviceBackwardResult(), this.numberEntries);
            CublasBackwardProjectionKt.cublasBackwardProjectionWrtBias(this.cublasHandle, pointer, getNumberInputRows(), i * getMaximumInputColumns(), this.deviceOnes, getDeviceBackwardResult());
        } else {
            CublasBackwardProjectionKt.cublasBackwardProjectionWrtBias(this.cublasHandle, pointer, getNumberInputRows(), this.numberBatchInputColumns, this.deviceOnes, getDeviceBackwardResult());
        }
        return getDeviceBackwardResult();
    }

    @Override // com.komputation.optimization.Optimizable
    public void optimize(int i) {
        BaseCudaUpdateRule baseCudaUpdateRule = this.biasUpdateRule;
        if (baseCudaUpdateRule != null) {
            Pointer pointer = this.pointerToDeviceBias;
            Intrinsics.checkExpressionValueIsNotNull(pointer, "this.pointerToDeviceBias");
            Pointer pointer2 = this.pointerToDeviceBackwardWrtBias;
            Intrinsics.checkExpressionValueIsNotNull(pointer2, "this.pointerToDeviceBackwardWrtBias");
            baseCudaUpdateRule.denseUpdate(i, pointer, pointer2);
        }
    }

    @Override // com.komputation.layers.Resourceful
    public void release() {
        JCuda.cudaFree(getDeviceForwardResult());
        JCuda.cudaFree(getDeviceBackwardResult());
        JCuda.cudaFree(this.deviceBias);
        JCuda.cudaFree(this.deviceOnes);
        Kernel kernel = this.kernel;
        if (kernel == null) {
            Intrinsics.throwNpe();
        }
        kernel.destroy();
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public CublasBiasLayer(@Nullable String str, @NotNull cublasHandle cublashandle, int i, int i2, @NotNull float[] fArr, @Nullable BaseCudaUpdateRule baseCudaUpdateRule, @NotNull Function0<Kernel> function0, int i3, int i4) {
        super(str);
        Intrinsics.checkParameterIsNotNull(cublashandle, "cublasHandle");
        Intrinsics.checkParameterIsNotNull(fArr, "initialBias");
        Intrinsics.checkParameterIsNotNull(function0, "createKernel");
        this.cublasHandle = cublashandle;
        this.initialBias = fArr;
        this.biasUpdateRule = baseCudaUpdateRule;
        this.createKernel = function0;
        this.warpSize = i3;
        this.maximumNumberThreadsPerBlock = i4;
        this.numberEntries = i * i2;
        this.numberBlocksInXDimension = -1;
        this.numberBlocksInYDimension = -1;
        this.numberThreadsPerBlock = -1;
        this.deviceForwardResult = new Pointer();
        this.numberOutputRows = i;
        this.maximumOutputColumns = i2;
        this.pointerToDeviceForwardResult = Pointer.to(new NativePointerObject[]{(NativePointerObject) getDeviceForwardResult()});
        this.deviceBias = new Pointer();
        this.pointerToDeviceBias = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.deviceBias});
        this.deviceBackwardResult = new Pointer();
        this.numberInputRows = i;
        this.maximumInputColumns = i2;
        this.pointerToDeviceBackwardWrtBias = Pointer.to(new NativePointerObject[]{(NativePointerObject) getDeviceBackwardResult()});
        this.deviceOnes = new Pointer();
        this.pointerToNumberEntries = Pointer.to(new int[]{this.numberEntries});
        this.pointerToNumberInputRows = Pointer.to(new int[]{getNumberInputRows()});
        this.batchSize = new int[]{-1};
        this.pointerToBatchSize = Pointer.to(this.batchSize);
        this.numberIterations = new int[]{-1};
        this.pointerToNumberIterations = Pointer.to(this.numberIterations);
        this.numberBatchInputColumns = -1;
        this.maximumBatchSize = -1;
        this.deviceMaximumInputColumns = new Pointer();
        this.pointerToMaximumInputColumns = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.deviceMaximumInputColumns});
    }
}
