Skip to content

Commit

Permalink
Handle thread switch for tools execution when the initial call in not…
Browse files Browse the repository at this point in the history
… done from a Vert.x context
  • Loading branch information
cescoffier committed Nov 6, 2024
1 parent 00cd5ef commit bc7d6c0
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,6 @@ public boolean detectAiServiceMethodThanNeedToBeDispatchedOnWorkerThread(
throw new RuntimeException("No tools detected in " + classname);
}
}

return requireSwitchToWorkerThread;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Predicate;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -120,6 +122,8 @@ public void handleTools(
List<String> generatedInvokerClasses = new ArrayList<>();
List<String> generatedArgumentMapperClasses = new ArrayList<>();

Set<String> toolsNames = new HashSet<>();

if (!instances.isEmpty()) {
ClassOutput classOutput = new GeneratedClassGizmoAdaptor(generatedClassProducer, true);

Expand Down Expand Up @@ -240,7 +244,13 @@ public void handleTools(

validateExecutionModel(methodCreateInfo, toolMethod, validation);

toolMethodBuildItemProducer.produce(new ToolMethodBuildItem(toolMethod, methodCreateInfo));
if (toolsNames.add(toolName)) {
toolMethodBuildItemProducer.produce(new ToolMethodBuildItem(toolMethod, methodCreateInfo));
} else {
validation.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem(
new IllegalStateException("A tool with the name '" + toolName
+ "' is already declared. Tools method name must be unique.")));
}

metadata.computeIfAbsent(className.toString(), (c) -> new ArrayList<>()).add(methodCreateInfo);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ private void createTokenStream(UnicastProcessor<String> processor) {
}

var stream = new QuarkusAiServiceTokenStream(messagesToSend, toolSpecifications,
toolsExecutors, contents, context, memoryId, ctxt);
toolsExecutors, contents, context, memoryId, ctxt, mustSwitchToWorkerThread);
TokenStream tokenStream = stream
.onNext(processor::onNext)
.onComplete(message -> processor.onComplete())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.tool.ToolExecutor;
import io.smallrye.mutiny.infrastructure.Infrastructure;
import io.vertx.core.Context;

/**
Expand All @@ -45,6 +46,7 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon
private final List<ToolSpecification> toolSpecifications;
private final Map<String, ToolExecutor> toolExecutors;
private final Context executionContext;
private final boolean mustSwitchToWorkerThread;

QuarkusAiServiceStreamingResponseHandler(AiServiceContext context,
Object memoryId,
Expand All @@ -54,7 +56,7 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon
List<ChatMessage> temporaryMemory,
TokenUsage tokenUsage,
List<ToolSpecification> toolSpecifications,
Map<String, ToolExecutor> toolExecutors, Context cxtx) {
Map<String, ToolExecutor> toolExecutors, boolean mustSwitchToWorkerThread, Context cxtx) {
this.context = ensureNotNull(context, "context");
this.memoryId = ensureNotNull(memoryId, "memoryId");

Expand All @@ -68,6 +70,7 @@ public class QuarkusAiServiceStreamingResponseHandler implements StreamingRespon
this.toolSpecifications = copyIfNotNull(toolSpecifications);
this.toolExecutors = copyIfNotNull(toolExecutors);

this.mustSwitchToWorkerThread = mustSwitchToWorkerThread;
this.executionContext = cxtx;
}

Expand All @@ -77,14 +80,19 @@ public void onNext(String token) {
}

private void executeTools(Runnable runnable) {
if (executionContext != null && Context.isOnEventLoopThread()) {
executionContext.executeBlocking(new Callable<Object>() {
@Override
public Object call() throws Exception {
runnable.run();
return null;
}
});
if (mustSwitchToWorkerThread && Context.isOnEventLoopThread()) {
if (executionContext != null) {
executionContext.executeBlocking(new Callable<Object>() {
@Override
public Object call() {
runnable.run();
return null;
}
});
} else {
// We do not have a context, switching to worker thread.
Infrastructure.getDefaultWorkerPool().execute(runnable);
}
} else {
runnable.run();
}
Expand Down Expand Up @@ -125,7 +133,7 @@ public void run() {
TokenUsage.sum(tokenUsage, response.tokenUsage()),
toolSpecifications,
toolExecutors,
executionContext));
mustSwitchToWorkerThread, executionContext));
}
});
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public class QuarkusAiServiceTokenStream implements TokenStream {
private final AiServiceContext context;
private final Object memoryId;
private final Context cxtx;
private final boolean mustSwitchToWorkerThread;

private Consumer<String> tokenHandler;
private Consumer<List<Content>> contentsHandler;
Expand All @@ -55,15 +56,16 @@ public QuarkusAiServiceTokenStream(List<ChatMessage> messages,
Map<String, ToolExecutor> toolExecutors,
List<Content> retrievedContents,
AiServiceContext context,
Object memoryId, Context ctxt) {
Object memoryId, Context ctxt, boolean mustSwitchToWorkerThread) {
this.messages = ensureNotEmpty(messages, "messages");
this.toolSpecifications = copyIfNotNull(toolSpecifications);
this.toolExecutors = copyIfNotNull(toolExecutors);
this.retrievedContents = retrievedContents;
this.context = ensureNotNull(context, "context");
this.memoryId = ensureNotNull(memoryId, "memoryId");
ensureNotNull(context.streamingChatModel, "streamingChatModel");
this.cxtx = ctxt; // If set, it means we need to switch to a worker thread to execute tools.
this.cxtx = ctxt; // If set, it means we need to handle the context propagation.
this.mustSwitchToWorkerThread = mustSwitchToWorkerThread; // If true, we need to switch to a worker thread to execute tools.
}

@Override
Expand Down Expand Up @@ -114,6 +116,7 @@ public void start() {
new TokenUsage(),
toolSpecifications,
toolExecutors,
mustSwitchToWorkerThread,
cxtx);

if (contentsHandler != null && retrievedContents != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,11 @@ public String memory() throws Exception {

@Override
public void onNext(String token) {
// System.out.print(token);
sb.append(token);
}

@Override
public void onComplete(Response<AiMessage> response) {
// System.out.println("\n\ndone\n\n");
if (response != null) {
futureRef.get().complete(response.content());
} else {
Expand All @@ -136,7 +134,6 @@ public void onError(Throwable throwable) {
}
};

// System.out.println("starting first\n\n");
streamingChatLanguageModel.generate(chatMemory.messages(), handler);
AiMessage firstAiMessage = futureRef.get().get(60, TimeUnit.SECONDS);
chatMemory.add(firstAiMessage);
Expand All @@ -151,7 +148,6 @@ public void onError(Throwable throwable) {
.append("\n[LLM]: ");

futureRef.set(new CompletableFuture<>());
// System.out.println("\n\n\nstarting second\n\n");
streamingChatLanguageModel.generate(chatMemory.messages(), handler);
futureRef.get().get(60, TimeUnit.SECONDS);
return sb.toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ public void testSearch() {
.build();
WebSearchResults result = webSearchEngine.search(searchRequest);
LoggedRequest actualRequest = singleLoggedRequest();
System.out.println(actualRequest);

// verify the request
assertEquals(actualRequest.getHeader("Accept"), "application/json");
Expand Down

0 comments on commit bc7d6c0

Please sign in to comment.