package cc.mallet.regression;

import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MVNormal;
import cc.mallet.util.StatFunctions;
import java.io.File;
import java.text.NumberFormat;
import java.util.Iterator;

/* loaded from: input_file:cc/mallet/regression/LeastSquares.class */
public class LeastSquares {
    LinearRegression regression;
    double[] parameters;
    InstanceList trainingData;
    double[] residuals;
    double meanSquaredError;
    double sumSquaredError;
    double sumSquaredModel;
    int degreesOfFreedom;
    NumberFormat formatter;
    int precisionIndex;
    int interceptIndex;
    int dimension;
    double[] xTransposeXInverse;

    public LeastSquares(InstanceList instanceList) {
        this(instanceList, 0.0d);
    }

    public LeastSquares(InstanceList instanceList, double d) {
        this.meanSquaredError = 0.0d;
        this.trainingData = instanceList;
        this.regression = new LinearRegression(this.trainingData.getDataAlphabet());
        this.parameters = this.regression.getParameters();
        this.interceptIndex = this.parameters.length - 2;
        this.precisionIndex = this.parameters.length - 1;
        this.residuals = new double[this.trainingData.size()];
        this.formatter = NumberFormat.getInstance();
        this.formatter.setMaximumFractionDigits(8);
        this.dimension = this.parameters.length - 1;
        double[] dArr = new double[this.dimension * this.dimension];
        double[] dArr2 = new double[this.dimension];
        double d2 = 0.0d;
        Iterator<Instance> it = instanceList.iterator();
        while (it.hasNext()) {
            Instance next = it.next();
            FeatureVector featureVector = (FeatureVector) next.getData();
            double doubleValue = ((Double) next.getTarget()).doubleValue();
            d2 += doubleValue;
            for (int i = 0; i < featureVector.numLocations(); i++) {
                int indexAtLocation = featureVector.indexAtLocation(i);
                double valueAtLocation = featureVector.valueAtLocation(i);
                for (int i2 = 0; i2 < featureVector.numLocations(); i2++) {
                    int indexAtLocation2 = featureVector.indexAtLocation(i2);
                    double valueAtLocation2 = featureVector.valueAtLocation(i2);
                    int i3 = (this.dimension * indexAtLocation) + indexAtLocation2;
                    dArr[i3] = dArr[i3] + (valueAtLocation * valueAtLocation2);
                }
                int i4 = (this.dimension * indexAtLocation) + this.interceptIndex;
                dArr[i4] = dArr[i4] + valueAtLocation;
                int i5 = (this.dimension * this.interceptIndex) + indexAtLocation;
                dArr[i5] = dArr[i5] + valueAtLocation;
                dArr2[indexAtLocation] = dArr2[indexAtLocation] + (valueAtLocation * doubleValue);
            }
            int i6 = (this.dimension * this.interceptIndex) + this.interceptIndex;
            dArr[i6] = dArr[i6] + 1.0d;
            int i7 = this.interceptIndex;
            dArr2[i7] = dArr2[i7] + doubleValue;
        }
        if (d > 0.0d) {
            for (int i8 = 0; i8 < this.dimension; i8++) {
                int i9 = (this.dimension * i8) + i8;
                dArr[i9] = dArr[i9] + d;
            }
        }
        double size = d2 / instanceList.size();
        this.xTransposeXInverse = MVNormal.invertSPD(dArr, this.dimension);
        double size2 = 1.0d / (instanceList.size() * instanceList.size());
        for (int i10 = 0; i10 < this.dimension; i10++) {
            for (int i11 = 0; i11 < this.dimension; i11++) {
                double[] dArr3 = this.parameters;
                int i12 = i10;
                dArr3[i12] = dArr3[i12] + (this.xTransposeXInverse[(i10 * this.dimension) + i11] * dArr2[i11]);
            }
        }
        this.sumSquaredError = 0.0d;
        this.sumSquaredModel = 0.0d;
        this.degreesOfFreedom = this.trainingData.size() - this.dimension;
        for (int i13 = 0; i13 < this.trainingData.size(); i13++) {
            Instance instance = this.trainingData.get(i13);
            double predict = this.regression.predict(instance);
            this.residuals[i13] = ((Double) instance.getTarget()).doubleValue() - predict;
            this.sumSquaredError += this.residuals[i13] * this.residuals[i13];
            this.sumSquaredModel += (size - predict) * (size - predict);
        }
        this.meanSquaredError = this.sumSquaredError / this.degreesOfFreedom;
    }

    public double[] pValues() {
        double[] dArr = new double[this.dimension];
        for (int i = 0; i < this.dimension; i++) {
            dArr[i] = 2.0d * (1.0d - StatFunctions.pt(Math.abs(this.parameters[i] / Math.sqrt(this.meanSquaredError * this.xTransposeXInverse[(this.dimension * i) + i])), this.degreesOfFreedom));
        }
        return dArr;
    }

    public void printSummary() {
        System.out.println("\tparam\tStd.Err\tt value\tPr(>|t|)");
        System.out.print("(Int)\t");
        System.out.print(String.valueOf(this.formatter.format(this.parameters[this.interceptIndex])) + "\t");
        double sqrt = Math.sqrt(this.meanSquaredError * this.xTransposeXInverse[(this.dimension * this.interceptIndex) + this.interceptIndex]);
        System.out.print(String.valueOf(this.formatter.format(sqrt)) + "\t");
        System.out.print(String.valueOf(this.formatter.format(this.parameters[this.interceptIndex] / sqrt)) + "\t");
        double pt = 2.0d * (1.0d - StatFunctions.pt(Math.abs(this.parameters[this.interceptIndex] / sqrt), this.degreesOfFreedom));
        System.out.println(String.valueOf(this.formatter.format(pt)) + " " + significanceStars(pt));
        for (int i = 0; i < this.dimension - 1; i++) {
            System.out.print(this.trainingData.getDataAlphabet().lookupObject(i) + "\t");
            System.out.print(String.valueOf(this.formatter.format(this.parameters[i])) + "\t");
            double sqrt2 = Math.sqrt(this.meanSquaredError * this.xTransposeXInverse[(this.dimension * i) + i]);
            System.out.print(String.valueOf(this.formatter.format(sqrt2)) + "\t");
            System.out.print(String.valueOf(this.formatter.format(this.parameters[i] / sqrt2)) + "\t");
            double pt2 = 2.0d * (1.0d - StatFunctions.pt(Math.abs(this.parameters[i] / sqrt2), this.degreesOfFreedom));
            System.out.println(String.valueOf(this.formatter.format(pt2)) + " " + significanceStars(pt2));
        }
        System.out.println();
        System.out.println("SSE: " + this.formatter.format(this.sumSquaredError) + " DF: " + this.degreesOfFreedom);
        System.out.println("R^2: " + this.formatter.format(this.sumSquaredModel / (this.sumSquaredError + this.sumSquaredModel)));
    }

    public String significanceStars(double d) {
        return d < 0.001d ? "***" : d < 0.01d ? "**" : d < 0.05d ? "*" : d < 0.1d ? "." : " ";
    }

    public int getNumParameters() {
        return this.parameters.length;
    }

    public double getParameter(int i) {
        return this.parameters[i];
    }

    public void getParameters(double[] dArr) {
        for (int i = 0; i < this.parameters.length; i++) {
            dArr[i] = this.parameters[i];
        }
    }

    public LinearRegression getRegression() {
        return this.regression;
    }

    public static void main(String[] strArr) throws Exception {
        InstanceList load = InstanceList.load(new File(strArr[0]));
        (strArr.length > 1 ? new LeastSquares(load, Double.parseDouble(strArr[1])) : new LeastSquares(load)).printSummary();
    }
}
