package net.maizegenetics.analysis.imputation;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import org.apache.log4j.Logger;

/* loaded from: input_file:net/maizegenetics/analysis/imputation/BackwardForwardVariableStateNumber.class */
public class BackwardForwardVariableStateNumber {
    private static final Logger myLogger = Logger.getLogger(BackwardForwardAlgorithm.class);
    private int[] myObservations;
    private int[] myPositions;
    private TransitionProbability myTransitions;
    private EmissionProbability myEmissions;
    private double[] initialStateProbability;
    private List<double[]> alpha;
    private List<double[]> beta;

    public BackwardForwardVariableStateNumber calculateAlpha() {
        this.myTransitions.setNode(0);
        int numberOfStates = this.myTransitions.getNumberOfStates();
        int length = this.myObservations.length;
        this.alpha = new LinkedList();
        double[] dArr = new double[numberOfStates];
        for (int i = 0; i < numberOfStates; i++) {
            dArr[i] = this.initialStateProbability[i] * this.myEmissions.getProbObsGivenState(i, this.myObservations[0], 0);
        }
        this.alpha.add(dArr);
        for (int i2 = 1; i2 < length; i2++) {
            this.myTransitions.setNode(i2);
            int i3 = numberOfStates;
            numberOfStates = this.myTransitions.getNumberOfStates();
            double[] dArr2 = new double[numberOfStates];
            for (int i4 = 0; i4 < numberOfStates; i4++) {
                double d = 0.0d;
                for (int i5 = 0; i5 < i3; i5++) {
                    d += dArr[i5] * this.myTransitions.getTransitionProbability(i5, i4);
                }
                dArr2[i4] = d * this.myEmissions.getProbObsGivenState(i4, this.myObservations[i2], i2);
            }
            double[] multiplyArrayByConstantIfSmall = multiplyArrayByConstantIfSmall(dArr2);
            this.alpha.add(multiplyArrayByConstantIfSmall);
            dArr = multiplyArrayByConstantIfSmall;
        }
        return this;
    }

    private double[] multiplyArrayByConstantIfSmall(double[] dArr) {
        return (Arrays.stream(dArr).min().getAsDouble() >= 1.0E-50d || Arrays.stream(dArr).max().getAsDouble() >= 1.0E-25d) ? dArr : Arrays.stream(dArr).map(d -> {
            return d * 1.0E25d;
        }).toArray();
    }

    public BackwardForwardVariableStateNumber calculateBeta() {
        int length = this.myObservations.length;
        LinkedList linkedList = new LinkedList();
        this.myTransitions.setNode(length - 1);
        int numberOfStates = this.myTransitions.getNumberOfStates();
        double[] dArr = new double[numberOfStates];
        Arrays.fill(dArr, 1.0d);
        linkedList.add(dArr);
        for (int i = length - 2; i >= 0; i--) {
            this.myTransitions.setNode(i);
            int i2 = numberOfStates;
            numberOfStates = this.myTransitions.getNumberOfStates();
            double[] dArr2 = new double[numberOfStates];
            this.myTransitions.setNode(i + 1);
            for (int i3 = 0; i3 < numberOfStates; i3++) {
                double d = 0.0d;
                for (int i4 = 0; i4 < i2; i4++) {
                    d += this.myTransitions.getTransitionProbability(i3, i4) * this.myEmissions.getProbObsGivenState(i4, this.myObservations[i + 1], i + 1) * dArr[i4];
                }
                dArr2[i3] = d;
            }
            double[] multiplyArrayByConstantIfSmall = multiplyArrayByConstantIfSmall(dArr2);
            linkedList.addFirst(multiplyArrayByConstantIfSmall);
            dArr = multiplyArrayByConstantIfSmall;
        }
        this.beta = linkedList;
        return this;
    }

    public List<double[]> gamma() {
        ArrayList arrayList = new ArrayList();
        Iterator<double[]> it = this.beta.iterator();
        for (double[] dArr : this.alpha) {
            double[] next = it.next();
            int length = dArr.length;
            double[] dArr2 = new double[length];
            for (int i = 0; i < length; i++) {
                dArr2[i] = dArr[i] * next[i];
            }
            double sum = Arrays.stream(dArr2).sum();
            for (int i2 = 0; i2 < length; i2++) {
                int i3 = i2;
                dArr2[i3] = dArr2[i3] / sum;
            }
            arrayList.add(dArr2);
        }
        return arrayList;
    }

    public BackwardForwardVariableStateNumber emission(EmissionProbability emissionProbability) {
        this.myEmissions = emissionProbability;
        return this;
    }

    public BackwardForwardVariableStateNumber transition(TransitionProbability transitionProbability) {
        this.myTransitions = transitionProbability;
        return this;
    }

    public BackwardForwardVariableStateNumber observations(int[] iArr) {
        this.myObservations = iArr;
        return this;
    }

    public BackwardForwardVariableStateNumber positions(int[] iArr) {
        this.myPositions = iArr;
        return this;
    }

    public BackwardForwardVariableStateNumber initialStateProbability(double[] dArr) {
        this.initialStateProbability = dArr;
        return this;
    }

    public List<double[]> alpha() {
        return this.alpha;
    }

    public List<double[]> beta() {
        return this.beta;
    }
}
