From 68b4a1fbd2f5868360d10878db441e65372370f0 Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Mon, 11 Dec 2023 11:42:33 +0200 Subject: [PATCH] Allow users to enable moderation for @RegisterAiService --- .../deployment/AiServicesProcessor.java | 47 ++++- .../deployment/BeansProcessor.java | 4 + .../DeclarativeAiServiceBuildItem.java | 7 + .../deployment/Langchain4jDotNames.java | 6 + .../RequestChatModelBeanBuildItem.java | 2 +- .../RequestModerationModelBeanBuildItem.java | 10 + .../langchain4j/RegisterAiService.java | 36 +++- .../runtime/AiServicesRecorder.java | 19 +- .../DeclarativeAiServiceCreateInfo.java | 9 +- docs/modules/ROOT/pages/ai-services.adoc | 9 + .../aiservices/ModerationModelTest.java | 176 ++++++++++++++++++ 11 files changed, 312 insertions(+), 13 deletions(-) create mode 100644 core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/RequestModerationModelBeanBuildItem.java create mode 100644 openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/ModerationModelTest.java diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index b6ea54884..afcb08147 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -153,11 +153,13 @@ public void nativeSupport(CombinedIndexBuildItem indexBuildItem, @BuildStep public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, BuildProducer requestChatModelBeanProducer, + BuildProducer requestModerationModelBeanProducer, BuildProducer declarativeAiServiceProducer, BuildProducer reflectiveClassProducer) { IndexView index = indexBuildItem.getIndex(); boolean needChatModelBean = false; + boolean needModerationModelBean = false; for (AnnotationInstance instance : index.getAnnotations(Langchain4jDotNames.REGISTER_AI_SERVICES)) { if (instance.target().kind() != AnnotationTarget.Kind.CLASS) { continue; // should never happen @@ -208,11 +210,24 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, } } - DotName auditServiceClassSupplierName = Langchain4jDotNames.BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER; - AnnotationValue auditServiceClassSupplierValue = instance.value("auditServiceSupplier"); - if (auditServiceClassSupplierValue != null) { - auditServiceClassSupplierName = auditServiceClassSupplierValue.asClass().name(); - validateSupplierAndRegisterForReflection(auditServiceClassSupplierName, index, reflectiveClassProducer); + DotName auditServiceSupplierClassName = Langchain4jDotNames.BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER; + AnnotationValue auditServiceSupplierValue = instance.value("auditServiceSupplier"); + if (auditServiceSupplierValue != null) { + auditServiceSupplierClassName = auditServiceSupplierValue.asClass().name(); + validateSupplierAndRegisterForReflection(auditServiceSupplierClassName, index, reflectiveClassProducer); + } + + DotName moderationModelSupplierClassName = null; + AnnotationValue moderationModelSupplierValue = instance.value("moderationModelSupplier"); + if (moderationModelSupplierValue != null) { + moderationModelSupplierClassName = moderationModelSupplierValue.asClass().name(); + if (Langchain4jDotNames.NO_MODERATION_MODEL_SUPPLIER.equals(moderationModelSupplierClassName)) { + moderationModelSupplierClassName = null; + } else if (Langchain4jDotNames.BEAN_MODERATION_MODEL_SUPPLIER.equals(moderationModelSupplierClassName)) { + needModerationModelBean = true; + } else { + validateSupplierAndRegisterForReflection(moderationModelSupplierClassName, index, reflectiveClassProducer); + } } BuiltinScope declaredScope = BuiltinScope.from(declarativeAiServiceClassInfo); @@ -225,13 +240,17 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem, toolDotNames, chatMemoryProviderSupplierClassDotName, retrieverSupplierClassDotName, - auditServiceClassSupplierName, + auditServiceSupplierClassName, + moderationModelSupplierClassName, cdiScope)); } if (needChatModelBean) { requestChatModelBeanProducer.produce(new RequestChatModelBeanBuildItem()); } + if (needModerationModelBean) { + requestModerationModelBeanProducer.produce(new RequestModerationModelBeanBuildItem()); + } } private void validateSupplierAndRegisterForReflection(DotName supplierDotName, IndexView index, @@ -262,6 +281,7 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, boolean needsChatMemoryProviderBean = false; boolean needsRetrieverBean = false; boolean needsAuditServiceBean = false; + boolean needsModerationModelBean = false; Set allToolNames = new HashSet<>(); for (DeclarativeAiServiceBuildItem bi : declarativeAiServiceItems) { @@ -286,13 +306,18 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, ? bi.getAuditServiceClassSupplierDotName().toString() : null; + String moderationModelSupplierClassName = (bi.getModerationModelSupplierDotName() != null + ? bi.getModerationModelSupplierDotName().toString() + : null); + SyntheticBeanBuildItem.ExtendedBeanConfigurator configurator = SyntheticBeanBuildItem .configure(declarativeAiServiceClassInfo.name()) .createWith(recorder.createDeclarativeAiService( new DeclarativeAiServiceCreateInfo(serviceClassName, chatLanguageModelSupplierClassName, toolClassNames, chatMemoryProviderSupplierClassName, retrieverSupplierClassName, - auditServiceClassSupplierName))) + auditServiceClassSupplierName, + moderationModelSupplierClassName))) .destroyer(DeclarativeAiServiceBeanDestroyer.class) .setRuntimeInit() .scope(bi.getCdiScope()); @@ -333,6 +358,11 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, needsAuditServiceBean = true; } + if (Langchain4jDotNames.BEAN_MODERATION_MODEL_SUPPLIER.toString().equals(moderationModelSupplierClassName)) { + configurator.addInjectionPoint(ClassType.create(Langchain4jDotNames.MODERATION_MODEL)); + needsModerationModelBean = true; + } + syntheticBeanProducer.produce(configurator.done()); } @@ -348,6 +378,9 @@ public void handleDeclarativeServices(AiServicesRecorder recorder, if (needsAuditServiceBean) { unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(Langchain4jDotNames.AUDIT_SERVICE)); } + if (needsModerationModelBean) { + unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(Langchain4jDotNames.MODERATION_MODEL)); + } if (!allToolNames.isEmpty()) { unremoveableProducer.produce(UnremovableBeanBuildItem.beanTypes(allToolNames)); } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/BeansProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/BeansProcessor.java index f76887830..e4a7f9329 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/BeansProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/BeansProcessor.java @@ -49,6 +49,7 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished List embeddingCandidateItems, List moderationCandidateItems, List requestChatModelBeanItems, + List requestModerationModelBeanBuildItems, LangChain4jBuildConfig buildConfig, BuildProducer selectedChatProducer, BuildProducer selectedEmbeddingProducer, @@ -74,6 +75,9 @@ public void handleProviders(BeanDiscoveryFinishedBuildItem beanDiscoveryFinished if (!requestChatModelBeanItems.isEmpty()) { chatModelBeanRequested = true; } + if (!requestModerationModelBeanBuildItems.isEmpty()) { + moderationModelBeanRequested = true; + } if (chatModelBeanRequested || streamingChatModelBeanRequested) { selectedChatProducer.produce( diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java index 97194e993..cd87db3c7 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/DeclarativeAiServiceBuildItem.java @@ -20,6 +20,7 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem { private final DotName chatMemoryProviderSupplierClassDotName; private final DotName retrieverSupplierClassDotName; private final DotName auditServiceClassSupplierDotName; + private final DotName moderationModelSupplierDotName; private final ScopeInfo cdiScope; public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languageModelSupplierClassDotName, @@ -27,6 +28,7 @@ public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languag DotName chatMemoryProviderSupplierClassDotName, DotName retrieverSupplierClassDotName, DotName auditServiceClassSupplierDotName, + DotName moderationModelSupplierDotName, ScopeInfo cdiScope) { this.serviceClassInfo = serviceClassInfo; this.languageModelSupplierClassDotName = languageModelSupplierClassDotName; @@ -34,6 +36,7 @@ public DeclarativeAiServiceBuildItem(ClassInfo serviceClassInfo, DotName languag this.chatMemoryProviderSupplierClassDotName = chatMemoryProviderSupplierClassDotName; this.retrieverSupplierClassDotName = retrieverSupplierClassDotName; this.auditServiceClassSupplierDotName = auditServiceClassSupplierDotName; + this.moderationModelSupplierDotName = moderationModelSupplierDotName; this.cdiScope = cdiScope; } @@ -61,6 +64,10 @@ public DotName getAuditServiceClassSupplierDotName() { return auditServiceClassSupplierDotName; } + public DotName getModerationModelSupplierDotName() { + return moderationModelSupplierDotName; + } + public ScopeInfo getCdiScope() { return cdiScope; } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java index a00aee0c0..1fdf792c5 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/Langchain4jDotNames.java @@ -61,4 +61,10 @@ public class Langchain4jDotNames { static final DotName BEAN_IF_EXISTS_AUDIT_SERVICE_SUPPLIER = DotName.createSimple( RegisterAiService.BeanIfExistsAuditServiceSupplier.class); + static final DotName BEAN_MODERATION_MODEL_SUPPLIER = DotName.createSimple( + RegisterAiService.BeanModerationModelSupplier.class); + + static final DotName NO_MODERATION_MODEL_SUPPLIER = DotName.createSimple( + RegisterAiService.NoModerationModelSupplier.class); + } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/RequestChatModelBeanBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/RequestChatModelBeanBuildItem.java index 838aa90c7..3a3ddbad0 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/RequestChatModelBeanBuildItem.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/RequestChatModelBeanBuildItem.java @@ -4,7 +4,7 @@ /** * Allows extension to request the creation of a {@link dev.langchain4j.model.chat.ChatLanguageModel} bean - * even if no injection point + * even if no injection point exists. */ public final class RequestChatModelBeanBuildItem extends MultiBuildItem { } diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/RequestModerationModelBeanBuildItem.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/RequestModerationModelBeanBuildItem.java new file mode 100644 index 000000000..da94cb365 --- /dev/null +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/RequestModerationModelBeanBuildItem.java @@ -0,0 +1,10 @@ +package io.quarkiverse.langchain4j.deployment; + +import io.quarkus.builder.item.MultiBuildItem; + +/** + * Allows extension to request the creation of a {@link dev.langchain4j.model.moderation.ModerationModel} bean + * even if no injection point exists. + */ +public final class RequestModerationModelBeanBuildItem extends MultiBuildItem { +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java index ef916d670..d8a33cc97 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/RegisterAiService.java @@ -12,6 +12,7 @@ import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.memory.chat.MessageWindowChatMemory; import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.moderation.ModerationModel; import dev.langchain4j.retriever.Retriever; import dev.langchain4j.service.AiServices; import dev.langchain4j.store.memory.chat.ChatMemoryStore; @@ -73,7 +74,7 @@ /** * Configures the way to obtain the {@link Retriever} to use (when using RAG). - * By default, no chat memory is used. + * By default, no supplier is used. * If a CDI bean of type {@link ChatMemory} is needed, the value should be {@link BeanRetrieverSupplier}. * If an arbitrary {@link ChatMemory} instance is needed, a custom implementation of {@link Supplier} * needs to be provided. @@ -89,6 +90,15 @@ */ Class> auditServiceSupplier() default BeanIfExistsAuditServiceSupplier.class; + /** + * Configures the way to obtain the {@link ModerationModel} to use. + * By default, no moderation model is used. + * If a CDI bean of type {@link ChatMemory} is needed, the value should be {@link BeanRetrieverSupplier}. + * If an arbitrary {@link ChatMemory} instance is needed, a custom implementation of {@link Supplier} + * needs to be provided. + */ + Class> moderationModelSupplier() default NoModerationModelSupplier.class; + /** * Marker that is used to tell Quarkus to use the {@link ChatLanguageModel} that has been configured as a CDI bean by * any of the extensions providing such capability (such as {@code quarkus-langchain4j-openai} and @@ -162,4 +172,28 @@ public AuditService get() { throw new UnsupportedOperationException("should never be called"); } } + + /** + * Marker that is used to tell Quarkus to use the {@link ModerationModel} that has been configured as a CDI bean by + * any of the extensions providing such capability (such as {@code quarkus-langchain4j-openai} and + * {@code quarkus-langchain4j-azure-openai}). + */ + final class BeanModerationModelSupplier implements Supplier { + + @Override + public ModerationModel get() { + throw new UnsupportedOperationException("should never be called"); + } + } + + /** + * Marker class to indicate that no moderation model should be used + */ + final class NoModerationModelSupplier implements Supplier { + + @Override + public ModerationModel get() { + throw new UnsupportedOperationException("should never be called"); + } + } } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java index 59d17f148..d7bc6bd42 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/AiServicesRecorder.java @@ -16,6 +16,7 @@ import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.moderation.ModerationModel; import dev.langchain4j.retriever.Retriever; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.audit.AuditService; @@ -28,9 +29,6 @@ @Recorder public class AiServicesRecorder { - - private static final TypeLiteral> CHAT_MEMORY_PROVIDER_INSTANCE_TYPE_LITERAL = new TypeLiteral<>() { - }; private static final TypeLiteral>> RETRIEVER_INSTANCE_TYPE_LITERAL = new TypeLiteral<>() { }; @@ -151,6 +149,21 @@ public T apply(SyntheticCreationalContext creationalContext) { } } + if (info.getModerationModelSupplierClassName() != null) { + if (RegisterAiService.BeanModerationModelSupplier.class.getName() + .equals(info.getModerationModelSupplierClassName())) { + ModerationModel moderationModel = creationalContext.getInjectedReference(ModerationModel.class); + quarkusAiServices.moderationModel(moderationModel); + } else { + @SuppressWarnings("rawtypes") + Supplier supplier = (Supplier) Thread + .currentThread().getContextClassLoader() + .loadClass(info.getModerationModelSupplierClassName()) + .getConstructor().newInstance(); + quarkusAiServices.auditService(supplier.get()); + } + } + return (T) quarkusAiServices.build(); } catch (ClassNotFoundException e) { throw new IllegalStateException(e); diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java index c50046f19..83fda3d0b 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/DeclarativeAiServiceCreateInfo.java @@ -13,18 +13,21 @@ public class DeclarativeAiServiceCreateInfo { private final String retrieverSupplierClassName; private final String auditServiceClassSupplierName; + private final String moderationModelSupplierClassName; @RecordableConstructor public DeclarativeAiServiceCreateInfo(String serviceClassName, String languageModelSupplierClassName, List toolsClassNames, String chatMemoryProviderSupplierClassName, String retrieverSupplierClassName, - String auditServiceClassSupplierName) { + String auditServiceClassSupplierName, + String moderationModelSupplierClassName) { this.serviceClassName = serviceClassName; this.languageModelSupplierClassName = languageModelSupplierClassName; this.toolsClassNames = toolsClassNames; this.chatMemoryProviderSupplierClassName = chatMemoryProviderSupplierClassName; this.retrieverSupplierClassName = retrieverSupplierClassName; this.auditServiceClassSupplierName = auditServiceClassSupplierName; + this.moderationModelSupplierClassName = moderationModelSupplierClassName; } public String getServiceClassName() { @@ -50,4 +53,8 @@ public String getRetrieverSupplierClassName() { public String getAuditServiceClassSupplierName() { return auditServiceClassSupplierName; } + + public String getModerationModelSupplierClassName() { + return moderationModelSupplierClassName; + } } diff --git a/docs/modules/ROOT/pages/ai-services.adoc b/docs/modules/ROOT/pages/ai-services.adoc index 849877d88..723f667f2 100644 --- a/docs/modules/ROOT/pages/ai-services.adoc +++ b/docs/modules/ROOT/pages/ai-services.adoc @@ -275,6 +275,15 @@ A document retriever fetches data from an external source and provides it to the This guidance aims to cover all crucial aspects of designing AI services with Quarkus, ensuring robust and efficient interactions with LLMs. +== Moderation + +By default, @RegisterAiService annotated interfaces don't moderate content. However, users can opt-in to having the LLM moderate +content by annotating the method with `@Moderate`. + +For this to work, a `dev.langchain4j.model.moderation.ModerationModel` must be configured, +which the extension does automatically if `@RegisterAiService(moderationModelSupplier = RegisterAiService.BeanModerationModelSupplier.class)` +is used. + == Observability Observability is built into services created via `@RegisterAiService` and is provided in the following form: diff --git a/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/ModerationModelTest.java b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/ModerationModelTest.java new file mode 100644 index 000000000..b2b5f1e86 --- /dev/null +++ b/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/ModerationModelTest.java @@ -0,0 +1,176 @@ +package org.acme.examples.aiservices; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.options; +import static io.quarkiverse.langchain4j.openai.test.WiremockUtils.DEFAULT_TOKEN; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.Map; + +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.tomakehurst.wiremock.WireMockServer; + +import dev.langchain4j.service.Moderate; +import dev.langchain4j.service.ModerationException; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.openai.test.WiremockUtils; +import io.quarkus.test.QuarkusUnitTest; + +public class ModerationModelTest { + + private static final int WIREMOCK_PORT = 8089; + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer( + () -> ShrinkWrap.create(JavaArchive.class).addClasses(WiremockUtils.class, MessageAssertUtils.class)) + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.api-key", "whatever") + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.base-url", "http://localhost:" + WIREMOCK_PORT + "/v1"); + private static final TypeReference> MAP_TYPE_REF = new TypeReference<>() { + }; + + static WireMockServer wireMockServer; + + static ObjectMapper mapper; + + @BeforeAll + static void beforeAll() { + wireMockServer = new WireMockServer(options().port(WIREMOCK_PORT)); + wireMockServer.start(); + + mapper = new ObjectMapper(); + } + + @AfterAll + static void afterAll() { + wireMockServer.stop(); + } + + @BeforeEach + void setup() { + wireMockServer.resetAll(); + wireMockServer.stubFor(WiremockUtils.defaultChatCompletionsStub()); + } + + @RegisterAiService(moderationModelSupplier = RegisterAiService.BeanModerationModelSupplier.class) + interface ChatWithModeration { + + @Moderate + String chat(String message); + } + + @Inject + ChatWithModeration chatWithModeration; + + @Test + @ActivateRequestContext + void should_throw_when_text_is_flagged() { + wireMockServer.stubFor(WiremockUtils.moderationMapping(DEFAULT_TOKEN) + .willReturn(aResponse() + .withHeader("Content-Type", "application/json") + .withBody( + """ + { + "id": "modr-8Bmx2bYNsgzuAsSuxaQRDCMKHgJbC", + "model": "text-moderation-006", + "results": [ + { + "flagged": true, + "categories": { + "sexual": false, + "hate": true, + "harassment": false, + "self-harm": false, + "sexual/minors": false, + "hate/threatening": true, + "violence/graphic": false, + "self-harm/intent": false, + "self-harm/instructions": false, + "harassment/threatening": false, + "violence": false + }, + "category_scores": { + "sexual": 0.0001485530665377155, + "hate": 0.00004570276360027492, + "harassment": 0.00006113418203312904, + "self-harm": 5.4490744361146426e-8, + "sexual/minors": 6.557503979820467e-7, + "hate/threatening": 7.536454127432535e-9, + "violence/graphic": 2.776141343474592e-7, + "self-harm/intent": 9.653235544249128e-9, + "self-harm/instructions": 1.2119762970996817e-9, + "harassment/threatening": 5.06949959344638e-7, + "violence": 0.000026839805286726914 + } + } + ] + } + """))); + + assertThatThrownBy(() -> chatWithModeration.chat("I WILL KILL YOU!!!")) + .isExactlyInstanceOf(ModerationException.class) + .hasMessage("Text \"" + "I WILL KILL YOU!!!" + "\" violates content policy"); + } + + @Test + @ActivateRequestContext + void should_not_throw_when_text_is_not_flagged() { + wireMockServer.stubFor(WiremockUtils.moderationMapping(DEFAULT_TOKEN) + .willReturn(aResponse() + .withHeader("Content-Type", "application/json") + .withBody( + """ + { + "id": "modr-8Bmx2bYNsgzuAsSuxaQRDCMKHgJbC", + "model": "text-moderation-006", + "results": [ + { + "flagged": false, + "categories": { + "sexual": false, + "hate": true, + "harassment": false, + "self-harm": false, + "sexual/minors": false, + "hate/threatening": false, + "violence/graphic": false, + "self-harm/intent": false, + "self-harm/instructions": false, + "harassment/threatening": false, + "violence": false + }, + "category_scores": { + "sexual": 0.0001485530665377155, + "hate": 0.00004570276360027492, + "harassment": 0.00006113418203312904, + "self-harm": 5.4490744361146426e-8, + "sexual/minors": 6.557503979820467e-7, + "hate/threatening": 7.536454127432535e-9, + "violence/graphic": 2.776141343474592e-7, + "self-harm/intent": 9.653235544249128e-9, + "self-harm/instructions": 1.2119762970996817e-9, + "harassment/threatening": 5.06949959344638e-7, + "violence": 0.000026839805286726914 + } + } + ] + } + """))); + + String result = chatWithModeration.chat("I will hug you"); + assertThat(result).isNotBlank(); + } +}