package tech.tablesaw.api.ml.classification;

import com.google.common.collect.Table;
import com.google.common.collect.TreeBasedTable;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.SortedMap;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;
import tech.tablesaw.api.CategoryColumn;
import tech.tablesaw.api.IntColumn;
import tech.tablesaw.columns.Column;

/* loaded from: input_file:tech/tablesaw/api/ml/classification/CategoryConfusionMatrix.class */
public class CategoryConfusionMatrix implements ConfusionMatrix {
    private final Table<Integer, Integer, Integer> table = TreeBasedTable.create();
    private SortedMap<Integer, String> labels = new TreeMap();
    private CategoryColumn labelColumn;

    public CategoryConfusionMatrix(CategoryColumn categoryColumn, SortedSet<String> sortedSet) {
        this.labelColumn = categoryColumn;
        int i = 0;
        Iterator<String> it = sortedSet.iterator();
        while (it.hasNext()) {
            this.labels.put(Integer.valueOf(i), it.next());
            i++;
        }
    }

    @Override // tech.tablesaw.api.ml.classification.ConfusionMatrix
    public void increment(Integer num, Integer num2) {
        Integer num3 = (Integer) this.table.get(num, num2);
        if (num3 == null) {
            this.table.put(num, num2, 1);
        } else {
            this.table.put(num, num2, Integer.valueOf(num3.intValue() + 1));
        }
    }

    @Override // tech.tablesaw.api.ml.classification.ConfusionMatrix
    public String toString() {
        return toTable().toString();
    }

    @Override // tech.tablesaw.api.ml.classification.ConfusionMatrix
    public tech.tablesaw.api.Table toTable() {
        Table<String, String, Integer> sortedTable = sortedTable();
        tech.tablesaw.api.Table create = tech.tablesaw.api.Table.create("Confusion Matrix");
        create.addColumn(new Column[]{new CategoryColumn("")});
        for (String str : sortedTable.rowKeySet()) {
            create.addColumn(new Column[]{new IntColumn(str)});
            create.column(0).appendCell("Predicted " + str);
        }
        int i = 0;
        for (String str2 : sortedTable.rowKeySet()) {
            int i2 = 1;
            Iterator it = sortedTable.columnKeySet().iterator();
            while (it.hasNext()) {
                Integer num = (Integer) sortedTable.get(str2, (String) it.next());
                if (num == null) {
                    create.intColumn(i2).append(0);
                } else {
                    create.intColumn(i2).append(num.intValue());
                    i += num.intValue();
                }
                i2++;
            }
        }
        create.column(0).setName("n = " + i);
        for (int i3 = 1; i3 <= sortedTable.columnKeySet().size(); i3++) {
            create.column(i3).setName("Actual " + create.column(i3).name());
        }
        return create;
    }

    private Table<String, String, Integer> sortedTable() {
        Int2ObjectMap keyToValueMap = this.labelColumn.dictionaryMap().keyToValueMap();
        TreeBasedTable create = TreeBasedTable.create();
        TreeSet treeSet = new TreeSet();
        treeSet.addAll(this.table.columnKeySet());
        treeSet.addAll(this.table.rowKeySet());
        ArrayList arrayList = new ArrayList(treeSet);
        for (int i = 0; i < arrayList.size(); i++) {
            for (int i2 = 0; i2 < arrayList.size(); i2++) {
                Integer num = (Integer) this.table.get(arrayList.get(i), arrayList.get(i2));
                if (num == null) {
                    create.put(keyToValueMap.get(i), keyToValueMap.get(i2), 0);
                } else {
                    create.put(keyToValueMap.get(i), keyToValueMap.get(i2), num);
                }
            }
        }
        return create;
    }

    @Override // tech.tablesaw.api.ml.classification.ConfusionMatrix
    public double accuracy() {
        int i = 0;
        int i2 = 0;
        for (Table.Cell cell : this.table.cellSet()) {
            if (((Integer) cell.getRowKey()).equals(cell.getColumnKey())) {
                i += ((Integer) cell.getValue()).intValue();
            } else {
                i2 += ((Integer) cell.getValue()).intValue();
            }
        }
        return i / ((i + i2) * 1.0d);
    }
}
