package org.tribuo.classification.evaluation;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.ToDoubleFunction;
import java.util.logging.Logger;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.math.la.DenseMatrix;

/* loaded from: input_file:org/tribuo/classification/evaluation/LabelConfusionMatrix.class */
public final class LabelConfusionMatrix implements ConfusionMatrix<Label> {
    private static final Logger logger = Logger.getLogger(LabelConfusionMatrix.class.getName());
    private final ImmutableOutputInfo<Label> domain;
    private final int total;
    private final Map<Label, Double> occurrences;
    private final Set<Label> observed;
    private final DenseMatrix cm;
    private List<Label> labelOrder;

    public LabelConfusionMatrix(Model<Label> model, List<Prediction<Label>> list) {
        this((ImmutableOutputInfo<Label>) model.getOutputIDInfo(), list);
    }

    public LabelConfusionMatrix(ImmutableOutputInfo<Label> immutableOutputInfo, List<Prediction<Label>> list) {
        this.domain = immutableOutputInfo;
        this.total = list.size();
        this.cm = new DenseMatrix(immutableOutputInfo.size(), immutableOutputInfo.size());
        this.occurrences = new HashMap();
        this.observed = new HashSet();
        tabulate(list);
    }

    private void tabulate(List<Prediction<Label>> list) {
        list.forEach(prediction -> {
            Label label = (Label) prediction.getExample().getOutput();
            Label label2 = (Label) prediction.getOutput();
            if (label.getLabel().equals(Label.UNKNOWN)) {
                throw new IllegalArgumentException("Prediction with unknown ground truth. Unable to evaluate.");
            }
            this.occurrences.merge(label, Double.valueOf(1.0d), (v0, v1) -> {
                return Double.sum(v0, v1);
            });
            this.observed.add(label);
            this.observed.add(label2);
            int iDOrThrow = getIDOrThrow(label);
            this.cm.add(getIDOrThrow(label2), iDOrThrow, 1.0d);
        });
    }

    @Override // org.tribuo.classification.evaluation.ConfusionMatrix
    public ImmutableOutputInfo<Label> getDomain() {
        return this.domain;
    }

    @Override // org.tribuo.classification.evaluation.ConfusionMatrix
    public double support() {
        return this.total;
    }

    @Override // org.tribuo.classification.evaluation.ConfusionMatrix
    public double support(Label label) {
        return this.occurrences.getOrDefault(label, Double.valueOf(0.0d)).doubleValue();
    }

    @Override // org.tribuo.classification.evaluation.ConfusionMatrix
    public double tp(Label label) {
        return compute(label, num -> {
            return this.cm.get(num.intValue(), num.intValue());
        });
    }

    @Override // org.tribuo.classification.evaluation.ConfusionMatrix
    public double fp(Label label) {
        return compute(label, num -> {
            return this.cm.rowSum(num.intValue()) - this.cm.get(num.intValue(), num.intValue());
        });
    }

    @Override // org.tribuo.classification.evaluation.ConfusionMatrix
    public double fn(Label label) {
        return compute(label, num -> {
            return this.cm.columnSum(num.intValue()) - this.cm.get(num.intValue(), num.intValue());
        });
    }

    @Override // org.tribuo.classification.evaluation.ConfusionMatrix
    public double tn(Label label) {
        int size = getDomain().size();
        int id = getDomain().getID(label);
        double d = 0.0d;
        for (int i = 0; i < size; i++) {
            if (i != id) {
                for (int i2 = 0; i2 < size; i2++) {
                    if (i2 != id) {
                        d += this.cm.get(i, i2);
                    }
                }
            }
        }
        return d;
    }

    @Override // org.tribuo.classification.evaluation.ConfusionMatrix
    public double confusion(Label label, Label label2) {
        return this.cm.get(getDomain().getID(label), getDomain().getID(label2));
    }

    private double compute(Label label, ToDoubleFunction<Integer> toDoubleFunction) {
        int id = getDomain().getID(label);
        if (id >= 0) {
            return toDoubleFunction.applyAsDouble(Integer.valueOf(id));
        }
        logger.fine("Unknown Label " + label);
        return 0.0d;
    }

    private int getIDOrThrow(Label label) {
        int id = this.domain.getID(label);
        if (id < 0) {
            throw new IllegalArgumentException("Unknown label: " + label);
        }
        return id;
    }

    public void setLabelOrder(List<Label> list) {
        this.labelOrder = list;
    }

    public String toString() {
        if (this.labelOrder == null) {
            this.labelOrder = new ArrayList(this.domain.getDomain());
        }
        this.labelOrder.retainAll(this.observed);
        int i = Integer.MIN_VALUE;
        for (Label label : this.labelOrder) {
            i = Math.max(String.format(" %,d", Integer.valueOf((int) this.occurrences.getOrDefault(label, Double.valueOf(0.0d)).doubleValue())).length(), Math.max(label.getLabel().length(), i));
        }
        StringBuilder sb = new StringBuilder();
        String format = String.format("%%-%ds", Integer.valueOf(i + 2));
        String format2 = String.format("%%%ds", Integer.valueOf(i + 2));
        String format3 = String.format("%%,%dd", Integer.valueOf(i + 2));
        sb.append(String.format(format, ""));
        Iterator<Label> it = this.labelOrder.iterator();
        while (it.hasNext()) {
            sb.append(String.format(format2, it.next().getLabel()));
        }
        sb.append('\n');
        for (Label label2 : this.labelOrder) {
            sb.append(String.format(format, label2.getLabel()));
            Iterator<Label> it2 = this.labelOrder.iterator();
            while (it2.hasNext()) {
                sb.append(String.format(format3, Integer.valueOf((int) confusion(it2.next(), label2))));
            }
            sb.append('\n');
        }
        return sb.toString();
    }

    public String toHTML() {
        if (this.labelOrder == null) {
            this.labelOrder = new ArrayList(this.domain.getDomain());
        }
        LinkedHashSet<Label> linkedHashSet = new LinkedHashSet(this.labelOrder);
        linkedHashSet.retainAll(this.observed);
        StringBuilder sb = new StringBuilder();
        sb.append("<table>\n");
        sb.append(String.format("<tr><th>True Label</th><th style=\"text-align:center\" colspan=\"%d\">Predicted Labels</th></tr>%n", Integer.valueOf(this.occurrences.size() + 1)));
        sb.append("<tr><th></th>");
        Iterator it = linkedHashSet.iterator();
        while (it.hasNext()) {
            sb.append("<th style=\"text-align:right\">").append((Label) it.next()).append("</th>");
        }
        sb.append("<th style=\"text-align:right\">Total</th>");
        sb.append("</tr>\n");
        for (Label label : linkedHashSet) {
            sb.append("<tr><th>").append(label).append("</th>");
            double doubleValue = this.occurrences.getOrDefault(label, Double.valueOf(0.0d)).doubleValue();
            Iterator it2 = linkedHashSet.iterator();
            while (it2.hasNext()) {
                double confusion = confusion((Label) it2.next(), label);
                sb.append("<td style=\"text-align:right\">").append(String.format("%,d (%.1f%%)", Integer.valueOf((int) confusion), Double.valueOf((confusion / doubleValue) * 100.0d))).append("</td>");
            }
            sb.append("<td style=\"text-align:right\">").append(doubleValue).append("</td>");
            sb.append("</tr>\n");
        }
        sb.append("</table>");
        return sb.toString();
    }
}
