Skip to content

Commit

Permalink
Sketch out API for auditing
Browse files Browse the repository at this point in the history
Relates to: #8
  • Loading branch information
geoand committed Dec 6, 2023
1 parent 374abca commit 16d9da4
Show file tree
Hide file tree
Showing 16 changed files with 633 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import org.objectweb.asm.tree.analysis.AnalyzerException;

import dev.langchain4j.exception.IllegalConfigurationException;
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.V;
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
import io.quarkiverse.langchain4j.runtime.AiServicesRecorder;
Expand All @@ -53,6 +52,7 @@
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodImplementationSupport;
import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.MetricsWrapper;
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext;
import io.quarkiverse.langchain4j.runtime.aiservice.SpanWrapper;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ArcContainer;
Expand Down Expand Up @@ -101,6 +101,7 @@ public class AiServicesProcessor {
private static final MethodDescriptor SUPPORT_IMPLEMENT = MethodDescriptor.ofMethod(
AiServiceMethodImplementationSupport.class,
"implement", Object.class, AiServiceMethodImplementationSupport.Input.class);
public static final DotName CDI_INSTANCE = DotName.createSimple(Instance.class);

@BuildStep
public void nativeSupport(CombinedIndexBuildItem indexBuildItem,
Expand Down Expand Up @@ -203,13 +204,21 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
}
}

DotName auditServiceClassSupplierName = Langchain4jDotNames.BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER;
AnnotationValue auditServiceClassSupplierValue = instance.value("auditServiceSupplier");
if (auditServiceClassSupplierValue != null) {
auditServiceClassSupplierName = auditServiceClassSupplierValue.asClass().name();
validateSupplierAndRegisterForReflection(auditServiceClassSupplierName, index, reflectiveClassProducer);
}

declarativeAiServiceProducer.produce(
new DeclarativeAiServiceBuildItem(
declarativeAiServiceClassInfo,
chatLanguageModelSupplierClassDotName,
toolDotNames,
chatMemoryProviderSupplierClassDotName,
retrieverSupplierClassDotName));
retrieverSupplierClassDotName,
auditServiceClassSupplierName));
}

if (needChatModelBean) {
Expand Down Expand Up @@ -244,6 +253,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
boolean needsChatModelBean = false;
boolean needsChatMemoryProviderBean = false;
boolean needsRetrieverBean = false;
boolean needsAuditServiceBean = false;
Set<DotName> allToolNames = new HashSet<>();

for (DeclarativeAiServiceBuildItem bi : declarativeAiServiceItems) {
Expand All @@ -264,12 +274,17 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
? bi.getRetrieverSupplierClassDotName().toString()
: null;

String auditServiceClassSupplierName = bi.getAuditServiceClassSupplierDotName() != null
? bi.getAuditServiceClassSupplierDotName().toString()
: null;

SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem
.configure(declarativeAiServiceClassInfo.name())
.createWith(recorder.createDeclarativeAiService(
new DeclarativeAiServiceCreateInfo(serviceClassName, chatLanguageModelSupplierClassName,
toolClassNames, chatMemoryProviderSupplierClassName,
retrieverSupplierClassName)))
retrieverSupplierClassName,
auditServiceClassSupplierName)))
.setRuntimeInit()
.scope(ApplicationScoped.class);
if ((chatLanguageModelSupplierClassName == null) && selectedChatModelProvider.isPresent()) { // TODO: is second condition needed?
Expand All @@ -290,7 +305,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
needsChatMemoryProviderBean = true;
} else if (Langchain4jDotNames.BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER.toString()
.equals(chatMemoryProviderSupplierClassName)) {
configurator.addInjectionPoint(ParameterizedType.create(DotName.createSimple(Instance.class),
configurator.addInjectionPoint(ParameterizedType.create(CDI_INSTANCE,
new Type[] { ClassType.create(Langchain4jDotNames.CHAT_MEMORY_PROVIDER) }, null));
needsChatMemoryProviderBean = true;
}
Expand All @@ -301,13 +316,19 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
needsRetrieverBean = true;
} else if (Langchain4jDotNames.BEAN_IF_EXISTS_RETRIEVER_SUPPLIER.toString()
.equals(retrieverSupplierClassName)) {
configurator.addInjectionPoint(ParameterizedType.create(DotName.createSimple(Instance.class),
configurator.addInjectionPoint(ParameterizedType.create(CDI_INSTANCE,
new Type[] { ParameterizedType.create(Langchain4jDotNames.RETRIEVER,
new Type[] { ClassType.create(Langchain4jDotNames.TEXT_SEGMENT) }, null) },
null));
needsRetrieverBean = true;
}

if (Langchain4jDotNames.BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER.toString().equals(auditServiceClassSupplierName)) {
configurator.addInjectionPoint(ParameterizedType.create(CDI_INSTANCE,
new Type[] { ClassType.create(Langchain4jDotNames.AUDIT_SERVICE) }, null));
needsAuditServiceBean = true;
}

syntheticBeanProducer.produce(configurator.done());
}

Expand All @@ -320,6 +341,9 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
if (needsRetrieverBean) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(Langchain4jDotNames.RETRIEVER));
}
if (needsAuditServiceBean) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(Langchain4jDotNames.AUDIT_SERVICE));
}
if (!allToolNames.isEmpty()) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(allToolNames));
}
Expand Down Expand Up @@ -436,19 +460,19 @@ public void handleAiServices(AiServicesRecorder recorder,
.interfaces(iface.name().toString())
.build()) {

FieldDescriptor contextField = classCreator.getFieldCreator("context", AiServiceContext.class)
FieldDescriptor contextField = classCreator.getFieldCreator("context", QuarkusAiServiceContext.class)
.setModifiers(Modifier.PRIVATE | Modifier.FINAL)
.getFieldDescriptor();

for (MethodInfo methodInfo : methodsToImplement) {
// The implementation essentially gets method the context and delegates to
// The implementation essentially gets the context and delegates to
// MethodImplementationSupport#implement

String methodId = createMethodId(methodInfo);
perMethodMetadata.put(methodId,
gatherMethodMetadata(methodInfo, addMicrometerMetrics, addOpenTelemetrySpan));
MethodCreator constructor = classCreator.getMethodCreator(MethodDescriptor.INIT, "V",
AiServiceContext.class);
QuarkusAiServiceContext.class);
constructor.invokeSpecialMethod(OBJECT_CONSTRUCTOR, constructor.getThis());
constructor.writeInstanceField(contextField, constructor.getThis(), constructor.getMethodParam(0));
constructor.returnValue(null);
Expand All @@ -466,7 +490,7 @@ public void handleAiServices(AiServicesRecorder recorder,
ResultHandle supportHandle = getFromCDI(mc, AiServiceMethodImplementationSupport.class.getName());
ResultHandle inputHandle = mc.newInstance(
MethodDescriptor.ofConstructor(AiServiceMethodImplementationSupport.Input.class,
AiServiceContext.class, AiServiceMethodCreateInfo.class, Object[].class),
QuarkusAiServiceContext.class, AiServiceMethodCreateInfo.class, Object[].class),
contextHandle, methodCreateInfoHandle, paramsHandle);

ResultHandle resultHandle = mc.invokeVirtualMethod(SUPPORT_IMPLEMENT, supportHandle, inputHandle);
Expand Down Expand Up @@ -547,7 +571,8 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(MethodInfo method, boolea
Optional<AiServiceMethodCreateInfo.MetricsInfo> metricsInfo = gatherMetricsInfo(method, addMicrometerMetrics);
Optional<AiServiceMethodCreateInfo.SpanInfo> spanInfo = gatherSpanInfo(method, addOpenTelemetrySpans);

return new AiServiceMethodCreateInfo(systemMessageInfo, userMessageInfo, memoryIdParamPosition, requiresModeration,
return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo,
userMessageInfo, memoryIdParamPosition, requiresModeration,
returnType, metricsInfo, spanInfo);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {

private final DotName chatMemoryProviderSupplierClassDotName;
private final DotName retrieverSupplierClassDotName;
private final DotName auditServiceClassSupplierDotName;

public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languageModelSupplierClassDotName,
List<DotName> toolDotNames,
DotName chatMemoryProviderSupplierClassDotName,
DotName retrieverSupplierClassDotName) {
DotName retrieverSupplierClassDotName,
DotName auditServiceClassSupplierDotName) {
this.serviceClassInfo = serviceClassInfo;
this.languageModelSupplierClassDotName = languageModelSupplierClassDotName;
this.toolDotNames = toolDotNames;
this.chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierClassDotName;
this.retrieverSupplierClassDotName = retrieverSupplierClassDotName;
this.auditServiceClassSupplierDotName = auditServiceClassSupplierDotName;
}

public ClassInfo getServiceClassInfo() {
Expand All @@ -49,4 +52,8 @@ public DotName getChatMemoryProviderSupplierClassDotName() {
public DotName getRetrieverSupplierClassDotName() {
return retrieverSupplierClassDotName;
}

public DotName getAuditServiceClassSupplierDotName() {
return auditServiceClassSupplierDotName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import dev.langchain4j.service.UserName;
import io.quarkiverse.langchain4j.CreatedAware;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.audit.AuditService;

public class Langchain4jDotNames {
public static final DotName CHAT_MODEL = DotName.createSimple(ChatLanguageModel.class);
Expand Down Expand Up @@ -53,10 +54,15 @@ public class Langchain4jDotNames {
static final DotName RETRIEVER = DotName.createSimple(Retriever.class);
static final DotName TEXT_SEGMENT = DotName.createSimple(TextSegment.class);

static final DotName AUDIT_SERVICE = DotName.createSimple(AuditService.class);

static final DotName BEAN_RETRIEVER_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanRetrieverSupplier.class);

static final DotName BEAN_IF_EXISTS_RETRIEVER_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanIfExistsRetrieverSupplier.class);

static final DotName BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanIfExistsAuditServiceSupplier.class);

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.spi.services.AiServicesFactory;
import io.quarkiverse.langchain4j.audit.AuditService;
import io.quarkiverse.langchain4j.runtime.AiServicesRecorder;
import io.quarkiverse.langchain4j.runtime.ToolsRecorder;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceClassCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext;
import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutor;
import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutorFactory;
import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo;
Expand All @@ -27,7 +29,12 @@ public class QuarkusAiServicesFactory implements AiServicesFactory {

@Override
public <T> QuarkusAiServices<T> create(AiServiceContext context) {
return new QuarkusAiServices<>(context);
if (context instanceof QuarkusAiServiceContext) {
return new QuarkusAiServices<>(context);
} else {
// the context is always empty (except for the aiServiceClass) anyway and never escapes, so we can just use our own type
return new QuarkusAiServices<>(new QuarkusAiServiceContext(context.aiServiceClass));
}
}

public static class InstanceHolder {
Expand Down Expand Up @@ -70,6 +77,11 @@ public AiServices<T> tools(List<Object> objectsWithTools) {
return this;
}

public AiServices<T> auditService(AuditService auditService) {
((QuarkusAiServiceContext) context).auditService = auditService;
return this;
}

List<ToolMethodCreateInfo> lookup(Object bean, String className) {
Map<String, List<ToolMethodCreateInfo>> metadata = ToolsRecorder.getMetadata();
// Fast path first.
Expand Down Expand Up @@ -116,8 +128,8 @@ public T build() {

try {
return (T) Class.forName(classCreateInfo.getImplClassName(), true, Thread.currentThread()
.getContextClassLoader()).getConstructor(AiServiceContext.class)
.newInstance(context);
.getContextClassLoader()).getConstructor(QuarkusAiServiceContext.class)
.newInstance(((QuarkusAiServiceContext) context));
} catch (Exception e) {
throw new IllegalStateException("Unable to create class '" + classCreateInfo.getImplClassName(), e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.retriever.Retriever;
import dev.langchain4j.service.AiServices;
import io.quarkiverse.langchain4j.audit.AuditService;

/**
* Used to create Langchain4j's {@link AiServices} in a declarative manner that the application can then use simply by
Expand Down Expand Up @@ -74,6 +75,15 @@
*/
Class<? extends Supplier<Retriever<TextSegment>>> retrieverSupplier() default NoRetrieverSupplier.class;

/**
* Configures the way to obtain the {@link AuditService} to use.
* By default, Quarkus will look for a CDI bean that implements {@link AuditService}, but will fall back to not using
* any memory if no such bean exists.
* If an arbitrary {@link AuditService} instance is needed, a custom implementation of
* {@link Supplier<AuditService>} needs to be provided.
*/
Class<? extends Supplier<AuditService>> auditServiceSupplier() default BeanIfExistsAuditServiceSupplier.class;

/**
* Marker that is used to tell Quarkus to use the {@link ChatLanguageModel} that has been configured as a CDI bean by
* any of the extensions providing such capability (such as {@code quarkus-langchain4j-openai} and
Expand Down Expand Up @@ -155,4 +165,16 @@ public Retriever<TextSegment> get() {
throw new UnsupportedOperationException("should never be called");
}
}

/**
* Marker that is used to tell Quarkus to use the {@link AuditService} that the user has configured as a CDI bean.
* If no such bean exists, then no audit service will be used.
*/
final class BeanIfExistsAuditServiceSupplier implements Supplier<AuditService> {

@Override
public AuditService get() {
throw new UnsupportedOperationException("should never be called");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package io.quarkiverse.langchain4j.audit;

import java.util.List;
import java.util.Optional;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.output.Response;

/**
* Abstract class to be implemented in order to keep track of whatever information is useful for the application auditing.
*/
public abstract class Audit {

/**
* Information about the AiService that is being audited
*/
public record CreateInfo(String interfaceName, String methodName, Object[] parameters,
Optional<Integer> memoryIDParamPosition) {

}

private final CreateInfo createInfo;
private Optional<SystemMessage> systemMessage;
private UserMessage userMessage;

/**
* Invoked by {@link AuditService} when an AiService is invoked
*/
public Audit(CreateInfo createInfo) {
this.createInfo = createInfo;
}

/**
* @return information about the AiService that is being audited
*/
public CreateInfo getCreateInfo() {
return createInfo;
}

/**
* Invoked when the original user and system messages have been created
*/
public void initialMessages(Optional<SystemMessage> systemMessage, UserMessage userMessage) {

}

/**
* Invoked if a relevant document was added to the messages to be sent to the LLM
*/
public void addRelevantDocument(List<TextSegment> segments, UserMessage userMessage) {

}

/**
* Invoked with a response from an LLM. It is important to note that this can be invoked multiple times
* when tools exist.
*/
public void addLLMToApplicationMessage(Response<AiMessage> response) {

}

/**
* Invoked with a response from an LLM. It is important to note that this can be invoked multiple times
* when tools exist.
*/
public void addApplicationToLLMMessage(ToolExecutionResultMessage toolExecutionResultMessage) {

}

/**
* Invoked when the final result of the AiService method has been computed
*/
public void onCompletion(Object result) {

}

/**
* Invoked when there was an exception computing the result of the AiService method
*/
public void onFailure(Exception e) {

}
}
Loading

0 comments on commit 16d9da4

Please sign in to comment.