package com.komputation.cuda.layers.entry;

import com.komputation.cpu.functions.ConcatenationKt;
import com.komputation.cpu.functions.PaddingKt;
import com.komputation.cuda.CudaFloatArrayKt;
import com.komputation.cuda.CudaForwardState;
import com.komputation.cuda.CudaIntArrayKt;
import com.komputation.cuda.kernels.Kernel;
import com.komputation.cuda.kernels.launch.ColumnwiseKt;
import com.komputation.cuda.kernels.launch.KernelLaunchConfiguration;
import com.komputation.cuda.layers.BaseCudaEntryPoint;
import com.komputation.cuda.memory.InputMemory;
import com.komputation.cuda.optimization.BaseCudaUpdateRule;
import com.komputation.layers.Resourceful;
import com.komputation.matrix.IntMatrix;
import com.komputation.matrix.Matrix;
import com.komputation.optimization.Optimizable;
import java.util.Arrays;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.runtime.JCuda;
import kotlin.Metadata;
import kotlin.TypeCastException;
import kotlin.collections.ArraysKt;
import kotlin.collections.IndexedValue;
import kotlin.jvm.functions.Function0;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: CudaLookupLayer.kt */
@Metadata(mv = {1, 1, 9}, bv = {1, 0, 2}, k = 1, d1 = {"��\u0080\u0001\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000e\n��\n\u0002\u0010\u0014\n��\n\u0002\u0010\b\n��\n\u0002\u0010\u000b\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u0011\n\u0002\u0010\u0015\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\n\n\u0002\u0018\u0002\n\u0002\b\u0014\n\u0002\u0010\u0002\n\u0002\b\u0007\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0006\u0018��2\u00020\u00012\u00020\u00022\u00020\u00032\u00020\u0004Ba\b��\u0012\b\u0010\u0005\u001a\u0004\u0018\u00010\u0006\u0012\u0006\u0010\u0007\u001a\u00020\b\u0012\u0006\u0010\t\u001a\u00020\n\u0012\u0006\u0010\u000b\u001a\u00020\f\u0012\u0006\u0010\r\u001a\u00020\n\u0012\b\u0010\u000e\u001a\u0004\u0018\u00010\u000f\u0012\f\u0010\u0010\u001a\b\u0012\u0004\u0012\u00020\u00120\u0011\u0012\u0006\u0010\u0013\u001a\u00020\u0014\u0012\u0006\u0010\u0015\u001a\u00020\u0016\u0012\u0006\u0010\u0017\u001a\u00020\n¢\u0006\u0002\u0010\u0018J\u0010\u0010=\u001a\u00020>2\u0006\u0010*\u001a\u00020\nH\u0016J\u0010\u0010?\u001a\u00020\u001e2\u0006\u0010@\u001a\u00020\u001eH\u0016J;\u0010A\u001a\u00020\u001e2\u0006\u0010B\u001a\u00020\n2\u0006\u0010C\u001a\u00020\n2\u0006\u0010D\u001a\u00020\u001b2\f\u0010E\u001a\b\u0012\u0004\u0012\u00020F0\u001a2\u0006\u0010G\u001a\u00020HH\u0016¢\u0006\u0002\u0010IJC\u0010J\u001a\u00020\u001e2\u0006\u0010G\u001a\u00020H2\u0006\u0010B\u001a\u00020\n2\u0006\u0010D\u001a\u00020\u001b2\f\u0010E\u001a\b\u0012\u0004\u0012\u00020F0\u001a2\u0006\u0010C\u001a\u00020\n2\u0006\u0010*\u001a\u00020\nH\u0002¢\u0006\u0002\u0010KJ\u0010\u0010L\u001a\u00020>2\u0006\u0010C\u001a\u00020\nH\u0016J\b\u0010M\u001a\u00020>H\u0016R\u0016\u0010\u0019\u001a\b\u0012\u0004\u0012\u00020\u001b0\u001aX\u0082\u000e¢\u0006\u0004\n\u0002\u0010\u001cR\u0014\u0010\u0010\u001a\b\u0012\u0004\u0012\u00020\u00120\u0011X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u001d\u001a\u00020\u001eX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b\u001f\u0010 R\u000e\u0010!\u001a\u00020\u001eX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\"\u001a\u00020\u001eX\u0082\u0004¢\u0006\u0002\n��R\u0010\u0010#\u001a\u0004\u0018\u00010\u0012X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010$\u001a\u00020\bX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0015\u001a\u00020\u0016X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u000b\u001a\u00020\fX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b%\u0010&R\u000e\u0010\u0013\u001a\u00020\u0014X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010'\u001a\u00020\u001bX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010(\u001a\u00020)X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010*\u001a\u00020\u001bX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\u0017\u001a\u00020\nX\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010+\u001a\u00020\nX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b,\u0010-R\u000e\u0010.\u001a\u00020\nX\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010/\u001a\u00020\nX\u0082\u0004¢\u0006\u0002\n��R\u000e\u00100\u001a\u00020\nX\u0082\u0004¢\u0006\u0002\n��R\u0014\u00101\u001a\u00020\nX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b2\u0010-R\u000e\u00103\u001a\u00020\nX\u0082\u0004¢\u0006\u0002\n��R\u000e\u00104\u001a\u00020\u001bX\u0082\u000e¢\u0006\u0002\n��R\u0016\u00105\u001a\n 6*\u0004\u0018\u00010\u001e0\u001eX\u0082\u0004¢\u0006\u0002\n��R\u0016\u00107\u001a\n 6*\u0004\u0018\u00010\u001e0\u001eX\u0082\u0004¢\u0006\u0002\n��R\u0016\u00108\u001a\n 6*\u0004\u0018\u00010\u001e0\u001eX\u0082\u0004¢\u0006\u0002\n��R\u0016\u00109\u001a\n 6*\u0004\u0018\u00010\u001e0\u001eX\u0082\u000e¢\u0006\u0002\n��R\u0016\u0010:\u001a\n 6*\u0004\u0018\u00010\u001e0\u001eX\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010;\u001a\n 6*\u0004\u0018\u00010\u001e0\u001eX\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010<\u001a\n 6*\u0004\u0018\u00010\u001e0\u001eX\u0082\u0004¢\u0006\u0002\n��R\u0010\u0010\u000e\u001a\u0004\u0018\u00010\u000fX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0007\u001a\u00020\bX\u0082\u000e¢\u0006\u0002\n��¨\u0006N"}, d2 = {"Lcom/komputation/cuda/layers/entry/CudaLookupLayer;", "Lcom/komputation/cuda/layers/BaseCudaEntryPoint;", "Lcom/komputation/cuda/CudaForwardState;", "Lcom/komputation/layers/Resourceful;", "Lcom/komputation/optimization/Optimizable;", "name", "", "vectors", "", "maximumLength", "", "hasFixedLength", "", "dimension", "updateRule", "Lcom/komputation/cuda/optimization/BaseCudaUpdateRule;", "createForwardKernel", "Lkotlin/Function0;", "Lcom/komputation/cuda/kernels/Kernel;", "hashing", "Lcom/komputation/cuda/layers/entry/CudaHashing;", "groupSum", "Lcom/komputation/cuda/layers/entry/CudaGroupSum;", "maximumNumberThreadsPerBlock", "(Ljava/lang/String;[FIZILcom/komputation/cuda/optimization/BaseCudaUpdateRule;Lkotlin/jvm/functions/Function0;Lcom/komputation/cuda/layers/entry/CudaHashing;Lcom/komputation/cuda/layers/entry/CudaGroupSum;I)V", "batchInputs", "", "", "[[I", "deviceForwardResult", "Ljcuda/Pointer;", "getDeviceForwardResult", "()Ljcuda/Pointer;", "deviceIndices", "deviceVectors", "forwardKernel", "forwardResult", "getHasFixedLength", "()Z", "indices", "launchConfiguration", "Lcom/komputation/cuda/kernels/launch/KernelLaunchConfiguration;", "maximumBatchSize", "maximumOutputColumns", "getMaximumOutputColumns", "()I", "maximumParameters", "numberEntries", "numberIterations", "numberOutputRows", "getNumberOutputRows", "numberVectorEntries", "numbersOfColumns", "pointerToDeviceVectors", "kotlin.jvm.PlatformType", "pointerToDimension", "pointerToForwardResult", "pointerToIndices", "pointerToMaximumBatchSize", "pointerToMaximumLength", "pointerToNumberIterations", "acquire", "", "backward", "chain", "forward", "batchId", "batchSize", "batch", "inputs", "Lcom/komputation/matrix/Matrix;", "memory", "Lcom/komputation/cuda/memory/InputMemory;", "(II[I[Lcom/komputation/matrix/Matrix;Lcom/komputation/cuda/memory/InputMemory;)Ljcuda/Pointer;", "getIndices", "(Lcom/komputation/cuda/memory/InputMemory;I[I[Lcom/komputation/matrix/Matrix;II)Ljcuda/Pointer;", "optimize", "release", "komputation"})
/* loaded from: input_file:com/komputation/cuda/layers/entry/CudaLookupLayer.class */
public final class CudaLookupLayer extends BaseCudaEntryPoint implements CudaForwardState, Resourceful, Optimizable {
    private final int numberVectorEntries;

    @NotNull
    private final Pointer deviceForwardResult;
    private final Pointer pointerToForwardResult;
    private final int numberOutputRows;
    private final int maximumOutputColumns;
    private final int numberEntries;
    private Pointer deviceIndices;
    private Pointer pointerToIndices;
    private int[] numbersOfColumns;
    private int[] indices;
    private float[] forwardResult;
    private int[] maximumBatchSize;
    private Kernel forwardKernel;
    private final Pointer deviceVectors;
    private final Pointer pointerToDeviceVectors;
    private final KernelLaunchConfiguration launchConfiguration;
    private final Pointer pointerToMaximumBatchSize;
    private final Pointer pointerToMaximumLength;
    private final Pointer pointerToDimension;
    private final int numberIterations;
    private final Pointer pointerToNumberIterations;
    private int[][] batchInputs;
    private int maximumParameters;
    private float[] vectors;
    private final boolean hasFixedLength;
    private final BaseCudaUpdateRule updateRule;
    private final Function0<Kernel> createForwardKernel;
    private final CudaHashing hashing;
    private final CudaGroupSum groupSum;
    private final int maximumNumberThreadsPerBlock;

    @Override // com.komputation.cuda.CudaForwardState
    @NotNull
    public Pointer getDeviceForwardResult() {
        return this.deviceForwardResult;
    }

    @Override // com.komputation.cuda.CudaForwardState
    public int getNumberOutputRows() {
        return this.numberOutputRows;
    }

    @Override // com.komputation.cuda.CudaForwardState
    public int getMaximumOutputColumns() {
        return this.maximumOutputColumns;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.komputation.layers.Resourceful
    public void acquire(int i) {
        this.maximumBatchSize[0] = i;
        this.numbersOfColumns = new int[i];
        this.indices = new int[i * getMaximumOutputColumns()];
        this.forwardKernel = (Kernel) this.createForwardKernel.invoke();
        CudaFloatArrayKt.setFloatArray(this.vectors, this.numberVectorEntries, this.deviceVectors);
        CudaFloatArrayKt.allocateDeviceFloatMemory(getDeviceForwardResult(), i * this.numberEntries);
        int[] iArr = new int[i];
        int length = iArr.length;
        for (int i2 = 0; i2 < length; i2++) {
            iArr[i2] = new int[0];
        }
        this.batchInputs = (int[][]) iArr;
        this.hashing.acquire(i);
        this.maximumParameters = this.hashing.getMaximumKeys();
        this.groupSum.acquire(i);
    }

    @Override // com.komputation.layers.Resourceful
    public void release() {
        this.numbersOfColumns = new int[0];
        this.indices = new int[0];
        this.forwardResult = new float[0];
        this.maximumBatchSize[0] = -1;
        Kernel kernel = this.forwardKernel;
        if (kernel == null) {
            Intrinsics.throwNpe();
        }
        kernel.destroy();
        this.vectors = CudaFloatArrayKt.getFloatArray(this.deviceVectors, this.numberVectorEntries);
        JCuda.cudaFree(this.deviceVectors);
        this.hashing.release();
        this.maximumParameters = -1;
        this.groupSum.release();
    }

    @Override // com.komputation.cuda.layers.CudaEntryPoint
    @NotNull
    public Pointer forward(int i, int i2, @NotNull int[] iArr, @NotNull Matrix[] matrixArr, @NotNull InputMemory inputMemory) {
        Intrinsics.checkParameterIsNotNull(iArr, "batch");
        Intrinsics.checkParameterIsNotNull(matrixArr, "inputs");
        Intrinsics.checkParameterIsNotNull(inputMemory, "memory");
        int i3 = this.maximumBatchSize[0];
        this.deviceIndices = getIndices(inputMemory, i, iArr, matrixArr, i2, i3);
        this.pointerToIndices = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.deviceIndices});
        Kernel kernel = this.forwardKernel;
        if (kernel == null) {
            Intrinsics.throwNpe();
        }
        Pointer pointer = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.pointerToDeviceVectors, (NativePointerObject) this.pointerToIndices, (NativePointerObject) this.pointerToForwardResult, (NativePointerObject) this.pointerToMaximumBatchSize, (NativePointerObject) this.pointerToMaximumLength, (NativePointerObject) this.pointerToDimension, (NativePointerObject) this.pointerToNumberIterations});
        Intrinsics.checkExpressionValueIsNotNull(pointer, "Pointer.to(\n            …rIterations\n            )");
        kernel.launch(pointer, i3, this.launchConfiguration.getNumberBlocks(), this.launchConfiguration.getNumberThreadsPerBlock(), 0);
        return getDeviceForwardResult();
    }

    private final Pointer getIndices(InputMemory inputMemory, int i, int[] iArr, Matrix[] matrixArr, int i2, int i3) {
        int[] iArr2;
        Pointer tryToGetData = inputMemory.tryToGetData(i);
        if (tryToGetData != null) {
            return tryToGetData;
        }
        int i4 = 0;
        for (IndexedValue indexedValue : ArraysKt.withIndex(iArr)) {
            int component1 = indexedValue.component1();
            Matrix matrix = matrixArr[((Number) indexedValue.component2()).intValue()];
            if (matrix == null) {
                throw new TypeCastException("null cannot be cast to non-null type com.komputation.matrix.IntMatrix");
            }
            IntMatrix intMatrix = (IntMatrix) matrix;
            int[] entries = intMatrix.getEntries();
            int numberEntries = intMatrix.getNumberEntries();
            this.batchInputs[component1] = entries;
            this.numbersOfColumns[component1] = numberEntries;
            i4 += numberEntries;
            if (getHasFixedLength()) {
                iArr2 = entries;
            } else {
                int[] iArr3 = new int[getMaximumOutputColumns()];
                PaddingKt.pad(entries, numberEntries, getMaximumOutputColumns(), -1, iArr3);
                iArr2 = iArr3;
            }
            ConcatenationKt.concatenate(iArr2, component1 * getMaximumOutputColumns(), getMaximumOutputColumns(), this.indices);
        }
        if (i2 < this.maximumBatchSize[0]) {
            Arrays.fill(this.indices, i2 * getMaximumOutputColumns(), this.maximumBatchSize[0] * getMaximumOutputColumns(), -1);
        }
        Pointer pointer = new Pointer();
        CudaIntArrayKt.setIntArray(this.indices, this.indices.length, pointer);
        CudaIntArrayKt.setIntArray(this.numbersOfColumns, this.numbersOfColumns.length, new Pointer());
        inputMemory.setData(i, pointer);
        if (!getHasFixedLength()) {
            int[] iArr4 = new int[i3];
            int length = iArr4.length;
            for (int i5 = 0; i5 < length; i5++) {
                int i6 = i5;
                iArr4[i5] = i6 < i2 ? matrixArr[iArr[i6]].getNumberEntries() : 0;
            }
            Pointer pointer2 = new Pointer();
            CudaIntArrayKt.setIntArray(iArr4, i3, pointer2);
            inputMemory.setLengths(i, pointer2);
        }
        return pointer;
    }

    @Override // com.komputation.cuda.layers.CudaEntryPoint
    @NotNull
    public Pointer backward(@NotNull Pointer pointer) {
        Intrinsics.checkParameterIsNotNull(pointer, "chain");
        this.hashing.reset();
        CudaHashing cudaHashing = this.hashing;
        Pointer pointer2 = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.deviceIndices});
        Intrinsics.checkExpressionValueIsNotNull(pointer2, "Pointer.to(this.deviceIndices)");
        cudaHashing.hash(pointer2);
        this.groupSum.reset();
        CudaGroupSum cudaGroupSum = this.groupSum;
        Pointer pointerToMapping = this.hashing.getPointerToMapping();
        Intrinsics.checkExpressionValueIsNotNull(pointerToMapping, "this.hashing.getPointerToMapping()");
        Pointer pointer3 = Pointer.to(new NativePointerObject[]{(NativePointerObject) pointer});
        Intrinsics.checkExpressionValueIsNotNull(pointer3, "Pointer.to(chain)");
        cudaGroupSum.sum(pointerToMapping, pointer3);
        return pointer;
    }

    @Override // com.komputation.optimization.Optimizable
    public void optimize(int i) {
        BaseCudaUpdateRule baseCudaUpdateRule = this.updateRule;
        if (baseCudaUpdateRule != null) {
            int i2 = this.maximumParameters;
            Pointer pointerToHashTable = this.hashing.getPointerToHashTable();
            Intrinsics.checkExpressionValueIsNotNull(pointerToHashTable, "this.hashing.getPointerToHashTable()");
            Pointer pointerToCounts = this.hashing.getPointerToCounts();
            Intrinsics.checkExpressionValueIsNotNull(pointerToCounts, "this.hashing.getPointerToCounts()");
            Pointer pointer = this.pointerToDeviceVectors;
            Intrinsics.checkExpressionValueIsNotNull(pointer, "this.pointerToDeviceVectors");
            Pointer pointerToSum = this.groupSum.getPointerToSum();
            Intrinsics.checkExpressionValueIsNotNull(pointerToSum, "this.groupSum.getPointerToSum()");
            baseCudaUpdateRule.sparseUpdate(i2, pointerToHashTable, pointerToCounts, pointer, pointerToSum);
        }
    }

    @Override // com.komputation.cuda.layers.CudaEntryPoint
    public boolean getHasFixedLength() {
        return this.hasFixedLength;
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    /* JADX WARN: Multi-variable type inference failed */
    public CudaLookupLayer(@Nullable String str, @NotNull float[] fArr, int i, boolean z, int i2, @Nullable BaseCudaUpdateRule baseCudaUpdateRule, @NotNull Function0<Kernel> function0, @NotNull CudaHashing cudaHashing, @NotNull CudaGroupSum cudaGroupSum, int i3) {
        super(str);
        Intrinsics.checkParameterIsNotNull(fArr, "vectors");
        Intrinsics.checkParameterIsNotNull(function0, "createForwardKernel");
        Intrinsics.checkParameterIsNotNull(cudaHashing, "hashing");
        Intrinsics.checkParameterIsNotNull(cudaGroupSum, "groupSum");
        this.vectors = fArr;
        this.hasFixedLength = z;
        this.updateRule = baseCudaUpdateRule;
        this.createForwardKernel = function0;
        this.hashing = cudaHashing;
        this.groupSum = cudaGroupSum;
        this.maximumNumberThreadsPerBlock = i3;
        this.numberVectorEntries = this.vectors.length;
        this.deviceForwardResult = new Pointer();
        this.pointerToForwardResult = Pointer.to(new NativePointerObject[]{(NativePointerObject) getDeviceForwardResult()});
        this.numberOutputRows = i2;
        this.maximumOutputColumns = i;
        this.numberEntries = getNumberOutputRows() * getMaximumOutputColumns();
        this.deviceIndices = new Pointer();
        this.pointerToIndices = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.deviceIndices});
        this.numbersOfColumns = new int[0];
        this.indices = new int[0];
        this.forwardResult = new float[0];
        this.maximumBatchSize = new int[]{-1};
        this.deviceVectors = new Pointer();
        this.pointerToDeviceVectors = Pointer.to(new NativePointerObject[]{(NativePointerObject) this.deviceVectors});
        this.launchConfiguration = ColumnwiseKt.computeColumnwiseLaunchConfiguration(getNumberOutputRows(), getMaximumOutputColumns(), this.maximumNumberThreadsPerBlock);
        this.pointerToMaximumBatchSize = Pointer.to(this.maximumBatchSize);
        this.pointerToMaximumLength = Pointer.to(new int[]{getMaximumOutputColumns()});
        this.pointerToDimension = Pointer.to(new int[]{getNumberOutputRows()});
        this.numberIterations = this.launchConfiguration.getNumberIterations();
        this.pointerToNumberIterations = Pointer.to(new int[]{this.numberIterations});
        this.batchInputs = (int[][]) new int[0];
        this.maximumParameters = -1;
    }
}
