package com.facebook.presto.execution.scheduler;

import com.facebook.presto.connector.ConnectorId;
import com.facebook.presto.metadata.TableHandle;
import com.facebook.presto.spi.predicate.TupleDomain;
import com.facebook.presto.spi.type.VarcharType;
import com.facebook.presto.sql.planner.Partitioning;
import com.facebook.presto.sql.planner.PartitioningScheme;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SystemPartitioningHandle;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.PlanFragmentId;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.UnionNode;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.testing.TestingMetadata;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/execution/scheduler/TestPhasedExecutionSchedule.class */
public class TestPhasedExecutionSchedule {
    @Test
    public void testExchange() throws Exception {
        PlanFragment createTableScanPlanFragment = createTableScanPlanFragment("a");
        PlanFragment createTableScanPlanFragment2 = createTableScanPlanFragment("b");
        PlanFragment createTableScanPlanFragment3 = createTableScanPlanFragment("c");
        PlanFragment createExchangePlanFragment = createExchangePlanFragment("exchange", createTableScanPlanFragment, createTableScanPlanFragment2, createTableScanPlanFragment3);
        Assert.assertEquals(PhasedExecutionSchedule.extractPhases(ImmutableList.of(createTableScanPlanFragment, createTableScanPlanFragment2, createTableScanPlanFragment3, createExchangePlanFragment)), ImmutableList.of(ImmutableSet.of(createExchangePlanFragment.getId()), ImmutableSet.of(createTableScanPlanFragment.getId()), ImmutableSet.of(createTableScanPlanFragment2.getId()), ImmutableSet.of(createTableScanPlanFragment3.getId())));
    }

    @Test
    public void testUnion() throws Exception {
        PlanFragment createTableScanPlanFragment = createTableScanPlanFragment("a");
        PlanFragment createTableScanPlanFragment2 = createTableScanPlanFragment("b");
        PlanFragment createTableScanPlanFragment3 = createTableScanPlanFragment("c");
        PlanFragment createUnionPlanFragment = createUnionPlanFragment("union", createTableScanPlanFragment, createTableScanPlanFragment2, createTableScanPlanFragment3);
        Assert.assertEquals(PhasedExecutionSchedule.extractPhases(ImmutableList.of(createTableScanPlanFragment, createTableScanPlanFragment2, createTableScanPlanFragment3, createUnionPlanFragment)), ImmutableList.of(ImmutableSet.of(createUnionPlanFragment.getId()), ImmutableSet.of(createTableScanPlanFragment.getId()), ImmutableSet.of(createTableScanPlanFragment2.getId()), ImmutableSet.of(createTableScanPlanFragment3.getId())));
    }

    @Test
    public void testJoin() throws Exception {
        PlanFragment createTableScanPlanFragment = createTableScanPlanFragment("build");
        PlanFragment createTableScanPlanFragment2 = createTableScanPlanFragment("probe");
        PlanFragment createJoinPlanFragment = createJoinPlanFragment(JoinNode.Type.INNER, "join", createTableScanPlanFragment, createTableScanPlanFragment2);
        Assert.assertEquals(PhasedExecutionSchedule.extractPhases(ImmutableList.of(createJoinPlanFragment, createTableScanPlanFragment, createTableScanPlanFragment2)), ImmutableList.of(ImmutableSet.of(createJoinPlanFragment.getId()), ImmutableSet.of(createTableScanPlanFragment.getId()), ImmutableSet.of(createTableScanPlanFragment2.getId())));
    }

    @Test
    public void testRightJoin() throws Exception {
        PlanFragment createTableScanPlanFragment = createTableScanPlanFragment("build");
        PlanFragment createTableScanPlanFragment2 = createTableScanPlanFragment("probe");
        PlanFragment createJoinPlanFragment = createJoinPlanFragment(JoinNode.Type.RIGHT, "join", createTableScanPlanFragment, createTableScanPlanFragment2);
        Assert.assertEquals(PhasedExecutionSchedule.extractPhases(ImmutableList.of(createJoinPlanFragment, createTableScanPlanFragment, createTableScanPlanFragment2)), ImmutableList.of(ImmutableSet.of(createJoinPlanFragment.getId()), ImmutableSet.of(createTableScanPlanFragment.getId()), ImmutableSet.of(createTableScanPlanFragment2.getId())));
    }

    @Test
    public void testBroadcastJoin() throws Exception {
        PlanFragment createTableScanPlanFragment = createTableScanPlanFragment("build");
        PlanFragment createBroadcastJoinPlanFragment = createBroadcastJoinPlanFragment("join", createTableScanPlanFragment);
        Assert.assertEquals(PhasedExecutionSchedule.extractPhases(ImmutableList.of(createBroadcastJoinPlanFragment, createTableScanPlanFragment)), ImmutableList.of(ImmutableSet.of(createBroadcastJoinPlanFragment.getId(), createTableScanPlanFragment.getId())));
    }

    @Test
    public void testJoinWithDeepSources() throws Exception {
        PlanFragment createTableScanPlanFragment = createTableScanPlanFragment("buildSource");
        PlanFragment createExchangePlanFragment = createExchangePlanFragment("buildMiddle", createTableScanPlanFragment);
        PlanFragment createExchangePlanFragment2 = createExchangePlanFragment("buildTop", createExchangePlanFragment);
        PlanFragment createTableScanPlanFragment2 = createTableScanPlanFragment("probeSource");
        PlanFragment createExchangePlanFragment3 = createExchangePlanFragment("probeMiddle", createTableScanPlanFragment2);
        PlanFragment createExchangePlanFragment4 = createExchangePlanFragment("probeTop", createExchangePlanFragment3);
        PlanFragment createJoinPlanFragment = createJoinPlanFragment(JoinNode.Type.INNER, "join", createExchangePlanFragment2, createExchangePlanFragment4);
        Assert.assertEquals(PhasedExecutionSchedule.extractPhases(ImmutableList.of(createJoinPlanFragment, createExchangePlanFragment2, createExchangePlanFragment, createTableScanPlanFragment, createExchangePlanFragment4, createExchangePlanFragment3, createTableScanPlanFragment2)), ImmutableList.of(ImmutableSet.of(createJoinPlanFragment.getId()), ImmutableSet.of(createExchangePlanFragment2.getId()), ImmutableSet.of(createExchangePlanFragment.getId()), ImmutableSet.of(createTableScanPlanFragment.getId()), ImmutableSet.of(createExchangePlanFragment4.getId()), ImmutableSet.of(createExchangePlanFragment3.getId()), ImmutableSet.of(createTableScanPlanFragment2.getId())));
    }

    private static PlanFragment createExchangePlanFragment(String str, PlanFragment... planFragmentArr) {
        return createFragment(new RemoteSourceNode(new PlanNodeId(str + "_id"), (List) Stream.of((Object[]) planFragmentArr).map((v0) -> {
            return v0.getId();
        }).collect(ImmutableList.toImmutableList()), planFragmentArr[0].getPartitioningScheme().getOutputLayout()));
    }

    private static PlanFragment createUnionPlanFragment(String str, PlanFragment... planFragmentArr) {
        return createFragment(new UnionNode(new PlanNodeId(str + "_id"), (List) Stream.of((Object[]) planFragmentArr).map(planFragment -> {
            return new RemoteSourceNode(new PlanNodeId(planFragment.getId().toString()), planFragment.getId(), planFragment.getPartitioningScheme().getOutputLayout());
        }).collect(ImmutableList.toImmutableList()), ImmutableListMultimap.of(), ImmutableList.of()));
    }

    private static PlanFragment createBroadcastJoinPlanFragment(String str, PlanFragment planFragment) {
        Symbol symbol = new Symbol("column");
        TableScanNode tableScanNode = new TableScanNode(new PlanNodeId(str), new TableHandle(new ConnectorId("test"), new TestingMetadata.TestingTableHandle()), ImmutableList.of(symbol), ImmutableMap.of(symbol, new TestingMetadata.TestingColumnHandle("column")), Optional.empty(), TupleDomain.all(), (Expression) null);
        RemoteSourceNode remoteSourceNode = new RemoteSourceNode(new PlanNodeId("build_id"), planFragment.getId(), ImmutableList.of());
        return createFragment(new JoinNode(new PlanNodeId(str + "_id"), JoinNode.Type.INNER, tableScanNode, remoteSourceNode, ImmutableList.of(), ImmutableList.builder().addAll(tableScanNode.getOutputSymbols()).addAll(remoteSourceNode.getOutputSymbols()).build(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(JoinNode.DistributionType.REPLICATED)));
    }

    private static PlanFragment createJoinPlanFragment(JoinNode.Type type, String str, PlanFragment planFragment, PlanFragment planFragment2) {
        RemoteSourceNode remoteSourceNode = new RemoteSourceNode(new PlanNodeId("probe_id"), planFragment2.getId(), ImmutableList.of());
        RemoteSourceNode remoteSourceNode2 = new RemoteSourceNode(new PlanNodeId("build_id"), planFragment.getId(), ImmutableList.of());
        return createFragment(new JoinNode(new PlanNodeId(str + "_id"), type, remoteSourceNode, remoteSourceNode2, ImmutableList.of(), ImmutableList.builder().addAll(remoteSourceNode.getOutputSymbols()).addAll(remoteSourceNode2.getOutputSymbols()).build(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.of(JoinNode.DistributionType.PARTITIONED)));
    }

    private static PlanFragment createTableScanPlanFragment(String str) {
        Symbol symbol = new Symbol("column");
        return createFragment(new TableScanNode(new PlanNodeId(str), new TableHandle(new ConnectorId("test"), new TestingMetadata.TestingTableHandle()), ImmutableList.of(symbol), ImmutableMap.of(symbol, new TestingMetadata.TestingColumnHandle("column")), Optional.empty(), TupleDomain.all(), (Expression) null));
    }

    private static PlanFragment createFragment(PlanNode planNode) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        Iterator it = planNode.getOutputSymbols().iterator();
        while (it.hasNext()) {
            builder.put((Symbol) it.next(), VarcharType.VARCHAR);
        }
        return new PlanFragment(new PlanFragmentId(planNode.getId() + "_fragment_id"), planNode, builder.build(), SystemPartitioningHandle.SOURCE_DISTRIBUTION, ImmutableList.of(planNode.getId()), new PartitioningScheme(Partitioning.create(SystemPartitioningHandle.SINGLE_DISTRIBUTION, ImmutableList.of()), planNode.getOutputSymbols()));
    }
}
