package com.gengoai.hermes.tools;

import com.gengoai.LogUtils;
import com.gengoai.Stopwatch;
import com.gengoai.Validation;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.evaluation.ClassifierEvaluation;
import com.gengoai.apollo.ml.evaluation.SequenceLabelerEvaluation;
import com.gengoai.apollo.ml.model.FitParameters;
import com.gengoai.apollo.ml.model.ModelIO;
import com.gengoai.application.Option;
import com.gengoai.collection.Maps;
import com.gengoai.config.Config;
import com.gengoai.conversion.Cast;
import com.gengoai.conversion.Converter;
import com.gengoai.conversion.TypeConversionException;
import com.gengoai.hermes.corpus.DocumentCollection;
import com.gengoai.hermes.en.ENPOSTagger;
import com.gengoai.hermes.ml.ElmoNERModel;
import com.gengoai.hermes.ml.EntityTagger;
import com.gengoai.hermes.ml.HStringMLModel;
import com.gengoai.hermes.ml.PhraseChunkTagger;
import com.gengoai.io.Resources;
import com.gengoai.io.resource.StringResource;
import com.gengoai.parsing.ParseException;
import com.gengoai.string.Strings;
import com.gengoai.tuple.Tuples;
import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintStream;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.Collections;
import java.util.Map;
import java.util.logging.Logger;

/* loaded from: input_file:com/gengoai/hermes/tools/TaggerApp.class */
public class TaggerApp extends HermesCLI {
    private static final Logger log = Logger.getLogger(TaggerApp.class.getName());
    private static final DateTimeFormatter TIME_FORMATTER = DateTimeFormatter.ofPattern("yyyy/MM/dd HH:mm:ss");
    private static final Map<String, HStringMLModel> NAMED_TRAINERS = Collections.unmodifiableMap(Maps.hashMapOf(new Map.Entry[]{Tuples.$("PHRASE_CHUNK", new PhraseChunkTagger()), Tuples.$("ENTITY", new EntityTagger()), Tuples.$("TF_ENTITY", new ElmoNERModel()), Tuples.$("EN_POS", new ENPOSTagger())}));

    @Option(description = "The specification or location the corpus or document collection to process.", name = "docFormat", aliases = {"df"})
    private String documentCollectionSpec;

    @Option(description = "The name or class of the sequence tagger to train.", aliases = {"tagger"}, required = true)
    private String sequenceTagger;

    @Option(description = "Location to save model", aliases = {"m"})
    private String model;

    @Option(description = "Print a Confusion Matrix", defaultValue = "false")
    private boolean printCM;

    @Option(description = "Query to generate data", defaultValue = "")
    private String query;

    public static void main(String[] strArr) {
        new TaggerApp().run(strArr);
    }

    private DocumentCollection getDocumentCollection() {
        DocumentCollection create = DocumentCollection.create(Validation.notNullOrBlank(this.documentCollectionSpec, "No Document Collection Specified!"));
        if (Strings.isNotNullOrBlank(this.query)) {
            try {
                create = create.query(this.query);
            } catch (ParseException e) {
                throw new RuntimeException((Throwable) e);
            }
        }
        return create;
    }

    private HStringMLModel getTrainer() {
        if (NAMED_TRAINERS.containsKey(this.sequenceTagger.toUpperCase())) {
            return NAMED_TRAINERS.get(this.sequenceTagger.toUpperCase());
        }
        try {
            return (HStringMLModel) Converter.convert(this.sequenceTagger, HStringMLModel.class);
        } catch (TypeConversionException e) {
            throw new RuntimeException((Throwable) e);
        }
    }

    private void logFitParameters(FitParameters<?> fitParameters) {
        LogUtils.logInfo(log, "========================================================", new Object[0]);
        LogUtils.logInfo(log, "                 FitParameters", new Object[0]);
        LogUtils.logInfo(log, "========================================================", new Object[0]);
        for (String str : fitParameters.parameterNames()) {
            LogUtils.logInfo(log, "{0} ({1}), value={2}", new Object[]{str, fitParameters.getParam(str).type.getSimpleName(), fitParameters.get(str)});
        }
        LogUtils.logInfo(log, "========================================================", new Object[0]);
    }

    protected void programLogic() throws Exception {
        Validation.checkState(getPositionalArgs().length > 0, "No Mode specified!");
        String upperCase = getPositionalArgs()[0].toUpperCase();
        boolean z = -1;
        switch (upperCase.hashCode()) {
            case 2439:
                if (upperCase.equals("LS")) {
                    z = 3;
                    break;
                }
                break;
            case 2336926:
                if (upperCase.equals("LIST")) {
                    z = 4;
                    break;
                }
                break;
            case 2571410:
                if (upperCase.equals("TEST")) {
                    z = true;
                    break;
                }
                break;
            case 80083432:
                if (upperCase.equals("TRAIN")) {
                    z = false;
                    break;
                }
                break;
            case 943828458:
                if (upperCase.equals("PARAMETERS")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                Validation.checkState(Strings.isNotNullOrBlank(this.model), "No Model Specified!");
                ModelIO.save(train(getDocumentCollection()), Resources.from(this.model));
                return;
            case true:
                Validation.checkState(Strings.isNotNullOrBlank(this.model), "No Model Specified!");
                test(getDocumentCollection(), (HStringMLModel) Cast.as(ModelIO.load(Resources.from(this.model))));
                return;
            case true:
                logFitParameters(getTrainer().getFitParameters());
                return;
            case true:
            case true:
                NAMED_TRAINERS.forEach((str, hStringMLModel) -> {
                    LogUtils.logInfo(log, "{0}   {1}", new Object[]{str, hStringMLModel});
                });
                return;
            default:
                return;
        }
    }

    private void test(DocumentCollection documentCollection, HStringMLModel hStringMLModel) {
        LogUtils.logInfo(log, "========================================================", new Object[0]);
        LogUtils.logInfo(log, "                         TEST", new Object[0]);
        LogUtils.logInfo(log, "========================================================", new Object[0]);
        LogUtils.logInfo(log, "   Data: {0}", new Object[]{this.documentCollectionSpec});
        LogUtils.logInfo(log, " Tagger: {0}", new Object[]{this.model});
        LogUtils.logInfo(log, "========================================================", new Object[0]);
        LogUtils.logInfo(log, "Loading data set", new Object[0]);
        DataSet transform = hStringMLModel.transform(documentCollection);
        SequenceLabelerEvaluation evaluator = hStringMLModel.getEvaluator();
        Stopwatch createStarted = Stopwatch.createStarted();
        LogUtils.logInfo(log, "Testing Started at {0}", new Object[]{LocalDateTime.now().format(TIME_FORMATTER)});
        evaluator.evaluate(hStringMLModel.delegate(), transform);
        createStarted.stop();
        LogUtils.logInfo(log, "Testing Stopped at {0} ({1})", new Object[]{LocalDateTime.now().format(TIME_FORMATTER), createStarted});
        StringResource stringResource = new StringResource();
        try {
            OutputStream outputStream = stringResource.outputStream();
            try {
                PrintStream printStream = new PrintStream(outputStream);
                try {
                    if (evaluator instanceof SequenceLabelerEvaluation) {
                        evaluator.output(printStream, this.printCM);
                    } else if (evaluator instanceof ClassifierEvaluation) {
                        ((ClassifierEvaluation) evaluator).output(printStream, this.printCM);
                    } else {
                        evaluator.output(printStream);
                    }
                    printStream.close();
                    if (outputStream != null) {
                        outputStream.close();
                    }
                    try {
                        LogUtils.logInfo(log, "\n{0}", new Object[]{stringResource.readToString()});
                    } catch (IOException e) {
                        throw new RuntimeException(e);
                    }
                } catch (Throwable th) {
                    try {
                        printStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                    throw th;
                }
            } finally {
            }
        } catch (IOException e2) {
            throw new RuntimeException(e2);
        }
    }

    private HStringMLModel train(DocumentCollection documentCollection) {
        HStringMLModel trainer = getTrainer();
        FitParameters<?> fitParameters = trainer.getFitParameters();
        for (String str : fitParameters.parameterNames()) {
            String str2 = "param." + str;
            if (Config.hasProperty(str2, new Object[0])) {
                fitParameters.set(str, Config.get(str2, new Object[0]));
            }
        }
        LogUtils.logInfo(log, "========================================================", new Object[0]);
        LogUtils.logInfo(log, "                         Train", new Object[0]);
        LogUtils.logInfo(log, "========================================================", new Object[0]);
        LogUtils.logInfo(log, "   Data: {0}", new Object[]{this.documentCollectionSpec});
        if (Strings.isNotNullOrBlank(this.query)) {
            LogUtils.logInfo(log, "  Query: {0}", new Object[]{this.query});
        }
        LogUtils.logInfo(log, "Trainer: {0}", new Object[]{this.sequenceTagger});
        LogUtils.logInfo(log, "  Model: {0}", new Object[]{this.model});
        LogUtils.logInfo(log, "========================================================", new Object[0]);
        logFitParameters(fitParameters);
        Stopwatch createStarted = Stopwatch.createStarted();
        LogUtils.logInfo(log, "Training Started at {0}", new Object[]{LocalDateTime.now().format(TIME_FORMATTER)});
        trainer.estimate(documentCollection);
        createStarted.stop();
        LogUtils.logInfo(log, "Training Stopped at {0} ({1})", new Object[]{LocalDateTime.now().format(TIME_FORMATTER), createStarted});
        return trainer;
    }
}
