package gate.plugin.learningframework.engines;

import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import gate.Annotation;
import gate.AnnotationSet;
import gate.lib.interaction.process.Process4JsonStream;
import gate.lib.interaction.process.ProcessBase;
import gate.lib.interaction.process.ProcessSimple;
import gate.plugin.learningframework.EvaluationMethod;
import gate.plugin.learningframework.ModelApplication;
import gate.plugin.learningframework.data.CorpusRepresentationMallet;
import gate.plugin.learningframework.data.CorpusRepresentationMalletTarget;
import gate.plugin.learningframework.export.CorpusExporter;
import gate.plugin.learningframework.export.Exporter;
import gate.plugin.learningframework.features.FeatureInfo;
import gate.plugin.learningframework.features.TargetType;
import gate.plugin.learningframework.mallet.LFPipe;
import gate.plugin.learningframework.mallet.NominalTargetWithCosts;
import gate.util.Files;
import gate.util.GateRuntimeException;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.net.URL;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.yaml.snakeyaml.Yaml;

/* loaded from: input_file:gate/plugin/learningframework/engines/EngineMBPythonNetworksBase.class */
public abstract class EngineMBPythonNetworksBase extends EngineMB {
    protected String WRAPPER_NAME;
    protected String ENV_WRAPPER_HOME;
    protected String PROP_WRAPPER_HOME;
    protected String YAML_FILE;
    protected String YAML_SETTING_WRAPPER_HOME;
    protected String SCRIPT_APPLY_BASENAME;
    protected String SCRIPT_TRAIN_BASENAME;
    protected String SCRIPT_EVAL_BASENAME;
    protected String MODEL_BASENAME;
    protected Object MODEL_INSTANCE;
    protected ProcessBase process;
    protected String shellcmd = null;
    protected String shellparms = null;
    protected String wrapperhome = null;
    protected CorpusExporter corpusExporter = null;

    @Override // gate.plugin.learningframework.engines.EngineMB, gate.plugin.learningframework.engines.Engine
    protected void initWhenCreating(URL url, Algorithm algorithm, String str, FeatureInfo featureInfo, TargetType targetType) {
        this.corpusExporter = CorpusExporter.create(Exporter.CSV_CL_MR, "-t -n " + str, this.featureInfo, str, url);
        this.corpusRepresentation = (CorpusRepresentationMallet) this.corpusExporter.getCorpusRepresentation();
    }

    protected File findWrapperCommand(File file, boolean z) {
        File file2;
        String str = System.getenv(this.ENV_WRAPPER_HOME);
        String property = System.getProperty(this.PROP_WRAPPER_HOME);
        if (property != null) {
            str = property;
        }
        File file3 = new File(file, this.YAML_FILE);
        if (file3.exists()) {
            try {
                Object load = new Yaml().load(new InputStreamReader(new FileInputStream(file3), "UTF-8"));
                if (!(load instanceof Map)) {
                    throw new GateRuntimeException("Info file has strange format: " + file3.getAbsolutePath());
                }
                Map map = (Map) load;
                String str2 = (String) map.get(this.YAML_SETTING_WRAPPER_HOME);
                if (str2 != null) {
                    str = str2;
                }
                this.shellcmd = (String) map.get("shellcmd");
                this.shellparms = (String) map.get("shellparms");
            } catch (FileNotFoundException | UnsupportedEncodingException e) {
                throw new GateRuntimeException("Could not load yaml file " + file3, e);
            }
        }
        if (str == null) {
            throw new GateRuntimeException(this.WRAPPER_NAME + " home not set, please see https://github.com/GateNLP/gateplugin-LearningFramework/wiki/UsingSklearn");
        }
        File file4 = new File(str);
        if (!file4.isAbsolute()) {
            file4 = new File(file, str);
        }
        if (!file4.isDirectory()) {
            throw new GateRuntimeException(this.WRAPPER_NAME + " home is not a directory: " + file4.getAbsolutePath());
        }
        this.wrapperhome = file4.getAbsolutePath();
        boolean equals = System.getProperty("file.separator").equals("/");
        boolean equals2 = System.getProperty("file.separator").equals("\\");
        if (equals) {
            file2 = z ? new File(new File(file4, "bin"), this.SCRIPT_APPLY_BASENAME + ".sh") : new File(new File(file4, "bin"), this.SCRIPT_TRAIN_BASENAME + ".sh");
        } else {
            if (!equals2) {
                throw new GateRuntimeException("It appears this OS is not supported");
            }
            file2 = z ? new File(new File(file4, "bin"), this.SCRIPT_APPLY_BASENAME + ".cmd") : new File(new File(file4, "bin"), this.SCRIPT_TRAIN_BASENAME + ".cmd");
        }
        File file5 = file2.isAbsolute() ? file2 : new File(file, file2.getPath());
        if (file5.canExecute()) {
            return file5;
        }
        throw new GateRuntimeException("Not an executable file or not found: " + file5 + " please see https://github.com/GateNLP/gateplugin-LearningFramework/wiki/UsingPythonWrappers");
    }

    @Override // gate.plugin.learningframework.engines.Engine
    protected void loadModel(URL url, String str) {
        if (!"file".equals(url.getProtocol())) {
            throw new GateRuntimeException("The dataDirectory for WekaWrapper must be a file: URL not " + url);
        }
        File fileFromURL = Files.fileFromURL(url);
        ArrayList arrayList = new ArrayList();
        loadAndSetCorpusRepresentation(url);
        AbstractMap.SimpleEntry<String, Integer> findOutMode = findOutMode((CorpusRepresentationMalletTarget) this.corpusRepresentation);
        String key = findOutMode.getKey();
        Integer value = findOutMode.getValue();
        File findWrapperCommand = findWrapperCommand(fileFromURL, true);
        String absolutePath = new File(fileFromURL, this.MODEL_BASENAME).getAbsolutePath();
        arrayList.add(findWrapperCommand.getAbsolutePath());
        arrayList.add(absolutePath);
        arrayList.add(key);
        arrayList.add(value.toString());
        if (this.shellcmd != null) {
            arrayList.add(0, this.shellcmd);
            if (this.shellparms != null) {
                int i = 0;
                for (String str2 : this.shellparms.trim().split("\\s+")) {
                    i++;
                    arrayList.add(i, str2);
                }
            }
        }
        this.model = this.MODEL_INSTANCE;
        HashMap hashMap = new HashMap();
        hashMap.put(this.ENV_WRAPPER_HOME, this.wrapperhome);
        this.process = Process4JsonStream.create(fileFromURL, hashMap, arrayList);
    }

    @Override // gate.plugin.learningframework.engines.Engine
    protected void saveModel(File file) {
        this.info.engineClass = getClass().getName();
        this.info.save(file);
    }

    @Override // gate.plugin.learningframework.engines.Engine
    public void trainModel(File file, String str, String str2) {
        ArrayList arrayList = new ArrayList();
        AbstractMap.SimpleEntry<String, Integer> findOutMode = findOutMode((CorpusRepresentationMalletTarget) this.corpusRepresentation);
        String key = findOutMode.getKey();
        Integer value = findOutMode.getValue();
        File findWrapperCommand = findWrapperCommand(file, false);
        this.corpusExporter.export();
        String str3 = file.getAbsolutePath() + File.separator;
        String absolutePath = new File(file, this.MODEL_BASENAME).getAbsolutePath();
        arrayList.add(findWrapperCommand.getAbsolutePath());
        arrayList.add(str3);
        arrayList.add(absolutePath);
        arrayList.add(key);
        arrayList.add(value.toString());
        if (!str2.trim().isEmpty()) {
            arrayList.addAll(Arrays.asList(str2.split("\\s+", -1)));
        }
        if (this.shellcmd != null) {
            arrayList.add(0, this.shellcmd);
            if (this.shellparms != null) {
                int i = 0;
                for (String str4 : this.shellparms.trim().split("\\s+")) {
                    i++;
                    arrayList.add(i, str4);
                }
            }
        }
        this.model = this.MODEL_INSTANCE;
        HashMap hashMap = new HashMap();
        hashMap.put(this.ENV_WRAPPER_HOME, this.wrapperhome);
        this.process = ProcessSimple.create(file, hashMap, arrayList);
        this.process.waitFor();
    }

    @Override // gate.plugin.learningframework.engines.Engine
    public EvaluationResult evaluate(String str, EvaluationMethod evaluationMethod, int i, double d, int i2) {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override // gate.plugin.learningframework.engines.Engine
    public List<ModelApplication> applyModel(AnnotationSet annotationSet, AnnotationSet annotationSet2, AnnotationSet annotationSet3, String str) {
        ModelApplication modelApplication;
        CorpusRepresentationMalletTarget corpusRepresentationMalletTarget = (CorpusRepresentationMalletTarget) this.corpusRepresentation;
        corpusRepresentationMalletTarget.stopGrowth();
        int size = corpusRepresentationMalletTarget.getPipe().getDataAlphabet().size();
        ArrayList arrayList = new ArrayList();
        LFPipe pipe = corpusRepresentationMalletTarget.getRepresentationMallet().getPipe();
        ArrayList arrayList2 = null;
        if (pipe.getTargetAlphabet() != null) {
            arrayList2 = new ArrayList();
            for (int i = 0; i < pipe.getTargetAlphabet().size(); i++) {
                arrayList2.add(pipe.getTargetAlphabet().lookupObject(i).toString());
            }
        }
        HashMap hashMap = new HashMap();
        if (arrayList2 == null) {
            hashMap.put("cmd", "AR");
        } else {
            hashMap.put("cmd", "AC");
        }
        ArrayList arrayList3 = new ArrayList();
        int i2 = 0;
        List<Annotation> inDocumentOrder = annotationSet.inDocumentOrder();
        Iterator it = inDocumentOrder.iterator();
        while (it.hasNext()) {
            FeatureVector featureVector = (FeatureVector) pipe.instanceFrom(corpusRepresentationMalletTarget.extractIndependentFeatures((Annotation) it.next(), annotationSet2)).getData();
            double[] dArr = new double[size];
            for (int i3 = 0; i3 < size; i3++) {
                dArr[i3] = featureVector.value(i3);
            }
            arrayList3.add(dArr);
            i2++;
        }
        hashMap.put("values", arrayList3);
        hashMap.put("n", Integer.valueOf(size));
        this.process.writeObject(hashMap);
        Object readObject = this.process.readObject();
        Map map = readObject instanceof Map ? (Map) readObject : null;
        if (map == null) {
            throw new RuntimeException("Got a response from Wrapper process which cannot be used: " + map);
        }
        String str2 = (String) map.get("status");
        if (str2 == null || !str2.equals("OK")) {
            throw new RuntimeException("Status of response is not OK but " + str2);
        }
        ArrayList arrayList4 = (ArrayList) map.get("targets");
        ArrayList arrayList5 = (ArrayList) map.get("probas");
        if (pipe.getTargetAlphabet() == null && arrayList5 != null) {
            throw new RuntimeException("We think we have regression but the Sklearn process sent probabilities");
        }
        int i4 = 0;
        for (Annotation annotation : inDocumentOrder) {
            if (pipe.getTargetAlphabet() == null) {
                modelApplication = new ModelApplication(annotation, ((Double) arrayList4.get(i4)).doubleValue());
            } else {
                modelApplication = new ModelApplication(annotation, pipe.getTargetAlphabet().lookupObject(((Double) arrayList4.get(i4)).intValue()).toString(), Double.valueOf(arrayList5 != null ? ((Double) Collections.max((Collection) arrayList5.get(i4))).doubleValue() : Double.NaN), arrayList2, (List) arrayList5.get(i4));
            }
            arrayList.add(modelApplication);
            i4++;
        }
        corpusRepresentationMalletTarget.startGrowth();
        return arrayList;
    }

    @Override // gate.plugin.learningframework.engines.Engine
    public void initializeAlgorithm(Algorithm algorithm, String str) {
    }

    protected AbstractMap.SimpleEntry<String, Integer> findOutMode(CorpusRepresentationMalletTarget corpusRepresentationMalletTarget) {
        InstanceList representationMallet = corpusRepresentationMalletTarget.getRepresentationMallet();
        int i = 0;
        Object obj = "regr";
        Alphabet targetAlphabet = corpusRepresentationMalletTarget.getPipe().getTargetAlphabet();
        if (targetAlphabet != null) {
            if (representationMallet == null || representationMallet.isEmpty()) {
                obj = "classind";
                i = -1;
            } else {
                Object target = ((Instance) representationMallet.get(0)).getTarget();
                if (target instanceof NominalTargetWithCosts) {
                    i = ((NominalTargetWithCosts) target).getCosts().length;
                    obj = "classcosts";
                } else {
                    obj = "classind";
                    i = targetAlphabet.size();
                }
            }
        }
        return new AbstractMap.SimpleEntry<>(obj, Integer.valueOf(i));
    }
}
