package com.facebook.presto.ml;

import com.facebook.presto.operator.GroupByIdBlock;
import com.facebook.presto.operator.Page;
import com.facebook.presto.operator.aggregation.Accumulator;
import com.facebook.presto.operator.aggregation.AggregationFunction;
import com.facebook.presto.operator.aggregation.GroupedAccumulator;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.block.BlockBuilderStatus;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.spi.type.VarcharType;
import com.facebook.presto.util.array.LongBigArray;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import java.util.List;

/* loaded from: input_file:com/facebook/presto/ml/EvaluateClassifierPredictionsAggregation.class */
public class EvaluateClassifierPredictionsAggregation implements AggregationFunction {

    /* loaded from: input_file:com/facebook/presto/ml/EvaluateClassifierPredictionsAggregation$EvaluatePredictionsAccumulator.class */
    public static class EvaluatePredictionsAccumulator implements Accumulator {
        private final int labelChannel;
        private final int predictionChannel;
        private long truePositives;
        private long falsePositives;
        private long trueNegatives;
        private long falseNegatives;

        public EvaluatePredictionsAccumulator(int i, int i2) {
            this.labelChannel = i;
            this.predictionChannel = i2;
        }

        public long getEstimatedSize() {
            return 0L;
        }

        public Type getFinalType() {
            return VarcharType.VARCHAR;
        }

        public Type getIntermediateType() {
            return VarcharType.VARCHAR;
        }

        public void addInput(Page page) {
            Block block = page.getBlock(this.labelChannel);
            Block block2 = page.getBlock(this.predictionChannel);
            for (int i = 0; i < block.getPositionCount(); i++) {
                long j = block2.getLong(i);
                long j2 = block.getLong(i);
                Preconditions.checkArgument(j == 1 || j == 0, "evaluate_predictions only supports binary classifiers");
                Preconditions.checkArgument(j2 == 1 || j2 == 0, "evaluate_predictions only supports binary classifiers");
                if (j2 == 1) {
                    if (j == 1) {
                        this.truePositives++;
                    } else {
                        this.falseNegatives++;
                    }
                } else if (j == 0) {
                    this.trueNegatives++;
                } else {
                    this.falsePositives++;
                }
            }
        }

        public void addIntermediate(Block block) {
            Preconditions.checkState(block.getPositionCount() == 1);
            Slice slice = block.getSlice(0);
            this.truePositives += slice.getLong(0);
            this.falsePositives += slice.getLong(8);
            this.trueNegatives += slice.getLong(16);
            this.falseNegatives += slice.getLong(24);
        }

        public Block evaluateIntermediate() {
            return getIntermediateType().createBlockBuilder(new BlockBuilderStatus()).appendSlice(EvaluateClassifierPredictionsAggregation.createIntermediate(this.truePositives, this.falsePositives, this.trueNegatives, this.falseNegatives)).build();
        }

        public Block evaluateFinal() {
            StringBuilder sb = new StringBuilder();
            long j = this.trueNegatives + this.truePositives;
            long j2 = this.truePositives + this.trueNegatives + this.falsePositives + this.falseNegatives;
            sb.append(String.format("Accuracy: %d/%d (%.2f%%)\n", Long.valueOf(j), Long.valueOf(j2), Double.valueOf((100.0d * j) / j2)));
            sb.append(String.format("Precision: %d/%d (%.2f%%)\n", Long.valueOf(this.truePositives), Long.valueOf(this.truePositives + this.falsePositives), Double.valueOf((100.0d * this.truePositives) / (this.truePositives + this.falsePositives))));
            sb.append(String.format("Recall: %d/%d (%.2f%%)", Long.valueOf(this.truePositives), Long.valueOf(this.truePositives + this.falseNegatives), Double.valueOf((100.0d * this.truePositives) / (this.truePositives + this.falseNegatives))));
            BlockBuilder createBlockBuilder = getFinalType().createBlockBuilder(new BlockBuilderStatus());
            createBlockBuilder.appendSlice(Slices.utf8Slice(sb.toString()));
            return createBlockBuilder.build();
        }
    }

    /* loaded from: input_file:com/facebook/presto/ml/EvaluateClassifierPredictionsAggregation$EvaluatePredictionsGroupedAccumulator.class */
    public static class EvaluatePredictionsGroupedAccumulator implements GroupedAccumulator {
        private final int labelChannel;
        private final int predictionChannel;
        private final LongBigArray truePositives = new LongBigArray();
        private final LongBigArray falsePositives = new LongBigArray();
        private final LongBigArray trueNegatives = new LongBigArray();
        private final LongBigArray falseNegatives = new LongBigArray();

        public EvaluatePredictionsGroupedAccumulator(int i, int i2) {
            this.labelChannel = i;
            this.predictionChannel = i2;
        }

        public Type getFinalType() {
            return VarcharType.VARCHAR;
        }

        public Type getIntermediateType() {
            return VarcharType.VARCHAR;
        }

        public long getEstimatedSize() {
            return this.truePositives.sizeOf() + this.falsePositives.sizeOf() + this.trueNegatives.sizeOf() + this.falseNegatives.sizeOf();
        }

        public void addInput(GroupByIdBlock groupByIdBlock, Page page) {
            this.truePositives.ensureCapacity(groupByIdBlock.getGroupCount());
            this.falsePositives.ensureCapacity(groupByIdBlock.getGroupCount());
            this.trueNegatives.ensureCapacity(groupByIdBlock.getGroupCount());
            this.falseNegatives.ensureCapacity(groupByIdBlock.getGroupCount());
            Block block = page.getBlock(this.labelChannel);
            Block block2 = page.getBlock(this.predictionChannel);
            for (int i = 0; i < groupByIdBlock.getPositionCount(); i++) {
                long groupId = groupByIdBlock.getGroupId(i);
                long j = block2.getLong(i);
                long j2 = block.getLong(i);
                Preconditions.checkArgument(j == 1 || j == 0, "evaluate_predictions only supports binary classifiers");
                Preconditions.checkArgument(j2 == 1 || j2 == 0, "evaluate_predictions only supports binary classifiers");
                if (j2 == 1) {
                    if (j == 1) {
                        this.truePositives.increment(groupId);
                    } else {
                        this.falseNegatives.increment(groupId);
                    }
                } else if (j == 0) {
                    this.trueNegatives.increment(groupId);
                } else {
                    this.falsePositives.increment(groupId);
                }
            }
        }

        public void addIntermediate(GroupByIdBlock groupByIdBlock, Block block) {
            this.truePositives.ensureCapacity(groupByIdBlock.getGroupCount());
            this.falsePositives.ensureCapacity(groupByIdBlock.getGroupCount());
            this.trueNegatives.ensureCapacity(groupByIdBlock.getGroupCount());
            this.falseNegatives.ensureCapacity(groupByIdBlock.getGroupCount());
            for (int i = 0; i < groupByIdBlock.getPositionCount(); i++) {
                long groupId = groupByIdBlock.getGroupId(i);
                Slice slice = block.getSlice(i);
                this.truePositives.add(groupId, slice.getLong(0));
                this.falsePositives.add(groupId, slice.getLong(8));
                this.trueNegatives.add(groupId, slice.getLong(16));
                this.falseNegatives.add(groupId, slice.getLong(24));
            }
        }

        public void evaluateIntermediate(int i, BlockBuilder blockBuilder) {
            blockBuilder.appendSlice(EvaluateClassifierPredictionsAggregation.createIntermediate(this.truePositives.get(i), this.falsePositives.get(i), this.trueNegatives.get(i), this.falseNegatives.get(i))).build();
        }

        public void evaluateFinal(int i, BlockBuilder blockBuilder) {
            blockBuilder.appendSlice(Slices.utf8Slice(EvaluateClassifierPredictionsAggregation.formatResults(this.truePositives.get(i), this.falsePositives.get(i), this.trueNegatives.get(i), this.falseNegatives.get(i))));
        }
    }

    public List<Type> getParameterTypes() {
        return ImmutableList.of(BigintType.BIGINT, BigintType.BIGINT);
    }

    public Type getFinalType() {
        return VarcharType.VARCHAR;
    }

    public Type getIntermediateType() {
        return VarcharType.VARCHAR;
    }

    public boolean isDecomposable() {
        return true;
    }

    public Accumulator createAggregation(Optional<Integer> optional, Optional<Integer> optional2, double d, int... iArr) {
        Preconditions.checkArgument(!optional.isPresent(), "masking is not supported");
        Preconditions.checkArgument(d == 1.0d, "approximation is not supported");
        Preconditions.checkArgument(!optional2.isPresent(), "sample weight is not supported");
        return new EvaluatePredictionsAccumulator(iArr[0], iArr[1]);
    }

    public Accumulator createIntermediateAggregation(double d) {
        Preconditions.checkArgument(d == 1.0d, "approximation is not supported");
        return new EvaluatePredictionsAccumulator(-1, -1);
    }

    public GroupedAccumulator createGroupedAggregation(Optional<Integer> optional, Optional<Integer> optional2, double d, int... iArr) {
        Preconditions.checkArgument(!optional.isPresent(), "masking is not supported");
        Preconditions.checkArgument(d == 1.0d, "approximation is not supported");
        Preconditions.checkArgument(!optional2.isPresent(), "sample weight is not supported");
        return new EvaluatePredictionsGroupedAccumulator(iArr[0], iArr[1]);
    }

    public GroupedAccumulator createGroupedIntermediateAggregation(double d) {
        Preconditions.checkArgument(d == 1.0d, "approximation is not supported");
        return new EvaluatePredictionsGroupedAccumulator(-1, -1);
    }

    public static String formatResults(long j, long j2, long j3, long j4) {
        StringBuilder sb = new StringBuilder();
        long j5 = j3 + j;
        long j6 = j + j3 + j2 + j4;
        sb.append(String.format("Accuracy: %d/%d (%.2f%%), ", Long.valueOf(j5), Long.valueOf(j6), Double.valueOf((100.0d * j5) / j6)));
        sb.append(String.format("Precision: %d/%d (%.2f%%), ", Long.valueOf(j), Long.valueOf(j + j2), Double.valueOf((100.0d * j) / (j + j2))));
        sb.append(String.format("Recall: %d/%d (%.2f%%)", Long.valueOf(j), Long.valueOf(j + j4), Double.valueOf((100.0d * j) / (j + j4))));
        return sb.toString();
    }

    public static Slice createIntermediate(long j, long j2, long j3, long j4) {
        Slice allocate = Slices.allocate(32);
        allocate.setLong(0, j);
        allocate.setLong(8, j2);
        allocate.setLong(16, j3);
        allocate.setLong(24, j4);
        return allocate;
    }
}
