package cc.mallet.grmm.types;

import cc.mallet.grmm.inference.Utils;
import cc.mallet.types.Matrixn;
import cc.mallet.types.SparseMatrixn;
import cc.mallet.util.Randoms;
import gnu.trove.TObjectIntHashMap;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.Serializable;
import java.io.StringWriter;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/* loaded from: input_file:cc/mallet/grmm/types/Assignment.class */
public class Assignment extends AbstractFactor implements Serializable {
    transient TObjectIntHashMap var2idx;
    transient ArrayList values;
    double scale;
    private static final long serialVersionUID = 1;
    private static final int SERIAL_VERSION = 2;

    public Assignment() {
        super(new HashVarSet());
        this.scale = 1.0d;
        this.var2idx = new TObjectIntHashMap();
        this.values = new ArrayList();
    }

    public Assignment(Variable variable, int i) {
        this();
        addRow(new Variable[]{variable}, new int[]{i});
    }

    public Assignment(Variable variable, double d) {
        this();
        addRow(new Variable[]{variable}, new double[]{d});
    }

    public Assignment(Variable[] variableArr, int[] iArr) {
        this.scale = 1.0d;
        this.var2idx = new TObjectIntHashMap(variableArr.length);
        this.values = new ArrayList();
        addRow(variableArr, iArr);
    }

    public Assignment(Variable[] variableArr, double[] dArr) {
        this.scale = 1.0d;
        this.var2idx = new TObjectIntHashMap(variableArr.length);
        this.values = new ArrayList();
        addRow(variableArr, dArr);
    }

    public Assignment(List list, int[] iArr) {
        this.scale = 1.0d;
        this.var2idx = new TObjectIntHashMap(list.size());
        this.values = new ArrayList();
        addRow((Variable[]) list.toArray(new Variable[0]), iArr);
    }

    public Assignment(FactorGraph factorGraph, int[] iArr) {
        this.scale = 1.0d;
        this.var2idx = new TObjectIntHashMap(factorGraph.numVariables());
        this.values = new ArrayList();
        Variable[] variableArr = new Variable[factorGraph.numVariables()];
        for (int i = 0; i < variableArr.length; i++) {
            variableArr[i] = factorGraph.get(i);
        }
        addRow(variableArr, iArr);
    }

    public static Assignment union(Assignment assignment, Assignment assignment2) {
        Assignment assignment3 = new Assignment();
        HashVarSet hashVarSet = new HashVarSet();
        hashVarSet.addAll(assignment.vars);
        hashVarSet.addAll(assignment2.vars);
        Variable[] variableArray = hashVarSet.toVariableArray();
        if (assignment.numRows() == 0) {
            return (Assignment) assignment2.duplicate();
        }
        if (assignment2.numRows() == 0) {
            return (Assignment) assignment.duplicate();
        }
        if (assignment.numRows() != assignment2.numRows()) {
            throw new IllegalArgumentException("Number of rows not equal.");
        }
        for (int i = 0; i < assignment2.numRows(); i++) {
            Object[] objArr = new Object[hashVarSet.size()];
            for (int i2 = 0; i2 < hashVarSet.size(); i2++) {
                Variable variable = variableArray[i2];
                if (!assignment.containsVar(variable)) {
                    objArr[i2] = assignment2.getObject(variable);
                } else if (assignment2.containsVar(variable)) {
                    Object object = assignment.getObject(variable);
                    Object object2 = assignment2.getObject(variable);
                    if (!object.equals(object2)) {
                        throw new IllegalArgumentException("Assignments don't match on intersection.\n  ASSN1[" + variable + "] = " + object + "\n  ASSN2[" + variable + "] = " + object2);
                    }
                    objArr[i2] = object;
                } else {
                    objArr[i2] = assignment.getObject(variable);
                }
            }
            assignment3.addRow(variableArray, objArr);
        }
        return assignment3;
    }

    public static Assignment restriction(Assignment assignment, VarSet varSet) {
        return (Assignment) assignment.marginalize(varSet);
    }

    public Assignment getRow(int i) {
        Assignment assignment = new Assignment();
        assignment.var2idx = this.var2idx.clone();
        assignment.vars = new UnmodifiableVarSet(this.vars);
        assignment.addRow((Object[]) this.values.get(i));
        return assignment;
    }

    public void addRow(Variable[] variableArr, int[] iArr) {
        addRow(variableArr, boxArray(iArr));
    }

    public void addRow(Variable[] variableArr, double[] dArr) {
        addRow(variableArr, boxArray(dArr));
    }

    public void addRow(Variable[] variableArr, Object[] objArr) {
        checkAssignmentsMatch(variableArr);
        addRow(objArr);
    }

    public void addRow(Object[] objArr) {
        if (objArr.length != numVariables()) {
            throw new IllegalArgumentException("Wrong number of variables when adding to " + this + "\nwas:\n");
        }
        this.values.add(objArr);
    }

    public void addRow(Assignment assignment) {
        checkAssignmentsMatch(assignment);
        for (int i = 0; i < assignment.numRows(); i++) {
            Object[] objArr = new Object[((Object[]) assignment.values.get(i)).length];
            for (int i2 = 0; i2 < objArr.length; i2++) {
                objArr[i2] = assignment.getObject(i, getVariable(i2));
            }
            addRow(objArr);
        }
    }

    private void checkAssignmentsMatch(Assignment assignment) {
        if (numVariables() == 0) {
            setVariables(assignment.vars.toVariableArray());
            return;
        }
        if (numVariables() != assignment.numVariables()) {
            throw new IllegalArgumentException("Attempt to add row with non-matching variables.\n  This has vars: " + varSet() + "\n  Other has vars:" + assignment.varSet());
        }
        for (int i = 0; i < numVariables(); i++) {
            if (!assignment.containsVar(getVariable(i))) {
                throw new IllegalArgumentException("Attempt to add row with non-matching variables");
            }
        }
    }

    private void checkAssignmentsMatch(Variable[] variableArr) {
        if (numRows() == 0) {
            setVariables(variableArr);
        } else {
            checkVariables(variableArr);
        }
    }

    private void checkVariables(Variable[] variableArr) {
        for (int i = 0; i < variableArr.length; i++) {
            if (variableArr[i] != this.vars.get(i)) {
                throw new IllegalArgumentException("Attempt to add row with incompatible variables.");
            }
        }
    }

    private void setVariables(Variable[] variableArr) {
        this.vars.addAll(Arrays.asList(variableArr));
        for (int i = 0; i < variableArr.length; i++) {
            this.var2idx.put(variableArr[i], i);
        }
    }

    private Object[] boxArray(int[] iArr) {
        Object[] objArr = new Object[iArr.length];
        for (int i = 0; i < objArr.length; i++) {
            objArr[i] = new Integer(iArr[i]);
        }
        return objArr;
    }

    private Object[] boxArray(double[] dArr) {
        Object[] objArr = new Object[dArr.length];
        for (int i = 0; i < objArr.length; i++) {
            objArr[i] = new Double(dArr[i]);
        }
        return objArr;
    }

    public int numRows() {
        return this.values.size();
    }

    public int get(Variable variable) {
        if (numRows() != 1) {
            throw new IllegalArgumentException("Attempt to call get() with no row specified: " + this);
        }
        return get(0, variable);
    }

    public double getDouble(Variable variable) {
        if (numRows() != 1) {
            throw new IllegalArgumentException("Attempt to call getDouble() with no row specified: " + this);
        }
        return getDouble(0, variable);
    }

    public Object getObject(Variable variable) {
        if (numRows() != 1) {
            throw new IllegalArgumentException("Attempt to call getObject() with no row specified: " + this);
        }
        return getObject(0, variable);
    }

    public int get(int i, Variable variable) {
        int colOfVar = colOfVar(variable, false);
        if (colOfVar == -1) {
            throw new IndexOutOfBoundsException("Assignment does not give a value for variable " + variable);
        }
        return ((Integer) ((Object[]) this.values.get(i))[colOfVar]).intValue();
    }

    public double getDouble(int i, Variable variable) {
        int colOfVar = colOfVar(variable, false);
        if (colOfVar == -1) {
            throw new IndexOutOfBoundsException("Assignment does not give a value for variable " + variable);
        }
        return ((Double) ((Object[]) this.values.get(i))[colOfVar]).doubleValue();
    }

    public Object getObject(int i, Variable variable) {
        Object[] objArr = (Object[]) this.values.get(i);
        int colOfVar = colOfVar(variable, false);
        if (colOfVar < 0) {
            throw new IllegalArgumentException("Variable " + variable + " does not exist in this assignment.");
        }
        return objArr[colOfVar];
    }

    @Override // cc.mallet.grmm.types.AbstractFactor, cc.mallet.grmm.types.Factor
    public Variable getVariable(int i) {
        return this.vars.get(i);
    }

    public Variable[] getVars() {
        return (Variable[]) this.vars.toArray(new Variable[0]);
    }

    public int size() {
        return numVariables();
    }

    public static Assignment makeFromSingleIndex(VarSet varSet, int i) {
        int size = varSet.size();
        Variable[] variableArray = varSet.toVariableArray();
        int[] iArr = new int[size];
        int[] iArr2 = new int[size];
        for (int i2 = 0; i2 < size; i2++) {
            iArr2[i2] = variableArray[i2].getNumOutcomes();
        }
        Matrixn.singleToIndices(i, iArr, iArr2);
        return new Assignment(variableArray, iArr);
    }

    public int singleIndex() {
        int numRows = numRows();
        if (numRows == 0) {
            return -1;
        }
        if (numRows > 1) {
            throw new IllegalArgumentException("No row specified.");
        }
        return singleIndex(0);
    }

    private void checkIsSingleRow() {
        if (numRows() != 1) {
            throw new IllegalArgumentException("No row specified.");
        }
    }

    public int singleIndex(int i) {
        int[] iArr = new int[numVariables()];
        for (int i2 = 0; i2 < numVariables(); i2++) {
            iArr[i2] = this.vars.get(i2).getNumOutcomes();
        }
        return Matrixn.singleIndex(iArr, toIntArray(i));
    }

    public int numVariables() {
        return this.vars.size();
    }

    private int[] toIntArray(int i) {
        int[] iArr = new int[numVariables()];
        Object[] objArr = (Object[]) this.values.get(i);
        for (int i2 = 0; i2 < objArr.length; i2++) {
            iArr[i2] = ((Integer) objArr[i2]).intValue();
        }
        return iArr;
    }

    public double[] toDoubleArray(int i) {
        double[] dArr = new double[numVariables()];
        Object[] objArr = (Object[]) this.values.get(i);
        for (int i2 = 0; i2 < objArr.length; i2++) {
            dArr[i2] = ((Double) objArr[i2]).doubleValue();
        }
        return dArr;
    }

    @Override // cc.mallet.grmm.types.Factor
    public Factor duplicate() {
        Assignment assignment = new Assignment();
        assignment.vars = new HashVarSet(this.vars);
        assignment.var2idx = this.var2idx.clone();
        assignment.values = new ArrayList(this.values.size());
        for (int i = 0; i < this.values.size(); i++) {
            assignment.values.add(((Object[]) this.values.get(i)).clone());
        }
        assignment.scale = this.scale;
        return assignment;
    }

    public void dump() {
        dump(new PrintWriter((Writer) new OutputStreamWriter(System.out), true));
    }

    public void dump(PrintWriter printWriter) {
        printWriter.print("ASSIGNMENT ");
        printWriter.println(varSet());
        for (int i = 0; i < this.var2idx.size(); i++) {
            printWriter.print(this.vars.get(i));
            printWriter.print(" ");
        }
        printWriter.println();
        for (int i2 = 0; i2 < numRows(); i2++) {
            for (int i3 = 0; i3 < this.var2idx.size(); i3++) {
                printWriter.print(getObject(i2, this.vars.get(i3)));
                printWriter.print(" ");
            }
            printWriter.println();
        }
    }

    public void dumpNumeric() {
        for (int i = 0; i < this.var2idx.size(); i++) {
            Variable variable = this.vars.get(i);
            System.out.println(variable + " " + get(variable));
        }
    }

    @Override // cc.mallet.grmm.types.AbstractFactor, cc.mallet.grmm.types.Factor
    public boolean containsVar(Variable variable) {
        return colOfVar(variable, false) != -1;
    }

    public void setValue(Variable variable, int i) {
        checkIsSingleRow();
        setValue(0, variable, i);
    }

    public void setValue(int i, Variable variable, int i2) {
        ((Object[]) this.values.get(i))[colOfVar(variable, true)] = new Integer(i2);
    }

    public void setDouble(int i, Variable variable, double d) {
        ((Object[]) this.values.get(i))[colOfVar(variable, true)] = new Double(d);
    }

    private int colOfVar(Variable variable, boolean z) {
        if (this.var2idx.containsKey(variable)) {
            return this.var2idx.get(variable);
        }
        if (z) {
            return addVar(variable);
        }
        return -1;
    }

    private int addVar(Variable variable) {
        int size = this.vars.size();
        this.vars.add(variable);
        this.var2idx.put(variable, size);
        for (int i = 0; i < numRows(); i++) {
            Object[] objArr = new Object[size + 1];
            System.arraycopy((Object[]) this.values.get(i), 0, objArr, 0, size);
            this.values.set(i, objArr);
        }
        return size;
    }

    public void setRow(int i, Assignment assignment) {
        checkAssignmentsMatch(assignment);
        this.values.set(i, ((Object[]) assignment.values.get(i)).clone());
    }

    public void setRow(int i, int[] iArr) {
        this.values.set(i, boxArray(iArr));
    }

    @Override // cc.mallet.grmm.types.AbstractFactor
    protected Factor extractMaxInternal(VarSet varSet) {
        return asTable().extractMax(varSet);
    }

    @Override // cc.mallet.grmm.types.AbstractFactor
    protected double lookupValueInternal(int i) {
        int i2 = 0;
        for (int i3 = 0; i3 < numRows(); i3++) {
            if (singleIndex(i3) == i) {
                i2++;
            }
        }
        return i2 * this.scale;
    }

    @Override // cc.mallet.grmm.types.AbstractFactor
    protected Factor marginalizeInternal(VarSet varSet) {
        Assignment assignment = new Assignment();
        Variable[] variableArray = varSet.toVariableArray();
        for (int i = 0; i < numRows(); i++) {
            Object[] objArr = new Object[variableArray.length];
            for (int i2 = 0; i2 < varSet.size(); i2++) {
                objArr[i2] = getObject(i, varSet.get(i2));
            }
            assignment.addRow(variableArray, objArr);
        }
        assignment.scale = this.scale;
        return assignment;
    }

    @Override // cc.mallet.grmm.types.Factor
    public boolean almostEquals(Factor factor, double d) {
        return asTable().almostEquals(factor, d);
    }

    @Override // cc.mallet.grmm.types.Factor
    public boolean isNaN() {
        return false;
    }

    @Override // cc.mallet.grmm.types.Factor
    public Factor normalize() {
        this.scale = 1.0d / numRows();
        return this;
    }

    @Override // cc.mallet.grmm.types.AbstractFactor, cc.mallet.grmm.types.Factor
    public Assignment sample(Randoms randoms) {
        Object[] objArr = (Object[]) this.values.get(randoms.nextInt(numRows()));
        Assignment assignment = new Assignment();
        assignment.addRow((Variable[]) this.vars.toArray(new Variable[numVariables()]), objArr);
        return assignment;
    }

    @Override // cc.mallet.grmm.types.Factor
    public String dumpToString() {
        StringWriter stringWriter = new StringWriter();
        dump(new PrintWriter(stringWriter));
        return stringWriter.toString();
    }

    @Override // cc.mallet.grmm.types.Factor
    public Factor slice(Assignment assignment) {
        throw new UnsupportedOperationException();
    }

    @Override // cc.mallet.grmm.types.AbstractFactor, cc.mallet.grmm.types.Factor
    public AbstractTableFactor asTable() {
        Variable[] variableArr = (Variable[]) this.vars.toArray(new Variable[0]);
        int[] iArr = new int[numRows()];
        double[] dArr = new double[numRows()];
        for (int i = 0; i < numRows(); i++) {
            iArr[i] = singleIndex(i);
            int i2 = i;
            dArr[i2] = dArr[i2] + 1.0d;
        }
        return new TableFactor(variableArr, new SparseMatrixn(Utils.toSizesArray(variableArr), iArr, dArr));
    }

    public List asList() {
        ArrayList arrayList = new ArrayList(numRows());
        for (int i = 0; i < numRows(); i++) {
            arrayList.add(getRow(i));
        }
        return arrayList;
    }

    public Assignment subAssn(int i, int i2) {
        Assignment assignment = new Assignment();
        for (int i3 = i; i3 < i2; i3++) {
            assignment.addRow(getRow(i3));
        }
        return assignment;
    }

    public int[] getColumnInt(Variable variable) {
        int[] iArr = new int[numRows()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = get(i, variable);
        }
        return iArr;
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        int readInt = objectInputStream.readInt();
        int readInt2 = objectInputStream.readInt();
        this.var2idx = new TObjectIntHashMap(readInt2);
        for (int i = 0; i < readInt2; i++) {
            this.var2idx.put((Variable) objectInputStream.readObject(), i);
        }
        int readInt3 = objectInputStream.readInt();
        this.values = new ArrayList(readInt3);
        for (int i2 = 0; i2 < readInt3; i2++) {
            this.values.add((Object[]) objectInputStream.readObject());
        }
        this.scale = readInt >= 2 ? objectInputStream.readDouble() : 1.0d;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeInt(2);
        objectOutputStream.writeInt(numVariables());
        for (int i = 0; i < numVariables(); i++) {
            objectOutputStream.writeObject(getVariable(i));
        }
        objectOutputStream.writeInt(numRows());
        for (int i2 = 0; i2 < numRows(); i2++) {
            objectOutputStream.writeObject((Object[]) this.values.get(i2));
        }
        objectOutputStream.writeDouble(this.scale);
    }
}
