package uk.ac.cam.ch.wwmm.oscarMEMM.memm.rescorer;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import nu.xom.Attribute;
import nu.xom.Builder;
import nu.xom.Document;
import nu.xom.Element;
import nu.xom.Nodes;
import nu.xom.ParsingException;
import opennlp.maxent.GIS;
import opennlp.maxent.GISModel;
import opennlp.model.Event;
import opennlp.model.EventCollectorAsStream;
import opennlp.model.TwoPassDataIndexer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import uk.ac.cam.ch.wwmm.oscar.document.NamedEntity;
import uk.ac.cam.ch.wwmm.oscar.document.TokenSequence;
import uk.ac.cam.ch.wwmm.oscar.document.XOMBasedProcessingDocument;
import uk.ac.cam.ch.wwmm.oscar.document.XOMBasedProcessingDocumentFactory;
import uk.ac.cam.ch.wwmm.oscar.exceptions.DataFormatException;
import uk.ac.cam.ch.wwmm.oscar.types.NamedEntityType;
import uk.ac.cam.ch.wwmm.oscar.xmltools.XOMTools;
import uk.ac.cam.ch.wwmm.oscarMEMM.memm.MEMMModel;
import uk.ac.cam.ch.wwmm.oscarMEMM.memm.gis.SimpleEventCollector;
import uk.ac.cam.ch.wwmm.oscarMEMM.memm.gis.StringGISModelWriter;
import uk.ac.cam.ch.wwmm.oscartokeniser.Tokeniser;

/* loaded from: input_file:uk/ac/cam/ch/wwmm/oscarMEMM/memm/rescorer/MEMMOutputRescorerTrainer.class */
public final class MEMMOutputRescorerTrainer {
    private MEMMModel memm;
    private double confidenceThreshold;
    Map<NamedEntityType, GISModel> modelsByNamedEntityType;
    private static final Logger LOG = LoggerFactory.getLogger(MEMMOutputRescorerTrainer.class);
    static String experName = "rescoreHalf";
    int trainingCycles = 200;
    Map<NamedEntityType, List<Event>> eventsByNamedEntityType = new HashMap();
    double grandTotalGain = 0.0d;
    List<Double> goodProbsBefore = new ArrayList();
    List<Double> goodProbsAfter = new ArrayList();
    List<Double> badProbsBefore = new ArrayList();
    List<Double> badProbsAfter = new ArrayList();
    int totalRecall = 0;

    @Deprecated
    private static void recPrec(List<Double> list, List<Double> list2, int i, String str) throws Exception {
        double doubleValue;
        PrintWriter printWriter = new PrintWriter(new FileWriter(new File("/home/ptc24/tmp/rpres/" + str + "_" + experName + ".csv")));
        Collections.sort(list, Collections.reverseOrder());
        Collections.sort(list2, Collections.reverseOrder());
        int i2 = 0;
        int i3 = 0;
        while (i2 < list.size() && i3 < list2.size()) {
            if (list.get(i2).doubleValue() > list2.get(i3).doubleValue()) {
                doubleValue = list.get(i2).doubleValue();
                i2++;
            } else {
                doubleValue = list2.get(i3).doubleValue();
                i3++;
            }
            printWriter.println(((i2 * 1.0d) / i) + "\t" + ((i2 * 1.0d) / (i2 + i3)) + "\t" + doubleValue);
        }
        printWriter.close();
    }

    public MEMMOutputRescorerTrainer(MEMMModel mEMMModel, double d) {
        this.memm = mEMMModel;
        this.confidenceThreshold = d;
    }

    @Deprecated
    public void trainOnFile(File file) throws Exception {
        trainOnFile(file, this.memm);
    }

    public void trainOnFile(File file, MEMMModel mEMMModel) throws IOException, DataFormatException {
        try {
            Document build = new Builder().build(file);
            LOG.debug(file.getParentFile().getName());
            Nodes query = build.query("//cmlPile");
            for (int i = 0; i < query.size(); i++) {
                query.get(i).detach();
            }
            Nodes query2 = build.query("//ne[@type='CPR']");
            for (int i2 = 0; i2 < query2.size(); i2++) {
                XOMTools.removeElementPreservingText(query2.get(i2));
            }
            XOMBasedProcessingDocument makeTokenisedDocument = XOMBasedProcessingDocumentFactory.getInstance().makeTokenisedDocument(Tokeniser.getDefaultInstance(), build, true, false);
            ArrayList<NamedEntity> arrayList = new ArrayList();
            HashSet hashSet = new HashSet();
            for (TokenSequence tokenSequence : makeTokenisedDocument.getTokenSequences()) {
                Nodes query3 = tokenSequence.getElem().query(".//ne");
                for (int i3 = 0; i3 < query3.size(); i3++) {
                    Element element = query3.get(i3);
                    hashSet.add("[NE:" + element.getAttributeValue("type") + ":" + element.getAttributeValue("xtspanstart") + ":" + element.getAttributeValue("xtspanend") + ":" + element.getValue() + "]");
                }
                arrayList.addAll(this.memm.findNEs(tokenSequence, this.confidenceThreshold));
            }
            FeatureExtractor featureExtractor = new FeatureExtractor(arrayList);
            for (NamedEntity namedEntity : arrayList) {
                String str = hashSet.contains(namedEntity.toString()) ? "T" : "F";
                List features = featureExtractor.getFeatures(namedEntity, this.memm.getChemNameDictNames());
                NamedEntityType type = namedEntity.getType();
                if (!this.eventsByNamedEntityType.containsKey(type)) {
                    this.eventsByNamedEntityType.put(type, new ArrayList());
                }
                this.eventsByNamedEntityType.get(type).add(new Event(str, (String[]) features.toArray(new String[0])));
            }
        } catch (ParsingException e) {
            throw new DataFormatException("incorrectly formatted training file: " + file.getName());
        }
    }

    public void finishTraining() throws IOException {
        this.modelsByNamedEntityType = new HashMap();
        for (NamedEntityType namedEntityType : this.eventsByNamedEntityType.keySet()) {
            List<Event> list = this.eventsByNamedEntityType.get(namedEntityType);
            if (list.size() == 1) {
                list.add(list.get(0));
            }
            this.modelsByNamedEntityType.put(namedEntityType, GIS.trainModel(this.trainingCycles, new TwoPassDataIndexer(new EventCollectorAsStream(new SimpleEventCollector(list)), 1)));
        }
    }

    private void runOnFile(File file) throws Exception {
        Document build = new Builder().build(file);
        file.getParentFile().getName();
        Nodes query = build.query("//cmlPile");
        for (int i = 0; i < query.size(); i++) {
            query.get(i).detach();
        }
        Nodes query2 = build.query("//ne[@type='CPR']");
        for (int i2 = 0; i2 < query2.size(); i2++) {
            XOMTools.removeElementPreservingText(query2.get(i2));
        }
        XOMBasedProcessingDocument makeTokenisedDocument = XOMBasedProcessingDocumentFactory.getInstance().makeTokenisedDocument(Tokeniser.getDefaultInstance(), build, true, false);
        ArrayList<NamedEntity> arrayList = new ArrayList();
        HashSet hashSet = new HashSet();
        for (TokenSequence tokenSequence : makeTokenisedDocument.getTokenSequences()) {
            Nodes query3 = tokenSequence.getElem().query(".//ne");
            for (int i3 = 0; i3 < query3.size(); i3++) {
                Element element = query3.get(i3);
                hashSet.add("[NE:" + element.getAttributeValue("type") + ":" + element.getAttributeValue("xtspanstart") + ":" + element.getAttributeValue("xtspanend") + ":" + element.getValue() + "]");
            }
            arrayList.addAll(this.memm.findNEs(tokenSequence, this.confidenceThreshold));
        }
        this.totalRecall += hashSet.size();
        double d = 0.0d;
        FeatureExtractor featureExtractor = new FeatureExtractor(arrayList);
        for (NamedEntity namedEntity : arrayList) {
            List features = featureExtractor.getFeatures(namedEntity, this.memm.getChemNameDictNames());
            NamedEntityType type = namedEntity.getType();
            if (this.modelsByNamedEntityType.containsKey(type)) {
                GISModel gISModel = this.modelsByNamedEntityType.get(type);
                if (gISModel.getNumOutcomes() == 2) {
                    double d2 = gISModel.eval((String[]) features.toArray(new String[0]))[gISModel.getIndex("T")];
                    double confidence = namedEntity.getConfidence();
                    double d3 = d2;
                    double d4 = confidence;
                    if (hashSet.contains(namedEntity.toString())) {
                        d3 = 1.0d - d2;
                        d4 = 1.0d - confidence;
                        this.goodProbsBefore.add(Double.valueOf(confidence));
                        this.goodProbsAfter.add(Double.valueOf(d2));
                    } else {
                        this.badProbsBefore.add(Double.valueOf(confidence));
                        this.badProbsAfter.add(Double.valueOf(d2));
                    }
                    double d5 = confidence * (1.0d - confidence);
                    double d6 = d2 * (1.0d - d2);
                    d += (Math.log(d4) - Math.log(d3)) / Math.log(2.0d);
                }
            }
        }
        this.grandTotalGain += d;
    }

    public Element writeElement() throws IOException {
        Element element = new Element("rescorer");
        for (NamedEntityType namedEntityType : this.modelsByNamedEntityType.keySet()) {
            Element element2 = new Element("maxent");
            element2.addAttribute(new Attribute("type", namedEntityType.getName()));
            StringGISModelWriter stringGISModelWriter = new StringGISModelWriter(this.modelsByNamedEntityType.get(namedEntityType));
            stringGISModelWriter.persist();
            element2.appendChild(stringGISModelWriter.toString());
            element.appendChild(element2);
        }
        return element;
    }

    public MEMMOutputRescorer getMEMMOutputRescorer() {
        MEMMOutputRescorer mEMMOutputRescorer = new MEMMOutputRescorer();
        try {
            mEMMOutputRescorer.readElement(writeElement());
            return mEMMOutputRescorer;
        } catch (Exception e) {
            throw new Error("Error while creating MEMM output rescorer: " + e.getMessage(), e);
        }
    }
}
