package ch.rasc.openai4j.chatcompletions.service;

import ch.rasc.openai4j.chatcompletions.AssistantMessage;
import ch.rasc.openai4j.chatcompletions.ChatCompletionCreateRequest;
import ch.rasc.openai4j.chatcompletions.ChatCompletionMessage;
import ch.rasc.openai4j.chatcompletions.ChatCompletionResponse;
import ch.rasc.openai4j.chatcompletions.ChatCompletionTool;
import ch.rasc.openai4j.chatcompletions.ChatCompletionsClient;
import ch.rasc.openai4j.chatcompletions.SystemMessage;
import ch.rasc.openai4j.chatcompletions.ToolMessage;
import ch.rasc.openai4j.chatcompletions.UserMessage;
import ch.rasc.openai4j.chatcompletions.service.ChatCompletionsJavaFunctionRequest;
import ch.rasc.openai4j.chatcompletions.service.ChatCompletionsModelRequest;
import ch.rasc.openai4j.common.FunctionParameters;
import ch.rasc.openai4j.common.ToolCall;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.github.victools.jsonschema.generator.OptionPreset;
import com.github.victools.jsonschema.generator.SchemaGenerator;
import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder;
import com.github.victools.jsonschema.generator.SchemaVersion;
import com.github.victools.jsonschema.module.jackson.JacksonModule;
import com.github.victools.jsonschema.module.jackson.JacksonOption;
import jakarta.validation.ConstraintViolation;
import jakarta.validation.Validation;
import jakarta.validation.Validator;
import jakarta.validation.ValidatorFactory;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.Type;
import java.lang.runtime.ObjectMethods;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import org.hibernate.validator.messageinterpolation.ParameterMessageInterpolator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ch/rasc/openai4j/chatcompletions/service/ChatCompletionsService.class */
public class ChatCompletionsService {
    private static final Logger log = LoggerFactory.getLogger(ChatCompletionsService.class);
    private final SchemaGenerator schemaGenerator;
    private final ChatCompletionsClient chatCompletionsClient;
    private final ObjectMapper objectMapper;
    private final Validator validator;

    /* loaded from: input_file:ch/rasc/openai4j/chatcompletions/service/ChatCompletionsService$ChatCompletionsModelResponse.class */
    public static final class ChatCompletionsModelResponse<T> extends Record {
        private final ChatCompletionResponse response;
        private final T responseModel;
        private final String error;

        public ChatCompletionsModelResponse(ChatCompletionResponse chatCompletionResponse, T t, String str) {
            this.response = chatCompletionResponse;
            this.responseModel = t;
            this.error = str;
        }

        public ChatCompletionResponse response() {
            return this.response;
        }

        public T responseModel() {
            return this.responseModel;
        }

        public String error() {
            return this.error;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, ChatCompletionsModelResponse.class), ChatCompletionsModelResponse.class, "response;responseModel;error", "FIELD:Lch/rasc/openai4j/chatcompletions/service/ChatCompletionsService$ChatCompletionsModelResponse;->response:Lch/rasc/openai4j/chatcompletions/ChatCompletionResponse;", "FIELD:Lch/rasc/openai4j/chatcompletions/service/ChatCompletionsService$ChatCompletionsModelResponse;->responseModel:Ljava/lang/Object;", "FIELD:Lch/rasc/openai4j/chatcompletions/service/ChatCompletionsService$ChatCompletionsModelResponse;->error:Ljava/lang/String;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, ChatCompletionsModelResponse.class), ChatCompletionsModelResponse.class, "response;responseModel;error", "FIELD:Lch/rasc/openai4j/chatcompletions/service/ChatCompletionsService$ChatCompletionsModelResponse;->response:Lch/rasc/openai4j/chatcompletions/ChatCompletionResponse;", "FIELD:Lch/rasc/openai4j/chatcompletions/service/ChatCompletionsService$ChatCompletionsModelResponse;->responseModel:Ljava/lang/Object;", "FIELD:Lch/rasc/openai4j/chatcompletions/service/ChatCompletionsService$ChatCompletionsModelResponse;->error:Ljava/lang/String;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, ChatCompletionsModelResponse.class, Object.class), ChatCompletionsModelResponse.class, "response;responseModel;error", "FIELD:Lch/rasc/openai4j/chatcompletions/service/ChatCompletionsService$ChatCompletionsModelResponse;->response:Lch/rasc/openai4j/chatcompletions/ChatCompletionResponse;", "FIELD:Lch/rasc/openai4j/chatcompletions/service/ChatCompletionsService$ChatCompletionsModelResponse;->responseModel:Ljava/lang/Object;", "FIELD:Lch/rasc/openai4j/chatcompletions/service/ChatCompletionsService$ChatCompletionsModelResponse;->error:Ljava/lang/String;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }
    }

    public ChatCompletionsService(ChatCompletionsClient chatCompletionsClient, ObjectMapper objectMapper) {
        this.schemaGenerator = new SchemaGenerator(new SchemaGeneratorConfigBuilder(SchemaVersion.DRAFT_2020_12, OptionPreset.PLAIN_JSON).with(new JacksonModule(new JacksonOption[]{JacksonOption.RESPECT_JSONPROPERTY_REQUIRED})).build());
        this.chatCompletionsClient = chatCompletionsClient;
        this.objectMapper = objectMapper;
        ValidatorFactory buildValidatorFactory = Validation.byDefaultProvider().configure().messageInterpolator(new ParameterMessageInterpolator()).buildValidatorFactory();
        try {
            this.validator = buildValidatorFactory.getValidator();
            if (buildValidatorFactory != null) {
                buildValidatorFactory.close();
            }
        } catch (Throwable th) {
            if (buildValidatorFactory != null) {
                try {
                    buildValidatorFactory.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public ChatCompletionsService(ChatCompletionsClient chatCompletionsClient) {
        this(chatCompletionsClient, new ObjectMapper());
    }

    public ChatCompletionResponse createJavaFunctions(Function<ChatCompletionsJavaFunctionRequest.Builder, ChatCompletionsJavaFunctionRequest.Builder> function) throws JsonProcessingException {
        ChatCompletionsJavaFunctionRequest build = function.apply(ChatCompletionsJavaFunctionRequest.builder()).build();
        HashMap hashMap = new HashMap();
        ArrayList arrayList = new ArrayList();
        for (JavaFunction<?, ?> javaFunction : build.javaFunctions()) {
            hashMap.put(javaFunction.name(), javaFunction);
            arrayList.add(javaFunction.toTool(this.schemaGenerator));
        }
        ChatCompletionCreateRequest.Builder convertToChatCompletionsCreateRequestBuilder = build.convertToChatCompletionsCreateRequestBuilder();
        ChatCompletionCreateRequest build2 = convertToChatCompletionsCreateRequestBuilder.tools(arrayList).build();
        ChatCompletionResponse create = this.chatCompletionsClient.create(build2);
        ArrayList arrayList2 = new ArrayList(build2.messages());
        ChatCompletionResponse.Choice choice = create.choices().get(0);
        int i = 1;
        while (choice.finishReason() == ChatCompletionResponse.Choice.FinishReason.TOOL_CALLS) {
            if (i > build.maxIterations().intValue()) {
                log.debug("Max iterations reached");
                return create;
            }
            log.debug("Iteration {}", Integer.valueOf(i));
            ChatCompletionResponse.Message message = choice.message();
            arrayList2.add(AssistantMessage.of(choice.message()));
            for (ToolCall toolCall : message.toolCalls()) {
                JavaFunction javaFunction2 = (JavaFunction) hashMap.get(toolCall.function().name());
                if (javaFunction2 == null) {
                    throw new IllegalStateException("Unknown function " + toolCall.function().name());
                }
                Object readValue = this.objectMapper.readValue(toolCall.function().arguments(), javaFunction2.parameterClass());
                log.debug("Calling function {}", javaFunction2.name());
                log.debug("with argument {}", readValue);
                Object call = javaFunction2.call(readValue);
                if (call != null) {
                    arrayList2.add(ToolMessage.of(toolCall.id(), this.objectMapper.writeValueAsString(call)));
                } else {
                    arrayList2.add(ToolMessage.of(toolCall.id(), null));
                }
            }
            create = this.chatCompletionsClient.create(convertToChatCompletionsCreateRequestBuilder.messages(arrayList2).build());
            i++;
        }
        return create;
    }

    public <T> ChatCompletionsModelResponse<T> createModel(Function<ChatCompletionsModelRequest.Builder<T>, ChatCompletionsModelRequest.Builder<T>> function) {
        ArrayList arrayList;
        ChatCompletionsModelRequest<T> build = function.apply(ChatCompletionsModelRequest.builder()).build();
        ChatCompletionCreateRequest.Builder convertToChatCompletionsCreateRequestBuilder = build.convertToChatCompletionsCreateRequestBuilder();
        ObjectNode generateSchema = this.schemaGenerator.generateSchema(build.responseModel(), new Type[0]);
        String simpleName = build.responseModel().getSimpleName();
        if (build.mode() == ChatCompletionsModelRequest.Mode.JSON) {
            convertToChatCompletionsCreateRequestBuilder.responseFormat(ChatCompletionCreateRequest.ResponseFormat.jsonObject());
            String str = "Make sure that your response to any message matches the json_schema below, do not deviate at all: \n" + String.valueOf(generateSchema);
            List<ChatCompletionMessage> messages = build.messages();
            arrayList = new ArrayList();
            if (!messages.isEmpty()) {
                ChatCompletionMessage chatCompletionMessage = messages.get(0);
                if (chatCompletionMessage instanceof SystemMessage) {
                    SystemMessage of = SystemMessage.of(((SystemMessage) chatCompletionMessage).content() + "\n\n" + str);
                    arrayList.add(of);
                    arrayList.addAll(messages.subList(1, messages.size()));
                    log.debug("Replacing system message: {}", of.content());
                }
            }
            arrayList.add(SystemMessage.of(str));
            arrayList.addAll(messages);
            log.debug("Adding system message: {}", str);
        } else {
            arrayList = new ArrayList(build.messages());
            JsonNode jsonNode = generateSchema.get("description");
            List<ChatCompletionTool> of2 = List.of(ChatCompletionTool.of(FunctionParameters.of(simpleName, jsonNode != null ? jsonNode.textValue() : null, generateSchema)));
            convertToChatCompletionsCreateRequestBuilder.tools(of2);
            convertToChatCompletionsCreateRequestBuilder.toolChoice(ChatCompletionCreateRequest.ToolChoice.function(simpleName));
            log.debug("Adding tool: {}", of2);
        }
        ChatCompletionResponse chatCompletionResponse = null;
        for (int i = 0; i < build.maxRetries().intValue(); i++) {
            log.debug("Retry {}", Integer.valueOf(i));
            chatCompletionResponse = this.chatCompletionsClient.create(convertToChatCompletionsCreateRequestBuilder.messages(arrayList).build());
            ChatCompletionResponse.Choice choice = chatCompletionResponse.choices().get(0);
            if (choice.finishReason() != ChatCompletionResponse.Choice.FinishReason.STOP) {
                return new ChatCompletionsModelResponse<>(chatCompletionResponse, null, "finish reason not STOP");
            }
            try {
                Object obj = null;
                if (build.mode() == ChatCompletionsModelRequest.Mode.JSON) {
                    obj = this.objectMapper.readValue(choice.message().content(), build.responseModel());
                } else {
                    ToolCall toolCall = choice.message().toolCalls().get(0);
                    if (toolCall.function().name().equals(simpleName)) {
                        obj = this.objectMapper.readValue(toolCall.function().arguments(), build.responseModel());
                    } else {
                        String str2 = "Recall the correct function, function " + toolCall.function().name() + " does not exist";
                        arrayList.add(AssistantMessage.of(choice.message()));
                        arrayList.add(UserMessage.of(str2));
                    }
                }
                if (obj != null) {
                    Set<ConstraintViolation> validate = this.validator.validate(obj, new Class[0]);
                    if (validate.isEmpty()) {
                        return new ChatCompletionsModelResponse<>(chatCompletionResponse, obj, null);
                    }
                    StringBuilder sb = build.mode() == ChatCompletionsModelRequest.Mode.JSON ? new StringBuilder("Validation errors found\n") : new StringBuilder("Recall the function correctly, validation errors found\n");
                    for (ConstraintViolation constraintViolation : validate) {
                        sb.append(constraintViolation.getPropertyPath()).append(": ").append(constraintViolation.getMessage()).append("\n");
                    }
                    arrayList.add(AssistantMessage.of(choice.message()));
                    arrayList.add(UserMessage.of(sb.toString()));
                    log.debug("Adding validation error user message: {}", sb);
                }
            } catch (JsonProcessingException e) {
                String str3 = build.mode() == ChatCompletionsModelRequest.Mode.JSON ? "Could not deserialize response\n" + e.getMessage() : "Recall the function correctly, exceptions during deserialization found\n" + e.getMessage();
                arrayList.add(AssistantMessage.of(choice.message()));
                arrayList.add(UserMessage.of(str3));
                log.debug("Adding deserialization error user message: {}", str3);
            }
        }
        return new ChatCompletionsModelResponse<>(chatCompletionResponse, null, "max retries reached");
    }
}
