diff --git a/core/deployment/pom.xml b/core/deployment/pom.xml index 3153acca2..ab47f8411 100644 --- a/core/deployment/pom.xml +++ b/core/deployment/pom.xml @@ -102,6 +102,11 @@ + + io.quarkus + quarkus-test-vertx + test + diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index 25e497c84..30ee9a92b 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -8,6 +8,7 @@ import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.NO_RETRIEVAL_AUGMENTOR_SUPPLIER; import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.NO_RETRIEVER; import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.OUTPUT_GUARDRAILS; +import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.REGISTER_AI_SERVICES; import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.SEED_MEMORY; import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.V; import static io.quarkiverse.langchain4j.deployment.MethodParameterAsTemplateVariableAllowance.FORCE_ALLOW; @@ -69,6 +70,7 @@ import io.quarkiverse.langchain4j.deployment.items.MethodParameterAllowedAnnotationsBuildItem; import io.quarkiverse.langchain4j.deployment.items.MethodParameterIgnoredAnnotationsBuildItem; import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem; +import io.quarkiverse.langchain4j.deployment.items.ToolMethodBuildItem; import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator; import io.quarkiverse.langchain4j.runtime.AiServicesRecorder; @@ -123,6 +125,7 @@ import io.quarkus.gizmo.ResultHandle; import io.quarkus.runtime.metrics.MetricsFactory; import io.smallrye.mutiny.Multi; +import io.smallrye.mutiny.Uni; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") public class AiServicesProcessor { @@ -760,6 +763,57 @@ public void markUsedOutputGuardRailsUnremovable(List } } + /** + * Because the tools execution uses an imperative API (`String execute(...)`) and uses the caller thread, we need + * to anticipate the need to dispatch the invocation on a worker thread. + * This is the case for AI service methods that returns `Uni`, `CompletionStage` and `Multi` (stream) and that uses + * tools returning `Uni`, `CompletionStage` or that are blocking. + * Basically, for "reactive AI service method, the switch is necessary except if all the tools are imperative (return `T`) + * and marked explicitly as blocking (using `@Blocking`). + * + * @param method the AI method + * @param tools the tools + */ + public boolean detectAiServiceMethodThanNeedToBeDispatchedOnWorkerThread( + MethodInfo method, + List associatedTools, + List tools) { + boolean reactive = method.returnType().name().equals(DotNames.UNI) + || method.returnType().name().equals(DotNames.COMPLETION_STAGE) + || method.returnType().name().equals(DotNames.MULTI); + + boolean requireSwitchToWorkerThread = false; + + if (associatedTools.isEmpty()) { + // No tools, no need to dispatch + return false; + } + + if (!reactive) { + // We are already on a thread we can block. + return false; + } + + // We need to find if any of the tools that could be used by the method is requiring a blocking execution + for (String classname : associatedTools) { + // Look for the tool in the list of tools + boolean found = false; + for (ToolMethodBuildItem tool : tools) { + if (tool.getDeclaringClassName().equals(classname)) { + found = true; + if (tool.requiresSwitchToWorkerThread()) { + requireSwitchToWorkerThread = true; + break; + } + } + } + if (!found) { + throw new RuntimeException("No tools detected in " + classname); + } + } + return requireSwitchToWorkerThread; + } + @BuildStep public void validateGuardrails(SynthesisFinishedBuildItem synthesisFinished, List methods, @@ -857,7 +911,8 @@ public void handleAiServices( BuildProducer additionalBeanProducer, BuildProducer unremovableBeanProducer, Optional metricsCapability, - Capabilities capabilities) { + Capabilities capabilities, + List tools) { IndexView index = indexBuildItem.getIndex(); @@ -1026,7 +1081,8 @@ public void handleAiServices( addOpenTelemetrySpan, config.responseSchema(), allowedPredicates, - ignoredPredicates); + ignoredPredicates, + tools); if (!methodCreateInfo.getToolClassNames().isEmpty()) { unremovableBeanProducer.produce(UnremovableBeanBuildItem .beanClassNames(methodCreateInfo.getToolClassNames().toArray(EMPTY_STRING_ARRAY))); @@ -1068,7 +1124,9 @@ public void handleAiServices( aiServicesMethodProducer.produce(new AiServicesMethodBuildItem(methodInfo, methodCreateInfo.getInputGuardrailsClassNames(), - methodCreateInfo.getOutputGuardrailsClassNames())); + methodCreateInfo.getOutputGuardrailsClassNames(), + gatherMethodToolClassNames(methodInfo), + methodCreateInfo)); } } @@ -1155,7 +1213,8 @@ private AiServiceMethodCreateInfo gatherMethodMetadata( MethodInfo method, IndexView index, boolean addMicrometerMetrics, boolean addOpenTelemetrySpans, boolean generateResponseSchema, Collection> allowedPredicates, - Collection> ignoredPredicates) { + Collection> ignoredPredicates, + List tools) { validateReturnType(method); boolean requiresModeration = method.hasAnnotation(LangChain4jDotNames.MODERATE); @@ -1200,11 +1259,28 @@ private AiServiceMethodCreateInfo gatherMethodMetadata( String accumulatorClassName = AiServicesMethodBuildItem.gatherAccumulator(method); + // Detect if tools execution may block the caller thread. + boolean switchToWorkerThread = detectIfToolExecutionRequiresAWorkerThread(method, tools, methodToolClassNames); + return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo, userMessageInfo, memoryIdParamPosition, requiresModeration, returnTypeSignature(method.returnType(), new TypeArgMapper(method.declaringClass(), index)), - metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames, inputGuardrails, - outputGuardrails, accumulatorClassName); + metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames, switchToWorkerThread, + inputGuardrails, outputGuardrails, accumulatorClassName); + } + + private boolean detectIfToolExecutionRequiresAWorkerThread(MethodInfo method, List tools, + List methodToolClassNames) { + List allTools = new ArrayList<>(methodToolClassNames); + // We need to combine it with the tools that are registered globally - unfortunately, we don't have access to the AI service here, so, re-parsing. + AnnotationInstance annotation = method.declaringClass().annotation(REGISTER_AI_SERVICES); + if (annotation != null) { + AnnotationValue value = annotation.value("tools"); + if (value != null) { + allTools.addAll(Arrays.stream(value.asClassArray()).map(t -> t.name().toString()).toList()); + } + } + return detectAiServiceMethodThanNeedToBeDispatchedOnWorkerThread(method, allTools, tools); } private void validateReturnType(MethodInfo method) { @@ -1688,11 +1764,17 @@ public static final class AiServicesMethodBuildItem extends MultiBuildItem { private final MethodInfo methodInfo; private final List outputGuardrails; private final List inputGuardrails; + private final List tools; + private final AiServiceMethodCreateInfo methodCreateInfo; - public AiServicesMethodBuildItem(MethodInfo methodInfo, List inputGuardrails, List outputGuardrails) { + public AiServicesMethodBuildItem(MethodInfo methodInfo, List inputGuardrails, List outputGuardrails, + List tools, + AiServiceMethodCreateInfo methodCreateInfo) { this.methodInfo = methodInfo; this.inputGuardrails = inputGuardrails; this.outputGuardrails = outputGuardrails; + this.tools = tools; + this.methodCreateInfo = methodCreateInfo; } public List getOutputGuardrails() { @@ -1707,6 +1789,10 @@ public MethodInfo getMethodInfo() { return methodInfo; } + public AiServiceMethodCreateInfo getMethodCreateInfo() { + return methodCreateInfo; + } + public static List gatherGuardrails(MethodInfo methodInfo, DotName annotation) { List guardrails = new ArrayList<>(); AnnotationInstance instance = methodInfo.annotation(annotation); diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java index 395cb65dd..25d8571a8 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java @@ -4,13 +4,19 @@ import java.math.BigInteger; import java.util.List; import java.util.Set; +import java.util.concurrent.CompletionStage; import jakarta.enterprise.inject.Instance; import org.jboss.jandex.DotName; +import dev.langchain4j.agent.tool.Tool; import dev.langchain4j.model.chat.listener.ChatModelListener; +import io.smallrye.common.annotation.Blocking; +import io.smallrye.common.annotation.NonBlocking; +import io.smallrye.common.annotation.RunOnVirtualThread; import io.smallrye.mutiny.Multi; +import io.smallrye.mutiny.Uni; public class DotNames { @@ -38,10 +44,16 @@ public class DotNames { public static final DotName LIST = DotName.createSimple(List.class); public static final DotName SET = DotName.createSimple(Set.class); public static final DotName MULTI = DotName.createSimple(Multi.class); + public static final DotName UNI = DotName.createSimple(Uni.class); + public static final DotName BLOCKING = DotName.createSimple(Blocking.class); + public static final DotName NON_BLOCKING = DotName.createSimple(NonBlocking.class); + public static final DotName COMPLETION_STAGE = DotName.createSimple(CompletionStage.class); + public static final DotName RUN_ON_VIRTUAL_THREAD = DotName.createSimple(RunOnVirtualThread.class); public static final DotName OBJECT = DotName.createSimple(Object.class.getName()); public static final DotName RECORD = DotName.createSimple(Record.class); public static final DotName CDI_INSTANCE = DotName.createSimple(Instance.class); public static final DotName CHAT_MODEL_LISTENER = DotName.createSimple(ChatModelListener.class); + public static final DotName TOOL = DotName.createSimple(Tool.class); } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java index 7201e1835..6493730f9 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java @@ -8,6 +8,12 @@ import static dev.langchain4j.agent.tool.JsonSchemaProperty.STRING; import static dev.langchain4j.agent.tool.JsonSchemaProperty.description; import static dev.langchain4j.agent.tool.JsonSchemaProperty.enums; +import static io.quarkiverse.langchain4j.deployment.DotNames.BLOCKING; +import static io.quarkiverse.langchain4j.deployment.DotNames.COMPLETION_STAGE; +import static io.quarkiverse.langchain4j.deployment.DotNames.MULTI; +import static io.quarkiverse.langchain4j.deployment.DotNames.NON_BLOCKING; +import static io.quarkiverse.langchain4j.deployment.DotNames.RUN_ON_VIRTUAL_THREAD; +import static io.quarkiverse.langchain4j.deployment.DotNames.UNI; import static java.util.Arrays.stream; import static java.util.stream.Collectors.toList; @@ -15,10 +21,13 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.function.BiFunction; +import java.util.function.Predicate; import java.util.stream.Collectors; import org.jboss.jandex.AnnotationInstance; @@ -41,6 +50,7 @@ import dev.langchain4j.agent.tool.ToolMemoryId; import dev.langchain4j.agent.tool.ToolParameters; import dev.langchain4j.agent.tool.ToolSpecification; +import io.quarkiverse.langchain4j.deployment.items.ToolMethodBuildItem; import io.quarkiverse.langchain4j.runtime.ToolsRecorder; import io.quarkiverse.langchain4j.runtime.prompt.Mappable; import io.quarkiverse.langchain4j.runtime.tool.ToolInvoker; @@ -61,6 +71,7 @@ import io.quarkus.deployment.builditem.CombinedIndexBuildItem; import io.quarkus.deployment.builditem.GeneratedClassBuildItem; import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem; +import io.quarkus.deployment.execannotations.ExecutionModelAnnotationsAllowedBuildItem; import io.quarkus.deployment.recording.RecorderContext; import io.quarkus.gizmo.ClassCreator; import io.quarkus.gizmo.ClassOutput; @@ -74,12 +85,14 @@ public class ToolProcessor { private static final DotName TOOL = DotName.createSimple(Tool.class); private static final DotName TOOL_MEMORY_ID = DotName.createSimple(ToolMemoryId.class); + private static final DotName P = DotName.createSimple(dev.langchain4j.agent.tool.P.class); private static final MethodDescriptor METHOD_METADATA_CTOR = MethodDescriptor .ofConstructor(ToolInvoker.MethodMetadata.class, boolean.class, Map.class, Integer.class); private static final MethodDescriptor HASHMAP_CTOR = MethodDescriptor.ofConstructor(HashMap.class); public static final MethodDescriptor MAP_PUT = MethodDescriptor.ofMethod(Map.class, "put", Object.class, Object.class, Object.class); + private static final Logger log = Logger.getLogger(ToolProcessor.class); @BuildStep @@ -91,7 +104,9 @@ public void telemetry(Capabilities capabilities, BuildProducer toolMethodBuildItemProducer, + CombinedIndexBuildItem indexBuildItem, BuildProducer additionalBeanProducer, BuildProducer transformerProducer, BuildProducer generatedClassProducer, @@ -107,6 +122,8 @@ public void handleTools(CombinedIndexBuildItem indexBuildItem, List generatedInvokerClasses = new ArrayList<>(); List generatedArgumentMapperClasses = new ArrayList<>(); + Set toolsNames = new HashSet<>(); + if (!instances.isEmpty()) { ClassOutput classOutput = new GeneratedClassGizmoAdaptor(generatedClassProducer, true); @@ -223,7 +240,17 @@ public void handleTools(CombinedIndexBuildItem indexBuildItem, ToolSpecification toolSpecification = builder.build(); ToolMethodCreateInfo methodCreateInfo = new ToolMethodCreateInfo( toolMethod.name(), invokerClassName, - toolSpecification, argumentMapperClassName); + toolSpecification, argumentMapperClassName, determineExecutionModel(toolMethod)); + + validateExecutionModel(methodCreateInfo, toolMethod, validation); + + if (toolsNames.add(toolName)) { + toolMethodBuildItemProducer.produce(new ToolMethodBuildItem(toolMethod, methodCreateInfo)); + } else { + validation.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem( + new IllegalStateException("A tool with the name '" + toolName + + "' is already declared. Tools method name must be unique."))); + } metadata.computeIfAbsent(className.toString(), (c) -> new ArrayList<>()).add(methodCreateInfo); @@ -248,13 +275,49 @@ public void handleTools(CombinedIndexBuildItem indexBuildItem, toolsMetadataProducer.produce(new ToolsMetadataBeforeRemovalBuildItem(metadata)); } + @BuildStep + ExecutionModelAnnotationsAllowedBuildItem toolsMethods() { + return new ExecutionModelAnnotationsAllowedBuildItem(new Predicate() { + @Override + public boolean test(MethodInfo method) { + return method.hasDeclaredAnnotation(DotNames.TOOL); + } + }); + } + + private void validateExecutionModel(ToolMethodCreateInfo methodCreateInfo, MethodInfo toolMethod, + BuildProducer validation) { + String methodName = toolMethod.declaringClass().name() + "." + toolMethod.name(); + + if (MULTI.equals(toolMethod.returnType().name())) { + validation.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem( + new Exception("Method " + methodName + " returns Multi, which is not supported for tools"))); + } + + if (methodCreateInfo.executionModel() == ToolMethodCreateInfo.ExecutionModel.VIRTUAL_THREAD) { + // We can't use Uni or CS with virtual thread + if (UNI.equals(toolMethod.returnType().name())) { + validation.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem( + new Exception("Method " + methodName + + " returns Uni, which is not supported with @RunOnVirtualThread for tools"))); + } + if (COMPLETION_STAGE.equals(toolMethod.returnType().name())) { + validation.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem( + new Exception("Method " + methodName + + " returns CompletionStage, which is not supported with @RunOnVirtualThread for tools"))); + } + } + + } + /** * Transforms ToolsMetadataBeforeRemovalBuildItem into ToolsMetadataBuildItem by filtering * out tools belonging to beans that have been removed by ArC. */ @BuildStep @Record(ExecutionTime.STATIC_INIT) - public ToolsMetadataBuildItem filterOutRemovedTools(ToolsMetadataBeforeRemovalBuildItem beforeRemoval, + public ToolsMetadataBuildItem filterOutRemovedTools( + ToolsMetadataBeforeRemovalBuildItem beforeRemoval, ValidationPhaseBuildItem validationPhase, RecorderContext recorderContext, ToolsRecorder recorder) { @@ -552,4 +615,20 @@ public ClassVisitor apply(String className, ClassVisitor classVisitor) { return transformer.applyTo(classVisitor); } } + + private ToolMethodCreateInfo.ExecutionModel determineExecutionModel(MethodInfo methodInfo) { + if (methodInfo.hasAnnotation(BLOCKING)) { + return ToolMethodCreateInfo.ExecutionModel.BLOCKING; + } + Type returnedType = methodInfo.returnType(); + if (methodInfo.hasAnnotation(NON_BLOCKING) + || UNI.equals(returnedType.name()) || COMPLETION_STAGE.equals(returnedType.name()) + || MULTI.equals(returnedType.name())) { + return ToolMethodCreateInfo.ExecutionModel.NON_BLOCKING; + } + if (methodInfo.hasAnnotation(RUN_ON_VIRTUAL_THREAD)) { + return ToolMethodCreateInfo.ExecutionModel.VIRTUAL_THREAD; + } + return ToolMethodCreateInfo.ExecutionModel.BLOCKING; + } } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/ToolMethodBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/ToolMethodBuildItem.java new file mode 100644 index 000000000..9c1f25b34 --- /dev/null +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/ToolMethodBuildItem.java @@ -0,0 +1,53 @@ +package io.quarkiverse.langchain4j.deployment.items; + +import org.jboss.jandex.MethodInfo; + +import io.quarkiverse.langchain4j.deployment.DotNames; +import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo; +import io.quarkus.builder.item.MultiBuildItem; + +/** + * A build item that represents a method that is annotated with {@link dev.langchain4j.agent.tool.Tool}. + * It contains the method info and the tool method create info. + */ +public final class ToolMethodBuildItem extends MultiBuildItem { + + private final MethodInfo toolsMethodInfo; + + private final ToolMethodCreateInfo toolMethodCreateInfo; + + public ToolMethodBuildItem(MethodInfo toolsMethodInfo, ToolMethodCreateInfo toolMethodCreateInfo) { + this.toolsMethodInfo = toolsMethodInfo; + this.toolMethodCreateInfo = toolMethodCreateInfo; + } + + public MethodInfo getToolsMethodInfo() { + return toolsMethodInfo; + } + + public String getDeclaringClassName() { + return toolsMethodInfo.declaringClass().name().toString(); + } + + public ToolMethodCreateInfo getToolMethodCreateInfo() { + return toolMethodCreateInfo; + } + + /** + * Returns true if the method requires a switch to a worker thread, even if the method is non-blocking. + * This is because of the tools executor limitation (imperative API). + * + * @return true if the method requires a switch to a worker thread + */ + public boolean requiresSwitchToWorkerThread() { + return !(toolMethodCreateInfo.executionModel() == ToolMethodCreateInfo.ExecutionModel.NON_BLOCKING + && isImperativeMethod()); + } + + private boolean isImperativeMethod() { + var type = toolsMethodInfo.returnType(); + return !DotNames.UNI.equals(type.name()) + && !DotNames.MULTI.equals(type.name()) + && !DotNames.COMPLETION_STAGE.equals(type.name()); + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/Lists.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/Lists.java new file mode 100644 index 000000000..850b1ad57 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/Lists.java @@ -0,0 +1,13 @@ +package io.quarkiverse.langchain4j.test; + +import java.util.List; + +public class Lists { + + public static T last(List list) { + if (list.isEmpty()) { + return null; + } + return list.get(list.size() - 1); + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/ToolExecutorTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/ToolExecutorTest.java index f68de2ff5..feb6d3e3e 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/ToolExecutorTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/ToolExecutorTest.java @@ -291,7 +291,7 @@ private ToolExecutor getToolExecutor(String methodName) { toolSpecification.name())) { // this only works because TestTool does not contain overloaded methods toolExecutor = new QuarkusToolExecutor( new QuarkusToolExecutor.Context(testTool, invokerClassName, methodCreateInfo.methodName(), - methodCreateInfo.argumentMapperClassName())); + methodCreateInfo.argumentMapperClassName(), methodCreateInfo.executionModel())); break; } } diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelTest.java new file mode 100644 index 000000000..ad9044c42 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelTest.java @@ -0,0 +1,340 @@ +package io.quarkiverse.langchain4j.test.tools; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledForJreRange; +import org.junit.jupiter.api.condition.JRE; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.testcontainers.shaded.org.awaitility.Awaitility; + +import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.ToolBox; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkiverse.langchain4j.test.Lists; +import io.quarkus.arc.Arc; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.virtual.threads.VirtualThreadsRecorder; +import io.smallrye.common.annotation.NonBlocking; +import io.smallrye.common.annotation.RunOnVirtualThread; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Vertx; + +public class ToolExecutionModelTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses(MyAiService.class, Lists.class)); + + @Inject + MyAiService aiService; + + @Inject + Vertx vertx; + + @Test + @ActivateRequestContext + void testBlockingToolInvocationFromWorkerThread() { + String uuid = UUID.randomUUID().toString(); + var r = aiService.hello("abc", "hi - " + uuid); + assertThat(r).contains(uuid, Thread.currentThread().getName()); // We are invoked on the same thread + } + + @Test + @ActivateRequestContext + void testBlockingToolInvocationFromEventLoop() { + String uuid = UUID.randomUUID().toString(); + AtomicReference failure = new AtomicReference<>(); + vertx.getOrCreateContext().runOnContext(x -> { + try { + Arc.container().requestContext().activate(); + aiService.hello("abc", "hi - " + uuid); + } catch (IllegalStateException e) { + failure.set(e); + } finally { + Arc.container().requestContext().deactivate(); + } + }); + + Awaitility.await().until(() -> failure.get() != null); + assertThat(failure.get()).hasMessageContaining("Cannot execute blocking tools on event loop thread"); + } + + @Test + @ActivateRequestContext + @EnabledForJreRange(min = JRE.JAVA_21) + void testBlockingToolInvocationFromVirtualThread() throws ExecutionException, InterruptedException { + String uuid = UUID.randomUUID().toString(); + AtomicReference caller = new AtomicReference<>(); + var r = VirtualThreadsRecorder.getCurrent().submit(() -> { + try { + Arc.container().requestContext().activate(); + caller.set(Thread.currentThread().getName()); + return aiService.hello("abc", "hi - " + uuid); + } finally { + Arc.container().requestContext().deactivate(); + } + }).get(); + + // The blocking tool is executed on the same thread + assertThat(r).contains(uuid, "quarkus-virtual-thread-") + .contains(caller.get()); + } + + @Test + @ActivateRequestContext + void testNonBlockingToolInvocationFromWorkerThread() { + String uuid = UUID.randomUUID().toString(); + var r = aiService.hello("abc", "hiNonBlocking - " + uuid); + assertThat(r).contains(uuid, Thread.currentThread().getName()); // We are invoked on the same thread + } + + @Test + @ActivateRequestContext + void testNonBlockingToolInvocationFromEventLoop() { + String uuid = UUID.randomUUID().toString(); + AtomicReference result = new AtomicReference<>(); + AtomicReference caller = new AtomicReference<>(); + ; + vertx.getOrCreateContext().runOnContext(x -> { + try { + caller.set(Thread.currentThread().getName()); + Arc.container().requestContext().activate(); + result.set(aiService.hello("abc", "hiNonBlocking - " + uuid)); + } finally { + Arc.container().requestContext().deactivate(); + } + }); + + Awaitility.await().until(() -> result.get() != null); + assertThat(result.get()).contains(uuid, caller.get()); + } + + @Test + @ActivateRequestContext + @EnabledForJreRange(min = JRE.JAVA_21) + void testNonBlockingToolInvocationFromVirtualThread() throws ExecutionException, InterruptedException { + String uuid = UUID.randomUUID().toString(); + AtomicReference caller = new AtomicReference<>(); + var r = VirtualThreadsRecorder.getCurrent().submit(() -> { + try { + Arc.container().requestContext().activate(); + caller.set(Thread.currentThread().getName()); + return aiService.hello("abc", "hiNonBlocking - " + uuid); + } finally { + Arc.container().requestContext().deactivate(); + } + }).get(); + + // The blocking tool is executed on the same thread + assertThat(r).contains(uuid, "quarkus-virtual-thread-") + .contains(caller.get()); + } + + @Test + @ActivateRequestContext + void testUniToolInvocationFromWorkerThread() { + String uuid = UUID.randomUUID().toString(); + var r = aiService.hello("abc", "hiUni - " + uuid); + assertThat(r).contains(uuid, Thread.currentThread().getName()); // We are invoked on the same thread + } + + @Test + @ActivateRequestContext + void testUniToolInvocationFromEventLoop() { + String uuid = UUID.randomUUID().toString(); + AtomicReference failure = new AtomicReference<>(); + vertx.getOrCreateContext().runOnContext(x -> { + try { + Arc.container().requestContext().activate(); + aiService.hello("abc", "hiUni - " + uuid); + } catch (Exception e) { + failure.set(e); + } finally { + Arc.container().requestContext().deactivate(); + } + }); + + Awaitility.await().until(() -> failure.get() != null); + assertThat(failure.get()).hasMessageContaining("Cannot execute tools returning Uni on event loop thread"); + } + + @Test + @ActivateRequestContext + @EnabledForJreRange(min = JRE.JAVA_21) + void testUniToolInvocationFromVirtualThread() throws ExecutionException, InterruptedException { + String uuid = UUID.randomUUID().toString(); + AtomicReference caller = new AtomicReference<>(); + var r = VirtualThreadsRecorder.getCurrent().submit(() -> { + try { + Arc.container().requestContext().activate(); + caller.set(Thread.currentThread().getName()); + return aiService.hello("abc", "hiUni - " + uuid); + } finally { + Arc.container().requestContext().deactivate(); + } + }).get(); + + // The blocking tool is executed on the same thread + assertThat(r).contains(uuid, "quarkus-virtual-thread-") + .contains(caller.get()); + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_21) + @ActivateRequestContext + void testToolInvocationOnVirtualThread() { + String uuid = UUID.randomUUID().toString(); + var r = aiService.hello("abc", "hiVirtualThread - " + uuid); + assertThat(r).contains(uuid, "quarkus-virtual-thread-"); + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_21) + void testToolInvocationOnVirtualThreadFromVirtualThread() throws ExecutionException, InterruptedException { + String uuid = UUID.randomUUID().toString(); + AtomicReference caller = new AtomicReference<>(); + var r = VirtualThreadsRecorder.getCurrent().submit(() -> { + try { + Arc.container().requestContext().activate(); + caller.set(Thread.currentThread().getName()); + return aiService.hello("abc", "hiVirtualThread - " + uuid); + } finally { + Arc.container().requestContext().deactivate(); + } + }).get(); + + // At the moment, we create a virtual thread every time. + assertThat(r).contains(uuid, "quarkus-virtual-thread-") + .doesNotContain(caller.get()); + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_21) + void testToolInvocationOnVirtualThreadFromEventLoop() { + String uuid = UUID.randomUUID().toString(); + AtomicReference failure = new AtomicReference<>(); + vertx.getOrCreateContext().runOnContext(x -> { + try { + Arc.container().requestContext().activate(); + aiService.hello("abc", "hiVirtualThread - " + uuid); + } catch (IllegalStateException e) { + failure.set(e); + } finally { + Arc.container().requestContext().deactivate(); + } + }); + + Awaitility.await().until(() -> failure.get() != null); + assertThat(failure.get()).hasMessageContaining("Cannot execute virtual thread tools on event loop thread"); + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @ToolBox(MyTool.class) + String hello(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId); + } + + @Singleton + public static class MyTool { + @Tool + public String hi(String m) { + return m + " " + Thread.currentThread(); + } + + @Tool + @NonBlocking + public String hiNonBlocking(String m) { + return m + " " + Thread.currentThread(); + } + + @Tool + public Uni hiUni(String m) { + return Uni.createFrom().item(() -> m + " " + Thread.currentThread()); + } + + @Tool + @RunOnVirtualThread + public String hiVirtualThread(String m) { + return m + " " + Thread.currentThread(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatLanguageModel get() { + return new MyChatModel(); + } + } + + public static class MyChatModel implements ChatLanguageModel { + + @Override + public Response generate(List messages) { + throw new UnsupportedOperationException("Should not be called"); + } + + @Override + public Response generate(List messages, List toolSpecifications) { + if (messages.size() == 1) { + // Only the user message, extract the tool id from it + String text = ((dev.langchain4j.data.message.UserMessage) messages.get(0)).singleText(); + var segments = text.split(" - "); + var toolId = segments[0]; + var content = segments[1]; + // Only the user message + return new Response<>(new AiMessage("cannot be blank", List.of(ToolExecutionRequest.builder() + .id("my-tool-" + toolId) + .name(toolId) + .arguments("{\"m\":\"" + content + "\"}") + .build())), new TokenUsage(0, 0), FinishReason.TOOL_EXECUTION); + } else if (messages.size() == 3) { + // user -> tool request -> tool response + ToolExecutionResultMessage last = (ToolExecutionResultMessage) Lists.last(messages); + return new Response<>(AiMessage.from("response: " + last.text())); + + } + return new Response<>(new AiMessage("Unexpected")); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new NoopChatMemory(); + } + }; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingAndRequestScopePropagationTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingAndRequestScopePropagationTest.java new file mode 100644 index 000000000..116efcff4 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingAndRequestScopePropagationTest.java @@ -0,0 +1,427 @@ +package io.quarkiverse.langchain4j.test.tools; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledForJreRange; +import org.junit.jupiter.api.condition.JRE; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.testcontainers.shaded.org.awaitility.Awaitility; + +import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.ToolBox; +import io.quarkiverse.langchain4j.test.Lists; +import io.quarkus.arc.Arc; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.vertx.RunOnVertxContext; +import io.quarkus.virtual.threads.VirtualThreadsRecorder; +import io.smallrye.common.annotation.NonBlocking; +import io.smallrye.common.annotation.RunOnVirtualThread; +import io.smallrye.common.vertx.VertxContext; +import io.smallrye.mutiny.Multi; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Vertx; + +public class ToolExecutionModelWithStreamingAndRequestScopePropagationTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses(MyAiService.class, Lists.class)); + + @Inject + MyAiService aiService; + + @Inject + Vertx vertx; + + @Inject + UUIDGenerator uuidGenerator; + + @Test + @ActivateRequestContext + void testBlockingToolInvocationFromWorkerThread() { + String uuid = uuidGenerator.get(); + var r = aiService.hello("abc", "hi") + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + assertThat(r).contains(uuid, Thread.currentThread().getName()); // We are invoked on the same thread + } + + @Test + void testBlockingToolInvocationFromEventLoop() { + AtomicReference failure = new AtomicReference<>(); + AtomicReference result = new AtomicReference<>(); + AtomicReference value = new AtomicReference<>(); + var ctxt = VertxContext.getOrCreateDuplicatedContext(vertx); + ctxt.runOnContext(x -> { + try { + Arc.container().requestContext().activate(); + String uuid = uuidGenerator.get(); + value.set(uuid); + aiService.hello("abc", "hi") + .collect().asList().map(l -> String.join(" ", l)) + .subscribeAsCompletionStage() + .thenAccept(result::set) + .whenComplete((r, t) -> Arc.container().requestContext().deactivate()); + } catch (IllegalStateException e) { + failure.set(e); + Arc.container().requestContext().deactivate(); + } + }); + + // We would automatically detect this case and switch to a worker thread at subscription time. + + Awaitility.await().until(() -> failure.get() != null || result.get() != null); + assertThat(failure.get()).isNull(); + assertThat(result.get()).doesNotContain("event", "loop") + .contains(value.get(), "executor-thread"); + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_21) + void testBlockingToolInvocationFromVirtualThread() throws ExecutionException, InterruptedException { + AtomicReference value = new AtomicReference<>(); + AtomicReference caller = new AtomicReference<>(); + var r = VirtualThreadsRecorder.getCurrent().submit(() -> { + try { + Arc.container().requestContext().activate(); + value.set(uuidGenerator.get()); + caller.set(Thread.currentThread().getName()); + return aiService.hello("abc", "hi") + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + } finally { + Arc.container().requestContext().deactivate(); + } + }).get(); + + // The blocking tool is executed on the same thread + assertThat(r).contains(value.get(), "quarkus-virtual-thread-") + .contains(caller.get()); + } + + @Test + @ActivateRequestContext + void testNonBlockingToolInvocationFromWorkerThread() { + String uuid = uuidGenerator.get(); + var r = aiService.helloNonBlocking("abc", "hiNonBlocking") + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + assertThat(r).contains(uuid, Thread.currentThread().getName()); // We are invoked on the same thread + } + + @Test + @ActivateRequestContext + void testNonBlockingToolInvocationFromEventLoop() { + AtomicReference value = new AtomicReference<>(); + AtomicReference result = new AtomicReference<>(); + AtomicReference caller = new AtomicReference<>(); + + var ctxt = VertxContext.getOrCreateDuplicatedContext(vertx); + ctxt.runOnContext(x -> { + caller.set(Thread.currentThread().getName()); + Arc.container().requestContext().activate(); + value.set(uuidGenerator.get()); + aiService.helloNonBlocking("abc", "hiNonBlocking") + .collect().asList().map(l -> String.join(" ", l)) + .subscribeAsCompletionStage().thenAccept(result::set) + .whenComplete((r, t) -> Arc.container().requestContext().deactivate()); + }); + + Awaitility.await().until(() -> result.get() != null); + assertThat(result.get()).contains(value.get(), caller.get()); + } + + @Test + @ActivateRequestContext + void testNonBlockingToolInvocationFromEventLoopWhenWeSwitchToWorkerThread() { + AtomicReference value = new AtomicReference<>(); + AtomicReference result = new AtomicReference<>(); + AtomicReference caller = new AtomicReference<>(); + + var ctxt = VertxContext.getOrCreateDuplicatedContext(vertx); + ctxt.runOnContext(x -> { + caller.set(Thread.currentThread().getName()); + Arc.container().requestContext().activate(); + value.set(uuidGenerator.get()); + aiService.helloNonBlockingWithSwitch("abc", "hiNonBlocking") + .collect().asList().map(l -> String.join(" ", l)) + .subscribeAsCompletionStage().thenAccept(result::set) + .whenComplete((r, t) -> Arc.container().requestContext().deactivate()); + + }); + + Awaitility.await().until(() -> result.get() != null); + assertThat(result.get()).contains(value.get(), "executor-thread") + .doesNotContain(caller.get()); + } + + @Test + @RunOnVertxContext + @ActivateRequestContext + @EnabledForJreRange(min = JRE.JAVA_21) + void testNonBlockingToolInvocationFromVirtualThread() throws ExecutionException, InterruptedException { + String uuid = uuidGenerator.get(); + AtomicReference caller = new AtomicReference<>(); + var r = VirtualThreadsRecorder.getCurrent().submit(() -> { + caller.set(Thread.currentThread().getName()); + return aiService.helloNonBlocking("abc", "hiNonBlocking") + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + }).get(); + + // The blocking tool is executed on the same thread + assertThat(r).contains(uuid, "quarkus-virtual-thread-") + .contains(caller.get()); + } + + @Test + @ActivateRequestContext + void testUniToolInvocationFromWorkerThread() { + String uuid = uuidGenerator.get(); + var r = aiService.helloUni("abc", "hiUni") + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + assertThat(r).contains(uuid, Thread.currentThread().getName()); // We are invoked on the same thread + } + + @Test + @ActivateRequestContext + void testUniToolInvocationFromEventLoop() { + AtomicReference value = new AtomicReference<>(); + AtomicReference result = new AtomicReference<>(); + var ctxt = VertxContext.getOrCreateDuplicatedContext(vertx); + ctxt.runOnContext(x -> { + Arc.container().requestContext().activate(); + value.set(uuidGenerator.get()); + aiService.helloUni("abc", "hiUni") + .collect().asList().map(l -> String.join(" ", l)) + .subscribeAsCompletionStage() + .thenAccept(result::set) + .whenComplete((r, t) -> Arc.container().requestContext().deactivate()); + + }); + + Awaitility.await().until(() -> result.get() != null); + assertThat(result.get()).contains(value.get(), "executor-thread"); + } + + @Test + @RunOnVertxContext + @ActivateRequestContext + @EnabledForJreRange(min = JRE.JAVA_21) + void testUniToolInvocationFromVirtualThread() throws ExecutionException, InterruptedException { + String uuid = uuidGenerator.get(); + AtomicReference caller = new AtomicReference<>(); + var r = VirtualThreadsRecorder.getCurrent().submit(() -> { + caller.set(Thread.currentThread().getName()); + return aiService.helloUni("abc", "hiUni") + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + }).get(); + + // The blocking tool is executed on the same thread (synchronous emission) + assertThat(r).contains(uuid, "quarkus-virtual-thread-") + .contains(caller.get()); + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_21) + @RunOnVertxContext(runOnEventLoop = false) + @ActivateRequestContext + void testToolInvocationOnVirtualThread() { + String uuid = uuidGenerator.get(); + var r = aiService.helloVirtualTools("abc", "hiVirtualThread") + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + assertThat(r).contains(uuid, "quarkus-virtual-thread-"); + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_21) + @RunOnVertxContext + @ActivateRequestContext + void testToolInvocationOnVirtualThreadFromVirtualThread() throws ExecutionException, InterruptedException { + String uuid = uuidGenerator.get(); + AtomicReference caller = new AtomicReference<>(); + var r = VirtualThreadsRecorder.getCurrent().submit(() -> { + caller.set(Thread.currentThread().getName()); + return aiService.helloVirtualTools("abc", "hiVirtualThread") + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + }).get(); + + // At the moment, we create a virtual thread every time. + assertThat(r).contains(uuid, "quarkus-virtual-thread-") + .doesNotContain(caller.get()); + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_21) + @ActivateRequestContext + void testToolInvocationOnVirtualThreadFromEventLoop() { + AtomicReference value = new AtomicReference<>(); + AtomicReference result = new AtomicReference<>(); + var ctxt = VertxContext.getOrCreateDuplicatedContext(vertx); + ctxt.runOnContext(x -> { + Arc.container().requestContext().activate(); + value.set(uuidGenerator.get()); + aiService.helloVirtualTools("abc", "hiVirtualThread") + .collect().asList().map(l -> String.join(" ", l)) + .subscribeAsCompletionStage().thenAccept(result::set) + .whenComplete((r, t) -> Arc.container().requestContext().deactivate()); + }); + + Awaitility.await().until(() -> result.get() != null); + assertThat(result.get()).contains(value.get(), "quarkus-virtual-thread-"); + } + + @RegisterAiService(streamingChatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @ToolBox(BlockingTool.class) + Multi hello(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId); + + @ToolBox(NonBlockingTool.class) + Multi helloNonBlocking(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId); + + @ToolBox({ NonBlockingTool.class, BlockingTool.class }) + Multi helloNonBlockingWithSwitch(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId); + + @ToolBox(UniTool.class) + Multi helloUni(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId); + + @ToolBox(VirtualTool.class) + Multi helloVirtualTools(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId); + } + + @Singleton + public static class BlockingTool { + @Inject + UUIDGenerator uuidGenerator; + + @Tool + public String hi() { + return uuidGenerator.get() + " " + Thread.currentThread(); + } + } + + @Singleton + public static class NonBlockingTool { + @Inject + UUIDGenerator uuidGenerator; + + @Tool + @NonBlocking + public String hiNonBlocking() { + return uuidGenerator.get() + " " + Thread.currentThread(); + } + } + + @Singleton + public static class UniTool { + @Inject + UUIDGenerator uuidGenerator; + + @Tool + public Uni hiUni() { + return Uni.createFrom().item(() -> uuidGenerator.get() + " " + Thread.currentThread()); + } + } + + @Singleton + public static class VirtualTool { + + @Inject + UUIDGenerator uuidGenerator; + + @Tool + @RunOnVirtualThread + public String hiVirtualThread() { + return uuidGenerator.get() + " " + Thread.currentThread(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public StreamingChatLanguageModel get() { + return new MyChatModel(); + } + } + + public static class MyChatModel implements StreamingChatLanguageModel { + + @Override + public void generate(List messages, StreamingResponseHandler handler) { + throw new UnsupportedOperationException(); + } + + @Override + public void generate(List messages, List toolSpecifications, + StreamingResponseHandler handler) { + if (messages.size() == 1) { + // Only the user message, extract the tool id from it + String text = ((dev.langchain4j.data.message.UserMessage) messages.get(0)).singleText(); + // Only the user message + handler.onComplete(new Response<>(new AiMessage("cannot be blank", List.of(ToolExecutionRequest.builder() + .id("my-tool-" + text) + .name(text) + .arguments("{}") + .build())), new TokenUsage(0, 0), FinishReason.TOOL_EXECUTION)); + } else if (messages.size() == 3) { + // user -> tool request -> tool response + ToolExecutionResultMessage last = (ToolExecutionResultMessage) Lists.last(messages); + handler.onNext("response: "); + handler.onNext(last.text()); + handler.onComplete(new Response<>(new AiMessage(""), new TokenUsage(0, 0), FinishReason.STOP)); + + } else { + handler.onError(new RuntimeException("Invalid number of messages")); + } + } + + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return MessageWindowChatMemory.withMaxMessages(10); + } + }; + } + } + + @RequestScoped + public static class UUIDGenerator { + private final String uuid = UUID.randomUUID().toString(); + + public String get() { + return uuid; + } + } +} diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingTest.java new file mode 100644 index 000000000..0b8092018 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingTest.java @@ -0,0 +1,420 @@ +package io.quarkiverse.langchain4j.test.tools; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledForJreRange; +import org.junit.jupiter.api.condition.JRE; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.testcontainers.shaded.org.awaitility.Awaitility; + +import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.output.FinishReason; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.ToolBox; +import io.quarkiverse.langchain4j.test.Lists; +import io.quarkus.arc.Arc; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.virtual.threads.VirtualThreadsRecorder; +import io.smallrye.common.annotation.NonBlocking; +import io.smallrye.common.annotation.RunOnVirtualThread; +import io.smallrye.mutiny.Multi; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Vertx; + +public class ToolExecutionModelWithStreamingTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses(MyAiService.class, Lists.class)); + + @Inject + MyAiService aiService; + + @Inject + Vertx vertx; + + @Test + @ActivateRequestContext + void testBlockingToolInvocationFromWorkerThread() { + String uuid = UUID.randomUUID().toString(); + var r = aiService.hello("abc", "hi - " + uuid) + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + assertThat(r).contains(uuid, Thread.currentThread().getName()); // We are invoked on the same thread + } + + @Test + @ActivateRequestContext + void testBlockingToolInvocationFromEventLoop() { + String uuid = UUID.randomUUID().toString(); + AtomicReference failure = new AtomicReference<>(); + AtomicReference result = new AtomicReference<>(); + vertx.getOrCreateContext().runOnContext(x -> { + try { + Arc.container().requestContext().activate(); + aiService.hello("abc", "hi - " + uuid) + .collect().asList().map(l -> String.join(" ", l)) + .subscribeAsCompletionStage() + .thenAccept(result::set); + } catch (IllegalStateException e) { + failure.set(e); + } finally { + Arc.container().requestContext().deactivate(); + } + }); + + // We would automatically detect this case and switch to a worker thread at subscription time. + + Awaitility.await().until(() -> failure.get() != null || result.get() != null); + assertThat(failure.get()).isNull(); + assertThat(result.get()).doesNotContain("event", "loop") + .contains(uuid, "executor-thread"); + } + + @Test + @ActivateRequestContext + @EnabledForJreRange(min = JRE.JAVA_21) + void testBlockingToolInvocationFromVirtualThread() throws ExecutionException, InterruptedException { + String uuid = UUID.randomUUID().toString(); + AtomicReference caller = new AtomicReference<>(); + var r = VirtualThreadsRecorder.getCurrent().submit(() -> { + try { + Arc.container().requestContext().activate(); + caller.set(Thread.currentThread().getName()); + return aiService.hello("abc", "hi - " + uuid) + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + } finally { + Arc.container().requestContext().deactivate(); + } + }).get(); + + // The blocking tool is executed on the same thread + assertThat(r).contains(uuid, "quarkus-virtual-thread-") + .contains(caller.get()); + } + + @Test + @ActivateRequestContext + void testNonBlockingToolInvocationFromWorkerThread() { + String uuid = UUID.randomUUID().toString(); + var r = aiService.helloNonBlocking("abc", "hiNonBlocking - " + uuid) + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + assertThat(r).contains(uuid, Thread.currentThread().getName()); // We are invoked on the same thread + } + + @Test + @ActivateRequestContext + void testNonBlockingToolInvocationFromEventLoop() { + String uuid = UUID.randomUUID().toString(); + AtomicReference result = new AtomicReference<>(); + AtomicReference caller = new AtomicReference<>(); + + vertx.getOrCreateContext().runOnContext(x -> { + try { + caller.set(Thread.currentThread().getName()); + Arc.container().requestContext().activate(); + aiService.helloNonBlocking("abc", "hiNonBlocking - " + uuid) + .collect().asList().map(l -> String.join(" ", l)) + .subscribeAsCompletionStage().thenAccept(result::set); + } finally { + Arc.container().requestContext().deactivate(); + } + }); + + Awaitility.await().until(() -> result.get() != null); + assertThat(result.get()).contains(uuid, caller.get()); + } + + @Test + @ActivateRequestContext + void testNonBlockingToolInvocationFromEventLoopWhenWeSwitchToWorkerThread() { + String uuid = UUID.randomUUID().toString(); + AtomicReference result = new AtomicReference<>(); + AtomicReference caller = new AtomicReference<>(); + + vertx.getOrCreateContext().runOnContext(x -> { + try { + caller.set(Thread.currentThread().getName()); + Arc.container().requestContext().activate(); + aiService.helloNonBlockingWithSwitch("abc", "hiNonBlocking - " + uuid) + .collect().asList().map(l -> String.join(" ", l)) + .subscribeAsCompletionStage().thenAccept(result::set); + } finally { + Arc.container().requestContext().deactivate(); + } + }); + + Awaitility.await().until(() -> result.get() != null); + assertThat(result.get()).contains(uuid, "executor-thread") + .doesNotContain(caller.get()); + } + + @Test + @ActivateRequestContext + @EnabledForJreRange(min = JRE.JAVA_21) + void testNonBlockingToolInvocationFromVirtualThread() throws ExecutionException, InterruptedException { + String uuid = UUID.randomUUID().toString(); + AtomicReference caller = new AtomicReference<>(); + var r = VirtualThreadsRecorder.getCurrent().submit(() -> { + try { + Arc.container().requestContext().activate(); + caller.set(Thread.currentThread().getName()); + return aiService.helloNonBlocking("abc", "hiNonBlocking - " + uuid) + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + } finally { + Arc.container().requestContext().deactivate(); + } + }).get(); + + // The blocking tool is executed on the same thread + assertThat(r).contains(uuid, "quarkus-virtual-thread-") + .contains(caller.get()); + } + + @Test + @ActivateRequestContext + void testUniToolInvocationFromWorkerThread() { + String uuid = UUID.randomUUID().toString(); + var r = aiService.helloUni("abc", "hiUni - " + uuid) + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + assertThat(r).contains(uuid, Thread.currentThread().getName()); // We are invoked on the same thread + } + + @Test + @ActivateRequestContext + void testUniToolInvocationFromEventLoop() { + String uuid = UUID.randomUUID().toString(); + AtomicReference failure = new AtomicReference<>(); + AtomicReference result = new AtomicReference<>(); + vertx.getOrCreateContext().runOnContext(x -> { + try { + Arc.container().requestContext().activate(); + aiService.helloUni("abc", "hiUni - " + uuid) + .collect().asList().map(l -> String.join(" ", l)) + .subscribeAsCompletionStage() + .thenAccept(result::set); + } catch (Exception e) { + failure.set(e); + } finally { + Arc.container().requestContext().deactivate(); + } + }); + + Awaitility.await().until(() -> failure.get() != null || result.get() != null); + assertThat(failure.get()).isNull(); + assertThat(result.get()).contains(uuid, "executor-thread"); + } + + @Test + @ActivateRequestContext + @EnabledForJreRange(min = JRE.JAVA_21) + void testUniToolInvocationFromVirtualThread() throws ExecutionException, InterruptedException { + String uuid = UUID.randomUUID().toString(); + AtomicReference caller = new AtomicReference<>(); + var r = VirtualThreadsRecorder.getCurrent().submit(() -> { + try { + Arc.container().requestContext().activate(); + caller.set(Thread.currentThread().getName()); + return aiService.helloUni("abc", "hiUni - " + uuid) + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + } finally { + Arc.container().requestContext().deactivate(); + } + }).get(); + + // The blocking tool is executed on the same thread (synchronous emission) + assertThat(r).contains(uuid, "quarkus-virtual-thread-") + .contains(caller.get()); + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_21) + @ActivateRequestContext + void testToolInvocationOnVirtualThread() { + String uuid = UUID.randomUUID().toString(); + var r = aiService.helloVirtualTools("abc", "hiVirtualThread - " + uuid) + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + assertThat(r).contains(uuid, "quarkus-virtual-thread-"); + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_21) + void testToolInvocationOnVirtualThreadFromVirtualThread() throws ExecutionException, InterruptedException { + String uuid = UUID.randomUUID().toString(); + AtomicReference caller = new AtomicReference<>(); + var r = VirtualThreadsRecorder.getCurrent().submit(() -> { + try { + Arc.container().requestContext().activate(); + caller.set(Thread.currentThread().getName()); + return aiService.helloVirtualTools("abc", "hiVirtualThread - " + uuid) + .collect().asList().map(l -> String.join(" ", l)).await().indefinitely(); + } finally { + Arc.container().requestContext().deactivate(); + } + }).get(); + + // At the moment, we create a virtual thread every time. + assertThat(r).contains(uuid, "quarkus-virtual-thread-") + .doesNotContain(caller.get()); + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_21) + void testToolInvocationOnVirtualThreadFromEventLoop() { + String uuid = UUID.randomUUID().toString(); + AtomicReference failure = new AtomicReference<>(); + AtomicReference result = new AtomicReference<>(); + vertx.getOrCreateContext().runOnContext(x -> { + try { + Arc.container().requestContext().activate(); + aiService.helloVirtualTools("abc", "hiVirtualThread - " + uuid) + .collect().asList().map(l -> String.join(" ", l)) + .subscribeAsCompletionStage().thenAccept(result::set); + } catch (IllegalStateException e) { + failure.set(e); + } finally { + Arc.container().requestContext().deactivate(); + } + }); + + Awaitility.await().until(() -> failure.get() != null || result.get() != null); + assertThat(failure.get()).isNull(); + assertThat(result.get()).contains(uuid, "quarkus-virtual-thread-"); + } + + @RegisterAiService(streamingChatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @ToolBox(BlockingTool.class) + Multi hello(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId); + + @ToolBox(NonBlockingTool.class) + Multi helloNonBlocking(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId); + + @ToolBox({ NonBlockingTool.class, BlockingTool.class }) + Multi helloNonBlockingWithSwitch(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId); + + @ToolBox(UniTool.class) + Multi helloUni(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId); + + @ToolBox(VirtualTool.class) + Multi helloVirtualTools(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId); + } + + @Singleton + public static class BlockingTool { + @Tool + public String hi(String m) { + return m + " " + Thread.currentThread(); + } + } + + @Singleton + public static class NonBlockingTool { + @Tool + @NonBlocking + public String hiNonBlocking(String m) { + return m + " " + Thread.currentThread(); + } + } + + @Singleton + public static class UniTool { + @Tool + public Uni hiUni(String m) { + return Uni.createFrom().item(() -> m + " " + Thread.currentThread()); + } + } + + @Singleton + public static class VirtualTool { + + @Tool + @RunOnVirtualThread + public String hiVirtualThread(String m) { + return m + " " + Thread.currentThread(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public StreamingChatLanguageModel get() { + return new MyChatModel(); + } + } + + public static class MyChatModel implements StreamingChatLanguageModel { + + @Override + public void generate(List messages, StreamingResponseHandler handler) { + throw new UnsupportedOperationException(); + } + + @Override + public void generate(List messages, List toolSpecifications, + StreamingResponseHandler handler) { + if (messages.size() == 1) { + // Only the user message, extract the tool id from it + String text = ((dev.langchain4j.data.message.UserMessage) messages.get(0)).singleText(); + var segments = text.split(" - "); + var toolId = segments[0]; + var content = segments[1]; + // Only the user message + handler.onComplete(new Response<>(new AiMessage("cannot be blank", List.of(ToolExecutionRequest.builder() + .id("my-tool-" + toolId) + .name(toolId) + .arguments("{\"m\":\"" + content + "\"}") + .build())), new TokenUsage(0, 0), FinishReason.TOOL_EXECUTION)); + } else if (messages.size() == 3) { + // user -> tool request -> tool response + ToolExecutionResultMessage last = (ToolExecutionResultMessage) Lists.last(messages); + handler.onNext("response: "); + handler.onNext(last.text()); + handler.onComplete(new Response<>(new AiMessage(""), new TokenUsage(0, 0), FinishReason.STOP)); + + } else { + handler.onError(new RuntimeException("Invalid number of messages: " + messages.size())); + } + } + + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return MessageWindowChatMemory.withMaxMessages(10); + } + }; + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/ToolsRecorder.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/ToolsRecorder.java index 61b63a6af..5e2a4720b 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/ToolsRecorder.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/ToolsRecorder.java @@ -51,7 +51,7 @@ public static void populateToolMetadata(List objectsWithTools, List guardrailsMaxRetry; + private final boolean switchToWorkerThread; @RecordableConstructor public AiServiceMethodCreateInfo(String interfaceName, String methodName, @@ -69,6 +70,7 @@ public AiServiceMethodCreateInfo(String interfaceName, String methodName, Optional spanInfo, ResponseSchemaInfo responseSchemaInfo, List toolClassNames, + boolean switchToWorkerThread, List inputGuardrailsClassNames, List outputGuardrailsClassNames, String outputTokenAccumulatorClassName) { @@ -101,6 +103,7 @@ public Integer get() { .orElse(GuardrailsConfig.MAX_RETRIES_DEFAULT); } }); + this.switchToWorkerThread = switchToWorkerThread; } public String getInterfaceName() { @@ -202,6 +205,10 @@ public String getUserMessageTemplate() { return userMessageTemplateOpt.orElse(EMPTY); } + public boolean isSwitchToWorkerThread() { + return switchToWorkerThread; + } + public record UserMessageInfo(Optional template, Optional paramPosition, Optional userNameParamPosition, diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java index a06cb9880..111af826c 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java @@ -25,7 +25,6 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Flow; import java.util.concurrent.Future; -import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; @@ -70,11 +69,13 @@ import io.quarkiverse.langchain4j.runtime.QuarkusServiceOutputParser; import io.quarkiverse.langchain4j.runtime.ResponseSchemaUtil; import io.quarkiverse.langchain4j.spi.DefaultMemoryIdProvider; +import io.smallrye.common.vertx.VertxContext; import io.smallrye.mutiny.Multi; import io.smallrye.mutiny.infrastructure.Infrastructure; import io.smallrye.mutiny.operators.AbstractMulti; import io.smallrye.mutiny.operators.multi.processors.UnicastProcessor; import io.smallrye.mutiny.subscription.MultiSubscriber; +import io.vertx.core.Context; /** * Provides the basic building blocks that the generated Interface methods call into @@ -192,6 +193,7 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob augmentationResult = context.retrievalAugmentor.augment(augmentationRequest); userMessage = (UserMessage) augmentationResult.chatMessage(); } else { + // TODO duplicated context propagation. // this a special case where we can't block, so we need to delegate the retrieval augmentation to a worker pool CompletableFuture augmentationResultCF = CompletableFuture.supplyAsync(new Supplier<>() { @Override @@ -210,7 +212,8 @@ public Flow.Publisher apply(AugmentationResult ar) { context.chatMemory(memoryId), ar, templateVariables); List messagesToSend = messagesToSend(augmentedUserMessage, needsMemorySeed); return new TokenStreamMulti(messagesToSend, effectiveToolSpecifications, - finalToolExecutors, ar.contents(), context, memoryId); + finalToolExecutors, ar.contents(), context, memoryId, + methodCreateInfo.isSwitchToWorkerThread()); } private List messagesToSend(ChatMessage augmentedUserMessage, @@ -261,12 +264,14 @@ private List messagesToSend(ChatMessage augmentedUserMessage, if (methodCreateInfo.getOutputGuardrailsClassNames().isEmpty()) { return new TokenStreamMulti(messagesToSend, toolSpecifications, toolExecutors, - (augmentationResult != null ? augmentationResult.contents() : null), context, memoryId); + (augmentationResult != null ? augmentationResult.contents() : null), context, memoryId, + methodCreateInfo.isSwitchToWorkerThread()); } var actualAugmentationResult = augmentationResult; return new TokenStreamMulti(messagesToSend, toolSpecifications, toolExecutors, - (augmentationResult != null ? augmentationResult.contents() : null), context, memoryId) + (augmentationResult != null ? augmentationResult.contents() : null), context, memoryId, + methodCreateInfo.isSwitchToWorkerThread()) .plug(s -> GuardrailsSupport.accumulate(s, methodCreateInfo)) .map(chunk -> { OutputGuardrailResult result; @@ -761,10 +766,11 @@ private static class TokenStreamMulti extends AbstractMulti implements M private final List contents; private final QuarkusAiServiceContext context; private final Object memoryId; + private final boolean mustSwitchToWorkerThread; public TokenStreamMulti(List messagesToSend, List toolSpecifications, Map toolExecutors, - List contents, QuarkusAiServiceContext context, Object memoryId) { + List contents, QuarkusAiServiceContext context, Object memoryId, boolean mustSwitchToWorkerThread) { // We need to pass and store the parameters to the constructor because we need to re-create a stream on every subscription. this.messagesToSend = messagesToSend; this.toolSpecifications = toolSpecifications; @@ -772,24 +778,42 @@ public TokenStreamMulti(List messagesToSend, List subscriber) { UnicastProcessor processor = UnicastProcessor.create(); processor.subscribe(subscriber); - var stream = new AiServiceTokenStream(messagesToSend, toolSpecifications, - toolsExecutors, contents, context, memoryId); - stream + + createTokenStream(processor); + } + + private void createTokenStream(UnicastProcessor processor) { + Context ctxt = null; + if (mustSwitchToWorkerThread) { + // we create or retrieve the current context, to use `executeBlocking` when required. + ctxt = VertxContext.getOrCreateDuplicatedContext(); + } + + var stream = new QuarkusAiServiceTokenStream(messagesToSend, toolSpecifications, + toolsExecutors, contents, context, memoryId, ctxt, mustSwitchToWorkerThread); + TokenStream tokenStream = stream .onNext(processor::onNext) - .onComplete(new Consumer<>() { - @Override - public void accept(Response message) { - processor.onComplete(); - } - }) - .onError(processor::onError) - .start(); + .onComplete(message -> processor.onComplete()) + .onError(processor::onError); + // This is equivalent to "run subscription on worker thread" + if (mustSwitchToWorkerThread && Context.isOnEventLoopThread()) { + ctxt.executeBlocking(new Callable() { + @Override + public Void call() { + tokenStream.start(); + return null; + } + }); + } else { + tokenStream.start(); + } } } } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceStreamingResponseHandler.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceStreamingResponseHandler.java new file mode 100644 index 000000000..b76b98d04 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceStreamingResponseHandler.java @@ -0,0 +1,176 @@ +package io.quarkiverse.langchain4j.runtime.aiservice; + +import static dev.langchain4j.internal.Utils.copyIfNotNull; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.function.Consumer; + +import org.jboss.logging.Logger; + +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import dev.langchain4j.service.AiServiceContext; +import dev.langchain4j.service.tool.ToolExecutor; +import io.smallrye.mutiny.infrastructure.Infrastructure; +import io.vertx.core.Context; + +/** + * A {@link StreamingResponseHandler} implementation for Quarkus. + * The main difference with the upstream implementation is the thread switch when receiving the `completion` event + * when there is tool execution requests. + */ +public class QuarkusAiServiceStreamingResponseHandler implements StreamingResponseHandler { + + private final Logger log = Logger.getLogger(QuarkusAiServiceStreamingResponseHandler.class); + + private final AiServiceContext context; + private final Object memoryId; + + private final Consumer tokenHandler; + private final Consumer> completionHandler; + private final Consumer errorHandler; + + private final List temporaryMemory; + private final TokenUsage tokenUsage; + + private final List toolSpecifications; + private final Map toolExecutors; + private final Context executionContext; + private final boolean mustSwitchToWorkerThread; + + QuarkusAiServiceStreamingResponseHandler(AiServiceContext context, + Object memoryId, + Consumer tokenHandler, + Consumer> completionHandler, + Consumer errorHandler, + List temporaryMemory, + TokenUsage tokenUsage, + List toolSpecifications, + Map toolExecutors, boolean mustSwitchToWorkerThread, Context cxtx) { + this.context = ensureNotNull(context, "context"); + this.memoryId = ensureNotNull(memoryId, "memoryId"); + + this.tokenHandler = ensureNotNull(tokenHandler, "tokenHandler"); + this.completionHandler = completionHandler; + this.errorHandler = errorHandler; + + this.temporaryMemory = new ArrayList<>(temporaryMemory); + this.tokenUsage = ensureNotNull(tokenUsage, "tokenUsage"); + + this.toolSpecifications = copyIfNotNull(toolSpecifications); + this.toolExecutors = copyIfNotNull(toolExecutors); + + this.mustSwitchToWorkerThread = mustSwitchToWorkerThread; + this.executionContext = cxtx; + } + + @Override + public void onNext(String token) { + tokenHandler.accept(token); + } + + private void executeTools(Runnable runnable) { + if (mustSwitchToWorkerThread && Context.isOnEventLoopThread()) { + if (executionContext != null) { + executionContext.executeBlocking(new Callable() { + @Override + public Object call() { + runnable.run(); + return null; + } + }); + } else { + // We do not have a context, switching to worker thread. + Infrastructure.getDefaultWorkerPool().execute(runnable); + } + } else { + runnable.run(); + } + } + + @Override + public void onComplete(Response response) { + + AiMessage aiMessage = response.content(); + addToMemory(aiMessage); + + if (aiMessage.hasToolExecutionRequests()) { + // Tools execution may block the caller thread. When the caller thread is the event loop thread, and + // when tools have been detected to be potentially blocking, we need to switch to a worker thread. + executeTools(new Runnable() { + @Override + public void run() { + for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) { + String toolName = toolExecutionRequest.name(); + ToolExecutor toolExecutor = toolExecutors.get(toolName); + String toolExecutionResult = toolExecutor.execute(toolExecutionRequest, memoryId); + ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from( + toolExecutionRequest, + toolExecutionResult); + QuarkusAiServiceStreamingResponseHandler.this.addToMemory(toolExecutionResultMessage); + } + + context.streamingChatModel.generate( + QuarkusAiServiceStreamingResponseHandler.this.messagesToSend(memoryId), + toolSpecifications, + new QuarkusAiServiceStreamingResponseHandler( + context, + memoryId, + tokenHandler, + completionHandler, + errorHandler, + temporaryMemory, + TokenUsage.sum(tokenUsage, response.tokenUsage()), + toolSpecifications, + toolExecutors, + mustSwitchToWorkerThread, executionContext)); + } + }); + } else { + if (completionHandler != null) { + completionHandler.accept(Response.from( + aiMessage, + TokenUsage.sum(tokenUsage, response.tokenUsage()), + response.finishReason())); + } + } + } + + private void addToMemory(ChatMessage chatMessage) { + if (context.hasChatMemory()) { + context.chatMemory(memoryId).add(chatMessage); + } else { + temporaryMemory.add(chatMessage); + } + } + + private List messagesToSend(Object memoryId) { + return context.hasChatMemory() + ? context.chatMemory(memoryId).messages() + : temporaryMemory; + } + + @Override + public void onError(Throwable error) { + if (errorHandler != null) { + try { + errorHandler.accept(error); + } catch (Exception e) { + log.error("While handling the following error...", error); + log.error("...the following error happened", e); + } + } else { + log.warn("Ignored error", error); + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceTokenStream.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceTokenStream.java new file mode 100644 index 000000000..c81f35315 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceTokenStream.java @@ -0,0 +1,158 @@ +package io.quarkiverse.langchain4j.runtime.aiservice; + +import static dev.langchain4j.internal.Utils.copyIfNotNull; +import static dev.langchain4j.internal.Utils.isNullOrEmpty; +import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +import static java.util.Collections.emptyList; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.exception.IllegalConfigurationException; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import dev.langchain4j.rag.content.Content; +import dev.langchain4j.service.AiServiceContext; +import dev.langchain4j.service.TokenStream; +import dev.langchain4j.service.tool.ToolExecutor; +import io.vertx.core.Context; + +/** + * An implementation of token stream for Quarkus. + * The only difference with the upstream implementation is the usage of the custom + * {@link QuarkusAiServiceStreamingResponseHandler} instead of the upstream one. + * It allows handling blocking tools execution, when we are invoked on the event loop. + */ +public class QuarkusAiServiceTokenStream implements TokenStream { + + private final List messages; + private final List toolSpecifications; + private final Map toolExecutors; + private final List retrievedContents; + private final AiServiceContext context; + private final Object memoryId; + private final Context cxtx; + private final boolean mustSwitchToWorkerThread; + + private Consumer tokenHandler; + private Consumer> contentsHandler; + private Consumer errorHandler; + private Consumer> completionHandler; + + private int onNextInvoked; + private int onCompleteInvoked; + private int onRetrievedInvoked; + private int onErrorInvoked; + private int ignoreErrorsInvoked; + + public QuarkusAiServiceTokenStream(List messages, + List toolSpecifications, + Map toolExecutors, + List retrievedContents, + AiServiceContext context, + Object memoryId, Context ctxt, boolean mustSwitchToWorkerThread) { + this.messages = ensureNotEmpty(messages, "messages"); + this.toolSpecifications = copyIfNotNull(toolSpecifications); + this.toolExecutors = copyIfNotNull(toolExecutors); + this.retrievedContents = retrievedContents; + this.context = ensureNotNull(context, "context"); + this.memoryId = ensureNotNull(memoryId, "memoryId"); + ensureNotNull(context.streamingChatModel, "streamingChatModel"); + this.cxtx = ctxt; // If set, it means we need to handle the context propagation. + this.mustSwitchToWorkerThread = mustSwitchToWorkerThread; // If true, we need to switch to a worker thread to execute tools. + } + + @Override + public TokenStream onNext(Consumer tokenHandler) { + this.tokenHandler = tokenHandler; + this.onNextInvoked++; + return this; + } + + @Override + public TokenStream onRetrieved(Consumer> contentsHandler) { + this.contentsHandler = contentsHandler; + this.onRetrievedInvoked++; + return this; + } + + @Override + public TokenStream onComplete(Consumer> completionHandler) { + this.completionHandler = completionHandler; + this.onCompleteInvoked++; + return this; + } + + @Override + public TokenStream onError(Consumer errorHandler) { + this.errorHandler = errorHandler; + this.onErrorInvoked++; + return this; + } + + @Override + public TokenStream ignoreErrors() { + this.errorHandler = null; + this.ignoreErrorsInvoked++; + return this; + } + + @Override + public void start() { + validateConfiguration(); + QuarkusAiServiceStreamingResponseHandler handler = new QuarkusAiServiceStreamingResponseHandler( + context, + memoryId, + tokenHandler, + completionHandler, + errorHandler, + initTemporaryMemory(context, messages), + new TokenUsage(), + toolSpecifications, + toolExecutors, + mustSwitchToWorkerThread, + cxtx); + + if (contentsHandler != null && retrievedContents != null) { + contentsHandler.accept(retrievedContents); + } + + if (isNullOrEmpty(toolSpecifications)) { + context.streamingChatModel.generate(messages, handler); + } else { + context.streamingChatModel.generate(messages, toolSpecifications, handler); + } + } + + private void validateConfiguration() { + if (onNextInvoked != 1) { + throw new IllegalConfigurationException("onNext must be invoked exactly 1 time"); + } + + if (onCompleteInvoked > 1) { + throw new IllegalConfigurationException("onComplete must be invoked at most 1 time"); + } + + if (onRetrievedInvoked > 1) { + throw new IllegalConfigurationException("onRetrieved must be invoked at most 1 time"); + } + + if (onErrorInvoked + ignoreErrorsInvoked != 1) { + throw new IllegalConfigurationException("One of onError or ignoreErrors must be invoked exactly 1 time"); + } + } + + private List initTemporaryMemory(AiServiceContext context, List messagesToSend) { + if (context.hasChatMemory()) { + return emptyList(); + } else { + return new ArrayList<>(messagesToSend); + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/devui/ChatJsonRPCService.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/devui/ChatJsonRPCService.java index 900f48221..662b9c922 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/devui/ChatJsonRPCService.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/devui/ChatJsonRPCService.java @@ -109,7 +109,7 @@ public ChatJsonRPCService(@All List models, // don't use Chat } QuarkusToolExecutor.Context executorContext = new QuarkusToolExecutor.Context(objectWithTool, methodCreateInfo.invokerClassName(), methodCreateInfo.methodName(), - methodCreateInfo.argumentMapperClassName()); + methodCreateInfo.argumentMapperClassName(), methodCreateInfo.executionModel()); toolExecutors.put(methodCreateInfo.toolSpecification().name(), toolExecutorFactory.create(executorContext)); toolSpecifications.add(methodCreateInfo.toolSpecification()); diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/QuarkusToolExecutor.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/QuarkusToolExecutor.java index 5f417ba77..48c0590ef 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/QuarkusToolExecutor.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/QuarkusToolExecutor.java @@ -4,6 +4,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.Map; +import java.util.concurrent.CompletionException; import java.util.function.BiFunction; import org.jboss.logging.Logger; @@ -15,6 +16,9 @@ import dev.langchain4j.service.tool.ToolExecutor; import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory; import io.quarkiverse.langchain4j.runtime.prompt.Mappable; +import io.quarkus.virtual.threads.VirtualThreadsRecorder; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Context; public class QuarkusToolExecutor implements ToolExecutor { @@ -22,7 +26,8 @@ public class QuarkusToolExecutor implements ToolExecutor { private final Context context; - public record Context(Object tool, String toolInvokerName, String methodName, String argumentMapperClassName) { + public record Context(Object tool, String toolInvokerName, String methodName, String argumentMapperClassName, + ToolMethodCreateInfo.ExecutionModel executionModel) { } public interface Wrapper { @@ -38,22 +43,67 @@ public QuarkusToolExecutor(Context context) { public String execute(ToolExecutionRequest toolExecutionRequest, Object memoryId) { log.debugv("About to execute {0}", toolExecutionRequest); - ToolInvoker invokerInstance = createInvokerInstance(); + // TODO Tools invocation are "imperative" + // TODO This method is called from the caller thread + // TODO So, we need to handle the dispatch here, depending on the caller thread and the tool invocation + // TODO Note that we need to return a String in an imperative manner. + // TODO We may have to check who's going to call this method from a non-blocking thread to handle the dispatch there. + ToolInvoker invokerInstance = createInvokerInstance(); Object[] params = prepareArguments(toolExecutionRequest, invokerInstance.methodMetadata(), memoryId); + // When required to block, we are invoked on a worker thread (stream with blocking tools). + switch (context.executionModel) { + case BLOCKING: + if (io.vertx.core.Context.isOnEventLoopThread()) { + throw new IllegalStateException("Cannot execute blocking tools on event loop thread"); + } + return invoke(params, invokerInstance); + case NON_BLOCKING: + return invoke(params, invokerInstance); + case VIRTUAL_THREAD: + if (io.vertx.core.Context.isOnEventLoopThread()) { + throw new IllegalStateException("Cannot execute virtual thread tools on event loop thread"); + } + try { + return VirtualThreadsRecorder.getCurrent().submit(() -> invoke(params, invokerInstance)) + .get(); + } catch (Exception e) { + if (e instanceof CompletionException) { + return e.getCause().getMessage(); + } + return e.getMessage(); + } + default: + throw new IllegalStateException("Unknown execution model: " + context.executionModel); + } + + } + + private String invoke(Object[] params, ToolInvoker invokerInstance) { try { if (log.isDebugEnabled()) { log.debugv("Attempting to invoke tool {0} with parameters {1}", context.tool, Arrays.toString(params)); } - Object invocationResult = invokerInstance.invoke(context.tool, - params); - String result = handleResult(invokerInstance, invocationResult); + Object invocationResult = invokerInstance.invoke(context.tool, params); + String result; + if (invocationResult instanceof Uni) { // TODO CS + if (io.vertx.core.Context.isOnEventLoopThread()) { + throw new IllegalStateException( + "Cannot execute tools returning Uni on event loop thread due to a tool executor limitation"); + } + result = handleResult(invokerInstance, ((Uni) invocationResult).await().indefinitely()); + } else { + result = handleResult(invokerInstance, invocationResult); + } log.debugv("Tool execution result: {0}", result); return result; } catch (Exception e) { if (e instanceof IllegalArgumentException) { throw (IllegalArgumentException) e; } + if (e instanceof IllegalStateException) { + throw (IllegalStateException) e; + } log.error("Error while executing tool '" + context.tool.getClass() + "'", e); return e.getMessage(); } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/ToolMethodCreateInfo.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/ToolMethodCreateInfo.java index d8580459f..182dd2c0d 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/ToolMethodCreateInfo.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/ToolMethodCreateInfo.java @@ -5,5 +5,13 @@ public record ToolMethodCreateInfo(String methodName, String invokerClassName, ToolSpecification toolSpecification, - String argumentMapperClassName) { + String argumentMapperClassName, + ExecutionModel executionModel) { + + public enum ExecutionModel { + BLOCKING, + NON_BLOCKING, + VIRTUAL_THREAD + } + } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/ToolSpanWrapper.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/ToolSpanWrapper.java index df92a5328..f3f6ba75e 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/ToolSpanWrapper.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/ToolSpanWrapper.java @@ -23,6 +23,7 @@ public String wrap(ToolExecutionRequest toolExecutionRequest, Object memoryId, BiFunction fun) { Span span = tracer.spanBuilder("langchain4j.tools." + toolExecutionRequest.name()).startSpan(); try (Scope scope = span.makeCurrent()) { + // TODO Handle async method here. return fun.apply(toolExecutionRequest, memoryId); } catch (Throwable t) { span.recordException(t); diff --git a/docs/modules/ROOT/pages/agent-and-tools.adoc b/docs/modules/ROOT/pages/agent-and-tools.adoc index 489d8b4b6..019c596ec 100644 --- a/docs/modules/ROOT/pages/agent-and-tools.adoc +++ b/docs/modules/ROOT/pages/agent-and-tools.adoc @@ -90,6 +90,87 @@ If a method of an AI Service needs to use tools other than the ones configured o In this case `@Toolbox` completely overrides `@RegisterAiService(tools=...)` ==== +== Tools execution model + +Tools can have different execution models: + +- _blocking_ - the tools execution blocks the caller thread +- _non-blocking_ - the tools execution is asynchronous and does not block the caller thread +- _run on a virtual threads - the tools execution is asynchronous and runs on a new virtual thread + +The execution model is configured using the `@Blocking`, `@NonBlocking` and `@RunOnVirtualThread` annotations. +In addition, tool methods returning `CompletionStage` or `Uni` are considered non-blocking (if no annotations are used). + +For example, the following tool is considered blocking: + +[source,java] +---- +@Tool("get the customer name for the given customerId") +public String getCustomerName(long id) { + return find("id", id).firstResult().name; +} +---- + +The following tool is considered non-blocking: + +[source,java] +---- +@Tool("add a and b") +@NonBlocking +int sum(int a, int b) { + return a + b; +} +---- + +This other tool is considered non-blocking as well: + +[source,java] +---- +@Tool("get the customer name for the given customerId") +public Uni getCustomerName(long id) { + // ... +} +---- + +The following tool runs on a virtual thread: + +[source,java] +---- +@Tool("get the customer name for the given customerId") +@RunOnVirtualThread +public String getCustomerName(long id) { + return find("id", id).firstResult().name; +} +---- + +=== Streaming + +The execution model is particularly important when using streamed response. +Indeed, streamed response are executed on the _event loop_, which cannot be blocked. +Thus, in this case, based on the execution model of the tool, Quarkus LangChain4J will automatically switch to a worker thread to execute the tool. +This mechanism allows to avoid blocking the _event loop_, while still allowing to use blocking tools (like database access). + +For example, let's imagine the following AI service method: + +[source,java] +---- +@UserMessage("...") +@ToolBox({TransactionRepository.class, CustomerRepository.class}) +Multi detectAmountFraudForCustomerStreamed(long customerId); +---- + +This method returns a stream of tokens. +Each token is emitted on the event loop. +The given tools are blocking (they access a database using a blocking manner). +Thus, Quarkus LangChain4J will automatically switch to a worker thread before invoking the tools, to avoid blocking the event loop. + +=== Request scope + +When the request scope is active when invoking the LLM, the tool execution is also bound to the request scope. +It means that you can propagate data to the tools execution, for example, a security context or a transaction context. + +The propagation also works with virtual thread or when Quarkus LangChain4J automatically switches to a worker thread. + == Dynamically Use Tools at Runtime By default, tools in Quarkus are static, defined during the build phase. diff --git a/integration-tests/openai/src/main/java/org/acme/example/openai/chat/ChatLanguageModelResource.java b/integration-tests/openai/src/main/java/org/acme/example/openai/chat/ChatLanguageModelResource.java index ae15843a5..0012a3779 100644 --- a/integration-tests/openai/src/main/java/org/acme/example/openai/chat/ChatLanguageModelResource.java +++ b/integration-tests/openai/src/main/java/org/acme/example/openai/chat/ChatLanguageModelResource.java @@ -116,13 +116,11 @@ public String memory() throws Exception { @Override public void onNext(String token) { - // System.out.print(token); sb.append(token); } @Override public void onComplete(Response response) { - // System.out.println("\n\ndone\n\n"); if (response != null) { futureRef.get().complete(response.content()); } else { @@ -136,7 +134,6 @@ public void onError(Throwable throwable) { } }; - // System.out.println("starting first\n\n"); streamingChatLanguageModel.generate(chatMemory.messages(), handler); AiMessage firstAiMessage = futureRef.get().get(60, TimeUnit.SECONDS); chatMemory.add(firstAiMessage); @@ -151,7 +148,6 @@ public void onError(Throwable throwable) { .append("\n[LLM]: "); futureRef.set(new CompletableFuture<>()); - // System.out.println("\n\n\nstarting second\n\n"); streamingChatLanguageModel.generate(chatMemory.messages(), handler); futureRef.get().get(60, TimeUnit.SECONDS); return sb.toString(); diff --git a/integration-tests/pom.xml b/integration-tests/pom.xml index 4362e44c0..75d6836de 100644 --- a/integration-tests/pom.xml +++ b/integration-tests/pom.xml @@ -26,6 +26,7 @@ rag-pgvector rag-pgvector-flyway easy-rag + tools diff --git a/integration-tests/tools/pom.xml b/integration-tests/tools/pom.xml new file mode 100644 index 000000000..e81e63022 --- /dev/null +++ b/integration-tests/tools/pom.xml @@ -0,0 +1,124 @@ + + + 4.0.0 + + io.quarkiverse.langchain4j + quarkus-langchain4j-integration-tests-parent + 999-SNAPSHOT + + quarkus-langchain4j-integration-test-tools + Quarkus LangChain4j - Integration Tests - Tools + + true + + + + + io.quarkus + quarkus-rest-jackson + + + io.quarkiverse.langchain4j + quarkus-langchain4j-watsonx + ${project.version} + + + io.quarkus + quarkus-junit5 + test + + + io.rest-assured + rest-assured + test + + + io.quarkus + quarkus-junit5-mockito + test + + + org.assertj + assertj-core + ${assertj.version} + test + + + io.quarkus + quarkus-devtools-testing + test + + + + + io.quarkiverse.langchain4j + quarkus-langchain4j-watsonx-deployment + ${project.version} + pom + test + + + * + * + + + + + + + + io.quarkus + quarkus-maven-plugin + + + + build + + + + + + maven-failsafe-plugin + + + + integration-test + verify + + + + ${project.build.directory}/${project.build.finalName}-runner + org.jboss.logmanager.LogManager + ${maven.home} + + + + + + + + + + native-image + + + native + + + + + + maven-surefire-plugin + + ${native.surefire.skip} + + + + + + false + native + + + + diff --git a/integration-tests/tools/src/main/java/org/acme/tools/AiService.java b/integration-tests/tools/src/main/java/org/acme/tools/AiService.java new file mode 100644 index 000000000..80c531308 --- /dev/null +++ b/integration-tests/tools/src/main/java/org/acme/tools/AiService.java @@ -0,0 +1,11 @@ +package org.acme.tools; + +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.ToolBox; + +@RegisterAiService +public interface AiService { + @ToolBox(Calculator.class) + public String chat(@UserMessage String message); +} diff --git a/integration-tests/tools/src/main/java/org/acme/tools/Calculator.java b/integration-tests/tools/src/main/java/org/acme/tools/Calculator.java new file mode 100644 index 000000000..4fd82b4e9 --- /dev/null +++ b/integration-tests/tools/src/main/java/org/acme/tools/Calculator.java @@ -0,0 +1,30 @@ +package org.acme.tools; + +import jakarta.enterprise.context.ApplicationScoped; + +import dev.langchain4j.agent.tool.Tool; +import io.smallrye.common.annotation.Blocking; +import io.smallrye.common.annotation.NonBlocking; +import io.smallrye.common.annotation.RunOnVirtualThread; + +@ApplicationScoped +public class Calculator { + + @Tool + @Blocking + public int blockingSum(int a, int b) { + return a + b; + } + + @Tool + @NonBlocking + public int nonBlockingSum(int a, int b) { + return a + b; + } + + @Tool + @RunOnVirtualThread + public int virtualSum(int a, int b) { + return a + b; + } +} diff --git a/integration-tests/tools/src/main/resources/application.properties b/integration-tests/tools/src/main/resources/application.properties new file mode 100644 index 000000000..28bf72df9 --- /dev/null +++ b/integration-tests/tools/src/main/resources/application.properties @@ -0,0 +1,4 @@ +quarkus.langchain4j.watsonx.base-url=https://toolstest.com +quarkus.langchain4j.watsonx.api-key=api-key +quarkus.langchain4j.watsonx.project-id=project-id +quarkus.langchain4j.watsonx.version=yyyy-mm-dd \ No newline at end of file diff --git a/integration-tests/tools/src/test/java/org/acme/tools/ToolsTest.java b/integration-tests/tools/src/test/java/org/acme/tools/ToolsTest.java new file mode 100644 index 000000000..ed7bac0d8 --- /dev/null +++ b/integration-tests/tools/src/test/java/org/acme/tools/ToolsTest.java @@ -0,0 +1,80 @@ +package org.acme.tools; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.anyList; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledForJreRange; +import org.junit.jupiter.api.condition.JRE; +import org.mockito.Mockito; + +import dev.langchain4j.agent.tool.ToolExecutionRequest; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.output.TokenUsage; +import io.quarkus.test.InjectMock; +import io.quarkus.test.junit.QuarkusTest; + +@QuarkusTest +@EnabledForJreRange(min = JRE.JAVA_21) +@SuppressWarnings("unchecked") +public class ToolsTest { + + @Inject + AiService aiService; + + @InjectMock + ChatLanguageModel model; + + @Test + void blockingSum() { + var toolExecution = create("blockingSum"); + + Mockito.when(model.generate(anyList(), anyList())) + .thenReturn( + Response.from(AiMessage.from(toolExecution), new TokenUsage(1)), + Response.from(AiMessage.from("The result is 2"), new TokenUsage(1))); + + assertEquals("The result is 2", aiService.chat("Execute 1 + 1")); + } + + @Test + void nonBlockingSum() { + var toolExecution = create("nonBlockingSum"); + + Mockito.when(model.generate(anyList(), anyList())) + .thenReturn( + Response.from(AiMessage.from(toolExecution), new TokenUsage(1)), + Response.from(AiMessage.from("The result is 2"), new TokenUsage(1))); + + assertEquals("The result is 2", aiService.chat("Execute 1 + 1")); + } + + @Test + void virtualThreadSum() { + var toolExecution = create("virtualSum"); + + Mockito.when(model.generate(anyList(), anyList())) + .thenReturn( + Response.from(AiMessage.from(toolExecution), new TokenUsage(1)), + Response.from(AiMessage.from("The result is 2"), new TokenUsage(1))); + + assertEquals("The result is 2", aiService.chat("Execute 1 + 1")); + } + + private ToolExecutionRequest create(String methodName) { + return ToolExecutionRequest.builder() + .id("1") + .name(methodName) + .arguments(""" + { + "a": 1, + "b": 1 + } + """) + .build(); + } +} diff --git a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiChatServiceTest.java b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiChatServiceTest.java index 8d77072d5..069771476 100644 --- a/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiChatServiceTest.java +++ b/model-providers/watsonx/deployment/src/test/java/io/quarkiverse/langchain4j/watsonx/deployment/AiChatServiceTest.java @@ -45,6 +45,7 @@ import io.quarkiverse.langchain4j.watsonx.runtime.config.ChatModelConfig; import io.quarkiverse.langchain4j.watsonx.runtime.config.LangChain4jWatsonxConfig; import io.quarkus.test.QuarkusUnitTest; +import io.smallrye.common.annotation.Blocking; import io.smallrye.mutiny.Multi; public class AiChatServiceTest extends WireMockAbstract { @@ -106,6 +107,7 @@ interface AIServiceWithTool { @Singleton static class Calculator { @Tool("Execute the sum of two numbers") + @Blocking public int sum(int first, int second) { return first + second; } diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java index 924cd8d09..a600376ac 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/WatsonxChatModel.java @@ -32,7 +32,6 @@ import io.quarkiverse.langchain4j.watsonx.bean.TextStreamingChatResponse; import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest; import io.smallrye.mutiny.Context; -import io.smallrye.mutiny.infrastructure.Infrastructure; public class WatsonxChatModel extends Watsonx implements ChatLanguageModel, StreamingChatLanguageModel, TokenCountEstimator { @@ -110,14 +109,8 @@ public void generate(List messages, List toolSpe context.put(TOOLS_CONTEXT, new ArrayList()); context.put(COMPLETE_MESSAGE_CONTEXT, new StringBuilder()); - var mutiny = client.streamingChat(request, version); - if (tools != null) { - // Today Langchain4j doesn't allow to use the async operation with tools. - // One idea might be to give to the developer the possibility to use the VirtualThread. - mutiny.emitOn(Infrastructure.getDefaultWorkerPool()); - } - - mutiny.subscribe() + client.streamingChat(request, version) + .subscribe() .with(context, new Consumer() { @Override diff --git a/samples/fraud-detection/pom.xml b/samples/fraud-detection/pom.xml index 061115b2a..79780cb88 100644 --- a/samples/fraud-detection/pom.xml +++ b/samples/fraud-detection/pom.xml @@ -18,7 +18,7 @@ 3.15.1 true 3.2.5 - 0.21.0.CR4 + 999-SNAPSHOT diff --git a/samples/fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionAi.java b/samples/fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionAi.java index b3b53b6f1..e7083bb5c 100644 --- a/samples/fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionAi.java +++ b/samples/fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionAi.java @@ -2,6 +2,7 @@ import java.time.temporal.ChronoUnit; +import io.smallrye.mutiny.Multi; import org.eclipse.microprofile.faulttolerance.Timeout; import dev.langchain4j.service.SystemMessage; @@ -36,6 +37,31 @@ public interface FraudDetectionAi { @Timeout(value = 2, unit = ChronoUnit.MINUTES) String detectAmountFraudForCustomer(long customerId); + @SystemMessage(""" + You are a bank account fraud detection AI. You have to detect frauds in transactions. + """) + @UserMessage(""" + Your task is to detect whether a fraud was committed for the customer {{customerId}}. + + To detect a fraud, perform the following actions: + 1 - Retrieve the name of the customer {{customerId}} + 2 - Retrieve the transactions for the customer {{customerId}} for the last 15 minutes. + 3 - Sum the amount of the all these transactions. Make sure the sum is correct. + 4 - If the amount is greater than 10000, a fraud is detected. + + Answer with a **single** JSON document containing: + - the customer name in the 'customer-name' key + - the computed sum in the 'total' key + - the 'fraud' key set to a boolean value indicating if a fraud was detected + - the 'transactions' key containing the list of transaction amounts + - the 'explanation' key containing a explanation of your answer, especially how the sum is computed. + - if there is a fraud, the 'email' key containing an email to the customer {{customerId}} to warn him about the fraud. The text must be formal and polite. It must ask the customer to contact the bank ASAP. + + Your response must be just the raw JSON document, without ```json, ``` or anything else. + """) + @Timeout(value = 2, unit = ChronoUnit.MINUTES) + Multi detectAmountFraudForCustomerStreamed(long customerId); + @SystemMessage(""" You are a bank account fraud detection AI. You have to detect frauds in transactions. """) diff --git a/samples/fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionResource.java b/samples/fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionResource.java index 7bed6df42..a3f4b3eff 100644 --- a/samples/fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionResource.java +++ b/samples/fraud-detection/src/main/java/io/quarkiverse/langchain4j/sample/FraudDetectionResource.java @@ -2,9 +2,12 @@ import java.util.List; +import io.smallrye.mutiny.Multi; import jakarta.ws.rs.GET; import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.core.MediaType; import org.jboss.resteasy.reactive.RestQuery; @Path("/fraud") @@ -30,6 +33,13 @@ public String detectBaseOnAmount(@RestQuery long customerId) { return service.detectAmountFraudForCustomer(customerId); } + @GET + @Path("/amount/streamed") + @Produces(MediaType.SERVER_SENT_EVENTS) + public Multi detectBaseOnAmountReactive(@RestQuery long customerId) { + return service.detectAmountFraudForCustomerStreamed(customerId); + } + @GET @Path("/transactions") public List list(@RestQuery long customerId) { diff --git a/tools/tavily/deployment/src/test/java/io/quarkiverse/langchain4j/tavily/test/TavilyTest.java b/tools/tavily/deployment/src/test/java/io/quarkiverse/langchain4j/tavily/test/TavilyTest.java index f05efc882..82fc92ea0 100644 --- a/tools/tavily/deployment/src/test/java/io/quarkiverse/langchain4j/tavily/test/TavilyTest.java +++ b/tools/tavily/deployment/src/test/java/io/quarkiverse/langchain4j/tavily/test/TavilyTest.java @@ -104,7 +104,6 @@ public void testSearch() { .build(); WebSearchResults result = webSearchEngine.search(searchRequest); LoggedRequest actualRequest = singleLoggedRequest(); - System.out.println(actualRequest); // verify the request assertEquals(actualRequest.getHeader("Accept"), "application/json");