package com.komputation.cuda.workflow;

import com.komputation.cuda.loss.CudaLossFunction;
import com.komputation.cuda.memory.InputMemory;
import com.komputation.cuda.memory.TargetMemory;
import com.komputation.cuda.network.CudaBackwardPropagator;
import com.komputation.cuda.network.CudaForwardPropagator;
import com.komputation.matrix.Matrix;
import com.komputation.matrix.PartitioningKt;
import com.komputation.optimization.Optimizable;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import kotlin.Metadata;
import kotlin.Unit;
import kotlin.collections.ArraysKt;
import kotlin.collections.IndexedValue;
import kotlin.jvm.functions.Function2;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.FloatCompanionObject;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: CudaTrainer.kt */
@Metadata(mv = {1, 1, 9}, bv = {1, 0, 2}, k = 1, d1 = {"��r\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0011\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0014\n��\n\u0002\u0010\b\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0007\n��\n\u0002\u0010\u0002\n\u0002\b\u0002\n\u0002\u0010\u0015\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\t\n��\u0018��2\u00020\u0001B\u0093\u0001\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005\u0012\f\u0010\u0006\u001a\b\u0012\u0004\u0012\u00020\b0\u0007\u0012\f\u0010\t\u001a\b\u0012\u0004\u0012\u00020\n0\u0007\u0012\f\u0010\u000b\u001a\b\u0012\u0004\u0012\u00020\f0\u0007\u0012\u0006\u0010\r\u001a\u00020\u000e\u0012\u0006\u0010\u000f\u001a\u00020\u000e\u0012\u0006\u0010\u0010\u001a\u00020\u0011\u0012:\b\u0002\u0010\u0012\u001a4\u0012\u0013\u0012\u00110\u000e¢\u0006\f\b\u0014\u0012\b\b\u0015\u0012\u0004\b\b(\u0016\u0012\u0013\u0012\u00110\u0017¢\u0006\f\b\u0014\u0012\b\b\u0015\u0012\u0004\b\b(\u0018\u0012\u0004\u0012\u00020\u0019\u0018\u00010\u0013¢\u0006\u0002\u0010\u001aJ\u0006\u0010&\u001a\u00020\u0019J\u0006\u0010'\u001a\u00020(R@\u0010\u0012\u001a4\u0012\u0013\u0012\u00110\u000e¢\u0006\f\b\u0014\u0012\b\b\u0015\u0012\u0004\b\b(\u0016\u0012\u0013\u0012\u00110\u0017¢\u0006\f\b\u0014\u0012\b\b\u0015\u0012\u0004\b\b(\u0018\u0012\u0004\u0012\u00020\u0019\u0018\u00010\u0013X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0004\u001a\u00020\u0005X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010\u001b\u001a\b\u0012\u0004\u0012\u00020\u001c0\u0007X\u0082\u0004¢\u0006\u0004\n\u0002\u0010\u001dR\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u001e\u001a\u00020\u001fX\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010\t\u001a\b\u0012\u0004\u0012\u00020\n0\u0007X\u0082\u0004¢\u0006\u0004\n\u0002\u0010 R\u000e\u0010\u0010\u001a\u00020\u0011X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u000f\u001a\u00020\u000eX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010!\u001a\u00020\u000eX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\r\u001a\u00020\u000eX\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010\u0006\u001a\b\u0012\u0004\u0012\u00020\b0\u0007X\u0082\u0004¢\u0006\u0004\n\u0002\u0010\"R\u000e\u0010#\u001a\u00020$X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010\u000b\u001a\b\u0012\u0004\u0012\u00020\f0\u0007X\u0082\u0004¢\u0006\u0004\n\u0002\u0010%¨\u0006)"}, d2 = {"Lcom/komputation/cuda/workflow/CudaTrainer;", "", "forwardPropagator", "Lcom/komputation/cuda/network/CudaForwardPropagator;", "backwardPropagator", "Lcom/komputation/cuda/network/CudaBackwardPropagator;", "optimizables", "", "Lcom/komputation/optimization/Optimizable;", "inputs", "Lcom/komputation/matrix/Matrix;", "targets", "", "numberIterations", "", "maximumBatchSize", "lossFunction", "Lcom/komputation/cuda/loss/CudaLossFunction;", "afterEachIteration", "Lkotlin/Function2;", "Lkotlin/ParameterName;", "name", "index", "", "loss", "", "(Lcom/komputation/cuda/network/CudaForwardPropagator;Lcom/komputation/cuda/network/CudaBackwardPropagator;[Lcom/komputation/optimization/Optimizable;[Lcom/komputation/matrix/Matrix;[[FIILcom/komputation/cuda/loss/CudaLossFunction;Lkotlin/jvm/functions/Function2;)V", "batches", "", "[[I", "inputMemory", "Lcom/komputation/cuda/memory/InputMemory;", "[Lcom/komputation/matrix/Matrix;", "numberExamples", "[Lcom/komputation/optimization/Optimizable;", "targetMemory", "Lcom/komputation/cuda/memory/TargetMemory;", "[[F", "free", "run", "", "komputation"})
/* loaded from: input_file:com/komputation/cuda/workflow/CudaTrainer.class */
public final class CudaTrainer {
    private final int numberExamples;
    private final int[][] batches;
    private final InputMemory inputMemory;
    private final TargetMemory targetMemory;
    private final CudaForwardPropagator forwardPropagator;
    private final CudaBackwardPropagator backwardPropagator;
    private final Optimizable[] optimizables;
    private final Matrix[] inputs;
    private final float[][] targets;
    private final int numberIterations;
    private final int maximumBatchSize;
    private final CudaLossFunction lossFunction;
    private final Function2<Integer, Float, Unit> afterEachIteration;

    public final void free() {
        this.lossFunction.release();
        this.inputMemory.free();
        this.targetMemory.free();
    }

    public final long run() {
        boolean z = this.afterEachIteration != null;
        long currentTimeMillis = System.currentTimeMillis();
        int i = 0;
        int i2 = this.numberIterations - 1;
        if (0 <= i2) {
            while (true) {
                int i3 = i;
                float naN = z ? 0.0f : FloatCompanionObject.INSTANCE.getNaN();
                for (IndexedValue indexedValue : ArraysKt.withIndex(this.batches)) {
                    int component1 = indexedValue.component1();
                    int[] iArr = (int[]) indexedValue.component2();
                    int length = iArr.length;
                    Pointer pointer = Pointer.to(new NativePointerObject[]{this.forwardPropagator.forward(component1, length, iArr, this.inputs, this.inputMemory, true)});
                    Pointer pointer2 = this.targetMemory.get(component1, length, iArr, this.targets);
                    if (z) {
                        CudaLossFunction cudaLossFunction = this.lossFunction;
                        Intrinsics.checkExpressionValueIsNotNull(pointer, "pointerToDevicePredictions");
                        Intrinsics.checkExpressionValueIsNotNull(pointer2, "pointerToTargets");
                        cudaLossFunction.accumulate(pointer, pointer2, length);
                    }
                    CudaLossFunction cudaLossFunction2 = this.lossFunction;
                    Intrinsics.checkExpressionValueIsNotNull(pointer, "pointerToDevicePredictions");
                    Intrinsics.checkExpressionValueIsNotNull(pointer2, "pointerToTargets");
                    this.backwardPropagator.backward(cudaLossFunction2.backward(pointer, pointer2, length), length);
                    for (Optimizable optimizable : this.optimizables) {
                        optimizable.optimize(length);
                    }
                    if (z) {
                        naN += this.lossFunction.accessAccumulation();
                    }
                }
                Function2<Integer, Float, Unit> function2 = this.afterEachIteration;
                if (function2 != null) {
                }
                if (i == i2) {
                    break;
                }
                i++;
            }
        }
        return System.currentTimeMillis() - currentTimeMillis;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public CudaTrainer(@NotNull CudaForwardPropagator cudaForwardPropagator, @NotNull CudaBackwardPropagator cudaBackwardPropagator, @NotNull Optimizable[] optimizableArr, @NotNull Matrix[] matrixArr, @NotNull float[][] fArr, int i, int i2, @NotNull CudaLossFunction cudaLossFunction, @Nullable Function2<? super Integer, ? super Float, Unit> function2) {
        Intrinsics.checkParameterIsNotNull(cudaForwardPropagator, "forwardPropagator");
        Intrinsics.checkParameterIsNotNull(cudaBackwardPropagator, "backwardPropagator");
        Intrinsics.checkParameterIsNotNull(optimizableArr, "optimizables");
        Intrinsics.checkParameterIsNotNull(matrixArr, "inputs");
        Intrinsics.checkParameterIsNotNull(fArr, "targets");
        Intrinsics.checkParameterIsNotNull(cudaLossFunction, "lossFunction");
        this.forwardPropagator = cudaForwardPropagator;
        this.backwardPropagator = cudaBackwardPropagator;
        this.optimizables = optimizableArr;
        this.inputs = matrixArr;
        this.targets = fArr;
        this.numberIterations = i;
        this.maximumBatchSize = i2;
        this.lossFunction = cudaLossFunction;
        this.afterEachIteration = function2;
        this.numberExamples = this.inputs.length;
        this.batches = PartitioningKt.partitionIndices(this.numberExamples, this.maximumBatchSize);
        this.inputMemory = new InputMemory();
        this.targetMemory = new TargetMemory(((float[]) ArraysKt.first(this.targets)).length);
        this.lossFunction.acquire(this.maximumBatchSize);
    }

    public /* synthetic */ CudaTrainer(CudaForwardPropagator cudaForwardPropagator, CudaBackwardPropagator cudaBackwardPropagator, Optimizable[] optimizableArr, Matrix[] matrixArr, float[][] fArr, int i, int i2, CudaLossFunction cudaLossFunction, Function2 function2, int i3, DefaultConstructorMarker defaultConstructorMarker) {
        this(cudaForwardPropagator, cudaBackwardPropagator, optimizableArr, matrixArr, fArr, i, i2, cudaLossFunction, (i3 & 256) != 0 ? (Function2) null : function2);
    }
}
