package cc.mallet.topics;

import cc.mallet.types.FeatureSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import java.util.Random;

/* loaded from: input_file:cc/mallet/topics/WordEmbeddingRunnable.class */
public class WordEmbeddingRunnable implements Runnable {
    public WordEmbeddings model;
    public InstanceList instances;
    public int numSamples;
    int numThreads;
    int threadID;
    int stride;
    public int docID;
    int numColumns;
    public boolean shouldRun = true;
    double residual = 0.0d;
    int numUpdates = 0;
    public int wordsSoFar = 0;
    public Random random = new Random();

    public WordEmbeddingRunnable(WordEmbeddings wordEmbeddings, InstanceList instanceList, int i, int i2, int i3) {
        this.model = wordEmbeddings;
        this.stride = wordEmbeddings.stride;
        this.instances = instanceList;
        this.numSamples = i;
        this.numThreads = i2;
        this.threadID = i3;
        this.numColumns = wordEmbeddings.numColumns;
    }

    public double getMeanError() {
        if (this.numUpdates == 0) {
            return this.docID;
        }
        double d = this.residual / this.numUpdates;
        this.residual = 0.0d;
        this.numUpdates = 0;
        return d;
    }

    @Override // java.lang.Runnable
    public void run() {
        int i;
        int size = this.instances.size();
        double d = 1.0f / this.numSamples;
        double[] dArr = new double[this.numColumns];
        this.model.vocabulary.lookupIndex(this.model.queryWord);
        int i2 = this.model.numColumns;
        this.docID = this.threadID * (size / this.numThreads);
        int i3 = (this.threadID + 1) * (size / this.numThreads);
        if (i3 > size) {
            i3 = size;
        }
        double d2 = 1.0d / (this.model.maxExpValue - this.model.minExpValue);
        int[] iArr = new int[100000];
        while (this.shouldRun) {
            Instance instance = this.instances.get(this.docID);
            this.docID++;
            if (this.docID == i3) {
                this.docID = this.threadID * (size / this.numThreads);
            }
            double max = Math.max(1.0E-4d, 0.025d * (1.0d - ((this.numThreads * this.wordsSoFar) / this.model.totalWords)));
            FeatureSequence featureSequence = (FeatureSequence) instance.getData();
            int length = featureSequence.getLength();
            int i4 = 0;
            for (int i5 = 0; i5 < length; i5++) {
                int indexAtPosition = featureSequence.getIndexAtPosition(i5);
                this.wordsSoFar++;
                double d3 = this.model.wordCounts[indexAtPosition] / (1.0E-4d * this.model.totalWords);
                if (this.random.nextDouble() < (Math.sqrt(d3) + 1.0d) / d3) {
                    iArr[i4] = indexAtPosition;
                    i4++;
                }
            }
            if (i4 >= 10) {
                for (int i6 = 0; i6 < i4; i6++) {
                    int i7 = iArr[i6];
                    int i8 = this.model.windowSize;
                    int max2 = Math.max(0, i6 - i8);
                    int min = Math.min(i4 - 1, i6 + i8);
                    for (int i9 = max2; i9 <= min; i9++) {
                        if (i6 != i9 && i7 != (i = iArr[i9])) {
                            double d4 = this.model.weights[(i7 * this.stride) + 0] + this.model.weights[(i7 * this.stride) + i2];
                            for (int i10 = 1; i10 < this.numColumns; i10++) {
                                d4 += this.model.weights[(i7 * this.stride) + i10] * this.model.weights[(i * this.stride) + i2 + i10];
                            }
                            double d5 = d4 < this.model.minExpValue ? 0.0d : d4 > this.model.maxExpValue ? 1.0d : this.model.sigmoidCache[(int) Math.floor(this.model.sigmoidCacheSize * (d4 - this.model.minExpValue) * d2)];
                            dArr[0] = 1.0d - d5;
                            double[] dArr2 = this.model.weights;
                            int i11 = (i * this.stride) + i2;
                            dArr2[i11] = dArr2[i11] + (max * (1.0d - d5));
                            for (int i12 = 1; i12 < this.numColumns; i12++) {
                                dArr[i12] = (1.0d - d5) * this.model.weights[(i * this.stride) + i2 + i12];
                                double[] dArr3 = this.model.weights;
                                int i13 = (i * this.stride) + i2 + i12;
                                dArr3[i13] = dArr3[i13] + (max * (1.0d - d5) * this.model.weights[(i7 * this.stride) + i12]);
                            }
                            double d6 = 0.0d;
                            for (int i14 = 0; i14 < this.numSamples; i14++) {
                                int i15 = this.model.samplingTable[this.random.nextInt(this.model.samplingTableSize)];
                                int i16 = i15 * this.stride;
                                double d7 = this.model.weights[(i7 * this.stride) + 0] + this.model.weights[i16 + i2];
                                for (int i17 = 0; i17 < this.numColumns; i17++) {
                                    d7 += this.model.weights[(i7 * this.stride) + i17] * this.model.weights[i16 + i2 + i17];
                                }
                                double d8 = d7 < this.model.minExpValue ? 0.0d : d7 > this.model.maxExpValue ? 1.0d : this.model.sigmoidCache[(int) Math.floor(this.model.sigmoidCacheSize * (d7 - this.model.minExpValue) * d2)];
                                d6 += d8;
                                dArr[0] = dArr[0] + (d * (-d8));
                                double[] dArr4 = this.model.weights;
                                int i18 = i16 + i2;
                                dArr4[i18] = dArr4[i18] + (max * d * (-d8));
                                for (int i19 = 1; i19 < this.numColumns; i19++) {
                                    int i20 = i19;
                                    dArr[i20] = dArr[i20] + (d * (-d8) * this.model.weights[(i15 * this.stride) + i2 + i19]);
                                    double[] dArr5 = this.model.weights;
                                    int i21 = i16 + i2 + i19;
                                    dArr5[i21] = dArr5[i21] + (max * d * (-d8) * this.model.weights[(i7 * this.stride) + i19]);
                                }
                            }
                            this.residual += d5 - (d6 * d);
                            this.numUpdates++;
                            for (int i22 = 0; i22 < this.numColumns; i22++) {
                                double[] dArr6 = this.model.weights;
                                int i23 = (i7 * this.stride) + i22;
                                dArr6[i23] = dArr6[i23] + (max * dArr[i22]);
                            }
                        }
                    }
                }
            }
        }
    }
}
