Skip to content

Commit

Permalink
Sketch out API for auditing
Browse files Browse the repository at this point in the history
  • Loading branch information
geoand committed Dec 4, 2023
1 parent 1ba5659 commit 596ced6
Show file tree
Hide file tree
Showing 10 changed files with 212 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import org.objectweb.asm.tree.analysis.AnalyzerException;

import dev.langchain4j.exception.IllegalConfigurationException;
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.V;
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
import io.quarkiverse.langchain4j.runtime.AiServicesRecorder;
Expand All @@ -53,6 +52,7 @@
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodImplementationSupport;
import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.MetricsWrapper;
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext;
import io.quarkiverse.langchain4j.runtime.aiservice.SpanWrapper;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ArcContainer;
Expand Down Expand Up @@ -266,12 +266,15 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
? bi.getRetrieverSupplierClassDotName().toString()
: null;

String auditServiceClassSupplierName = null; // TODO: implement

SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem
.configure(declarativeAiServiceClassInfo.name())
.createWith(recorder.createDeclarativeAiService(
new DeclarativeAiServiceCreateInfo(serviceClassName, chatLanguageModelSupplierClassName,
toolClassNames, chatMemoryProviderSupplierClassName,
retrieverSupplierClassName)))
retrieverSupplierClassName,
auditServiceClassSupplierName)))
.setRuntimeInit()
.scope(ApplicationScoped.class);
if ((chatLanguageModelSupplierClassName == null) && selectedChatModelProvider.isPresent()) { // TODO: is second condition needed?
Expand Down Expand Up @@ -438,19 +441,19 @@ public void handleAiServices(AiServicesRecorder recorder,
.interfaces(iface.name().toString())
.build()) {

FieldDescriptor contextField = classCreator.getFieldCreator("context", AiServiceContext.class)
FieldDescriptor contextField = classCreator.getFieldCreator("context", QuarkusAiServiceContext.class)
.setModifiers(Modifier.PRIVATE | Modifier.FINAL)
.getFieldDescriptor();

for (MethodInfo methodInfo : methodsToImplement) {
// The implementation essentially gets method the context and delegates to
// The implementation essentially gets the context and delegates to
// MethodImplementationSupport#implement

String methodId = createMethodId(methodInfo);
perMethodMetadata.put(methodId,
gatherMethodMetadata(methodInfo, addMicrometerMetrics, addOpenTelemetrySpan));
MethodCreator constructor = classCreator.getMethodCreator(MethodDescriptor.INIT, "V",
AiServiceContext.class);
QuarkusAiServiceContext.class);
constructor.invokeSpecialMethod(OBJECT_CONSTRUCTOR, constructor.getThis());
constructor.writeInstanceField(contextField, constructor.getThis(), constructor.getMethodParam(0));
constructor.returnValue(null);
Expand All @@ -468,7 +471,7 @@ public void handleAiServices(AiServicesRecorder recorder,
ResultHandle supportHandle = getFromCDI(mc, AiServiceMethodImplementationSupport.class.getName());
ResultHandle inputHandle = mc.newInstance(
MethodDescriptor.ofConstructor(AiServiceMethodImplementationSupport.Input.class,
AiServiceContext.class, AiServiceMethodCreateInfo.class, Object[].class),
QuarkusAiServiceContext.class, AiServiceMethodCreateInfo.class, Object[].class),
contextHandle, methodCreateInfoHandle, paramsHandle);

ResultHandle resultHandle = mc.invokeVirtualMethod(SUPPORT_IMPLEMENT, supportHandle, inputHandle);
Expand Down Expand Up @@ -549,7 +552,8 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(MethodInfo method, boolea
Optional<AiServiceMethodCreateInfo.MetricsInfo> metricsInfo = gatherMetricsInfo(method, addMicrometerMetrics);
Optional<AiServiceMethodCreateInfo.SpanInfo> spanInfo = gatherSpanInfo(method, addOpenTelemetrySpans);

return new AiServiceMethodCreateInfo(systemMessageInfo, userMessageInfo, memoryIdParamPosition, requiresModeration,
return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo,
userMessageInfo, memoryIdParamPosition, requiresModeration,
returnType, metricsInfo, spanInfo);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.spi.services.AiServicesFactory;
import io.quarkiverse.langchain4j.audit.AuditService;
import io.quarkiverse.langchain4j.runtime.AiServicesRecorder;
import io.quarkiverse.langchain4j.runtime.ToolsRecorder;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceClassCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext;
import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutor;
import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutorFactory;
import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo;
Expand All @@ -27,7 +29,12 @@ public class QuarkusAiServicesFactory implements AiServicesFactory {

@Override
public <T> QuarkusAiServices<T> create(AiServiceContext context) {
return new QuarkusAiServices<>(context);
if (context instanceof QuarkusAiServiceContext) {
return new QuarkusAiServices<>(context);
} else {
// the context is always empty (except for the aiServiceClass) anyway and never escapes, so we can just use our own type
return new QuarkusAiServices<>(new QuarkusAiServiceContext(context.aiServiceClass));
}
}

public static class InstanceHolder {
Expand Down Expand Up @@ -70,6 +77,11 @@ public AiServices<T> tools(List<Object> objectsWithTools) {
return this;
}

public AiServices<T> auditService(AuditService auditService) {
((QuarkusAiServiceContext) context).auditService = auditService;
return this;
}

List<ToolMethodCreateInfo> lookup(Object bean, String className) {
Map<String, List<ToolMethodCreateInfo>> metadata = ToolsRecorder.getMetadata();
// Fast path first.
Expand Down Expand Up @@ -116,8 +128,8 @@ public T build() {

try {
return (T) Class.forName(classCreateInfo.getImplClassName(), true, Thread.currentThread()
.getContextClassLoader()).getConstructor(AiServiceContext.class)
.newInstance(context);
.getContextClassLoader()).getConstructor(QuarkusAiServiceContext.class)
.newInstance(((QuarkusAiServiceContext) context));
} catch (Exception e) {
throw new IllegalStateException("Unable to create class '" + classCreateInfo.getImplClassName(), e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.retriever.Retriever;
import dev.langchain4j.service.AiServices;
import io.quarkiverse.langchain4j.audit.AuditService;

/**
* Used to create Langchain4j's {@link AiServices} in a declarative manner that the application can then use simply by
Expand Down Expand Up @@ -69,6 +70,11 @@
*/
Class<? extends Supplier<Retriever<TextSegment>>> retrieverSupplier() default NoRetrieverSupplier.class;

/**
* TODO: write
*/
Class<? extends Supplier<AuditService>> auditServiceSupplier() default BeanIfExistsAuditServiceSupplier.class;

/**
* Marker that is used to tell Quarkus to use the {@link ChatLanguageModel} that has been configured as a CDI bean by
* any of the extensions providing such capability (such as {@code quarkus-langchain4j-openai} and
Expand Down Expand Up @@ -150,4 +156,15 @@ public Retriever<TextSegment> get() {
throw new UnsupportedOperationException("should never be called");
}
}

/**
* TODO: write
*/
final class BeanIfExistsAuditServiceSupplier implements Supplier<AuditService> {

@Override
public AuditService get() {
throw new UnsupportedOperationException("should never be called");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package io.quarkiverse.langchain4j.audit;

import java.util.List;
import java.util.Optional;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.output.Response;

public abstract class Audit {

public record CreateInfo(String interfaceName, String methodName, Object[] parameters,
Optional<Integer> memoryIDParamPosition) {

}

private final CreateInfo createInfo;

public Audit(CreateInfo createInfo) {
this.createInfo = createInfo;
}

public void initialMessages(Optional<SystemMessage> systemMessage, UserMessage userMessage) {

}

public void addRelevantDocument(List<TextSegment> segments, UserMessage userMessage) {

}

public void addLLMToApplicationMessage(Response<AiMessage> response) {

}

public void addApplicationToLLMMessage(ToolExecutionResultMessage toolExecutionResultMessage) {

}

public void onCompletion(Object result) {

}

public void onFailure(Exception e) {

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package io.quarkiverse.langchain4j.audit;

public interface AuditService {

Audit create(Audit.CreateInfo createInfo);

void complete(Audit audit);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.retriever.Retriever;
import dev.langchain4j.service.AiServiceContext;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.audit.AuditService;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceClassCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext;
import io.quarkus.arc.SyntheticCreationalContext;
import io.quarkus.runtime.annotations.Recorder;

Expand All @@ -31,6 +32,9 @@ public class AiServicesRecorder {
private static final TypeLiteral<Instance<ChatMemoryProvider>> CHAT_MEMORY_PROVIDER_INSTANCE_TYPE_LITERAL = new TypeLiteral<>() {
};
private static final TypeLiteral<Instance<Retriever<TextSegment>>> RETRIEVER_INSTANCE_TYPE_LITERAL = new TypeLiteral<>() {

};
private static final TypeLiteral<Instance<AuditService>> AUDIT_SERVICE_TYPE_LITERAL = new TypeLiteral<>() {
};

// the key is the interface's class name
Expand Down Expand Up @@ -72,7 +76,7 @@ public T apply(SyntheticCreationalContext<T> creationalContext) {
Class<?> serviceClass = Thread.currentThread().getContextClassLoader()
.loadClass(info.getServiceClassName());

AiServiceContext aiServiceContext = new AiServiceContext(serviceClass);
QuarkusAiServiceContext aiServiceContext = new QuarkusAiServiceContext(serviceClass);
var quarkusAiServices = INSTANCE.create(aiServiceContext);

if (info.getLanguageModelSupplierClassName() != null) {
Expand Down Expand Up @@ -137,6 +141,23 @@ public T apply(SyntheticCreationalContext<T> creationalContext) {
}
}

if (info.getAuditServiceClassSupplierName() != null) {
if (RegisterAiService.BeanIfExistsAuditServiceSupplier.class.getName()
.equals(info.getAuditServiceClassSupplierName())) {
Instance<AuditService> instance = creationalContext
.getInjectedReference(AUDIT_SERVICE_TYPE_LITERAL);
if (instance.isResolvable()) {
quarkusAiServices.auditService(instance.get());
}
} else {
@SuppressWarnings("rawtypes")
Supplier<? extends AuditService> supplier = (Supplier<? extends AuditService>) Thread
.currentThread().getContextClassLoader().loadClass(info.getAuditServiceClassSupplierName())
.getConstructor().newInstance();
quarkusAiServices.auditService(supplier.get());
}
}

return (T) quarkusAiServices.build();
} catch (ClassNotFoundException e) {
throw new IllegalStateException(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class AiServiceMethodCreateInfo {

private final String interfaceName;
private final String methodName;

private final Optional<TemplateInfo> systemMessageInfo;
private final UserMessageInfo userMessageInfo;
private final Optional<Integer> memoryIdParamPosition;
Expand All @@ -20,11 +23,14 @@ public class AiServiceMethodCreateInfo {
private final Optional<SpanInfo> spanInfo;

@RecordableConstructor
public AiServiceMethodCreateInfo(Optional<TemplateInfo> systemMessageInfo, UserMessageInfo userMessageInfo,
public AiServiceMethodCreateInfo(String interfaceName, String methodName,
Optional<TemplateInfo> systemMessageInfo, UserMessageInfo userMessageInfo,
Optional<Integer> memoryIdParamPosition,
boolean requiresModeration, Class<?> returnType,
Optional<MetricsInfo> metricsInfo,
Optional<SpanInfo> spanInfo) {
this.interfaceName = interfaceName;
this.methodName = methodName;
this.systemMessageInfo = systemMessageInfo;
this.userMessageInfo = userMessageInfo;
this.memoryIdParamPosition = memoryIdParamPosition;
Expand All @@ -34,6 +40,14 @@ public AiServiceMethodCreateInfo(Optional<TemplateInfo> systemMessageInfo, UserM
this.spanInfo = spanInfo;
}

public String getInterfaceName() {
return interfaceName;
}

public String getMethodName() {
return methodName;
}

public Optional<TemplateInfo> getSystemMessageInfo() {
return systemMessageInfo;
}
Expand Down
Loading

0 comments on commit 596ced6

Please sign in to comment.