package com.komputation.cuda.layers.forward.projection;

import com.komputation.cuda.layers.BaseCudaForwardLayer;
import com.komputation.cuda.layers.CudaVariableLengthForwardLayer;
import com.komputation.optimization.Optimizable;
import jcuda.Pointer;
import kotlin.Metadata;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: CublasProjectionLayer.kt */
@Metadata(mv = {1, 1, 9}, bv = {1, 0, 2}, k = 1, d1 = {"��D\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000e\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0010\b\n\u0002\b\u000f\n\u0002\u0010\u000b\n\u0002\b\u0003\n\u0002\u0010\u0002\n��\u0018��2\u00020\u00012\u00020\u00022\u00020\u0003B!\b��\u0012\b\u0010\u0004\u001a\u0004\u0018\u00010\u0005\u0012\u0006\u0010\u0006\u001a\u00020\u0007\u0012\u0006\u0010\b\u001a\u00020\t¢\u0006\u0002\u0010\nJ\u0018\u0010\u001b\u001a\u00020\f2\u0006\u0010\u001c\u001a\u00020\u00122\u0006\u0010\u001d\u001a\u00020\fH\u0016J(\u0010\u001e\u001a\u00020\f2\u0006\u0010\u001c\u001a\u00020\u00122\u0006\u0010\u001f\u001a\u00020\f2\u0006\u0010 \u001a\u00020\f2\u0006\u0010!\u001a\u00020\"H\u0016J \u0010\u001e\u001a\u00020\f2\u0006\u0010\u001c\u001a\u00020\u00122\u0006\u0010 \u001a\u00020\f2\u0006\u0010!\u001a\u00020\"H\u0016J\u0006\u0010#\u001a\u00020\fJ\u0006\u0010$\u001a\u00020\fJ\u0010\u0010%\u001a\u00020&2\u0006\u0010\u001c\u001a\u00020\u0012H\u0016R\u000e\u0010\b\u001a\u00020\tX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u000b\u001a\u00020\f8VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b\r\u0010\u000eR\u0014\u0010\u000f\u001a\u00020\f8VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b\u0010\u0010\u000eR\u0014\u0010\u0011\u001a\u00020\u00128VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b\u0013\u0010\u0014R\u0014\u0010\u0015\u001a\u00020\u00128VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b\u0016\u0010\u0014R\u0014\u0010\u0017\u001a\u00020\u00128VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b\u0018\u0010\u0014R\u0014\u0010\u0019\u001a\u00020\u00128VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b\u001a\u0010\u0014R\u000e\u0010\u0006\u001a\u00020\u0007X\u0082\u0004¢\u0006\u0002\n��¨\u0006'"}, d2 = {"Lcom/komputation/cuda/layers/forward/projection/CublasProjectionLayer;", "Lcom/komputation/cuda/layers/BaseCudaForwardLayer;", "Lcom/komputation/cuda/layers/CudaVariableLengthForwardLayer;", "Lcom/komputation/optimization/Optimizable;", "name", "", "weightingLayer", "Lcom/komputation/cuda/layers/forward/projection/CublasWeightingLayer;", "biasLayer", "Lcom/komputation/cuda/layers/forward/projection/CublasBiasLayer;", "(Ljava/lang/String;Lcom/komputation/cuda/layers/forward/projection/CublasWeightingLayer;Lcom/komputation/cuda/layers/forward/projection/CublasBiasLayer;)V", "deviceBackwardResult", "Ljcuda/Pointer;", "getDeviceBackwardResult", "()Ljcuda/Pointer;", "deviceForwardResult", "getDeviceForwardResult", "maximumInputColumns", "", "getMaximumInputColumns", "()I", "maximumOutputColumns", "getMaximumOutputColumns", "numberInputRows", "getNumberInputRows", "numberOutputRows", "getNumberOutputRows", "backward", "batchSize", "chain", "forward", "deviceLengths", "deviceInput", "isTraining", "", "getDeviceBias", "getDeviceWeights", "optimize", "", "komputation"})
/* loaded from: input_file:com/komputation/cuda/layers/forward/projection/CublasProjectionLayer.class */
public final class CublasProjectionLayer extends BaseCudaForwardLayer implements CudaVariableLengthForwardLayer, Optimizable {
    private final CublasWeightingLayer weightingLayer;
    private final CublasBiasLayer biasLayer;

    @Override // com.komputation.cuda.CudaForwardState
    @NotNull
    public Pointer getDeviceForwardResult() {
        return this.biasLayer.getDeviceForwardResult();
    }

    @Override // com.komputation.cuda.CudaForwardState
    public int getNumberOutputRows() {
        return this.biasLayer.getNumberOutputRows();
    }

    @Override // com.komputation.cuda.CudaForwardState
    public int getMaximumOutputColumns() {
        return this.biasLayer.getMaximumOutputColumns();
    }

    @Override // com.komputation.cuda.CudaBackwardState
    @NotNull
    public Pointer getDeviceBackwardResult() {
        return this.weightingLayer.getDeviceBackwardResult();
    }

    @Override // com.komputation.cuda.CudaBackwardState
    public int getNumberInputRows() {
        return this.weightingLayer.getNumberInputRows();
    }

    @Override // com.komputation.cuda.CudaBackwardState
    public int getMaximumInputColumns() {
        return this.weightingLayer.getMaximumInputColumns();
    }

    @NotNull
    public final Pointer getDeviceWeights() {
        return this.weightingLayer.getDeviceWeights();
    }

    @NotNull
    public final Pointer getDeviceBias() {
        return this.biasLayer.getDeviceBias();
    }

    @Override // com.komputation.cuda.layers.CudaForwardLayer
    @NotNull
    public Pointer forward(int i, @NotNull Pointer pointer, boolean z) {
        Intrinsics.checkParameterIsNotNull(pointer, "deviceInput");
        this.biasLayer.forward(i, this.weightingLayer.forward(i, pointer, z), z);
        return getDeviceForwardResult();
    }

    @Override // com.komputation.cuda.layers.CudaVariableLengthForwardLayer
    @NotNull
    public Pointer forward(int i, @NotNull Pointer pointer, @NotNull Pointer pointer2, boolean z) {
        Intrinsics.checkParameterIsNotNull(pointer, "deviceLengths");
        Intrinsics.checkParameterIsNotNull(pointer2, "deviceInput");
        this.biasLayer.forward(i, pointer, this.weightingLayer.forward(i, pointer2, z), z);
        return getDeviceForwardResult();
    }

    @Override // com.komputation.cuda.layers.CudaForwardLayer
    @NotNull
    public Pointer backward(int i, @NotNull Pointer pointer) {
        Intrinsics.checkParameterIsNotNull(pointer, "chain");
        this.biasLayer.backward(i, pointer);
        this.weightingLayer.backward(i, pointer);
        return getDeviceBackwardResult();
    }

    @Override // com.komputation.optimization.Optimizable
    public void optimize(int i) {
        this.weightingLayer.optimize(i);
        this.biasLayer.optimize(i);
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public CublasProjectionLayer(@Nullable String str, @NotNull CublasWeightingLayer cublasWeightingLayer, @NotNull CublasBiasLayer cublasBiasLayer) {
        super(str);
        Intrinsics.checkParameterIsNotNull(cublasWeightingLayer, "weightingLayer");
        Intrinsics.checkParameterIsNotNull(cublasBiasLayer, "biasLayer");
        this.weightingLayer = cublasWeightingLayer;
        this.biasLayer = cublasBiasLayer;
    }
}
