package edu.umd.cloud9.util.cfd;

import edu.umd.cloud9.io.pair.PairOfInts;
import edu.umd.cloud9.util.fd.Int2IntFrequencyDistributionOpen;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import java.util.Iterator;

/* loaded from: input_file:edu/umd/cloud9/util/cfd/Int2IntConditionalFrequencyDistributionOpen.class */
public class Int2IntConditionalFrequencyDistributionOpen implements Int2IntConditionalFrequencyDistribution {
    private final Int2ObjectMap<Int2IntFrequencyDistributionOpen> distributions = new Int2ObjectOpenHashMap();
    private final Int2IntFrequencyDistributionOpen marginals = new Int2IntFrequencyDistributionOpen();
    private long sumOfAllFrequencies = 0;

    @Override // edu.umd.cloud9.util.cfd.Int2IntConditionalFrequencyDistribution
    public void set(int i, int i2, int i3) {
        if (!this.distributions.containsKey(i2)) {
            Int2IntFrequencyDistributionOpen int2IntFrequencyDistributionOpen = new Int2IntFrequencyDistributionOpen();
            int2IntFrequencyDistributionOpen.set(i, i3);
            this.distributions.put(i2, int2IntFrequencyDistributionOpen);
            this.marginals.increment(i, i3);
            this.sumOfAllFrequencies += i3;
            return;
        }
        Int2IntFrequencyDistributionOpen int2IntFrequencyDistributionOpen2 = (Int2IntFrequencyDistributionOpen) this.distributions.get(i2);
        int i4 = int2IntFrequencyDistributionOpen2.get(i);
        int2IntFrequencyDistributionOpen2.set(i, i3);
        this.distributions.put(i2, int2IntFrequencyDistributionOpen2);
        this.marginals.increment(i, (-i4) + i3);
        this.sumOfAllFrequencies = (this.sumOfAllFrequencies - i4) + i3;
    }

    @Override // edu.umd.cloud9.util.cfd.Int2IntConditionalFrequencyDistribution
    public void increment(int i, int i2) {
        increment(i, i2, 1);
    }

    @Override // edu.umd.cloud9.util.cfd.Int2IntConditionalFrequencyDistribution
    public void increment(int i, int i2, int i3) {
        int i4 = get(i, i2);
        if (i4 == 0) {
            set(i, i2, i3);
        } else {
            set(i, i2, i4 + i3);
        }
    }

    @Override // edu.umd.cloud9.util.cfd.Int2IntConditionalFrequencyDistribution
    public int get(int i, int i2) {
        if (this.distributions.containsKey(i2)) {
            return ((Int2IntFrequencyDistributionOpen) this.distributions.get(i2)).get(i);
        }
        return 0;
    }

    @Override // edu.umd.cloud9.util.cfd.Int2IntConditionalFrequencyDistribution
    public int getMarginalCount(int i) {
        return this.marginals.get(i);
    }

    @Override // edu.umd.cloud9.util.cfd.Int2IntConditionalFrequencyDistribution
    public Int2IntFrequencyDistributionOpen getConditionalDistribution(int i) {
        return this.distributions.containsKey(i) ? (Int2IntFrequencyDistributionOpen) this.distributions.get(i) : new Int2IntFrequencyDistributionOpen();
    }

    @Override // edu.umd.cloud9.util.cfd.Int2IntConditionalFrequencyDistribution
    public long getSumOfAllFrequencies() {
        return this.sumOfAllFrequencies;
    }

    @Override // edu.umd.cloud9.util.cfd.Int2IntConditionalFrequencyDistribution
    public void check() {
        Int2IntFrequencyDistributionOpen int2IntFrequencyDistributionOpen = new Int2IntFrequencyDistributionOpen();
        long j = 0;
        for (Int2IntFrequencyDistributionOpen int2IntFrequencyDistributionOpen2 : this.distributions.values()) {
            long j2 = 0;
            Iterator<PairOfInts> it = int2IntFrequencyDistributionOpen2.iterator();
            while (it.hasNext()) {
                PairOfInts next = it.next();
                j2 += next.getRightElement();
                int2IntFrequencyDistributionOpen.increment(next.getLeftElement(), next.getRightElement());
            }
            if (j2 != int2IntFrequencyDistributionOpen2.getSumOfCounts()) {
                throw new RuntimeException("Internal Error!");
            }
            j += int2IntFrequencyDistributionOpen2.getSumOfCounts();
        }
        if (j != getSumOfAllFrequencies()) {
            throw new RuntimeException("Internal Error! Got " + j + ", Expected " + getSumOfAllFrequencies());
        }
        Iterator<PairOfInts> it2 = int2IntFrequencyDistributionOpen.iterator();
        while (it2.hasNext()) {
            PairOfInts next2 = it2.next();
            if (next2.getRightElement() != this.marginals.get(next2.getLeftElement())) {
                throw new RuntimeException("Internal Error!");
            }
        }
        Iterator<PairOfInts> it3 = int2IntFrequencyDistributionOpen.iterator();
        while (it3.hasNext()) {
            PairOfInts next3 = it3.next();
            if (next3.getRightElement() != int2IntFrequencyDistributionOpen.get(next3.getLeftElement())) {
                throw new RuntimeException("Internal Error!");
            }
        }
    }
}
