package gate.plugin.learningframework;

import gate.Annotation;
import gate.AnnotationSet;
import gate.Controller;
import gate.Document;
import gate.FeatureMap;
import gate.creole.metadata.CreoleParameter;
import gate.creole.metadata.CreoleResource;
import gate.creole.metadata.Optional;
import gate.creole.metadata.RunTime;
import gate.plugin.learningframework.data.CorpusRepresentation;
import gate.plugin.learningframework.data.CorpusRepresentationMallet;
import gate.plugin.learningframework.engines.AlgorithmClassification;
import gate.plugin.learningframework.engines.AlgorithmKind;
import gate.plugin.learningframework.engines.Engine;
import gate.plugin.learningframework.features.FeatureInfo;
import gate.plugin.learningframework.features.FeatureSpecification;
import gate.plugin.learningframework.features.TargetType;
import gate.util.Files;
import gate.util.GateRuntimeException;
import java.io.File;
import java.net.URL;
import org.apache.commons.clipatched.HelpFormatter;
import org.apache.log4j.Logger;

@CreoleResource(name = "LF_TrainClassification", helpURL = "https://gatenlp.github.io/gateplugin-LearningFramework/LF_TrainClassification", comment = "Train a machine learning model for classification")
/* loaded from: input_file:gate/plugin/learningframework/LF_TrainClassification.class */
public class LF_TrainClassification extends LearningFrameworkPRBase {
    private static final long serialVersionUID = 4218101157699142046L;
    protected URL dataDirectory;
    private URL featureSpecURL;
    private AlgorithmClassification trainingAlgorithm;
    protected String targetFeature;
    protected String sequenceSpan;
    private int nrDocuments;
    private File dataDirFile;
    private final Logger logger = Logger.getLogger(LF_TrainClassification.class.getCanonicalName());
    protected String instanceWeightFeature = "";
    protected ScalingMethod scaleFeatures = ScalingMethod.NONE;
    private CorpusRepresentation corpusRepresentation = null;
    private FeatureSpecification featureSpec = null;
    private Engine engine = null;

    @CreoleParameter(comment = "The directory where all data will be stored and read from")
    @RunTime
    public void setDataDirectory(URL url) {
        this.dataDirectory = url;
    }

    public URL getDataDirectory() {
        return this.dataDirectory;
    }

    @CreoleParameter(comment = "The feature that constains the instance weight. If empty, no instance weights are used", defaultValue = "")
    @RunTime
    @Optional
    public void setInstanceWeightFeature(String str) {
        this.instanceWeightFeature = str;
    }

    public String getInstanceWeightFeature() {
        return this.instanceWeightFeature;
    }

    @CreoleParameter(comment = "The feature specification file.")
    @RunTime
    public void setFeatureSpecURL(URL url) {
        this.featureSpecURL = url;
    }

    public URL getFeatureSpecURL() {
        return this.featureSpecURL;
    }

    @CreoleParameter(comment = "The algorithm to be used for training the classifier")
    @RunTime
    @Optional
    public void setTrainingAlgorithm(AlgorithmClassification algorithmClassification) {
        this.trainingAlgorithm = algorithmClassification;
    }

    public AlgorithmClassification getTrainingAlgorithm() {
        return this.trainingAlgorithm;
    }

    @CreoleParameter(defaultValue = "NONE", comment = "If and how to scale features. ")
    @RunTime
    public void setScaleFeatures(ScalingMethod scalingMethod) {
        this.scaleFeatures = scalingMethod;
    }

    public ScalingMethod getScaleFeatures() {
        return this.scaleFeatures;
    }

    @CreoleParameter(comment = "The feature containing the class label")
    @RunTime
    @Optional
    public void setTargetFeature(String str) {
        this.targetFeature = str;
    }

    public String getTargetFeature() {
        return this.targetFeature;
    }

    @CreoleParameter(comment = "For sequence learners, an annotation type defining a meaningful sequence span. Ignored by non-sequence learners.")
    @RunTime
    @Optional
    public void setSequenceSpan(String str) {
        this.sequenceSpan = str;
    }

    public String getSequenceSpan() {
        return this.sequenceSpan;
    }

    @Override // gate.plugin.learningframework.AbstractDocumentProcessor
    public Document process(Document document) {
        if (isInterrupted()) {
            this.interrupted = false;
            throw new GateRuntimeException("Execution was requested to be interrupted");
        }
        this.nrDocuments++;
        AnnotationSet annotations = document.getAnnotations(getInputASName());
        AnnotationSet<Annotation> annotationSet = annotations.get(getInstanceType());
        AnnotationSet annotationSet2 = getTrainingAlgorithm().getAlgorithmKind() == AlgorithmKind.SEQUENCE_TAGGER ? annotations.get(getSequenceSpan()) : null;
        for (Annotation annotation : annotationSet) {
            FeatureMap features = annotation.getFeatures();
            Object obj = features.get(getTargetFeature());
            if (null == obj) {
                throw new GateRuntimeException("Target value is null in document " + this.document.getName() + " for instance " + annotation);
            }
            features.put("gate.LF.target", obj);
        }
        this.corpusRepresentation.add(annotationSet, annotationSet2, annotations, null, getTargetFeature(), TargetType.NOMINAL, this.instanceWeightFeature, null, null);
        return document;
    }

    @Override // gate.plugin.learningframework.AbstractDocumentProcessor
    protected void beforeFirstDocument(Controller controller) {
        if (!"file".equals(this.dataDirectory.getProtocol())) {
            throw new GateRuntimeException("Training is only possible if the dataDirectory URL is a file: URL");
        }
        this.dataDirFile = Files.fileFromURL(this.dataDirectory);
        if (!this.dataDirFile.exists()) {
            throw new GateRuntimeException("Data directory not found: " + this.dataDirFile.getAbsolutePath());
        }
        if (getTrainingAlgorithm() == null) {
            throw new GateRuntimeException("LearningFramework: no training algorithm specified");
        }
        if (getTrainingAlgorithm().getAlgorithmKind() == AlgorithmKind.SEQUENCE_TAGGER) {
            if (getSequenceSpan() == null || getSequenceSpan().isEmpty()) {
                throw new GateRuntimeException("SequenceSpan parameter is required for sequence tagging algorithm");
            }
        } else if (getSequenceSpan() != null && !getSequenceSpan().isEmpty()) {
            throw new GateRuntimeException("SequenceSpan parameter must not be specified with non-sequence tagging algorithm");
        }
        AlgorithmClassification trainingAlgorithm = getTrainingAlgorithm();
        System.err.println("DEBUG: Before Document.");
        System.err.println("  Training algorithm engine class is " + trainingAlgorithm.getEngineClass());
        System.err.println("  Training algorithm algor class is " + trainingAlgorithm.getTrainerClass());
        this.featureSpec = new FeatureSpecification(this.featureSpecURL);
        System.err.println("DEBUG Read the feature specification: " + this.featureSpec);
        FeatureInfo featureInfo = this.featureSpec.getFeatureInfo();
        featureInfo.setGlobalScalingMethod(this.scaleFeatures);
        this.engine = Engine.create(this.trainingAlgorithm, getAlgorithmParameters(), featureInfo, TargetType.NOMINAL, this.dataDirectory);
        this.corpusRepresentation = this.engine.getCorpusRepresentation();
        System.err.println("DEBUG: created the engine: " + this.engine + " with CR=" + this.corpusRepresentation);
        this.nrDocuments = 0;
        System.err.println("DEBUG: setup of the training PR complete");
    }

    @Override // gate.plugin.learningframework.AbstractDocumentProcessor
    public void afterLastDocument(Controller controller, Throwable th) {
        if (th != null) {
            System.err.println("An exception occurred during processing of documents, no training will be done");
            System.err.println("Exception was " + th.getClass() + ": " + th.getMessage());
            return;
        }
        System.out.println("LearningFramework: Starting training engine " + this.engine);
        if (this.corpusRepresentation instanceof CorpusRepresentationMallet) {
            CorpusRepresentationMallet corpusRepresentationMallet = (CorpusRepresentationMallet) this.corpusRepresentation;
            System.out.println("Training set classes: " + corpusRepresentationMallet.getRepresentationMallet().getPipe().getTargetAlphabet().toString().replaceAll("\\n", HelpFormatter.DEFAULT_LONG_OPT_SEPARATOR));
            System.out.println("Training set size: " + corpusRepresentationMallet.getRepresentationMallet().size());
            if (corpusRepresentationMallet.getRepresentationMallet().getDataAlphabet().size() > 20) {
                System.out.println("LearningFramework: Attributes " + corpusRepresentationMallet.getRepresentationMallet().getDataAlphabet().size());
            } else {
                System.out.println("LearningFramework: Attributes " + corpusRepresentationMallet.getRepresentationMallet().getDataAlphabet().toString().replaceAll("\\n", HelpFormatter.DEFAULT_LONG_OPT_SEPARATOR));
            }
        }
        this.engine.getInfo().nrTrainingInstances = this.corpusRepresentation.nrInstances();
        this.engine.getInfo().nrTrainingDocuments = this.nrDocuments;
        this.engine.getInfo().targetFeature = getTargetFeature();
        this.engine.getInfo().trainingCorpusName = this.corpus.getName();
        this.engine.trainModel(Files.fileFromURL(this.dataDirectory), getInstanceType(), getAlgorithmParameters());
        this.logger.info("LearningFramework: Training complete!");
        this.engine.saveEngine(this.dataDirFile);
    }

    @Override // gate.plugin.learningframework.AbstractDocumentProcessor
    protected void finishedNoDocument(Controller controller, Throwable th) {
        this.logger.error("Processing finished, but got an error, no documents seen, or the PR was disabled in the pipeline - cannot train!");
    }
}
