package com.facebook.presto.sql.gen;

import com.facebook.presto.bytecode.Access;
import com.facebook.presto.bytecode.BytecodeBlock;
import com.facebook.presto.bytecode.ClassDefinition;
import com.facebook.presto.bytecode.FieldDefinition;
import com.facebook.presto.bytecode.Variable;
import com.facebook.presto.bytecode.expression.BytecodeExpressions;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.InputReferenceExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.RowExpressionVisitor;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.PlanVariableAllocator;
import com.facebook.presto.sql.relational.Expressions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import com.google.common.primitives.Primitives;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:com/facebook/presto/sql/gen/CommonSubExpressionRewriter.class */
public class CommonSubExpressionRewriter {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/facebook/presto/sql/gen/CommonSubExpressionRewriter$CommonSubExpressionCollector.class */
    public static class CommonSubExpressionCollector implements RowExpressionVisitor<Integer, Void> {
        private final Map<Integer, Set<RowExpression>> expressionsByLevel = new HashMap();
        private final Map<Integer, Set<RowExpression>> cseByLevel = new HashMap();
        private final Map<RowExpression, Integer> expressionCount = new HashMap();

        CommonSubExpressionCollector() {
        }

        private int addAtLevel(int i, RowExpression rowExpression) {
            Set<RowExpression> expressionsAtLevel = getExpressionsAtLevel(i, this.expressionsByLevel);
            this.expressionCount.putIfAbsent(rowExpression, 1);
            if (expressionsAtLevel.contains(rowExpression)) {
                getExpressionsAtLevel(i, this.cseByLevel).add(rowExpression);
                this.expressionCount.put(rowExpression, Integer.valueOf(this.expressionCount.get(rowExpression).intValue() + 1));
            }
            expressionsAtLevel.add(rowExpression);
            return i;
        }

        private static Set<RowExpression> getExpressionsAtLevel(int i, Map<Integer, Set<RowExpression>> map) {
            map.putIfAbsent(Integer.valueOf(i), new HashSet());
            return map.get(Integer.valueOf(i));
        }

        public Integer visitCall(CallExpression callExpression, Void r7) {
            if (callExpression.getArguments().isEmpty()) {
                return 0;
            }
            return Integer.valueOf(addAtLevel(((Integer) callExpression.getArguments().stream().map(rowExpression -> {
                return (Integer) rowExpression.accept(this, r7);
            }).reduce((v0, v1) -> {
                return Math.max(v0, v1);
            }).get()).intValue() + 1, callExpression));
        }

        public Integer visitInputReference(InputReferenceExpression inputReferenceExpression, Void r4) {
            return 0;
        }

        public Integer visitConstant(ConstantExpression constantExpression, Void r4) {
            return 0;
        }

        public Integer visitLambda(LambdaDefinitionExpression lambdaDefinitionExpression, Void r4) {
            return 0;
        }

        public Integer visitVariableReference(VariableReferenceExpression variableReferenceExpression, Void r4) {
            return 0;
        }

        public Integer visitSpecialForm(SpecialFormExpression specialFormExpression, Void r6) {
            int intValue = ((Integer) specialFormExpression.getArguments().stream().map(rowExpression -> {
                return (Integer) rowExpression.accept(this, (Object) null);
            }).reduce((v0, v1) -> {
                return Math.max(v0, v1);
            }).get()).intValue() + 1;
            if (specialFormExpression.getForm() != SpecialFormExpression.Form.WHEN && specialFormExpression.getForm() != SpecialFormExpression.Form.BIND) {
                addAtLevel(intValue, specialFormExpression);
            }
            return Integer.valueOf(intValue);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/facebook/presto/sql/gen/CommonSubExpressionRewriter$CommonSubExpressionFields.class */
    public static class CommonSubExpressionFields {
        private final FieldDefinition evaluatedField;
        private final FieldDefinition resultField;
        private final Class<?> resultType;
        private final String methodName;

        public CommonSubExpressionFields(FieldDefinition fieldDefinition, FieldDefinition fieldDefinition2, Class<?> cls, String str) {
            this.evaluatedField = fieldDefinition;
            this.resultField = fieldDefinition2;
            this.resultType = cls;
            this.methodName = str;
        }

        public FieldDefinition getEvaluatedField() {
            return this.evaluatedField;
        }

        public FieldDefinition getResultField() {
            return this.resultField;
        }

        public String getMethodName() {
            return this.methodName;
        }

        public Class<?> getResultType() {
            return this.resultType;
        }

        public static Map<VariableReferenceExpression, CommonSubExpressionFields> declareCommonSubExpressionFields(ClassDefinition classDefinition, Map<Integer, Map<RowExpression, VariableReferenceExpression>> map) {
            ImmutableMap.Builder builder = ImmutableMap.builder();
            map.values().stream().map((v0) -> {
                return v0.values();
            }).flatMap((v0) -> {
                return v0.stream();
            }).forEach(variableReferenceExpression -> {
                Class wrap = Primitives.wrap(variableReferenceExpression.getType().getJavaType());
                builder.put(variableReferenceExpression, new CommonSubExpressionFields(classDefinition.declareField(Access.a(new Access[]{Access.PRIVATE}), variableReferenceExpression.getName() + "Evaluated", Boolean.TYPE), classDefinition.declareField(Access.a(new Access[]{Access.PRIVATE}), variableReferenceExpression.getName() + "Result", wrap), wrap, "get" + variableReferenceExpression.getName()));
            });
            return builder.build();
        }

        public static void initializeCommonSubExpressionFields(Collection<CommonSubExpressionFields> collection, Variable variable, BytecodeBlock bytecodeBlock) {
            collection.forEach(commonSubExpressionFields -> {
                bytecodeBlock.append(variable.setField(commonSubExpressionFields.getEvaluatedField(), BytecodeExpressions.constantBoolean(false)));
                bytecodeBlock.append(variable.setField(commonSubExpressionFields.getResultField(), BytecodeExpressions.constantNull(commonSubExpressionFields.getResultType())));
            });
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/facebook/presto/sql/gen/CommonSubExpressionRewriter$ExpressionRewriter.class */
    public static class ExpressionRewriter implements RowExpressionVisitor<RowExpression, Void> {
        private final Map<RowExpression, VariableReferenceExpression> expressionMap;

        public ExpressionRewriter(Map<RowExpression, VariableReferenceExpression> map) {
            this.expressionMap = ImmutableMap.copyOf(map);
        }

        public RowExpression visitCall(CallExpression callExpression, Void r11) {
            CallExpression callExpression2 = new CallExpression(callExpression.getSourceLocation(), callExpression.getDisplayName(), callExpression.getFunctionHandle(), callExpression.getType(), (List) callExpression.getArguments().stream().map(rowExpression -> {
                return (RowExpression) rowExpression.accept(this, (Object) null);
            }).collect(ImmutableList.toImmutableList()));
            return this.expressionMap.containsKey(callExpression2) ? this.expressionMap.get(callExpression2) : callExpression2;
        }

        public RowExpression visitInputReference(InputReferenceExpression inputReferenceExpression, Void r4) {
            return inputReferenceExpression;
        }

        public RowExpression visitConstant(ConstantExpression constantExpression, Void r4) {
            return constantExpression;
        }

        public RowExpression visitLambda(LambdaDefinitionExpression lambdaDefinitionExpression, Void r4) {
            return lambdaDefinitionExpression;
        }

        public RowExpression visitVariableReference(VariableReferenceExpression variableReferenceExpression, Void r4) {
            return variableReferenceExpression;
        }

        public RowExpression visitSpecialForm(SpecialFormExpression specialFormExpression, Void r9) {
            SpecialFormExpression specialFormExpression2 = new SpecialFormExpression(specialFormExpression.getForm(), specialFormExpression.getType(), (List) specialFormExpression.getArguments().stream().map(rowExpression -> {
                return (RowExpression) rowExpression.accept(this, (Object) null);
            }).collect(ImmutableList.toImmutableList()));
            return this.expressionMap.containsKey(specialFormExpression2) ? this.expressionMap.get(specialFormExpression2) : specialFormExpression2;
        }
    }

    /* loaded from: input_file:com/facebook/presto/sql/gen/CommonSubExpressionRewriter$SubExpressionChecker.class */
    static class SubExpressionChecker implements RowExpressionVisitor<Boolean, Void> {
        private final Set<RowExpression> subExpressions;

        SubExpressionChecker(Set<RowExpression> set) {
            this.subExpressions = set;
        }

        public Boolean visitCall(CallExpression callExpression, Void r5) {
            if (this.subExpressions.contains(callExpression)) {
                return true;
            }
            if (callExpression.getArguments().isEmpty()) {
                return false;
            }
            return Boolean.valueOf(callExpression.getArguments().stream().anyMatch(rowExpression -> {
                return ((Boolean) rowExpression.accept(this, (Object) null)).booleanValue();
            }));
        }

        public Boolean visitInputReference(InputReferenceExpression inputReferenceExpression, Void r5) {
            return Boolean.valueOf(this.subExpressions.contains(inputReferenceExpression));
        }

        public Boolean visitConstant(ConstantExpression constantExpression, Void r5) {
            return Boolean.valueOf(this.subExpressions.contains(constantExpression));
        }

        public Boolean visitLambda(LambdaDefinitionExpression lambdaDefinitionExpression, Void r4) {
            return false;
        }

        public Boolean visitVariableReference(VariableReferenceExpression variableReferenceExpression, Void r5) {
            return Boolean.valueOf(this.subExpressions.contains(variableReferenceExpression));
        }

        public Boolean visitSpecialForm(SpecialFormExpression specialFormExpression, Void r5) {
            if (this.subExpressions.contains(specialFormExpression)) {
                return true;
            }
            if (specialFormExpression.getArguments().isEmpty()) {
                return false;
            }
            return Boolean.valueOf(specialFormExpression.getArguments().stream().anyMatch(rowExpression -> {
                return ((Boolean) rowExpression.accept(this, (Object) null)).booleanValue();
            }));
        }
    }

    private CommonSubExpressionRewriter() {
    }

    public static Map<Integer, Map<RowExpression, VariableReferenceExpression>> collectCSEByLevel(List<? extends RowExpression> list) {
        if (list.isEmpty()) {
            return ImmutableMap.of();
        }
        CommonSubExpressionCollector commonSubExpressionCollector = new CommonSubExpressionCollector();
        list.forEach(rowExpression -> {
        });
        if (commonSubExpressionCollector.cseByLevel.isEmpty()) {
            return ImmutableMap.of();
        }
        Map<Integer, Map<RowExpression, Integer>> removeRedundantCSE = removeRedundantCSE(commonSubExpressionCollector.cseByLevel, commonSubExpressionCollector.expressionCount);
        PlanVariableAllocator planVariableAllocator = new PlanVariableAllocator();
        ImmutableMap.Builder builder = ImmutableMap.builder();
        HashMap hashMap = new HashMap();
        int intValue = removeRedundantCSE.keySet().stream().reduce((v0, v1) -> {
            return Math.min(v0, v1);
        }).get().intValue();
        int intValue2 = removeRedundantCSE.keySet().stream().reduce((v0, v1) -> {
            return Math.max(v0, v1);
        }).get().intValue();
        for (int i = intValue; i <= intValue2; i++) {
            if (removeRedundantCSE.containsKey(Integer.valueOf(i))) {
                ExpressionRewriter expressionRewriter = new ExpressionRewriter(hashMap);
                ImmutableMap.Builder builder2 = ImmutableMap.builder();
                Iterator<Map.Entry<RowExpression, Integer>> it = removeRedundantCSE.get(Integer.valueOf(i)).entrySet().iterator();
                while (it.hasNext()) {
                    RowExpression rowExpression2 = (RowExpression) it.next().getKey().accept(expressionRewriter, (Object) null);
                    builder2.put(rowExpression2, planVariableAllocator.newVariable(rowExpression2, "cse"));
                }
                ImmutableMap build = builder2.build();
                builder.put(Integer.valueOf(i), build);
                hashMap.putAll((Map) build.entrySet().stream().collect(ImmutableMap.toImmutableMap((v0) -> {
                    return v0.getKey();
                }, entry -> {
                    return (VariableReferenceExpression) entry.getValue();
                })));
            }
        }
        return builder.build();
    }

    public static Map<Integer, Map<RowExpression, VariableReferenceExpression>> collectCSEByLevel(RowExpression rowExpression) {
        return collectCSEByLevel((List<? extends RowExpression>) ImmutableList.of(rowExpression));
    }

    public static Map<List<RowExpression>, Boolean> getExpressionsPartitionedByCSE(Collection<? extends RowExpression> collection, int i) {
        if (collection.isEmpty()) {
            return ImmutableMap.of();
        }
        CommonSubExpressionCollector commonSubExpressionCollector = new CommonSubExpressionCollector();
        collection.forEach(rowExpression -> {
        });
        Set set = (Set) commonSubExpressionCollector.cseByLevel.values().stream().flatMap((v0) -> {
            return v0.stream();
        }).collect(ImmutableSet.toImmutableSet());
        if (set.isEmpty()) {
            return (Map) collection.stream().collect(ImmutableMap.toImmutableMap((v0) -> {
                return ImmutableList.of(v0);
            }, rowExpression2 -> {
                return false;
            }));
        }
        ImmutableMap.Builder builder = ImmutableMap.builder();
        SubExpressionChecker subExpressionChecker = new SubExpressionChecker(set);
        Map map = (Map) collection.stream().collect(Collectors.partitioningBy(rowExpression3 -> {
            return ((Boolean) rowExpression3.accept(subExpressionChecker, (Object) null)).booleanValue();
        }));
        ((List) map.get(false)).forEach(rowExpression4 -> {
            builder.put(ImmutableList.of(rowExpression4), false);
        });
        List list = (List) map.get(true);
        if (list.size() == 1) {
            builder.put(ImmutableList.of((RowExpression) list.get(0)), true);
            return builder.build();
        }
        List list2 = (List) list.stream().map(rowExpression5 -> {
            Stream<RowExpression> stream = Expressions.subExpressions(rowExpression5).stream();
            set.getClass();
            return (ImmutableSet) stream.filter((v1) -> {
                return r1.contains(v1);
            }).collect(ImmutableSet.toImmutableSet());
        }).collect(ImmutableList.toImmutableList());
        boolean[] zArr = new boolean[list.size()];
        int i2 = 0;
        while (i2 < zArr.length) {
            while (i2 < zArr.length && zArr[i2]) {
                i2++;
            }
            if (i2 >= zArr.length) {
                break;
            }
            zArr[i2] = true;
            ArrayList arrayList = new ArrayList();
            arrayList.add(list.get(i2));
            HashSet hashSet = new HashSet();
            hashSet.addAll((Set) list2.get(i2));
            int i3 = i2 + 1;
            while (i3 < zArr.length && arrayList.size() < i) {
                while (i3 < zArr.length && zArr[i3]) {
                    i3++;
                }
                if (i3 >= zArr.length) {
                    break;
                }
                Set set2 = (Set) list2.get(i3);
                if (Sets.intersection(hashSet, set2).isEmpty()) {
                    i3++;
                } else {
                    arrayList.add((RowExpression) list.get(i3));
                    hashSet.addAll(set2);
                    zArr[i3] = true;
                    i3 = i2 + 1;
                }
            }
            builder.put(ImmutableList.copyOf(arrayList), true);
        }
        return builder.build();
    }

    public static RowExpression rewriteExpressionWithCSE(RowExpression rowExpression, Map<RowExpression, VariableReferenceExpression> map) {
        return (RowExpression) rowExpression.accept(new ExpressionRewriter(map), (Object) null);
    }

    private static Map<Integer, Map<RowExpression, Integer>> removeRedundantCSE(Map<Integer, Set<RowExpression>> map, Map<RowExpression, Integer> map2) {
        HashMap hashMap = new HashMap();
        int intValue = map.keySet().stream().reduce((v0, v1) -> {
            return Math.max(v0, v1);
        }).get().intValue();
        int intValue2 = map.keySet().stream().reduce((v0, v1) -> {
            return Math.min(v0, v1);
        }).get().intValue();
        for (int i = intValue; i > intValue2; i--) {
            if (map.containsKey(Integer.valueOf(i))) {
                Stream<RowExpression> filter = map.get(Integer.valueOf(i)).stream().filter(rowExpression -> {
                    return ((Integer) map2.get(rowExpression)).intValue() > 0;
                });
                Function identity = Function.identity();
                map2.getClass();
                Map map3 = (Map) filter.collect(ImmutableMap.toImmutableMap(identity, (v1) -> {
                    return r2.get(v1);
                }));
                if (!map3.isEmpty()) {
                    hashMap.put(Integer.valueOf(i), map3);
                }
                for (RowExpression rowExpression2 : map3.keySet()) {
                    int intValue3 = map2.get(rowExpression2).intValue();
                    Expressions.subExpressions(rowExpression2).stream().filter(rowExpression3 -> {
                        return !rowExpression3.equals(rowExpression2);
                    }).forEach(rowExpression4 -> {
                        if (map2.containsKey(rowExpression4)) {
                            map2.put(rowExpression4, Integer.valueOf(((Integer) map2.get(rowExpression4)).intValue() - intValue3));
                        }
                    });
                }
            }
        }
        Map map4 = (Map) map.get(Integer.valueOf(intValue2)).stream().filter(rowExpression5 -> {
            return ((Integer) map2.get(rowExpression5)).intValue() > 0;
        }).collect(ImmutableMap.toImmutableMap(Function.identity(), rowExpression6 -> {
            return Integer.valueOf(((Integer) map2.get(rowExpression6)).intValue() + 1);
        }));
        if (!map4.isEmpty()) {
            hashMap.put(Integer.valueOf(intValue2), map4);
        }
        return hashMap;
    }
}
