package com.komputation.cuda.layers.forward.projection;

import com.komputation.cuda.CudaFloatArrayKt;
import com.komputation.cuda.CudaIntArrayKt;
import com.komputation.cuda.functions.CublasBackwardProjectionKt;
import com.komputation.cuda.functions.CublasProjectionKt;
import com.komputation.cuda.layers.BaseCudaForwardLayer;
import com.komputation.cuda.optimization.BaseCudaUpdateRule;
import com.komputation.layers.Resourceful;
import com.komputation.optimization.Optimizable;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.jcublas.cublasHandle;
import jcuda.runtime.JCuda;
import kotlin.Metadata;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: CublasWeightingLayer.kt */
@Metadata(mv = {1, 1, 9}, bv = {1, 0, 2}, k = 1, d1 = {"��L\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000e\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\b\n\u0002\b\u0003\n\u0002\u0010\u0014\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u001a\n\u0002\u0010\u0002\n\u0002\b\u0005\n\u0002\u0010\u000b\n\u0002\b\u0004\u0018��2\u00020\u00012\u00020\u00022\u00020\u0003BE\b��\u0012\b\u0010\u0004\u001a\u0004\u0018\u00010\u0005\u0012\u0006\u0010\u0006\u001a\u00020\u0007\u0012\u0006\u0010\b\u001a\u00020\t\u0012\u0006\u0010\n\u001a\u00020\t\u0012\u0006\u0010\u000b\u001a\u00020\t\u0012\u0006\u0010\f\u001a\u00020\r\u0012\n\b\u0002\u0010\u000e\u001a\u0004\u0018\u00010\u000f¢\u0006\u0002\u0010\u0010J\u0010\u0010,\u001a\u00020-2\u0006\u0010\u001b\u001a\u00020\tH\u0016J\u0018\u0010.\u001a\u00020\u00122\u0006\u0010/\u001a\u00020\t2\u0006\u00100\u001a\u00020\u0012H\u0016J \u00101\u001a\u00020\u00122\u0006\u0010/\u001a\u00020\t2\u0006\u0010\u0018\u001a\u00020\u00122\u0006\u00102\u001a\u000203H\u0016J\u0006\u00104\u001a\u00020\u0012J\u0010\u00105\u001a\u00020-2\u0006\u0010/\u001a\u00020\tH\u0016J\b\u00106\u001a\u00020-H\u0016R\u000e\u0010\u0006\u001a\u00020\u0007X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0011\u001a\u00020\u0012X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0013\u0010\u0014R\u000e\u0010\u0015\u001a\u00020\u0012X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0016\u001a\u00020\u0012X\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u0017\u0010\u0014R\u000e\u0010\u0018\u001a\u00020\u0012X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0019\u001a\u00020\u0012X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\f\u001a\u00020\rX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u001a\u001a\u00020\tX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u001b\u001a\u00020\tX\u0082\u000e¢\u0006\u0002\n��R\u0014\u0010\n\u001a\u00020\tX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u001c\u0010\u001dR\u0014\u0010\u001e\u001a\u00020\tX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u001f\u0010\u001dR\u000e\u0010 \u001a\u00020\tX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010!\u001a\u00020\tX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\"\u001a\u00020\tX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\b\u001a\u00020\tX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b#\u0010\u001dR\u000e\u0010$\u001a\u00020\tX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u000b\u001a\u00020\tX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b%\u0010\u001dR\u000e\u0010&\u001a\u00020\tX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010'\u001a\u00020\tX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010(\u001a\u00020\tX\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010)\u001a\n **\u0004\u0018\u00010\u00120\u0012X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010+\u001a\n **\u0004\u0018\u00010\u00120\u0012X\u0082\u0004¢\u0006\u0002\n��R\u0010\u0010\u000e\u001a\u0004\u0018\u00010\u000fX\u0082\u0004¢\u0006\u0002\n��¨\u00067"}, d2 = {"Lcom/komputation/cuda/layers/forward/projection/CublasWeightingLayer;", "Lcom/komputation/cuda/layers/BaseCudaForwardLayer;", "Lcom/komputation/optimization/Optimizable;", "Lcom/komputation/layers/Resourceful;", "name", "", "cublasHandle", "Ljcuda/jcublas/cublasHandle;", "numberInputRows", "", "maximumInputColumns", "numberOutputRows", "initialWeights", "", "weightUpdateRule", "Lcom/komputation/cuda/optimization/BaseCudaUpdateRule;", "(Ljava/lang/String;Ljcuda/jcublas/cublasHandle;III[FLcom/komputation/cuda/optimization/BaseCudaUpdateRule;)V", "deviceBackwardResult", "Ljcuda/Pointer;", "getDeviceBackwardResult", "()Ljcuda/Pointer;", "deviceBackwardWrtWeights", "deviceForwardResult", "getDeviceForwardResult", "deviceInput", "deviceWeights", "lastBatchSize", "maximumBatchSize", "getMaximumInputColumns", "()I", "maximumOutputColumns", "getMaximumOutputColumns", "numberBatchInputColumns", "numberBatchOutputColumns", "numberInputEntries", "getNumberInputRows", "numberOutputEntries", "getNumberOutputRows", "numberWeightColumns", "numberWeightEntries", "numberWeightRows", "pointerToDeviceBackwardWrtWeights", "kotlin.jvm.PlatformType", "pointerToDeviceWeights", "acquire", "", "backward", "batchSize", "chain", "forward", "isTraining", "", "getDeviceWeights", "optimize", "release", "komputation"})
/* loaded from: input_file:com/komputation/cuda/layers/forward/projection/CublasWeightingLayer.class */
public final class CublasWeightingLayer extends BaseCudaForwardLayer implements Optimizable, Resourceful {
    private final int numberInputEntries;
    private final int numberWeightRows;
    private final int numberWeightColumns;
    private final int numberWeightEntries;
    private final int maximumOutputColumns;
    private final int numberOutputEntries;
    private Pointer deviceInput;
    private final Pointer deviceWeights;
    private final Pointer pointerToDeviceWeights;

    @NotNull
    private final Pointer deviceForwardResult;
    private final Pointer deviceBackwardWrtWeights;
    private final Pointer pointerToDeviceBackwardWrtWeights;

    @NotNull
    private final Pointer deviceBackwardResult;
    private int maximumBatchSize;
    private int numberBatchInputColumns;
    private int numberBatchOutputColumns;
    private int lastBatchSize;
    private final cublasHandle cublasHandle;
    private final int numberInputRows;
    private final int maximumInputColumns;
    private final int numberOutputRows;
    private final float[] initialWeights;
    private final BaseCudaUpdateRule weightUpdateRule;

    @Override // com.komputation.cuda.CudaForwardState
    public int getMaximumOutputColumns() {
        return this.maximumOutputColumns;
    }

    @Override // com.komputation.cuda.CudaForwardState
    @NotNull
    public Pointer getDeviceForwardResult() {
        return this.deviceForwardResult;
    }

    @Override // com.komputation.cuda.CudaBackwardState
    @NotNull
    public Pointer getDeviceBackwardResult() {
        return this.deviceBackwardResult;
    }

    @NotNull
    public final Pointer getDeviceWeights() {
        return this.deviceWeights;
    }

    @Override // com.komputation.layers.Resourceful
    public void acquire(int i) {
        this.maximumBatchSize = i;
        this.numberBatchInputColumns = i * getMaximumInputColumns();
        this.numberBatchOutputColumns = i * getMaximumOutputColumns();
        CudaFloatArrayKt.setFloatArray(this.initialWeights, this.numberWeightEntries, this.deviceWeights);
        CudaFloatArrayKt.allocateDeviceFloatMemory(this.deviceBackwardWrtWeights, this.numberWeightEntries);
        BaseCudaUpdateRule baseCudaUpdateRule = this.weightUpdateRule;
        if (baseCudaUpdateRule != null) {
            baseCudaUpdateRule.acquire(i);
        }
        CudaFloatArrayKt.allocateDeviceFloatMemory(getDeviceForwardResult(), i * this.numberOutputEntries);
        CudaFloatArrayKt.allocateDeviceFloatMemory(getDeviceBackwardResult(), i * this.numberInputEntries);
    }

    @Override // com.komputation.cuda.layers.CudaForwardLayer
    @NotNull
    public Pointer forward(int i, @NotNull Pointer pointer, boolean z) {
        Intrinsics.checkParameterIsNotNull(pointer, "deviceInput");
        this.deviceInput = pointer;
        if (i == 1 && getMaximumInputColumns() == 1) {
            CublasProjectionKt.cublasMatrixVectorMultiplication(this.cublasHandle, this.deviceWeights, this.numberWeightRows, this.numberWeightColumns, this.deviceInput, getDeviceForwardResult());
        } else {
            CublasProjectionKt.cublasMatrixMatrixMultiplication(this.cublasHandle, this.deviceWeights, this.numberWeightRows, this.numberWeightColumns, this.deviceInput, getNumberInputRows(), this.numberBatchInputColumns, getDeviceForwardResult());
        }
        return getDeviceForwardResult();
    }

    @Override // com.komputation.cuda.layers.CudaForwardLayer
    @NotNull
    public Pointer backward(int i, @NotNull Pointer pointer) {
        Intrinsics.checkParameterIsNotNull(pointer, "chain");
        this.lastBatchSize = i;
        if (i < this.maximumBatchSize) {
            CudaIntArrayKt.setArrayToZero(getDeviceBackwardResult(), this.maximumBatchSize * this.numberInputEntries);
            CublasBackwardProjectionKt.cublasBackwardProjectionWrtInput(this.cublasHandle, this.deviceWeights, this.numberWeightRows, this.numberWeightColumns, pointer, getNumberOutputRows(), i * getMaximumOutputColumns(), getDeviceBackwardResult());
            CudaIntArrayKt.setArrayToZero(this.deviceBackwardWrtWeights, this.numberWeightEntries);
            CublasBackwardProjectionKt.cublasBackwardProjectionWrtWeights(this.cublasHandle, pointer, getNumberOutputRows(), i * getMaximumOutputColumns(), this.deviceInput, getNumberInputRows(), this.deviceBackwardWrtWeights, this.numberWeightEntries);
        } else {
            CublasBackwardProjectionKt.cublasBackwardProjectionWrtInput(this.cublasHandle, this.deviceWeights, this.numberWeightRows, this.numberWeightColumns, pointer, getNumberOutputRows(), this.numberBatchOutputColumns, getDeviceBackwardResult());
            CublasBackwardProjectionKt.cublasBackwardProjectionWrtWeights(this.cublasHandle, pointer, getNumberOutputRows(), this.numberBatchInputColumns, this.deviceInput, getNumberInputRows(), this.deviceBackwardWrtWeights, this.numberWeightEntries);
        }
        return getDeviceBackwardResult();
    }

    @Override // com.komputation.optimization.Optimizable
    public void optimize(int i) {
        BaseCudaUpdateRule baseCudaUpdateRule = this.weightUpdateRule;
        if (baseCudaUpdateRule != null) {
            Pointer pointer = this.pointerToDeviceWeights;
            Intrinsics.checkExpressionValueIsNotNull(pointer, "this.pointerToDeviceWeights");
            Pointer pointer2 = this.pointerToDeviceBackwardWrtWeights;
            Intrinsics.checkExpressionValueIsNotNull(pointer2, "this.pointerToDeviceBackwardWrtWeights");
            baseCudaUpdateRule.denseUpdate(i, pointer, pointer2);
        }
    }

    @Override // com.komputation.layers.Resourceful
    public void release() {
        JCuda.cudaFree(this.deviceWeights);
        JCuda.cudaFree(getDeviceForwardResult());
        JCuda.cudaFree(this.deviceBackwardWrtWeights);
        JCuda.cudaFree(getDeviceBackwardResult());
        this.maximumBatchSize = -1;
        this.numberBatchInputColumns = -1;
    }

    @Override // com.komputation.cuda.CudaBackwardState
    public int getNumberInputRows() {
        return this.numberInputRows;
    }

    @Override // com.komputation.cuda.CudaBackwardState
    public int getMaximumInputColumns() {
        return this.maximumInputColumns;
    }

    @Override // com.komputation.cuda.CudaForwardState
    public int getNumberOutputRows() {
        return this.numberOutputRows;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public CublasWeightingLayer(@Nullable String str, @NotNull cublasHandle cublashandle, int i, int i2, int i3, @NotNull float[] fArr, @Nullable BaseCudaUpdateRule baseCudaUpdateRule) {
        super(str);
        Intrinsics.checkParameterIsNotNull(cublashandle, "cublasHandle");
        Intrinsics.checkParameterIsNotNull(fArr, "initialWeights");
        this.cublasHandle = cublashandle;
        this.numberInputRows = i;
        this.maximumInputColumns = i2;
        this.numberOutputRows = i3;
        this.initialWeights = fArr;
        this.weightUpdateRule = baseCudaUpdateRule;
        this.numberInputEntries = getNumberInputRows() * getMaximumInputColumns();
        this.numberWeightRows = getNumberOutputRows();
        this.numberWeightColumns = getNumberInputRows();
        this.numberWeightEntries = this.numberWeightRows * this.numberWeightColumns;
        this.maximumOutputColumns = getMaximumInputColumns();
        this.numberOutputEntries = getNumberOutputRows() * getMaximumOutputColumns();
        this.deviceInput = new Pointer();
        this.deviceWeights = new Pointer();
        this.pointerToDeviceWeights = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.deviceWeights});
        this.deviceForwardResult = new Pointer();
        this.deviceBackwardWrtWeights = new Pointer();
        this.pointerToDeviceBackwardWrtWeights = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.deviceBackwardWrtWeights});
        this.deviceBackwardResult = new Pointer();
        this.maximumBatchSize = -1;
        this.numberBatchInputColumns = -1;
        this.numberBatchOutputColumns = -1;
        this.lastBatchSize = -1;
    }

    public /* synthetic */ CublasWeightingLayer(String str, cublasHandle cublashandle, int i, int i2, int i3, float[] fArr, BaseCudaUpdateRule baseCudaUpdateRule, int i4, DefaultConstructorMarker defaultConstructorMarker) {
        this(str, cublashandle, i, i2, i3, fArr, (i4 & 64) != 0 ? (BaseCudaUpdateRule) null : baseCudaUpdateRule);
    }
}
