package com.gengoai.hermes.ml;

import com.gengoai.Validation;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.evaluation.SequenceLabelerEvaluation;
import com.gengoai.apollo.ml.model.Model;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.apollo.ml.observation.Sequence;
import com.gengoai.collection.Sets;
import com.gengoai.collection.counter.Counter;
import com.gengoai.collection.counter.Counters;
import com.gengoai.conversion.Cast;
import com.gengoai.string.TableFormatter;
import com.gengoai.tuple.Tuple3;
import java.io.PrintStream;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Objects;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Stream;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/hermes/ml/CoNLLEvaluation.class */
public class CoNLLEvaluation implements SequenceLabelerEvaluation {
    private final String outputName;
    private final Counter<String> incorrect = Counters.newCounter(new String[0]);
    private final Counter<String> correct = Counters.newCounter(new String[0]);
    private final Counter<String> missed = Counters.newCounter(new String[0]);
    private final Set<String> tags = new HashSet();
    private double totalPhrasesGold = 0.0d;
    private double totalPhrasesFound = 0.0d;

    public CoNLLEvaluation(String str) {
        this.outputName = Validation.notNullOrBlank(str);
    }

    public double accuracy() {
        return this.correct.sum() / ((this.correct.sum() + this.incorrect.sum()) + this.missed.sum());
    }

    private void entry(Set<Tuple3<Integer, Integer, String>> set, Set<Tuple3<Integer, Integer, String>> set2) {
        this.totalPhrasesFound += set2.size();
        this.totalPhrasesGold += set.size();
        Stream map = Sets.union(set, set2).stream().map((v0) -> {
            return v0.getV3();
        });
        Set<String> set3 = this.tags;
        Objects.requireNonNull(set3);
        map.forEach((v1) -> {
            r1.add(v1);
        });
        Stream map2 = Sets.intersection(set, set2).stream().map((v0) -> {
            return v0.getV3();
        });
        Counter<String> counter = this.correct;
        Objects.requireNonNull(counter);
        map2.forEach((v1) -> {
            r1.increment(v1);
        });
        Stream map3 = Sets.difference(set, set2).stream().map((v0) -> {
            return v0.getV3();
        });
        Counter<String> counter2 = this.missed;
        Objects.requireNonNull(counter2);
        map3.forEach((v1) -> {
            r1.increment(v1);
        });
        Stream map4 = Sets.difference(set2, set).stream().map((v0) -> {
            return v0.getV3();
        });
        Counter<String> counter3 = this.incorrect;
        Objects.requireNonNull(counter3);
        map4.forEach((v1) -> {
            r1.increment(v1);
        });
    }

    public void evaluate(@NonNull Model model, @NonNull DataSet dataSet) {
        if (model == null) {
            throw new NullPointerException("model is marked non-null but is null");
        }
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        dataSet.forEach(datum -> {
            entry(tags(((Observation) datum.get(this.outputName)).asSequence()), tags(((Observation) model.transform(datum).get(this.outputName)).asSequence()));
        });
    }

    private double f1(double d, double d2) {
        if (d + d2 == 0.0d) {
            return 0.0d;
        }
        return ((2.0d * d) * d2) / (d + d2);
    }

    public double f1(String str) {
        return f1(precision(str), recall(str));
    }

    public double macroF1() {
        double d = 0.0d;
        Iterator<String> it = this.tags.iterator();
        while (it.hasNext()) {
            d += f1(it.next());
        }
        return d / this.tags.size();
    }

    public double macroPrecision() {
        double d = 0.0d;
        Iterator<String> it = this.tags.iterator();
        while (it.hasNext()) {
            d += precision(it.next());
        }
        return d / this.tags.size();
    }

    public double macroRecall() {
        double d = 0.0d;
        Iterator<String> it = this.tags.iterator();
        while (it.hasNext()) {
            d += recall(it.next());
        }
        return d / this.tags.size();
    }

    public void merge(@NonNull SequenceLabelerEvaluation sequenceLabelerEvaluation) {
        if (sequenceLabelerEvaluation == null) {
            throw new NullPointerException("evaluation is marked non-null but is null");
        }
        Validation.checkArgument(sequenceLabelerEvaluation instanceof CoNLLEvaluation);
        CoNLLEvaluation coNLLEvaluation = (CoNLLEvaluation) Cast.as(sequenceLabelerEvaluation);
        this.incorrect.merge(coNLLEvaluation.incorrect);
        this.correct.merge(coNLLEvaluation.correct);
        this.missed.merge(coNLLEvaluation.missed);
        this.tags.addAll(coNLLEvaluation.tags);
    }

    public double microF1() {
        return f1(microPrecision(), microRecall());
    }

    public double microPrecision() {
        double sum = this.correct.sum();
        double sum2 = this.incorrect.sum();
        if (sum2 + sum <= 0.0d) {
            return 1.0d;
        }
        return sum / (sum + sum2);
    }

    public double microRecall() {
        double sum = this.correct.sum();
        double sum2 = this.missed.sum();
        if (sum2 + sum <= 0.0d) {
            return 1.0d;
        }
        return sum / (sum + sum2);
    }

    public void output(@NonNull PrintStream printStream, boolean z) {
        if (printStream == null) {
            throw new NullPointerException("printStream is marked non-null but is null");
        }
        TreeSet treeSet = new TreeSet(this.tags);
        printStream.println("Total Gold Phrases: " + this.totalPhrasesGold);
        printStream.println("Total Predicted Phrases: " + this.totalPhrasesFound);
        printStream.println("Total Correct: " + this.correct.sum());
        TableFormatter tableFormatter = new TableFormatter();
        tableFormatter.setMinCellWidth(5);
        tableFormatter.setNumberFormatter(new DecimalFormat("#,###"));
        DecimalFormat decimalFormat = new DecimalFormat("0.0%");
        tableFormatter.title("Tag Metrics").header(Arrays.asList("", "Precision", "Recall", "F1-Measure", "Correct", "Missed", "Incorrect"));
        treeSet.forEach(str -> {
            tableFormatter.content(Arrays.asList(str, decimalFormat.format(precision(str)), decimalFormat.format(recall(str)), decimalFormat.format(f1(str)), Double.valueOf(this.correct.get(str)), Double.valueOf(this.missed.get(str)), Double.valueOf(this.incorrect.get(str))));
        });
        tableFormatter.footer(Arrays.asList("micro", decimalFormat.format(microPrecision()), decimalFormat.format(microRecall()), decimalFormat.format(microF1()), Double.valueOf(this.correct.sum()), Double.valueOf(this.missed.sum()), Double.valueOf(this.incorrect.sum())));
        tableFormatter.footer(Arrays.asList("macro", decimalFormat.format(macroPrecision()), decimalFormat.format(macroRecall()), decimalFormat.format(macroF1()), "-", "-", "-"));
        tableFormatter.print(printStream);
    }

    public double precision(String str) {
        Validation.notNullOrBlank(str);
        double d = this.correct.get(str);
        double d2 = this.incorrect.get(str);
        if (d2 + d <= 0.0d) {
            return 1.0d;
        }
        return d / (d + d2);
    }

    public double recall(String str) {
        double d = this.correct.get(str);
        double d2 = this.missed.get(str);
        if (d2 + d <= 0.0d) {
            return 1.0d;
        }
        return d / (d + d2);
    }

    private Set<Tuple3<Integer, Integer, String>> tags(Sequence<?> sequence) {
        HashSet hashSet = new HashSet();
        int i = 0;
        while (i < sequence.size()) {
            String name = ((Observation) sequence.get(i)).asVariable().getName();
            if (name.equals("O")) {
                i++;
            } else {
                String substring = name.substring(2);
                int i2 = i;
                do {
                    i++;
                    if (i >= sequence.size()) {
                        break;
                    }
                } while (((Observation) sequence.get(i)).asVariable().getName().startsWith("I-"));
                hashSet.add(Tuple3.of(Integer.valueOf(i2), Integer.valueOf(i), substring));
            }
        }
        return hashSet;
    }
}
