package pitt.search.semanticvectors;

import cern.colt.matrix.AbstractFormatter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Hashtable;
import org.apache.lucene.analysis.pattern.PatternReplaceCharFilter;
import org.apache.lucene.index.DocsAndPositionsEnum;
import org.apache.lucene.index.LogDocMergePolicy;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.packed.PackedInts;
import pitt.search.semanticvectors.orthography.NumberRepresentation;
import pitt.search.semanticvectors.utils.VerbatimLogger;
import pitt.search.semanticvectors.vectors.PermutationUtils;
import pitt.search.semanticvectors.vectors.Vector;
import pitt.search.semanticvectors.vectors.VectorFactory;

/* loaded from: input_file:pitt/search/semanticvectors/TermTermVectorsFromLucene.class */
public class TermTermVectorsFromLucene {
    private FlagConfig flagConfig;
    private boolean retraining;
    private VectorStoreRAM semanticTermVectors;
    private VectorStore elementalTermVectors;
    private LuceneUtils luceneUtils;
    private VectorStoreRAM positionalNumberVectors;
    private int[][] permutationCache;
    static final short NONEXISTENT = -1;

    /* loaded from: input_file:pitt/search/semanticvectors/TermTermVectorsFromLucene$PositionalMethod.class */
    public enum PositionalMethod {
        BASIC,
        DIRECTIONAL,
        PERMUTATION,
        PERMUTATIONPLUSBASIC,
        PROXIMITY
    }

    public VectorStore getSemanticTermVectors() {
        return this.semanticTermVectors;
    }

    public TermTermVectorsFromLucene(FlagConfig flagConfig, VectorStore vectorStore) throws IOException {
        this.retraining = false;
        this.flagConfig = flagConfig;
        if (vectorStore != null) {
            this.retraining = true;
            this.elementalTermVectors = vectorStore;
            VerbatimLogger.info("Reusing basic term vectors; number of terms: " + vectorStore.getNumVectors() + AbstractFormatter.DEFAULT_ROW_SEPARATOR);
        } else {
            this.elementalTermVectors = new ElementalVectorStore(flagConfig);
        }
        if (flagConfig.positionalmethod() == PositionalMethod.PERMUTATION || flagConfig.positionalmethod() == PositionalMethod.PERMUTATIONPLUSBASIC) {
            initializePermutations();
        } else if (flagConfig.positionalmethod() == PositionalMethod.DIRECTIONAL) {
            initializeDirectionalPermutations();
        } else if (flagConfig.positionalmethod() == PositionalMethod.PROXIMITY) {
            initializeNumberRepresentations();
        }
        trainTermTermVectors();
    }

    private void initializePermutations() {
        this.permutationCache = new int[(2 * this.flagConfig.windowradius()) + 1][PermutationUtils.getPermutationLength(this.flagConfig.vectortype(), this.flagConfig.dimension())];
        for (int i = 0; i < (2 * this.flagConfig.windowradius()) + 1; i++) {
            this.permutationCache[i] = PermutationUtils.getShiftPermutation(this.flagConfig.vectortype(), this.flagConfig.dimension(), i - this.flagConfig.windowradius());
        }
    }

    private void initializeNumberRepresentations() {
        this.positionalNumberVectors = new NumberRepresentation(this.flagConfig).getNumberVectors(1, this.flagConfig.windowradius() + 2);
        initializeDirectionalPermutations();
        Enumeration<ObjectVector> allVectors = this.positionalNumberVectors.getAllVectors();
        VectorStoreRAM vectorStoreRAM = new VectorStoreRAM(this.flagConfig);
        while (allVectors.hasMoreElements()) {
            ObjectVector nextElement = allVectors.nextElement();
            Vector createZeroVector = VectorFactory.createZeroVector(this.flagConfig.vectortype(), this.flagConfig.dimension());
            createZeroVector.superpose(nextElement.getVector(), 1.0d, this.permutationCache[0]);
            createZeroVector.normalize();
            Vector createZeroVector2 = VectorFactory.createZeroVector(this.flagConfig.vectortype(), this.flagConfig.dimension());
            createZeroVector2.superpose(nextElement.getVector(), 1.0d, this.permutationCache[1]);
            createZeroVector2.normalize();
            if (((Integer) nextElement.getObject()).intValue() > 1) {
                vectorStoreRAM.putVector(Integer.valueOf((-1) * (((Integer) nextElement.getObject()).intValue() - 1)), createZeroVector);
                vectorStoreRAM.putVector(Integer.valueOf(((Integer) nextElement.getObject()).intValue() - 1), createZeroVector2);
            }
            VerbatimLogger.finest("\nInitialized number representation: " + nextElement.getObject());
        }
        try {
            VectorStoreWriter.writeVectorsInLuceneFormat("numbervectors.bin", this.flagConfig, vectorStoreRAM);
            this.positionalNumberVectors = vectorStoreRAM;
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private void initializeDirectionalPermutations() {
        this.permutationCache = new int[2][PermutationUtils.getPermutationLength(this.flagConfig.vectortype(), this.flagConfig.dimension())];
        this.permutationCache[0] = PermutationUtils.getShiftPermutation(this.flagConfig.vectortype(), this.flagConfig.dimension(), -1);
        this.permutationCache[1] = PermutationUtils.getShiftPermutation(this.flagConfig.vectortype(), this.flagConfig.dimension(), 1);
    }

    private void trainTermTermVectors() throws IOException, RuntimeException {
        LuceneUtils.compressIndex(this.flagConfig.luceneindexpath());
        this.luceneUtils = new LuceneUtils(this.flagConfig);
        if (!this.luceneUtils.getFieldInfos().hasVectors()) {
            throw new IOException("Term-term indexing requires a Lucene index containing TermPositionVectors.\nTry rebuilding Lucene index using pitt.search.lucene.IndexFilePositions");
        }
        this.semanticTermVectors = new VectorStoreRAM(this.flagConfig);
        int i = 0;
        for (String str : this.flagConfig.contentsfields()) {
            TermsEnum it = this.luceneUtils.getTermsForField(str).iterator(null);
            while (true) {
                BytesRef next = it.next();
                if (next != null) {
                    Term term = new Term(str, next);
                    if (this.luceneUtils.termFilter(term)) {
                        i++;
                        this.semanticTermVectors.putVector(term.text(), VectorFactory.createZeroVector(this.flagConfig.vectortype(), this.flagConfig.dimension()));
                        if (!this.retraining) {
                            this.elementalTermVectors.getVector(term.text());
                        }
                    }
                }
            }
        }
        VerbatimLogger.info("There are now elemental term vectors for " + i + " terms (and " + this.luceneUtils.getNumDocs() + " docs).\n");
        int numDocs = this.luceneUtils.getNumDocs();
        for (int i2 = 0; i2 < numDocs; i2++) {
            if (i2 % PatternReplaceCharFilter.DEFAULT_MAX_BLOCK_CHARS == 0 || (i2 < 10000 && i2 % LogDocMergePolicy.DEFAULT_MIN_MERGE_DOCS == 0)) {
                VerbatimLogger.info("Processed " + i2 + " documents ... ");
            }
            for (String str2 : this.flagConfig.contentsfields()) {
                Terms termVector = this.luceneUtils.getTermVector(i2, str2);
                if (termVector == null) {
                    VerbatimLogger.severe("No term vector for document " + i2);
                } else {
                    processTermPositionVector(termVector, str2);
                }
            }
        }
        VerbatimLogger.info("Created " + this.semanticTermVectors.getNumVectors() + " term vectors ...\n");
        VerbatimLogger.info("Normalizing term vectors.\n");
        Enumeration<ObjectVector> allVectors = this.semanticTermVectors.getAllVectors();
        while (allVectors.hasMoreElements()) {
            allVectors.nextElement().getVector().normalize();
        }
        if ((this.flagConfig.positionalmethod() == PositionalMethod.PERMUTATION || this.flagConfig.positionalmethod() == PositionalMethod.PERMUTATIONPLUSBASIC) && !this.retraining) {
            VerbatimLogger.info("Normalizing and writing elemental vectors to " + this.flagConfig.elementalvectorfile() + AbstractFormatter.DEFAULT_ROW_SEPARATOR);
            Enumeration<ObjectVector> allVectors2 = this.elementalTermVectors.getAllVectors();
            while (allVectors2.hasMoreElements()) {
                allVectors2.nextElement().getVector().normalize();
            }
            VectorStoreWriter.writeVectors(this.flagConfig.elementalvectorfile(), this.flagConfig, this.elementalTermVectors);
        }
    }

    private void processTermPositionVector(Terms terms, String str) throws ArrayIndexOutOfBoundsException, IOException {
        String str2;
        if (terms == null) {
            return;
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Hashtable hashtable = new Hashtable();
        TermsEnum it = terms.iterator(null);
        int i = 0;
        while (true) {
            BytesRef next = it.next();
            if (next == null) {
                for (int i2 = 0; i2 < hashtable.size(); i2++) {
                    if (hashtable.get(Integer.valueOf(i2)) != null) {
                        String str3 = (String) arrayList.get(((Integer) hashtable.get(Integer.valueOf(i2))).intValue());
                        int max = Math.max(0, i2 - this.flagConfig.windowradius());
                        int min = Math.min(i2 + this.flagConfig.windowradius(), hashtable.size() - 1);
                        for (int i3 = max; i3 <= min; i3++) {
                            if (i3 != i2 && hashtable.get(Integer.valueOf(i3)) != null && (str2 = (String) arrayList.get(((Integer) hashtable.get(Integer.valueOf(i3))).intValue())) != null) {
                                Vector vector = this.elementalTermVectors.getVector(str2);
                                float globalTermWeight = this.luceneUtils.getGlobalTermWeight(new Term(str, str2));
                                if (this.flagConfig.positionalmethod() == PositionalMethod.PROXIMITY) {
                                    vector = this.elementalTermVectors.getVector(str2).copy();
                                    vector.bind(this.positionalNumberVectors.getVector(Integer.valueOf(i3 - i2)));
                                }
                                if (this.flagConfig.positionalmethod() == PositionalMethod.BASIC || this.flagConfig.positionalmethod() == PositionalMethod.PERMUTATIONPLUSBASIC || this.flagConfig.positionalmethod() == PositionalMethod.PROXIMITY) {
                                    this.semanticTermVectors.getVector(str3).superpose(vector, globalTermWeight, null);
                                }
                                if (this.flagConfig.positionalmethod() == PositionalMethod.PERMUTATION || this.flagConfig.positionalmethod() == PositionalMethod.PERMUTATIONPLUSBASIC) {
                                    this.semanticTermVectors.getVector(str3).superpose(vector, globalTermWeight, this.permutationCache[(i3 - i2) + this.flagConfig.windowradius()]);
                                } else if (this.flagConfig.positionalmethod() == PositionalMethod.DIRECTIONAL) {
                                    this.semanticTermVectors.getVector(str3).superpose(vector, globalTermWeight, this.permutationCache[(int) Math.max(PackedInts.COMPACT, Math.signum(i3 - i2))]);
                                }
                            }
                        }
                    }
                }
                return;
            }
            String utf8ToString = next.utf8ToString();
            if (this.semanticTermVectors.containsVector(utf8ToString)) {
                DocsAndPositionsEnum docsAndPositions = it.docsAndPositions(null, null);
                if (docsAndPositions == null) {
                    return;
                }
                docsAndPositions.nextDoc();
                arrayList2.add(Integer.valueOf(docsAndPositions.freq()));
                arrayList.add(utf8ToString);
                for (int i4 = 0; i4 < docsAndPositions.freq(); i4++) {
                    hashtable.put(new Integer(docsAndPositions.nextPosition()), Integer.valueOf(i));
                }
                i++;
            }
        }
    }
}
