package de.datexis.retrieval.encoder;

import de.datexis.annotator.Annotator;
import de.datexis.annotator.AnnotatorComponent;
import de.datexis.common.Resource;
import de.datexis.common.WordHelpers;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.IEncoder;
import de.datexis.retrieval.tagger.LSTMSentenceTagger;
import de.datexis.retrieval.tagger.ModelBuilder;
import de.datexis.tagger.Tagger;
import java.util.Collection;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/retrieval/encoder/LSTMSentenceAnnotator.class */
public class LSTMSentenceAnnotator extends Annotator {
    protected final Logger log;

    /* loaded from: input_file:de/datexis/retrieval/encoder/LSTMSentenceAnnotator$Builder.class */
    public static class Builder {
        IEncoder inputEncoder;
        IEncoder targetEncoder;
        private Collection<String> stopWords;
        protected ILossFunction lossFunc = LossFunctions.LossFunction.MCXENT.getILossFunction();
        protected Activation activation = Activation.SOFTMAX;
        private int examplesPerEpoch = -1;
        private int maxTimeSeriesLength = -1;
        private int lstmLayerSize = 256;
        private int embeddingLayerSize = 128;
        private double learningRate = 0.01d;
        private double dropOut = 0.5d;
        private int iterations = 1;
        private int batchSize = 16;
        private int numEpochs = 1;
        private boolean enabletrainingUI = false;
        LSTMSentenceTagger tagger = new LSTMSentenceTagger();
        LSTMSentenceAnnotator ann = new LSTMSentenceAnnotator(this.tagger);

        public Builder withId(String str) {
            this.tagger.setId(str);
            return this;
        }

        public Builder withDataset(String str, WordHelpers.Language language) {
            this.ann.getProvenance().setDataset(str);
            this.ann.getProvenance().setLanguage(language.toString().toLowerCase());
            return this;
        }

        public Builder withLossFunction(LossFunctions.LossFunction lossFunction, Activation activation) {
            this.lossFunc = lossFunction.getILossFunction();
            this.activation = activation;
            return this;
        }

        public Builder withLossFunction(ILossFunction iLossFunction, Activation activation) {
            this.lossFunc = iLossFunction;
            this.activation = activation;
            return this;
        }

        public Builder withModelParams(int i, int i2) {
            this.lstmLayerSize = i;
            this.embeddingLayerSize = i2;
            return this;
        }

        public Builder withTrainingParams(double d, double d2, int i, int i2, int i3) {
            this.learningRate = d;
            this.dropOut = d2;
            this.examplesPerEpoch = i;
            this.batchSize = i2;
            this.numEpochs = i3;
            return this;
        }

        public Builder withTrainingParams(double d, double d2, int i, int i2, int i3, int i4) {
            this.learningRate = d;
            this.dropOut = d2;
            this.examplesPerEpoch = i;
            this.batchSize = i3;
            this.maxTimeSeriesLength = i2;
            this.numEpochs = i4;
            return this;
        }

        public Builder withInputEncoders(String str, Encoder encoder) {
            this.inputEncoder = encoder;
            this.tagger.setInputEncoders(encoder);
            this.ann.getProvenance().setFeatures(str);
            this.ann.addComponent(encoder);
            return this;
        }

        public Builder withTargetEncoder(Encoder encoder) {
            this.targetEncoder = encoder;
            this.tagger.setTargetEncoder(encoder);
            this.ann.addComponent(encoder);
            return this;
        }

        public Builder enableTrainingUI(boolean z) {
            this.enabletrainingUI = z;
            return this;
        }

        public Builder withStopWords(Collection<String> collection) {
            this.stopWords = collection;
            return this;
        }

        public LSTMSentenceAnnotator build() {
            this.tagger.initializeNetwork(ModelBuilder.buildLSTMSentenceTagger(this.inputEncoder.getEmbeddingVectorSize(), this.lstmLayerSize, this.embeddingLayerSize, this.targetEncoder.getEmbeddingVectorSize(), this.iterations, this.learningRate, this.dropOut, this.lossFunc, this.activation));
            if (this.stopWords != null) {
                this.tagger.setStopWords(this.stopWords);
            }
            if (this.enabletrainingUI) {
                this.tagger.enableTrainingUI();
            }
            this.tagger.setEmbeddingLayerSize(this.embeddingLayerSize);
            this.tagger.setTrainingParams(this.examplesPerEpoch, this.maxTimeSeriesLength, this.batchSize, this.numEpochs, true);
            this.ann.getProvenance().setTask(this.tagger.getId());
            this.tagger.setName(this.ann.getProvenance().toString());
            this.tagger.appendTrainLog(printParams());
            return this.ann;
        }

        private String printParams() {
            StringBuilder sb = new StringBuilder();
            sb.append("TRAINING PARAMS: ").append(this.tagger.getName()).append("\n");
            sb.append("\nEncoders:\n");
            for (Encoder encoder : this.tagger.getEncoders()) {
                sb.append(encoder.getId()).append("\t").append(encoder.getClass().getSimpleName()).append("\t").append(encoder.getEmbeddingVectorSize()).append("\n");
            }
            sb.append("\nNetwork Params:\n");
            sb.append("BLSTM").append("\t").append(this.lstmLayerSize).append("\n");
            sb.append("EMB").append("\t").append(this.embeddingLayerSize).append("\n");
            sb.append("\nTraining Params:\n");
            sb.append("examples per epoch").append("\t").append(this.examplesPerEpoch).append("\n");
            sb.append("max time series length").append("\t").append(this.maxTimeSeriesLength).append("\n");
            sb.append("epochs").append("\t").append(this.numEpochs).append("\n");
            sb.append("iterations").append("\t").append(this.iterations).append("\n");
            sb.append("batch size").append("\t").append(this.batchSize).append("\n");
            sb.append("learning rate").append("\t").append(this.learningRate).append("\n");
            sb.append("dropout").append("\t").append(this.dropOut).append("\n");
            sb.append("loss").append("\t").append(this.lossFunc.toString()).append("\n");
            sb.append("\n");
            return sb.toString();
        }
    }

    public LSTMSentenceAnnotator() {
        this.log = LoggerFactory.getLogger(getClass());
    }

    public LSTMSentenceAnnotator(Tagger tagger) {
        super(tagger);
        this.log = LoggerFactory.getLogger(getClass());
    }

    protected LSTMSentenceAnnotator(AnnotatorComponent annotatorComponent) {
        super(annotatorComponent);
        this.log = LoggerFactory.getLogger(getClass());
    }

    /* renamed from: getTagger, reason: merged with bridge method [inline-methods] */
    public LSTMSentenceTagger m0getTagger() {
        return (LSTMSentenceTagger) super.getTagger();
    }

    public void trainModel(Resource resource) {
        m0getTagger().trainModel(resource);
    }

    public LSTMSentenceEncoder asEncoder() {
        return new LSTMSentenceEncoder(m0getTagger());
    }
}
