diff --git a/core/deployment/pom.xml b/core/deployment/pom.xml
index 3153acca2..ab47f8411 100644
--- a/core/deployment/pom.xml
+++ b/core/deployment/pom.xml
@@ -102,6 +102,11 @@
+
+ io.quarkus
+ quarkus-test-vertx
+ test
+
diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java
index 25e497c84..30ee9a92b 100644
--- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java
+++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java
@@ -8,6 +8,7 @@
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.NO_RETRIEVAL_AUGMENTOR_SUPPLIER;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.NO_RETRIEVER;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.OUTPUT_GUARDRAILS;
+import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.REGISTER_AI_SERVICES;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.SEED_MEMORY;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.V;
import static io.quarkiverse.langchain4j.deployment.MethodParameterAsTemplateVariableAllowance.FORCE_ALLOW;
@@ -69,6 +70,7 @@
import io.quarkiverse.langchain4j.deployment.items.MethodParameterAllowedAnnotationsBuildItem;
import io.quarkiverse.langchain4j.deployment.items.MethodParameterIgnoredAnnotationsBuildItem;
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
+import io.quarkiverse.langchain4j.deployment.items.ToolMethodBuildItem;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator;
import io.quarkiverse.langchain4j.runtime.AiServicesRecorder;
@@ -123,6 +125,7 @@
import io.quarkus.gizmo.ResultHandle;
import io.quarkus.runtime.metrics.MetricsFactory;
import io.smallrye.mutiny.Multi;
+import io.smallrye.mutiny.Uni;
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class AiServicesProcessor {
@@ -760,6 +763,57 @@ public void markUsedOutputGuardRailsUnremovable(List
}
}
+ /**
+ * Because the tools execution uses an imperative API (`String execute(...)`) and uses the caller thread, we need
+ * to anticipate the need to dispatch the invocation on a worker thread.
+ * This is the case for AI service methods that returns `Uni`, `CompletionStage` and `Multi` (stream) and that uses
+ * tools returning `Uni`, `CompletionStage` or that are blocking.
+ * Basically, for "reactive AI service method, the switch is necessary except if all the tools are imperative (return `T`)
+ * and marked explicitly as blocking (using `@Blocking`).
+ *
+ * @param method the AI method
+ * @param tools the tools
+ */
+ public boolean detectAiServiceMethodThanNeedToBeDispatchedOnWorkerThread(
+ MethodInfo method,
+ List associatedTools,
+ List tools) {
+ boolean reactive = method.returnType().name().equals(DotNames.UNI)
+ || method.returnType().name().equals(DotNames.COMPLETION_STAGE)
+ || method.returnType().name().equals(DotNames.MULTI);
+
+ boolean requireSwitchToWorkerThread = false;
+
+ if (associatedTools.isEmpty()) {
+ // No tools, no need to dispatch
+ return false;
+ }
+
+ if (!reactive) {
+ // We are already on a thread we can block.
+ return false;
+ }
+
+ // We need to find if any of the tools that could be used by the method is requiring a blocking execution
+ for (String classname : associatedTools) {
+ // Look for the tool in the list of tools
+ boolean found = false;
+ for (ToolMethodBuildItem tool : tools) {
+ if (tool.getDeclaringClassName().equals(classname)) {
+ found = true;
+ if (tool.requiresSwitchToWorkerThread()) {
+ requireSwitchToWorkerThread = true;
+ break;
+ }
+ }
+ }
+ if (!found) {
+ throw new RuntimeException("No tools detected in " + classname);
+ }
+ }
+ return requireSwitchToWorkerThread;
+ }
+
@BuildStep
public void validateGuardrails(SynthesisFinishedBuildItem synthesisFinished,
List methods,
@@ -857,7 +911,8 @@ public void handleAiServices(
BuildProducer additionalBeanProducer,
BuildProducer unremovableBeanProducer,
Optional metricsCapability,
- Capabilities capabilities) {
+ Capabilities capabilities,
+ List tools) {
IndexView index = indexBuildItem.getIndex();
@@ -1026,7 +1081,8 @@ public void handleAiServices(
addOpenTelemetrySpan,
config.responseSchema(),
allowedPredicates,
- ignoredPredicates);
+ ignoredPredicates,
+ tools);
if (!methodCreateInfo.getToolClassNames().isEmpty()) {
unremovableBeanProducer.produce(UnremovableBeanBuildItem
.beanClassNames(methodCreateInfo.getToolClassNames().toArray(EMPTY_STRING_ARRAY)));
@@ -1068,7 +1124,9 @@ public void handleAiServices(
aiServicesMethodProducer.produce(new AiServicesMethodBuildItem(methodInfo,
methodCreateInfo.getInputGuardrailsClassNames(),
- methodCreateInfo.getOutputGuardrailsClassNames()));
+ methodCreateInfo.getOutputGuardrailsClassNames(),
+ gatherMethodToolClassNames(methodInfo),
+ methodCreateInfo));
}
}
@@ -1155,7 +1213,8 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
MethodInfo method, IndexView index, boolean addMicrometerMetrics,
boolean addOpenTelemetrySpans, boolean generateResponseSchema,
Collection> allowedPredicates,
- Collection> ignoredPredicates) {
+ Collection> ignoredPredicates,
+ List tools) {
validateReturnType(method);
boolean requiresModeration = method.hasAnnotation(LangChain4jDotNames.MODERATE);
@@ -1200,11 +1259,28 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
String accumulatorClassName = AiServicesMethodBuildItem.gatherAccumulator(method);
+ // Detect if tools execution may block the caller thread.
+ boolean switchToWorkerThread = detectIfToolExecutionRequiresAWorkerThread(method, tools, methodToolClassNames);
+
return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo,
userMessageInfo, memoryIdParamPosition, requiresModeration,
returnTypeSignature(method.returnType(), new TypeArgMapper(method.declaringClass(), index)),
- metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames, inputGuardrails,
- outputGuardrails, accumulatorClassName);
+ metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames, switchToWorkerThread,
+ inputGuardrails, outputGuardrails, accumulatorClassName);
+ }
+
+ private boolean detectIfToolExecutionRequiresAWorkerThread(MethodInfo method, List tools,
+ List methodToolClassNames) {
+ List allTools = new ArrayList<>(methodToolClassNames);
+ // We need to combine it with the tools that are registered globally - unfortunately, we don't have access to the AI service here, so, re-parsing.
+ AnnotationInstance annotation = method.declaringClass().annotation(REGISTER_AI_SERVICES);
+ if (annotation != null) {
+ AnnotationValue value = annotation.value("tools");
+ if (value != null) {
+ allTools.addAll(Arrays.stream(value.asClassArray()).map(t -> t.name().toString()).toList());
+ }
+ }
+ return detectAiServiceMethodThanNeedToBeDispatchedOnWorkerThread(method, allTools, tools);
}
private void validateReturnType(MethodInfo method) {
@@ -1688,11 +1764,17 @@ public static final class AiServicesMethodBuildItem extends MultiBuildItem {
private final MethodInfo methodInfo;
private final List outputGuardrails;
private final List inputGuardrails;
+ private final List tools;
+ private final AiServiceMethodCreateInfo methodCreateInfo;
- public AiServicesMethodBuildItem(MethodInfo methodInfo, List inputGuardrails, List outputGuardrails) {
+ public AiServicesMethodBuildItem(MethodInfo methodInfo, List inputGuardrails, List outputGuardrails,
+ List tools,
+ AiServiceMethodCreateInfo methodCreateInfo) {
this.methodInfo = methodInfo;
this.inputGuardrails = inputGuardrails;
this.outputGuardrails = outputGuardrails;
+ this.tools = tools;
+ this.methodCreateInfo = methodCreateInfo;
}
public List getOutputGuardrails() {
@@ -1707,6 +1789,10 @@ public MethodInfo getMethodInfo() {
return methodInfo;
}
+ public AiServiceMethodCreateInfo getMethodCreateInfo() {
+ return methodCreateInfo;
+ }
+
public static List gatherGuardrails(MethodInfo methodInfo, DotName annotation) {
List guardrails = new ArrayList<>();
AnnotationInstance instance = methodInfo.annotation(annotation);
diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java
index 395cb65dd..25d8571a8 100644
--- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java
+++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DotNames.java
@@ -4,13 +4,19 @@
import java.math.BigInteger;
import java.util.List;
import java.util.Set;
+import java.util.concurrent.CompletionStage;
import jakarta.enterprise.inject.Instance;
import org.jboss.jandex.DotName;
+import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.model.chat.listener.ChatModelListener;
+import io.smallrye.common.annotation.Blocking;
+import io.smallrye.common.annotation.NonBlocking;
+import io.smallrye.common.annotation.RunOnVirtualThread;
import io.smallrye.mutiny.Multi;
+import io.smallrye.mutiny.Uni;
public class DotNames {
@@ -38,10 +44,16 @@ public class DotNames {
public static final DotName LIST = DotName.createSimple(List.class);
public static final DotName SET = DotName.createSimple(Set.class);
public static final DotName MULTI = DotName.createSimple(Multi.class);
+ public static final DotName UNI = DotName.createSimple(Uni.class);
+ public static final DotName BLOCKING = DotName.createSimple(Blocking.class);
+ public static final DotName NON_BLOCKING = DotName.createSimple(NonBlocking.class);
+ public static final DotName COMPLETION_STAGE = DotName.createSimple(CompletionStage.class);
+ public static final DotName RUN_ON_VIRTUAL_THREAD = DotName.createSimple(RunOnVirtualThread.class);
public static final DotName OBJECT = DotName.createSimple(Object.class.getName());
public static final DotName RECORD = DotName.createSimple(Record.class);
public static final DotName CDI_INSTANCE = DotName.createSimple(Instance.class);
public static final DotName CHAT_MODEL_LISTENER = DotName.createSimple(ChatModelListener.class);
+ public static final DotName TOOL = DotName.createSimple(Tool.class);
}
diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java
index 7201e1835..6493730f9 100644
--- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java
+++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java
@@ -8,6 +8,12 @@
import static dev.langchain4j.agent.tool.JsonSchemaProperty.STRING;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.description;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.enums;
+import static io.quarkiverse.langchain4j.deployment.DotNames.BLOCKING;
+import static io.quarkiverse.langchain4j.deployment.DotNames.COMPLETION_STAGE;
+import static io.quarkiverse.langchain4j.deployment.DotNames.MULTI;
+import static io.quarkiverse.langchain4j.deployment.DotNames.NON_BLOCKING;
+import static io.quarkiverse.langchain4j.deployment.DotNames.RUN_ON_VIRTUAL_THREAD;
+import static io.quarkiverse.langchain4j.deployment.DotNames.UNI;
import static java.util.Arrays.stream;
import static java.util.stream.Collectors.toList;
@@ -15,10 +21,13 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
+import java.util.Set;
import java.util.function.BiFunction;
+import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.jboss.jandex.AnnotationInstance;
@@ -41,6 +50,7 @@
import dev.langchain4j.agent.tool.ToolMemoryId;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
+import io.quarkiverse.langchain4j.deployment.items.ToolMethodBuildItem;
import io.quarkiverse.langchain4j.runtime.ToolsRecorder;
import io.quarkiverse.langchain4j.runtime.prompt.Mappable;
import io.quarkiverse.langchain4j.runtime.tool.ToolInvoker;
@@ -61,6 +71,7 @@
import io.quarkus.deployment.builditem.CombinedIndexBuildItem;
import io.quarkus.deployment.builditem.GeneratedClassBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
+import io.quarkus.deployment.execannotations.ExecutionModelAnnotationsAllowedBuildItem;
import io.quarkus.deployment.recording.RecorderContext;
import io.quarkus.gizmo.ClassCreator;
import io.quarkus.gizmo.ClassOutput;
@@ -74,12 +85,14 @@ public class ToolProcessor {
private static final DotName TOOL = DotName.createSimple(Tool.class);
private static final DotName TOOL_MEMORY_ID = DotName.createSimple(ToolMemoryId.class);
+
private static final DotName P = DotName.createSimple(dev.langchain4j.agent.tool.P.class);
private static final MethodDescriptor METHOD_METADATA_CTOR = MethodDescriptor
.ofConstructor(ToolInvoker.MethodMetadata.class, boolean.class, Map.class, Integer.class);
private static final MethodDescriptor HASHMAP_CTOR = MethodDescriptor.ofConstructor(HashMap.class);
public static final MethodDescriptor MAP_PUT = MethodDescriptor.ofMethod(Map.class, "put", Object.class, Object.class,
Object.class);
+
private static final Logger log = Logger.getLogger(ToolProcessor.class);
@BuildStep
@@ -91,7 +104,9 @@ public void telemetry(Capabilities capabilities, BuildProducer toolMethodBuildItemProducer,
+ CombinedIndexBuildItem indexBuildItem,
BuildProducer additionalBeanProducer,
BuildProducer transformerProducer,
BuildProducer generatedClassProducer,
@@ -107,6 +122,8 @@ public void handleTools(CombinedIndexBuildItem indexBuildItem,
List generatedInvokerClasses = new ArrayList<>();
List generatedArgumentMapperClasses = new ArrayList<>();
+ Set toolsNames = new HashSet<>();
+
if (!instances.isEmpty()) {
ClassOutput classOutput = new GeneratedClassGizmoAdaptor(generatedClassProducer, true);
@@ -223,7 +240,17 @@ public void handleTools(CombinedIndexBuildItem indexBuildItem,
ToolSpecification toolSpecification = builder.build();
ToolMethodCreateInfo methodCreateInfo = new ToolMethodCreateInfo(
toolMethod.name(), invokerClassName,
- toolSpecification, argumentMapperClassName);
+ toolSpecification, argumentMapperClassName, determineExecutionModel(toolMethod));
+
+ validateExecutionModel(methodCreateInfo, toolMethod, validation);
+
+ if (toolsNames.add(toolName)) {
+ toolMethodBuildItemProducer.produce(new ToolMethodBuildItem(toolMethod, methodCreateInfo));
+ } else {
+ validation.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem(
+ new IllegalStateException("A tool with the name '" + toolName
+ + "' is already declared. Tools method name must be unique.")));
+ }
metadata.computeIfAbsent(className.toString(), (c) -> new ArrayList<>()).add(methodCreateInfo);
@@ -248,13 +275,49 @@ public void handleTools(CombinedIndexBuildItem indexBuildItem,
toolsMetadataProducer.produce(new ToolsMetadataBeforeRemovalBuildItem(metadata));
}
+ @BuildStep
+ ExecutionModelAnnotationsAllowedBuildItem toolsMethods() {
+ return new ExecutionModelAnnotationsAllowedBuildItem(new Predicate() {
+ @Override
+ public boolean test(MethodInfo method) {
+ return method.hasDeclaredAnnotation(DotNames.TOOL);
+ }
+ });
+ }
+
+ private void validateExecutionModel(ToolMethodCreateInfo methodCreateInfo, MethodInfo toolMethod,
+ BuildProducer validation) {
+ String methodName = toolMethod.declaringClass().name() + "." + toolMethod.name();
+
+ if (MULTI.equals(toolMethod.returnType().name())) {
+ validation.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem(
+ new Exception("Method " + methodName + " returns Multi, which is not supported for tools")));
+ }
+
+ if (methodCreateInfo.executionModel() == ToolMethodCreateInfo.ExecutionModel.VIRTUAL_THREAD) {
+ // We can't use Uni or CS with virtual thread
+ if (UNI.equals(toolMethod.returnType().name())) {
+ validation.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem(
+ new Exception("Method " + methodName
+ + " returns Uni, which is not supported with @RunOnVirtualThread for tools")));
+ }
+ if (COMPLETION_STAGE.equals(toolMethod.returnType().name())) {
+ validation.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem(
+ new Exception("Method " + methodName
+ + " returns CompletionStage, which is not supported with @RunOnVirtualThread for tools")));
+ }
+ }
+
+ }
+
/**
* Transforms ToolsMetadataBeforeRemovalBuildItem into ToolsMetadataBuildItem by filtering
* out tools belonging to beans that have been removed by ArC.
*/
@BuildStep
@Record(ExecutionTime.STATIC_INIT)
- public ToolsMetadataBuildItem filterOutRemovedTools(ToolsMetadataBeforeRemovalBuildItem beforeRemoval,
+ public ToolsMetadataBuildItem filterOutRemovedTools(
+ ToolsMetadataBeforeRemovalBuildItem beforeRemoval,
ValidationPhaseBuildItem validationPhase,
RecorderContext recorderContext,
ToolsRecorder recorder) {
@@ -552,4 +615,20 @@ public ClassVisitor apply(String className, ClassVisitor classVisitor) {
return transformer.applyTo(classVisitor);
}
}
+
+ private ToolMethodCreateInfo.ExecutionModel determineExecutionModel(MethodInfo methodInfo) {
+ if (methodInfo.hasAnnotation(BLOCKING)) {
+ return ToolMethodCreateInfo.ExecutionModel.BLOCKING;
+ }
+ Type returnedType = methodInfo.returnType();
+ if (methodInfo.hasAnnotation(NON_BLOCKING)
+ || UNI.equals(returnedType.name()) || COMPLETION_STAGE.equals(returnedType.name())
+ || MULTI.equals(returnedType.name())) {
+ return ToolMethodCreateInfo.ExecutionModel.NON_BLOCKING;
+ }
+ if (methodInfo.hasAnnotation(RUN_ON_VIRTUAL_THREAD)) {
+ return ToolMethodCreateInfo.ExecutionModel.VIRTUAL_THREAD;
+ }
+ return ToolMethodCreateInfo.ExecutionModel.BLOCKING;
+ }
}
diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/ToolMethodBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/ToolMethodBuildItem.java
new file mode 100644
index 000000000..9c1f25b34
--- /dev/null
+++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/items/ToolMethodBuildItem.java
@@ -0,0 +1,53 @@
+package io.quarkiverse.langchain4j.deployment.items;
+
+import org.jboss.jandex.MethodInfo;
+
+import io.quarkiverse.langchain4j.deployment.DotNames;
+import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo;
+import io.quarkus.builder.item.MultiBuildItem;
+
+/**
+ * A build item that represents a method that is annotated with {@link dev.langchain4j.agent.tool.Tool}.
+ * It contains the method info and the tool method create info.
+ */
+public final class ToolMethodBuildItem extends MultiBuildItem {
+
+ private final MethodInfo toolsMethodInfo;
+
+ private final ToolMethodCreateInfo toolMethodCreateInfo;
+
+ public ToolMethodBuildItem(MethodInfo toolsMethodInfo, ToolMethodCreateInfo toolMethodCreateInfo) {
+ this.toolsMethodInfo = toolsMethodInfo;
+ this.toolMethodCreateInfo = toolMethodCreateInfo;
+ }
+
+ public MethodInfo getToolsMethodInfo() {
+ return toolsMethodInfo;
+ }
+
+ public String getDeclaringClassName() {
+ return toolsMethodInfo.declaringClass().name().toString();
+ }
+
+ public ToolMethodCreateInfo getToolMethodCreateInfo() {
+ return toolMethodCreateInfo;
+ }
+
+ /**
+ * Returns true if the method requires a switch to a worker thread, even if the method is non-blocking.
+ * This is because of the tools executor limitation (imperative API).
+ *
+ * @return true if the method requires a switch to a worker thread
+ */
+ public boolean requiresSwitchToWorkerThread() {
+ return !(toolMethodCreateInfo.executionModel() == ToolMethodCreateInfo.ExecutionModel.NON_BLOCKING
+ && isImperativeMethod());
+ }
+
+ private boolean isImperativeMethod() {
+ var type = toolsMethodInfo.returnType();
+ return !DotNames.UNI.equals(type.name())
+ && !DotNames.MULTI.equals(type.name())
+ && !DotNames.COMPLETION_STAGE.equals(type.name());
+ }
+}
diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/Lists.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/Lists.java
new file mode 100644
index 000000000..850b1ad57
--- /dev/null
+++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/Lists.java
@@ -0,0 +1,13 @@
+package io.quarkiverse.langchain4j.test;
+
+import java.util.List;
+
+public class Lists {
+
+ public static T last(List list) {
+ if (list.isEmpty()) {
+ return null;
+ }
+ return list.get(list.size() - 1);
+ }
+}
diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/ToolExecutorTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/ToolExecutorTest.java
index f68de2ff5..feb6d3e3e 100644
--- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/ToolExecutorTest.java
+++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/ToolExecutorTest.java
@@ -291,7 +291,7 @@ private ToolExecutor getToolExecutor(String methodName) {
toolSpecification.name())) { // this only works because TestTool does not contain overloaded methods
toolExecutor = new QuarkusToolExecutor(
new QuarkusToolExecutor.Context(testTool, invokerClassName, methodCreateInfo.methodName(),
- methodCreateInfo.argumentMapperClassName()));
+ methodCreateInfo.argumentMapperClassName(), methodCreateInfo.executionModel()));
break;
}
}
diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelTest.java
new file mode 100644
index 000000000..ad9044c42
--- /dev/null
+++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelTest.java
@@ -0,0 +1,340 @@
+package io.quarkiverse.langchain4j.test.tools;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.util.List;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Supplier;
+
+import jakarta.enterprise.context.control.ActivateRequestContext;
+import jakarta.inject.Inject;
+import jakarta.inject.Singleton;
+
+import org.jboss.shrinkwrap.api.ShrinkWrap;
+import org.jboss.shrinkwrap.api.spec.JavaArchive;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledForJreRange;
+import org.junit.jupiter.api.condition.JRE;
+import org.junit.jupiter.api.extension.RegisterExtension;
+import org.testcontainers.shaded.org.awaitility.Awaitility;
+
+import dev.langchain4j.agent.tool.Tool;
+import dev.langchain4j.agent.tool.ToolExecutionRequest;
+import dev.langchain4j.agent.tool.ToolSpecification;
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.data.message.ToolExecutionResultMessage;
+import dev.langchain4j.memory.ChatMemory;
+import dev.langchain4j.memory.chat.ChatMemoryProvider;
+import dev.langchain4j.model.chat.ChatLanguageModel;
+import dev.langchain4j.model.output.FinishReason;
+import dev.langchain4j.model.output.Response;
+import dev.langchain4j.model.output.TokenUsage;
+import dev.langchain4j.service.MemoryId;
+import dev.langchain4j.service.UserMessage;
+import io.quarkiverse.langchain4j.RegisterAiService;
+import io.quarkiverse.langchain4j.ToolBox;
+import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory;
+import io.quarkiverse.langchain4j.test.Lists;
+import io.quarkus.arc.Arc;
+import io.quarkus.test.QuarkusUnitTest;
+import io.quarkus.virtual.threads.VirtualThreadsRecorder;
+import io.smallrye.common.annotation.NonBlocking;
+import io.smallrye.common.annotation.RunOnVirtualThread;
+import io.smallrye.mutiny.Uni;
+import io.vertx.core.Vertx;
+
+public class ToolExecutionModelTest {
+
+ @RegisterExtension
+ static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
+ .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses(MyAiService.class, Lists.class));
+
+ @Inject
+ MyAiService aiService;
+
+ @Inject
+ Vertx vertx;
+
+ @Test
+ @ActivateRequestContext
+ void testBlockingToolInvocationFromWorkerThread() {
+ String uuid = UUID.randomUUID().toString();
+ var r = aiService.hello("abc", "hi - " + uuid);
+ assertThat(r).contains(uuid, Thread.currentThread().getName()); // We are invoked on the same thread
+ }
+
+ @Test
+ @ActivateRequestContext
+ void testBlockingToolInvocationFromEventLoop() {
+ String uuid = UUID.randomUUID().toString();
+ AtomicReference failure = new AtomicReference<>();
+ vertx.getOrCreateContext().runOnContext(x -> {
+ try {
+ Arc.container().requestContext().activate();
+ aiService.hello("abc", "hi - " + uuid);
+ } catch (IllegalStateException e) {
+ failure.set(e);
+ } finally {
+ Arc.container().requestContext().deactivate();
+ }
+ });
+
+ Awaitility.await().until(() -> failure.get() != null);
+ assertThat(failure.get()).hasMessageContaining("Cannot execute blocking tools on event loop thread");
+ }
+
+ @Test
+ @ActivateRequestContext
+ @EnabledForJreRange(min = JRE.JAVA_21)
+ void testBlockingToolInvocationFromVirtualThread() throws ExecutionException, InterruptedException {
+ String uuid = UUID.randomUUID().toString();
+ AtomicReference caller = new AtomicReference<>();
+ var r = VirtualThreadsRecorder.getCurrent().submit(() -> {
+ try {
+ Arc.container().requestContext().activate();
+ caller.set(Thread.currentThread().getName());
+ return aiService.hello("abc", "hi - " + uuid);
+ } finally {
+ Arc.container().requestContext().deactivate();
+ }
+ }).get();
+
+ // The blocking tool is executed on the same thread
+ assertThat(r).contains(uuid, "quarkus-virtual-thread-")
+ .contains(caller.get());
+ }
+
+ @Test
+ @ActivateRequestContext
+ void testNonBlockingToolInvocationFromWorkerThread() {
+ String uuid = UUID.randomUUID().toString();
+ var r = aiService.hello("abc", "hiNonBlocking - " + uuid);
+ assertThat(r).contains(uuid, Thread.currentThread().getName()); // We are invoked on the same thread
+ }
+
+ @Test
+ @ActivateRequestContext
+ void testNonBlockingToolInvocationFromEventLoop() {
+ String uuid = UUID.randomUUID().toString();
+ AtomicReference result = new AtomicReference<>();
+ AtomicReference caller = new AtomicReference<>();
+ ;
+ vertx.getOrCreateContext().runOnContext(x -> {
+ try {
+ caller.set(Thread.currentThread().getName());
+ Arc.container().requestContext().activate();
+ result.set(aiService.hello("abc", "hiNonBlocking - " + uuid));
+ } finally {
+ Arc.container().requestContext().deactivate();
+ }
+ });
+
+ Awaitility.await().until(() -> result.get() != null);
+ assertThat(result.get()).contains(uuid, caller.get());
+ }
+
+ @Test
+ @ActivateRequestContext
+ @EnabledForJreRange(min = JRE.JAVA_21)
+ void testNonBlockingToolInvocationFromVirtualThread() throws ExecutionException, InterruptedException {
+ String uuid = UUID.randomUUID().toString();
+ AtomicReference caller = new AtomicReference<>();
+ var r = VirtualThreadsRecorder.getCurrent().submit(() -> {
+ try {
+ Arc.container().requestContext().activate();
+ caller.set(Thread.currentThread().getName());
+ return aiService.hello("abc", "hiNonBlocking - " + uuid);
+ } finally {
+ Arc.container().requestContext().deactivate();
+ }
+ }).get();
+
+ // The blocking tool is executed on the same thread
+ assertThat(r).contains(uuid, "quarkus-virtual-thread-")
+ .contains(caller.get());
+ }
+
+ @Test
+ @ActivateRequestContext
+ void testUniToolInvocationFromWorkerThread() {
+ String uuid = UUID.randomUUID().toString();
+ var r = aiService.hello("abc", "hiUni - " + uuid);
+ assertThat(r).contains(uuid, Thread.currentThread().getName()); // We are invoked on the same thread
+ }
+
+ @Test
+ @ActivateRequestContext
+ void testUniToolInvocationFromEventLoop() {
+ String uuid = UUID.randomUUID().toString();
+ AtomicReference failure = new AtomicReference<>();
+ vertx.getOrCreateContext().runOnContext(x -> {
+ try {
+ Arc.container().requestContext().activate();
+ aiService.hello("abc", "hiUni - " + uuid);
+ } catch (Exception e) {
+ failure.set(e);
+ } finally {
+ Arc.container().requestContext().deactivate();
+ }
+ });
+
+ Awaitility.await().until(() -> failure.get() != null);
+ assertThat(failure.get()).hasMessageContaining("Cannot execute tools returning Uni on event loop thread");
+ }
+
+ @Test
+ @ActivateRequestContext
+ @EnabledForJreRange(min = JRE.JAVA_21)
+ void testUniToolInvocationFromVirtualThread() throws ExecutionException, InterruptedException {
+ String uuid = UUID.randomUUID().toString();
+ AtomicReference caller = new AtomicReference<>();
+ var r = VirtualThreadsRecorder.getCurrent().submit(() -> {
+ try {
+ Arc.container().requestContext().activate();
+ caller.set(Thread.currentThread().getName());
+ return aiService.hello("abc", "hiUni - " + uuid);
+ } finally {
+ Arc.container().requestContext().deactivate();
+ }
+ }).get();
+
+ // The blocking tool is executed on the same thread
+ assertThat(r).contains(uuid, "quarkus-virtual-thread-")
+ .contains(caller.get());
+ }
+
+ @Test
+ @EnabledForJreRange(min = JRE.JAVA_21)
+ @ActivateRequestContext
+ void testToolInvocationOnVirtualThread() {
+ String uuid = UUID.randomUUID().toString();
+ var r = aiService.hello("abc", "hiVirtualThread - " + uuid);
+ assertThat(r).contains(uuid, "quarkus-virtual-thread-");
+ }
+
+ @Test
+ @EnabledForJreRange(min = JRE.JAVA_21)
+ void testToolInvocationOnVirtualThreadFromVirtualThread() throws ExecutionException, InterruptedException {
+ String uuid = UUID.randomUUID().toString();
+ AtomicReference caller = new AtomicReference<>();
+ var r = VirtualThreadsRecorder.getCurrent().submit(() -> {
+ try {
+ Arc.container().requestContext().activate();
+ caller.set(Thread.currentThread().getName());
+ return aiService.hello("abc", "hiVirtualThread - " + uuid);
+ } finally {
+ Arc.container().requestContext().deactivate();
+ }
+ }).get();
+
+ // At the moment, we create a virtual thread every time.
+ assertThat(r).contains(uuid, "quarkus-virtual-thread-")
+ .doesNotContain(caller.get());
+ }
+
+ @Test
+ @EnabledForJreRange(min = JRE.JAVA_21)
+ void testToolInvocationOnVirtualThreadFromEventLoop() {
+ String uuid = UUID.randomUUID().toString();
+ AtomicReference failure = new AtomicReference<>();
+ vertx.getOrCreateContext().runOnContext(x -> {
+ try {
+ Arc.container().requestContext().activate();
+ aiService.hello("abc", "hiVirtualThread - " + uuid);
+ } catch (IllegalStateException e) {
+ failure.set(e);
+ } finally {
+ Arc.container().requestContext().deactivate();
+ }
+ });
+
+ Awaitility.await().until(() -> failure.get() != null);
+ assertThat(failure.get()).hasMessageContaining("Cannot execute virtual thread tools on event loop thread");
+ }
+
+ @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
+ public interface MyAiService {
+
+ @ToolBox(MyTool.class)
+ String hello(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId);
+ }
+
+ @Singleton
+ public static class MyTool {
+ @Tool
+ public String hi(String m) {
+ return m + " " + Thread.currentThread();
+ }
+
+ @Tool
+ @NonBlocking
+ public String hiNonBlocking(String m) {
+ return m + " " + Thread.currentThread();
+ }
+
+ @Tool
+ public Uni hiUni(String m) {
+ return Uni.createFrom().item(() -> m + " " + Thread.currentThread());
+ }
+
+ @Tool
+ @RunOnVirtualThread
+ public String hiVirtualThread(String m) {
+ return m + " " + Thread.currentThread();
+ }
+ }
+
+ public static class MyChatModelSupplier implements Supplier {
+
+ @Override
+ public ChatLanguageModel get() {
+ return new MyChatModel();
+ }
+ }
+
+ public static class MyChatModel implements ChatLanguageModel {
+
+ @Override
+ public Response generate(List messages) {
+ throw new UnsupportedOperationException("Should not be called");
+ }
+
+ @Override
+ public Response generate(List messages, List toolSpecifications) {
+ if (messages.size() == 1) {
+ // Only the user message, extract the tool id from it
+ String text = ((dev.langchain4j.data.message.UserMessage) messages.get(0)).singleText();
+ var segments = text.split(" - ");
+ var toolId = segments[0];
+ var content = segments[1];
+ // Only the user message
+ return new Response<>(new AiMessage("cannot be blank", List.of(ToolExecutionRequest.builder()
+ .id("my-tool-" + toolId)
+ .name(toolId)
+ .arguments("{\"m\":\"" + content + "\"}")
+ .build())), new TokenUsage(0, 0), FinishReason.TOOL_EXECUTION);
+ } else if (messages.size() == 3) {
+ // user -> tool request -> tool response
+ ToolExecutionResultMessage last = (ToolExecutionResultMessage) Lists.last(messages);
+ return new Response<>(AiMessage.from("response: " + last.text()));
+
+ }
+ return new Response<>(new AiMessage("Unexpected"));
+ }
+ }
+
+ public static class MyMemoryProviderSupplier implements Supplier {
+ @Override
+ public ChatMemoryProvider get() {
+ return new ChatMemoryProvider() {
+ @Override
+ public ChatMemory get(Object memoryId) {
+ return new NoopChatMemory();
+ }
+ };
+ }
+ }
+}
diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingAndRequestScopePropagationTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingAndRequestScopePropagationTest.java
new file mode 100644
index 000000000..116efcff4
--- /dev/null
+++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingAndRequestScopePropagationTest.java
@@ -0,0 +1,427 @@
+package io.quarkiverse.langchain4j.test.tools;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.util.List;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Supplier;
+
+import jakarta.enterprise.context.RequestScoped;
+import jakarta.enterprise.context.control.ActivateRequestContext;
+import jakarta.inject.Inject;
+import jakarta.inject.Singleton;
+
+import org.jboss.shrinkwrap.api.ShrinkWrap;
+import org.jboss.shrinkwrap.api.spec.JavaArchive;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledForJreRange;
+import org.junit.jupiter.api.condition.JRE;
+import org.junit.jupiter.api.extension.RegisterExtension;
+import org.testcontainers.shaded.org.awaitility.Awaitility;
+
+import dev.langchain4j.agent.tool.Tool;
+import dev.langchain4j.agent.tool.ToolExecutionRequest;
+import dev.langchain4j.agent.tool.ToolSpecification;
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.data.message.ToolExecutionResultMessage;
+import dev.langchain4j.memory.ChatMemory;
+import dev.langchain4j.memory.chat.ChatMemoryProvider;
+import dev.langchain4j.memory.chat.MessageWindowChatMemory;
+import dev.langchain4j.model.StreamingResponseHandler;
+import dev.langchain4j.model.chat.StreamingChatLanguageModel;
+import dev.langchain4j.model.output.FinishReason;
+import dev.langchain4j.model.output.Response;
+import dev.langchain4j.model.output.TokenUsage;
+import dev.langchain4j.service.MemoryId;
+import dev.langchain4j.service.UserMessage;
+import io.quarkiverse.langchain4j.RegisterAiService;
+import io.quarkiverse.langchain4j.ToolBox;
+import io.quarkiverse.langchain4j.test.Lists;
+import io.quarkus.arc.Arc;
+import io.quarkus.test.QuarkusUnitTest;
+import io.quarkus.test.vertx.RunOnVertxContext;
+import io.quarkus.virtual.threads.VirtualThreadsRecorder;
+import io.smallrye.common.annotation.NonBlocking;
+import io.smallrye.common.annotation.RunOnVirtualThread;
+import io.smallrye.common.vertx.VertxContext;
+import io.smallrye.mutiny.Multi;
+import io.smallrye.mutiny.Uni;
+import io.vertx.core.Vertx;
+
+public class ToolExecutionModelWithStreamingAndRequestScopePropagationTest {
+
+ @RegisterExtension
+ static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
+ .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses(MyAiService.class, Lists.class));
+
+ @Inject
+ MyAiService aiService;
+
+ @Inject
+ Vertx vertx;
+
+ @Inject
+ UUIDGenerator uuidGenerator;
+
+ @Test
+ @ActivateRequestContext
+ void testBlockingToolInvocationFromWorkerThread() {
+ String uuid = uuidGenerator.get();
+ var r = aiService.hello("abc", "hi")
+ .collect().asList().map(l -> String.join(" ", l)).await().indefinitely();
+ assertThat(r).contains(uuid, Thread.currentThread().getName()); // We are invoked on the same thread
+ }
+
+ @Test
+ void testBlockingToolInvocationFromEventLoop() {
+ AtomicReference failure = new AtomicReference<>();
+ AtomicReference result = new AtomicReference<>();
+ AtomicReference value = new AtomicReference<>();
+ var ctxt = VertxContext.getOrCreateDuplicatedContext(vertx);
+ ctxt.runOnContext(x -> {
+ try {
+ Arc.container().requestContext().activate();
+ String uuid = uuidGenerator.get();
+ value.set(uuid);
+ aiService.hello("abc", "hi")
+ .collect().asList().map(l -> String.join(" ", l))
+ .subscribeAsCompletionStage()
+ .thenAccept(result::set)
+ .whenComplete((r, t) -> Arc.container().requestContext().deactivate());
+ } catch (IllegalStateException e) {
+ failure.set(e);
+ Arc.container().requestContext().deactivate();
+ }
+ });
+
+ // We would automatically detect this case and switch to a worker thread at subscription time.
+
+ Awaitility.await().until(() -> failure.get() != null || result.get() != null);
+ assertThat(failure.get()).isNull();
+ assertThat(result.get()).doesNotContain("event", "loop")
+ .contains(value.get(), "executor-thread");
+ }
+
+ @Test
+ @EnabledForJreRange(min = JRE.JAVA_21)
+ void testBlockingToolInvocationFromVirtualThread() throws ExecutionException, InterruptedException {
+ AtomicReference value = new AtomicReference<>();
+ AtomicReference caller = new AtomicReference<>();
+ var r = VirtualThreadsRecorder.getCurrent().submit(() -> {
+ try {
+ Arc.container().requestContext().activate();
+ value.set(uuidGenerator.get());
+ caller.set(Thread.currentThread().getName());
+ return aiService.hello("abc", "hi")
+ .collect().asList().map(l -> String.join(" ", l)).await().indefinitely();
+ } finally {
+ Arc.container().requestContext().deactivate();
+ }
+ }).get();
+
+ // The blocking tool is executed on the same thread
+ assertThat(r).contains(value.get(), "quarkus-virtual-thread-")
+ .contains(caller.get());
+ }
+
+ @Test
+ @ActivateRequestContext
+ void testNonBlockingToolInvocationFromWorkerThread() {
+ String uuid = uuidGenerator.get();
+ var r = aiService.helloNonBlocking("abc", "hiNonBlocking")
+ .collect().asList().map(l -> String.join(" ", l)).await().indefinitely();
+ assertThat(r).contains(uuid, Thread.currentThread().getName()); // We are invoked on the same thread
+ }
+
+ @Test
+ @ActivateRequestContext
+ void testNonBlockingToolInvocationFromEventLoop() {
+ AtomicReference value = new AtomicReference<>();
+ AtomicReference result = new AtomicReference<>();
+ AtomicReference caller = new AtomicReference<>();
+
+ var ctxt = VertxContext.getOrCreateDuplicatedContext(vertx);
+ ctxt.runOnContext(x -> {
+ caller.set(Thread.currentThread().getName());
+ Arc.container().requestContext().activate();
+ value.set(uuidGenerator.get());
+ aiService.helloNonBlocking("abc", "hiNonBlocking")
+ .collect().asList().map(l -> String.join(" ", l))
+ .subscribeAsCompletionStage().thenAccept(result::set)
+ .whenComplete((r, t) -> Arc.container().requestContext().deactivate());
+ });
+
+ Awaitility.await().until(() -> result.get() != null);
+ assertThat(result.get()).contains(value.get(), caller.get());
+ }
+
+ @Test
+ @ActivateRequestContext
+ void testNonBlockingToolInvocationFromEventLoopWhenWeSwitchToWorkerThread() {
+ AtomicReference value = new AtomicReference<>();
+ AtomicReference result = new AtomicReference<>();
+ AtomicReference caller = new AtomicReference<>();
+
+ var ctxt = VertxContext.getOrCreateDuplicatedContext(vertx);
+ ctxt.runOnContext(x -> {
+ caller.set(Thread.currentThread().getName());
+ Arc.container().requestContext().activate();
+ value.set(uuidGenerator.get());
+ aiService.helloNonBlockingWithSwitch("abc", "hiNonBlocking")
+ .collect().asList().map(l -> String.join(" ", l))
+ .subscribeAsCompletionStage().thenAccept(result::set)
+ .whenComplete((r, t) -> Arc.container().requestContext().deactivate());
+
+ });
+
+ Awaitility.await().until(() -> result.get() != null);
+ assertThat(result.get()).contains(value.get(), "executor-thread")
+ .doesNotContain(caller.get());
+ }
+
+ @Test
+ @RunOnVertxContext
+ @ActivateRequestContext
+ @EnabledForJreRange(min = JRE.JAVA_21)
+ void testNonBlockingToolInvocationFromVirtualThread() throws ExecutionException, InterruptedException {
+ String uuid = uuidGenerator.get();
+ AtomicReference caller = new AtomicReference<>();
+ var r = VirtualThreadsRecorder.getCurrent().submit(() -> {
+ caller.set(Thread.currentThread().getName());
+ return aiService.helloNonBlocking("abc", "hiNonBlocking")
+ .collect().asList().map(l -> String.join(" ", l)).await().indefinitely();
+ }).get();
+
+ // The blocking tool is executed on the same thread
+ assertThat(r).contains(uuid, "quarkus-virtual-thread-")
+ .contains(caller.get());
+ }
+
+ @Test
+ @ActivateRequestContext
+ void testUniToolInvocationFromWorkerThread() {
+ String uuid = uuidGenerator.get();
+ var r = aiService.helloUni("abc", "hiUni")
+ .collect().asList().map(l -> String.join(" ", l)).await().indefinitely();
+ assertThat(r).contains(uuid, Thread.currentThread().getName()); // We are invoked on the same thread
+ }
+
+ @Test
+ @ActivateRequestContext
+ void testUniToolInvocationFromEventLoop() {
+ AtomicReference value = new AtomicReference<>();
+ AtomicReference result = new AtomicReference<>();
+ var ctxt = VertxContext.getOrCreateDuplicatedContext(vertx);
+ ctxt.runOnContext(x -> {
+ Arc.container().requestContext().activate();
+ value.set(uuidGenerator.get());
+ aiService.helloUni("abc", "hiUni")
+ .collect().asList().map(l -> String.join(" ", l))
+ .subscribeAsCompletionStage()
+ .thenAccept(result::set)
+ .whenComplete((r, t) -> Arc.container().requestContext().deactivate());
+
+ });
+
+ Awaitility.await().until(() -> result.get() != null);
+ assertThat(result.get()).contains(value.get(), "executor-thread");
+ }
+
+ @Test
+ @RunOnVertxContext
+ @ActivateRequestContext
+ @EnabledForJreRange(min = JRE.JAVA_21)
+ void testUniToolInvocationFromVirtualThread() throws ExecutionException, InterruptedException {
+ String uuid = uuidGenerator.get();
+ AtomicReference caller = new AtomicReference<>();
+ var r = VirtualThreadsRecorder.getCurrent().submit(() -> {
+ caller.set(Thread.currentThread().getName());
+ return aiService.helloUni("abc", "hiUni")
+ .collect().asList().map(l -> String.join(" ", l)).await().indefinitely();
+ }).get();
+
+ // The blocking tool is executed on the same thread (synchronous emission)
+ assertThat(r).contains(uuid, "quarkus-virtual-thread-")
+ .contains(caller.get());
+ }
+
+ @Test
+ @EnabledForJreRange(min = JRE.JAVA_21)
+ @RunOnVertxContext(runOnEventLoop = false)
+ @ActivateRequestContext
+ void testToolInvocationOnVirtualThread() {
+ String uuid = uuidGenerator.get();
+ var r = aiService.helloVirtualTools("abc", "hiVirtualThread")
+ .collect().asList().map(l -> String.join(" ", l)).await().indefinitely();
+ assertThat(r).contains(uuid, "quarkus-virtual-thread-");
+ }
+
+ @Test
+ @EnabledForJreRange(min = JRE.JAVA_21)
+ @RunOnVertxContext
+ @ActivateRequestContext
+ void testToolInvocationOnVirtualThreadFromVirtualThread() throws ExecutionException, InterruptedException {
+ String uuid = uuidGenerator.get();
+ AtomicReference caller = new AtomicReference<>();
+ var r = VirtualThreadsRecorder.getCurrent().submit(() -> {
+ caller.set(Thread.currentThread().getName());
+ return aiService.helloVirtualTools("abc", "hiVirtualThread")
+ .collect().asList().map(l -> String.join(" ", l)).await().indefinitely();
+ }).get();
+
+ // At the moment, we create a virtual thread every time.
+ assertThat(r).contains(uuid, "quarkus-virtual-thread-")
+ .doesNotContain(caller.get());
+ }
+
+ @Test
+ @EnabledForJreRange(min = JRE.JAVA_21)
+ @ActivateRequestContext
+ void testToolInvocationOnVirtualThreadFromEventLoop() {
+ AtomicReference value = new AtomicReference<>();
+ AtomicReference result = new AtomicReference<>();
+ var ctxt = VertxContext.getOrCreateDuplicatedContext(vertx);
+ ctxt.runOnContext(x -> {
+ Arc.container().requestContext().activate();
+ value.set(uuidGenerator.get());
+ aiService.helloVirtualTools("abc", "hiVirtualThread")
+ .collect().asList().map(l -> String.join(" ", l))
+ .subscribeAsCompletionStage().thenAccept(result::set)
+ .whenComplete((r, t) -> Arc.container().requestContext().deactivate());
+ });
+
+ Awaitility.await().until(() -> result.get() != null);
+ assertThat(result.get()).contains(value.get(), "quarkus-virtual-thread-");
+ }
+
+ @RegisterAiService(streamingChatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
+ public interface MyAiService {
+
+ @ToolBox(BlockingTool.class)
+ Multi hello(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId);
+
+ @ToolBox(NonBlockingTool.class)
+ Multi helloNonBlocking(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId);
+
+ @ToolBox({ NonBlockingTool.class, BlockingTool.class })
+ Multi helloNonBlockingWithSwitch(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId);
+
+ @ToolBox(UniTool.class)
+ Multi helloUni(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId);
+
+ @ToolBox(VirtualTool.class)
+ Multi helloVirtualTools(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId);
+ }
+
+ @Singleton
+ public static class BlockingTool {
+ @Inject
+ UUIDGenerator uuidGenerator;
+
+ @Tool
+ public String hi() {
+ return uuidGenerator.get() + " " + Thread.currentThread();
+ }
+ }
+
+ @Singleton
+ public static class NonBlockingTool {
+ @Inject
+ UUIDGenerator uuidGenerator;
+
+ @Tool
+ @NonBlocking
+ public String hiNonBlocking() {
+ return uuidGenerator.get() + " " + Thread.currentThread();
+ }
+ }
+
+ @Singleton
+ public static class UniTool {
+ @Inject
+ UUIDGenerator uuidGenerator;
+
+ @Tool
+ public Uni hiUni() {
+ return Uni.createFrom().item(() -> uuidGenerator.get() + " " + Thread.currentThread());
+ }
+ }
+
+ @Singleton
+ public static class VirtualTool {
+
+ @Inject
+ UUIDGenerator uuidGenerator;
+
+ @Tool
+ @RunOnVirtualThread
+ public String hiVirtualThread() {
+ return uuidGenerator.get() + " " + Thread.currentThread();
+ }
+ }
+
+ public static class MyChatModelSupplier implements Supplier {
+
+ @Override
+ public StreamingChatLanguageModel get() {
+ return new MyChatModel();
+ }
+ }
+
+ public static class MyChatModel implements StreamingChatLanguageModel {
+
+ @Override
+ public void generate(List messages, StreamingResponseHandler handler) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void generate(List messages, List toolSpecifications,
+ StreamingResponseHandler handler) {
+ if (messages.size() == 1) {
+ // Only the user message, extract the tool id from it
+ String text = ((dev.langchain4j.data.message.UserMessage) messages.get(0)).singleText();
+ // Only the user message
+ handler.onComplete(new Response<>(new AiMessage("cannot be blank", List.of(ToolExecutionRequest.builder()
+ .id("my-tool-" + text)
+ .name(text)
+ .arguments("{}")
+ .build())), new TokenUsage(0, 0), FinishReason.TOOL_EXECUTION));
+ } else if (messages.size() == 3) {
+ // user -> tool request -> tool response
+ ToolExecutionResultMessage last = (ToolExecutionResultMessage) Lists.last(messages);
+ handler.onNext("response: ");
+ handler.onNext(last.text());
+ handler.onComplete(new Response<>(new AiMessage(""), new TokenUsage(0, 0), FinishReason.STOP));
+
+ } else {
+ handler.onError(new RuntimeException("Invalid number of messages"));
+ }
+ }
+
+ }
+
+ public static class MyMemoryProviderSupplier implements Supplier {
+ @Override
+ public ChatMemoryProvider get() {
+ return new ChatMemoryProvider() {
+ @Override
+ public ChatMemory get(Object memoryId) {
+ return MessageWindowChatMemory.withMaxMessages(10);
+ }
+ };
+ }
+ }
+
+ @RequestScoped
+ public static class UUIDGenerator {
+ private final String uuid = UUID.randomUUID().toString();
+
+ public String get() {
+ return uuid;
+ }
+ }
+}
diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingTest.java
new file mode 100644
index 000000000..0b8092018
--- /dev/null
+++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/tools/ToolExecutionModelWithStreamingTest.java
@@ -0,0 +1,420 @@
+package io.quarkiverse.langchain4j.test.tools;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.util.List;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Supplier;
+
+import jakarta.enterprise.context.control.ActivateRequestContext;
+import jakarta.inject.Inject;
+import jakarta.inject.Singleton;
+
+import org.jboss.shrinkwrap.api.ShrinkWrap;
+import org.jboss.shrinkwrap.api.spec.JavaArchive;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledForJreRange;
+import org.junit.jupiter.api.condition.JRE;
+import org.junit.jupiter.api.extension.RegisterExtension;
+import org.testcontainers.shaded.org.awaitility.Awaitility;
+
+import dev.langchain4j.agent.tool.Tool;
+import dev.langchain4j.agent.tool.ToolExecutionRequest;
+import dev.langchain4j.agent.tool.ToolSpecification;
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.data.message.ToolExecutionResultMessage;
+import dev.langchain4j.memory.ChatMemory;
+import dev.langchain4j.memory.chat.ChatMemoryProvider;
+import dev.langchain4j.memory.chat.MessageWindowChatMemory;
+import dev.langchain4j.model.StreamingResponseHandler;
+import dev.langchain4j.model.chat.StreamingChatLanguageModel;
+import dev.langchain4j.model.output.FinishReason;
+import dev.langchain4j.model.output.Response;
+import dev.langchain4j.model.output.TokenUsage;
+import dev.langchain4j.service.MemoryId;
+import dev.langchain4j.service.UserMessage;
+import io.quarkiverse.langchain4j.RegisterAiService;
+import io.quarkiverse.langchain4j.ToolBox;
+import io.quarkiverse.langchain4j.test.Lists;
+import io.quarkus.arc.Arc;
+import io.quarkus.test.QuarkusUnitTest;
+import io.quarkus.virtual.threads.VirtualThreadsRecorder;
+import io.smallrye.common.annotation.NonBlocking;
+import io.smallrye.common.annotation.RunOnVirtualThread;
+import io.smallrye.mutiny.Multi;
+import io.smallrye.mutiny.Uni;
+import io.vertx.core.Vertx;
+
+public class ToolExecutionModelWithStreamingTest {
+
+ @RegisterExtension
+ static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
+ .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class).addClasses(MyAiService.class, Lists.class));
+
+ @Inject
+ MyAiService aiService;
+
+ @Inject
+ Vertx vertx;
+
+ @Test
+ @ActivateRequestContext
+ void testBlockingToolInvocationFromWorkerThread() {
+ String uuid = UUID.randomUUID().toString();
+ var r = aiService.hello("abc", "hi - " + uuid)
+ .collect().asList().map(l -> String.join(" ", l)).await().indefinitely();
+ assertThat(r).contains(uuid, Thread.currentThread().getName()); // We are invoked on the same thread
+ }
+
+ @Test
+ @ActivateRequestContext
+ void testBlockingToolInvocationFromEventLoop() {
+ String uuid = UUID.randomUUID().toString();
+ AtomicReference failure = new AtomicReference<>();
+ AtomicReference result = new AtomicReference<>();
+ vertx.getOrCreateContext().runOnContext(x -> {
+ try {
+ Arc.container().requestContext().activate();
+ aiService.hello("abc", "hi - " + uuid)
+ .collect().asList().map(l -> String.join(" ", l))
+ .subscribeAsCompletionStage()
+ .thenAccept(result::set);
+ } catch (IllegalStateException e) {
+ failure.set(e);
+ } finally {
+ Arc.container().requestContext().deactivate();
+ }
+ });
+
+ // We would automatically detect this case and switch to a worker thread at subscription time.
+
+ Awaitility.await().until(() -> failure.get() != null || result.get() != null);
+ assertThat(failure.get()).isNull();
+ assertThat(result.get()).doesNotContain("event", "loop")
+ .contains(uuid, "executor-thread");
+ }
+
+ @Test
+ @ActivateRequestContext
+ @EnabledForJreRange(min = JRE.JAVA_21)
+ void testBlockingToolInvocationFromVirtualThread() throws ExecutionException, InterruptedException {
+ String uuid = UUID.randomUUID().toString();
+ AtomicReference caller = new AtomicReference<>();
+ var r = VirtualThreadsRecorder.getCurrent().submit(() -> {
+ try {
+ Arc.container().requestContext().activate();
+ caller.set(Thread.currentThread().getName());
+ return aiService.hello("abc", "hi - " + uuid)
+ .collect().asList().map(l -> String.join(" ", l)).await().indefinitely();
+ } finally {
+ Arc.container().requestContext().deactivate();
+ }
+ }).get();
+
+ // The blocking tool is executed on the same thread
+ assertThat(r).contains(uuid, "quarkus-virtual-thread-")
+ .contains(caller.get());
+ }
+
+ @Test
+ @ActivateRequestContext
+ void testNonBlockingToolInvocationFromWorkerThread() {
+ String uuid = UUID.randomUUID().toString();
+ var r = aiService.helloNonBlocking("abc", "hiNonBlocking - " + uuid)
+ .collect().asList().map(l -> String.join(" ", l)).await().indefinitely();
+ assertThat(r).contains(uuid, Thread.currentThread().getName()); // We are invoked on the same thread
+ }
+
+ @Test
+ @ActivateRequestContext
+ void testNonBlockingToolInvocationFromEventLoop() {
+ String uuid = UUID.randomUUID().toString();
+ AtomicReference result = new AtomicReference<>();
+ AtomicReference caller = new AtomicReference<>();
+
+ vertx.getOrCreateContext().runOnContext(x -> {
+ try {
+ caller.set(Thread.currentThread().getName());
+ Arc.container().requestContext().activate();
+ aiService.helloNonBlocking("abc", "hiNonBlocking - " + uuid)
+ .collect().asList().map(l -> String.join(" ", l))
+ .subscribeAsCompletionStage().thenAccept(result::set);
+ } finally {
+ Arc.container().requestContext().deactivate();
+ }
+ });
+
+ Awaitility.await().until(() -> result.get() != null);
+ assertThat(result.get()).contains(uuid, caller.get());
+ }
+
+ @Test
+ @ActivateRequestContext
+ void testNonBlockingToolInvocationFromEventLoopWhenWeSwitchToWorkerThread() {
+ String uuid = UUID.randomUUID().toString();
+ AtomicReference result = new AtomicReference<>();
+ AtomicReference caller = new AtomicReference<>();
+
+ vertx.getOrCreateContext().runOnContext(x -> {
+ try {
+ caller.set(Thread.currentThread().getName());
+ Arc.container().requestContext().activate();
+ aiService.helloNonBlockingWithSwitch("abc", "hiNonBlocking - " + uuid)
+ .collect().asList().map(l -> String.join(" ", l))
+ .subscribeAsCompletionStage().thenAccept(result::set);
+ } finally {
+ Arc.container().requestContext().deactivate();
+ }
+ });
+
+ Awaitility.await().until(() -> result.get() != null);
+ assertThat(result.get()).contains(uuid, "executor-thread")
+ .doesNotContain(caller.get());
+ }
+
+ @Test
+ @ActivateRequestContext
+ @EnabledForJreRange(min = JRE.JAVA_21)
+ void testNonBlockingToolInvocationFromVirtualThread() throws ExecutionException, InterruptedException {
+ String uuid = UUID.randomUUID().toString();
+ AtomicReference caller = new AtomicReference<>();
+ var r = VirtualThreadsRecorder.getCurrent().submit(() -> {
+ try {
+ Arc.container().requestContext().activate();
+ caller.set(Thread.currentThread().getName());
+ return aiService.helloNonBlocking("abc", "hiNonBlocking - " + uuid)
+ .collect().asList().map(l -> String.join(" ", l)).await().indefinitely();
+ } finally {
+ Arc.container().requestContext().deactivate();
+ }
+ }).get();
+
+ // The blocking tool is executed on the same thread
+ assertThat(r).contains(uuid, "quarkus-virtual-thread-")
+ .contains(caller.get());
+ }
+
+ @Test
+ @ActivateRequestContext
+ void testUniToolInvocationFromWorkerThread() {
+ String uuid = UUID.randomUUID().toString();
+ var r = aiService.helloUni("abc", "hiUni - " + uuid)
+ .collect().asList().map(l -> String.join(" ", l)).await().indefinitely();
+ assertThat(r).contains(uuid, Thread.currentThread().getName()); // We are invoked on the same thread
+ }
+
+ @Test
+ @ActivateRequestContext
+ void testUniToolInvocationFromEventLoop() {
+ String uuid = UUID.randomUUID().toString();
+ AtomicReference failure = new AtomicReference<>();
+ AtomicReference result = new AtomicReference<>();
+ vertx.getOrCreateContext().runOnContext(x -> {
+ try {
+ Arc.container().requestContext().activate();
+ aiService.helloUni("abc", "hiUni - " + uuid)
+ .collect().asList().map(l -> String.join(" ", l))
+ .subscribeAsCompletionStage()
+ .thenAccept(result::set);
+ } catch (Exception e) {
+ failure.set(e);
+ } finally {
+ Arc.container().requestContext().deactivate();
+ }
+ });
+
+ Awaitility.await().until(() -> failure.get() != null || result.get() != null);
+ assertThat(failure.get()).isNull();
+ assertThat(result.get()).contains(uuid, "executor-thread");
+ }
+
+ @Test
+ @ActivateRequestContext
+ @EnabledForJreRange(min = JRE.JAVA_21)
+ void testUniToolInvocationFromVirtualThread() throws ExecutionException, InterruptedException {
+ String uuid = UUID.randomUUID().toString();
+ AtomicReference caller = new AtomicReference<>();
+ var r = VirtualThreadsRecorder.getCurrent().submit(() -> {
+ try {
+ Arc.container().requestContext().activate();
+ caller.set(Thread.currentThread().getName());
+ return aiService.helloUni("abc", "hiUni - " + uuid)
+ .collect().asList().map(l -> String.join(" ", l)).await().indefinitely();
+ } finally {
+ Arc.container().requestContext().deactivate();
+ }
+ }).get();
+
+ // The blocking tool is executed on the same thread (synchronous emission)
+ assertThat(r).contains(uuid, "quarkus-virtual-thread-")
+ .contains(caller.get());
+ }
+
+ @Test
+ @EnabledForJreRange(min = JRE.JAVA_21)
+ @ActivateRequestContext
+ void testToolInvocationOnVirtualThread() {
+ String uuid = UUID.randomUUID().toString();
+ var r = aiService.helloVirtualTools("abc", "hiVirtualThread - " + uuid)
+ .collect().asList().map(l -> String.join(" ", l)).await().indefinitely();
+ assertThat(r).contains(uuid, "quarkus-virtual-thread-");
+ }
+
+ @Test
+ @EnabledForJreRange(min = JRE.JAVA_21)
+ void testToolInvocationOnVirtualThreadFromVirtualThread() throws ExecutionException, InterruptedException {
+ String uuid = UUID.randomUUID().toString();
+ AtomicReference caller = new AtomicReference<>();
+ var r = VirtualThreadsRecorder.getCurrent().submit(() -> {
+ try {
+ Arc.container().requestContext().activate();
+ caller.set(Thread.currentThread().getName());
+ return aiService.helloVirtualTools("abc", "hiVirtualThread - " + uuid)
+ .collect().asList().map(l -> String.join(" ", l)).await().indefinitely();
+ } finally {
+ Arc.container().requestContext().deactivate();
+ }
+ }).get();
+
+ // At the moment, we create a virtual thread every time.
+ assertThat(r).contains(uuid, "quarkus-virtual-thread-")
+ .doesNotContain(caller.get());
+ }
+
+ @Test
+ @EnabledForJreRange(min = JRE.JAVA_21)
+ void testToolInvocationOnVirtualThreadFromEventLoop() {
+ String uuid = UUID.randomUUID().toString();
+ AtomicReference failure = new AtomicReference<>();
+ AtomicReference result = new AtomicReference<>();
+ vertx.getOrCreateContext().runOnContext(x -> {
+ try {
+ Arc.container().requestContext().activate();
+ aiService.helloVirtualTools("abc", "hiVirtualThread - " + uuid)
+ .collect().asList().map(l -> String.join(" ", l))
+ .subscribeAsCompletionStage().thenAccept(result::set);
+ } catch (IllegalStateException e) {
+ failure.set(e);
+ } finally {
+ Arc.container().requestContext().deactivate();
+ }
+ });
+
+ Awaitility.await().until(() -> failure.get() != null || result.get() != null);
+ assertThat(failure.get()).isNull();
+ assertThat(result.get()).contains(uuid, "quarkus-virtual-thread-");
+ }
+
+ @RegisterAiService(streamingChatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
+ public interface MyAiService {
+
+ @ToolBox(BlockingTool.class)
+ Multi hello(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId);
+
+ @ToolBox(NonBlockingTool.class)
+ Multi helloNonBlocking(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId);
+
+ @ToolBox({ NonBlockingTool.class, BlockingTool.class })
+ Multi helloNonBlockingWithSwitch(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId);
+
+ @ToolBox(UniTool.class)
+ Multi helloUni(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId);
+
+ @ToolBox(VirtualTool.class)
+ Multi helloVirtualTools(@MemoryId String memoryId, @UserMessage String userMessageContainingTheToolId);
+ }
+
+ @Singleton
+ public static class BlockingTool {
+ @Tool
+ public String hi(String m) {
+ return m + " " + Thread.currentThread();
+ }
+ }
+
+ @Singleton
+ public static class NonBlockingTool {
+ @Tool
+ @NonBlocking
+ public String hiNonBlocking(String m) {
+ return m + " " + Thread.currentThread();
+ }
+ }
+
+ @Singleton
+ public static class UniTool {
+ @Tool
+ public Uni hiUni(String m) {
+ return Uni.createFrom().item(() -> m + " " + Thread.currentThread());
+ }
+ }
+
+ @Singleton
+ public static class VirtualTool {
+
+ @Tool
+ @RunOnVirtualThread
+ public String hiVirtualThread(String m) {
+ return m + " " + Thread.currentThread();
+ }
+ }
+
+ public static class MyChatModelSupplier implements Supplier {
+
+ @Override
+ public StreamingChatLanguageModel get() {
+ return new MyChatModel();
+ }
+ }
+
+ public static class MyChatModel implements StreamingChatLanguageModel {
+
+ @Override
+ public void generate(List messages, StreamingResponseHandler handler) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void generate(List messages, List toolSpecifications,
+ StreamingResponseHandler handler) {
+ if (messages.size() == 1) {
+ // Only the user message, extract the tool id from it
+ String text = ((dev.langchain4j.data.message.UserMessage) messages.get(0)).singleText();
+ var segments = text.split(" - ");
+ var toolId = segments[0];
+ var content = segments[1];
+ // Only the user message
+ handler.onComplete(new Response<>(new AiMessage("cannot be blank", List.of(ToolExecutionRequest.builder()
+ .id("my-tool-" + toolId)
+ .name(toolId)
+ .arguments("{\"m\":\"" + content + "\"}")
+ .build())), new TokenUsage(0, 0), FinishReason.TOOL_EXECUTION));
+ } else if (messages.size() == 3) {
+ // user -> tool request -> tool response
+ ToolExecutionResultMessage last = (ToolExecutionResultMessage) Lists.last(messages);
+ handler.onNext("response: ");
+ handler.onNext(last.text());
+ handler.onComplete(new Response<>(new AiMessage(""), new TokenUsage(0, 0), FinishReason.STOP));
+
+ } else {
+ handler.onError(new RuntimeException("Invalid number of messages: " + messages.size()));
+ }
+ }
+
+ }
+
+ public static class MyMemoryProviderSupplier implements Supplier {
+ @Override
+ public ChatMemoryProvider get() {
+ return new ChatMemoryProvider() {
+ @Override
+ public ChatMemory get(Object memoryId) {
+ return MessageWindowChatMemory.withMaxMessages(10);
+ }
+ };
+ }
+ }
+}
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/ToolsRecorder.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/ToolsRecorder.java
index 61b63a6af..5e2a4720b 100644
--- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/ToolsRecorder.java
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/ToolsRecorder.java
@@ -51,7 +51,7 @@ public static void populateToolMetadata(List