package com.gengoai.hermes.workflow.actions;

import com.gengoai.Validation;
import com.gengoai.collection.counter.Counter;
import com.gengoai.conversion.Cast;
import com.gengoai.hermes.Types;
import com.gengoai.hermes.corpus.DocumentCollection;
import com.gengoai.hermes.extraction.keyword.KeywordExtractor;
import com.gengoai.hermes.extraction.keyword.TermKeywordExtractor;
import com.gengoai.hermes.morphology.StandardTokenizer;
import com.gengoai.hermes.workflow.Action;
import com.gengoai.hermes.workflow.Context;
import com.gengoai.stream.MCounterAccumulator;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/hermes/workflow/actions/KeywordExtraction.class */
public class KeywordExtraction implements Action {
    private static final long serialVersionUID = 1;
    private int N;
    private KeywordExtractor extractor;
    private boolean keepGlobalCounts;

    public static Counter<String> getKeywords(@NonNull Context context) {
        if (context == null) {
            throw new NullPointerException("context is marked non-null but is null");
        }
        return (Counter) Cast.as(context.get(Types.KEYWORDS.name()));
    }

    public KeywordExtraction(@NonNull KeywordExtractor keywordExtractor, int i, boolean z) {
        this.N = Integer.MAX_VALUE;
        this.extractor = new TermKeywordExtractor();
        this.keepGlobalCounts = false;
        if (keywordExtractor == null) {
            throw new NullPointerException("extractor is marked non-null but is null");
        }
        this.extractor = keywordExtractor;
        this.keepGlobalCounts = z;
        Validation.checkArgument(i > 0, "N must be >0");
        this.N = i;
    }

    public KeywordExtraction() {
        this.N = Integer.MAX_VALUE;
        this.extractor = new TermKeywordExtractor();
        this.keepGlobalCounts = false;
    }

    @Override // com.gengoai.hermes.workflow.Action
    public DocumentCollection process(DocumentCollection documentCollection, Context context) throws Exception {
        this.extractor.fit(documentCollection);
        MCounterAccumulator counterAccumulator = this.keepGlobalCounts ? documentCollection.getStreamingContext().counterAccumulator() : null;
        documentCollection.update("KeywordExtraction", document -> {
            ArrayList arrayList = new ArrayList(this.extractor.extract(document).count().topN(this.N).items());
            document.put(Types.KEYWORDS, arrayList);
            if (this.keepGlobalCounts) {
                arrayList.forEach(str -> {
                    counterAccumulator.increment(str, 1.0d);
                });
            }
        });
        if (this.keepGlobalCounts) {
            context.property(Types.KEYWORDS.name(), counterAccumulator.value());
        }
        return documentCollection;
    }

    public int getN() {
        return this.N;
    }

    public KeywordExtractor getExtractor() {
        return this.extractor;
    }

    public boolean isKeepGlobalCounts() {
        return this.keepGlobalCounts;
    }

    public void setN(int i) {
        this.N = i;
    }

    public void setExtractor(KeywordExtractor keywordExtractor) {
        this.extractor = keywordExtractor;
    }

    public void setKeepGlobalCounts(boolean z) {
        this.keepGlobalCounts = z;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof KeywordExtraction)) {
            return false;
        }
        KeywordExtraction keywordExtraction = (KeywordExtraction) obj;
        if (!keywordExtraction.canEqual(this) || getN() != keywordExtraction.getN()) {
            return false;
        }
        KeywordExtractor extractor = getExtractor();
        KeywordExtractor extractor2 = keywordExtraction.getExtractor();
        if (extractor == null) {
            if (extractor2 != null) {
                return false;
            }
        } else if (!extractor.equals(extractor2)) {
            return false;
        }
        return isKeepGlobalCounts() == keywordExtraction.isKeepGlobalCounts();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof KeywordExtraction;
    }

    public int hashCode() {
        int n = (1 * 59) + getN();
        KeywordExtractor extractor = getExtractor();
        return (((n * 59) + (extractor == null ? 43 : extractor.hashCode())) * 59) + (isKeepGlobalCounts() ? 79 : 97);
    }

    public String toString() {
        return "KeywordExtraction(N=" + getN() + ", extractor=" + getExtractor() + ", keepGlobalCounts=" + isKeepGlobalCounts() + ")";
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 222917569:
                if (implMethodName.equals("lambda$process$860594d$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case StandardTokenizer.YYINITIAL /* 0 */:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableConsumer") && serializedLambda.getFunctionalInterfaceMethodName().equals("accept") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)V") && serializedLambda.getImplClass().equals("com/gengoai/hermes/workflow/actions/KeywordExtraction") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/stream/MCounterAccumulator;Lcom/gengoai/hermes/Document;)V")) {
                    KeywordExtraction keywordExtraction = (KeywordExtraction) serializedLambda.getCapturedArg(0);
                    MCounterAccumulator mCounterAccumulator = (MCounterAccumulator) serializedLambda.getCapturedArg(1);
                    return document -> {
                        ArrayList arrayList = new ArrayList(this.extractor.extract(document).count().topN(this.N).items());
                        document.put(Types.KEYWORDS, arrayList);
                        if (this.keepGlobalCounts) {
                            arrayList.forEach(str -> {
                                mCounterAccumulator.increment(str, 1.0d);
                            });
                        }
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
