package com.facebook.presto.spark.planner;

import com.facebook.airlift.json.Codec;
import com.facebook.airlift.json.JsonCodec;
import com.facebook.airlift.log.Logger;
import com.facebook.presto.Session;
import com.facebook.presto.execution.ScheduledSplit;
import com.facebook.presto.execution.TaskSource;
import com.facebook.presto.execution.scheduler.TableWriteInfo;
import com.facebook.presto.spark.PrestoSparkTaskDescriptor;
import com.facebook.presto.spark.classloader_interface.MutablePartitionId;
import com.facebook.presto.spark.classloader_interface.PrestoSparkMutableRow;
import com.facebook.presto.spark.classloader_interface.PrestoSparkShuffleStats;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskExecutorFactoryProvider;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskOutput;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskProcessor;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskRdd;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskSourceRdd;
import com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskDescriptor;
import com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskSource;
import com.facebook.presto.spark.classloader_interface.SerializedTaskInfo;
import com.facebook.presto.spark.util.PrestoSparkUtils;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.StandardErrorCode;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.split.CloseableSplitSourceProvider;
import com.facebook.presto.split.SplitManager;
import com.facebook.presto.split.SplitSource;
import com.facebook.presto.sql.planner.PartitioningHandle;
import com.facebook.presto.sql.planner.PartitioningProviderManager;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.SplitSourceFactory;
import com.facebook.presto.sql.planner.SystemPartitioningHandle;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.plan.PlanFragmentId;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.google.common.base.Preconditions;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimaps;
import com.google.common.collect.SetMultimap;
import com.google.common.collect.Sets;
import com.google.common.collect.UnmodifiableIterator;
import io.airlift.units.DataSize;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import javax.inject.Inject;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.RDD;
import org.apache.spark.util.CollectionAccumulator;

/* loaded from: input_file:com/facebook/presto/spark/planner/PrestoSparkRddFactory.class */
public class PrestoSparkRddFactory {
    private static final Logger log = Logger.get(PrestoSparkRddFactory.class);
    private final SplitManager splitManager;
    private final PartitioningProviderManager partitioningProviderManager;
    private final JsonCodec<PrestoSparkTaskDescriptor> taskDescriptorJsonCodec;
    private final Codec<TaskSource> taskSourceCodec;

    @Inject
    public PrestoSparkRddFactory(SplitManager splitManager, PartitioningProviderManager partitioningProviderManager, JsonCodec<PrestoSparkTaskDescriptor> jsonCodec, Codec<TaskSource> codec) {
        this.splitManager = (SplitManager) Objects.requireNonNull(splitManager, "splitManager is null");
        this.partitioningProviderManager = (PartitioningProviderManager) Objects.requireNonNull(partitioningProviderManager, "partitioningProviderManager is null");
        this.taskDescriptorJsonCodec = (JsonCodec) Objects.requireNonNull(jsonCodec, "taskDescriptorJsonCodec is null");
        this.taskSourceCodec = (Codec) Objects.requireNonNull(codec, "taskSourceCodec is null");
    }

    public <T extends PrestoSparkTaskOutput> JavaPairRDD<MutablePartitionId, T> createSparkRdd(JavaSparkContext javaSparkContext, Session session, PlanFragment planFragment, Map<PlanFragmentId, JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow>> map, Map<PlanFragmentId, Broadcast<?>> map2, PrestoSparkTaskExecutorFactoryProvider prestoSparkTaskExecutorFactoryProvider, CollectionAccumulator<SerializedTaskInfo> collectionAccumulator, CollectionAccumulator<PrestoSparkShuffleStats> collectionAccumulator2, TableWriteInfo tableWriteInfo, Class<T> cls) {
        Preconditions.checkArgument(!planFragment.getStageExecutionDescriptor().isStageGroupedExecution(), "unexpected grouped execution fragment: %s", planFragment.getId());
        PartitioningHandle partitioning = planFragment.getPartitioning();
        if (partitioning.equals(SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION)) {
            throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "Automatic writers scaling is not supported by Presto on Spark");
        }
        Preconditions.checkArgument(!partitioning.equals(SystemPartitioningHandle.COORDINATOR_DISTRIBUTION), "COORDINATOR_DISTRIBUTION fragment must be run on the driver");
        Preconditions.checkArgument(!partitioning.equals(SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION), "FIXED_BROADCAST_DISTRIBUTION can only be set as an output partitioning scheme, and not as a fragment distribution");
        Preconditions.checkArgument(!partitioning.equals(SystemPartitioningHandle.FIXED_PASSTHROUGH_DISTRIBUTION), "FIXED_PASSTHROUGH_DISTRIBUTION can only be set as local exchange partitioning");
        Preconditions.checkArgument(!partitioning.equals(SystemPartitioningHandle.ARBITRARY_DISTRIBUTION), "ARBITRARY_DISTRIBUTION is not expected to be set as a fragment distribution");
        if (!partitioning.equals(SystemPartitioningHandle.SINGLE_DISTRIBUTION) && !partitioning.equals(SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION) && !partitioning.equals(SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION) && !partitioning.equals(SystemPartitioningHandle.SOURCE_DISTRIBUTION) && !partitioning.getConnectorId().isPresent()) {
            throw new IllegalArgumentException(String.format("Unexpected fragment partitioning %s, fragmentId: %s", partitioning, planFragment.getId()));
        }
        for (RemoteSourceNode remoteSourceNode : planFragment.getRemoteSourceNodes()) {
            if (remoteSourceNode.isEnsureSourceOrdering() || remoteSourceNode.getOrderingScheme().isPresent()) {
                throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, String.format("Order sensitive exchange is not supported by Presto on Spark. fragmentId: %s, sourceFragmentIds: %s", planFragment.getId(), remoteSourceNode.getSourceFragmentIds()));
            }
        }
        return createRdd(javaSparkContext, session, planFragment, prestoSparkTaskExecutorFactoryProvider, collectionAccumulator, collectionAccumulator2, tableWriteInfo, map, map2, cls);
    }

    private <T extends PrestoSparkTaskOutput> JavaPairRDD<MutablePartitionId, T> createRdd(JavaSparkContext javaSparkContext, Session session, PlanFragment planFragment, PrestoSparkTaskExecutorFactoryProvider prestoSparkTaskExecutorFactoryProvider, CollectionAccumulator<SerializedTaskInfo> collectionAccumulator, CollectionAccumulator<PrestoSparkShuffleStats> collectionAccumulator2, TableWriteInfo tableWriteInfo, Map<PlanFragmentId, JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow>> map, Map<PlanFragmentId, Broadcast<?>> map2, Class<T> cls) {
        Optional empty;
        checkInputs(planFragment.getRemoteSourceNodes(), map, map2);
        SerializedPrestoSparkTaskDescriptor serializedPrestoSparkTaskDescriptor = new SerializedPrestoSparkTaskDescriptor(this.taskDescriptorJsonCodec.toJsonBytes(new PrestoSparkTaskDescriptor(session.toSessionRepresentation(), session.getIdentity().getExtraCredentials(), planFragment, tableWriteInfo)));
        Optional<Integer> empty2 = Optional.empty();
        HashMap hashMap = new HashMap();
        for (Map.Entry<PlanFragmentId, JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow>> entry : map.entrySet()) {
            RDD rdd = entry.getValue().rdd();
            hashMap.put(entry.getKey().toString(), rdd);
            if (empty2.isPresent()) {
                Preconditions.checkArgument(empty2.get().intValue() == rdd.getNumPartitions(), "Incompatible number of input partitions: %s != %s", empty2.get(), rdd.getNumPartitions());
            } else {
                empty2 = Optional.of(Integer.valueOf(rdd.getNumPartitions()));
            }
        }
        PrestoSparkTaskProcessor prestoSparkTaskProcessor = new PrestoSparkTaskProcessor(prestoSparkTaskExecutorFactoryProvider, serializedPrestoSparkTaskDescriptor, collectionAccumulator, collectionAccumulator2, toTaskProcessorBroadcastInputs(map2), cls);
        List<TableScanNode> findTableScanNodes = findTableScanNodes(planFragment.getRoot());
        if (!findTableScanNodes.isEmpty()) {
            SplitManager splitManager = this.splitManager;
            splitManager.getClass();
            CloseableSplitSourceProvider closeableSplitSourceProvider = new CloseableSplitSourceProvider(splitManager::getSplits);
            Throwable th = null;
            try {
                try {
                    empty = Optional.of(createTaskSourcesRdd(planFragment.getId(), javaSparkContext, session, planFragment.getPartitioning(), findTableScanNodes, new SplitSourceFactory(closeableSplitSourceProvider, WarningCollector.NOOP).createSplitSources(planFragment, session, tableWriteInfo), empty2));
                    if (closeableSplitSourceProvider != null) {
                        if (0 != 0) {
                            try {
                                closeableSplitSourceProvider.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            closeableSplitSourceProvider.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                if (closeableSplitSourceProvider != null) {
                    if (th != null) {
                        try {
                            closeableSplitSourceProvider.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        closeableSplitSourceProvider.close();
                    }
                }
                throw th3;
            }
        } else if (map.size() == 0) {
            Preconditions.checkArgument(planFragment.getPartitioning().equals(SystemPartitioningHandle.SINGLE_DISTRIBUTION), "SINGLE_DISTRIBUTION partitioning is expected: %s", planFragment.getPartitioning());
            empty = Optional.of(new PrestoSparkTaskSourceRdd(javaSparkContext.sc(), ImmutableList.of(ImmutableList.of())));
        } else {
            empty = Optional.empty();
        }
        return JavaPairRDD.fromRDD(PrestoSparkTaskRdd.create(javaSparkContext.sc(), empty, hashMap, prestoSparkTaskProcessor), PrestoSparkUtils.classTag(MutablePartitionId.class), PrestoSparkUtils.classTag(cls));
    }

    private PrestoSparkTaskSourceRdd createTaskSourcesRdd(PlanFragmentId planFragmentId, JavaSparkContext javaSparkContext, Session session, PartitioningHandle partitioningHandle, List<TableScanNode> list, Map<PlanNodeId, SplitSource> map, Optional<Integer> optional) {
        ArrayListMultimap create = ArrayListMultimap.create();
        for (TableScanNode tableScanNode : list) {
            int i = 0;
            PrestoSparkSplitAssigner createSplitAssigner = createSplitAssigner(session, tableScanNode.getId(), (SplitSource) Objects.requireNonNull(map.get(tableScanNode.getId()), "split source is missing for table scan node with id: " + tableScanNode.getId()), partitioningHandle);
            Throwable th = null;
            while (true) {
                try {
                    try {
                        Optional<SetMultimap<Integer, ScheduledSplit>> nextBatch = createSplitAssigner.getNextBatch();
                        if (!nextBatch.isPresent()) {
                            break;
                        }
                        int size = nextBatch.get().size();
                        log.info("Found %s splits for table scan node with id %s", new Object[]{Integer.valueOf(size), tableScanNode.getId()});
                        i += size;
                        create.putAll(createTaskSources(tableScanNode.getId(), nextBatch.get()));
                    } finally {
                    }
                } catch (Throwable th2) {
                    if (createSplitAssigner != null) {
                        if (th != null) {
                            try {
                                createSplitAssigner.close();
                            } catch (Throwable th3) {
                                th.addSuppressed(th3);
                            }
                        } else {
                            createSplitAssigner.close();
                        }
                    }
                    throw th2;
                }
            }
            if (createSplitAssigner != null) {
                if (0 != 0) {
                    try {
                        createSplitAssigner.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    createSplitAssigner.close();
                }
            }
            log.info("Total number of splits for table scan node with id %s: %s", new Object[]{tableScanNode.getId(), Integer.valueOf(i)});
        }
        log.info("Total serialized size of all task sources for fragment %s: %s", new Object[]{planFragmentId, DataSize.succinctBytes(create.values().stream().mapToLong(serializedPrestoSparkTaskSource -> {
            return serializedPrestoSparkTaskSource.getBytes().length;
        }).sum())});
        ArrayList arrayList = new ArrayList();
        if (optional.isPresent()) {
            for (int i2 = 0; i2 < optional.get().intValue(); i2++) {
                arrayList.add(Objects.requireNonNull(create.removeAll(Integer.valueOf(i2)), "taskSources is null"));
            }
        } else {
            arrayList.addAll(Multimaps.asMap(create).values());
        }
        return new PrestoSparkTaskSourceRdd(javaSparkContext.sc(), arrayList);
    }

    private PrestoSparkSplitAssigner createSplitAssigner(Session session, PlanNodeId planNodeId, SplitSource splitSource, PartitioningHandle partitioningHandle) {
        return partitioningHandle.equals(SystemPartitioningHandle.SOURCE_DISTRIBUTION) ? PrestoSparkSourceDistributionSplitAssigner.create(session, planNodeId, splitSource) : PrestoSparkPartitionedSplitAssigner.create(session, planNodeId, splitSource, partitioningHandle, this.partitioningProviderManager);
    }

    private ListMultimap<Integer, SerializedPrestoSparkTaskSource> createTaskSources(PlanNodeId planNodeId, SetMultimap<Integer, ScheduledSplit> setMultimap) {
        ArrayListMultimap create = ArrayListMultimap.create();
        UnmodifiableIterator it = ImmutableSet.copyOf(setMultimap.keySet()).iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            create.put(Integer.valueOf(intValue), new SerializedPrestoSparkTaskSource(PrestoSparkUtils.serializeZstdCompressed(this.taskSourceCodec, new TaskSource(planNodeId, setMultimap.removeAll(Integer.valueOf(intValue)), true))));
        }
        return create;
    }

    private static List<TableScanNode> findTableScanNodes(PlanNode planNode) {
        PlanNodeSearcher searchFrom = PlanNodeSearcher.searchFrom(planNode);
        Class<TableScanNode> cls = TableScanNode.class;
        TableScanNode.class.getClass();
        return searchFrom.where((v1) -> {
            return r1.isInstance(v1);
        }).findAll();
    }

    private static Map<String, Broadcast<?>> toTaskProcessorBroadcastInputs(Map<PlanFragmentId, Broadcast<?>> map) {
        return (Map) map.entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> {
            return ((PlanFragmentId) entry.getKey()).toString();
        }, (v0) -> {
            return v0.getValue();
        }));
    }

    private static void checkInputs(List<RemoteSourceNode> list, Map<PlanFragmentId, JavaPairRDD<MutablePartitionId, PrestoSparkMutableRow>> map, Map<PlanFragmentId, Broadcast<?>> map2) {
        Set set = (Set) list.stream().map((v0) -> {
            return v0.getSourceFragmentIds();
        }).flatMap((v0) -> {
            return v0.stream();
        }).collect(ImmutableSet.toImmutableSet());
        Sets.SetView union = Sets.union(map.keySet(), map2.keySet());
        Sets.SetView difference = Sets.difference(set, union);
        Preconditions.checkArgument(difference.isEmpty() && Sets.difference(union, set).isEmpty(), "rddInputs mismatch discovered. expected inputs: %s, actual rdd inputs: %s, actual broadcast inputs: %s, missing inputs: %s, extra inputs: %s", new Object[]{set, map.keySet(), map2.keySet(), difference, set});
    }
}
