package org.apache.mahout.cf.taste.impl.recommender.svd;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.mahout.cf.taste.common.NoSuchUserException;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.impl.common.RunningAverage;
import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.cf.taste.model.PreferenceArray;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DiagonalMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.QRDecomposition;
import org.apache.mahout.math.SparseMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:WEB-INF/lib/mahout-core-0.7.jar:org/apache/mahout/cf/taste/impl/recommender/svd/ImplicitLinearRegressionFactorizer.class */
public final class ImplicitLinearRegressionFactorizer extends AbstractFactorizer {
    private static final Logger log = LoggerFactory.getLogger(ImplicitLinearRegressionFactorizer.class);
    private final double preventOverfitting;
    private final int numFeatures;
    private final int numIterations;
    private final DataModel dataModel;
    private double[][] userMatrix;
    private double[][] itemMatrix;
    private Matrix userTransUser;
    private Matrix itemTransItem;
    Collection<Callable<Void>> fVectorCallables;
    private boolean recomputeUserFeatures;
    private RunningAverage avrChange;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:WEB-INF/lib/mahout-core-0.7.jar:org/apache/mahout/cf/taste/impl/recommender/svd/ImplicitLinearRegressionFactorizer$FeatureVectorCallable.class */
    public class FeatureVectorCallable implements Callable<Void> {
        private final Matrix C;
        private final Matrix prefVector;
        private final int id;

        private FeatureVectorCallable(Matrix matrix, Matrix matrix2, int i) {
            this.C = matrix;
            this.prefVector = matrix2;
            this.id = i;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Void call() throws Exception {
            if (ImplicitLinearRegressionFactorizer.this.recomputeUserFeatures) {
                Matrix identityV = ImplicitLinearRegressionFactorizer.this.identityV(ImplicitLinearRegressionFactorizer.this.dataModel.getNumItems());
                Matrix identityV2 = ImplicitLinearRegressionFactorizer.this.identityV(ImplicitLinearRegressionFactorizer.this.numFeatures);
                Matrix clone = ImplicitLinearRegressionFactorizer.this.itemTransItem.clone();
                DenseMatrix denseMatrix = new DenseMatrix(ImplicitLinearRegressionFactorizer.this.itemMatrix);
                ImplicitLinearRegressionFactorizer.this.updateMatrix(this.id, ImplicitLinearRegressionFactorizer.this.solve(clone.plus(denseMatrix.transpose().times(this.C.minus(identityV)).times(denseMatrix)).plus(identityV2.times(ImplicitLinearRegressionFactorizer.this.preventOverfitting)), identityV2).times(denseMatrix.transpose().times(this.C)).times(this.prefVector.transpose()));
                return null;
            }
            Matrix identityV3 = ImplicitLinearRegressionFactorizer.this.identityV(ImplicitLinearRegressionFactorizer.this.dataModel.getNumUsers());
            Matrix identityV4 = ImplicitLinearRegressionFactorizer.this.identityV(ImplicitLinearRegressionFactorizer.this.numFeatures);
            Matrix clone2 = ImplicitLinearRegressionFactorizer.this.userTransUser.clone();
            DenseMatrix denseMatrix2 = new DenseMatrix(ImplicitLinearRegressionFactorizer.this.userMatrix);
            ImplicitLinearRegressionFactorizer.this.updateMatrix(this.id, ImplicitLinearRegressionFactorizer.this.solve(clone2.plus(denseMatrix2.transpose().times(this.C.minus(identityV3)).times(denseMatrix2)).plus(identityV4.times(ImplicitLinearRegressionFactorizer.this.preventOverfitting)), identityV4).times(denseMatrix2.transpose().times(this.C)).times(this.prefVector.transpose()));
            return null;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:WEB-INF/lib/mahout-core-0.7.jar:org/apache/mahout/cf/taste/impl/recommender/svd/ImplicitLinearRegressionFactorizer$StatsCallable.class */
    public static class StatsCallable implements Callable<Void> {
        private final Callable<Void> delegate;
        private final boolean logStats;
        private final RunningAverageAndStdDev timing;

        private StatsCallable(Callable<Void> callable, boolean z, RunningAverageAndStdDev runningAverageAndStdDev) {
            this.delegate = callable;
            this.logStats = z;
            this.timing = runningAverageAndStdDev;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Void call() throws Exception {
            long currentTimeMillis = System.currentTimeMillis();
            this.delegate.call();
            this.timing.addDatum(System.currentTimeMillis() - currentTimeMillis);
            if (!this.logStats) {
                return null;
            }
            Runtime runtime = Runtime.getRuntime();
            ImplicitLinearRegressionFactorizer.log.info("Average time per task: {}ms", Integer.valueOf((int) this.timing.getAverage()));
            long j = runtime.totalMemory();
            ImplicitLinearRegressionFactorizer.log.info("Approximate memory used: {}MB / {}MB", Long.valueOf((j - runtime.freeMemory()) / 1000000), Long.valueOf(j / 1000000));
            return null;
        }
    }

    public ImplicitLinearRegressionFactorizer(DataModel dataModel) throws TasteException {
        this(dataModel, 64, 10, 0.1d);
    }

    public ImplicitLinearRegressionFactorizer(DataModel dataModel, int i, int i2, double d) throws TasteException {
        super(dataModel);
        this.dataModel = dataModel;
        this.numFeatures = i;
        this.numIterations = i2;
        this.preventOverfitting = d;
        this.fVectorCallables = Lists.newArrayList();
        this.avrChange = new FullRunningAverage();
    }

    @Override // org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer
    public Factorization factorize() throws TasteException {
        Random random = RandomUtils.getRandom();
        this.userMatrix = new double[this.dataModel.getNumUsers()][this.numFeatures];
        this.itemMatrix = new double[this.dataModel.getNumItems()][this.numFeatures];
        this.recomputeUserFeatures = true;
        double averagePreference = getAveragePreference();
        double maxPreference = this.dataModel.getMaxPreference() - this.dataModel.getMinPreference();
        double sqrt = Math.sqrt((averagePreference - (maxPreference * 0.1d)) / this.numFeatures);
        double d = (maxPreference * 0.1d) / this.numFeatures;
        for (int i = 0; i < this.numFeatures; i++) {
            for (int i2 = 0; i2 < this.dataModel.getNumUsers(); i2++) {
                this.userMatrix[i2][i] = sqrt + ((random.nextDouble() - 0.5d) * d * random.nextDouble());
            }
            for (int i3 = 0; i3 < this.dataModel.getNumItems(); i3++) {
                this.itemMatrix[i3][i] = sqrt + ((random.nextDouble() - 0.5d) * d * random.nextDouble());
            }
        }
        train();
        return createFactorization(this.userMatrix, this.itemMatrix);
    }

    public void train() throws TasteException {
        for (int i = 0; i < this.numIterations; i++) {
            if (this.recomputeUserFeatures) {
                LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
                log.info("Calculating Y^TY");
                reCalculateTrans(this.recomputeUserFeatures);
                log.info("Building callables for users.");
                while (userIDs.hasNext()) {
                    long nextLong = userIDs.nextLong();
                    buildCallables(buildConfidenceMatrixForUser(nextLong), buildPreferenceVectorForUser(nextLong), userIndex(nextLong).intValue());
                }
                finishProcessing();
            } else {
                LongPrimitiveIterator itemIDs = this.dataModel.getItemIDs();
                log.info("Calculating X^TX");
                reCalculateTrans(this.recomputeUserFeatures);
                log.info("Building callables for items.");
                while (itemIDs.hasNext()) {
                    long nextLong2 = itemIDs.nextLong();
                    buildCallables(buildConfidenceMatrixForItem(nextLong2), buildPreferenceVectorForItem(nextLong2), itemIndex(nextLong2).intValue());
                }
                finishProcessing();
            }
        }
    }

    public Matrix buildPreferenceVectorForUser(long j) throws TasteException {
        SparseMatrix sparseMatrix = new SparseMatrix(1, this.dataModel.getNumItems());
        Iterator<Preference> it = this.dataModel.getPreferencesFromUser(j).iterator();
        while (it.hasNext()) {
            sparseMatrix.setQuick(0, itemIndex(it.next().getItemID()).intValue(), r0.getValue());
        }
        return sparseMatrix;
    }

    private Matrix buildConfidenceMatrixForItem(long j) throws TasteException {
        PreferenceArray preferencesForItem = this.dataModel.getPreferencesForItem(j);
        SparseMatrix sparseMatrix = new SparseMatrix(this.dataModel.getNumUsers(), this.dataModel.getNumUsers());
        Iterator<Preference> it = preferencesForItem.iterator();
        while (it.hasNext()) {
            int intValue = userIndex(it.next().getUserID()).intValue();
            sparseMatrix.setQuick(intValue, intValue, 1.0d);
        }
        return new DiagonalMatrix(sparseMatrix);
    }

    private Matrix buildConfidenceMatrixForUser(long j) throws TasteException {
        PreferenceArray preferencesFromUser = this.dataModel.getPreferencesFromUser(j);
        SparseMatrix sparseMatrix = new SparseMatrix(this.dataModel.getNumItems(), this.dataModel.getNumItems());
        Iterator<Preference> it = preferencesFromUser.iterator();
        while (it.hasNext()) {
            int intValue = itemIndex(it.next().getItemID()).intValue();
            sparseMatrix.setQuick(intValue, intValue, 1.0d);
        }
        return new DiagonalMatrix(sparseMatrix);
    }

    private Matrix buildPreferenceVectorForItem(long j) throws TasteException {
        SparseMatrix sparseMatrix = new SparseMatrix(1, this.dataModel.getNumUsers());
        Iterator<Preference> it = this.dataModel.getPreferencesForItem(j).iterator();
        while (it.hasNext()) {
            sparseMatrix.setQuick(0, userIndex(it.next().getUserID()).intValue(), r0.getValue());
        }
        return sparseMatrix;
    }

    private Matrix ones(int i) {
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = 1.0d;
        }
        return new DiagonalMatrix(dArr);
    }

    private double getAveragePreference() throws TasteException {
        FullRunningAverage fullRunningAverage = new FullRunningAverage();
        LongPrimitiveIterator userIDs = this.dataModel.getUserIDs();
        while (userIDs.hasNext()) {
            int i = 0;
            try {
                Iterator<Preference> it = this.dataModel.getPreferencesFromUser(userIDs.nextLong()).iterator();
                while (it.hasNext()) {
                    fullRunningAverage.addDatum(it.next().getValue());
                    i++;
                }
                for (int i2 = 0; i2 < this.dataModel.getNumItems() - i; i2++) {
                    fullRunningAverage.addDatum(0.0d);
                }
            } catch (NoSuchUserException e) {
            }
        }
        return fullRunningAverage.getAverage();
    }

    public void reCalculateTrans(boolean z) {
        if (z) {
            DenseMatrix denseMatrix = new DenseMatrix(this.itemMatrix);
            this.itemTransItem = denseMatrix.transpose().times(denseMatrix);
        } else {
            DenseMatrix denseMatrix2 = new DenseMatrix(this.userMatrix);
            this.userTransUser = denseMatrix2.transpose().times(denseMatrix2);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public synchronized void updateMatrix(int i, Matrix matrix) {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i2 = 0; i2 < this.numFeatures; i2++) {
            if (this.recomputeUserFeatures) {
                d += this.userMatrix[i][i2] * this.userMatrix[i][i2];
                d2 += matrix.get(i2, 0) * matrix.get(i2, 0);
                d3 += this.userMatrix[i][i2] * matrix.get(i2, 0);
                this.userMatrix[i][i2] = matrix.get(i2, 0);
            } else {
                d += this.itemMatrix[i][i2] * this.itemMatrix[i][i2];
                d2 += matrix.get(i2, 0) * matrix.get(i2, 0);
                d3 += this.itemMatrix[i][i2] * matrix.get(i2, 0);
                this.itemMatrix[i][i2] = matrix.get(i2, 0);
            }
        }
        double sqrt = d3 / (Math.sqrt(d) * Math.sqrt(d2));
        if (Double.isNaN(sqrt)) {
            log.info("Cosine similarity is NaN, recomputeUserFeatures=" + this.recomputeUserFeatures + " id=" + i);
        } else {
            this.avrChange.addDatum(sqrt);
        }
    }

    public void resetCallables() {
        this.fVectorCallables = Lists.newArrayList();
    }

    private void resetAvrChange() {
        log.info("Avr Change: {}", Double.valueOf(this.avrChange.getAverage()));
        this.avrChange = new FullRunningAverage();
    }

    public void buildCallables(Matrix matrix, Matrix matrix2, int i) throws TasteException {
        this.fVectorCallables.add(new FeatureVectorCallable(matrix, matrix2, i));
        if (this.fVectorCallables.size() % (200 * Runtime.getRuntime().availableProcessors()) == 0) {
            execute(this.fVectorCallables);
            resetCallables();
        }
    }

    public void finishProcessing() throws TasteException {
        if (this.fVectorCallables != null) {
            execute(this.fVectorCallables);
        }
        resetCallables();
        if ((this.recomputeUserFeatures && this.avrChange.getCount() != this.userMatrix.length) || (!this.recomputeUserFeatures && this.avrChange.getCount() != this.itemMatrix.length)) {
            log.info("Matrix length is not equal to count");
        }
        resetAvrChange();
        this.recomputeUserFeatures = !this.recomputeUserFeatures;
    }

    public Matrix identityV(int i) {
        return ones(i);
    }

    void execute(Collection<Callable<Void>> collection) throws TasteException {
        Collection<Callable<Void>> wrapWithStatsCallables = wrapWithStatsCallables(collection);
        int availableProcessors = Runtime.getRuntime().availableProcessors();
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(availableProcessors);
        log.info("Starting timing of {} tasks in {} threads", Integer.valueOf(wrapWithStatsCallables.size()), Integer.valueOf(availableProcessors));
        try {
            Iterator it = newFixedThreadPool.invokeAll(wrapWithStatsCallables).iterator();
            while (it.hasNext()) {
                ((Future) it.next()).get();
            }
        } catch (InterruptedException e) {
            log.warn("error in factorization", (Throwable) e);
        } catch (ExecutionException e2) {
            log.warn("error in factorization", (Throwable) e2);
        }
        newFixedThreadPool.shutdown();
    }

    private Collection<Callable<Void>> wrapWithStatsCallables(Collection<Callable<Void>> collection) {
        ArrayList newArrayListWithExpectedSize = Lists.newArrayListWithExpectedSize(collection.size());
        int i = 1;
        FullRunningAverageAndStdDev fullRunningAverageAndStdDev = new FullRunningAverageAndStdDev();
        for (Callable<Void> callable : collection) {
            int i2 = i;
            i++;
            newArrayListWithExpectedSize.add(new StatsCallable(callable, i2 % 1000 == 0, fullRunningAverageAndStdDev));
        }
        return newArrayListWithExpectedSize;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Matrix solve(Matrix matrix, Matrix matrix2) {
        return new QRDecomposition(matrix).solve(matrix2);
    }
}
