package com.gengoai.hermes.workflow.actions;

import com.gengoai.apollo.math.statistics.measure.Similarity;
import com.gengoai.apollo.ml.model.embedding.WordEmbedding;
import com.gengoai.collection.counter.Counter;
import com.gengoai.collection.counter.Counters;
import com.gengoai.config.Config;
import com.gengoai.hermes.AnnotationType;
import com.gengoai.hermes.Types;
import com.gengoai.hermes.corpus.DocumentCollection;
import com.gengoai.hermes.extraction.TermExtractor;
import com.gengoai.hermes.extraction.lyre.LyreDSL;
import com.gengoai.hermes.lexicon.SimpleWordList;
import com.gengoai.hermes.lexicon.TrieWordList;
import com.gengoai.hermes.morphology.StandardTokenizer;
import com.gengoai.hermes.workflow.Action;
import com.gengoai.hermes.workflow.Context;
import com.gengoai.io.Resources;
import com.gengoai.io.resource.Resource;
import com.gengoai.tuple.Tuples;
import java.io.IOException;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.Map;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/hermes/workflow/actions/SpellChecker.class */
public class SpellChecker implements Action, Serializable {
    private static final long serialVersionUID = 1;
    private final TrieWordList dictionary;
    private final int maxCost;
    private final WordEmbedding spellingEmbedding;

    public SpellChecker(@NonNull WordEmbedding wordEmbedding) {
        this(wordEmbedding, Config.get("SpellcheckerModule.dictionary", new Object[0]).asResource(Resources.fromString()), 2);
        if (wordEmbedding == null) {
            throw new NullPointerException("spellingEmbedding is marked non-null but is null");
        }
    }

    public SpellChecker(@NonNull WordEmbedding wordEmbedding, @NonNull Resource resource, int i) {
        if (wordEmbedding == null) {
            throw new NullPointerException("spellingEmbedding is marked non-null but is null");
        }
        if (resource == null) {
            throw new NullPointerException("dictionary is marked non-null but is null");
        }
        this.spellingEmbedding = wordEmbedding;
        try {
            this.dictionary = TrieWordList.read(resource);
            this.maxCost = i;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // com.gengoai.hermes.workflow.Action
    public DocumentCollection process(@NonNull DocumentCollection documentCollection, @NonNull Context context) throws Exception {
        if (documentCollection == null) {
            throw new NullPointerException("corpus is marked non-null but is null");
        }
        if (context == null) {
            throw new NullPointerException("context is marked non-null but is null");
        }
        Counter<String> documentCount = documentCollection.documentCount(((TermExtractor.Builder) ((TermExtractor.Builder) ((TermExtractor.Builder) TermExtractor.builder().annotations(new AnnotationType[]{Types.TOKEN})).filter(LyreDSL.and(LyreDSL.and(LyreDSL.gte(LyreDSL.len(LyreDSL.$_), 3.0d), LyreDSL.in(LyreDSL.$_, LyreDSL.wordList(new SimpleWordList(this.spellingEmbedding.getAlphabet())))), LyreDSL.or(LyreDSL.isLetter, LyreDSL.isWhitespace)))).toLemma()).build());
        Map collectAsMap = documentCollection.getStreamingContext().stream(documentCount.items()).filter(str -> {
            return !this.dictionary.contains(str);
        }).mapToPair(str2 -> {
            Map<String, Integer> suggest = this.dictionary.suggest(str2, this.maxCost);
            Counter newCounter = Counters.newCounter(new String[0]);
            int orElse = suggest.values().stream().mapToInt(num -> {
                return num.intValue();
            }).min().orElse(2);
            suggest.entrySet().stream().filter(entry -> {
                return ((Integer) entry.getValue()).intValue() <= orElse;
            }).filter(entry2 -> {
                return this.spellingEmbedding.getAlphabet().contains(entry2.getKey());
            }).filter(entry3 -> {
                return documentCount.get((String) entry3.getKey()) >= 10.0d;
            }).forEach(entry4 -> {
                double calculate = Similarity.Cosine.calculate(this.spellingEmbedding.embed((String) entry4.getKey()), this.spellingEmbedding.embed(str2));
                if (calculate > 0.0d) {
                    newCounter.increment((String) entry4.getKey(), calculate);
                }
            });
            return Tuples.$(str2, (String) newCounter.max());
        }).collectAsMap();
        return documentCollection.update("SpellChecker", document -> {
            document.tokenStream().forEach(annotation -> {
                if (collectAsMap.containsKey(annotation.getLemma())) {
                    annotation.put(Types.SPELLING_CORRECTION, (String) collectAsMap.get(annotation.getLemma()));
                }
            });
        });
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -580719739:
                if (implMethodName.equals("lambda$process$ae61ea8c$1")) {
                    z = 2;
                    break;
                }
                break;
            case 1657015460:
                if (implMethodName.equals("lambda$process$80f23433$1")) {
                    z = false;
                    break;
                }
                break;
            case 1766085040:
                if (implMethodName.equals("lambda$process$b8117464$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case StandardTokenizer.YYINITIAL /* 0 */:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializablePredicate") && serializedLambda.getFunctionalInterfaceMethodName().equals("test") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("com/gengoai/hermes/workflow/actions/SpellChecker") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/String;)Z")) {
                    SpellChecker spellChecker = (SpellChecker) serializedLambda.getCapturedArg(0);
                    return str -> {
                        return !this.dictionary.contains(str);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/hermes/workflow/actions/SpellChecker") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/collection/counter/Counter;Ljava/lang/String;)Ljava/util/Map$Entry;")) {
                    SpellChecker spellChecker2 = (SpellChecker) serializedLambda.getCapturedArg(0);
                    Counter counter = (Counter) serializedLambda.getCapturedArg(1);
                    return str2 -> {
                        Map<String, Integer> suggest = this.dictionary.suggest(str2, this.maxCost);
                        Counter newCounter = Counters.newCounter(new String[0]);
                        int orElse = suggest.values().stream().mapToInt(num -> {
                            return num.intValue();
                        }).min().orElse(2);
                        suggest.entrySet().stream().filter(entry -> {
                            return ((Integer) entry.getValue()).intValue() <= orElse;
                        }).filter(entry2 -> {
                            return this.spellingEmbedding.getAlphabet().contains(entry2.getKey());
                        }).filter(entry3 -> {
                            return counter.get((String) entry3.getKey()) >= 10.0d;
                        }).forEach(entry4 -> {
                            double calculate = Similarity.Cosine.calculate(this.spellingEmbedding.embed((String) entry4.getKey()), this.spellingEmbedding.embed(str2));
                            if (calculate > 0.0d) {
                                newCounter.increment((String) entry4.getKey(), calculate);
                            }
                        });
                        return Tuples.$(str2, (String) newCounter.max());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableConsumer") && serializedLambda.getFunctionalInterfaceMethodName().equals("accept") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)V") && serializedLambda.getImplClass().equals("com/gengoai/hermes/workflow/actions/SpellChecker") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Map;Lcom/gengoai/hermes/Document;)V")) {
                    Map map = (Map) serializedLambda.getCapturedArg(0);
                    return document -> {
                        document.tokenStream().forEach(annotation -> {
                            if (map.containsKey(annotation.getLemma())) {
                                annotation.put(Types.SPELLING_CORRECTION, (String) map.get(annotation.getLemma()));
                            }
                        });
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
