package com.facebook.presto.sql.planner.iterative.rule.test;

import com.facebook.presto.Session;
import com.facebook.presto.cost.CachingCostProvider;
import com.facebook.presto.cost.CachingStatsProvider;
import com.facebook.presto.cost.CostCalculator;
import com.facebook.presto.cost.CostProvider;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.StatsAndCosts;
import com.facebook.presto.cost.StatsCalculator;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.matching.Match;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.security.AccessControl;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.Plan;
import com.facebook.presto.sql.planner.PlanVariableAllocator;
import com.facebook.presto.sql.planner.RuleStatsRecorder;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.assertions.PlanAssert;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.iterative.IterativeOptimizer;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Memo;
import com.facebook.presto.sql.planner.iterative.PlanNodeMatcher;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.TranslateExpressions;
import com.facebook.presto.sql.planner.planPrinter.PlanPrinter;
import com.facebook.presto.transaction.TransactionBuilder;
import com.facebook.presto.transaction.TransactionManager;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Stream;
import org.testng.Assert;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.class */
public class RuleAssert {
    private final Metadata metadata;
    private final TestingStatsCalculator statsCalculator;
    private final CostCalculator costCalculator;
    private final Rule<?> rule;
    private final PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator();
    private final TransactionManager transactionManager;
    private final AccessControl accessControl;
    private Session session;
    private TypeProvider types;
    private PlanNode plan;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert$RuleApplication.class */
    public static class RuleApplication {
        private final Lookup lookup;
        private final StatsProvider statsProvider;
        private final TypeProvider types;
        private final Rule.Result result;

        public RuleApplication(Lookup lookup, StatsProvider statsProvider, TypeProvider typeProvider, Rule.Result result) {
            this.lookup = (Lookup) Objects.requireNonNull(lookup, "lookup is null");
            this.statsProvider = (StatsProvider) Objects.requireNonNull(statsProvider, "statsProvider is null");
            this.types = (TypeProvider) Objects.requireNonNull(typeProvider, "types is null");
            this.result = (Rule.Result) Objects.requireNonNull(result, "result is null");
        }

        /* JADX INFO: Access modifiers changed from: private */
        public boolean wasRuleApplied() {
            return !this.result.isEmpty();
        }

        public PlanNode getTransformedPlan() {
            return (PlanNode) this.result.getTransformedPlan().orElseThrow(() -> {
                return new IllegalStateException("Rule did not produce transformed plan");
            });
        }
    }

    /* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert$TestingStatsCalculator.class */
    public static class TestingStatsCalculator implements StatsCalculator {
        private final StatsCalculator delegate;
        private final Map<PlanNodeId, PlanNodeStatsEstimate> stats = new HashMap();

        public TestingStatsCalculator(StatsCalculator statsCalculator) {
            this.delegate = (StatsCalculator) Objects.requireNonNull(statsCalculator, "delegate is null");
        }

        public PlanNodeStatsEstimate calculateStats(PlanNode planNode, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider typeProvider) {
            return this.stats.containsKey(planNode.getId()) ? this.stats.get(planNode.getId()) : this.delegate.calculateStats(planNode, statsProvider, lookup, session, typeProvider);
        }

        public void setNodeStats(PlanNodeId planNodeId, PlanNodeStatsEstimate planNodeStatsEstimate) {
            this.stats.put(planNodeId, planNodeStatsEstimate);
        }
    }

    public RuleAssert(Metadata metadata, StatsCalculator statsCalculator, CostCalculator costCalculator, Session session, Rule rule, TransactionManager transactionManager, AccessControl accessControl) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
        this.statsCalculator = new TestingStatsCalculator((StatsCalculator) Objects.requireNonNull(statsCalculator, "statsCalculator is null"));
        this.costCalculator = (CostCalculator) Objects.requireNonNull(costCalculator, "costCalculator is null");
        this.session = (Session) Objects.requireNonNull(session, "session is null");
        this.rule = (Rule) Objects.requireNonNull(rule, "rule is null");
        this.transactionManager = (TransactionManager) Objects.requireNonNull(transactionManager, "transactionManager is null");
        this.accessControl = (AccessControl) Objects.requireNonNull(accessControl, "accessControl is null");
    }

    public RuleAssert setSystemProperty(String str, String str2) {
        return withSession(Session.builder(this.session).setSystemProperty(str, str2).build());
    }

    public RuleAssert withSession(Session session) {
        this.session = session;
        return this;
    }

    public RuleAssert overrideStats(String str, PlanNodeStatsEstimate planNodeStatsEstimate) {
        this.statsCalculator.setNodeStats(new PlanNodeId(str), planNodeStatsEstimate);
        return this;
    }

    public RuleAssert on(Function<PlanBuilder, PlanNode> function) {
        Preconditions.checkState(this.plan == null, "plan has already been set");
        PlanBuilder planBuilder = new PlanBuilder(this.session, this.idAllocator, this.metadata);
        this.plan = function.apply(planBuilder);
        this.types = planBuilder.getTypes();
        return this;
    }

    public PlanNode get() {
        RuleApplication applyRule = applyRule();
        TypeProvider typeProvider = applyRule.types;
        if (!applyRule.wasRuleApplied()) {
            Assert.fail(String.format("%s did not fire for:\n%s", this.rule.getClass().getName(), formatPlan(this.plan, typeProvider)));
        }
        return applyRule.getTransformedPlan();
    }

    public void doesNotFire() {
        RuleApplication applyRule = applyRule();
        if (applyRule.wasRuleApplied()) {
            Assert.fail(String.format("Expected %s to not fire for:\n%s", this.rule.getClass().getName(), inTransaction(session -> {
                return PlanPrinter.textLogicalPlan(this.plan, applyRule.types, this.metadata.getFunctionAndTypeManager(), StatsAndCosts.empty(), session, 2);
            })));
        }
    }

    public void matches(PlanMatchPattern planMatchPattern) {
        RuleApplication applyRule = applyRule();
        TypeProvider typeProvider = applyRule.types;
        if (!applyRule.wasRuleApplied()) {
            Assert.fail(String.format("%s did not fire for:\n%s", this.rule.getClass().getName(), formatPlan(this.plan, typeProvider)));
        }
        PlanNode transformedPlan = applyRule.getTransformedPlan();
        if (transformedPlan == this.plan) {
            Assert.fail(String.format("%s: rule fired but return the original plan:\n%s", this.rule.getClass().getName(), formatPlan(this.plan, typeProvider)));
        }
        if (!ImmutableSet.copyOf(this.plan.getOutputVariables()).equals(ImmutableSet.copyOf(transformedPlan.getOutputVariables()))) {
            Assert.fail(String.format("%s: output schema of transformed and original plans are not equivalent\n\texpected: %s\n\tactual:   %s", this.rule.getClass().getName(), this.plan.getOutputVariables(), transformedPlan.getOutputVariables()));
        }
        inTransaction(session -> {
            PlanAssert.assertPlan(session, this.metadata, applyRule.statsProvider, new Plan(transformedPlan, typeProvider, StatsAndCosts.empty()), applyRule.lookup, planMatchPattern, planNode -> {
                return translateExpressions(planNode, typeProvider);
            });
            return null;
        });
    }

    private RuleApplication applyRule() {
        PlanVariableAllocator planVariableAllocator = new PlanVariableAllocator(this.types.allVariables());
        Memo memo = new Memo(this.idAllocator, this.plan);
        Lookup from = Lookup.from(groupReference -> {
            return Stream.of(memo.resolve(groupReference));
        });
        PlanNode node = memo.getNode(memo.getRootGroup());
        return (RuleApplication) inTransaction(session -> {
            return applyRule(this.rule, node, ruleContext(this.statsCalculator, this.costCalculator, planVariableAllocator, memo, from, session));
        });
    }

    private static <T> RuleApplication applyRule(Rule<T> rule, PlanNode planNode, Rule.Context context) {
        Match match = new PlanNodeMatcher(context.getLookup()).match(rule.getPattern(), planNode);
        return new RuleApplication(context.getLookup(), context.getStatsProvider(), context.getVariableAllocator().getTypes(), (!rule.isEnabled(context.getSession()) || match.isEmpty()) ? Rule.Result.empty() : rule.apply(match.value(), match.captures(), context));
    }

    private String formatPlan(PlanNode planNode, TypeProvider typeProvider) {
        CachingStatsProvider cachingStatsProvider = new CachingStatsProvider(this.statsCalculator, this.session, typeProvider);
        CachingCostProvider cachingCostProvider = new CachingCostProvider(this.costCalculator, cachingStatsProvider, this.session);
        return (String) inTransaction(session -> {
            return PlanPrinter.textLogicalPlan(translateExpressions(planNode, typeProvider), typeProvider, this.metadata.getFunctionAndTypeManager(), StatsAndCosts.create(planNode, cachingStatsProvider, cachingCostProvider), session, 2, false);
        });
    }

    private <T> T inTransaction(Function<Session, T> function) {
        return (T) TransactionBuilder.transaction(this.transactionManager, this.accessControl).singleStatement().execute(this.session, session -> {
            session.getCatalog().ifPresent(str -> {
                this.metadata.getCatalogHandle(session, str);
            });
            return function.apply(session);
        });
    }

    private PlanNode translateExpressions(PlanNode planNode, TypeProvider typeProvider) {
        return new IterativeOptimizer(new RuleStatsRecorder(), this.statsCalculator, this.costCalculator, new TranslateExpressions(this.metadata, new SqlParser()).rules()).optimize(planNode, this.session, typeProvider, new PlanVariableAllocator(typeProvider.allVariables()), this.idAllocator, WarningCollector.NOOP);
    }

    private Rule.Context ruleContext(StatsCalculator statsCalculator, CostCalculator costCalculator, final PlanVariableAllocator planVariableAllocator, Memo memo, final Lookup lookup, final Session session) {
        final CachingStatsProvider cachingStatsProvider = new CachingStatsProvider(statsCalculator, Optional.of(memo), lookup, session, planVariableAllocator.getTypes());
        final CachingCostProvider cachingCostProvider = new CachingCostProvider(costCalculator, cachingStatsProvider, Optional.of(memo), session);
        return new Rule.Context() { // from class: com.facebook.presto.sql.planner.iterative.rule.test.RuleAssert.1
            public Lookup getLookup() {
                return lookup;
            }

            public PlanNodeIdAllocator getIdAllocator() {
                return RuleAssert.this.idAllocator;
            }

            public PlanVariableAllocator getVariableAllocator() {
                return planVariableAllocator;
            }

            public Session getSession() {
                return session;
            }

            public StatsProvider getStatsProvider() {
                return cachingStatsProvider;
            }

            public CostProvider getCostProvider() {
                return cachingCostProvider;
            }

            public void checkTimeoutNotExhausted() {
            }

            public WarningCollector getWarningCollector() {
                return WarningCollector.NOOP;
            }
        };
    }
}
