package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.LimitNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.ExpressionTreeUtils;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.PlanNodeDecorrelator;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.AssignmentUtils;
import com.facebook.presto.sql.planner.plan.LateralJoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.ExistsPredicate;
import com.facebook.presto.sql.tree.LongLiteral;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.class */
public class TransformExistsApplyToLateralNode implements Rule<ApplyNode> {
    private static final Pattern<ApplyNode> PATTERN = Patterns.applyNode();
    private final StandardFunctionResolution functionResolution;

    public TransformExistsApplyToLateralNode(FunctionAndTypeManager functionAndTypeManager) {
        Objects.requireNonNull(functionAndTypeManager, "functionManager is null");
        this.functionResolution = new FunctionResolution(functionAndTypeManager);
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Pattern<ApplyNode> getPattern() {
        return PATTERN;
    }

    @Override // com.facebook.presto.sql.planner.iterative.Rule
    public Rule.Result apply(ApplyNode applyNode, Captures captures, Rule.Context context) {
        if (applyNode.getSubqueryAssignments().size() == 1 && (OriginalExpressionUtils.castToExpression((RowExpression) Iterables.getOnlyElement(applyNode.getSubqueryAssignments().getExpressions())) instanceof ExistsPredicate)) {
            return (Rule.Result) rewriteToNonDefaultAggregation(applyNode, context).map(Rule.Result::ofPlanNode).orElseGet(() -> {
                return Rule.Result.ofPlanNode(rewriteToDefaultAggregation(applyNode, context));
            });
        }
        return Rule.Result.empty();
    }

    private Optional<PlanNode> rewriteToNonDefaultAggregation(ApplyNode applyNode, Rule.Context context) {
        Preconditions.checkState(applyNode.getSubquery().getOutputVariables().isEmpty(), "Expected subquery output variables to be pruned");
        VariableReferenceExpression variableReferenceExpression = (VariableReferenceExpression) Iterables.getOnlyElement(applyNode.getSubqueryAssignments().getVariables());
        VariableReferenceExpression newVariable = context.getVariableAllocator().newVariable(variableReferenceExpression.getSourceLocation(), "subqueryTrue", (Type) BooleanType.BOOLEAN);
        Assignments.Builder builder = Assignments.builder();
        builder.putAll(AssignmentUtils.identitiesAsSymbolReferences(applyNode.getInput().getOutputVariables()));
        builder.put(variableReferenceExpression, OriginalExpressionUtils.castToRowExpression(new CoalesceExpression(ImmutableList.of(ExpressionTreeUtils.createSymbolReference(newVariable), BooleanLiteral.FALSE_LITERAL))));
        PlanNode projectNode = new ProjectNode(context.getIdAllocator().getNextId(), new LimitNode(applyNode.getSourceLocation(), context.getIdAllocator().getNextId(), applyNode.getSubquery(), 1L, LimitNode.Step.FINAL), Assignments.of(newVariable, OriginalExpressionUtils.castToRowExpression(BooleanLiteral.TRUE_LITERAL)));
        return !new PlanNodeDecorrelator(context.getIdAllocator(), context.getVariableAllocator(), context.getLookup()).decorrelateFilters(projectNode, applyNode.getCorrelation()).isPresent() ? Optional.empty() : Optional.of(new ProjectNode(context.getIdAllocator().getNextId(), new LateralJoinNode(applyNode.getSourceLocation(), applyNode.getId(), applyNode.getInput(), projectNode, applyNode.getCorrelation(), LateralJoinNode.Type.LEFT, applyNode.getOriginSubqueryError()), builder.build()));
    }

    private PlanNode rewriteToDefaultAggregation(ApplyNode applyNode, Rule.Context context) {
        VariableReferenceExpression newVariable = context.getVariableAllocator().newVariable("count", (Type) BigintType.BIGINT);
        VariableReferenceExpression variableReferenceExpression = (VariableReferenceExpression) Iterables.getOnlyElement(applyNode.getSubqueryAssignments().getVariables());
        return new LateralJoinNode(applyNode.getSourceLocation(), applyNode.getId(), applyNode.getInput(), new ProjectNode(context.getIdAllocator().getNextId(), new AggregationNode(applyNode.getSourceLocation(), context.getIdAllocator().getNextId(), applyNode.getSubquery(), ImmutableMap.of(newVariable, new AggregationNode.Aggregation(new CallExpression(variableReferenceExpression.getSourceLocation(), "count", this.functionResolution.countFunction(), BigintType.BIGINT, ImmutableList.of()), Optional.empty(), Optional.empty(), false, Optional.empty())), AggregationNode.globalAggregation(), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()), Assignments.of(variableReferenceExpression, OriginalExpressionUtils.castToRowExpression(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, OriginalExpressionUtils.asSymbolReference(newVariable), new Cast(new LongLiteral("0"), BigintType.BIGINT.toString()))))), applyNode.getCorrelation(), LateralJoinNode.Type.INNER, applyNode.getOriginSubqueryError());
    }
}
