diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java index 1f98219d0..4b07a17c3 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java @@ -48,6 +48,7 @@ import dev.langchain4j.rag.query.Metadata; import dev.langchain4j.service.AiServiceContext; import dev.langchain4j.service.AiServiceTokenStream; +import dev.langchain4j.service.Result; import dev.langchain4j.service.TokenStream; import dev.langchain4j.service.output.ServiceOutputParser; import dev.langchain4j.service.tool.ToolExecutor; @@ -134,7 +135,7 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob boolean needsMemorySeed = needsMemorySeed(context, memoryId); // we need to know figure this out before we add the system and user message Type returnType = methodCreateInfo.getReturnType(); - AugmentationResult augmentationResult; + AugmentationResult augmentationResult = null; if (context.retrievalAugmentor != null) { List chatMemory = context.hasChatMemory() ? context.chatMemory(memoryId).messages() @@ -276,7 +277,17 @@ private List messagesToSend(ChatMessage augmentedUserMessage, chatMemory.commit(); response = Response.from(response.content(), tokenUsageAccumulator, response.finishReason()); - return SERVICE_OUTPUT_PARSER.parse(response, returnType); + if (isResult(returnType)) { + var parsedResponse = SERVICE_OUTPUT_PARSER.parse(response, resultTypeParam((ParameterizedType) returnType)); + return Result.builder() + .content(parsedResponse) + .tokenUsage(tokenUsageAccumulator) + .sources(augmentationResult == null ? null : augmentationResult.contents()) + .finishReason(response.finishReason()) + .build(); + } else { + return SERVICE_OUTPUT_PARSER.parse(response, returnType); + } } private static boolean needsMemorySeed(QuarkusAiServiceContext context, Object memoryId) { @@ -343,6 +354,17 @@ private static boolean isMulti(Type returnType) { return isTypeOf(returnType, Multi.class); } + private static boolean isResult(Type returnType) { + return isTypeOf(returnType, Result.class); + } + + private static Type resultTypeParam(ParameterizedType returnType) { + if (!isTypeOf(returnType, Result.class)) { + throw new IllegalStateException("Can only be called with Result type"); + } + return returnType.getActualTypeArguments()[0]; + } + private static boolean isTypeOf(Type type, Class clazz) { if (type instanceof Class) { return type.equals(clazz); diff --git a/model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/AiServicesTest.java b/model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/AiServicesTest.java index 8c9eeb884..97b286c8c 100644 --- a/model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/AiServicesTest.java +++ b/model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/AiServicesTest.java @@ -9,6 +9,7 @@ import static dev.langchain4j.data.message.ChatMessageType.AI; import static dev.langchain4j.data.message.ChatMessageType.SYSTEM; import static dev.langchain4j.data.message.ChatMessageType.USER; +import static dev.langchain4j.data.message.UserMessage.userMessage; import static java.time.Month.JULY; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -44,11 +45,13 @@ import dev.langchain4j.model.input.structured.StructuredPrompt; import dev.langchain4j.model.openai.OpenAiChatModel; import dev.langchain4j.model.openai.OpenAiModerationModel; +import dev.langchain4j.model.output.TokenUsage; import dev.langchain4j.model.output.structured.Description; import dev.langchain4j.service.AiServices; import dev.langchain4j.service.MemoryId; import dev.langchain4j.service.Moderate; import dev.langchain4j.service.ModerationException; +import dev.langchain4j.service.Result; import dev.langchain4j.service.SystemMessage; import dev.langchain4j.service.UserMessage; import dev.langchain4j.service.V; @@ -901,6 +904,38 @@ public void deleteMessages(Object memoryId) { tuple(USER, secondsMessageFromSecondUser), tuple(AI, secondAiMessageToSecondUser)); } + interface AssistantReturningResult { + + Result chat(String userMessage); + } + + @Test + void should_return_result() throws IOException { + setChatCompletionMessageContent("Berlin is the capital of Germany"); + + // given + AssistantReturningResult assistant = AiServices.create(AssistantReturningResult.class, createChatModel()); + + String userMessage = "What is the capital of Germany?"; + + // when + Result result = assistant.chat(userMessage); + + // then + assertThat(result.content()).containsIgnoringCase("Berlin"); + + TokenUsage tokenUsage = result.tokenUsage(); + assertThat(tokenUsage).isNotNull(); + assertThat(tokenUsage.inputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.outputTokenCount()).isGreaterThan(0); + assertThat(tokenUsage.totalTokenCount()) + .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); + + assertThat(result.sources()).isNull(); + + assertSingleRequestMessage(getRequestAsMap(), "What is the capital of Germany?"); + } + static class Calculator { private final Runnable after;