package edu.cmu.graphchi.toolkits.collaborative_filtering;

import edu.cmu.graphchi.ChiVertex;
import edu.cmu.graphchi.GraphChiContext;
import edu.cmu.graphchi.GraphChiProgram;
import edu.cmu.graphchi.datablocks.FloatConverter;
import edu.cmu.graphchi.engine.GraphChiEngine;
import edu.cmu.graphchi.engine.VertexInterval;
import edu.cmu.graphchi.preprocessing.VertexIdTranslate;
import edu.cmu.graphchi.util.FileUtils;
import java.io.BufferedWriter;
import java.io.FileWriter;
import org.apache.commons.math.linear.ArrayRealVector;
import org.apache.commons.math.linear.BlockRealMatrix;
import org.apache.commons.math.linear.CholeskyDecompositionImpl;
import org.apache.commons.math.linear.NotPositiveDefiniteMatrixException;
import org.apache.commons.math.linear.RealVector;

/* loaded from: input_file:edu/cmu/graphchi/toolkits/collaborative_filtering/ALS.class */
public class ALS extends ProblemSetup implements GraphChiProgram<Integer, Float> {
    double LAMBDA = 0.065d;

    private ALS() {
    }

    public static double als_predict(RealVector realVector, RealVector realVector2) {
        return Math.max(Math.min(realVector.dotProduct(realVector2), maxval), minval);
    }

    @Override // edu.cmu.graphchi.GraphChiProgram
    public void update(ChiVertex<Integer, Float> chiVertex, GraphChiContext graphChiContext) {
        if (chiVertex.numEdges() == 0) {
            return;
        }
        BlockRealMatrix blockRealMatrix = new BlockRealMatrix(D, D);
        ArrayRealVector arrayRealVector = new ArrayRealVector(D);
        RealVector rowAsVector = latent_factors_inmem.getRowAsVector(chiVertex.getId());
        try {
            double d = 0.0d;
            boolean z = chiVertex.numOutEdges() > 0;
            for (int i = 0; i < chiVertex.numEdges(); i++) {
                float floatValue = chiVertex.edge(i).getValue().floatValue();
                RealVector rowAsVector2 = latent_factors_inmem.getRowAsVector(chiVertex.edge(i).getVertexId());
                for (int i2 = 0; i2 < D; i2++) {
                    arrayRealVector.setEntry(i2, arrayRealVector.getEntry(i2) + (rowAsVector2.getEntry(i2) * floatValue));
                    for (int i3 = i2; i3 < D; i3++) {
                        blockRealMatrix.setEntry(i3, i2, blockRealMatrix.getEntry(i3, i2) + (rowAsVector2.getEntry(i2) * rowAsVector2.getEntry(i3)));
                    }
                }
                for (int i4 = 0; i4 < D; i4++) {
                    for (int i5 = i4 + 1; i5 < D; i5++) {
                        blockRealMatrix.setEntry(i4, i5, blockRealMatrix.getEntry(i5, i4));
                    }
                }
                if (z) {
                    d += Math.pow(als_predict(rowAsVector2, new ArrayRealVector(rowAsVector)) - floatValue, 2.0d);
                }
            }
            for (int i6 = 0; i6 < D; i6++) {
                blockRealMatrix.setEntry(i6, i6, blockRealMatrix.getEntry(i6, i6) + (this.LAMBDA * chiVertex.numEdges()));
            }
            latent_factors_inmem.setRow(chiVertex.getId(), new CholeskyDecompositionImpl(blockRealMatrix).getSolver().solve(arrayRealVector).getData());
            if (z) {
                synchronized (this) {
                    train_rmse += d;
                }
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        } catch (NotPositiveDefiniteMatrixException e2) {
            logger.warning("Matrix was not positive definite: " + blockRealMatrix);
        }
    }

    @Override // edu.cmu.graphchi.GraphChiProgram
    public void beginIteration(GraphChiContext graphChiContext) {
        train_rmse = 0.0d;
        if (graphChiContext.getIteration() == 0) {
            init_feature_vectors(graphChiContext.getNumVertices());
        }
    }

    @Override // edu.cmu.graphchi.GraphChiProgram
    public void endIteration(GraphChiContext graphChiContext) {
        train_rmse = Math.sqrt(train_rmse / 3298163.0d);
        ProblemSetup.validation_rmse_engine.calc_validation_rmse(training, nShards);
    }

    @Override // edu.cmu.graphchi.GraphChiProgram
    public void beginInterval(GraphChiContext graphChiContext, VertexInterval vertexInterval) {
    }

    @Override // edu.cmu.graphchi.GraphChiProgram
    public void endInterval(GraphChiContext graphChiContext, VertexInterval vertexInterval) {
    }

    @Override // edu.cmu.graphchi.GraphChiProgram
    public void beginSubInterval(GraphChiContext graphChiContext, VertexInterval vertexInterval) {
    }

    @Override // edu.cmu.graphchi.GraphChiProgram
    public void endSubInterval(GraphChiContext graphChiContext, VertexInterval vertexInterval) {
    }

    public static void main(String[] strArr) throws Exception {
        ALS als = new ALS();
        Common.parse_command_line_arguments(strArr);
        logger.info("Set latent factor dimension to: " + D);
        IO.convert_matrix_market();
        validation_rmse_engine = new RMSEEngine();
        validation_rmse_engine.init_validation();
        GraphChiEngine graphChiEngine = new GraphChiEngine(ProblemSetup.training, ProblemSetup.nShards);
        graphChiEngine.setEdataConverter(new FloatConverter());
        graphChiEngine.setEnableDeterministicExecution(false);
        graphChiEngine.setVertexDataConverter(null);
        graphChiEngine.setModifiesInedges(false);
        graphChiEngine.setModifiesOutedges(false);
        graphChiEngine.run(als, 5);
        als.writeOutputMatrices(graphChiEngine.getVertexIdTranslate());
    }

    private void writeOutputMatrices(VertexIdTranslate vertexIdTranslate) throws Exception {
        String[] split = FileUtils.readToString(training + ".matrixinfo").split("\t");
        int parseInt = Integer.parseInt(split[0]);
        int parseInt2 = Integer.parseInt(split[1]);
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(training + "_U.mm"));
        bufferedWriter.write("%%MatrixMarket matrix array real general\n");
        bufferedWriter.write(D + " " + parseInt + "\n");
        for (int i = 0; i < parseInt; i++) {
            int forward = vertexIdTranslate.forward(i);
            for (int i2 = 0; i2 < D; i2++) {
                bufferedWriter.write(latent_factors_inmem.getValue(forward, i2) + "\n");
            }
        }
        bufferedWriter.close();
        BufferedWriter bufferedWriter2 = new BufferedWriter(new FileWriter(training + "_V.mm"));
        bufferedWriter2.write("%%MatrixMarket matrix array real general\n");
        bufferedWriter2.write(D + " " + parseInt2 + "\n");
        for (int i3 = 0; i3 < parseInt2; i3++) {
            int forward2 = vertexIdTranslate.forward(parseInt + i3);
            for (int i4 = 0; i4 < D; i4++) {
                bufferedWriter2.write(latent_factors_inmem.getValue(forward2, i4) + "\n");
            }
        }
        bufferedWriter2.close();
        logger.info("Latent factor matrices saved: " + training + "_U.mm, " + training + "_V.mm");
    }
}
