package de.julielab.ml.embeddings;

import de.julielab.ml.embeddings.spi.EmbeddingVectors;
import de.julielab.ml.embeddings.spi.WordEmbedding;
import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.models.word2vec.Word2Vec;

/* loaded from: input_file:de/julielab/ml/embeddings/Dl4jWordEmbedding.class */
public class Dl4jWordEmbedding implements WordEmbedding {
    private static final long serialVersionUID = 4840675668587884047L;
    private Word2Vec word2vec;
    private static final double[] EMPTY = new double[0];

    public Dl4jWordEmbedding(Word2Vec word2Vec) {
        this.word2vec = word2Vec;
    }

    public double[] getWordVector(String str) {
        double[] wordVector = this.word2vec.getWordVector(str);
        return wordVector != null ? wordVector : EMPTY;
    }

    public EmbeddingVectors getWordVectors(List<String> list) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            String str = list.get(i);
            if (hasWord(str)) {
                arrayList.add(str);
                arrayList2.add(Integer.valueOf(i));
            }
        }
        return new EmbeddingVectors(this.word2vec.getWordVectors(arrayList), list, arrayList2, getEmbeddingDimensions(), EmbeddingVectors.StreamType.CONCATENATION);
    }

    public EmbeddingVectors getWordVectorsMean(List<String> list) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            String str = list.get(i);
            if (hasWord(str)) {
                arrayList.add(str);
                arrayList2.add(Integer.valueOf(i));
            }
        }
        return new EmbeddingVectors(this.word2vec.getWordVectorsMean(arrayList), list, arrayList2, getEmbeddingDimensions(), EmbeddingVectors.StreamType.AGGREGATION);
    }

    public int getVocabularySize() {
        return this.word2vec.getVocab().numWords();
    }

    public String getWord(int i) {
        return this.word2vec.getVocab().wordAtIndex(i);
    }

    public boolean hasWord(String str) {
        return this.word2vec.getVocab().hasToken(str);
    }

    public int getEmbeddingDimensions() {
        return this.word2vec.getLayerSize();
    }
}
