package de.datexis.index.impl;

import de.datexis.common.Resource;
import de.datexis.index.ArticleRef;
import de.datexis.index.encoder.EntityEncoder;
import java.io.BufferedOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import org.apache.lucene.index.IndexReader;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.VocabularyHolder;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/index/impl/KNNArticleIndex.class */
public class KNNArticleIndex extends LuceneArticleIndex {
    protected static final Logger log = LoggerFactory.getLogger(KNNArticleIndex.class);
    protected ParagraphVectors parvec;
    EntityEncoder encoder;
    protected VocabCache<VocabWord> vocabCache;
    protected WeightLookupTable<VocabWord> lookupVectors = null;
    protected ModelUtils<VocabWord> lookupUtils;

    public KNNArticleIndex(Resource resource) throws IOException {
        this.encoder = new EntityEncoder(resource, EntityEncoder.Strategy.NAME);
        generateLookupCache();
    }

    protected void generateLookupCache() {
        log.debug("building entity list....");
        VocabularyHolder build = new VocabularyHolder.Builder().build();
        this.vocabCache = new InMemoryLookupCache();
        this.lookupUtils = new BasicModelUtils();
        try {
            IndexReader indexReader = this.searcher.getIndexReader();
            ArrayList<ArticleRef> arrayList = new ArrayList(indexReader.maxDoc());
            for (int i = 0; i < indexReader.maxDoc(); i++) {
                String id = createWikidataArticleRef(indexReader.document(i)).getId();
                if (build.containsWord(id)) {
                    build.incrementWordCounter(id);
                } else {
                    build.addWord(id);
                }
            }
            build.updateHuffmanCodes();
            build.transferBackToVocabCache(this.vocabCache, true);
            this.lookupVectors = new InMemoryLookupTable(this.vocabCache, (int) this.encoder.getEmbeddingVectorSize(), true, 0.01d, Nd4j.getRandom(), 0.0d, true);
            this.lookupVectors.resetWeights();
            int i2 = 0;
            for (ArticleRef articleRef : arrayList) {
                this.lookupVectors.putVector(articleRef.getId(), this.encoder.encodeEntity(articleRef));
                i2++;
                if (i2 % 100000 == 0) {
                    log.info("inserted " + i2 + " vectors into lookup table");
                }
            }
            log.info("generated " + arrayList.size() + " entity vectors");
            this.lookupUtils.init(this.lookupVectors);
            log.info("initialized lookup tables");
        } catch (IOException e) {
            log.error(e.toString());
        }
    }

    public void saveModel(Resource resource, String str) throws IOException {
        writeBinaryModel(this.lookupVectors, resource.resolve(str + "_lookup.bin").getOutputStream());
    }

    private static void writeBinaryModel(WeightLookupTable<VocabWord> weightLookupTable, OutputStream outputStream) throws IOException {
        int i = 0;
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(outputStream);
        Throwable th = null;
        try {
            DataOutputStream dataOutputStream = new DataOutputStream(bufferedOutputStream);
            Throwable th2 = null;
            try {
                try {
                    for (String str : weightLookupTable.getVocabCache().words()) {
                        if (str != null) {
                            INDArray vector = weightLookupTable.vector(str);
                            log.trace("Write: " + str + " (size " + vector.length() + ")");
                            dataOutputStream.writeUTF(str);
                            Nd4j.write(vector, dataOutputStream);
                            i++;
                        }
                    }
                    dataOutputStream.flush();
                    if (dataOutputStream != null) {
                        if (0 != 0) {
                            try {
                                dataOutputStream.close();
                            } catch (Throwable th3) {
                                th2.addSuppressed(th3);
                            }
                        } else {
                            dataOutputStream.close();
                        }
                    }
                    log.info("Wrote " + i + " words with size " + weightLookupTable.layerSize());
                } finally {
                }
            } catch (Throwable th4) {
                if (dataOutputStream != null) {
                    if (th2 != null) {
                        try {
                            dataOutputStream.close();
                        } catch (Throwable th5) {
                            th2.addSuppressed(th5);
                        }
                    } else {
                        dataOutputStream.close();
                    }
                }
                throw th4;
            }
        } finally {
            if (bufferedOutputStream != null) {
                if (0 != 0) {
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    bufferedOutputStream.close();
                }
            }
        }
    }

    public List<ArticleRef> querySimilarArticles(String str, int i) {
        ArrayList arrayList = new ArrayList(i);
        Iterator it = this.lookupUtils.wordsNearest(str, i).iterator();
        while (it.hasNext()) {
            Optional<ArticleRef> queryWikidataID = queryWikidataID((String) it.next());
            if (queryWikidataID.isPresent()) {
                arrayList.add(queryWikidataID.get());
            }
        }
        return arrayList;
    }

    public List<ArticleRef> querySimilarArticles(String str, String str2, int i) {
        ArrayList arrayList = new ArrayList(i);
        Iterator it = this.lookupUtils.wordsNearest(Nd4j.hstack(new INDArray[]{this.encoder.encode(str), this.encoder.encode(str2)}), i).iterator();
        while (it.hasNext()) {
            Optional<ArticleRef> queryWikidataID = queryWikidataID((String) it.next());
            if (queryWikidataID.isPresent()) {
                arrayList.add(queryWikidataID.get());
            }
        }
        return arrayList;
    }

    public List<ArticleRef> querySimilarArticles(INDArray iNDArray, int i) {
        ArrayList arrayList = new ArrayList(i);
        Iterator it = this.lookupUtils.wordsNearest(iNDArray, i).iterator();
        while (it.hasNext()) {
            Optional<ArticleRef> queryWikidataID = queryWikidataID((String) it.next());
            if (queryWikidataID.isPresent()) {
                arrayList.add(queryWikidataID.get());
            }
        }
        return arrayList;
    }
}
