package edu.umass.cs.mallet.projects.seg_plus_coref.condclust.tui;

import com.wcohen.secondstring.AbstractStatisticalTokenDistance;
import com.wcohen.secondstring.TFIDF;
import com.wcohen.secondstring.tokens.NGramTokenizer;
import com.wcohen.secondstring.tokens.SimpleTokenizer;
import edu.umass.cs.mallet.base.classify.Classifier;
import edu.umass.cs.mallet.base.classify.MaxEnt;
import edu.umass.cs.mallet.base.classify.MaxEntTrainer;
import edu.umass.cs.mallet.base.classify.Trial;
import edu.umass.cs.mallet.base.fst.Transducer;
import edu.umass.cs.mallet.base.pipe.Pipe;
import edu.umass.cs.mallet.base.pipe.PrintInputAndTarget;
import edu.umass.cs.mallet.base.pipe.SerialPipes;
import edu.umass.cs.mallet.base.pipe.Target2Label;
import edu.umass.cs.mallet.base.pipe.iterator.FileIterator;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.util.CommandOption;
import edu.umass.cs.mallet.base.util.MalletLogger;
import edu.umass.cs.mallet.base.util.RegexFileFilter;
import edu.umass.cs.mallet.projects.seg_plus_coref.condclust.cluster.ConditionalClusterer;
import edu.umass.cs.mallet.projects.seg_plus_coref.condclust.cluster.ConditionalClustererTrainer;
import edu.umass.cs.mallet.projects.seg_plus_coref.condclust.pipe.AllLinks;
import edu.umass.cs.mallet.projects.seg_plus_coref.condclust.pipe.ClusterHomogeneity;
import edu.umass.cs.mallet.projects.seg_plus_coref.condclust.pipe.ClusterSize;
import edu.umass.cs.mallet.projects.seg_plus_coref.condclust.pipe.ForAll;
import edu.umass.cs.mallet.projects.seg_plus_coref.condclust.pipe.NodeClusterPair2FeatureVector;
import edu.umass.cs.mallet.projects.seg_plus_coref.condclust.pipe.ThereExists;
import edu.umass.cs.mallet.projects.seg_plus_coref.coreference.AuthorPipe;
import edu.umass.cs.mallet.projects.seg_plus_coref.coreference.Citation;
import edu.umass.cs.mallet.projects.seg_plus_coref.coreference.CitationUtils;
import edu.umass.cs.mallet.projects.seg_plus_coref.coreference.ExactFieldMatchPipe;
import edu.umass.cs.mallet.projects.seg_plus_coref.coreference.FieldStringDistancePipe;
import edu.umass.cs.mallet.projects.seg_plus_coref.coreference.GlobalPipe;
import edu.umass.cs.mallet.projects.seg_plus_coref.coreference.HeuristicPipe;
import edu.umass.cs.mallet.projects.seg_plus_coref.coreference.InterFieldPipe;
import edu.umass.cs.mallet.projects.seg_plus_coref.coreference.NodePair2FeatureVector;
import edu.umass.cs.mallet.projects.seg_plus_coref.coreference.PageMatchPipe;
import edu.umass.cs.mallet.projects.seg_plus_coref.coreference.YearsWithinFivePipe;
import edu.umass.cs.mallet.projects.seg_plus_coref.ie.IEInterface;
import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Random;
import java.util.logging.Logger;
import java.util.regex.Pattern;

/* loaded from: input_file:edu/umass/cs/mallet/projects/seg_plus_coref/condclust/tui/ConditionalClustererTUI.class */
public class ConditionalClustererTUI {
    private static Logger logger = MalletLogger.getLogger(ConditionalClustererTUI.class.getName());
    static CommandOption.SpacedStrings trainingDirs = new CommandOption.SpacedStrings(ConditionalClustererTUI.class, "training-dirs", "DIR...", true, null, "The directories containing the citations to be clustered at training time. One file per cluster.", null);
    static CommandOption.SpacedStrings testingDirs = new CommandOption.SpacedStrings(ConditionalClustererTUI.class, "testing-dirs", "DIR...", true, null, "The directories containing the citations to be clustered at test time. One file per cluster.", null);
    static CommandOption.Boolean randomOrderClustering = new CommandOption.Boolean(ConditionalClustererTUI.class, "random-order-clustering", "BOOL", false, false, "At test time, choose the nodes to consider at random", null);
    static CommandOption.Boolean sampleTrainingInstances = new CommandOption.Boolean(ConditionalClustererTUI.class, "sample-training-instances", "BOOL", false, true, "Generate instances by sampling from true clusters", null);
    static CommandOption.Integer numberTrainingInstances = new CommandOption.Integer(ConditionalClustererTUI.class, "number-training-instances", "INTEGER", true, 5000, "The number of training instances to sample", null);
    static CommandOption.Integer randomSeed = new CommandOption.Integer(ConditionalClustererTUI.class, "random-seed", "INTEGER", true, 1, "Seed for random number in random order clustering", null);
    static CommandOption.Integer numRandomTrials = new CommandOption.Integer(ConditionalClustererTUI.class, "num-random-trials", "INTEGER", true, 5, "number of random trials to run", null);
    static CommandOption.Boolean errorAnalysis = new CommandOption.Boolean(ConditionalClustererTUI.class, "error-analysis", "BOOL", false, false, "Print errors (False positives)", null);
    static CommandOption.Boolean useCRF = new CommandOption.Boolean(ConditionalClustererTUI.class, "use-crf", "BOOL", false, false, "Use CRF or not.", null);
    static CommandOption.Boolean useFeatureInduction = new CommandOption.Boolean(ConditionalClustererTUI.class, "use-feature-induction", "BOOL", false, false, "Use Feature Induction or Not.", null);
    static CommandOption.Boolean useClusterSize = new CommandOption.Boolean(ConditionalClustererTUI.class, "use-cluster-size", "BOOL", true, true, "add feature that is cluster's size", null);
    static CommandOption.Boolean useThereExists = new CommandOption.Boolean(ConditionalClustererTUI.class, "use-there-exists", "BOOL", true, true, "Use thereExists pipe.", null);
    static CommandOption.Boolean usePairwiseClassifier = new CommandOption.Boolean(ConditionalClustererTUI.class, "use-pairwise-classifier", "BOOL", true, true, "Use pairwise classifier to weight edges.", null);
    static CommandOption.Boolean useClusterHomogeneity = new CommandOption.Boolean(ConditionalClustererTUI.class, "use-cluster-homogeneity", "BOOL", true, true, "add feature that is within-cluster similarity.", null);
    static CommandOption.Boolean printInputAndTarget = new CommandOption.Boolean(ConditionalClustererTUI.class, "print-input-and-target", "BOOL", false, false, "Print features and target.", null);
    static CommandOption.String crfInputFile = new CommandOption.String(ConditionalClustererTUI.class, "crf-input-file", "FILENAME", true, null, "The name of the file to read the trained CRF for testing.", null);
    static CommandOption.Integer numNBest = new CommandOption.Integer(ConditionalClustererTUI.class, "num-n-best", "INTEGER", true, 3, "Number of n-best candidates to store.", null);
    static CommandOption.Integer nthViterbi = new CommandOption.Integer(ConditionalClustererTUI.class, "nth-viterbi", "INTEGER", true, 0, "Number of n-best candidates to use .", null);
    static CommandOption.Double negativeClusterThreshold = new CommandOption.Double(ConditionalClustererTUI.class, "negative-cluster-threshold", "DECIMAL", true, Transducer.ZERO_COST, "Decision threhold to place a node in a cluster. Takes opposite of input because CommandOptions seem to have trouble with negative inputs", null);
    static CommandOption.Double positiveInstanceRatio = new CommandOption.Double(ConditionalClustererTUI.class, "positive-instance-ratio", "DECIMAL", true, 0.1d, "Ratio of positive to negative training instances", null);
    static final CommandOption.List commandOptions = new CommandOption.List("Training and testing a conditional clusterer.", new CommandOption[]{trainingDirs, testingDirs, sampleTrainingInstances, numberTrainingInstances, errorAnalysis, useCRF, useFeatureInduction, crfInputFile, numNBest, nthViterbi, negativeClusterThreshold, randomOrderClustering, randomSeed, numRandomTrials, usePairwiseClassifier, useThereExists, useClusterSize, useClusterHomogeneity, printInputAndTarget, positiveInstanceRatio});

    public static void main(String[] strArr) {
        commandOptions.process(strArr);
        commandOptions.logOptions(logger);
        IEInterface loadIEInterface = loadIEInterface();
        ArrayList[] createNodesFromFiles = createNodesFromFiles(trainingDirs.value(), loadIEInterface, CitationUtils.PAPER);
        ArrayList[] createNodesFromFiles2 = createNodesFromFiles(testingDirs.value(), loadIEInterface, CitationUtils.PAPER);
        ArrayList arrayList = new ArrayList();
        for (ArrayList arrayList2 : createNodesFromFiles) {
            arrayList.addAll(arrayList2);
        }
        ArrayList arrayList3 = new ArrayList();
        for (ArrayList arrayList4 : createNodesFromFiles2) {
            arrayList3.addAll(arrayList4);
        }
        Collection makeCollections = CitationUtils.makeCollections(arrayList);
        Collection makeCollections2 = CitationUtils.makeCollections(arrayList3);
        ConditionalClusterer train = new ConditionalClustererTrainer(getPipe(usePairwiseClassifier.value() ? trainPairwiseClassifier(createNodesFromFiles, getPaperPipe(arrayList)) : null), -negativeClusterThreshold.value()).train(makeCollections, useFeatureInduction.value(), sampleTrainingInstances.value(), positiveInstanceRatio.value(), numberTrainingInstances.value());
        System.err.println("FINISHED TRAINING. BEGIN CLUSTERING.");
        if (!randomOrderClustering.value()) {
            CitationUtils.evaluateClustering(makeCollections2, train.cluster(arrayList3, errorAnalysis.value() ? makeCollections2 : null), "GREEDY COREFERENCE RESULTS");
            return;
        }
        for (int i = 0; i < numRandomTrials.value(); i++) {
            Collection clusterRandom = train.clusterRandom(arrayList3, errorAnalysis.value() ? makeCollections2 : null, new Random(randomSeed.value() + (i * 10)));
            System.err.println("FINISHED CLUSTERING. BEGIN EVALUATION.");
            CitationUtils.evaluateClustering(makeCollections2, clusterRandom, "RANDOM TRIAL " + i + " COREFERENCE RESULTS");
        }
    }

    private static Classifier trainPairwiseClassifier(ArrayList[] arrayListArr, Pipe pipe) {
        InstanceList instanceList = new InstanceList(pipe);
        for (ArrayList arrayList : arrayListArr) {
            instanceList.add(CitationUtils.makePairs(pipe, arrayList));
        }
        MaxEnt maxEnt = (MaxEnt) new MaxEntTrainer().train(instanceList, null, null, null, null);
        instanceList.getDataAlphabet().stopGrowth();
        System.out.println("Pairwise classifier: -> Training F1 on \"yes\" is: " + new Trial(maxEnt, instanceList).labelF1("yes"));
        return maxEnt;
    }

    private static Pipe getPaperPipe(ArrayList arrayList) {
        AbstractStatisticalTokenDistance computeDistanceMetric = CitationUtils.computeDistanceMetric(arrayList);
        TFIDF tfidf = new TFIDF();
        TFIDF tfidf2 = new TFIDF(new NGramTokenizer(3, 3, false, new SimpleTokenizer(true, true)));
        CitationUtils.makeDistMetric(arrayList, tfidf, tfidf2);
        return new SerialPipes(new Pipe[]{new ExactFieldMatchPipe(Citation.corefFields), new PageMatchPipe(), new YearsWithinFivePipe(), new FieldStringDistancePipe(tfidf2, Citation.corefFields, "trigramTFIDF"), new GlobalPipe(computeDistanceMetric), new AuthorPipe(computeDistanceMetric), new HeuristicPipe(Citation.corefFields), new InterFieldPipe(), new NodePair2FeatureVector(), new Target2Label()});
    }

    private static Pipe getPipe(Classifier classifier) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new ForAll(Citation.corefFields));
        if (useThereExists.value()) {
            arrayList.add(new ThereExists(Citation.corefFields));
        }
        if (classifier != null) {
            arrayList.add(new AllLinks(classifier));
            if (useClusterHomogeneity.value()) {
                arrayList.add(new ClusterHomogeneity(classifier));
            }
        }
        if (useClusterSize.value()) {
            arrayList.add(new ClusterSize());
        }
        arrayList.add(new NodeClusterPair2FeatureVector());
        if (printInputAndTarget.value()) {
            arrayList.add(new PrintInputAndTarget());
        }
        arrayList.add(new Target2Label());
        return new SerialPipes((Pipe[]) arrayList.toArray(new Pipe[0]));
    }

    private static IEInterface loadIEInterface() {
        IEInterface iEInterface = null;
        if (useCRF.value()) {
            File file = new File(crfInputFile.value());
            iEInterface = new IEInterface(file);
            iEInterface.loadCRF(file);
        }
        return iEInterface;
    }

    private static ArrayList[] createNodesFromFiles(String[] strArr, IEInterface iEInterface, String str) {
        ArrayList[] arrayListArr = new ArrayList[strArr.length];
        new ArrayList();
        for (int i = 0; i < strArr.length; i++) {
            arrayListArr[i] = CitationUtils.computeNodes(new FileIterator(new File(strArr[i]), new RegexFileFilter(Pattern.compile(".*"))).getFileArray(), iEInterface, useCRF.value(), numNBest.value(), nthViterbi.value(), str);
        }
        return arrayListArr;
    }
}
