package com.facebook.presto.orc.stream;

import com.facebook.presto.memory.context.AggregatedMemoryContext;
import com.facebook.presto.orc.OrcCorruptionException;
import com.facebook.presto.orc.OrcDataSourceId;
import com.facebook.presto.spi.type.Decimals;
import com.facebook.presto.spi.type.UnscaledDecimal128Arithmetic;
import io.airlift.slice.BasicSliceInput;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.math.BigInteger;
import java.util.Optional;
import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/orc/stream/TestDecimalStream.class */
public class TestDecimalStream {
    private static final BigInteger BIG_INTEGER_127_BIT_SET;

    @Test
    public void testShortDecimals() throws IOException {
        assertReadsShortValue(0L);
        assertReadsShortValue(1L);
        assertReadsShortValue(-1L);
        assertReadsShortValue(256L);
        assertReadsShortValue(-256L);
        assertReadsShortValue(Long.MAX_VALUE);
        assertReadsShortValue(Long.MIN_VALUE);
    }

    @Test
    public void testShouldFailWhenShortDecimalDoesNotFit() {
        assertShortValueReadFails(BigInteger.valueOf(Long.MAX_VALUE).add(BigInteger.ONE));
    }

    @Test
    public void testShouldFailWhenExceeds128Bits() {
        assertLongValueReadFails(BigInteger.valueOf(1L).shiftLeft(127));
        assertLongValueReadFails(BigInteger.valueOf(-2L).shiftLeft(127));
    }

    @Test
    public void testLongDecimals() throws IOException {
        assertReadsLongValue(BigInteger.valueOf(0L));
        assertReadsLongValue(BigInteger.valueOf(1L));
        assertReadsLongValue(BigInteger.valueOf(-1L));
        assertReadsLongValue(BigInteger.valueOf(-1L).shiftLeft(126));
        assertReadsLongValue(BigInteger.valueOf(1L).shiftLeft(126));
        assertReadsLongValue(BIG_INTEGER_127_BIT_SET);
        assertReadsLongValue(BIG_INTEGER_127_BIT_SET.negate());
        assertReadsLongValue(Decimals.MAX_DECIMAL_UNSCALED_VALUE);
        assertReadsLongValue(Decimals.MIN_DECIMAL_UNSCALED_VALUE);
    }

    @Test
    public void testSkipsValue() throws IOException {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        writeBigInteger(byteArrayOutputStream, BigInteger.valueOf(Long.MAX_VALUE));
        writeBigInteger(byteArrayOutputStream, BigInteger.valueOf(Long.MIN_VALUE));
        DecimalInputStream decimalInputStream = new DecimalInputStream(orcInputStreamFor("skip test", byteArrayOutputStream.toByteArray()));
        decimalInputStream.skip(1L);
        Assert.assertEquals(decimalInputStream.nextLong(), Long.MIN_VALUE);
    }

    private static void assertReadsShortValue(long j) throws IOException {
        Assert.assertEquals(new DecimalInputStream(decimalInputStream(BigInteger.valueOf(j))).nextLong(), j);
    }

    private static void assertReadsLongValue(BigInteger bigInteger) throws IOException {
        DecimalInputStream decimalInputStream = new DecimalInputStream(decimalInputStream(bigInteger));
        Slice unscaledDecimal = UnscaledDecimal128Arithmetic.unscaledDecimal();
        decimalInputStream.nextLongDecimal(unscaledDecimal);
        Assert.assertEquals(UnscaledDecimal128Arithmetic.unscaledDecimalToBigInteger(unscaledDecimal), bigInteger);
    }

    private static void assertShortValueReadFails(BigInteger bigInteger) {
        Assert.assertThrows(OrcCorruptionException.class, () -> {
            new DecimalInputStream(decimalInputStream(bigInteger)).nextLong();
        });
    }

    private static void assertLongValueReadFails(BigInteger bigInteger) {
        Slice unscaledDecimal = UnscaledDecimal128Arithmetic.unscaledDecimal();
        Assert.assertThrows(OrcCorruptionException.class, () -> {
            new DecimalInputStream(decimalInputStream(bigInteger)).nextLongDecimal(unscaledDecimal);
        });
    }

    private static OrcInputStream decimalInputStream(BigInteger bigInteger) throws IOException {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        writeBigInteger(byteArrayOutputStream, bigInteger);
        return orcInputStreamFor(bigInteger.toString(), byteArrayOutputStream.toByteArray());
    }

    private static OrcInputStream orcInputStreamFor(String str, byte[] bArr) {
        return new OrcInputStream(new OrcDataSourceId(str), new BasicSliceInput(Slices.wrappedBuffer(bArr)), Optional.empty(), AggregatedMemoryContext.newSimpleAggregatedMemoryContext());
    }

    private static void writeBigInteger(OutputStream outputStream, BigInteger bigInteger) throws IOException {
        BigInteger shiftLeft = bigInteger.shiftLeft(1);
        if (shiftLeft.signum() < 0) {
            shiftLeft = shiftLeft.negate().subtract(BigInteger.ONE);
        }
        int bitLength = shiftLeft.bitLength();
        while (true) {
            long longValue = shiftLeft.longValue() & Long.MAX_VALUE;
            bitLength -= 63;
            for (int i = 0; i < 9; i++) {
                if (bitLength <= 0 && (longValue & (-128)) == 0) {
                    outputStream.write((byte) longValue);
                    return;
                } else {
                    outputStream.write((byte) (128 | (longValue & 127)));
                    longValue >>>= 7;
                }
            }
            shiftLeft = shiftLeft.shiftRight(63);
        }
    }

    static {
        BigInteger bigInteger = BigInteger.ZERO;
        for (int i = 0; i < 127; i++) {
            bigInteger = bigInteger.setBit(i);
        }
        BIG_INTEGER_127_BIT_SET = bigInteger;
    }
}
