package com.aliasi.hmm;

import com.aliasi.symbol.SymbolTable;
import com.aliasi.tag.MarginalTagger;
import com.aliasi.tag.NBestTagger;
import com.aliasi.tag.ScoredTagging;
import com.aliasi.tag.TagLattice;
import com.aliasi.tag.Tagger;
import com.aliasi.tag.Tagging;
import com.aliasi.util.BoundedPriorityQueue;
import com.aliasi.util.Iterators;
import com.aliasi.util.Scored;
import com.aliasi.util.ScoredObject;
import com.aliasi.util.Strings;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:com/aliasi/hmm/HmmDecoder.class */
public class HmmDecoder implements Tagger<String>, NBestTagger<String>, MarginalTagger<String> {
    private final HiddenMarkovModel mHmm;
    private Map<String, double[]> mEmissionCache;
    private Map<String, double[]> mEmissionLog2Cache;
    private double mLog2EmissionBeam;
    private double mLog2Beam;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/aliasi/hmm/HmmDecoder$JointIterator.class */
    public static final class JointIterator extends Iterators.Modifier<ScoredObject<String[]>> {
        final double mLog2TotalProb;

        JointIterator(Iterator<ScoredObject<String[]>> it, double d) {
            super(it);
            this.mLog2TotalProb = d;
        }

        @Override // com.aliasi.util.Iterators.Modifier
        public ScoredObject<String[]> modify(ScoredObject<String[]> scoredObject) {
            return new ScoredObject<>(scoredObject.getObject(), scoredObject.score() - this.mLog2TotalProb);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/aliasi/hmm/HmmDecoder$NBestIterator.class */
    public class NBestIterator extends Iterators.Buffered<ScoredObject<String[]>> {
        private final Viterbi mViterbi;
        private final BoundedPriorityQueue<State> mPQ;

        NBestIterator(Viterbi viterbi, int i) {
            this.mViterbi = viterbi;
            this.mPQ = new BoundedPriorityQueue<>(ScoredObject.comparator(), i);
            String[] strArr = viterbi.mEmissions;
            int numSymbols = HmmDecoder.this.mHmm.stateSymbolTable().numSymbols();
            int length = strArr.length - 1;
            for (int i2 = 0; i2 < numSymbols; i2++) {
                double d = viterbi.mLattice[length][i2];
                if (d > Double.NEGATIVE_INFINITY) {
                    this.mPQ.offer(new State(length, 0.0d, d, i2, null));
                }
            }
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // com.aliasi.util.Iterators.Buffered
        public ScoredObject<String[]> bufferNext() {
            int numSymbols = HmmDecoder.this.mHmm.stateSymbolTable().numSymbols();
            int length = this.mViterbi.mEmissions.length;
            int i = length - 1;
            while (!this.mPQ.isEmpty()) {
                State poll = this.mPQ.poll();
                int emissionIndex = poll.emissionIndex();
                if (emissionIndex == 0) {
                    this.mPQ.setMaxSize(this.mPQ.maxSize() - 1);
                    return poll.result(length);
                }
                String str = this.mViterbi.mEmissions[emissionIndex];
                int i2 = poll.mTagId;
                double d = poll.mScore;
                if (emissionIndex == i) {
                    d += HmmDecoder.this.mHmm.endLog2Prob(i2);
                }
                int i3 = emissionIndex - 1;
                double emitLog2Prob = HmmDecoder.this.mHmm.emitLog2Prob(i2, str);
                for (int i4 = 0; i4 < numSymbols; i4++) {
                    double transitLog2Prob = d + HmmDecoder.this.mHmm.transitLog2Prob(i4, i2) + emitLog2Prob;
                    double d2 = this.mViterbi.mLattice[i3][i4];
                    if (transitLog2Prob > Double.NEGATIVE_INFINITY && d2 > Double.NEGATIVE_INFINITY) {
                        this.mPQ.offer(new State(i3, transitLog2Prob, d2, i4, poll));
                    }
                }
            }
            return null;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/aliasi/hmm/HmmDecoder$State.class */
    public final class State implements Scored {
        private final double mScore;
        private final double mContScore;
        private final int mTagId;
        private final State mPreviousState;
        private final int mEmissionIndex;

        State(int i, double d, double d2, int i2, State state) {
            this.mEmissionIndex = i;
            this.mScore = d;
            this.mContScore = d2;
            this.mTagId = i2;
            this.mPreviousState = state;
        }

        public int emissionIndex() {
            return this.mEmissionIndex;
        }

        @Override // com.aliasi.util.Scored
        public double score() {
            return this.mScore + this.mContScore;
        }

        ScoredObject<String[]> result(int i) {
            return new ScoredObject<>(tags(i), score());
        }

        String[] tags(int i) {
            SymbolTable stateSymbolTable = HmmDecoder.this.mHmm.stateSymbolTable();
            String[] strArr = new String[i];
            State state = this;
            for (int i2 = 0; i2 < i; i2++) {
                strArr[i2] = stateSymbolTable.idToSymbol(state.mTagId);
                state = state.mPreviousState;
            }
            return strArr;
        }
    }

    /* loaded from: input_file:com/aliasi/hmm/HmmDecoder$TaggingIteratorAdapter.class */
    static class TaggingIteratorAdapter implements Iterator<ScoredTagging<String>> {
        private final Iterator<ScoredObject<String[]>> mIt;
        private final List<String> mTokens;
        private final int mMaxResults;
        private int mResults = 0;

        TaggingIteratorAdapter(List<String> list, Iterator<ScoredObject<String[]>> it, int i) {
            this.mTokens = list;
            this.mIt = it;
            this.mMaxResults = i;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public ScoredTagging<String> next() {
            ScoredObject<String[]> next = this.mIt.next();
            double score = next.score();
            List asList = Arrays.asList(next.getObject());
            this.mResults++;
            return new ScoredTagging<>(this.mTokens, asList, score);
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this.mResults < this.mMaxResults && this.mIt.hasNext();
        }

        @Override // java.util.Iterator
        public void remove() {
            this.mIt.remove();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/aliasi/hmm/HmmDecoder$Viterbi.class */
    public class Viterbi {
        private final String[] mEmissions;
        private final double[][] mLattice;
        private final int[][] mBackPts;

        Viterbi(String[] strArr) {
            this.mEmissions = strArr;
            HiddenMarkovModel hiddenMarkovModel = HmmDecoder.this.mHmm;
            int numSymbols = hiddenMarkovModel.stateSymbolTable().numSymbols();
            int length = strArr.length;
            double[][] dArr = new double[length][numSymbols];
            this.mLattice = dArr;
            int[][] iArr = new int[length][numSymbols];
            this.mBackPts = iArr;
            if (strArr.length == 0) {
                return;
            }
            double[] emitLog2Probs = HmmDecoder.this.emitLog2Probs(strArr[0]);
            for (int i = 0; i < numSymbols; i++) {
                dArr[0][i] = emitLog2Probs[i] + hiddenMarkovModel.startLog2Prob(i);
            }
            int[] iArr2 = new int[numSymbols + 1];
            for (int i2 = 1; i2 < length; i2++) {
                double[] dArr2 = dArr[i2 - 1];
                HmmDecoder.this.unprunedSources(dArr2, iArr2, HmmDecoder.this.mLog2Beam);
                double[] emitLog2Probs2 = HmmDecoder.this.emitLog2Probs(strArr[i2]);
                for (int i3 = 0; i3 < numSymbols; i3++) {
                    if (Double.NEGATIVE_INFINITY != emitLog2Probs2[i3]) {
                        double d = Double.NEGATIVE_INFINITY;
                        int i4 = 0;
                        for (int i5 = 0; iArr2[i5] != -1; i5++) {
                            int i6 = iArr2[i5];
                            double transitLog2Prob = dArr2[i6] + hiddenMarkovModel.transitLog2Prob(i6, i3);
                            if (transitLog2Prob > d) {
                                d = transitLog2Prob;
                                i4 = i6;
                            }
                        }
                        dArr[i2][i3] = d + emitLog2Probs2[i3];
                        iArr[i2][i3] = i4;
                    } else {
                        dArr[i2][i3] = Double.NEGATIVE_INFINITY;
                        iArr[i2][i3] = 0;
                    }
                }
            }
            double[] dArr3 = dArr[length - 1];
            for (int i7 = 0; i7 < numSymbols; i7++) {
                int i8 = i7;
                dArr3[i8] = dArr3[i8] + hiddenMarkovModel.endLog2Prob(i7);
            }
        }

        String[] bestStates() {
            HiddenMarkovModel hiddenMarkovModel = HmmDecoder.this.mHmm;
            int numSymbols = hiddenMarkovModel.stateSymbolTable().numSymbols();
            int length = this.mEmissions.length;
            if (length == 0) {
                return Strings.EMPTY_STRING_ARRAY;
            }
            int[][] iArr = this.mBackPts;
            int[] iArr2 = new int[length];
            int i = 0;
            double[] dArr = this.mLattice[length - 1];
            for (int i2 = 1; i2 < numSymbols; i2++) {
                if (dArr[i2] > dArr[i]) {
                    i = i2;
                }
            }
            iArr2[length - 1] = i;
            int i3 = length;
            while (true) {
                i3--;
                if (i3 <= 0) {
                    break;
                }
                iArr2[i3 - 1] = iArr[i3][iArr2[i3]];
            }
            String[] strArr = new String[length];
            SymbolTable stateSymbolTable = hiddenMarkovModel.stateSymbolTable();
            for (int i4 = 0; i4 < strArr.length; i4++) {
                strArr[i4] = stateSymbolTable.idToSymbol(iArr2[i4]);
            }
            return strArr;
        }
    }

    public HmmDecoder(HiddenMarkovModel hiddenMarkovModel) {
        this(hiddenMarkovModel, null, null);
    }

    public HmmDecoder(HiddenMarkovModel hiddenMarkovModel, Map<String, double[]> map, Map<String, double[]> map2) {
        this(hiddenMarkovModel, map, map2, Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY);
    }

    public HmmDecoder(HiddenMarkovModel hiddenMarkovModel, Map<String, double[]> map, Map<String, double[]> map2, double d, double d2) {
        this.mHmm = hiddenMarkovModel;
        this.mEmissionCache = map;
        this.mEmissionLog2Cache = map2;
        setLog2Beam(d);
        setLog2EmissionBeam(d2);
    }

    public HiddenMarkovModel getHmm() {
        return this.mHmm;
    }

    public Map<String, double[]> emissionCache() {
        return this.mEmissionCache;
    }

    public Map<String, double[]> emissionLog2Cache() {
        return this.mEmissionLog2Cache;
    }

    public void setEmissionCache(Map<String, double[]> map) {
        this.mEmissionCache = map;
    }

    public void setLog2EmissionBeam(double d) {
        if (d <= 0.0d || Double.isNaN(d)) {
            throw new IllegalArgumentException("Beam width must be a positive number. Found log2EmissionBeam=" + d);
        }
        this.mLog2EmissionBeam = d;
    }

    public void setLog2Beam(double d) {
        if (d <= 0.0d || Double.isNaN(d)) {
            throw new IllegalArgumentException("Beam width must be a positive number. Found log2EmissionBeam=" + d);
        }
        this.mLog2Beam = d;
    }

    public void setEmissionLog2Cache(Map<String, double[]> map) {
        this.mEmissionLog2Cache = map;
    }

    double[] cachedEmitProbs(String str) {
        double[] dArr = this.mEmissionCache.get(str);
        if (dArr != null) {
            return dArr;
        }
        double[] computeEmitProbs = computeEmitProbs(str);
        this.mEmissionCache.put(str, computeEmitProbs);
        return computeEmitProbs;
    }

    double[] computeEmitProbs(String str) {
        int numSymbols = this.mHmm.stateSymbolTable().numSymbols();
        double[] dArr = new double[numSymbols];
        for (int i = 0; i < numSymbols; i++) {
            dArr[i] = this.mHmm.emitProb(i, str);
        }
        return dArr;
    }

    double[] emitProbs(String str) {
        return this.mEmissionCache == null ? computeEmitProbs(str) : cachedEmitProbs(str);
    }

    double[] cachedEmitLog2Probs(String str) {
        double[] dArr = this.mEmissionLog2Cache.get(str);
        if (dArr != null) {
            return dArr;
        }
        double[] computeEmitLog2Probs = computeEmitLog2Probs(str);
        this.mEmissionLog2Cache.put(str, computeEmitLog2Probs);
        return computeEmitLog2Probs;
    }

    double[] computeEmitLog2Probs(String str) {
        int numSymbols = this.mHmm.stateSymbolTable().numSymbols();
        double[] dArr = new double[numSymbols];
        for (int i = 0; i < numSymbols; i++) {
            dArr[i] = this.mHmm.emitLog2Prob(i, str);
        }
        additiveBeamPrune(dArr, this.mLog2EmissionBeam);
        return dArr;
    }

    static void additiveBeamPrune(double[] dArr, double d) {
        if (d == Double.POSITIVE_INFINITY) {
            return;
        }
        double d2 = dArr[0];
        for (int i = 1; i < dArr.length; i++) {
            if (dArr[i] > d2) {
                d2 = dArr[i];
            }
        }
        for (int i2 = 1; i2 < dArr.length; i2++) {
            if (dArr[i2] + d < d2) {
                dArr[i2] = Double.NEGATIVE_INFINITY;
            }
        }
    }

    double[] emitLog2Probs(String str) {
        return this.mEmissionLog2Cache == null ? computeEmitLog2Probs(str) : cachedEmitLog2Probs(str);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v14, types: [double[][], double[][][]] */
    TagWordLattice lattice(String[] strArr) {
        int length = strArr.length;
        int numSymbols = this.mHmm.stateSymbolTable().numSymbols();
        if (length == 0) {
            return new TagWordLattice(strArr, this.mHmm.stateSymbolTable(), new double[numSymbols], new double[numSymbols], new double[0][numSymbols][numSymbols]);
        }
        double[] dArr = new double[numSymbols];
        double[] emitProbs = emitProbs(strArr[0]);
        for (int i = 0; i < numSymbols; i++) {
            dArr[i] = this.mHmm.startProb(i) * emitProbs[i];
        }
        ?? r0 = new double[length];
        for (int i2 = 1; i2 < length; i2++) {
            r0[i2] = new double[numSymbols];
            double[] emitProbs2 = emitProbs(strArr[i2]);
            for (int i3 = 0; i3 < numSymbols; i3++) {
                double[] dArr2 = new double[numSymbols];
                r0[i2][i3] = dArr2;
                for (int i4 = 0; i4 < numSymbols; i4++) {
                    dArr2[i4] = this.mHmm.transitProb(i3, i4) * emitProbs2[i4];
                }
            }
        }
        double[] dArr3 = new double[numSymbols];
        for (int i5 = 0; i5 < numSymbols; i5++) {
            dArr3[i5] = this.mHmm.endProb(i5);
        }
        return new TagWordLattice(strArr, this.mHmm.stateSymbolTable(), dArr, dArr3, r0);
    }

    String[] firstBest(String[] strArr) {
        return strArr.length == 0 ? Strings.EMPTY_STRING_ARRAY : new Viterbi(strArr).bestStates();
    }

    Iterator<ScoredObject<String[]>> nBest(String[] strArr) {
        return strArr.length == 0 ? Iterators.singleton(new ScoredObject(Strings.EMPTY_STRING_ARRAY, 0.0d)) : new NBestIterator(new Viterbi(strArr), Integer.MAX_VALUE);
    }

    Iterator<ScoredObject<String[]>> nBest(String[] strArr, int i) {
        return strArr.length == 0 ? Iterators.singleton(new ScoredObject(Strings.EMPTY_STRING_ARRAY, 0.0d)) : new NBestIterator(new Viterbi(strArr), i);
    }

    Iterator<ScoredObject<String[]>> nBestConditional(String[] strArr) {
        return new JointIterator(nBest(strArr), lattice(strArr).log2Total());
    }

    @Override // com.aliasi.tag.Tagger
    public Tagging<String> tag(List<String> list) {
        String[] strArr = (String[]) list.toArray(Strings.EMPTY_STRING_ARRAY);
        return new Tagging<>(Arrays.asList(strArr), Arrays.asList(firstBest(strArr)));
    }

    @Override // com.aliasi.tag.NBestTagger
    public Iterator<ScoredTagging<String>> tagNBest(List<String> list, int i) {
        return new TaggingIteratorAdapter(list, nBest((String[]) list.toArray(Strings.EMPTY_STRING_ARRAY), i), i);
    }

    @Override // com.aliasi.tag.NBestTagger
    public Iterator<ScoredTagging<String>> tagNBestConditional(List<String> list, int i) {
        return new TaggingIteratorAdapter(list, nBestConditional((String[]) list.toArray(Strings.EMPTY_STRING_ARRAY)), i);
    }

    @Override // com.aliasi.tag.MarginalTagger
    public TagLattice<String> tagMarginal(List<String> list) {
        return lattice((String[]) list.toArray(Strings.EMPTY_STRING_ARRAY));
    }

    void unprunedSources(double[] dArr, int[] iArr, double d) {
        double d2 = dArr[0];
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] > d2) {
                d2 = dArr[i];
            }
        }
        int i2 = 0;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            if (dArr[i3] + d >= d2) {
                int i4 = i2;
                i2++;
                iArr[i4] = i3;
            }
        }
        iArr[i2] = -1;
    }
}
