package com.facebook.presto.sql.gen;

import com.facebook.presto.bytecode.Access;
import com.facebook.presto.bytecode.BytecodeBlock;
import com.facebook.presto.bytecode.ClassDefinition;
import com.facebook.presto.bytecode.CompilerUtils;
import com.facebook.presto.bytecode.DynamicClassLoader;
import com.facebook.presto.bytecode.FieldDefinition;
import com.facebook.presto.bytecode.MethodDefinition;
import com.facebook.presto.bytecode.Parameter;
import com.facebook.presto.bytecode.ParameterizedType;
import com.facebook.presto.bytecode.Variable;
import com.facebook.presto.bytecode.control.IfStatement;
import com.facebook.presto.bytecode.expression.BytecodeExpression;
import com.facebook.presto.bytecode.expression.BytecodeExpressions;
import com.facebook.presto.bytecode.instruction.JumpInstruction;
import com.facebook.presto.bytecode.instruction.LabelNode;
import com.facebook.presto.operator.JoinProbe;
import com.facebook.presto.operator.JoinProbeFactory;
import com.facebook.presto.operator.LookupJoinOperator;
import com.facebook.presto.operator.LookupJoinOperatorFactory;
import com.facebook.presto.operator.LookupJoinOperators;
import com.facebook.presto.operator.LookupSource;
import com.facebook.presto.operator.LookupSourceFactory;
import com.facebook.presto.operator.OperatorFactory;
import com.facebook.presto.operator.SimpleJoinProbe;
import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.PageBuilder;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spiller.PartitioningSpillerFactory;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Throwables;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ExecutionError;
import com.google.common.util.concurrent.UncheckedExecutionException;
import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.OptionalInt;
import java.util.concurrent.ExecutionException;
import java.util.stream.Stream;
import org.weakref.jmx.Managed;
import org.weakref.jmx.Nested;

/* loaded from: input_file:com/facebook/presto/sql/gen/JoinProbeCompiler.class */
public class JoinProbeCompiler {
    private final LoadingCache<JoinOperatorCacheKey, HashJoinOperatorFactoryFactory> joinProbeFactories = CacheBuilder.newBuilder().recordStats().maximumSize(1000).build(CacheLoader.from(joinOperatorCacheKey -> {
        return internalCompileJoinOperatorFactory(joinOperatorCacheKey.getTypes(), joinOperatorCacheKey.getProbeOutputChannels(), joinOperatorCacheKey.getProbeChannels(), joinOperatorCacheKey.getProbeHashChannel());
    }));

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/sql/gen/JoinProbeCompiler$HashJoinOperatorFactoryFactory.class */
    public static class HashJoinOperatorFactoryFactory {
        private final JoinProbeFactory joinProbeFactory;
        private final Constructor<? extends OperatorFactory> constructor;

        private HashJoinOperatorFactoryFactory(JoinProbeFactory joinProbeFactory, Class<? extends OperatorFactory> cls) {
            this.joinProbeFactory = joinProbeFactory;
            try {
                this.constructor = cls.getConstructor(Integer.TYPE, PlanNodeId.class, LookupSourceFactory.class, List.class, List.class, LookupJoinOperators.JoinType.class, JoinProbeFactory.class, OptionalInt.class, List.class, OptionalInt.class, PartitioningSpillerFactory.class);
            } catch (NoSuchMethodException e) {
                throw Throwables.propagate(e);
            }
        }

        public OperatorFactory createHashJoinOperatorFactory(int i, PlanNodeId planNodeId, LookupSourceFactory lookupSourceFactory, List<? extends Type> list, List<? extends Type> list2, LookupJoinOperators.JoinType joinType, OptionalInt optionalInt, List<Integer> list3, OptionalInt optionalInt2, PartitioningSpillerFactory partitioningSpillerFactory) {
            try {
                return this.constructor.newInstance(Integer.valueOf(i), planNodeId, lookupSourceFactory, list, list2, joinType, this.joinProbeFactory, optionalInt, list3, optionalInt2, partitioningSpillerFactory);
            } catch (Exception e) {
                throw Throwables.propagate(e);
            }
        }
    }

    /* loaded from: input_file:com/facebook/presto/sql/gen/JoinProbeCompiler$JoinOperatorCacheKey.class */
    private static final class JoinOperatorCacheKey {
        private final List<Type> types;
        private final List<Integer> probeOutputChannels;
        private final List<Integer> probeChannels;
        private final LookupJoinOperators.JoinType joinType;
        private final OptionalInt probeHashChannel;

        private JoinOperatorCacheKey(List<? extends Type> list, List<Integer> list2, List<Integer> list3, OptionalInt optionalInt, LookupJoinOperators.JoinType joinType) {
            this.probeHashChannel = optionalInt;
            this.types = ImmutableList.copyOf(list);
            this.probeOutputChannels = ImmutableList.copyOf(list2);
            this.probeChannels = ImmutableList.copyOf(list3);
            this.joinType = joinType;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public List<Type> getTypes() {
            return this.types;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public List<Integer> getProbeOutputChannels() {
            return this.probeOutputChannels;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public List<Integer> getProbeChannels() {
            return this.probeChannels;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public OptionalInt getProbeHashChannel() {
            return this.probeHashChannel;
        }

        public int hashCode() {
            return Objects.hash(this.types, this.probeOutputChannels, this.probeChannels, this.joinType);
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (!(obj instanceof JoinOperatorCacheKey)) {
                return false;
            }
            JoinOperatorCacheKey joinOperatorCacheKey = (JoinOperatorCacheKey) obj;
            return Objects.equals(this.types, joinOperatorCacheKey.types) && Objects.equals(this.probeOutputChannels, joinOperatorCacheKey.probeOutputChannels) && Objects.equals(this.probeChannels, joinOperatorCacheKey.probeChannels) && Objects.equals(this.probeHashChannel, joinOperatorCacheKey.probeHashChannel) && Objects.equals(this.joinType, joinOperatorCacheKey.joinType);
        }
    }

    /* loaded from: input_file:com/facebook/presto/sql/gen/JoinProbeCompiler$ReflectionJoinProbeFactory.class */
    public static class ReflectionJoinProbeFactory implements JoinProbeFactory {
        private final Constructor<? extends JoinProbe> constructor;

        public ReflectionJoinProbeFactory(Class<? extends JoinProbe> cls) {
            try {
                this.constructor = cls.getConstructor(Page.class);
            } catch (NoSuchMethodException e) {
                throw Throwables.propagate(e);
            }
        }

        @Override // com.facebook.presto.operator.JoinProbeFactory
        public JoinProbe createJoinProbe(Page page) {
            try {
                return this.constructor.newInstance(page);
            } catch (Exception e) {
                throw Throwables.propagate(e);
            }
        }
    }

    @Managed
    @Nested
    public CacheStatsMBean getJoinProbeFactoriesStats() {
        return new CacheStatsMBean(this.joinProbeFactories);
    }

    public OperatorFactory compileJoinOperatorFactory(int i, PlanNodeId planNodeId, LookupSourceFactory lookupSourceFactory, List<? extends Type> list, List<Integer> list2, OptionalInt optionalInt, List<Integer> list3, LookupJoinOperators.JoinType joinType, OptionalInt optionalInt2, PartitioningSpillerFactory partitioningSpillerFactory) {
        try {
            Stream<Integer> stream = list3.stream();
            list.getClass();
            return ((HashJoinOperatorFactoryFactory) this.joinProbeFactories.get(new JoinOperatorCacheKey(list, list3, list2, optionalInt, joinType))).createHashJoinOperatorFactory(i, planNodeId, lookupSourceFactory, list, (List) stream.map((v1) -> {
                return r1.get(v1);
            }).collect(ImmutableList.toImmutableList()), joinType, optionalInt2, list2, optionalInt, partitioningSpillerFactory);
        } catch (ExecutionException | UncheckedExecutionException | ExecutionError e) {
            throw Throwables.propagate(e.getCause());
        }
    }

    @VisibleForTesting
    public HashJoinOperatorFactoryFactory internalCompileJoinOperatorFactory(List<Type> list, List<Integer> list2, List<Integer> list3, OptionalInt optionalInt) {
        JoinProbeFactory joinProbeFactory;
        Class<? extends JoinProbe> compileJoinProbe = compileJoinProbe(list, list2, list3, optionalInt);
        ClassDefinition classDefinition = new ClassDefinition(Access.a(new Access[]{Access.PUBLIC, Access.FINAL}), CompilerUtils.makeClassName("JoinProbeFactory"), ParameterizedType.type(Object.class), new ParameterizedType[]{ParameterizedType.type(JoinProbeFactory.class)});
        classDefinition.declareDefaultConstructor(Access.a(new Access[]{Access.PUBLIC}));
        Parameter arg = Parameter.arg("page", Page.class);
        classDefinition.declareMethod(Access.a(new Access[]{Access.PUBLIC}), "createJoinProbe", ParameterizedType.type(JoinProbe.class), new Parameter[]{arg}).getBody().newObject(compileJoinProbe).dup().append(arg).invokeConstructor(compileJoinProbe, new Class[]{Page.class}).retObject();
        DynamicClassLoader dynamicClassLoader = new DynamicClassLoader(compileJoinProbe.getClassLoader());
        if (list3.isEmpty()) {
            joinProbeFactory = new SimpleJoinProbe.SimpleJoinProbeFactory(list, list2, list3, optionalInt);
        } else {
            try {
                joinProbeFactory = (JoinProbeFactory) CompilerUtils.defineClass(classDefinition, JoinProbeFactory.class, dynamicClassLoader).newInstance();
            } catch (Exception e) {
                throw Throwables.propagate(e);
            }
        }
        return new HashJoinOperatorFactoryFactory(joinProbeFactory, IsolatedClass.isolateClass(dynamicClassLoader, OperatorFactory.class, LookupJoinOperatorFactory.class, LookupJoinOperator.class));
    }

    @VisibleForTesting
    public JoinProbeFactory internalCompileJoinProbe(List<Type> list, List<Integer> list2, List<Integer> list3, OptionalInt optionalInt) {
        return new ReflectionJoinProbeFactory(compileJoinProbe(list, list2, list3, optionalInt));
    }

    private Class<? extends JoinProbe> compileJoinProbe(List<Type> list, List<Integer> list2, List<Integer> list3, OptionalInt optionalInt) {
        CallSiteBinder callSiteBinder = new CallSiteBinder();
        ClassDefinition classDefinition = new ClassDefinition(Access.a(new Access[]{Access.PUBLIC, Access.FINAL}), CompilerUtils.makeClassName("JoinProbe"), ParameterizedType.type(Object.class), new ParameterizedType[]{ParameterizedType.type(JoinProbe.class)});
        FieldDefinition declareField = classDefinition.declareField(Access.a(new Access[]{Access.PRIVATE, Access.FINAL}), "positionCount", Integer.TYPE);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(classDefinition.declareField(Access.a(new Access[]{Access.PRIVATE, Access.FINAL}), "block_" + i, Block.class));
        }
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < list3.size(); i2++) {
            arrayList2.add(classDefinition.declareField(Access.a(new Access[]{Access.PRIVATE, Access.FINAL}), "probeBlock_" + i2, Block.class));
        }
        FieldDefinition declareField2 = classDefinition.declareField(Access.a(new Access[]{Access.PRIVATE, Access.FINAL}), "probeBlocks", Block[].class);
        FieldDefinition declareField3 = classDefinition.declareField(Access.a(new Access[]{Access.PRIVATE, Access.FINAL}), "probePage", Page.class);
        FieldDefinition declareField4 = classDefinition.declareField(Access.a(new Access[]{Access.PRIVATE, Access.FINAL}), "page", Page.class);
        FieldDefinition declareField5 = classDefinition.declareField(Access.a(new Access[]{Access.PRIVATE}), "position", Integer.TYPE);
        FieldDefinition declareField6 = classDefinition.declareField(Access.a(new Access[]{Access.PRIVATE, Access.FINAL}), "probeHashBlock", Block.class);
        generateConstructor(classDefinition, list3, optionalInt, arrayList, arrayList2, declareField2, declareField3, declareField4, declareField6, declareField5, declareField);
        generateGetChannelCountMethod(classDefinition, list2.size());
        generateAppendToMethod(classDefinition, callSiteBinder, list, list2, arrayList, declareField5);
        generateAdvanceNextPosition(classDefinition, declareField5, declareField);
        generateGetCurrentJoinPosition(classDefinition, callSiteBinder, declareField3, declareField4, optionalInt, declareField6, declareField5);
        generateCurrentRowContainsNull(classDefinition, arrayList2, declareField5);
        generateGetPosition(classDefinition, declareField5);
        generateGetPage(classDefinition, declareField4);
        return CompilerUtils.defineClass(classDefinition, JoinProbe.class, callSiteBinder.getBindings(), getClass().getClassLoader());
    }

    private static void generateConstructor(ClassDefinition classDefinition, List<Integer> list, OptionalInt optionalInt, List<FieldDefinition> list2, List<FieldDefinition> list3, FieldDefinition fieldDefinition, FieldDefinition fieldDefinition2, FieldDefinition fieldDefinition3, FieldDefinition fieldDefinition4, FieldDefinition fieldDefinition5, FieldDefinition fieldDefinition6) {
        Parameter arg = Parameter.arg("page", Page.class);
        MethodDefinition declareConstructor = classDefinition.declareConstructor(Access.a(new Access[]{Access.PUBLIC}), new Parameter[]{arg});
        Variable variable = declareConstructor.getThis();
        BytecodeBlock invokeConstructor = declareConstructor.getBody().comment("super();").append(variable).invokeConstructor(Object.class, new Class[0]);
        invokeConstructor.comment("this.positionCount = page.getPositionCount();").append(variable.setField(fieldDefinition6, arg.invoke("getPositionCount", Integer.TYPE, new BytecodeExpression[0])));
        invokeConstructor.comment("Set block fields");
        for (int i = 0; i < list2.size(); i++) {
            invokeConstructor.append(variable.setField(list2.get(i), arg.invoke("getBlock", Block.class, new BytecodeExpression[]{BytecodeExpressions.constantInt(i)})));
        }
        invokeConstructor.comment("Set probe channel fields");
        for (int i2 = 0; i2 < list3.size(); i2++) {
            invokeConstructor.append(variable.setField(list3.get(i2), variable.getField(list2.get(list.get(i2).intValue()))));
        }
        invokeConstructor.comment("this.probeBlocks = new Block[<probeChannelCount>];");
        invokeConstructor.append(variable).push(list3.size()).newArray(Block.class).putField(fieldDefinition);
        for (int i3 = 0; i3 < list3.size(); i3++) {
            invokeConstructor.append(variable).getField(fieldDefinition).push(i3).append(variable).getField(list3.get(i3)).putObjectArrayElement();
        }
        invokeConstructor.comment("this.page = page").append(variable.setField(fieldDefinition3, arg));
        invokeConstructor.comment("this.probePage = new Page(probeBlocks)").append(variable.setField(fieldDefinition2, BytecodeExpressions.newInstance(Page.class, new BytecodeExpression[]{variable.getField(fieldDefinition)})));
        if (optionalInt.isPresent()) {
            invokeConstructor.comment("this.probeHashBlock = blocks[hashChannel.get()]").append(variable.setField(fieldDefinition4, variable.getField(list2.get(Integer.valueOf(optionalInt.getAsInt()).intValue()))));
        }
        invokeConstructor.comment("this.position = -1;").append(variable.setField(fieldDefinition5, BytecodeExpressions.constantInt(-1)));
        invokeConstructor.ret();
    }

    private static void generateGetChannelCountMethod(ClassDefinition classDefinition, int i) {
        classDefinition.declareMethod(Access.a(new Access[]{Access.PUBLIC}), "getOutputChannelCount", ParameterizedType.type(Integer.TYPE), new Parameter[0]).getBody().push(i).retInt();
    }

    private static void generateAppendToMethod(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, List<Type> list, List<Integer> list2, List<FieldDefinition> list3, FieldDefinition fieldDefinition) {
        Parameter arg = Parameter.arg("pageBuilder", PageBuilder.class);
        MethodDefinition declareMethod = classDefinition.declareMethod(Access.a(new Access[]{Access.PUBLIC}), "appendTo", ParameterizedType.type(Void.TYPE), new Parameter[]{arg});
        Variable variable = declareMethod.getThis();
        int i = 0;
        Iterator<Integer> it = list2.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            Type type = list.get(intValue);
            int i2 = i;
            i++;
            declareMethod.getBody().comment("%s.appendTo(block_%s, position, pageBuilder.getBlockBuilder(%s));", new Object[]{type.getClass(), Integer.valueOf(intValue), Integer.valueOf(i)}).append(SqlTypeBytecodeExpression.constantType(callSiteBinder, type).invoke("appendTo", Void.TYPE, new BytecodeExpression[]{variable.getField(list3.get(intValue)), variable.getField(fieldDefinition), arg.invoke("getBlockBuilder", BlockBuilder.class, new BytecodeExpression[]{BytecodeExpressions.constantInt(i2)})}));
        }
        declareMethod.getBody().ret();
    }

    private static void generateAdvanceNextPosition(ClassDefinition classDefinition, FieldDefinition fieldDefinition, FieldDefinition fieldDefinition2) {
        MethodDefinition declareMethod = classDefinition.declareMethod(Access.a(new Access[]{Access.PUBLIC}), "advanceNextPosition", ParameterizedType.type(Boolean.TYPE), new Parameter[0]);
        Variable variable = declareMethod.getThis();
        declareMethod.getBody().comment("this.position = this.position + 1;").append(variable).append(variable).getField(fieldDefinition).push(1).intAdd().putField(fieldDefinition);
        LabelNode labelNode = new LabelNode("lessThan");
        LabelNode labelNode2 = new LabelNode("end");
        declareMethod.getBody().comment("return position < positionCount;").append(variable).getField(fieldDefinition).append(variable).getField(fieldDefinition2).append(JumpInstruction.jumpIfIntLessThan(labelNode)).push(false).gotoLabel(labelNode2).visitLabel(labelNode).push(true).visitLabel(labelNode2).retBoolean();
    }

    private static void generateGetCurrentJoinPosition(ClassDefinition classDefinition, CallSiteBinder callSiteBinder, FieldDefinition fieldDefinition, FieldDefinition fieldDefinition2, OptionalInt optionalInt, FieldDefinition fieldDefinition3, FieldDefinition fieldDefinition4) {
        Parameter arg = Parameter.arg("lookupSource", LookupSource.class);
        MethodDefinition declareMethod = classDefinition.declareMethod(Access.a(new Access[]{Access.PUBLIC}), "getCurrentJoinPosition", ParameterizedType.type(Long.TYPE), new Parameter[]{arg});
        Variable variable = declareMethod.getThis();
        BytecodeBlock append = declareMethod.getBody().append(new IfStatement().condition(variable.invoke("currentRowContainsNull", Boolean.TYPE, new BytecodeExpression[0])).ifTrue(BytecodeExpressions.constantLong(-1L).ret()));
        BytecodeExpression field = variable.getField(fieldDefinition4);
        BytecodeExpression field2 = variable.getField(fieldDefinition);
        BytecodeExpression field3 = variable.getField(fieldDefinition2);
        BytecodeExpression field4 = variable.getField(fieldDefinition3);
        if (optionalInt.isPresent()) {
            append.append(arg.invoke("getJoinPosition", Long.TYPE, new BytecodeExpression[]{field, field2, field3, SqlTypeBytecodeExpression.constantType(callSiteBinder, BigintType.BIGINT).invoke("getLong", Long.TYPE, new BytecodeExpression[]{field4, field})})).retLong();
        } else {
            append.append(arg.invoke("getJoinPosition", Long.TYPE, new BytecodeExpression[]{field, field2, field3})).retLong();
        }
    }

    private static void generateCurrentRowContainsNull(ClassDefinition classDefinition, List<FieldDefinition> list, FieldDefinition fieldDefinition) {
        MethodDefinition declareMethod = classDefinition.declareMethod(Access.a(new Access[]{Access.PRIVATE}), "currentRowContainsNull", ParameterizedType.type(Boolean.TYPE), new Parameter[0]);
        Variable variable = declareMethod.getThis();
        for (FieldDefinition fieldDefinition2 : list) {
            LabelNode labelNode = new LabelNode("checkNextField");
            declareMethod.getBody().append(variable.getField(fieldDefinition2).invoke("isNull", Boolean.TYPE, new BytecodeExpression[]{variable.getField(fieldDefinition)})).ifFalseGoto(labelNode).push(true).retBoolean().visitLabel(labelNode);
        }
        declareMethod.getBody().push(false).retInt();
    }

    private static void generateGetPosition(ClassDefinition classDefinition, FieldDefinition fieldDefinition) {
        MethodDefinition declareMethod = classDefinition.declareMethod(Access.a(new Access[]{Access.PUBLIC}), "getPosition", ParameterizedType.type(Integer.TYPE), new Parameter[0]);
        declareMethod.getBody().append(declareMethod.getThis().getField(fieldDefinition)).retInt();
    }

    private static void generateGetPage(ClassDefinition classDefinition, FieldDefinition fieldDefinition) {
        MethodDefinition declareMethod = classDefinition.declareMethod(Access.a(new Access[]{Access.PUBLIC}), "getPage", ParameterizedType.type(Page.class), new Parameter[0]);
        declareMethod.getBody().append(declareMethod.getThis().getField(fieldDefinition)).ret(Page.class);
    }

    public static void checkState(boolean z, boolean z2) {
        if (z != z2) {
            throw new IllegalStateException();
        }
    }
}
