package us.ihmc.convexOptimization.quadraticProgram;

import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import org.ejml.dense.row.factory.LinearSolverFactory_DDRM;
import org.ejml.interfaces.linsol.LinearSolverDense;

/* loaded from: input_file:us/ihmc/convexOptimization/quadraticProgram/BlockDiagSquareMatrix.class */
public class BlockDiagSquareMatrix extends DMatrixRMaj {
    private static final long serialVersionUID = 8813856249678942997L;
    int[] blockSizes;
    int[] blockStarts;
    DMatrixRMaj[] tmpMatrix;
    DMatrixRMaj multTempB;
    DMatrixRMaj multTempC;

    public BlockDiagSquareMatrix(int... iArr) {
        super(0);
        this.multTempB = new DMatrixRMaj(0);
        this.multTempC = new DMatrixRMaj(0);
        this.blockSizes = iArr;
        this.blockStarts = new int[getNumBlocks() + 1];
        this.tmpMatrix = new DMatrixRMaj[getNumBlocks()];
        int i = 0;
        for (int i2 = 0; i2 < getNumBlocks(); i2++) {
            this.tmpMatrix[i2] = new DMatrixRMaj(iArr[i2], iArr[i2]);
            this.blockStarts[i2] = i;
            i += iArr[i2];
        }
        this.blockStarts[this.blockStarts.length - 1] = i;
        super.reshape(i, i);
    }

    public int getNumBlocks() {
        return this.blockSizes.length;
    }

    public void setBlock(DMatrixRMaj dMatrixRMaj, int i) {
        setBlock(dMatrixRMaj, i, this);
    }

    public void setBlock(DMatrixRMaj dMatrixRMaj, int i, DMatrixRMaj dMatrixRMaj2) {
        dMatrixRMaj2.reshape(this.numRows, this.numCols);
        int i2 = this.blockStarts[i];
        CommonOps_DDRM.insert(dMatrixRMaj, dMatrixRMaj2, i2, i2);
    }

    public void packBlock(DMatrixRMaj dMatrixRMaj, int i, int i2, int i3) {
        int i4 = this.blockStarts[i];
        int i5 = this.blockStarts[i + 1];
        CommonOps_DDRM.extract(this, i4, i5, i4, i5, dMatrixRMaj, i2, i3);
    }

    public void packInverse(LinearSolverDense<DMatrixRMaj> linearSolverDense, BlockDiagSquareMatrix blockDiagSquareMatrix) {
        for (int i = 0; i < this.blockSizes.length; i++) {
            this.tmpMatrix[i].reshape(this.blockSizes[i], this.blockSizes[i]);
            packBlock(this.tmpMatrix[i], i, 0, 0);
            linearSolverDense.setA(this.tmpMatrix[i]);
            linearSolverDense.invert(this.tmpMatrix[i]);
            blockDiagSquareMatrix.setBlock(this.tmpMatrix[i], i);
        }
    }

    public void packInverse(LinearSolverDense<DMatrixRMaj> linearSolverDense, DMatrixRMaj dMatrixRMaj) {
        dMatrixRMaj.zero();
        for (int i = 0; i < this.blockSizes.length; i++) {
            this.tmpMatrix[i].reshape(this.blockSizes[i], this.blockSizes[i]);
            packBlock(this.tmpMatrix[i], i, 0, 0);
            linearSolverDense.setA(this.tmpMatrix[i]);
            linearSolverDense.invert(this.tmpMatrix[i]);
            setBlock(this.tmpMatrix[i], i, dMatrixRMaj);
        }
    }

    public void multTransB(DMatrixRMaj dMatrixRMaj, DMatrixRMaj dMatrixRMaj2) {
        for (int i = 0; i < this.blockSizes.length; i++) {
            for (int i2 = this.blockStarts[i]; i2 < this.blockStarts[i + 1]; i2++) {
                int index = getIndex(i2, this.blockStarts[i]);
                for (int i3 = 0; i3 < dMatrixRMaj2.numCols; i3++) {
                    double d = 0.0d;
                    int i4 = index;
                    int index2 = dMatrixRMaj.getIndex(i3, this.blockStarts[i]);
                    int i5 = index2 + this.blockSizes[i];
                    while (index2 < i5) {
                        int i6 = i4;
                        i4++;
                        int i7 = index2;
                        index2++;
                        d += this.data[i6] * dMatrixRMaj.data[i7];
                    }
                    dMatrixRMaj2.set(i2, i3, d);
                }
            }
        }
    }

    public void mult(double d, DMatrixRMaj dMatrixRMaj, DMatrixRMaj dMatrixRMaj2) {
        for (int i = 0; i < this.blockSizes.length; i++) {
            this.tmpMatrix[i].reshape(this.blockSizes[i], this.blockSizes[i]);
            packBlock(this.tmpMatrix[i], i, 0, 0);
            this.multTempB.reshape(this.blockSizes[i], dMatrixRMaj.numCols);
            this.multTempC.reshape(this.blockSizes[i], dMatrixRMaj2.numCols);
            CommonOps_DDRM.extract(dMatrixRMaj, this.blockStarts[i], this.blockStarts[i + 1], 0, dMatrixRMaj.numCols, this.multTempB, 0, 0);
            CommonOps_DDRM.mult(d, this.tmpMatrix[i], this.multTempB, this.multTempC);
            CommonOps_DDRM.insert(this.multTempC, dMatrixRMaj2, this.blockStarts[i], 0);
        }
    }

    public static void main(String[] strArr) {
        BlockDiagSquareMatrix blockDiagSquareMatrix = new BlockDiagSquareMatrix(1, 2);
        DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(1, 1, true, new double[]{1.0d});
        DMatrixRMaj dMatrixRMaj2 = new DMatrixRMaj(2, 2, true, new double[]{2.0d, 3.0d, 4.0d, 5.0d});
        blockDiagSquareMatrix.setBlock(dMatrixRMaj, 0);
        blockDiagSquareMatrix.setBlock(dMatrixRMaj2, 1);
        System.out.println(blockDiagSquareMatrix);
        blockDiagSquareMatrix.packInverse(LinearSolverFactory_DDRM.general(blockDiagSquareMatrix.numRows, blockDiagSquareMatrix.numCols), blockDiagSquareMatrix);
        dMatrixRMaj.zero();
        dMatrixRMaj2.zero();
        blockDiagSquareMatrix.packBlock(dMatrixRMaj, 0, 0, 0);
        blockDiagSquareMatrix.packBlock(dMatrixRMaj2, 1, 0, 0);
        System.out.println(dMatrixRMaj);
        System.out.println(dMatrixRMaj2);
        System.out.println("m=\n" + blockDiagSquareMatrix);
    }
}
