package org.deeplearning4j.cli.subcommands;

import java.io.BufferedOutputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.Collections;
import java.util.Properties;
import org.apache.commons.io.FileUtils;
import org.canova.api.formats.input.InputFormat;
import org.canova.api.records.reader.RecordReader;
import org.canova.api.split.FileSplit;
import org.deeplearning4j.datasets.canova.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.layers.factory.LayerFactories;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.kohsuke.args4j.Option;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/cli/subcommands/Train.class */
public class Train extends BaseSubCommand {
    public static final String EXECUTION_RUNTIME_MODE_KEY = "execution.runtime";
    public static final String EXECUTION_RUNTIME_MODE_DEFAULT = "local";
    public static final String OUTPUT_FILENAME_KEY = "output.directory";
    public static final String INPUT_DATA_FILENAME_KEY = "input.directory";
    public static final String INPUT_FORMAT_KEY = "input.format";
    public static final String DEFAULT_INPUT_FORMAT_CLASSNAME = "org.canova.api.formats.input.impl.SVMLightInputFormat";

    @Option(name = "-conf", usage = "configuration file for training", required = true)
    public String configurationFile;
    public Properties configProps;
    private static Logger log = LoggerFactory.getLogger(Train.class);

    @Option(name = "-input", usage = "input data", aliases = {"-i"}, required = true)
    private String input;

    @Option(name = "-output", usage = "location for saving model", aliases = {"-o"})
    private String outputDirectory;

    @Option(name = "-model", usage = "location for configuration of model", aliases = {"-m"})
    private String modelPath;

    @Option(name = "-type", usage = "type of network (layer or multi layer)")
    private String type;

    @Option(name = "-runtime", usage = "runtime- local, Hadoop, Spark, etc.", aliases = {"-r"}, required = false)
    private String runtime;

    @Option(name = "-properties", usage = "configuration for distributed systems", aliases = {"-p"}, required = false)
    private String properties;

    @Option(name = "-savemode", usage = "output: (binary | txt)")
    private String saveMode;

    @Option(name = "-verbose", usage = "verbose(true | false)", aliases = {"-v"})
    private boolean verbose;

    public Train() {
        this(new String[1]);
    }

    public Train(String[] strArr) {
        super(strArr);
        this.configurationFile = "";
        this.configProps = null;
        this.input = "input.txt";
        this.outputDirectory = "output.txt";
        this.type = "multi";
        this.runtime = EXECUTION_RUNTIME_MODE_DEFAULT;
        this.saveMode = "txt";
        this.verbose = false;
    }

    @Override // org.deeplearning4j.cli.subcommands.SubCommand
    public void execute() {
        try {
            loadConfigFile();
        } catch (Exception e) {
            e.printStackTrace();
        }
        if ("hadoop".equals(this.runtime.trim().toLowerCase())) {
            execOnHadoop();
        } else if ("spark".equals(this.runtime.trim().toLowerCase())) {
            execOnSpark();
        } else {
            execLocal();
        }
    }

    public void execLocal() {
        log.warn("[dl4j] - executing local ... ");
        log.warn("using training input: " + this.input);
        RecordReader recordReader = null;
        try {
            recordReader = createInputFormat().createReader(new FileSplit(new File(this.input)));
        } catch (Exception e) {
            e.printStackTrace();
        }
        if (!this.type.equals("multi")) {
            try {
                NeuralNetConfiguration fromJson = NeuralNetConfiguration.fromJson(FileUtils.readFileToString(new File(this.modelPath)));
                Layer create = LayerFactories.getFactory(fromJson).create(fromJson);
                RecordReaderDataSetIterator recordReaderDataSetIterator = new RecordReaderDataSetIterator(recordReader, fromJson.getBatchSize());
                while (recordReaderDataSetIterator.hasNext()) {
                    create.fit(((DataSet) recordReaderDataSetIterator.next()).getFeatureMatrix());
                }
                if (this.saveMode.equals("binary")) {
                    Nd4j.write(create.params(), new DataOutputStream(new BufferedOutputStream(new FileOutputStream(this.outputDirectory))));
                } else {
                    Nd4j.writeTxt(create.params(), this.outputDirectory, ",");
                }
                return;
            } catch (IOException e2) {
                e2.printStackTrace();
                return;
            }
        }
        try {
            MultiLayerConfiguration fromJson2 = MultiLayerConfiguration.fromJson(FileUtils.readFileToString(new File(this.modelPath)));
            RecordReaderDataSetIterator recordReaderDataSetIterator2 = new RecordReaderDataSetIterator(recordReader, fromJson2.getConf(0).getBatchSize(), -1, fromJson2.getConf(fromJson2.getConfs().size() - 1).getLayer().getNOut());
            MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(fromJson2);
            if (this.verbose) {
                multiLayerNetwork.init();
                multiLayerNetwork.setListeners(Collections.singletonList(new ScoreIterationListener(1)));
            }
            multiLayerNetwork.fit(recordReaderDataSetIterator2);
            if (this.saveMode.equals("binary")) {
                Nd4j.write(multiLayerNetwork.params(), new DataOutputStream(new BufferedOutputStream(new FileOutputStream(this.outputDirectory + File.separator + "outputmodel.bin"))));
            } else {
                Nd4j.writeTxt(multiLayerNetwork.params(), this.outputDirectory + File.separator + "outputmodel.txt", ",");
            }
        } catch (IOException e3) {
            e3.printStackTrace();
        }
    }

    public void execOnSpark() {
        log.warn("DL4J: Execution on spark from CLI not yet supported");
    }

    public void execOnHadoop() {
        log.warn("DL4J: Execution on hadoop from CLI not yet supported");
    }

    public InputFormat createInputFormat() {
        if (this.configProps == null) {
            try {
                loadConfigFile();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        String str = (String) this.configProps.get(INPUT_FORMAT_KEY);
        if (null == str) {
            str = DEFAULT_INPUT_FORMAT_CLASSNAME;
        }
        try {
            return (InputFormat) Class.forName(str).newInstance();
        } catch (Exception e2) {
            throw new RuntimeException(e2);
        }
    }

    public void loadConfigFile() throws Exception {
        this.configProps = new Properties();
        FileInputStream fileInputStream = null;
        try {
            fileInputStream = new FileInputStream(this.configurationFile);
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        }
        try {
            this.configProps.load(fileInputStream);
            fileInputStream.close();
        } catch (IOException e2) {
            e2.printStackTrace();
        }
        if (this.configProps.get(EXECUTION_RUNTIME_MODE_KEY) != null) {
            this.runtime = (String) this.configProps.get(EXECUTION_RUNTIME_MODE_KEY);
        } else {
            this.runtime = EXECUTION_RUNTIME_MODE_DEFAULT;
        }
        if (null != this.configProps.get(OUTPUT_FILENAME_KEY)) {
            this.outputDirectory = (String) this.configProps.get(OUTPUT_FILENAME_KEY);
        } else {
            this.outputDirectory = "/tmp/dl4_model_default.txt";
        }
        if (null == this.configProps.get(INPUT_DATA_FILENAME_KEY)) {
            throw new RuntimeException("no input file to train on!");
        }
        this.input = (String) this.configProps.get(INPUT_DATA_FILENAME_KEY);
    }
}
