package com.facebook.presto.ml;

import com.facebook.presto.RowPageBuilder;
import com.facebook.presto.metadata.FunctionExtractor;
import com.facebook.presto.metadata.FunctionManager;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.operator.aggregation.Accumulator;
import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.block.BlockBuilderStatus;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.VarcharType;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
import java.util.Optional;
import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/ml/TestEvaluateClassifierPredictions.class */
public class TestEvaluateClassifierPredictions {
    private final MetadataManager metadata = MetadataManager.createTestMetadataManager();
    private final FunctionManager functionManager = this.metadata.getFunctionManager();

    @Test
    public void testEvaluateClassifierPredictions() {
        this.metadata.registerBuiltInFunctions(FunctionExtractor.extractFunctions(new MLPlugin().getFunctions()));
        Accumulator createAccumulator = this.functionManager.getAggregateFunctionImplementation(this.functionManager.lookupFunction("evaluate_classifier_predictions", TypeSignatureProvider.fromTypes(new Type[]{BigintType.BIGINT, BigintType.BIGINT}))).bind(ImmutableList.of(0, 1), Optional.empty()).createAccumulator();
        createAccumulator.addInput(getPage());
        BlockBuilder createBlockBuilder = createAccumulator.getFinalType().createBlockBuilder((BlockBuilderStatus) null, 1);
        createAccumulator.evaluateFinal(createBlockBuilder);
        String stringUtf8 = VarcharType.VARCHAR.getSlice(createBlockBuilder.build(), 0).toStringUtf8();
        ImmutableList copyOf = ImmutableList.copyOf(Splitter.on('\n').omitEmptyStrings().split(stringUtf8));
        Assert.assertEquals(copyOf.size(), 7, stringUtf8);
        Assert.assertEquals((String) copyOf.get(0), "Accuracy: 1/2 (50.00%)");
    }

    private static Page getPage() {
        return RowPageBuilder.rowPageBuilder(new Type[]{BigintType.BIGINT, BigintType.BIGINT}).row(new Object[]{1L, 1L}).row(new Object[]{1L, 0L}).build();
    }
}
