package pitt.search.semanticvectors;

import cern.colt.matrix.AbstractFormatter;
import java.io.IOException;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.logging.Logger;
import org.apache.lucene.analysis.pattern.PatternReplaceCharFilter;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.DocsEnum;
import org.apache.lucene.index.LogDocMergePolicy;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.util.BytesRef;
import pitt.search.semanticvectors.utils.VerbatimLogger;
import pitt.search.semanticvectors.vectors.Vector;
import pitt.search.semanticvectors.vectors.VectorFactory;

/* loaded from: input_file:pitt/search/semanticvectors/PSI.class */
public class PSI {
    private static final Logger logger = Logger.getLogger(PSI.class.getCanonicalName());
    private FlagConfig flagConfig;
    private ElementalVectorStore elementalItemVectors;
    private ElementalVectorStore predicateVectors;
    private VectorStoreRAM semanticItemVectors;
    private static final String SUBJECT_FIELD = "subject";
    private static final String PREDICATE_FIELD = "predicate";
    private static final String OBJECT_FIELD = "object";
    private static final String PREDICATION_FIELD = "predication";
    private String[] itemFields = {SUBJECT_FIELD, OBJECT_FIELD};
    private LuceneUtils luceneUtils;

    private PSI() {
    }

    public static void createIncrementalPSIVectors(FlagConfig flagConfig) throws IOException {
        PSI psi = new PSI();
        psi.flagConfig = flagConfig;
        if (psi.luceneUtils == null) {
            psi.luceneUtils = new LuceneUtils(flagConfig);
        }
        psi.trainIncrementalPSIVectors();
    }

    private void trainIncrementalPSIVectors() throws IOException {
        this.elementalItemVectors = new ElementalVectorStore(this.flagConfig);
        this.semanticItemVectors = new VectorStoreRAM(this.flagConfig);
        this.predicateVectors = new ElementalVectorStore(this.flagConfig);
        this.flagConfig.setContentsfields(this.itemFields);
        HashSet hashSet = new HashSet();
        for (String str : this.itemFields) {
            Terms termsForField = this.luceneUtils.getTermsForField(str);
            if (termsForField == null) {
                throw new NullPointerException(String.format("No terms for field '%s'. Please check that index at '%s' was built correctly for use with PSI.", str, this.flagConfig.luceneindexpath()));
            }
            TermsEnum it = termsForField.iterator(null);
            while (true) {
                BytesRef next = it.next();
                if (next != null) {
                    Term term = new Term(str, next);
                    if (!this.luceneUtils.termFilter(term)) {
                        VerbatimLogger.fine("Filtering out term: " + term + AbstractFormatter.DEFAULT_ROW_SEPARATOR);
                    } else if (!hashSet.contains(term.text())) {
                        hashSet.add(term.text());
                        this.elementalItemVectors.getVector(term.text());
                        this.semanticItemVectors.putVector(term.text(), VectorFactory.createZeroVector(this.flagConfig.vectortype(), this.flagConfig.dimension()));
                    }
                }
            }
        }
        Terms termsForField2 = this.luceneUtils.getTermsForField(PREDICATE_FIELD);
        String[] strArr = {PREDICATE_FIELD};
        TermsEnum it2 = termsForField2.iterator(null);
        while (true) {
            BytesRef next2 = it2.next();
            if (next2 == null) {
                break;
            }
            Term term2 = new Term(PREDICATE_FIELD, next2);
            if (this.luceneUtils.termFilter(term2, strArr, 0, Integer.MAX_VALUE, Integer.MAX_VALUE, 1)) {
                this.predicateVectors.getVector(term2.text().trim());
                this.predicateVectors.getVector(term2.text().trim() + "-INV");
            }
        }
        TermsEnum it3 = this.luceneUtils.getTermsForField(PREDICATION_FIELD).iterator(null);
        while (true) {
            BytesRef next3 = it3.next();
            if (next3 == null) {
                break;
            }
            Term term3 = new Term(PREDICATION_FIELD, next3);
            int i = 0 + 1;
            if (i > 0 && (i % PatternReplaceCharFilter.DEFAULT_MAX_BLOCK_CHARS == 0 || (i < 10000 && i % LogDocMergePolicy.DEFAULT_MIN_MERGE_DOCS == 0))) {
                VerbatimLogger.info("Processed " + i + " unique predications ... ");
            }
            DocsEnum docsForTerm = this.luceneUtils.getDocsForTerm(term3);
            docsForTerm.nextDoc();
            Document doc = this.luceneUtils.getDoc(docsForTerm.docID());
            String str2 = doc.get(SUBJECT_FIELD);
            String str3 = doc.get(PREDICATE_FIELD);
            String str4 = doc.get(OBJECT_FIELD);
            if (this.elementalItemVectors.containsVector(str4) && this.elementalItemVectors.containsVector(str2) && this.predicateVectors.containsVector(str3)) {
                float globalTermWeight = this.luceneUtils.getGlobalTermWeight(new Term(SUBJECT_FIELD, str2));
                float globalTermWeight2 = this.luceneUtils.getGlobalTermWeight(new Term(OBJECT_FIELD, str4));
                float localTermWeight = this.luceneUtils.getLocalTermWeight(this.luceneUtils.getGlobalTermFreq(term3));
                Vector vector = this.semanticItemVectors.getVector(str2);
                Vector vector2 = this.semanticItemVectors.getVector(str4);
                Vector vector3 = this.elementalItemVectors.getVector(str2);
                Vector vector4 = this.elementalItemVectors.getVector(str4);
                Vector vector5 = this.predicateVectors.getVector(str3);
                Vector vector6 = this.predicateVectors.getVector(str3 + "-INV");
                Vector copy = vector4.copy();
                copy.bind(vector5);
                vector.superpose(copy, localTermWeight * globalTermWeight2, null);
                Vector copy2 = vector3.copy();
                copy2.bind(vector6);
                vector2.superpose(copy2, localTermWeight * globalTermWeight, null);
            } else {
                logger.info("skipping predication " + str2 + " " + str3 + " " + str4);
            }
        }
        Enumeration<ObjectVector> allVectors = this.semanticItemVectors.getAllVectors();
        while (allVectors.hasMoreElements()) {
            allVectors.nextElement().getVector().normalize();
        }
        VectorStoreWriter.writeVectors(this.flagConfig.elementalvectorfile(), this.flagConfig, this.elementalItemVectors);
        VectorStoreWriter.writeVectors(this.flagConfig.semanticvectorfile(), this.flagConfig, this.semanticItemVectors);
        VectorStoreWriter.writeVectors(this.flagConfig.predicatevectorfile(), this.flagConfig, this.predicateVectors);
        VerbatimLogger.info("Finished writing vectors.\n");
    }

    public static void main(String[] strArr) throws IllegalArgumentException, IOException {
        FlagConfig flagConfig = FlagConfig.getFlagConfig(strArr);
        String[] strArr2 = flagConfig.remainingArgs;
        if (flagConfig.luceneindexpath().isEmpty()) {
            throw new IllegalArgumentException("-luceneindexpath argument must be provided.");
        }
        VerbatimLogger.info("Building PSI model from index in: " + flagConfig.luceneindexpath() + AbstractFormatter.DEFAULT_ROW_SEPARATOR);
        VerbatimLogger.info("Minimum frequency = " + flagConfig.minfrequency() + AbstractFormatter.DEFAULT_ROW_SEPARATOR);
        VerbatimLogger.info("Maximum frequency = " + flagConfig.maxfrequency() + AbstractFormatter.DEFAULT_ROW_SEPARATOR);
        VerbatimLogger.info("Number non-alphabet characters = " + flagConfig.maxnonalphabetchars() + AbstractFormatter.DEFAULT_ROW_SEPARATOR);
        createIncrementalPSIVectors(flagConfig);
    }
}
