From 180d25e92ae86d748058568f74bc5e61efd31803 Mon Sep 17 00:00:00 2001 From: Dennys Fredericci Date: Sat, 19 Oct 2024 20:37:42 +0200 Subject: [PATCH 1/3] 946 - Add prompt template and variables to input guardrails --- core/deployment/pom.xml | 6 +- .../InputGuardrailPromptTemplateTest.java | 238 ++++++++++++++++++ .../guardrails/InputGuardrail.java | 1 + .../AiServiceMethodImplementationSupport.java | 37 ++- .../runtime/aiservice/GuardrailsSupport.java | 13 +- 5 files changed, 279 insertions(+), 16 deletions(-) create mode 100644 core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailPromptTemplateTest.java diff --git a/core/deployment/pom.xml b/core/deployment/pom.xml index e12725d97..3153acca2 100644 --- a/core/deployment/pom.xml +++ b/core/deployment/pom.xml @@ -63,7 +63,11 @@ quarkus-junit5-internal test - + + io.quarkus + quarkus-websockets-next + test + org.assertj assertj-core diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailPromptTemplateTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailPromptTemplateTest.java new file mode 100644 index 000000000..d07bc5cc5 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailPromptTemplateTest.java @@ -0,0 +1,238 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.InputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkus.test.QuarkusUnitTest; + +public class InputGuardrailPromptTemplateTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, MyAiService.class, GuardrailValidation.class, + MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class)); + @Inject + MyAiService aiService; + + @Inject + GuardrailValidation guardrailValidation; + + @Test + @ActivateRequestContext + void shouldWorkNoParameters() { + aiService.getJoke(); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me a joke"); + assertThat(guardrailValidation.spyVariables()).isEmpty(); + } + + @Test + @ActivateRequestContext + void shouldWorkWithMemoryId() { + aiService.getAnotherJoke("memory-id-001"); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me another joke"); + assertThat(guardrailValidation.spyVariables()).containsExactlyInAnyOrderEntriesOf(Map.of( + "memoryId", "memory-id-001", + "it", "memory-id-001" // is this correct? + )); + } + + @Test + @ActivateRequestContext + void shouldWorkWithNoMemoryIdAndOneParameter() { + aiService.sayHiToMyFriendNoMemory("Rambo"); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Say hi to my friend {friend}!"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "friend", "Rambo", + "it", "Rambo")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithMemoryIdAndOneParameter() { + aiService.sayHiToMyFriend("1", "Chuck Norris"); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Say hi to my friend {friend}!"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "friend", "Chuck Norris", + "mem", "1")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithNoMemoryIdAndThreeParameters() { + aiService.sayHiToMyFriends("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"); + assertThat(guardrailValidation.spyUserMessageTemplate()) + .isEqualTo("Tell me something about {topic1}, {topic2}, {topic3}!"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "topic1", "Chuck Norris", + "topic2", "Jean-Claude Van Damme", + "topic3", "Silvester Stallone")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithNoMemoryIdAndList() { + aiService.sayHiToMyFriends(List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone")); + assertThat(guardrailValidation.spyUserMessageText()) + .isEqualTo("Tell me something about [Chuck Norris, Jean-Claude Van Damme, Silvester Stallone]!"); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me something about {topics}!"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"), + "it", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"))); + } + + @Test + @ActivateRequestContext + void shouldWorkWithMemoryIdAndList() { + aiService.sayHiToMyFriends("memory-id-007", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone")); + assertThat(guardrailValidation.spyUserMessageText()).isEqualTo( + "Tell me something about [Chuck Norris, Jean-Claude Van Damme, Silvester Stallone]! This is my memory id: memory-id-007"); + assertThat(guardrailValidation.spyUserMessageTemplate()) + .isEqualTo("Tell me something about {topics}! This is my memory id: {memoryId}"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"), + "memoryId", "memory-id-007")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithMemoryIdAndOneItemFromList() { + aiService.sayHiToMyFriend("memory-id-007", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone")); + assertThat(guardrailValidation.spyUserMessageText()) + .isEqualTo("Tell me something about Chuck Norris! This is my memory id: memory-id-007"); + assertThat(guardrailValidation.spyUserMessageTemplate()) + .isEqualTo("Tell me something about {topics[0]}! This is my memory id: {memoryId}"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"), + "memoryId", "memory-id-007")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithNoUserMessage() { + // This is a special case where the UserMessage annotation is not present + // The prompt template doesn't exist in this case + // But the current implementation use the parameter name as prompt template + // Not sure if this is the correct behavior, should we always have @UserMessage? + // I need some thoughts on this case + aiService.saySomething("Is this a parameter or a prompt?"); + assertThat(guardrailValidation.spyUserMessageTemplate()).isNull(); + assertThat(guardrailValidation.spyVariables()).isEmpty(); + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @InputGuardrails(GuardrailValidation.class) + @UserMessage("Tell me a joke") + String getJoke(); + + @UserMessage("Tell me another joke") + @InputGuardrails(GuardrailValidation.class) + String getAnotherJoke(@MemoryId String memoryId); + + @UserMessage("Say hi to my friend {friend}!") + @InputGuardrails(GuardrailValidation.class) + String sayHiToMyFriendNoMemory(String friend); + + @UserMessage("Say hi to my friend {friend}!") + @InputGuardrails(GuardrailValidation.class) + String sayHiToMyFriend(@MemoryId String mem, String friend); + + @UserMessage("Tell me something about {topic1}, {topic2}, {topic3}!") + @InputGuardrails(GuardrailValidation.class) + String sayHiToMyFriends(String topic1, String topic2, String topic3); + + @UserMessage("Tell me something about {topics}!") + @InputGuardrails(GuardrailValidation.class) + String sayHiToMyFriends(List topics); + + @UserMessage("Tell me something about {topics}! This is my memory id: {memoryId}") + @InputGuardrails(GuardrailValidation.class) + String sayHiToMyFriends(@MemoryId String memoryId, List topics); + + @UserMessage("Tell me something about {topics[0]}! This is my memory id: {memoryId}") + @InputGuardrails(GuardrailValidation.class) + String sayHiToMyFriend(@MemoryId String memoryId, List topics); + + @InputGuardrails(GuardrailValidation.class) + String saySomething(String isThisAPromptOrAParameter); + + } + + @RequestScoped + public static class GuardrailValidation implements InputGuardrail { + + InputGuardrailParams params; + + public InputGuardrailResult validate(InputGuardrailParams params) { + this.params = params; + return success(); + } + + public String spyUserMessageTemplate() { + return params.userMessageTemplate(); + } + + public String spyUserMessageText() { + return params.userMessage().singleText(); + } + + public Map spyVariables() { + return params.variables(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatLanguageModel get() { + return new MyChatModel(); + } + } + + public static class MyChatModel implements ChatLanguageModel { + + @Override + public Response generate(List messages) { + return new Response<>(new AiMessage("Hi!")); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return memoryId -> new NoopChatMemory(); + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java index 46e32649f..2b079c67d 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java @@ -1,6 +1,7 @@ package io.quarkiverse.langchain4j.guardrails; import java.util.Arrays; +import java.util.Map; import dev.langchain4j.data.message.UserMessage; import io.smallrye.common.annotation.Experimental; diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java index 8ad728780..526a82ec9 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java @@ -146,10 +146,12 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob Optional systemMessage = prepareSystemMessage(methodCreateInfo, methodArgs, context.hasChatMemory() ? context.chatMemory(memoryId).messages() : Collections.emptyList()); UserMessage userMessage = prepareUserMessage(context, methodCreateInfo, methodArgs); + Map templateParams = getTemplateParams(methodArgs, methodCreateInfo.getUserMessageInfo()); Type returnType = methodCreateInfo.getReturnType(); if (isImage(returnType) || isResultImage(returnType)) { - return doImplementGenerateImage(methodCreateInfo, context, audit, systemMessage, userMessage, memoryId, returnType); + return doImplementGenerateImage(methodCreateInfo, context, audit, systemMessage, userMessage, memoryId, returnType, + templateParams); } if (audit != null) { @@ -203,8 +205,9 @@ public AugmentationResult get() { @Override public Flow.Publisher apply(AugmentationResult ar) { ChatMessage augmentedUserMessage = ar.chatMessage(); + GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, (UserMessage) augmentedUserMessage, - context.chatMemory(memoryId), ar); + context.chatMemory(memoryId), ar, templateParams); List messagesToSend = messagesToSend(augmentedUserMessage, needsMemorySeed); return new TokenStreamMulti(messagesToSend, effectiveToolSpecifications, finalToolExecutors, ar.contents(), context, memoryId); @@ -230,7 +233,7 @@ private List messagesToSend(ChatMessage augmentedUserMessage, GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, userMessage, context.hasChatMemory() ? context.chatMemory(memoryId) : null, - augmentationResult); + augmentationResult, templateParams); CommittableChatMemory chatMemory; List messagesToSend; @@ -379,7 +382,7 @@ private List messagesToSend(ChatMessage augmentedUserMessage, private static Object doImplementGenerateImage(AiServiceMethodCreateInfo methodCreateInfo, QuarkusAiServiceContext context, Audit audit, Optional systemMessage, UserMessage userMessage, - Object memoryId, Type returnType) { + Object memoryId, Type returnType, Map templateParams) { String imagePrompt; if (systemMessage.isPresent()) { imagePrompt = systemMessage.get().text() + "\n" + userMessage.singleText(); @@ -397,7 +400,7 @@ private static Object doImplementGenerateImage(AiServiceMethodCreateInfo methodC // TODO: we can only support input guardrails for now as it is tied to AiMessage GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, userMessage, context.hasChatMemory() ? context.chatMemory(memoryId) : null, - augmentationResult); + augmentationResult, templateParams); Response imageResponse = context.imageModel.generate(imagePrompt); if (audit != null) { @@ -589,12 +592,7 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic if (userMessageInfo.template().isPresent()) { AiServiceMethodCreateInfo.TemplateInfo templateInfo = userMessageInfo.template().get(); - Map templateParams = new HashMap<>(); - Map nameToParamPosition = templateInfo.nameToParamPosition(); - for (var entry : nameToParamPosition.entrySet()) { - Object value = transformTemplateParamValue(methodArgs[entry.getValue()]); - templateParams.put(entry.getKey(), value); - } + Map templateParams = getTemplateParams(methodArgs, userMessageInfo); String templateText; if (templateInfo.text().isPresent()) { templateText = templateInfo.text().get(); @@ -642,6 +640,23 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic } } + private static Map getTemplateParams(Object[] methodArgs, + AiServiceMethodCreateInfo.UserMessageInfo userMessageInfo) { + Map templateParams = new HashMap<>(); + + if (userMessageInfo.template().isPresent()) { + AiServiceMethodCreateInfo.TemplateInfo templateInfo = userMessageInfo.template().get(); + Map nameToParamPosition = templateInfo.nameToParamPosition(); + + for (var entry : nameToParamPosition.entrySet()) { + Object value = transformTemplateParamValue(methodArgs[entry.getValue()]); + templateParams.put(entry.getKey(), value); + } + } + + return templateParams; + } + private static UserMessage createUserMessage(String name, ImageContent imageContent, String text) { if (name == null) { if (imageContent == null) { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java index b1240f8ff..17b72fa69 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java @@ -2,8 +2,7 @@ import static dev.langchain4j.data.message.UserMessage.userMessage; -import java.util.ArrayList; -import java.util.List; +import java.util.*; import java.util.function.Function; import jakarta.enterprise.inject.spi.CDI; @@ -30,11 +29,17 @@ public class GuardrailsSupport { public static void invokeInputGuardrails(AiServiceMethodCreateInfo methodCreateInfo, UserMessage userMessage, - ChatMemory chatMemory, AugmentationResult augmentationResult) { + ChatMemory chatMemory, AugmentationResult augmentationResult, Map templateParams) { InputGuardrailResult result; try { + + Optional userMessageTemplateOpt = methodCreateInfo.getUserMessageInfo().template() + .flatMap(AiServiceMethodCreateInfo.TemplateInfo::text); + + String userMessageTemplate = userMessageTemplateOpt.orElse(null); + result = invokeInputGuardRails(methodCreateInfo, - new InputGuardrailParams(userMessage, chatMemory, augmentationResult)); + new InputGuardrailParams(userMessage, chatMemory, augmentationResult, promptTemplate, templateParams)); } catch (Exception e) { throw new GuardrailException(e.getMessage(), e); } From 5f55595cf830b8bdd1201f4d7334a174b81c5b03 Mon Sep 17 00:00:00 2001 From: Dennys Fredericci Date: Mon, 21 Oct 2024 22:37:38 +0200 Subject: [PATCH 2/3] 946 - Add prompt template and variables to output guardrails --- .../InputGuardrailPromptTemplateTest.java | 10 +- .../OutputGuardrailPromptTemplateTest.java | 227 ++++++++++++++++++ .../guardrails/InputGuardrail.java | 1 - .../guardrails/InputGuardrailParams.java | 7 +- .../guardrails/OutputGuardrailParams.java | 7 +- .../aiservice/AiServiceMethodCreateInfo.java | 7 + .../AiServiceMethodImplementationSupport.java | 34 +-- .../runtime/aiservice/GuardrailsSupport.java | 12 +- 8 files changed, 273 insertions(+), 32 deletions(-) create mode 100644 core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailPromptTemplateTest.java diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailPromptTemplateTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailPromptTemplateTest.java index d07bc5cc5..92ff09b53 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailPromptTemplateTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailPromptTemplateTest.java @@ -24,6 +24,7 @@ import dev.langchain4j.service.UserMessage; import io.quarkiverse.langchain4j.RegisterAiService; import io.quarkiverse.langchain4j.guardrails.InputGuardrail; +import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams; import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult; import io.quarkiverse.langchain4j.guardrails.InputGuardrails; import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; @@ -57,8 +58,7 @@ void shouldWorkWithMemoryId() { assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me another joke"); assertThat(guardrailValidation.spyVariables()).containsExactlyInAnyOrderEntriesOf(Map.of( "memoryId", "memory-id-001", - "it", "memory-id-001" // is this correct? - )); + "it", "memory-id-001")); } @Test @@ -140,11 +140,7 @@ void shouldWorkWithMemoryIdAndOneItemFromList() { @Test @ActivateRequestContext void shouldWorkWithNoUserMessage() { - // This is a special case where the UserMessage annotation is not present - // The prompt template doesn't exist in this case - // But the current implementation use the parameter name as prompt template - // Not sure if this is the correct behavior, should we always have @UserMessage? - // I need some thoughts on this case + // UserMessage annotation is not provided, then no user message template should be available aiService.saySomething("Is this a parameter or a prompt?"); assertThat(guardrailValidation.spyUserMessageTemplate()).isNull(); assertThat(guardrailValidation.spyVariables()).isEmpty(); diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailPromptTemplateTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailPromptTemplateTest.java new file mode 100644 index 000000000..3c170f17b --- /dev/null +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailPromptTemplateTest.java @@ -0,0 +1,227 @@ +package io.quarkiverse.langchain4j.test.guardrails; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + +import jakarta.enterprise.context.RequestScoped; +import jakarta.enterprise.context.control.ActivateRequestContext; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.memory.chat.ChatMemoryProvider; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; +import io.quarkiverse.langchain4j.guardrails.OutputGuardrails; +import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory; +import io.quarkus.test.QuarkusUnitTest; + +public class OutputGuardrailPromptTemplateTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class) + .addClasses(MyAiService.class, MyAiService.class, GuardrailValidation.class, + MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class)); + @Inject + MyAiService aiService; + + @Inject + GuardrailValidation guardrailValidation; + + @Test + @ActivateRequestContext + void shouldWorkNoParameters() { + aiService.getJoke(); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me a joke"); + assertThat(guardrailValidation.spyVariables()).isEmpty(); + } + + @Test + @ActivateRequestContext + void shouldWorkWithMemoryId() { + aiService.getAnotherJoke("memory-id-001"); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me another joke"); + assertThat(guardrailValidation.spyVariables()).containsExactlyInAnyOrderEntriesOf(Map.of( + "memoryId", "memory-id-001", + "it", "memory-id-001")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithNoMemoryIdAndOneParameter() { + aiService.sayHiToMyFriendNoMemory("Rambo"); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Say hi to my friend {friend}!"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "friend", "Rambo", + "it", "Rambo")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithMemoryIdAndOneParameter() { + aiService.sayHiToMyFriend("1", "Chuck Norris"); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Say hi to my friend {friend}!"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "friend", "Chuck Norris", + "mem", "1")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithNoMemoryIdAndThreeParameters() { + aiService.sayHiToMyFriends("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"); + assertThat(guardrailValidation.spyUserMessageTemplate()) + .isEqualTo("Tell me something about {topic1}, {topic2}, {topic3}!"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "topic1", "Chuck Norris", + "topic2", "Jean-Claude Van Damme", + "topic3", "Silvester Stallone")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithNoMemoryIdAndList() { + aiService.sayHiToMyFriends(List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone")); + + assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me something about {topics}!"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"), + "it", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"))); + } + + @Test + @ActivateRequestContext + void shouldWorkWithMemoryIdAndList() { + aiService.sayHiToMyFriends("memory-id-007", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone")); + + assertThat(guardrailValidation.spyUserMessageTemplate()) + .isEqualTo("Tell me something about {topics}! This is my memory id: {memoryId}"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"), + "memoryId", "memory-id-007")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithMemoryIdAndOneItemFromList() { + aiService.sayHiToMyFriend("memory-id-007", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone")); + + assertThat(guardrailValidation.spyUserMessageTemplate()) + .isEqualTo("Tell me something about {topics[0]}! This is my memory id: {memoryId}"); + assertThat(guardrailValidation.spyVariables()) + .containsExactlyInAnyOrderEntriesOf(Map.of( + "topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"), + "memoryId", "memory-id-007")); + } + + @Test + @ActivateRequestContext + void shouldWorkWithNoUserMessage() { + // UserMessage annotation is not provided, then no user message template should be available + aiService.saySomething("Is this a parameter or a prompt?"); + assertThat(guardrailValidation.spyUserMessageTemplate()).isNull(); + assertThat(guardrailValidation.spyVariables()).isEmpty(); + } + + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) + public interface MyAiService { + + @OutputGuardrails(GuardrailValidation.class) + @UserMessage("Tell me a joke") + String getJoke(); + + @UserMessage("Tell me another joke") + @OutputGuardrails(GuardrailValidation.class) + String getAnotherJoke(@MemoryId String memoryId); + + @UserMessage("Say hi to my friend {friend}!") + @OutputGuardrails(GuardrailValidation.class) + String sayHiToMyFriendNoMemory(String friend); + + @UserMessage("Say hi to my friend {friend}!") + @OutputGuardrails(GuardrailValidation.class) + String sayHiToMyFriend(@MemoryId String mem, String friend); + + @UserMessage("Tell me something about {topic1}, {topic2}, {topic3}!") + @OutputGuardrails(GuardrailValidation.class) + String sayHiToMyFriends(String topic1, String topic2, String topic3); + + @UserMessage("Tell me something about {topics}!") + @OutputGuardrails(GuardrailValidation.class) + String sayHiToMyFriends(List topics); + + @UserMessage("Tell me something about {topics}! This is my memory id: {memoryId}") + @OutputGuardrails(GuardrailValidation.class) + String sayHiToMyFriends(@MemoryId String memoryId, List topics); + + @UserMessage("Tell me something about {topics[0]}! This is my memory id: {memoryId}") + @OutputGuardrails(GuardrailValidation.class) + String sayHiToMyFriend(@MemoryId String memoryId, List topics); + + @OutputGuardrails(GuardrailValidation.class) + String saySomething(String isThisAPromptOrAParameter); + + } + + @RequestScoped + public static class GuardrailValidation implements OutputGuardrail { + + OutputGuardrailParams params; + + public OutputGuardrailResult validate(OutputGuardrailParams params) { + this.params = params; + return success(); + } + + public String spyUserMessageTemplate() { + return params.userMessageTemplate(); + } + + public Map spyVariables() { + return params.variables(); + } + } + + public static class MyChatModelSupplier implements Supplier { + + @Override + public ChatLanguageModel get() { + return new MyChatModel(); + } + } + + public static class MyChatModel implements ChatLanguageModel { + + @Override + public Response generate(List messages) { + return new Response<>(new AiMessage("Hi!")); + } + } + + public static class MyMemoryProviderSupplier implements Supplier { + @Override + public ChatMemoryProvider get() { + return memoryId -> new NoopChatMemory(); + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java index 2b079c67d..46e32649f 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrail.java @@ -1,7 +1,6 @@ package io.quarkiverse.langchain4j.guardrails; import java.util.Arrays; -import java.util.Map; import dev.langchain4j.data.message.UserMessage; import io.smallrye.common.annotation.Experimental; diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java index a3cda8037..b27371263 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java @@ -1,5 +1,7 @@ package io.quarkiverse.langchain4j.guardrails; +import java.util.Map; + import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.rag.AugmentationResult; @@ -10,7 +12,10 @@ * @param userMessage the user message, cannot be {@code null} * @param memory the memory, can be {@code null} or empty * @param augmentationResult the augmentation result, can be {@code null} + * @param userMessageTemplate the user message template, can be {@code null} when @UserMessage is not provided. + * @param variables the variable to be used with userMessageTemplate, can be {@code null} or empty */ public record InputGuardrailParams(UserMessage userMessage, ChatMemory memory, - AugmentationResult augmentationResult) implements GuardrailParams { + AugmentationResult augmentationResult, String userMessageTemplate, + Map variables) implements GuardrailParams { } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailParams.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailParams.java index ee0c960cb..3bc39f2cd 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailParams.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailParams.java @@ -1,5 +1,7 @@ package io.quarkiverse.langchain4j.guardrails; +import java.util.Map; + import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.rag.AugmentationResult; @@ -10,7 +12,10 @@ * @param responseFromLLM the response from the LLM * @param memory the memory, can be {@code null} or empty * @param augmentationResult the augmentation result, can be {@code null} + * @param userMessageTemplate the user message template, can be {@code null} when @UserMessage is not provided. + * @param variables the variable to be used with userMessageTemplate, can be {@code null} or empty */ public record OutputGuardrailParams(AiMessage responseFromLLM, ChatMemory memory, - AugmentationResult augmentationResult) implements GuardrailParams { + AugmentationResult augmentationResult, String userMessageTemplate, + Map variables) implements GuardrailParams { } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java index 3d7545ed0..dc095ec2b 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java @@ -193,6 +193,13 @@ public OutputTokenAccumulator getOutputTokenAccumulator() { return accumulator; } + public String getUserMessageTemplate() { + Optional userMessageTemplateOpt = this.getUserMessageInfo().template() + .flatMap(AiServiceMethodCreateInfo.TemplateInfo::text); + + return userMessageTemplateOpt.orElse(null); + } + public record UserMessageInfo(Optional template, Optional paramPosition, Optional userNameParamPosition, diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java index 526a82ec9..97e48f177 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java @@ -146,12 +146,12 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob Optional systemMessage = prepareSystemMessage(methodCreateInfo, methodArgs, context.hasChatMemory() ? context.chatMemory(memoryId).messages() : Collections.emptyList()); UserMessage userMessage = prepareUserMessage(context, methodCreateInfo, methodArgs); - Map templateParams = getTemplateParams(methodArgs, methodCreateInfo.getUserMessageInfo()); + Map templateVariables = getTemplateVariables(methodArgs, methodCreateInfo.getUserMessageInfo()); Type returnType = methodCreateInfo.getReturnType(); if (isImage(returnType) || isResultImage(returnType)) { return doImplementGenerateImage(methodCreateInfo, context, audit, systemMessage, userMessage, memoryId, returnType, - templateParams); + templateVariables); } if (audit != null) { @@ -207,7 +207,7 @@ public Flow.Publisher apply(AugmentationResult ar) { ChatMessage augmentedUserMessage = ar.chatMessage(); GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, (UserMessage) augmentedUserMessage, - context.chatMemory(memoryId), ar, templateParams); + context.chatMemory(memoryId), ar, templateVariables); List messagesToSend = messagesToSend(augmentedUserMessage, needsMemorySeed); return new TokenStreamMulti(messagesToSend, effectiveToolSpecifications, finalToolExecutors, ar.contents(), context, memoryId); @@ -233,7 +233,7 @@ private List messagesToSend(ChatMessage augmentedUserMessage, GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, userMessage, context.hasChatMemory() ? context.chatMemory(memoryId) : null, - augmentationResult, templateParams); + augmentationResult, templateVariables); CommittableChatMemory chatMemory; List messagesToSend; @@ -272,7 +272,8 @@ private List messagesToSend(ChatMessage augmentedUserMessage, OutputGuardrailResult result; try { result = GuardrailsSupport.invokeOutputGuardrailsForStream(methodCreateInfo, - new OutputGuardrailParams(AiMessage.from(chunk), chatMemory, actualAugmentationResult)); + new OutputGuardrailParams(AiMessage.from(chunk), chatMemory, actualAugmentationResult, + methodCreateInfo.getUserMessageTemplate(), templateVariables)); } catch (Exception e) { throw new GuardrailException(e.getMessage(), e); } @@ -359,9 +360,12 @@ private List messagesToSend(ChatMessage augmentedUserMessage, tokenUsageAccumulator = tokenUsageAccumulator.add(response.tokenUsage()); } + String userMessageTemplate = methodCreateInfo.getUserMessageTemplate(); + response = GuardrailsSupport.invokeOutputGuardrails(methodCreateInfo, chatMemory, context.chatModel, response, toolSpecifications, - new OutputGuardrailParams(response.content(), chatMemory, augmentationResult)); + new OutputGuardrailParams(response.content(), chatMemory, augmentationResult, userMessageTemplate, + templateVariables)); // everything worked as expected so let's commit the messages chatMemory.commit(); @@ -382,7 +386,7 @@ private List messagesToSend(ChatMessage augmentedUserMessage, private static Object doImplementGenerateImage(AiServiceMethodCreateInfo methodCreateInfo, QuarkusAiServiceContext context, Audit audit, Optional systemMessage, UserMessage userMessage, - Object memoryId, Type returnType, Map templateParams) { + Object memoryId, Type returnType, Map templateVariables) { String imagePrompt; if (systemMessage.isPresent()) { imagePrompt = systemMessage.get().text() + "\n" + userMessage.singleText(); @@ -400,7 +404,7 @@ private static Object doImplementGenerateImage(AiServiceMethodCreateInfo methodC // TODO: we can only support input guardrails for now as it is tied to AiMessage GuardrailsSupport.invokeInputGuardrails(methodCreateInfo, userMessage, context.hasChatMemory() ? context.chatMemory(memoryId) : null, - augmentationResult, templateParams); + augmentationResult, templateVariables); Response imageResponse = context.imageModel.generate(imagePrompt); if (audit != null) { @@ -592,7 +596,7 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic if (userMessageInfo.template().isPresent()) { AiServiceMethodCreateInfo.TemplateInfo templateInfo = userMessageInfo.template().get(); - Map templateParams = getTemplateParams(methodArgs, userMessageInfo); + Map templateVariables = getTemplateVariables(methodArgs, userMessageInfo); String templateText; if (templateInfo.text().isPresent()) { templateText = templateInfo.text().get(); @@ -615,9 +619,9 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic } // we do not need to apply the instructions as they have already been added to the template text at build time - templateParams.put(ResponseSchemaUtil.templateParam(), + templateVariables.put(ResponseSchemaUtil.templateParam(), createInfo.getResponseSchemaInfo().outputFormatInstructions()); - Prompt prompt = PromptTemplate.from(templateText).apply(templateParams); + Prompt prompt = PromptTemplate.from(templateText).apply(templateVariables); return createUserMessage(userName, imageContent, prompt.text()); } else if (userMessageInfo.paramPosition().isPresent()) { @@ -640,9 +644,9 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic } } - private static Map getTemplateParams(Object[] methodArgs, + private static Map getTemplateVariables(Object[] methodArgs, AiServiceMethodCreateInfo.UserMessageInfo userMessageInfo) { - Map templateParams = new HashMap<>(); + Map variables = new HashMap<>(); if (userMessageInfo.template().isPresent()) { AiServiceMethodCreateInfo.TemplateInfo templateInfo = userMessageInfo.template().get(); @@ -650,11 +654,11 @@ private static Map getTemplateParams(Object[] methodArgs, for (var entry : nameToParamPosition.entrySet()) { Object value = transformTemplateParamValue(methodArgs[entry.getValue()]); - templateParams.put(entry.getKey(), value); + variables.put(entry.getKey(), value); } } - return templateParams; + return variables; } private static UserMessage createUserMessage(String name, ImageContent imageContent, String text) { diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java index 17b72fa69..e8d0f8b2c 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java @@ -29,17 +29,15 @@ public class GuardrailsSupport { public static void invokeInputGuardrails(AiServiceMethodCreateInfo methodCreateInfo, UserMessage userMessage, - ChatMemory chatMemory, AugmentationResult augmentationResult, Map templateParams) { + ChatMemory chatMemory, AugmentationResult augmentationResult, Map templateVariables) { InputGuardrailResult result; try { - Optional userMessageTemplateOpt = methodCreateInfo.getUserMessageInfo().template() - .flatMap(AiServiceMethodCreateInfo.TemplateInfo::text); - - String userMessageTemplate = userMessageTemplateOpt.orElse(null); + String userMessageTemplate = methodCreateInfo.getUserMessageTemplate(); result = invokeInputGuardRails(methodCreateInfo, - new InputGuardrailParams(userMessage, chatMemory, augmentationResult, promptTemplate, templateParams)); + new InputGuardrailParams(userMessage, chatMemory, augmentationResult, userMessageTemplate, + templateVariables)); } catch (Exception e) { throw new GuardrailException(e.getMessage(), e); } @@ -90,7 +88,7 @@ public static Response invokeOutputGuardrails(AiServiceMethodCreateIn } attempt++; output = new OutputGuardrailParams(response.content(), output.memory(), - output.augmentationResult()); + output.augmentationResult(), output.userMessageTemplate(), output.variables()); } else { break; } From 54116eeb74be918662e8408cebc8c198665a14e5 Mon Sep 17 00:00:00 2001 From: Dennys Fredericci Date: Wed, 30 Oct 2024 16:52:36 +0100 Subject: [PATCH 3/3] 946 - User message template and variables cannot be null. --- .../test/guardrails/InputGuardrailPromptTemplateTest.java | 2 +- .../test/guardrails/OutputGuardrailPromptTemplateTest.java | 2 +- .../langchain4j/guardrails/InputGuardrailParams.java | 4 ++-- .../langchain4j/guardrails/OutputGuardrailParams.java | 4 ++-- .../runtime/aiservice/AiServiceMethodCreateInfo.java | 4 +++- .../aiservice/AiServiceMethodImplementationSupport.java | 5 +++-- .../langchain4j/runtime/aiservice/GuardrailsSupport.java | 2 +- 7 files changed, 13 insertions(+), 10 deletions(-) diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailPromptTemplateTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailPromptTemplateTest.java index 92ff09b53..921d77b0c 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailPromptTemplateTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/InputGuardrailPromptTemplateTest.java @@ -142,7 +142,7 @@ void shouldWorkWithMemoryIdAndOneItemFromList() { void shouldWorkWithNoUserMessage() { // UserMessage annotation is not provided, then no user message template should be available aiService.saySomething("Is this a parameter or a prompt?"); - assertThat(guardrailValidation.spyUserMessageTemplate()).isNull(); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEmpty(); assertThat(guardrailValidation.spyVariables()).isEmpty(); } diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailPromptTemplateTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailPromptTemplateTest.java index 3c170f17b..1bf2ea7d4 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailPromptTemplateTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailPromptTemplateTest.java @@ -139,7 +139,7 @@ void shouldWorkWithMemoryIdAndOneItemFromList() { void shouldWorkWithNoUserMessage() { // UserMessage annotation is not provided, then no user message template should be available aiService.saySomething("Is this a parameter or a prompt?"); - assertThat(guardrailValidation.spyUserMessageTemplate()).isNull(); + assertThat(guardrailValidation.spyUserMessageTemplate()).isEmpty(); assertThat(guardrailValidation.spyVariables()).isEmpty(); } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java index b27371263..1900d27b1 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/InputGuardrailParams.java @@ -12,8 +12,8 @@ * @param userMessage the user message, cannot be {@code null} * @param memory the memory, can be {@code null} or empty * @param augmentationResult the augmentation result, can be {@code null} - * @param userMessageTemplate the user message template, can be {@code null} when @UserMessage is not provided. - * @param variables the variable to be used with userMessageTemplate, can be {@code null} or empty + * @param userMessageTemplate the user message template, cannot be {@code null} + * @param variables the variable to be used with userMessageTemplate, cannot be {@code null} */ public record InputGuardrailParams(UserMessage userMessage, ChatMemory memory, AugmentationResult augmentationResult, String userMessageTemplate, diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailParams.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailParams.java index 3bc39f2cd..0162c5f5a 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailParams.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailParams.java @@ -12,8 +12,8 @@ * @param responseFromLLM the response from the LLM * @param memory the memory, can be {@code null} or empty * @param augmentationResult the augmentation result, can be {@code null} - * @param userMessageTemplate the user message template, can be {@code null} when @UserMessage is not provided. - * @param variables the variable to be used with userMessageTemplate, can be {@code null} or empty + * @param userMessageTemplate the user message template, cannot be {@code null} + * @param variables the variable to be used with userMessageTemplate, cannot be {@code null} */ public record OutputGuardrailParams(AiMessage responseFromLLM, ChatMemory memory, AugmentationResult augmentationResult, String userMessageTemplate, diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java index dc095ec2b..e1c26af15 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java @@ -1,5 +1,7 @@ package io.quarkiverse.langchain4j.runtime.aiservice; +import static org.apache.commons.lang3.StringUtils.EMPTY; + import java.lang.reflect.Type; import java.util.List; import java.util.Map; @@ -197,7 +199,7 @@ public String getUserMessageTemplate() { Optional userMessageTemplateOpt = this.getUserMessageInfo().template() .flatMap(AiServiceMethodCreateInfo.TemplateInfo::text); - return userMessageTemplateOpt.orElse(null); + return userMessageTemplateOpt.orElse(EMPTY); } public record UserMessageInfo(Optional template, diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java index 97e48f177..f484949aa 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java @@ -273,7 +273,8 @@ private List messagesToSend(ChatMessage augmentedUserMessage, try { result = GuardrailsSupport.invokeOutputGuardrailsForStream(methodCreateInfo, new OutputGuardrailParams(AiMessage.from(chunk), chatMemory, actualAugmentationResult, - methodCreateInfo.getUserMessageTemplate(), templateVariables)); + methodCreateInfo.getUserMessageTemplate(), + Collections.unmodifiableMap(templateVariables))); } catch (Exception e) { throw new GuardrailException(e.getMessage(), e); } @@ -365,7 +366,7 @@ private List messagesToSend(ChatMessage augmentedUserMessage, response = GuardrailsSupport.invokeOutputGuardrails(methodCreateInfo, chatMemory, context.chatModel, response, toolSpecifications, new OutputGuardrailParams(response.content(), chatMemory, augmentationResult, userMessageTemplate, - templateVariables)); + Collections.unmodifiableMap(templateVariables))); // everything worked as expected so let's commit the messages chatMemory.commit(); diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java index e8d0f8b2c..477a2df1b 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java @@ -37,7 +37,7 @@ public static void invokeInputGuardrails(AiServiceMethodCreateInfo methodCreateI result = invokeInputGuardRails(methodCreateInfo, new InputGuardrailParams(userMessage, chatMemory, augmentationResult, userMessageTemplate, - templateVariables)); + Collections.unmodifiableMap(templateVariables))); } catch (Exception e) { throw new GuardrailException(e.getMessage(), e); }