package fi.evolver.ai.spring.provider.openai;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingRegistry;
import com.knuddels.jtokkit.api.EncodingType;
import fi.evolver.ai.spring.ApiConnectionParameters;
import fi.evolver.ai.spring.ApiResponseException;
import fi.evolver.ai.spring.chat.ChatApi;
import fi.evolver.ai.spring.chat.prompt.ChatPrompt;
import fi.evolver.ai.spring.chat.prompt.Model;
import fi.evolver.ai.spring.embedding.EmbeddingApi;
import fi.evolver.ai.spring.embedding.EmbeddingCache;
import fi.evolver.ai.spring.embedding.EmbeddingService;
import fi.evolver.ai.spring.embedding.entity.Embedding;
import fi.evolver.ai.spring.embedding.model.EmbeddingData;
import fi.evolver.ai.spring.provider.openai.response.chat.OChatResult;
import fi.evolver.ai.spring.provider.openai.response.embeddings.OEmbeddingsResult;
import fi.evolver.ai.spring.util.Json;
import fi.evolver.basics.spring.http.LoggingHttpClient;
import fi.evolver.basics.spring.http.SseSubscriber;
import fi.evolver.basics.spring.log.MessageLogService;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

@Component
/* loaded from: input_file:fi/evolver/ai/spring/provider/openai/OpenAiService.class */
public class OpenAiService implements ChatApi<OpenAiChatResponse>, EmbeddingApi {
    private final EmbeddingService embeddingService;
    private final MessageLogService messageLogService;
    private static final Logger LOG = LoggerFactory.getLogger(OpenAiService.class);
    private static final EncodingRegistry ENCODING_REGISTRY = Encodings.newDefaultEncodingRegistry();
    private static final Model DEFAULT_EMBEDDING_MODEL = Model.TEXT_EMBEDDING_ADA;
    private static final String AI_CONNECTION_HEADERS = System.getenv("AI_CONNECTION_HEADERS");
    private static final String ENDPOINT_URL_COMPLETIONS = System.getenv("ENDPOINT_URL_COMPLETIONS");
    private static final String ENDPOINT_URL_EMBEDDINGS = System.getenv("ENDPOINT_URL_EMBEDDINGS");
    private static final Map<Model, ModelParameterConfig> MODEL_PARAMETERS = new HashMap();
    private static final Map<String, String> CONNECTION_HEADERS = Collections.unmodifiableMap(parseHeaderString());
    private static final Map<Model, EmbeddingModelConfig> EMBEDDING_MODEL_CONFIG = Map.of(DEFAULT_EMBEDDING_MODEL, new EmbeddingModelConfig(8142, EncodingType.CL100K_BASE));

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:fi/evolver/ai/spring/provider/openai/OpenAiService$EmbeddingBatchEntry.class */
    public static final class EmbeddingBatchEntry extends Record {
        private final String identifier;
        private final String data;

        private EmbeddingBatchEntry(String str, String str2) {
            this.identifier = str;
            this.data = str2;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, EmbeddingBatchEntry.class), EmbeddingBatchEntry.class, "identifier;data", "FIELD:Lfi/evolver/ai/spring/provider/openai/OpenAiService$EmbeddingBatchEntry;->identifier:Ljava/lang/String;", "FIELD:Lfi/evolver/ai/spring/provider/openai/OpenAiService$EmbeddingBatchEntry;->data:Ljava/lang/String;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, EmbeddingBatchEntry.class), EmbeddingBatchEntry.class, "identifier;data", "FIELD:Lfi/evolver/ai/spring/provider/openai/OpenAiService$EmbeddingBatchEntry;->identifier:Ljava/lang/String;", "FIELD:Lfi/evolver/ai/spring/provider/openai/OpenAiService$EmbeddingBatchEntry;->data:Ljava/lang/String;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, EmbeddingBatchEntry.class, Object.class), EmbeddingBatchEntry.class, "identifier;data", "FIELD:Lfi/evolver/ai/spring/provider/openai/OpenAiService$EmbeddingBatchEntry;->identifier:Ljava/lang/String;", "FIELD:Lfi/evolver/ai/spring/provider/openai/OpenAiService$EmbeddingBatchEntry;->data:Ljava/lang/String;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public String identifier() {
            return this.identifier;
        }

        public String data() {
            return this.data;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:fi/evolver/ai/spring/provider/openai/OpenAiService$EmbeddingModelConfig.class */
    public static final class EmbeddingModelConfig extends Record {
        private final Integer maxTokens;
        private final EncodingType encoding;

        private EmbeddingModelConfig(Integer num, EncodingType encodingType) {
            this.maxTokens = num;
            this.encoding = encodingType;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, EmbeddingModelConfig.class), EmbeddingModelConfig.class, "maxTokens;encoding", "FIELD:Lfi/evolver/ai/spring/provider/openai/OpenAiService$EmbeddingModelConfig;->maxTokens:Ljava/lang/Integer;", "FIELD:Lfi/evolver/ai/spring/provider/openai/OpenAiService$EmbeddingModelConfig;->encoding:Lcom/knuddels/jtokkit/api/EncodingType;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, EmbeddingModelConfig.class), EmbeddingModelConfig.class, "maxTokens;encoding", "FIELD:Lfi/evolver/ai/spring/provider/openai/OpenAiService$EmbeddingModelConfig;->maxTokens:Ljava/lang/Integer;", "FIELD:Lfi/evolver/ai/spring/provider/openai/OpenAiService$EmbeddingModelConfig;->encoding:Lcom/knuddels/jtokkit/api/EncodingType;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, EmbeddingModelConfig.class, Object.class), EmbeddingModelConfig.class, "maxTokens;encoding", "FIELD:Lfi/evolver/ai/spring/provider/openai/OpenAiService$EmbeddingModelConfig;->maxTokens:Ljava/lang/Integer;", "FIELD:Lfi/evolver/ai/spring/provider/openai/OpenAiService$EmbeddingModelConfig;->encoding:Lcom/knuddels/jtokkit/api/EncodingType;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public Integer maxTokens() {
            return this.maxTokens;
        }

        public EncodingType encoding() {
            return this.encoding;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    @JsonIgnoreProperties(ignoreUnknown = true)
    /* loaded from: input_file:fi/evolver/ai/spring/provider/openai/OpenAiService$ModelParameterConfig.class */
    public static final class ModelParameterConfig extends Record {
        private final Map<String, String> headers;
        private final String endpoint;

        private ModelParameterConfig(Map<String, String> map, String str) {
            this.headers = map;
            this.endpoint = str;
        }

        public static ModelParameterConfig empty() {
            return new ModelParameterConfig(Collections.emptyMap(), null);
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, ModelParameterConfig.class), ModelParameterConfig.class, "headers;endpoint", "FIELD:Lfi/evolver/ai/spring/provider/openai/OpenAiService$ModelParameterConfig;->headers:Ljava/util/Map;", "FIELD:Lfi/evolver/ai/spring/provider/openai/OpenAiService$ModelParameterConfig;->endpoint:Ljava/lang/String;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, ModelParameterConfig.class), ModelParameterConfig.class, "headers;endpoint", "FIELD:Lfi/evolver/ai/spring/provider/openai/OpenAiService$ModelParameterConfig;->headers:Ljava/util/Map;", "FIELD:Lfi/evolver/ai/spring/provider/openai/OpenAiService$ModelParameterConfig;->endpoint:Ljava/lang/String;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, ModelParameterConfig.class, Object.class), ModelParameterConfig.class, "headers;endpoint", "FIELD:Lfi/evolver/ai/spring/provider/openai/OpenAiService$ModelParameterConfig;->headers:Ljava/util/Map;", "FIELD:Lfi/evolver/ai/spring/provider/openai/OpenAiService$ModelParameterConfig;->endpoint:Ljava/lang/String;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public Map<String, String> headers() {
            return this.headers;
        }

        public String endpoint() {
            return this.endpoint;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:fi/evolver/ai/spring/provider/openai/OpenAiService$StreamingCompletionsEventConsumer.class */
    public static class StreamingCompletionsEventConsumer implements SseSubscriber.SseEventConsumer {
        private final OpenAiChatResponse response;

        public StreamingCompletionsEventConsumer(OpenAiChatResponse openAiChatResponse) {
            this.response = openAiChatResponse;
        }

        public void onEvent(SseSubscriber.SseEvent sseEvent) {
            if ("[DONE]".equals(sseEvent.data().strip())) {
                return;
            }
            if (!sseEvent.data().startsWith("{")) {
                OpenAiService.LOG.warn("Unknown chunk: {}", sseEvent.data());
                return;
            }
            try {
                this.response.addResult((OChatResult) Json.OBJECT_MAPPER.readValue(sseEvent.data(), OChatResult.class));
            } catch (JsonProcessingException e) {
                OpenAiService.LOG.warn("Bad SSE event", e);
            }
        }

        public void onError(Throwable th) {
            this.response.handleError(th);
        }

        public void onComplete() {
            this.response.handleStreamEnd();
        }
    }

    @Autowired
    public OpenAiService(EmbeddingService embeddingService, MessageLogService messageLogService) {
        this.embeddingService = embeddingService;
        this.messageLogService = messageLogService;
    }

    private static URI prepareUri(List<String> list) {
        try {
            return new URI(list.stream().filter((v0) -> {
                return Objects.nonNull(v0);
            }).findFirst().orElseThrow(() -> {
                return new ApiResponseException("The API connection has not been initialized correctly", new Object[0]);
            }));
        } catch (URISyntaxException e) {
            throw new IllegalArgumentException(e);
        }
    }

    private static Map<String, String> prepareHeaders(ApiConnectionParameters apiConnectionParameters, Map<String, String> map) {
        HashMap hashMap = new HashMap(CONNECTION_HEADERS);
        hashMap.putAll(map);
        hashMap.putAll(apiConnectionParameters.headers());
        if (hashMap.isEmpty()) {
            throw new ApiResponseException("The API connection has not been initialized correctly", new Object[0]);
        }
        return hashMap;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // fi.evolver.ai.spring.chat.ChatApi
    public OpenAiChatResponse send(ChatPrompt chatPrompt, ApiConnectionParameters apiConnectionParameters) {
        String generate = OpenAiChatRequestGenerator.generate(chatPrompt);
        LoggingHttpClient loggingHttpClient = new LoggingHttpClient(this.messageLogService, HttpClient.newBuilder().connectTimeout(apiConnectionParameters.connectionTimeout()).build());
        ModelParameterConfig modelParameters = getModelParameters(chatPrompt.model());
        HttpRequest.Builder POST = HttpRequest.newBuilder(prepareUri(Arrays.asList(apiConnectionParameters.endpointUrl(), modelParameters.endpoint(), ENDPOINT_URL_COMPLETIONS))).header("Content-Type", "application/json").timeout(apiConnectionParameters.readTimeout()).POST(HttpRequest.BodyPublishers.ofString(generate));
        Map<String, String> prepareHeaders = prepareHeaders(apiConnectionParameters, modelParameters.headers());
        Objects.requireNonNull(POST);
        prepareHeaders.forEach(POST::header);
        HttpRequest build = POST.build();
        OpenAiChatResponse openAiChatResponse = new OpenAiChatResponse(chatPrompt);
        loggingHttpClient.sendAsync(build, SseSubscriber.createBodyHandler(new StreamingCompletionsEventConsumer(openAiChatResponse)), new LoggingHttpClient.LogParameters("CompletionsRequest")).exceptionally(th -> {
            openAiChatResponse.handleError(th);
            return null;
        });
        return openAiChatResponse;
    }

    @Override // fi.evolver.ai.spring.embedding.EmbeddingApi
    public void createEmbeddings(Model model, String str, Map<String, String> map, ApiConnectionParameters apiConnectionParameters) {
        this.embeddingService.persistEmbeddings(getEmbeddings(model, map, apiConnectionParameters), Embedding.Source.OPEN_AI, model, str);
    }

    @Override // fi.evolver.ai.spring.embedding.EmbeddingApi
    public void createEmbeddings(String str, Map<String, String> map, ApiConnectionParameters apiConnectionParameters) {
        createEmbeddings(DEFAULT_EMBEDDING_MODEL, str, map, apiConnectionParameters);
    }

    @Override // fi.evolver.ai.spring.embedding.EmbeddingApi
    public EmbeddingCache fetchEmbeddings(Model model, String str) {
        return this.embeddingService.fetchEmbeddings(model, str);
    }

    @Override // fi.evolver.ai.spring.embedding.EmbeddingApi
    public EmbeddingCache fetchEmbeddings(String str) {
        return fetchEmbeddings(DEFAULT_EMBEDDING_MODEL, str);
    }

    @Override // fi.evolver.ai.spring.embedding.EmbeddingApi
    public List<String> findMatches(String str, EmbeddingCache embeddingCache, int i, ApiConnectionParameters apiConnectionParameters) {
        if (embeddingCache == null) {
            throw new ApiResponseException("Missing embedding cache", new Object[0]);
        }
        Optional<EmbeddingData> findFirst = getEmbeddings(embeddingCache.getModel(), Collections.singletonMap("data", str), apiConnectionParameters).stream().findFirst();
        if (findFirst.isPresent()) {
            return this.embeddingService.findClosestMatches(findFirst.get(), embeddingCache, i);
        }
        LOG.warn("Failed generating embedding for input");
        return Collections.emptyList();
    }

    private static Map<String, String> parseHeaderString() {
        HashMap hashMap = new HashMap();
        if (AI_CONNECTION_HEADERS != null) {
            for (String str : split(AI_CONNECTION_HEADERS, ',')) {
                String[] split = split(str, '=');
                if (split.length == 2) {
                    hashMap.put(split[0], split[1]);
                }
            }
        }
        return hashMap;
    }

    private static String[] split(String str, char c) {
        return str != null ? str.trim().split("\\s*%s\\s*".formatted(Character.valueOf(c)), 2) : new String[0];
    }

    private OEmbeddingsResult makeEmbeddingsRequest(Model model, String str, ApiConnectionParameters apiConnectionParameters) {
        LoggingHttpClient loggingHttpClient = new LoggingHttpClient(this.messageLogService, HttpClient.newBuilder().connectTimeout(apiConnectionParameters.connectionTimeout()).build());
        ModelParameterConfig modelParameters = getModelParameters(model);
        HttpRequest.Builder POST = HttpRequest.newBuilder(prepareUri(Arrays.asList(apiConnectionParameters.endpointUrl(), modelParameters.endpoint(), ENDPOINT_URL_EMBEDDINGS))).header("Content-Type", "application/json").timeout(apiConnectionParameters.readTimeout()).POST(HttpRequest.BodyPublishers.ofString(str));
        Map<String, String> prepareHeaders = prepareHeaders(apiConnectionParameters, modelParameters.headers());
        Objects.requireNonNull(POST);
        prepareHeaders.forEach(POST::header);
        try {
            return (OEmbeddingsResult) Json.OBJECT_MAPPER.readValue((String) loggingHttpClient.send(POST.build(), HttpResponse.BodyHandlers.ofString()).body(), OEmbeddingsResult.class);
        } catch (IOException e) {
            throw new ApiResponseException(e, "Failed making embeddings request", new Object[0]);
        } catch (InterruptedException e2) {
            throw new ApiResponseException(e2, "Interrupted while making embeddings request", new Object[0]);
        }
    }

    private static String generateEmbeddingsRequest(Model model, List<String> list) {
        try {
            HashMap hashMap = new HashMap();
            hashMap.put(OpenAiRequestParameters.MODEL, model.getEngine());
            hashMap.put("input", list.toArray());
            return Json.OBJECT_MAPPER.writeValueAsString(hashMap);
        } catch (JsonProcessingException e) {
            throw new UncheckedIOException(e);
        }
    }

    public List<EmbeddingData> getEmbeddings(Model model, Map<String, String> map, ApiConnectionParameters apiConnectionParameters) {
        EmbeddingModelConfig embeddingModelConfig = EMBEDDING_MODEL_CONFIG.get(model);
        if (embeddingModelConfig == null) {
            throw new ApiResponseException("Missing embedding configuration for model %s".formatted(model), new Object[0]);
        }
        if (map.isEmpty()) {
            return Collections.emptyList();
        }
        ArrayList arrayList = new ArrayList();
        int i = 0;
        Encoding encoding = ENCODING_REGISTRY.getEncoding(embeddingModelConfig.encoding());
        ArrayList arrayList2 = new ArrayList();
        for (Map.Entry<String, String> entry : map.entrySet()) {
            String value = entry.getValue();
            if (value == null || value.isEmpty()) {
                throw new IllegalArgumentException("Cannot create embedding for empty string");
            }
            List encode = encoding.encode(value);
            if (i + encode.size() > embeddingModelConfig.maxTokens().intValue() || arrayList2.size() == 16) {
                if (encode.size() > embeddingModelConfig.maxTokens().intValue()) {
                    throw new IllegalArgumentException("Text too long for embedding: " + value);
                }
                arrayList.addAll(getEmbeddingsBatch(model, arrayList2, apiConnectionParameters));
                arrayList2.clear();
                i = 0;
            }
            arrayList2.add(new EmbeddingBatchEntry(entry.getKey(), value));
            i += encode.size();
        }
        arrayList.addAll(getEmbeddingsBatch(model, arrayList2, apiConnectionParameters));
        return arrayList;
    }

    private List<EmbeddingData> getEmbeddingsBatch(Model model, List<EmbeddingBatchEntry> list, ApiConnectionParameters apiConnectionParameters) {
        ArrayList arrayList = new ArrayList();
        OEmbeddingsResult makeEmbeddingsRequest = makeEmbeddingsRequest(model, generateEmbeddingsRequest(model, (List) list.stream().map((v0) -> {
            return v0.data();
        }).collect(Collectors.toList())), apiConnectionParameters);
        for (int i = 0; i < makeEmbeddingsRequest.data().size(); i++) {
            EmbeddingBatchEntry embeddingBatchEntry = list.get(i);
            arrayList.add(new EmbeddingData(embeddingBatchEntry.identifier, EmbeddingService.calculateHash(embeddingBatchEntry.data), embeddingBatchEntry.data, makeEmbeddingsRequest.data().get(i).embedding()));
        }
        return arrayList;
    }

    private static ModelParameterConfig getModelParameters(Model model) {
        ModelParameterConfig modelParameterConfig = MODEL_PARAMETERS.get(model);
        if (modelParameterConfig != null) {
            return modelParameterConfig;
        }
        String str = System.getenv("MODEL_PARAMETERS_%s".formatted(model.getEngine().toUpperCase().replaceAll("\\W", "_")));
        try {
            ModelParameterConfig empty = str != null ? (ModelParameterConfig) Json.OBJECT_MAPPER.readValue(str, ModelParameterConfig.class) : ModelParameterConfig.empty();
            MODEL_PARAMETERS.put(model, empty);
            return empty;
        } catch (JsonProcessingException e) {
            throw new ApiResponseException(e, "Failed parsing model parameters: {}", str);
        }
    }
}
