package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.spi.Plugin;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.iterative.GroupReference;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.planner.optimizations.joins.JoinGraph;
import com.facebook.presto.sql.planner.plan.Assignments;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.ValuesNode;
import com.facebook.presto.sql.tree.ArithmeticUnaryExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import org.testng.Assert;
import org.testng.annotations.Test;

@Test(singleThreaded = true)
/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.class */
public class TestEliminateCrossJoins extends BaseRuleTest {
    private final PlanNodeIdAllocator idAllocator;

    public TestEliminateCrossJoins() {
        super(new Plugin[0]);
        this.idAllocator = new PlanNodeIdAllocator();
    }

    @Test
    public void testEliminateCrossJoin() {
        tester().assertThat(new EliminateCrossJoins()).setSystemProperty("join_reordering_strategy", "ELIMINATE_CROSS_JOINS").on(crossJoinAndJoin(JoinNode.Type.INNER)).matches(PlanMatchPattern.join(JoinNode.Type.INNER, ImmutableList.of(symbolAliases -> {
            return new JoinNode.EquiJoinClause(new Symbol("cySymbol"), new Symbol("bySymbol"));
        }), PlanMatchPattern.join(JoinNode.Type.INNER, ImmutableList.of(symbolAliases2 -> {
            return new JoinNode.EquiJoinClause(new Symbol("axSymbol"), new Symbol("cxSymbol"));
        }), PlanMatchPattern.any(new PlanMatchPattern[0]), PlanMatchPattern.any(new PlanMatchPattern[0])), PlanMatchPattern.any(new PlanMatchPattern[0])));
    }

    @Test
    public void testRetainOutgoingGroupReferences() {
        tester().assertThat(new EliminateCrossJoins()).setSystemProperty("join_reordering_strategy", "ELIMINATE_CROSS_JOINS").on(crossJoinAndJoin(JoinNode.Type.INNER)).matches(PlanMatchPattern.node(JoinNode.class, PlanMatchPattern.node(JoinNode.class, PlanMatchPattern.node(GroupReference.class, new PlanMatchPattern[0]), PlanMatchPattern.node(GroupReference.class, new PlanMatchPattern[0])), PlanMatchPattern.node(GroupReference.class, new PlanMatchPattern[0])));
    }

    @Test
    public void testDoNotReorderOuterJoin() {
        tester().assertThat(new EliminateCrossJoins()).setSystemProperty("join_reordering_strategy", "ELIMINATE_CROSS_JOINS").on(crossJoinAndJoin(JoinNode.Type.LEFT)).doesNotFire();
    }

    @Test
    public void testIsOriginalOrder() {
        Assert.assertTrue(EliminateCrossJoins.isOriginalOrder(ImmutableList.of(0, 1, 2, 3, 4)));
        Assert.assertFalse(EliminateCrossJoins.isOriginalOrder(ImmutableList.of(0, 2, 1, 3, 4)));
    }

    @Test
    public void testJoinOrder() {
        Assert.assertEquals(EliminateCrossJoins.getJoinOrder((JoinGraph) Iterables.getOnlyElement(JoinGraph.buildFrom(joinNode(joinNode(values(symbol("a")), values(symbol("b")), new String[0]), values(symbol("c")), symbol("a"), symbol("c"), symbol("c"), symbol("b"))))), ImmutableList.of(0, 2, 1));
    }

    @Test
    public void testJoinOrderWithRealCrossJoin() {
        Assert.assertEquals(EliminateCrossJoins.getJoinOrder((JoinGraph) Iterables.getOnlyElement(JoinGraph.buildFrom(joinNode(joinNode(joinNode(values(symbol("a")), values(symbol("b")), new String[0]), values(symbol("c")), symbol("a"), symbol("c"), symbol("c"), symbol("b")), joinNode(joinNode(values(symbol("x")), values(symbol("y")), new String[0]), values(symbol("z")), symbol("x"), symbol("z"), symbol("z"), symbol("y")), new String[0])))), ImmutableList.of(0, 2, 1, 3, 5, 4));
    }

    @Test
    public void testJoinOrderWithMultipleEdgesBetweenNodes() {
        Assert.assertEquals(EliminateCrossJoins.getJoinOrder((JoinGraph) Iterables.getOnlyElement(JoinGraph.buildFrom(joinNode(joinNode(values(symbol("a")), values(symbol("b1"), symbol("b2")), new String[0]), values(symbol("c1"), symbol("c2")), symbol("a"), symbol("c1"), symbol("c1"), symbol("b1"), symbol("c2"), symbol("b2"))))), ImmutableList.of(0, 2, 1));
    }

    @Test
    public void testDonNotChangeOrderWithoutCrossJoin() {
        Assert.assertEquals(EliminateCrossJoins.getJoinOrder((JoinGraph) Iterables.getOnlyElement(JoinGraph.buildFrom(joinNode(joinNode(values(symbol("a")), values(symbol("b")), symbol("a"), symbol("b")), values(symbol("c")), symbol("c"), symbol("b"))))), ImmutableList.of(0, 1, 2));
    }

    @Test
    public void testDoNotReorderCrossJoins() {
        Assert.assertEquals(EliminateCrossJoins.getJoinOrder((JoinGraph) Iterables.getOnlyElement(JoinGraph.buildFrom(joinNode(joinNode(values(symbol("a")), values(symbol("b")), new String[0]), values(symbol("c")), symbol("c"), symbol("b"))))), ImmutableList.of(0, 1, 2));
    }

    @Test
    public void testGiveUpOnNonIdentityProjections() {
        Assert.assertEquals(JoinGraph.buildFrom(joinNode(projectNode(joinNode(values(symbol("a1")), values(symbol("b")), new String[0]), symbol("a2"), new ArithmeticUnaryExpression(ArithmeticUnaryExpression.Sign.MINUS, new SymbolReference("a1"))), values(symbol("c")), symbol("a2"), symbol("c"), symbol("c"), symbol("b"))).size(), 2);
    }

    private Function<PlanBuilder, PlanNode> crossJoinAndJoin(JoinNode.Type type) {
        return planBuilder -> {
            Symbol symbol = planBuilder.symbol("axSymbol");
            Symbol symbol2 = planBuilder.symbol("bySymbol");
            Symbol symbol3 = planBuilder.symbol("cxSymbol");
            Symbol symbol4 = planBuilder.symbol("cySymbol");
            return planBuilder.join(JoinNode.Type.INNER, planBuilder.join(type, planBuilder.values(symbol), planBuilder.values(symbol2), new JoinNode.EquiJoinClause[0]), planBuilder.values(symbol3, symbol4), new JoinNode.EquiJoinClause(symbol3, symbol), new JoinNode.EquiJoinClause(symbol4, symbol2));
        };
    }

    private PlanNode projectNode(PlanNode planNode, String str, Expression expression) {
        return new ProjectNode(this.idAllocator.getNextId(), planNode, Assignments.of(new Symbol(str), expression));
    }

    private String symbol(String str) {
        return str;
    }

    private JoinNode joinNode(PlanNode planNode, PlanNode planNode2, String... strArr) {
        Preconditions.checkArgument(strArr.length % 2 == 0);
        ImmutableList.Builder builder = ImmutableList.builder();
        for (int i = 0; i < strArr.length; i += 2) {
            builder.add(new JoinNode.EquiJoinClause(new Symbol(strArr[i]), new Symbol(strArr[i + 1])));
        }
        return new JoinNode(this.idAllocator.getNextId(), JoinNode.Type.INNER, planNode, planNode2, builder.build(), ImmutableList.builder().addAll(planNode.getOutputSymbols()).addAll(planNode2.getOutputSymbols()).build(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
    }

    private ValuesNode values(String... strArr) {
        return new ValuesNode(this.idAllocator.getNextId(), (List) Arrays.stream(strArr).map(Symbol::new).collect(ImmutableList.toImmutableList()), ImmutableList.of());
    }
}
