package com.facebook.presto.operator.aggregation;

import com.facebook.presto.block.Block;
import com.facebook.presto.block.BlockBuilder;
import com.facebook.presto.block.BlockCursor;
import com.facebook.presto.operator.GroupByIdBlock;
import com.facebook.presto.operator.aggregation.SimpleAggregationFunction;
import com.facebook.presto.tuple.TupleInfo;
import com.facebook.presto.util.array.DoubleBigArray;
import com.facebook.presto.util.array.LongBigArray;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;

/* loaded from: input_file:com/facebook/presto/operator/aggregation/VarianceAggregation.class */
public class VarianceAggregation extends SimpleAggregationFunction {
    protected final boolean population;
    protected final boolean inputIsLong;
    protected final boolean standardDeviation;

    /* loaded from: input_file:com/facebook/presto/operator/aggregation/VarianceAggregation$VarianceAccumulator.class */
    public static class VarianceAccumulator extends SimpleAggregationFunction.SimpleAccumulator {
        private final boolean inputIsLong;
        private final boolean population;
        private final boolean standardDeviation;
        private long currentCount;
        private double currentMean;
        private double currentM2;

        private VarianceAccumulator(int i, boolean z, boolean z2, boolean z3, Optional<Integer> optional, Optional<Integer> optional2) {
            super(i, TupleInfo.SINGLE_DOUBLE, TupleInfo.SINGLE_VARBINARY, optional, optional2);
            this.inputIsLong = z;
            this.population = z2;
            this.standardDeviation = z3;
        }

        @Override // com.facebook.presto.operator.aggregation.SimpleAggregationFunction.SimpleAccumulator
        protected void processInput(Block block, Optional<Block> optional, Optional<Block> optional2) {
            BlockCursor cursor = block.cursor();
            BlockCursor cursor2 = optional.isPresent() ? ((Block) optional.get()).cursor() : null;
            BlockCursor cursor3 = optional2.isPresent() ? ((Block) optional2.get()).cursor() : null;
            for (int i = 0; i < block.getPositionCount(); i++) {
                Preconditions.checkState(cursor.advanceNextPosition());
                Preconditions.checkState(cursor2 == null || cursor2.advanceNextPosition());
                Preconditions.checkState(cursor3 == null || cursor3.advanceNextPosition());
                long computeSampleWeight = SimpleAggregationFunction.computeSampleWeight(cursor2, cursor3);
                if (!cursor.isNull() && computeSampleWeight > 0) {
                    double d = this.inputIsLong ? cursor.getLong() : cursor.getDouble();
                    for (int i2 = 0; i2 < computeSampleWeight; i2++) {
                        this.currentCount++;
                        double d2 = d - this.currentMean;
                        this.currentMean += d2 / this.currentCount;
                        this.currentM2 += d2 * (d - this.currentMean);
                    }
                }
            }
            Preconditions.checkState(!cursor.advanceNextPosition());
        }

        @Override // com.facebook.presto.operator.aggregation.SimpleAggregationFunction.SimpleAccumulator
        protected void processIntermediate(Block block) {
            BlockCursor cursor = block.cursor();
            for (int i = 0; i < block.getPositionCount(); i++) {
                Preconditions.checkState(cursor.advanceNextPosition());
                if (!cursor.isNull()) {
                    Slice slice = cursor.getSlice();
                    long count = VarianceAggregation.getCount(slice);
                    double mean = VarianceAggregation.getMean(slice);
                    double m2 = VarianceAggregation.getM2(slice);
                    if (count > 0) {
                        long j = this.currentCount + count;
                        double d = ((this.currentCount * this.currentMean) + (count * mean)) / j;
                        double d2 = mean - this.currentMean;
                        double d3 = this.currentM2 + m2 + (((d2 * d2) * (this.currentCount * count)) / j);
                        this.currentCount = j;
                        this.currentMean = d;
                        this.currentM2 = d3;
                    }
                }
            }
            Preconditions.checkState(!cursor.advanceNextPosition());
        }

        @Override // com.facebook.presto.operator.aggregation.SimpleAggregationFunction.SimpleAccumulator
        public void evaluateIntermediate(BlockBuilder blockBuilder) {
            blockBuilder.append(VarianceAggregation.createIntermediate(this.currentCount, this.currentMean, this.currentM2));
        }

        @Override // com.facebook.presto.operator.aggregation.SimpleAggregationFunction.SimpleAccumulator
        public void evaluateFinal(BlockBuilder blockBuilder) {
            if (this.population) {
                if (this.currentCount == 0) {
                    blockBuilder.appendNull();
                    return;
                }
                double d = this.currentM2 / this.currentCount;
                if (this.standardDeviation) {
                    d = Math.sqrt(d);
                }
                blockBuilder.append(d);
                return;
            }
            if (this.currentCount < 2) {
                blockBuilder.appendNull();
                return;
            }
            double d2 = this.currentM2 / (this.currentCount - 1);
            if (this.standardDeviation) {
                d2 = Math.sqrt(d2);
            }
            blockBuilder.append(d2);
        }
    }

    /* loaded from: input_file:com/facebook/presto/operator/aggregation/VarianceAggregation$VarianceGroupedAccumulator.class */
    public static class VarianceGroupedAccumulator extends SimpleAggregationFunction.SimpleGroupedAccumulator {
        private final boolean inputIsLong;
        private final boolean population;
        private final boolean standardDeviation;
        private final LongBigArray counts;
        private final DoubleBigArray means;
        private final DoubleBigArray m2s;

        private VarianceGroupedAccumulator(int i, boolean z, boolean z2, boolean z3, Optional<Integer> optional, Optional<Integer> optional2) {
            super(i, TupleInfo.SINGLE_DOUBLE, TupleInfo.SINGLE_VARBINARY, optional, optional2);
            this.inputIsLong = z;
            this.population = z2;
            this.standardDeviation = z3;
            this.counts = new LongBigArray();
            this.means = new DoubleBigArray();
            this.m2s = new DoubleBigArray();
        }

        @Override // com.facebook.presto.operator.aggregation.GroupedAccumulator
        public long getEstimatedSize() {
            return this.counts.sizeOf() + this.means.sizeOf() + this.m2s.sizeOf();
        }

        @Override // com.facebook.presto.operator.aggregation.SimpleAggregationFunction.SimpleGroupedAccumulator
        protected void processInput(GroupByIdBlock groupByIdBlock, Block block, Optional<Block> optional, Optional<Block> optional2) {
            this.counts.ensureCapacity(groupByIdBlock.getGroupCount());
            this.means.ensureCapacity(groupByIdBlock.getGroupCount());
            this.m2s.ensureCapacity(groupByIdBlock.getGroupCount());
            BlockCursor cursor = block.cursor();
            BlockCursor cursor2 = optional.isPresent() ? ((Block) optional.get()).cursor() : null;
            BlockCursor cursor3 = optional2.isPresent() ? ((Block) optional2.get()).cursor() : null;
            for (int i = 0; i < groupByIdBlock.getPositionCount(); i++) {
                Preconditions.checkState(cursor.advanceNextPosition());
                Preconditions.checkState(cursor2 == null || cursor2.advanceNextPosition());
                Preconditions.checkState(cursor3 == null || cursor3.advanceNextPosition());
                long computeSampleWeight = SimpleAggregationFunction.computeSampleWeight(cursor2, cursor3);
                if (!cursor.isNull() && computeSampleWeight > 0) {
                    long groupId = groupByIdBlock.getGroupId(i);
                    double d = this.inputIsLong ? cursor.getLong() : cursor.getDouble();
                    long j = this.counts.get(groupId);
                    double d2 = this.means.get(groupId);
                    for (int i2 = 0; i2 < computeSampleWeight; i2++) {
                        j++;
                        double d3 = d - d2;
                        d2 += d3 / j;
                        this.m2s.add(groupId, d3 * (d - d2));
                    }
                    this.counts.set(groupId, j);
                    this.means.set(groupId, d2);
                }
            }
            Preconditions.checkState(!cursor.advanceNextPosition());
        }

        @Override // com.facebook.presto.operator.aggregation.SimpleAggregationFunction.SimpleGroupedAccumulator
        protected void processIntermediate(GroupByIdBlock groupByIdBlock, Block block) {
            this.counts.ensureCapacity(groupByIdBlock.getGroupCount());
            this.means.ensureCapacity(groupByIdBlock.getGroupCount());
            this.m2s.ensureCapacity(groupByIdBlock.getGroupCount());
            BlockCursor cursor = block.cursor();
            for (int i = 0; i < groupByIdBlock.getPositionCount(); i++) {
                Preconditions.checkState(cursor.advanceNextPosition());
                if (!cursor.isNull()) {
                    long groupId = groupByIdBlock.getGroupId(i);
                    Slice slice = cursor.getSlice();
                    long count = VarianceAggregation.getCount(slice);
                    double mean = VarianceAggregation.getMean(slice);
                    double m2 = VarianceAggregation.getM2(slice);
                    long j = this.counts.get(groupId);
                    double d = this.means.get(groupId);
                    double d2 = this.m2s.get(groupId);
                    if (count > 0) {
                        long j2 = j + count;
                        double d3 = ((j * d) + (count * mean)) / j2;
                        double d4 = mean - d;
                        this.counts.set(groupId, j2);
                        this.means.set(groupId, d3);
                        this.m2s.set(groupId, d2 + m2 + (((d4 * d4) * (j * count)) / j2));
                    }
                }
            }
            Preconditions.checkState(!cursor.advanceNextPosition());
        }

        @Override // com.facebook.presto.operator.aggregation.SimpleAggregationFunction.SimpleGroupedAccumulator, com.facebook.presto.operator.aggregation.GroupedAccumulator
        public void evaluateIntermediate(int i, BlockBuilder blockBuilder) {
            blockBuilder.append(VarianceAggregation.createIntermediate(this.counts.get(i), this.means.get(i), this.m2s.get(i)));
        }

        @Override // com.facebook.presto.operator.aggregation.SimpleAggregationFunction.SimpleGroupedAccumulator, com.facebook.presto.operator.aggregation.GroupedAccumulator
        public void evaluateFinal(int i, BlockBuilder blockBuilder) {
            long j = this.counts.get(i);
            if (this.population) {
                if (j == 0) {
                    blockBuilder.appendNull();
                    return;
                }
                double d = this.m2s.get(i) / j;
                if (this.standardDeviation) {
                    d = Math.sqrt(d);
                }
                blockBuilder.append(d);
                return;
            }
            if (j < 2) {
                blockBuilder.appendNull();
                return;
            }
            double d2 = this.m2s.get(i) / (j - 1);
            if (this.standardDeviation) {
                d2 = Math.sqrt(d2);
            }
            blockBuilder.append(d2);
        }
    }

    public VarianceAggregation(TupleInfo.Type type, boolean z, boolean z2) {
        super(TupleInfo.SINGLE_DOUBLE, TupleInfo.SINGLE_VARBINARY, type);
        this.population = z;
        if (type == TupleInfo.Type.FIXED_INT_64) {
            this.inputIsLong = true;
        } else {
            if (type != TupleInfo.Type.DOUBLE) {
                throw new IllegalArgumentException("Expected parameter type to be FIXED_INT_64 or DOUBLE, but was " + type);
            }
            this.inputIsLong = false;
        }
        this.standardDeviation = z2;
    }

    @Override // com.facebook.presto.operator.aggregation.SimpleAggregationFunction
    protected GroupedAccumulator createGroupedAccumulator(Optional<Integer> optional, Optional<Integer> optional2, double d, int i) {
        Preconditions.checkArgument(d == 1.0d, "variance does not support approximate queries");
        return new VarianceGroupedAccumulator(i, this.inputIsLong, this.population, this.standardDeviation, optional, optional2);
    }

    @Override // com.facebook.presto.operator.aggregation.SimpleAggregationFunction
    protected Accumulator createAccumulator(Optional<Integer> optional, Optional<Integer> optional2, double d, int i) {
        Preconditions.checkArgument(d == 1.0d, "variance does not support approximate queries");
        return new VarianceAccumulator(i, this.inputIsLong, this.population, this.standardDeviation, optional, optional2);
    }

    public static long getCount(Slice slice) {
        return slice.getLong(0);
    }

    public static double getMean(Slice slice) {
        return slice.getDouble(8);
    }

    public static double getM2(Slice slice) {
        return slice.getDouble(16);
    }

    public static Slice createIntermediate(long j, double d, double d2) {
        Slice allocate = Slices.allocate(24);
        allocate.setLong(0, j);
        allocate.setDouble(8, d);
        allocate.setDouble(16, d2);
        return allocate;
    }
}
