package com.facebook.presto.druid;

import com.facebook.presto.Session;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.StatsAndCosts;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.druid.TestDruidQueryBase;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.planner.Plan;
import com.facebook.presto.sql.planner.PlanVariableAllocator;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.assertions.MatchResult;
import com.facebook.presto.sql.planner.assertions.Matcher;
import com.facebook.presto.sql.planner.assertions.PlanAssert;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.assertions.SymbolAliases;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/druid/TestDruidPlanOptimizer.class */
public class TestDruidPlanOptimizer extends TestDruidQueryBase {
    private final DruidTableHandle druidTableOne = realtimeOnlyTable;
    private final DruidTableHandle druidTableTwo = hybridTable;
    private final TestDruidQueryBase.SessionHolder defaultSessionHolder = new TestDruidQueryBase.SessionHolder();

    /* loaded from: input_file:com/facebook/presto/druid/TestDruidPlanOptimizer$DruidTableScanMatcher.class */
    static final class DruidTableScanMatcher implements Matcher {
        private final String tableName;
        private final String expectedDql;

        static PlanMatchPattern match(String str, String str2) {
            return PlanMatchPattern.node(TableScanNode.class, new PlanMatchPattern[0]).with(new DruidTableScanMatcher(str, str2));
        }

        private DruidTableScanMatcher(String str, String str2) {
            this.tableName = str;
            this.expectedDql = str2;
        }

        public boolean shapeMatches(PlanNode planNode) {
            return planNode instanceof TableScanNode;
        }

        public MatchResult detailMatches(PlanNode planNode, StatsProvider statsProvider, Session session, Metadata metadata, SymbolAliases symbolAliases) {
            Preconditions.checkState(shapeMatches(planNode), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", getClass().getName());
            DruidTableHandle connectorHandle = ((TableScanNode) planNode).getTable().getConnectorHandle();
            if (!connectorHandle.getTableName().equals(this.tableName)) {
                return MatchResult.NO_MATCH;
            }
            Optional map = connectorHandle.getDql().map((v0) -> {
                return v0.getDql();
            });
            return (map.isPresent() && ((String) map.get()).equalsIgnoreCase(this.expectedDql)) ? MatchResult.match() : MatchResult.NO_MATCH;
        }
    }

    @Test
    public void testUnionAll() {
        PlanBuilder createPlanBuilder = createPlanBuilder(this.defaultSessionHolder);
        PlanVariableAllocator planVariableAllocator = new PlanVariableAllocator();
        AggregationNode simpleAggregationSum = simpleAggregationSum(createPlanBuilder, tableScan(createPlanBuilder, this.druidTableOne, city, fare), planVariableAllocator, ImmutableList.of(city), fare);
        AggregationNode simpleAggregationSum2 = simpleAggregationSum(createPlanBuilder, tableScan(createPlanBuilder, this.druidTableTwo, city, fare), planVariableAllocator, ImmutableList.of(city), fare);
        VariableReferenceExpression newVariable = planVariableAllocator.newVariable(city.getColumnName(), city.getColumnType());
        VariableReferenceExpression newVariable2 = planVariableAllocator.newVariable(fare.getColumnName(), fare.getColumnType());
        assertPlanMatch(getOptimizedPlan(createPlanBuilder, new UnionNode(Optional.empty(), createPlanBuilder.getIdAllocator().getNextId(), ImmutableList.of(simpleAggregationSum, simpleAggregationSum2), ImmutableList.of(newVariable, newVariable2), ImmutableMap.of(newVariable, Stream.concat(simpleAggregationSum.getGroupingKeys().stream(), simpleAggregationSum2.getGroupingKeys().stream()).collect(ImmutableList.toImmutableList()), newVariable2, ImmutableList.of(Iterables.getOnlyElement(simpleAggregationSum.getAggregations().keySet()), Iterables.getOnlyElement(simpleAggregationSum2.getAggregations().keySet()))))), PlanMatchPattern.union(new PlanMatchPattern[]{DruidTableScanMatcher.match(this.druidTableOne.getTableName(), "SELECT \"city\", sum(fare) FROM \"realtimeOnly\" GROUP BY \"city\""), DruidTableScanMatcher.match(this.druidTableTwo.getTableName(), "SELECT \"city\", sum(fare) FROM \"hybrid\" GROUP BY \"city\"")}), this.typeProvider);
    }

    private void assertPlanMatch(PlanNode planNode, PlanMatchPattern planMatchPattern, TypeProvider typeProvider) {
        PlanAssert.assertPlan(this.defaultSessionHolder.getSession(), metadata, (planNode2, statsProvider, lookup, session, typeProvider2) -> {
            return PlanNodeStatsEstimate.unknown();
        }, new Plan(planNode, typeProvider, StatsAndCosts.empty()), planMatchPattern);
    }

    private PlanNode getOptimizedPlan(PlanBuilder planBuilder, PlanNode planNode) {
        return new DruidPlanOptimizer(new DruidQueryGenerator(functionAndTypeManager, functionAndTypeManager, standardFunctionResolution), functionAndTypeManager, new RowExpressionDeterminismEvaluator(functionAndTypeManager), functionAndTypeManager, standardFunctionResolution).optimize(planNode, this.defaultSessionHolder.getConnectorSession(), new PlanVariableAllocator(), planBuilder.getIdAllocator());
    }

    private AggregationNode simpleAggregationSum(PlanBuilder planBuilder, PlanNode planNode, PlanVariableAllocator planVariableAllocator, List<DruidColumnHandle> list, DruidColumnHandle druidColumnHandle) {
        return new AggregationNode(planNode.getSourceLocation(), planBuilder.getIdAllocator().getNextId(), planNode, ImmutableMap.of(planVariableAllocator.newVariable("sum", druidColumnHandle.getColumnType()), new AggregationNode.Aggregation(new CallExpression("sum", functionAndTypeManager.lookupFunction("sum", TypeSignatureProvider.fromTypes(new Type[]{druidColumnHandle.getColumnType()})), druidColumnHandle.getColumnType(), (List) planNode.getOutputVariables().stream().filter(variableReferenceExpression -> {
            return variableReferenceExpression.getName().startsWith(druidColumnHandle.getColumnName());
        }).collect(ImmutableList.toImmutableList())), Optional.empty(), Optional.empty(), false, Optional.empty())), AggregationNode.singleGroupingSet((List) planNode.getOutputVariables().stream().filter(variableReferenceExpression2 -> {
            Iterator it = list.iterator();
            while (it.hasNext()) {
                if (variableReferenceExpression2.getName().startsWith(((DruidColumnHandle) it.next()).getColumnName())) {
                    return true;
                }
            }
            return false;
        }).collect(ImmutableList.toImmutableList())), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
    }
}
