package de.julielab.jcore.ae.fte;

import com.google.gson.Gson;
import de.julielab.ipc.javabridge.Options;
import de.julielab.ipc.javabridge.ResultDecoders;
import de.julielab.ipc.javabridge.StdioBridge;
import de.julielab.jcore.types.EmbeddingVector;
import de.julielab.jcore.types.Sentence;
import de.julielab.jcore.types.Token;
import de.julielab.jcore.utility.JCoReTools;
import de.julielab.jcore.utility.index.Comparators;
import de.julielab.jcore.utility.index.JCoReSetAnnotationIndex;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_component.JCasAnnotator_ImplBase;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.FSIterator;
import org.apache.uima.cas.Type;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.descriptor.ResourceMetaData;
import org.apache.uima.fit.descriptor.TypeCapability;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.cas.DoubleArray;
import org.apache.uima.jcas.tcas.Annotation;
import org.apache.uima.resource.ResourceInitializationException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@ResourceMetaData(name = "JCoRe Flair Token Embedding Annotator", description = "Adds the Flair compatible embedding vectors to the token annotations.")
@TypeCapability(inputs = {"de.julielab.jcore.types.Sentence", "de.julielab.jcore.types.Token"}, outputs = {"de.julielab.jcore.types.EmbeddingVector"})
/* loaded from: input_file:de/julielab/jcore/ae/fte/FlairTokenEmbeddingAnnotator.class */
public class FlairTokenEmbeddingAnnotator extends JCasAnnotator_ImplBase {
    public static final String PARAM_EMBEDDING_PATH = "EmbeddingPath";
    public static final String PARAM_COMPUTATION_FILTER = "ComputationFilter";
    public static final String PARAM_EMBEDDING_SOURCE = "EmbeddingSource";
    public static final String PARAM_PYTHON_EXECUTABLE = "PythonExecutable";
    private static final Logger log = LoggerFactory.getLogger(FlairTokenEmbeddingAnnotator.class);
    private static final int TIME_OUTPUT_INTERVAL = 1000;

    @ConfigurationParameter(name = PARAM_EMBEDDING_PATH, description = "Path to a Flair compatible embedding file. Since flair supports a range of different embeddings, a type prefix is required. The syntax is 'prefix:<path or built-in flair embedding name>. The possible prefixes are 'word', 'char', 'bytepair', 'flair', 'bert', 'elmo'.")
    private String embeddingPath;

    @ConfigurationParameter(name = PARAM_COMPUTATION_FILTER, mandatory = false, description = "This parameter may be set to a fully qualified annotation type. If given, only for documents containing at least one annotation of this type embeddings will be retrieved from the computing flair python script. However, for contextualized embeddings, all embedding vectors are computed anyway and the the I/O cost is minor in comparison to the embedding computation. Thus, setting this parameter will most probably only result in small time savings.")
    private String computationFilter;

    @ConfigurationParameter(name = PARAM_EMBEDDING_SOURCE, mandatory = false, description = "The value of this parameter will be set to the source feature of the EmbeddingVector annotation instance created on the tokens. If left blank, the value of the EmbeddingPath will be used.")
    private String embeddingSource;

    @ConfigurationParameter(name = PARAM_PYTHON_EXECUTABLE, mandatory = false, description = "The path to the python executable. Required is a python verion >=3.6.")
    private String pythonExecutable;
    private StdioBridge<byte[]> flairBridge;
    private Gson gson;
    private long embeddingRequestTime;
    private long embeddingRequestTimeForLastInterval;
    private int docsProcessed;

    public void initialize(UimaContext uimaContext) throws ResourceInitializationException {
        this.embeddingPath = (String) uimaContext.getConfigParameterValue(PARAM_EMBEDDING_PATH);
        this.computationFilter = (String) uimaContext.getConfigParameterValue(PARAM_COMPUTATION_FILTER);
        this.embeddingSource = (String) Optional.ofNullable((String) uimaContext.getConfigParameterValue(PARAM_EMBEDDING_SOURCE)).orElse(this.embeddingPath);
        Optional ofNullable = Optional.ofNullable((String) uimaContext.getConfigParameterValue(PARAM_PYTHON_EXECUTABLE));
        if (ofNullable.isPresent()) {
            this.pythonExecutable = (String) ofNullable.get();
            log.info("Python executable: {} (from descriptor)", this.pythonExecutable);
        } else {
            log.debug("No python executable given in the component descriptor, trying to read PYTHON environment variable.");
            String str = System.getenv("PYTHON");
            if (str != null) {
                this.pythonExecutable = str;
                log.info("Python executable: {} (from environment variable PYTHON).", this.pythonExecutable);
            }
        }
        if (this.pythonExecutable == null) {
            this.pythonExecutable = "python3.6";
            log.info("Python executable: {} (default)", this.pythonExecutable);
        }
        try {
            Options options = new Options(byte[].class);
            options.setExecutable(this.pythonExecutable);
            options.setExternalProgramTerminationSignal("exit");
            options.setExternalProgramReadySignal("Script is ready");
            options.setTerminationSignalFromErrorStream("SyntaxError");
            this.flairBridge = new StdioBridge<>(options, new String[]{"-u", "-c", IOUtils.toString(getClass().getResourceAsStream("/de/julielab/jcore/ae/fte/python/getEmbeddingScript.py"), StandardCharsets.UTF_8), this.embeddingPath});
            this.flairBridge.start();
            this.gson = new Gson();
            this.docsProcessed = 0;
            this.embeddingRequestTime = 0L;
            this.embeddingRequestTimeForLastInterval = 0L;
        } catch (IOException e) {
            log.error("Could not create the IO bridge object.", e);
            throw new ResourceInitializationException(e);
        }
    }

    public void process(JCas jCas) throws AnalysisEngineProcessException {
        ArrayList arrayList = new ArrayList();
        JCoReSetAnnotationIndex<Annotation> jCoReSetAnnotationIndex = null;
        if (!StringUtils.isBlank(this.computationFilter)) {
            Type type = jCas.getTypeSystem().getType(this.computationFilter);
            if (type == null) {
                throw new AnalysisEngineProcessException(new IllegalArgumentException("The type " + this.computationFilter + " was not found in the type system."));
            }
            if (!jCas.getAnnotationIndex(type).iterator().hasNext()) {
                return;
            } else {
                jCoReSetAnnotationIndex = new JCoReSetAnnotationIndex<>(Comparators.overlapComparator(), jCas, type);
            }
        }
        String constructEmbeddingRequest = constructEmbeddingRequest(jCas, arrayList, jCoReSetAnnotationIndex);
        try {
            long currentTimeMillis = System.currentTimeMillis();
            Optional<double[][]> findAny = this.flairBridge.sendAndReceive(constructEmbeddingRequest).map(ResultDecoders.decodeVectors).findAny();
            long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
            log.trace("Sending and receiving token embeddings took {} ms", Long.valueOf(currentTimeMillis2));
            this.embeddingRequestTime += currentTimeMillis2;
            this.embeddingRequestTimeForLastInterval += currentTimeMillis2;
            writeEmbeddingsToCas(jCas, arrayList, findAny);
            this.docsProcessed++;
            if (this.docsProcessed % TIME_OUTPUT_INTERVAL == 0) {
                if (log.isDebugEnabled()) {
                    log.debug("Embedding computation for the last {} documents took {}ms (avg: {}ms). Total time for all {} processed documents until here: {}ms ({}s)", new Object[]{Integer.valueOf(TIME_OUTPUT_INTERVAL), Long.valueOf(this.embeddingRequestTimeForLastInterval), Long.valueOf(this.embeddingRequestTimeForLastInterval / 1000), Integer.valueOf(this.docsProcessed), Long.valueOf(this.embeddingRequestTime), Long.valueOf(this.embeddingRequestTime / 60)});
                }
                this.embeddingRequestTimeForLastInterval = 0L;
            }
        } catch (InterruptedException e) {
            log.error("Computation of embedding vectors was interrupted", e);
            throw new AnalysisEngineProcessException(e);
        }
    }

    private void writeEmbeddingsToCas(JCas jCas, List<Token> list, Optional<double[][]> optional) {
        if (optional.isPresent()) {
            double[][] dArr = optional.get();
            for (int i = 0; i < list.size(); i++) {
                Token token = list.get(i);
                double[] dArr2 = dArr[i];
                DoubleArray doubleArray = new DoubleArray(jCas, dArr2.length);
                doubleArray.copyFromArray(dArr2, 0, 0, dArr2.length);
                EmbeddingVector embeddingVector = new EmbeddingVector(jCas, token.getBegin(), token.getEnd());
                embeddingVector.setSource(this.embeddingSource);
                embeddingVector.setVector(doubleArray);
                token.setEmbeddingVectors(JCoReTools.addToFSArray(token.getEmbeddingVectors(), embeddingVector));
            }
        }
    }

    private String constructEmbeddingRequest(JCas jCas, List<Token> list, JCoReSetAnnotationIndex<Annotation> jCoReSetAnnotationIndex) {
        Map indexCovered = JCasUtil.indexCovered(jCas, Sentence.class, Token.class);
        ArrayList arrayList = new ArrayList();
        FSIterator it = jCas.getAnnotationIndex(Sentence.type).iterator();
        while (it.hasNext()) {
            Annotation annotation = (Annotation) it.next();
            List arrayList2 = jCoReSetAnnotationIndex != null ? new ArrayList() : Collections.emptyList();
            int i = 0;
            StringBuilder sb = new StringBuilder();
            for (Token token : (Collection) indexCovered.get(annotation)) {
                sb.append(token.getCoveredText()).append(" ");
                if (jCoReSetAnnotationIndex == null) {
                    list.add(token);
                } else if (!jCoReSetAnnotationIndex.searchSubset(token).isEmpty()) {
                    arrayList2.add(Integer.valueOf(i));
                    list.add(token);
                }
                i++;
            }
            sb.deleteCharAt(sb.length() - 1);
            HashMap hashMap = new HashMap();
            hashMap.put("sentence", sb.toString());
            hashMap.put("tokenIndicesToReturn", arrayList2);
            arrayList.add(hashMap);
        }
        return this.gson.toJson(arrayList);
    }

    public void collectionProcessComplete() throws AnalysisEngineProcessException {
        if (log.isDebugEnabled()) {
            log.debug("The total time for embedding computation, including I/O, was {}ms ({}s)", Long.valueOf(this.embeddingRequestTime), Long.valueOf(this.embeddingRequestTime / 1000));
        }
        try {
            this.flairBridge.stop();
        } catch (IOException | InterruptedException e) {
            log.error("Exception when trying shut down IO bridge to the python embedding computation script", e);
            throw new AnalysisEngineProcessException(e);
        }
    }
}
