package com.gengoai.hermes.lexicon.generation;

import com.gengoai.Tag;
import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.math.linalg.NDArrayFactory;
import com.gengoai.apollo.math.statistics.measure.Similarity;
import com.gengoai.apollo.ml.model.embedding.VSQuery;
import com.gengoai.apollo.ml.model.embedding.WordEmbedding;
import com.gengoai.collection.counter.MultiCounter;
import com.gengoai.collection.counter.MultiCounters;
import com.gengoai.collection.multimap.HashSetMultimap;
import com.gengoai.collection.multimap.Multimap;
import com.gengoai.collection.multimap.SetMultimap;
import com.gengoai.string.Strings;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Stream;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/hermes/lexicon/generation/DistributionalLexiconGenerator.class */
public class DistributionalLexiconGenerator<T extends Tag> implements LexiconGenerator<T> {
    private final SetMultimap<T, String> negativeSeedTerms;
    private final SetMultimap<T, String> seedTerms;
    private final WordEmbedding wordEmbeddings;
    private int maximumTermCount;
    private double threshold;

    public DistributionalLexiconGenerator(@NonNull WordEmbedding wordEmbedding) {
        this.negativeSeedTerms = new HashSetMultimap();
        this.seedTerms = new HashSetMultimap();
        this.maximumTermCount = 100;
        this.threshold = 0.4d;
        if (wordEmbedding == null) {
            throw new NullPointerException("wordEmbeddings is marked non-null but is null");
        }
        this.wordEmbeddings = wordEmbedding;
    }

    public DistributionalLexiconGenerator(@NonNull WordEmbedding wordEmbedding, @NonNull Multimap<T, String> multimap) {
        this.negativeSeedTerms = new HashSetMultimap();
        this.seedTerms = new HashSetMultimap();
        this.maximumTermCount = 100;
        this.threshold = 0.4d;
        if (wordEmbedding == null) {
            throw new NullPointerException("wordEmbeddings is marked non-null but is null");
        }
        if (multimap == null) {
            throw new NullPointerException("seedTerms is marked non-null but is null");
        }
        this.wordEmbeddings = wordEmbedding;
        this.seedTerms.putAll(multimap);
    }

    public DistributionalLexiconGenerator(@NonNull WordEmbedding wordEmbedding, @NonNull Multimap<T, String> multimap, double d) {
        this.negativeSeedTerms = new HashSetMultimap();
        this.seedTerms = new HashSetMultimap();
        this.maximumTermCount = 100;
        this.threshold = 0.4d;
        if (wordEmbedding == null) {
            throw new NullPointerException("wordEmbeddings is marked non-null but is null");
        }
        if (multimap == null) {
            throw new NullPointerException("seedTerms is marked non-null but is null");
        }
        this.wordEmbeddings = wordEmbedding;
        this.seedTerms.putAll(multimap);
        this.threshold = d;
    }

    public void addNegativeSeeds(@NonNull T t, @NonNull String... strArr) {
        if (t == null) {
            throw new NullPointerException("tag is marked non-null but is null");
        }
        if (strArr == null) {
            throw new NullPointerException("phrases is marked non-null but is null");
        }
        for (String str : strArr) {
            if (Strings.isNotNullOrBlank(str)) {
                this.negativeSeedTerms.put(t, str);
            }
        }
    }

    public void addPositiveSeeds(@NonNull T t, @NonNull String... strArr) {
        if (t == null) {
            throw new NullPointerException("tag is marked non-null but is null");
        }
        if (strArr == null) {
            throw new NullPointerException("phrases is marked non-null but is null");
        }
        for (String str : strArr) {
            if (Strings.isNotNullOrBlank(str)) {
                this.seedTerms.put(t, str);
            }
        }
    }

    @Override // com.gengoai.hermes.lexicon.generation.LexiconGenerator
    public Multimap<T, String> generate() {
        HashSetMultimap hashSetMultimap = new HashSetMultimap();
        if (this.seedTerms.size() > 0) {
            HashMap hashMap = new HashMap();
            HashMap hashMap2 = new HashMap();
            this.seedTerms.keySet().forEach(tag -> {
                NDArray array = NDArrayFactory.DENSE.array(new int[]{this.wordEmbeddings.dimension()});
                Stream stream = this.seedTerms.get(tag).stream();
                WordEmbedding wordEmbedding = this.wordEmbeddings;
                Objects.requireNonNull(wordEmbedding);
                stream.filter(wordEmbedding::contains).forEach(str -> {
                    array.addi(this.wordEmbeddings.embed(str));
                });
                array.divi(this.seedTerms.size());
                hashMap.put(tag, array);
                NDArray array2 = NDArrayFactory.DENSE.array(new int[]{this.wordEmbeddings.dimension()});
                Stream stream2 = this.negativeSeedTerms.get(tag).stream();
                WordEmbedding wordEmbedding2 = this.wordEmbeddings;
                Objects.requireNonNull(wordEmbedding2);
                stream2.filter(wordEmbedding2::contains).forEach(str2 -> {
                    array2.addi(this.wordEmbeddings.embed(str2));
                });
                hashMap2.put(tag, array2);
            });
            hashSetMultimap.putAll(this.seedTerms);
            MultiCounter newConcurrentMultiCounter = MultiCounters.newConcurrentMultiCounter(new Map.Entry[0]);
            hashMap.forEach((tag2, nDArray) -> {
                this.wordEmbeddings.query(VSQuery.vectorQuery(nDArray).limit(this.maximumTermCount * 10)).filter(nDArray -> {
                    return !this.seedTerms.containsValue(nDArray.getLabel());
                }).forEach(nDArray2 -> {
                    double d = 0.0d;
                    if (((NDArray) hashMap2.get(tag2)).norm2() > 0.0d) {
                        d = Similarity.Cosine.calculate((NDArray) hashMap2.get(tag2), nDArray2);
                    }
                    newConcurrentMultiCounter.set((String) nDArray2.getLabel(), tag2, nDArray2.getWeight() - d);
                });
            });
            MultiCounter newMultiCounter = MultiCounters.newMultiCounter(new Map.Entry[0]);
            newConcurrentMultiCounter.firstKeys().forEach(str -> {
                Tag tag3;
                if (this.seedTerms.containsValue(str) || !newConcurrentMultiCounter.contains(str) || (tag3 = (Tag) newConcurrentMultiCounter.get(str).filterByValue(d -> {
                    return d >= this.threshold;
                }).max()) == null) {
                    return;
                }
                newMultiCounter.set(tag3, str, newConcurrentMultiCounter.get(str, tag3));
            });
            newMultiCounter.firstKeys().forEach(tag3 -> {
                hashSetMultimap.putAll(tag3, newMultiCounter.get(tag3).topN(this.maximumTermCount).items());
            });
        }
        return hashSetMultimap;
    }

    public int getMaximumTermCount() {
        return this.maximumTermCount;
    }

    public void setMaximumTermCount(int i) {
        this.maximumTermCount = i;
    }

    public double getThreshold() {
        return this.threshold;
    }

    public void setThreshold(double d) {
        this.threshold = d;
    }
}
