package com.komputation.cpu.optimization;

import com.komputation.matrix.IntMatrixKt;
import java.util.Arrays;
import kotlin.Metadata;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;

/* compiled from: SparseAccumulator.kt */
@Metadata(mv = {1, 1, 9}, bv = {1, 0, 2}, k = 1, d1 = {"��6\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0010\b\n\u0002\b\u0005\n\u0002\u0010\u0014\n��\n\u0002\u0010\u0015\n\u0002\b\u0003\n\u0002\u0010\u0011\n\u0002\b\u0002\n\u0002\u0010\u0018\n��\n\u0002\u0010\u0002\n\u0002\b\u000f\u0018��2\u00020\u0001B%\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0003\u0012\u0006\u0010\u0005\u001a\u00020\u0003\u0012\u0006\u0010\u0006\u001a\u00020\u0003¢\u0006\u0002\u0010\u0007J\u001e\u0010\u0013\u001a\u00020\u00142\u0006\u0010\u0015\u001a\u00020\u000b2\u0006\u0010\u0016\u001a\u00020\u00032\u0006\u0010\u0017\u001a\u00020\tJ \u0010\u0018\u001a\u00020\u00142\u0006\u0010\u0019\u001a\u00020\u00032\u0006\u0010\u001a\u001a\u00020\u00032\u0006\u0010\u0017\u001a\u00020\tH\u0002J\u0006\u0010\u001b\u001a\u00020\tJ\u0006\u0010\u001c\u001a\u00020\u000bJ\u0006\u0010\u001d\u001a\u00020\u0003J\u0011\u0010\u001e\u001a\b\u0012\u0004\u0012\u00020\t0\u000f¢\u0006\u0002\u0010\u001fJ\u0010\u0010 \u001a\u00020\u00032\u0006\u0010!\u001a\u00020\u0003H\u0002J\u0006\u0010\"\u001a\u00020\u0014R\u000e\u0010\b\u001a\u00020\tX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0006\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\n\u001a\u00020\u000bX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\f\u001a\u00020\u0003X\u0082\u000e¢\u0006\u0002\n��R\u000e\u0010\r\u001a\u00020\u000bX\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010\u000e\u001a\b\u0012\u0004\u0012\u00020\t0\u000fX\u0082\u0004¢\u0006\u0004\n\u0002\u0010\u0010R\u000e\u0010\u0011\u001a\u00020\u0012X\u0082\u0004¢\u0006\u0002\n��¨\u0006#"}, d2 = {"Lcom/komputation/cpu/optimization/SparseAccumulator;", "", "numberVectors", "", "maximumBatchSize", "maximumLength", "dimension", "(IIII)V", "counts", "", "hashTable", "", "lastNewId", "reverseHashTable", "sums", "", "[[F", "visited", "", "accumulate", "", "ids", "numberIds", "gradient", "addToSum", "indexId", "hashedId", "getCounts", "getIds", "getSize", "getSums", "()[[F", "hashId", "id", "reset", "komputation"})
/* loaded from: input_file:com/komputation/cpu/optimization/SparseAccumulator.class */
public final class SparseAccumulator {
    private final int[] hashTable;
    private final boolean[] visited;
    private final int[] reverseHashTable;
    private final float[] counts;
    private final float[][] sums;
    private int lastNewId;
    private final int dimension;

    public final void accumulate(@NotNull int[] iArr, int i, @NotNull float[] fArr) {
        Intrinsics.checkParameterIsNotNull(iArr, "ids");
        Intrinsics.checkParameterIsNotNull(fArr, "gradient");
        for (int i2 = 0; i2 < i; i2++) {
            int i3 = iArr[i2];
            int hashId = hashId(i3);
            addToSum(i2, hashId, fArr);
            if (!this.visited[i3]) {
                float[] fArr2 = this.counts;
                fArr2[hashId] = fArr2[hashId] + 1.0f;
                this.visited[i3] = true;
            }
        }
        for (int i4 : iArr) {
            this.visited[i4] = false;
        }
    }

    private final int hashId(int i) {
        int i2 = this.hashTable[i];
        if (i2 != -1) {
            return i2;
        }
        this.lastNewId++;
        int i3 = this.lastNewId;
        this.hashTable[i] = i3;
        this.reverseHashTable[i3] = i;
        return i3;
    }

    private final void addToSum(int i, int i2, float[] fArr) {
        float[] fArr2 = this.sums[i2];
        int i3 = i * this.dimension;
        int i4 = this.dimension;
        for (int i5 = 0; i5 < i4; i5++) {
            int i6 = i5;
            fArr2[i6] = fArr2[i6] + fArr[i3 + i5];
        }
    }

    public final int getSize() {
        return this.lastNewId + 1;
    }

    @NotNull
    public final int[] getIds() {
        return this.reverseHashTable;
    }

    @NotNull
    public final float[] getCounts() {
        return this.counts;
    }

    @NotNull
    public final float[][] getSums() {
        return this.sums;
    }

    public final void reset() {
        int i = 0;
        int i2 = this.lastNewId;
        if (0 <= i2) {
            while (true) {
                int i3 = this.reverseHashTable[i];
                this.reverseHashTable[i] = -1;
                this.hashTable[i3] = -1;
                this.counts[i] = 0.0f;
                Arrays.fill(this.sums[i], 0.0f);
                if (i == i2) {
                    break;
                } else {
                    i++;
                }
            }
        }
        this.lastNewId = -1;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public SparseAccumulator(int i, int i2, int i3, int i4) {
        this.dimension = i4;
        this.hashTable = IntMatrixKt.constantIntArray(i, -1);
        this.visited = new boolean[i];
        this.reverseHashTable = IntMatrixKt.constantIntArray(i2 * i3, -1);
        this.counts = new float[i2 * i3];
        float[] fArr = new float[i2 * i3];
        int length = fArr.length;
        for (int i5 = 0; i5 < length; i5++) {
            fArr[i5] = new float[this.dimension];
        }
        this.sums = (float[][]) fArr;
        this.lastNewId = -1;
    }
}
