package ai.vespa.models.evaluation;

import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

/* loaded from: input_file:ai/vespa/models/evaluation/BindingExtractor.class */
class BindingExtractor {
    private final Map<FunctionReference, ExpressionFunction> referencedFunctions;
    private final List<OnnxModel> onnxModels;
    private final Map<FunctionReference, FunctionInfo> functionsInfo = new LinkedHashMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:ai/vespa/models/evaluation/BindingExtractor$FunctionInfo.class */
    public static class FunctionInfo {
        final Set<String> bindTargets = new LinkedHashSet();
        final Set<String> arguments = new LinkedHashSet();
        final Map<String, OnnxModel> onnxModelsInUse = new LinkedHashMap();

        FunctionInfo() {
        }

        void merge(FunctionInfo functionInfo) {
            this.bindTargets.addAll(functionInfo.bindTargets);
            this.arguments.addAll(functionInfo.arguments);
            this.onnxModelsInUse.putAll(functionInfo.onnxModelsInUse);
        }

        void removeTarget(String str) {
            this.bindTargets.remove(str);
            this.arguments.remove(str);
        }
    }

    public BindingExtractor(Map<FunctionReference, ExpressionFunction> map, List<OnnxModel> list) {
        this.referencedFunctions = map;
        this.onnxModels = list;
    }

    FunctionInfo extractFrom(FunctionReference functionReference) {
        if (this.functionsInfo.containsKey(functionReference)) {
            return this.functionsInfo.get(functionReference);
        }
        FunctionInfo extractFrom = extractFrom(this.referencedFunctions.get(functionReference));
        this.functionsInfo.put(functionReference, extractFrom);
        return extractFrom;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public FunctionInfo extractFrom(ExpressionFunction expressionFunction) {
        if (expressionFunction == null) {
            return null;
        }
        return extractBindTargets(expressionFunction.getBody().getRoot());
    }

    private FunctionInfo extractBindTargets(ExpressionNode expressionNode) {
        FunctionInfo functionInfo = new FunctionInfo();
        if (isFunctionReference(expressionNode)) {
            Optional<FunctionReference> fromSerial = FunctionReference.fromSerial(expressionNode.toString());
            if (fromSerial.isEmpty()) {
                throw new IllegalArgumentException("Could not extract function " + expressionNode + " from serialized form '" + expressionNode.toString() + "'");
            }
            FunctionReference functionReference = fromSerial.get();
            functionInfo.bindTargets.add(functionReference.serialForm());
            FunctionInfo extractFrom = extractFrom(functionReference);
            if (extractFrom == null) {
                functionInfo.arguments.add(functionReference.serialForm());
            } else {
                functionInfo.merge(extractFrom);
            }
            return functionInfo;
        }
        if (expressionNode instanceof TensorFunctionNode) {
            TensorFunctionNode tensorFunctionNode = (TensorFunctionNode) expressionNode;
            Iterator it = tensorFunctionNode.children().iterator();
            while (it.hasNext()) {
                functionInfo.merge(extractBindTargets((ExpressionNode) it.next()));
            }
            tensorFunctionNode.withTransformedExpressions(expressionNode2 -> {
                functionInfo.merge(extractBindTargets(expressionNode2));
                return expressionNode2;
            });
            TensorFunction function = tensorFunctionNode.function();
            if (function instanceof Generate) {
                Iterator it2 = function.type((TypeContext) null).dimensions().iterator();
                while (it2.hasNext()) {
                    functionInfo.removeTarget(((TensorType.Dimension) it2.next()).name());
                }
            }
            return functionInfo;
        }
        if (isOnnx(expressionNode)) {
            return extractOnnxTargets(expressionNode);
        }
        if (isConstant(expressionNode)) {
            functionInfo.bindTargets.add(expressionNode.toString());
            return functionInfo;
        }
        if (expressionNode instanceof ReferenceNode) {
            functionInfo.bindTargets.add(expressionNode.toString());
            functionInfo.arguments.add(expressionNode.toString());
            return functionInfo;
        }
        if (!(expressionNode instanceof CompositeNode)) {
            return expressionNode instanceof ConstantNode ? functionInfo : functionInfo;
        }
        Iterator it3 = ((CompositeNode) expressionNode).children().iterator();
        while (it3.hasNext()) {
            functionInfo.merge(extractBindTargets((ExpressionNode) it3.next()));
        }
        return functionInfo;
    }

    private FunctionInfo extractOnnxTargets(ExpressionNode expressionNode) {
        FunctionInfo functionInfo = new FunctionInfo();
        String expressionNode2 = expressionNode.toString();
        functionInfo.bindTargets.add(expressionNode2);
        Optional<String> argument = getArgument(expressionNode);
        if (argument.isPresent()) {
            for (OnnxModel onnxModel : this.onnxModels) {
                if (onnxModel.name().equals(argument.get())) {
                    onnxModel.load();
                    for (String str : onnxModel.inputs().keySet()) {
                        functionInfo.bindTargets.add(str);
                        functionInfo.arguments.add(str);
                    }
                    functionInfo.onnxModelsInUse.put(expressionNode2, onnxModel);
                    return functionInfo;
                }
            }
        }
        functionInfo.arguments.add(expressionNode2);
        return functionInfo;
    }

    private Optional<String> getArgument(ExpressionNode expressionNode) {
        if (expressionNode instanceof ReferenceNode) {
            ReferenceNode referenceNode = (ReferenceNode) expressionNode;
            if (referenceNode.getArguments().size() > 0) {
                ReferenceNode referenceNode2 = (ExpressionNode) referenceNode.getArguments().expressions().get(0);
                if (referenceNode2 instanceof ConstantNode) {
                    return Optional.of(stripQuotes(referenceNode2.toString()));
                }
                if (referenceNode2 instanceof ReferenceNode) {
                    return Optional.of(referenceNode2.getName());
                }
            }
        }
        return Optional.empty();
    }

    public static String stripQuotes(String str) {
        if (str.length() < 3) {
            return str;
        }
        int length = str.length() - 1;
        char charAt = str.charAt(0);
        char charAt2 = str.charAt(length);
        return (charAt == '\"' && charAt2 == '\"') ? str.substring(1, length) : (charAt == '\'' && charAt2 == '\'') ? str.substring(1, length) : str;
    }

    private boolean isFunctionReference(ExpressionNode expressionNode) {
        if (!(expressionNode instanceof ReferenceNode)) {
            return false;
        }
        ReferenceNode referenceNode = (ReferenceNode) expressionNode;
        return referenceNode.getName().equals("rankingExpression") && referenceNode.getArguments().size() == 1;
    }

    private boolean isOnnx(ExpressionNode expressionNode) {
        if (!(expressionNode instanceof ReferenceNode)) {
            return false;
        }
        ReferenceNode referenceNode = (ReferenceNode) expressionNode;
        return referenceNode.getName().equals("onnx") || referenceNode.getName().equals("onnxModel");
    }

    private boolean isConstant(ExpressionNode expressionNode) {
        if (!(expressionNode instanceof ReferenceNode)) {
            return false;
        }
        ReferenceNode referenceNode = (ReferenceNode) expressionNode;
        return referenceNode.getName().equals("constant") && referenceNode.getArguments().size() == 1;
    }
}
