diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java index 4c5836466..197692c6f 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java @@ -44,10 +44,10 @@ import dev.langchain4j.data.message.ToolExecutionResultMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.request.ChatRequest; import dev.langchain4j.model.chat.request.ResponseFormat; import dev.langchain4j.model.chat.request.json.JsonSchema; -import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.input.structured.StructuredPrompt; @@ -337,32 +337,7 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage, log.debug("Attempting to obtain AI response"); - Optional<JsonSchema> jsonSchema = Optional.empty(); - if (supportsJsonSchema) { - jsonSchema = methodCreateInfo.getResponseSchemaInfo().structuredOutputSchema(); - } - - Response<AiMessage> response; - if (jsonSchema.isPresent()) { - ChatRequest chatRequest = ChatRequest.builder() - .messages(messagesToSend) - .toolSpecifications(toolSpecifications) - .responseFormat(ResponseFormat.builder() - .type(JSON) - .jsonSchema(jsonSchema.get()) - .build()) - .build(); - - ChatResponse chatResponse = context.chatModel.chat(chatRequest); - response = new Response<>( - chatResponse.aiMessage(), - chatResponse.tokenUsage(), - chatResponse.finishReason()); - } else { - response = toolSpecifications == null - ? context.chatModel.generate(messagesToSend) - : context.chatModel.generate(messagesToSend, toolSpecifications); - } + var response = executeRequest(context, methodCreateInfo, messagesToSend, toolSpecifications); log.debug("AI response obtained"); if (audit != null) { @@ -450,6 +425,46 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage, methodCreateInfo, responseAugmenterParam); } + private static Response<AiMessage> executeRequest(JsonSchema jsonSchema, List<ChatMessage> messagesToSend, + ChatLanguageModel chatModel, List<ToolSpecification> toolSpecifications) { + var chatRequest = ChatRequest.builder() + .messages(messagesToSend) + .toolSpecifications(toolSpecifications) + .responseFormat( + ResponseFormat.builder() + .type(JSON) + .jsonSchema(jsonSchema) + .build()) + .build(); + + var response = chatModel.chat(chatRequest); + + return new Response<>( + response.aiMessage(), + response.tokenUsage(), + response.finishReason()); + } + + private static Response<AiMessage> executeRequest(List<ChatMessage> messagesToSend, ChatLanguageModel chatModel, + List<ToolSpecification> toolSpecifications) { + return (toolSpecifications == null) ? chatModel.generate(messagesToSend) + : chatModel.generate(messagesToSend, toolSpecifications); + } + + static Response<AiMessage> executeRequest(AiServiceMethodCreateInfo methodCreateInfo, List<ChatMessage> messagesToSend, + ChatLanguageModel chatModel, List<ToolSpecification> toolSpecifications) { + var jsonSchema = supportsJsonSchema(chatModel) ? methodCreateInfo.getResponseSchemaInfo().structuredOutputSchema() + : Optional.<JsonSchema> empty(); + + return jsonSchema.isPresent() ? executeRequest(jsonSchema.get(), messagesToSend, chatModel, toolSpecifications) + : executeRequest(messagesToSend, chatModel, toolSpecifications); + } + + static Response<AiMessage> executeRequest(QuarkusAiServiceContext context, AiServiceMethodCreateInfo methodCreateInfo, + List<ChatMessage> messagesToSend, List<ToolSpecification> toolSpecifications) { + return executeRequest(methodCreateInfo, messagesToSend, context.chatModel, toolSpecifications); + } + private static Object doImplementGenerateImage(AiServiceMethodCreateInfo methodCreateInfo, QuarkusAiServiceContext context, Audit audit, Optional<SystemMessage> systemMessage, UserMessage userMessage, Object memoryId, Type returnType, Map<String, Object> templateVariables) { @@ -547,9 +562,12 @@ private static List<ChatMessage> createMessagesToSendForNoMemory(Optional<System return result; } + private static boolean supportsJsonSchema(ChatLanguageModel chatModel) { + return (chatModel != null) && chatModel.supportedCapabilities().contains(RESPONSE_FORMAT_JSON_SCHEMA); + } + private static boolean supportsJsonSchema(AiServiceContext context) { - return context.chatModel != null - && context.chatModel.supportedCapabilities().contains(RESPONSE_FORMAT_JSON_SCHEMA); + return supportsJsonSchema(context.chatModel); } private static Future<Moderation> triggerModerationIfNeeded(AiServiceContext context, diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java index a23f3a0e1..5ba5e9698 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java @@ -81,24 +81,18 @@ public static Response<AiMessage> invokeOutputGuardrails(AiServiceMethodCreateIn if (!result.isSuccess()) { if (!result.isRetry()) { throw new GuardrailException(result.toString(), result.getFirstFailureException()); - } else if (result.getReprompt() != null) { - // Retry with re-prompting - chatMemory.add(userMessage(result.getReprompt())); - if (toolSpecifications == null) { - response = chatModel.generate(chatMemory.messages()); - } else { - response = chatModel.generate(chatMemory.messages(), toolSpecifications); - } - chatMemory.add(response.content()); } else { - // Retry without re-prompting - if (toolSpecifications == null) { - response = chatModel.generate(chatMemory.messages()); - } else { - response = chatModel.generate(chatMemory.messages(), toolSpecifications); + // Retry + if (result.getReprompt() != null) { + // Retry with reprompting + chatMemory.add(userMessage(result.getReprompt())); } + + response = AiServiceMethodImplementationSupport.executeRequest(methodCreateInfo, chatMemory.messages(), + chatModel, toolSpecifications); chatMemory.add(response.content()); } + attempt++; output = new OutputGuardrailParams(response.content(), output.memory(), output.augmentationResult(), output.userMessageTemplate(), output.variables());