package com.facebook.presto.pinot.query;

import com.facebook.presto.common.block.SortOrder;
import com.facebook.presto.common.type.TimestampType;
import com.facebook.presto.pinot.PinotConfig;
import com.facebook.presto.pinot.PinotTableHandle;
import com.facebook.presto.pinot.TestPinotQueryBase;
import com.facebook.presto.pinot.query.PinotQueryGenerator;
import com.facebook.presto.spi.plan.Ordering;
import com.facebook.presto.spi.plan.OrderingScheme;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.TopNNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/pinot/query/TestPinotQueryGenerator.class */
public class TestPinotQueryGenerator extends TestPinotQueryBase {
    private static final TestPinotQueryBase.SessionHolder defaultSessionHolder = new TestPinotQueryBase.SessionHolder(false);
    private static final PinotTableHandle pinotTable = realtimeOnlyTable;

    private void testPQL(PinotConfig pinotConfig, Function<PlanBuilder, PlanNode> function, String str, TestPinotQueryBase.SessionHolder sessionHolder, Map<String, String> map) {
        testPQL(pinotConfig, function.apply(createPlanBuilder(sessionHolder)), str, sessionHolder, map);
    }

    private void testPQL(PinotConfig pinotConfig, PlanNode planNode, String str, TestPinotQueryBase.SessionHolder sessionHolder, Map<String, String> map) {
        PinotQueryGenerator.PinotQueryGeneratorResult pinotQueryGeneratorResult = (PinotQueryGenerator.PinotQueryGeneratorResult) new PinotQueryGenerator(pinotConfig, typeManager, functionMetadataManager, standardFunctionResolution).generate(planNode, sessionHolder.getConnectorSession()).get();
        if (str.contains("__expressions__")) {
            str = str.replace("__expressions__", (String) planNode.getOutputVariables().stream().map(variableReferenceExpression -> {
                return (String) map.get(variableReferenceExpression.getName());
            }).filter(str2 -> {
                return str2 != null;
            }).collect(Collectors.joining(", ")));
        }
        Assert.assertEquals(pinotQueryGeneratorResult.getGeneratedPql().getPql(), str);
    }

    private void testPQL(Function<PlanBuilder, PlanNode> function, String str, TestPinotQueryBase.SessionHolder sessionHolder, Map<String, String> map) {
        testPQL(this.pinotConfig, function, str, sessionHolder, map);
    }

    private void testPQL(Function<PlanBuilder, PlanNode> function, String str, TestPinotQueryBase.SessionHolder sessionHolder) {
        testPQL(function, str, sessionHolder, ImmutableMap.of());
    }

    private void testPQL(PinotConfig pinotConfig, Function<PlanBuilder, PlanNode> function, String str) {
        testPQL(pinotConfig, function, str, defaultSessionHolder, (Map<String, String>) ImmutableMap.of());
    }

    private void testPQL(PinotConfig pinotConfig, PlanNode planNode, String str) {
        testPQL(pinotConfig, planNode, str, defaultSessionHolder, (Map<String, String>) ImmutableMap.of());
    }

    private void testPQL(Function<PlanBuilder, PlanNode> function, String str) {
        testPQL(function, str, defaultSessionHolder);
    }

    private PlanNode buildPlan(Function<PlanBuilder, PlanNode> function) {
        return function.apply(createPlanBuilder(defaultSessionHolder));
    }

    private void testUnaryAggregationHelper(BiConsumer<PlanBuilder, PlanBuilder.AggregationBuilder> biConsumer, String str) {
        PlanNode buildPlan = buildPlan(planBuilder -> {
            return tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare);
        });
        PlanNode buildPlan2 = buildPlan(planBuilder2 -> {
            return filter(planBuilder2, tableScan(planBuilder2, pinotTable, regionId, secondsSinceEpoch, city, fare), getRowExpression("fare > 3", defaultSessionHolder));
        });
        PlanNode buildPlan3 = buildPlan(planBuilder3 -> {
            return filter(planBuilder3, tableScan(planBuilder3, pinotTable, regionId, secondsSinceEpoch, city, fare), getRowExpression("secondssinceepoch between 200 and 300 and regionid >= 40", defaultSessionHolder));
        });
        testPQL(planBuilder4 -> {
            return planBuilder4.aggregation(aggregationBuilder -> {
                biConsumer.accept(planBuilder4, aggregationBuilder.source(buildPlan).globalGrouping());
            });
        }, String.format("SELECT %s FROM realtimeOnly", str));
        testPQL(planBuilder5 -> {
            return planBuilder5.aggregation(aggregationBuilder -> {
                biConsumer.accept(planBuilder5, aggregationBuilder.source(buildPlan2).globalGrouping());
            });
        }, String.format("SELECT %s FROM realtimeOnly WHERE (fare > 3)", str));
        testPQL(planBuilder6 -> {
            return planBuilder6.aggregation(aggregationBuilder -> {
                biConsumer.accept(planBuilder6, aggregationBuilder.source(buildPlan2).singleGroupingSet(new VariableReferenceExpression[]{variable("regionid")}));
            });
        }, String.format("SELECT %s FROM realtimeOnly WHERE (fare > 3) GROUP BY regionId TOP 10000", str));
        testPQL(planBuilder7 -> {
            return planBuilder7.aggregation(aggregationBuilder -> {
                biConsumer.accept(planBuilder7, aggregationBuilder.source(buildPlan).singleGroupingSet(new VariableReferenceExpression[]{variable("regionid")}));
            });
        }, String.format("SELECT %s FROM realtimeOnly GROUP BY regionId TOP 10000", str));
        testPQL(planBuilder8 -> {
            return planBuilder8.aggregation(aggregationBuilder -> {
                biConsumer.accept(planBuilder8, aggregationBuilder.source(buildPlan3).singleGroupingSet(new VariableReferenceExpression[]{variable("regionid"), variable("city")}));
            });
        }, String.format("SELECT %s FROM realtimeOnly WHERE ((secondsSinceEpoch BETWEEN 200 AND 300) AND (regionId >= 40)) GROUP BY regionId, city TOP 10000", str));
    }

    @Test
    public void testSimpleSelectStar() {
        testPQL(planBuilder -> {
            return limit(planBuilder, 50L, tableScan(planBuilder, pinotTable, regionId, city, fare, secondsSinceEpoch));
        }, "SELECT regionId, city, fare, secondsSinceEpoch FROM realtimeOnly LIMIT 50");
        testPQL(planBuilder2 -> {
            return limit(planBuilder2, 50L, tableScan(planBuilder2, pinotTable, regionId, secondsSinceEpoch));
        }, "SELECT regionId, secondsSinceEpoch FROM realtimeOnly LIMIT 50");
    }

    @Test
    public void testSimpleSelectWithFilterLimit() {
        testPQL(planBuilder -> {
            return limit(planBuilder, 50L, project(planBuilder, filter(planBuilder, tableScan(planBuilder, pinotTable, regionId, city, fare, secondsSinceEpoch), getRowExpression("secondssinceepoch > 20", defaultSessionHolder)), ImmutableList.of("city", "secondssinceepoch")));
        }, "SELECT city, secondsSinceEpoch FROM realtimeOnly WHERE (secondsSinceEpoch > 20) LIMIT 50");
    }

    @Test
    public void testCountStar() {
        testUnaryAggregationHelper((planBuilder, aggregationBuilder) -> {
            aggregationBuilder.addAggregation(planBuilder.variable("agg"), getRowExpression("count(*)", defaultSessionHolder));
        }, "count(*)");
    }

    @Test
    public void testDistinctCountPushdown() {
        PlanNode buildPlan = buildPlan(planBuilder -> {
            return tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare);
        });
        PlanNode buildPlan2 = buildPlan(planBuilder2 -> {
            return planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(buildPlan).singleGroupingSet(new VariableReferenceExpression[]{variable("regionid")});
            });
        });
        testPQL(planBuilder3 -> {
            return planBuilder3.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(buildPlan2).globalGrouping().addAggregation(variable("count_regionid"), getRowExpression("count(regionid)", defaultSessionHolder));
            });
        }, "SELECT DISTINCTCOUNT(regionId) FROM realtimeOnly");
    }

    @Test
    public void testDistinctCountGroupByPushdown() {
        PlanNode buildPlan = buildPlan(planBuilder -> {
            return tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare);
        });
        PlanNode buildPlan2 = buildPlan(planBuilder2 -> {
            return planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(buildPlan).singleGroupingSet(new VariableReferenceExpression[]{variable("city"), variable("regionid")});
            });
        });
        testPQL(planBuilder3 -> {
            return planBuilder3.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(buildPlan2).singleGroupingSet(new VariableReferenceExpression[]{variable("city")}).addAggregation(variable("count_regionid"), getRowExpression("count(regionid)", defaultSessionHolder));
            });
        }, "SELECT DISTINCTCOUNT(regionId) FROM realtimeOnly GROUP BY city TOP 10000");
    }

    @Test
    public void testDistinctCountWithOtherAggregationPushdown() {
        PlanNode buildPlan = buildPlan(planBuilder -> {
            return tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare);
        });
        PlanNode buildPlan2 = buildPlan(planBuilder2 -> {
            return markDistinct(planBuilder2, variable("regionid$distinct"), ImmutableList.of(variable("regionid")), buildPlan);
        });
        PlanNode buildPlan3 = buildPlan(planBuilder3 -> {
            return planBuilder3.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(buildPlan2).addAggregation(planBuilder3.variable("agg"), getRowExpression("count(*)", defaultSessionHolder)).addAggregation(planBuilder3.variable("count(regionid)"), getRowExpression("count(regionid)", defaultSessionHolder), Optional.empty(), Optional.empty(), false, Optional.of(variable("regionid$distinct"))).globalGrouping();
            });
        });
        testPQL(new PinotConfig().setAllowMultipleAggregations(true), planBuilder4 -> {
            return planBuilder4.limit(10L, buildPlan3);
        }, ((VariableReferenceExpression) buildPlan3.getOutputVariables().get(0)).getName().equalsIgnoreCase("count(regionid)") ? "SELECT DISTINCTCOUNT(regionId), count(*) FROM realtimeOnly" : "SELECT count(*), DISTINCTCOUNT(regionId) FROM realtimeOnly");
    }

    @Test
    public void testDistinctCountWithOtherAggregationGroupByPushdown() {
        PlanNode buildPlan = buildPlan(planBuilder -> {
            return tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare);
        });
        PlanNode buildPlan2 = buildPlan(planBuilder2 -> {
            return markDistinct(planBuilder2, variable("regionid$distinct"), ImmutableList.of(variable("regionid")), buildPlan);
        });
        PlanNode buildPlan3 = buildPlan(planBuilder3 -> {
            return planBuilder3.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(buildPlan2).singleGroupingSet(new VariableReferenceExpression[]{variable("city")}).addAggregation(planBuilder3.variable("agg"), getRowExpression("count(*)", defaultSessionHolder)).addAggregation(planBuilder3.variable("count(regionid)"), getRowExpression("count(regionid)", defaultSessionHolder), Optional.empty(), Optional.empty(), false, Optional.of(variable("regionid$distinct")));
            });
        });
        testPQL(new PinotConfig().setAllowMultipleAggregations(true), buildPlan3, ((VariableReferenceExpression) buildPlan3.getOutputVariables().get(1)).getName().equalsIgnoreCase("count(regionid)") ? "SELECT DISTINCTCOUNT(regionId), count(*) FROM realtimeOnly GROUP BY city TOP 10000" : "SELECT count(*), DISTINCTCOUNT(regionId) FROM realtimeOnly GROUP BY city TOP 10000");
    }

    @Test
    public void testDistinctSelection() {
        PlanNode buildPlan = buildPlan(planBuilder -> {
            return tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare);
        });
        testPQL(planBuilder2 -> {
            return planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(buildPlan).singleGroupingSet(new VariableReferenceExpression[]{variable("regionid")});
            });
        }, "SELECT count(*) FROM realtimeOnly GROUP BY regionId TOP 10000");
    }

    @Test
    public void testPercentileAggregation() {
        testUnaryAggregationHelper((planBuilder, aggregationBuilder) -> {
            aggregationBuilder.addAggregation(planBuilder.variable("agg"), getRowExpression("approx_percentile(fare, 0.10)", defaultSessionHolder));
        }, "PERCENTILEEST10(fare)");
    }

    @Test
    public void testApproxDistinct() {
        testUnaryAggregationHelper((planBuilder, aggregationBuilder) -> {
            aggregationBuilder.addAggregation(planBuilder.variable("agg"), getRowExpression("approx_distinct(fare)", defaultSessionHolder));
        }, "DISTINCTCOUNTHLL(fare)");
    }

    @Test
    public void testAggWithUDFInGroupBy() {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put("date", "date_trunc('day', cast(from_unixtime(secondssinceepoch - 50) AS TIMESTAMP))");
        PlanNode buildPlan = buildPlan(planBuilder -> {
            return project(planBuilder, tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare), linkedHashMap, defaultSessionHolder);
        });
        testPQL(planBuilder2 -> {
            return planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(buildPlan).singleGroupingSet(new VariableReferenceExpression[]{new VariableReferenceExpression("date", TimestampType.TIMESTAMP)}).addAggregation(planBuilder2.variable("agg"), getRowExpression("count(*)", defaultSessionHolder));
            });
        }, "SELECT count(*) FROM realtimeOnly GROUP BY dateTimeConvert(SUB(secondsSinceEpoch, 50), '1:SECONDS:EPOCH', '1:MILLISECONDS:EPOCH', '1:DAYS') TOP 10000");
        linkedHashMap.put("city", "city");
        PlanNode buildPlan2 = buildPlan(planBuilder3 -> {
            return project(planBuilder3, tableScan(planBuilder3, pinotTable, regionId, secondsSinceEpoch, city, fare), linkedHashMap, defaultSessionHolder);
        });
        testPQL(planBuilder4 -> {
            return planBuilder4.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(buildPlan2).singleGroupingSet(new VariableReferenceExpression[]{new VariableReferenceExpression("date", TimestampType.TIMESTAMP), variable("city")}).addAggregation(planBuilder4.variable("agg"), getRowExpression("count(*)", defaultSessionHolder));
            });
        }, "SELECT count(*) FROM realtimeOnly GROUP BY dateTimeConvert(SUB(secondsSinceEpoch, 50), '1:SECONDS:EPOCH', '1:MILLISECONDS:EPOCH', '1:DAYS'), city TOP 10000");
    }

    @Test
    public void testMultipleAggregatesWithOutGroupBy() {
        ImmutableMap of = ImmutableMap.of("agg", "count(*)", "min", "min(fare)");
        PlanNode buildPlan = buildPlan(planBuilder -> {
            return tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare);
        });
        testPQL(planBuilder2 -> {
            return planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(buildPlan).globalGrouping().addAggregation(planBuilder2.variable("agg"), getRowExpression("count(*)", defaultSessionHolder)).addAggregation(planBuilder2.variable("min"), getRowExpression("min(fare)", defaultSessionHolder));
            });
        }, "SELECT __expressions__ FROM realtimeOnly", defaultSessionHolder, of);
        testPQL(planBuilder3 -> {
            return planBuilder3.limit(50L, planBuilder3.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(buildPlan).globalGrouping().addAggregation(planBuilder3.variable("agg"), getRowExpression("count(*)", defaultSessionHolder)).addAggregation(planBuilder3.variable("min"), getRowExpression("min(fare)", defaultSessionHolder));
            }));
        }, "SELECT __expressions__ FROM realtimeOnly", defaultSessionHolder, of);
    }

    @Test
    public void testMultipleAggregatesWhenAllowed() {
        helperTestMultipleAggregatesWithGroupBy(new PinotConfig().setAllowMultipleAggregations(true));
    }

    @Test(expectedExceptions = {NoSuchElementException.class})
    public void testMultipleAggregatesNotAllowed() {
        helperTestMultipleAggregatesWithGroupBy(this.pinotConfig);
    }

    private void helperTestMultipleAggregatesWithGroupBy(PinotConfig pinotConfig) {
        ImmutableMap of = ImmutableMap.of("agg", "count(*)", "min", "min(fare)");
        PlanNode buildPlan = buildPlan(planBuilder -> {
            return tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare);
        });
        testPQL(pinotConfig, planBuilder2 -> {
            return planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(buildPlan).singleGroupingSet(new VariableReferenceExpression[]{variable("city")}).addAggregation(planBuilder2.variable("agg"), getRowExpression("count(*)", defaultSessionHolder)).addAggregation(planBuilder2.variable("min"), getRowExpression("min(fare)", defaultSessionHolder));
            });
        }, "SELECT __expressions__ FROM realtimeOnly GROUP BY city TOP 10000", defaultSessionHolder, (Map<String, String>) of);
    }

    @Test(expectedExceptions = {NoSuchElementException.class})
    public void testMultipleAggregateGroupByWithLimitFails() {
        ImmutableMap of = ImmutableMap.of("agg", "count(*)", "min", "min(fare)");
        PlanNode buildPlan = buildPlan(planBuilder -> {
            return tableScan(planBuilder, pinotTable, regionId, secondsSinceEpoch, city, fare);
        });
        testPQL(planBuilder2 -> {
            return planBuilder2.limit(50L, planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(buildPlan).singleGroupingSet(new VariableReferenceExpression[]{variable("city")}).addAggregation(planBuilder2.variable("agg"), getRowExpression("count(*)", defaultSessionHolder)).addAggregation(planBuilder2.variable("min"), getRowExpression("min(fare)", defaultSessionHolder));
            }));
        }, "SELECT __expressions__ FROM realtimeOnly GROUP BY city TOP 50", defaultSessionHolder, of);
    }

    @Test(expectedExceptions = {NoSuchElementException.class})
    public void testForbiddenProjectionOutsideOfAggregation() {
        LinkedHashMap linkedHashMap = new LinkedHashMap((Map) ImmutableMap.of("hour", "date_trunc('hour', from_unixtime(secondssinceepoch))", "regionid", "regionid"));
        testPQL(this.pinotConfig, buildPlan(planBuilder -> {
            return limit(planBuilder, 10L, project(planBuilder, tableScan(planBuilder, pinotTable, secondsSinceEpoch, regionId), linkedHashMap, defaultSessionHolder));
        }), "Should fail", defaultSessionHolder, (Map<String, String>) ImmutableMap.of());
    }

    @Test
    public void testSimpleSelectWithTopN() {
        PlanBuilder createPlanBuilder = createPlanBuilder(defaultSessionHolder);
        TableScanNode tableScan = tableScan(createPlanBuilder, pinotTable, regionId, city, fare);
        testPQL(this.pinotConfig, (PlanNode) topN(createPlanBuilder, 50L, ImmutableList.of("fare"), ImmutableList.of(false), tableScan), "SELECT regionId, city, fare FROM realtimeOnly ORDER BY fare DESC LIMIT 50", defaultSessionHolder, (Map<String, String>) ImmutableMap.of());
        testPQL(this.pinotConfig, (PlanNode) topN(createPlanBuilder, 50L, ImmutableList.of("fare", "city"), ImmutableList.of(true, false), tableScan), "SELECT regionId, city, fare FROM realtimeOnly ORDER BY fare, city DESC LIMIT 50", defaultSessionHolder, (Map<String, String>) ImmutableMap.of());
    }

    @Test(expectedExceptions = {NoSuchElementException.class})
    public void testAggregationWithOrderByPushDownInTopN() {
        PlanBuilder createPlanBuilder = createPlanBuilder(defaultSessionHolder);
        TableScanNode tableScan = tableScan(createPlanBuilder, pinotTable, city, fare);
        testPQL(this.pinotConfig, (PlanNode) new TopNNode(createPlanBuilder.getIdAllocator().getNextId(), createPlanBuilder.aggregation(aggregationBuilder -> {
            aggregationBuilder.source(tableScan).singleGroupingSet(new VariableReferenceExpression[]{variable("city")}).addAggregation(createPlanBuilder.variable("agg"), getRowExpression("sum(fare)", defaultSessionHolder));
        }), 50L, new OrderingScheme(ImmutableList.of(new Ordering(variable("city"), SortOrder.DESC_NULLS_FIRST))), TopNNode.Step.FINAL), "", defaultSessionHolder, (Map<String, String>) ImmutableMap.of());
    }
}
