package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.Range;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.predicate.ValueSet;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.DomainTranslator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.RowNumberNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.Expression;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/PushPredicateThroughProjectIntoRowNumber.class */
public class PushPredicateThroughProjectIntoRowNumber implements Rule<FilterNode> {
    private static final Capture<ProjectNode> PROJECT = Capture.newCapture();
    private static final Capture<RowNumberNode> ROW_NUMBER = Capture.newCapture();
    private static final Pattern<FilterNode> PATTERN = Patterns.filter().with(Patterns.source().matching(Patterns.project().matching((v0) -> {
        return v0.isIdentity();
    }).capturedAs(PROJECT).with(Patterns.source().matching(Patterns.rowNumber().capturedAs(ROW_NUMBER)))));
    private final PlannerContext plannerContext;

    public PushPredicateThroughProjectIntoRowNumber(PlannerContext plannerContext) {
        this.plannerContext = (PlannerContext) Objects.requireNonNull(plannerContext, "plannerContext is null");
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<FilterNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(FilterNode filterNode, Captures captures, Rule.Context context) {
        ProjectNode projectNode = (ProjectNode) captures.get(PROJECT);
        RowNumberNode rowNumberNode = (RowNumberNode) captures.get(ROW_NUMBER);
        Symbol rowNumberSymbol = rowNumberNode.getRowNumberSymbol();
        if (!projectNode.getAssignments().getSymbols().contains(rowNumberSymbol)) {
            return Rule.Result.empty();
        }
        DomainTranslator.ExtractionResult extractionResult = DomainTranslator.getExtractionResult(this.plannerContext, context.getSession(), filterNode.getPredicate(), context.getSymbolAllocator().getTypes());
        TupleDomain<Symbol> tupleDomain = extractionResult.getTupleDomain();
        OptionalInt extractUpperBound = extractUpperBound(tupleDomain, rowNumberSymbol);
        if (extractUpperBound.isEmpty()) {
            return Rule.Result.empty();
        }
        if (extractUpperBound.getAsInt() <= 0) {
            return Rule.Result.ofPlanNode(new ValuesNode(filterNode.getId(), filterNode.getOutputSymbols(), ImmutableList.of()));
        }
        boolean z = false;
        if (rowNumberNode.getMaxRowCountPerPartition().isEmpty() || rowNumberNode.getMaxRowCountPerPartition().get().intValue() > extractUpperBound.getAsInt()) {
            rowNumberNode = new RowNumberNode(rowNumberNode.getId(), rowNumberNode.getSource(), rowNumberNode.getPartitionBy(), rowNumberNode.isOrderSensitive(), rowNumberNode.getRowNumberSymbol(), Optional.of(Integer.valueOf(extractUpperBound.getAsInt())), rowNumberNode.getHashSymbol());
            projectNode = (ProjectNode) projectNode.replaceChildren(ImmutableList.of(rowNumberNode));
            z = true;
        }
        if (!allRowNumberValuesInDomain(tupleDomain, rowNumberSymbol, rowNumberNode.getMaxRowCountPerPartition().get().intValue())) {
            return z ? Rule.Result.ofPlanNode(filterNode.replaceChildren(ImmutableList.of(projectNode))) : Rule.Result.empty();
        }
        Expression combineConjuncts = ExpressionUtils.combineConjuncts(this.plannerContext.getMetadata(), extractionResult.getRemainingExpression(), new DomainTranslator(this.plannerContext).toPredicate(context.getSession(), tupleDomain.filter((symbol, domain) -> {
            return !symbol.equals(rowNumberSymbol);
        })));
        return combineConjuncts.equals(BooleanLiteral.TRUE_LITERAL) ? Rule.Result.ofPlanNode(projectNode) : Rule.Result.ofPlanNode(new FilterNode(filterNode.getId(), projectNode, combineConjuncts));
    }

    private static OptionalInt extractUpperBound(TupleDomain<Symbol> tupleDomain, Symbol symbol) {
        Domain domain;
        if (!tupleDomain.isNone() && (domain = (Domain) ((Map) tupleDomain.getDomains().get()).get(symbol)) != null) {
            ValueSet values = domain.getValues();
            if (values.isAll() || values.isNone() || values.getRanges().getRangeCount() <= 0) {
                return OptionalInt.empty();
            }
            Range span = values.getRanges().getSpan();
            if (span.isHighUnbounded()) {
                return OptionalInt.empty();
            }
            long longValue = ((Long) span.getHighBoundedValue()).longValue();
            if (!span.isHighInclusive()) {
                longValue--;
            }
            return (longValue < -2147483648L || longValue > 2147483647L) ? OptionalInt.empty() : OptionalInt.of(Math.toIntExact(longValue));
        }
        return OptionalInt.empty();
    }

    private static boolean allRowNumberValuesInDomain(TupleDomain<Symbol> tupleDomain, Symbol symbol, long j) {
        if (tupleDomain.isNone()) {
            return false;
        }
        Domain domain = (Domain) ((Map) tupleDomain.getDomains().get()).get(symbol);
        if (domain == null) {
            return true;
        }
        return domain.getValues().contains(ValueSet.ofRanges(Range.range(domain.getType(), 0L, true, Long.valueOf(j), true), new Range[0]));
    }
}
