package gate.plugin.learningframework;

import gate.AnnotationSet;
import gate.Controller;
import gate.Document;
import gate.creole.metadata.CreoleParameter;
import gate.creole.metadata.CreoleResource;
import gate.creole.metadata.Optional;
import gate.creole.metadata.RunTime;
import gate.plugin.learningframework.data.CorpusRepresentation;
import gate.plugin.learningframework.data.CorpusRepresentationMallet;
import gate.plugin.learningframework.engines.AlgorithmRegression;
import gate.plugin.learningframework.engines.Engine;
import gate.plugin.learningframework.engines.EvaluationResult;
import gate.plugin.learningframework.engines.EvaluationResultRegression;
import gate.plugin.learningframework.features.FeatureInfo;
import gate.plugin.learningframework.features.FeatureSpecification;
import gate.plugin.learningframework.features.TargetType;
import gate.util.GateRuntimeException;
import java.net.URL;
import org.apache.commons.clipatched.HelpFormatter;
import org.apache.log4j.Logger;

@CreoleResource(name = "LF_EvaluateRegression", helpURL = "https://gatenlp.github.io/gateplugin-LearningFramework/LF_EvaluateRegression", comment = "Evaluate an algorithm and parameter settings for regression")
/* loaded from: input_file:gate/plugin/learningframework/LF_EvaluateRegression.class */
public class LF_EvaluateRegression extends LearningFrameworkPRBase {
    private static final long serialVersionUID = -4216855026883354L;
    private URL featureSpecURL;
    private AlgorithmRegression trainingAlgorithm;
    protected String targetFeature;
    protected String sequenceSpan;
    private int nrDocuments;
    private URL dataDirURL;
    private final Logger logger = Logger.getLogger(LF_EvaluateRegression.class.getCanonicalName());
    protected ScalingMethod scaleFeatures = ScalingMethod.NONE;
    private CorpusRepresentation corpusRepresentation = null;
    private FeatureSpecification featureSpec = null;
    private Engine engine = null;
    protected EvaluationMethod evaluationMethod = EvaluationMethod.CROSSVALIDATION;
    protected int numberOfFolds = 10;
    protected double trainingFraction = 0.6667d;
    protected int numberOfRepeats = 1;
    protected String instanceWeightFeature = "";

    @CreoleParameter(comment = "The feature specification file.")
    @RunTime
    public void setFeatureSpecURL(URL url) {
        this.featureSpecURL = url;
    }

    public URL getFeatureSpecURL() {
        return this.featureSpecURL;
    }

    @CreoleParameter(comment = "The algorithm to be used for training the classifier")
    @RunTime
    @Optional
    public void setTrainingAlgorithm(AlgorithmRegression algorithmRegression) {
        this.trainingAlgorithm = algorithmRegression;
    }

    public AlgorithmRegression getTrainingAlgorithm() {
        return this.trainingAlgorithm;
    }

    @CreoleParameter(defaultValue = "NONE", comment = "If and how to scale features. ")
    @RunTime
    public void setScaleFeatures(ScalingMethod scalingMethod) {
        this.scaleFeatures = scalingMethod;
    }

    public ScalingMethod getScaleFeatures() {
        return this.scaleFeatures;
    }

    @CreoleParameter(comment = "The feature containing the target label")
    @RunTime
    @Optional
    public void setTargetFeature(String str) {
        this.targetFeature = str;
    }

    public String getTargetFeature() {
        return this.targetFeature;
    }

    @CreoleParameter(comment = "Evaluation Method, not all algorithms may support all methods", defaultValue = "CROSSVALIDATION")
    @RunTime
    @Optional
    public void setEvaluationMethod(EvaluationMethod evaluationMethod) {
        this.evaluationMethod = evaluationMethod;
    }

    public EvaluationMethod getEvaluationMethod() {
        return this.evaluationMethod;
    }

    @CreoleParameter(comment = "Number of folds for the cross validation", defaultValue = "10")
    @RunTime
    @Optional
    public void setNumberOfFolds(Integer num) {
        if (num.intValue() < 2) {
            throw new GateRuntimeException("numberOfFolds must be > 1");
        }
        this.numberOfFolds = num.intValue();
    }

    public Integer getNumberOfFolds() {
        return Integer.valueOf(this.numberOfFolds);
    }

    @CreoleParameter(comment = "Fraction of instances to use for training, > 0.0 and < 1.0", defaultValue = "0.6667")
    @RunTime
    @Optional
    public void setTrainingFraction(Double d) {
        if (d.doubleValue() <= 0.0d || d.doubleValue() >= 1.0d) {
            throw new GateRuntimeException("trainingFraction must be > 0.0 and < 1.0");
        }
        this.trainingFraction = d.doubleValue();
    }

    public Double getTrainingFraction() {
        return Double.valueOf(this.trainingFraction);
    }

    @CreoleParameter(comment = "Number of times to perform holdout evaluation to get an average", defaultValue = "1")
    @RunTime
    @Optional
    public void setNumberOfRepeats(Integer num) {
        this.numberOfRepeats = num.intValue();
    }

    public Integer getNumberOfRepeats() {
        return Integer.valueOf(this.numberOfRepeats);
    }

    @CreoleParameter(comment = "The feature that constains the instance weight. If empty, no instance weights are used", defaultValue = "")
    @RunTime
    @Optional
    public void setInstanceWeightFeature(String str) {
        this.instanceWeightFeature = str;
    }

    public String getInstanceWeightFeature() {
        return this.instanceWeightFeature;
    }

    @Override // gate.plugin.learningframework.AbstractDocumentProcessor
    public Document process(Document document) {
        if (isInterrupted()) {
            this.interrupted = false;
            throw new GateRuntimeException("Execution was requested to be interrupted");
        }
        AnnotationSet annotations = document.getAnnotations(getInputASName());
        this.corpusRepresentation.add(annotations.get(getInstanceType()), null, annotations, null, getTargetFeature(), TargetType.NUMERIC, this.instanceWeightFeature, null, null);
        this.nrDocuments++;
        return document;
    }

    @Override // gate.plugin.learningframework.AbstractDocumentProcessor
    protected void beforeFirstDocument(Controller controller) {
        if (getTrainingAlgorithm() == null) {
            throw new GateRuntimeException("LF_EvaluateRegression: no training algorithm specified");
        }
        if (getTargetFeature() == null || getTargetFeature().isEmpty()) {
            throw new GateRuntimeException("LF_EvaluateRegression: no target feature specified");
        }
        AlgorithmRegression trainingAlgorithm = getTrainingAlgorithm();
        System.err.println("DEBUG: Before Document.");
        System.err.println("  Training algorithm engine class is " + trainingAlgorithm.getEngineClass());
        System.err.println("  Training algorithm algor class is " + trainingAlgorithm.getTrainerClass());
        this.featureSpec = new FeatureSpecification(this.featureSpecURL);
        System.err.println("DEBUG Read the feature specification: " + this.featureSpec);
        FeatureInfo featureInfo = this.featureSpec.getFeatureInfo();
        featureInfo.setGlobalScalingMethod(this.scaleFeatures);
        this.engine = Engine.create(this.trainingAlgorithm, getAlgorithmParameters(), featureInfo, TargetType.NUMERIC, this.dataDirURL);
        System.err.println("DEBUG: created the engine: " + this.engine);
        this.corpusRepresentation = this.engine.getCorpusRepresentation();
        System.err.println("DEBUG: created the corpusRepresentationMallet: " + this.corpusRepresentation);
        this.nrDocuments = 0;
        System.err.println("DEBUG: setup of the training PR complete");
    }

    @Override // gate.plugin.learningframework.AbstractDocumentProcessor
    public void afterLastDocument(Controller controller, Throwable th) {
        System.out.println("LearningFramework: Starting evaluating engine " + this.engine);
        if (this.corpusRepresentation instanceof CorpusRepresentationMallet) {
            CorpusRepresentationMallet corpusRepresentationMallet = (CorpusRepresentationMallet) this.corpusRepresentation;
            System.out.println("Training set size: " + corpusRepresentationMallet.getRepresentationMallet().size());
            if (corpusRepresentationMallet.getRepresentationMallet().getDataAlphabet().size() > 20) {
                System.out.println("LearningFramework: Attributes " + corpusRepresentationMallet.getRepresentationMallet().getDataAlphabet().size());
            } else {
                System.out.println("LearningFramework: Attributes " + corpusRepresentationMallet.getRepresentationMallet().getDataAlphabet().toString().replaceAll("\\n", HelpFormatter.DEFAULT_LONG_OPT_SEPARATOR));
            }
        }
        EvaluationResult evaluate = this.engine.evaluate(getAlgorithmParameters(), this.evaluationMethod, this.numberOfFolds, this.trainingFraction, this.numberOfRepeats);
        this.logger.info("LearningFramework: Evaluation complete!");
        this.logger.info(evaluate);
        if (getCorpus() == null || !(evaluate instanceof EvaluationResultRegression)) {
            return;
        }
        getCorpus().getFeatures().put("LearningFramework.rmse", Double.valueOf(((EvaluationResultRegression) evaluate).rmse));
    }

    @Override // gate.plugin.learningframework.AbstractDocumentProcessor
    protected void finishedNoDocument(Controller controller, Throwable th) {
        this.logger.error("Processing finished, but no documents seen, cannot train!");
    }
}
