package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.metadata.TableMetadata;
import com.facebook.presto.spi.ConnectorTableMetadata;
import com.facebook.presto.sql.analyzer.Session;
import com.facebook.presto.sql.analyzer.Type;
import com.facebook.presto.sql.planner.DeterminismEvaluator;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.DistinctLimitNode;
import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.LimitNode;
import com.facebook.presto.sql.planner.plan.MarkDistinctNode;
import com.facebook.presto.sql.planner.plan.MaterializeSampleNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanNodeRewriter;
import com.facebook.presto.sql.planner.plan.PlanRewriter;
import com.facebook.presto.sql.planner.plan.ProjectNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.SortNode;
import com.facebook.presto.sql.planner.plan.TableWriterNode;
import com.facebook.presto.sql.planner.plan.TopNNode;
import com.facebook.presto.sql.planner.plan.UnionNode;
import com.facebook.presto.sql.tree.ArithmeticExpression;
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.QualifiedNameReference;
import com.facebook.presto.util.IterableTransformer;
import com.google.common.base.Function;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/MaterializeSamplePullUp.class */
public class MaterializeSamplePullUp extends PlanOptimizer {

    /* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/MaterializeSamplePullUp$Rewriter.class */
    private static class Rewriter extends PlanNodeRewriter<Void> {
        private final PlanNodeIdAllocator idAllocator;
        private final SymbolAllocator symbolAllocator;

        private Rewriter(PlanNodeIdAllocator planNodeIdAllocator, SymbolAllocator symbolAllocator) {
            this.idAllocator = (PlanNodeIdAllocator) Preconditions.checkNotNull(planNodeIdAllocator, "idAllocator is null");
            this.symbolAllocator = (SymbolAllocator) Preconditions.checkNotNull(symbolAllocator, "symbolAllocator is null");
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanNodeRewriter
        public PlanNode rewriteNode(PlanNode planNode, Void r6, PlanRewriter<Void> planRewriter) {
            return planRewriter.defaultRewrite(planNode, null);
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanNodeRewriter
        public PlanNode rewriteTableWriter(TableWriterNode tableWriterNode, Void r17, PlanRewriter<Void> planRewriter) {
            PlanNode rewrite = planRewriter.rewrite(tableWriterNode.getSource(), null);
            if (!(rewrite instanceof MaterializeSampleNode)) {
                return planRewriter.defaultRewrite(tableWriterNode, null);
            }
            Preconditions.checkArgument(tableWriterNode.isSampleWeightSupported(), "Cannot write sampled data to a store that doesn't support sampling");
            ConnectorTableMetadata metadata = tableWriterNode.getTableMetadata().getMetadata();
            return new TableWriterNode(tableWriterNode.getId(), ((MaterializeSampleNode) rewrite).getSource(), tableWriterNode.getTarget(), tableWriterNode.getColumns(), tableWriterNode.getColumnNames(), tableWriterNode.getOutputSymbols(), Optional.of(((MaterializeSampleNode) rewrite).getSampleWeightSymbol()), tableWriterNode.getCatalog(), new TableMetadata(tableWriterNode.getTableMetadata().getConnectorId(), new ConnectorTableMetadata(metadata.getTable(), metadata.getColumns(), metadata.getOwner(), true)), tableWriterNode.isSampleWeightSupported());
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanNodeRewriter
        public PlanNode rewriteFilter(FilterNode filterNode, Void r8, PlanRewriter<Void> planRewriter) {
            PlanNode rewrite = planRewriter.rewrite(filterNode.getSource(), null);
            if (!(rewrite instanceof MaterializeSampleNode) || !DeterminismEvaluator.isDeterministic(filterNode.getPredicate())) {
                return planRewriter.defaultRewrite(filterNode, null);
            }
            return new MaterializeSampleNode(rewrite.getId(), new FilterNode(filterNode.getId(), ((MaterializeSampleNode) rewrite).getSource(), filterNode.getPredicate()), ((MaterializeSampleNode) rewrite).getSampleWeightSymbol());
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanNodeRewriter
        public PlanNode rewriteProject(ProjectNode projectNode, Void r8, PlanRewriter<Void> planRewriter) {
            PlanNode rewrite = planRewriter.rewrite(projectNode.getSource(), null);
            if (!(rewrite instanceof MaterializeSampleNode) || !Iterables.all(projectNode.getExpressions(), DeterminismEvaluator.deterministic())) {
                return planRewriter.defaultRewrite(projectNode, null);
            }
            Symbol sampleWeightSymbol = ((MaterializeSampleNode) rewrite).getSampleWeightSymbol();
            return new MaterializeSampleNode(rewrite.getId(), new ProjectNode(projectNode.getId(), ((MaterializeSampleNode) rewrite).getSource(), ImmutableMap.builder().putAll(projectNode.getOutputMap()).put(sampleWeightSymbol, new QualifiedNameReference(sampleWeightSymbol.toQualifiedName())).build()), sampleWeightSymbol);
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanNodeRewriter
        public PlanNode rewriteTopN(TopNNode topNNode, Void r13, PlanRewriter<Void> planRewriter) {
            PlanNode rewrite = planRewriter.rewrite(topNNode.getSource(), null);
            if (!(rewrite instanceof MaterializeSampleNode)) {
                return planRewriter.defaultRewrite(topNNode, null);
            }
            return new MaterializeSampleNode(rewrite.getId(), new TopNNode(topNNode.getId(), ((MaterializeSampleNode) rewrite).getSource(), topNNode.getCount(), topNNode.getOrderBy(), topNNode.getOrderings(), topNNode.isPartial(), Optional.of(((MaterializeSampleNode) rewrite).getSampleWeightSymbol())), ((MaterializeSampleNode) rewrite).getSampleWeightSymbol());
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanNodeRewriter
        public PlanNode rewriteSort(SortNode sortNode, Void r9, PlanRewriter<Void> planRewriter) {
            PlanNode rewrite = planRewriter.rewrite(sortNode.getSource(), null);
            if (!(rewrite instanceof MaterializeSampleNode)) {
                return planRewriter.defaultRewrite(sortNode, null);
            }
            return new MaterializeSampleNode(rewrite.getId(), new SortNode(sortNode.getId(), ((MaterializeSampleNode) rewrite).getSource(), sortNode.getOrderBy(), sortNode.getOrderings()), ((MaterializeSampleNode) rewrite).getSampleWeightSymbol());
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanNodeRewriter
        public PlanNode rewriteLimit(LimitNode limitNode, Void r10, PlanRewriter<Void> planRewriter) {
            PlanNode rewrite = planRewriter.rewrite(limitNode.getSource(), null);
            if (!(rewrite instanceof MaterializeSampleNode)) {
                return planRewriter.defaultRewrite(limitNode, null);
            }
            return new MaterializeSampleNode(rewrite.getId(), new LimitNode(limitNode.getId(), ((MaterializeSampleNode) rewrite).getSource(), limitNode.getCount(), Optional.of(((MaterializeSampleNode) rewrite).getSampleWeightSymbol())), ((MaterializeSampleNode) rewrite).getSampleWeightSymbol());
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanNodeRewriter
        public PlanNode rewriteDistinctLimit(DistinctLimitNode distinctLimitNode, Void r9, PlanRewriter<Void> planRewriter) {
            PlanNode rewrite = planRewriter.rewrite(distinctLimitNode.getSource(), null);
            if (!(rewrite instanceof MaterializeSampleNode)) {
                return new DistinctLimitNode(distinctLimitNode.getId(), rewrite, distinctLimitNode.getLimit());
            }
            ImmutableMap.Builder builder = ImmutableMap.builder();
            for (Symbol symbol : rewrite.getOutputSymbols()) {
                builder.put(symbol, new QualifiedNameReference(symbol.toQualifiedName()));
            }
            return new DistinctLimitNode(distinctLimitNode.getId(), new ProjectNode(this.idAllocator.getNextId(), ((MaterializeSampleNode) rewrite).getSource(), builder.build()), distinctLimitNode.getLimit());
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanNodeRewriter
        public PlanNode rewriteMarkDistinct(MarkDistinctNode markDistinctNode, Void r10, PlanRewriter<Void> planRewriter) {
            PlanNode rewrite = planRewriter.rewrite(markDistinctNode.getSource(), null);
            if (!(rewrite instanceof MaterializeSampleNode)) {
                return planRewriter.defaultRewrite(markDistinctNode, null);
            }
            return new MaterializeSampleNode(rewrite.getId(), new MarkDistinctNode(markDistinctNode.getId(), ((MaterializeSampleNode) rewrite).getSource(), markDistinctNode.getMarkerSymbol(), markDistinctNode.getDistinctSymbols(), Optional.of(((MaterializeSampleNode) rewrite).getSampleWeightSymbol())), ((MaterializeSampleNode) rewrite).getSampleWeightSymbol());
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanNodeRewriter
        public PlanNode rewriteSemiJoin(SemiJoinNode semiJoinNode, Void r11, PlanRewriter<Void> planRewriter) {
            PlanNode rewrite = planRewriter.rewrite(semiJoinNode.getFilteringSource(), null);
            PlanNode rewrite2 = planRewriter.rewrite(semiJoinNode.getSource(), null);
            if (rewrite instanceof MaterializeSampleNode) {
                rewrite = ((MaterializeSampleNode) rewrite).getSource();
            }
            if (!(rewrite2 instanceof MaterializeSampleNode)) {
                return new SemiJoinNode(semiJoinNode.getId(), rewrite2, rewrite, semiJoinNode.getSourceJoinSymbol(), semiJoinNode.getFilteringSourceJoinSymbol(), semiJoinNode.getSemiJoinOutput());
            }
            return new MaterializeSampleNode(rewrite2.getId(), new SemiJoinNode(semiJoinNode.getId(), ((MaterializeSampleNode) rewrite2).getSource(), rewrite, semiJoinNode.getSourceJoinSymbol(), semiJoinNode.getFilteringSourceJoinSymbol(), semiJoinNode.getSemiJoinOutput()), ((MaterializeSampleNode) rewrite2).getSampleWeightSymbol());
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanNodeRewriter
        public PlanNode rewriteAggregation(AggregationNode aggregationNode, Void r14, PlanRewriter<Void> planRewriter) {
            PlanNode rewrite = planRewriter.rewrite(aggregationNode.getSource(), null);
            if (!(rewrite instanceof MaterializeSampleNode)) {
                return new AggregationNode(aggregationNode.getId(), rewrite, aggregationNode.getGroupBy(), aggregationNode.getAggregations(), aggregationNode.getFunctions(), aggregationNode.getMasks(), aggregationNode.getSampleWeight(), aggregationNode.getConfidence());
            }
            if (!aggregationNode.getAggregations().isEmpty() || aggregationNode.getOutputSymbols().size() != aggregationNode.getGroupBy().size() || !aggregationNode.getOutputSymbols().containsAll(aggregationNode.getGroupBy())) {
                return new AggregationNode(aggregationNode.getId(), ((MaterializeSampleNode) rewrite).getSource(), aggregationNode.getGroupBy(), aggregationNode.getAggregations(), aggregationNode.getFunctions(), aggregationNode.getMasks(), Optional.of(((MaterializeSampleNode) rewrite).getSampleWeightSymbol()), aggregationNode.getConfidence());
            }
            ImmutableMap.Builder builder = ImmutableMap.builder();
            for (Symbol symbol : rewrite.getOutputSymbols()) {
                builder.put(symbol, new QualifiedNameReference(symbol.toQualifiedName()));
            }
            return new AggregationNode(aggregationNode.getId(), new ProjectNode(this.idAllocator.getNextId(), ((MaterializeSampleNode) rewrite).getSource(), builder.build()), aggregationNode.getGroupBy(), aggregationNode.getAggregations(), aggregationNode.getFunctions(), aggregationNode.getMasks(), Optional.absent(), aggregationNode.getConfidence());
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanNodeRewriter
        public PlanNode rewriteJoin(JoinNode joinNode, Void r10, PlanRewriter<Void> planRewriter) {
            Symbol symbol;
            ArithmeticExpression arithmeticExpression;
            PlanNode rewrite = planRewriter.rewrite(joinNode.getLeft(), null);
            PlanNode rewrite2 = planRewriter.rewrite(joinNode.getRight(), null);
            if (!(rewrite instanceof MaterializeSampleNode) && !(rewrite2 instanceof MaterializeSampleNode)) {
                return new JoinNode(joinNode.getId(), joinNode.getType(), rewrite, rewrite2, joinNode.getCriteria());
            }
            Symbol symbol2 = null;
            Symbol symbol3 = null;
            if (rewrite instanceof MaterializeSampleNode) {
                symbol2 = ((MaterializeSampleNode) rewrite).getSampleWeightSymbol();
                rewrite = ((MaterializeSampleNode) rewrite).getSource();
            }
            if (rewrite2 instanceof MaterializeSampleNode) {
                symbol3 = ((MaterializeSampleNode) rewrite2).getSampleWeightSymbol();
                rewrite2 = ((MaterializeSampleNode) rewrite2).getSource();
            }
            PlanNode joinNode2 = new JoinNode(joinNode.getId(), joinNode.getType(), rewrite, rewrite2, joinNode.getCriteria());
            if (symbol2 == null || symbol3 == null) {
                symbol = symbol2 == null ? symbol3 : symbol2;
                if ((joinNode.getType() == JoinNode.Type.LEFT && symbol2 == null) || (joinNode.getType() == JoinNode.Type.RIGHT && symbol3 == null)) {
                    ImmutableMap.Builder builder = ImmutableMap.builder();
                    for (Symbol symbol4 : Iterables.filter(joinNode.getOutputSymbols(), Predicates.not(Predicates.equalTo(symbol)))) {
                        builder.put(symbol4, new QualifiedNameReference(symbol4.toQualifiedName()));
                    }
                    Expression oneIfNull = oneIfNull(symbol);
                    symbol = this.symbolAllocator.newSymbol(oneIfNull, Type.BIGINT);
                    builder.put(symbol, oneIfNull);
                    joinNode2 = new ProjectNode(this.idAllocator.getNextId(), joinNode2, builder.build());
                }
            } else {
                ImmutableMap.Builder builder2 = ImmutableMap.builder();
                switch (joinNode.getType()) {
                    case INNER:
                    case CROSS:
                        arithmeticExpression = new ArithmeticExpression(ArithmeticExpression.Type.MULTIPLY, new QualifiedNameReference(symbol2.toQualifiedName()), new QualifiedNameReference(symbol3.toQualifiedName()));
                        break;
                    case LEFT:
                        arithmeticExpression = new ArithmeticExpression(ArithmeticExpression.Type.MULTIPLY, new QualifiedNameReference(symbol2.toQualifiedName()), oneIfNull(symbol3));
                        break;
                    case RIGHT:
                        arithmeticExpression = new ArithmeticExpression(ArithmeticExpression.Type.MULTIPLY, oneIfNull(symbol2), new QualifiedNameReference(symbol3.toQualifiedName()));
                        break;
                    default:
                        throw new AssertionError(String.format("Unknown join type: %s", joinNode.getType()));
                }
                symbol = this.symbolAllocator.newSymbol((Expression) arithmeticExpression, Type.BIGINT);
                builder2.put(symbol, arithmeticExpression);
                for (Symbol symbol5 : Iterables.filter(joinNode.getOutputSymbols(), Predicates.not(Predicates.in(ImmutableSet.of(symbol2, symbol3))))) {
                    builder2.put(symbol5, new QualifiedNameReference(symbol5.toQualifiedName()));
                }
                joinNode2 = new ProjectNode(this.idAllocator.getNextId(), joinNode2, builder2.build());
            }
            return new MaterializeSampleNode(this.idAllocator.getNextId(), joinNode2, symbol);
        }

        private Expression oneIfNull(Symbol symbol) {
            return new CoalesceExpression(new Expression[]{new QualifiedNameReference(symbol.toQualifiedName()), new LongLiteral("1")});
        }

        @Override // com.facebook.presto.sql.planner.plan.PlanNodeRewriter
        public PlanNode rewriteUnion(UnionNode unionNode, Void r8, final PlanRewriter<Void> planRewriter) {
            List<PlanNode> list = IterableTransformer.on(unionNode.getSources()).transform(new Function<PlanNode, PlanNode>() { // from class: com.facebook.presto.sql.planner.optimizations.MaterializeSamplePullUp.Rewriter.1
                public PlanNode apply(PlanNode planNode) {
                    return planRewriter.rewrite(planNode, null);
                }
            }).list();
            if (Iterables.all(list, Predicates.not(Predicates.instanceOf(MaterializeSampleNode.class)))) {
                return new UnionNode(unionNode.getId(), list, unionNode.getSymbolMapping());
            }
            ImmutableListMultimap.Builder putAll = ImmutableListMultimap.builder().putAll(unionNode.getSymbolMapping());
            ImmutableList.Builder builder = ImmutableList.builder();
            Symbol newSymbol = this.symbolAllocator.newSymbol("$sampleWeight", Type.BIGINT);
            for (PlanNode planNode : list) {
                if (planNode instanceof MaterializeSampleNode) {
                    putAll.put(newSymbol, ((MaterializeSampleNode) planNode).getSampleWeightSymbol());
                    builder.add(((MaterializeSampleNode) planNode).getSource());
                } else {
                    Symbol newSymbol2 = this.symbolAllocator.newSymbol("$sampleWeight", Type.BIGINT);
                    putAll.put(newSymbol, newSymbol2);
                    builder.add(addSampleWeight(planNode, newSymbol2));
                }
            }
            return new MaterializeSampleNode(this.idAllocator.getNextId(), new UnionNode(unionNode.getId(), builder.build(), putAll.build()), newSymbol);
        }

        private PlanNode addSampleWeight(PlanNode planNode, Symbol symbol) {
            ImmutableMap.Builder builder = ImmutableMap.builder();
            for (Symbol symbol2 : planNode.getOutputSymbols()) {
                builder.put(symbol2, new QualifiedNameReference(symbol2.toQualifiedName()));
            }
            builder.put(symbol, new LongLiteral("1"));
            return new ProjectNode(this.idAllocator.getNextId(), planNode, builder.build());
        }
    }

    @Override // com.facebook.presto.sql.planner.optimizations.PlanOptimizer
    public PlanNode optimize(PlanNode planNode, Session session, Map<Symbol, Type> map, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator) {
        Preconditions.checkNotNull(planNode, "plan is null");
        Preconditions.checkNotNull(session, "session is null");
        Preconditions.checkNotNull(map, "types is null");
        Preconditions.checkNotNull(symbolAllocator, "symbolAllocator is null");
        Preconditions.checkNotNull(planNodeIdAllocator, "idAllocator is null");
        return PlanRewriter.rewriteWith(new Rewriter(planNodeIdAllocator, symbolAllocator), planNode, null);
    }
}
