package de.datexis.retrieval.tagger;

import com.google.common.collect.Lists;
import de.datexis.common.Resource;
import de.datexis.common.WordHelpers;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.EncodingHelpers;
import de.datexis.encoder.IEncoder;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Token;
import de.datexis.retrieval.tagger.LabeledSentenceIterator;
import de.datexis.tagger.AbstractMultiDataSetIterator;
import de.datexis.tagger.DocumentSentenceIterator;
import de.datexis.tagger.Tagger;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/retrieval/tagger/LSTMSentenceTagger.class */
public class LSTMSentenceTagger extends Tagger {
    protected static final Logger log = LoggerFactory.getLogger(LSTMSentenceTagger.class);
    protected IEncoder inputEncoder;
    protected IEncoder targetEncoder;
    protected final FeedForwardToRnnPreProcessor ff2rnn;

    public LSTMSentenceTagger() {
        super("EMB");
        this.ff2rnn = new FeedForwardToRnnPreProcessor();
    }

    public LSTMSentenceTagger(String str) {
        super(str);
        this.ff2rnn = new FeedForwardToRnnPreProcessor();
    }

    public LSTMSentenceTagger(Resource resource) {
        super(resource);
        this.ff2rnn = new FeedForwardToRnnPreProcessor();
        setId("EMB");
    }

    @JsonIgnore
    /* renamed from: getNN, reason: merged with bridge method [inline-methods] */
    public ComputationGraph m4getNN() {
        return this.net;
    }

    public void initializeNetwork(ComputationGraph computationGraph) {
        this.net = computationGraph;
    }

    public void setInputEncoders(IEncoder iEncoder) {
        this.inputEncoder = iEncoder;
    }

    public void setTargetEncoder(IEncoder iEncoder) {
        this.targetEncoder = iEncoder;
    }

    @JsonIgnore
    public IEncoder getInputEncoder() {
        return this.inputEncoder;
    }

    @JsonIgnore
    public IEncoder getTargetEncoder() {
        return this.targetEncoder;
    }

    public List<Encoder> getEncoders() {
        return Lists.newArrayList(new Encoder[]{(Encoder) this.inputEncoder, (Encoder) this.targetEncoder});
    }

    public void setEncoders(List<Encoder> list) {
        if (list.size() != 2) {
            throw new IllegalArgumentException("wrong number of encoders given (expected=3, actual=" + list.size() + ")");
        }
        this.inputEncoder = list.get(0);
        this.targetEncoder = list.get(1);
    }

    public void trainModel(Resource resource) {
        trainModel(resource, this.numEpochs);
    }

    public void trainModel(Dataset dataset) {
        throw new UnsupportedOperationException("training from Dataset not implemented, please use a TSV file");
    }

    public void tag(Collection<Document> collection) {
        throw new UnsupportedOperationException("not implemented yet");
    }

    protected void trainModel(Resource resource, int i) {
        LSTMSentenceTaggerIterator lSTMSentenceTaggerIterator = new LSTMSentenceTaggerIterator(AbstractMultiDataSetIterator.Stage.TRAIN, this.inputEncoder, this.targetEncoder, resource, "utf-8", WordHelpers.Language.EN, true, this.batchSize);
        this.timer.start();
        appendTrainLog("Training " + getName() + " for " + i + " epochs.");
        Nd4j.getMemoryManager().togglePeriodicGc(false);
        for (int i2 = 1; i2 <= i; i2++) {
            appendTrainLog("Starting epoch " + i2 + " of " + i + "\t");
            triggerEpochListeners(true, i2 - 1);
            m4getNN().fit(lSTMSentenceTaggerIterator);
            this.timer.setSplit("epoch");
            appendTrainLog("Completed epoch " + i2 + " of " + i + "\t", this.timer.getLong("epoch"));
            triggerEpochListeners(false, i2 - 1);
            if (i2 < i) {
                lSTMSentenceTaggerIterator.reset();
            }
            Nd4j.getMemoryManager().invokeGc();
        }
        this.timer.stop();
        appendTrainLog("Training complete", this.timer.getLong());
        setModelAvailable(true);
    }

    public void testModel(Dataset dataset) {
        this.timer.start();
        this.timer.stop();
        appendTestLog("Testing complete", this.timer.getLong());
    }

    public INDArray encodeSentence(Sentence sentence) {
        INDArray createMask = LSTMSentenceTaggerIterator.createMask(Collections.singletonList(sentence), sentence.getLength(), Token.class);
        INDArray ones = Nd4j.ones(DataType.FLOAT, 1, 1);
        INDArray encodeTimeStepMatrix = EncodingHelpers.encodeTimeStepMatrix(Collections.singletonList(sentence), this.inputEncoder, sentence.getLength(), Token.class);
        m4getNN().setLayerMaskArrays(new INDArray[]{createMask}, new INDArray[]{ones});
        Map feedForward = m4getNN().feedForward(new INDArray[]{encodeTimeStepMatrix}, false, true);
        if (feedForward.containsKey("embedding")) {
            return ((INDArray) feedForward.get("embedding")).transpose();
        }
        throw new IllegalStateException("Embedding does not have an embeddding layer");
    }

    public INDArray encodeBatch(LabeledSentenceIterator.LabeledSentenceBatch labeledSentenceBatch) {
        INDArray createMask = LSTMSentenceTaggerIterator.createMask(labeledSentenceBatch.sentences, labeledSentenceBatch.maxSentenceLength, Token.class);
        INDArray createLabelsMask = LSTMSentenceTaggerIterator.createLabelsMask(labeledSentenceBatch.sentences, Token.class);
        INDArray encodeTimeStepMatrix = EncodingHelpers.encodeTimeStepMatrix(labeledSentenceBatch.sentences, this.inputEncoder, labeledSentenceBatch.maxSentenceLength, Token.class);
        m4getNN().setLayerMaskArrays(new INDArray[]{createMask}, new INDArray[]{createLabelsMask});
        Map feedForward = m4getNN().feedForward(new INDArray[]{encodeTimeStepMatrix}, false, true);
        if (feedForward.containsKey("embedding")) {
            return (INDArray) feedForward.get("embedding");
        }
        throw new IllegalStateException("Embedding does not have an embeddding layer");
    }

    public INDArray encodeBatchMatrix(List<Sentence> list) {
        int i = 1;
        Iterator<Sentence> it = list.iterator();
        while (it.hasNext()) {
            i = Math.max(i, it.next().countTokens());
        }
        INDArray createMask = LSTMSentenceTaggerIterator.createMask(list, i, Token.class);
        INDArray createLabelsMask = LSTMSentenceTaggerIterator.createLabelsMask(list, Token.class);
        INDArray encodeTimeStepMatrix = EncodingHelpers.encodeTimeStepMatrix(list, this.inputEncoder, i, Token.class);
        m4getNN().setLayerMaskArrays(new INDArray[]{createMask}, new INDArray[]{createLabelsMask});
        Map feedForward = m4getNN().feedForward(new INDArray[]{encodeTimeStepMatrix}, false, true);
        if (feedForward.containsKey("embedding")) {
            return (INDArray) feedForward.get("embedding");
        }
        throw new IllegalStateException("Embedding does not have an embeddding layer");
    }

    public Map<String, INDArray> encodeMatrix(DocumentSentenceIterator.DocumentBatch documentBatch) {
        Map<String, INDArray> feedForward = feedForward(m4getNN(), documentBatch.dataset);
        if (feedForward.containsKey("embedding")) {
            feedForward.put("embedding", this.ff2rnn.preProcess(feedForward.get("embedding"), documentBatch.size, LayerWorkspaceMgr.noWorkspaces()));
        }
        return feedForward;
    }

    public static Map<String, INDArray> feedForward(ComputationGraph computationGraph, MultiDataSet multiDataSet) {
        INDArray[] features = multiDataSet.getFeatures();
        computationGraph.setLayerMaskArrays(multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays());
        return computationGraph.feedForward(features, false, true);
    }

    protected void triggerEpochListeners(boolean z, int i) {
        Collection<TrainingListener> listeners = m4getNN().getListeners();
        m4getNN().getConfiguration().setEpochCount(i);
        if (listeners == null || listeners.isEmpty()) {
            return;
        }
        for (TrainingListener trainingListener : listeners) {
            if (z) {
                trainingListener.onEpochStart(m4getNN());
            } else {
                trainingListener.onEpochEnd(m4getNN());
            }
        }
    }

    public void enableTrainingUI() {
        throw new UnsupportedOperationException("Training UI is not part of texoo-retrieval. Please use deeplearning4j-ui_2.11 in your code for that.");
    }

    public void saveModel(Resource resource, String str) {
        Resource resolve = resource.resolve(str + ".zip");
        try {
            OutputStream outputStream = resolve.getOutputStream();
            Throwable th = null;
            try {
                try {
                    ModelSerializer.writeModel(this.net, outputStream, false);
                    setModel(resolve);
                    if (outputStream != null) {
                        if (0 != 0) {
                            try {
                                outputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            outputStream.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (IOException e) {
            log.error(e.toString());
        }
    }

    public void loadModel(Resource resource) {
        try {
            InputStream inputStream = resource.getInputStream();
            Throwable th = null;
            try {
                this.net = ModelSerializer.restoreComputationGraph(inputStream, false);
                setModel(resource);
                setModelAvailable(true);
                log.info("loaded Computation Graph from " + resource.getFileName());
                if (inputStream != null) {
                    if (0 != 0) {
                        try {
                            inputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        inputStream.close();
                    }
                }
            } finally {
            }
        } catch (IOException e) {
            log.error(e.toString());
        }
    }

    public ComputationGraphConfiguration getGraphConfiguration() {
        return null;
    }

    public void setGraphConfiguration(JsonNode jsonNode) {
    }
}
