package com.gengoai.hermes.ml.trainer;

import com.gengoai.LogUtils;
import com.gengoai.Stopwatch;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.model.FitParameters;
import com.gengoai.apollo.ml.model.Model;
import com.gengoai.hermes.corpus.DocumentCollection;
import com.gengoai.hermes.ml.HStringDataSetGenerator;
import com.gengoai.hermes.ml.SequenceTagger;
import java.util.logging.Logger;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/hermes/ml/trainer/SequenceTaggerTrainer.class */
public abstract class SequenceTaggerTrainer<T extends SequenceTagger> {
    private static final Logger log = Logger.getLogger(SequenceTaggerTrainer.class.getName());

    public DataSet createDataset(@NonNull DocumentCollection documentCollection) {
        if (documentCollection == null) {
            throw new NullPointerException("data is marked non-null but is null");
        }
        Stopwatch createStarted = Stopwatch.createStarted();
        DataSet cache = documentCollection.asDataSet(getExampleGenerator()).cache();
        createStarted.stop();
        LogUtils.logInfo(log, "Took {0} to create dataset with {1} examples.", new Object[]{createStarted, Long.valueOf(cache.size())});
        return cache;
    }

    protected abstract Model createSequenceLabeler(FitParameters<?> fitParameters);

    protected abstract T createTagger(Model model, HStringDataSetGenerator hStringDataSetGenerator);

    public T fit(DocumentCollection documentCollection, FitParameters<?> fitParameters) {
        Model createSequenceLabeler = createSequenceLabeler(fitParameters);
        createSequenceLabeler.estimate(createDataset(documentCollection));
        return createTagger(createSequenceLabeler, getExampleGenerator());
    }

    protected abstract HStringDataSetGenerator getExampleGenerator();

    public abstract FitParameters<?> getFitParameters();
}
