package eus.ixa.ixa.pipe.pos.train;

import java.io.File;
import java.io.IOException;
import opennlp.tools.cmdline.TerminateToolException;
import opennlp.tools.dictionary.Dictionary;
import opennlp.tools.postag.MutableTagDictionary;
import opennlp.tools.postag.POSEvaluator;
import opennlp.tools.postag.POSModel;
import opennlp.tools.postag.POSSample;
import opennlp.tools.postag.POSTaggerEvaluationMonitor;
import opennlp.tools.postag.POSTaggerFactory;
import opennlp.tools.postag.POSTaggerME;
import opennlp.tools.postag.TagDictionary;
import opennlp.tools.postag.WordTagSampleStream;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.TrainingParameters;

/* loaded from: input_file:eus/ixa/ixa/pipe/pos/train/AbstractTrainer.class */
public abstract class AbstractTrainer implements Trainer {
    private final String lang;
    private final ObjectStream<POSSample> trainSamples;
    private final ObjectStream<POSSample> testSamples;
    private WordTagSampleStream dictSamples;
    private final int dictCutOff;
    private final int ngramCutOff;
    private POSTaggerFactory posTaggerFactory;

    public AbstractTrainer(TrainingParameters trainingParameters) throws IOException {
        this.lang = Flags.getLanguage(trainingParameters);
        String dataSet = Flags.getDataSet("TrainSet", trainingParameters);
        String dataSet2 = Flags.getDataSet("TestSet", trainingParameters);
        this.trainSamples = new WordTagSampleStream(InputOutputUtils.readFileIntoMarkableStreamFactory(dataSet));
        this.testSamples = new WordTagSampleStream(InputOutputUtils.readFileIntoMarkableStreamFactory(dataSet2));
        setDictSamples(new WordTagSampleStream(InputOutputUtils.readFileIntoMarkableStreamFactory(dataSet)));
        this.dictCutOff = Flags.getAutoDictFeatures(trainingParameters).intValue();
        this.ngramCutOff = Flags.getNgramDictFeatures(trainingParameters).intValue();
    }

    @Override // eus.ixa.ixa.pipe.pos.train.Trainer
    public final POSModel train(TrainingParameters trainingParameters) {
        if (getPosTaggerFactory() == null) {
            throw new IllegalStateException("Classes derived from AbstractTrainer must  create a POSTaggerFactory features!");
        }
        POSModel pOSModel = null;
        POSEvaluator pOSEvaluator = null;
        try {
            pOSModel = POSTaggerME.train(this.lang, this.trainSamples, trainingParameters, getPosTaggerFactory());
            pOSEvaluator = new POSEvaluator(new POSTaggerME(pOSModel), new POSTaggerEvaluationMonitor[0]);
            pOSEvaluator.evaluate(this.testSamples);
        } catch (IOException e) {
            System.err.println("IO error while loading traing and test sets!");
            e.printStackTrace();
            System.exit(1);
        }
        System.out.println("Final result: " + pOSEvaluator.getWordAccuracy());
        return pOSModel;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final void createTagDictionary(String str) {
        if (str.equalsIgnoreCase("off")) {
            return;
        }
        try {
            getPosTaggerFactory().setTagDictionary(getPosTaggerFactory().createTagDictionary(new File(str)));
        } catch (IOException e) {
            throw new TerminateToolException(-1, "IO error while loading POS Dictionary: " + e.getMessage(), e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final void createAutomaticDictionary(ObjectStream<POSSample> objectStream, int i) {
        if (i != -1) {
            try {
                TagDictionary tagDictionary = getPosTaggerFactory().getTagDictionary();
                if (tagDictionary == null) {
                    tagDictionary = getPosTaggerFactory().createEmptyTagDictionary();
                    getPosTaggerFactory().setTagDictionary(tagDictionary);
                }
                if (!(tagDictionary instanceof MutableTagDictionary)) {
                    throw new IllegalArgumentException("Can't extend a POSDictionary that does not implement MutableTagDictionary.");
                }
                POSTaggerME.populatePOSDictionary(objectStream, (MutableTagDictionary) tagDictionary, i);
                this.dictSamples.reset();
            } catch (IOException e) {
                throw new TerminateToolException(-1, "IO error while creating/extending POS Dictionary: " + e.getMessage(), e);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final Dictionary createNgramDictionary(ObjectStream<POSSample> objectStream, int i) {
        Dictionary dictionary = null;
        if (i != -1) {
            System.err.print("Building ngram dictionary ... ");
            try {
                dictionary = POSTaggerME.buildNGramDictionary(objectStream, i);
                this.dictSamples.reset();
                System.err.println("done");
            } catch (IOException e) {
                throw new TerminateToolException(-1, "IO error while building NGram Dictionary: " + e.getMessage(), e);
            }
        }
        return dictionary;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final WordTagSampleStream getDictSamples() {
        return this.dictSamples;
    }

    protected final void setDictSamples(WordTagSampleStream wordTagSampleStream) {
        this.dictSamples = wordTagSampleStream;
    }

    protected final POSTaggerFactory getPosTaggerFactory() {
        return this.posTaggerFactory;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final void setPosTaggerFactory(POSTaggerFactory pOSTaggerFactory) {
        this.posTaggerFactory = pOSTaggerFactory;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final Integer getDictCutOff() {
        return Integer.valueOf(this.dictCutOff);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final Integer getNgramDictCutOff() {
        return Integer.valueOf(this.ngramCutOff);
    }
}
