package uk.ac.sussex.gdsc.smlm.math3.optim.nonlinear.scalar.gradient;

import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.solvers.BrentSolver;
import org.apache.commons.math3.analysis.solvers.UnivariateSolver;
import org.apache.commons.math3.exception.MathIllegalStateException;
import org.apache.commons.math3.exception.MathInternalError;
import org.apache.commons.math3.exception.MathUnsupportedOperationException;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.SimpleValueChecker;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.optim.nonlinear.scalar.GradientMultivariateOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.Preconditioner;
import org.apache.commons.math3.optim.univariate.BracketFinder;
import org.apache.commons.math3.optim.univariate.BrentOptimizer;
import org.apache.commons.math3.optim.univariate.SearchInterval;
import org.apache.commons.math3.optim.univariate.SimpleUnivariateValueChecker;
import org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction;
import org.apache.commons.math3.optim.univariate.UnivariatePointValuePair;

/* loaded from: input_file:uk/ac/sussex/gdsc/smlm/math3/optim/nonlinear/scalar/gradient/BoundedNonLinearConjugateGradientOptimizer.class */
public class BoundedNonLinearConjugateGradientOptimizer extends GradientMultivariateOptimizer {
    private final Formula updateFormula;
    private final Preconditioner preconditioner;
    private final UnivariateSolver solver;
    double initialStep;
    private boolean isLower;
    private boolean isUpper;
    private double[] lower;
    private double[] upper;
    private boolean useGradientLineSearch;
    private boolean noBracket;
    private double sign;

    /* loaded from: input_file:uk/ac/sussex/gdsc/smlm/math3/optim/nonlinear/scalar/gradient/BoundedNonLinearConjugateGradientOptimizer$BracketingStep.class */
    public static class BracketingStep implements OptimizationData {
        private final double initialStep;

        public BracketingStep(double d) {
            this.initialStep = d;
        }

        public double getBracketingStep() {
            return this.initialStep;
        }
    }

    /* loaded from: input_file:uk/ac/sussex/gdsc/smlm/math3/optim/nonlinear/scalar/gradient/BoundedNonLinearConjugateGradientOptimizer$Formula.class */
    public enum Formula {
        FLETCHER_REEVES,
        POLAK_RIBIERE
    }

    /* loaded from: input_file:uk/ac/sussex/gdsc/smlm/math3/optim/nonlinear/scalar/gradient/BoundedNonLinearConjugateGradientOptimizer$IdentityPreconditioner.class */
    public static class IdentityPreconditioner implements Preconditioner {
        public double[] precondition(double[] dArr, double[] dArr2) {
            return (double[]) dArr2.clone();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:uk/ac/sussex/gdsc/smlm/math3/optim/nonlinear/scalar/gradient/BoundedNonLinearConjugateGradientOptimizer$LineSearch.class */
    public class LineSearch extends BrentOptimizer {
        private static final double REL_TOL_UNUSED = 1.0E-15d;
        private static final double ABS_TOL_UNUSED = Double.MIN_VALUE;
        private final BracketFinder bracket;

        LineSearch(double d, double d2) {
            super(REL_TOL_UNUSED, ABS_TOL_UNUSED, new SimpleUnivariateValueChecker(d, d2));
            this.bracket = new BracketFinder();
        }

        public UnivariatePointValuePair search(double[] dArr, double[] dArr2) {
            int length = dArr.length;
            UnivariateFunction univariateFunction = d -> {
                double[] dArr3 = new double[length];
                for (int i = 0; i < length; i++) {
                    dArr3[i] = dArr[i] + (d * dArr2[i]);
                }
                BoundedNonLinearConjugateGradientOptimizer.this.applyBounds(dArr3);
                return BoundedNonLinearConjugateGradientOptimizer.this.computeObjectiveValue(dArr3);
            };
            GoalType goalType = BoundedNonLinearConjugateGradientOptimizer.this.getGoalType();
            this.bracket.search(univariateFunction, goalType, 0.0d, BoundedNonLinearConjugateGradientOptimizer.this.initialStep);
            return optimize(new OptimizationData[]{new MaxEval(Integer.MAX_VALUE), new UnivariateObjectiveFunction(univariateFunction), goalType, new SearchInterval(this.bracket.getLo(), this.bracket.getHi(), this.bracket.getMid())});
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:uk/ac/sussex/gdsc/smlm/math3/optim/nonlinear/scalar/gradient/BoundedNonLinearConjugateGradientOptimizer$LineSearchFunction.class */
    public class LineSearchFunction implements UnivariateFunction {
        private final double[] currentPoint;
        private final double[] searchDirection;

        public LineSearchFunction(double[] dArr, double[] dArr2) {
            this.currentPoint = (double[]) dArr.clone();
            this.searchDirection = (double[]) dArr2.clone();
        }

        public double value(double d) {
            double[] dArr = (double[]) this.currentPoint.clone();
            for (int i = 0; i < dArr.length; i++) {
                int i2 = i;
                dArr[i2] = dArr[i2] + (d * this.searchDirection[i]);
            }
            double[] dArr2 = (double[]) dArr.clone();
            BoundedNonLinearConjugateGradientOptimizer.this.applyBounds(dArr);
            double[] computeObjectiveGradient = BoundedNonLinearConjugateGradientOptimizer.this.computeObjectiveGradient(dArr);
            if (BoundedNonLinearConjugateGradientOptimizer.this.checkGradients(computeObjectiveGradient, dArr2)) {
                return Double.NaN;
            }
            double d2 = 0.0d;
            for (int i3 = 0; i3 < computeObjectiveGradient.length; i3++) {
                d2 += computeObjectiveGradient[i3] * this.searchDirection[i3];
            }
            return d2;
        }
    }

    public BoundedNonLinearConjugateGradientOptimizer(Formula formula, ConvergenceChecker<PointValuePair> convergenceChecker) {
        this(formula, convergenceChecker, new BrentSolver(), new IdentityPreconditioner());
    }

    public BoundedNonLinearConjugateGradientOptimizer(Formula formula, ConvergenceChecker<PointValuePair> convergenceChecker, UnivariateSolver univariateSolver) {
        this(formula, convergenceChecker, univariateSolver, new IdentityPreconditioner());
    }

    public BoundedNonLinearConjugateGradientOptimizer(Formula formula, ConvergenceChecker<PointValuePair> convergenceChecker, UnivariateSolver univariateSolver, Preconditioner preconditioner) {
        super(convergenceChecker);
        this.initialStep = 1.0d;
        this.useGradientLineSearch = true;
        this.updateFormula = formula;
        this.solver = univariateSolver;
        this.preconditioner = preconditioner;
        this.initialStep = 1.0d;
    }

    /* renamed from: optimize, reason: merged with bridge method [inline-methods] */
    public PointValuePair m1830optimize(OptimizationData... optimizationDataArr) {
        return super.optimize(optimizationDataArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: doOptimize, reason: merged with bridge method [inline-methods] */
    public PointValuePair m1831doOptimize() {
        double point;
        double d;
        ConvergenceChecker convergenceChecker = getConvergenceChecker();
        double[] startPoint = getStartPoint();
        GoalType goalType = getGoalType();
        int length = startPoint.length;
        this.sign = goalType == GoalType.MINIMIZE ? -1.0d : 1.0d;
        double[] dArr = (double[]) startPoint.clone();
        applyBounds(startPoint);
        double[] computeObjectiveGradient = computeObjectiveGradient(startPoint);
        checkGradients(computeObjectiveGradient, dArr);
        if (goalType == GoalType.MINIMIZE) {
            for (int i = 0; i < length; i++) {
                computeObjectiveGradient[i] = -computeObjectiveGradient[i];
            }
        }
        double[] precondition = this.preconditioner.precondition(startPoint, computeObjectiveGradient);
        double[] dArr2 = (double[]) precondition.clone();
        double d2 = 0.0d;
        for (int i2 = 0; i2 < length; i2++) {
            d2 += computeObjectiveGradient[i2] * dArr2[i2];
        }
        double d3 = 1.0E-6d;
        double d4 = 1.0E-10d;
        if (getConvergenceChecker() instanceof SimpleValueChecker) {
            d3 = getConvergenceChecker().getRelativeThreshold();
            d4 = getConvergenceChecker().getRelativeThreshold();
        }
        LineSearch lineSearch = new LineSearch(Math.sqrt(d3), Math.sqrt(d4));
        PointValuePair pointValuePair = null;
        int maxEvaluations = getMaxEvaluations();
        while (true) {
            incrementIterationCount();
            double computeObjectiveValue = computeObjectiveValue(startPoint);
            PointValuePair pointValuePair2 = pointValuePair;
            pointValuePair = new PointValuePair(startPoint, computeObjectiveValue);
            if (pointValuePair2 != null && convergenceChecker.converged(getIterations(), pointValuePair2, pointValuePair)) {
                return pointValuePair;
            }
            if (this.useGradientLineSearch) {
                LineSearchFunction lineSearchFunction = new LineSearchFunction(startPoint, dArr2);
                try {
                    double findUpperBound = findUpperBound(lineSearchFunction, 0.0d, this.initialStep);
                    if (this.noBracket) {
                        point = findUpperBound;
                    } else {
                        point = this.solver.solve(maxEvaluations, lineSearchFunction, 0.0d, findUpperBound, 1.0E-15d);
                        maxEvaluations -= this.solver.getEvaluations();
                    }
                } catch (MathIllegalStateException e) {
                    point = lineSearch.search(startPoint, dArr2).getPoint();
                }
            } else {
                point = lineSearch.search(startPoint, dArr2).getPoint();
            }
            for (int i3 = 0; i3 < startPoint.length; i3++) {
                int i4 = i3;
                startPoint[i4] = startPoint[i4] + (point * dArr2[i3]);
            }
            double[] dArr3 = (double[]) startPoint.clone();
            applyBounds(startPoint);
            double[] computeObjectiveGradient2 = computeObjectiveGradient(startPoint);
            checkGradients(computeObjectiveGradient2, dArr3);
            if (goalType == GoalType.MINIMIZE) {
                for (int i5 = 0; i5 < length; i5++) {
                    computeObjectiveGradient2[i5] = -computeObjectiveGradient2[i5];
                }
            }
            double d5 = d2;
            double[] precondition2 = this.preconditioner.precondition(startPoint, computeObjectiveGradient2);
            d2 = 0.0d;
            for (int i6 = 0; i6 < length; i6++) {
                d2 += computeObjectiveGradient2[i6] * precondition2[i6];
            }
            if (d2 == 0.0d) {
                return new PointValuePair(startPoint, computeObjectiveValue(startPoint));
            }
            switch (this.updateFormula) {
                case FLETCHER_REEVES:
                    d = d2 / d5;
                    break;
                case POLAK_RIBIERE:
                    double d6 = 0.0d;
                    for (int i7 = 0; i7 < computeObjectiveGradient2.length; i7++) {
                        d6 += computeObjectiveGradient2[i7] * precondition[i7];
                    }
                    d = (d2 - d6) / d5;
                    break;
                default:
                    throw new MathInternalError();
            }
            precondition = precondition2;
            if (getIterations() % length == 0 || d < 0.0d) {
                dArr2 = (double[]) precondition.clone();
            } else {
                for (int i8 = 0; i8 < length; i8++) {
                    dArr2[i8] = precondition[i8] + (d * dArr2[i8]);
                }
            }
            checkGradients(dArr2, dArr3, -this.sign);
        }
    }

    protected void parseOptimizationData(OptimizationData... optimizationDataArr) {
        super.parseOptimizationData(optimizationDataArr);
        int length = optimizationDataArr.length;
        int i = 0;
        while (true) {
            if (i >= length) {
                break;
            }
            OptimizationData optimizationData = optimizationDataArr[i];
            if (optimizationData instanceof BracketingStep) {
                this.initialStep = ((BracketingStep) optimizationData).getBracketingStep();
                break;
            }
            i++;
        }
        checkParameters();
    }

    private static double findUpperBound(UnivariateFunction univariateFunction, double d, double d2) {
        double value = univariateFunction.value(d);
        double d3 = d2;
        while (true) {
            double d4 = d3;
            if (d4 >= Double.MAX_VALUE) {
                throw new MathIllegalStateException(LocalizedFormats.UNABLE_TO_BRACKET_OPTIMUM_IN_LINE_SEARCH, new Object[0]);
            }
            double d5 = d + d4;
            double value2 = univariateFunction.value(d5);
            if (value * value2 <= 0.0d) {
                return d5;
            }
            if (Double.isNaN(value2)) {
                throw new MathIllegalStateException(LocalizedFormats.UNABLE_TO_BRACKET_OPTIMUM_IN_LINE_SEARCH, new Object[0]);
            }
            d3 = d4 * Math.max(2.0d, value / value2);
        }
    }

    private double findUpperBoundWithChecks(UnivariateFunction univariateFunction, double d, double d2) {
        this.noBracket = false;
        double value = univariateFunction.value(d);
        if (Double.isNaN(value)) {
            throw new MathIllegalStateException(LocalizedFormats.UNABLE_TO_BRACKET_OPTIMUM_IN_LINE_SEARCH, new Object[0]);
        }
        double d3 = Double.NaN;
        double d4 = d2;
        while (true) {
            double d5 = d4;
            if (d5 >= Double.MAX_VALUE) {
                throw new MathIllegalStateException(LocalizedFormats.UNABLE_TO_BRACKET_OPTIMUM_IN_LINE_SEARCH, new Object[0]);
            }
            double d6 = d + d5;
            double value2 = univariateFunction.value(d6);
            if (value * value2 <= 0.0d) {
                return d6;
            }
            if (Double.isNaN(value2)) {
                this.noBracket = true;
                if (!Double.isNaN(d3)) {
                    return d3;
                }
                while (true) {
                    d5 *= 0.1d;
                    if (d5 <= Double.MIN_VALUE) {
                        if (Double.isNaN(d3)) {
                            throw new MathIllegalStateException(LocalizedFormats.UNABLE_TO_BRACKET_OPTIMUM_IN_LINE_SEARCH, new Object[0]);
                        }
                        return d3;
                    }
                    double d7 = d + d5;
                    double value3 = univariateFunction.value(d7);
                    if (value * value3 <= 0.0d) {
                        return d7;
                    }
                    if (!Double.isNaN(value3)) {
                        d3 = d7;
                    }
                }
            } else {
                d3 = d6;
                d4 = d5 * Math.max(2.0d, value / value2);
            }
        }
    }

    private void checkParameters() {
        this.lower = getLowerBound();
        this.upper = getUpperBound();
        this.isLower = checkArray(this.lower, Double.NEGATIVE_INFINITY);
        this.isUpper = checkArray(this.upper, Double.POSITIVE_INFINITY);
        if (this.isUpper && this.isLower) {
            for (int i = 0; i < this.lower.length; i++) {
                if (this.lower[i] > this.upper[i]) {
                    throw new MathUnsupportedOperationException(LocalizedFormats.CONSTRAINT, new Object[0]);
                }
            }
        }
    }

    private static boolean checkArray(double[] dArr, double d) {
        if (dArr == null) {
            return false;
        }
        for (double d2 : dArr) {
            if (d2 != d) {
                return true;
            }
        }
        return false;
    }

    void applyBounds(double[] dArr) {
        if (this.isUpper) {
            for (int i = 0; i < dArr.length; i++) {
                if (dArr[i] > this.upper[i]) {
                    dArr[i] = this.upper[i];
                }
            }
        }
        if (this.isLower) {
            for (int i2 = 0; i2 < dArr.length; i2++) {
                if (dArr[i2] < this.lower[i2]) {
                    dArr[i2] = this.lower[i2];
                }
            }
        }
    }

    boolean checkGradients(double[] dArr, double[] dArr2) {
        return checkGradients(dArr, dArr2, this.sign);
    }

    private boolean checkGradients(double[] dArr, double[] dArr2, double d) {
        if (this.isUpper) {
            for (int i = 0; i < dArr2.length; i++) {
                if (dArr2[i] >= this.upper[i] && Math.signum(dArr[i]) == d) {
                    dArr[i] = 0.0d;
                }
            }
        }
        if (this.isLower) {
            for (int i2 = 0; i2 < dArr2.length; i2++) {
                if (dArr2[i2] <= this.lower[i2] && Math.signum(dArr[i2]) == (-d)) {
                    dArr[i2] = 0.0d;
                }
            }
        }
        boolean z = false;
        for (int i3 = 0; i3 < dArr2.length; i3++) {
            if (Double.isNaN(dArr[i3])) {
                z = true;
                dArr[i3] = 0.0d;
            }
        }
        return z;
    }

    public boolean isUseGradientLineSearch() {
        return this.useGradientLineSearch;
    }

    public void setUseGradientLineSearch(boolean z) {
        this.useGradientLineSearch = z;
    }
}
