Skip to content

Commit

Permalink
Merge pull request #124 from quarkiverse/moderation
Browse files Browse the repository at this point in the history
Allow users to enable moderation for @RegisterAiService
  • Loading branch information
geoand authored Dec 11, 2023
2 parents ef365fe + cb19983 commit 2202073
Show file tree
Hide file tree
Showing 13 changed files with 331 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,13 @@ public void nativeSupport(CombinedIndexBuildItem indexBuildItem,
@BuildStep
public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
BuildProducer<RequestChatModelBeanBuildItem> requestChatModelBeanProducer,
BuildProducer<RequestModerationModelBeanBuildItem> requestModerationModelBeanProducer,
BuildProducer<DeclarativeAiServiceBuildItem> declarativeAiServiceProducer,
BuildProducer<ReflectiveClassBuildItem> reflectiveClassProducer) {
IndexView index = indexBuildItem.getIndex();

boolean needChatModelBean = false;
boolean needModerationModelBean = false;
for (AnnotationInstance instance : index.getAnnotations(Langchain4jDotNames.REGISTER_AI_SERVICES)) {
if (instance.target().kind() != AnnotationTarget.Kind.CLASS) {
continue; // should never happen
Expand Down Expand Up @@ -208,11 +210,24 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
}
}

DotName auditServiceClassSupplierName = Langchain4jDotNames.BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER;
AnnotationValue auditServiceClassSupplierValue = instance.value("auditServiceSupplier");
if (auditServiceClassSupplierValue != null) {
auditServiceClassSupplierName = auditServiceClassSupplierValue.asClass().name();
validateSupplierAndRegisterForReflection(auditServiceClassSupplierName, index, reflectiveClassProducer);
DotName auditServiceSupplierClassName = Langchain4jDotNames.BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER;
AnnotationValue auditServiceSupplierValue = instance.value("auditServiceSupplier");
if (auditServiceSupplierValue != null) {
auditServiceSupplierClassName = auditServiceSupplierValue.asClass().name();
validateSupplierAndRegisterForReflection(auditServiceSupplierClassName, index, reflectiveClassProducer);
}

DotName moderationModelSupplierClassName = null;
AnnotationValue moderationModelSupplierValue = instance.value("moderationModelSupplier");
if (moderationModelSupplierValue != null) {
moderationModelSupplierClassName = moderationModelSupplierValue.asClass().name();
if (Langchain4jDotNames.NO_MODERATION_MODEL_SUPPLIER.equals(moderationModelSupplierClassName)) {
moderationModelSupplierClassName = null;
} else if (Langchain4jDotNames.BEAN_MODERATION_MODEL_SUPPLIER.equals(moderationModelSupplierClassName)) {
needModerationModelBean = true;
} else {
validateSupplierAndRegisterForReflection(moderationModelSupplierClassName, index, reflectiveClassProducer);
}
}

BuiltinScope declaredScope = BuiltinScope.from(declarativeAiServiceClassInfo);
Expand All @@ -225,13 +240,17 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
toolDotNames,
chatMemoryProviderSupplierClassDotName,
retrieverSupplierClassDotName,
auditServiceClassSupplierName,
auditServiceSupplierClassName,
moderationModelSupplierClassName,
cdiScope));
}

if (needChatModelBean) {
requestChatModelBeanProducer.produce(new RequestChatModelBeanBuildItem());
}
if (needModerationModelBean) {
requestModerationModelBeanProducer.produce(new RequestModerationModelBeanBuildItem());
}
}

private void validateSupplierAndRegisterForReflection(DotName supplierDotName, IndexView index,
Expand Down Expand Up @@ -262,6 +281,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
boolean needsChatMemoryProviderBean = false;
boolean needsRetrieverBean = false;
boolean needsAuditServiceBean = false;
boolean needsModerationModelBean = false;
Set<DotName> allToolNames = new HashSet<>();

for (DeclarativeAiServiceBuildItem bi : declarativeAiServiceItems) {
Expand All @@ -286,13 +306,18 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
? bi.getAuditServiceClassSupplierDotName().toString()
: null;

String moderationModelSupplierClassName = (bi.getModerationModelSupplierDotName() != null
? bi.getModerationModelSupplierDotName().toString()
: null);

SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem
.configure(declarativeAiServiceClassInfo.name())
.createWith(recorder.createDeclarativeAiService(
new DeclarativeAiServiceCreateInfo(serviceClassName, chatLanguageModelSupplierClassName,
toolClassNames, chatMemoryProviderSupplierClassName,
retrieverSupplierClassName,
auditServiceClassSupplierName)))
auditServiceClassSupplierName,
moderationModelSupplierClassName)))
.destroyer(DeclarativeAiServiceBeanDestroyer.class)
.setRuntimeInit()
.scope(bi.getCdiScope());
Expand Down Expand Up @@ -333,6 +358,11 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
needsAuditServiceBean = true;
}

if (Langchain4jDotNames.BEAN_MODERATION_MODEL_SUPPLIER.toString().equals(moderationModelSupplierClassName)) {
configurator.addInjectionPoint(ClassType.create(Langchain4jDotNames.MODERATION_MODEL));
needsModerationModelBean = true;
}

syntheticBeanProducer.produce(configurator.done());
}

Expand All @@ -348,6 +378,9 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
if (needsAuditServiceBean) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(Langchain4jDotNames.AUDIT_SERVICE));
}
if (needsModerationModelBean) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(Langchain4jDotNames.MODERATION_MODEL));
}
if (!allToolNames.isEmpty()) {
unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(allToolNames));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
List<EmbeddingModelProviderCandidateBuildItem> embeddingCandidateItems,
List<ModerationModelProviderCandidateBuildItem> moderationCandidateItems,
List<RequestChatModelBeanBuildItem> requestChatModelBeanItems,
List<RequestModerationModelBeanBuildItem> requestModerationModelBeanBuildItems,
LangChain4jBuildConfig buildConfig,
BuildProducer<SelectedChatModelProviderBuildItem> selectedChatProducer,
BuildProducer<SelectedEmbeddingModelCandidateBuildItem> selectedEmbeddingProducer,
Expand All @@ -74,6 +75,9 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished
if (!requestChatModelBeanItems.isEmpty()) {
chatModelBeanRequested = true;
}
if (!requestModerationModelBeanBuildItems.isEmpty()) {
moderationModelBeanRequested = true;
}

if (chatModelBeanRequested || streamingChatModelBeanRequested) {
selectedChatProducer.produce(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,23 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
private final DotName chatMemoryProviderSupplierClassDotName;
private final DotName retrieverSupplierClassDotName;
private final DotName auditServiceClassSupplierDotName;
private final DotName moderationModelSupplierDotName;
private final ScopeInfo cdiScope;

public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languageModelSupplierClassDotName,
List<DotName> toolDotNames,
DotName chatMemoryProviderSupplierClassDotName,
DotName retrieverSupplierClassDotName,
DotName auditServiceClassSupplierDotName,
DotName moderationModelSupplierDotName,
ScopeInfo cdiScope) {
this.serviceClassInfo = serviceClassInfo;
this.languageModelSupplierClassDotName = languageModelSupplierClassDotName;
this.toolDotNames = toolDotNames;
this.chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierClassDotName;
this.retrieverSupplierClassDotName = retrieverSupplierClassDotName;
this.auditServiceClassSupplierDotName = auditServiceClassSupplierDotName;
this.moderationModelSupplierDotName = moderationModelSupplierDotName;
this.cdiScope = cdiScope;
}

Expand Down Expand Up @@ -61,6 +64,10 @@ public DotName getAuditServiceClassSupplierDotName() {
return auditServiceClassSupplierDotName;
}

public DotName getModerationModelSupplierDotName() {
return moderationModelSupplierDotName;
}

public ScopeInfo getCdiScope() {
return cdiScope;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,10 @@ public class Langchain4jDotNames {
static final DotName BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanIfExistsAuditServiceSupplier.class);

static final DotName BEAN_MODERATION_MODEL_SUPPLIER = DotName.createSimple(
RegisterAiService.BeanModerationModelSupplier.class);

static final DotName NO_MODERATION_MODEL_SUPPLIER = DotName.createSimple(
RegisterAiService.NoModerationModelSupplier.class);

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

/**
* Allows extension to request the creation of a {@link dev.langchain4j.model.chat.ChatLanguageModel} bean
* even if no injection point
* even if no injection point exists.
*/
public final class RequestChatModelBeanBuildItem extends MultiBuildItem {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package io.quarkiverse.langchain4j.deployment;

import io.quarkus.builder.item.MultiBuildItem;

/**
* Allows extension to request the creation of a {@link dev.langchain4j.model.moderation.ModerationModel} bean
* even if no injection point exists.
*/
public final class RequestModerationModelBeanBuildItem extends MultiBuildItem {
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.moderation.ModerationModel;
import dev.langchain4j.retriever.Retriever;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
Expand Down Expand Up @@ -73,7 +74,7 @@

/**
* Configures the way to obtain the {@link Retriever} to use (when using RAG).
* By default, no chat memory is used.
* By default, no supplier is used.
* If a CDI bean of type {@link ChatMemory} is needed, the value should be {@link BeanRetrieverSupplier}.
* If an arbitrary {@link ChatMemory} instance is needed, a custom implementation of {@link Supplier<ChatMemory>}
* needs to be provided.
Expand All @@ -89,6 +90,15 @@
*/
Class<? extends Supplier<AuditService>> auditServiceSupplier() default BeanIfExistsAuditServiceSupplier.class;

/**
* Configures the way to obtain the {@link ModerationModel} to use.
* By default, no moderation model is used.
* If a CDI bean of type {@link ChatMemory} is needed, the value should be {@link BeanRetrieverSupplier}.
* If an arbitrary {@link ChatMemory} instance is needed, a custom implementation of {@link Supplier<ChatMemory>}
* needs to be provided.
*/
Class<? extends Supplier<ModerationModel>> moderationModelSupplier() default NoModerationModelSupplier.class;

/**
* Marker that is used to tell Quarkus to use the {@link ChatLanguageModel} that has been configured as a CDI bean by
* any of the extensions providing such capability (such as {@code quarkus-langchain4j-openai} and
Expand Down Expand Up @@ -162,4 +172,28 @@ public AuditService get() {
throw new UnsupportedOperationException("should never be called");
}
}

/**
* Marker that is used to tell Quarkus to use the {@link ModerationModel} that has been configured as a CDI bean by
* any of the extensions providing such capability (such as {@code quarkus-langchain4j-openai} and
* {@code quarkus-langchain4j-azure-openai}).
*/
final class BeanModerationModelSupplier implements Supplier<ModerationModel> {

@Override
public ModerationModel get() {
throw new UnsupportedOperationException("should never be called");
}
}

/**
* Marker class to indicate that no moderation model should be used
*/
final class NoModerationModelSupplier implements Supplier<ModerationModel> {

@Override
public ModerationModel get() {
throw new UnsupportedOperationException("should never be called");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.moderation.ModerationModel;
import dev.langchain4j.retriever.Retriever;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.audit.AuditService;
Expand All @@ -28,9 +29,6 @@

@Recorder
public class AiServicesRecorder {

private static final TypeLiteral<Instance<ChatMemoryProvider>> CHAT_MEMORY_PROVIDER_INSTANCE_TYPE_LITERAL = new TypeLiteral<>() {
};
private static final TypeLiteral<Instance<Retriever<TextSegment>>> RETRIEVER_INSTANCE_TYPE_LITERAL = new TypeLiteral<>() {

};
Expand Down Expand Up @@ -151,6 +149,21 @@ public T apply(SyntheticCreationalContext<T> creationalContext) {
}
}

if (info.getModerationModelSupplierClassName() != null) {
if (RegisterAiService.BeanModerationModelSupplier.class.getName()
.equals(info.getModerationModelSupplierClassName())) {
ModerationModel moderationModel = creationalContext.getInjectedReference(ModerationModel.class);
quarkusAiServices.moderationModel(moderationModel);
} else {
@SuppressWarnings("rawtypes")
Supplier<? extends AuditService> supplier = (Supplier<? extends AuditService>) Thread
.currentThread().getContextClassLoader()
.loadClass(info.getModerationModelSupplierClassName())
.getConstructor().newInstance();
quarkusAiServices.auditService(supplier.get());
}
}

return (T) quarkusAiServices.build();
} catch (ClassNotFoundException e) {
throw new IllegalStateException(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,21 @@ public class DeclarativeAiServiceCreateInfo {
private final String retrieverSupplierClassName;

private final String auditServiceClassSupplierName;
private final String moderationModelSupplierClassName;

@RecordableConstructor
public DeclarativeAiServiceCreateInfo(String serviceClassName, String languageModelSupplierClassName,
List<String> toolsClassNames, String chatMemoryProviderSupplierClassName,
String retrieverSupplierClassName,
String auditServiceClassSupplierName) {
String auditServiceClassSupplierName,
String moderationModelSupplierClassName) {
this.serviceClassName = serviceClassName;
this.languageModelSupplierClassName = languageModelSupplierClassName;
this.toolsClassNames = toolsClassNames;
this.chatMemoryProviderSupplierClassName = chatMemoryProviderSupplierClassName;
this.retrieverSupplierClassName = retrieverSupplierClassName;
this.auditServiceClassSupplierName = auditServiceClassSupplierName;
this.moderationModelSupplierClassName = moderationModelSupplierClassName;
}

public String getServiceClassName() {
Expand All @@ -50,4 +53,8 @@ public String getRetrieverSupplierClassName() {
public String getAuditServiceClassSupplierName() {
return auditServiceClassSupplierName;
}

public String getModerationModelSupplierClassName() {
return moderationModelSupplierClassName;
}
}
28 changes: 28 additions & 0 deletions docs/modules/ROOT/pages/ai-services.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,34 @@ A document retriever fetches data from an external source and provides it to the

This guidance aims to cover all crucial aspects of designing AI services with Quarkus, ensuring robust and efficient interactions with LLMs.

== Moderation

By default, @RegisterAiService annotated interfaces don't moderate content. However, users can opt in to having the LLM moderate
content by annotating the method with `@Moderate`.

For moderation to work, the following criteria need to be met:

* A CDI bean for `dev.langchain4j.model.moderation.ModerationModel` must be configured (the `quarkus-langchain4j-openai` and `quarkus-langchain4j-azure-openai` provide one out of the box)
* The interface must be configured with `@RegisterAiService(moderationModelSupplier = RegisterAiService.BeanModerationModelSupplier.class)`

=== Advanced usage
An alternative to providing a CDI bean is to configure the interface with `@RegisterAiService(moderationModelSupplier = MyCustomSupplier.class)`
and implement `MyCustomModerationSupplier` like so:

[source,java]
----
import dev.langchain4j.model.moderation.ModerationModel;
public class MyCustomModerationSupplier implements Supplier<ModerationModel> {
@Override
public ModerationModel get(){
// TODO: implement
}
}
----

== Observability

Observability is built into services created via `@RegisterAiService` and is provided in the following form:
Expand Down
Loading

0 comments on commit 2202073

Please sign in to comment.