package gate.plugin.learningframework.engines;

import cc.mallet.classify.BalancedWinnowTrainer;
import cc.mallet.classify.C45Trainer;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.DecisionTreeTrainer;
import cc.mallet.classify.MaxEntTrainer;
import cc.mallet.classify.WinnowTrainer;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Labeling;
import gate.Annotation;
import gate.AnnotationSet;
import gate.plugin.learningframework.EvaluationMethod;
import gate.plugin.learningframework.LFUtils;
import gate.plugin.learningframework.ModelApplication;
import gate.plugin.learningframework.data.CorpusRepresentationMalletTarget;
import gate.plugin.learningframework.mallet.LFPipe;
import gate.util.GateRuntimeException;
import java.io.File;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.lang.reflect.InvocationTargetException;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.apache.log4j.Logger;

/* loaded from: input_file:gate/plugin/learningframework/engines/EngineMBMalletClass.class */
public class EngineMBMalletClass extends EngineMBMallet {
    private static Logger LOGGER = Logger.getLogger(EngineMBMalletClass.class);

    @Override // gate.plugin.learningframework.engines.Engine
    public void trainModel(File file, String str, String str2) {
        System.err.println("EngineMalletClass.trainModel: trainer=" + this.trainer);
        System.err.println("EngineMalletClass.trainModel: CR=" + this.corpusRepresentation);
        this.model = ((ClassifierTrainer) this.trainer).train(this.corpusRepresentation.getRepresentationMallet());
        updateInfo();
    }

    @Override // gate.plugin.learningframework.engines.Engine
    public List<ModelApplication> applyModel(AnnotationSet annotationSet, AnnotationSet annotationSet2, AnnotationSet annotationSet3, String str) {
        if (!(this.corpusRepresentation instanceof CorpusRepresentationMalletTarget)) {
            throw new GateRuntimeException("Cannot perform classification with data from " + this.corpusRepresentation.getClass());
        }
        CorpusRepresentationMalletTarget corpusRepresentationMalletTarget = (CorpusRepresentationMalletTarget) this.corpusRepresentation;
        corpusRepresentationMalletTarget.stopGrowth();
        ArrayList arrayList = new ArrayList();
        LFPipe pipe = corpusRepresentationMalletTarget.getRepresentationMallet().getPipe();
        Classifier classifier = (Classifier) this.model;
        for (Annotation annotation : annotationSet.inDocumentOrder()) {
            Labeling labeling = classifier.classify(pipe.instanceFrom(corpusRepresentationMalletTarget.extractIndependentFeatures(annotation, annotationSet2))).getLabeling();
            LabelVector labelVector = labeling.toLabelVector();
            ArrayList arrayList2 = new ArrayList(labelVector.numLocations());
            ArrayList arrayList3 = new ArrayList(labelVector.numLocations());
            for (int i = 0; i < labelVector.numLocations(); i++) {
                arrayList2.add(labelVector.getLabelAtRank(i).toString());
                arrayList3.add(Double.valueOf(labelVector.getValueAtRank(i)));
            }
            ModelApplication modelApplication = new ModelApplication(annotation, labeling.getBestLabel().toString(), Double.valueOf(labeling.getBestValue()), arrayList2, arrayList3);
            annotation.getFeatures().put("gate.LF.target", labeling.getBestLabel().toString());
            arrayList.add(modelApplication);
        }
        corpusRepresentationMalletTarget.startGrowth();
        return arrayList;
    }

    @Override // gate.plugin.learningframework.engines.Engine
    public void initializeAlgorithm(Algorithm algorithm, String str) {
        if (str == null || str.trim().isEmpty()) {
            Class<?> trainerClass = algorithm.getTrainerClass();
            try {
                this.trainer = trainerClass.getDeclaredConstructor(new Class[0]).newInstance(new Object[0]);
                return;
            } catch (IllegalAccessException | IllegalArgumentException | InstantiationException | NoSuchMethodException | SecurityException | InvocationTargetException e) {
                throw new GateRuntimeException("Could not create trainer instance for " + trainerClass, e);
            }
        }
        if (algorithm.equals(AlgorithmClassification.MalletC45_CL_MR)) {
            Parms parms = new Parms(str, "m:maxDepth:i", "p:prune:B", "n:minNumInsts:i");
            int intValue = ((Integer) parms.getValueOrElse("maxDepth", 0)).intValue();
            int intValue2 = ((Integer) parms.getValueOrElse("minNumInsts", 2)).intValue();
            boolean booleanValue = ((Boolean) parms.getValueOrElse("prune", true)).booleanValue();
            C45Trainer c45Trainer = intValue > 0 ? !booleanValue ? new C45Trainer(intValue, false) : new C45Trainer(intValue, true) : new C45Trainer(booleanValue);
            c45Trainer.setMinNumInsts(intValue2);
            this.trainer = c45Trainer;
            return;
        }
        if (algorithm.equals(AlgorithmClassification.MalletDecisionTree_CL_MR)) {
            DecisionTreeTrainer decisionTreeTrainer = new DecisionTreeTrainer();
            Parms parms2 = new Parms(str, "m:maxDepth:i", "i:minInfoGainSplit:d");
            int intValue3 = ((Integer) parms2.getValueOrElse("maxDepth", 5)).intValue();
            double doubleValue = ((Double) parms2.getValueOrElse("minInfoGainSplit", Double.valueOf(0.001d))).doubleValue();
            decisionTreeTrainer.setMaxDepth(intValue3);
            decisionTreeTrainer.setMinInfoGainSplit(doubleValue);
            this.trainer = decisionTreeTrainer;
            return;
        }
        if (algorithm.equals(AlgorithmClassification.MalletMexEnt_CL_MR)) {
            MaxEntTrainer maxEntTrainer = new MaxEntTrainer();
            Parms parms3 = new Parms(str, "v:gaussianPriorVariance:d", "l:l1Weight:d", "i:numIterations:i");
            maxEntTrainer.setGaussianPriorVariance(((Double) parms3.getValueOrElse("gaussianPriorVariance", Double.valueOf(1.0d))).doubleValue());
            maxEntTrainer.setL1Weight(((Double) parms3.getValueOrElse("l1Weight", Double.valueOf(0.0d))).doubleValue());
            maxEntTrainer.setNumIterations(((Integer) parms3.getValueOrElse("numIterations", Integer.MAX_VALUE)).intValue());
            this.trainer = maxEntTrainer;
            return;
        }
        if (algorithm.equals(AlgorithmClassification.MalletBalancedWinnow_CL_MR)) {
            Parms parms4 = new Parms(str, "e:epsilon:d", "d:delta:d", "i:maxIterations:i", "c:coolingRate:d");
            this.trainer = new BalancedWinnowTrainer(((Double) parms4.getValueOrElse("epsilon", Double.valueOf(0.5d))).doubleValue(), ((Double) parms4.getValueOrElse("delta", Double.valueOf(0.1d))).doubleValue(), ((Integer) parms4.getValueOrElse("int", 30)).intValue(), ((Double) parms4.getValueOrElse("coolingRate", Double.valueOf(0.5d))).doubleValue());
        } else {
            if (algorithm.equals(AlgorithmClassification.MalletWinnow_CL_MR)) {
                Parms parms5 = new Parms(str, "a:alpha:d", "b:beta:d", "n:nfact:d");
                this.trainer = new WinnowTrainer(((Double) parms5.getValueOrElse("alpha", Double.valueOf(2.0d))).doubleValue(), ((Double) parms5.getValueOrElse("beta", Double.valueOf(2.0d))).doubleValue(), ((Double) parms5.getValueOrElse("nfact", Double.valueOf(0.5d))).doubleValue());
                return;
            }
            LOGGER.warn("IMPORTANT: parameters ignored when creating Mallet trainer " + algorithm.getTrainerClass());
            Class<?> trainerClass2 = algorithm.getTrainerClass();
            try {
                this.trainer = trainerClass2.getDeclaredConstructor(new Class[0]).newInstance(new Object[0]);
            } catch (IllegalAccessException | IllegalArgumentException | InstantiationException | NoSuchMethodException | SecurityException | InvocationTargetException e2) {
                throw new GateRuntimeException("Could not create trainer instance for " + trainerClass2, e2);
            }
        }
    }

    @Override // gate.plugin.learningframework.engines.Engine
    protected void loadModel(URL url, String str) {
        try {
            InputStream openStream = LFUtils.newURL(url, Engine.FILENAME_MODEL).openStream();
            Throwable th = null;
            try {
                try {
                    this.model = (Classifier) new ObjectInputStream(openStream).readObject();
                    if (openStream != null) {
                        if (0 != 0) {
                            try {
                                openStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            openStream.close();
                        }
                    }
                } finally {
                }
            } finally {
            }
        } catch (Exception e) {
            throw new GateRuntimeException("Could not load Mallet model", e);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // gate.plugin.learningframework.engines.Engine
    public EvaluationResult evaluate(String str, EvaluationMethod evaluationMethod, int i, double d, int i2) {
        EvaluationResultClHO evaluationResultClHO;
        double d2;
        int intValue = ((Integer) new Parms(str, "s:seed:i").getValueOrElse("seed", 1)).intValue();
        if (evaluationMethod == EvaluationMethod.CROSSVALIDATION) {
            InstanceList.CrossValidationIterator crossValidationIterator = this.corpusRepresentation.getRepresentationMallet().crossValidationIterator(i, intValue);
            if (!(this.algorithm instanceof AlgorithmClassification)) {
                throw new GateRuntimeException("Mallet evaluation: not available for regression!");
            }
            double d3 = 0.0d;
            while (true) {
                d2 = d3;
                if (!crossValidationIterator.hasNext()) {
                    break;
                }
                InstanceList[] nextSplit = crossValidationIterator.nextSplit();
                d3 = d2 + ((ClassifierTrainer) this.trainer).train(nextSplit[0]).getAccuracy(nextSplit[1]);
            }
            EvaluationResultClXval evaluationResultClXval = new EvaluationResultClXval();
            evaluationResultClXval.accuracyEstimate = d2 / i;
            evaluationResultClXval.nrFolds = i;
            evaluationResultClHO = evaluationResultClXval;
        } else {
            if (!(this.algorithm instanceof AlgorithmClassification)) {
                throw new GateRuntimeException("Mallet evaluation: not available for regression!");
            }
            Random random = new Random(intValue);
            double d4 = 0.0d;
            for (int i3 = 0; i3 < i2; i3++) {
                InstanceList[] split = this.corpusRepresentation.getRepresentationMallet().split(random, new double[]{d, 1.0d - d});
                d4 += ((ClassifierTrainer) this.trainer).train(split[0]).getAccuracy(split[1]);
            }
            EvaluationResultClHO evaluationResultClHO2 = new EvaluationResultClHO();
            evaluationResultClHO2.accuracyEstimate = d4 / i2;
            evaluationResultClHO2.trainingFraction = d;
            evaluationResultClHO2.nrRepeats = i2;
            evaluationResultClHO = evaluationResultClHO2;
        }
        return evaluationResultClHO;
    }
}
