package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.spi.type.DoubleType;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Map;
import java.util.Optional;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TestPushAggregationThroughOuterJoin.class */
public class TestPushAggregationThroughOuterJoin extends BaseRuleTest {
    @Test
    public void testPushesAggregationThroughLeftJoin() {
        tester().assertThat((Rule) new PushAggregationThroughOuterJoin()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.join(JoinNode.Type.LEFT, planBuilder.values(ImmutableList.of(planBuilder.symbol("COL1")), ImmutableList.of(PlanBuilder.expressions("10"))), planBuilder.values(planBuilder.symbol("COL2")), ImmutableList.of(new JoinNode.EquiJoinClause(planBuilder.symbol("COL1"), planBuilder.symbol("COL2"))), ImmutableList.of(planBuilder.symbol("COL1"), planBuilder.symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(planBuilder.symbol("AVG", DoubleType.DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DoubleType.DOUBLE)).addGroupingSet(planBuilder.symbol("COL1"));
            });
        }).matches(PlanMatchPattern.project(ImmutableMap.of("COL1", PlanMatchPattern.expression("COL1"), "COALESCE", PlanMatchPattern.expression("coalesce(AVG, AVG_NULL)")), PlanMatchPattern.join(JoinNode.Type.INNER, ImmutableList.of(), PlanMatchPattern.join(JoinNode.Type.LEFT, ImmutableList.of(PlanMatchPattern.equiJoinClause("COL1", "COL2")), PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("COL1", 0)), PlanMatchPattern.aggregation(ImmutableList.of(ImmutableList.of("COL2")), ImmutableMap.of(Optional.of("AVG"), PlanMatchPattern.functionCall("avg", ImmutableList.of("COL2"))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("COL2", 0)))), PlanMatchPattern.aggregation(ImmutableList.of(ImmutableList.of()), ImmutableMap.of(Optional.of("AVG_NULL"), PlanMatchPattern.functionCall("avg", ImmutableList.of("null_literal"))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("null_literal", 0))))));
    }

    @Test
    public void testPushesAggregationThroughRightJoin() {
        tester().assertThat((Rule) new PushAggregationThroughOuterJoin()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.join(JoinNode.Type.RIGHT, planBuilder.values(planBuilder.symbol("COL2")), planBuilder.values(ImmutableList.of(planBuilder.symbol("COL1")), ImmutableList.of(PlanBuilder.expressions("10"))), ImmutableList.of(new JoinNode.EquiJoinClause(planBuilder.symbol("COL2"), planBuilder.symbol("COL1"))), ImmutableList.of(planBuilder.symbol("COL2"), planBuilder.symbol("COL1")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(planBuilder.symbol("AVG", DoubleType.DOUBLE), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DoubleType.DOUBLE)).addGroupingSet(planBuilder.symbol("COL1"));
            });
        }).matches(PlanMatchPattern.project(ImmutableMap.of("COALESCE", PlanMatchPattern.expression("coalesce(AVG, AVG_NULL)"), "COL1", PlanMatchPattern.expression("COL1")), PlanMatchPattern.join(JoinNode.Type.INNER, ImmutableList.of(), PlanMatchPattern.join(JoinNode.Type.RIGHT, ImmutableList.of(PlanMatchPattern.equiJoinClause("COL2", "COL1")), PlanMatchPattern.aggregation(ImmutableList.of(ImmutableList.of("COL2")), ImmutableMap.of(Optional.of("AVG"), PlanMatchPattern.functionCall("avg", ImmutableList.of("COL2"))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("COL2", 0))), PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("COL1", 0))), PlanMatchPattern.aggregation(ImmutableList.of(ImmutableList.of()), ImmutableMap.of(Optional.of("AVG_NULL"), PlanMatchPattern.functionCall("avg", ImmutableList.of("null_literal"))), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("null_literal", 0))))));
    }

    @Test
    public void testDoesNotFireWhenNotDistinct() {
        tester().assertThat((Rule) new PushAggregationThroughOuterJoin()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.join(JoinNode.Type.LEFT, planBuilder.values(ImmutableList.of(planBuilder.symbol("COL1")), ImmutableList.of(PlanBuilder.expressions("10"), PlanBuilder.expressions("11"))), planBuilder.values(new Symbol("COL2")), ImmutableList.of(new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), ImmutableList.of(new Symbol("COL1"), new Symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(new Symbol("AVG"), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DoubleType.DOUBLE)).addGroupingSet(new Symbol("COL1"));
            });
        }).doesNotFire();
    }

    @Test
    public void testDoesNotFireWhenGroupingOnInner() {
        tester().assertThat((Rule) new PushAggregationThroughOuterJoin()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.source(planBuilder.join(JoinNode.Type.LEFT, planBuilder.values(ImmutableList.of(planBuilder.symbol("COL1")), ImmutableList.of(PlanBuilder.expressions("10"))), planBuilder.values(new Symbol("COL2"), new Symbol("COL3")), ImmutableList.of(new JoinNode.EquiJoinClause(new Symbol("COL1"), new Symbol("COL2"))), ImmutableList.of(new Symbol("COL1"), new Symbol("COL2")), Optional.empty(), Optional.empty(), Optional.empty())).addAggregation(new Symbol("AVG"), PlanBuilder.expression("avg(COL2)"), ImmutableList.of(DoubleType.DOUBLE)).addGroupingSet(new Symbol("COL1"), new Symbol("COL3"));
            });
        }).doesNotFire();
    }
}
