package de.datexis.retrieval.index;

import com.google.common.collect.Lists;
import com.google.common.collect.Multimap;
import de.datexis.common.Resource;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.EncodingHelpers;
import de.datexis.encoder.IEncoder;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.preprocess.IdentityPreprocessor;
import de.datexis.retrieval.encoder.LSTMSentenceEncoder;
import de.datexis.retrieval.index.IVectorIndex;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.stream.Collectors;
import org.apache.commons.lang.Validate;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/retrieval/index/InMemoryIndex.class */
public class InMemoryIndex extends Encoder implements IEncoder, IVocabulary, IVectorIndex {
    protected static final Logger log = LoggerFactory.getLogger(InMemoryIndex.class);
    protected IEncoder encoder;
    protected AbstractCache<VocabWord> keyVocabulary;
    protected InMemoryLookupTable<VocabWord> lookupVectors;
    protected TokenPreProcess keyPreprocessor;

    protected InMemoryIndex() {
    }

    public InMemoryIndex(IEncoder iEncoder) {
        this(new IdentityPreprocessor(), iEncoder);
    }

    public InMemoryIndex(TokenPreProcess tokenPreProcess, IEncoder iEncoder) {
        super("KNN");
        this.encoder = iEncoder;
        this.keyPreprocessor = tokenPreProcess;
        this.keyVocabulary = new AbstractCache.Builder().hugeModelExpected(false).minElementFrequency(0).build();
        this.lookupVectors = new InMemoryLookupTable<>(this.keyVocabulary, (int) getEmbeddingVectorSize(), true, 0.01d, Nd4j.getRandom(), 0.0d, true);
    }

    public void trainModel(Collection<Document> collection) {
        throw new UnsupportedOperationException("not implemented");
    }

    public void buildKeyIndex(Iterable<String> iterable) {
        buildKeyIndex(iterable, true);
    }

    public void buildKeyIndex(Iterable<String> iterable, boolean z) {
        log.info("Building key index...");
        int i = 0;
        Iterator<String> it = iterable.iterator();
        while (it.hasNext()) {
            String next = it.next();
            if (z) {
                next = this.keyPreprocessor.preProcess(next);
            }
            if (this.keyVocabulary.containsWord(next)) {
                this.keyVocabulary.incrementWordCount(next);
            } else {
                int i2 = i;
                i++;
                VocabWord vocabWord = new VocabWord(1.0d, next, i2);
                vocabWord.setSpecial(false);
                vocabWord.markAsLabel(true);
                vocabWord.setIndex(this.keyVocabulary.numWords());
                this.keyVocabulary.addToken(vocabWord);
                this.keyVocabulary.addWordToIndex(vocabWord.getIndex(), vocabWord.getLabel());
            }
        }
    }

    public void buildVectorIndex(Map<String, INDArray> map, boolean z) {
        log.info("Building vector index for {} entries...", Integer.valueOf(map.size()));
        if (size() <= 0) {
            throw new IllegalStateException("Cannot insert vectors into empty index. Please insert keys first.");
        }
        this.lookupVectors.resetWeights();
        long j = 0;
        for (Map.Entry<String, INDArray> entry : map.entrySet()) {
            this.lookupVectors.putVector(z ? this.keyPreprocessor.preProcess(entry.getKey()) : entry.getKey(), entry.getValue());
            long j2 = j + 1;
            j = j2;
            if (j2 % 100000 == 0) {
                log.info("inserted {} vectors into vector index", Long.valueOf(j));
            }
        }
        INDArray syn0 = this.lookupVectors.getSyn0();
        syn0.diviColumnVector(syn0.norm2(new int[]{1}));
    }

    public void encodeAndBuildVectorIndex(Map<String, String> map, boolean z) {
        log.info("Building vector index for {} entries...", Integer.valueOf(map.size()));
        if (size() <= 0) {
            throw new IllegalStateException("Cannot insert vectors into empty index. Please insert keys first.");
        }
        this.lookupVectors.resetWeights();
        long j = 0;
        for (Map.Entry<String, String> entry : map.entrySet()) {
            this.lookupVectors.putVector(this.keyPreprocessor.preProcess(z ? this.keyPreprocessor.preProcess(entry.getKey()) : entry.getKey()), this.encoder.encode(entry.getValue()));
            long j2 = j + 1;
            j = j2;
            if (j2 % 100000 == 0) {
                log.info("inserted {} vectors into vector index", Long.valueOf(j));
            }
        }
        INDArray syn0 = this.lookupVectors.getSyn0();
        syn0.diviColumnVector(syn0.norm2(new int[]{1}));
    }

    /* JADX WARN: Type inference failed for: r0v78, types: [org.nd4j.linalg.api.ndarray.INDArray, java.lang.Object] */
    /* JADX WARN: Type inference failed for: r20v0, types: [org.nd4j.linalg.api.ndarray.INDArray] */
    public void encodeAndBuildVectorIndex(Multimap<String, ? extends Span> multimap, boolean z) {
        INDArray lookupBatchMatrix;
        log.info("Building vector index for {} entries...", Integer.valueOf(multimap.keySet().size()));
        if (size() <= 0) {
            throw new IllegalStateException("Cannot insert vectors into empty index. Please insert keys first.");
        }
        this.lookupVectors.resetWeights();
        long j = 0;
        Nd4j.getMemoryManager().togglePeriodicGc(false);
        HashMap hashMap = new HashMap(size());
        for (List list : Lists.partition(Lists.newArrayList(multimap.entries()), 128)) {
            if (this.encoder instanceof LSTMSentenceEncoder) {
                lookupBatchMatrix = this.encoder.getTagger().encodeBatchMatrix((List) list.stream().map(entry -> {
                    return (Sentence) entry.getValue();
                }).collect(Collectors.toList()));
            } else {
                List list2 = (List) list.stream().map(entry2 -> {
                    return (Span) entry2.getValue();
                }).collect(Collectors.toList());
                lookupBatchMatrix = (z && (this.encoder instanceof InMemoryIndex)) ? lookupBatchMatrix(list2, (InMemoryIndex) this.encoder) : EncodingHelpers.encodeBatchMatrix(list2, this.encoder);
            }
            for (int i = 0; i < list.size(); i++) {
                String str = (String) ((Map.Entry) list.get(i)).getKey();
                ?? r0 = (INDArray) hashMap.getOrDefault(str, Nd4j.zeros(DataType.FLOAT, new long[]{this.encoder.getEmbeddingVectorSize(), 1}));
                r0.addi(lookupBatchMatrix.getRow(i).reshape(this.encoder.getEmbeddingVectorSize(), 1L));
                hashMap.put(str, r0);
                long j2 = j + 1;
                j = r0;
                if (j2 % 10000 == 0) {
                    log.info("encoded {} vectors", Long.valueOf(j));
                }
            }
        }
        long j3 = 0;
        for (String str2 : multimap.keySet()) {
            INDArray iNDArray = (INDArray) hashMap.get(str2);
            int size = multimap.get(str2).size();
            if (iNDArray != null) {
                ?? divi = size > 1 ? iNDArray.divi(Integer.valueOf(size)) : iNDArray;
                this.lookupVectors.putVector(str2, (INDArray) divi);
                long j4 = j3 + 1;
                j3 = divi;
                if (j4 % 100000 == 0) {
                    log.info("inserted {} vectors into vector index", Long.valueOf(j3));
                }
            }
        }
        INDArray syn0 = this.lookupVectors.getSyn0();
        syn0.diviColumnVector(syn0.norm2(new int[]{1}));
    }

    public static INDArray lookupBatchMatrix(List<? extends Span> list, InMemoryIndex inMemoryIndex) {
        INDArray zeros = Nd4j.zeros(DataType.FLOAT, new long[]{list.size(), inMemoryIndex.getEmbeddingVectorSize()});
        for (int i = 0; i < list.size(); i++) {
            Span span = list.get(i);
            INDArray lookup = inMemoryIndex.lookup(span.getText());
            if (lookup == null) {
                log.debug("fallback encoding '{}' during lookup", span.getText());
                lookup = inMemoryIndex.encode(span.getText());
            }
            zeros.get(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all()}).assign(lookup);
        }
        return zeros;
    }

    public void saveModel(Resource resource, String str) throws IOException {
        Resource resolve = resource.resolve(str + ".bin");
        writeBinaryModel(resolve.getOutputStream());
        setModel(resolve);
    }

    public void loadModel(Resource resource) throws IOException {
        loadBinaryModel(resource.getInputStream());
        setModel(resource);
        setModelAvailable(true);
    }

    private void writeBinaryModel(OutputStream outputStream) throws IOException {
        int i = 0;
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(outputStream);
        Throwable th = null;
        try {
            DataOutputStream dataOutputStream = new DataOutputStream(bufferedOutputStream);
            Throwable th2 = null;
            try {
                int numWords = this.keyVocabulary.numWords();
                dataOutputStream.writeLong(numWords);
                dataOutputStream.writeLong(this.keyVocabulary.totalNumberOfDocs());
                dataOutputStream.writeLong(this.lookupVectors.layerSize());
                for (int i2 = 0; i2 < numWords; i2++) {
                    VocabWord elementAtIndex = this.keyVocabulary.elementAtIndex(i2);
                    dataOutputStream.writeUTF(elementAtIndex.getLabel());
                    dataOutputStream.writeDouble(elementAtIndex.getElementFrequency());
                    i++;
                }
                for (int i3 = 0; i3 < numWords; i3++) {
                    Nd4j.write(this.lookupVectors.vector(this.keyVocabulary.elementAtIndex(i3).getLabel()), dataOutputStream);
                }
                dataOutputStream.flush();
                if (dataOutputStream != null) {
                    if (0 != 0) {
                        try {
                            dataOutputStream.close();
                        } catch (Throwable th3) {
                            th2.addSuppressed(th3);
                        }
                    } else {
                        dataOutputStream.close();
                    }
                }
                log.info("Wrote {} entries with vector size {}", Integer.valueOf(i), Integer.valueOf(this.lookupVectors.layerSize()));
            } catch (Throwable th4) {
                if (dataOutputStream != null) {
                    if (0 != 0) {
                        try {
                            dataOutputStream.close();
                        } catch (Throwable th5) {
                            th2.addSuppressed(th5);
                        }
                    } else {
                        dataOutputStream.close();
                    }
                }
                throw th4;
            }
        } finally {
            if (bufferedOutputStream != null) {
                if (0 != 0) {
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    bufferedOutputStream.close();
                }
            }
        }
    }

    /* JADX WARN: Finally extract failed */
    private void loadBinaryModel(InputStream inputStream) throws IOException {
        int i = 0;
        BufferedInputStream bufferedInputStream = new BufferedInputStream(inputStream);
        Throwable th = null;
        try {
            DataInputStream dataInputStream = new DataInputStream(bufferedInputStream);
            Throwable th2 = null;
            try {
                long readLong = dataInputStream.readLong();
                dataInputStream.readLong();
                long readLong2 = dataInputStream.readLong();
                this.keyVocabulary = new AbstractCache.Builder().hugeModelExpected(false).minElementFrequency(0).build();
                this.lookupVectors = new InMemoryLookupTable<>(this.keyVocabulary, (int) readLong2, true, 0.01d, Nd4j.getRandom(), 0.0d, true);
                for (int i2 = 0; i2 < readLong; i2++) {
                    if (dataInputStream.available() <= 0) {
                        throw new IOException("binary file truncated");
                    }
                    int i3 = i;
                    i++;
                    VocabWord vocabWord = new VocabWord(dataInputStream.readDouble(), dataInputStream.readUTF(), i3);
                    vocabWord.setSpecial(false);
                    vocabWord.markAsLabel(true);
                    vocabWord.setIndex(this.keyVocabulary.numWords());
                    this.keyVocabulary.addToken(vocabWord);
                    this.keyVocabulary.addWordToIndex(vocabWord.getIndex(), vocabWord.getLabel());
                    if (i % 100000 == 0) {
                        log.info("loaded {} keys into word index", Integer.valueOf(i2));
                    }
                }
                this.keyVocabulary.updateWordsOccurrences();
                int i4 = 0;
                this.lookupVectors.resetWeights();
                for (int i5 = 0; i5 < readLong; i5++) {
                    if (dataInputStream.available() <= 0) {
                        throw new IOException("binary file truncated");
                    }
                    this.lookupVectors.putVector(this.keyVocabulary.wordAtIndex(i5), Nd4j.read(dataInputStream));
                    i4++;
                    if (i4 % 100000 == 0) {
                        log.info("loaded {} vectors into vector index", Integer.valueOf(i4));
                    }
                }
                if (dataInputStream != null) {
                    if (0 != 0) {
                        try {
                            dataInputStream.close();
                        } catch (Throwable th3) {
                            th2.addSuppressed(th3);
                        }
                    } else {
                        dataInputStream.close();
                    }
                }
                log.info("Read {} entries with vector size {}", Integer.valueOf(i4), Integer.valueOf(this.lookupVectors.layerSize()));
            } catch (Throwable th4) {
                if (dataInputStream != null) {
                    if (0 != 0) {
                        try {
                            dataInputStream.close();
                        } catch (Throwable th5) {
                            th2.addSuppressed(th5);
                        }
                    } else {
                        dataInputStream.close();
                    }
                }
                throw th4;
            }
        } finally {
            if (bufferedInputStream != null) {
                if (0 != 0) {
                    try {
                        bufferedInputStream.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    bufferedInputStream.close();
                }
            }
        }
    }

    @JsonIgnore
    public IEncoder getEncoder() {
        return this.encoder;
    }

    public void setEncoder(IEncoder iEncoder) {
        this.encoder = iEncoder;
    }

    @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "class")
    public TokenPreProcess getKeyPreprocessor() {
        return this.keyPreprocessor;
    }

    public void setKeyPreprocessor(TokenPreProcess tokenPreProcess) {
        this.keyPreprocessor = tokenPreProcess;
    }

    public void setEncoders(List<Encoder> list) {
        if (list.size() != 1) {
            throw new IllegalArgumentException("wrong number of encoders given (expected=1, actual=" + list.size() + ")");
        }
        this.encoder = list.get(0);
    }

    public List<Encoder> getEncoders() {
        return Lists.newArrayList(new Encoder[]{(Encoder) this.encoder});
    }

    public long getEmbeddingVectorSize() {
        return this.encoder.getEmbeddingVectorSize();
    }

    public INDArray encode(String str) {
        return Transforms.unitVec(this.encoder.encode(str));
    }

    public INDArray encode(Span span) {
        return Transforms.unitVec(this.encoder.encode(span));
    }

    public INDArray encode(Iterable<? extends Span> iterable) {
        return Transforms.unitVec(this.encoder.encode(iterable));
    }

    @Override // de.datexis.retrieval.index.IVocabulary
    public INDArray lookup(String str) {
        INDArray vector;
        if (str == null || (vector = this.lookupVectors.vector(this.keyPreprocessor.preProcess(str))) == null) {
            return null;
        }
        return vector.transpose();
    }

    @Override // de.datexis.retrieval.index.IVocabulary
    public int index(String str) {
        int indexOf;
        if (str != null && (indexOf = this.keyVocabulary.indexOf(this.keyPreprocessor.preProcess(str))) >= 0) {
            return indexOf;
        }
        return -1;
    }

    @Override // de.datexis.retrieval.index.IVocabulary
    public String key(int i) {
        VocabWord elementAtIndex = this.keyVocabulary.elementAtIndex(i);
        if (elementAtIndex != null) {
            return elementAtIndex.getWord();
        }
        return null;
    }

    public List<String> keys() {
        ArrayList arrayList = new ArrayList(size());
        for (int i = 0; i < size(); i++) {
            arrayList.add(this.keyVocabulary.wordAtIndex(i));
        }
        return arrayList;
    }

    @Override // de.datexis.retrieval.index.IVocabulary
    public int size() {
        return this.keyVocabulary.numWords();
    }

    public long totalInstances() {
        return this.keyVocabulary.totalWordOccurrences();
    }

    public double frequency(String str) {
        return this.keyVocabulary.wordFrequency(this.keyPreprocessor.preProcess(str));
    }

    public double frequency(int i) {
        if (key(i) != null) {
            return this.keyVocabulary.wordFrequency(r0);
        }
        return 0.0d;
    }

    public double probability(String str) {
        return frequency(str) / totalInstances();
    }

    public double probability(int i) {
        return frequency(i) / totalInstances();
    }

    public INDArray decode(String str) {
        int index = index(str);
        Validate.isTrue(index >= 0, "key is not contained in index");
        return decode(index);
    }

    public INDArray decode(int i) {
        Validate.isTrue(i >= 0 && i < size(), "index out of bounds");
        return Nd4j.zeros(DataType.FLOAT, new long[]{size(), 1}).putScalarUnsafe(i, 1.0d);
    }

    public INDArray similarity(INDArray iNDArray) {
        Validate.isTrue(iNDArray.isColumnVector(), "column vector expected");
        Validate.isTrue(iNDArray.length() == getEmbeddingVectorSize(), "invalid vector size");
        return Transforms.unitVec(iNDArray.transpose()).mmul(this.lookupVectors.getSyn0().transpose()).transpose();
    }

    @Override // de.datexis.retrieval.index.IVectorIndex
    public List<IVectorIndex.IndexEntry> find(INDArray iNDArray, int i) {
        INDArray similarity = similarity(iNDArray);
        List<Double> topN = getTopN(similarity, i);
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < topN.size(); i2++) {
            IVectorIndex.IndexEntry indexEntry = new IVectorIndex.IndexEntry();
            indexEntry.index = topN.get(i2).intValue();
            indexEntry.key = key(indexEntry.index);
            indexEntry.similarity = similarity.getDouble(indexEntry.index);
            if (indexEntry.similarity != 0.0d) {
                arrayList.add(indexEntry);
            }
        }
        return arrayList;
    }

    private List<Double> getTopN(INDArray iNDArray, int i) {
        BasicModelUtils.ArrayComparator arrayComparator = new BasicModelUtils.ArrayComparator();
        PriorityQueue priorityQueue = new PriorityQueue(iNDArray.rows(), arrayComparator);
        for (int i2 = 0; i2 < iNDArray.length(); i2++) {
            Double[] dArr = {Double.valueOf(iNDArray.getDouble(i2)), Double.valueOf(i2)};
            if (priorityQueue.size() < i) {
                priorityQueue.add(dArr);
            } else if (arrayComparator.compare(dArr, (Double[]) priorityQueue.peek()) > 0) {
                priorityQueue.poll();
                priorityQueue.add(dArr);
            }
        }
        ArrayList arrayList = new ArrayList();
        while (!priorityQueue.isEmpty()) {
            arrayList.add(Double.valueOf(((Double[]) priorityQueue.poll())[1].doubleValue()));
        }
        return Lists.reverse(arrayList);
    }

    @Override // de.datexis.retrieval.index.IVectorIndex
    public IVectorIndex.IndexEntry find(INDArray iNDArray) {
        return find(iNDArray, 1).get(0);
    }

    public void writeVectors(Resource resource, String str) throws IOException {
        writeVectors(resource, str, null);
    }

    public void writeVectors(Resource resource, String str, Map<String, String> map) throws IOException {
        Resource resolve = resource.resolve(str + ".vectors.tsv");
        Resource resolve2 = resource.resolve(str + ".meta.tsv");
        Resource resolve3 = resource.resolve(str + ".glove.txt");
        PrintWriter printWriter = new PrintWriter(new OutputStreamWriter(resolve.getOutputStream(), StandardCharsets.UTF_8));
        PrintWriter printWriter2 = new PrintWriter(new OutputStreamWriter(resolve2.getOutputStream(), StandardCharsets.UTF_8));
        PrintWriter printWriter3 = new PrintWriter(new OutputStreamWriter(resolve3.getOutputStream(), StandardCharsets.UTF_8));
        try {
            int numWords = this.keyVocabulary.numWords();
            printWriter2.println("Key\tFreq");
            for (int i = 0; i < numWords; i++) {
                VocabWord elementAtIndex = this.keyVocabulary.elementAtIndex(i);
                String label = elementAtIndex.getLabel();
                INDArray vector = this.lookupVectors.vector(label);
                StringBuilder sb = new StringBuilder();
                StringBuilder sb2 = new StringBuilder();
                StringBuilder sb3 = new StringBuilder();
                sb3.append(map != null ? map.getOrDefault(label, label) : label).append("\t").append((int) elementAtIndex.getElementFrequency());
                sb.append(label.replaceAll("\\s+", "_")).append(" ");
                for (int i2 = 0; i2 < vector.length(); i2++) {
                    String fDbl8 = fDbl8(vector.getDouble(i2));
                    sb.append(fDbl8);
                    sb2.append(fDbl8);
                    if (i2 < vector.length() - 1) {
                        sb.append(" ");
                        sb2.append("\t");
                    }
                }
                printWriter.println(sb2.toString());
                printWriter2.println(sb3.toString());
                printWriter3.println(sb.toString());
            }
        } finally {
            printWriter.flush();
            printWriter2.flush();
            printWriter3.flush();
            printWriter.close();
            printWriter2.close();
            printWriter3.close();
        }
    }

    protected static String fDbl8(double d) {
        return String.format(Locale.ENGLISH, "%.8f", Double.valueOf(d));
    }
}
