package com.facebook.presto.pinot.query;

import com.facebook.presto.Session;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.StatsAndCosts;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.pinot.PinotColumnHandle;
import com.facebook.presto.pinot.PinotConfig;
import com.facebook.presto.pinot.PinotPlanOptimizer;
import com.facebook.presto.pinot.PinotTableHandle;
import com.facebook.presto.pinot.TestPinotQueryBase;
import com.facebook.presto.pinot.TestPinotSplitManager;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.LimitNode;
import com.facebook.presto.spi.plan.MarkDistinctNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.Plan;
import com.facebook.presto.sql.planner.PlanVariableAllocator;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.assertions.MatchResult;
import com.facebook.presto.sql.planner.assertions.Matcher;
import com.facebook.presto.sql.planner.assertions.PlanAssert;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.assertions.SymbolAliases;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/pinot/query/TestPinotPlanOptimizer.class */
public class TestPinotPlanOptimizer extends TestPinotQueryBase {
    private final LogicalRowExpressions logicalRowExpressions = new LogicalRowExpressions(new RowExpressionDeterminismEvaluator(functionAndTypeManager), new FunctionResolution(functionAndTypeManager), functionAndTypeManager);
    protected final PinotTableHandle pinotTable = TestPinotSplitManager.hybridTable;
    protected final TestPinotQueryBase.SessionHolder defaultSessionHolder = getDefaultSessionHolder();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/facebook/presto/pinot/query/TestPinotPlanOptimizer$PinotTableScanMatcher.class */
    public static final class PinotTableScanMatcher implements Matcher {
        private final ConnectorId connectorId;
        private final String tableName;
        private final Optional<String> pinotQueryRegex;
        private final Optional<Boolean> scanParallelismExpected;
        private final String[] columns;
        private final boolean useSqlSyntax;

        static PlanMatchPattern match(String str, String str2, Optional<String> optional, Optional<Boolean> optional2, boolean z, String... strArr) {
            return PlanMatchPattern.node(TableScanNode.class, new PlanMatchPattern[0]).with(new PinotTableScanMatcher(new ConnectorId(str), str2, optional, optional2, z, strArr));
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public static PlanMatchPattern match(PinotTableHandle pinotTableHandle, Optional<String> optional, Optional<Boolean> optional2, List<VariableReferenceExpression> list, boolean z) {
            return match(pinotTableHandle.getConnectorId(), pinotTableHandle.getTableName(), optional, optional2, z, (String[]) list.stream().map((v0) -> {
                return v0.getName();
            }).toArray(i -> {
                return new String[i];
            }));
        }

        private PinotTableScanMatcher(ConnectorId connectorId, String str, Optional<String> optional, Optional<Boolean> optional2, boolean z, String... strArr) {
            this.connectorId = connectorId;
            this.pinotQueryRegex = optional;
            this.scanParallelismExpected = optional2;
            this.columns = strArr;
            this.tableName = str;
            this.useSqlSyntax = z;
        }

        public boolean shapeMatches(PlanNode planNode) {
            return planNode instanceof TableScanNode;
        }

        private static boolean checkPinotQueryMatches(Optional<String> optional, Optional<String> optional2) {
            if (!optional2.isPresent() && !optional.isPresent()) {
                return true;
            }
            if (!optional2.isPresent() || !optional.isPresent()) {
                return false;
            }
            return Pattern.compile(optional.get(), 2).matcher(optional2.get()).matches();
        }

        public MatchResult detailMatches(PlanNode planNode, StatsProvider statsProvider, Session session, Metadata metadata, SymbolAliases symbolAliases) {
            Preconditions.checkState(shapeMatches(planNode), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", getClass().getName());
            TableScanNode tableScanNode = (TableScanNode) planNode;
            if (this.connectorId.equals(tableScanNode.getTable().getConnectorId())) {
                PinotTableHandle connectorHandle = tableScanNode.getTable().getConnectorHandle();
                if (connectorHandle.getTableName().equals(this.tableName)) {
                    if (checkPinotQueryMatches(this.pinotQueryRegex, connectorHandle.getPinotQuery().map((v0) -> {
                        return v0.getQuery();
                    }))) {
                        return MatchResult.match(SymbolAliases.builder().putAll((Map) Arrays.stream(this.columns).collect(Collectors.toMap(Function.identity(), SymbolReference::new))).build());
                    }
                }
            }
            return MatchResult.NO_MATCH;
        }

        public String toString() {
            return MoreObjects.toStringHelper(this).add("connectorId", this.connectorId).add("tableName", this.tableName).add("pinotQueryRegex", this.pinotQueryRegex).add("scanParallelismExpected", this.scanParallelismExpected).add("columns", this.columns).toString();
        }
    }

    public TestPinotQueryBase.SessionHolder getDefaultSessionHolder() {
        return new TestPinotQueryBase.SessionHolder(false, useSqlSyntax());
    }

    public boolean useSqlSyntax() {
        return false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void assertPlanMatch(PlanNode planNode, PlanMatchPattern planMatchPattern, TypeProvider typeProvider) {
        PlanAssert.assertPlan(this.defaultSessionHolder.getSession(), metadata, (planNode2, statsProvider, lookup, session, typeProvider2) -> {
            return PlanNodeStatsEstimate.unknown();
        }, new Plan(planNode, typeProvider, StatsAndCosts.empty()), planMatchPattern);
    }

    @Test
    public void testLimitPushdownWithStarSelection() {
        PlanBuilder createPlanBuilder = createPlanBuilder(this.defaultSessionHolder);
        LimitNode limit = limit(createPlanBuilder, 50L, tableScan(createPlanBuilder, this.pinotTable, regionId, city, fare, secondsSinceEpoch));
        assertPlanMatch(getOptimizedPlan(createPlanBuilder, limit), PinotTableScanMatcher.match(this.pinotTable, Optional.of("SELECT \"regionId\", \"city\", \"fare\", \"secondsSinceEpoch\" FROM hybrid LIMIT 50"), Optional.of(false), limit.getOutputVariables(), useSqlSyntax()), this.typeProvider);
    }

    @Test
    public void testPartialPredicatePushdown() {
        PlanBuilder createPlanBuilder = createPlanBuilder(this.defaultSessionHolder);
        FilterNode filter = filter(createPlanBuilder, tableScan(createPlanBuilder, this.pinotTable, regionId, city, fare, secondsSinceEpoch), getRowExpression("lower(substr(city, 0, 3)) = 'del' AND fare > 100", this.defaultSessionHolder));
        assertPlanMatch(getOptimizedPlan(createPlanBuilder, limit(createPlanBuilder, 50L, filter)), PlanMatchPattern.limit(50L, PlanMatchPattern.filter("lower(substr(city, 0, 3)) = 'del'", PinotTableScanMatcher.match(this.pinotTable, Optional.of("SELECT \"regionId\", \"city\", \"fare\", \"secondsSinceEpoch\" FROM hybrid__TABLE_NAME_SUFFIX_TEMPLATE__ WHERE \\(\"fare\" > 100\\).*"), Optional.of(true), filter.getOutputVariables(), useSqlSyntax()))), this.typeProvider);
    }

    @Test
    public void testDatePredicatePushdown() {
        PlanBuilder createPlanBuilder = createPlanBuilder(this.defaultSessionHolder);
        LimitNode limit = limit(createPlanBuilder, 50L, filter(createPlanBuilder, tableScan(createPlanBuilder, this.pinotTable, regionId, city, fare, daysSinceEpoch), getRowExpression("dayssinceepoch < DATE '2014-01-31'", this.defaultSessionHolder)));
        assertPlanMatch(getOptimizedPlan(createPlanBuilder, limit), PinotTableScanMatcher.match(this.pinotTable, Optional.of("SELECT \"regionId\", \"city\", \"fare\", \"daysSinceEpoch\" FROM hybrid WHERE \\(\"daysSinceEpoch\" < 16101\\) LIMIT 50"), Optional.of(false), limit.getOutputVariables(), useSqlSyntax()), this.typeProvider);
    }

    @Test
    public void testDateCastingPredicatePushdown() {
        PlanBuilder createPlanBuilder = createPlanBuilder(this.defaultSessionHolder);
        LimitNode limit = limit(createPlanBuilder, 50L, filter(createPlanBuilder, tableScan(createPlanBuilder, this.pinotTable, regionId, city, fare, daysSinceEpoch), getRowExpression("cast(dayssinceepoch as timestamp) < TIMESTAMP '2014-01-31 00:00:00 UTC'", this.defaultSessionHolder)));
        assertPlanMatch(getOptimizedPlan(createPlanBuilder, limit), PinotTableScanMatcher.match(this.pinotTable, Optional.of("SELECT \"regionId\", \"city\", \"fare\", \"daysSinceEpoch\" FROM hybrid WHERE \\(\"daysSinceEpoch\" < 16101\\) LIMIT 50"), Optional.of(false), limit.getOutputVariables(), useSqlSyntax()), this.typeProvider);
    }

    @Test
    public void testTimestampPredicatePushdown() {
        PlanBuilder createPlanBuilder = createPlanBuilder(this.defaultSessionHolder);
        LimitNode limit = limit(createPlanBuilder, 50L, filter(createPlanBuilder, tableScan(createPlanBuilder, this.pinotTable, regionId, city, fare, millisSinceEpoch), getRowExpression("millissinceepoch < TIMESTAMP '2014-01-31 00:00:00 UTC'", this.defaultSessionHolder)));
        assertPlanMatch(getOptimizedPlan(createPlanBuilder, limit), PinotTableScanMatcher.match(this.pinotTable, Optional.of("SELECT \"regionId\", \"city\", \"fare\", \"millisSinceEpoch\" FROM hybrid WHERE \\(\"millisSinceEpoch\" < 1391126400000\\) LIMIT 50"), Optional.of(false), limit.getOutputVariables(), useSqlSyntax()), this.typeProvider);
    }

    @Test
    public void testTimestampCastingPredicatePushdown() {
        PlanBuilder createPlanBuilder = createPlanBuilder(this.defaultSessionHolder);
        LimitNode limit = limit(createPlanBuilder, 50L, filter(createPlanBuilder, tableScan(createPlanBuilder, this.pinotTable, regionId, city, fare, millisSinceEpoch), getRowExpression("cast(millissinceepoch as date) < DATE '2014-01-31'", this.defaultSessionHolder)));
        assertPlanMatch(getOptimizedPlan(createPlanBuilder, limit), PinotTableScanMatcher.match(this.pinotTable, Optional.of("SELECT \"regionId\", \"city\", \"fare\", \"millisSinceEpoch\" FROM hybrid WHERE \\(\"millisSinceEpoch\" < 1391126400000\\) LIMIT 50"), Optional.of(false), limit.getOutputVariables(), useSqlSyntax()), this.typeProvider);
    }

    @Test
    public void testDateFieldCompareToTimestampLiteralPredicatePushdown() {
        PlanBuilder createPlanBuilder = createPlanBuilder(this.defaultSessionHolder);
        LimitNode limit = limit(createPlanBuilder, 50L, filter(createPlanBuilder, tableScan(createPlanBuilder, this.pinotTable, regionId, city, fare, daysSinceEpoch), getRowExpression("dayssinceepoch <  TIMESTAMP '2014-01-31 00:00:00 UTC'", this.defaultSessionHolder)));
        assertPlanMatch(getOptimizedPlan(createPlanBuilder, limit), PinotTableScanMatcher.match(this.pinotTable, Optional.of("SELECT \"regionId\", \"city\", \"fare\", \"daysSinceEpoch\" FROM hybrid WHERE \\(\"dayssinceepoch\" < 16101\\) LIMIT 50"), Optional.of(false), limit.getOutputVariables(), useSqlSyntax()), this.typeProvider);
    }

    @Test
    public void testTimestampFieldCompareToDateLiteralPredicatePushdown() {
        PlanBuilder createPlanBuilder = createPlanBuilder(this.defaultSessionHolder);
        LimitNode limit = limit(createPlanBuilder, 50L, filter(createPlanBuilder, tableScan(createPlanBuilder, this.pinotTable, regionId, city, fare, millisSinceEpoch), getRowExpression("millissinceepoch <  DATE '2014-01-31'", this.defaultSessionHolder)));
        assertPlanMatch(getOptimizedPlan(createPlanBuilder, limit), PinotTableScanMatcher.match(this.pinotTable, Optional.of("SELECT \"regionId\", \"city\", \"fare\", \"millisSinceEpoch\" FROM hybrid WHERE \\(\"millisSinceEpoch\" < 1391126400000\\) LIMIT 50"), Optional.of(false), limit.getOutputVariables(), useSqlSyntax()), this.typeProvider);
    }

    @Test
    public void testUnsupportedPredicatePushdown() {
        ImmutableMap of = ImmutableMap.of("count", PlanMatchPattern.functionCall("count", false, ImmutableList.of()));
        PlanBuilder createPlanBuilder = createPlanBuilder(this.defaultSessionHolder);
        LimitNode limit = limit(createPlanBuilder, 50L, tableScan(createPlanBuilder, this.pinotTable, regionId, city, fare, secondsSinceEpoch));
        AggregationNode aggregation = createPlanBuilder.aggregation(aggregationBuilder -> {
            aggregationBuilder.source(limit).globalGrouping().addAggregation(new VariableReferenceExpression(Optional.empty(), "count", BigintType.BIGINT), getRowExpression("count(*)", this.defaultSessionHolder));
        });
        assertPlanMatch(getOptimizedPlan(createPlanBuilder, aggregation), PlanMatchPattern.aggregation(of, PinotTableScanMatcher.match(this.pinotTable, Optional.of("SELECT \"regionId\", \"city\", \"fare\", \"secondsSinceEpoch\" FROM hybrid LIMIT 50"), Optional.of(false), aggregation.getOutputVariables(), useSqlSyntax())), this.typeProvider);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public PlanNode getOptimizedPlan(PlanBuilder planBuilder, PlanNode planNode) {
        return getOptimizedPlan(new PinotConfig(), planBuilder, planNode);
    }

    protected PlanNode getOptimizedPlan(PinotConfig pinotConfig, PlanBuilder planBuilder, PlanNode planNode) {
        return new PinotPlanOptimizer(new PinotQueryGenerator(pinotConfig, functionAndTypeManager, functionAndTypeManager, standardFunctionResolution), functionAndTypeManager, functionAndTypeManager, this.logicalRowExpressions, standardFunctionResolution).optimize(planNode, new TestPinotQueryBase.SessionHolder(pinotConfig).getConnectorSession(), new PlanVariableAllocator(), planBuilder.getIdAllocator());
    }

    @Test
    public void testDistinctCountInSubQueryPushdown() {
        for (String str : Arrays.asList("DISTINCTCOUNT", "DISTINCTCOUNTBITMAP", "SEGMENTPARTITIONEDDISTINCTCOUNT")) {
            PinotConfig overrideDistinctCountFunction = new PinotConfig().setOverrideDistinctCountFunction(str);
            testDistinctCountInSubQueryPushdown(str, overrideDistinctCountFunction);
            testDistinctCountPushdownNoOverride(overrideDistinctCountFunction);
        }
    }

    private void testDistinctCountInSubQueryPushdown(String str, PinotConfig pinotConfig) {
        PlanBuilder createPlanBuilder = createPlanBuilder(new TestPinotQueryBase.SessionHolder(pinotConfig));
        MarkDistinctNode markDistinct = markDistinct(createPlanBuilder, variable("regionid$distinct"), ImmutableList.of(variable("regionid")), tableScan(createPlanBuilder, this.pinotTable, (Map<VariableReferenceExpression, PinotColumnHandle>) ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "regionid", regionId.getDataType()), regionId)));
        AggregationNode aggregation = createPlanBuilder.aggregation(aggregationBuilder -> {
            aggregationBuilder.source(markDistinct).addAggregation(createPlanBuilder.variable("count(regionid)"), getRowExpression("count(regionid)", this.defaultSessionHolder), Optional.empty(), Optional.empty(), false, Optional.of(variable("regionid$distinct"))).globalGrouping();
        });
        assertPlanMatch(getOptimizedPlan(pinotConfig, createPlanBuilder, aggregation), PinotTableScanMatcher.match(this.pinotTable, Optional.of(String.format("SELECT %s\\(\"regionId\"\\) FROM hybrid", str)), Optional.of(false), aggregation.getOutputVariables(), useSqlSyntax()), this.typeProvider);
        MarkDistinctNode markDistinct2 = markDistinct(createPlanBuilder, variable("regionid$distinct_62"), ImmutableList.of(variable("regionid")), tableScan(createPlanBuilder, this.pinotTable, (Map<VariableReferenceExpression, PinotColumnHandle>) ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "regionid_33", regionId.getDataType()), regionId)));
        AggregationNode aggregation2 = createPlanBuilder.aggregation(aggregationBuilder2 -> {
            aggregationBuilder2.source(markDistinct2).addAggregation(createPlanBuilder.variable("count(regionid_33)"), getRowExpression("count(regionid_33)", this.defaultSessionHolder), Optional.empty(), Optional.empty(), false, Optional.of(variable("regionid$distinct_62"))).globalGrouping();
        });
        assertPlanMatch(getOptimizedPlan(pinotConfig, createPlanBuilder, aggregation2), PinotTableScanMatcher.match(this.pinotTable, Optional.of(String.format("SELECT %s\\(\"regionId\"\\) FROM hybrid", str)), Optional.of(false), aggregation2.getOutputVariables(), useSqlSyntax()), this.typeProvider);
    }

    private void testDistinctCountPushdownNoOverride(PinotConfig pinotConfig) {
        PlanBuilder createPlanBuilder = createPlanBuilder(new TestPinotQueryBase.SessionHolder(pinotConfig));
        TableScanNode tableScan = tableScan(createPlanBuilder, this.pinotTable, (Map<VariableReferenceExpression, PinotColumnHandle>) ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "regionid", regionId.getDataType()), regionId));
        AggregationNode aggregation = createPlanBuilder.aggregation(aggregationBuilder -> {
            aggregationBuilder.source(tableScan).addAggregation(createPlanBuilder.variable("approx_distinct(regionid)"), getRowExpression("approx_distinct(regionid)", this.defaultSessionHolder), Optional.empty(), Optional.empty(), false, Optional.empty()).globalGrouping();
        });
        assertPlanMatch(getOptimizedPlan(pinotConfig, createPlanBuilder, aggregation), PinotTableScanMatcher.match(this.pinotTable, Optional.of("SELECT DISTINCTCOUNTHLL\\(\"regionId\"\\) FROM hybrid"), Optional.of(false), aggregation.getOutputVariables(), useSqlSyntax()), this.typeProvider);
        PlanNode optimizedPlan = getOptimizedPlan(createPlanBuilder, limit(createPlanBuilder, 50L, tableScan(createPlanBuilder, this.pinotTable, distinctCountDim)));
        assertPlanMatch(optimizedPlan, PinotTableScanMatcher.match(this.pinotTable, Optional.of("SELECT \"distinctCountDim\" FROM hybrid LIMIT 50"), Optional.of(false), optimizedPlan.getOutputVariables(), useSqlSyntax()), this.typeProvider);
    }

    @Test
    public void testSetOperationQueryWithSubQueriesPushdown() {
        PlanBuilder createPlanBuilder = createPlanBuilder(this.defaultSessionHolder);
        MarkDistinctNode markDistinct = markDistinct(createPlanBuilder, variable("regionid$distinct"), ImmutableList.of(variable("regionid")), tableScan(createPlanBuilder, this.pinotTable, (Map<VariableReferenceExpression, PinotColumnHandle>) ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "regionid", regionId.getDataType()), regionId)));
        AggregationNode aggregation = createPlanBuilder.aggregation(aggregationBuilder -> {
            aggregationBuilder.source(markDistinct).addAggregation(createPlanBuilder.variable("count(regionid)"), getRowExpression("count(regionid)", this.defaultSessionHolder), Optional.empty(), Optional.empty(), false, Optional.of(variable("regionid$distinct"))).globalGrouping();
        });
        MarkDistinctNode markDistinct2 = markDistinct(createPlanBuilder, variable("regionid$distinct_62"), ImmutableList.of(variable("regionid")), tableScan(createPlanBuilder, this.pinotTable, (Map<VariableReferenceExpression, PinotColumnHandle>) ImmutableMap.of(new VariableReferenceExpression(Optional.empty(), "regionid_33", regionId.getDataType()), regionId)));
        AggregationNode aggregation2 = createPlanBuilder.aggregation(aggregationBuilder2 -> {
            aggregationBuilder2.source(markDistinct2).addAggregation(createPlanBuilder.variable("count(regionid_33)"), getRowExpression("count(regionid_33)", this.defaultSessionHolder), Optional.empty(), Optional.empty(), false, Optional.of(variable("regionid$distinct_62"))).globalGrouping();
        });
        validateSetOperationOptimizer(createPlanBuilder, createPlanBuilder.union(ArrayListMultimap.create(), ImmutableList.of(aggregation, aggregation2)));
        validateSetOperationOptimizer(createPlanBuilder, createPlanBuilder.intersect(ArrayListMultimap.create(), ImmutableList.of(aggregation, aggregation2)));
        validateSetOperationOptimizer(createPlanBuilder, createPlanBuilder.except(ArrayListMultimap.create(), ImmutableList.of(aggregation, aggregation2)));
    }

    private void validateSetOperationOptimizer(PlanBuilder planBuilder, PlanNode planNode) {
        for (PlanNode planNode2 : getOptimizedPlan(planBuilder, planNode).getSources()) {
            assertPlanMatch(planNode2, PinotTableScanMatcher.match(this.pinotTable, Optional.of("SELECT DISTINCTCOUNT\\(\"regionId\"\\) FROM hybrid"), Optional.of(false), planNode2.getOutputVariables(), useSqlSyntax()), this.typeProvider);
        }
    }
}
