package epic.parser.models;

import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.math.Field$fieldDouble$;
import breeze.optimize.CachedBatchDiffFunction;
import breeze.optimize.FirstOrderMinimizer;
import breeze.optimize.GradientTester$;
import breeze.optimize.StochasticDiffFunction;
import breeze.util.Implicits$;
import breeze.util.LazyLogger;
import breeze.util.SerializableLogging;
import breeze.util.package$;
import epic.constraints.CachedChartConstraintsFactory;
import epic.constraints.CachedChartConstraintsFactory$;
import epic.constraints.ChartConstraints;
import epic.constraints.ChartConstraints$Factory$;
import epic.dense.AdadeltaGradientDescentDVD;
import epic.dense.AdadeltaGradientDescentDVD$;
import epic.framework.Model;
import epic.framework.ModelObjective;
import epic.lexicon.Lexicon;
import epic.parser.GenerativeParser$;
import epic.parser.ParseEval;
import epic.parser.ParseMarginal;
import epic.parser.Parser;
import epic.parser.ParserParams;
import epic.parser.ParserParams$XbarGrammar$;
import epic.parser.ParserPipeline;
import epic.parser.RuleTopology;
import epic.parser.StandardChartFactory;
import epic.parser.models.ParserTrainer;
import epic.parser.projections.OracleParser;
import epic.parser.projections.ParserChartConstraintsFactory;
import epic.parser.projections.ParserChartConstraintsFactory$;
import epic.trees.AnnotatedLabel;
import epic.trees.Debinarizer$AnnotatedLabelDebinarizer$;
import epic.trees.ProcessedTreebank;
import epic.trees.TreeInstance;
import epic.trees.annotations.IdentityAnnotator;
import epic.trees.annotations.TreeAnnotator$;
import epic.util.Optional$;
import java.io.File;
import scala.Console$;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Option$;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.IndexedSeq;
import scala.collection.Iterator;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.Range$;
import scala.collection.mutable.StringBuilder;
import scala.collection.parallel.ParIterableLike;
import scala.collection.parallel.immutable.ParSeq$;
import scala.reflect.ClassTag$;
import scala.reflect.Manifest;
import scala.reflect.ManifestFactory$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: ParserTrainer.scala */
/* loaded from: input_file:epic/parser/models/ParserTrainer$.class */
public final class ParserTrainer$ implements ParserPipeline {
    public static final ParserTrainer$ MODULE$ = null;
    private final Manifest<ParserTrainer.Params> paramManifest;
    private volatile transient LazyLogger breeze$util$SerializableLogging$$_the_logger;

    static {
        new ParserTrainer$();
    }

    @Override // epic.parser.ParserPipeline
    public Iterator<Tuple2<String, Parser<AnnotatedLabel, String>>> trainParser(ProcessedTreebank processedTreebank, Object obj) {
        return ParserPipeline.Cclass.trainParser(this, processedTreebank, obj);
    }

    @Override // epic.parser.ParserPipeline
    public void main(String[] strArr) {
        ParserPipeline.Cclass.main(this, strArr);
    }

    @Override // epic.parser.ParserPipeline
    public ParseEval.Statistics evalParser(IndexedSeq<TreeInstance<AnnotatedLabel, String>> indexedSeq, Parser<AnnotatedLabel, String> parser, String str) {
        return ParserPipeline.Cclass.evalParser(this, indexedSeq, parser, str);
    }

    public LazyLogger breeze$util$SerializableLogging$$_the_logger() {
        return this.breeze$util$SerializableLogging$$_the_logger;
    }

    public void breeze$util$SerializableLogging$$_the_logger_$eq(LazyLogger lazyLogger) {
        this.breeze$util$SerializableLogging$$_the_logger = lazyLogger;
    }

    public LazyLogger logger() {
        return SerializableLogging.class.logger(this);
    }

    @Override // epic.parser.ParserPipeline
    public Manifest<ParserTrainer.Params> paramManifest() {
        return this.paramManifest;
    }

    public Iterator<Tuple2<String, Parser<AnnotatedLabel, String>>> trainParser(IndexedSeq<TreeInstance<AnnotatedLabel, String>> indexedSeq, Function1<Parser<AnnotatedLabel, String>, ParseEval.Statistics> function1, ParserTrainer.Params params) {
        Parser<AnnotatedLabel, String> parser;
        Option option;
        Iterator iterations;
        File parser2 = params.parser();
        if (parser2 == null) {
            ParserParams$XbarGrammar$ parserParams$XbarGrammar$ = ParserParams$XbarGrammar$.MODULE$;
            Tuple2<RuleTopology<AnnotatedLabel>, Lexicon<AnnotatedLabel, String>> xbarGrammar = new ParserParams.XbarGrammar(new File("xbar.gr")).xbarGrammar(indexedSeq);
            if (xbarGrammar == null) {
                throw new MatchError(xbarGrammar);
            }
            Tuple2 tuple2 = new Tuple2(xbarGrammar._1(), xbarGrammar._2());
            parser = GenerativeParser$.MODULE$.annotatedParser((RuleTopology) tuple2._1(), (Lexicon) tuple2._2(), params.annotator(), indexedSeq);
        } else {
            parser = (Parser) package$.MODULE$.readObject(parser2);
        }
        ParseMarginal.Factory<AnnotatedLabel, String> marginalFactory = parser.marginalFactory();
        CachedChartConstraintsFactory cachedChartConstraintsFactory = new CachedChartConstraintsFactory(new ParserChartConstraintsFactory(parser.copy(parser.topology(), parser.lexicon(), parser.copy$default$3(), marginalFactory instanceof StandardChartFactory ? new StandardChartFactory(((StandardChartFactory) marginalFactory).refinedGrammar(), true) : marginalFactory, parser.copy$default$5(), Debinarizer$AnnotatedLabelDebinarizer$.MODULE$), new ParserTrainer$$anonfun$4(), ParserChartConstraintsFactory$.MODULE$.$lessinit$greater$default$3()), CachedChartConstraintsFactory$.MODULE$.$lessinit$greater$default$2(), params.cache());
        scala.collection.immutable.IndexedSeq indexedSeq2 = (scala.collection.immutable.IndexedSeq) indexedSeq.toIndexedSeq().filterNot(new ParserTrainer$$anonfun$5(params));
        if (params.useConstraints() && params.enforceReachability()) {
            GenerativeParser$ generativeParser$ = GenerativeParser$.MODULE$;
            RuleTopology<AnnotatedLabel> ruleTopology = parser.topology();
            Lexicon<AnnotatedLabel, String> lexicon = parser.lexicon();
            TreeAnnotator$ treeAnnotator$ = TreeAnnotator$.MODULE$;
            indexedSeq2 = ((ParIterableLike) indexedSeq2.par().map(new ParserTrainer$$anonfun$trainParser$1(cachedChartConstraintsFactory, new OracleParser(generativeParser$.annotated(ruleTopology, lexicon, new IdentityAnnotator(), indexedSeq), Optional$.MODULE$.anyToOptional(GenerativeParser$.MODULE$.annotated(parser.topology(), parser.lexicon(), params.annotator(), indexedSeq)))), ParSeq$.MODULE$.canBuildFrom())).seq().toIndexedSeq();
        }
        if (!params.useConstraints()) {
            ChartConstraints$Factory$ chartConstraints$Factory$ = ChartConstraints$Factory$.MODULE$;
            new ChartConstraints.NoSparsityFactory();
        }
        Model make = params.modelFactory().make(indexedSeq2, parser.topology(), parser.lexicon(), cachedChartConstraintsFactory);
        ModelObjective modelObjective = new ModelObjective(make, (IndexedSeq) indexedSeq2, params.threads());
        CachedBatchDiffFunction cachedBatchDiffFunction = new CachedBatchDiffFunction(modelObjective, DenseVector$.MODULE$.canCopyDenseVector(ClassTag$.MODULE$.Double()));
        DenseVector<Object> initialWeightVector = modelObjective.initialWeightVector(params.randomize());
        if (params.checkGradient()) {
            CachedBatchDiffFunction cachedBatchDiffFunction2 = new CachedBatchDiffFunction(new ModelObjective(make, (IndexedSeq) indexedSeq2.take(params.opt().batchSize()), params.threads()), DenseVector$.MODULE$.canCopyDenseVector(ClassTag$.MODULE$.Double()));
            Predef$ predef$ = Predef$.MODULE$;
            scala.collection.immutable.IndexedSeq indexedSeq3 = (scala.collection.immutable.IndexedSeq) Range$.MODULE$.apply(0, 10).map(new ParserTrainer$$anonfun$1(make), IndexedSeq$.MODULE$.canBuildFrom());
            Predef$.MODULE$.println(new StringBuilder().append("testIndices: ").append(indexedSeq3).toString());
            GradientTester$.MODULE$.testIndices(cachedBatchDiffFunction2, modelObjective.initialWeightVector(true), indexedSeq3, true, new ParserTrainer$$anonfun$6(make), GradientTester$.MODULE$.testIndices$default$6(), GradientTester$.MODULE$.testIndices$default$7(), Predef$.MODULE$.$conforms(), Predef$.MODULE$.$conforms(), DenseVector$.MODULE$.canCopyDenseVector(ClassTag$.MODULE$.Double()), DenseVector$.MODULE$.canNorm(Field$fieldDouble$.MODULE$), DenseVector$.MODULE$.canSubD());
            Predef$.MODULE$.println("test");
            GradientTester$.MODULE$.test(cachedBatchDiffFunction2, modelObjective.initialWeightVector(true), GradientTester$.MODULE$.test$default$3(), false, GradientTester$.MODULE$.test$default$5(), GradientTester$.MODULE$.test$default$6(), new ParserTrainer$$anonfun$7(make), Predef$.MODULE$.$conforms(), Predef$.MODULE$.$conforms(), DenseVector$.MODULE$.canCopyDenseVector(ClassTag$.MODULE$.Double()), DenseVector$.MODULE$.canNorm(Field$fieldDouble$.MODULE$), DenseVector$.MODULE$.canSubD());
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        Option apply = Option$.MODULE$.apply(params.name());
        ParserTrainer$$anonfun$8 parserTrainer$$anonfun$8 = new ParserTrainer$$anonfun$8(make);
        if (apply.isEmpty()) {
            None$ apply2 = Option$.MODULE$.apply(make.getClass().getSimpleName());
            option = (Option) ((apply2.isEmpty() || BoxesRunTime.unboxToBoolean(new ParserTrainer$$anonfun$8$$anonfun$apply$1(parserTrainer$$anonfun$8).apply(apply2.get()))) ? apply2 : None$.MODULE$);
        } else {
            option = apply;
        }
        String str = (String) (!option.isEmpty() ? option.get() : "DiscrimParser");
        if (params.determinizeTraining()) {
            StochasticDiffFunction withScanningBatches = cachedBatchDiffFunction.withScanningBatches(params.opt().batchSize());
            if (params.useAdadelta()) {
                Predef$ predef$2 = Predef$.MODULE$;
                Console$.MODULE$.println("OPTIMIZATION: Adadelta");
                int maxIterations = params.opt().maxIterations();
                AdadeltaGradientDescentDVD$ adadeltaGradientDescentDVD$ = AdadeltaGradientDescentDVD$.MODULE$;
                AdadeltaGradientDescentDVD$ adadeltaGradientDescentDVD$2 = AdadeltaGradientDescentDVD$.MODULE$;
                iterations = new AdadeltaGradientDescentDVD(maxIterations, 0.95d, 1.0E-5d, AdadeltaGradientDescentDVD$.MODULE$.$lessinit$greater$default$4()).iterations(withScanningBatches, initialWeightVector);
            } else {
                Predef$ predef$3 = Predef$.MODULE$;
                Console$.MODULE$.println("OPTIMIZATION: Adagrad");
                iterations = params.opt().iterations(withScanningBatches, initialWeightVector, DenseVector$.MODULE$.space_Double());
            }
        } else if (params.useAdadelta()) {
            Predef$ predef$4 = Predef$.MODULE$;
            Console$.MODULE$.println("OPTIMIZATION: Adadelta");
            int maxIterations2 = params.opt().maxIterations();
            AdadeltaGradientDescentDVD$ adadeltaGradientDescentDVD$3 = AdadeltaGradientDescentDVD$.MODULE$;
            AdadeltaGradientDescentDVD$ adadeltaGradientDescentDVD$4 = AdadeltaGradientDescentDVD$.MODULE$;
            iterations = new AdadeltaGradientDescentDVD(maxIterations2, 0.95d, 1.0E-5d, AdadeltaGradientDescentDVD$.MODULE$.$lessinit$greater$default$4()).iterations(cachedBatchDiffFunction.withRandomBatches(params.opt().batchSize()), initialWeightVector);
        } else {
            Predef$.MODULE$.println("OPTIMIZATION: Adagrad");
            iterations = params.opt().iterations(cachedBatchDiffFunction, initialWeightVector, DenseVector$.MODULE$.space_Double());
        }
        return Implicits$.MODULE$.scEnrichIterator(iterations.take(params.maxIterations()).zipWithIndex()).tee(new ParserTrainer$$anonfun$trainParser$2(function1, params, make)).withFilter(new ParserTrainer$$anonfun$trainParser$3()).withFilter(new ParserTrainer$$anonfun$trainParser$4(params)).map(new ParserTrainer$$anonfun$trainParser$5(indexedSeq, params, make, str));
    }

    public boolean sentTooLong(TreeInstance<AnnotatedLabel, String> treeInstance, int i) {
        return treeInstance.words().count(new ParserTrainer$$anonfun$sentTooLong$1()) > i;
    }

    public boolean evaluateNow() {
        File file = new File("EVALUATE_NOW");
        if (!file.exists()) {
            return false;
        }
        file.delete();
        logger().info(new ParserTrainer$$anonfun$evaluateNow$1());
        return true;
    }

    public void computeLL(IndexedSeq<TreeInstance<AnnotatedLabel, String>> indexedSeq, Model<TreeInstance<AnnotatedLabel, String>> model, DenseVector<Object> denseVector) {
        Predef$.MODULE$.println("Computing final log likelihood on the whole training set...");
        Predef$.MODULE$.println(new StringBuilder().append("Log likelihood on ").append(BoxesRunTime.boxToInteger(indexedSeq.size())).append(" examples: ").append(BoxesRunTime.boxToDouble(BoxesRunTime.unboxToDouble(indexedSeq.par().aggregate(new ParserTrainer$$anonfun$2(), new ParserTrainer$$anonfun$10(model.inferenceFromWeights(denseVector).forTesting()), new ParserTrainer$$anonfun$3())))).toString());
    }

    private Object readResolve() {
        return MODULE$;
    }

    @Override // epic.parser.ParserPipeline
    public /* bridge */ /* synthetic */ Iterator trainParser(IndexedSeq indexedSeq, Function1 function1, Object obj) {
        return trainParser((IndexedSeq<TreeInstance<AnnotatedLabel, String>>) indexedSeq, (Function1<Parser<AnnotatedLabel, String>, ParseEval.Statistics>) function1, (ParserTrainer.Params) obj);
    }

    public final void epic$parser$models$ParserTrainer$$evalAndCache$1(Tuple2 tuple2, Function1 function1, ParserTrainer.Params params, Model model) {
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2(tuple2._1(), BoxesRunTime.boxToInteger(tuple2._2$mcI$sp()));
        FirstOrderMinimizer.State state = (FirstOrderMinimizer.State) tuple22._1();
        int _2$mcI$sp = tuple22._2$mcI$sp();
        DenseVector<Object> denseVector = (DenseVector) state.x();
        if (_2$mcI$sp % params.iterPerValidate() == 0) {
            logger().info(new ParserTrainer$$anonfun$epic$parser$models$ParserTrainer$$evalAndCache$1$1());
            logger().info(new ParserTrainer$$anonfun$epic$parser$models$ParserTrainer$$evalAndCache$1$2((ParseEval.Statistics) function1.apply(((ParserExtractable) model).extractParser(denseVector, Debinarizer$AnnotatedLabelDebinarizer$.MODULE$))));
        }
    }

    private ParserTrainer$() {
        MODULE$ = this;
        SerializableLogging.class.$init$(this);
        ParserPipeline.Cclass.$init$(this);
        this.paramManifest = Predef$.MODULE$.manifest(ManifestFactory$.MODULE$.classType(ParserTrainer.Params.class));
    }
}
