package com.facebook.presto.operator;

import com.facebook.presto.spi.Page;
import com.facebook.presto.spi.PageBuilder;
import com.facebook.presto.spi.block.Block;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.BooleanType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.gen.JoinCompiler;
import com.facebook.presto.util.array.LongBigArray;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.airlift.slice.SizeOf;
import it.unimi.dsi.fastutil.HashCommon;
import it.unimi.dsi.fastutil.objects.ObjectArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:com/facebook/presto/operator/MultiChannelGroupByHash.class */
public class MultiChannelGroupByHash implements GroupByHash {
    private static final JoinCompiler JOIN_COMPILER = new JoinCompiler();
    private static final float FILL_RATIO = 0.9f;
    private final List<Type> types;
    private final int[] channels;
    private final PagesHashStrategy hashStrategy;
    private final List<ObjectArrayList<Block>> channelBuilders;
    private final HashGenerator hashGenerator;
    private final Optional<Integer> precomputedHashChannel;
    private final int maskChannel;
    private PageBuilder currentPageBuilder;
    private long completedPagesMemorySize;
    private int maxFill;
    private int mask;
    private long[] groupAddressByHash;
    private int[] groupIdsByHash;
    private final LongBigArray groupAddressByGroupId;
    private int nextGroupId;

    public MultiChannelGroupByHash(List<? extends Type> list, int[] iArr, Optional<Integer> optional, Optional<Integer> optional2, int i) {
        Objects.requireNonNull(list, "hashTypes is null");
        Preconditions.checkArgument(list.size() == iArr.length, "hashTypes and hashChannels have different sizes");
        Objects.requireNonNull(optional2, "inputHashChannel is null");
        Preconditions.checkArgument(i > 0, "expectedSize must be greater than zero");
        this.types = optional2.isPresent() ? ImmutableList.copyOf(Iterables.concat(list, ImmutableList.of(BigintType.BIGINT))) : ImmutableList.copyOf(list);
        this.channels = (int[]) ((int[]) Objects.requireNonNull(iArr, "hashChannels is null")).clone();
        this.maskChannel = ((Integer) ((Optional) Objects.requireNonNull(optional, "maskChannel is null")).orElse(-1)).intValue();
        this.hashGenerator = optional2.isPresent() ? new PrecomputedHashGenerator(optional2.get().intValue()) : new InterpretedHashGenerator(ImmutableList.copyOf(list), iArr);
        ImmutableList.Builder builder = ImmutableList.builder();
        ImmutableList.Builder builder2 = ImmutableList.builder();
        for (int i2 = 0; i2 < iArr.length; i2++) {
            builder.add(Integer.valueOf(i2));
            builder2.add(ObjectArrayList.wrap(new Block[1024], 0));
        }
        if (optional2.isPresent()) {
            this.precomputedHashChannel = Optional.of(Integer.valueOf(iArr.length));
            builder2.add(ObjectArrayList.wrap(new Block[1024], 0));
        } else {
            this.precomputedHashChannel = Optional.empty();
        }
        this.channelBuilders = builder2.build();
        this.hashStrategy = JOIN_COMPILER.compilePagesHashStrategyFactory(this.types, builder.build()).createPagesHashStrategy(this.channelBuilders, this.precomputedHashChannel);
        startNewPage();
        int arraySize = HashCommon.arraySize(i, FILL_RATIO);
        this.maxFill = calculateMaxFill(arraySize);
        this.mask = arraySize - 1;
        this.groupAddressByHash = new long[arraySize];
        Arrays.fill(this.groupAddressByHash, -1L);
        this.groupIdsByHash = new int[arraySize];
        this.groupAddressByGroupId = new LongBigArray();
        this.groupAddressByGroupId.ensureCapacity(this.maxFill);
    }

    @Override // com.facebook.presto.operator.GroupByHash
    public long getEstimatedSize() {
        return (SizeOf.sizeOf(this.channelBuilders.get(0).elements()) * this.channelBuilders.size()) + this.completedPagesMemorySize + this.currentPageBuilder.getRetainedSizeInBytes() + SizeOf.sizeOf(this.groupAddressByHash) + SizeOf.sizeOf(this.groupIdsByHash) + this.groupAddressByGroupId.sizeOf();
    }

    @Override // com.facebook.presto.operator.GroupByHash
    public List<Type> getTypes() {
        return this.types;
    }

    @Override // com.facebook.presto.operator.GroupByHash
    public int getGroupCount() {
        return this.nextGroupId;
    }

    @Override // com.facebook.presto.operator.GroupByHash
    public void appendValuesTo(int i, PageBuilder pageBuilder, int i2) {
        long j = this.groupAddressByGroupId.get(i);
        this.hashStrategy.appendTo(SyntheticAddress.decodeSliceIndex(j), SyntheticAddress.decodePosition(j), pageBuilder, i2);
    }

    @Override // com.facebook.presto.operator.GroupByHash
    public void addPage(Page page) {
        Block[] extractHashColumns = extractHashColumns(page);
        Block block = this.maskChannel >= 0 ? page.getBlock(this.maskChannel) : null;
        int positionCount = page.getPositionCount();
        for (int i = 0; i < positionCount; i++) {
            if (block == null || BooleanType.BOOLEAN.getBoolean(block, i)) {
                putIfAbsent(i, page, extractHashColumns);
            }
        }
    }

    @Override // com.facebook.presto.operator.GroupByHash
    public GroupByIdBlock getGroupIds(Page page) {
        int positionCount = page.getPositionCount();
        BlockBuilder createFixedSizeBlockBuilder = BigintType.BIGINT.createFixedSizeBlockBuilder(positionCount);
        Block block = this.maskChannel >= 0 ? page.getBlock(this.maskChannel) : null;
        Block[] extractHashColumns = extractHashColumns(page);
        for (int i = 0; i < positionCount; i++) {
            if (block == null || BooleanType.BOOLEAN.getBoolean(block, i)) {
                BigintType.BIGINT.writeLong(createFixedSizeBlockBuilder, putIfAbsent(i, page, extractHashColumns));
            } else {
                createFixedSizeBlockBuilder.appendNull();
            }
        }
        return new GroupByIdBlock(this.nextGroupId, createFixedSizeBlockBuilder.build());
    }

    @Override // com.facebook.presto.operator.GroupByHash
    public boolean contains(int i, Page page) {
        int hashPosition = getHashPosition(this.hashStrategy.hashRow(i, page.getBlocks()), this.mask);
        while (true) {
            int i2 = hashPosition;
            if (this.groupAddressByHash[i2] == -1) {
                return false;
            }
            long j = this.groupAddressByHash[i2];
            if (this.hashStrategy.positionEqualsRow(SyntheticAddress.decodeSliceIndex(j), SyntheticAddress.decodePosition(j), i, page.getBlocks())) {
                return true;
            }
            hashPosition = (i2 + 1) & this.mask;
        }
    }

    @Override // com.facebook.presto.operator.GroupByHash
    public int putIfAbsent(int i, Page page) {
        return putIfAbsent(i, page, extractHashColumns(page));
    }

    private int putIfAbsent(int i, Page page, Block[] blockArr) {
        int hashPosition = this.hashGenerator.hashPosition(i, page);
        int hashPosition2 = getHashPosition(hashPosition, this.mask);
        int i2 = -1;
        while (true) {
            if (this.groupAddressByHash[hashPosition2] == -1) {
                break;
            }
            long j = this.groupAddressByHash[hashPosition2];
            if (positionEqualsCurrentRow(SyntheticAddress.decodeSliceIndex(j), SyntheticAddress.decodePosition(j), i, blockArr)) {
                i2 = this.groupIdsByHash[hashPosition2];
                break;
            }
            hashPosition2 = (hashPosition2 + 1) & this.mask;
        }
        if (i2 < 0) {
            i2 = addNewGroup(hashPosition2, i, page, hashPosition);
        }
        return i2;
    }

    private int addNewGroup(int i, int i2, Page page, int i3) {
        Block[] blocks = page.getBlocks();
        for (int i4 = 0; i4 < this.channels.length; i4++) {
            this.types.get(i4).appendTo(blocks[this.channels[i4]], i2, this.currentPageBuilder.getBlockBuilder(i4));
        }
        if (this.precomputedHashChannel.isPresent()) {
            BigintType.BIGINT.writeLong(this.currentPageBuilder.getBlockBuilder(this.precomputedHashChannel.get().intValue()), i3);
        }
        this.currentPageBuilder.declarePosition();
        long encodeSyntheticAddress = SyntheticAddress.encodeSyntheticAddress(this.channelBuilders.get(0).size() - 1, this.currentPageBuilder.getPositionCount() - 1);
        int i5 = this.nextGroupId;
        this.nextGroupId = i5 + 1;
        this.groupAddressByHash[i] = encodeSyntheticAddress;
        this.groupIdsByHash[i] = i5;
        this.groupAddressByGroupId.set(i5, encodeSyntheticAddress);
        if (this.currentPageBuilder.isFull()) {
            startNewPage();
        }
        if (this.nextGroupId >= this.maxFill) {
            rehash(this.maxFill * 2);
        }
        return i5;
    }

    private void startNewPage() {
        if (this.currentPageBuilder != null) {
            this.completedPagesMemorySize += this.currentPageBuilder.getRetainedSizeInBytes();
        }
        this.currentPageBuilder = new PageBuilder(this.types);
        for (int i = 0; i < this.types.size(); i++) {
            this.channelBuilders.get(i).add(this.currentPageBuilder.getBlockBuilder(i));
        }
    }

    private void rehash(int i) {
        int i2;
        int arraySize = HashCommon.arraySize(i + 1, FILL_RATIO);
        int i3 = arraySize - 1;
        long[] jArr = new long[arraySize];
        Arrays.fill(jArr, -1L);
        int[] iArr = new int[arraySize];
        int i4 = 0;
        for (int i5 = 0; i5 < this.nextGroupId; i5++) {
            while (this.groupAddressByHash[i4] == -1) {
                i4++;
            }
            long j = this.groupAddressByHash[i4];
            int hashPosition = getHashPosition(hashPosition(j), i3);
            while (true) {
                i2 = hashPosition;
                if (jArr[i2] != -1) {
                    hashPosition = (i2 + 1) & i3;
                }
            }
            jArr[i2] = j;
            iArr[i2] = this.groupIdsByHash[i4];
            i4++;
        }
        this.mask = i3;
        this.maxFill = calculateMaxFill(arraySize);
        this.groupAddressByHash = jArr;
        this.groupIdsByHash = iArr;
        this.groupAddressByGroupId.ensureCapacity(this.maxFill);
    }

    private Block[] extractHashColumns(Page page) {
        Block[] blockArr = new Block[this.channels.length];
        for (int i = 0; i < this.channels.length; i++) {
            blockArr[i] = page.getBlock(this.channels[i]);
        }
        return blockArr;
    }

    private int hashPosition(long j) {
        int decodeSliceIndex = SyntheticAddress.decodeSliceIndex(j);
        int decodePosition = SyntheticAddress.decodePosition(j);
        return this.precomputedHashChannel.isPresent() ? getRawHash(decodeSliceIndex, decodePosition) : this.hashStrategy.hashPosition(decodeSliceIndex, decodePosition);
    }

    private int getRawHash(int i, int i2) {
        return (int) ((Block) this.channelBuilders.get(this.precomputedHashChannel.get().intValue()).get(i)).getLong(i2, 0);
    }

    private boolean positionEqualsCurrentRow(int i, int i2, int i3, Block[] blockArr) {
        return this.hashStrategy.positionEqualsRow(i, i2, i3, blockArr);
    }

    private static int getHashPosition(int i, int i2) {
        return HashCommon.murmurHash3(i) & i2;
    }

    private static int calculateMaxFill(int i) {
        Preconditions.checkArgument(i > 0, "hashSize must greater than 0");
        int ceil = (int) Math.ceil(i * FILL_RATIO);
        if (ceil == i) {
            ceil--;
        }
        Preconditions.checkArgument(i > ceil, "hashSize must be larger than maxFill");
        return ceil;
    }
}
