package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Streams;
import io.trino.Session;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.tree.Expression;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/PushProjectionThroughJoin.class */
public final class PushProjectionThroughJoin {
    public static Optional<PlanNode> pushProjectionThroughJoin(PlannerContext plannerContext, ProjectNode projectNode, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator, Session session, TypeAnalyzer typeAnalyzer, TypeProvider typeProvider) {
        if (!projectNode.getAssignments().getExpressions().stream().allMatch(expression -> {
            return DeterminismEvaluator.isDeterministic(expression, plannerContext.getMetadata());
        })) {
            return Optional.empty();
        }
        PlanNode resolve = lookup.resolve(projectNode.getSource());
        if (!(resolve instanceof JoinNode)) {
            return Optional.empty();
        }
        JoinNode joinNode = (JoinNode) resolve;
        PlanNode left = joinNode.getLeft();
        PlanNode right = joinNode.getRight();
        if (joinNode.getType() != JoinNode.Type.INNER) {
            return Optional.empty();
        }
        Assignments.Builder builder = Assignments.builder();
        Assignments.Builder builder2 = Assignments.builder();
        for (Map.Entry<Symbol, Expression> entry : projectNode.getAssignments().entrySet()) {
            Expression value = entry.getValue();
            Set<Symbol> extractUnique = SymbolsExtractor.extractUnique(value);
            if (left.getOutputSymbols().containsAll(extractUnique)) {
                builder.put(entry.getKey(), value);
            } else {
                if (!right.getOutputSymbols().containsAll(extractUnique)) {
                    return Optional.empty();
                }
                builder2.put(entry.getKey(), value);
            }
        }
        for (Symbol symbol : getJoinRequiredSymbols(joinNode)) {
            if (left.getOutputSymbols().contains(symbol)) {
                builder.putIdentity(symbol);
            } else {
                Preconditions.checkState(right.getOutputSymbols().contains(symbol));
                builder2.putIdentity(symbol);
            }
        }
        Assignments build = builder.build();
        Assignments build2 = builder2.build();
        Stream<Symbol> stream = build.getOutputs().stream();
        ImmutableSet copyOf = ImmutableSet.copyOf(projectNode.getOutputSymbols());
        Objects.requireNonNull(copyOf);
        List list = (List) stream.filter((v1) -> {
            return r1.contains(v1);
        }).collect(ImmutableList.toImmutableList());
        Stream<Symbol> stream2 = build2.getOutputs().stream();
        ImmutableSet copyOf2 = ImmutableSet.copyOf(projectNode.getOutputSymbols());
        Objects.requireNonNull(copyOf2);
        return Optional.of(new JoinNode(joinNode.getId(), joinNode.getType(), inlineProjections(plannerContext, new ProjectNode(planNodeIdAllocator.getNextId(), left, build), lookup, session, typeAnalyzer, typeProvider), inlineProjections(plannerContext, new ProjectNode(planNodeIdAllocator.getNextId(), right, build2), lookup, session, typeAnalyzer, typeProvider), joinNode.getCriteria(), list, (List) stream2.filter((v1) -> {
            return r1.contains(v1);
        }).collect(ImmutableList.toImmutableList()), joinNode.isMaySkipOutputDuplicates(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters(), joinNode.getReorderJoinStatsAndCost()));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static PlanNode inlineProjections(PlannerContext plannerContext, ProjectNode projectNode, Lookup lookup, Session session, TypeAnalyzer typeAnalyzer, TypeProvider typeProvider) {
        PlanNode resolve = lookup.resolve(projectNode.getSource());
        return resolve instanceof ProjectNode ? (PlanNode) InlineProjections.inlineProjections(plannerContext, projectNode, (ProjectNode) resolve, session, typeAnalyzer, typeProvider).map(projectNode2 -> {
            return inlineProjections(plannerContext, projectNode2, lookup, session, typeAnalyzer, typeProvider);
        }).orElse(projectNode) : projectNode;
    }

    private static Set<Symbol> getJoinRequiredSymbols(JoinNode joinNode) {
        return (Set) Streams.concat(new Stream[]{joinNode.getCriteria().stream().map((v0) -> {
            return v0.getLeft();
        }), joinNode.getCriteria().stream().map((v0) -> {
            return v0.getRight();
        }), ((Set) joinNode.getFilter().map(SymbolsExtractor::extractUnique).orElse(ImmutableSet.of())).stream(), ((ImmutableSet) joinNode.getLeftHashSymbol().map((v0) -> {
            return ImmutableSet.of(v0);
        }).orElse(ImmutableSet.of())).stream(), ((ImmutableSet) joinNode.getRightHashSymbol().map((v0) -> {
            return ImmutableSet.of(v0);
        }).orElse(ImmutableSet.of())).stream()}).collect(ImmutableSet.toImmutableSet());
    }

    private PushProjectionThroughJoin() {
    }
}
