package com.facebook.presto.druid;

import com.facebook.presto.druid.DruidQueryGenerator;
import com.facebook.presto.druid.TestDruidQueryBase;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/druid/TestDruidQueryGenerator.class */
public class TestDruidQueryGenerator extends TestDruidQueryBase {
    private static final TestDruidQueryBase.SessionHolder defaultSessionHolder = new TestDruidQueryBase.SessionHolder();
    private static final DruidTableHandle druidTable = realtimeOnlyTable;

    private void testDQL(DruidConfig druidConfig, Function<PlanBuilder, PlanNode> function, String str, TestDruidQueryBase.SessionHolder sessionHolder, Map<String, String> map) {
        testDQL(druidConfig, function.apply(createPlanBuilder(sessionHolder)), str, sessionHolder, map);
    }

    private void testDQL(DruidConfig druidConfig, PlanNode planNode, String str, TestDruidQueryBase.SessionHolder sessionHolder, Map<String, String> map) {
        DruidQueryGenerator.DruidQueryGeneratorResult druidQueryGeneratorResult = (DruidQueryGenerator.DruidQueryGeneratorResult) new DruidQueryGenerator(typeManager, functionMetadataManager, standardFunctionResolution).generate(planNode, sessionHolder.getConnectorSession()).get();
        if (str.contains("__expressions__")) {
            str = str.replace("__expressions__", (String) planNode.getOutputVariables().stream().map(variableReferenceExpression -> {
                return (String) map.get(variableReferenceExpression.getName());
            }).filter(str2 -> {
                return str2 != null;
            }).collect(Collectors.joining(", ")));
        }
        Assert.assertEquals(druidQueryGeneratorResult.getGeneratedDql().getDql(), str);
    }

    private void testDQL(Function<PlanBuilder, PlanNode> function, String str, TestDruidQueryBase.SessionHolder sessionHolder, Map<String, String> map) {
        testDQL(this.druidConfig, function, str, sessionHolder, map);
    }

    private void testDQL(Function<PlanBuilder, PlanNode> function, String str, TestDruidQueryBase.SessionHolder sessionHolder) {
        testDQL(function, str, sessionHolder, ImmutableMap.of());
    }

    private void testDQL(Function<PlanBuilder, PlanNode> function, String str) {
        testDQL(function, str, defaultSessionHolder);
    }

    private PlanNode buildPlan(Function<PlanBuilder, PlanNode> function) {
        return function.apply(createPlanBuilder(defaultSessionHolder));
    }

    @Test
    public void testSimpleSelectStar() {
        testDQL(planBuilder -> {
            return limit(planBuilder, 50L, tableScan(planBuilder, druidTable, regionId, city, fare, secondsSinceEpoch));
        }, "SELECT regionId, city, fare, secondsSinceEpoch FROM realtimeOnly LIMIT 50");
        testDQL(planBuilder2 -> {
            return limit(planBuilder2, 10L, tableScan(planBuilder2, druidTable, regionId, secondsSinceEpoch));
        }, "SELECT regionId, secondsSinceEpoch FROM realtimeOnly LIMIT 10");
    }

    @Test
    public void testSimpleSelectWithFilterLimit() {
        testDQL(planBuilder -> {
            return limit(planBuilder, 30L, project(planBuilder, filter(planBuilder, tableScan(planBuilder, druidTable, regionId, city, fare, secondsSinceEpoch), getRowExpression("secondssinceepoch > 20", defaultSessionHolder)), ImmutableList.of("city", "secondssinceepoch")));
        }, "SELECT city, secondsSinceEpoch FROM realtimeOnly WHERE (secondsSinceEpoch > 20) LIMIT 30");
    }

    @Test
    public void testCountStar() {
        BiConsumer biConsumer = (planBuilder, aggregationBuilder) -> {
            aggregationBuilder.addAggregation(planBuilder.variable("agg"), getRowExpression("count(*)", defaultSessionHolder));
        };
        PlanNode buildPlan = buildPlan(planBuilder2 -> {
            return tableScan(planBuilder2, druidTable, regionId, secondsSinceEpoch, city, fare);
        });
        PlanNode buildPlan2 = buildPlan(planBuilder3 -> {
            return filter(planBuilder3, tableScan(planBuilder3, druidTable, regionId, secondsSinceEpoch, city, fare), getRowExpression("fare > 3", defaultSessionHolder));
        });
        PlanNode buildPlan3 = buildPlan(planBuilder4 -> {
            return filter(planBuilder4, tableScan(planBuilder4, druidTable, regionId, secondsSinceEpoch, city, fare), getRowExpression("secondssinceepoch between 200 and 300 and regionid >= 40", defaultSessionHolder));
        });
        testDQL(planBuilder5 -> {
            return planBuilder5.aggregation(aggregationBuilder2 -> {
                biConsumer.accept(planBuilder5, aggregationBuilder2.source(buildPlan).globalGrouping());
            });
        }, "SELECT count(*) FROM realtimeOnly");
        testDQL(planBuilder6 -> {
            return planBuilder6.aggregation(aggregationBuilder2 -> {
                biConsumer.accept(planBuilder6, aggregationBuilder2.source(buildPlan2).globalGrouping());
            });
        }, "SELECT count(*) FROM realtimeOnly WHERE (fare > 3)");
        testDQL(planBuilder7 -> {
            return planBuilder7.aggregation(aggregationBuilder2 -> {
                biConsumer.accept(planBuilder7, aggregationBuilder2.source(buildPlan2).singleGroupingSet(new VariableReferenceExpression[]{v("regionid")}));
            });
        }, "SELECT regionId, count(*) FROM realtimeOnly WHERE (fare > 3) GROUP BY regionId");
        testDQL(planBuilder8 -> {
            return planBuilder8.aggregation(aggregationBuilder2 -> {
                biConsumer.accept(planBuilder8, aggregationBuilder2.source(buildPlan).singleGroupingSet(new VariableReferenceExpression[]{v("regionid")}));
            });
        }, "SELECT regionId, count(*) FROM realtimeOnly GROUP BY regionId");
        testDQL(planBuilder9 -> {
            return planBuilder9.aggregation(aggregationBuilder2 -> {
                biConsumer.accept(planBuilder9, aggregationBuilder2.source(buildPlan3).singleGroupingSet(new VariableReferenceExpression[]{v("regionid"), v("city")}));
            });
        }, "SELECT regionId, city, count(*) FROM realtimeOnly WHERE ((secondsSinceEpoch BETWEEN 200 AND 300) AND (regionId >= 40)) GROUP BY regionId, city");
    }

    @Test
    public void testDistinctSelection() {
        PlanNode buildPlan = buildPlan(planBuilder -> {
            return tableScan(planBuilder, druidTable, regionId, secondsSinceEpoch, city, fare);
        });
        testDQL(planBuilder2 -> {
            return planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(buildPlan).singleGroupingSet(new VariableReferenceExpression[]{v("regionid")});
            });
        }, "SELECT regionId, count(*) FROM realtimeOnly GROUP BY regionId");
    }
}
