package io.prestosql.orc;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Longs;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.prestosql.orc.metadata.ColumnMetadata;
import io.prestosql.orc.metadata.CompressedMetadataWriter;
import io.prestosql.orc.metadata.CompressionKind;
import io.prestosql.orc.metadata.OrcColumnId;
import io.prestosql.orc.metadata.OrcMetadataReader;
import io.prestosql.orc.metadata.OrcMetadataWriter;
import io.prestosql.orc.metadata.statistics.BinaryStatistics;
import io.prestosql.orc.metadata.statistics.BloomFilter;
import io.prestosql.orc.metadata.statistics.BooleanStatistics;
import io.prestosql.orc.metadata.statistics.ColumnStatistics;
import io.prestosql.orc.metadata.statistics.DateStatistics;
import io.prestosql.orc.metadata.statistics.DecimalStatistics;
import io.prestosql.orc.metadata.statistics.DoubleStatistics;
import io.prestosql.orc.metadata.statistics.IntegerStatistics;
import io.prestosql.orc.metadata.statistics.StringStatistics;
import io.prestosql.orc.metadata.statistics.TimestampStatistics;
import io.prestosql.orc.metadata.statistics.Utf8BloomFilterBuilder;
import io.prestosql.orc.proto.OrcProto;
import io.prestosql.orc.protobuf.CodedInputStream;
import io.prestosql.spi.predicate.Domain;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.BooleanType;
import io.prestosql.spi.type.DateType;
import io.prestosql.spi.type.DoubleType;
import io.prestosql.spi.type.IntegerType;
import io.prestosql.spi.type.RealType;
import io.prestosql.spi.type.SmallintType;
import io.prestosql.spi.type.TimestampType;
import io.prestosql.spi.type.TinyintType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.VarbinaryType;
import io.prestosql.spi.type.VarcharType;
import java.math.BigDecimal;
import java.nio.charset.StandardCharsets;
import java.sql.Timestamp;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ThreadLocalRandom;
import org.apache.orc.util.Murmur3;
import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:io/prestosql/orc/TestOrcBloomFilters.class */
public class TestOrcBloomFilters {
    private static final int TEST_INTEGER = 12345;
    private static final byte[] TEST_STRING = "ORC_STRING".getBytes(StandardCharsets.UTF_8);
    private static final byte[] TEST_STRING_NOT_WRITTEN = "ORC_STRING_not".getBytes(StandardCharsets.UTF_8);
    private static final Map<Object, Type> TEST_VALUES = ImmutableMap.builder().put(Slices.wrappedBuffer(TEST_STRING), VarcharType.VARCHAR).put(Slices.wrappedBuffer(new byte[]{12, 34, 56}), VarbinaryType.VARBINARY).put(4312L, BigintType.BIGINT).put(123, IntegerType.INTEGER).put(789, SmallintType.SMALLINT).put(77, TinyintType.TINYINT).put(901, DateType.DATE).put(987654L, TimestampType.TIMESTAMP_MILLIS).put(Double.valueOf(234.567d), DoubleType.DOUBLE).put(Long.valueOf(Float.floatToIntBits(987.654f)), RealType.REAL).build();

    @Test
    public void testHiveBloomFilterSerde() {
        BloomFilter bloomFilter = new BloomFilter(1000000L, 0.05d);
        bloomFilter.add(TEST_STRING);
        Assert.assertTrue(bloomFilter.test(TEST_STRING));
        Assert.assertTrue(bloomFilter.testSlice(Slices.wrappedBuffer(TEST_STRING)));
        Assert.assertFalse(bloomFilter.test(TEST_STRING_NOT_WRITTEN));
        Assert.assertFalse(bloomFilter.testSlice(Slices.wrappedBuffer(TEST_STRING_NOT_WRITTEN)));
        bloomFilter.addLong(12345L);
        Assert.assertTrue(bloomFilter.testLong(12345L));
        Assert.assertFalse(bloomFilter.testLong(12346L));
        BloomFilter bloomFilter2 = new BloomFilter(bloomFilter.getBitSet(), bloomFilter.getNumHashFunctions());
        Assert.assertTrue(bloomFilter2.test(TEST_STRING));
        Assert.assertTrue(bloomFilter2.testSlice(Slices.wrappedBuffer(TEST_STRING)));
        Assert.assertFalse(bloomFilter2.test(TEST_STRING_NOT_WRITTEN));
        Assert.assertFalse(bloomFilter2.testSlice(Slices.wrappedBuffer(TEST_STRING_NOT_WRITTEN)));
        Assert.assertTrue(bloomFilter2.testLong(12345L));
        Assert.assertFalse(bloomFilter2.testLong(12346L));
    }

    @Test
    public void testOrcHiveBloomFilterSerde() throws Exception {
        BloomFilter bloomFilter = new BloomFilter(1000L, 0.05d);
        bloomFilter.add(TEST_STRING);
        Assert.assertTrue(bloomFilter.test(TEST_STRING));
        Assert.assertTrue(bloomFilter.testSlice(Slices.wrappedBuffer(TEST_STRING)));
        Slice writeBloomFilters = new CompressedMetadataWriter(new OrcMetadataWriter(true), CompressionKind.NONE, 1024).writeBloomFilters(ImmutableList.of(bloomFilter));
        List readBloomFilterIndexes = new OrcMetadataReader().readBloomFilterIndexes(writeBloomFilters.getInput());
        Assert.assertEquals(readBloomFilterIndexes.size(), 1);
        Assert.assertTrue(((BloomFilter) readBloomFilterIndexes.get(0)).test(TEST_STRING));
        Assert.assertTrue(((BloomFilter) readBloomFilterIndexes.get(0)).testSlice(Slices.wrappedBuffer(TEST_STRING)));
        Assert.assertFalse(((BloomFilter) readBloomFilterIndexes.get(0)).test(TEST_STRING_NOT_WRITTEN));
        Assert.assertFalse(((BloomFilter) readBloomFilterIndexes.get(0)).testSlice(Slices.wrappedBuffer(TEST_STRING_NOT_WRITTEN)));
        Assert.assertEquals(bloomFilter.getNumBits(), ((BloomFilter) readBloomFilterIndexes.get(0)).getNumBits());
        Assert.assertEquals(bloomFilter.getNumHashFunctions(), ((BloomFilter) readBloomFilterIndexes.get(0)).getNumHashFunctions());
        Assert.assertTrue(Arrays.equals(((BloomFilter) readBloomFilterIndexes.get(0)).getBitSet(), bloomFilter.getBitSet()));
        List bloomFilterList = OrcProto.BloomFilterIndex.parseFrom(CodedInputStream.newInstance(writeBloomFilters.getBytes())).getBloomFilterList();
        Assert.assertEquals(bloomFilterList.size(), 1);
        OrcProto.BloomFilter bloomFilter2 = (OrcProto.BloomFilter) bloomFilterList.get(0);
        Assert.assertTrue(Arrays.equals(Longs.toArray(bloomFilter2.getBitsetList()), bloomFilter.getBitSet()));
        Assert.assertEquals(bloomFilter.getNumHashFunctions(), bloomFilter2.getNumHashFunctions());
        Assert.assertEquals(bloomFilter.getBitSet().length, bloomFilter2.getBitsetCount());
    }

    @Test
    public void testBloomFilterPredicateValuesExisting() {
        BloomFilter bloomFilter = new BloomFilter(TEST_VALUES.size() * 10, 0.01d);
        for (Map.Entry<Object, Type> entry : TEST_VALUES.entrySet()) {
            Object key = entry.getKey();
            if (key instanceof Long) {
                if (entry.getValue() instanceof RealType) {
                    bloomFilter.addDouble(Float.intBitsToFloat(((Number) key).intValue()));
                } else {
                    bloomFilter.addLong(((Long) key).longValue());
                }
            } else if (key instanceof Integer) {
                bloomFilter.addLong(((Integer) key).intValue());
            } else if (key instanceof String) {
                bloomFilter.add(((String) key).getBytes(StandardCharsets.UTF_8));
            } else if (key instanceof BigDecimal) {
                bloomFilter.add(key.toString().getBytes(StandardCharsets.UTF_8));
            } else if (key instanceof Slice) {
                bloomFilter.add(((Slice) key).getBytes());
            } else if (key instanceof Timestamp) {
                bloomFilter.addLong(((Timestamp) key).getTime());
            } else if (key instanceof Double) {
                bloomFilter.addDouble(((Double) key).doubleValue());
            } else {
                Assert.fail("Unsupported type " + key.getClass());
            }
        }
        for (Map.Entry<Object, Type> entry2 : TEST_VALUES.entrySet()) {
            Assert.assertTrue(TupleDomainOrcPredicate.checkInBloomFilter(bloomFilter, entry2.getKey(), entry2.getValue()), "type " + entry2.getClass());
        }
    }

    @Test
    public void testBloomFilterPredicateValuesNonExisting() {
        BloomFilter bloomFilter = new BloomFilter(TEST_VALUES.size() * 10, 0.01d);
        for (Map.Entry<Object, Type> entry : TEST_VALUES.entrySet()) {
            Assert.assertFalse(TupleDomainOrcPredicate.checkInBloomFilter(bloomFilter, entry.getKey(), entry.getValue()), "type " + entry.getKey().getClass());
        }
    }

    @Test
    public void testExtractValuesFromSingleDomain() {
        for (Map.Entry entry : ImmutableMap.builder().put(BooleanType.BOOLEAN, true).put(IntegerType.INTEGER, 1234L).put(SmallintType.SMALLINT, 789L).put(TinyintType.TINYINT, 77L).put(DateType.DATE, 901L).put(TimestampType.TIMESTAMP_MILLIS, 987654L).put(BigintType.BIGINT, 4321L).put(DoubleType.DOUBLE, Double.valueOf(0.123d)).put(RealType.REAL, Long.valueOf(Float.floatToIntBits(0.456f))).put(VarcharType.VARCHAR, Slices.wrappedBuffer(TEST_STRING)).build().entrySet()) {
            Optional extractDiscreteValues = TupleDomainOrcPredicate.extractDiscreteValues(Domain.singleValue((Type) entry.getKey(), entry.getValue()).getValues());
            Assert.assertTrue(extractDiscreteValues.isPresent());
            Collection collection = (Collection) extractDiscreteValues.get();
            Assert.assertEquals(collection.size(), 1);
            Assert.assertEquals(collection.iterator().next(), entry.getValue());
        }
    }

    @Test
    public void testMatches() {
        TupleDomainOrcPredicate build = TupleDomainOrcPredicate.builder().setBloomFiltersEnabled(true).addColumn(OrcColumnId.ROOT_COLUMN, Domain.singleValue(BigintType.BIGINT, 1234L)).build();
        TupleDomainOrcPredicate build2 = TupleDomainOrcPredicate.builder().build();
        ColumnMetadata columnMetadata = new ColumnMetadata(ImmutableList.of(new ColumnStatistics((Long) null, 0L, (BooleanStatistics) null, new IntegerStatistics(10L, 2000L, (Long) null), (DoubleStatistics) null, (StringStatistics) null, (DateStatistics) null, (TimestampStatistics) null, (DecimalStatistics) null, (BinaryStatistics) null, new Utf8BloomFilterBuilder(1000, 0.01d).addLong(1234L).buildBloomFilter())));
        ColumnMetadata columnMetadata2 = new ColumnMetadata(ImmutableList.of(new ColumnStatistics((Long) null, 0L, (BooleanStatistics) null, new IntegerStatistics(10L, 2000L, (Long) null), (DoubleStatistics) null, (StringStatistics) null, (DateStatistics) null, (TimestampStatistics) null, (DecimalStatistics) null, (BinaryStatistics) null, new Utf8BloomFilterBuilder(1000, 0.01d).buildBloomFilter())));
        ColumnMetadata columnMetadata3 = new ColumnMetadata(ImmutableList.of(new ColumnStatistics((Long) null, 0L, (BooleanStatistics) null, new IntegerStatistics(10L, 2000L, (Long) null), (DoubleStatistics) null, (StringStatistics) null, (DateStatistics) null, (TimestampStatistics) null, (DecimalStatistics) null, (BinaryStatistics) null, (BloomFilter) null)));
        Assert.assertTrue(build.matches(1L, columnMetadata));
        Assert.assertTrue(build.matches(1L, columnMetadata3));
        Assert.assertFalse(build.matches(1L, columnMetadata2));
        Assert.assertTrue(build2.matches(1L, columnMetadata));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Test
    public void testBloomFilterCompatibility() {
        for (int i = 0; i < 200; i++) {
            double nextDouble = ThreadLocalRandom.current().nextDouble(0.01d, 0.1d);
            int nextInt = ThreadLocalRandom.current().nextInt(100, TestingOrcPredicate.ORC_ROW_GROUP_SIZE);
            int nextInt2 = ThreadLocalRandom.current().nextInt(nextInt / 2, nextInt);
            BloomFilter bloomFilter = new BloomFilter(nextInt, nextDouble);
            org.apache.orc.util.BloomFilter bloomFilter2 = new org.apache.orc.util.BloomFilter(nextInt, nextDouble);
            Assert.assertFalse(bloomFilter.test((byte[]) null));
            Assert.assertFalse(bloomFilter2.test((byte[]) null));
            byte[] bArr = new byte[nextInt2];
            long[] jArr = new long[nextInt2];
            double[] dArr = new double[nextInt2];
            float[] fArr = new float[nextInt2];
            for (int i2 = 0; i2 < nextInt2; i2++) {
                bArr[i2] = randomBytes(ThreadLocalRandom.current().nextInt(100));
                jArr[i2] = ThreadLocalRandom.current().nextLong();
                dArr[i2] = ThreadLocalRandom.current().nextDouble();
                fArr[i2] = ThreadLocalRandom.current().nextFloat();
            }
            for (int i3 = 0; i3 < nextInt2; i3++) {
                Assert.assertFalse(bloomFilter.test(bArr[i3]));
                Assert.assertFalse(bloomFilter.testSlice(Slices.wrappedBuffer(bArr[i3])));
                Assert.assertFalse(bloomFilter.testLong(jArr[i3]));
                Assert.assertFalse(bloomFilter.testDouble(dArr[i3]));
                Assert.assertFalse(bloomFilter.testFloat(fArr[i3]));
                Assert.assertFalse(bloomFilter2.test(bArr[i3]));
                Assert.assertFalse(bloomFilter2.testLong(jArr[i3]));
                Assert.assertFalse(bloomFilter2.testDouble(dArr[i3]));
                Assert.assertFalse(bloomFilter2.testDouble(fArr[i3]));
            }
            for (int i4 = 0; i4 < nextInt2; i4++) {
                bloomFilter.add(bArr[i4]);
                bloomFilter.addLong(jArr[i4]);
                bloomFilter.addDouble(dArr[i4]);
                bloomFilter.addFloat(fArr[i4]);
                bloomFilter2.add(bArr[i4]);
                bloomFilter2.addLong(jArr[i4]);
                bloomFilter2.addDouble(dArr[i4]);
                bloomFilter2.addDouble(fArr[i4]);
            }
            for (int i5 = 0; i5 < nextInt2; i5++) {
                Assert.assertTrue(bloomFilter.test(bArr[i5]));
                Assert.assertTrue(bloomFilter.testSlice(Slices.wrappedBuffer(bArr[i5])));
                Assert.assertTrue(bloomFilter.testLong(jArr[i5]));
                Assert.assertTrue(bloomFilter.testDouble(dArr[i5]));
                Assert.assertTrue(bloomFilter.testFloat(fArr[i5]));
                Assert.assertTrue(bloomFilter2.test(bArr[i5]));
                Assert.assertTrue(bloomFilter2.testLong(jArr[i5]));
                Assert.assertTrue(bloomFilter2.testDouble(dArr[i5]));
                Assert.assertTrue(bloomFilter2.testDouble(fArr[i5]));
            }
            bloomFilter.add((byte[]) null);
            bloomFilter2.add((byte[]) null);
            Assert.assertTrue(bloomFilter.test((byte[]) null));
            Assert.assertTrue(bloomFilter.testSlice((Slice) null));
            Assert.assertTrue(bloomFilter2.test((byte[]) null));
            Assert.assertEquals(bloomFilter.getBitSet(), bloomFilter2.getBitSet());
        }
    }

    @Test
    public void testHashCompatibility() {
        for (int i = 0; i < 1000; i++) {
            for (int i2 = 0; i2 < 100; i2++) {
                byte[] randomBytes = randomBytes(i);
                Assert.assertEquals(BloomFilter.OrcMurmur3.hash64(randomBytes), Murmur3.hash64(randomBytes));
            }
        }
    }

    private static byte[] randomBytes(int i) {
        byte[] bArr = new byte[i];
        ThreadLocalRandom.current().nextBytes(bArr);
        return bArr;
    }
}
