package com.facebook.presto.ml;

import com.facebook.presto.Session;
import com.facebook.presto.metadata.FunctionExtractor;
import com.facebook.presto.spi.type.ParametricType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.testing.LocalQueryRunner;
import com.facebook.presto.testing.TestingSession;
import com.facebook.presto.tests.AbstractTestQueryFramework;
import com.facebook.presto.tpch.TpchConnectorFactory;
import com.google.common.collect.ImmutableMap;
import java.util.Iterator;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/ml/TestMLQueries.class */
public class TestMLQueries extends AbstractTestQueryFramework {
    public TestMLQueries() {
        super(createLocalQueryRunner());
    }

    @Test
    public void testPrediction() throws Exception {
        assertQuery("SELECT classify(features(1, 2), model) FROM (SELECT learn_classifier(labels, features) AS model FROM (VALUES (1, features(1, 2))) t(labels, features)) t2", "SELECT 1");
    }

    @Test
    public void testVarcharPrediction() throws Exception {
        assertQuery("SELECT classify(features(1, 2), model) FROM (SELECT learn_classifier(labels, features) AS model FROM (VALUES ('cat', features(1, 2))) t(labels, features)) t2", "SELECT 'cat'");
    }

    private static LocalQueryRunner createLocalQueryRunner() {
        Session build = TestingSession.testSessionBuilder().setCatalog("local").setSchema("tiny").build();
        LocalQueryRunner localQueryRunner = new LocalQueryRunner(build);
        localQueryRunner.createCatalog((String) build.getCatalog().get(), new TpchConnectorFactory(1), ImmutableMap.of());
        MLPlugin mLPlugin = new MLPlugin();
        Iterator it = mLPlugin.getTypes().iterator();
        while (it.hasNext()) {
            localQueryRunner.getTypeManager().addType((Type) it.next());
        }
        Iterator it2 = mLPlugin.getParametricTypes().iterator();
        while (it2.hasNext()) {
            localQueryRunner.getTypeManager().addParametricType((ParametricType) it2.next());
        }
        localQueryRunner.getMetadata().addFunctions(FunctionExtractor.extractFunctions(new MLPlugin().getFunctions()));
        return localQueryRunner;
    }
}
