package de.datexis.retrieval.tagger;

import de.datexis.common.Resource;
import de.datexis.common.WordHelpers;
import de.datexis.encoder.EncodingHelpers;
import de.datexis.encoder.IEncoder;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.model.Token;
import de.datexis.retrieval.tagger.LabeledSentenceIterator;
import de.datexis.tagger.AbstractMultiDataSetIterator;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:de/datexis/retrieval/tagger/LSTMSentenceTaggerIterator.class */
public class LSTMSentenceTaggerIterator extends LabeledSentenceIterator {
    protected IEncoder inputEncoder;
    protected IEncoder targetEncoder;
    protected Set<String> stopWords;

    public LSTMSentenceTaggerIterator(AbstractMultiDataSetIterator.Stage stage, IEncoder iEncoder, IEncoder iEncoder2, Resource resource, String str, WordHelpers.Language language, boolean z, int i) {
        this(stage, iEncoder, iEncoder2, resource, str, language, Collections.emptySet(), z, i, -1, -1);
    }

    public LSTMSentenceTaggerIterator(AbstractMultiDataSetIterator.Stage stage, IEncoder iEncoder, IEncoder iEncoder2, Resource resource, String str, WordHelpers.Language language, Set<String> set, boolean z, int i) {
        this(stage, iEncoder, iEncoder2, resource, str, language, set, z, i, -1, -1);
    }

    public LSTMSentenceTaggerIterator(AbstractMultiDataSetIterator.Stage stage, IEncoder iEncoder, IEncoder iEncoder2, Resource resource, String str, WordHelpers.Language language, Set<String> set, boolean z, int i, int i2, int i3) {
        super(stage, resource, str, language, z, i, i2, i3);
        this.stopWords = Collections.emptySet();
        this.inputEncoder = iEncoder;
        this.targetEncoder = iEncoder2;
        this.stopWords = set;
    }

    public LSTMSentenceTaggerIterator(AbstractMultiDataSetIterator.Stage stage, IEncoder iEncoder, IEncoder iEncoder2, Set<String> set, int i, int i2, int i3) {
        super(stage, i, i2, i3);
        this.stopWords = Collections.emptySet();
        this.inputEncoder = iEncoder;
        this.targetEncoder = iEncoder2;
        this.stopWords = set;
    }

    public LabeledSentenceIterator.LabeledSentenceBatch applyStopWordFilter(LabeledSentenceIterator.LabeledSentenceBatch labeledSentenceBatch) {
        if (!this.stopWords.isEmpty()) {
            List<String> list = labeledSentenceBatch.labels;
            List<Sentence> list2 = (List) labeledSentenceBatch.sentences.stream().map(sentence -> {
                return applyStopWordFilter(sentence);
            }).collect(Collectors.toList());
            int i = 1;
            for (int size = list2.size() - 1; size >= 0; size--) {
                Sentence sentence2 = list2.get(size);
                if (sentence2.countTokens() > 0) {
                    i = Math.max(i, sentence2.countTokens());
                } else {
                    list2.remove(size);
                    if (list != null) {
                        list.remove(size);
                    }
                }
            }
            labeledSentenceBatch.sentences = list2;
            labeledSentenceBatch.labels = list;
            labeledSentenceBatch.size = list2.size();
            labeledSentenceBatch.maxSentenceLength = i;
        }
        return labeledSentenceBatch;
    }

    public Sentence applyStopWordFilter(Sentence sentence) {
        return new Sentence((List) sentence.streamTokens().filter(token -> {
            return !this.stopWords.contains(token.getText().toLowerCase().trim());
        }).collect(Collectors.toList()));
    }

    @Override // de.datexis.retrieval.tagger.LabeledSentenceIterator
    public LabeledSentenceIterator.LabeledSentenceBatch nextSentenceBatch(int i) {
        return applyStopWordFilter(super.nextSentenceBatch(i));
    }

    @Override // de.datexis.retrieval.tagger.LabeledSentenceIterator
    public Map.Entry<String, Sentence> nextLabeledSentence() {
        Map.Entry<String, Sentence> nextLabeledSentence = super.nextLabeledSentence();
        if (!this.stopWords.isEmpty()) {
            nextLabeledSentence.setValue(applyStopWordFilter(nextLabeledSentence.getValue()));
        }
        return nextLabeledSentence;
    }

    @Override // de.datexis.retrieval.tagger.LabeledSentenceIterator
    public MultiDataSet generateDataSet(LabeledSentenceIterator.LabeledSentenceBatch labeledSentenceBatch) {
        return new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{EncodingHelpers.encodeTimeStepMatrix(labeledSentenceBatch.sentences, this.inputEncoder, labeledSentenceBatch.maxSentenceLength, Token.class)}, new INDArray[]{(this.stage.equals(AbstractMultiDataSetIterator.Stage.TRAIN) || this.stage.equals(AbstractMultiDataSetIterator.Stage.TEST)) ? encodeTarget(labeledSentenceBatch.sentences, labeledSentenceBatch.labels) : Nd4j.zeros(DataType.FLOAT, new long[]{labeledSentenceBatch.size, this.targetEncoder.getEmbeddingVectorSize()})}, new INDArray[]{createMask(labeledSentenceBatch.sentences, labeledSentenceBatch.maxSentenceLength, Token.class)}, new INDArray[]{createLabelsMask(labeledSentenceBatch.sentences, Token.class)});
    }

    public static INDArray createMask(List<Sentence> list, int i, Class<? extends Span> cls) {
        INDArray zeros = Nd4j.zeros(DataType.FLOAT, new long[]{list.size(), i});
        for (int i2 = 0; i2 < list.size(); i2++) {
            int countTokens = cls == Token.class ? list.get(i2).countTokens() : 0;
            zeros.putScalar(new int[]{i2, 0}, 1.0d);
            for (int i3 = 1; i3 < countTokens && i3 < i; i3++) {
                zeros.putScalar(new int[]{i2, i3}, 1.0d);
            }
        }
        return zeros;
    }

    public static INDArray createLabelsMask(List<Sentence> list, Class<? extends Span> cls) {
        INDArray zeros = Nd4j.zeros(DataType.FLOAT, new long[]{list.size(), 1});
        for (int i = 0; i < list.size(); i++) {
            if ((cls == Token.class ? list.get(i).countTokens() : 0) > 0) {
                zeros.putScalar(new int[]{i, 0}, 1.0d);
            }
        }
        return zeros;
    }

    public INDArray createBwdMask(List<Sentence> list, int i, Class<? extends Span> cls) {
        INDArray zeros = Nd4j.zeros(DataType.FLOAT, new long[]{list.size(), i});
        for (int i2 = 0; i2 < list.size(); i2++) {
            if ((cls == Token.class ? Math.min(list.get(i2).countTokens(), i) : 0) > 0 && i > 0) {
                zeros.putScalar(new int[]{i2, 0}, 1.0d);
            }
        }
        return zeros;
    }

    public INDArray createFwdMask(List<Sentence> list, int i, Class<? extends Span> cls) {
        INDArray zeros = Nd4j.zeros(DataType.FLOAT, new long[]{list.size(), i});
        for (int i2 = 0; i2 < list.size(); i2++) {
            int min = cls == Token.class ? Math.min(list.get(i2).countTokens(), i) : 0;
            if (min > 0 && i > 0) {
                zeros.putScalar(new int[]{i2, min - 1}, 1.0d);
            }
        }
        return zeros;
    }

    public INDArray encodeRNNTarget(List<Sentence> list, List<String> list2, int i, Class<? extends Span> cls) {
        INDArray zeros = Nd4j.zeros(DataType.FLOAT, new long[]{list2.size(), this.targetEncoder.getEmbeddingVectorSize(), i});
        for (int i2 = 0; i2 < list2.size(); i2++) {
            String str = list2.get(i2);
            Sentence sentence = list.get(i2);
            int i3 = 0;
            INDArray encode = this.targetEncoder.encode(str);
            for (Token token : sentence.getTokens()) {
                if (i3 >= i) {
                    break;
                }
                int i4 = i3;
                i3++;
                zeros.get(new INDArrayIndex[]{NDArrayIndex.point(i2), NDArrayIndex.all(), NDArrayIndex.point(i4)}).assign(encode.dup());
            }
        }
        return zeros;
    }

    public INDArray encodeTarget(List<Sentence> list, List<String> list2) {
        INDArray zeros = Nd4j.zeros(DataType.FLOAT, new long[]{list2.size(), this.targetEncoder.getEmbeddingVectorSize()});
        for (int i = 0; i < list2.size(); i++) {
            zeros.slice(i).assign(this.targetEncoder.encode(list2.get(i)).dup());
        }
        return zeros;
    }
}
