package de.datexis.retrieval.tagger;

import de.datexis.common.Resource;
import de.datexis.common.WordHelpers;
import de.datexis.model.Sentence;
import de.datexis.preprocess.DocumentFactory;
import de.datexis.tagger.AbstractMultiDataSetIterator;
import java.io.IOException;
import java.nio.file.Files;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/retrieval/tagger/LabeledSentenceIterator.class */
public abstract class LabeledSentenceIterator extends AbstractMultiDataSetIterator {
    protected final Logger log;
    protected DocumentFactory df;
    protected Resource source;
    protected LineIterator iterator;
    protected WordHelpers.Language lang;
    protected String encoding;
    protected boolean tokenized;
    protected static Pattern TAB_SEPARATOR = Pattern.compile("^(.*)\t(.*)$");

    /* loaded from: input_file:de/datexis/retrieval/tagger/LabeledSentenceIterator$LabeledSentenceBatch.class */
    public static class LabeledSentenceBatch {
        public List<Sentence> sentences;
        public List<String> labels;
        public MultiDataSet dataset;
        public int size;
        public int maxSentenceLength;

        public LabeledSentenceBatch(int i, List<Sentence> list, List<String> list2, int i2, MultiDataSet multiDataSet) {
            this.size = i;
            this.sentences = list;
            this.labels = list2;
            this.dataset = multiDataSet;
            this.maxSentenceLength = i2;
        }

        public LabeledSentenceBatch(List<Sentence> list) {
            this.sentences = list;
            this.size = list.size();
            int i = 1;
            Iterator<Sentence> it = list.iterator();
            while (it.hasNext()) {
                i = Math.max(i, it.next().countTokens());
            }
            this.maxSentenceLength = i;
        }
    }

    public LabeledSentenceIterator(AbstractMultiDataSetIterator.Stage stage, int i, int i2, int i3) {
        super(stage, i2, i3, i, false);
        this.log = LoggerFactory.getLogger(getClass());
    }

    public LabeledSentenceIterator(AbstractMultiDataSetIterator.Stage stage, Resource resource, String str, WordHelpers.Language language, boolean z, int i, int i2, int i3) {
        super(stage, i2, i3, i, false);
        this.log = LoggerFactory.getLogger(getClass());
        this.df = DocumentFactory.getInstance();
        this.lang = language;
        this.encoding = str;
        this.tokenized = z;
        this.source = resource;
        if (i2 < 0) {
            try {
                this.numExamples = Files.lines(resource.getPath()).count();
            } catch (IOException e) {
                this.log.error(e.getMessage());
            }
        }
        reset();
    }

    public void reset() {
        try {
            if (this.iterator != null) {
                this.iterator.close();
            }
            this.iterator = IOUtils.lineIterator(this.source.getInputStream(), this.encoding);
        } catch (IOException e) {
            this.log.error(e.getMessage());
        }
        super.reset();
    }

    protected boolean hasNextSentence() {
        return this.iterator != null && this.iterator.hasNext();
    }

    public boolean hasNext() {
        return hasNextSentence() && !reachedEnd();
    }

    public List<String> getLabels() {
        reset();
        LinkedList linkedList = new LinkedList();
        while (hasNext()) {
            linkedList.add(nextLabeledSentence().getKey());
        }
        reset();
        return linkedList;
    }

    public Map.Entry<String, Sentence> nextLabeledSentence() {
        this.cursor++;
        String next = this.iterator.next();
        Matcher matcher = TAB_SEPARATOR.matcher(next);
        if (!matcher.matches()) {
            this.log.warn("Could not read line '{}'", next);
            return new AbstractMap.SimpleEntry(new String(), new Sentence());
        }
        String group = matcher.group(1);
        String group2 = matcher.group(2);
        return new AbstractMap.SimpleEntry(group, this.tokenized ? DocumentFactory.createSentenceFromTokenizedString(group2) : DocumentFactory.createSentenceFromString(group2, this.lang.toString()));
    }

    public LabeledSentenceBatch nextSentenceBatch(int i) {
        ArrayList arrayList = new ArrayList(i);
        ArrayList arrayList2 = new ArrayList(i);
        int i2 = 1;
        for (int i3 = 0; i3 < i; i3++) {
            Map.Entry<String, Sentence> nextLabeledSentence = hasNext() ? nextLabeledSentence() : new AbstractMap.SimpleEntry<>(new String(), new Sentence());
            arrayList2.add(nextLabeledSentence.getKey());
            arrayList.add(nextLabeledSentence.getValue());
            i2 = Math.max(i2, nextLabeledSentence.getValue().countTokens());
        }
        return new LabeledSentenceBatch(i, arrayList, arrayList2, i2, null);
    }

    public LabeledSentenceBatch nextSentenceBatch() {
        return nextSentenceBatch(this.batchSize);
    }

    public MultiDataSet next(int i) {
        LabeledSentenceBatch nextSentenceBatch = nextSentenceBatch(i);
        reportProgress(nextSentenceBatch.maxSentenceLength);
        return generateDataSet(nextSentenceBatch);
    }

    public abstract MultiDataSet generateDataSet(LabeledSentenceBatch labeledSentenceBatch);
}
