package com.komputation.cuda.demos.embeddings;

import com.komputation.cuda.network.CudaNetwork;
import com.komputation.demos.embeddings.EmbeddingData;
import com.komputation.initialization.InitializationKt;
import com.komputation.initialization.UniformInitialization;
import com.komputation.initialization.UniformInitializationKt;
import com.komputation.layers.entry.LookupLayerKt;
import com.komputation.layers.forward.activation.ActivationFunction;
import com.komputation.layers.forward.convolution.MaxPoolingLayerKt;
import com.komputation.layers.forward.dense.DenseLayerKt;
import com.komputation.loss.LossPrintKt;
import com.komputation.loss.SquaredLossKt;
import com.komputation.matrix.Matrix;
import com.komputation.optimization.historical.Momentum;
import com.komputation.optimization.historical.MomentumKt;
import java.util.Random;
import kotlin.Metadata;
import kotlin.jvm.functions.Function0;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: Embeddings.kt */
@Metadata(mv = {1, 1, 9}, bv = {1, 0, 2}, k = 2, d1 = {"��\u0014\n��\n\u0002\u0010\u0002\n��\n\u0002\u0010\u0011\n\u0002\u0010\u000e\n\u0002\b\u0002\u001a\u0019\u0010��\u001a\u00020\u00012\f\u0010\u0002\u001a\b\u0012\u0004\u0012\u00020\u00040\u0003¢\u0006\u0002\u0010\u0005¨\u0006\u0006"}, d2 = {"main", "", "args", "", "", "([Ljava/lang/String;)V", "komputation"})
/* loaded from: input_file:com/komputation/cuda/demos/embeddings/EmbeddingsKt.class */
public final class EmbeddingsKt {
    /* JADX WARN: Multi-variable type inference failed */
    public static final void main(@NotNull String[] strArr) {
        Intrinsics.checkParameterIsNotNull(strArr, "args");
        final int i = 2;
        final UniformInitialization uniformInitialization = UniformInitializationKt.uniformInitialization(new Random(1L), -0.05f, 0.05f);
        Function0<float[]> function0 = new Function0<float[]>() { // from class: com.komputation.cuda.demos.embeddings.EmbeddingsKt$main$initializeEmbedding$1
            @NotNull
            public final float[] invoke() {
                return InitializationKt.initializeColumnVector(UniformInitialization.this, i);
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            {
                super(0);
            }
        };
        float[] fArr = new float[40];
        int length = fArr.length;
        for (int i2 = 0; i2 < length; i2++) {
            fArr[i2] = (float[]) function0.invoke();
        }
        float[][] fArr2 = (float[][]) fArr;
        Momentum momentum = MomentumKt.momentum(0.01f, 0.9f);
        Matrix[] inputs = EmbeddingData.INSTANCE.getInputs();
        float[][] targets = EmbeddingData.INSTANCE.getTargets();
        int numberClasses = EmbeddingData.INSTANCE.getNumberClasses();
        new CudaNetwork(1, LookupLayerKt.lookupLayer(fArr2, 2, true, 2, momentum), MaxPoolingLayerKt.maxPoolingLayer(2, 2), DenseLayerKt.denseLayer(2, numberClasses, uniformInitialization, uniformInitialization, ActivationFunction.Softmax, momentum)).training(inputs, targets, 1000, SquaredLossKt.squaredLoss$default(numberClasses, 0, false, 6, null), LossPrintKt.getPrintLoss()).run();
    }
}
