package me.yingrui.segment.word2vec.apps;

import java.io.File;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import me.yingrui.segment.math.Matrix;
import me.yingrui.segment.math.Matrix$;
import me.yingrui.segment.neural.BPLayer;
import me.yingrui.segment.neural.BPRecurrentLayer;
import me.yingrui.segment.neural.BackPropagation;
import me.yingrui.segment.neural.SoftmaxLayer$;
import me.yingrui.segment.neural.errors.CrossEntropyLoss;
import me.yingrui.segment.util.SerializeHandler;
import me.yingrui.segment.util.SerializeHandler$;
import me.yingrui.segment.word2vec.RNNSegmentViterbiClassifier;
import me.yingrui.segment.word2vec.SegmentCorpus;
import me.yingrui.segment.word2vec.Vocabulary;
import me.yingrui.segment.word2vec.Vocabulary$;
import scala.App;
import scala.Function0;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.IndexedSeq;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.generic.TraversableForwarder;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ListBuffer;
import scala.collection.mutable.ListBuffer$;
import scala.collection.mutable.StringBuilder;
import scala.concurrent.ExecutionContext$Implicits$;
import scala.concurrent.ExecutionContextExecutor;
import scala.math.Numeric$DoubleIsFractional$;
import scala.runtime.AbstractFunction0;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;
import scala.util.Random;

/* compiled from: RNNSegmentTrainingApp.scala */
/* loaded from: input_file:me/yingrui/segment/word2vec/apps/RNNSegmentTrainingApp$.class */
public final class RNNSegmentTrainingApp$ implements App {
    public static final RNNSegmentTrainingApp$ MODULE$ = null;
    private ExecutionContextExecutor executionContext;
    private Random random;
    private String word2VecModelFile;
    private String trainFile;
    private String saveFile;
    private int ngram;
    private int maxIteration;
    private int punishment;
    private boolean skipSelf;
    private int taskCount;
    private SerializeHandler reader;
    private Vocabulary vocab;
    private double[][] word2VecModel;
    private int numberOfFeatures;
    private int labelNgram;
    private int numberOfClasses;
    private Tuple2<IndexedSeq<BackPropagation>, Matrix> x$1;
    private IndexedSeq<BackPropagation> networks;
    private Matrix transitionProb;
    private BackPropagation rnn;
    private SegmentCorpus corpus;
    private Seq<String> files;
    private int iteration;
    private double cost;
    private double lastCost;
    private ListBuffer<Object> costs;
    private double lastAverageCost;
    private boolean hasImprovement;
    private double learningRate;
    private final long executionStart;
    private String[] scala$App$$_args;
    private final ListBuffer<Function0<BoxedUnit>> scala$App$$initCode;

    static {
        new RNNSegmentTrainingApp$();
    }

    public long executionStart() {
        return this.executionStart;
    }

    public String[] scala$App$$_args() {
        return this.scala$App$$_args;
    }

    public void scala$App$$_args_$eq(String[] strArr) {
        this.scala$App$$_args = strArr;
    }

    public ListBuffer<Function0<BoxedUnit>> scala$App$$initCode() {
        return this.scala$App$$initCode;
    }

    public void scala$App$_setter_$executionStart_$eq(long j) {
        this.executionStart = j;
    }

    public void scala$App$_setter_$scala$App$$initCode_$eq(ListBuffer listBuffer) {
        this.scala$App$$initCode = listBuffer;
    }

    public String[] args() {
        return App.class.args(this);
    }

    public void delayedInit(Function0<BoxedUnit> function0) {
        App.class.delayedInit(this, function0);
    }

    public void main(String[] strArr) {
        App.class.main(this, strArr);
    }

    public ExecutionContextExecutor executionContext() {
        return this.executionContext;
    }

    public Random random() {
        return this.random;
    }

    public String word2VecModelFile() {
        return this.word2VecModelFile;
    }

    public String trainFile() {
        return this.trainFile;
    }

    public String saveFile() {
        return this.saveFile;
    }

    public int ngram() {
        return this.ngram;
    }

    public int maxIteration() {
        return this.maxIteration;
    }

    public int punishment() {
        return this.punishment;
    }

    public boolean skipSelf() {
        return this.skipSelf;
    }

    public int taskCount() {
        return this.taskCount;
    }

    public SerializeHandler reader() {
        return this.reader;
    }

    public Vocabulary vocab() {
        return this.vocab;
    }

    public double[][] word2VecModel() {
        return this.word2VecModel;
    }

    public int numberOfFeatures() {
        return this.numberOfFeatures;
    }

    public int labelNgram() {
        return this.labelNgram;
    }

    public int numberOfClasses() {
        return this.numberOfClasses;
    }

    public IndexedSeq<BackPropagation> networks() {
        return this.networks;
    }

    public Matrix transitionProb() {
        return this.transitionProb;
    }

    public BackPropagation rnn() {
        return this.rnn;
    }

    public SegmentCorpus corpus() {
        return this.corpus;
    }

    public Seq<String> files() {
        return this.files;
    }

    public int iteration() {
        return this.iteration;
    }

    public void iteration_$eq(int i) {
        this.iteration = i;
    }

    public double cost() {
        return this.cost;
    }

    public void cost_$eq(double d) {
        this.cost = d;
    }

    public double lastCost() {
        return this.lastCost;
    }

    public void lastCost_$eq(double d) {
        this.lastCost = d;
    }

    public ListBuffer<Object> costs() {
        return this.costs;
    }

    public double lastAverageCost() {
        return this.lastAverageCost;
    }

    public void lastAverageCost_$eq(double d) {
        this.lastAverageCost = d;
    }

    public boolean hasImprovement() {
        return this.hasImprovement;
    }

    public void hasImprovement_$eq(boolean z) {
        this.hasImprovement = z;
    }

    public double learningRate() {
        return this.learningRate;
    }

    public void learningRate_$eq(double d) {
        this.learningRate = d;
    }

    public void updateLearningRate(double d) {
        if (d <= 0.03d) {
            learningRate_$eq(learningRate() * 0.1d);
        }
        if (learningRate() < 1.0E-4d) {
            learningRate_$eq(1.0E-5d);
        }
    }

    public void me$yingrui$segment$word2vec$apps$RNNSegmentTrainingApp$$displayResult(Tuple2<Object, Object> tuple2) {
        if (tuple2 == null) {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
            return;
        }
        double _1$mcD$sp = tuple2._1$mcD$sp();
        double _2$mcD$sp = tuple2._2$mcD$sp();
        Predef$.MODULE$.println(new StringBuilder().append("error = ").append(BoxesRunTime.boxToDouble(_1$mcD$sp)).append(" total = ").append(BoxesRunTime.boxToDouble(_2$mcD$sp)).toString());
        Predef$.MODULE$.println(new StringBuilder().append("accuracy = ").append(BoxesRunTime.boxToDouble(1.0d - (_1$mcD$sp / _2$mcD$sp))).toString());
        BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
    }

    public Tuple2<Object, Object> testSegmentCorpus(String str) {
        DoubleRef doubleRef = new DoubleRef(0.0d);
        DoubleRef doubleRef2 = new DoubleRef(0.0d);
        corpus().foreachDocuments(str, new RNNSegmentTrainingApp$$anonfun$testSegmentCorpus$1(doubleRef, doubleRef2));
        return new Tuple2.mcDD.sp(doubleRef.elem, doubleRef2.elem);
    }

    public Seq<Object> classify(Seq<Tuple2<Object, Matrix>> seq, Seq<BackPropagation> seq2) {
        return seq.forall(new RNNSegmentTrainingApp$$anonfun$classify$3()) ? (Seq) seq.map(new RNNSegmentTrainingApp$$anonfun$classify$4(), Seq$.MODULE$.canBuildFrom()) : new RNNSegmentViterbiClassifier(seq2, rnn(), transitionProb(), ngram()).classify(seq);
    }

    public Seq<Seq<Tuple2<Object, Matrix>>> splitByUnknownWords(Seq<Tuple3<Object, Matrix, Object>> seq) {
        Seq seq2 = (Seq) seq.map(new RNNSegmentTrainingApp$$anonfun$4(), Seq$.MODULE$.canBuildFrom());
        int i = 0;
        int indexWhere = seq2.indexWhere(new RNNSegmentTrainingApp$$anonfun$5(), 0);
        ListBuffer apply = ListBuffer$.MODULE$.apply(Nil$.MODULE$);
        while (i < seq2.length()) {
            if (indexWhere < 0) {
                apply.$plus$eq(seq2.slice(i, seq2.length()));
                i = seq2.length();
            } else {
                if (i < indexWhere) {
                    apply.$plus$eq(seq2.slice(i, indexWhere));
                } else {
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                }
                apply.$plus$eq(seq2.slice(indexWhere, indexWhere + 1));
                i = indexWhere + 1;
                indexWhere = seq2.indexWhere(new RNNSegmentTrainingApp$$anonfun$splitByUnknownWords$1(), i);
            }
        }
        return apply;
    }

    private void saveModel() {
        SerializeHandler apply = SerializeHandler$.MODULE$.apply(new File(saveFile()), SerializeHandler$.MODULE$.WRITE_ONLY());
        apply.serializeInt(networks().size());
        networks().foreach(new RNNSegmentTrainingApp$$anonfun$saveModel$1(apply));
        apply.serializeMatrix(transitionProb());
        apply.close();
    }

    public Tuple2<Object, Object> me$yingrui$segment$word2vec$apps$RNNSegmentTrainingApp$$test(String str) {
        DoubleRef doubleRef = new DoubleRef(0.0d);
        DoubleRef doubleRef2 = new DoubleRef(0.0d);
        corpus().foreachDocuments(str, new RNNSegmentTrainingApp$$anonfun$test$1(doubleRef, doubleRef2));
        return new Tuple2.mcDD.sp(doubleRef.elem, doubleRef2.elem);
    }

    public Matrix classify(BackPropagation backPropagation, Matrix matrix) {
        Matrix computeOutput = rnn().computeOutput(backPropagation.computeOutput(matrix));
        IntRef intRef = new IntRef(0);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), computeOutput.col()).foreach$mVc$sp(new RNNSegmentTrainingApp$$anonfun$classify$1(computeOutput, intRef, new DoubleRef(0.0d)));
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), computeOutput.col()).foreach$mVc$sp(new RNNSegmentTrainingApp$$anonfun$classify$2(computeOutput, intRef));
        return computeOutput;
    }

    public double me$yingrui$segment$word2vec$apps$RNNSegmentTrainingApp$$takeARound(int i, double d) {
        rnn().errorCalculator().clear();
        ((Seq) files().map(new RNNSegmentTrainingApp$$anonfun$6(d), Seq$.MODULE$.canBuildFrom())).foreach(new RNNSegmentTrainingApp$$anonfun$takeARound$1());
        return rnn().getLoss();
    }

    public boolean shouldContinue() {
        File file = new File("stop-training.tmp");
        if (!Files.exists(file.toPath(), new LinkOption[0])) {
            return true;
        }
        file.delete();
        return false;
    }

    private IndexedSeq<BackPropagation> initializeNetworks(int i, int i2, int i3, int i4) {
        return (IndexedSeq) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), i4).map(new RNNSegmentTrainingApp$$anonfun$initializeNetworks$1(i, i2, (int) Math.pow(4.0d, i3)), IndexedSeq$.MODULE$.canBuildFrom());
    }

    public BackPropagation me$yingrui$segment$word2vec$apps$RNNSegmentTrainingApp$$initializeRNN(int i) {
        BackPropagation backPropagation = new BackPropagation(numberOfFeatures(), i, 0.1d, 0.0d, new CrossEntropyLoss());
        int pow = (int) Math.pow(4.0d, ngram());
        BPRecurrentLayer bPRecurrentLayer = new BPRecurrentLayer(Matrix$.MODULE$.randomize(pow, pow, -1.0d, 1.0d), Matrix$.MODULE$.randomize(1, pow, -1.0d, 1.0d), false);
        BPLayer apply = SoftmaxLayer$.MODULE$.apply(Matrix$.MODULE$.randomize(pow, i, -1.0d, 1.0d));
        backPropagation.addLayer(bPRecurrentLayer);
        backPropagation.addLayer(apply);
        return backPropagation;
    }

    public Tuple2<IndexedSeq<BackPropagation>, Matrix> me$yingrui$segment$word2vec$apps$RNNSegmentTrainingApp$$load(String str) {
        SerializeHandler apply = SerializeHandler$.MODULE$.apply(new File(str), SerializeHandler$.MODULE$.READ_ONLY());
        IndexedSeq<BackPropagation> initializeNetworks = initializeNetworks(numberOfFeatures(), numberOfClasses(), ngram(), vocab().size());
        int deserializeInt = apply.deserializeInt();
        Predef$.MODULE$.assert(deserializeInt == initializeNetworks.size());
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), deserializeInt).foreach$mVc$sp(new RNNSegmentTrainingApp$$anonfun$load$1(apply, initializeNetworks));
        Matrix deserializeMatrix = apply.deserializeMatrix();
        apply.close();
        return new Tuple2<>(initializeNetworks, deserializeMatrix);
    }

    public void executionContext_$eq(ExecutionContextExecutor executionContextExecutor) {
        this.executionContext = executionContextExecutor;
    }

    public void random_$eq(Random random) {
        this.random = random;
    }

    public void word2VecModelFile_$eq(String str) {
        this.word2VecModelFile = str;
    }

    public void trainFile_$eq(String str) {
        this.trainFile = str;
    }

    public void saveFile_$eq(String str) {
        this.saveFile = str;
    }

    public void ngram_$eq(int i) {
        this.ngram = i;
    }

    public void maxIteration_$eq(int i) {
        this.maxIteration = i;
    }

    public void punishment_$eq(int i) {
        this.punishment = i;
    }

    public void skipSelf_$eq(boolean z) {
        this.skipSelf = z;
    }

    public void taskCount_$eq(int i) {
        this.taskCount = i;
    }

    public void reader_$eq(SerializeHandler serializeHandler) {
        this.reader = serializeHandler;
    }

    public void vocab_$eq(Vocabulary vocabulary) {
        this.vocab = vocabulary;
    }

    public void word2VecModel_$eq(double[][] dArr) {
        this.word2VecModel = dArr;
    }

    public void numberOfFeatures_$eq(int i) {
        this.numberOfFeatures = i;
    }

    public void labelNgram_$eq(int i) {
        this.labelNgram = i;
    }

    public void numberOfClasses_$eq(int i) {
        this.numberOfClasses = i;
    }

    public void x$1_$eq(Tuple2 tuple2) {
        this.x$1 = tuple2;
    }

    public void networks_$eq(IndexedSeq indexedSeq) {
        this.networks = indexedSeq;
    }

    public Tuple2 x$1() {
        return this.x$1;
    }

    public void transitionProb_$eq(Matrix matrix) {
        this.transitionProb = matrix;
    }

    public void rnn_$eq(BackPropagation backPropagation) {
        this.rnn = backPropagation;
    }

    public void corpus_$eq(SegmentCorpus segmentCorpus) {
        this.corpus = segmentCorpus;
    }

    public void files_$eq(Seq seq) {
        this.files = seq;
    }

    public void costs_$eq(ListBuffer listBuffer) {
        this.costs = listBuffer;
    }

    private RNNSegmentTrainingApp$() {
        MODULE$ = this;
        App.class.$init$(this);
        delayedInit(new AbstractFunction0(this) { // from class: me.yingrui.segment.word2vec.apps.RNNSegmentTrainingApp$delayedInit$body
            private final RNNSegmentTrainingApp$ $outer;

            public final Object apply() {
                this.$outer.executionContext_$eq(ExecutionContext$Implicits$.MODULE$.global());
                this.$outer.random_$eq(new Random(System.currentTimeMillis()));
                this.$outer.word2VecModelFile_$eq(Predef$.MODULE$.refArrayOps(this.$outer.args()).indexOf("--word2vec-model") >= 0 ? this.$outer.args()[Predef$.MODULE$.refArrayOps(this.$outer.args()).indexOf("--word2vec-model") + 1] : "vectors.cn.hs.dat");
                this.$outer.trainFile_$eq(Predef$.MODULE$.refArrayOps(this.$outer.args()).indexOf("--train-file") >= 0 ? this.$outer.args()[Predef$.MODULE$.refArrayOps(this.$outer.args()).indexOf("--train-file") + 1] : "lib-segment/training-100000.txt");
                this.$outer.saveFile_$eq(Predef$.MODULE$.refArrayOps(this.$outer.args()).indexOf("--save-file") >= 0 ? this.$outer.args()[Predef$.MODULE$.refArrayOps(this.$outer.args()).indexOf("--save-file") + 1] : "segment-vector-100000.dat");
                this.$outer.ngram_$eq(Predef$.MODULE$.refArrayOps(this.$outer.args()).indexOf("-ngram") >= 0 ? new StringOps(Predef$.MODULE$.augmentString(this.$outer.args()[Predef$.MODULE$.refArrayOps(this.$outer.args()).indexOf("-ngram") + 1])).toInt() : 2);
                this.$outer.maxIteration_$eq(Predef$.MODULE$.refArrayOps(this.$outer.args()).indexOf("-iter") >= 0 ? new StringOps(Predef$.MODULE$.augmentString(this.$outer.args()[Predef$.MODULE$.refArrayOps(this.$outer.args()).indexOf("-iter") + 1])).toInt() : 20);
                this.$outer.punishment_$eq(Predef$.MODULE$.refArrayOps(this.$outer.args()).indexOf("-punishment") >= 0 ? new StringOps(Predef$.MODULE$.augmentString(this.$outer.args()[Predef$.MODULE$.refArrayOps(this.$outer.args()).indexOf("-punishment") + 1])).toInt() : 0);
                this.$outer.skipSelf_$eq(Predef$.MODULE$.refArrayOps(this.$outer.args()).indexOf("-skip-self") >= 0 ? new StringOps(Predef$.MODULE$.augmentString(this.$outer.args()[Predef$.MODULE$.refArrayOps(this.$outer.args()).indexOf("-skip-self") + 1])).toBoolean() : true);
                this.$outer.taskCount_$eq(1);
                Predef$.MODULE$.print("loading word2vec model...\r");
                this.$outer.reader_$eq(SerializeHandler$.MODULE$.apply(new File(this.$outer.word2VecModelFile()), SerializeHandler$.MODULE$.READ_ONLY()));
                this.$outer.vocab_$eq(Vocabulary$.MODULE$.apply(this.$outer.reader()));
                this.$outer.word2VecModel_$eq(this.$outer.reader().deserialize2DArrayDouble());
                Predef$.MODULE$.assert(this.$outer.vocab().size() == this.$outer.word2VecModel().length, new RNNSegmentTrainingApp$$anonfun$1());
                this.$outer.numberOfFeatures_$eq(this.$outer.word2VecModel()[0].length);
                this.$outer.labelNgram_$eq(1);
                this.$outer.numberOfClasses_$eq((int) Math.pow(4.0d, this.$outer.labelNgram()));
                RNNSegmentTrainingApp$ rNNSegmentTrainingApp$ = this.$outer;
                Tuple2<IndexedSeq<BackPropagation>, Matrix> me$yingrui$segment$word2vec$apps$RNNSegmentTrainingApp$$load = this.$outer.me$yingrui$segment$word2vec$apps$RNNSegmentTrainingApp$$load(this.$outer.saveFile());
                if (me$yingrui$segment$word2vec$apps$RNNSegmentTrainingApp$$load == null) {
                    throw new MatchError(me$yingrui$segment$word2vec$apps$RNNSegmentTrainingApp$$load);
                }
                rNNSegmentTrainingApp$.x$1_$eq(new Tuple2((IndexedSeq) me$yingrui$segment$word2vec$apps$RNNSegmentTrainingApp$$load._1(), (Matrix) me$yingrui$segment$word2vec$apps$RNNSegmentTrainingApp$$load._2()));
                this.$outer.networks_$eq((IndexedSeq) this.$outer.x$1()._1());
                this.$outer.transitionProb_$eq((Matrix) this.$outer.x$1()._2());
                this.$outer.rnn_$eq(this.$outer.me$yingrui$segment$word2vec$apps$RNNSegmentTrainingApp$$initializeRNN(this.$outer.numberOfClasses()));
                Predef$.MODULE$.print("loading training corpus...\r");
                this.$outer.corpus_$eq(new SegmentCorpus(this.$outer.word2VecModel(), this.$outer.vocab(), this.$outer.ngram(), this.$outer.labelNgram()));
                this.$outer.files_$eq(this.$outer.corpus().splitCorpus(this.$outer.trainFile(), this.$outer.taskCount()));
                Predef$.MODULE$.print("training...\r");
                this.$outer.iteration_$eq(0);
                this.$outer.cost_$eq(0.0d);
                this.$outer.lastCost_$eq(Double.MAX_VALUE);
                this.$outer.costs_$eq(new ListBuffer());
                this.$outer.lastAverageCost_$eq(Double.MAX_VALUE);
                this.$outer.hasImprovement_$eq(true);
                this.$outer.learningRate_$eq(1.0E-4d);
                while (this.$outer.shouldContinue() && this.$outer.iteration() < this.$outer.maxIteration() && this.$outer.hasImprovement()) {
                    long currentTimeMillis = System.currentTimeMillis();
                    this.$outer.cost_$eq(this.$outer.me$yingrui$segment$word2vec$apps$RNNSegmentTrainingApp$$takeARound(this.$outer.iteration(), this.$outer.learningRate()));
                    long currentTimeMillis2 = System.currentTimeMillis();
                    this.$outer.costs().$plus$eq(BoxesRunTime.boxToDouble(this.$outer.cost()));
                    double unboxToDouble = BoxesRunTime.unboxToDouble(((TraversableForwarder) this.$outer.costs().takeRight(5)).sum(Numeric$DoubleIsFractional$.MODULE$)) / ((ListBuffer) this.$outer.costs().takeRight(5)).size();
                    double lastCost = (this.$outer.lastCost() - this.$outer.cost()) / this.$outer.lastCost();
                    Predef$.MODULE$.println(new StringOps(Predef$.MODULE$.augmentString("Iteration: %2d learning rate: %2.5f improved: %2.5f cost: %2.5f average cost: %2.5f elapse: %ds")).format(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(this.$outer.iteration()), BoxesRunTime.boxToDouble(this.$outer.learningRate()), BoxesRunTime.boxToDouble(lastCost), BoxesRunTime.boxToDouble(this.$outer.cost()), BoxesRunTime.boxToDouble(unboxToDouble), BoxesRunTime.boxToLong((currentTimeMillis2 - currentTimeMillis) / 1000)})));
                    this.$outer.updateLearningRate(lastCost);
                    this.$outer.hasImprovement_$eq(this.$outer.lastAverageCost() - unboxToDouble > 1.0E-5d);
                    this.$outer.lastAverageCost_$eq(unboxToDouble);
                    this.$outer.lastCost_$eq(this.$outer.cost());
                    this.$outer.iteration_$eq(this.$outer.iteration() + 1);
                }
                Predef$.MODULE$.println("testing...");
                this.$outer.me$yingrui$segment$word2vec$apps$RNNSegmentTrainingApp$$displayResult(this.$outer.me$yingrui$segment$word2vec$apps$RNNSegmentTrainingApp$$test(this.$outer.trainFile()));
                this.$outer.me$yingrui$segment$word2vec$apps$RNNSegmentTrainingApp$$displayResult(this.$outer.testSegmentCorpus(this.$outer.trainFile()));
                return BoxedUnit.UNIT;
            }

            {
                if (this == null) {
                    throw new NullPointerException();
                }
                this.$outer = this;
            }
        });
    }
}
