package de.julielab.jcore.ae.jtbd;

import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFTrainerByLabelLikelihood;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.TokenSequence2FeatureVectorSequence;
import cc.mallet.pipe.tsf.OffsetConjunctions;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.Sequence;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/julielab/jcore/ae/jtbd/Tokenizer.class */
public class Tokenizer {
    private static final Logger LOGGER = LoggerFactory.getLogger(Tokenizer.class);
    CRF model;
    private boolean trained;

    public Tokenizer() {
        this.model = null;
        this.trained = false;
        LOGGER.debug("this is the JTBD constuctor");
        this.model = null;
        this.trained = false;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ArrayList<String> getLabelsFromLabelSequence(LabelSequence labelSequence) {
        ArrayList<String> arrayList = new ArrayList<>();
        for (int i = 0; i < labelSequence.size(); i++) {
            arrayList.add((String) labelSequence.get(i));
        }
        return arrayList;
    }

    public CRF getModel() {
        return this.model;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public InstanceList makePredictionData(List<String> list, List<String> list2) {
        LOGGER.debug("makePredictionData() - making prediction data");
        InstanceList instanceList = new InstanceList(this.model.getInputPipe());
        for (int i = 0; i < list.size(); i++) {
            Instance makePredictionData = makePredictionData(new StringBuffer(list.get(i)), new StringBuffer(list2.get(i)));
            if (!(makePredictionData.getSource() instanceof String)) {
                instanceList.add(makePredictionData);
            }
        }
        return instanceList;
    }

    private Instance makePredictionData(StringBuffer stringBuffer, StringBuffer stringBuffer2) {
        if (stringBuffer2.length() > 0 && EOSSymbols.contains(Character.valueOf(stringBuffer2.charAt(stringBuffer2.length() - 1)))) {
            stringBuffer2.deleteCharAt(stringBuffer2.length() - 1);
        }
        if (stringBuffer.length() > 0 && EOSSymbols.contains(Character.valueOf(stringBuffer.charAt(stringBuffer.length() - 1)))) {
            stringBuffer.deleteCharAt(stringBuffer.length() - 1);
        }
        Instance instance = null;
        try {
            instance = this.model.getInputPipe().instanceFrom(new Instance(stringBuffer.toString(), (Object) null, (Object) null, stringBuffer2.toString()));
        } catch (NoSuchMethodError e) {
            e.printStackTrace();
            System.exit(0);
        }
        return instance;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Type inference failed for: r7v1, types: [int[], int[][]] */
    public InstanceList makeTrainingData(List<String> list, List<String> list2) {
        LOGGER.debug("makeTrainingData() - making training data...");
        LabelAlphabet labelAlphabet = new LabelAlphabet();
        labelAlphabet.lookupLabel("P", true);
        labelAlphabet.lookupLabel("N", true);
        SerialPipes serialPipes = new SerialPipes(new Pipe[]{new Sentence2TokenPipe(), new OffsetConjunctions((int[][]) new int[]{new int[]{-1}, new int[]{1}}), new TokenSequence2FeatureVectorSequence(true, true)});
        InstanceList instanceList = new InstanceList(serialPipes);
        System.out.print("preparing training data...");
        for (int i = 0; i < list.size(); i++) {
            StringBuffer stringBuffer = new StringBuffer(list.get(i).trim());
            StringBuffer stringBuffer2 = new StringBuffer(list2.get(i).trim());
            if (EOSSymbols.contains(Character.valueOf(stringBuffer2.charAt(stringBuffer2.length() - 1)))) {
                stringBuffer2.deleteCharAt(stringBuffer2.length() - 1);
            }
            if (EOSSymbols.contains(Character.valueOf(stringBuffer.charAt(stringBuffer.length() - 1)))) {
                stringBuffer.deleteCharAt(stringBuffer.length() - 1);
            }
            instanceList.addThruPipe(new Instance(stringBuffer.toString(), "", new Integer(i), stringBuffer2.toString()));
        }
        LOGGER.debug("makeTrainingData() -  number of features on training data: " + serialPipes.getDataAlphabet().size());
        return instanceList;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ArrayList<Unit> predict(Instance instance) {
        if (!this.trained || this.model == null) {
            throw new IllegalStateException("No model available. Train or load trained model first.");
        }
        ArrayList<Unit> arrayList = (ArrayList) instance.getName();
        if (arrayList.size() > 0) {
            Sequence transduce = this.model.transduce((Sequence) instance.getData());
            for (int i = 0; i < transduce.size(); i++) {
                arrayList.get(i).label = (String) transduce.get(i);
            }
        }
        return arrayList;
    }

    public ArrayList<Unit> predict(String str) {
        LOGGER.debug("predict() - before pedicting labels ...");
        if (!this.trained || this.model == null) {
            throw new IllegalStateException("No model available. Train or load trained model first.");
        }
        LOGGER.debug("predict() - now making pedictions ...");
        Instance makePredictionData = makePredictionData(new StringBuffer(str), new StringBuffer(""));
        LOGGER.debug("predict() - after pedicting labels ...");
        return predict(makePredictionData);
    }

    public void readModel(File file) throws IOException, FileNotFoundException, ClassNotFoundException {
        readModel(new FileInputStream(file));
    }

    public void readModel(InputStream inputStream) throws IOException, ClassNotFoundException {
        ObjectInputStream objectInputStream = new ObjectInputStream(new GZIPInputStream(inputStream));
        this.model = (CRF) objectInputStream.readObject();
        this.trained = true;
        this.model.getInputPipe().getDataAlphabet().stopGrowth();
        objectInputStream.close();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setModel(CRF crf) {
        this.trained = true;
        this.model = crf;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public String showErrorContext(int i, ArrayList<Unit> arrayList, ArrayList<String> arrayList2) {
        String str = "";
        String str2 = "";
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            if (i2 >= i - 2 && i2 <= i + 2) {
                String str3 = arrayList2.get(i2).equals("P") ? " " : "";
                String str4 = arrayList.get(i2).label.equals("P") ? " " : "";
                str = str + arrayList.get(i2).rep + str3;
                str2 = str2 + arrayList.get(i2).rep + str4;
            }
        }
        return str2 + "\n" + str + "\n";
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void train(InstanceList instanceList, Pipe pipe) {
        long currentTimeMillis = System.currentTimeMillis();
        this.model = new CRF(pipe, (Pipe) null);
        this.model.addStatesForLabelsConnectedAsIn(instanceList);
        LOGGER.info("Tokenizer training: model converged: " + new CRFTrainerByLabelLikelihood(this.model).train(instanceList));
        long currentTimeMillis2 = System.currentTimeMillis();
        this.model.getInputPipe().getDataAlphabet().stopGrowth();
        this.trained = true;
        LOGGER.debug("train() - training time: " + ((currentTimeMillis2 - currentTimeMillis) / 1000) + " sec");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void writeModel(String str) {
        if (!this.trained || this.model == null) {
            throw new IllegalStateException("train or load trained model first.");
        }
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(new File(str + ".gz"))));
            objectOutputStream.writeObject(this.model);
            objectOutputStream.close();
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(0);
        }
    }
}
