package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.block.SortOrder;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.Ordering;
import com.facebook.presto.spi.plan.OrderingScheme;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.PlanVariableAllocator;
import com.facebook.presto.sql.planner.RowExpressionVariableInliner;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.DistinctOutputQueryUtil;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.relational.Expressions;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin.class */
public class PushAggregationThroughOuterJoin implements Rule<AggregationNode> {
    private static final Capture<JoinNode> JOIN = Capture.newCapture();
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().with(Patterns.source().matching(Patterns.join().capturedAs(JOIN)));
    private final FunctionAndTypeManager functionAndTypeManager;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/PushAggregationThroughOuterJoin$MappedAggregationInfo.class */
    public static class MappedAggregationInfo {
        private final AggregationNode aggregationNode;
        private final Map<VariableReferenceExpression, VariableReferenceExpression> variableMapping;

        public MappedAggregationInfo(AggregationNode aggregationNode, Map<VariableReferenceExpression, VariableReferenceExpression> map) {
            this.aggregationNode = aggregationNode;
            this.variableMapping = map;
        }

        public Map<VariableReferenceExpression, VariableReferenceExpression> getVariableMapping() {
            return this.variableMapping;
        }

        public AggregationNode getAggregation() {
            return this.aggregationNode;
        }
    }

    public PushAggregationThroughOuterJoin(FunctionAndTypeManager functionAndTypeManager) {
        this.functionAndTypeManager = (FunctionAndTypeManager) Objects.requireNonNull(functionAndTypeManager, "functionManager is null");
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.shouldPushAggregationThroughJoin(session);
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
        JoinNode joinNode = (JoinNode) captures.get(JOIN);
        if (!joinNode.getFilter().isPresent() && ((joinNode.getType() == JoinNode.Type.LEFT || joinNode.getType() == JoinNode.Type.RIGHT) && groupsOnAllColumns(aggregationNode, getOuterTable(joinNode).getOutputVariables()))) {
            PlanNode resolve = context.getLookup().resolve(getOuterTable(joinNode));
            Lookup lookup = context.getLookup();
            lookup.getClass();
            if (DistinctOutputQueryUtil.isDistinct(resolve, lookup::resolve)) {
                AggregationNode aggregationNode2 = new AggregationNode(aggregationNode.getSourceLocation(), aggregationNode.getId(), getInnerTable(joinNode), aggregationNode.getAggregations(), AggregationNode.singleGroupingSet((List) joinNode.getCriteria().stream().map(joinNode.getType() == JoinNode.Type.RIGHT ? (v0) -> {
                    return v0.getLeft();
                } : (v0) -> {
                    return v0.getRight();
                }).collect(ImmutableList.toImmutableList())), ImmutableList.of(), aggregationNode.getStep(), aggregationNode.getHashVariable(), aggregationNode.getGroupIdVariable());
                Optional<PlanNode> coalesceWithNullAggregation = coalesceWithNullAggregation(aggregationNode2, joinNode.getType() == JoinNode.Type.LEFT ? new JoinNode(joinNode.getSourceLocation(), joinNode.getId(), joinNode.getType(), joinNode.getLeft(), aggregationNode2, joinNode.getCriteria(), ImmutableList.builder().addAll(joinNode.getLeft().getOutputVariables()).addAll(aggregationNode2.getAggregations().keySet()).build(), joinNode.getFilter(), joinNode.getLeftHashVariable(), joinNode.getRightHashVariable(), joinNode.getDistributionType(), joinNode.getDynamicFilters()) : new JoinNode(joinNode.getSourceLocation(), joinNode.getId(), joinNode.getType(), aggregationNode2, joinNode.getRight(), joinNode.getCriteria(), ImmutableList.builder().addAll(aggregationNode2.getAggregations().keySet()).addAll(joinNode.getRight().getOutputVariables()).build(), joinNode.getFilter(), joinNode.getLeftHashVariable(), joinNode.getRightHashVariable(), joinNode.getDistributionType(), joinNode.getDynamicFilters()), context.getVariableAllocator(), context.getIdAllocator(), context.getLookup());
                return !coalesceWithNullAggregation.isPresent() ? Rule.Result.empty() : Rule.Result.ofPlanNode(coalesceWithNullAggregation.get());
            }
        }
        return Rule.Result.empty();
    }

    private static PlanNode getInnerTable(JoinNode joinNode) {
        Preconditions.checkState(joinNode.getType() == JoinNode.Type.LEFT || joinNode.getType() == JoinNode.Type.RIGHT, "expected LEFT or RIGHT JOIN");
        return joinNode.getType().equals(JoinNode.Type.LEFT) ? joinNode.getRight() : joinNode.getLeft();
    }

    private static PlanNode getOuterTable(JoinNode joinNode) {
        Preconditions.checkState(joinNode.getType() == JoinNode.Type.LEFT || joinNode.getType() == JoinNode.Type.RIGHT, "expected LEFT or RIGHT JOIN");
        return joinNode.getType().equals(JoinNode.Type.LEFT) ? joinNode.getLeft() : joinNode.getRight();
    }

    private static boolean groupsOnAllColumns(AggregationNode aggregationNode, List<VariableReferenceExpression> list) {
        return new HashSet(aggregationNode.getGroupingKeys()).equals(new HashSet(list));
    }

    private Optional<PlanNode> coalesceWithNullAggregation(AggregationNode aggregationNode, PlanNode planNode, PlanVariableAllocator planVariableAllocator, PlanNodeIdAllocator planNodeIdAllocator, Lookup lookup) {
        Optional<MappedAggregationInfo> createAggregationOverNull = createAggregationOverNull(aggregationNode, planVariableAllocator, planNodeIdAllocator, lookup);
        if (!createAggregationOverNull.isPresent()) {
            return Optional.empty();
        }
        MappedAggregationInfo mappedAggregationInfo = createAggregationOverNull.get();
        AggregationNode aggregation = mappedAggregationInfo.getAggregation();
        Map<VariableReferenceExpression, VariableReferenceExpression> variableMapping = mappedAggregationInfo.getVariableMapping();
        JoinNode joinNode = new JoinNode(planNode.getSourceLocation(), planNodeIdAllocator.getNextId(), JoinNode.Type.INNER, planNode, aggregation, ImmutableList.of(), ImmutableList.builder().addAll(planNode.getOutputVariables()).addAll(aggregation.getOutputVariables()).build(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of());
        Assignments.Builder builder = Assignments.builder();
        for (VariableReferenceExpression variableReferenceExpression : planNode.getOutputVariables()) {
            if (aggregationNode.getAggregations().keySet().contains(variableReferenceExpression)) {
                builder.put(variableReferenceExpression, coalesce(ImmutableList.of(variableReferenceExpression, variableMapping.get(variableReferenceExpression))));
            } else {
                builder.put(variableReferenceExpression, variableReferenceExpression);
            }
        }
        return Optional.of(new ProjectNode(planNodeIdAllocator.getNextId(), joinNode, builder.build()));
    }

    private static RowExpression coalesce(List<RowExpression> list) {
        return new SpecialFormExpression(SpecialFormExpression.Form.COALESCE, list.get(0).getType(), list);
    }

    private Optional<MappedAggregationInfo> createAggregationOverNull(AggregationNode aggregationNode, PlanVariableAllocator planVariableAllocator, PlanNodeIdAllocator planNodeIdAllocator, Lookup lookup) {
        ImmutableList.Builder builder = ImmutableList.builder();
        ImmutableList.Builder builder2 = ImmutableList.builder();
        ImmutableMap.Builder builder3 = ImmutableMap.builder();
        for (VariableReferenceExpression variableReferenceExpression : aggregationNode.getSource().getOutputVariables()) {
            ConstantExpression constantNull = Expressions.constantNull(variableReferenceExpression.getSourceLocation(), variableReferenceExpression.getType());
            builder2.add(constantNull);
            VariableReferenceExpression newVariable = planVariableAllocator.newVariable((RowExpression) constantNull);
            builder.add(newVariable);
            builder3.put(variableReferenceExpression, newVariable);
        }
        ValuesNode valuesNode = new ValuesNode(aggregationNode.getSourceLocation(), planNodeIdAllocator.getNextId(), builder.build(), ImmutableList.of(builder2.build()));
        ImmutableMap build = builder3.build();
        ImmutableMap.Builder builder4 = ImmutableMap.builder();
        ImmutableMap.Builder builder5 = ImmutableMap.builder();
        for (Map.Entry entry : aggregationNode.getAggregations().entrySet()) {
            VariableReferenceExpression variableReferenceExpression2 = (VariableReferenceExpression) entry.getKey();
            AggregationNode.Aggregation aggregation = (AggregationNode.Aggregation) entry.getValue();
            if (!isUsingVariables(aggregation, build.keySet())) {
                return Optional.empty();
            }
            AggregationNode.Aggregation aggregation2 = new AggregationNode.Aggregation(new CallExpression(aggregation.getCall().getSourceLocation(), aggregation.getCall().getDisplayName(), aggregation.getCall().getFunctionHandle(), aggregation.getCall().getType(), (List) aggregation.getArguments().stream().map(rowExpression -> {
                return RowExpressionVariableInliner.inlineVariables((Map<VariableReferenceExpression, ? extends RowExpression>) build, rowExpression);
            }).collect(ImmutableList.toImmutableList())), aggregation.getFilter().map(rowExpression2 -> {
                return RowExpressionVariableInliner.inlineVariables((Map<VariableReferenceExpression, ? extends RowExpression>) build, rowExpression2);
            }), aggregation.getOrderBy().map(orderingScheme -> {
                return inlineOrderByVariables(build, orderingScheme);
            }), aggregation.isDistinct(), aggregation.getMask().map(variableReferenceExpression3 -> {
                return new VariableReferenceExpression(((VariableReferenceExpression) build.get(variableReferenceExpression3)).getSourceLocation(), ((VariableReferenceExpression) build.get(variableReferenceExpression3)).getName(), variableReferenceExpression3.getType());
            }));
            VariableReferenceExpression newVariable2 = planVariableAllocator.newVariable(aggregation.getCall().getSourceLocation(), this.functionAndTypeManager.getFunctionMetadata(aggregation2.getFunctionHandle()).getName().getObjectName(), variableReferenceExpression2.getType());
            builder5.put(newVariable2, aggregation2);
            builder4.put(variableReferenceExpression2, newVariable2);
        }
        return Optional.of(new MappedAggregationInfo(new AggregationNode(aggregationNode.getSourceLocation(), planNodeIdAllocator.getNextId(), valuesNode, builder5.build(), AggregationNode.globalAggregation(), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()), builder4.build()));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static OrderingScheme inlineOrderByVariables(Map<VariableReferenceExpression, VariableReferenceExpression> map, OrderingScheme orderingScheme) {
        ImmutableList.Builder builder = ImmutableList.builder();
        ImmutableMap.Builder builder2 = new ImmutableMap.Builder();
        for (VariableReferenceExpression variableReferenceExpression : orderingScheme.getOrderByVariables()) {
            VariableReferenceExpression variableReferenceExpression2 = map.get(variableReferenceExpression);
            builder.add(variableReferenceExpression2);
            builder2.put(variableReferenceExpression2, orderingScheme.getOrdering(variableReferenceExpression));
        }
        ImmutableMap build = builder2.build();
        return new OrderingScheme((List) builder.build().stream().map(variableReferenceExpression3 -> {
            return new Ordering(variableReferenceExpression3, (SortOrder) build.get(variableReferenceExpression3));
        }).collect(ImmutableList.toImmutableList()));
    }

    private static boolean isUsingVariables(AggregationNode.Aggregation aggregation, Set<VariableReferenceExpression> set) {
        HashSet hashSet = new HashSet();
        for (VariableReferenceExpression variableReferenceExpression : aggregation.getArguments()) {
            if (variableReferenceExpression instanceof VariableReferenceExpression) {
                hashSet.add(variableReferenceExpression);
            }
        }
        Stream<VariableReferenceExpression> stream = set.stream();
        hashSet.getClass();
        return stream.anyMatch((v1) -> {
            return r1.contains(v1);
        });
    }
}
