package com.gengoai.hermes.extraction.keyword;

import com.gengoai.collection.counter.Counter;
import com.gengoai.collection.counter.Counters;
import com.gengoai.hermes.HString;
import com.gengoai.hermes.Types;
import com.gengoai.hermes.corpus.DocumentCollection;
import com.gengoai.hermes.extraction.Extraction;
import com.gengoai.hermes.extraction.FeaturizingExtractor;
import com.gengoai.hermes.extraction.lyre.LyreDSL;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/hermes/extraction/keyword/TFIDFKeywordExtractor.class */
public class TFIDFKeywordExtractor implements KeywordExtractor {
    private static final long serialVersionUID = 1;
    private final FeaturizingExtractor termExtractor;
    private Counter<String> inverseDocumentFrequencies;

    public TFIDFKeywordExtractor() {
        this(LyreDSL.lower(LyreDSL.filter(LyreDSL.annotation(Types.TOKEN), LyreDSL.isContentWord)));
    }

    public TFIDFKeywordExtractor(@NonNull FeaturizingExtractor featurizingExtractor) {
        if (featurizingExtractor == null) {
            throw new NullPointerException("termExtractor is marked non-null but is null");
        }
        this.termExtractor = featurizingExtractor;
    }

    @Override // com.gengoai.hermes.extraction.Extractor
    public Extraction extract(HString hString) {
        Counter<String> count = this.termExtractor.extract(hString).count();
        Counter newCounter = Counters.newCounter(new String[0]);
        double maximumCount = count.maximumCount();
        count.forEach((str, d) -> {
            newCounter.set(str, (0.5d + ((0.5d * d.doubleValue()) / maximumCount)) * this.inverseDocumentFrequencies.get(str));
        });
        return Extraction.fromCounter(newCounter);
    }

    @Override // com.gengoai.hermes.extraction.keyword.KeywordExtractor
    public void fit(DocumentCollection documentCollection) {
        double size = documentCollection.size();
        this.inverseDocumentFrequencies = documentCollection.documentCount(this.termExtractor).adjustValuesSelf(d -> {
            return Math.log(size / d);
        });
    }
}
