package cc.mallet.topics;

import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.CommandOption;
import cc.mallet.util.Randoms;
import java.io.File;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.Formatter;
import java.util.Iterator;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/* loaded from: input_file:cc/mallet/topics/WordEmbeddings.class */
public class WordEmbeddings {
    static CommandOption.String inputFile = new CommandOption.String(WordEmbeddings.class, "input", "FILENAME", true, null, "The filename from which to read the list of training instances.  Use - for stdin.  The instances must be FeatureSequence or FeatureSequenceWithBigrams, not FeatureVector", null);
    static CommandOption.String outputFile = new CommandOption.String(WordEmbeddings.class, "output", "FILENAME", true, "weights.txt", "The filename to write text-formatted word vectors.", null);
    static CommandOption.Integer numDimensions = new CommandOption.Integer(WordEmbeddings.class, "num-dimensions", "INTEGER", true, 50, "The number of dimensions to fit.", null);
    static CommandOption.Integer windowSizeOption = new CommandOption.Integer(WordEmbeddings.class, "window-size", "INTEGER", true, 5, "The number of adjacent words to consider.", null);
    static CommandOption.Integer numThreads = new CommandOption.Integer(WordEmbeddings.class, "num-threads", "INTEGER", true, 1, "The number of threads for parallel training.", null);
    static CommandOption.Integer numSamples = new CommandOption.Integer(WordEmbeddings.class, "num-samples", "INTEGER", true, 5, "The number of negative samples to use in training.", null);
    static CommandOption.String exampleWord = new CommandOption.String(WordEmbeddings.class, "example-word", "STRING", true, null, "If defined, periodically show the closest vectors to this word.", null);
    Alphabet vocabulary;
    int numWords;
    int numColumns;
    double[] weights;
    double[] squaredGradientSums;
    int stride;
    int[] wordCounts;
    double[] samplingDistribution;
    int[] samplingTable;
    int samplingTableSize;
    double samplingSum;
    int totalWords;
    double maxExpValue;
    double minExpValue;
    double[] sigmoidCache;
    int sigmoidCacheSize;
    int windowSize;
    String queryWord;
    Randoms random;

    public WordEmbeddings() {
        this.samplingTableSize = 100000000;
        this.samplingSum = 0.0d;
        this.totalWords = 0;
        this.maxExpValue = 6.0d;
        this.minExpValue = -6.0d;
        this.sigmoidCacheSize = 1000;
        this.windowSize = 5;
        this.queryWord = "the";
        this.random = new Randoms();
    }

    public WordEmbeddings(Alphabet alphabet, int i, int i2) {
        this.samplingTableSize = 100000000;
        this.samplingSum = 0.0d;
        this.totalWords = 0;
        this.maxExpValue = 6.0d;
        this.minExpValue = -6.0d;
        this.sigmoidCacheSize = 1000;
        this.windowSize = 5;
        this.queryWord = "the";
        this.random = new Randoms();
        this.vocabulary = alphabet;
        this.numWords = this.vocabulary.size();
        System.out.format("Vocab size: %d\n", Integer.valueOf(this.numWords));
        this.numColumns = i;
        this.stride = 2 * i;
        this.weights = new double[this.numWords * this.stride];
        this.squaredGradientSums = new double[this.numWords * this.stride];
        for (int i3 = 0; i3 < this.numWords; i3++) {
            for (int i4 = 0; i4 < 2 * i; i4++) {
                this.weights[(i3 * this.stride) + i4] = (this.random.nextDouble() - 0.5d) / i;
            }
        }
        this.wordCounts = new int[this.numWords];
        this.samplingDistribution = new double[this.numWords];
        this.samplingTable = new int[this.samplingTableSize];
        this.windowSize = i2;
        this.sigmoidCache = new double[this.sigmoidCacheSize + 1];
        for (int i5 = 0; i5 < this.sigmoidCacheSize; i5++) {
            this.sigmoidCache[i5] = 1.0d / (1.0d + Math.exp(-(((i5 / this.sigmoidCacheSize) * (this.maxExpValue - this.minExpValue)) + this.minExpValue)));
        }
    }

    public void countWords(InstanceList instanceList) {
        Iterator<Instance> it = instanceList.iterator();
        while (it.hasNext()) {
            FeatureSequence featureSequence = (FeatureSequence) it.next().getData();
            int length = featureSequence.getLength();
            for (int i = 0; i < length; i++) {
                int indexAtPosition = featureSequence.getIndexAtPosition(i);
                int[] iArr = this.wordCounts;
                iArr[indexAtPosition] = iArr[indexAtPosition] + 1;
            }
            this.totalWords += length;
        }
        double d = 1.0f / this.totalWords;
        this.samplingDistribution[0] = Math.pow(d * this.wordCounts[0], 0.75d);
        for (int i2 = 1; i2 < this.numWords; i2++) {
            this.samplingDistribution[i2] = this.samplingDistribution[i2 - 1] + Math.pow(d * this.wordCounts[i2], 0.75d);
        }
        this.samplingSum = this.samplingDistribution[this.numWords - 1];
        int i3 = 0;
        for (int i4 = 0; i4 < this.samplingTableSize; i4++) {
            while ((this.samplingSum * i4) / this.samplingTableSize > this.samplingDistribution[i3]) {
                i3++;
            }
            this.samplingTable[i4] = i3;
        }
        System.out.println("done counting");
    }

    public void train(InstanceList instanceList, int i, int i2) {
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(i);
        WordEmbeddingRunnable[] wordEmbeddingRunnableArr = new WordEmbeddingRunnable[i];
        for (int i3 = 0; i3 < i; i3++) {
            wordEmbeddingRunnableArr[i3] = new WordEmbeddingRunnable(this, instanceList, i2, i, i3);
            newFixedThreadPool.submit(wordEmbeddingRunnableArr[i3]);
        }
        long currentTimeMillis = System.currentTimeMillis();
        double d = 0.0d;
        boolean z = false;
        while (!z) {
            try {
                Thread.sleep(5000L);
            } catch (InterruptedException e) {
            }
            int i4 = 0;
            for (int i5 = 0; i5 < i; i5++) {
                i4 += wordEmbeddingRunnableArr[i5].wordsSoFar;
                System.out.format("%.3f ", Double.valueOf(wordEmbeddingRunnableArr[i5].getMeanError()));
            }
            long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
            System.out.format("%d\t%d\t%fk w/s %f loss %f avg\n", Integer.valueOf(i4), Long.valueOf(currentTimeMillis2), Double.valueOf(i4 / currentTimeMillis2), Double.valueOf(d / 10000.0d), Double.valueOf(averageAbsWeight()));
            d = 0.0d;
            if (i4 > 5 * this.totalWords) {
                z = true;
                for (int i6 = 0; i6 < i; i6++) {
                    wordEmbeddingRunnableArr[i6].shouldRun = false;
                }
            }
            if (this.queryWord != null) {
                findClosest(copy(this.queryWord));
            }
        }
        newFixedThreadPool.shutdownNow();
    }

    public void findClosest(double[] dArr) {
        IDSorter[] iDSorterArr = new IDSorter[this.numWords];
        double d = 0.0d;
        for (int i = 0; i < this.numColumns; i++) {
            d += dArr[i] * dArr[i];
        }
        double sqrt = 1.0d / Math.sqrt(d);
        System.out.println(d);
        for (int i2 = 0; i2 < this.numWords; i2++) {
            double d2 = 0.0d;
            double d3 = 0.0d;
            for (int i3 = 0; i3 < this.numColumns; i3++) {
                d3 += this.weights[(i2 * this.stride) + i3] * this.weights[(i2 * this.stride) + i3];
            }
            double sqrt2 = 1.0d / Math.sqrt(d3);
            for (int i4 = 0; i4 < this.numColumns; i4++) {
                d2 += sqrt * dArr[i4] * sqrt2 * this.weights[(i2 * this.stride) + i4];
            }
            iDSorterArr[i2] = new IDSorter(i2, d2);
        }
        Arrays.sort(iDSorterArr);
        for (int i5 = 0; i5 < 10; i5++) {
            System.out.format("%f\t%d\t%s\n", Double.valueOf(iDSorterArr[i5].getWeight()), Integer.valueOf(iDSorterArr[i5].getID()), this.vocabulary.lookupObject(iDSorterArr[i5].getID()));
        }
    }

    public double averageAbsWeight() {
        double d = 0.0d;
        for (int i = 0; i < this.numWords; i++) {
            for (int i2 = 0; i2 < this.numColumns; i2++) {
                d += Math.abs(this.weights[(i * this.stride) + i2]);
            }
        }
        return d / (this.numWords * this.numColumns);
    }

    public void write(PrintWriter printWriter) {
        for (int i = 0; i < this.numWords; i++) {
            Formatter formatter = new Formatter();
            formatter.format("%s", this.vocabulary.lookupObject(i));
            for (int i2 = 0; i2 < this.numColumns; i2++) {
                formatter.format(" %.6f", Double.valueOf(this.weights[(i * this.stride) + i2]));
            }
            printWriter.println(formatter);
        }
    }

    public double[] copy(String str) {
        return copy(this.vocabulary.lookupIndex(str));
    }

    public double[] copy(int i) {
        double[] dArr = new double[this.numColumns];
        for (int i2 = 0; i2 < this.numColumns; i2++) {
            dArr[i2] = this.weights[(i * this.stride) + i2];
        }
        return dArr;
    }

    public double[] add(double[] dArr, String str) {
        return add(dArr, this.vocabulary.lookupIndex(str));
    }

    public double[] add(double[] dArr, int i) {
        for (int i2 = 0; i2 < this.numColumns; i2++) {
            int i3 = i2;
            dArr[i3] = dArr[i3] + this.weights[(i * this.stride) + i2];
        }
        return dArr;
    }

    public double[] subtract(double[] dArr, String str) {
        return subtract(dArr, this.vocabulary.lookupIndex(str));
    }

    public double[] subtract(double[] dArr, int i) {
        for (int i2 = 0; i2 < this.numColumns; i2++) {
            int i3 = i2;
            dArr[i3] = dArr[i3] - this.weights[(i * this.stride) + i2];
        }
        return dArr;
    }

    public static void main(String[] strArr) throws Exception {
        CommandOption.setSummary(WordEmbeddings.class, "Train continuous word embeddings using the skip-gram method with negative sampling.");
        CommandOption.process(WordEmbeddings.class, strArr);
        InstanceList load = InstanceList.load(new File(inputFile.value));
        WordEmbeddings wordEmbeddings = new WordEmbeddings(load.getDataAlphabet(), numDimensions.value, windowSizeOption.value);
        wordEmbeddings.queryWord = exampleWord.value;
        wordEmbeddings.countWords(load);
        wordEmbeddings.train(load, numThreads.value, numSamples.value);
        PrintWriter printWriter = new PrintWriter(outputFile.value);
        wordEmbeddings.write(printWriter);
        printWriter.close();
    }
}
