package com.komputation.cpu.layers.forward;

import com.komputation.cpu.functions.RowSplitKt;
import com.komputation.cpu.functions.StackingKt;
import com.komputation.cpu.layers.BaseCpuForwardLayerKt;
import com.komputation.cpu.layers.CpuForwardLayer;
import com.komputation.cpu.layers.VariableLengthFloatArray;
import com.komputation.optimization.Optimizable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import kotlin.Metadata;
import kotlin.collections.ArraysKt;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: CpuConcatenation.kt */
@Metadata(mv = {1, 1, 9}, bv = {1, 0, 2}, k = 1, d1 = {"��P\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000e\n��\n\u0002\u0010\u0011\n\u0002\b\u0002\n\u0002\u0010\u0014\n\u0002\b\u0005\n\u0002\u0018\u0002\n��\n\u0002\u0010 \n\u0002\b\t\n\u0002\u0010\b\n\u0002\b\u0010\n\u0002\u0010\u0015\n\u0002\b\u000b\n\u0002\u0010\u000b\n��\n\u0002\u0010\u0002\n\u0002\b\u0002\u0018��2\u00020\u00012\u00020\u0002B!\b��\u0012\n\b\u0002\u0010\u0003\u001a\u0004\u0018\u00010\u0004\u0012\f\u0010\u0005\u001a\b\u0012\u0004\u0012\u00020\u00010\u0006¢\u0006\u0002\u0010\u0007J\u0018\u00102\u001a\u00020\t2\u0006\u00103\u001a\u00020\u001b2\u0006\u00104\u001a\u00020\tH\u0016J(\u00105\u001a\u00020\t2\u0006\u00103\u001a\u00020\u001b2\u0006\u0010\u001e\u001a\u00020\u001b2\u0006\u00106\u001a\u00020\t2\u0006\u00107\u001a\u000208H\u0016J\u0010\u00109\u001a\u00020:2\u0006\u0010;\u001a\u00020\u001bH\u0016R\u001a\u0010\b\u001a\u00020\tX\u0096\u000e¢\u0006\u000e\n��\u001a\u0004\b\n\u0010\u000b\"\u0004\b\f\u0010\rR\u000e\u0010\u000e\u001a\u00020\u000fX\u0082\u0004¢\u0006\u0002\n��R\u001a\u0010\u0010\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020\t0\u00060\u0011X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0012\u001a\u00020\u0001X\u0082\u0004¢\u0006\u0002\n��R\u001a\u0010\u0013\u001a\u00020\tX\u0096\u000e¢\u0006\u000e\n��\u001a\u0004\b\u0014\u0010\u000b\"\u0004\b\u0015\u0010\rR\u000e\u0010\u0016\u001a\u00020\u000fX\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010\u0017\u001a\b\u0012\u0004\u0012\u00020\t0\u0006X\u0082\u0004¢\u0006\u0004\n\u0002\u0010\u0018R\u0016\u0010\u0005\u001a\b\u0012\u0004\u0012\u00020\u00010\u0006X\u0082\u0004¢\u0006\u0004\n\u0002\u0010\u0019R\u000e\u0010\u001a\u001a\u00020\u001bX\u0082\u0004¢\u0006\u0002\n��R\u0013\u0010\u0003\u001a\u0004\u0018\u00010\u0004¢\u0006\b\n��\u001a\u0004\b\u001c\u0010\u001dR\u001a\u0010\u001e\u001a\u00020\u001bX\u0096\u000e¢\u0006\u000e\n��\u001a\u0004\b\u001f\u0010 \"\u0004\b!\u0010\"R\u0014\u0010#\u001a\u00020\u001bX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b$\u0010 R\u000e\u0010%\u001a\u00020\u001bX\u0082\u0004¢\u0006\u0002\n��R\u001a\u0010&\u001a\u00020\u001bX\u0096\u000e¢\u0006\u000e\n��\u001a\u0004\b'\u0010 \"\u0004\b(\u0010\"R\u0014\u0010)\u001a\u00020\u001bX\u0096\u0004¢\u0006\b\n��\u001a\u0004\b*\u0010 R\u000e\u0010+\u001a\u00020,X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010-\u001a\u00020,8VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b.\u0010/R\u0014\u00100\u001a\u00020,8VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b1\u0010/¨\u0006<"}, d2 = {"Lcom/komputation/cpu/layers/forward/CpuConcatenation;", "Lcom/komputation/cpu/layers/CpuForwardLayer;", "Lcom/komputation/optimization/Optimizable;", "name", "", "layers", "", "(Ljava/lang/String;[Lcom/komputation/cpu/layers/CpuForwardLayer;)V", "backwardResult", "", "getBackwardResult", "()[F", "setBackwardResult", "([F)V", "backwardStore", "Lcom/komputation/cpu/layers/VariableLengthFloatArray;", "chainSplitsOverPossibleLengths", "", "firstLayer", "forwardResult", "getForwardResult", "setForwardResult", "forwardStore", "individualResults", "[[F", "[Lcom/komputation/cpu/layers/CpuForwardLayer;", "minimumOutputLength", "", "getName", "()Ljava/lang/String;", "numberInputColumns", "getNumberInputColumns", "()I", "setNumberInputColumns", "(I)V", "numberInputRows", "getNumberInputRows", "numberLayers", "numberOutputColumns", "getNumberOutputColumns", "setNumberOutputColumns", "numberOutputRows", "getNumberOutputRows", "numbersOfOutputRows", "", "possibleInputLengths", "getPossibleInputLengths", "()[I", "possibleOutputLengths", "getPossibleOutputLengths", "backward", "withinBatch", "chain", "forward", "input", "isTraining", "", "optimize", "", "batchSize", "komputation"})
/* loaded from: input_file:com/komputation/cpu/layers/forward/CpuConcatenation.class */
public final class CpuConcatenation implements CpuForwardLayer, Optimizable {
    private final CpuForwardLayer firstLayer;
    private final int numberLayers;
    private final int numberInputRows;
    private int numberInputColumns;
    private final int[] numbersOfOutputRows;
    private final int numberOutputRows;
    private int numberOutputColumns;
    private final int minimumOutputLength;
    private final VariableLengthFloatArray forwardStore;

    @NotNull
    private float[] forwardResult;
    private final VariableLengthFloatArray backwardStore;

    @NotNull
    private float[] backwardResult;
    private final float[][] individualResults;
    private final List<float[][]> chainSplitsOverPossibleLengths;

    @Nullable
    private final String name;
    private final CpuForwardLayer[] layers;

    @Override // com.komputation.cpu.layers.CpuBackwardState
    public int getNumberInputRows() {
        return this.numberInputRows;
    }

    @Override // com.komputation.cpu.layers.CpuBackwardState
    public int getNumberInputColumns() {
        return this.numberInputColumns;
    }

    public void setNumberInputColumns(int i) {
        this.numberInputColumns = i;
    }

    @Override // com.komputation.cpu.layers.CpuVariableLengthBackwardState
    @NotNull
    public int[] getPossibleInputLengths() {
        return this.firstLayer.getPossibleInputLengths();
    }

    @Override // com.komputation.cpu.layers.CpuForwardState
    public int getNumberOutputRows() {
        return this.numberOutputRows;
    }

    @Override // com.komputation.cpu.layers.CpuForwardState
    public int getNumberOutputColumns() {
        return this.numberOutputColumns;
    }

    public void setNumberOutputColumns(int i) {
        this.numberOutputColumns = i;
    }

    @Override // com.komputation.cpu.layers.CpuVariableLengthForwardState
    @NotNull
    public int[] getPossibleOutputLengths() {
        return this.firstLayer.getPossibleOutputLengths();
    }

    @Override // com.komputation.cpu.layers.CpuForwardState
    @NotNull
    public float[] getForwardResult() {
        return this.forwardResult;
    }

    public void setForwardResult(@NotNull float[] fArr) {
        Intrinsics.checkParameterIsNotNull(fArr, "<set-?>");
        this.forwardResult = fArr;
    }

    @Override // com.komputation.cpu.layers.CpuBackwardState
    @NotNull
    public float[] getBackwardResult() {
        return this.backwardResult;
    }

    public void setBackwardResult(@NotNull float[] fArr) {
        Intrinsics.checkParameterIsNotNull(fArr, "<set-?>");
        this.backwardResult = fArr;
    }

    @Override // com.komputation.cpu.layers.CpuForwardLayer
    @NotNull
    public float[] forward(int i, int i2, @NotNull float[] fArr, boolean z) {
        Intrinsics.checkParameterIsNotNull(fArr, "input");
        this.individualResults[0] = this.firstLayer.forward(i, i2, fArr, z);
        setNumberInputColumns(this.firstLayer.getNumberInputColumns());
        setNumberOutputColumns(this.firstLayer.getNumberOutputColumns());
        setForwardResult(this.forwardStore.get(getNumberOutputColumns()));
        int i3 = this.numberLayers;
        for (int i4 = 1; i4 < i3; i4++) {
            this.individualResults[i4] = this.layers[i4].forward(i, i2, fArr, z);
        }
        int[] iArr = this.numbersOfOutputRows;
        int numberOutputRows = getNumberOutputRows();
        int numberOutputColumns = getNumberOutputColumns();
        float[] forwardResult = getForwardResult();
        float[][] fArr2 = this.individualResults;
        StackingKt.stackRows(iArr, numberOutputRows, numberOutputColumns, forwardResult, (float[][]) Arrays.copyOf(fArr2, fArr2.length));
        return getForwardResult();
    }

    @Override // com.komputation.cpu.layers.CpuForwardLayer
    @NotNull
    public float[] backward(int i, @NotNull float[] fArr) {
        Intrinsics.checkParameterIsNotNull(fArr, "chain");
        setBackwardResult(this.backwardStore.get(getNumberInputColumns()));
        float[][] fArr2 = this.chainSplitsOverPossibleLengths.get(BaseCpuForwardLayerKt.computeLengthIndex(getNumberOutputColumns(), this.minimumOutputLength));
        RowSplitKt.splitRows(getNumberOutputRows(), getNumberOutputColumns(), fArr, this.numbersOfOutputRows, this.numberLayers, fArr2);
        float[] backward = this.firstLayer.backward(i, fArr2[0]);
        System.arraycopy(backward, 0, getBackwardResult(), 0, backward.length);
        int i2 = this.numberLayers;
        for (int i3 = 1; i3 < i2; i3++) {
            float[] backward2 = this.layers[i3].backward(i, fArr2[i3]);
            int length = backward2.length;
            for (int i4 = 0; i4 < length; i4++) {
                float[] backwardResult = getBackwardResult();
                int i5 = i4;
                backwardResult[i5] = backwardResult[i5] + backward2[i4];
            }
        }
        return getBackwardResult();
    }

    @Override // com.komputation.optimization.Optimizable
    public void optimize(int i) {
        for (CpuForwardLayer cpuForwardLayer : this.layers) {
            if (cpuForwardLayer instanceof Optimizable) {
                ((Optimizable) cpuForwardLayer).optimize(i);
            }
        }
    }

    @Nullable
    public final String getName() {
        return this.name;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public CpuConcatenation(@Nullable String str, @NotNull CpuForwardLayer[] cpuForwardLayerArr) {
        Intrinsics.checkParameterIsNotNull(cpuForwardLayerArr, "layers");
        this.name = str;
        this.layers = cpuForwardLayerArr;
        this.firstLayer = (CpuForwardLayer) ArraysKt.first(this.layers);
        this.numberLayers = this.layers.length;
        this.numberInputRows = this.firstLayer.getNumberInputRows();
        this.numberInputColumns = -1;
        int[] iArr = new int[this.numberLayers];
        int length = iArr.length;
        for (int i = 0; i < length; i++) {
            iArr[i] = this.layers[i].getNumberOutputRows();
        }
        this.numbersOfOutputRows = iArr;
        this.numberOutputRows = ArraysKt.sum(this.numbersOfOutputRows);
        this.numberOutputColumns = -1;
        Integer min = ArraysKt.min(getPossibleOutputLengths());
        if (min == null) {
            Intrinsics.throwNpe();
        }
        this.minimumOutputLength = min.intValue();
        this.forwardStore = new VariableLengthFloatArray(getNumberOutputRows(), getPossibleOutputLengths());
        this.forwardResult = new float[0];
        this.backwardStore = new VariableLengthFloatArray(getNumberInputRows(), getPossibleInputLengths());
        this.backwardResult = new float[0];
        float[] fArr = new float[this.numberLayers];
        int length2 = fArr.length;
        for (int i2 = 0; i2 < length2; i2++) {
            fArr[i2] = new float[0];
        }
        this.individualResults = (float[][]) fArr;
        int[] possibleOutputLengths = getPossibleOutputLengths();
        ArrayList arrayList = new ArrayList(possibleOutputLengths.length);
        for (int i3 : possibleOutputLengths) {
            float[] fArr2 = new float[this.numberLayers];
            int length3 = fArr2.length;
            for (int i4 = 0; i4 < length3; i4++) {
                fArr2[i4] = new float[this.layers[i4].getNumberOutputRows() * i3];
            }
            arrayList.add((float[][]) fArr2);
        }
        this.chainSplitsOverPossibleLengths = arrayList;
    }

    public /* synthetic */ CpuConcatenation(String str, CpuForwardLayer[] cpuForwardLayerArr, int i, DefaultConstructorMarker defaultConstructorMarker) {
        this((i & 1) != 0 ? (String) null : str, cpuForwardLayerArr);
    }
}
