From 6bb23522faba49f523c833dfcb24a8e534f4e781 Mon Sep 17 00:00:00 2001 From: mariofusco Date: Fri, 15 Nov 2024 16:04:00 +0100 Subject: [PATCH] Allow rewriting of user messages from input guardrails --- .../InputGuardrailRewritingTest.java | 96 +++++++++++++++++++ .../guardrails/GuardrailResult.java | 8 +- .../guardrails/InputGuardrail.java | 8 ++ .../guardrails/InputGuardrailParams.java | 14 ++- .../guardrails/InputGuardrailResult.java | 21 ++-- .../guardrails/OutputGuardrailResult.java | 9 +- .../AiServiceMethodImplementationSupport.java | 25 ++--- .../runtime/aiservice/GuardrailsSupport.java | 8 +- 8 files changed, 157 insertions(+), 32 deletions(-) create mode 100644 core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailRewritingTest.java diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailRewritingTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailRewritingTest.java new file mode 100644 index 000000000..19d4189f1 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailRewritingTest.java @@ -0,0 +1,96 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.List; +import java.util.function.Supplier; + +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.InputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkus.test.QuarkusUnitTest; + +public class InputGuardrailRewritingTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, MessageTruncatingGuardrail.class, EchoChatModel.class, + MyChatModelSupplier.class, MyMemoryProviderSupplier.class)); + + @Inject + MyAiService aiService; + + @Test + @ActivateRequestContext + void testRewriting() { + assertEquals(MessageTruncatingGuardrail.MAX_LENGTH, aiService.test("first prompt", "second prompt").length()); + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @UserMessage("Given {first} and {second} do something") + @InputGuardrails(MessageTruncatingGuardrail.class) + String test(String first, String second); + + } + + @RequestScoped + public static class MessageTruncatingGuardrail implements InputGuardrail { + + static final int MAX_LENGTH = 20; + + @Override + public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) { + String text = um.singleText(); + return successWith(text.substring(0, MAX_LENGTH)); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatLanguageModel get() { + return new EchoChatModel(); + } + } + + public static class EchoChatModel implements ChatLanguageModel { + + @Override + public Response generate(List messages) { + return new Response<>(new AiMessage(((dev.langchain4j.data.message.UserMessage) messages.get(0)).singleText())); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return new ChatMemoryProvider() { + @Override + public ChatMemory get(Object memoryId) { + return new NoopChatMemory(); + } + }; + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java index 965860f31..b37601ab8 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java @@ -29,10 +29,14 @@ enum Result { FATAL } - boolean isSuccess(); + Result getResult(); + + default boolean isSuccess() { + return getResult() == Result.SUCCESS || getResult() == Result.SUCCESS_WITH_RESULT; + } default boolean hasRewrittenResult() { - return false; + return getResult() == Result.SUCCESS_WITH_RESULT; } default GuardrailResult blockRetry() { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java index 46e32649f..6c15d4f08 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java @@ -47,6 +47,14 @@ default InputGuardrailResult success() { return InputGuardrailResult.success(); } + /** + * @return The result of a successful input guardrail validation with a specific text. + * @param successfulText The text of the successful result. + */ + default InputGuardrailResult successWith(String successfulText) { + return InputGuardrailResult.successWith(successfulText); + } + /** * @param message A message describing the failure. * @return The result of a failed input guardrail validation. diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java index 62bdcbfca..f1e4d5c60 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java @@ -1,7 +1,11 @@ package io.quarkiverse.langchain4j.guardrails; +import java.util.List; import java.util.Map; +import dev.langchain4j.data.message.Content; +import dev.langchain4j.data.message.ContentType; +import dev.langchain4j.data.message.TextContent; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.rag.AugmentationResult; @@ -21,6 +25,14 @@ public record InputGuardrailParams(UserMessage userMessage, ChatMemory memory, @Override public InputGuardrailParams withText(String text) { - throw new UnsupportedOperationException(); + return new InputGuardrailParams(rewriteUserMessage(userMessage, text), memory, augmentationResult, userMessageTemplate, + variables); + } + + public static UserMessage rewriteUserMessage(UserMessage userMessage, String text) { + List rewrittenContent = userMessage.contents().stream() + .map(c -> c.type() == ContentType.TEXT ? new TextContent(text) : c).toList(); + return userMessage.name() == null ? new UserMessage(rewrittenContent) + : new UserMessage(userMessage.name(), rewrittenContent); } } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailResult.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailResult.java index 8c56b7953..ea00a0dce 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailResult.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailResult.java @@ -10,29 +10,38 @@ * @param result The result of the input guardrail validation. * @param failures The list of failures, empty if the validation succeeded. */ -public record InputGuardrailResult(Result result, List failures) implements GuardrailResult { +public record InputGuardrailResult(Result result, String successfulText, + List failures) implements GuardrailResult { private static final InputGuardrailResult SUCCESS = new InputGuardrailResult(); private InputGuardrailResult() { - this(Result.SUCCESS, Collections.emptyList()); + this(Result.SUCCESS, null, Collections.emptyList()); + } + + private InputGuardrailResult(String successfulText) { + this(Result.SUCCESS_WITH_RESULT, successfulText, Collections.emptyList()); } InputGuardrailResult(List failures, boolean fatal) { - this(fatal ? Result.FATAL : Result.FAILURE, failures); + this(fatal ? Result.FATAL : Result.FAILURE, null, failures); } public static InputGuardrailResult success() { return InputGuardrailResult.SUCCESS; } + public static InputGuardrailResult successWith(String successfulText) { + return new InputGuardrailResult(successfulText); + } + public static InputGuardrailResult failure(List failures) { return new InputGuardrailResult((List) failures, false); } @Override - public boolean isSuccess() { - return result == Result.SUCCESS; + public Result getResult() { + return result; } @Override @@ -54,7 +63,7 @@ public InputGuardrailResult validatedBy(Class guardrailClas @Override public String toString() { if (isSuccess()) { - return "success"; + return hasRewrittenResult() ? "Success with '" + successfulText + "'" : "Success"; } return failures.stream().map(Failure::toString).collect(Collectors.joining(", ")); } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java index 28d85dc63..f3dabc2d0 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java @@ -48,13 +48,8 @@ public static OutputGuardrailResult failure(List apply(AugmentationResult ar) { ChatMessage augmentedUserMessage = ar.chatMessage(); ChatMemory memory = context.chatMemory(memoryId); - GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, (UserMessage) augmentedUserMessage, + UserMessage guardrailsMessage = GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, + (UserMessage) augmentedUserMessage, memory, ar, templateVariables); - List messagesToSend = messagesToSend(augmentedUserMessage, needsMemorySeed); + List messagesToSend = messagesToSend(guardrailsMessage, needsMemorySeed); var stream = new TokenStreamMulti(messagesToSend, effectiveToolSpecifications, finalToolExecutors, ar.contents(), context, memoryId, methodCreateInfo.isSwitchToWorkerThread()); @@ -223,25 +224,19 @@ public Flow.Publisher apply(AugmentationResult ar) { templateVariables))); } - private List messagesToSend(ChatMessage augmentedUserMessage, + private List messagesToSend(UserMessage augmentedUserMessage, boolean needsMemorySeed) { - List messagesToSend; - ChatMemory chatMemory; - if (context.hasChatMemory()) { - chatMemory = context.chatMemory(memoryId); - messagesToSend = createMessagesToSendForExistingMemory(systemMessage, augmentedUserMessage, - chatMemory, needsMemorySeed, context, methodCreateInfo); - } else { - messagesToSend = createMessagesToSendForNoMemory(systemMessage, augmentedUserMessage, - needsMemorySeed, context, methodCreateInfo); - } - return messagesToSend; + return context.hasChatMemory() + ? createMessagesToSendForExistingMemory(systemMessage, augmentedUserMessage, + context.chatMemory(memoryId), needsMemorySeed, context, methodCreateInfo) + : createMessagesToSendForNoMemory(systemMessage, augmentedUserMessage, + needsMemorySeed, context, methodCreateInfo); } }); } } - GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, userMessage, + userMessage = GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, userMessage, context.hasChatMemory() ? context.chatMemory(memoryId) : null, augmentationResult, templateVariables); diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java index 0fc62a060..22434a881 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java @@ -1,6 +1,7 @@ package io.quarkiverse.langchain4j.runtime.aiservice; import static dev.langchain4j.data.message.UserMessage.userMessage; +import static io.quarkiverse.langchain4j.guardrails.InputGuardrailParams.rewriteUserMessage; import java.util.ArrayList; import java.util.Collections; @@ -32,7 +33,7 @@ public class GuardrailsSupport { - public static void invokeInputGuardrails(AiServiceMethodCreateInfo methodCreateInfo, UserMessage userMessage, + public static UserMessage invokeInputGuardrails(AiServiceMethodCreateInfo methodCreateInfo, UserMessage userMessage, ChatMemory chatMemory, AugmentationResult augmentationResult, Map templateVariables) { InputGuardrailResult result; try { @@ -48,6 +49,11 @@ public static void invokeInputGuardrails(AiServiceMethodCreateInfo methodCreateI if (!result.isSuccess()) { throw new GuardrailException(result.toString(), result.getFirstFailureException()); } + + if (result.hasRewrittenResult()) { + userMessage = rewriteUserMessage(userMessage, result.successfulText()); + } + return userMessage; } public static Response invokeOutputGuardrails(AiServiceMethodCreateInfo methodCreateInfo,