package de.julielab.jcore.ae.jnet.tagger;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.MaxEnt;
import cc.mallet.classify.MaxEntTrainer;
import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFTrainerByLabelLikelihood;
import cc.mallet.fst.Segment;
import cc.mallet.fst.SumLatticeConstrained;
import cc.mallet.fst.SumLatticeDefault;
import cc.mallet.fst.Transducer;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import de.julielab.jcore.ae.jnet.utils.IOEvaluation;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Properties;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.propertyeditors.StringArrayPropertyEditor;

/* loaded from: input_file:de/julielab/jcore/ae/jnet/tagger/NETagger.class */
public class NETagger {
    private Properties featureConfig;
    static Logger LOGGER = LoggerFactory.getLogger((Class<?>) NETagger.class);
    private Object model = null;
    private boolean trained = false;
    private int number_iterations = 0;
    private boolean max_ent = false;
    private Pipe generalPipe = null;
    private Pipe dummyPipe = null;

    public NETagger() {
        this.featureConfig = null;
        Properties properties = new Properties();
        InputStream resourceAsStream = getClass().getResourceAsStream("/defaultFeatureConf.conf");
        try {
            LOGGER.debug("loading default configuration");
            properties.load(resourceAsStream);
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            LOGGER.error("", (Throwable) e);
        } catch (IOException e2) {
            e2.printStackTrace();
            LOGGER.error("", (Throwable) e2);
        }
        this.featureConfig = new Properties(properties);
    }

    public NETagger(File file) {
        this.featureConfig = null;
        this.featureConfig = new Properties();
        if (!file.isFile()) {
            IllegalStateException illegalStateException = new IllegalStateException("specified file for feature configuration not found!");
            LOGGER.error("", (Throwable) illegalStateException);
            throw illegalStateException;
        }
        try {
            this.featureConfig.load(new FileInputStream(file));
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            LOGGER.error("", (Throwable) e);
        } catch (IOException e2) {
            e2.printStackTrace();
            LOGGER.error("", (Throwable) e2);
        }
    }

    public boolean isTrained() {
        return this.trained;
    }

    public void train(ArrayList<Sentence> arrayList) {
        System.out.println("   * training model... on " + arrayList.size() + " sentences");
        InstanceList createFeatureData = new FeatureGenerator().createFeatureData(arrayList, this.featureConfig);
        this.generalPipe = createFeatureData.getPipe();
        LOGGER.info("  * number of features for training: " + createFeatureData.getDataAlphabet().size());
        long currentTimeMillis = System.currentTimeMillis();
        if (!this.max_ent) {
            this.model = new CRF(createFeatureData.getPipe(), (Pipe) null);
            ((CRF) this.model).addStatesForBiLabelsConnectedAsIn(createFeatureData);
            CRFTrainerByLabelLikelihood cRFTrainerByLabelLikelihood = new CRFTrainerByLabelLikelihood((CRF) this.model);
            if (this.number_iterations == 0) {
                LOGGER.info("JNET training: model converged: " + cRFTrainerByLabelLikelihood.trainOptimized(createFeatureData));
            } else {
                cRFTrainerByLabelLikelihood.train(createFeatureData, this.number_iterations);
                LOGGER.info("JNET training: with iterations = " + this.number_iterations);
            }
        } else if (this.max_ent) {
            this.dummyPipe = new SerialPipes(new Pipe[]{new METrainerDummyPipe(createFeatureData.getDataAlphabet(), createFeatureData.getTargetAlphabet())});
            InstanceList convertFeatsforClassifier = FeatureGenerator.convertFeatsforClassifier(this.dummyPipe, createFeatureData);
            LOGGER.info("train() - now training on " + createFeatureData.size() + " instances");
            MaxEntTrainer maxEntTrainer = new MaxEntTrainer();
            LOGGER.info("JNET ME training ...");
            this.model = this.number_iterations == 0 ? maxEntTrainer.train(convertFeatsforClassifier) : maxEntTrainer.train(convertFeatsforClassifier, this.number_iterations);
        }
        LOGGER.info("  * learning took (sec): " + ((System.currentTimeMillis() - currentTimeMillis) / 1000));
        this.trained = true;
    }

    public void predict(Sentence sentence, boolean z) {
        if (!this.trained || this.model == null) {
            IllegalStateException illegalStateException = new IllegalStateException("No model available. Train or load trained model first.");
            LOGGER.error("", (Throwable) illegalStateException);
            throw illegalStateException;
        }
        if (!this.max_ent) {
            Sequence sequence = (Sequence) ((Transducer) this.model).getInputPipe().instanceFrom(new Instance(sentence, "", "", "")).getData();
            Sequence transduce = ((Transducer) this.model).transduce(sequence);
            if (transduce.size() != sentence.getUnits().size()) {
                IllegalStateException illegalStateException2 = new IllegalStateException("Wrong number of labels predicted.");
                LOGGER.error("", (Throwable) illegalStateException2);
                throw illegalStateException2;
            }
            double[] segmentConfidence = z ? getSegmentConfidence(sequence, transduce) : null;
            for (int i = 0; i < sentence.getUnits().size(); i++) {
                Unit unit = sentence.get(i);
                unit.setLabel((String) transduce.get(i));
                if (z) {
                    unit.setConfidence(segmentConfidence[i]);
                }
            }
            return;
        }
        if (this.max_ent) {
            System.out.println("  * predicting with me model...");
            Classifier classifier = (Classifier) this.model;
            InstanceList convertFeatsforClassifier = FeatureGenerator.convertFeatsforClassifier(classifier.getInstancePipe(), this.generalPipe.instanceFrom(new Instance(sentence, "", "", "")));
            LOGGER.info("current sentence has this number of token features: " + convertFeatsforClassifier.size());
            ArrayList<Unit> units = sentence.getUnits();
            if (units.size() != convertFeatsforClassifier.size()) {
                LOGGER.error("precit() - something went wrong with sequence feature conversion");
                System.exit(-1);
            }
            for (int i2 = 0; i2 < convertFeatsforClassifier.size(); i2++) {
                Classification classify = classifier.classify(convertFeatsforClassifier.get(i2));
                String label = classify.getLabeling().getBestLabel().toString();
                classify.getLabeling().getBestValue();
                units.get(i2).setLabel(label);
            }
        }
    }

    public ArrayList<String> predictIOB(ArrayList<Sentence> arrayList, boolean z) {
        if (!this.trained || this.model == null) {
            IllegalStateException illegalStateException = new IllegalStateException("no model available. Train or load trained model first.");
            LOGGER.error("", (Throwable) illegalStateException);
            throw illegalStateException;
        }
        long currentTimeMillis = System.currentTimeMillis();
        ArrayList<String> arrayList2 = new ArrayList<>();
        if (!this.max_ent) {
            System.out.println("  * predicting with crf model...");
            for (int i = 0; i < arrayList.size(); i++) {
                Sentence sentence = arrayList.get(i);
                Sequence sequence = (Sequence) ((Transducer) this.model).getInputPipe().instanceFrom(new Instance(sentence, "", "", "")).getData();
                Sequence transduce = ((Transducer) this.model).transduce(sequence);
                ArrayList<Unit> units = sentence.getUnits();
                if (transduce.size() != sentence.getUnits().size()) {
                    IllegalStateException illegalStateException2 = new IllegalStateException("Wrong number of labels predicted.");
                    LOGGER.error("", (Throwable) illegalStateException2);
                    throw illegalStateException2;
                }
                double[] segmentConfidence = z ? getSegmentConfidence(sequence, transduce) : null;
                for (int i2 = 0; i2 < sentence.getUnits().size(); i2++) {
                    Unit unit = sentence.get(i2);
                    unit.setLabel((String) transduce.get(i2));
                    String str = units.get(i2).getRep() + "\t" + ((String) transduce.get(i2));
                    if (z) {
                        unit.setConfidence(segmentConfidence[i2]);
                        str = str + "\t" + segmentConfidence[i2];
                    }
                    arrayList2.add(str);
                }
                arrayList2.add("O\tO");
            }
        } else if (this.max_ent) {
            System.out.println("  * predicting with me model...");
            Classifier classifier = (Classifier) this.model;
            InstanceList instanceList = new InstanceList(this.generalPipe);
            for (int i3 = 0; i3 < arrayList.size(); i3++) {
                instanceList.add(this.generalPipe.instanceFrom(new Instance(arrayList.get(i3), "", "", "")));
            }
            for (int i4 = 0; i4 < instanceList.size(); i4++) {
                InstanceList convertFeatsforClassifier = FeatureGenerator.convertFeatsforClassifier(classifier.getInstancePipe(), instanceList.get(i4));
                LOGGER.info("current sentence has this number of token features: " + convertFeatsforClassifier.size());
                ArrayList<Unit> units2 = arrayList.get(i4).getUnits();
                if (units2.size() != convertFeatsforClassifier.size()) {
                    LOGGER.error("precit() - something went wrong with sequence feature conversion");
                    System.exit(-1);
                }
                for (int i5 = 0; i5 < convertFeatsforClassifier.size(); i5++) {
                    Classification classify = classifier.classify(convertFeatsforClassifier.get(i5));
                    String label = classify.getLabeling().getBestLabel().toString();
                    classify.getLabeling().getBestValue();
                    units2.get(i5).setLabel(label);
                    arrayList2.add(units2.get(i5).getRep() + "\t" + label);
                }
                arrayList2.add("O\tO");
            }
        }
        System.out.println("prediction took: " + (System.currentTimeMillis() - currentTimeMillis));
        return arrayList2;
    }

    private double[] getSegmentConfidence(Sequence<?> sequence, Sequence<?> sequence2) {
        double[] dArr = new double[sequence2.size()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = -1.0d;
        }
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < sequence2.size(); i2++) {
            arrayList.add((String) sequence2.get(i2));
        }
        HashMap<String, String> chunksIO = IOEvaluation.getChunksIO(arrayList);
        for (String str : chunksIO.keySet()) {
            String str2 = chunksIO.get(str).split("#")[0];
            String[] split = str.split(StringArrayPropertyEditor.DEFAULT_SEPARATOR);
            int intValue = new Integer(split[0]).intValue();
            int intValue2 = new Integer(split[1]).intValue();
            double estimateConfidenceFor = estimateConfidenceFor(new Segment(sequence, sequence2, sequence2, intValue, intValue2, str2, str2), null);
            for (int i3 = intValue; i3 <= intValue2; i3++) {
                dArr[i3] = estimateConfidenceFor;
            }
        }
        return dArr;
    }

    private double estimateConfidenceFor(Segment segment, SumLatticeDefault sumLatticeDefault) {
        Sequence predicted = segment.getPredicted();
        Sequence input = segment.getInput();
        SumLatticeDefault sumLatticeDefault2 = sumLatticeDefault == null ? new SumLatticeDefault((Transducer) this.model, input) : sumLatticeDefault;
        return Math.exp(new SumLatticeConstrained((Transducer) this.model, input, null, segment, predicted).getTotalWeight() - sumLatticeDefault2.getTotalWeight());
    }

    public void writeModel(String str) {
        if (!this.trained || this.model == null || this.featureConfig == null) {
            System.err.println("train or load trained model first.");
            System.exit(0);
        }
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(new File(str + ".gz"))));
            objectOutputStream.writeObject(new FeatureSubsetModel(this.model, this.featureConfig));
            objectOutputStream.close();
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(-1);
        }
    }

    public void readModel(File file) throws IOException, FileNotFoundException, ClassNotFoundException {
        readModel(new FileInputStream(file));
    }

    public void readModel(InputStream inputStream) throws IOException, FileNotFoundException, ClassNotFoundException {
        ObjectInputStream objectInputStream = new ObjectInputStream(new GZIPInputStream(inputStream));
        FeatureSubsetModel featureSubsetModel = (FeatureSubsetModel) objectInputStream.readObject();
        objectInputStream.close();
        this.model = featureSubsetModel.getModel();
        this.featureConfig = featureSubsetModel.getFeatureConfig();
        this.trained = true;
        if (!(this.model instanceof MaxEnt)) {
            ((Transducer) this.model).getInputPipe().getDataAlphabet().stopGrowth();
        } else {
            ((MaxEnt) this.model).getInstancePipe().getDataAlphabet().stopGrowth();
            this.max_ent = true;
        }
    }

    public Object getModel() {
        return this.model;
    }

    public void setFeatureConfig(Properties properties) {
        this.featureConfig = properties;
    }

    public Properties getFeatureConfig() {
        return this.featureConfig;
    }

    public Sentence PPDtoUnits(String str) {
        String[] split = str.trim().split("[\t ]+");
        ArrayList arrayList = new ArrayList();
        String[] trueMetas = new FeatureConfiguration().getTrueMetas(this.featureConfig);
        for (String str2 : split) {
            HashMap hashMap = new HashMap();
            String[] split2 = str2.split("\\|+");
            String str3 = split2[0];
            String str4 = split2[split2.length - 1];
            if (trueMetas.length + 2 != split2.length) {
                System.err.println("Error in input format (PipedFormat)! Mal-formatted sentence: " + str + "\n token: " + str2);
                System.err.println("Check your configuration file. Most probably you use more or less meta-data as specified in the configuration file.\nIf you don't use a config file, you should check whether your input files fit to the default configuration.");
                System.exit(-1);
            }
            for (String str5 : trueMetas) {
                int parseInt = Integer.parseInt(this.featureConfig.getProperty(str5 + "_feat_position"));
                String property = this.featureConfig.getProperty(str5 + "_feat_unit");
                if (!split2[parseInt].equals(this.featureConfig.getProperty("gap_character"))) {
                    hashMap.put(property, split2[parseInt]);
                }
            }
            arrayList.add(new Unit(0, 0, str3, str4, hashMap));
        }
        return new Sentence(arrayList);
    }

    public int getNumber_Iterations() {
        return this.number_iterations;
    }

    public void set_Number_Iterations(int i) {
        this.number_iterations = i;
    }

    public boolean is_Max_Ent() {
        return this.max_ent;
    }

    public void set_Max_Ent(boolean z) {
        this.max_ent = z;
    }
}
