package net.sf.tweety.math.opt.solver;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import net.sf.tweety.commons.util.VectorTools;
import net.sf.tweety.math.GeneralMathException;
import net.sf.tweety.math.opt.ConstraintSatisfactionProblem;
import net.sf.tweety.math.opt.OptimizationProblem;
import net.sf.tweety.math.opt.Solver;
import net.sf.tweety.math.term.FloatConstant;
import net.sf.tweety.math.term.IntegerConstant;
import net.sf.tweety.math.term.Term;
import net.sf.tweety.math.term.Variable;
import org.riso.numerical.LBFGS;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:net.sf.tweety.math-1.15.jar:net/sf/tweety/math/opt/solver/LbfgsSolver.class */
public class LbfgsSolver extends Solver {
    private Logger log = LoggerFactory.getLogger(LbfgsSolver.class);
    private Map<Variable, Term> startingPoint;

    public LbfgsSolver(Map<Variable, Term> map) {
        this.startingPoint = map;
    }

    @Override // net.sf.tweety.math.opt.Solver
    public Map<Variable, Term> solve(ConstraintSatisfactionProblem constraintSatisfactionProblem) throws GeneralMathException {
        if (constraintSatisfactionProblem.size() > 0) {
            throw new IllegalArgumentException("The gradient descent method works only for optimization problems without constraints.");
        }
        this.log.trace("Solving the following optimization problem using L-BFGS:\n===BEGIN===\n" + constraintSatisfactionProblem + "\n===END===");
        Term targetFunction = ((OptimizationProblem) constraintSatisfactionProblem).getTargetFunction();
        if (((OptimizationProblem) constraintSatisfactionProblem).getType() == 1) {
            targetFunction = new IntegerConstant(-1).mult(targetFunction);
        }
        ArrayList arrayList = new ArrayList(targetFunction.getVariables());
        LinkedList linkedList = new LinkedList();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            linkedList.add(targetFunction.derive((Variable) it.next()).simplify());
        }
        Map<Variable, Term> map = this.startingPoint;
        int size = arrayList.size();
        double[] dArr = new double[size];
        for (int i = 0; i < size; i++) {
            dArr[i] = map.get(arrayList.get(i)).doubleValue();
        }
        double doubleValue = targetFunction.replaceAllTerms(map).doubleValue();
        double[] dArr2 = new double[size];
        for (int i2 = 0; i2 < size; i2++) {
            dArr2[i2] = ((Term) linkedList.get(i2)).replaceAllTerms(map).doubleValue();
        }
        double[] dArr3 = new double[size];
        int[] iArr = {-1, 3};
        int[] iArr2 = {0};
        this.log.trace("Starting optimization.");
        while (iArr2[0] >= 0) {
            try {
                new LBFGS().lbfgs(size, 1000, dArr, doubleValue, dArr2, false, dArr3, iArr, 1.0E-5d, 1.0E-15d, iArr2);
                this.log.trace("Current manhattan distance of gradient to zero: " + VectorTools.manhattanDistanceToZero(dArr2));
                if (iArr2[0] == 0) {
                    break;
                }
                if (iArr2[0] == 1) {
                    int i3 = 0;
                    Iterator it2 = arrayList.iterator();
                    while (it2.hasNext()) {
                        map.put((Variable) it2.next(), new FloatConstant(dArr[i3]));
                        i3++;
                    }
                    doubleValue = targetFunction.replaceAllTerms(map).doubleValue();
                    for (int i4 = 0; i4 < size; i4++) {
                        dArr2[i4] = ((Term) linkedList.get(i4)).replaceAllTerms(map).doubleValue();
                    }
                }
            } catch (Exception e) {
                throw new GeneralMathException("Call to L-BFGS failed.");
            }
        }
        this.log.trace("Optimum found: " + map);
        return map;
    }

    public static boolean isInstalled() throws UnsupportedOperationException {
        return true;
    }
}
