Skip to content

Commit

Permalink
Merge pull request #1023 from cescoffier/tools-execution-model
Browse files Browse the repository at this point in the history
Enable Tools to Define Execution Model
  • Loading branch information
cescoffier authored Nov 6, 2024
2 parents 1b1115e + 790932f commit b9d826a
Show file tree
Hide file tree
Showing 33 changed files with 2,266 additions and 50 deletions.
5 changes: 5 additions & 0 deletions core/deployment/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-test-vertx</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.NO_RETRIEVAL_AUGMENTOR_SUPPLIER;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.NO_RETRIEVER;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.OUTPUT_GUARDRAILS;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.REGISTER_AI_SERVICES;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.SEED_MEMORY;
import static io.quarkiverse.langchain4j.deployment.LangChain4jDotNames.V;
import static io.quarkiverse.langchain4j.deployment.MethodParameterAsTemplateVariableAllowance.FORCE_ALLOW;
Expand Down Expand Up @@ -69,6 +70,7 @@
import io.quarkiverse.langchain4j.deployment.items.MethodParameterAllowedAnnotationsBuildItem;
import io.quarkiverse.langchain4j.deployment.items.MethodParameterIgnoredAnnotationsBuildItem;
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
import io.quarkiverse.langchain4j.deployment.items.ToolMethodBuildItem;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator;
import io.quarkiverse.langchain4j.runtime.AiServicesRecorder;
Expand Down Expand Up @@ -123,6 +125,7 @@
import io.quarkus.gizmo.ResultHandle;
import io.quarkus.runtime.metrics.MetricsFactory;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.Uni;

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class AiServicesProcessor {
Expand Down Expand Up @@ -760,6 +763,57 @@ public void markUsedOutputGuardRailsUnremovable(List<AiServicesMethodBuildItem>
}
}

/**
* Because the tools execution uses an imperative API (`String execute(...)`) and uses the caller thread, we need
* to anticipate the need to dispatch the invocation on a worker thread.
* This is the case for AI service methods that returns `Uni`, `CompletionStage` and `Multi` (stream) and that uses
* tools returning `Uni`, `CompletionStage` or that are blocking.
* Basically, for "reactive AI service method, the switch is necessary except if all the tools are imperative (return `T`)
* and marked explicitly as blocking (using `@Blocking`).
*
* @param method the AI method
* @param tools the tools
*/
public boolean detectAiServiceMethodThanNeedToBeDispatchedOnWorkerThread(
MethodInfo method,
List<String> associatedTools,
List<ToolMethodBuildItem> tools) {
boolean reactive = method.returnType().name().equals(DotNames.UNI)
|| method.returnType().name().equals(DotNames.COMPLETION_STAGE)
|| method.returnType().name().equals(DotNames.MULTI);

boolean requireSwitchToWorkerThread = false;

if (associatedTools.isEmpty()) {
// No tools, no need to dispatch
return false;
}

if (!reactive) {
// We are already on a thread we can block.
return false;
}

// We need to find if any of the tools that could be used by the method is requiring a blocking execution
for (String classname : associatedTools) {
// Look for the tool in the list of tools
boolean found = false;
for (ToolMethodBuildItem tool : tools) {
if (tool.getDeclaringClassName().equals(classname)) {
found = true;
if (tool.requiresSwitchToWorkerThread()) {
requireSwitchToWorkerThread = true;
break;
}
}
}
if (!found) {
throw new RuntimeException("No tools detected in " + classname);
}
}
return requireSwitchToWorkerThread;
}

@BuildStep
public void validateGuardrails(SynthesisFinishedBuildItem synthesisFinished,
List<AiServicesMethodBuildItem> methods,
Expand Down Expand Up @@ -857,7 +911,8 @@ public void handleAiServices(
BuildProducer<AdditionalBeanBuildItem> additionalBeanProducer,
BuildProducer<UnremovableBeanBuildItem> unremovableBeanProducer,
Optional<MetricsCapabilityBuildItem> metricsCapability,
Capabilities capabilities) {
Capabilities capabilities,
List<ToolMethodBuildItem> tools) {

IndexView index = indexBuildItem.getIndex();

Expand Down Expand Up @@ -1026,7 +1081,8 @@ public void handleAiServices(
addOpenTelemetrySpan,
config.responseSchema(),
allowedPredicates,
ignoredPredicates);
ignoredPredicates,
tools);
if (!methodCreateInfo.getToolClassNames().isEmpty()) {
unremovableBeanProducer.produce(UnremovableBeanBuildItem
.beanClassNames(methodCreateInfo.getToolClassNames().toArray(EMPTY_STRING_ARRAY)));
Expand Down Expand Up @@ -1068,7 +1124,9 @@ public void handleAiServices(

aiServicesMethodProducer.produce(new AiServicesMethodBuildItem(methodInfo,
methodCreateInfo.getInputGuardrailsClassNames(),
methodCreateInfo.getOutputGuardrailsClassNames()));
methodCreateInfo.getOutputGuardrailsClassNames(),
gatherMethodToolClassNames(methodInfo),
methodCreateInfo));
}
}

Expand Down Expand Up @@ -1155,7 +1213,8 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
MethodInfo method, IndexView index, boolean addMicrometerMetrics,
boolean addOpenTelemetrySpans, boolean generateResponseSchema,
Collection<Predicate<AnnotationInstance>> allowedPredicates,
Collection<Predicate<AnnotationInstance>> ignoredPredicates) {
Collection<Predicate<AnnotationInstance>> ignoredPredicates,
List<ToolMethodBuildItem> tools) {
validateReturnType(method);

boolean requiresModeration = method.hasAnnotation(LangChain4jDotNames.MODERATE);
Expand Down Expand Up @@ -1200,11 +1259,28 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(

String accumulatorClassName = AiServicesMethodBuildItem.gatherAccumulator(method);

// Detect if tools execution may block the caller thread.
boolean switchToWorkerThread = detectIfToolExecutionRequiresAWorkerThread(method, tools, methodToolClassNames);

return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo,
userMessageInfo, memoryIdParamPosition, requiresModeration,
returnTypeSignature(method.returnType(), new TypeArgMapper(method.declaringClass(), index)),
metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames, inputGuardrails,
outputGuardrails, accumulatorClassName);
metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames, switchToWorkerThread,
inputGuardrails, outputGuardrails, accumulatorClassName);
}

private boolean detectIfToolExecutionRequiresAWorkerThread(MethodInfo method, List<ToolMethodBuildItem> tools,
List<String> methodToolClassNames) {
List<String> allTools = new ArrayList<>(methodToolClassNames);
// We need to combine it with the tools that are registered globally - unfortunately, we don't have access to the AI service here, so, re-parsing.
AnnotationInstance annotation = method.declaringClass().annotation(REGISTER_AI_SERVICES);
if (annotation != null) {
AnnotationValue value = annotation.value("tools");
if (value != null) {
allTools.addAll(Arrays.stream(value.asClassArray()).map(t -> t.name().toString()).toList());
}
}
return detectAiServiceMethodThanNeedToBeDispatchedOnWorkerThread(method, allTools, tools);
}

private void validateReturnType(MethodInfo method) {
Expand Down Expand Up @@ -1688,11 +1764,17 @@ public static final class AiServicesMethodBuildItem extends MultiBuildItem {
private final MethodInfo methodInfo;
private final List<String> outputGuardrails;
private final List<String> inputGuardrails;
private final List<String> tools;
private final AiServiceMethodCreateInfo methodCreateInfo;

public AiServicesMethodBuildItem(MethodInfo methodInfo, List<String> inputGuardrails, List<String> outputGuardrails) {
public AiServicesMethodBuildItem(MethodInfo methodInfo, List<String> inputGuardrails, List<String> outputGuardrails,
List<String> tools,
AiServiceMethodCreateInfo methodCreateInfo) {
this.methodInfo = methodInfo;
this.inputGuardrails = inputGuardrails;
this.outputGuardrails = outputGuardrails;
this.tools = tools;
this.methodCreateInfo = methodCreateInfo;
}

public List<String> getOutputGuardrails() {
Expand All @@ -1707,6 +1789,10 @@ public MethodInfo getMethodInfo() {
return methodInfo;
}

public AiServiceMethodCreateInfo getMethodCreateInfo() {
return methodCreateInfo;
}

public static List<String> gatherGuardrails(MethodInfo methodInfo, DotName annotation) {
List<String> guardrails = new ArrayList<>();
AnnotationInstance instance = methodInfo.annotation(annotation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,19 @@
import java.math.BigInteger;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletionStage;

import jakarta.enterprise.inject.Instance;

import org.jboss.jandex.DotName;

import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import io.smallrye.common.annotation.Blocking;
import io.smallrye.common.annotation.NonBlocking;
import io.smallrye.common.annotation.RunOnVirtualThread;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.Uni;

public class DotNames {

Expand Down Expand Up @@ -38,10 +44,16 @@ public class DotNames {
public static final DotName LIST = DotName.createSimple(List.class);
public static final DotName SET = DotName.createSimple(Set.class);
public static final DotName MULTI = DotName.createSimple(Multi.class);
public static final DotName UNI = DotName.createSimple(Uni.class);
public static final DotName BLOCKING = DotName.createSimple(Blocking.class);
public static final DotName NON_BLOCKING = DotName.createSimple(NonBlocking.class);
public static final DotName COMPLETION_STAGE = DotName.createSimple(CompletionStage.class);
public static final DotName RUN_ON_VIRTUAL_THREAD = DotName.createSimple(RunOnVirtualThread.class);

public static final DotName OBJECT = DotName.createSimple(Object.class.getName());
public static final DotName RECORD = DotName.createSimple(Record.class);
public static final DotName CDI_INSTANCE = DotName.createSimple(Instance.class);

public static final DotName CHAT_MODEL_LISTENER = DotName.createSimple(ChatModelListener.class);
public static final DotName TOOL = DotName.createSimple(Tool.class);
}
Loading

0 comments on commit b9d826a

Please sign in to comment.