package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.metadata.FunctionKind;
import com.facebook.presto.metadata.FunctionRegistry;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.Signature;
import com.facebook.presto.spi.function.OperatorType;
import com.facebook.presto.spi.type.BooleanType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.TypeSignature;
import com.facebook.presto.sql.analyzer.ExpressionAnalyzer;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.SymbolToInputRewriter;
import com.facebook.presto.sql.relational.CallExpression;
import com.facebook.presto.sql.relational.ConstantExpression;
import com.facebook.presto.sql.relational.InputReferenceExpression;
import com.facebook.presto.sql.relational.LambdaDefinitionExpression;
import com.facebook.presto.sql.relational.RowExpression;
import com.facebook.presto.sql.relational.RowExpressionVisitor;
import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator;
import com.facebook.presto.sql.relational.VariableReferenceExpression;
import com.facebook.presto.sql.tree.Expression;
import com.google.common.base.Preconditions;
import com.google.common.collect.ComparisonChain;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Ordering;
import io.airlift.slice.Slice;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence.class */
public class ExpressionEquivalence {
    private static final Ordering<RowExpression> ROW_EXPRESSION_ORDERING = Ordering.from(new RowExpressionComparator());
    private static final CanonicalizationVisitor CANONICALIZATION_VISITOR = new CanonicalizationVisitor();
    private final Metadata metadata;
    private final SqlParser sqlParser;

    /* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence$CanonicalizationVisitor.class */
    private static class CanonicalizationVisitor implements RowExpressionVisitor<RowExpression, Void> {
        private CanonicalizationVisitor() {
        }

        @Override // com.facebook.presto.sql.relational.RowExpressionVisitor
        public RowExpression visitCall(CallExpression callExpression, Void r14) {
            CallExpression callExpression2 = new CallExpression(callExpression.getSignature(), callExpression.getType(), (List) callExpression.getArguments().stream().map(rowExpression -> {
                return (RowExpression) rowExpression.accept(this, r14);
            }).collect(ImmutableList.toImmutableList()));
            String name = callExpression2.getSignature().getName();
            if (name.equals("AND") || name.equals("OR")) {
                ImmutableSet copyOf = ImmutableSet.copyOf(flattenNestedCallArgs(callExpression2));
                if (copyOf.size() == 1) {
                    return (RowExpression) Iterables.getOnlyElement(copyOf);
                }
                return new CallExpression(Signature.internalScalarFunction(name, BooleanType.BOOLEAN.getTypeSignature(), (List<TypeSignature>) copyOf.stream().map((v0) -> {
                    return v0.getType();
                }).map((v0) -> {
                    return v0.getTypeSignature();
                }).collect(ImmutableList.toImmutableList())), BooleanType.BOOLEAN, ExpressionEquivalence.ROW_EXPRESSION_ORDERING.sortedCopy(copyOf));
            }
            if (name.equals(FunctionRegistry.mangleOperatorName(OperatorType.EQUAL)) || name.equals(FunctionRegistry.mangleOperatorName(OperatorType.NOT_EQUAL)) || name.equals(FunctionRegistry.mangleOperatorName(OperatorType.IS_DISTINCT_FROM))) {
                return new CallExpression(callExpression2.getSignature(), callExpression2.getType(), ExpressionEquivalence.ROW_EXPRESSION_ORDERING.sortedCopy(callExpression2.getArguments()));
            }
            if (name.equals(FunctionRegistry.mangleOperatorName(OperatorType.GREATER_THAN)) || name.equals(FunctionRegistry.mangleOperatorName(OperatorType.GREATER_THAN_OR_EQUAL))) {
                return new CallExpression(new Signature(name.equals(FunctionRegistry.mangleOperatorName(OperatorType.GREATER_THAN)) ? FunctionRegistry.mangleOperatorName(OperatorType.LESS_THAN) : FunctionRegistry.mangleOperatorName(OperatorType.LESS_THAN_OR_EQUAL), FunctionKind.SCALAR, callExpression2.getSignature().getTypeVariableConstraints(), callExpression2.getSignature().getLongVariableConstraints(), callExpression2.getSignature().getReturnType(), ExpressionEquivalence.swapPair(callExpression2.getSignature().getArgumentTypes()), false), callExpression2.getType(), ExpressionEquivalence.swapPair(callExpression2.getArguments()));
            }
            return callExpression2;
        }

        public static List<RowExpression> flattenNestedCallArgs(CallExpression callExpression) {
            String name = callExpression.getSignature().getName();
            ImmutableList.Builder builder = ImmutableList.builder();
            for (RowExpression rowExpression : callExpression.getArguments()) {
                if ((rowExpression instanceof CallExpression) && name.equals(((CallExpression) rowExpression).getSignature().getName())) {
                    builder.addAll(flattenNestedCallArgs((CallExpression) rowExpression));
                } else {
                    builder.add(rowExpression);
                }
            }
            return builder.build();
        }

        @Override // com.facebook.presto.sql.relational.RowExpressionVisitor
        public RowExpression visitConstant(ConstantExpression constantExpression, Void r4) {
            return constantExpression;
        }

        @Override // com.facebook.presto.sql.relational.RowExpressionVisitor
        public RowExpression visitInputReference(InputReferenceExpression inputReferenceExpression, Void r4) {
            return inputReferenceExpression;
        }

        @Override // com.facebook.presto.sql.relational.RowExpressionVisitor
        public RowExpression visitLambda(LambdaDefinitionExpression lambdaDefinitionExpression, Void r10) {
            return new LambdaDefinitionExpression(lambdaDefinitionExpression.getArgumentTypes(), lambdaDefinitionExpression.getArguments(), (RowExpression) lambdaDefinitionExpression.getBody().accept(this, r10));
        }

        @Override // com.facebook.presto.sql.relational.RowExpressionVisitor
        public RowExpression visitVariableReference(VariableReferenceExpression variableReferenceExpression, Void r4) {
            return variableReferenceExpression;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence$ListComparator.class */
    public static class ListComparator<T> implements Comparator<List<T>> {
        private final Comparator<T> elementComparator;

        public ListComparator(Comparator<T> comparator) {
            this.elementComparator = (Comparator) Objects.requireNonNull(comparator, "elementComparator is null");
        }

        @Override // java.util.Comparator
        public int compare(List<T> list, List<T> list2) {
            int min = Integer.min(list.size(), list2.size());
            for (int i = 0; i < min; i++) {
                int compare = this.elementComparator.compare(list.get(i), list2.get(i));
                if (compare != 0) {
                    return compare;
                }
            }
            return Integer.compare(list.size(), list2.size());
        }
    }

    /* loaded from: input_file:com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence$RowExpressionComparator.class */
    private static class RowExpressionComparator implements Comparator<RowExpression> {
        private final Comparator<Object> classComparator;
        private final ListComparator<RowExpression> argumentComparator;

        private RowExpressionComparator() {
            this.classComparator = Ordering.arbitrary();
            this.argumentComparator = new ListComparator<>(this);
        }

        @Override // java.util.Comparator
        public int compare(RowExpression rowExpression, RowExpression rowExpression2) {
            int compare = this.classComparator.compare(rowExpression.getClass(), rowExpression2.getClass());
            if (compare != 0) {
                return compare;
            }
            if (rowExpression instanceof CallExpression) {
                CallExpression callExpression = (CallExpression) rowExpression;
                CallExpression callExpression2 = (CallExpression) rowExpression2;
                return ComparisonChain.start().compare(callExpression.getSignature().toString(), callExpression2.getSignature().toString()).compare(callExpression.getArguments(), callExpression2.getArguments(), this.argumentComparator).result();
            }
            if (!(rowExpression instanceof ConstantExpression)) {
                if (rowExpression instanceof InputReferenceExpression) {
                    return Integer.compare(((InputReferenceExpression) rowExpression).getField(), ((InputReferenceExpression) rowExpression2).getField());
                }
                throw new IllegalArgumentException("Unsupported RowExpression type " + rowExpression.getClass().getSimpleName());
            }
            ConstantExpression constantExpression = (ConstantExpression) rowExpression;
            ConstantExpression constantExpression2 = (ConstantExpression) rowExpression2;
            int compareTo = constantExpression.getType().getTypeSignature().toString().compareTo(rowExpression2.getType().getTypeSignature().toString());
            if (compareTo != 0) {
                return compareTo;
            }
            Object value = constantExpression.getValue();
            Object value2 = constantExpression2.getValue();
            Class javaType = constantExpression.getType().getJavaType();
            if (javaType == Boolean.TYPE) {
                return ((Boolean) value).compareTo((Boolean) value2);
            }
            if (javaType == Byte.TYPE || javaType == Short.TYPE || javaType == Integer.TYPE || javaType == Long.TYPE) {
                return Long.compare(((Number) value).longValue(), ((Number) value2).longValue());
            }
            if (javaType == Float.TYPE || javaType == Double.TYPE) {
                return Double.compare(((Number) value).doubleValue(), ((Number) value2).doubleValue());
            }
            if (javaType == Slice.class) {
                return ((Slice) value).compareTo((Slice) value2);
            }
            return -1;
        }
    }

    public ExpressionEquivalence(Metadata metadata, SqlParser sqlParser) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
        this.sqlParser = (SqlParser) Objects.requireNonNull(sqlParser, "sqlParser is null");
    }

    public boolean areExpressionsEquivalent(Session session, Expression expression, Expression expression2, Map<Symbol, Type> map) {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        int i = 0;
        for (Map.Entry<Symbol, Type> entry : map.entrySet()) {
            hashMap.put(entry.getKey(), Integer.valueOf(i));
            hashMap2.put(Integer.valueOf(i), entry.getValue());
            i++;
        }
        return ((RowExpression) toRowExpression(session, expression, hashMap, hashMap2).accept(CANONICALIZATION_VISITOR, null)).equals((RowExpression) toRowExpression(session, expression2, hashMap, hashMap2).accept(CANONICALIZATION_VISITOR, null));
    }

    private RowExpression toRowExpression(Session session, Expression expression, Map<Symbol, Integer> map, Map<Integer, Type> map2) {
        Expression rewrite = new SymbolToInputRewriter(map).rewrite(expression);
        return SqlToRowExpressionTranslator.translate(rewrite, FunctionKind.SCALAR, ExpressionAnalyzer.getExpressionTypesFromInput(session, this.metadata, this.sqlParser, map2, rewrite, (List<Expression>) Collections.emptyList()), this.metadata.getFunctionRegistry(), this.metadata.getTypeManager(), session, false);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static <T> List<T> swapPair(List<T> list) {
        Objects.requireNonNull(list, "pair is null");
        Preconditions.checkArgument(list.size() == 2, "Expected pair to have two elements");
        return ImmutableList.of(list.get(1), list.get(0));
    }
}
