Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sketch out API for auditing #83

Merged
merged 2 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions .github/workflows/build-pull-request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ jobs:
build:
name: Build on ${{ matrix.os }} - ${{ matrix.java }}
strategy:
# PineconeEmbeddingStoreTest uses a single shared index, we can't run multiple CI runs on it at once
# If we have PINECONE_API_KEY available, then the test will run, so set max-parallel to 1
max-parallel: ${{ github.secret_source == 'Actions' && 1 || 16 }}
fail-fast: false
matrix:
os: [ubuntu-latest]
Expand All @@ -46,16 +43,6 @@ jobs:

- name: Build with Maven
run: mvn -B clean install -Dno-format
env: # note that secrets are not available when triggered by PR from a fork, so some tests will be skipped
PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY }}
PINECONE_ENVIRONMENT: ${{ secrets.PINECONE_ENVIRONMENT }}
PINECONE_INDEX_NAME: ${{ secrets.PINECONE_INDEX_NAME }}
PINECONE_PROJECT_ID: ${{ secrets.PINECONE_PROJECT_ID }}

- name: Build with Maven (Native)
run: mvn -B install -Dnative -Dquarkus.native.container-build -Dnative.surefire.skip
env: # note that secrets are not available when triggered by PR from a fork, so some tests will be skipped
PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY }}
PINECONE_ENVIRONMENT: ${{ secrets.PINECONE_ENVIRONMENT }}
PINECONE_INDEX_NAME: ${{ secrets.PINECONE_INDEX_NAME }}
PINECONE_PROJECT_ID: ${{ secrets.PINECONE_PROJECT_ID }}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import org.objectweb.asm.tree.analysis.AnalyzerException;

import dev.langchain4j.exception.IllegalConfigurationException;
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.V;
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
import io.quarkiverse.langchain4j.runtime.AiServicesRecorder;
Expand All @@ -53,6 +52,7 @@
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodImplementationSupport;
import io.quarkiverse.langchain4j.runtime.aiservice.DeclarativeAiServiceCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.MetricsWrapper;
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext;
import io.quarkiverse.langchain4j.runtime.aiservice.SpanWrapper;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ArcContainer;
Expand Down Expand Up @@ -101,6 +101,7 @@ public class AiServicesProcessor {
private static final MethodDescriptor SUPPORT_IMPLEMENT = MethodDescriptor.ofMethod(
AiServiceMethodImplementationSupport.class,
"implement", Object.class, AiServiceMethodImplementationSupport.Input.class);
public static final DotName CDI_INSTANCE = DotName.createSimple(Instance.class);

@BuildStep
public void nativeSupport(CombinedIndexBuildItem indexBuildItem,
Expand Down Expand Up @@ -203,13 +204,21 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
}
}

DotName auditServiceClassSupplierName = Langchain4jDotNames.BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER;
AnnotationValue auditServiceClassSupplierValue = instance.value("auditServiceSupplier");
if (auditServiceClassSupplierValue != null) {
auditServiceClassSupplierName = auditServiceClassSupplierValue.asClass().name();
validateSupplierAndRegisterForReflection(auditServiceClassSupplierName, index, reflectiveClassProducer);
}

declarativeAiServiceProducer.produce(
new DeclarativeAiServiceBuildItem(
declarativeAiServiceClassInfo,
chatLanguageModelSupplierClassDotName,
toolDotNames,
chatMemoryProviderSupplierClassDotName,
retrieverSupplierClassDotName));
retrieverSupplierClassDotName,
auditServiceClassSupplierName));
}

if (needChatModelBean) {
Expand Down Expand Up @@ -244,6 +253,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
boolean needsChatModelBean = false;
boolean needsChatMemoryProviderBean = false;
boolean needsRetrieverBean = false;
boolean needsAuditServiceBean = false;
Set<DotName> allToolNames = new HashSet<>();

for (DeclarativeAiServiceBuildItem bi : declarativeAiServiceItems) {
Expand All @@ -264,12 +274,17 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
? bi.getRetrieverSupplierClassDotName().toString()
: null;

String auditServiceClassSupplierName = bi.getAuditServiceClassSupplierDotName() != null
? bi.getAuditServiceClassSupplierDotName().toString()
: null;

SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem
.configure(declarativeAiServiceClassInfo.name())
.createWith(recorder.createDeclarativeAiService(
new DeclarativeAiServiceCreateInfo(serviceClassName, chatLanguageModelSupplierClassName,
toolClassNames, chatMemoryProviderSupplierClassName,
retrieverSupplierClassName)))
retrieverSupplierClassName,
auditServiceClassSupplierName)))
.setRuntimeInit()
.scope(ApplicationScoped.class);
if ((chatLanguageModelSupplierClassName == null) && selectedChatModelProvider.isPresent()) { // TODO: is second condition needed?
Expand All @@ -290,7 +305,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
needsChatMemoryProviderBean = true;
} else if (Langchain4jDotNames.BEAN_IF_EXISTS_CHAT_MEMORY_PROVIDER_SUPPLIER.toString()
.equals(chatMemoryProviderSupplierClassName)) {
configurator.addInjectionPoint(ParameterizedType.create(DotName.createSimple(Instance.class),
configurator.addInjectionPoint(ParameterizedType.create(CDI_INSTANCE,
new Type[] { ClassType.create(Langchain4jDotNames.CHAT_MEMORY_PROVIDER) }, null));
needsChatMemoryProviderBean = true;
}
Expand All @@ -301,13 +316,19 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
needsRetrieverBean = true;
} else if (Langchain4jDotNames.BEAN_IF_EXISTS_RETRIEVER_SUPPLIER.toString()
.equals(retrieverSupplierClassName)) {
configurator.addInjectionPoint(ParameterizedType.create(DotName.createSimple(Instance.class),
configurator.addInjectionPoint(ParameterizedType.create(CDI_INSTANCE,
new Type[] { ParameterizedType.create(Langchain4jDotNames.RETRIEVER,
new Type[] { ClassType.create(Langchain4jDotNames.TEXT_SEGMENT) }, null) },
null));
needsRetrieverBean = true;
}

if (Langchain4jDotNames.BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER.toString().equals(auditServiceClassSupplierName)) {
configurator.addInjectionPoint(ParameterizedType.create(CDI_INSTANCE,
new Type[] { ClassType.create(Langchain4jDotNames.AUDIT_SERVICE) }, null));
needsAuditServiceBean = true;
}

syntheticBeanProducer.produce(configurator.done());
}

Expand All @@ -320,6 +341,9 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
if (needsRetrieverBean) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(Langchain4jDotNames.RETRIEVER));
}
if (needsAuditServiceBean) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(Langchain4jDotNames.AUDIT_SERVICE));
}
if (!allToolNames.isEmpty()) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(allToolNames));
}
Expand Down Expand Up @@ -436,19 +460,19 @@ public void handleAiServices(AiServicesRecorder recorder,
.interfaces(iface.name().toString())
.build()) {

FieldDescriptor contextField = classCreator.getFieldCreator("context", AiServiceContext.class)
FieldDescriptor contextField = classCreator.getFieldCreator("context", QuarkusAiServiceContext.class)
.setModifiers(Modifier.PRIVATE | Modifier.FINAL)
.getFieldDescriptor();

for (MethodInfo methodInfo : methodsToImplement) {
// The implementation essentially gets method the context and delegates to
// The implementation essentially gets the context and delegates to
// MethodImplementationSupport#implement

String methodId = createMethodId(methodInfo);
perMethodMetadata.put(methodId,
gatherMethodMetadata(methodInfo, addMicrometerMetrics, addOpenTelemetrySpan));
MethodCreator constructor = classCreator.getMethodCreator(MethodDescriptor.INIT, "V",
AiServiceContext.class);
QuarkusAiServiceContext.class);
constructor.invokeSpecialMethod(OBJECT_CONSTRUCTOR, constructor.getThis());
constructor.writeInstanceField(contextField, constructor.getThis(), constructor.getMethodParam(0));
constructor.returnValue(null);
Expand All @@ -466,7 +490,7 @@ public void handleAiServices(AiServicesRecorder recorder,
ResultHandle supportHandle = getFromCDI(mc, AiServiceMethodImplementationSupport.class.getName());
ResultHandle inputHandle = mc.newInstance(
MethodDescriptor.ofConstructor(AiServiceMethodImplementationSupport.Input.class,
AiServiceContext.class, AiServiceMethodCreateInfo.class, Object[].class),
QuarkusAiServiceContext.class, AiServiceMethodCreateInfo.class, Object[].class),
contextHandle, methodCreateInfoHandle, paramsHandle);

ResultHandle resultHandle = mc.invokeVirtualMethod(SUPPORT_IMPLEMENT, supportHandle, inputHandle);
Expand Down Expand Up @@ -547,7 +571,8 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(MethodInfo method, boolea
Optional<AiServiceMethodCreateInfo.MetricsInfo> metricsInfo = gatherMetricsInfo(method, addMicrometerMetrics);
Optional<AiServiceMethodCreateInfo.SpanInfo> spanInfo = gatherSpanInfo(method, addOpenTelemetrySpans);

return new AiServiceMethodCreateInfo(systemMessageInfo, userMessageInfo, memoryIdParamPosition, requiresModeration,
return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo,
userMessageInfo, memoryIdParamPosition, requiresModeration,
returnType, metricsInfo, spanInfo);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {

private final DotName chatMemoryProviderSupplierClassDotName;
private final DotName retrieverSupplierClassDotName;
private final DotName auditServiceClassSupplierDotName;

public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languageModelSupplierClassDotName,
List<DotName> toolDotNames,
DotName chatMemoryProviderSupplierClassDotName,
DotName retrieverSupplierClassDotName) {
DotName retrieverSupplierClassDotName,
DotName auditServiceClassSupplierDotName) {
this.serviceClassInfo = serviceClassInfo;
this.languageModelSupplierClassDotName = languageModelSupplierClassDotName;
this.toolDotNames = toolDotNames;
this.chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierClassDotName;
this.retrieverSupplierClassDotName = retrieverSupplierClassDotName;
this.auditServiceClassSupplierDotName = auditServiceClassSupplierDotName;
}

public ClassInfo getServiceClassInfo() {
Expand All @@ -49,4 +52,8 @@ public DotName getChatMemoryProviderSupplierClassDotName() {
public DotName getRetrieverSupplierClassDotName() {
return retrieverSupplierClassDotName;
}

public DotName getAuditServiceClassSupplierDotName() {
return auditServiceClassSupplierDotName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import dev.langchain4j.service.UserName;
import io.quarkiverse.langchain4j.CreatedAware;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.audit.AuditService;

public class Langchain4jDotNames {
public static final DotName CHAT_MODEL = DotName.createSimple(ChatLanguageModel.class);
Expand Down Expand Up @@ -53,10 +54,15 @@ public class Langchain4jDotNames {
static final DotName RETRIEVER = DotName.createSimple(Retriever.class);
static final DotName TEXT_SEGMENT = DotName.createSimple(TextSegment.class);

static final DotName AUDIT_SERVICE = DotName.createSimple(AuditService.class);

static final DotName BEAN_RETRIEVER_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanRetrieverSupplier.class);

static final DotName BEAN_IF_EXISTS_RETRIEVER_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanIfExistsRetrieverSupplier.class);

static final DotName BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanIfExistsAuditServiceSupplier.class);

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.spi.services.AiServicesFactory;
import io.quarkiverse.langchain4j.audit.AuditService;
import io.quarkiverse.langchain4j.runtime.AiServicesRecorder;
import io.quarkiverse.langchain4j.runtime.ToolsRecorder;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceClassCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext;
import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutor;
import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutorFactory;
import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo;
Expand All @@ -27,7 +29,12 @@ public class QuarkusAiServicesFactory implements AiServicesFactory {

@Override
public <T> QuarkusAiServices<T> create(AiServiceContext context) {
return new QuarkusAiServices<>(context);
if (context instanceof QuarkusAiServiceContext) {
return new QuarkusAiServices<>(context);
} else {
// the context is always empty (except for the aiServiceClass) anyway and never escapes, so we can just use our own type
return new QuarkusAiServices<>(new QuarkusAiServiceContext(context.aiServiceClass));
}
}

public static class InstanceHolder {
Expand Down Expand Up @@ -70,6 +77,11 @@ public AiServices<T> tools(List<Object> objectsWithTools) {
return this;
}

public AiServices<T> auditService(AuditService auditService) {
((QuarkusAiServiceContext) context).auditService = auditService;
return this;
}

List<ToolMethodCreateInfo> lookup(Object bean, String className) {
Map<String, List<ToolMethodCreateInfo>> metadata = ToolsRecorder.getMetadata();
// Fast path first.
Expand Down Expand Up @@ -116,8 +128,8 @@ public T build() {

try {
return (T) Class.forName(classCreateInfo.getImplClassName(), true, Thread.currentThread()
.getContextClassLoader()).getConstructor(AiServiceContext.class)
.newInstance(context);
.getContextClassLoader()).getConstructor(QuarkusAiServiceContext.class)
.newInstance(((QuarkusAiServiceContext) context));
} catch (Exception e) {
throw new IllegalStateException("Unable to create class '" + classCreateInfo.getImplClassName(), e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.retriever.Retriever;
import dev.langchain4j.service.AiServices;
import io.quarkiverse.langchain4j.audit.AuditService;

/**
* Used to create Langchain4j's {@link AiServices} in a declarative manner that the application can then use simply by
Expand Down Expand Up @@ -74,6 +75,15 @@
*/
Class<? extends Supplier<Retriever<TextSegment>>> retrieverSupplier() default NoRetrieverSupplier.class;

/**
* Configures the way to obtain the {@link AuditService} to use.
* By default, Quarkus will look for a CDI bean that implements {@link AuditService}, but will fall back to not using
* any memory if no such bean exists.
* If an arbitrary {@link AuditService} instance is needed, a custom implementation of
* {@link Supplier<AuditService>} needs to be provided.
*/
Class<? extends Supplier<AuditService>> auditServiceSupplier() default BeanIfExistsAuditServiceSupplier.class;

/**
* Marker that is used to tell Quarkus to use the {@link ChatLanguageModel} that has been configured as a CDI bean by
* any of the extensions providing such capability (such as {@code quarkus-langchain4j-openai} and
Expand Down Expand Up @@ -155,4 +165,16 @@ public Retriever<TextSegment> get() {
throw new UnsupportedOperationException("should never be called");
}
}

/**
* Marker that is used to tell Quarkus to use the {@link AuditService} that the user has configured as a CDI bean.
* If no such bean exists, then no audit service will be used.
*/
final class BeanIfExistsAuditServiceSupplier implements Supplier<AuditService> {

@Override
public AuditService get() {
throw new UnsupportedOperationException("should never be called");
}
}
}
Loading
Loading