package com.aliasi.lm;

import com.aliasi.lm.LanguageModel;
import com.aliasi.symbol.SymbolTable;
import com.aliasi.tokenizer.Tokenizer;
import com.aliasi.tokenizer.TokenizerFactory;
import com.aliasi.util.Strings;
import java.io.IOException;
import java.io.ObjectInput;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;

/* loaded from: input_file:com/aliasi/lm/CompiledTokenizedLM.class */
public class CompiledTokenizedLM implements LanguageModel.Sequence, LanguageModel.Tokenized {
    private final TokenizerFactory mTokenizerFactory;
    private final SymbolTable mSymbolTable;
    private final LanguageModel.Sequence mUnknownTokenModel;
    private final LanguageModel.Sequence mWhitespaceModel;
    private final int mMaxNGram;
    private final int[] mTokens;
    private final float[] mLogProbs;
    private final float[] mLogLambdas;
    private final int[] mFirstChild;

    /* JADX INFO: Access modifiers changed from: package-private */
    public CompiledTokenizedLM(ObjectInput objectInput) throws IOException, ClassNotFoundException {
        String readUTF = objectInput.readUTF();
        if (readUTF.equals(Strings.EMPTY_STRING)) {
            this.mTokenizerFactory = (TokenizerFactory) objectInput.readObject();
        } else {
            try {
                this.mTokenizerFactory = (TokenizerFactory) Class.forName(readUTF).getConstructor(new Class[0]).newInstance(new Object[0]);
            } catch (IllegalAccessException e) {
                throw new ClassNotFoundException("Constructing " + readUTF, e);
            } catch (InstantiationException e2) {
                throw new ClassNotFoundException("Constructing " + readUTF, e2);
            } catch (NoSuchMethodException e3) {
                throw new ClassNotFoundException("Constructing " + readUTF, e3);
            } catch (InvocationTargetException e4) {
                throw new ClassNotFoundException("Constructing " + readUTF, e4);
            }
        }
        this.mSymbolTable = (SymbolTable) objectInput.readObject();
        this.mUnknownTokenModel = (LanguageModel.Sequence) objectInput.readObject();
        this.mWhitespaceModel = (LanguageModel.Sequence) objectInput.readObject();
        this.mMaxNGram = objectInput.readInt();
        int readInt = objectInput.readInt();
        int readInt2 = objectInput.readInt();
        this.mTokens = new int[readInt];
        this.mLogProbs = new float[readInt];
        this.mLogLambdas = new float[readInt2 + 1];
        this.mFirstChild = new int[readInt2 + 2];
        this.mFirstChild[this.mFirstChild.length - 1] = readInt;
        for (int i = 0; i < readInt; i++) {
            this.mTokens[i] = objectInput.readInt();
            this.mLogProbs[i] = objectInput.readFloat();
            if (i <= readInt2) {
                this.mLogLambdas[i] = objectInput.readFloat();
                this.mFirstChild[i] = objectInput.readInt();
            }
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Tokenizer Class Name=" + this.mTokenizerFactory);
        sb.append('\n');
        sb.append("Symbol Table=" + this.mSymbolTable);
        sb.append('\n');
        sb.append("Unknown Token Model=" + this.mUnknownTokenModel);
        sb.append('\n');
        sb.append("Whitespace Model=" + this.mWhitespaceModel);
        sb.append('\n');
        sb.append("Token Trie");
        sb.append('\n');
        sb.append("Nodes=" + this.mTokens.length + " Internal=" + this.mLogLambdas.length);
        sb.append('\n');
        sb.append("Index Tok logP firstDtr log(1-L)");
        sb.append('\n');
        for (int i = 0; i < this.mTokens.length; i++) {
            sb.append(i);
            sb.append('\t');
            sb.append(this.mTokens[i]);
            sb.append('\t');
            sb.append(this.mLogProbs[i]);
            if (i < this.mFirstChild.length) {
                sb.append('\t');
                sb.append(this.mFirstChild[i]);
                if (i < this.mLogLambdas.length) {
                    sb.append('\t');
                    sb.append(this.mLogLambdas[i]);
                }
            }
            sb.append('\n');
        }
        return sb.toString();
    }

    @Override // com.aliasi.lm.LanguageModel
    public double log2Estimate(CharSequence charSequence) {
        char[] charArray = Strings.toCharArray(charSequence);
        return log2Estimate(charArray, 0, charArray.length);
    }

    @Override // com.aliasi.lm.LanguageModel
    public double log2Estimate(char[] cArr, int i, int i2) {
        Strings.checkArgsStartEnd(cArr, i, i2);
        double d = 0.0d;
        Tokenizer tokenizer = this.mTokenizerFactory.tokenizer(cArr, i, i2 - i);
        ArrayList<String> arrayList = new ArrayList();
        while (true) {
            d += this.mWhitespaceModel.log2Estimate(tokenizer.nextWhitespace());
            String nextToken = tokenizer.nextToken();
            if (nextToken == null) {
                break;
            }
            arrayList.add(nextToken);
        }
        int[] iArr = new int[arrayList.size() + 2];
        iArr[0] = -2;
        iArr[iArr.length - 1] = -2;
        int i3 = 1;
        for (String str : arrayList) {
            iArr[i3] = this.mSymbolTable.symbolToID(str);
            if (iArr[i3] < 0) {
                d += this.mUnknownTokenModel.log2Estimate(str);
            }
            i3++;
        }
        for (int i4 = 2; i4 <= iArr.length; i4++) {
            d += conditionalTokenEstimate(iArr, 0, i4);
        }
        return d;
    }

    private double conditionalTokenEstimate(int[] iArr, int i, int i2) {
        double d = 0.0d;
        int i3 = i2 - 1;
        int i4 = iArr[i3];
        for (int min = Math.min(i3 - i, this.mMaxNGram - 1); min >= 0; min--) {
            int index = getIndex(iArr, i3 - min, i3);
            if (index != -1) {
                if (i4 != -1) {
                    if (getIndex(index, i4) != -1) {
                        return d + this.mLogProbs[r0];
                    }
                    if (hasDtrs(index)) {
                        d += this.mLogLambdas[index];
                    }
                } else if (hasDtrs(index)) {
                    d += this.mLogLambdas[index];
                }
            }
        }
        return d;
    }

    @Override // com.aliasi.lm.LanguageModel.Tokenized
    public double tokenLog2Probability(String[] strArr, int i, int i2) {
        int[] iArr = new int[strArr.length];
        for (int i3 = 0; i3 < strArr.length; i3++) {
            iArr[i3] = this.mSymbolTable.symbolToID(strArr[i3]);
        }
        double d = 0.0d;
        for (int i4 = i + 1; i4 <= i2; i4++) {
            d += conditionalTokenEstimate(iArr, i, i4);
        }
        return d;
    }

    @Override // com.aliasi.lm.LanguageModel.Tokenized
    public double tokenProbability(String[] strArr, int i, int i2) {
        return Math.pow(2.0d, tokenLog2Probability(strArr, i, i2));
    }

    boolean hasDtrs(int i) {
        return i < this.mLogLambdas.length && !Double.isNaN((double) this.mLogLambdas[i]);
    }

    private int getIndex(int i, int i2) {
        if (i + 1 >= this.mFirstChild.length) {
            return -1;
        }
        int i3 = this.mFirstChild[i];
        int i4 = this.mFirstChild[i + 1] - 1;
        while (i3 <= i4) {
            int i5 = (i4 + i3) / 2;
            if (this.mTokens[i5] == i2) {
                return i5;
            }
            if (this.mTokens[i5] < i2) {
                i3 = i3 == i5 ? i5 + 1 : i5;
            } else {
                i4 = i4 == i5 ? i5 - 1 : i5;
            }
        }
        return -1;
    }

    private int getIndex(int[] iArr, int i, int i2) {
        int i3 = 0;
        for (int i4 = i; i4 < i2; i4++) {
            i3 = getIndex(i3, iArr[i4]);
            if (i3 == -1) {
                return -1;
            }
        }
        return i3;
    }
}
