package com.facebook.presto.operator.aggregation;

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.operator.TableWriterUtils;
import com.facebook.presto.spi.function.AccumulatorStateFactory;
import com.facebook.presto.spi.function.AccumulatorStateSerializer;
import com.facebook.presto.testing.MaterializedResult;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.airlift.slice.Slice;
import java.lang.invoke.MethodHandle;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Set;

/* loaded from: input_file:com/facebook/presto/operator/aggregation/AggregationMetadata.class */
public class AggregationMetadata {
    public static final Set<Class<?>> SUPPORTED_PARAMETER_TYPES = ImmutableSet.of(Block.class, Long.TYPE, Double.TYPE, Boolean.TYPE, Slice.class);
    private final String name;
    private final List<ParameterMetadata> valueInputMetadata;
    private final List<Class> lambdaInterfaces;
    private final MethodHandle inputFunction;
    private final MethodHandle combineFunction;
    private final MethodHandle outputFunction;
    private final List<AccumulatorStateDescriptor> accumulatorStateDescriptors;
    private final Type outputType;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: com.facebook.presto.operator.aggregation.AggregationMetadata$1, reason: invalid class name */
    /* loaded from: input_file:com/facebook/presto/operator/aggregation/AggregationMetadata$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$com$facebook$presto$operator$aggregation$AggregationMetadata$ParameterMetadata$ParameterType = new int[ParameterMetadata.ParameterType.values().length];

        static {
            try {
                $SwitchMap$com$facebook$presto$operator$aggregation$AggregationMetadata$ParameterMetadata$ParameterType[ParameterMetadata.ParameterType.STATE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$com$facebook$presto$operator$aggregation$AggregationMetadata$ParameterMetadata$ParameterType[ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$com$facebook$presto$operator$aggregation$AggregationMetadata$ParameterMetadata$ParameterType[ParameterMetadata.ParameterType.NULLABLE_BLOCK_INPUT_CHANNEL.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$com$facebook$presto$operator$aggregation$AggregationMetadata$ParameterMetadata$ParameterType[ParameterMetadata.ParameterType.INPUT_CHANNEL.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$com$facebook$presto$operator$aggregation$AggregationMetadata$ParameterMetadata$ParameterType[ParameterMetadata.ParameterType.BLOCK_INDEX.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
        }
    }

    /* loaded from: input_file:com/facebook/presto/operator/aggregation/AggregationMetadata$AccumulatorStateDescriptor.class */
    public static class AccumulatorStateDescriptor {
        private final Class<?> stateInterface;
        private final AccumulatorStateSerializer<?> serializer;
        private final AccumulatorStateFactory<?> factory;

        public AccumulatorStateDescriptor(Class<?> cls, AccumulatorStateSerializer<?> accumulatorStateSerializer, AccumulatorStateFactory<?> accumulatorStateFactory) {
            this.stateInterface = (Class) Objects.requireNonNull(cls, "stateInterface is null");
            this.serializer = (AccumulatorStateSerializer) Objects.requireNonNull(accumulatorStateSerializer, "serializer is null");
            this.factory = (AccumulatorStateFactory) Objects.requireNonNull(accumulatorStateFactory, "factory is null");
        }

        public Class<?> getStateInterface() {
            return this.stateInterface;
        }

        public AccumulatorStateSerializer<?> getSerializer() {
            return this.serializer;
        }

        public AccumulatorStateFactory<?> getFactory() {
            return this.factory;
        }
    }

    /* loaded from: input_file:com/facebook/presto/operator/aggregation/AggregationMetadata$ParameterMetadata.class */
    public static class ParameterMetadata {
        private final ParameterType parameterType;
        private final Type sqlType;

        /* loaded from: input_file:com/facebook/presto/operator/aggregation/AggregationMetadata$ParameterMetadata$ParameterType.class */
        public enum ParameterType {
            INPUT_CHANNEL,
            BLOCK_INPUT_CHANNEL,
            NULLABLE_BLOCK_INPUT_CHANNEL,
            BLOCK_INDEX,
            STATE;

            /* JADX INFO: Access modifiers changed from: package-private */
            public static ParameterType inputChannelParameterType(boolean z, boolean z2, String str) {
                if (z2) {
                    return z ? NULLABLE_BLOCK_INPUT_CHANNEL : BLOCK_INPUT_CHANNEL;
                }
                if (z) {
                    throw new IllegalArgumentException(str + " contains a parameter with @NullablePosition that is not @BlockPosition");
                }
                return INPUT_CHANNEL;
            }
        }

        public ParameterMetadata(ParameterType parameterType) {
            this(parameterType, null);
        }

        public ParameterMetadata(ParameterType parameterType, Type type) {
            Preconditions.checkArgument((type == null) == (parameterType == ParameterType.BLOCK_INDEX || parameterType == ParameterType.STATE), "sqlType must be provided only for input channels");
            this.parameterType = parameterType;
            this.sqlType = type;
        }

        public static ParameterMetadata fromSqlType(Type type, boolean z, boolean z2, String str) {
            return new ParameterMetadata(ParameterType.inputChannelParameterType(z2, z, str), type);
        }

        public static ParameterMetadata forBlockIndexParameter() {
            return new ParameterMetadata(ParameterType.BLOCK_INDEX);
        }

        public static ParameterMetadata forStateParameter() {
            return new ParameterMetadata(ParameterType.STATE);
        }

        public ParameterType getParameterType() {
            return this.parameterType;
        }

        public Type getSqlType() {
            return this.sqlType;
        }
    }

    public AggregationMetadata(String str, List<ParameterMetadata> list, MethodHandle methodHandle, MethodHandle methodHandle2, MethodHandle methodHandle3, List<AccumulatorStateDescriptor> list2, Type type) {
        this(str, list, methodHandle, methodHandle2, methodHandle3, list2, type, ImmutableList.of());
    }

    public AggregationMetadata(String str, List<ParameterMetadata> list, MethodHandle methodHandle, MethodHandle methodHandle2, MethodHandle methodHandle3, List<AccumulatorStateDescriptor> list2, Type type, List<Class> list3) {
        this.outputType = (Type) Objects.requireNonNull(type);
        this.valueInputMetadata = ImmutableList.copyOf((Collection) Objects.requireNonNull(list, "valueInputMetadata is null"));
        this.name = (String) Objects.requireNonNull(str, "name is null");
        this.inputFunction = (MethodHandle) Objects.requireNonNull(methodHandle, "inputFunction is null");
        this.combineFunction = (MethodHandle) Objects.requireNonNull(methodHandle2, "combineFunction is null");
        this.outputFunction = (MethodHandle) Objects.requireNonNull(methodHandle3, "outputFunction is null");
        this.accumulatorStateDescriptors = (List) Objects.requireNonNull(list2, "accumulatorStateDescriptors is null");
        this.lambdaInterfaces = ImmutableList.copyOf((Collection) Objects.requireNonNull(list3, "lambdaInterfaces is null"));
        verifyInputFunctionSignature(methodHandle, list, list3, list2);
        verifyCombineFunction(methodHandle2, list3, list2);
        verifyExactOutputFunction(methodHandle3, list2);
    }

    public Type getOutputType() {
        return this.outputType;
    }

    public List<ParameterMetadata> getValueInputMetadata() {
        return this.valueInputMetadata;
    }

    public List<Class> getLambdaInterfaces() {
        return this.lambdaInterfaces;
    }

    public String getName() {
        return this.name;
    }

    public MethodHandle getInputFunction() {
        return this.inputFunction;
    }

    public MethodHandle getCombineFunction() {
        return this.combineFunction;
    }

    public MethodHandle getOutputFunction() {
        return this.outputFunction;
    }

    public List<AccumulatorStateDescriptor> getAccumulatorStateDescriptors() {
        return this.accumulatorStateDescriptors;
    }

    private static void verifyInputFunctionSignature(MethodHandle methodHandle, List<ParameterMetadata> list, List<Class> list2, List<AccumulatorStateDescriptor> list3) {
        Class<?>[] parameterArray = methodHandle.type().parameterArray();
        Preconditions.checkArgument(parameterArray.length > 0, "Aggregation input function must have at least one parameter");
        Preconditions.checkArgument(parameterArray.length == list.size() + list2.size(), String.format("Number of parameters in input function (%d) must be the total number of data channels and lambda channels (%d)", Integer.valueOf(parameterArray.length), Integer.valueOf(list.size() + list2.size())));
        Preconditions.checkArgument(list.stream().filter(parameterMetadata -> {
            return parameterMetadata.getParameterType() == ParameterMetadata.ParameterType.STATE;
        }).count() == ((long) list3.size()), "Number of state parameter in input function must be the same as size of stateDescriptors");
        Preconditions.checkArgument(list.get(0).getParameterType() == ParameterMetadata.ParameterType.STATE, "First parameter must be state");
        int i = 0;
        for (int i2 = 0; i2 < list.size(); i2++) {
            ParameterMetadata parameterMetadata2 = list.get(i2);
            switch (AnonymousClass1.$SwitchMap$com$facebook$presto$operator$aggregation$AggregationMetadata$ParameterMetadata$ParameterType[parameterMetadata2.getParameterType().ordinal()]) {
                case 1:
                    Preconditions.checkArgument(list3.get(i).getStateInterface() == parameterArray[i2], String.format("State argument must be of type %s", list3.get(i).getStateInterface()));
                    i++;
                    break;
                case 2:
                case TableWriterUtils.STATS_START_CHANNEL /* 3 */:
                    Preconditions.checkArgument(parameterArray[i2] == Block.class, "Parameter must be Block if it has @BlockPosition");
                    break;
                case 4:
                    Preconditions.checkArgument(SUPPORTED_PARAMETER_TYPES.contains(parameterArray[i2]), "Unsupported type: %s", parameterArray[i2].getSimpleName());
                    verifyMethodParameterType(methodHandle, i2, parameterMetadata2.getSqlType().getJavaType(), parameterMetadata2.getSqlType().getDisplayName());
                    break;
                case MaterializedResult.DEFAULT_PRECISION /* 5 */:
                    Preconditions.checkArgument(parameterArray[i2] == Integer.TYPE, "Block index parameter must be an int");
                    break;
                default:
                    throw new IllegalArgumentException("Unsupported parameter: " + parameterMetadata2.getParameterType());
            }
        }
        Preconditions.checkArgument(i == list3.size(), String.format("Input function only has %d states, expected: %d.", Integer.valueOf(i), Integer.valueOf(list3.size())));
        for (int i3 = 0; i3 < list2.size(); i3++) {
            verifyMethodParameterType(methodHandle, i3 + list.size(), list2.get(i3), "function");
        }
    }

    private static void verifyCombineFunction(MethodHandle methodHandle, List<Class> list, List<AccumulatorStateDescriptor> list2) {
        Class<?>[] parameterArray = methodHandle.type().parameterArray();
        Preconditions.checkArgument(parameterArray.length == (list2.size() * 2) + list.size(), "Number of arguments for combine function must be 2 times the size of states plus number of lambda channels.");
        for (int i = 0; i < list2.size() * 2; i++) {
            Preconditions.checkArgument(parameterArray[i].equals(list2.get(i % list2.size()).getStateInterface()), String.format("Type for Parameter index %d is unexpected. Arguments for combine function must appear in the order of state1, state2, ..., otherState1, otherState2, ...", Integer.valueOf(i)));
        }
        for (int i2 = 0; i2 < list.size(); i2++) {
            verifyMethodParameterType(methodHandle, i2 + (list2.size() * 2), list.get(i2), "function");
        }
    }

    private static void verifyExactOutputFunction(MethodHandle methodHandle, List<AccumulatorStateDescriptor> list) {
        if (methodHandle == null) {
            return;
        }
        Class<?>[] parameterArray = methodHandle.type().parameterArray();
        Preconditions.checkArgument(parameterArray.length == list.size() + 1, "Number of arguments for combine function must be exactly one plus than number of states.");
        for (int i = 0; i < list.size(); i++) {
            Preconditions.checkArgument(parameterArray[i].equals(list.get(i).getStateInterface()), String.format("Type for Parameter index %d is unexpected", Integer.valueOf(i)));
        }
        Preconditions.checkArgument(Arrays.stream(parameterArray).filter(cls -> {
            return cls.equals(BlockBuilder.class);
        }).count() == 1, "Output function must take exactly one BlockBuilder parameter");
    }

    private static void verifyMethodParameterType(MethodHandle methodHandle, int i, Class cls, String str) {
        Preconditions.checkArgument(methodHandle.type().parameterArray()[i] == cls, "Expected method %s parameter %s type to be %s (%s)", methodHandle, Integer.valueOf(i), cls.getName(), str);
    }

    public static int countInputChannels(List<ParameterMetadata> list) {
        int i = 0;
        for (ParameterMetadata parameterMetadata : list) {
            if (parameterMetadata.getParameterType() == ParameterMetadata.ParameterType.INPUT_CHANNEL || parameterMetadata.getParameterType() == ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL || parameterMetadata.getParameterType() == ParameterMetadata.ParameterType.NULLABLE_BLOCK_INPUT_CHANNEL) {
                i++;
            }
        }
        return i;
    }
}
