package gate.plugin.learningframework;

import gate.AnnotationSet;
import gate.Controller;
import gate.Document;
import gate.creole.metadata.CreoleParameter;
import gate.creole.metadata.CreoleResource;
import gate.creole.metadata.Optional;
import gate.creole.metadata.RunTime;
import gate.plugin.learningframework.engines.AlgorithmKind;
import gate.plugin.learningframework.engines.Engine;
import gate.plugin.learningframework.engines.EngineMBServer;
import gate.util.GateRuntimeException;
import java.net.URL;
import java.util.List;
import org.apache.log4j.Logger;

@CreoleResource(name = "LF_ApplyClassification", helpURL = "https://gatenlp.github.io/gateplugin-LearningFramework/LF_ApplyClassification", comment = "Apply a trained classification model to documents")
/* loaded from: input_file:gate/plugin/learningframework/LF_ApplyClassification.class */
public class LF_ApplyClassification extends LearningFrameworkPRBase {
    static final Logger LOGGER = Logger.getLogger(LF_ApplyClassification.class.getCanonicalName());
    private static final long serialVersionUID = -754439854542759988L;
    protected URL dataDirectory;
    protected URL oldDataDirectory;
    protected String outputASName;
    private Double confidenceThreshold;
    protected String targetFeature;
    private String sequenceSpan;
    protected String serverUrl;
    private Engine engine;
    private String targetFeatureToUse;

    @CreoleParameter(comment = "The directory where all data will be stored and read from")
    @RunTime
    public void setDataDirectory(URL url) {
        this.dataDirectory = url;
    }

    public URL getDataDirectory() {
        return this.dataDirectory;
    }

    @CreoleParameter(comment = "If identical to the input AS, existing annotations are updated", defaultValue = "LearningFramework")
    @RunTime
    @Optional
    public void setOutputASName(String str) {
        this.outputASName = str;
    }

    public String getOutputASName() {
        return this.outputASName;
    }

    @CreoleParameter(comment = "The minimum confidence/probability for including an annotation at application time. If empty, ignore.")
    @RunTime
    @Optional
    public void setConfidenceThreshold(Double d) {
        this.confidenceThreshold = d;
    }

    public Double getConfidenceThreshold() {
        return this.confidenceThreshold;
    }

    @CreoleParameter(comment = "Name of class feature to add to the original instance annotations. Default is the name that was used for training.", defaultValue = "")
    @RunTime
    @Optional
    public void setTargetFeature(String str) {
        this.targetFeature = str;
    }

    public String getTargetFeature() {
        return this.targetFeature;
    }

    @CreoleParameter(comment = "For sequence learners, an annotation type defining a meaningful sequence span. Ignored by non-sequence learners. ")
    @RunTime
    @Optional
    public void setSequenceSpan(String str) {
        this.sequenceSpan = str;
    }

    public String getSequenceSpan() {
        return this.sequenceSpan;
    }

    @CreoleParameter(comment = "Classify from a server instead of a stored model")
    @RunTime
    @Optional
    public void setServerUrl(String str) {
        this.serverUrl = str;
    }

    public String getServerUrl() {
        return this.serverUrl;
    }

    @Override // gate.plugin.learningframework.AbstractDocumentProcessor
    public Document process(Document document) {
        if (isInterrupted()) {
            this.interrupted = false;
            throw new GateRuntimeException("Execution was requested to be interrupted");
        }
        AnnotationSet annotations = document.getAnnotations(getInputASName());
        AnnotationSet annotationSet = annotations.get(getInstanceType());
        AnnotationSet annotationSet2 = null;
        if (this.engine.getAlgorithm().getAlgorithmKind() == AlgorithmKind.SEQUENCE_TAGGER) {
            annotationSet2 = annotations.get(getSequenceSpan());
        }
        if (annotationSet2 == null) {
            System.err.println("DEBUG: classifying doc " + document.getName() + " instanceAS:" + annotationSet.size() + ", inputAS:" + annotations.size());
        } else {
            System.err.println("DEBUG: classifying doc " + document.getName() + " instanceAS:" + annotationSet.size() + ", inputAS:" + annotations.size() + ", sequenceAS:" + annotationSet2.size());
        }
        List<ModelApplication> applyModel = this.engine.applyModel(annotationSet, annotations, annotationSet2, getAlgorithmParameters());
        AnnotationSet annotationSet3 = null;
        if ((getOutputASName() != null || getInputASName() != null) && (getOutputASName() == null || getInputASName() == null || !getOutputASName().equals(getInputASName()))) {
            annotationSet3 = document.getAnnotations(getOutputASName());
        }
        ModelApplication.applyClassification(document, applyModel, this.targetFeatureToUse, annotationSet3, getConfidenceThreshold());
        return document;
    }

    @Override // gate.plugin.learningframework.AbstractDocumentProcessor
    protected void beforeFirstDocument(Controller controller) {
        if (this.serverUrl == null || this.serverUrl.isEmpty()) {
            if (this.engine == null || !this.dataDirectory.toString().equals(this.oldDataDirectory.toString()) || getAlgorithmParametersIsChanged()) {
                this.oldDataDirectory = this.dataDirectory;
                this.engine = Engine.load(this.dataDirectory, getAlgorithmParameters());
            }
            System.out.println("LF-Info: loaded model is " + this.engine);
            if (this.engine.getModel() == null) {
                System.err.println("WARNING: no internal model to apply, this is ok if an external model is used");
            } else {
                System.out.println("LearningFramework: Applying model " + this.engine.getModel().getClass() + " ...");
            }
            if (this.engine.getAlgorithm().getAlgorithmKind() == AlgorithmKind.SEQUENCE_TAGGER && (getSequenceSpan() == null || getSequenceSpan().isEmpty())) {
                throw new GateRuntimeException("sequenceSpan parameter must not be empty when a sequence tagging algorithm is used for classification");
            }
        } else {
            if (getSequenceSpan() != null && !getSequenceSpan().isEmpty()) {
                throw new GateRuntimeException("Sequence span not supported for server");
            }
            this.engine = new EngineMBServer(this.dataDirectory, this.serverUrl);
        }
        if (getTargetFeature() == null || getTargetFeature().isEmpty()) {
            String str = this.engine.getInfo().targetFeature;
            if (str == null || str.isEmpty()) {
                throw new GateRuntimeException("Not targetFeature parameter specified and none available from the model info file either.");
            }
            this.targetFeatureToUse = str;
            LOGGER.warn("Using target feature name from model: " + this.targetFeatureToUse);
        } else {
            this.targetFeatureToUse = getTargetFeature();
            LOGGER.warn("Using target feature name from PR parameter: " + this.targetFeatureToUse);
        }
        LOGGER.debug("Parameter confidenceThreshold not given, not using confidence threshold");
    }

    @Override // gate.plugin.learningframework.AbstractDocumentProcessor
    public void afterLastDocument(Controller controller, Throwable th) {
    }

    @Override // gate.plugin.learningframework.AbstractDocumentProcessor
    public void finishedNoDocument(Controller controller, Throwable th) {
    }
}
