package de.datexis.retrieval.tagger;

import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.graph.rnn.ReverseTimeSeriesVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.schedule.ExponentialSchedule;
import org.nd4j.linalg.schedule.ScheduleType;

/* loaded from: input_file:de/datexis/retrieval/tagger/ModelBuilder.class */
public class ModelBuilder {
    public static ComputationGraph buildLSTMSentenceTagger(long j, long j2, long j3, long j4, int i, double d, double d2, ILossFunction iLossFunction, Activation activation) {
        ComputationGraph computationGraph = new ComputationGraph(new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Adam(new ExponentialSchedule(ScheduleType.EPOCH, d, 0.95d))).weightInit(WeightInit.XAVIER).l2(1.0E-5d).dropOut(0.0d).gradientNormalization(GradientNormalization.ClipL2PerLayer).trainingWorkspaceMode(WorkspaceMode.ENABLED).inferenceWorkspaceMode(WorkspaceMode.ENABLED).cacheMode(CacheMode.HOST).graphBuilder().addInputs(new String[]{"input"}).addVertex("reverse", new ReverseTimeSeriesVertex("input"), new String[]{"input"}).addLayer("LSTM_FW", new LastTimeStep(new LSTM.Builder().nIn(j).nOut(j2).activation(Activation.TANH).gateActivationFunction(Activation.SIGMOID).dropOut(d2).build()), new String[]{"input"}).addLayer("LSTM_BW", new LastTimeStep(new LSTM.Builder().nIn(j).nOut(j2).activation(Activation.TANH).gateActivationFunction(Activation.SIGMOID).dropOut(d2).build()), new String[]{"reverse"}).addVertex("merge", new MergeVertex(), new String[]{"LSTM_FW", "LSTM_BW"}).addLayer("embedding", new DenseLayer.Builder().nIn(2 * j2).nOut(j3).dropOut(0.0d).activation(Activation.TANH).build(), new String[]{"merge"}).addLayer("target", new OutputLayer.Builder(iLossFunction).nIn(j3).nOut(j4).activation(activation).dropOut(0.0d).weightInit(WeightInit.SIGMOID_UNIFORM).build(), new String[]{"embedding"}).setOutputs(new String[]{"target"}).setInputTypes(new InputType[]{InputType.recurrent(j)}).backpropType(BackpropType.Standard).build());
        computationGraph.init();
        computationGraph.setListeners(new TrainingListener[]{new PerformanceListener(128, true), new ScoreIterationListener(16)});
        return computationGraph;
    }
}
