Skip to content

Commit

Permalink
Add OpenTelemetry support to tool execution
Browse files Browse the repository at this point in the history
  • Loading branch information
geoand committed Dec 1, 2023
1 parent 10d0634 commit 4b4f212
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,12 @@
import io.quarkiverse.langchain4j.runtime.tool.ToolInvoker;
import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo;
import io.quarkiverse.langchain4j.runtime.tool.ToolParametersObjectSubstitution;
import io.quarkiverse.langchain4j.runtime.tool.ToolSpanWrapper;
import io.quarkiverse.langchain4j.runtime.tool.ToolSpecificationObjectSubstitution;
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.deployment.ValidationPhaseBuildItem;
import io.quarkus.deployment.Capabilities;
import io.quarkus.deployment.Capability;
import io.quarkus.deployment.GeneratedClassGizmoAdaptor;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
Expand Down Expand Up @@ -75,6 +79,14 @@ public class ToolProcessor {
public static final MethodDescriptor MAP_PUT = MethodDescriptor.ofMethod(Map.class, "put", Object.class, Object.class,
Object.class);

@BuildStep
public void telemetry(Capabilities capabilities, BuildProducer<AdditionalBeanBuildItem> additionalBeanProducer) {
var addOpenTelemetrySpan = capabilities.isPresent(Capability.OPENTELEMETRY_TRACER);
if (addOpenTelemetrySpan) {
additionalBeanProducer.produce(AdditionalBeanBuildItem.builder().addBeanClass(ToolSpanWrapper.class).build());
}
}

@BuildStep
@Record(ExecutionTime.STATIC_INIT)
public void handleTools(CombinedIndexBuildItem indexBuildItem,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,9 @@ private ToolExecutor getToolExecutor(String methodName) {
ToolSpecification toolSpecification = methodCreateInfo.getToolSpecification();
if (methodName.equals(
toolSpecification.name())) { // this only works because TestTool does not contain overloaded methods
toolExecutor = new QuarkusToolExecutor(testTool, invokerClassName, methodCreateInfo.getMethodName(),
methodCreateInfo.getArgumentMapperClassName());
toolExecutor = new QuarkusToolExecutor(
new QuarkusToolExecutor.Context(testTool, invokerClassName, methodCreateInfo.getMethodName(),
methodCreateInfo.getArgumentMapperClassName()));
break;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceClassCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo;
import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutor;
import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutorFactory;
import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ClientProxy;

public class QuarkusAiServicesFactory implements AiServicesFactory {
Expand All @@ -33,8 +35,12 @@ public static class InstanceHolder {
}

public static class QuarkusAiServices<T> extends AiServices<T> {

private final QuarkusToolExecutorFactory toolExecutorFactory;

public QuarkusAiServices(AiServiceContext context) {
super(context);
toolExecutorFactory = Arc.container().instance(QuarkusToolExecutorFactory.class).get();
}

@Override
Expand All @@ -54,9 +60,10 @@ public AiServices<T> tools(List<Object> objectsWithTools) {
String invokerClassName = methodCreateInfo.getInvokerClassName();
ToolSpecification toolSpecification = methodCreateInfo.getToolSpecification();
context.toolSpecifications.add(toolSpecification);
context.toolExecutors.put(toolSpecification.name(),
new QuarkusToolExecutor(objectWithTool, invokerClassName, methodCreateInfo.getMethodName(),
methodCreateInfo.getArgumentMapperClassName()));
QuarkusToolExecutor.Context executorContext = new QuarkusToolExecutor.Context(objectWithTool,
invokerClassName, methodCreateInfo.getMethodName(),
methodCreateInfo.getArgumentMapperClassName());
context.toolExecutors.put(toolSpecification.name(), toolExecutorFactory.create(executorContext));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ public Object apply(Input input) {
});

for (Wrapper wrapper : wrappers) {
Function<Input, Object> currentFun = funRef.get();
Function<Input, Object> newFunction = new Function<Input, Object>() {
var currentFun = funRef.get();
var newFunction = new Function<Input, Object>() {
@Override
public Object apply(Input input) {
return wrapper.wrap(input, currentFun);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public SpanWrapper(OpenTelemetry openTelemetry) {

// TODO: there is probably more information here we need to set
this.instrumenter = builder
.buildInstrumenter(new SpanKindExtractor<AiServiceMethodImplementationSupport.Input>() {
.buildInstrumenter(new SpanKindExtractor<>() {
@Override
public SpanKind extract(AiServiceMethodImplementationSupport.Input input) {
return SpanKind.INTERNAL;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.lang.reflect.InvocationTargetException;
import java.util.Arrays;
import java.util.Map;
import java.util.function.BiFunction;

import org.jboss.logging.Logger;

Expand All @@ -18,17 +19,19 @@ public class QuarkusToolExecutor implements ToolExecutor {

private static final Logger log = Logger.getLogger(QuarkusToolExecutor.class);

private final Object tool;
private final String toolInvokerName;
private final String methodName;
private final Context context;

private final String argumentMapperClassName;
public record Context(Object tool, String toolInvokerName, String methodName, String argumentMapperClassName) {
}

public interface Wrapper {

public QuarkusToolExecutor(Object tool, String toolInvokerName, String methodName, String argumentMapperClassName) {
this.tool = tool;
this.toolInvokerName = toolInvokerName;
this.methodName = methodName;
this.argumentMapperClassName = argumentMapperClassName;
String wrap(ToolExecutionRequest toolExecutionRequest, Object memoryId,
BiFunction<ToolExecutionRequest, Object, String> fun);
}

public QuarkusToolExecutor(Context context) {
this.context = context;
}

public String execute(ToolExecutionRequest toolExecutionRequest, Object memoryId) {
Expand All @@ -39,9 +42,9 @@ public String execute(ToolExecutionRequest toolExecutionRequest, Object memoryId
Object[] params = prepareArguments(toolExecutionRequest, invokerInstance.methodMetadata());
try {
if (log.isDebugEnabled()) {
log.debugv("Attempting to invoke tool '{0}' with parameters '{1}'", tool, Arrays.toString(params));
log.debugv("Attempting to invoke tool '{0}' with parameters '{1}'", context.tool, Arrays.toString(params));
}
Object invocationResult = invokerInstance.invoke(tool,
Object invocationResult = invokerInstance.invoke(context.tool,
params);
String result = handleResult(invokerInstance, invocationResult);
log.debugv("Tool execution result: '{0}'", result);
Expand All @@ -50,7 +53,7 @@ public String execute(ToolExecutionRequest toolExecutionRequest, Object memoryId
if (e instanceof IllegalArgumentException) {
throw (IllegalArgumentException) e;
}
log.error("Error while executing tool '" + tool.getClass() + "'", e);
log.error("Error while executing tool '" + context.tool.getClass() + "'", e);
return e.getMessage();
}
}
Expand All @@ -66,12 +69,14 @@ private static String handleResult(ToolInvoker invokerInstance, Object invocatio
private ToolInvoker createInvokerInstance() {
ToolInvoker invokerInstance;
try {
invokerInstance = (ToolInvoker) Class.forName(toolInvokerName, true, Thread.currentThread()
invokerInstance = (ToolInvoker) Class.forName(context.toolInvokerName, true, Thread.currentThread()
.getContextClassLoader()).getConstructor().newInstance();
} catch (ClassNotFoundException | NoSuchMethodException | InstantiationException | IllegalAccessException
| InvocationTargetException e) {
throw new IllegalStateException(
"Unable to create instance of '" + toolInvokerName + "'. Please report this issue to the maintainers", e);
"Unable to create instance of '" + context.toolInvokerName
+ "'. Please report this issue to the maintainers",
e);
}
return invokerInstance;
}
Expand Down Expand Up @@ -113,18 +118,20 @@ private Map<String, Object> convertJsonToArguments(String argumentsJsonStr) thro
@SuppressWarnings("unchecked")
private Class<? extends Mappable> loadMapperClass() {
try {
return (Class<? extends Mappable>) Class.forName(argumentMapperClassName, true, Thread.currentThread()
return (Class<? extends Mappable>) Class.forName(context.argumentMapperClassName, true, Thread.currentThread()
.getContextClassLoader());
} catch (ClassNotFoundException e) {
throw new IllegalStateException(
"Unable to load argument mapper of '" + toolInvokerName + "'. Please report this issue to the maintainers",
"Unable to load argument mapper of '" + context.toolInvokerName
+ "'. Please report this issue to the maintainers",
e);
}
}

private void invalidMethodParams(String argumentsJsonStr) {
throw new IllegalArgumentException("params '" + argumentsJsonStr
+ "' from request do not map onto the parameters needed by '" + tool.getClass().getName() + "#" + methodName
+ "' from request do not map onto the parameters needed by '" + context.tool.getClass().getName() + "#"
+ context.methodName
+ "'");
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package io.quarkiverse.langchain4j.runtime.tool;

import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;

import jakarta.inject.Singleton;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import io.quarkus.arc.All;
import io.quarkus.arc.Unremovable;

@Singleton
@Unremovable
public class QuarkusToolExecutorFactory {

private final List<QuarkusToolExecutor.Wrapper> wrappers;

public QuarkusToolExecutorFactory(@All List<QuarkusToolExecutor.Wrapper> wrappers) {
this.wrappers = wrappers;
}

public QuarkusToolExecutor create(QuarkusToolExecutor.Context context) {
if (wrappers.isEmpty()) {
return new QuarkusToolExecutor(context);
}

return new QuarkusToolExecutor(context) {
final QuarkusToolExecutor originalTool = new QuarkusToolExecutor(context);

@Override
public String execute(ToolExecutionRequest toolExecutionRequest, Object memoryId) {
AtomicReference<BiFunction<ToolExecutionRequest, Object, String>> funRef = new AtomicReference<>(
new BiFunction<>() {
@Override
public String apply(ToolExecutionRequest toolExecutionRequest, Object o) {
return originalTool.execute(toolExecutionRequest, memoryId);
}
});

for (QuarkusToolExecutor.Wrapper wrapper : wrappers) {
var currentFun = funRef.get();
BiFunction<ToolExecutionRequest, Object, String> newFunction = new BiFunction<>() {
@Override
public String apply(ToolExecutionRequest toolExecutionRequest, Object memoryId) {
return wrapper.wrap(toolExecutionRequest, memoryId, currentFun);
}
};
funRef.set(newFunction);
}

return funRef.get().apply(toolExecutionRequest, memoryId);
}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package io.quarkiverse.langchain4j.runtime.tool;

import java.util.function.BiFunction;

import jakarta.inject.Inject;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import io.opentelemetry.api.OpenTelemetry;
import io.opentelemetry.api.trace.SpanKind;
import io.opentelemetry.context.Context;
import io.opentelemetry.context.Scope;
import io.opentelemetry.instrumentation.api.instrumenter.Instrumenter;
import io.opentelemetry.instrumentation.api.instrumenter.InstrumenterBuilder;
import io.opentelemetry.instrumentation.api.instrumenter.SpanKindExtractor;
import io.opentelemetry.instrumentation.api.instrumenter.SpanNameExtractor;

public class ToolSpanWrapper implements QuarkusToolExecutor.Wrapper {

private static final String INSTRUMENTATION_NAME = "io.quarkus.opentelemetry";

private final Instrumenter<ToolExecutionRequest, Void> instrumenter;

@Inject
public ToolSpanWrapper(OpenTelemetry openTelemetry) {
InstrumenterBuilder<ToolExecutionRequest, Void> builder = Instrumenter.builder(
openTelemetry,
INSTRUMENTATION_NAME,
InputSpanNameExtractor.INSTANCE);

// TODO: there is probably more information here we need to set
this.instrumenter = builder
.buildInstrumenter(new SpanKindExtractor<>() {
@Override
public SpanKind extract(ToolExecutionRequest toolExecutionRequest) {
return SpanKind.INTERNAL;
}
});
}

@Override
public String wrap(ToolExecutionRequest toolExecutionRequest, Object memoryId,
BiFunction<ToolExecutionRequest, Object, String> fun) {
Context parentContext = Context.current();
Context spanContext = null;
Scope scope = null;
boolean shouldStart = instrumenter.shouldStart(parentContext, toolExecutionRequest);
if (shouldStart) {
spanContext = instrumenter.start(parentContext, toolExecutionRequest);
scope = spanContext.makeCurrent();
}

try {
String result = fun.apply(toolExecutionRequest, memoryId);

if (shouldStart) {
instrumenter.end(spanContext, toolExecutionRequest, null, null);
}

return result;
} catch (Throwable t) {
if (shouldStart) {
instrumenter.end(spanContext, toolExecutionRequest, null, t);
}
throw t;
} finally {
if (scope != null) {
scope.close();
}
}
}

private static class InputSpanNameExtractor implements SpanNameExtractor<ToolExecutionRequest> {

private static final InputSpanNameExtractor INSTANCE = new InputSpanNameExtractor();

@Override
public String extract(ToolExecutionRequest toolExecutionRequest) {
return "langchain4j.tools." + toolExecutionRequest.name();
}
}
}

0 comments on commit 4b4f212

Please sign in to comment.