package com.komputation.cpu.demos.trec;

import com.komputation.cpu.network.Network;
import com.komputation.cpu.workflow.CpuTester;
import com.komputation.demos.trec.NLP;
import com.komputation.demos.trec.TRECData;
import com.komputation.initialization.UniformInitialization;
import com.komputation.initialization.UniformInitializationKt;
import com.komputation.layers.CpuForwardLayerInstruction;
import com.komputation.layers.entry.LookupLayer;
import com.komputation.layers.entry.LookupLayerKt;
import com.komputation.layers.forward.ConcatenationKt;
import com.komputation.layers.forward.activation.ReluLayerKt;
import com.komputation.layers.forward.activation.SoftmaxLayerKt;
import com.komputation.layers.forward.convolution.ConvolutionLayer;
import com.komputation.layers.forward.convolution.ConvolutionLayerKt;
import com.komputation.layers.forward.dropout.DropoutLayerKt;
import com.komputation.layers.forward.projection.ProjectionLayerKt;
import com.komputation.loss.CrossEntropyLossKt;
import com.komputation.matrix.Matrix;
import com.komputation.optimization.historical.Nesterov;
import com.komputation.optimization.historical.NesterovKt;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import kotlin.Metadata;
import kotlin.Pair;
import kotlin.TypeCastException;
import kotlin.Unit;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.collections.SetsKt;
import kotlin.jvm.functions.Function2;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: TRECWithTwoFilterWidths.kt */
@Metadata(mv = {1, 1, 9}, bv = {1, 0, 2}, k = 1, d1 = {"��\u001e\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0010\u0002\n��\n\u0002\u0010\u000e\n��\n\u0002\u0010\b\n��\u0018��2\u00020\u0001B\u0005¢\u0006\u0002\u0010\u0002J\u0016\u0010\u0003\u001a\u00020\u00042\u0006\u0010\u0005\u001a\u00020\u00062\u0006\u0010\u0007\u001a\u00020\b¨\u0006\t"}, d2 = {"Lcom/komputation/cpu/demos/trec/TrecWithTwoFilterWidths;", "", "()V", "run", "", "embeddingFilePath", "", "embeddingDimension", "", "komputation"})
/* loaded from: input_file:com/komputation/cpu/demos/trec/TrecWithTwoFilterWidths.class */
public final class TrecWithTwoFilterWidths {
    public final void run(@NotNull String str, int i) {
        Object obj;
        Intrinsics.checkParameterIsNotNull(str, "embeddingFilePath");
        Random random = new Random(1L);
        UniformInitialization uniformInitialization = UniformInitializationKt.uniformInitialization(random, -0.1f, 0.1f);
        Nesterov nesterov = NesterovKt.nesterov(0.001f, 0.85f);
        int[] iArr = {2, 3};
        Integer max = ArraysKt.max(iArr);
        if (max == null) {
            Intrinsics.throwNpe();
        }
        int intValue = max.intValue();
        int length = iArr.length;
        File file = new File(getClass().getClassLoader().getResource("trec").toURI());
        File file2 = new File(file, "training.data");
        File file3 = new File(file, "test.data");
        Pair<List<String>, List<List<String>>> readExamples = TRECData.INSTANCE.readExamples(file2);
        List list = (List) readExamples.component1();
        List list2 = (List) readExamples.component2();
        Pair<List<String>, List<List<String>>> readExamples2 = TRECData.INSTANCE.readExamples(file3);
        List list3 = (List) readExamples2.component1();
        List list4 = (List) readExamples2.component2();
        Set<String> generateVocabulary = NLP.INSTANCE.generateVocabulary(list2);
        Map<String, float[]> embedVocabulary = NLP.INSTANCE.embedVocabulary(generateVocabulary, new File(str));
        List sorted = CollectionsKt.sorted(embedVocabulary.keySet());
        SetsKt.minus(generateVocabulary, embedVocabulary.keySet());
        List<List<String>> filterTokens = NLP.INSTANCE.filterTokens(list2, sorted);
        Iterator<T> it = filterTokens.iterator();
        if (it.hasNext()) {
            Object next = it.next();
            int size = ((List) next).size();
            while (it.hasNext()) {
                Object next2 = it.next();
                int size2 = ((List) next2).size();
                if (size < size2) {
                    next = next2;
                    size = size2;
                }
            }
            obj = next;
        } else {
            obj = null;
        }
        if (obj == null) {
            Intrinsics.throwNpe();
        }
        int size3 = ((List) obj).size();
        List<List<String>> filterTokens2 = NLP.INSTANCE.filterTokens(list4, sorted);
        List<Integer> filterDocuments = NLP.INSTANCE.filterDocuments(filterTokens, intValue);
        List<Integer> filterDocuments2 = NLP.INSTANCE.filterDocuments(filterTokens2, intValue);
        List slice = CollectionsKt.slice(filterTokens, filterDocuments);
        List slice2 = CollectionsKt.slice(filterTokens2, filterDocuments2);
        Matrix[] vectorizeDocuments = NLP.INSTANCE.vectorizeDocuments(slice, sorted);
        Matrix[] vectorizeDocuments2 = NLP.INSTANCE.vectorizeDocuments(slice2, sorted);
        List slice3 = CollectionsKt.slice(list, filterDocuments);
        List slice4 = CollectionsKt.slice(list3, filterDocuments2);
        Map<String, Integer> indexCategories = NLP.INSTANCE.indexCategories(CollectionsKt.toSet(slice3));
        int size4 = indexCategories.size();
        float[][] createTargets = NLP.INSTANCE.createTargets(slice3, indexCategories);
        float[][] createTargets2 = NLP.INSTANCE.createTargets(slice4, indexCategories);
        List list5 = sorted;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(list5, 10));
        Iterator it2 = list5.iterator();
        while (it2.hasNext()) {
            float[] fArr = embedVocabulary.get((String) it2.next());
            if (fArr == null) {
                Intrinsics.throwNpe();
            }
            arrayList.add(fArr);
        }
        ArrayList arrayList2 = arrayList;
        Object[] array = arrayList2.toArray((Object[]) new float[arrayList2.size()]);
        if (array == null) {
            throw new TypeCastException("null cannot be cast to non-null type kotlin.Array<T>");
        }
        LookupLayer lookupLayer = LookupLayerKt.lookupLayer((float[][]) array, size3, false, i, nesterov);
        CpuForwardLayerInstruction[] cpuForwardLayerInstructionArr = new CpuForwardLayerInstruction[5];
        ArrayList arrayList3 = new ArrayList(iArr.length);
        for (int i2 : iArr) {
            arrayList3.add(ConvolutionLayerKt.convolutionalLayer(i, size3, false, 100, i2, i, uniformInitialization, uniformInitialization, nesterov));
        }
        ArrayList arrayList4 = arrayList3;
        Object[] array2 = arrayList4.toArray(new ConvolutionLayer[arrayList4.size()]);
        if (array2 == null) {
            throw new TypeCastException("null cannot be cast to non-null type kotlin.Array<T>");
        }
        CpuForwardLayerInstruction[] cpuForwardLayerInstructionArr2 = (CpuForwardLayerInstruction[]) array2;
        cpuForwardLayerInstructionArr[0] = ConcatenationKt.concatenation((CpuForwardLayerInstruction[]) Arrays.copyOf(cpuForwardLayerInstructionArr2, cpuForwardLayerInstructionArr2.length));
        cpuForwardLayerInstructionArr[1] = ReluLayerKt.reluLayer$default(length * 100, 0, false, 6, null);
        cpuForwardLayerInstructionArr[2] = DropoutLayerKt.dropoutLayer(length * 100, 1, false, random, 0.8f);
        cpuForwardLayerInstructionArr[3] = ProjectionLayerKt.projectionLayer(iArr.length * 100, size4, uniformInitialization, uniformInitialization, nesterov);
        cpuForwardLayerInstructionArr[4] = SoftmaxLayerKt.softmaxLayer$default(size4, 0, false, 6, null);
        Network network = new Network(1, lookupLayer, cpuForwardLayerInstructionArr);
        final CpuTester test = network.test(vectorizeDocuments2, createTargets2, 1, size4, 1);
        network.training(vectorizeDocuments, createTargets, 20, CrossEntropyLossKt.crossEntropyLoss$default(size4, 0, false, 6, null), new Function2<Integer, Float, Unit>() { // from class: com.komputation.cpu.demos.trec.TrecWithTwoFilterWidths$run$1
            public /* bridge */ /* synthetic */ Object invoke(Object obj2, Object obj3) {
                invoke(((Number) obj2).intValue(), ((Number) obj3).floatValue());
                return Unit.INSTANCE;
            }

            public final void invoke(int i3, float f) {
                System.out.println(CpuTester.this.run());
            }

            /* JADX INFO: Access modifiers changed from: package-private */
            {
                super(2);
            }
        }).run();
    }
}
