package ai.vespa.models.evaluation;

import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import com.yahoo.collections.Pair;
import com.yahoo.config.FileReference;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.text.Utf8;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import net.jpountz.lz4.LZ4FrameInputStream;

/* loaded from: input_file:ai/vespa/models/evaluation/RankProfilesConfigImporter.class */
public class RankProfilesConfigImporter {
    private final FileAcquirer fileAcquirer;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/vespa/models/evaluation/RankProfilesConfigImporter$SmallConstantsInfo.class */
    public static class SmallConstantsInfo {
        private static final Pattern valuePattern = Pattern.compile("constant\\(([a-zA-Z0-9_.]+)\\)\\.value");
        private static final Pattern typePattern = Pattern.compile("constant\\(([a-zA-Z0-9_.]+)\\)\\.type");
        private final Map<String, TensorType> types = new HashMap();
        private final Map<String, String> values = new HashMap();

        private SmallConstantsInfo() {
        }

        void addIfSmallConstantInfo(String str, String str2) {
            tryValue(str, str2);
            tryType(str, str2);
        }

        private void tryValue(String str, String str2) {
            Matcher matcher = valuePattern.matcher(str);
            if (matcher.matches()) {
                this.values.put(matcher.group(1), str2);
            }
        }

        private void tryType(String str, String str2) {
            Matcher matcher = typePattern.matcher(str);
            if (matcher.matches()) {
                this.types.put(matcher.group(1), TensorType.fromSpec(str2));
            }
        }

        List<Constant> asConstants() {
            ArrayList arrayList = new ArrayList();
            for (Map.Entry<String, String> entry : this.values.entrySet()) {
                TensorType tensorType = this.types.get(entry.getKey());
                if (tensorType == null) {
                    throw new IllegalStateException("Missing type of '" + entry.getKey() + "'");
                }
                arrayList.add(new Constant(entry.getKey(), Tensor.from(tensorType, entry.getValue())));
            }
            return arrayList;
        }
    }

    public RankProfilesConfigImporter(FileAcquirer fileAcquirer) {
        this.fileAcquirer = fileAcquirer;
    }

    public Map<String, Model> importFrom(RankProfilesConfig rankProfilesConfig, RankingConstantsConfig rankingConstantsConfig, RankingExpressionsConfig rankingExpressionsConfig, OnnxModelsConfig onnxModelsConfig) {
        try {
            HashMap hashMap = new HashMap();
            Iterator it = rankProfilesConfig.rankprofile().iterator();
            while (it.hasNext()) {
                Model importProfile = importProfile((RankProfilesConfig.Rankprofile) it.next(), rankingConstantsConfig, rankingExpressionsConfig, onnxModelsConfig);
                hashMap.put(importProfile.name(), importProfile);
            }
            return hashMap;
        } catch (ParseException e) {
            throw new IllegalArgumentException("Could not read rank profiles config - version mismatch?", e);
        }
    }

    @Deprecated
    public Map<String, Model> importFrom(RankProfilesConfig rankProfilesConfig, RankingConstantsConfig rankingConstantsConfig, OnnxModelsConfig onnxModelsConfig) {
        return importFrom(rankProfilesConfig, rankingConstantsConfig, new RankingExpressionsConfig.Builder().build(), onnxModelsConfig);
    }

    private Model importProfile(RankProfilesConfig.Rankprofile rankprofile, RankingConstantsConfig rankingConstantsConfig, RankingExpressionsConfig rankingExpressionsConfig, OnnxModelsConfig onnxModelsConfig) throws ParseException {
        List<OnnxModel> readOnnxModelsConfig = readOnnxModelsConfig(onnxModelsConfig);
        List<Constant> readLargeConstants = readLargeConstants(rankingConstantsConfig);
        Map<String, RankingExpression> readLargeExpressions = readLargeExpressions(rankingExpressionsConfig);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        LinkedHashMap linkedHashMap2 = new LinkedHashMap();
        SmallConstantsInfo smallConstantsInfo = new SmallConstantsInfo();
        ExpressionFunction expressionFunction = null;
        ExpressionFunction expressionFunction2 = null;
        for (RankProfilesConfig.Rankprofile.Fef.Property property : rankprofile.fef().property()) {
            Optional<FunctionReference> fromSerial = FunctionReference.fromSerial(property.name());
            Optional<FunctionReference> fromExternalSerial = FunctionReference.fromExternalSerial(property.name());
            Optional<Pair<FunctionReference, String>> fromTypeArgumentSerial = FunctionReference.fromTypeArgumentSerial(property.name());
            Optional<FunctionReference> fromReturnTypeSerial = FunctionReference.fromReturnTypeSerial(property.name());
            if (fromExternalSerial.isPresent()) {
                ExpressionFunction expressionFunction3 = new ExpressionFunction(fromExternalSerial.get().functionName(), Collections.emptyList(), readLargeExpressions.get(property.value()));
                if (fromExternalSerial.get().isFree()) {
                    linkedHashMap.put(fromExternalSerial.get(), expressionFunction3);
                }
                linkedHashMap2.put(fromExternalSerial.get(), expressionFunction3);
            } else if (fromSerial.isPresent()) {
                ExpressionFunction expressionFunction4 = new ExpressionFunction(fromSerial.get().functionName(), Collections.emptyList(), new RankingExpression(fromSerial.get().functionName(), property.value()));
                if (fromSerial.get().isFree()) {
                    linkedHashMap.put(fromSerial.get(), expressionFunction4);
                }
                linkedHashMap2.put(fromSerial.get(), expressionFunction4);
            } else if (fromTypeArgumentSerial.isPresent()) {
                FunctionReference functionReference = (FunctionReference) fromTypeArgumentSerial.get().getFirst();
                ExpressionFunction withArgument = ((ExpressionFunction) linkedHashMap2.get(functionReference)).withArgument((String) fromTypeArgumentSerial.get().getSecond(), TensorType.fromSpec(property.value()));
                if (functionReference.isFree()) {
                    linkedHashMap.put(functionReference, withArgument);
                }
                linkedHashMap2.put(functionReference, withArgument);
            } else if (fromReturnTypeSerial.isPresent()) {
                ExpressionFunction withReturnType = ((ExpressionFunction) linkedHashMap2.get(fromReturnTypeSerial.get())).withReturnType(TensorType.fromSpec(property.value()));
                if (fromReturnTypeSerial.get().isFree()) {
                    linkedHashMap.put(fromReturnTypeSerial.get(), withReturnType);
                }
                linkedHashMap2.put(fromReturnTypeSerial.get(), withReturnType);
            } else if (property.name().equals("vespa.rank.firstphase")) {
                expressionFunction = new ExpressionFunction("firstphase", new ArrayList(), new RankingExpression("first-phase", property.value()));
            } else if (property.name().equals("vespa.rank.secondphase")) {
                expressionFunction2 = new ExpressionFunction("secondphase", new ArrayList(), new RankingExpression("second-phase", property.value()));
            } else {
                smallConstantsInfo.addIfSmallConstantInfo(property.name(), property.value());
            }
        }
        if (functionByName("firstphase", linkedHashMap.values()) == null && expressionFunction != null) {
            linkedHashMap.put(FunctionReference.fromName("firstphase"), expressionFunction);
        }
        if (functionByName("secondphase", linkedHashMap.values()) == null && expressionFunction2 != null) {
            linkedHashMap.put(FunctionReference.fromName("secondphase"), expressionFunction2);
        }
        readLargeConstants.addAll(smallConstantsInfo.asConstants());
        try {
            return new Model(rankprofile.name(), linkedHashMap, linkedHashMap2, readLargeConstants, readOnnxModelsConfig);
        } catch (RuntimeException e) {
            throw new IllegalArgumentException("Could not load model '" + rankprofile.name() + "'", e);
        }
    }

    private ExpressionFunction functionByName(String str, Collection<ExpressionFunction> collection) {
        for (ExpressionFunction expressionFunction : collection) {
            if (expressionFunction.getName().equals(str)) {
                return expressionFunction;
            }
        }
        return null;
    }

    private List<OnnxModel> readOnnxModelsConfig(OnnxModelsConfig onnxModelsConfig) {
        ArrayList arrayList = new ArrayList();
        if (onnxModelsConfig != null) {
            Iterator it = onnxModelsConfig.model().iterator();
            while (it.hasNext()) {
                arrayList.add(readOnnxModelConfig((OnnxModelsConfig.Model) it.next()));
            }
        }
        return arrayList;
    }

    private OnnxModel readOnnxModelConfig(OnnxModelsConfig.Model model) {
        try {
            String name = model.name();
            File waitFor = this.fileAcquirer.waitFor(model.fileref(), 7L, TimeUnit.DAYS);
            OnnxEvaluatorOptions onnxEvaluatorOptions = new OnnxEvaluatorOptions();
            onnxEvaluatorOptions.setExecutionMode(model.stateless_execution_mode());
            onnxEvaluatorOptions.setInterOpThreads(model.stateless_interop_threads());
            onnxEvaluatorOptions.setIntraOpThreads(model.stateless_intraop_threads());
            return new OnnxModel(name, waitFor, onnxEvaluatorOptions);
        } catch (InterruptedException e) {
            throw new IllegalStateException("Gave up waiting for ONNX model " + model.name());
        }
    }

    private List<Constant> readLargeConstants(RankingConstantsConfig rankingConstantsConfig) {
        ArrayList arrayList = new ArrayList();
        for (RankingConstantsConfig.Constant constant : rankingConstantsConfig.constant()) {
            arrayList.add(new Constant(constant.name(), readTensorFromFile(constant.name(), TensorType.fromSpec(constant.type()), constant.fileref())));
        }
        return arrayList;
    }

    private Map<String, RankingExpression> readLargeExpressions(RankingExpressionsConfig rankingExpressionsConfig) throws ParseException {
        HashMap hashMap = new HashMap();
        for (RankingExpressionsConfig.Expression expression : rankingExpressionsConfig.expression()) {
            hashMap.put(expression.name(), readExpressionFromFile(expression.name(), expression.fileref()));
        }
        return hashMap;
    }

    protected final String readExpressionFromFile(File file) throws IOException {
        return file.getName().endsWith(".lz4") ? Utf8.toString(IOUtils.readBytes(new LZ4FrameInputStream(new FileInputStream(file)), 65536)) : Utf8.toString(IOUtils.readFileBytes(file));
    }

    protected RankingExpression readExpressionFromFile(String str, FileReference fileReference) throws ParseException {
        try {
            return new RankingExpression(str, readExpressionFromFile(this.fileAcquirer.waitFor(fileReference, 7L, TimeUnit.DAYS)));
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        } catch (InterruptedException e2) {
            throw new IllegalStateException("Gave up waiting for expression " + str);
        }
    }

    protected Tensor readTensorFromFile(String str, TensorType tensorType, FileReference fileReference) {
        try {
            File waitFor = this.fileAcquirer.waitFor(fileReference, 7L, TimeUnit.DAYS);
            if (waitFor.getName().endsWith(".tbf")) {
                return TypedBinaryFormat.decode(Optional.of(tensorType), GrowableByteBuffer.wrap(IOUtils.readFileBytes(waitFor)));
            }
            throw new IllegalArgumentException("Constant files on other formats than .tbf are not supported, got " + waitFor + " for constant " + str);
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        } catch (InterruptedException e2) {
            throw new IllegalStateException("Gave up waiting for constant " + str);
        }
    }
}
