package com.facebook.presto.orc.writer;

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.block.BlockBuilderStatus;
import com.facebook.presto.common.block.RunLengthEncodedBlock;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.orc.ColumnWriterOptions;
import com.facebook.presto.orc.OrcCorruptionException;
import com.facebook.presto.orc.OrcDataSourceId;
import com.facebook.presto.orc.OrcDecompressor;
import com.facebook.presto.orc.OrcEncoding;
import com.facebook.presto.orc.TestingHiveOrcAggregatedMemoryContext;
import com.facebook.presto.orc.TestingOrcPredicate;
import com.facebook.presto.orc.metadata.ColumnEncoding;
import com.facebook.presto.orc.metadata.CompressionKind;
import com.facebook.presto.orc.metadata.Stream;
import com.facebook.presto.orc.stream.ByteArrayInputStream;
import com.facebook.presto.orc.stream.LongInputStream;
import com.facebook.presto.orc.stream.LongInputStreamV1;
import com.facebook.presto.orc.stream.LongInputStreamV2;
import com.facebook.presto.orc.stream.OrcInputStream;
import com.facebook.presto.orc.stream.SharedBuffer;
import com.facebook.presto.orc.stream.StreamDataOutput;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.MoreCollectors;
import io.airlift.slice.DynamicSliceOutput;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.units.DataSize;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ThreadLocalRandom;
import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/orc/writer/TestSliceDictionaryColumnWriter.class */
public class TestSliceDictionaryColumnWriter {
    private static final int COLUMN_ID = 1;
    private static final OrcDataSourceId ORC_DATA_SOURCE_ID = new OrcDataSourceId("test");

    private StreamDataOutput getStreamKind(List<StreamDataOutput> list, Stream.StreamKind streamKind) {
        return (StreamDataOutput) list.stream().filter(streamDataOutput -> {
            return streamDataOutput.getStream().getStreamKind() == streamKind;
        }).collect(MoreCollectors.onlyElement());
    }

    private Optional<OrcDecompressor> getOrcDecompressor() {
        return OrcDecompressor.createOrcDecompressor(ORC_DATA_SOURCE_ID, CompressionKind.SNAPPY, Math.toIntExact(new DataSize(256.0d, DataSize.Unit.KILOBYTE).toBytes()));
    }

    private OrcInputStream convertSliceToInputStream(Slice slice) {
        TestingHiveOrcAggregatedMemoryContext testingHiveOrcAggregatedMemoryContext = new TestingHiveOrcAggregatedMemoryContext();
        return new OrcInputStream(ORC_DATA_SOURCE_ID, new SharedBuffer(testingHiveOrcAggregatedMemoryContext.newOrcLocalMemoryContext("sharedDecompressionBuffer")), slice.getInput(), getOrcDecompressor(), Optional.empty(), testingHiveOrcAggregatedMemoryContext, slice.getRetainedSize());
    }

    private Slice convertStreamToSlice(StreamDataOutput streamDataOutput) throws OrcCorruptionException {
        DynamicSliceOutput dynamicSliceOutput = new DynamicSliceOutput(Math.toIntExact(streamDataOutput.size()));
        streamDataOutput.writeData(dynamicSliceOutput);
        return dynamicSliceOutput.slice();
    }

    private OrcInputStream getOrcInputStream(List<StreamDataOutput> list, Stream.StreamKind streamKind) throws OrcCorruptionException {
        return convertSliceToInputStream(convertStreamToSlice(getStreamKind(list, streamKind)));
    }

    private LongInputStream getDictionaryLengthStream(List<StreamDataOutput> list, OrcEncoding orcEncoding) {
        return orcEncoding == OrcEncoding.DWRF ? new LongInputStreamV1(getOrcInputStream(list, Stream.StreamKind.LENGTH), false) : new LongInputStreamV2(getOrcInputStream(list, Stream.StreamKind.LENGTH), false, false);
    }

    private List<String> getDictionaryKeys(List<String> list, OrcEncoding orcEncoding, boolean z) throws IOException {
        DictionaryColumnWriter dictionaryColumnWriter = getDictionaryColumnWriter(orcEncoding, z);
        int i = 0;
        while (i < list.size()) {
            int min = Math.min(i + TestingOrcPredicate.ORC_ROW_GROUP_SIZE, list.size());
            BlockBuilder createBlockBuilder = VarcharType.VARCHAR.createBlockBuilder((BlockBuilderStatus) null, TestingOrcPredicate.ORC_ROW_GROUP_SIZE);
            while (i < min) {
                int i2 = i;
                i++;
                VarcharType.VARCHAR.writeSlice(createBlockBuilder, Slices.utf8Slice(list.get(i2)));
            }
            dictionaryColumnWriter.beginRowGroup();
            dictionaryColumnWriter.writeBlock(createBlockBuilder);
            dictionaryColumnWriter.finishRowGroup();
        }
        dictionaryColumnWriter.close();
        List<StreamDataOutput> dataStreams = dictionaryColumnWriter.getDataStreams();
        int dictionarySize = ((ColumnEncoding) dictionaryColumnWriter.getColumnEncodings().get(1)).getDictionarySize();
        ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(getOrcInputStream(dataStreams, Stream.StreamKind.DICTIONARY_DATA));
        LongInputStream dictionaryLengthStream = getDictionaryLengthStream(dataStreams, orcEncoding);
        ArrayList arrayList = new ArrayList(dictionarySize);
        for (int i3 = 0; i3 < dictionarySize; i3++) {
            arrayList.add(new String(byteArrayInputStream.next(Math.toIntExact(dictionaryLengthStream.next())), StandardCharsets.UTF_8));
        }
        return arrayList;
    }

    private DictionaryColumnWriter getDictionaryColumnWriter(OrcEncoding orcEncoding, boolean z) {
        return new SliceDictionaryColumnWriter(1, 0, VarcharType.VARCHAR, ColumnWriterOptions.builder().setCompressionKind(CompressionKind.SNAPPY).setStringDictionarySortingEnabled(z).build(), Optional.empty(), orcEncoding, orcEncoding.createMetadataWriter());
    }

    @Test
    public void testSortedDictionaryKeys() throws IOException {
        for (OrcEncoding orcEncoding : OrcEncoding.values()) {
            Assert.assertEquals(getDictionaryKeys(ImmutableList.of("b", "a", "c"), orcEncoding, true), ImmutableList.of("a", "b", "c"));
            Assert.assertEquals(getDictionaryKeys(ImmutableList.of("b", "b", "a"), orcEncoding, true), ImmutableList.of("a", "b"));
        }
    }

    @Test
    public void testUnsortedDictionaryKeys() throws IOException {
        Assert.assertEquals(getDictionaryKeys(ImmutableList.of("b", "a", "c"), OrcEncoding.DWRF, false), ImmutableList.of("b", "a", "c"));
        Assert.assertEquals(getDictionaryKeys(ImmutableList.of("b", "b", "a"), OrcEncoding.DWRF, false), ImmutableList.of("b", "a"));
    }

    @Test(expectedExceptions = {IllegalStateException.class})
    public void testOrcStringSortingDisabledThrows() {
        getDictionaryColumnWriter(OrcEncoding.ORC, false);
    }

    @Test
    public void testStringDirectConversion() {
        byte[] bArr = new byte[megabytes(1)];
        ThreadLocalRandom.current().nextBytes(bArr);
        Block create = RunLengthEncodedBlock.create(VarcharType.VARCHAR, Slices.wrappedBuffer(bArr), 3000);
        for (OrcEncoding orcEncoding : OrcEncoding.values()) {
            DictionaryColumnWriter dictionaryColumnWriter = getDictionaryColumnWriter(orcEncoding, true);
            dictionaryColumnWriter.beginRowGroup();
            dictionaryColumnWriter.writeBlock(create);
            dictionaryColumnWriter.finishRowGroup();
            Assert.assertFalse(dictionaryColumnWriter.tryConvertToDirect(megabytes(64)).isPresent());
        }
    }

    private static int megabytes(int i) {
        return Math.toIntExact(new DataSize(i, DataSize.Unit.MEGABYTE).toBytes());
    }
}
