package com.facebook.presto.ml;

import com.facebook.presto.RowPageBuilder;
import com.facebook.presto.block.BlockEncodingManager;
import com.facebook.presto.metadata.BoundVariables;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.ml.type.ClassifierParametricType;
import com.facebook.presto.ml.type.ClassifierType;
import com.facebook.presto.ml.type.ModelType;
import com.facebook.presto.ml.type.RegressorType;
import com.facebook.presto.operator.aggregation.Accumulator;
import com.facebook.presto.operator.aggregation.AggregationCompiler;
import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.block.BlockBuilderStatus;
import com.facebook.presto.spi.block.BlockEncodingFactory;
import com.facebook.presto.spi.block.InterleavedBlockBuilder;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.DoubleType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeManager;
import com.facebook.presto.spi.type.TypeSignature;
import com.facebook.presto.spi.type.TypeSignatureParameter;
import com.facebook.presto.spi.type.VarcharType;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.testing.AggregationTestUtils;
import com.facebook.presto.type.TypeJsonUtils;
import com.facebook.presto.type.TypeRegistry;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.collect.ImmutableList;
import java.util.Optional;
import java.util.Random;
import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/ml/TestLearnAggregations.class */
public class TestLearnAggregations {
    private static final TypeManager typeManager;

    @Test
    public void testLearn() throws Exception {
        assertLearnClassifer(AggregationTestUtils.generateInternalAggregationFunction(LearnClassifierAggregation.class, ClassifierType.BIGINT_CLASSIFIER.getTypeSignature(), ImmutableList.of(BigintType.BIGINT.getTypeSignature(), typeManager.getParameterizedType("map", ImmutableList.of(TypeSignatureParameter.of(TypeSignature.parseTypeSignature("bigint")), TypeSignatureParameter.of(TypeSignature.parseTypeSignature("double")))).getTypeSignature()), typeManager).bind(ImmutableList.of(0, 1), Optional.empty()).createAccumulator());
    }

    @Test
    public void testLearnLibSvm() throws Exception {
        assertLearnClassifer(AggregationCompiler.generateAggregationBindableFunction(LearnLibSvmClassifierAggregation.class, ClassifierType.BIGINT_CLASSIFIER.getTypeSignature(), ImmutableList.of(BigintType.BIGINT.getTypeSignature(), typeManager.getParameterizedType("map", ImmutableList.of(TypeSignatureParameter.of(TypeSignature.parseTypeSignature("bigint")), TypeSignatureParameter.of(TypeSignature.parseTypeSignature("double")))).getTypeSignature(), VarcharType.getParametrizedVarcharSignature("x"))).specialize(BoundVariables.builder().setLongVariable("x", 2147483647L).build(), 3, typeManager).bind(ImmutableList.of(0, 1, 2), Optional.empty()).createAccumulator());
    }

    private static void assertLearnClassifer(Accumulator accumulator) throws Exception {
        accumulator.addInput(getPage());
        BlockBuilder createBlockBuilder = accumulator.getFinalType().createBlockBuilder(new BlockBuilderStatus(), 1);
        accumulator.evaluateFinal(createBlockBuilder);
        Model deserialize = ModelUtils.deserialize(accumulator.getFinalType().getSlice(createBlockBuilder.build(), 0));
        Assert.assertNotNull(deserialize, "deserialization failed");
        Assert.assertTrue(deserialize instanceof Classifier, "deserialized model is not a classifier");
    }

    private static Page getPage() throws JsonProcessingException {
        RowPageBuilder rowPageBuilder = RowPageBuilder.rowPageBuilder(new Type[]{BigintType.BIGINT, typeManager.getParameterizedType("map", ImmutableList.of(TypeSignatureParameter.of(TypeSignature.parseTypeSignature("bigint")), TypeSignatureParameter.of(TypeSignature.parseTypeSignature("double")))), VarcharType.VARCHAR});
        Random random = new Random(0L);
        for (int i = 0; i < 100; i++) {
            long j = random.nextDouble() < 0.5d ? 0L : 1L;
            rowPageBuilder.row(new Object[]{Long.valueOf(j), mapSliceOf(BigintType.BIGINT, DoubleType.DOUBLE, 0, Double.valueOf(j + random.nextGaussian())), "C=1"});
        }
        return rowPageBuilder.build();
    }

    private static Block mapSliceOf(Type type, Type type2, Object obj, Object obj2) {
        InterleavedBlockBuilder interleavedBlockBuilder = new InterleavedBlockBuilder(ImmutableList.of(type, type2), new BlockBuilderStatus(), 100);
        TypeJsonUtils.appendToBlockBuilder(type, obj, interleavedBlockBuilder);
        TypeJsonUtils.appendToBlockBuilder(type2, obj2, interleavedBlockBuilder);
        return interleavedBlockBuilder.build();
    }

    static {
        TypeRegistry typeRegistry = new TypeRegistry();
        typeRegistry.addParametricType(new ClassifierParametricType());
        typeRegistry.addType(ModelType.MODEL);
        typeRegistry.addType(RegressorType.REGRESSOR);
        new FunctionRegistry(typeRegistry, new BlockEncodingManager(typeRegistry, new BlockEncodingFactory[0]), new FeaturesConfig());
        typeManager = typeRegistry;
    }
}
