package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.block.SortOrder;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.DoubleType;
import com.facebook.presto.spi.Plugin;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Ordering;
import com.facebook.presto.spi.plan.OrderingScheme;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.planner.plan.AssignmentUtils;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.tree.SortItem;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.class */
public class TestPushAggregationThroughOuterJoin extends BaseRuleTest {
    public TestPushAggregationThroughOuterJoin() {
        super(new Plugin[0]);
    }

    @Test
    public void testPushesAggregationThroughLeftJoin() {
        tester().assertThat((Rule) new PushAggregationThroughOuterJoin(getFunctionManager())).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.join(JoinNode.Type.LEFT, planBuilder.values((List<VariableReferenceExpression>) ImmutableList.of(planBuilder.variable("COL1")), (List<List<RowExpression>>) ImmutableList.of(PlanBuilder.constantExpressions(BigintType.BIGINT, 10L))), planBuilder.values(planBuilder.variable("COL2")), ImmutableList.of(new JoinNode.EquiJoinClause(planBuilder.variable("COL1"), planBuilder.variable("COL2"))), ImmutableList.of(planBuilder.variable("COL1"), planBuilder.variable("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(planBuilder.variable("AVG", DoubleType.DOUBLE), planBuilder.rowExpression("avg(COL2)")).singleGroupingSet(planBuilder.variable("COL1"));
            });
        }).matches(PlanMatchPattern.project(ImmutableMap.of("COL1", PlanMatchPattern.expression("COL1"), "COALESCE", PlanMatchPattern.expression("coalesce(AVG, AVG_NULL)")), PlanMatchPattern.join(JoinNode.Type.INNER, ImmutableList.of(), PlanMatchPattern.join(JoinNode.Type.LEFT, ImmutableList.of(PlanMatchPattern.equiJoinClause("COL1", "COL2")), PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("COL1", 0)), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("COL2"), ImmutableMap.of(Optional.of("AVG"), PlanMatchPattern.functionCall("avg", ImmutableList.of("COL2"))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("COL2", 0)))), PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("AVG_NULL"), PlanMatchPattern.functionCall("avg", ImmutableList.of("null_literal"))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("null_literal", 0))))));
    }

    @Test
    public void testPushesAggregationThroughLeftJoinWithOrderByFromRightSideColumn() {
        tester().assertThat((Rule) new PushAggregationThroughOuterJoin(getFunctionManager())).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.join(JoinNode.Type.LEFT, planBuilder.values((List<VariableReferenceExpression>) ImmutableList.of(planBuilder.variable("COL1"), planBuilder.variable("COL3")), (List<List<RowExpression>>) ImmutableList.of(PlanBuilder.constantExpressions(BigintType.BIGINT, 10L, 20L))), planBuilder.values(planBuilder.variable("COL2"), planBuilder.variable("COL4")), ImmutableList.of(new JoinNode.EquiJoinClause(planBuilder.variable("COL1"), planBuilder.variable("COL2"))), ImmutableList.of(planBuilder.variable("COL1"), planBuilder.variable("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(planBuilder.variable("AVG", DoubleType.DOUBLE), planBuilder.rowExpression("avg(COL2)"), Optional.empty(), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(planBuilder.variable("COL4"), SortOrder.ASC_NULLS_LAST)))), false, Optional.empty()).singleGroupingSet(planBuilder.variable("COL1"), planBuilder.variable("COL3"));
            });
        }).matches(PlanMatchPattern.project(ImmutableMap.of("COL1", PlanMatchPattern.expression("COL1"), "COL3", PlanMatchPattern.expression("COL3"), "COALESCE", PlanMatchPattern.expression("coalesce(AVG, AVG_NULL)")), PlanMatchPattern.join(JoinNode.Type.INNER, ImmutableList.of(), PlanMatchPattern.join(JoinNode.Type.LEFT, ImmutableList.of(PlanMatchPattern.equiJoinClause("COL1", "COL2")), PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("COL1", 0, "COL3", 0)), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("COL2"), ImmutableMap.of(Optional.of("AVG"), PlanMatchPattern.functionCall("avg", (List<String>) ImmutableList.of("COL2"), (List<PlanMatchPattern.Ordering>) ImmutableList.of(PlanMatchPattern.sort("COL4", SortItem.Ordering.ASCENDING, SortItem.NullOrdering.LAST)))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((List<String>) ImmutableList.of("COL2", "COL4")))), PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("AVG_NULL"), PlanMatchPattern.functionCall("avg", (List<String>) ImmutableList.of("null_literal"), (List<PlanMatchPattern.Ordering>) ImmutableList.of(PlanMatchPattern.sort("null_literal2", SortItem.Ordering.ASCENDING, SortItem.NullOrdering.LAST)))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((List<String>) ImmutableList.of("null_literal", "null_literal2"))))));
    }

    @Test
    public void testPushesAggregationThroughRightJoin() {
        tester().assertThat((Rule) new PushAggregationThroughOuterJoin(getFunctionManager())).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.join(JoinNode.Type.RIGHT, planBuilder.values(planBuilder.variable("COL2")), planBuilder.values((List<VariableReferenceExpression>) ImmutableList.of(planBuilder.variable("COL1")), (List<List<RowExpression>>) ImmutableList.of(PlanBuilder.constantExpressions(BigintType.BIGINT, 10L))), ImmutableList.of(new JoinNode.EquiJoinClause(planBuilder.variable("COL2"), planBuilder.variable("COL1"))), ImmutableList.of(planBuilder.variable("COL2"), planBuilder.variable("COL1")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(planBuilder.variable("AVG", DoubleType.DOUBLE), planBuilder.rowExpression("avg(COL2)")).singleGroupingSet(planBuilder.variable("COL1"));
            });
        }).matches(PlanMatchPattern.project(ImmutableMap.of("COALESCE", PlanMatchPattern.expression("coalesce(AVG, AVG_NULL)"), "COL1", PlanMatchPattern.expression("COL1")), PlanMatchPattern.join(JoinNode.Type.INNER, ImmutableList.of(), PlanMatchPattern.join(JoinNode.Type.RIGHT, ImmutableList.of(PlanMatchPattern.equiJoinClause("COL2", "COL1")), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("COL2"), ImmutableMap.of(Optional.of("AVG"), PlanMatchPattern.functionCall("avg", ImmutableList.of("COL2"))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("COL2", 0))), PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("COL1", 0))), PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.of("AVG_NULL"), PlanMatchPattern.functionCall("avg", ImmutableList.of("null_literal"))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("null_literal", 0))))));
    }

    @Test
    public void testDoesNotFireWhenNotDistinct() {
        tester().assertThat((Rule) new PushAggregationThroughOuterJoin(getFunctionManager())).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.join(JoinNode.Type.LEFT, planBuilder.values((List<VariableReferenceExpression>) ImmutableList.of(planBuilder.variable("COL1")), (List<List<RowExpression>>) ImmutableList.of(PlanBuilder.constantExpressions(BigintType.BIGINT, 10L), PlanBuilder.constantExpressions(BigintType.BIGINT, 11L))), planBuilder.values(planBuilder.variable("COL2")), ImmutableList.of(new JoinNode.EquiJoinClause(planBuilder.variable("COL1"), planBuilder.variable("COL2"))), ImmutableList.of(planBuilder.variable("COL1"), planBuilder.variable("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(planBuilder.variable("AVG", DoubleType.DOUBLE), planBuilder.rowExpression("avg(COL2)")).singleGroupingSet(planBuilder.variable("COL1"));
            });
        }).doesNotFire();
        tester().assertThat((Rule) new PushAggregationThroughOuterJoin(getFunctionManager())).on(planBuilder2 -> {
            return planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder2.join(JoinNode.Type.LEFT, planBuilder2.project(AssignmentUtils.identityAssignmentsAsSymbolReferences(new VariableReferenceExpression[]{planBuilder2.variable("COL1", BigintType.BIGINT)}), planBuilder2.aggregation(aggregationBuilder -> {
                    aggregationBuilder.singleGroupingSet(planBuilder2.variable("COL1"), planBuilder2.variable("unused")).source(planBuilder2.values((List<VariableReferenceExpression>) ImmutableList.of(planBuilder2.variable("COL1"), planBuilder2.variable("unused")), (List<List<RowExpression>>) ImmutableList.of(PlanBuilder.constantExpressions(BigintType.BIGINT, 10L, 1L), PlanBuilder.constantExpressions(BigintType.BIGINT, 10L, 2L))));
                })), planBuilder2.values(planBuilder2.variable("COL2")), ImmutableList.of(new JoinNode.EquiJoinClause(planBuilder2.variable("COL1"), planBuilder2.variable("COL2"))), ImmutableList.of(planBuilder2.variable("COL1"), planBuilder2.variable("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(planBuilder2.variable("AVG", DoubleType.DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DoubleType.DOUBLE)).singleGroupingSet(planBuilder2.variable("COL1"));
            });
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireWhenGroupingOnInner() {
        tester().assertThat((Rule) new PushAggregationThroughOuterJoin(getFunctionManager())).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.join(JoinNode.Type.LEFT, planBuilder.values((List<VariableReferenceExpression>) ImmutableList.of(planBuilder.variable("COL1")), (List<List<RowExpression>>) ImmutableList.of(PlanBuilder.constantExpressions(BigintType.BIGINT, 10L))), planBuilder.values(planBuilder.variable("COL2"), planBuilder.variable("COL3")), ImmutableList.of(new JoinNode.EquiJoinClause(planBuilder.variable("COL1"), planBuilder.variable("COL2"))), ImmutableList.of(planBuilder.variable("COL1"), planBuilder.variable("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(planBuilder.variable("AVG", DoubleType.DOUBLE), planBuilder.rowExpression("avg(COL2)")).singleGroupingSet(planBuilder.variable("COL1"), planBuilder.variable("COL3"));
            });
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireWhenAggregationDoesNotHaveSymbols() {
        tester().assertThat((Rule) new PushAggregationThroughOuterJoin(getFunctionManager())).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.join(JoinNode.Type.LEFT, planBuilder.values((List<VariableReferenceExpression>) ImmutableList.of(planBuilder.variable("COL1")), (List<List<RowExpression>>) ImmutableList.of(PlanBuilder.constantExpressions(BigintType.BIGINT, 10L))), planBuilder.values((List<VariableReferenceExpression>) ImmutableList.of(planBuilder.variable("COL2")), (List<List<RowExpression>>) ImmutableList.of(PlanBuilder.constantExpressions(BigintType.BIGINT, 20L))), ImmutableList.of(new JoinNode.EquiJoinClause(planBuilder.variable("COL1"), planBuilder.variable("COL2"))), ImmutableList.of(planBuilder.variable("COL1"), planBuilder.variable("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(planBuilder.variable("SUM", DoubleType.DOUBLE), planBuilder.rowExpression("sum(COL1)")).singleGroupingSet(planBuilder.variable("COL1"));
            });
        }).doesNotFire();
    }
}
