package com.facebook.presto.operator.aggregation;

import com.facebook.presto.block.BlockAssertions;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.block.BlockBuilderStatus;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.DoubleType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.google.common.collect.ImmutableList;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/operator/aggregation/TestPrecisionRecallAggregation.class */
public abstract class TestPrecisionRecallAggregation extends AbstractTestAggregationFunction {
    private static final Integer NUM_BINS = 3;
    private static final double MIN_FALSE_PRED = 0.2d;
    private static final double MAX_FALSE_PRED = 0.5d;
    private final String functionName;
    private InternalAggregationFunction precisionRecallFunction;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:com/facebook/presto/operator/aggregation/TestPrecisionRecallAggregation$BucketResult.class */
    public static class BucketResult {
        public final Double left;
        public final Double right;
        public final Double totalTrueWeight;
        public final Double totalFalseWeight;
        public final Double remainingTrueWeight;
        public final Double remainingFalseWeight;

        public BucketResult(Double d, Double d2, Double d3, Double d4, Double d5, Double d6) {
            this.left = d;
            this.right = d2;
            this.totalTrueWeight = d3;
            this.totalFalseWeight = d4;
            this.remainingTrueWeight = d5;
            this.remainingFalseWeight = d6;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/facebook/presto/operator/aggregation/TestPrecisionRecallAggregation$Result.class */
    public static class Result {
        public final Boolean outcome;
        public final Double prediction;

        public Result(Boolean bool, Double d) {
            this.outcome = bool;
            this.prediction = d;
        }
    }

    @BeforeClass
    public void setUp() {
        FunctionAndTypeManager functionAndTypeManager = MetadataManager.createTestMetadataManager().getFunctionAndTypeManager();
        this.precisionRecallFunction = functionAndTypeManager.getAggregateFunctionImplementation(functionAndTypeManager.lookupFunction(this.functionName, TypeSignatureProvider.fromTypes(new Type[]{BigintType.BIGINT, BooleanType.BOOLEAN, DoubleType.DOUBLE, DoubleType.DOUBLE})));
    }

    @Test
    public void testNegativeWeight() {
        try {
            AggregationTestUtils.assertAggregation(this.precisionRecallFunction, Double.valueOf(0.0d), BlockAssertions.createLongsBlock(200L), BlockAssertions.createBooleansBlock(true), BlockAssertions.createDoublesBlock(Double.valueOf(MIN_FALSE_PRED)), BlockAssertions.createDoublesBlock(Double.valueOf(-0.2d)));
            Assert.fail("Expected PrestoException");
        } catch (PrestoException e) {
            Assert.assertTrue(e.getMessage().toLowerCase(Locale.ENGLISH).contains("weight"));
            Assert.assertTrue(e.getMessage().toLowerCase(Locale.ENGLISH).contains("negative"));
        }
    }

    @Test
    public void testTooHighPrediction() {
        try {
            AggregationTestUtils.assertAggregation(this.precisionRecallFunction, Double.valueOf(0.0d), BlockAssertions.createLongsBlock(200L), BlockAssertions.createBooleansBlock(true), BlockAssertions.createDoublesBlock(Double.valueOf(1.2d)), BlockAssertions.createDoublesBlock(Double.valueOf(MIN_FALSE_PRED)));
            Assert.fail("Expected PrestoException");
        } catch (PrestoException e) {
            Assert.assertTrue(e.getMessage().toLowerCase(Locale.ENGLISH).contains("prediction"));
        }
    }

    @Test
    public void testTooLowPrediction() {
        try {
            AggregationTestUtils.assertAggregation(this.precisionRecallFunction, Double.valueOf(0.0d), BlockAssertions.createLongsBlock(200L), BlockAssertions.createBooleansBlock(true), BlockAssertions.createDoublesBlock(Double.valueOf(-1.2d)), BlockAssertions.createDoublesBlock(Double.valueOf(MIN_FALSE_PRED)));
            Assert.fail("Expected PrestoException");
        } catch (PrestoException e) {
            Assert.assertTrue(e.getMessage().toLowerCase(Locale.ENGLISH).contains("prediction"));
        }
    }

    @Test
    public void testNonConstantBuckets() {
        try {
            AggregationTestUtils.assertAggregation(this.precisionRecallFunction, Double.valueOf(0.0d), BlockAssertions.createLongsBlock(200L, 300L), BlockAssertions.createBooleansBlock(true, false), BlockAssertions.createDoublesBlock(Double.valueOf(MIN_FALSE_PRED), Double.valueOf(0.3d)), BlockAssertions.createDoublesBlock(Double.valueOf(1.0d), Double.valueOf(1.0d)));
            Assert.fail("Expected PrestoException");
        } catch (PrestoException e) {
            Assert.assertTrue(e.getMessage().toLowerCase(Locale.ENGLISH).contains("bucket"));
        }
    }

    @Override // com.facebook.presto.operator.aggregation.AbstractTestAggregationFunction
    public Block[] getSequenceBlocks(int i, int i2) {
        int abs = Math.abs(i);
        BlockBuilder createBlockBuilder = BigintType.BIGINT.createBlockBuilder((BlockBuilderStatus) null, i2);
        BlockBuilder createBlockBuilder2 = BooleanType.BOOLEAN.createBlockBuilder((BlockBuilderStatus) null, i2);
        BlockBuilder createBlockBuilder3 = DoubleType.DOUBLE.createBlockBuilder((BlockBuilderStatus) null, i2);
        for (int i3 = abs; i3 < abs + i2; i3++) {
            BigintType.BIGINT.writeLong(createBlockBuilder, NUM_BINS.intValue());
            Result result = getResult(abs, i2, i3);
            BooleanType.BOOLEAN.writeBoolean(createBlockBuilder2, result.outcome.booleanValue());
            DoubleType.DOUBLE.writeDouble(createBlockBuilder3, result.prediction.doubleValue());
        }
        return new Block[]{createBlockBuilder.build(), createBlockBuilder2.build(), createBlockBuilder3.build()};
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static Iterator<BucketResult> getResultsIterator(final int i, final int i2) {
        final int abs = Math.abs(i);
        return new Iterator<BucketResult>() { // from class: com.facebook.presto.operator.aggregation.TestPrecisionRecallAggregation.1
            int i;

            @Override // java.util.Iterator
            public boolean hasNext() {
                Double valueOf = Double.valueOf(this.i / TestPrecisionRecallAggregation.NUM_BINS.intValue());
                for (int i3 = i; i3 < abs + i2; i3++) {
                    Result result = TestPrecisionRecallAggregation.getResult(abs, i2, i3);
                    if (result.outcome.booleanValue() && result.prediction.doubleValue() >= valueOf.doubleValue()) {
                        return true;
                    }
                }
                return false;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public BucketResult next() {
                Double valueOf = Double.valueOf(this.i / TestPrecisionRecallAggregation.NUM_BINS.intValue());
                Double valueOf2 = Double.valueOf((this.i + 1) / TestPrecisionRecallAggregation.NUM_BINS.intValue());
                Double valueOf3 = Double.valueOf(0.0d);
                Double valueOf4 = Double.valueOf(0.0d);
                Double valueOf5 = Double.valueOf(0.0d);
                Double valueOf6 = Double.valueOf(0.0d);
                for (int i3 = i; i3 < i + i2; i3++) {
                    Result result = TestPrecisionRecallAggregation.getResult(i, i2, i3);
                    if (result.outcome.booleanValue()) {
                        valueOf3 = Double.valueOf(valueOf3.doubleValue() + 1.0d);
                        if (result.prediction.doubleValue() >= valueOf.doubleValue()) {
                            valueOf5 = Double.valueOf(valueOf5.doubleValue() + 1.0d);
                        }
                    } else {
                        valueOf4 = Double.valueOf(valueOf4.doubleValue() + 1.0d);
                        if (result.prediction.doubleValue() >= valueOf.doubleValue()) {
                            valueOf6 = Double.valueOf(valueOf6.doubleValue() + 1.0d);
                        }
                    }
                }
                this.i++;
                return new BucketResult(valueOf, valueOf2, valueOf3, valueOf4, valueOf5, valueOf6);
            }

            @Override // java.util.Iterator
            public void remove() {
                throw new UnsupportedOperationException();
            }
        };
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public TestPrecisionRecallAggregation(String str) {
        this.functionName = str;
    }

    @Override // com.facebook.presto.operator.aggregation.AbstractTestAggregationFunction
    protected String getFunctionName() {
        return this.functionName;
    }

    @Override // com.facebook.presto.operator.aggregation.AbstractTestAggregationFunction
    protected List<String> getFunctionParameterTypes() {
        return ImmutableList.of("integer", "boolean", "double");
    }

    protected static Result getResult(int i, int i2, int i3) {
        Double valueOf = Double.valueOf(Double.valueOf(i3 - i).doubleValue() / (i2 + 1));
        return new Result(Boolean.valueOf(valueOf.doubleValue() < MIN_FALSE_PRED || valueOf.doubleValue() > MAX_FALSE_PRED), valueOf);
    }
}
