package com.komputation.cpu.layers.recurrent;

import com.komputation.cpu.functions.ColumnSelectionKt;
import com.komputation.cpu.layers.BaseCpuForwardLayerKt;
import com.komputation.cpu.layers.BaseCpuHigherOrderLayer;
import com.komputation.cpu.layers.VariableLengthFloatArray;
import com.komputation.cpu.layers.combination.CpuAdditionCombination;
import com.komputation.cpu.layers.forward.projection.CpuWeightingLayer;
import com.komputation.cpu.layers.recurrent.extraction.ResultExtractionStrategy;
import com.komputation.cpu.layers.recurrent.series.ParameterizedSeries;
import com.komputation.cpu.layers.recurrent.series.Series;
import com.komputation.optimization.Optimizable;
import kotlin.Metadata;
import kotlin.NoWhenBranchMatchedException;
import kotlin.jvm.internal.Intrinsics;
import kotlin.ranges.IntProgression;
import kotlin.ranges.IntRange;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/* compiled from: CpuRecurrentLayer.kt */
@Metadata(mv = {1, 1, 9}, bv = {1, 0, 2}, k = 1, d1 = {"��|\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000e\n��\n\u0002\u0010\b\n\u0002\b\u0003\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0014\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0011\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u0015\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\t\n\u0002\u0010\u000b\n��\n\u0002\u0010\u0002\n\u0002\b\u0002\u0018��2\u00020\u00012\u00020\u0002Bo\u0012\b\u0010\u0003\u001a\u0004\u0018\u00010\u0004\u0012\u0006\u0010\u0005\u001a\u00020\u0006\u0012\u0006\u0010\u0007\u001a\u00020\u0006\u0012\u0006\u0010\b\u001a\u00020\u0006\u0012\u0006\u0010\t\u001a\u00020\n\u0012\u0006\u0010\u000b\u001a\u00020\f\u0012\u0006\u0010\r\u001a\u00020\u000e\u0012\u0006\u0010\u000f\u001a\u00020\u0010\u0012\u0006\u0010\u0011\u001a\u00020\u0012\u0012\f\u0010\u0013\u001a\b\u0012\u0004\u0012\u00020\u00150\u0014\u0012\b\u0010\u0016\u001a\u0004\u0018\u00010\u0012\u0012\u0006\u0010\u0017\u001a\u00020\u0018¢\u0006\u0002\u0010\u0019J\u0018\u0010)\u001a\u00020\u00102\u0006\u0010*\u001a\u00020\u00062\u0006\u0010+\u001a\u00020\u0010H\u0016J(\u0010,\u001a\u00020\u00102\u0006\u0010*\u001a\u00020\u00062\u0006\u0010-\u001a\u00020\u00062\u0006\u0010.\u001a\u00020\u00102\u0006\u0010/\u001a\u000200H\u0016J\u0010\u00101\u001a\u0002022\u0006\u00103\u001a\u00020\u0006H\u0016R\u000e\u0010\u0017\u001a\u00020\u0018X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010\u0013\u001a\b\u0012\u0004\u0012\u00020\u00150\u0014X\u0082\u0004¢\u0006\u0004\n\u0002\u0010\u001aR\u0018\u0010\u001b\u001a\n\u0012\u0006\b\u0001\u0012\u00020\u001c0\u0014X\u0082\u0004¢\u0006\u0004\n\u0002\u0010\u001dR\u000e\u0010\u001e\u001a\u00020\u001fX\u0082\u0004¢\u0006\u0002\n��R\u0010\u0010\u0016\u001a\u0004\u0018\u00010\u0012X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u000b\u001a\u00020\fX\u0082\u0004¢\u0006\u0002\n��R\u0018\u0010 \u001a\n\u0012\u0006\b\u0001\u0012\u00020\u001c0\u0014X\u0082\u0004¢\u0006\u0004\n\u0002\u0010\u001dR\u000e\u0010\b\u001a\u00020\u0006X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u000f\u001a\u00020\u0010X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\t\u001a\u00020\nX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0007\u001a\u00020\u0006X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\u0005\u001a\u00020\u0006X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010!\u001a\u00020\u0006X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\"\u001a\u00020#X\u0082\u0004¢\u0006\u0002\n��R\u0016\u0010$\u001a\b\u0012\u0004\u0012\u00020\u001c0\u0014X\u0082\u0004¢\u0006\u0004\n\u0002\u0010\u001dR\u0016\u0010%\u001a\b\u0012\u0004\u0012\u00020&0\u0014X\u0082\u0004¢\u0006\u0004\n\u0002\u0010'R\u000e\u0010\u0011\u001a\u00020\u0012X\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010\r\u001a\u00020\u000eX\u0082\u0004¢\u0006\u0002\n��R\u000e\u0010(\u001a\u00020\u0010X\u0082\u0004¢\u0006\u0002\n��¨\u00064"}, d2 = {"Lcom/komputation/cpu/layers/recurrent/CpuRecurrentLayer;", "Lcom/komputation/cpu/layers/BaseCpuHigherOrderLayer;", "Lcom/komputation/optimization/Optimizable;", "name", "", "minimumSteps", "", "maximumSteps", "hiddenDimension", "inputWeighting", "Lcom/komputation/cpu/layers/forward/projection/CpuWeightingLayer;", "direction", "Lcom/komputation/cpu/layers/recurrent/Direction;", "resultExtraction", "Lcom/komputation/cpu/layers/recurrent/extraction/ResultExtractionStrategy;", "initialState", "", "previousHiddenStateWeighting", "Lcom/komputation/cpu/layers/recurrent/series/ParameterizedSeries;", "additions", "", "Lcom/komputation/cpu/layers/combination/CpuAdditionCombination;", "bias", "activation", "Lcom/komputation/cpu/layers/recurrent/series/Series;", "(Ljava/lang/String;IIILcom/komputation/cpu/layers/forward/projection/CpuWeightingLayer;Lcom/komputation/cpu/layers/recurrent/Direction;Lcom/komputation/cpu/layers/recurrent/extraction/ResultExtractionStrategy;[FLcom/komputation/cpu/layers/recurrent/series/ParameterizedSeries;[Lcom/komputation/cpu/layers/combination/CpuAdditionCombination;Lcom/komputation/cpu/layers/recurrent/series/ParameterizedSeries;Lcom/komputation/cpu/layers/recurrent/series/Series;)V", "[Lcom/komputation/cpu/layers/combination/CpuAdditionCombination;", "backwardStepsOverPossibleLengths", "Lkotlin/ranges/IntProgression;", "[Lkotlin/ranges/IntProgression;", "backwardStore", "Lcom/komputation/cpu/layers/VariableLengthFloatArray;", "forwardStepsOverPossibleLengths", "numberPossibleLengths", "possibleLengths", "", "possibleStepsBackward", "possibleStepsForward", "Lkotlin/ranges/IntRange;", "[Lkotlin/ranges/IntRange;", "stepWeightedInput", "backward", "withinBatch", "chain", "forward", "numberInputColumns", "input", "isTraining", "", "optimize", "", "batchSize", "komputation"})
/* loaded from: input_file:com/komputation/cpu/layers/recurrent/CpuRecurrentLayer.class */
public final class CpuRecurrentLayer extends BaseCpuHigherOrderLayer implements Optimizable {
    private final int numberPossibleLengths;
    private final int[] possibleLengths;
    private final float[] stepWeightedInput;
    private final IntRange[] possibleStepsForward;
    private final IntProgression[] possibleStepsBackward;
    private final IntProgression[] forwardStepsOverPossibleLengths;
    private final IntProgression[] backwardStepsOverPossibleLengths;
    private final VariableLengthFloatArray backwardStore;
    private final int minimumSteps;
    private final int maximumSteps;
    private final int hiddenDimension;
    private final CpuWeightingLayer inputWeighting;
    private final Direction direction;
    private final ResultExtractionStrategy resultExtraction;
    private final float[] initialState;
    private final ParameterizedSeries previousHiddenStateWeighting;
    private final CpuAdditionCombination[] additions;
    private final ParameterizedSeries bias;
    private final Series activation;

    @Override // com.komputation.cpu.layers.CpuForwardLayer
    @NotNull
    public float[] forward(int i, int i2, @NotNull float[] fArr, boolean z) {
        Intrinsics.checkParameterIsNotNull(fArr, "input");
        float[] forward = this.inputWeighting.forward(i, i2, fArr, z);
        float[] fArr2 = this.initialState;
        IntProgression intProgression = this.forwardStepsOverPossibleLengths[BaseCpuForwardLayerKt.computeLengthIndex(getNumberInputColumns(), this.minimumSteps)];
        int first = intProgression.getFirst();
        int last = intProgression.getLast();
        int step = intProgression.getStep();
        if (step <= 0 ? first >= last : first <= last) {
            while (true) {
                ColumnSelectionKt.getColumn(forward, first, this.hiddenDimension, this.stepWeightedInput);
                float[] forward2 = this.additions[first].forward(this.stepWeightedInput, this.previousHiddenStateWeighting.forwardStep(i, first, 1, fArr2, z));
                fArr2 = this.activation.forwardStep(i, first, 1, this.bias != null ? this.bias.forwardStep(i, first, 1, forward2, z) : forward2, z);
                if (first == last) {
                    break;
                }
                first += step;
            }
        }
        return this.resultExtraction.extractResult(i2);
    }

    @Override // com.komputation.cpu.layers.CpuForwardLayer
    @NotNull
    public float[] backward(int i, @NotNull float[] fArr) {
        Intrinsics.checkParameterIsNotNull(fArr, "chain");
        float[] fArr2 = this.backwardStore.get(getNumberInputColumns());
        float[] fArr3 = (float[]) null;
        IntProgression intProgression = this.backwardStepsOverPossibleLengths[BaseCpuForwardLayerKt.computeLengthIndex(getNumberInputColumns(), this.minimumSteps)];
        int first = intProgression.getFirst();
        int last = intProgression.getLast();
        int step = intProgression.getStep();
        if (step <= 0 ? first >= last : first <= last) {
            while (true) {
                float[] backwardStep = this.activation.backwardStep(i, first, this.resultExtraction.backwardStep(first, fArr, fArr3));
                fArr3 = this.previousHiddenStateWeighting.backwardStep(i, first, backwardStep);
                ParameterizedSeries parameterizedSeries = this.bias;
                if (parameterizedSeries != null) {
                    parameterizedSeries.backwardStep(i, first, backwardStep);
                }
                ColumnSelectionKt.setColumn(backwardStep, first, this.hiddenDimension, fArr2);
                if (first == last) {
                    break;
                }
                first += step;
            }
        }
        this.previousHiddenStateWeighting.backwardSeries();
        ParameterizedSeries parameterizedSeries2 = this.bias;
        if (parameterizedSeries2 != null) {
            parameterizedSeries2.backwardSeries();
        }
        return this.inputWeighting.backward(i, fArr2);
    }

    @Override // com.komputation.optimization.Optimizable
    public void optimize(int i) {
        this.inputWeighting.optimize(i);
        this.previousHiddenStateWeighting.optimize(i);
        ParameterizedSeries parameterizedSeries = this.bias;
        if (parameterizedSeries != null) {
            parameterizedSeries.optimize(i);
        }
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public CpuRecurrentLayer(@Nullable String str, int i, int i2, int i3, @NotNull CpuWeightingLayer cpuWeightingLayer, @NotNull Direction direction, @NotNull ResultExtractionStrategy resultExtractionStrategy, @NotNull float[] fArr, @NotNull ParameterizedSeries parameterizedSeries, @NotNull CpuAdditionCombination[] cpuAdditionCombinationArr, @Nullable ParameterizedSeries parameterizedSeries2, @NotNull Series series) {
        super(str, cpuWeightingLayer, resultExtractionStrategy);
        IntProgression[] intProgressionArr;
        IntProgression[] intProgressionArr2;
        Intrinsics.checkParameterIsNotNull(cpuWeightingLayer, "inputWeighting");
        Intrinsics.checkParameterIsNotNull(direction, "direction");
        Intrinsics.checkParameterIsNotNull(resultExtractionStrategy, "resultExtraction");
        Intrinsics.checkParameterIsNotNull(fArr, "initialState");
        Intrinsics.checkParameterIsNotNull(parameterizedSeries, "previousHiddenStateWeighting");
        Intrinsics.checkParameterIsNotNull(cpuAdditionCombinationArr, "additions");
        Intrinsics.checkParameterIsNotNull(series, "activation");
        this.minimumSteps = i;
        this.maximumSteps = i2;
        this.hiddenDimension = i3;
        this.inputWeighting = cpuWeightingLayer;
        this.direction = direction;
        this.resultExtraction = resultExtractionStrategy;
        this.initialState = fArr;
        this.previousHiddenStateWeighting = parameterizedSeries;
        this.additions = cpuAdditionCombinationArr;
        this.bias = parameterizedSeries2;
        this.activation = series;
        this.numberPossibleLengths = BaseCpuForwardLayerKt.computeNumberPossibleLengths(this.minimumSteps, this.maximumSteps);
        this.possibleLengths = BaseCpuForwardLayerKt.computePossibleLengths(this.minimumSteps, this.numberPossibleLengths);
        this.stepWeightedInput = new float[this.hiddenDimension];
        this.possibleStepsForward = CpuRecurrentLayerKt.computePossibleStepsForward(this.minimumSteps, this.numberPossibleLengths);
        this.possibleStepsBackward = CpuRecurrentLayerKt.computePossibleStepsBackward(this.minimumSteps, this.numberPossibleLengths);
        switch (this.direction) {
            case Forward:
                intProgressionArr = (IntProgression[]) this.possibleStepsForward;
                break;
            case Backward:
                intProgressionArr = this.possibleStepsBackward;
                break;
            default:
                throw new NoWhenBranchMatchedException();
        }
        this.forwardStepsOverPossibleLengths = intProgressionArr;
        switch (this.direction) {
            case Forward:
                intProgressionArr2 = this.possibleStepsBackward;
                break;
            case Backward:
                intProgressionArr2 = this.possibleStepsForward;
                break;
            default:
                throw new NoWhenBranchMatchedException();
        }
        this.backwardStepsOverPossibleLengths = intProgressionArr2;
        this.backwardStore = new VariableLengthFloatArray(this.hiddenDimension, this.possibleLengths);
    }
}
