package uk.ac.sussex.gdsc.smlm.fitting.nonlinear;

import java.util.Arrays;
import java.util.Objects;
import java.util.logging.Level;
import java.util.logging.Logger;
import uk.ac.sussex.gdsc.core.utils.BitFlagUtils;
import uk.ac.sussex.gdsc.smlm.fitting.FisherInformationMatrix;
import uk.ac.sussex.gdsc.smlm.fitting.FitStatus;
import uk.ac.sussex.gdsc.smlm.fitting.FunctionSolverType;
import uk.ac.sussex.gdsc.smlm.function.Gradient1Function;
import uk.ac.sussex.gdsc.smlm.function.Gradient1FunctionStore;
import uk.ac.sussex.gdsc.smlm.function.GradientFunction;
import uk.ac.sussex.gdsc.smlm.function.ValueFunction;
import uk.ac.sussex.gdsc.smlm.function.ValueProcedure;

/* loaded from: input_file:uk/ac/sussex/gdsc/smlm/fitting/nonlinear/SteppingFunctionSolver.class */
public abstract class SteppingFunctionSolver extends BaseFunctionSolver {
    private static Logger logger = Logger.getLogger(SteppingFunctionSolver.class.getName());
    private static Level traceLevel = Level.FINEST;
    protected int[] gradientIndices;
    protected final ToleranceChecker tc;
    protected final ParameterBounds bounds;
    private double[] weights;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:uk/ac/sussex/gdsc/smlm/fitting/nonlinear/SteppingFunctionSolver$SimpleValueProcedure.class */
    public static class SimpleValueProcedure implements ValueProcedure {
        int index;
        double[] fx;

        SimpleValueProcedure(double[] dArr) {
            this.fx = dArr;
        }

        @Override // uk.ac.sussex.gdsc.smlm.function.ValueProcedure
        public void execute(double d) {
            double[] dArr = this.fx;
            int i = this.index;
            this.index = i + 1;
            dArr[i] = d;
        }
    }

    public SteppingFunctionSolver(FunctionSolverType functionSolverType, Gradient1Function gradient1Function) {
        this(functionSolverType, gradient1Function, new ToleranceChecker(0.001d, 1.0E-6d), null);
    }

    public SteppingFunctionSolver(FunctionSolverType functionSolverType, Gradient1Function gradient1Function, ToleranceChecker toleranceChecker, ParameterBounds parameterBounds) {
        super(functionSolverType, gradient1Function);
        this.tc = (ToleranceChecker) Objects.requireNonNull(toleranceChecker, "tolerance checker");
        if (parameterBounds == null) {
            parameterBounds = ParameterBounds.create(gradient1Function);
        } else if (parameterBounds.getGradientFunction() != gradient1Function) {
            throw new IllegalArgumentException("Bounds must be constructed with the same gradient function");
        }
        this.bounds = parameterBounds;
    }

    @Override // uk.ac.sussex.gdsc.smlm.fitting.nonlinear.BaseFunctionSolver
    protected FitStatus computeFit(double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4) {
        double computeFitValue;
        int converged;
        this.gradientIndices = this.function.gradientIndices();
        double[] dArr5 = new double[this.gradientIndices.length];
        double[] dArr6 = (double[]) dArr3.clone();
        this.bounds.initialise();
        this.tc.reset();
        String simpleName = getClass().getSimpleName();
        try {
            try {
                this.lastY = prepareFitValue(dArr, dArr3);
                double computeFitValue2 = computeFitValue(dArr3);
                if (logger.isLoggable(traceLevel)) {
                    log("%s Value [%s] = %s : %s", simpleName, Integer.valueOf(this.tc.getIterations()), Double.valueOf(computeFitValue2), dArr3);
                }
                while (true) {
                    computeStep(dArr5);
                    if (logger.isLoggable(traceLevel)) {
                        log("%s Step [%s] = %s", simpleName, Integer.valueOf(this.tc.getIterations()), dArr5);
                    }
                    this.bounds.applyBounds(dArr3, dArr5, dArr6);
                    computeFitValue = computeFitValue(dArr6);
                    if (logger.isLoggable(traceLevel)) {
                        log("%s Value [%s] = %s : %s", simpleName, Integer.valueOf(this.tc.getIterations()), Double.valueOf(computeFitValue), dArr6);
                    }
                    converged = this.tc.converged(computeFitValue2, dArr3, computeFitValue, dArr6);
                    if (logger.isLoggable(traceLevel)) {
                        log("%s Status [%s] = %s", simpleName, Integer.valueOf(this.tc.getIterations()), Integer.valueOf(converged));
                    }
                    if (converged != 0) {
                        break;
                    }
                    if (accept(computeFitValue2, dArr3, computeFitValue, dArr6)) {
                        if (logger.isLoggable(traceLevel)) {
                            log("%s Accepted [%s]", simpleName, Integer.valueOf(this.tc.getIterations()));
                        }
                        computeFitValue2 = computeFitValue;
                        System.arraycopy(dArr6, 0, dArr3, 0, dArr3.length);
                        this.bounds.accepted(dArr3, dArr6);
                    }
                }
                this.value = computeFitValue;
                System.arraycopy(dArr6, 0, dArr3, 0, dArr3.length);
                if (logger.isLoggable(traceLevel)) {
                    log("%s End [%s] = %s", simpleName, Integer.valueOf(this.tc.getIterations()), Integer.valueOf(converged));
                }
                if (BitFlagUtils.anySet(converged, 30)) {
                    if (logger.isLoggable(traceLevel)) {
                        log("%s Converged [%s]", simpleName, Integer.valueOf(this.tc.getIterations()));
                    }
                    if (dArr4 != null) {
                        computeDeviationsAndValues(dArr4, dArr2);
                    } else if (dArr2 != null) {
                        computeValues(dArr2);
                    }
                    FitStatus fitStatus = FitStatus.OK;
                    this.iterations = this.tc.getIterations();
                    if (this.evaluations == 0) {
                        this.evaluations = this.iterations;
                    }
                    return fitStatus;
                }
                if (BitFlagUtils.areSet(converged, 1)) {
                    FitStatus fitStatus2 = FitStatus.TOO_MANY_ITERATIONS;
                    this.iterations = this.tc.getIterations();
                    if (this.evaluations == 0) {
                        this.evaluations = this.iterations;
                    }
                    return fitStatus2;
                }
                FitStatus fitStatus3 = FitStatus.FAILED_TO_CONVERGE;
                this.iterations = this.tc.getIterations();
                if (this.evaluations == 0) {
                    this.evaluations = this.iterations;
                }
                return fitStatus3;
            } catch (FunctionSolverException e) {
                String message = e.getMessage();
                logger.log(Level.FINE, () -> {
                    Object[] objArr = new Object[3];
                    objArr[0] = getClass().getSimpleName();
                    objArr[1] = e.fitStatus.getName();
                    objArr[2] = message == null ? "" : " - " + message;
                    return String.format("%s failed: %s%s", objArr);
                });
                FitStatus fitStatus4 = e.fitStatus;
                this.iterations = this.tc.getIterations();
                if (this.evaluations == 0) {
                    this.evaluations = this.iterations;
                }
                return fitStatus4;
            }
        } catch (Throwable th) {
            this.iterations = this.tc.getIterations();
            if (this.evaluations == 0) {
                this.evaluations = this.iterations;
            }
            throw th;
        }
    }

    private static void log(String str, Object... objArr) {
        for (int i = 0; i < objArr.length; i++) {
            if (objArr[i] instanceof double[]) {
                objArr[i] = Arrays.toString((double[]) objArr[i]);
            }
        }
        logger.log(traceLevel, () -> {
            return String.format(str, objArr);
        });
    }

    protected abstract double[] prepareFitValue(double[] dArr, double[] dArr2);

    protected abstract double computeFitValue(double[] dArr);

    protected abstract void computeStep(double[] dArr);

    protected abstract boolean accept(double d, double[] dArr, double d2, double[] dArr2);

    protected void computeDeviationsAndValues(double[] dArr, double[] dArr2) {
        setDeviations(dArr, computeLastFisherInformationMatrix(dArr2));
    }

    protected abstract FisherInformationMatrix computeLastFisherInformationMatrix(double[] dArr);

    /* JADX INFO: Access modifiers changed from: protected */
    public void computeValues(double[] dArr) {
        ((ValueFunction) this.function).forEach(new SimpleValueProcedure(dArr));
    }

    @Override // uk.ac.sussex.gdsc.smlm.fitting.nonlinear.BaseFunctionSolver
    protected boolean computeValue(double[] dArr, double[] dArr2, double[] dArr3) {
        this.gradientIndices = this.function.gradientIndices();
        if (dArr2 == null || dArr2.length != ((Gradient1Function) this.function).size()) {
            this.lastY = prepareFunctionValue(dArr, dArr3);
            this.value = computeFunctionValue(dArr3);
            return true;
        }
        GradientFunction gradientFunction = this.function;
        this.function = new Gradient1FunctionStore((Gradient1Function) this.function, dArr2, (double[][]) null);
        this.lastY = prepareFunctionValue(dArr, dArr3);
        this.value = computeFunctionValue(dArr3);
        this.function = gradientFunction;
        return true;
    }

    protected abstract double[] prepareFunctionValue(double[] dArr, double[] dArr2);

    protected abstract double computeFunctionValue(double[] dArr);

    @Override // uk.ac.sussex.gdsc.smlm.fitting.nonlinear.BaseFunctionSolver
    protected FisherInformationMatrix computeFisherInformationMatrix(double[] dArr, double[] dArr2) {
        this.gradientIndices = this.function.gradientIndices();
        return computeFunctionFisherInformationMatrix(prepareFunctionFisherInformationMatrix(dArr, dArr2), dArr2);
    }

    protected abstract double[] prepareFunctionFisherInformationMatrix(double[] dArr, double[] dArr2);

    protected abstract FisherInformationMatrix computeFunctionFisherInformationMatrix(double[] dArr, double[] dArr2);

    @Override // uk.ac.sussex.gdsc.smlm.fitting.nonlinear.BaseFunctionSolver, uk.ac.sussex.gdsc.smlm.fitting.FunctionSolver
    public boolean isBounded() {
        return true;
    }

    @Override // uk.ac.sussex.gdsc.smlm.fitting.nonlinear.BaseFunctionSolver, uk.ac.sussex.gdsc.smlm.fitting.FunctionSolver
    public void setBounds(double[] dArr, double[] dArr2) {
        this.bounds.setBounds(dArr, dArr2);
    }

    @Override // uk.ac.sussex.gdsc.smlm.fitting.nonlinear.BaseFunctionSolver
    public void setGradientFunction(GradientFunction gradientFunction) {
        super.setGradientFunction(gradientFunction);
        this.bounds.setGradientFunction(gradientFunction);
    }

    @Override // uk.ac.sussex.gdsc.smlm.fitting.nonlinear.BaseFunctionSolver, uk.ac.sussex.gdsc.smlm.fitting.FunctionSolver
    public boolean isWeighted() {
        return true;
    }

    @Override // uk.ac.sussex.gdsc.smlm.fitting.nonlinear.BaseFunctionSolver, uk.ac.sussex.gdsc.smlm.fitting.FunctionSolver
    public void setWeights(double[] dArr) {
        this.weights = dArr;
    }

    public double[] getWeights(int i) {
        if (this.weights == null || this.weights.length != i) {
            return null;
        }
        return this.weights;
    }
}
