package uk.ac.sussex.gdsc.smlm.function;

import java.lang.ref.SoftReference;
import java.util.concurrent.atomic.AtomicReference;
import uk.ac.sussex.gdsc.core.utils.MathUtils;
import uk.ac.sussex.gdsc.smlm.utils.StdMath;

/* loaded from: input_file:uk/ac/sussex/gdsc/smlm/function/PoissonGaussianConvolutionFunction.class */
public final class PoissonGaussianConvolutionFunction implements LikelihoodFunction, LogLikelihoodFunction {
    private static final AtomicReference<SoftReference<LogFactorialCache>> LOG_FACTORIAL_CACHE = new AtomicReference<>(new SoftReference(null));
    final double gain;
    private final double var;
    private final double sd;
    private final double twoVar;
    private final double sqrtTwoVar;
    private final double logNormalisationGaussian;
    private boolean computePmf;

    private PoissonGaussianConvolutionFunction(double d, double d2, boolean z) {
        if (d2 <= 0.0d) {
            throw new IllegalArgumentException("Gaussian variance must be strictly positive");
        }
        this.gain = 1.0d / Math.abs(d);
        if (z) {
            this.sd = Math.sqrt(d2);
            this.var = d2;
        } else {
            this.sd = d2;
            this.var = this.sd * this.sd;
        }
        this.twoVar = 2.0d * this.var;
        this.sqrtTwoVar = Math.sqrt(this.twoVar);
        this.logNormalisationGaussian = PoissonGaussianFunction.getLogNormalisation(this.var);
    }

    public static PoissonGaussianConvolutionFunction createWithStandardDeviation(double d, double d2) {
        return new PoissonGaussianConvolutionFunction(d, d2, false);
    }

    public static PoissonGaussianConvolutionFunction createWithVariance(double d, double d2) {
        return new PoissonGaussianConvolutionFunction(d, d2, true);
    }

    @Override // uk.ac.sussex.gdsc.smlm.function.LikelihoodFunction
    public double likelihood(double d, double d2) {
        if (d2 <= 0.0d) {
            if (!this.computePmf) {
                return StdMath.exp(((((-0.5d) * d) * d) / this.var) + this.logNormalisationGaussian);
            }
            double round = Math.round(d);
            return (gaussianCdf(round + 0.5d) - gaussianCdf(round - 0.5d)) * 0.5d;
        }
        int ceil = (int) Math.ceil((d + (5.0d * this.sd)) / this.gain);
        if (ceil < 0) {
            return 0.0d;
        }
        int floor = (int) Math.floor((d - (5.0d * this.sd)) / this.gain);
        if (floor < 0) {
            floor = 0;
            if (ceil == 0) {
                ceil++;
            }
        }
        LogFactorialCache logFactorialCache = getLogFactorialCache(floor, ceil);
        double log = Math.log(d2);
        double d3 = 0.0d;
        if (this.computePmf) {
            for (int i = floor; i <= ceil; i++) {
                double exp = StdMath.exp(((i * log) - d2) - logFactorialCache.getLogFactorialUnsafe(i));
                double x = getX(d, i);
                d3 += exp * (gaussianCdf(x + 0.5d) - gaussianCdf(x - 0.5d)) * 0.5d;
            }
        } else {
            for (int i2 = floor; i2 <= ceil; i2++) {
                d3 += StdMath.exp((((i2 * log) - d2) - logFactorialCache.getLogFactorialUnsafe(i2)) + (-(MathUtils.pow2(getX(d, i2)) / this.twoVar)) + this.logNormalisationGaussian);
            }
        }
        return d3;
    }

    private double getX(double d, int i) {
        return d - (i * this.gain);
    }

    double gaussianCdf(double d) {
        return Erf.erf(d / this.sqrtTwoVar);
    }

    @Override // uk.ac.sussex.gdsc.smlm.function.LogLikelihoodFunction
    public double logLikelihood(double d, double d2) {
        if (d2 <= 0.0d) {
            if (!this.computePmf) {
                return ((((-0.5d) * d) * d) / this.var) + this.logNormalisationGaussian;
            }
            double round = Math.round(d);
            return Math.log((gaussianCdf(round + 0.5d) - gaussianCdf(round - 0.5d)) * 0.5d);
        }
        int ceil = (int) Math.ceil((d + (5.0d * this.sd)) / this.gain);
        if (ceil < 0) {
            return Double.NEGATIVE_INFINITY;
        }
        int floor = (int) Math.floor((d - (5.0d * this.sd)) / this.gain);
        if (floor < 0) {
            floor = 0;
            if (ceil == 0) {
                ceil++;
            }
        }
        LogFactorialCache logFactorialCache = getLogFactorialCache(floor, ceil);
        double log = Math.log(d2);
        double d3 = 0.0d;
        if (this.computePmf) {
            for (int i = floor; i <= ceil; i++) {
                double exp = StdMath.exp(((i * log) - d2) - logFactorialCache.getLogFactorialUnsafe(i));
                double x = getX(d, i);
                d3 += exp * (gaussianCdf(x + 0.5d) - gaussianCdf(x - 0.5d)) * 0.5d;
            }
        } else {
            for (int i2 = floor; i2 <= ceil; i2++) {
                d3 += StdMath.exp(((((i2 * log) - d2) - logFactorialCache.getLogFactorialUnsafe(i2)) - (MathUtils.pow2(getX(d, i2)) / this.twoVar)) + this.logNormalisationGaussian);
            }
        }
        return Math.log(d3);
    }

    public boolean isComputePmf() {
        return this.computePmf;
    }

    public void setComputePmf(boolean z) {
        this.computePmf = z;
    }

    private static LogFactorialCache getLogFactorialCache(int i, int i2) {
        LogFactorialCache logFactorialCache = LOG_FACTORIAL_CACHE.get().get();
        if (logFactorialCache == null) {
            logFactorialCache = new LogFactorialCache(i2);
            LOG_FACTORIAL_CACHE.set(new SoftReference<>(logFactorialCache));
        }
        logFactorialCache.ensureRange(i, i2);
        return logFactorialCache;
    }
}
