package de.julielab.ml.embeddings.spi;

import de.julielab.ml.embeddings.spi.EmbeddingVectors;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:de/julielab/ml/embeddings/spi/SkipVectorIndexStream.class */
public class SkipVectorIndexStream {
    private List<String> queryWords;
    private int embeddingDimensions;
    private INDArray vectors;
    private double surrogateValue;
    private List<Integer> foundIndices;
    private int currentIndex;
    private String currentWord;
    private int currentFoundIndex;
    private int indexOfCurrentFoundIndex;
    private int streamLength;
    private EmbeddingVectors.StreamType type;

    public SkipVectorIndexStream(List<String> list, List<Integer> list2, INDArray iNDArray, int i, double d, EmbeddingVectors.StreamType streamType) {
        this.queryWords = list;
        this.foundIndices = list2;
        this.vectors = iNDArray;
        this.embeddingDimensions = i;
        this.surrogateValue = d;
        this.type = streamType;
        this.streamLength = streamType == EmbeddingVectors.StreamType.CONCATENATION ? list.size() * i : i;
        this.currentIndex = -1;
        setState();
    }

    public boolean hasNext() {
        return this.currentIndex + 1 < this.streamLength;
    }

    public boolean forward() {
        this.currentIndex++;
        if (this.currentIndex >= this.streamLength) {
            return false;
        }
        if (this.currentIndex % this.embeddingDimensions == 0) {
            setState();
        }
        return this.currentIndex < this.streamLength;
    }

    public void setState() {
        if (this.type != EmbeddingVectors.StreamType.CONCATENATION) {
            if (this.type == EmbeddingVectors.StreamType.AGGREGATION) {
                this.currentWord = "<aggregation of " + ((String) this.queryWords.stream().collect(Collectors.joining(", "))) + ">";
                if (this.foundIndices.isEmpty()) {
                    this.currentFoundIndex = -1;
                    return;
                } else {
                    this.currentFoundIndex = 0;
                    return;
                }
            }
            return;
        }
        int i = this.currentIndex / this.embeddingDimensions;
        if (!this.queryWords.isEmpty()) {
            this.currentWord = this.queryWords.get(i);
        }
        this.indexOfCurrentFoundIndex = Collections.binarySearch(this.foundIndices, Integer.valueOf(i));
        if (this.indexOfCurrentFoundIndex >= 0) {
            this.currentFoundIndex = this.foundIndices.get(this.indexOfCurrentFoundIndex).intValue();
        } else {
            this.currentFoundIndex = -1;
        }
    }

    public int getIndex() {
        return this.currentIndex;
    }

    public double getValue() {
        return this.currentFoundIndex != -1 ? this.vectors.getRow(this.indexOfCurrentFoundIndex).getDouble(this.currentIndex - (this.currentFoundIndex * this.embeddingDimensions)) : this.surrogateValue;
    }

    public String getWord() {
        return this.currentWord;
    }
}
