Skip to content

Commit

Permalink
Allow rewriting of user messages from input guardrails
Browse files Browse the repository at this point in the history
  • Loading branch information
mariofusco committed Nov 26, 2024
1 parent 8a29c07 commit 6bb2352
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 32 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package io.quarkiverse.langchain4j.test.guardrails;

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.util.List;
import java.util.function.Supplier;

import jakarta.enterprise.context.RequestScoped;
import jakarta.enterprise.context.control.ActivateRequestContext;
import jakarta.inject.Inject;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
import io.quarkiverse.langchain4j.guardrails.InputGuardrails;
import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory;
import io.quarkus.test.QuarkusUnitTest;

public class InputGuardrailRewritingTest {

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addClasses(MyAiService.class, MessageTruncatingGuardrail.class, EchoChatModel.class,
MyChatModelSupplier.class, MyMemoryProviderSupplier.class));

@Inject
MyAiService aiService;

@Test
@ActivateRequestContext
void testRewriting() {
assertEquals(MessageTruncatingGuardrail.MAX_LENGTH, aiService.test("first prompt", "second prompt").length());
}

@RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
public interface MyAiService {

@UserMessage("Given {first} and {second} do something")
@InputGuardrails(MessageTruncatingGuardrail.class)
String test(String first, String second);

}

@RequestScoped
public static class MessageTruncatingGuardrail implements InputGuardrail {

static final int MAX_LENGTH = 20;

@Override
public InputGuardrailResult validate(dev.langchain4j.data.message.UserMessage um) {
String text = um.singleText();
return successWith(text.substring(0, MAX_LENGTH));
}
}

public static class MyChatModelSupplier implements Supplier<ChatLanguageModel> {

@Override
public ChatLanguageModel get() {
return new EchoChatModel();
}
}

public static class EchoChatModel implements ChatLanguageModel {

@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
return new Response<>(new AiMessage(((dev.langchain4j.data.message.UserMessage) messages.get(0)).singleText()));
}
}

public static class MyMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {
@Override
public ChatMemoryProvider get() {
return new ChatMemoryProvider() {
@Override
public ChatMemory get(Object memoryId) {
return new NoopChatMemory();
}
};
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@ enum Result {
FATAL
}

boolean isSuccess();
Result getResult();

default boolean isSuccess() {
return getResult() == Result.SUCCESS || getResult() == Result.SUCCESS_WITH_RESULT;
}

default boolean hasRewrittenResult() {
return false;
return getResult() == Result.SUCCESS_WITH_RESULT;
}

default GuardrailResult<GR> blockRetry() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ default InputGuardrailResult success() {
return InputGuardrailResult.success();
}

/**
* @return The result of a successful input guardrail validation with a specific text.
* @param successfulText The text of the successful result.
*/
default InputGuardrailResult successWith(String successfulText) {
return InputGuardrailResult.successWith(successfulText);
}

/**
* @param message A message describing the failure.
* @return The result of a failed input guardrail validation.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package io.quarkiverse.langchain4j.guardrails;

import java.util.List;
import java.util.Map;

import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.ContentType;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.rag.AugmentationResult;
Expand All @@ -21,6 +25,14 @@ public record InputGuardrailParams(UserMessage userMessage, ChatMemory memory,

@Override
public InputGuardrailParams withText(String text) {
throw new UnsupportedOperationException();
return new InputGuardrailParams(rewriteUserMessage(userMessage, text), memory, augmentationResult, userMessageTemplate,
variables);
}

public static UserMessage rewriteUserMessage(UserMessage userMessage, String text) {
List<Content> rewrittenContent = userMessage.contents().stream()
.map(c -> c.type() == ContentType.TEXT ? new TextContent(text) : c).toList();
return userMessage.name() == null ? new UserMessage(rewrittenContent)
: new UserMessage(userMessage.name(), rewrittenContent);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,38 @@
* @param result The result of the input guardrail validation.
* @param failures The list of failures, empty if the validation succeeded.
*/
public record InputGuardrailResult(Result result, List<Failure> failures) implements GuardrailResult<InputGuardrailResult> {
public record InputGuardrailResult(Result result, String successfulText,
List<Failure> failures) implements GuardrailResult<InputGuardrailResult> {

private static final InputGuardrailResult SUCCESS = new InputGuardrailResult();

private InputGuardrailResult() {
this(Result.SUCCESS, Collections.emptyList());
this(Result.SUCCESS, null, Collections.emptyList());
}

private InputGuardrailResult(String successfulText) {
this(Result.SUCCESS_WITH_RESULT, successfulText, Collections.emptyList());
}

InputGuardrailResult(List<Failure> failures, boolean fatal) {
this(fatal ? Result.FATAL : Result.FAILURE, failures);
this(fatal ? Result.FATAL : Result.FAILURE, null, failures);
}

public static InputGuardrailResult success() {
return InputGuardrailResult.SUCCESS;
}

public static InputGuardrailResult successWith(String successfulText) {
return new InputGuardrailResult(successfulText);
}

public static InputGuardrailResult failure(List<? extends GuardrailResult.Failure> failures) {
return new InputGuardrailResult((List<Failure>) failures, false);
}

@Override
public boolean isSuccess() {
return result == Result.SUCCESS;
public Result getResult() {
return result;
}

@Override
Expand All @@ -54,7 +63,7 @@ public InputGuardrailResult validatedBy(Class<? extends Guardrail> guardrailClas
@Override
public String toString() {
if (isSuccess()) {
return "success";
return hasRewrittenResult() ? "Success with '" + successfulText + "'" : "Success";
}
return failures.stream().map(Failure::toString).collect(Collectors.joining(", "));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,8 @@ public static OutputGuardrailResult failure(List<? extends GuardrailResult.Failu
}

@Override
public boolean isSuccess() {
return result == Result.SUCCESS || result == Result.SUCCESS_WITH_RESULT;
}

@Override
public boolean hasRewrittenResult() {
return result == Result.SUCCESS_WITH_RESULT;
public Result getResult() {
return result;
}

public boolean isRetry() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,10 @@ public Flow.Publisher<?> apply(AugmentationResult ar) {
ChatMessage augmentedUserMessage = ar.chatMessage();

ChatMemory memory = context.chatMemory(memoryId);
GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, (UserMessage) augmentedUserMessage,
UserMessage guardrailsMessage = GuardrailsSupport.invokeInputGuardrails(methodCreateInfo,
(UserMessage) augmentedUserMessage,
memory, ar, templateVariables);
List<ChatMessage> messagesToSend = messagesToSend(augmentedUserMessage, needsMemorySeed);
List<ChatMessage> messagesToSend = messagesToSend(guardrailsMessage, needsMemorySeed);
var stream = new TokenStreamMulti(messagesToSend, effectiveToolSpecifications,
finalToolExecutors, ar.contents(), context, memoryId,
methodCreateInfo.isSwitchToWorkerThread());
Expand All @@ -223,25 +224,19 @@ public Flow.Publisher<?> apply(AugmentationResult ar) {
templateVariables)));
}

private List<ChatMessage> messagesToSend(ChatMessage augmentedUserMessage,
private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
boolean needsMemorySeed) {
List<ChatMessage> messagesToSend;
ChatMemory chatMemory;
if (context.hasChatMemory()) {
chatMemory = context.chatMemory(memoryId);
messagesToSend = createMessagesToSendForExistingMemory(systemMessage, augmentedUserMessage,
chatMemory, needsMemorySeed, context, methodCreateInfo);
} else {
messagesToSend = createMessagesToSendForNoMemory(systemMessage, augmentedUserMessage,
needsMemorySeed, context, methodCreateInfo);
}
return messagesToSend;
return context.hasChatMemory()
? createMessagesToSendForExistingMemory(systemMessage, augmentedUserMessage,
context.chatMemory(memoryId), needsMemorySeed, context, methodCreateInfo)
: createMessagesToSendForNoMemory(systemMessage, augmentedUserMessage,
needsMemorySeed, context, methodCreateInfo);
}
});
}
}

GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, userMessage,
userMessage = GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, userMessage,
context.hasChatMemory() ? context.chatMemory(memoryId) : null,
augmentationResult, templateVariables);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.quarkiverse.langchain4j.runtime.aiservice;

import static dev.langchain4j.data.message.UserMessage.userMessage;
import static io.quarkiverse.langchain4j.guardrails.InputGuardrailParams.rewriteUserMessage;

import java.util.ArrayList;
import java.util.Collections;
Expand Down Expand Up @@ -32,7 +33,7 @@

public class GuardrailsSupport {

public static void invokeInputGuardrails(AiServiceMethodCreateInfo methodCreateInfo, UserMessage userMessage,
public static UserMessage invokeInputGuardrails(AiServiceMethodCreateInfo methodCreateInfo, UserMessage userMessage,
ChatMemory chatMemory, AugmentationResult augmentationResult, Map<String, Object> templateVariables) {
InputGuardrailResult result;
try {
Expand All @@ -48,6 +49,11 @@ public static void invokeInputGuardrails(AiServiceMethodCreateInfo methodCreateI
if (!result.isSuccess()) {
throw new GuardrailException(result.toString(), result.getFirstFailureException());
}

if (result.hasRewrittenResult()) {
userMessage = rewriteUserMessage(userMessage, result.successfulText());
}
return userMessage;
}

public static Response<AiMessage> invokeOutputGuardrails(AiServiceMethodCreateInfo methodCreateInfo,
Expand Down

0 comments on commit 6bb2352

Please sign in to comment.