package fi.evolver.ai.spring.util;

import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingRegistry;
import com.knuddels.jtokkit.api.EncodingType;
import com.knuddels.jtokkit.api.ModelType;
import fi.evolver.ai.spring.Model;
import fi.evolver.ai.spring.chat.function.FunctionSpec;
import fi.evolver.ai.spring.chat.prompt.ChatPrompt;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:fi/evolver/ai/spring/util/TokenUtils.class */
public class TokenUtils {
    private static final int PER_FUNCTION_OVERHEAD = 30;
    private static final Logger log = LoggerFactory.getLogger(TokenUtils.class);
    private static final EncodingType DEFAULT_TOKENIZER_ENCODING = EncodingType.CL100K_BASE;
    private static final EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();

    public static int calculateTokens(String str, Model model) {
        return getEncodingForModel(model).encode(str).size();
    }

    public static int calculateTokens(String str) {
        return calculateTokens(str, (Model) null);
    }

    public static int calculateTokens(FunctionSpec<?> functionSpec, Model model) {
        return calculateTokens(functionSpec.toJsonSchema(), model) + PER_FUNCTION_OVERHEAD;
    }

    public static int calculateTokens(ChatPrompt chatPrompt) {
        return chatPrompt.messages().stream().map((v0) -> {
            return v0.getContent();
        }).mapToInt(str -> {
            return calculateTokens(str, chatPrompt.model());
        }).sum() + chatPrompt.functions().stream().mapToInt(functionSpec -> {
            return calculateTokens((FunctionSpec<?>) functionSpec, chatPrompt.model());
        }).sum();
    }

    private static Encoding getEncodingForModel(Model model) {
        if (model == null) {
            return registry.getEncoding(DEFAULT_TOKENIZER_ENCODING);
        }
        String engine = model.getEngine();
        Optional fromName = ModelType.fromName(engine);
        while (true) {
            Optional optional = fromName;
            if (!optional.isEmpty()) {
                return registry.getEncodingForModel((ModelType) optional.get());
            }
            int lastIndexOf = engine.lastIndexOf(45);
            if (lastIndexOf == -1) {
                log.warn("No token encoding found for model {}, using default", model.name());
                return registry.getEncoding(DEFAULT_TOKENIZER_ENCODING);
            }
            engine = engine.substring(0, lastIndexOf);
            fromName = ModelType.fromName(engine);
        }
    }
}
