Skip to content

Commit

Permalink
Allows methods annotated with @tool to used @Blocking / @nonblocking
Browse files Browse the repository at this point in the history
…and @RunOnVirtualThread
  • Loading branch information
cescoffier committed Nov 5, 2024
1 parent 04c8bd3 commit b33c558
Show file tree
Hide file tree
Showing 17 changed files with 1,146 additions and 19 deletions.
5 changes: 5 additions & 0 deletions core/deployment/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-test-vertx</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.NO_RETRIEVAL_AUGMENTOR_SUPPLIER;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.NO_RETRIEVER;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.OUTPUT_GUARDRAILS;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.REGISTER_AI_SERVICES;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.SEED_MEMORY;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.V;
import static io.quarkiverse.langchain4j.deployment.MethodParameterAsTemplateVariableAllowance.FORCE_ALLOW;
Expand All @@ -31,7 +32,6 @@
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletionStage;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
Expand Down Expand Up @@ -778,9 +778,9 @@ public boolean detectAiServiceMethodThanNeedToBeDispatchedOnWorkerThread(
MethodInfo method,
List<String> associatedTools,
List<ToolMethodBuildItem> tools) {
boolean reactive = method.returnType().name().equals(DotName.createSimple(Uni.class.getName()))
|| method.returnType().name().equals(DotName.createSimple(CompletionStage.class.getName()))
|| method.returnType().name().equals(DotName.createSimple(Multi.class.getName()));
boolean reactive = method.returnType().name().equals(DotNames.UNI)
|| method.returnType().name().equals(DotNames.COMPLETION_STAGE)
|| method.returnType().name().equals(DotNames.MULTI);

boolean requireSwitchToWorkerThread = false;

Expand Down Expand Up @@ -1260,8 +1260,8 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(

String accumulatorClassName = AiServicesMethodBuildItem.gatherAccumulator(method);

boolean switchToWorkerThread = detectAiServiceMethodThanNeedToBeDispatchedOnWorkerThread(method, methodToolClassNames,
tools);
// Detect if tools execution may block the caller thread.
boolean switchToWorkerThread = detectIfToolExecutionRequiresAWorkerThread(method, tools, methodToolClassNames);

return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo,
userMessageInfo, memoryIdParamPosition, requiresModeration,
Expand All @@ -1270,6 +1270,20 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
inputGuardrails, outputGuardrails, accumulatorClassName);
}

private boolean detectIfToolExecutionRequiresAWorkerThread(MethodInfo method, List<ToolMethodBuildItem> tools,
List<String> methodToolClassNames) {
List<String> allTools = new ArrayList<>(methodToolClassNames);
// We need to combine it with the tools that are registered globally - unfortunately, we don't have access to the AI service here, so, re-parsing.
AnnotationInstance annotation = method.declaringClass().annotation(REGISTER_AI_SERVICES);
if (annotation != null) {
AnnotationValue value = annotation.value("tools");
if (value != null) {
allTools.addAll(Arrays.stream(value.asClassArray()).map(t -> t.name().toString()).toList());
}
}
return detectAiServiceMethodThanNeedToBeDispatchedOnWorkerThread(method, allTools, tools);
}

private void validateReturnType(MethodInfo method) {
Type returnType = method.returnType();
Type.Kind returnTypeKind = returnType.kind();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.jboss.jandex.DotName;

import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import io.smallrye.common.annotation.Blocking;
import io.smallrye.common.annotation.NonBlocking;
Expand Down Expand Up @@ -54,4 +55,5 @@ public class DotNames {
public static final DotName CDI_INSTANCE = DotName.createSimple(Instance.class);

public static final DotName CHAT_MODEL_LISTENER = DotName.createSimple(ChatModelListener.class);
public static final DotName TOOL = DotName.createSimple(Tool.class);
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import org.jboss.jandex.AnnotationInstance;
Expand Down Expand Up @@ -68,6 +69,7 @@
import io.quarkus.deployment.builditem.CombinedIndexBuildItem;
import io.quarkus.deployment.builditem.GeneratedClassBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
import io.quarkus.deployment.execannotations.ExecutionModelAnnotationsAllowedBuildItem;
import io.quarkus.deployment.recording.RecorderContext;
import io.quarkus.gizmo.ClassCreator;
import io.quarkus.gizmo.ClassOutput;
Expand Down Expand Up @@ -263,6 +265,16 @@ public void handleTools(
toolsMetadataProducer.produce(new ToolsMetadataBeforeRemovalBuildItem(metadata));
}

@BuildStep
ExecutionModelAnnotationsAllowedBuildItem toolsMethods() {
return new ExecutionModelAnnotationsAllowedBuildItem(new Predicate<MethodInfo>() {
@Override
public boolean test(MethodInfo method) {
return method.hasDeclaredAnnotation(DotNames.TOOL);
}
});
}

private void validateExecutionModel(ToolMethodCreateInfo methodCreateInfo, MethodInfo toolMethod,
BuildProducer<ValidationPhaseBuildItem.ValidationErrorBuildItem> validation) {
String methodName = toolMethod.declaringClass().name() + "." + toolMethod.name();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo;
import io.quarkus.builder.item.MultiBuildItem;

/**
* A build item that represents a method that is annotated with {@link dev.langchain4j.agent.tool.Tool}.
* It contains the method info and the tool method create info.
*/
public final class ToolMethodBuildItem extends MultiBuildItem {

private final MethodInfo toolsMethodInfo;
Expand Down Expand Up @@ -33,13 +37,11 @@ public ToolMethodCreateInfo getToolMethodCreateInfo() {
* Returns true if the method requires a switch to a worker thread, even if the method is non-blocking.
* This is because of the tools executor limitation (imperative API).
*
*
* @return true if the method requires a switch to a worker thread
*/
public boolean requiresSwitchToWorkerThread() {
return !(toolMethodCreateInfo.executionModel() == ToolMethodCreateInfo.ExecutionModel.NON_BLOCKING
&& isImperativeMethod());

}

private boolean isImperativeMethod() {
Expand Down
Loading

0 comments on commit b33c558

Please sign in to comment.