package com.facebook.presto.operator;

import com.facebook.presto.Session;
import com.facebook.presto.block.BlockAssertions;
import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.PageBuilder;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.DoubleType;
import com.facebook.presto.spi.type.VarcharType;
import com.facebook.presto.sql.gen.JoinCompiler;
import com.facebook.presto.testing.TestingSession;
import com.facebook.presto.type.TypeUtils;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Optional;
import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/operator/TestGroupByHash.class */
public class TestGroupByHash {
    private static final int MAX_GROUP_ID = 500;
    private static final int[] CONTAINS_CHANNELS = {0};
    private static final Session TEST_SESSION = TestingSession.testSessionBuilder().build();
    private static final JoinCompiler JOIN_COMPILER = new JoinCompiler();

    @Test
    public void testAddPage() throws Exception {
        GroupByHash createGroupByHash = GroupByHash.createGroupByHash(TEST_SESSION, ImmutableList.of(BigintType.BIGINT), new int[]{0}, Optional.of(1), 100, JOIN_COMPILER);
        int i = 0;
        while (i < 2) {
            for (int i2 = 0; i2 < MAX_GROUP_ID; i2++) {
                Block createLongsBlock = BlockAssertions.createLongsBlock(i2);
                Page page = new Page(new Block[]{createLongsBlock, TypeUtils.getHashBlock(ImmutableList.of(BigintType.BIGINT), new Block[]{createLongsBlock})});
                for (int i3 = 0; i3 < 10; i3++) {
                    createGroupByHash.addPage(page);
                    Assert.assertEquals(createGroupByHash.getGroupCount(), i == 0 ? i2 + 1 : MAX_GROUP_ID);
                    GroupByIdBlock groupIds = createGroupByHash.getGroupIds(page);
                    Assert.assertEquals(createGroupByHash.getGroupCount(), i == 0 ? i2 + 1 : MAX_GROUP_ID);
                    Assert.assertEquals(groupIds.getGroupCount(), i == 0 ? i2 + 1 : 500L);
                    Assert.assertEquals(groupIds.getPositionCount(), 1);
                    Assert.assertEquals(groupIds.getGroupId(0), i2);
                }
            }
            i++;
        }
    }

    @Test
    public void testNullGroup() throws Exception {
        GroupByHash createGroupByHash = GroupByHash.createGroupByHash(TEST_SESSION, ImmutableList.of(BigintType.BIGINT), new int[]{0}, Optional.of(1), 100, JOIN_COMPILER);
        Block createLongsBlock = BlockAssertions.createLongsBlock((Long) null);
        createGroupByHash.addPage(new Page(new Block[]{createLongsBlock, TypeUtils.getHashBlock(ImmutableList.of(BigintType.BIGINT), new Block[]{createLongsBlock})}));
        Block createLongSequenceBlock = BlockAssertions.createLongSequenceBlock(1, 132748);
        createGroupByHash.addPage(new Page(new Block[]{createLongSequenceBlock, TypeUtils.getHashBlock(ImmutableList.of(BigintType.BIGINT), new Block[]{createLongSequenceBlock})}));
        Block createLongsBlock2 = BlockAssertions.createLongsBlock(0);
        Assert.assertFalse(createGroupByHash.contains(0, new Page(new Block[]{createLongsBlock2, TypeUtils.getHashBlock(ImmutableList.of(BigintType.BIGINT), new Block[]{createLongsBlock2})}), CONTAINS_CHANNELS));
    }

    @Test
    public void testGetGroupIds() throws Exception {
        GroupByHash createGroupByHash = GroupByHash.createGroupByHash(TEST_SESSION, ImmutableList.of(BigintType.BIGINT), new int[]{0}, Optional.of(1), 100, JOIN_COMPILER);
        int i = 0;
        while (i < 2) {
            for (int i2 = 0; i2 < MAX_GROUP_ID; i2++) {
                Block createLongsBlock = BlockAssertions.createLongsBlock(i2);
                Page page = new Page(new Block[]{createLongsBlock, TypeUtils.getHashBlock(ImmutableList.of(BigintType.BIGINT), new Block[]{createLongsBlock})});
                for (int i3 = 0; i3 < 10; i3++) {
                    GroupByIdBlock groupIds = createGroupByHash.getGroupIds(page);
                    Assert.assertEquals(groupIds.getGroupCount(), i == 0 ? i2 + 1 : 500L);
                    Assert.assertEquals(groupIds.getPositionCount(), 1);
                    Assert.assertEquals(groupIds.getGroupId(0), i2);
                }
            }
            i++;
        }
    }

    @Test
    public void testTypes() throws Exception {
        Assert.assertEquals(GroupByHash.createGroupByHash(TEST_SESSION, ImmutableList.of(VarcharType.VARCHAR), new int[]{0}, Optional.of(1), 100, JOIN_COMPILER).getTypes(), ImmutableList.of(VarcharType.VARCHAR, BigintType.BIGINT));
    }

    @Test
    public void testAppendTo() throws Exception {
        Block createStringSequenceBlock = BlockAssertions.createStringSequenceBlock(0, 100);
        Block hashBlock = TypeUtils.getHashBlock(ImmutableList.of(VarcharType.VARCHAR), new Block[]{createStringSequenceBlock});
        GroupByHash createGroupByHash = GroupByHash.createGroupByHash(TEST_SESSION, ImmutableList.of(VarcharType.VARCHAR), new int[]{0}, Optional.of(1), 100, JOIN_COMPILER);
        GroupByIdBlock groupIds = createGroupByHash.getGroupIds(new Page(new Block[]{createStringSequenceBlock, hashBlock}));
        for (int i = 0; i < groupIds.getPositionCount(); i++) {
            Assert.assertEquals(groupIds.getGroupId(i), i);
        }
        Assert.assertEquals(createGroupByHash.getGroupCount(), 100);
        PageBuilder pageBuilder = new PageBuilder(createGroupByHash.getTypes());
        for (int i2 = 0; i2 < createGroupByHash.getGroupCount(); i2++) {
            pageBuilder.declarePosition();
            createGroupByHash.appendValuesTo(i2, pageBuilder, 0);
        }
        Page build = pageBuilder.build();
        for (int i3 = 0; i3 < createGroupByHash.getTypes().size(); i3++) {
            Assert.assertEquals(build.getBlock(i3).getPositionCount(), 100);
        }
        Assert.assertEquals(build.getPositionCount(), 100);
        BlockAssertions.assertBlockEquals(VarcharType.VARCHAR, build.getBlock(0), createStringSequenceBlock);
        BlockAssertions.assertBlockEquals(BigintType.BIGINT, build.getBlock(1), hashBlock);
    }

    @Test
    public void testAppendToMultipleTuplesPerGroup() throws Exception {
        ArrayList arrayList = new ArrayList();
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= 100) {
                break;
            }
            arrayList.add(Long.valueOf(j2 % 50));
            j = j2 + 1;
        }
        Block createLongsBlock = BlockAssertions.createLongsBlock(arrayList);
        Block hashBlock = TypeUtils.getHashBlock(ImmutableList.of(BigintType.BIGINT), new Block[]{createLongsBlock});
        GroupByHash createGroupByHash = GroupByHash.createGroupByHash(TEST_SESSION, ImmutableList.of(BigintType.BIGINT), new int[]{0}, Optional.of(1), 100, JOIN_COMPILER);
        createGroupByHash.getGroupIds(new Page(new Block[]{createLongsBlock, hashBlock}));
        Assert.assertEquals(createGroupByHash.getGroupCount(), 50);
        PageBuilder pageBuilder = new PageBuilder(createGroupByHash.getTypes());
        for (int i = 0; i < createGroupByHash.getGroupCount(); i++) {
            pageBuilder.declarePosition();
            createGroupByHash.appendValuesTo(i, pageBuilder, 0);
        }
        Page build = pageBuilder.build();
        Assert.assertEquals(build.getPositionCount(), 50);
        BlockAssertions.assertBlockEquals(BigintType.BIGINT, build.getBlock(0), BlockAssertions.createLongSequenceBlock(0, 50));
    }

    @Test
    public void testContains() throws Exception {
        Block createDoubleSequenceBlock = BlockAssertions.createDoubleSequenceBlock(0, 10);
        Block hashBlock = TypeUtils.getHashBlock(ImmutableList.of(DoubleType.DOUBLE), new Block[]{createDoubleSequenceBlock});
        GroupByHash createGroupByHash = GroupByHash.createGroupByHash(TEST_SESSION, ImmutableList.of(DoubleType.DOUBLE), new int[]{0}, Optional.of(1), 100, JOIN_COMPILER);
        createGroupByHash.getGroupIds(new Page(new Block[]{createDoubleSequenceBlock, hashBlock}));
        Block createDoublesBlock = BlockAssertions.createDoublesBlock(Double.valueOf(3.0d));
        Assert.assertTrue(createGroupByHash.contains(0, new Page(new Block[]{createDoublesBlock, TypeUtils.getHashBlock(ImmutableList.of(DoubleType.DOUBLE), new Block[]{createDoublesBlock})}), CONTAINS_CHANNELS));
        Block createDoublesBlock2 = BlockAssertions.createDoublesBlock(Double.valueOf(11.0d));
        Assert.assertFalse(createGroupByHash.contains(0, new Page(new Block[]{createDoublesBlock2, TypeUtils.getHashBlock(ImmutableList.of(DoubleType.DOUBLE), new Block[]{createDoublesBlock2})}), CONTAINS_CHANNELS));
    }

    @Test
    public void testContainsMultipleColumns() throws Exception {
        Block createDoubleSequenceBlock = BlockAssertions.createDoubleSequenceBlock(0, 10);
        Block createStringSequenceBlock = BlockAssertions.createStringSequenceBlock(0, 10);
        Block hashBlock = TypeUtils.getHashBlock(ImmutableList.of(DoubleType.DOUBLE, VarcharType.VARCHAR), new Block[]{createDoubleSequenceBlock, createStringSequenceBlock});
        int[] iArr = {0, 1};
        GroupByHash createGroupByHash = GroupByHash.createGroupByHash(TEST_SESSION, ImmutableList.of(DoubleType.DOUBLE, VarcharType.VARCHAR), iArr, Optional.of(2), 100, JOIN_COMPILER);
        createGroupByHash.getGroupIds(new Page(new Block[]{createDoubleSequenceBlock, createStringSequenceBlock, hashBlock}));
        Block createDoublesBlock = BlockAssertions.createDoublesBlock(Double.valueOf(3.0d));
        Block createStringsBlock = BlockAssertions.createStringsBlock("3");
        Assert.assertTrue(createGroupByHash.contains(0, new Page(new Block[]{createDoublesBlock, createStringsBlock, TypeUtils.getHashBlock(ImmutableList.of(DoubleType.DOUBLE, VarcharType.VARCHAR), new Block[]{createDoublesBlock, createStringsBlock})}), iArr));
    }

    @Test
    public void testForceRehash() throws Exception {
        Block createStringSequenceBlock = BlockAssertions.createStringSequenceBlock(0, 100);
        Block hashBlock = TypeUtils.getHashBlock(ImmutableList.of(VarcharType.VARCHAR), new Block[]{createStringSequenceBlock});
        GroupByHash createGroupByHash = GroupByHash.createGroupByHash(TEST_SESSION, ImmutableList.of(VarcharType.VARCHAR), new int[]{0}, Optional.of(1), 4, JOIN_COMPILER);
        createGroupByHash.getGroupIds(new Page(new Block[]{createStringSequenceBlock, hashBlock}));
        for (int i = 0; i < createStringSequenceBlock.getPositionCount(); i++) {
            Assert.assertTrue(createGroupByHash.contains(i, new Page(new Block[]{createStringSequenceBlock, hashBlock}), CONTAINS_CHANNELS));
        }
    }
}
