package com.gengoai.hermes.ml.trainer;

import com.gengoai.LogUtils;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.feature.FeatureExtractor;
import com.gengoai.apollo.ml.feature.Featurizer;
import com.gengoai.apollo.ml.model.FitParameters;
import com.gengoai.apollo.ml.model.Model;
import com.gengoai.apollo.ml.model.sequence.Crf;
import com.gengoai.conversion.Cast;
import com.gengoai.hermes.HString;
import com.gengoai.hermes.Types;
import com.gengoai.hermes.corpus.DocumentCollection;
import com.gengoai.hermes.ml.HStringDataSetGenerator;
import com.gengoai.hermes.ml.feature.BasicCategoryFeature;
import com.gengoai.hermes.ml.feature.Features;
import java.util.logging.Logger;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/hermes/ml/trainer/EntityTrainer.class */
public class EntityTrainer extends IOBTaggerTrainer {
    private static final Logger log = Logger.getLogger(EntityTrainer.class.getName());

    public EntityTrainer() {
        super(Types.ML_ENTITY, Types.ENTITY);
    }

    @Override // com.gengoai.hermes.ml.trainer.IOBTaggerTrainer
    protected void addInputs(@NonNull HStringDataSetGenerator.Builder builder) {
        if (builder == null) {
            throw new NullPointerException("builder is marked non-null but is null");
        }
        FeatureExtractor<HString> featurizer = getFeaturizer();
        LogUtils.logInfo(log, "\n{0}", new Object[]{featurizer});
        builder.tokenSequence("input", featurizer);
    }

    @Override // com.gengoai.hermes.ml.trainer.SequenceTaggerTrainer
    public DataSet createDataset(@NonNull DocumentCollection documentCollection) {
        if (documentCollection == null) {
            throw new NullPointerException("data is marked non-null but is null");
        }
        return documentCollection.annotate(Types.CATEGORY).asDataSet(getExampleGenerator());
    }

    @Override // com.gengoai.hermes.ml.trainer.SequenceTaggerTrainer
    protected Model createSequenceLabeler(FitParameters<?> fitParameters) {
        return new Crf((Crf.Parameters) Cast.as(fitParameters));
    }

    private FeatureExtractor<HString> getFeaturizer() {
        return Featurizer.chain(new Featurizer[]{Features.LowerCaseWord, Features.PartOfSpeech, Features.WordShape, Features.WordClass, Features.IsBeginOfSentence, Features.IsEndOfSentence, Features.IsDigit, Features.IsPercent, Features.IsCardinalNumber, Features.IsOrdinalNumber, Features.IsTitleCase, Features.IsAllCaps, Features.HasCapital, Features.IsLanguageName, Features.IsAlphaNumeric, Features.IsCurrency, Features.IsPunctuation, Features.PhraseChunkBIO, new BasicCategoryFeature()}).withContext(new String[]{"LowerWord[-1]", "~LowerWord[-2]|LowerWord[-1]", "LowerWord[+1]", "~LowerWord[+1]|LowerWord[+2]", "~LowerWord[-1]|LowerWord[+1]", "~LowerWord[-2]|LowerWord[-1]|LowerWord[+1]|LowerWord[+2]", "LowerWord[-1]|LowerWord[0]", "~LowerWord[-2]|LowerWord[-1]|LowerWord[0]", "LowerWord[0]|LowerWord[+1]", "~LowerWord[0]|LowerWord[+1]|LowerWord[+2]", "~LowerWord[-1]|LowerWord[0]|LowerWord[+1]", "~LowerWord[-2]|LowerWord[-1]|LowerWord[0]|LowerWord[+1]|LowerWord[+2]", "POS[-1]", "~POS[-2]|POS[-1]", "POS[+1]", "~POS[+1]|POS[+2]", "~POS[-1]|POS[+1]", "~POS[-2]|POS[-1]|POS[+1]|POS[+2]", "POS[-1]|LowerWord[0]", "~POS[-2]|POS[-1]|LowerWord[0]", "LowerWord[0]|POS[+1]", "~LowerWord[0]|POS[+1]|POS[+2]", "WordShape[-1]", "~WordShape[-2]|WordShape[-1]", "WordShape[+1]", "~WordShape[+1]|WordShape[+2]", "~WordShape[-1]|WordShape[+1]", "~WordShape[-2]|WordShape[-1]|WordShape[+1]|WordShape[+2]", "WordShape[-1]|LowerWord[0]", "~WordShape[-2]|WordShape[-1]|LowerWord[0]", "LowerWord[0]|WordShape[+1]", "~LowerWord[0]|WordShape[+1]|WordShape[+2]", "WordClass[-1]", "~WordClass[-2]|WordClass[-1]", "WordClass[+1]", "~WordClass[+1]|WordClass[+2]", "~WordClass[-1]|WordClass[+1]", "~WordClass[-2]|WordClass[-1]|WordClass[+1]|WordClass[+2]", "WordClass[-1]|LowerWord[0]", "~WordClass[-2]|WordClass[-1]|LowerWord[0]", "LowerWord[0]|WordClass[+1]", "~LowerWord[0]|WordClass[+1]|WordClass[+2]", "IsBeginOfSentence[-1]|LowerWord[0]", "LowerWord[0]|IsEndOfSentence[+1]", "IsDigit[-1]|IsDigit[0]", "IsPercent[-1]|IsPercent[0]", "IsDigit[-1]|IsOrdinal[0]", "IsOrdinal[-1]|IsOrdinal[0]", "IsCardinal[-1]|IsCardinal[0]", "IsOrdinal[-1]|POS[0]", "IsCardinal[-1]|POS[0]", "IsDigit[-1]|LowerWord[0]", "IsPercent[-1]|LowerWord[0]", "IsDigit[-1]|LowerWord[0]", "IsOrdinal[-1]|LowerWord[0]", "IsCardinal[-1]|LowerWord[0]", "IsOrdinal[-1]|LowerWord[0]", "IsCardinal[-1]|LowerWord[0]", "IsTitleCase[-1]|LowerWord[0]", "IsAllCaps[-1]|LowerWord[0]", "HasCapital[-1]|LowerWord[0]"});
    }

    @Override // com.gengoai.hermes.ml.trainer.SequenceTaggerTrainer
    public FitParameters<?> getFitParameters() {
        return new Crf.Parameters().update(parameters -> {
            parameters.minFeatureFreq.set(5);
            parameters.maxIterations.set(500);
        });
    }
}
