package edu.umd.cloud9.util.count;

import edu.umd.cloud9.util.pair.PairOfObjectInt;
import it.unimi.dsi.fastutil.objects.Object2ObjectMap;
import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap;
import java.lang.Comparable;

/* loaded from: input_file:edu/umd/cloud9/util/count/OpenObject2IntConditionalFrequencyDistribution.class */
public class OpenObject2IntConditionalFrequencyDistribution<K extends Comparable<K>> implements Object2IntConditionalFrequencyDistribution<K> {
    private final Object2ObjectMap<K, OpenObject2IntFrequencyDistribution<K>> distributions = new Object2ObjectOpenHashMap();
    private final OpenObject2IntFrequencyDistribution<K> marginals = new OpenObject2IntFrequencyDistribution<>();
    private long sumOfAllFrequencies = 0;

    @Override // edu.umd.cloud9.util.count.Object2IntConditionalFrequencyDistribution
    public void set(K k, K k2, int i) {
        if (!this.distributions.containsKey(k2)) {
            OpenObject2IntFrequencyDistribution openObject2IntFrequencyDistribution = new OpenObject2IntFrequencyDistribution();
            openObject2IntFrequencyDistribution.set(k, i);
            this.distributions.put(k2, openObject2IntFrequencyDistribution);
            this.marginals.increment(k, i);
            this.sumOfAllFrequencies += i;
            return;
        }
        OpenObject2IntFrequencyDistribution openObject2IntFrequencyDistribution2 = (OpenObject2IntFrequencyDistribution) this.distributions.get(k2);
        int i2 = openObject2IntFrequencyDistribution2.get(k);
        openObject2IntFrequencyDistribution2.set(k, i);
        this.distributions.put(k2, openObject2IntFrequencyDistribution2);
        this.marginals.increment(k, (-i2) + i);
        this.sumOfAllFrequencies = (this.sumOfAllFrequencies - i2) + i;
    }

    @Override // edu.umd.cloud9.util.count.Object2IntConditionalFrequencyDistribution
    public void increment(K k, K k2) {
        increment(k, k2, 1);
    }

    @Override // edu.umd.cloud9.util.count.Object2IntConditionalFrequencyDistribution
    public void increment(K k, K k2, int i) {
        int i2 = get(k, k2);
        if (i2 == 0) {
            set(k, k2, i);
        } else {
            set(k, k2, i2 + i);
        }
    }

    @Override // edu.umd.cloud9.util.count.Object2IntConditionalFrequencyDistribution
    public int get(K k, K k2) {
        if (this.distributions.containsKey(k2)) {
            return ((OpenObject2IntFrequencyDistribution) this.distributions.get(k2)).get(k);
        }
        return 0;
    }

    @Override // edu.umd.cloud9.util.count.Object2IntConditionalFrequencyDistribution
    public int getMarginalCount(K k) {
        return this.marginals.get(k);
    }

    @Override // edu.umd.cloud9.util.count.Object2IntConditionalFrequencyDistribution
    public OpenObject2IntFrequencyDistribution<K> getConditionalDistribution(K k) {
        return this.distributions.containsKey(k) ? (OpenObject2IntFrequencyDistribution) this.distributions.get(k) : new OpenObject2IntFrequencyDistribution<>();
    }

    @Override // edu.umd.cloud9.util.count.Object2IntConditionalFrequencyDistribution
    public long getSumOfAllFrequencies() {
        return this.sumOfAllFrequencies;
    }

    @Override // edu.umd.cloud9.util.count.Object2IntConditionalFrequencyDistribution
    public void check() {
        OpenObject2IntFrequencyDistribution openObject2IntFrequencyDistribution = new OpenObject2IntFrequencyDistribution();
        long j = 0;
        for (OpenObject2IntFrequencyDistribution openObject2IntFrequencyDistribution2 : this.distributions.values()) {
            long j2 = 0;
            for (PairOfObjectInt<K> pairOfObjectInt : openObject2IntFrequencyDistribution2.getSortedEvents()) {
                j2 += pairOfObjectInt.getRightElement();
                openObject2IntFrequencyDistribution.increment(pairOfObjectInt.getLeftElement(), pairOfObjectInt.getRightElement());
            }
            if (j2 != openObject2IntFrequencyDistribution2.getSumOfFrequencies()) {
                throw new RuntimeException("Internal Error!");
            }
            j += openObject2IntFrequencyDistribution2.getSumOfFrequencies();
        }
        if (j != getSumOfAllFrequencies()) {
            throw new RuntimeException("Internal Error! Got " + j + ", Expected " + getSumOfAllFrequencies());
        }
        for (PairOfObjectInt<K> pairOfObjectInt2 : openObject2IntFrequencyDistribution.getSortedEvents()) {
            if (pairOfObjectInt2.getRightElement() != this.marginals.get(pairOfObjectInt2.getLeftElement())) {
                throw new RuntimeException("Internal Error!");
            }
        }
        for (PairOfObjectInt<K> pairOfObjectInt3 : openObject2IntFrequencyDistribution.getSortedEvents()) {
            if (pairOfObjectInt3.getRightElement() != openObject2IntFrequencyDistribution.get(pairOfObjectInt3.getLeftElement())) {
                throw new RuntimeException("Internal Error!");
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // edu.umd.cloud9.util.count.Object2IntConditionalFrequencyDistribution
    public /* bridge */ /* synthetic */ Object2IntFrequencyDistribution getConditionalDistribution(Comparable comparable) {
        return getConditionalDistribution((OpenObject2IntConditionalFrequencyDistribution<K>) comparable);
    }
}
