diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index 08057d89f..30ee9a92b 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -811,7 +811,6 @@ public boolean detectAiServiceMethodThanNeedToBeDispatchedOnWorkerThread( throw new RuntimeException("No tools detected in " + classname); } } - return requireSwitchToWorkerThread; } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java index 16384449f..6493730f9 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java @@ -21,9 +21,11 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.function.BiFunction; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -120,6 +122,8 @@ public void handleTools( List generatedInvokerClasses = new ArrayList<>(); List generatedArgumentMapperClasses = new ArrayList<>(); + Set toolsNames = new HashSet<>(); + if (!instances.isEmpty()) { ClassOutput classOutput = new GeneratedClassGizmoAdaptor(generatedClassProducer, true); @@ -240,7 +244,13 @@ public void handleTools( validateExecutionModel(methodCreateInfo, toolMethod, validation); - toolMethodBuildItemProducer.produce(new ToolMethodBuildItem(toolMethod, methodCreateInfo)); + if (toolsNames.add(toolName)) { + toolMethodBuildItemProducer.produce(new ToolMethodBuildItem(toolMethod, methodCreateInfo)); + } else { + validation.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem( + new IllegalStateException("A tool with the name '" + toolName + + "' is already declared. Tools method name must be unique."))); + } metadata.computeIfAbsent(className.toString(), (c) -> new ArrayList<>()).add(methodCreateInfo); diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java index 70479fffc..111af826c 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java @@ -797,7 +797,7 @@ private void createTokenStream(UnicastProcessor processor) { } var stream = new QuarkusAiServiceTokenStream(messagesToSend, toolSpecifications, - toolsExecutors, contents, context, memoryId, ctxt); + toolsExecutors, contents, context, memoryId, ctxt, mustSwitchToWorkerThread); TokenStream tokenStream = stream .onNext(processor::onNext) .onComplete(message -> processor.onComplete()) diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceStreamingResponseHandler.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceStreamingResponseHandler.java index 8b103d09a..b76b98d04 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceStreamingResponseHandler.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceStreamingResponseHandler.java @@ -21,6 +21,7 @@ import dev.langchain4j.model.output.TokenUsage; import dev.langchain4j.service.AiServiceContext; import dev.langchain4j.service.tool.ToolExecutor; +import io.smallrye.mutiny.infrastructure.Infrastructure; import io.vertx.core.Context; /** @@ -45,6 +46,7 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon private final List toolSpecifications; private final Map toolExecutors; private final Context executionContext; + private final boolean mustSwitchToWorkerThread; QuarkusAiServiceStreamingResponseHandler(AiServiceContext context, Object memoryId, @@ -54,7 +56,7 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon List temporaryMemory, TokenUsage tokenUsage, List toolSpecifications, - Map toolExecutors, Context cxtx) { + Map toolExecutors, boolean mustSwitchToWorkerThread, Context cxtx) { this.context = ensureNotNull(context, "context"); this.memoryId = ensureNotNull(memoryId, "memoryId"); @@ -68,6 +70,7 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon this.toolSpecifications = copyIfNotNull(toolSpecifications); this.toolExecutors = copyIfNotNull(toolExecutors); + this.mustSwitchToWorkerThread = mustSwitchToWorkerThread; this.executionContext = cxtx; } @@ -77,14 +80,19 @@ public void onNext(String token) { } private void executeTools(Runnable runnable) { - if (executionContext != null && Context.isOnEventLoopThread()) { - executionContext.executeBlocking(new Callable() { - @Override - public Object call() throws Exception { - runnable.run(); - return null; - } - }); + if (mustSwitchToWorkerThread && Context.isOnEventLoopThread()) { + if (executionContext != null) { + executionContext.executeBlocking(new Callable() { + @Override + public Object call() { + runnable.run(); + return null; + } + }); + } else { + // We do not have a context, switching to worker thread. + Infrastructure.getDefaultWorkerPool().execute(runnable); + } } else { runnable.run(); } @@ -125,7 +133,7 @@ public void run() { TokenUsage.sum(tokenUsage, response.tokenUsage()), toolSpecifications, toolExecutors, - executionContext)); + mustSwitchToWorkerThread, executionContext)); } }); } else { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceTokenStream.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceTokenStream.java index 820e5ae65..c81f35315 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceTokenStream.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceTokenStream.java @@ -38,6 +38,7 @@ public class QuarkusAiServiceTokenStream implements TokenStream { private final AiServiceContext context; private final Object memoryId; private final Context cxtx; + private final boolean mustSwitchToWorkerThread; private Consumer tokenHandler; private Consumer> contentsHandler; @@ -55,7 +56,7 @@ public QuarkusAiServiceTokenStream(List messages, Map toolExecutors, List retrievedContents, AiServiceContext context, - Object memoryId, Context ctxt) { + Object memoryId, Context ctxt, boolean mustSwitchToWorkerThread) { this.messages = ensureNotEmpty(messages, "messages"); this.toolSpecifications = copyIfNotNull(toolSpecifications); this.toolExecutors = copyIfNotNull(toolExecutors); @@ -63,7 +64,8 @@ public QuarkusAiServiceTokenStream(List messages, this.context = ensureNotNull(context, "context"); this.memoryId = ensureNotNull(memoryId, "memoryId"); ensureNotNull(context.streamingChatModel, "streamingChatModel"); - this.cxtx = ctxt; // If set, it means we need to switch to a worker thread to execute tools. + this.cxtx = ctxt; // If set, it means we need to handle the context propagation. + this.mustSwitchToWorkerThread = mustSwitchToWorkerThread; // If true, we need to switch to a worker thread to execute tools. } @Override @@ -114,6 +116,7 @@ public void start() { new TokenUsage(), toolSpecifications, toolExecutors, + mustSwitchToWorkerThread, cxtx); if (contentsHandler != null && retrievedContents != null) { diff --git a/integration-tests/openai/src/main/java/org/acme/example/openai/chat/ChatLanguageModelResource.java b/integration-tests/openai/src/main/java/org/acme/example/openai/chat/ChatLanguageModelResource.java index ae15843a5..0012a3779 100644 --- a/integration-tests/openai/src/main/java/org/acme/example/openai/chat/ChatLanguageModelResource.java +++ b/integration-tests/openai/src/main/java/org/acme/example/openai/chat/ChatLanguageModelResource.java @@ -116,13 +116,11 @@ public String memory() throws Exception { @Override public void onNext(String token) { - // System.out.print(token); sb.append(token); } @Override public void onComplete(Response response) { - // System.out.println("\n\ndone\n\n"); if (response != null) { futureRef.get().complete(response.content()); } else { @@ -136,7 +134,6 @@ public void onError(Throwable throwable) { } }; - // System.out.println("starting first\n\n"); streamingChatLanguageModel.generate(chatMemory.messages(), handler); AiMessage firstAiMessage = futureRef.get().get(60, TimeUnit.SECONDS); chatMemory.add(firstAiMessage); @@ -151,7 +148,6 @@ public void onError(Throwable throwable) { .append("\n[LLM]: "); futureRef.set(new CompletableFuture<>()); - // System.out.println("\n\n\nstarting second\n\n"); streamingChatLanguageModel.generate(chatMemory.messages(), handler); futureRef.get().get(60, TimeUnit.SECONDS); return sb.toString(); diff --git a/tools/tavily/deployment/src/test/java/io/quarkiverse/langchain4j/tavily/test/TavilyTest.java b/tools/tavily/deployment/src/test/java/io/quarkiverse/langchain4j/tavily/test/TavilyTest.java index f05efc882..82fc92ea0 100644 --- a/tools/tavily/deployment/src/test/java/io/quarkiverse/langchain4j/tavily/test/TavilyTest.java +++ b/tools/tavily/deployment/src/test/java/io/quarkiverse/langchain4j/tavily/test/TavilyTest.java @@ -104,7 +104,6 @@ public void testSearch() { .build(); WebSearchResults result = webSearchEngine.search(searchRequest); LoggedRequest actualRequest = singleLoggedRequest(); - System.out.println(actualRequest); // verify the request assertEquals(actualRequest.getHeader("Accept"), "application/json");