package ai.vespa.models.evaluation;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:ai/vespa/models/evaluation/LazyArrayContext.class */
public final class LazyArrayContext extends Context implements ContextIndex {
    private final ExpressionFunction function;
    private final IndexedBindings indexedBindings;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/vespa/models/evaluation/LazyArrayContext$IndexedBindings.class */
    public static class IndexedBindings {
        private final ImmutableMap<String, Integer> nameToIndex;
        private final ImmutableSet<String> arguments;
        private final Value[] values;
        private static final Value missing = new DoubleValue(Double.NaN).freeze();
        private Value missingValue = new DoubleValue(Double.NaN).freeze();

        private IndexedBindings(ImmutableMap<String, Integer> immutableMap, Value[] valueArr, ImmutableSet<String> immutableSet) {
            this.nameToIndex = immutableMap;
            this.values = valueArr;
            this.arguments = immutableSet;
        }

        IndexedBindings(ExpressionFunction expressionFunction, Map<FunctionReference, ExpressionFunction> map, List<Constant> list, LazyArrayContext lazyArrayContext, Model model) {
            LinkedHashSet linkedHashSet = new LinkedHashSet();
            LinkedHashSet linkedHashSet2 = new LinkedHashSet();
            extractBindTargets(expressionFunction.getBody().getRoot(), map, linkedHashSet, linkedHashSet2);
            this.arguments = ImmutableSet.copyOf(linkedHashSet2);
            this.values = new Value[linkedHashSet.size()];
            Arrays.fill(this.values, missing);
            int i = 0;
            ImmutableMap.Builder builder = new ImmutableMap.Builder();
            Iterator<String> it = linkedHashSet.iterator();
            while (it.hasNext()) {
                int i2 = i;
                i++;
                builder.put(it.next(), Integer.valueOf(i2));
            }
            this.nameToIndex = builder.build();
            for (Constant constant : list) {
                Integer num = (Integer) this.nameToIndex.get("constant(" + constant.name() + ")");
                if (num != null) {
                    this.values[num.intValue()] = new TensorValue(constant.value());
                }
            }
            for (Map.Entry<FunctionReference, ExpressionFunction> entry : map.entrySet()) {
                Integer num2 = (Integer) this.nameToIndex.get(entry.getKey().serialForm());
                if (num2 != null) {
                    this.values[num2.intValue()] = new LazyValue(entry.getKey(), lazyArrayContext, model);
                }
            }
        }

        private void setMissingValue(Tensor tensor) {
            this.missingValue = new TensorValue(tensor).freeze();
        }

        private void extractBindTargets(ExpressionNode expressionNode, Map<FunctionReference, ExpressionFunction> map, Set<String> set, Set<String> set2) {
            if (isFunctionReference(expressionNode)) {
                FunctionReference functionReference = FunctionReference.fromSerial(expressionNode.toString()).get();
                set.add(functionReference.serialForm());
                extractBindTargets(map.get(functionReference).getBody().getRoot(), map, set, set2);
            } else {
                if (isConstant(expressionNode)) {
                    set.add(expressionNode.toString());
                    return;
                }
                if (expressionNode instanceof ReferenceNode) {
                    set.add(expressionNode.toString());
                    set2.add(expressionNode.toString());
                } else if (expressionNode instanceof CompositeNode) {
                    Iterator it = ((CompositeNode) expressionNode).children().iterator();
                    while (it.hasNext()) {
                        extractBindTargets((ExpressionNode) it.next(), map, set, set2);
                    }
                }
            }
        }

        private boolean isFunctionReference(ExpressionNode expressionNode) {
            if (!(expressionNode instanceof ReferenceNode)) {
                return false;
            }
            ReferenceNode referenceNode = (ReferenceNode) expressionNode;
            return referenceNode.getName().equals("rankingExpression") && referenceNode.getArguments().size() == 1;
        }

        private boolean isConstant(ExpressionNode expressionNode) {
            if (!(expressionNode instanceof ReferenceNode)) {
                return false;
            }
            ReferenceNode referenceNode = (ReferenceNode) expressionNode;
            return referenceNode.getName().equals("constant") && referenceNode.getArguments().size() == 1;
        }

        Value get(int i) {
            Value value = this.values[i];
            return value == missing ? this.missingValue : value;
        }

        void set(int i, Value value) {
            this.values[i] = value;
        }

        Set<String> names() {
            return this.nameToIndex.keySet();
        }

        Set<String> arguments() {
            return this.arguments;
        }

        Integer indexOf(String str) {
            return (Integer) this.nameToIndex.get(str);
        }

        IndexedBindings copy(Context context) {
            Value[] valueArr = new Value[this.values.length];
            for (int i = 0; i < this.values.length; i++) {
                valueArr[i] = this.values[i] instanceof LazyValue ? ((LazyValue) this.values[i]).copyFor(context) : this.values[i];
            }
            return new IndexedBindings(this.nameToIndex, valueArr, this.arguments);
        }
    }

    private LazyArrayContext(ExpressionFunction expressionFunction, IndexedBindings indexedBindings) {
        this.function = expressionFunction;
        this.indexedBindings = indexedBindings.copy(this);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public LazyArrayContext(ExpressionFunction expressionFunction, Map<FunctionReference, ExpressionFunction> map, List<Constant> list, Model model) {
        this.function = expressionFunction;
        this.indexedBindings = new IndexedBindings(expressionFunction, map, list, this, model);
    }

    public void setMissingValue(Tensor tensor) {
        this.indexedBindings.setMissingValue(tensor);
    }

    public void put(String str, Value value) {
        put(requireIndexOf(str).intValue(), value);
    }

    public final void put(int i, double d) {
        put(i, (Value) DoubleValue.frozen(d));
    }

    public void put(int i, Value value) {
        this.indexedBindings.set(i, value.freeze());
    }

    public TensorType getType(Reference reference) {
        return get(requireIndexOf(reference.toString()).intValue()).type();
    }

    public Value get(String str) {
        return get(requireIndexOf(str).intValue());
    }

    public Value get(int i) {
        return this.indexedBindings.get(i);
    }

    public double getDouble(int i) {
        return get(i).asDouble();
    }

    public int getIndex(String str) {
        return requireIndexOf(str).intValue();
    }

    public int size() {
        return this.indexedBindings.names().size();
    }

    public Set<String> names() {
        return this.indexedBindings.names();
    }

    public Set<String> arguments() {
        return this.indexedBindings.arguments();
    }

    private Integer requireIndexOf(String str) {
        Integer indexOf = this.indexedBindings.indexOf(str);
        if (indexOf == null) {
            throw new IllegalArgumentException("Value '" + str + "' can not be bound in " + this);
        }
        return indexOf;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public boolean isMissing(String str) {
        return this.indexedBindings.indexOf(str) == null;
    }

    public Value defaultValue() {
        return this.indexedBindings.missingValue;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public LazyArrayContext copy() {
        return new LazyArrayContext(this.function, this.indexedBindings);
    }
}
