package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.spi.Plugin;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.assertions.ExpectedValueProvider;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.assertions.PlanTestSymbol;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.tree.FunctionCall;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TestAddIntermediateAggregations.class */
public class TestAddIntermediateAggregations extends BaseRuleTest {
    public TestAddIntermediateAggregations() {
        super(new Plugin[0]);
    }

    @Test
    public void testBasic() {
        ExpectedValueProvider<FunctionCall> functionCall = PlanMatchPattern.functionCall("count", false, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.anySymbol()));
        tester().assertThat(new AddIntermediateAggregations()).setSystemProperty("enable_intermediate_aggregations", "true").setSystemProperty("task_concurrency", "4").on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.symbol("c"), PlanBuilder.expression("count(b)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.gatheringExchange(ExchangeNode.Scope.REMOTE, planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.globalGrouping().step(AggregationNode.Step.PARTIAL).addAggregation(planBuilder.symbol("b"), PlanBuilder.expression("count(a)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("a")));
                })));
            });
        }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.empty(), functionCall), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.GATHER, PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.empty(), functionCall), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.INTERMEDIATE, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.GATHER, PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.empty(), functionCall), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.INTERMEDIATE, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.GATHER, PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.empty(), functionCall), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.PARTIAL, PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("a", 0)))))))))));
    }

    @Test
    public void testNoInputCount() {
        ExpectedValueProvider<FunctionCall> functionCall = PlanMatchPattern.functionCall("count", false, (List<PlanTestSymbol>) ImmutableList.of());
        ExpectedValueProvider<FunctionCall> functionCall2 = PlanMatchPattern.functionCall("count", false, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.anySymbol()));
        tester().assertThat(new AddIntermediateAggregations()).setSystemProperty("enable_intermediate_aggregations", "true").setSystemProperty("task_concurrency", "4").on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.symbol("c"), PlanBuilder.expression("count(b)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.gatheringExchange(ExchangeNode.Scope.REMOTE, planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.globalGrouping().step(AggregationNode.Step.PARTIAL).addAggregation(planBuilder.symbol("b"), PlanBuilder.expression("count(*)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("a")));
                })));
            });
        }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.empty(), functionCall2), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.GATHER, PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.empty(), functionCall2), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.INTERMEDIATE, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.GATHER, PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.empty(), functionCall2), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.INTERMEDIATE, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.GATHER, PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.empty(), functionCall), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.PARTIAL, PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("a", 0)))))))))));
    }

    @Test
    public void testMultipleExchanges() {
        ExpectedValueProvider<FunctionCall> functionCall = PlanMatchPattern.functionCall("count", false, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.anySymbol()));
        tester().assertThat(new AddIntermediateAggregations()).setSystemProperty("enable_intermediate_aggregations", "true").setSystemProperty("task_concurrency", "4").on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.symbol("c"), PlanBuilder.expression("count(b)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.gatheringExchange(ExchangeNode.Scope.REMOTE, planBuilder.gatheringExchange(ExchangeNode.Scope.REMOTE, planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.globalGrouping().step(AggregationNode.Step.PARTIAL).addAggregation(planBuilder.symbol("b"), PlanBuilder.expression("count(a)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("a")));
                }))));
            });
        }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.empty(), functionCall), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.GATHER, PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.empty(), functionCall), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.INTERMEDIATE, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.GATHER, PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.GATHER, PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.empty(), functionCall), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.INTERMEDIATE, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.GATHER, PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.empty(), functionCall), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.PARTIAL, PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("a", 0))))))))))));
    }

    @Test
    public void testSessionDisable() {
        tester().assertThat(new AddIntermediateAggregations()).setSystemProperty("enable_intermediate_aggregations", "false").setSystemProperty("task_concurrency", "4").on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.symbol("c"), PlanBuilder.expression("count(b)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.gatheringExchange(ExchangeNode.Scope.REMOTE, planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.globalGrouping().step(AggregationNode.Step.PARTIAL).addAggregation(planBuilder.symbol("b"), PlanBuilder.expression("count(a)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("a")));
                })));
            });
        }).doesNotFire();
    }

    @Test
    public void testNoLocalParallel() {
        ExpectedValueProvider<FunctionCall> functionCall = PlanMatchPattern.functionCall("count", false, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.anySymbol()));
        tester().assertThat(new AddIntermediateAggregations()).setSystemProperty("enable_intermediate_aggregations", "true").setSystemProperty("task_concurrency", "1").on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.symbol("c"), PlanBuilder.expression("count(b)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.gatheringExchange(ExchangeNode.Scope.REMOTE, planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.globalGrouping().step(AggregationNode.Step.PARTIAL).addAggregation(planBuilder.symbol("b"), PlanBuilder.expression("count(a)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("a")));
                })));
            });
        }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.empty(), functionCall), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.GATHER, PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.empty(), functionCall), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.INTERMEDIATE, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.GATHER, PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.empty(), functionCall), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.PARTIAL, PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("a", 0))))))));
    }

    @Test
    public void testWithGroups() {
        tester().assertThat(new AddIntermediateAggregations()).setSystemProperty("enable_intermediate_aggregations", "true").setSystemProperty("task_concurrency", "4").on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.singleGroupingSet(planBuilder.symbol("c")).step(AggregationNode.Step.FINAL).addAggregation(planBuilder.symbol("c"), PlanBuilder.expression("count(b)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.gatheringExchange(ExchangeNode.Scope.REMOTE, planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.singleGroupingSet(planBuilder.symbol("b")).step(AggregationNode.Step.PARTIAL).addAggregation(planBuilder.symbol("b"), PlanBuilder.expression("count(a)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("a")));
                })));
            });
        }).doesNotFire();
    }

    @Test
    public void testInterimProject() {
        ExpectedValueProvider<FunctionCall> functionCall = PlanMatchPattern.functionCall("count", false, (List<PlanTestSymbol>) ImmutableList.of(PlanMatchPattern.anySymbol()));
        tester().assertThat(new AddIntermediateAggregations()).setSystemProperty("enable_intermediate_aggregations", "true").setSystemProperty("task_concurrency", "4").on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL).addAggregation(planBuilder.symbol("c"), PlanBuilder.expression("count(b)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.gatheringExchange(ExchangeNode.Scope.REMOTE, planBuilder.project(Assignments.identity(new Symbol[]{planBuilder.symbol("b")}), planBuilder.aggregation(aggregationBuilder -> {
                    aggregationBuilder.globalGrouping().step(AggregationNode.Step.PARTIAL).addAggregation(planBuilder.symbol("b"), PlanBuilder.expression("count(a)"), ImmutableList.of(BigintType.BIGINT)).source(planBuilder.values(planBuilder.symbol("a")));
                }))));
            });
        }).matches(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.empty(), functionCall), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.FINAL, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.GATHER, PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.empty(), functionCall), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.INTERMEDIATE, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, PlanMatchPattern.exchange(ExchangeNode.Scope.REMOTE, ExchangeNode.Type.GATHER, PlanMatchPattern.project(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.empty(), functionCall), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.INTERMEDIATE, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.GATHER, PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.of(Optional.empty(), functionCall), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.PARTIAL, PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("a", 0))))))))))));
    }
}
