package bz.turtle.readable;

import bz.turtle.readable.input.Feature;
import bz.turtle.readable.input.FeatureInterface;
import bz.turtle.readable.input.Namespace;
import bz.turtle.readable.input.PredictionRequest;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URL;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.DoubleUnaryOperator;
import java.util.zip.GZIPInputStream;

/* loaded from: input_file:bz/turtle/readable/ReadableModel.class */
public class ReadableModel {
    private static final int intercept = 11650396;
    private static final Comparator<FeatureInterface> NOOP_COMPARATOR = (featureInterface, featureInterface2) -> {
        return 0;
    };
    private static final Set<Character> EMPTY_SET = new HashSet();
    private final int FNV_prime = 16777619;
    private boolean hasIntercept;
    private float[] weights;
    private int bits;
    private int oaa;
    private int mask;
    private int multiClassBits;
    private int seed;
    private boolean hashAll;
    private float minLabel;
    private float maxLabel;
    private int ngram;
    private int skip;
    private Map<Character, Set<Character>> quadratic;
    private Map<Character, Map<Character, Set<Character>>> cubic;
    private boolean quadraticAnyToAny;
    private DoubleUnaryOperator identity;
    private DoubleUnaryOperator logistic;
    private DoubleUnaryOperator glf1;
    private DoubleUnaryOperator poisson;
    private DoubleUnaryOperator link;

    private void extractOptions(String str, BiConsumer<String, String> biConsumer) {
        String trim = str.trim();
        if (trim.isEmpty()) {
            return;
        }
        String[] split = trim.split("\\s+");
        int i = 0;
        while (i < split.length) {
            if (split[i].contains("=")) {
                String[] split2 = split[i].split("=");
                biConsumer.accept(split2[0], split2[1]);
            } else {
                biConsumer.accept(split[i], split[i + 1]);
                i++;
            }
            i++;
        }
    }

    private String getSecondValue(String str) {
        String[] split = str.split(":");
        return split.length == 1 ? "" : split[1].trim();
    }

    private int intOrZero(String str) {
        if (str.equals("")) {
            return 0;
        }
        return Integer.parseInt(str);
    }

    private InputStream getReaderForExt(File file) throws IOException {
        return file.toString().endsWith(".gz") ? new GZIPInputStream(new FileInputStream(file)) : new FileInputStream(file);
    }

    private File findFileWithExt(File file, String str) {
        File file2 = Paths.get(file.toString(), str + ".gz").toFile();
        return file2.exists() ? file2 : Paths.get(file.toString(), str).toFile();
    }

    public void loadReadableModel(File file) throws IOException, UnsupportedOperationException {
        loadReadableModel(getReaderForExt(file));
    }

    public void loadReadableModel(InputStream inputStream) throws IOException, UnsupportedOperationException {
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
        this.bits = 0;
        boolean z = true;
        this.multiClassBits = 0;
        while (true) {
            try {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    this.mask = (1 << this.bits) - 1;
                    if (this.weights == null) {
                        throw new UnsupportedOperationException("failed to load the model, did not see 'bits:' line");
                    }
                    return;
                }
                if (z) {
                    if (readLine.equals(":0")) {
                        z = false;
                    }
                    if (readLine.contains("bits:")) {
                        this.bits = Integer.parseInt(getSecondValue(readLine));
                        this.weights = new float[1 << this.bits];
                    }
                    if (readLine.contains("Min label")) {
                        this.minLabel = Float.parseFloat(getSecondValue(readLine));
                    }
                    if (readLine.contains("Max label")) {
                        this.maxLabel = Float.parseFloat(getSecondValue(readLine));
                    }
                    if (readLine.contains("ngram")) {
                        this.ngram = intOrZero(getSecondValue(readLine));
                        if (this.ngram != 0) {
                            throw new UnsupportedOperationException("ngrams are not supported yet");
                        }
                    }
                    if (readLine.contains("skip")) {
                        this.skip = intOrZero(getSecondValue(readLine));
                        if (this.skip != 0) {
                            throw new UnsupportedOperationException("skip is not supported yet");
                        }
                    }
                    if (readLine.contains("options")) {
                        extractOptions(readLine.split(":", 2)[1], (str, str2) -> {
                            if (str.equals("--oaa")) {
                                this.oaa = Integer.parseInt(str2);
                                this.multiClassBits = 0;
                                int i = this.oaa - 1;
                                while (true) {
                                    int i2 = i;
                                    if (i2 <= 0) {
                                        break;
                                    }
                                    this.multiClassBits++;
                                    i = i2 >> 1;
                                }
                            }
                            if (str.equals("--cubic")) {
                                if (str2.contains(":")) {
                                    throw new UnsupportedOperationException("Any to any cubic interactions are not yet supported");
                                }
                                if (str2.charAt(0) == str2.charAt(1) || str2.charAt(0) == str2.charAt(2) || str2.charAt(1) == str2.charAt(2)) {
                                    throw new UnsupportedOperationException("Cubic interactions within the same namespace are not yet supported");
                                }
                                this.cubic.computeIfAbsent(Character.valueOf(str2.charAt(0)), ch -> {
                                    return new HashMap();
                                }).computeIfAbsent(Character.valueOf(str2.charAt(1)), ch2 -> {
                                    return new HashSet();
                                }).add(Character.valueOf(str2.charAt(2)));
                            }
                            if (str.equals("--link")) {
                                boolean z2 = -1;
                                switch (str2.hashCode()) {
                                    case -400457335:
                                        if (str2.equals("poisson")) {
                                            z2 = 2;
                                            break;
                                        }
                                        break;
                                    case -135761730:
                                        if (str2.equals("identity")) {
                                            z2 = true;
                                            break;
                                        }
                                        break;
                                    case 3175472:
                                        if (str2.equals("glf1")) {
                                            z2 = 3;
                                            break;
                                        }
                                        break;
                                    case 2022928992:
                                        if (str2.equals("logistic")) {
                                            z2 = false;
                                            break;
                                        }
                                        break;
                                }
                                switch (z2) {
                                    case false:
                                        this.link = this.logistic;
                                        break;
                                    case true:
                                        this.link = this.identity;
                                        break;
                                    case true:
                                        this.link = this.poisson;
                                        break;
                                    case true:
                                        this.link = this.glf1;
                                        break;
                                    default:
                                        throw new UnsupportedOperationException("only --link identity, logistic, glf1, or poisson are supported " + str2);
                                }
                            }
                            if (str.equals("--hash_seed")) {
                                this.seed = Integer.parseInt(str2);
                            }
                            if (str.equals("--hash") && str2.equals("all")) {
                                this.hashAll = true;
                            }
                            if (str.equals("--quadratic")) {
                                if (str2.equals("::")) {
                                    this.quadraticAnyToAny = true;
                                } else {
                                    this.quadratic.computeIfAbsent(Character.valueOf(str2.charAt(0)), ch3 -> {
                                        return new HashSet();
                                    }).add(Character.valueOf(str2.charAt(1)));
                                }
                            }
                        });
                    }
                } else {
                    String[] split = readLine.split(":");
                    this.weights[Integer.parseInt(split[0])] = Float.parseFloat(split[1]);
                }
            } finally {
                bufferedReader.close();
            }
        }
    }

    public ReadableModel(File file, boolean z, boolean z2) throws IOException, UnsupportedOperationException {
        this.FNV_prime = 16777619;
        this.hasIntercept = true;
        this.oaa = 1;
        this.mask = 0;
        this.multiClassBits = 0;
        this.seed = 0;
        this.hashAll = false;
        this.minLabel = 0.0f;
        this.maxLabel = 0.0f;
        this.ngram = 0;
        this.skip = 0;
        this.quadratic = new HashMap();
        this.cubic = new HashMap();
        this.quadraticAnyToAny = false;
        this.identity = DoubleUnaryOperator.identity();
        this.logistic = d -> {
            return 1.0d / (1.0d + Math.exp(-d));
        };
        this.glf1 = d2 -> {
            return (2.0d / (1.0d + Math.exp(-d2))) - 1.0d;
        };
        this.poisson = d3 -> {
            return Math.exp(d3);
        };
        this.link = this.identity;
        this.hasIntercept = z;
        if (!file.isDirectory()) {
            loadReadableModel(file);
            return;
        }
        File findFileWithExt = findFileWithExt(file, "readable_model.txt");
        File findFileWithExt2 = findFileWithExt(file, "test.txt");
        File findFileWithExt3 = findFileWithExt(file, "predictions.txt");
        loadReadableModel(findFileWithExt);
        if (findFileWithExt2.exists() && findFileWithExt3.exists()) {
            makeSureItWorks(findFileWithExt2, findFileWithExt3, z2);
        }
    }

    public ReadableModel(InputStream inputStream) throws IOException, UnsupportedOperationException {
        this(inputStream, true);
    }

    public ReadableModel(InputStream inputStream, boolean z) throws IOException, UnsupportedOperationException {
        this.FNV_prime = 16777619;
        this.hasIntercept = true;
        this.oaa = 1;
        this.mask = 0;
        this.multiClassBits = 0;
        this.seed = 0;
        this.hashAll = false;
        this.minLabel = 0.0f;
        this.maxLabel = 0.0f;
        this.ngram = 0;
        this.skip = 0;
        this.quadratic = new HashMap();
        this.cubic = new HashMap();
        this.quadraticAnyToAny = false;
        this.identity = DoubleUnaryOperator.identity();
        this.logistic = d -> {
            return 1.0d / (1.0d + Math.exp(-d));
        };
        this.glf1 = d2 -> {
            return (2.0d / (1.0d + Math.exp(-d2))) - 1.0d;
        };
        this.poisson = d3 -> {
            return Math.exp(d3);
        };
        this.link = this.identity;
        this.hasIntercept = z;
        loadReadableModel(inputStream);
    }

    public ReadableModel(URL url, boolean z) throws IOException, UnsupportedOperationException {
        this(new File(url.getFile()), z);
    }

    public ReadableModel(File file, boolean z) throws IOException, UnsupportedOperationException {
        this(file, z, false);
    }

    public ReadableModel(URL url) throws IOException, UnsupportedOperationException {
        this(new File(url.getFile()), true, false);
    }

    public ReadableModel(File file) throws IOException, UnsupportedOperationException {
        this(file, true, false);
    }

    public void makeSureItWorks(InputStream inputStream, InputStream inputStream2, boolean z) throws IOException, IllegalStateException {
        String readLine;
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
        BufferedReader bufferedReader2 = new BufferedReader(new InputStreamReader(inputStream2));
        int i = 0;
        while (true) {
            try {
                String readLine2 = bufferedReader.readLine();
                if (readLine2 == null || (readLine = bufferedReader2.readLine()) == null) {
                    break;
                }
                String[] split = readLine2.split("\\s+");
                PredictionRequest predictionRequest = new PredictionRequest();
                predictionRequest.probabilities = z;
                boolean z2 = false;
                for (int i2 = 0; i2 < split.length; i2++) {
                    if (split[i2].startsWith("|")) {
                        z2 = true;
                        predictionRequest.namespaces.add(new Namespace(split[i2].replaceFirst("\\|", "")));
                    } else if (z2) {
                        float f = 1.0f;
                        String str = split[i2];
                        if (split[i2].contains(":")) {
                            String[] split2 = split[i2].split(":");
                            str = split2[0];
                            f = Float.parseFloat(split2[1]);
                        }
                        predictionRequest.namespaces.get(predictionRequest.namespaces.size() - 1).features.add(new Feature(str, f));
                    }
                }
                float[] predict = predict(predictionRequest);
                if (readLine.contains(":")) {
                    for (String str2 : readLine.split(" ")) {
                        String[] split3 = str2.split(":");
                        int parseInt = Integer.parseInt(split3[0]) - 1;
                        float parseFloat = Float.parseFloat(split3[1]);
                        if (Math.abs(parseFloat - predict[parseInt]) > 0.01d) {
                            throw new IllegalStateException(String.format("line: %d index %d, prediction: %f, ourPrediction: %f \noaa %s,\npred line: %s\ntest line: %s", Integer.valueOf(i), Integer.valueOf(parseInt), Float.valueOf(parseFloat), Float.valueOf(predict[parseInt]), Arrays.toString(predict), readLine, readLine2));
                        }
                    }
                } else {
                    float parseFloat2 = Float.parseFloat(readLine);
                    if (Math.abs(parseFloat2 - predict[0]) > 0.01d) {
                        throw new IllegalStateException(String.format("line: %d index %d, prediction: %f, ourPrediction: %f \noaa %s,\npred line: %s\ntest line: %s", Integer.valueOf(i), 0, Float.valueOf(parseFloat2), Float.valueOf(predict[0]), Arrays.toString(predict), readLine, readLine2));
                    }
                }
                i++;
            } finally {
                bufferedReader2.close();
                bufferedReader.close();
            }
        }
    }

    public void makeSureItWorks(File file, File file2, boolean z) throws IOException, IllegalStateException {
        makeSureItWorks(getReaderForExt(file), getReaderForExt(file2), z);
    }

    private int getBucket(int i, int i2) {
        return ((i << this.multiClassBits) | i2) & this.mask;
    }

    public int featureHashOf(int i, FeatureInterface featureInterface) {
        if (!this.hashAll && featureInterface.hasIntegerName()) {
            return featureInterface.getIntegerName() + i;
        }
        return VWMurmur.hash(featureInterface.getBytes(), i);
    }

    public int namespaceHashOf(Namespace namespace, int i) {
        StringBuilder sb = namespace.namespace;
        if (this.hashAll) {
            return VWMurmur.hash(sb, i);
        }
        int i2 = 0;
        for (int i3 = 0; i3 < sb.length(); i3++) {
            char charAt = sb.charAt(i3);
            if (charAt < '0' || charAt > '9') {
                return VWMurmur.hash(sb, i);
            }
            i2 = ((10 * i2) + charAt) - 48;
        }
        return i2;
    }

    public float[] getReusableFloatArray() {
        return new float[this.oaa];
    }

    public float[] predict(PredictionRequest predictionRequest) {
        return predict(predictionRequest, null);
    }

    public float[] predict(PredictionRequest predictionRequest, Explanation explanation) {
        return predict(predictionRequest, explanation, NOOP_COMPARATOR);
    }

    public float[] predict(PredictionRequest predictionRequest, Explanation explanation, Comparator<FeatureInterface> comparator) {
        float[] reusableFloatArray = getReusableFloatArray();
        predict(reusableFloatArray, predictionRequest, explanation, comparator);
        return reusableFloatArray;
    }

    private void interact(float[] fArr, Namespace namespace, FeatureInterface featureInterface, Namespace namespace2, FeatureInterface featureInterface2, Explanation explanation) {
        int computedHash = (featureInterface.getComputedHash() * 16777619) ^ featureInterface2.getComputedHash();
        for (int i = 0; i < this.oaa; i++) {
            int bucket = getBucket(computedHash, i);
            if (explanation != null) {
                explanation.add(String.format("%s^%s*%s^%s:%d:%d:%f", namespace.namespace, featureInterface.getStringName(), namespace2.namespace, featureInterface2.getStringName(), Integer.valueOf(bucket), Integer.valueOf(i + 1), Float.valueOf(this.weights[bucket])));
                if (this.weights[bucket] == 0.0f) {
                    explanation.missingFeatures.add(1.0f);
                }
                explanation.featuresLookedUp.add(1.0f);
            }
            int i2 = i;
            fArr[i2] = fArr[i2] + (featureInterface.getValue() * featureInterface2.getValue() * this.weights[bucket]);
        }
    }

    private void interact3(float[] fArr, Namespace namespace, FeatureInterface featureInterface, Namespace namespace2, FeatureInterface featureInterface2, Namespace namespace3, FeatureInterface featureInterface3, Explanation explanation) {
        int computedHash = (((featureInterface.getComputedHash() * 16777619) ^ featureInterface2.getComputedHash()) * 16777619) ^ featureInterface3.getComputedHash();
        for (int i = 0; i < this.oaa; i++) {
            int bucket = getBucket(computedHash, i);
            if (explanation != null) {
                explanation.add(String.format("%s^%s*%s^%s*%s^%s:%d:%d:%f", namespace.namespace, featureInterface.getStringName(), namespace2.namespace, featureInterface2.getStringName(), namespace3.namespace, featureInterface3.getStringName(), Integer.valueOf(bucket), Integer.valueOf(i + 1), Float.valueOf(this.weights[bucket])));
                if (this.weights[bucket] == 0.0f) {
                    explanation.missingFeatures.add(1.0f);
                }
                explanation.featuresLookedUp.add(1.0f);
            }
            int i2 = i;
            fArr[i2] = fArr[i2] + (featureInterface.getValue() * featureInterface2.getValue() * featureInterface3.getValue() * this.weights[bucket]);
        }
    }

    public void predict(float[] fArr, PredictionRequest predictionRequest, Explanation explanation) {
        predict(fArr, predictionRequest, explanation, NOOP_COMPARATOR);
    }

    public void predict(float[] fArr, PredictionRequest predictionRequest, Explanation explanation, Comparator<FeatureInterface> comparator) {
        for (int i = 0; i < this.oaa; i++) {
            fArr[i] = 0.0f;
        }
        predictionRequest.namespaces.forEach(namespace -> {
            namespace.features.sort(comparator);
            if (!namespace.hashIsComputed) {
                namespace.computedHashValue = namespace.namespace.length() == 0 ? 0 : namespaceHashOf(namespace, this.seed);
                namespace.hashIsComputed = true;
            }
            namespace.features.forEach(featureInterface -> {
                if (!featureInterface.isHashComputed()) {
                    featureInterface.setComputedHash(featureHashOf(namespace.computedHashValue, featureInterface));
                }
                for (int i2 = 0; i2 < this.oaa; i2++) {
                    int bucket = getBucket(featureInterface.getComputedHash(), i2);
                    if (explanation != null) {
                        explanation.add(String.format("%s^%s:%d:%d:%f", namespace.namespace, featureInterface.getStringName(), Integer.valueOf(bucket), Integer.valueOf(i2 + 1), Float.valueOf(this.weights[bucket])));
                        if (this.weights[bucket] == 0.0f) {
                            explanation.missingFeatures.add(1.0f);
                        }
                        explanation.featuresLookedUp.add(1.0f);
                    }
                    int i3 = i2;
                    fArr[i3] = fArr[i3] + (featureInterface.getValue() * this.weights[bucket]);
                }
            });
        });
        if (this.quadratic.size() > 0 || this.quadraticAnyToAny) {
            if (this.quadraticAnyToAny) {
                predictionRequest.namespaces.forEach(namespace2 -> {
                    predictionRequest.namespaces.forEach(namespace2 -> {
                        namespace2.features.forEach(featureInterface -> {
                            namespace2.features.forEach(featureInterface -> {
                                interact(fArr, namespace2, featureInterface, namespace2, featureInterface, explanation);
                            });
                        });
                    });
                });
            } else {
                predictionRequest.namespaces.sort(new Comparator<Namespace>() { // from class: bz.turtle.readable.ReadableModel.1
                    @Override // java.util.Comparator
                    public int compare(Namespace namespace3, Namespace namespace4) {
                        return compare(namespace3.namespace, namespace4.namespace);
                    }

                    private int compare(StringBuilder sb, StringBuilder sb2) {
                        int length = sb.length();
                        int length2 = sb2.length();
                        int min = Math.min(length, length2);
                        for (int i2 = 0; i2 < min; i2++) {
                            char charAt = sb.charAt(i2);
                            char charAt2 = sb2.charAt(i2);
                            if (charAt != charAt2) {
                                return charAt - charAt2;
                            }
                        }
                        return length - length2;
                    }
                });
                for (int i2 = 0; i2 < predictionRequest.namespaces.size(); i2++) {
                    Namespace namespace3 = predictionRequest.namespaces.get(i2);
                    Set<Character> set = this.quadratic.get(Character.valueOf(namespace3.namespace.charAt(0)));
                    if (set != null) {
                        int i3 = i2;
                        set.forEach(ch -> {
                            int i4 = i3;
                            if (namespace3.namespace.charAt(0) > ch.charValue()) {
                                if (this.quadratic.getOrDefault(ch, EMPTY_SET).contains(Character.valueOf(namespace3.namespace.charAt(0)))) {
                                    return;
                                } else {
                                    i4 = 0;
                                }
                            }
                            for (int i5 = i4; i5 < predictionRequest.namespaces.size(); i5++) {
                                Namespace namespace4 = predictionRequest.namespaces.get(i5);
                                if (namespace4.namespace.charAt(0) == ch.charValue()) {
                                    if (namespace3.namespace.charAt(0) == ch.charValue()) {
                                        for (int i6 = 0; i6 < namespace3.features.size(); i6++) {
                                            for (int i7 = i6; i7 < namespace4.features.size(); i7++) {
                                                interact(fArr, namespace3, namespace3.features.get(i6), namespace4, namespace4.features.get(i7), explanation);
                                            }
                                        }
                                    } else {
                                        namespace3.features.forEach(featureInterface -> {
                                            namespace4.features.forEach(featureInterface -> {
                                                interact(fArr, namespace3, featureInterface, namespace4, featureInterface, explanation);
                                            });
                                        });
                                    }
                                }
                            }
                        });
                    }
                }
            }
        }
        if (this.cubic.size() > 0) {
            HashMap hashMap = new HashMap();
            predictionRequest.namespaces.stream().forEach(namespace4 -> {
            });
            this.cubic.forEach((ch2, map) -> {
                Namespace namespace5 = (Namespace) hashMap.get(ch2);
                if (namespace5 == null) {
                    return;
                }
                map.forEach((ch2, set2) -> {
                    Namespace namespace6 = (Namespace) hashMap.get(ch2);
                    if (namespace6 == null) {
                        return;
                    }
                    set2.forEach(ch2 -> {
                        Namespace namespace7 = (Namespace) hashMap.get(ch2);
                        if (namespace7 == null) {
                            return;
                        }
                        namespace5.features.forEach(featureInterface -> {
                            namespace6.features.forEach(featureInterface -> {
                                namespace7.features.forEach(featureInterface -> {
                                    interact3(fArr, namespace5, featureInterface, namespace6, featureInterface, namespace7, featureInterface, explanation);
                                });
                            });
                        });
                    });
                });
            });
        }
        if (this.hasIntercept) {
            for (int i4 = 0; i4 < this.oaa; i4++) {
                int bucket = getBucket(intercept, i4);
                if (explanation != null) {
                    explanation.add(String.format("%s:%d:%d:%f", "Constant", Integer.valueOf(bucket), Integer.valueOf(i4 + 1), Float.valueOf(this.weights[bucket])));
                    if (this.weights[bucket] == 0.0f) {
                        explanation.missingFeatures.add(1.0f);
                    }
                    explanation.featuresLookedUp.add(1.0f);
                }
                int i5 = i4;
                fArr[i5] = fArr[i5] + this.weights[bucket];
            }
        }
        if (explanation != null) {
            for (int i6 = 0; i6 < this.oaa; i6++) {
                explanation.predictions.add(fArr[i6]);
            }
        }
        if (!predictionRequest.probabilities) {
            clip(fArr);
            linkWith(fArr, this.link);
            return;
        }
        clip(fArr);
        linkWith(fArr, this.logistic);
        if (this.oaa > 1) {
            normalize(fArr);
        }
    }

    protected void clip(float[] fArr) {
        for (int i = 0; i < this.oaa; i++) {
            fArr[i] = clip(fArr[i]);
        }
    }

    protected float clip(float f) {
        return Math.max(Math.min(f, this.maxLabel), this.minLabel);
    }

    protected void linkWith(float[] fArr, DoubleUnaryOperator doubleUnaryOperator) {
        for (int i = 0; i < this.oaa; i++) {
            fArr[i] = (float) doubleUnaryOperator.applyAsDouble(fArr[i]);
        }
    }

    protected void normalize(float[] fArr) {
        float f = 0.0f;
        for (float f2 : fArr) {
            f += f2;
        }
        for (int i = 0; i < this.oaa; i++) {
            fArr[i] = fArr[i] / f;
        }
    }
}
