package org.tribuo.classification.evaluation;

import java.util.List;
import java.util.Objects;
import java.util.function.ToDoubleBiFunction;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.evaluation.metrics.EvaluationMetric;
import org.tribuo.evaluation.metrics.MetricContext;
import org.tribuo.evaluation.metrics.MetricTarget;
import org.tribuo.sequence.SequenceModel;

/* loaded from: input_file:org/tribuo/classification/evaluation/LabelMetric.class */
public class LabelMetric implements EvaluationMetric<Label, Context> {
    private final MetricTarget<Label> tgt;
    private final String name;
    private final ToDoubleBiFunction<MetricTarget<Label>, Context> impl;

    /* loaded from: input_file:org/tribuo/classification/evaluation/LabelMetric$Context.class */
    public static final class Context extends MetricContext<Label> {
        private final ConfusionMatrix<Label> cm;

        public Context(Model<Label> model, List<Prediction<Label>> list) {
            super(model, list);
            this.cm = new LabelConfusionMatrix((ImmutableOutputInfo<Label>) model.getOutputIDInfo(), list);
        }

        public Context(SequenceModel<Label> sequenceModel, List<Prediction<Label>> list) {
            super(sequenceModel, list);
            this.cm = new LabelConfusionMatrix((ImmutableOutputInfo<Label>) sequenceModel.getOutputIDInfo(), list);
        }

        public ConfusionMatrix<Label> getCM() {
            return this.cm;
        }
    }

    public LabelMetric(MetricTarget<Label> metricTarget, String str, ToDoubleBiFunction<MetricTarget<Label>, Context> toDoubleBiFunction) {
        this.tgt = metricTarget;
        this.name = str;
        this.impl = toDoubleBiFunction;
    }

    public double compute(Context context) {
        return this.impl.applyAsDouble(this.tgt, context);
    }

    public MetricTarget<Label> getTarget() {
        return this.tgt;
    }

    public String getName() {
        return this.name;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        LabelMetric labelMetric = (LabelMetric) obj;
        return Objects.equals(this.tgt, labelMetric.tgt) && Objects.equals(this.name, labelMetric.name) && Objects.equals(this.impl, labelMetric.impl);
    }

    public int hashCode() {
        return Objects.hash(this.tgt, this.name, this.impl);
    }

    public String toString() {
        return "LabelMetric{target=" + this.tgt + ", name='" + this.name + '}';
    }

    public Context createContext(Model<Label> model, List<Prediction<Label>> list) {
        return new Context(model, list);
    }

    /* renamed from: createContext, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ MetricContext m26createContext(Model model, List list) {
        return createContext((Model<Label>) model, (List<Prediction<Label>>) list);
    }
}
