From dd428ca02ac814a19b91dfee27a0acd404fc4b42 Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Mon, 4 Dec 2023 11:32:12 +0200 Subject: [PATCH] Sketch out API for auditing Relates to: #8 --- .../deployment/AiServicesProcessor.java | 45 ++- .../DeclarativeAiServiceBuildItem.java | 9 +- .../deployment/Langchain4jDotNames.java | 6 + .../langchain4j/QuarkusAiServicesFactory.java | 18 +- .../langchain4j/RegisterAiService.java | 22 ++ .../quarkiverse/langchain4j/audit/Audit.java | 87 +++++ .../langchain4j/audit/AuditService.java | 21 ++ .../runtime/AiServicesRecorder.java | 25 +- .../aiservice/AiServiceMethodCreateInfo.java | 16 +- .../AiServiceMethodImplementationSupport.java | 56 +++- .../DeclarativeAiServiceCreateInfo.java | 10 +- .../aiservice/QuarkusAiServiceContext.java | 13 + docs/modules/ROOT/pages/ai-services.adoc | 5 + docs/modules/ROOT/pages/index.adoc | 4 + .../aiservices/AuditingServiceTest.java | 316 ++++++++++++++++++ 15 files changed, 631 insertions(+), 22 deletions(-) create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/Audit.java create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/AuditService.java create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java create mode 100644 openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/AuditingServiceTest.java diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index e7f10598a..bc47fde56 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -44,7 +44,6 @@ import org.objectweb.asm.tree.analysis.AnalyzerException; import dev.langchain4j.exception.IllegalConfigurationException; -import dev.langchain4j.service.AiServiceContext; import dev.langchain4j.service.V; import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem; import io.quarkiverse.langchain4j.runtime.AiServicesRecorder; @@ -53,6 +52,7 @@ import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodImplementationSupport; import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceCreateInfo; import io.quarkiverse.langchain4j.runtime.aiservice.MetricsWrapper; +import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext; import io.quarkiverse.langchain4j.runtime.aiservice.SpanWrapper; import io.quarkus.arc.Arc; import io.quarkus.arc.ArcContainer; @@ -101,6 +101,7 @@ public class AiServicesProcessor { private static final MethodDescriptor SUPPORT_IMPLEMENT = MethodDescriptor.ofMethod( AiServiceMethodImplementationSupport.class, "implement", Object.class, AiServiceMethodImplementationSupport.Input.class); + public static final DotName CDI_INSTANCE = DotName.createSimple(Instance.class); @BuildStep public void nativeSupport(CombinedIndexBuildItem indexBuildItem, @@ -203,13 +204,21 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, } } + DotName auditServiceClassSupplierName = Langchain4jDotNames.BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER; + AnnotationValue auditServiceClassSupplierValue = instance.value("auditServiceSupplier"); + if (auditServiceClassSupplierValue != null) { + auditServiceClassSupplierName = auditServiceClassSupplierValue.asClass().name(); + validateSupplierAndRegisterForReflection(auditServiceClassSupplierName, index, reflectiveClassProducer); + } + declarativeAiServiceProducer.produce( new DeclarativeAiServiceBuildItem( declarativeAiServiceClassInfo, chatLanguageModelSupplierClassDotName, toolDotNames, chatMemoryProviderSupplierClassDotName, - retrieverSupplierClassDotName)); + retrieverSupplierClassDotName, + auditServiceClassSupplierName)); } if (needChatModelBean) { @@ -244,6 +253,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, boolean needsChatModelBean = false; boolean needsChatMemoryProviderBean = false; boolean needsRetrieverBean = false; + boolean needsAuditServiceBean = false; Set allToolNames = new HashSet<>(); for (DeclarativeAiServiceBuildItem bi : declarativeAiServiceItems) { @@ -264,12 +274,17 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, ? bi.getRetrieverSupplierClassDotName().toString() : null; + String auditServiceClassSupplierName = bi.getAuditServiceClassSupplierDotName() != null + ? bi.getAuditServiceClassSupplierDotName().toString() + : null; + SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem .configure(declarativeAiServiceClassInfo.name()) .createWith(recorder.createDeclarativeAiService( new DeclarativeAiServiceCreateInfo(serviceClassName, chatLanguageModelSupplierClassName, toolClassNames, chatMemoryProviderSupplierClassName, - retrieverSupplierClassName))) + retrieverSupplierClassName, + auditServiceClassSupplierName))) .setRuntimeInit() .scope(ApplicationScoped.class); if ((chatLanguageModelSupplierClassName == null) && selectedChatModelProvider.isPresent()) { // TODO: is second condition needed? @@ -290,7 +305,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, needsChatMemoryProviderBean = true; } else if (Langchain4jDotNames.BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER.toString() .equals(chatMemoryProviderSupplierClassName)) { - configurator.addInjectionPoint(ParameterizedType.create(DotName.createSimple(Instance.class), + configurator.addInjectionPoint(ParameterizedType.create(CDI_INSTANCE, new Type[] { ClassType.create(Langchain4jDotNames.CHAT_MEMORY_PROVIDER) }, null)); needsChatMemoryProviderBean = true; } @@ -301,13 +316,19 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, needsRetrieverBean = true; } else if (Langchain4jDotNames.BEAN_IF_EXISTS_RETRIEVER_SUPPLIER.toString() .equals(retrieverSupplierClassName)) { - configurator.addInjectionPoint(ParameterizedType.create(DotName.createSimple(Instance.class), + configurator.addInjectionPoint(ParameterizedType.create(CDI_INSTANCE, new Type[] { ParameterizedType.create(Langchain4jDotNames.RETRIEVER, new Type[] { ClassType.create(Langchain4jDotNames.TEXT_SEGMENT) }, null) }, null)); needsRetrieverBean = true; } + if (Langchain4jDotNames.BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER.toString().equals(auditServiceClassSupplierName)) { + configurator.addInjectionPoint(ParameterizedType.create(CDI_INSTANCE, + new Type[] { ClassType.create(Langchain4jDotNames.AUDIT_SERVICE) }, null)); + needsAuditServiceBean = true; + } + syntheticBeanProducer.produce(configurator.done()); } @@ -320,6 +341,9 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, if (needsRetrieverBean) { unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(Langchain4jDotNames.RETRIEVER)); } + if (needsAuditServiceBean) { + unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(Langchain4jDotNames.AUDIT_SERVICE)); + } if (!allToolNames.isEmpty()) { unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(allToolNames)); } @@ -436,19 +460,19 @@ public void handleAiServices(AiServicesRecorder recorder, .interfaces(iface.name().toString()) .build()) { - FieldDescriptor contextField = classCreator.getFieldCreator("context", AiServiceContext.class) + FieldDescriptor contextField = classCreator.getFieldCreator("context", QuarkusAiServiceContext.class) .setModifiers(Modifier.PRIVATE | Modifier.FINAL) .getFieldDescriptor(); for (MethodInfo methodInfo : methodsToImplement) { - // The implementation essentially gets method the context and delegates to + // The implementation essentially gets the context and delegates to // MethodImplementationSupport#implement String methodId = createMethodId(methodInfo); perMethodMetadata.put(methodId, gatherMethodMetadata(methodInfo, addMicrometerMetrics, addOpenTelemetrySpan)); MethodCreator constructor = classCreator.getMethodCreator(MethodDescriptor.INIT, "V", - AiServiceContext.class); + QuarkusAiServiceContext.class); constructor.invokeSpecialMethod(OBJECT_CONSTRUCTOR, constructor.getThis()); constructor.writeInstanceField(contextField, constructor.getThis(), constructor.getMethodParam(0)); constructor.returnValue(null); @@ -466,7 +490,7 @@ public void handleAiServices(AiServicesRecorder recorder, ResultHandle supportHandle = getFromCDI(mc, AiServiceMethodImplementationSupport.class.getName()); ResultHandle inputHandle = mc.newInstance( MethodDescriptor.ofConstructor(AiServiceMethodImplementationSupport.Input.class, - AiServiceContext.class, AiServiceMethodCreateInfo.class, Object[].class), + QuarkusAiServiceContext.class, AiServiceMethodCreateInfo.class, Object[].class), contextHandle, methodCreateInfoHandle, paramsHandle); ResultHandle resultHandle = mc.invokeVirtualMethod(SUPPORT_IMPLEMENT, supportHandle, inputHandle); @@ -547,7 +571,8 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(MethodInfo method, boolea Optional metricsInfo = gatherMetricsInfo(method, addMicrometerMetrics); Optional spanInfo = gatherSpanInfo(method, addOpenTelemetrySpans); - return new AiServiceMethodCreateInfo(systemMessageInfo, userMessageInfo, memoryIdParamPosition, requiresModeration, + return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo, + userMessageInfo, memoryIdParamPosition, requiresModeration, returnType, metricsInfo, spanInfo); } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java index b9c7a5ffe..fa32d3b91 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java @@ -18,16 +18,19 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem { private final DotName chatMemoryProviderSupplierClassDotName; private final DotName retrieverSupplierClassDotName; + private final DotName auditServiceClassSupplierDotName; public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languageModelSupplierClassDotName, List toolDotNames, DotName chatMemoryProviderSupplierClassDotName, - DotName retrieverSupplierClassDotName) { + DotName retrieverSupplierClassDotName, + DotName auditServiceClassSupplierDotName) { this.serviceClassInfo = serviceClassInfo; this.languageModelSupplierClassDotName = languageModelSupplierClassDotName; this.toolDotNames = toolDotNames; this.chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierClassDotName; this.retrieverSupplierClassDotName = retrieverSupplierClassDotName; + this.auditServiceClassSupplierDotName = auditServiceClassSupplierDotName; } public ClassInfo getServiceClassInfo() { @@ -49,4 +52,8 @@ public DotName getChatMemoryProviderSupplierClassDotName() { public DotName getRetrieverSupplierClassDotName() { return retrieverSupplierClassDotName; } + + public DotName getAuditServiceClassSupplierDotName() { + return auditServiceClassSupplierDotName; + } } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java index dad8425dc..428e5c447 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java @@ -20,6 +20,7 @@ import dev.langchain4j.service.UserName; import io.quarkiverse.langchain4j.CreatedAware; import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.audit.AuditService; public class Langchain4jDotNames { public static final DotName CHAT_MODEL = DotName.createSimple(ChatLanguageModel.class); @@ -53,10 +54,15 @@ public class Langchain4jDotNames { static final DotName RETRIEVER = DotName.createSimple(Retriever.class); static final DotName TEXT_SEGMENT = DotName.createSimple(TextSegment.class); + static final DotName AUDIT_SERVICE = DotName.createSimple(AuditService.class); + static final DotName BEAN_RETRIEVER_SUPPLIER = DotName.createSimple( RegisterAiService.BeanRetrieverSupplier.class); static final DotName BEAN_IF_EXISTS_RETRIEVER_SUPPLIER = DotName.createSimple( RegisterAiService.BeanIfExistsRetrieverSupplier.class); + static final DotName BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER = DotName.createSimple( + RegisterAiService.BeanIfExistsAuditServiceSupplier.class); + } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusAiServicesFactory.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusAiServicesFactory.java index 96aca42a1..d37a5c340 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusAiServicesFactory.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusAiServicesFactory.java @@ -13,10 +13,12 @@ import dev.langchain4j.service.AiServiceContext; import dev.langchain4j.service.AiServices; import dev.langchain4j.spi.services.AiServicesFactory; +import io.quarkiverse.langchain4j.audit.AuditService; import io.quarkiverse.langchain4j.runtime.AiServicesRecorder; import io.quarkiverse.langchain4j.runtime.ToolsRecorder; import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceClassCreateInfo; import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo; +import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext; import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutor; import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutorFactory; import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo; @@ -27,7 +29,12 @@ public class QuarkusAiServicesFactory implements AiServicesFactory { @Override public QuarkusAiServices create(AiServiceContext context) { - return new QuarkusAiServices<>(context); + if (context instanceof QuarkusAiServiceContext) { + return new QuarkusAiServices<>(context); + } else { + // the context is always empty (except for the aiServiceClass) anyway and never escapes, so we can just use our own type + return new QuarkusAiServices<>(new QuarkusAiServiceContext(context.aiServiceClass)); + } } public static class InstanceHolder { @@ -70,6 +77,11 @@ public AiServices tools(List objectsWithTools) { return this; } + public AiServices auditService(AuditService auditService) { + ((QuarkusAiServiceContext) context).auditService = auditService; + return this; + } + List lookup(Object bean, String className) { Map> metadata = ToolsRecorder.getMetadata(); // Fast path first. @@ -116,8 +128,8 @@ public T build() { try { return (T) Class.forName(classCreateInfo.getImplClassName(), true, Thread.currentThread() - .getContextClassLoader()).getConstructor(AiServiceContext.class) - .newInstance(context); + .getContextClassLoader()).getConstructor(QuarkusAiServiceContext.class) + .newInstance(((QuarkusAiServiceContext) context)); } catch (Exception e) { throw new IllegalStateException("Unable to create class '" + classCreateInfo.getImplClassName(), e); } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java index ca85a6082..94332e917 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java @@ -15,6 +15,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.retriever.Retriever; import dev.langchain4j.service.AiServices; +import io.quarkiverse.langchain4j.audit.AuditService; /** * Used to create Langchain4j's {@link AiServices} in a declarative manner that the application can then use simply by @@ -74,6 +75,15 @@ */ Class>> retrieverSupplier() default NoRetrieverSupplier.class; + /** + * Configures the way to obtain the {@link AuditService} to use. + * By default, Quarkus will look for a CDI bean that implements {@link AuditService}, but will fall back to not using + * any memory if no such bean exists. + * If an arbitrary {@link AuditService} instance is needed, a custom implementation of + * {@link Supplier} needs to be provided. + */ + Class> auditServiceSupplier() default BeanIfExistsAuditServiceSupplier.class; + /** * Marker that is used to tell Quarkus to use the {@link ChatLanguageModel} that has been configured as a CDI bean by * any of the extensions providing such capability (such as {@code quarkus-langchain4j-openai} and @@ -155,4 +165,16 @@ public Retriever get() { throw new UnsupportedOperationException("should never be called"); } } + + /** + * Marker that is used to tell Quarkus to use the {@link AuditService} that the user has configured as a CDI bean. + * If no such bean exists, then no audit service will be used. + */ + final class BeanIfExistsAuditServiceSupplier implements Supplier { + + @Override + public AuditService get() { + throw new UnsupportedOperationException("should never be called"); + } + } } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/Audit.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/Audit.java new file mode 100644 index 000000000..3c96ce296 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/Audit.java @@ -0,0 +1,87 @@ +package io.quarkiverse.langchain4j.audit; + +import java.util.List; +import java.util.Optional; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.output.Response; + +/** + * Abstract class to be implemented in order to keep track of whatever information is useful for the application auditing. + */ +public abstract class Audit { + + /** + * Information about the AiService that is being audited + */ + public record CreateInfo(String interfaceName, String methodName, Object[] parameters, + Optional memoryIDParamPosition) { + + } + + private final CreateInfo createInfo; + private Optional systemMessage; + private UserMessage userMessage; + + /** + * Invoked by {@link AuditService} when an AiService is invoked + */ + public Audit(CreateInfo createInfo) { + this.createInfo = createInfo; + } + + /** + * @return information about the AiService that is being audited + */ + public CreateInfo getCreateInfo() { + return createInfo; + } + + /** + * Invoked when the original user and system messages have been created + */ + public void initialMessages(Optional systemMessage, UserMessage userMessage) { + + } + + /** + * Invoked if a relevant document was added to the messages to be sent to the LLM + */ + public void addRelevantDocument(List segments, UserMessage userMessage) { + + } + + /** + * Invoked with a response from an LLM. It is important to note that this can be invoked multiple times + * when tools exist. + */ + public void addLLMToApplicationMessage(Response response) { + + } + + /** + * Invoked with a response from an LLM. It is important to note that this can be invoked multiple times + * when tools exist. + */ + public void addApplicationToLLMMessage(ToolExecutionResultMessage toolExecutionResultMessage) { + + } + + /** + * Invoked when the final result of the AiService method has been computed + */ + public void onCompletion(Object result) { + + } + + /** + * Invoked when there was an exception computing the result of the AiService method + */ + public void onFailure(Exception e) { + + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/AuditService.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/AuditService.java new file mode 100644 index 000000000..dacbc34f7 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/audit/AuditService.java @@ -0,0 +1,21 @@ +package io.quarkiverse.langchain4j.audit; + +/** + * Allow applications to audit parts of the interactions with the LLM that interest them + *

+ * When using {@link io.quarkiverse.langchain4j.RegisterAiService} if the application provides an implementation + * of {@link AuditService} that is a CDI bean, it will be used by default. + */ +public interface AuditService { + + /** + * Invoked when an AiService method is invoked and before any interaction with the LLM is performed. + */ + Audit create(Audit.CreateInfo createInfo); + + /** + * Invoked just before the AiService method returns its result - or throws an exception. + * The {@param audit} parameter is meant to be built up by implementing its callbacks. + */ + void complete(Audit audit); +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java index f38b4ebf1..05f6cc333 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java @@ -17,11 +17,12 @@ import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.retriever.Retriever; -import dev.langchain4j.service.AiServiceContext; import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.audit.AuditService; import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceClassCreateInfo; import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo; import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceCreateInfo; +import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext; import io.quarkus.arc.SyntheticCreationalContext; import io.quarkus.runtime.annotations.Recorder; @@ -31,6 +32,9 @@ public class AiServicesRecorder { private static final TypeLiteral> CHAT_MEMORY_PROVIDER_INSTANCE_TYPE_LITERAL = new TypeLiteral<>() { }; private static final TypeLiteral>> RETRIEVER_INSTANCE_TYPE_LITERAL = new TypeLiteral<>() { + + }; + private static final TypeLiteral> AUDIT_SERVICE_TYPE_LITERAL = new TypeLiteral<>() { }; // the key is the interface's class name @@ -72,7 +76,7 @@ public T apply(SyntheticCreationalContext creationalContext) { Class serviceClass = Thread.currentThread().getContextClassLoader() .loadClass(info.getServiceClassName()); - AiServiceContext aiServiceContext = new AiServiceContext(serviceClass); + QuarkusAiServiceContext aiServiceContext = new QuarkusAiServiceContext(serviceClass); var quarkusAiServices = INSTANCE.create(aiServiceContext); if (info.getLanguageModelSupplierClassName() != null) { @@ -137,6 +141,23 @@ public T apply(SyntheticCreationalContext creationalContext) { } } + if (info.getAuditServiceClassSupplierName() != null) { + if (RegisterAiService.BeanIfExistsAuditServiceSupplier.class.getName() + .equals(info.getAuditServiceClassSupplierName())) { + Instance instance = creationalContext + .getInjectedReference(AUDIT_SERVICE_TYPE_LITERAL); + if (instance.isResolvable()) { + quarkusAiServices.auditService(instance.get()); + } + } else { + @SuppressWarnings("rawtypes") + Supplier supplier = (Supplier) Thread + .currentThread().getContextClassLoader().loadClass(info.getAuditServiceClassSupplierName()) + .getConstructor().newInstance(); + quarkusAiServices.auditService(supplier.get()); + } + } + return (T) quarkusAiServices.build(); } catch (ClassNotFoundException e) { throw new IllegalStateException(e); diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java index 01c73414a..6b2921bf9 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java @@ -8,6 +8,9 @@ @SuppressWarnings("OptionalUsedAsFieldOrParameterType") public class AiServiceMethodCreateInfo { + private final String interfaceName; + private final String methodName; + private final Optional systemMessageInfo; private final UserMessageInfo userMessageInfo; private final Optional memoryIdParamPosition; @@ -20,11 +23,14 @@ public class AiServiceMethodCreateInfo { private final Optional spanInfo; @RecordableConstructor - public AiServiceMethodCreateInfo(Optional systemMessageInfo, UserMessageInfo userMessageInfo, + public AiServiceMethodCreateInfo(String interfaceName, String methodName, + Optional systemMessageInfo, UserMessageInfo userMessageInfo, Optional memoryIdParamPosition, boolean requiresModeration, Class returnType, Optional metricsInfo, Optional spanInfo) { + this.interfaceName = interfaceName; + this.methodName = methodName; this.systemMessageInfo = systemMessageInfo; this.userMessageInfo = userMessageInfo; this.memoryIdParamPosition = memoryIdParamPosition; @@ -34,6 +40,14 @@ public AiServiceMethodCreateInfo(Optional systemMessageInfo, UserM this.spanInfo = spanInfo; } + public String getInterfaceName() { + return interfaceName; + } + + public String getMethodName() { + return methodName; + } + public Optional getSystemMessageInfo() { return systemMessageInfo; } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java index 56a587337..108ddd68e 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java @@ -39,6 +39,8 @@ import dev.langchain4j.service.AiServiceTokenStream; import dev.langchain4j.service.ServiceOutputParser; import dev.langchain4j.service.TokenStream; +import io.quarkiverse.langchain4j.audit.Audit; +import io.quarkiverse.langchain4j.audit.AuditService; import io.smallrye.mutiny.infrastructure.Infrastructure; /** @@ -48,16 +50,48 @@ public class AiServiceMethodImplementationSupport { private static final Logger log = Logger.getLogger(AiServiceMethodImplementationSupport.class); + /** + * This method is called by the implementations of each ai service method. + */ public Object implement(Input input) { - AiServiceContext context = input.context; + QuarkusAiServiceContext context = input.context; AiServiceMethodCreateInfo createInfo = input.createInfo; Object[] methodArgs = input.methodArgs; + AuditService auditService = context.auditService; + Audit audit = null; + if (auditService != null) { + audit = auditService.create(new Audit.CreateInfo(createInfo.getInterfaceName(), createInfo.getMethodName(), + methodArgs, createInfo.getMemoryIdParamPosition())); + } + // TODO: add validation + try { + var result = doImplement(createInfo, methodArgs, context, audit); + if (audit != null) { + audit.onCompletion(result); + auditService.complete(audit); + } + return result; + } catch (Exception e) { + log.errorv(e, "Execution of {0}#{1} failed", createInfo.getInterfaceName(), createInfo.getMethodName()); + if (audit != null) { + audit.onFailure(e); + auditService.complete(audit); + } + throw e; + } + } + private static Object doImplement(AiServiceMethodCreateInfo createInfo, Object[] methodArgs, + QuarkusAiServiceContext context, Audit audit) { Optional systemMessage = prepareSystemMessage(createInfo, methodArgs); UserMessage userMessage = prepareUserMessage(context, createInfo, methodArgs); + if (audit != null) { + audit.initialMessages(systemMessage, userMessage); + } + if (context.retriever != null) { // TODO extract method/class List relevant = context.retriever.findRelevant(userMessage.text()); @@ -73,6 +107,10 @@ public Object implement(Input input) { userMessage = userMessage(userMessage.text() + "\n\nHere is some information that might be useful for answering:\n\n" + relevantConcatenated); + + if (audit != null) { + audit.addRelevantDocument(relevant, userMessage); + } } } @@ -97,7 +135,7 @@ public Object implement(Input input) { Class returnType = createInfo.getReturnType(); if (returnType.equals(TokenStream.class)) { - return new AiServiceTokenStream(messages, context, memoryId); // TODO: moderation + return new AiServiceTokenStream(messages, context, memoryId); } Future moderationFuture = triggerModerationIfNeeded(context, createInfo, messages); @@ -107,6 +145,9 @@ public Object implement(Input input) { ? context.chatModel.generate(messages, context.toolSpecifications) : context.chatModel.generate(messages); log.debug("AI response obtained"); + if (audit != null) { + audit.addLLMToApplicationMessage(response); + } verifyModerationIfNeeded(moderationFuture); ToolExecutionRequest toolExecutionRequest; @@ -128,6 +169,9 @@ public Object implement(Input input) { log.debugv("Result of {0} is '{1}'", toolExecutionRequest, toolExecutionResult); ToolExecutionResultMessage toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest.name(), toolExecutionResult); + if (audit != null) { + audit.addApplicationToLLMMessage(toolExecutionResultMessage); + } ChatMemory chatMemory = context.chatMemory(memoryId); chatMemory.add(toolExecutionResultMessage); @@ -135,6 +179,10 @@ public Object implement(Input input) { log.debug("Attempting to obtain AI response"); response = context.chatModel.generate(chatMemory.messages(), context.toolSpecifications); log.debug("AI response obtained"); + + if (audit != null) { + audit.addLLMToApplicationMessage(response); + } } return ServiceOutputParser.parse(response, returnType); @@ -257,11 +305,11 @@ private static String arrayToString(Object arg) { } public static class Input { - final AiServiceContext context; + final QuarkusAiServiceContext context; final AiServiceMethodCreateInfo createInfo; final Object[] methodArgs; - public Input(AiServiceContext context, AiServiceMethodCreateInfo createInfo, Object[] methodArgs) { + public Input(QuarkusAiServiceContext context, AiServiceMethodCreateInfo createInfo, Object[] methodArgs) { this.context = context; this.createInfo = createInfo; this.methodArgs = methodArgs; diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java index e2467893a..c50046f19 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java @@ -12,15 +12,19 @@ public class DeclarativeAiServiceCreateInfo { private final String chatMemoryProviderSupplierClassName; private final String retrieverSupplierClassName; + private final String auditServiceClassSupplierName; + @RecordableConstructor public DeclarativeAiServiceCreateInfo(String serviceClassName, String languageModelSupplierClassName, List toolsClassNames, String chatMemoryProviderSupplierClassName, - String retrieverSupplierClassName) { + String retrieverSupplierClassName, + String auditServiceClassSupplierName) { this.serviceClassName = serviceClassName; this.languageModelSupplierClassName = languageModelSupplierClassName; this.toolsClassNames = toolsClassNames; this.chatMemoryProviderSupplierClassName = chatMemoryProviderSupplierClassName; this.retrieverSupplierClassName = retrieverSupplierClassName; + this.auditServiceClassSupplierName = auditServiceClassSupplierName; } public String getServiceClassName() { @@ -42,4 +46,8 @@ public String getChatMemoryProviderSupplierClassName() { public String getRetrieverSupplierClassName() { return retrieverSupplierClassName; } + + public String getAuditServiceClassSupplierName() { + return auditServiceClassSupplierName; + } } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java new file mode 100644 index 000000000..564d4e787 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/QuarkusAiServiceContext.java @@ -0,0 +1,13 @@ +package io.quarkiverse.langchain4j.runtime.aiservice; + +import dev.langchain4j.service.AiServiceContext; +import io.quarkiverse.langchain4j.audit.AuditService; + +public class QuarkusAiServiceContext extends AiServiceContext { + + public AuditService auditService; + + public QuarkusAiServiceContext(Class aiServiceClass) { + super(aiServiceClass); + } +} diff --git a/docs/modules/ROOT/pages/ai-services.adoc b/docs/modules/ROOT/pages/ai-services.adoc index 3fc6a7109..1b78ff2e0 100644 --- a/docs/modules/ROOT/pages/ai-services.adoc +++ b/docs/modules/ROOT/pages/ai-services.adoc @@ -356,3 +356,8 @@ In the trace above we can see the parent span which corresponds to the handling interesting thing is the `langchain4j.aiservices.MyAiService.writeAPoem` span which corresponds to the invocation of the AI service. The child spans of this span correspond (from to right) to calling the OpenAI API, invoking the `sendEmail` tool and finally invoking calling the OpenAI API again. +=== Auditing + +The extension allows users to audit the process of implementing an AiService by introducing `io.quarkiverse.langchain4j.audit.AuditService` and `io.quarkiverse.langchain4j.audit.Audit`. +By default, if a bean of type `AuditService` is present in the application, it will be used in order to create an `Audit`, which received various callbacks pertaining to the implementation +of the AiService method. More information can be found on the javadoc of these two classes. diff --git a/docs/modules/ROOT/pages/index.adoc b/docs/modules/ROOT/pages/index.adoc index 75a69710f..1b06c63cf 100644 --- a/docs/modules/ROOT/pages/index.adoc +++ b/docs/modules/ROOT/pages/index.adoc @@ -73,6 +73,10 @@ The extension offers the following advantages over using the vanilla https://git ** CDI beans for the Langchain4j models ** Standard configuration properties for configuring said models * xref:ai-services.adoc[Declarative AI Services] +* Built-in observability +** Metrics +** Tracing +** Auditing * Build time wiring ** Reduced footprint of the library ** Feedback about misuse at build time diff --git a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/AuditingServiceTest.java b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/AuditingServiceTest.java new file mode 100644 index 000000000..d1facbfb6 --- /dev/null +++ b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/AuditingServiceTest.java @@ -0,0 +1,316 @@ +package org.acme.examples.aiservices; + +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options; +import static io.quarkiverse.langchain4j.openai.test.WiremockUtils.DEFAULT_TOKEN; +import static org.acme.examples.aiservices.MessageAssertUtils.assertMultipleRequestMessage; +import static org.acme.examples.aiservices.MessageAssertUtils.assertSingleRequestMessage; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.tomakehurst.wiremock.WireMockServer; +import com.github.tomakehurst.wiremock.stubbing.Scenario; +import com.github.tomakehurst.wiremock.stubbing.ServeEvent; +import com.github.tomakehurst.wiremock.verification.LoggedRequest; + +import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.ToolExecutionResultMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.audit.Audit; +import io.quarkiverse.langchain4j.audit.AuditService; +import io.quarkiverse.langchain4j.openai.test.WiremockUtils; +import io.quarkus.arc.Arc; +import io.quarkus.arc.InstanceHandle; +import io.quarkus.test.QuarkusUnitTest; + +public class AuditingServiceTest { + + private static final int WIREMOCK_PORT = 8089; + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer( + () -> ShrinkWrap.create(JavaArchive.class).addClasses(WiremockUtils.class, MessageAssertUtils.class)) + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.api-key", "whatever") + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.base-url", "http://localhost:" + WIREMOCK_PORT + "/v1"); + private static final TypeReference> MAP_TYPE_REF = new TypeReference<>() { + }; + + static WireMockServer wireMockServer; + + static ObjectMapper mapper; + + public static class ChatMemoryProviderProducer { + + @Singleton + ChatMemoryProvider chatMemory() { + return memoryId -> MessageWindowChatMemory.builder() + .id(memoryId) + .maxMessages(10) + .chatMemoryStore(new InMemoryChatMemoryStore()) + .build(); + } + } + + @BeforeAll + static void beforeAll() { + wireMockServer = new WireMockServer(options().port(WIREMOCK_PORT)); + wireMockServer.start(); + + mapper = new ObjectMapper(); + } + + @AfterAll + static void afterAll() { + wireMockServer.stop(); + } + + @BeforeEach + void setup() { + wireMockServer.resetAll(); + wireMockServer.stubFor(WiremockUtils.defaultChatCompletionsStub()); + } + + private static final String scenario = "tools"; + private static final String secondState = "second"; + + @Singleton + public static class CalculatorAfter implements Runnable { + + @Override + public void run() { + wireMockServer.setScenarioState(scenario, secondState); + } + } + + @Singleton + static class Calculator { + + private final CalculatorAfter after; + + Calculator(CalculatorAfter after) { + this.after = after; + } + + @Tool("calculates the square root of the provided number") + double squareRoot(double number) { + var result = Math.sqrt(number); + after.run(); + return result; + } + } + + @RegisterAiService(tools = Calculator.class) + interface Assistant { + + String chat(String message); + } + + @Inject + Assistant assistant; + + @Test + void should_execute_tool_then_answer() throws IOException { + var firstResponse = """ + { + "id": "chatcmpl-8D88Dag1gAKnOPP9Ed4bos7vSpaNz", + "object": "chat.completion", + "created": 1698140213, + "model": "gpt-3.5-turbo-0613", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "function_call": { + "name": "squareRoot", + "arguments": "{\\n \\"number\\": 485906798473894056\\n}" + } + }, + "finish_reason": "function_call" + } + ], + "usage": { + "prompt_tokens": 65, + "completion_tokens": 20, + "total_tokens": 85 + } + } + """; + + var secondResponse = """ + { + "id": "chatcmpl-8D88FIAUWSpwLaShFr0w8G1SWuVdl", + "object": "chat.completion", + "created": 1698140215, + "model": "gpt-3.5-turbo-0613", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "The square root of 485,906,798,473,894,056 in scientific notation is approximately 6.97070153193991E8." + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 102, + "completion_tokens": 33, + "total_tokens": 135 + } + } + """; + + wireMockServer.stubFor( + WiremockUtils.chatCompletionMapping(DEFAULT_TOKEN) + .inScenario(scenario) + .whenScenarioStateIs(Scenario.STARTED) + .willReturn(WiremockUtils.CHAT_RESPONSE_WITHOUT_BODY.withBody(firstResponse))); + wireMockServer.stubFor( + WiremockUtils.chatCompletionMapping(DEFAULT_TOKEN) + .inScenario(scenario) + .whenScenarioStateIs(secondState) + .willReturn(WiremockUtils.CHAT_RESPONSE_WITHOUT_BODY.withBody(secondResponse))); + + wireMockServer.setScenarioState(scenario, Scenario.STARTED); + + String userMessage = "What is the square root of 485906798473894056 in scientific notation?"; + + String answer = assistant.chat(userMessage); + + String expectedResult = "The square root of 485,906,798,473,894,056 in scientific notation is approximately 6.97070153193991E8."; + assertThat(answer).isEqualTo(expectedResult); + + assertThat(wireMockServer.getAllServeEvents()).hasSize(2); + + assertSingleRequestMessage(getRequestAsMap(getRequestBody(wireMockServer.getAllServeEvents().get(1))), + "What is the square root of 485906798473894056 in scientific notation?"); + assertMultipleRequestMessage(getRequestAsMap(getRequestBody(wireMockServer.getAllServeEvents().get(0))), + List.of( + new MessageAssertUtils.MessageContent("user", + "What is the square root of 485906798473894056 in scientific notation?"), + new MessageAssertUtils.MessageContent("assistant", null), + new MessageAssertUtils.MessageContent("function", "6.97070153193991E8"))); + + InstanceHandle auditServiceInstance = Arc.container().instance(SimpleAuditService.class); + assertTrue(auditServiceInstance.isAvailable()); + SimpleAuditService auditService = auditServiceInstance.get(); + SimpleAuditService.SimpleAudit audit = auditService.audit; + assertThat(audit).isNotNull(); + assertThat(audit.originalUserMessage).isEqualTo(userMessage); + assertThat(audit.systemMessage).isNull(); + assertThat(audit.aiMessageToolExecution).isEqualTo("squareRoot"); + assertThat(audit.toolResult).isEqualTo("6.97070153193991E8"); + assertThat(audit.result).isEqualTo(expectedResult); + assertThat(audit.failed).isZero(); + assertThat(audit.failed).isZero(); + + } + + private Map getRequestAsMap(byte[] body) throws IOException { + return mapper.readValue(body, MAP_TYPE_REF); + } + + private byte[] getRequestBody(ServeEvent serveEvent) { + LoggedRequest request = serveEvent.getRequest(); + assertThat(request.getBody()).isNotEmpty(); + return request.getBody(); + } + + @Singleton + public static class SimpleAuditService implements AuditService { + + SimpleAudit audit; + + @Override + public Audit create(Audit.CreateInfo createInfo) { + return new SimpleAudit(createInfo); + } + + @Override + public void complete(Audit audit) { + this.audit = (SimpleAudit) audit; + } + + public static class SimpleAudit extends Audit { + + String systemMessage; + String originalUserMessage; + String aiMessageToolExecution; + String toolResult; + String aiMessageResult; + Object result; + + int relevantDocs; + int failed; + + public SimpleAudit(CreateInfo createInfo) { + super(createInfo); + } + + @Override + public void initialMessages(Optional systemMessage, UserMessage userMessage) { + if (systemMessage.isPresent()) { + this.systemMessage = systemMessage.get().text(); + } + originalUserMessage = userMessage.text(); + } + + @Override + public void addRelevantDocument(List segments, UserMessage userMessage) { + relevantDocs++; + } + + @Override + public void addLLMToApplicationMessage(Response response) { + if (response.content().text() != null) { + aiMessageResult = response.content().text(); + } else { + aiMessageToolExecution = response.content().toolExecutionRequest().name(); + } + } + + @Override + public void addApplicationToLLMMessage(ToolExecutionResultMessage toolExecutionResultMessage) { + toolResult = toolExecutionResultMessage.text(); + } + + @Override + public void onCompletion(Object result) { + this.result = result; + } + + @Override + public void onFailure(Exception e) { + failed++; + } + } + } +}