package com.facebook.presto.ml;

import com.facebook.presto.operator.aggregation.AggregationFunction;
import com.facebook.presto.operator.aggregation.CombineFunction;
import com.facebook.presto.operator.aggregation.InputFunction;
import com.facebook.presto.operator.aggregation.OutputFunction;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.type.VarcharType;
import com.facebook.presto.type.SqlType;
import com.google.common.collect.Sets;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import java.util.Locale;
import java.util.Map;

@AggregationFunction("evaluate_classifier_predictions")
/* loaded from: input_file:com/facebook/presto/ml/EvaluateClassifierPredictionsAggregation.class */
public final class EvaluateClassifierPredictionsAggregation {
    private EvaluateClassifierPredictionsAggregation() {
    }

    @InputFunction
    public static void input(EvaluateClassifierPredictionsState evaluateClassifierPredictionsState, @SqlType("bigint") long j, @SqlType("bigint") long j2) {
        input(evaluateClassifierPredictionsState, Slices.utf8Slice(String.valueOf(j)), Slices.utf8Slice(String.valueOf(j2)));
    }

    @InputFunction
    public static void input(EvaluateClassifierPredictionsState evaluateClassifierPredictionsState, @SqlType("varchar") Slice slice, @SqlType("varchar") Slice slice2) {
        if (slice.equals(slice2)) {
            String stringUtf8 = slice.toStringUtf8();
            if (!evaluateClassifierPredictionsState.getTruePositives().containsKey(stringUtf8)) {
                evaluateClassifierPredictionsState.addMemoryUsage(slice.length() + 4);
            }
            evaluateClassifierPredictionsState.getTruePositives().put(stringUtf8, Integer.valueOf(evaluateClassifierPredictionsState.getTruePositives().getOrDefault(stringUtf8, 0).intValue() + 1));
            return;
        }
        String stringUtf82 = slice.toStringUtf8();
        String stringUtf83 = slice2.toStringUtf8();
        if (!evaluateClassifierPredictionsState.getFalsePositives().containsKey(stringUtf83)) {
            evaluateClassifierPredictionsState.addMemoryUsage(slice2.length() + 4);
        }
        evaluateClassifierPredictionsState.getFalsePositives().put(stringUtf83, Integer.valueOf(evaluateClassifierPredictionsState.getFalsePositives().getOrDefault(stringUtf83, 0).intValue() + 1));
        if (!evaluateClassifierPredictionsState.getFalseNegatives().containsKey(stringUtf82)) {
            evaluateClassifierPredictionsState.addMemoryUsage(slice.length() + 4);
        }
        evaluateClassifierPredictionsState.getFalseNegatives().put(stringUtf82, Integer.valueOf(evaluateClassifierPredictionsState.getFalseNegatives().getOrDefault(stringUtf82, 0).intValue() + 1));
    }

    @CombineFunction
    public static void combine(EvaluateClassifierPredictionsState evaluateClassifierPredictionsState, EvaluateClassifierPredictionsState evaluateClassifierPredictionsState2) {
        evaluateClassifierPredictionsState.addMemoryUsage(0 + mergeMaps(evaluateClassifierPredictionsState.getTruePositives(), evaluateClassifierPredictionsState2.getTruePositives()) + mergeMaps(evaluateClassifierPredictionsState.getFalsePositives(), evaluateClassifierPredictionsState2.getFalsePositives()) + mergeMaps(evaluateClassifierPredictionsState.getFalseNegatives(), evaluateClassifierPredictionsState2.getFalseNegatives()));
    }

    private static int mergeMaps(Map<String, Integer> map, Map<String, Integer> map2) {
        int i = 0;
        for (Map.Entry<String, Integer> entry : map2.entrySet()) {
            if (!map.containsKey(entry.getKey())) {
                i += entry.getKey().getBytes().length + 4;
            }
            map.put(entry.getKey(), Integer.valueOf(map.getOrDefault(entry.getKey(), 0).intValue() + map2.getOrDefault(entry.getKey(), 0).intValue()));
        }
        return i;
    }

    @OutputFunction("varchar")
    public static void output(EvaluateClassifierPredictionsState evaluateClassifierPredictionsState, BlockBuilder blockBuilder) {
        StringBuilder sb = new StringBuilder();
        long intValue = evaluateClassifierPredictionsState.getTruePositives().values().stream().reduce(0, (v0, v1) -> {
            return Integer.sum(v0, v1);
        }).intValue();
        long intValue2 = intValue + evaluateClassifierPredictionsState.getFalsePositives().values().stream().reduce(0, (v0, v1) -> {
            return Integer.sum(v0, v1);
        }).intValue();
        sb.append(String.format(Locale.US, "Accuracy: %d/%d (%.2f%%)\n", Long.valueOf(intValue), Long.valueOf(intValue2), Double.valueOf((100.0d * intValue) / intValue2)));
        for (String str : Sets.union(Sets.union(evaluateClassifierPredictionsState.getTruePositives().keySet(), evaluateClassifierPredictionsState.getFalsePositives().keySet()), evaluateClassifierPredictionsState.getFalseNegatives().keySet())) {
            int intValue3 = evaluateClassifierPredictionsState.getTruePositives().getOrDefault(str, 0).intValue();
            int intValue4 = evaluateClassifierPredictionsState.getFalsePositives().getOrDefault(str, 0).intValue();
            int intValue5 = evaluateClassifierPredictionsState.getFalseNegatives().getOrDefault(str, 0).intValue();
            sb.append(String.format(Locale.US, "Class '%s'\n", str));
            sb.append(String.format(Locale.US, "Precision: %d/%d (%.2f%%)\n", Integer.valueOf(intValue3), Integer.valueOf(intValue3 + intValue4), Double.valueOf((100.0d * intValue3) / (intValue3 + intValue4))));
            sb.append(String.format(Locale.US, "Recall: %d/%d (%.2f%%)\n", Integer.valueOf(intValue3), Integer.valueOf(intValue3 + intValue5), Double.valueOf((100.0d * intValue3) / (intValue3 + intValue5))));
        }
        VarcharType.VARCHAR.writeString(blockBuilder, sb.toString());
    }
}
