Skip to content

Commit

Permalink
946 - Add prompt template and variables to output guardrails
Browse files Browse the repository at this point in the history
  • Loading branch information
dennysfredericci committed Oct 26, 2024
1 parent 180d25e commit 5f55595
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams;
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
import io.quarkiverse.langchain4j.guardrails.InputGuardrails;
import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory;
Expand Down Expand Up @@ -57,8 +58,7 @@ void shouldWorkWithMemoryId() {
assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me another joke");
assertThat(guardrailValidation.spyVariables()).containsExactlyInAnyOrderEntriesOf(Map.of(
"memoryId", "memory-id-001",
"it", "memory-id-001" // is this correct?
));
"it", "memory-id-001"));
}

@Test
Expand Down Expand Up @@ -140,11 +140,7 @@ void shouldWorkWithMemoryIdAndOneItemFromList() {
@Test
@ActivateRequestContext
void shouldWorkWithNoUserMessage() {
// This is a special case where the UserMessage annotation is not present
// The prompt template doesn't exist in this case
// But the current implementation use the parameter name as prompt template
// Not sure if this is the correct behavior, should we always have @UserMessage?
// I need some thoughts on this case
// UserMessage annotation is not provided, then no user message template should be available
aiService.saySomething("Is this a parameter or a prompt?");
assertThat(guardrailValidation.spyUserMessageTemplate()).isNull();
assertThat(guardrailValidation.spyVariables()).isEmpty();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
package io.quarkiverse.langchain4j.test.guardrails;

import static org.assertj.core.api.Assertions.assertThat;

import java.util.List;
import java.util.Map;
import java.util.function.Supplier;

import jakarta.enterprise.context.RequestScoped;
import jakarta.enterprise.context.control.ActivateRequestContext;
import jakarta.inject.Inject;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.service.MemoryId;
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory;
import io.quarkus.test.QuarkusUnitTest;

public class OutputGuardrailPromptTemplateTest {

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addClasses(MyAiService.class, MyAiService.class, GuardrailValidation.class,
MyChatModel.class, MyChatModelSupplier.class, MyMemoryProviderSupplier.class));
@Inject
MyAiService aiService;

@Inject
GuardrailValidation guardrailValidation;

@Test
@ActivateRequestContext
void shouldWorkNoParameters() {
aiService.getJoke();
assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me a joke");
assertThat(guardrailValidation.spyVariables()).isEmpty();
}

@Test
@ActivateRequestContext
void shouldWorkWithMemoryId() {
aiService.getAnotherJoke("memory-id-001");
assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me another joke");
assertThat(guardrailValidation.spyVariables()).containsExactlyInAnyOrderEntriesOf(Map.of(
"memoryId", "memory-id-001",
"it", "memory-id-001"));
}

@Test
@ActivateRequestContext
void shouldWorkWithNoMemoryIdAndOneParameter() {
aiService.sayHiToMyFriendNoMemory("Rambo");
assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Say hi to my friend {friend}!");
assertThat(guardrailValidation.spyVariables())
.containsExactlyInAnyOrderEntriesOf(Map.of(
"friend", "Rambo",
"it", "Rambo"));
}

@Test
@ActivateRequestContext
void shouldWorkWithMemoryIdAndOneParameter() {
aiService.sayHiToMyFriend("1", "Chuck Norris");
assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Say hi to my friend {friend}!");
assertThat(guardrailValidation.spyVariables())
.containsExactlyInAnyOrderEntriesOf(Map.of(
"friend", "Chuck Norris",
"mem", "1"));
}

@Test
@ActivateRequestContext
void shouldWorkWithNoMemoryIdAndThreeParameters() {
aiService.sayHiToMyFriends("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone");
assertThat(guardrailValidation.spyUserMessageTemplate())
.isEqualTo("Tell me something about {topic1}, {topic2}, {topic3}!");
assertThat(guardrailValidation.spyVariables())
.containsExactlyInAnyOrderEntriesOf(Map.of(
"topic1", "Chuck Norris",
"topic2", "Jean-Claude Van Damme",
"topic3", "Silvester Stallone"));
}

@Test
@ActivateRequestContext
void shouldWorkWithNoMemoryIdAndList() {
aiService.sayHiToMyFriends(List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"));

assertThat(guardrailValidation.spyUserMessageTemplate()).isEqualTo("Tell me something about {topics}!");
assertThat(guardrailValidation.spyVariables())
.containsExactlyInAnyOrderEntriesOf(Map.of(
"topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"),
"it", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone")));
}

@Test
@ActivateRequestContext
void shouldWorkWithMemoryIdAndList() {
aiService.sayHiToMyFriends("memory-id-007", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"));

assertThat(guardrailValidation.spyUserMessageTemplate())
.isEqualTo("Tell me something about {topics}! This is my memory id: {memoryId}");
assertThat(guardrailValidation.spyVariables())
.containsExactlyInAnyOrderEntriesOf(Map.of(
"topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"),
"memoryId", "memory-id-007"));
}

@Test
@ActivateRequestContext
void shouldWorkWithMemoryIdAndOneItemFromList() {
aiService.sayHiToMyFriend("memory-id-007", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"));

assertThat(guardrailValidation.spyUserMessageTemplate())
.isEqualTo("Tell me something about {topics[0]}! This is my memory id: {memoryId}");
assertThat(guardrailValidation.spyVariables())
.containsExactlyInAnyOrderEntriesOf(Map.of(
"topics", List.of("Chuck Norris", "Jean-Claude Van Damme", "Silvester Stallone"),
"memoryId", "memory-id-007"));
}

@Test
@ActivateRequestContext
void shouldWorkWithNoUserMessage() {
// UserMessage annotation is not provided, then no user message template should be available
aiService.saySomething("Is this a parameter or a prompt?");
assertThat(guardrailValidation.spyUserMessageTemplate()).isNull();
assertThat(guardrailValidation.spyVariables()).isEmpty();
}

@RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
public interface MyAiService {

@OutputGuardrails(GuardrailValidation.class)
@UserMessage("Tell me a joke")
String getJoke();

@UserMessage("Tell me another joke")
@OutputGuardrails(GuardrailValidation.class)
String getAnotherJoke(@MemoryId String memoryId);

@UserMessage("Say hi to my friend {friend}!")
@OutputGuardrails(GuardrailValidation.class)
String sayHiToMyFriendNoMemory(String friend);

@UserMessage("Say hi to my friend {friend}!")
@OutputGuardrails(GuardrailValidation.class)
String sayHiToMyFriend(@MemoryId String mem, String friend);

@UserMessage("Tell me something about {topic1}, {topic2}, {topic3}!")
@OutputGuardrails(GuardrailValidation.class)
String sayHiToMyFriends(String topic1, String topic2, String topic3);

@UserMessage("Tell me something about {topics}!")
@OutputGuardrails(GuardrailValidation.class)
String sayHiToMyFriends(List<String> topics);

@UserMessage("Tell me something about {topics}! This is my memory id: {memoryId}")
@OutputGuardrails(GuardrailValidation.class)
String sayHiToMyFriends(@MemoryId String memoryId, List<String> topics);

@UserMessage("Tell me something about {topics[0]}! This is my memory id: {memoryId}")
@OutputGuardrails(GuardrailValidation.class)
String sayHiToMyFriend(@MemoryId String memoryId, List<String> topics);

@OutputGuardrails(GuardrailValidation.class)
String saySomething(String isThisAPromptOrAParameter);

}

@RequestScoped
public static class GuardrailValidation implements OutputGuardrail {

OutputGuardrailParams params;

public OutputGuardrailResult validate(OutputGuardrailParams params) {
this.params = params;
return success();
}

public String spyUserMessageTemplate() {
return params.userMessageTemplate();
}

public Map<String, Object> spyVariables() {
return params.variables();
}
}

public static class MyChatModelSupplier implements Supplier<ChatLanguageModel> {

@Override
public ChatLanguageModel get() {
return new MyChatModel();
}
}

public static class MyChatModel implements ChatLanguageModel {

@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
return new Response<>(new AiMessage("Hi!"));
}
}

public static class MyMemoryProviderSupplier implements Supplier<ChatMemoryProvider> {
@Override
public ChatMemoryProvider get() {
return memoryId -> new NoopChatMemory();
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package io.quarkiverse.langchain4j.guardrails;

import java.util.Arrays;
import java.util.Map;

import dev.langchain4j.data.message.UserMessage;
import io.smallrye.common.annotation.Experimental;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.quarkiverse.langchain4j.guardrails;

import java.util.Map;

import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.rag.AugmentationResult;
Expand All @@ -10,7 +12,10 @@
* @param userMessage the user message, cannot be {@code null}
* @param memory the memory, can be {@code null} or empty
* @param augmentationResult the augmentation result, can be {@code null}
* @param userMessageTemplate the user message template, can be {@code null} when @UserMessage is not provided.
* @param variables the variable to be used with userMessageTemplate, can be {@code null} or empty
*/
public record InputGuardrailParams(UserMessage userMessage, ChatMemory memory,
AugmentationResult augmentationResult) implements GuardrailParams {
AugmentationResult augmentationResult, String userMessageTemplate,
Map<String, Object> variables) implements GuardrailParams {
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.quarkiverse.langchain4j.guardrails;

import java.util.Map;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.rag.AugmentationResult;
Expand All @@ -10,7 +12,10 @@
* @param responseFromLLM the response from the LLM
* @param memory the memory, can be {@code null} or empty
* @param augmentationResult the augmentation result, can be {@code null}
* @param userMessageTemplate the user message template, can be {@code null} when @UserMessage is not provided.
* @param variables the variable to be used with userMessageTemplate, can be {@code null} or empty
*/
public record OutputGuardrailParams(AiMessage responseFromLLM, ChatMemory memory,
AugmentationResult augmentationResult) implements GuardrailParams {
AugmentationResult augmentationResult, String userMessageTemplate,
Map<String, Object> variables) implements GuardrailParams {
}
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,13 @@ public OutputTokenAccumulator getOutputTokenAccumulator() {
return accumulator;
}

public String getUserMessageTemplate() {
Optional<String> userMessageTemplateOpt = this.getUserMessageInfo().template()
.flatMap(AiServiceMethodCreateInfo.TemplateInfo::text);

return userMessageTemplateOpt.orElse(null);
}

public record UserMessageInfo(Optional<TemplateInfo> template,
Optional<Integer> paramPosition,
Optional<Integer> userNameParamPosition,
Expand Down
Loading

0 comments on commit 5f55595

Please sign in to comment.