package com.facebook.presto.ml;

import com.facebook.presto.operator.aggregation.AggregationFunction;
import com.facebook.presto.operator.aggregation.CombineFunction;
import com.facebook.presto.operator.aggregation.InputFunction;
import com.facebook.presto.operator.aggregation.OutputFunction;
import com.facebook.presto.operator.aggregation.state.AccumulatorState;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.type.VarcharType;
import com.facebook.presto.type.SqlType;
import com.google.common.base.Preconditions;
import java.util.Locale;

@AggregationFunction("evaluate_classifier_predictions")
/* loaded from: input_file:com/facebook/presto/ml/EvaluateClassifierPredictionsAggregation.class */
public final class EvaluateClassifierPredictionsAggregation {

    /* loaded from: input_file:com/facebook/presto/ml/EvaluateClassifierPredictionsAggregation$EvaluateClassifierPredictionsState.class */
    public interface EvaluateClassifierPredictionsState extends AccumulatorState {
        long getTruePositives();

        void setTruePositives(long j);

        long getFalsePositives();

        void setFalsePositives(long j);

        long getTrueNegatives();

        void setTrueNegatives(long j);

        long getFalseNegatives();

        void setFalseNegatives(long j);
    }

    private EvaluateClassifierPredictionsAggregation() {
    }

    @InputFunction
    public static void input(EvaluateClassifierPredictionsState evaluateClassifierPredictionsState, @SqlType("bigint") long j, @SqlType("bigint") long j2) {
        Preconditions.checkArgument(j2 == 1 || j2 == 0, "evaluate_predictions only supports binary classifiers");
        Preconditions.checkArgument(j == 1 || j == 0, "evaluate_predictions only supports binary classifiers");
        if (j == 1) {
            if (j2 == 1) {
                evaluateClassifierPredictionsState.setTruePositives(evaluateClassifierPredictionsState.getTruePositives() + 1);
                return;
            } else {
                evaluateClassifierPredictionsState.setFalseNegatives(evaluateClassifierPredictionsState.getFalseNegatives() + 1);
                return;
            }
        }
        if (j2 == 0) {
            evaluateClassifierPredictionsState.setTrueNegatives(evaluateClassifierPredictionsState.getTrueNegatives() + 1);
        } else {
            evaluateClassifierPredictionsState.setFalsePositives(evaluateClassifierPredictionsState.getFalsePositives() + 1);
        }
    }

    @CombineFunction
    public static void combine(EvaluateClassifierPredictionsState evaluateClassifierPredictionsState, EvaluateClassifierPredictionsState evaluateClassifierPredictionsState2) {
        evaluateClassifierPredictionsState.setTruePositives(evaluateClassifierPredictionsState.getTruePositives() + evaluateClassifierPredictionsState2.getTruePositives());
        evaluateClassifierPredictionsState.setFalsePositives(evaluateClassifierPredictionsState.getFalsePositives() + evaluateClassifierPredictionsState2.getFalsePositives());
        evaluateClassifierPredictionsState.setTrueNegatives(evaluateClassifierPredictionsState.getTrueNegatives() + evaluateClassifierPredictionsState2.getTrueNegatives());
        evaluateClassifierPredictionsState.setFalseNegatives(evaluateClassifierPredictionsState.getFalseNegatives() + evaluateClassifierPredictionsState2.getFalseNegatives());
    }

    @OutputFunction("varchar")
    public static void output(EvaluateClassifierPredictionsState evaluateClassifierPredictionsState, BlockBuilder blockBuilder) {
        long truePositives = evaluateClassifierPredictionsState.getTruePositives();
        long falsePositives = evaluateClassifierPredictionsState.getFalsePositives();
        long trueNegatives = evaluateClassifierPredictionsState.getTrueNegatives();
        long falseNegatives = evaluateClassifierPredictionsState.getFalseNegatives();
        StringBuilder sb = new StringBuilder();
        long j = trueNegatives + truePositives;
        long j2 = truePositives + trueNegatives + falsePositives + falseNegatives;
        sb.append(String.format(Locale.US, "Accuracy: %d/%d (%.2f%%)\n", Long.valueOf(j), Long.valueOf(j2), Double.valueOf((100.0d * j) / j2)));
        sb.append(String.format(Locale.US, "Precision: %d/%d (%.2f%%)\n", Long.valueOf(truePositives), Long.valueOf(truePositives + falsePositives), Double.valueOf((100.0d * truePositives) / (truePositives + falsePositives))));
        sb.append(String.format(Locale.US, "Recall: %d/%d (%.2f%%)", Long.valueOf(truePositives), Long.valueOf(truePositives + falseNegatives), Double.valueOf((100.0d * truePositives) / (truePositives + falseNegatives))));
        VarcharType.VARCHAR.writeString(blockBuilder, sb.toString());
    }
}
