package com.komputation.cuda.network;

import com.komputation.cuda.CudaContext;
import com.komputation.cuda.CudaContextKt;
import com.komputation.cuda.kernels.Kernel;
import com.komputation.cuda.kernels.TestingKernels;
import com.komputation.cuda.layers.CudaEntryPoint;
import com.komputation.cuda.layers.CudaForwardLayer;
import com.komputation.cuda.memory.InputMemory;
import com.komputation.cuda.workflow.CudaBinaryClassificationTester;
import com.komputation.cuda.workflow.CudaMultiClassificationTester;
import com.komputation.cuda.workflow.CudaTester;
import com.komputation.cuda.workflow.CudaTrainer;
import com.komputation.layers.CudaEntryPointInstruction;
import com.komputation.layers.CudaForwardLayerInstruction;
import com.komputation.layers.Resourceful;
import com.komputation.layers.ResourcefulKt;
import com.komputation.loss.CudaLossFunctionInstruction;
import com.komputation.matrix.Matrix;
import com.komputation.optimization.Optimizable;
import java.util.List;
import jcuda.Pointer;
import jcuda.jcublas.JCublas2;
import jcuda.jcublas.cublasHandle;
import kotlin.Metadata;
import kotlin.TypeCastException;
import kotlin.Unit;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.functions.Function0;
import kotlin.jvm.functions.Function2;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: CudaNetwork.kt */
@Metadata(mv = {1, 1, 9}, bv = {1, 0, 2}, k = 1, d1 = {"��\u009a\u0001\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0010\b\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0011\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000b\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0014\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0007\n\u0002\b\u0003\u0018��2\u00020\u0001B)\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u0012\u0012\u0010\u0006\u001a\n\u0012\u0006\b\u0001\u0012\u00020\b0\u0007\"\u00020\b¢\u0006\u0002\u0010\tJ\u0006\u0010\u001c\u001a\u00020\u001dJ\u0006\u0010\u001e\u001a\u00020\u0011J\u000e\u0010\u001f\u001a\u00020\u00172\u0006\u0010 \u001a\u00020\u0003J\u000e\u0010!\u001a\u00020\"2\u0006\u0010#\u001a\u00020$JA\u0010%\u001a\u00020&2\f\u0010'\u001a\b\u0012\u0004\u0012\u00020$0\u00072\f\u0010(\u001a\b\u0012\u0004\u0012\u00020)0\u00072\u0006\u0010*\u001a\u00020\u00032\u0006\u0010+\u001a\u00020\u00032\b\b\u0002\u0010,\u001a\u00020\u0003¢\u0006\u0002\u0010-Js\u0010.\u001a\u00020/2\f\u0010'\u001a\b\u0012\u0004\u0012\u00020$0\u00072\f\u0010(\u001a\b\u0012\u0004\u0012\u00020)0\u00072\u0006\u00100\u001a\u00020\u00032\u0006\u00101\u001a\u0002022:\b\u0002\u00103\u001a4\u0012\u0013\u0012\u00110\u0003¢\u0006\f\b5\u0012\b\b6\u0012\u0004\b\b( \u0012\u0013\u0012\u001107¢\u0006\f\b5\u0012\b\b6\u0012\u0004\b\b(8\u0012\u0004\u0012\u00020\u001d\u0018\u000104¢\u0006\u0002\u00109R\u000e\u0010\n\u001a\u00020\u000bX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\f\u001a\u00020\rX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u000e\u001a\u00020\u000fX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0010\u001a\u00020\u0011X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0012\u001a\u00020\u0013X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0014\u001a\u00020\u0015X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010\u0016\u001a\b\u0012\u0004\u0012\u00020\u00170\u0007X\u0082\u0004¢\u0006\u0004\n\u0002\u0010\u0018R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010\u0019\u001a\b\u0012\u0004\u0012\u00020\u001a0\u0007X\u0082\u0004¢\u0006\u0004\n\u0002\u0010\u001b¨\u0006:"}, d2 = {"Lcom/komputation/cuda/network/CudaNetwork;", "", "maximumBatchSize", "", "entryPointInstruction", "Lcom/komputation/layers/CudaEntryPointInstruction;", "forwardLayerInstructions", "", "Lcom/komputation/layers/CudaForwardLayerInstruction;", "(ILcom/komputation/layers/CudaEntryPointInstruction;[Lcom/komputation/layers/CudaForwardLayerInstruction;)V", "backwardPropagator", "Lcom/komputation/cuda/network/CudaBackwardPropagator;", "cublasHandle", "Ljcuda/jcublas/cublasHandle;", "cudaContext", "Lcom/komputation/cuda/CudaContext;", "entryPoint", "Lcom/komputation/cuda/layers/CudaEntryPoint;", "forwardPropagator", "Lcom/komputation/cuda/network/CudaForwardPropagator;", "hasFixedLengthInput", "", "layers", "Lcom/komputation/cuda/layers/CudaForwardLayer;", "[Lcom/komputation/cuda/layers/CudaForwardLayer;", "optimizables", "Lcom/komputation/optimization/Optimizable;", "[Lcom/komputation/optimization/Optimizable;", "free", "", "getEntryPoint", "getLayer", "index", "predict", "Ljcuda/Pointer;", "input", "Lcom/komputation/matrix/Matrix;", "test", "Lcom/komputation/cuda/workflow/CudaTester;", "inputs", "targets", "", "batchSize", "numberCategories", "length", "([Lcom/komputation/matrix/Matrix;[[FIII)Lcom/komputation/cuda/workflow/CudaTester;", "training", "Lcom/komputation/cuda/workflow/CudaTrainer;", "numberIterations", "lossFunction", "Lcom/komputation/loss/CudaLossFunctionInstruction;", "afterEachIteration", "Lkotlin/Function2;", "Lkotlin/ParameterName;", "name", "", "loss", "([Lcom/komputation/matrix/Matrix;[[FILcom/komputation/loss/CudaLossFunctionInstruction;Lkotlin/jvm/functions/Function2;)Lcom/komputation/cuda/workflow/CudaTrainer;", "komputation"})
/* loaded from: input_file:com/komputation/cuda/network/CudaNetwork.class */
public final class CudaNetwork {
    private final CudaContext cudaContext;
    private final cublasHandle cublasHandle;
    private final CudaEntryPoint entryPoint;
    private final boolean hasFixedLengthInput;
    private final CudaForwardLayer[] layers;
    private final Optimizable[] optimizables;
    private final CudaForwardPropagator forwardPropagator;
    private final CudaBackwardPropagator backwardPropagator;
    private final int maximumBatchSize;

    @NotNull
    public final CudaEntryPoint getEntryPoint() {
        return this.entryPoint;
    }

    @NotNull
    public final CudaForwardLayer getLayer(int i) {
        return this.layers[i];
    }

    public final void free() {
        JCublas2.cublasDestroy(this.cublasHandle);
        for (CudaForwardLayer cudaForwardLayer : this.layers) {
            if (cudaForwardLayer instanceof Resourceful) {
                ((Resourceful) cudaForwardLayer).release();
            }
        }
        if (this.entryPoint instanceof Resourceful) {
            ((Resourceful) this.entryPoint).release();
        }
    }

    @NotNull
    public final CudaTrainer training(@NotNull Matrix[] matrixArr, @NotNull float[][] fArr, int i, @NotNull CudaLossFunctionInstruction cudaLossFunctionInstruction, @Nullable Function2<? super Integer, ? super Float, Unit> function2) {
        Intrinsics.checkParameterIsNotNull(matrixArr, "inputs");
        Intrinsics.checkParameterIsNotNull(fArr, "targets");
        Intrinsics.checkParameterIsNotNull(cudaLossFunctionInstruction, "lossFunction");
        return new CudaTrainer(this.forwardPropagator, this.backwardPropagator, this.optimizables, matrixArr, fArr, i, this.maximumBatchSize, cudaLossFunctionInstruction.buildForCuda(this.cudaContext), function2);
    }

    @NotNull
    public static /* bridge */ /* synthetic */ CudaTrainer training$default(CudaNetwork cudaNetwork, Matrix[] matrixArr, float[][] fArr, int i, CudaLossFunctionInstruction cudaLossFunctionInstruction, Function2 function2, int i2, Object obj) {
        if ((i2 & 16) != 0) {
            function2 = (Function2) null;
        }
        return cudaNetwork.training(matrixArr, fArr, i, cudaLossFunctionInstruction, function2);
    }

    @NotNull
    public final CudaTester test(@NotNull Matrix[] matrixArr, @NotNull float[][] fArr, int i, int i2, int i3) {
        Intrinsics.checkParameterIsNotNull(matrixArr, "inputs");
        Intrinsics.checkParameterIsNotNull(fArr, "targets");
        return new CudaTester(this.forwardPropagator, i2 == 1 ? new CudaBinaryClassificationTester(matrixArr.length, i2, new Function0<Kernel>() { // from class: com.komputation.cuda.network.CudaNetwork$test$classificationTester$1
            @NotNull
            public final Kernel invoke() {
                CudaContext cudaContext;
                cudaContext = CudaNetwork.this.cudaContext;
                return cudaContext.createKernel(TestingKernels.INSTANCE.binary());
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            {
                super(0);
            }
        }) : new CudaMultiClassificationTester(matrixArr.length, i2, i3, new Function0<Kernel>() { // from class: com.komputation.cuda.network.CudaNetwork$test$classificationTester$2
            @NotNull
            public final Kernel invoke() {
                CudaContext cudaContext;
                cudaContext = CudaNetwork.this.cudaContext;
                return cudaContext.createKernel(TestingKernels.INSTANCE.multiClass());
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            {
                super(0);
            }
        }), matrixArr, fArr, i);
    }

    @NotNull
    public static /* bridge */ /* synthetic */ CudaTester test$default(CudaNetwork cudaNetwork, Matrix[] matrixArr, float[][] fArr, int i, int i2, int i3, int i4, Object obj) {
        if ((i4 & 16) != 0) {
            i3 = 1;
        }
        return cudaNetwork.test(matrixArr, fArr, i, i2, i3);
    }

    @NotNull
    public final Pointer predict(@NotNull Matrix matrix) {
        Intrinsics.checkParameterIsNotNull(matrix, "input");
        InputMemory inputMemory = new InputMemory();
        Pointer forward = this.forwardPropagator.forward(0, 1, new int[]{0}, new Matrix[]{matrix}, inputMemory, false);
        inputMemory.free();
        return forward;
    }

    public CudaNetwork(int i, @NotNull CudaEntryPointInstruction cudaEntryPointInstruction, @NotNull CudaForwardLayerInstruction... cudaForwardLayerInstructionArr) {
        Intrinsics.checkParameterIsNotNull(cudaEntryPointInstruction, "entryPointInstruction");
        Intrinsics.checkParameterIsNotNull(cudaForwardLayerInstructionArr, "forwardLayerInstructions");
        this.maximumBatchSize = i;
        this.cudaContext = CudaContextKt.setUpCudaContext$default(0, 1, null);
        this.cublasHandle = new cublasHandle();
        this.entryPoint = cudaEntryPointInstruction.buildForCuda(this.cudaContext);
        this.hasFixedLengthInput = this.entryPoint.getHasFixedLength();
        CudaForwardLayer[] cudaForwardLayerArr = new CudaForwardLayer[cudaForwardLayerInstructionArr.length];
        int length = cudaForwardLayerArr.length;
        for (int i2 = 0; i2 < length; i2++) {
            cudaForwardLayerArr[i2] = cudaForwardLayerInstructionArr[i2].buildForCuda(this.cudaContext, this.cublasHandle);
        }
        this.layers = cudaForwardLayerArr;
        List reversed = CollectionsKt.reversed(CollectionsKt.filterIsInstance(CollectionsKt.plus(CollectionsKt.listOf(this.entryPoint), this.layers), Optimizable.class));
        if (reversed == null) {
            throw new TypeCastException("null cannot be cast to non-null type java.util.Collection<T>");
        }
        Object[] array = reversed.toArray(new Optimizable[reversed.size()]);
        if (array == null) {
            throw new TypeCastException("null cannot be cast to non-null type kotlin.Array<T>");
        }
        this.optimizables = (Optimizable[]) array;
        this.forwardPropagator = this.hasFixedLengthInput ? new CudaFixedLengthForwardPropagator(this.entryPoint, this.layers) : new CudaVariableLengthForwardPropagator(this.entryPoint, this.layers);
        this.backwardPropagator = new CudaBackwardPropagator(this.entryPoint, this.layers);
        JCublas2.cublasCreate(this.cublasHandle);
        ResourcefulKt.acquireRecursively(this.entryPoint, this.maximumBatchSize);
        for (CudaForwardLayer cudaForwardLayer : this.layers) {
            ResourcefulKt.acquireRecursively(cudaForwardLayer, this.maximumBatchSize);
        }
    }
}
