From b33c558704433581427cc093f338cfb43097bbf8 Mon Sep 17 00:00:00 2001 From: Clement Escoffier Date: Thu, 31 Oct 2024 14:54:51 +0100 Subject: [PATCH] Allows methods annotated with @Tool to used @Blocking / @NonBlocking and @RunOnVirtualThread --- core/deployment/pom.xml | 5 + .../deployment/AiServicesProcessor.java | 26 +- .../langchain4j/deployment/DotNames.java | 2 + .../langchain4j/deployment/ToolProcessor.java | 12 + .../deployment/items/ToolMethodBuildItem.java | 6 +- ...reamingAndRequestScopePropagationTest.java | 427 ++++++++++++++++++ .../ToolExecutionModelWithStreamingTest.java | 3 +- .../AiServiceMethodImplementationSupport.java | 30 +- ...rkusAiServiceStreamingResponseHandler.java | 168 +++++++ .../QuarkusAiServiceTokenStream.java | 155 +++++++ docs/modules/ROOT/pages/agent-and-tools.adoc | 81 ++++ integration-tests/pom.xml | 1 + integration-tests/tools/pom.xml | 124 +++++ .../main/java/org/acme/tools/AiService.java | 11 + .../main/java/org/acme/tools/Calculator.java | 30 ++ .../src/main/resources/application.properties | 4 + .../test/java/org/acme/tools/ToolsTest.java | 80 ++++ 17 files changed, 1146 insertions(+), 19 deletions(-) create mode 100644 core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingAndRequestScopePropagationTest.java create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceStreamingResponseHandler.java create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceTokenStream.java create mode 100644 integration-tests/tools/pom.xml create mode 100644 integration-tests/tools/src/main/java/org/acme/tools/AiService.java create mode 100644 integration-tests/tools/src/main/java/org/acme/tools/Calculator.java create mode 100644 integration-tests/tools/src/main/resources/application.properties create mode 100644 integration-tests/tools/src/test/java/org/acme/tools/ToolsTest.java diff --git a/core/deployment/pom.xml b/core/deployment/pom.xml index 3153acca2..ab47f8411 100644 --- a/core/deployment/pom.xml +++ b/core/deployment/pom.xml @@ -102,6 +102,11 @@ + + io.quarkus + quarkus-test-vertx + test + diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index b8b248f20..08057d89f 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -8,6 +8,7 @@ import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.NO_RETRIEVAL_AUGMENTOR_SUPPLIER; import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.NO_RETRIEVER; import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.OUTPUT_GUARDRAILS; +import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.REGISTER_AI_SERVICES; import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.SEED_MEMORY; import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.V; import static io.quarkiverse.langchain4j.deployment.MethodParameterAsTemplateVariableAllowance.FORCE_ALLOW; @@ -31,7 +32,6 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; -import java.util.concurrent.CompletionStage; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; @@ -778,9 +778,9 @@ public boolean detectAiServiceMethodThanNeedToBeDispatchedOnWorkerThread( MethodInfo method, List associatedTools, List tools) { - boolean reactive = method.returnType().name().equals(DotName.createSimple(Uni.class.getName())) - || method.returnType().name().equals(DotName.createSimple(CompletionStage.class.getName())) - || method.returnType().name().equals(DotName.createSimple(Multi.class.getName())); + boolean reactive = method.returnType().name().equals(DotNames.UNI) + || method.returnType().name().equals(DotNames.COMPLETION_STAGE) + || method.returnType().name().equals(DotNames.MULTI); boolean requireSwitchToWorkerThread = false; @@ -1260,8 +1260,8 @@ private AiServiceMethodCreateInfo gatherMethodMetadata( String accumulatorClassName = AiServicesMethodBuildItem.gatherAccumulator(method); - boolean switchToWorkerThread = detectAiServiceMethodThanNeedToBeDispatchedOnWorkerThread(method, methodToolClassNames, - tools); + // Detect if tools execution may block the caller thread. + boolean switchToWorkerThread = detectIfToolExecutionRequiresAWorkerThread(method, tools, methodToolClassNames); return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo, userMessageInfo, memoryIdParamPosition, requiresModeration, @@ -1270,6 +1270,20 @@ private AiServiceMethodCreateInfo gatherMethodMetadata( inputGuardrails, outputGuardrails, accumulatorClassName); } + private boolean detectIfToolExecutionRequiresAWorkerThread(MethodInfo method, List tools, + List methodToolClassNames) { + List allTools = new ArrayList<>(methodToolClassNames); + // We need to combine it with the tools that are registered globally - unfortunately, we don't have access to the AI service here, so, re-parsing. + AnnotationInstance annotation = method.declaringClass().annotation(REGISTER_AI_SERVICES); + if (annotation != null) { + AnnotationValue value = annotation.value("tools"); + if (value != null) { + allTools.addAll(Arrays.stream(value.asClassArray()).map(t -> t.name().toString()).toList()); + } + } + return detectAiServiceMethodThanNeedToBeDispatchedOnWorkerThread(method, allTools, tools); + } + private void validateReturnType(MethodInfo method) { Type returnType = method.returnType(); Type.Kind returnTypeKind = returnType.kind(); diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java index b7a9dd1dd..25d8571a8 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java @@ -10,6 +10,7 @@ import org.jboss.jandex.DotName; +import dev.langchain4j.agent.tool.Tool; import dev.langchain4j.model.chat.listener.ChatModelListener; import io.smallrye.common.annotation.Blocking; import io.smallrye.common.annotation.NonBlocking; @@ -54,4 +55,5 @@ public class DotNames { public static final DotName CDI_INSTANCE = DotName.createSimple(Instance.class); public static final DotName CHAT_MODEL_LISTENER = DotName.createSimple(ChatModelListener.class); + public static final DotName TOOL = DotName.createSimple(Tool.class); } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java index 8dcbced00..16384449f 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java @@ -25,6 +25,7 @@ import java.util.Map; import java.util.Objects; import java.util.function.BiFunction; +import java.util.function.Predicate; import java.util.stream.Collectors; import org.jboss.jandex.AnnotationInstance; @@ -68,6 +69,7 @@ import io.quarkus.deployment.builditem.CombinedIndexBuildItem; import io.quarkus.deployment.builditem.GeneratedClassBuildItem; import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem; +import io.quarkus.deployment.execannotations.ExecutionModelAnnotationsAllowedBuildItem; import io.quarkus.deployment.recording.RecorderContext; import io.quarkus.gizmo.ClassCreator; import io.quarkus.gizmo.ClassOutput; @@ -263,6 +265,16 @@ public void handleTools( toolsMetadataProducer.produce(new ToolsMetadataBeforeRemovalBuildItem(metadata)); } + @BuildStep + ExecutionModelAnnotationsAllowedBuildItem toolsMethods() { + return new ExecutionModelAnnotationsAllowedBuildItem(new Predicate() { + @Override + public boolean test(MethodInfo method) { + return method.hasDeclaredAnnotation(DotNames.TOOL); + } + }); + } + private void validateExecutionModel(ToolMethodCreateInfo methodCreateInfo, MethodInfo toolMethod, BuildProducer validation) { String methodName = toolMethod.declaringClass().name() + "." + toolMethod.name(); diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/ToolMethodBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/ToolMethodBuildItem.java index b01d6943d..9c1f25b34 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/ToolMethodBuildItem.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/ToolMethodBuildItem.java @@ -6,6 +6,10 @@ import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo; import io.quarkus.builder.item.MultiBuildItem; +/** + * A build item that represents a method that is annotated with {@link dev.langchain4j.agent.tool.Tool}. + * It contains the method info and the tool method create info. + */ public final class ToolMethodBuildItem extends MultiBuildItem { private final MethodInfo toolsMethodInfo; @@ -33,13 +37,11 @@ public ToolMethodCreateInfo getToolMethodCreateInfo() { * Returns true if the method requires a switch to a worker thread, even if the method is non-blocking. * This is because of the tools executor limitation (imperative API). * - * * @return true if the method requires a switch to a worker thread */ public boolean requiresSwitchToWorkerThread() { return !(toolMethodCreateInfo.executionModel() == ToolMethodCreateInfo.ExecutionModel.NON_BLOCKING && isImperativeMethod()); - } private boolean isImperativeMethod() { diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingAndRequestScopePropagationTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingAndRequestScopePropagationTest.java new file mode 100644 index 000000000..116efcff4 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingAndRequestScopePropagationTest.java @@ -0,0 +1,427 @@ +package io.quarkiverse.langchain4j.test.tools; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledForJreRange; +import org.junit.jupiter.api.condition.JRE; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.testcontainers.shaded.org.awaitility.Awaitility; + +import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.ToolBox; +import io.quarkiverse.langchain4j.test.Lists; +import io.quarkus.arc.Arc; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.vertx.RunOnVertxContext; +import io.quarkus.virtual.threads.VirtualThreadsRecorder; +import io.smallrye.common.annotation.NonBlocking; +import io.smallrye.common.annotation.RunOnVirtualThread; +import io.smallrye.common.vertx.VertxContext; +import io.smallrye.mutiny.Multi; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Vertx; + +public class ToolExecutionModelWithStreamingAndRequestScopePropagationTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses(MyAiService.class, Lists.class)); + + @Inject + MyAiService aiService; + + @Inject + Vertx vertx; + + @Inject + UUIDGenerator uuidGenerator; + + @Test + @ActivateRequestContext + void testBlockingToolInvocationFromWorkerThread() { + String uuid = uuidGenerator.get(); + var r = aiService.hello("abc", "hi") + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + assertThat(r).contains(uuid, Thread.currentThread().getName()); // We are invoked on the same thread + } + + @Test + void testBlockingToolInvocationFromEventLoop() { + AtomicReference failure = new AtomicReference<>(); + AtomicReference result = new AtomicReference<>(); + AtomicReference value = new AtomicReference<>(); + var ctxt = VertxContext.getOrCreateDuplicatedContext(vertx); + ctxt.runOnContext(x -> { + try { + Arc.container().requestContext().activate(); + String uuid = uuidGenerator.get(); + value.set(uuid); + aiService.hello("abc", "hi") + .collect().asList().map(l -> String.join(" ", l)) + .subscribeAsCompletionStage() + .thenAccept(result::set) + .whenComplete((r, t) -> Arc.container().requestContext().deactivate()); + } catch (IllegalStateException e) { + failure.set(e); + Arc.container().requestContext().deactivate(); + } + }); + + // We would automatically detect this case and switch to a worker thread at subscription time. + + Awaitility.await().until(() -> failure.get() != null || result.get() != null); + assertThat(failure.get()).isNull(); + assertThat(result.get()).doesNotContain("event", "loop") + .contains(value.get(), "executor-thread"); + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_21) + void testBlockingToolInvocationFromVirtualThread() throws ExecutionException, InterruptedException { + AtomicReference value = new AtomicReference<>(); + AtomicReference caller = new AtomicReference<>(); + var r = VirtualThreadsRecorder.getCurrent().submit(() -> { + try { + Arc.container().requestContext().activate(); + value.set(uuidGenerator.get()); + caller.set(Thread.currentThread().getName()); + return aiService.hello("abc", "hi") + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + } finally { + Arc.container().requestContext().deactivate(); + } + }).get(); + + // The blocking tool is executed on the same thread + assertThat(r).contains(value.get(), "quarkus-virtual-thread-") + .contains(caller.get()); + } + + @Test + @ActivateRequestContext + void testNonBlockingToolInvocationFromWorkerThread() { + String uuid = uuidGenerator.get(); + var r = aiService.helloNonBlocking("abc", "hiNonBlocking") + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + assertThat(r).contains(uuid, Thread.currentThread().getName()); // We are invoked on the same thread + } + + @Test + @ActivateRequestContext + void testNonBlockingToolInvocationFromEventLoop() { + AtomicReference value = new AtomicReference<>(); + AtomicReference result = new AtomicReference<>(); + AtomicReference caller = new AtomicReference<>(); + + var ctxt = VertxContext.getOrCreateDuplicatedContext(vertx); + ctxt.runOnContext(x -> { + caller.set(Thread.currentThread().getName()); + Arc.container().requestContext().activate(); + value.set(uuidGenerator.get()); + aiService.helloNonBlocking("abc", "hiNonBlocking") + .collect().asList().map(l -> String.join(" ", l)) + .subscribeAsCompletionStage().thenAccept(result::set) + .whenComplete((r, t) -> Arc.container().requestContext().deactivate()); + }); + + Awaitility.await().until(() -> result.get() != null); + assertThat(result.get()).contains(value.get(), caller.get()); + } + + @Test + @ActivateRequestContext + void testNonBlockingToolInvocationFromEventLoopWhenWeSwitchToWorkerThread() { + AtomicReference value = new AtomicReference<>(); + AtomicReference result = new AtomicReference<>(); + AtomicReference caller = new AtomicReference<>(); + + var ctxt = VertxContext.getOrCreateDuplicatedContext(vertx); + ctxt.runOnContext(x -> { + caller.set(Thread.currentThread().getName()); + Arc.container().requestContext().activate(); + value.set(uuidGenerator.get()); + aiService.helloNonBlockingWithSwitch("abc", "hiNonBlocking") + .collect().asList().map(l -> String.join(" ", l)) + .subscribeAsCompletionStage().thenAccept(result::set) + .whenComplete((r, t) -> Arc.container().requestContext().deactivate()); + + }); + + Awaitility.await().until(() -> result.get() != null); + assertThat(result.get()).contains(value.get(), "executor-thread") + .doesNotContain(caller.get()); + } + + @Test + @RunOnVertxContext + @ActivateRequestContext + @EnabledForJreRange(min = JRE.JAVA_21) + void testNonBlockingToolInvocationFromVirtualThread() throws ExecutionException, InterruptedException { + String uuid = uuidGenerator.get(); + AtomicReference caller = new AtomicReference<>(); + var r = VirtualThreadsRecorder.getCurrent().submit(() -> { + caller.set(Thread.currentThread().getName()); + return aiService.helloNonBlocking("abc", "hiNonBlocking") + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + }).get(); + + // The blocking tool is executed on the same thread + assertThat(r).contains(uuid, "quarkus-virtual-thread-") + .contains(caller.get()); + } + + @Test + @ActivateRequestContext + void testUniToolInvocationFromWorkerThread() { + String uuid = uuidGenerator.get(); + var r = aiService.helloUni("abc", "hiUni") + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + assertThat(r).contains(uuid, Thread.currentThread().getName()); // We are invoked on the same thread + } + + @Test + @ActivateRequestContext + void testUniToolInvocationFromEventLoop() { + AtomicReference value = new AtomicReference<>(); + AtomicReference result = new AtomicReference<>(); + var ctxt = VertxContext.getOrCreateDuplicatedContext(vertx); + ctxt.runOnContext(x -> { + Arc.container().requestContext().activate(); + value.set(uuidGenerator.get()); + aiService.helloUni("abc", "hiUni") + .collect().asList().map(l -> String.join(" ", l)) + .subscribeAsCompletionStage() + .thenAccept(result::set) + .whenComplete((r, t) -> Arc.container().requestContext().deactivate()); + + }); + + Awaitility.await().until(() -> result.get() != null); + assertThat(result.get()).contains(value.get(), "executor-thread"); + } + + @Test + @RunOnVertxContext + @ActivateRequestContext + @EnabledForJreRange(min = JRE.JAVA_21) + void testUniToolInvocationFromVirtualThread() throws ExecutionException, InterruptedException { + String uuid = uuidGenerator.get(); + AtomicReference caller = new AtomicReference<>(); + var r = VirtualThreadsRecorder.getCurrent().submit(() -> { + caller.set(Thread.currentThread().getName()); + return aiService.helloUni("abc", "hiUni") + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + }).get(); + + // The blocking tool is executed on the same thread (synchronous emission) + assertThat(r).contains(uuid, "quarkus-virtual-thread-") + .contains(caller.get()); + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_21) + @RunOnVertxContext(runOnEventLoop = false) + @ActivateRequestContext + void testToolInvocationOnVirtualThread() { + String uuid = uuidGenerator.get(); + var r = aiService.helloVirtualTools("abc", "hiVirtualThread") + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + assertThat(r).contains(uuid, "quarkus-virtual-thread-"); + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_21) + @RunOnVertxContext + @ActivateRequestContext + void testToolInvocationOnVirtualThreadFromVirtualThread() throws ExecutionException, InterruptedException { + String uuid = uuidGenerator.get(); + AtomicReference caller = new AtomicReference<>(); + var r = VirtualThreadsRecorder.getCurrent().submit(() -> { + caller.set(Thread.currentThread().getName()); + return aiService.helloVirtualTools("abc", "hiVirtualThread") + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + }).get(); + + // At the moment, we create a virtual thread every time. + assertThat(r).contains(uuid, "quarkus-virtual-thread-") + .doesNotContain(caller.get()); + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_21) + @ActivateRequestContext + void testToolInvocationOnVirtualThreadFromEventLoop() { + AtomicReference value = new AtomicReference<>(); + AtomicReference result = new AtomicReference<>(); + var ctxt = VertxContext.getOrCreateDuplicatedContext(vertx); + ctxt.runOnContext(x -> { + Arc.container().requestContext().activate(); + value.set(uuidGenerator.get()); + aiService.helloVirtualTools("abc", "hiVirtualThread") + .collect().asList().map(l -> String.join(" ", l)) + .subscribeAsCompletionStage().thenAccept(result::set) + .whenComplete((r, t) -> Arc.container().requestContext().deactivate()); + }); + + Awaitility.await().until(() -> result.get() != null); + assertThat(result.get()).contains(value.get(), "quarkus-virtual-thread-"); + } + + @RegisterAiService(streamingChatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @ToolBox(BlockingTool.class) + Multi hello(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId); + + @ToolBox(NonBlockingTool.class) + Multi helloNonBlocking(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId); + + @ToolBox({ NonBlockingTool.class, BlockingTool.class }) + Multi helloNonBlockingWithSwitch(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId); + + @ToolBox(UniTool.class) + Multi helloUni(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId); + + @ToolBox(VirtualTool.class) + Multi helloVirtualTools(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId); + } + + @Singleton + public static class BlockingTool { + @Inject + UUIDGenerator uuidGenerator; + + @Tool + public String hi() { + return uuidGenerator.get() + " " + Thread.currentThread(); + } + } + + @Singleton + public static class NonBlockingTool { + @Inject + UUIDGenerator uuidGenerator; + + @Tool + @NonBlocking + public String hiNonBlocking() { + return uuidGenerator.get() + " " + Thread.currentThread(); + } + } + + @Singleton + public static class UniTool { + @Inject + UUIDGenerator uuidGenerator; + + @Tool + public Uni hiUni() { + return Uni.createFrom().item(() -> uuidGenerator.get() + " " + Thread.currentThread()); + } + } + + @Singleton + public static class VirtualTool { + + @Inject + UUIDGenerator uuidGenerator; + + @Tool + @RunOnVirtualThread + public String hiVirtualThread() { + return uuidGenerator.get() + " " + Thread.currentThread(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public StreamingChatLanguageModel get() { + return new MyChatModel(); + } + } + + public static class MyChatModel implements StreamingChatLanguageModel { + + @Override + public void generate(List messages, StreamingResponseHandler handler) { + throw new UnsupportedOperationException(); + } + + @Override + public void generate(List messages, List toolSpecifications, + StreamingResponseHandler handler) { + if (messages.size() == 1) { + // Only the user message, extract the tool id from it + String text = ((dev.langchain4j.data.message.UserMessage) messages.get(0)).singleText(); + // Only the user message + handler.onComplete(new Response<>(new AiMessage("cannot be blank", List.of(ToolExecutionRequest.builder() + .id("my-tool-" + text) + .name(text) + .arguments("{}") + .build())), new TokenUsage(0, 0), FinishReason.TOOL_EXECUTION)); + } else if (messages.size() == 3) { + // user -> tool request -> tool response + ToolExecutionResultMessage last = (ToolExecutionResultMessage) Lists.last(messages); + handler.onNext("response: "); + handler.onNext(last.text()); + handler.onComplete(new Response<>(new AiMessage(""), new TokenUsage(0, 0), FinishReason.STOP)); + + } else { + handler.onError(new RuntimeException("Invalid number of messages")); + } + } + + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return MessageWindowChatMemory.withMaxMessages(10); + } + }; + } + } + + @RequestScoped + public static class UUIDGenerator { + private final String uuid = UUID.randomUUID().toString(); + + public String get() { + return uuid; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingTest.java index e51f1ea18..0b8092018 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingTest.java @@ -399,8 +399,9 @@ public void generate(List messages, List toolSpe handler.onNext(last.text()); handler.onComplete(new Response<>(new AiMessage(""), new TokenUsage(0, 0), FinishReason.STOP)); + } else { + handler.onError(new RuntimeException("Invalid number of messages: " + messages.size())); } - handler.onError(new RuntimeException("Invalid number of messages")); } } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java index b16bd11f1..70479fffc 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java @@ -25,7 +25,6 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Flow; import java.util.concurrent.Future; -import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; @@ -70,6 +69,7 @@ import io.quarkiverse.langchain4j.runtime.QuarkusServiceOutputParser; import io.quarkiverse.langchain4j.runtime.ResponseSchemaUtil; import io.quarkiverse.langchain4j.spi.DefaultMemoryIdProvider; +import io.smallrye.common.vertx.VertxContext; import io.smallrye.mutiny.Multi; import io.smallrye.mutiny.infrastructure.Infrastructure; import io.smallrye.mutiny.operators.AbstractMulti; @@ -193,6 +193,7 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob augmentationResult = context.retrievalAugmentor.augment(augmentationRequest); userMessage = (UserMessage) augmentationResult.chatMessage(); } else { + // TODO duplicated context propagation. // this a special case where we can't block, so we need to delegate the retrieval augmentation to a worker pool CompletableFuture augmentationResultCF = CompletableFuture.supplyAsync(new Supplier<>() { @Override @@ -784,23 +785,32 @@ public TokenStreamMulti(List messagesToSend, List subscriber) { UnicastProcessor processor = UnicastProcessor.create(); processor.subscribe(subscriber); + createTokenStream(processor); } private void createTokenStream(UnicastProcessor processor) { - var stream = new AiServiceTokenStream(messagesToSend, toolSpecifications, - toolsExecutors, contents, context, memoryId); + Context ctxt = null; + if (mustSwitchToWorkerThread) { + // we create or retrieve the current context, to use `executeBlocking` when required. + ctxt = VertxContext.getOrCreateDuplicatedContext(); + } + + var stream = new QuarkusAiServiceTokenStream(messagesToSend, toolSpecifications, + toolsExecutors, contents, context, memoryId, ctxt); TokenStream tokenStream = stream .onNext(processor::onNext) - .onComplete(new Consumer<>() { - @Override - public void accept(Response message) { - processor.onComplete(); - } - }) + .onComplete(message -> processor.onComplete()) .onError(processor::onError); + // This is equivalent to "run subscription on worker thread" if (mustSwitchToWorkerThread && Context.isOnEventLoopThread()) { - Infrastructure.getDefaultWorkerPool().execute(tokenStream::start); + ctxt.executeBlocking(new Callable() { + @Override + public Void call() { + tokenStream.start(); + return null; + } + }); } else { tokenStream.start(); } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceStreamingResponseHandler.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceStreamingResponseHandler.java new file mode 100644 index 000000000..8b103d09a --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceStreamingResponseHandler.java @@ -0,0 +1,168 @@ +package io.quarkiverse.langchain4j.runtime.aiservice; + +import static dev.langchain4j.internal.Utils.copyIfNotNull; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.function.Consumer; + +import org.jboss.logging.Logger; + +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import dev.langchain4j.service.AiServiceContext; +import dev.langchain4j.service.tool.ToolExecutor; +import io.vertx.core.Context; + +/** + * A {@link StreamingResponseHandler} implementation for Quarkus. + * The main difference with the upstream implementation is the thread switch when receiving the `completion` event + * when there is tool execution requests. + */ +public class QuarkusAiServiceStreamingResponseHandler implements StreamingResponseHandler { + + private final Logger log = Logger.getLogger(QuarkusAiServiceStreamingResponseHandler.class); + + private final AiServiceContext context; + private final Object memoryId; + + private final Consumer tokenHandler; + private final Consumer> completionHandler; + private final Consumer errorHandler; + + private final List temporaryMemory; + private final TokenUsage tokenUsage; + + private final List toolSpecifications; + private final Map toolExecutors; + private final Context executionContext; + + QuarkusAiServiceStreamingResponseHandler(AiServiceContext context, + Object memoryId, + Consumer tokenHandler, + Consumer> completionHandler, + Consumer errorHandler, + List temporaryMemory, + TokenUsage tokenUsage, + List toolSpecifications, + Map toolExecutors, Context cxtx) { + this.context = ensureNotNull(context, "context"); + this.memoryId = ensureNotNull(memoryId, "memoryId"); + + this.tokenHandler = ensureNotNull(tokenHandler, "tokenHandler"); + this.completionHandler = completionHandler; + this.errorHandler = errorHandler; + + this.temporaryMemory = new ArrayList<>(temporaryMemory); + this.tokenUsage = ensureNotNull(tokenUsage, "tokenUsage"); + + this.toolSpecifications = copyIfNotNull(toolSpecifications); + this.toolExecutors = copyIfNotNull(toolExecutors); + + this.executionContext = cxtx; + } + + @Override + public void onNext(String token) { + tokenHandler.accept(token); + } + + private void executeTools(Runnable runnable) { + if (executionContext != null && Context.isOnEventLoopThread()) { + executionContext.executeBlocking(new Callable() { + @Override + public Object call() throws Exception { + runnable.run(); + return null; + } + }); + } else { + runnable.run(); + } + } + + @Override + public void onComplete(Response response) { + + AiMessage aiMessage = response.content(); + addToMemory(aiMessage); + + if (aiMessage.hasToolExecutionRequests()) { + // Tools execution may block the caller thread. When the caller thread is the event loop thread, and + // when tools have been detected to be potentially blocking, we need to switch to a worker thread. + executeTools(new Runnable() { + @Override + public void run() { + for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) { + String toolName = toolExecutionRequest.name(); + ToolExecutor toolExecutor = toolExecutors.get(toolName); + String toolExecutionResult = toolExecutor.execute(toolExecutionRequest, memoryId); + ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from( + toolExecutionRequest, + toolExecutionResult); + QuarkusAiServiceStreamingResponseHandler.this.addToMemory(toolExecutionResultMessage); + } + + context.streamingChatModel.generate( + QuarkusAiServiceStreamingResponseHandler.this.messagesToSend(memoryId), + toolSpecifications, + new QuarkusAiServiceStreamingResponseHandler( + context, + memoryId, + tokenHandler, + completionHandler, + errorHandler, + temporaryMemory, + TokenUsage.sum(tokenUsage, response.tokenUsage()), + toolSpecifications, + toolExecutors, + executionContext)); + } + }); + } else { + if (completionHandler != null) { + completionHandler.accept(Response.from( + aiMessage, + TokenUsage.sum(tokenUsage, response.tokenUsage()), + response.finishReason())); + } + } + } + + private void addToMemory(ChatMessage chatMessage) { + if (context.hasChatMemory()) { + context.chatMemory(memoryId).add(chatMessage); + } else { + temporaryMemory.add(chatMessage); + } + } + + private List messagesToSend(Object memoryId) { + return context.hasChatMemory() + ? context.chatMemory(memoryId).messages() + : temporaryMemory; + } + + @Override + public void onError(Throwable error) { + if (errorHandler != null) { + try { + errorHandler.accept(error); + } catch (Exception e) { + log.error("While handling the following error...", error); + log.error("...the following error happened", e); + } + } else { + log.warn("Ignored error", error); + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceTokenStream.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceTokenStream.java new file mode 100644 index 000000000..820e5ae65 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceTokenStream.java @@ -0,0 +1,155 @@ +package io.quarkiverse.langchain4j.runtime.aiservice; + +import static dev.langchain4j.internal.Utils.copyIfNotNull; +import static dev.langchain4j.internal.Utils.isNullOrEmpty; +import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +import static java.util.Collections.emptyList; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.exception.IllegalConfigurationException; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import dev.langchain4j.rag.content.Content; +import dev.langchain4j.service.AiServiceContext; +import dev.langchain4j.service.TokenStream; +import dev.langchain4j.service.tool.ToolExecutor; +import io.vertx.core.Context; + +/** + * An implementation of token stream for Quarkus. + * The only difference with the upstream implementation is the usage of the custom + * {@link QuarkusAiServiceStreamingResponseHandler} instead of the upstream one. + * It allows handling blocking tools execution, when we are invoked on the event loop. + */ +public class QuarkusAiServiceTokenStream implements TokenStream { + + private final List messages; + private final List toolSpecifications; + private final Map toolExecutors; + private final List retrievedContents; + private final AiServiceContext context; + private final Object memoryId; + private final Context cxtx; + + private Consumer tokenHandler; + private Consumer> contentsHandler; + private Consumer errorHandler; + private Consumer> completionHandler; + + private int onNextInvoked; + private int onCompleteInvoked; + private int onRetrievedInvoked; + private int onErrorInvoked; + private int ignoreErrorsInvoked; + + public QuarkusAiServiceTokenStream(List messages, + List toolSpecifications, + Map toolExecutors, + List retrievedContents, + AiServiceContext context, + Object memoryId, Context ctxt) { + this.messages = ensureNotEmpty(messages, "messages"); + this.toolSpecifications = copyIfNotNull(toolSpecifications); + this.toolExecutors = copyIfNotNull(toolExecutors); + this.retrievedContents = retrievedContents; + this.context = ensureNotNull(context, "context"); + this.memoryId = ensureNotNull(memoryId, "memoryId"); + ensureNotNull(context.streamingChatModel, "streamingChatModel"); + this.cxtx = ctxt; // If set, it means we need to switch to a worker thread to execute tools. + } + + @Override + public TokenStream onNext(Consumer tokenHandler) { + this.tokenHandler = tokenHandler; + this.onNextInvoked++; + return this; + } + + @Override + public TokenStream onRetrieved(Consumer> contentsHandler) { + this.contentsHandler = contentsHandler; + this.onRetrievedInvoked++; + return this; + } + + @Override + public TokenStream onComplete(Consumer> completionHandler) { + this.completionHandler = completionHandler; + this.onCompleteInvoked++; + return this; + } + + @Override + public TokenStream onError(Consumer errorHandler) { + this.errorHandler = errorHandler; + this.onErrorInvoked++; + return this; + } + + @Override + public TokenStream ignoreErrors() { + this.errorHandler = null; + this.ignoreErrorsInvoked++; + return this; + } + + @Override + public void start() { + validateConfiguration(); + QuarkusAiServiceStreamingResponseHandler handler = new QuarkusAiServiceStreamingResponseHandler( + context, + memoryId, + tokenHandler, + completionHandler, + errorHandler, + initTemporaryMemory(context, messages), + new TokenUsage(), + toolSpecifications, + toolExecutors, + cxtx); + + if (contentsHandler != null && retrievedContents != null) { + contentsHandler.accept(retrievedContents); + } + + if (isNullOrEmpty(toolSpecifications)) { + context.streamingChatModel.generate(messages, handler); + } else { + context.streamingChatModel.generate(messages, toolSpecifications, handler); + } + } + + private void validateConfiguration() { + if (onNextInvoked != 1) { + throw new IllegalConfigurationException("onNext must be invoked exactly 1 time"); + } + + if (onCompleteInvoked > 1) { + throw new IllegalConfigurationException("onComplete must be invoked at most 1 time"); + } + + if (onRetrievedInvoked > 1) { + throw new IllegalConfigurationException("onRetrieved must be invoked at most 1 time"); + } + + if (onErrorInvoked + ignoreErrorsInvoked != 1) { + throw new IllegalConfigurationException("One of onError or ignoreErrors must be invoked exactly 1 time"); + } + } + + private List initTemporaryMemory(AiServiceContext context, List messagesToSend) { + if (context.hasChatMemory()) { + return emptyList(); + } else { + return new ArrayList<>(messagesToSend); + } + } +} diff --git a/docs/modules/ROOT/pages/agent-and-tools.adoc b/docs/modules/ROOT/pages/agent-and-tools.adoc index 489d8b4b6..019c596ec 100644 --- a/docs/modules/ROOT/pages/agent-and-tools.adoc +++ b/docs/modules/ROOT/pages/agent-and-tools.adoc @@ -90,6 +90,87 @@ If a method of an AI Service needs to use tools other than the ones configured o In this case `@Toolbox` completely overrides `@RegisterAiService(tools=...)` ==== +== Tools execution model + +Tools can have different execution models: + +- _blocking_ - the tools execution blocks the caller thread +- _non-blocking_ - the tools execution is asynchronous and does not block the caller thread +- _run on a virtual threads - the tools execution is asynchronous and runs on a new virtual thread + +The execution model is configured using the `@Blocking`, `@NonBlocking` and `@RunOnVirtualThread` annotations. +In addition, tool methods returning `CompletionStage` or `Uni` are considered non-blocking (if no annotations are used). + +For example, the following tool is considered blocking: + +[source,java] +---- +@Tool("get the customer name for the given customerId") +public String getCustomerName(long id) { + return find("id", id).firstResult().name; +} +---- + +The following tool is considered non-blocking: + +[source,java] +---- +@Tool("add a and b") +@NonBlocking +int sum(int a, int b) { + return a + b; +} +---- + +This other tool is considered non-blocking as well: + +[source,java] +---- +@Tool("get the customer name for the given customerId") +public Uni getCustomerName(long id) { + // ... +} +---- + +The following tool runs on a virtual thread: + +[source,java] +---- +@Tool("get the customer name for the given customerId") +@RunOnVirtualThread +public String getCustomerName(long id) { + return find("id", id).firstResult().name; +} +---- + +=== Streaming + +The execution model is particularly important when using streamed response. +Indeed, streamed response are executed on the _event loop_, which cannot be blocked. +Thus, in this case, based on the execution model of the tool, Quarkus LangChain4J will automatically switch to a worker thread to execute the tool. +This mechanism allows to avoid blocking the _event loop_, while still allowing to use blocking tools (like database access). + +For example, let's imagine the following AI service method: + +[source,java] +---- +@UserMessage("...") +@ToolBox({TransactionRepository.class, CustomerRepository.class}) +Multi detectAmountFraudForCustomerStreamed(long customerId); +---- + +This method returns a stream of tokens. +Each token is emitted on the event loop. +The given tools are blocking (they access a database using a blocking manner). +Thus, Quarkus LangChain4J will automatically switch to a worker thread before invoking the tools, to avoid blocking the event loop. + +=== Request scope + +When the request scope is active when invoking the LLM, the tool execution is also bound to the request scope. +It means that you can propagate data to the tools execution, for example, a security context or a transaction context. + +The propagation also works with virtual thread or when Quarkus LangChain4J automatically switches to a worker thread. + == Dynamically Use Tools at Runtime By default, tools in Quarkus are static, defined during the build phase. diff --git a/integration-tests/pom.xml b/integration-tests/pom.xml index 4362e44c0..75d6836de 100644 --- a/integration-tests/pom.xml +++ b/integration-tests/pom.xml @@ -26,6 +26,7 @@ rag-pgvector rag-pgvector-flyway easy-rag + tools diff --git a/integration-tests/tools/pom.xml b/integration-tests/tools/pom.xml new file mode 100644 index 000000000..e81e63022 --- /dev/null +++ b/integration-tests/tools/pom.xml @@ -0,0 +1,124 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-integration-tests-parent + 999-SNAPSHOT + + quarkus-langchain4j-integration-test-tools + Quarkus LangChain4j - Integration Tests - Tools + + true + + + + + io.quarkus + quarkus-rest-jackson + + + io.quarkiverse.langchain4j + quarkus-langchain4j-watsonx + ${project.version} + + + io.quarkus + quarkus-junit5 + test + + + io.rest-assured + rest-assured + test + + + io.quarkus + quarkus-junit5-mockito + test + + + org.assertj + assertj-core + ${assertj.version} + test + + + io.quarkus + quarkus-devtools-testing + test + + + + + io.quarkiverse.langchain4j + quarkus-langchain4j-watsonx-deployment + ${project.version} + pom + test + + + * + * + + + + + + + + io.quarkus + quarkus-maven-plugin + + + + build + + + + + + maven-failsafe-plugin + + + + integration-test + verify + + + + ${project.build.directory}/${project.build.finalName}-runner + org.jboss.logmanager.LogManager + ${maven.home} + + + + + + + + + + native-image + + + native + + + + + + maven-surefire-plugin + + ${native.surefire.skip} + + + + + + false + native + + + + diff --git a/integration-tests/tools/src/main/java/org/acme/tools/AiService.java b/integration-tests/tools/src/main/java/org/acme/tools/AiService.java new file mode 100644 index 000000000..80c531308 --- /dev/null +++ b/integration-tests/tools/src/main/java/org/acme/tools/AiService.java @@ -0,0 +1,11 @@ +package org.acme.tools; + +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.ToolBox; + +@RegisterAiService +public interface AiService { + @ToolBox(Calculator.class) + public String chat(@UserMessage String message); +} diff --git a/integration-tests/tools/src/main/java/org/acme/tools/Calculator.java b/integration-tests/tools/src/main/java/org/acme/tools/Calculator.java new file mode 100644 index 000000000..4fd82b4e9 --- /dev/null +++ b/integration-tests/tools/src/main/java/org/acme/tools/Calculator.java @@ -0,0 +1,30 @@ +package org.acme.tools; + +import jakarta.enterprise.context.ApplicationScoped; + +import dev.langchain4j.agent.tool.Tool; +import io.smallrye.common.annotation.Blocking; +import io.smallrye.common.annotation.NonBlocking; +import io.smallrye.common.annotation.RunOnVirtualThread; + +@ApplicationScoped +public class Calculator { + + @Tool + @Blocking + public int blockingSum(int a, int b) { + return a + b; + } + + @Tool + @NonBlocking + public int nonBlockingSum(int a, int b) { + return a + b; + } + + @Tool + @RunOnVirtualThread + public int virtualSum(int a, int b) { + return a + b; + } +} diff --git a/integration-tests/tools/src/main/resources/application.properties b/integration-tests/tools/src/main/resources/application.properties new file mode 100644 index 000000000..28bf72df9 --- /dev/null +++ b/integration-tests/tools/src/main/resources/application.properties @@ -0,0 +1,4 @@ +quarkus.langchain4j.watsonx.base-url=https://toolstest.com +quarkus.langchain4j.watsonx.api-key=api-key +quarkus.langchain4j.watsonx.project-id=project-id +quarkus.langchain4j.watsonx.version=yyyy-mm-dd \ No newline at end of file diff --git a/integration-tests/tools/src/test/java/org/acme/tools/ToolsTest.java b/integration-tests/tools/src/test/java/org/acme/tools/ToolsTest.java new file mode 100644 index 000000000..ed7bac0d8 --- /dev/null +++ b/integration-tests/tools/src/test/java/org/acme/tools/ToolsTest.java @@ -0,0 +1,80 @@ +package org.acme.tools; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.anyList; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledForJreRange; +import org.junit.jupiter.api.condition.JRE; +import org.mockito.Mockito; + +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import io.quarkus.test.InjectMock; +import io.quarkus.test.junit.QuarkusTest; + +@QuarkusTest +@EnabledForJreRange(min = JRE.JAVA_21) +@SuppressWarnings("unchecked") +public class ToolsTest { + + @Inject + AiService aiService; + + @InjectMock + ChatLanguageModel model; + + @Test + void blockingSum() { + var toolExecution = create("blockingSum"); + + Mockito.when(model.generate(anyList(), anyList())) + .thenReturn( + Response.from(AiMessage.from(toolExecution), new TokenUsage(1)), + Response.from(AiMessage.from("The result is 2"), new TokenUsage(1))); + + assertEquals("The result is 2", aiService.chat("Execute 1 + 1")); + } + + @Test + void nonBlockingSum() { + var toolExecution = create("nonBlockingSum"); + + Mockito.when(model.generate(anyList(), anyList())) + .thenReturn( + Response.from(AiMessage.from(toolExecution), new TokenUsage(1)), + Response.from(AiMessage.from("The result is 2"), new TokenUsage(1))); + + assertEquals("The result is 2", aiService.chat("Execute 1 + 1")); + } + + @Test + void virtualThreadSum() { + var toolExecution = create("virtualSum"); + + Mockito.when(model.generate(anyList(), anyList())) + .thenReturn( + Response.from(AiMessage.from(toolExecution), new TokenUsage(1)), + Response.from(AiMessage.from("The result is 2"), new TokenUsage(1))); + + assertEquals("The result is 2", aiService.chat("Execute 1 + 1")); + } + + private ToolExecutionRequest create(String methodName) { + return ToolExecutionRequest.builder() + .id("1") + .name(methodName) + .arguments(""" + { + "a": 1, + "b": 1 + } + """) + .build(); + } +}