Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to rewrite LLM result in an OutputGuardrail #1021

Merged
merged 1 commit into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.quarkiverse.langchain4j.test.guardrails;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;

import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -28,6 +29,7 @@
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException;
import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory;
import io.quarkus.test.QuarkusUnitTest;

Expand Down Expand Up @@ -78,6 +80,20 @@ void testThatRetryRestartTheChain() {
assertThat(firstGuardrail.lastAccess()).isLessThan(secondGuardrail.lastAccess());
}

@Test
@ActivateRequestContext
void testThatRewritesTheOutputTwiceInTheChain() {
assertThat(aiService.rewritingSuccess("1", "foo")).isEqualTo("Hi!,1,2");
}

@Test
@ActivateRequestContext
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question for @geoand - do you know if we can use @ActivateRequestContext on the class itself?

void testThatRepromptAfterRewriteIsNotAllowed() {
assertThatExceptionOfType(GuardrailException.class)
.isThrownBy(() -> aiService.repromptAfterRewrite("1", "foo"))
.withMessageContaining("Retry or reprompt is not allowed after a rewritten output");
}

@RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
public interface MyAiService {

Expand All @@ -90,6 +106,12 @@ public interface MyAiService {
@OutputGuardrails({ FirstGuardrail.class, FailingGuardrail.class, SecondGuardrail.class })
String failingFirstTwo(@MemoryId String mem, @UserMessage String message);

@OutputGuardrails({ FirstRewritingGuardrail.class, SecondRewritingGuardrail.class })
String rewritingSuccess(@MemoryId String mem, @UserMessage String message);

@OutputGuardrails({ FirstRewritingGuardrail.class, RepromptingGuardrail.class })
String repromptAfterRewrite(@MemoryId String mem, @UserMessage String message);

}

@RequestScoped
Expand Down Expand Up @@ -164,6 +186,42 @@ public int spy() {
}
}

@RequestScoped
public static class FirstRewritingGuardrail implements OutputGuardrail {

@Override
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
String text = responseFromLLM.text();
return successWith(text + ",1");
}
}

@RequestScoped
public static class SecondRewritingGuardrail implements OutputGuardrail {

@Override
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
String text = responseFromLLM.text();
return successWith(text + ",2");
}
}

@RequestScoped
public static class RepromptingGuardrail implements OutputGuardrail {

private boolean firstCall = true;

@Override
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
if (firstCall) {
firstCall = false;
String text = responseFromLLM.text();
return reprompt("Wrong message", text + ", " + text);
}
return success();
}
}

public static class MyChatModelSupplier implements Supplier<ChatLanguageModel> {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ void testFatalExceptionWithPassThroughAccumulator() {
assertThat(fatal.spy()).isEqualTo(1);
}

@Test
@ActivateRequestContext
void testRewritingWhileStreamingIsNotAllowed() {
assertThatThrownBy(() -> aiService.rewriting("1").collect().asList().await().indefinitely())
.isInstanceOf(GuardrailException.class)
.hasMessageContaining("Attempting to rewrite the LLM output while streaming is not allowed");
}

@RegisterAiService(streamingChatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
public interface MyAiService {

Expand Down Expand Up @@ -187,6 +195,9 @@ public interface MyAiService {
@OutputGuardrailAccumulator(PassThroughAccumulator.class)
Multi<String> fatalWithPassThroughAccumulator(@MemoryId String mem);

@UserMessage("Say Hi!")
@OutputGuardrails({ RewritingGuardrail.class })
Multi<String> rewriting(@MemoryId String mem);
}

@RequestScoped
Expand Down Expand Up @@ -272,6 +283,16 @@ public int spy() {
}
}

@RequestScoped
public static class RewritingGuardrail implements OutputGuardrail {

@Override
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
String text = responseFromLLM.text();
return successWith(text + ",1");
}
}

public static class MyChatModelSupplier implements Supplier<StreamingChatLanguageModel> {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,12 @@ public interface GuardrailParams {
* @return the augmentation result, can be {@code null}
*/
AugmentationResult augmentationResult();

/**
* Recreate this guardrail param with the given input or output text.
*
* @param text The text of the rewritten param.
* @return A clone of this guardrail params with the given input or output text.
*/
GuardrailParams withText(String text);
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ enum Result {
* A successful validation.
*/
SUCCESS,
/**
* A successful validation with a specific result.
*/
SUCCESS_WITH_RESULT,
/**
* A failed validation not preventing the subsequent validations eventually registered to be evaluated.
*/
Expand All @@ -27,6 +31,18 @@ enum Result {

boolean isSuccess();

default boolean isRewrittenResult() {
return false;
}

default GuardrailResult<GR> blockRetry() {
throw new UnsupportedOperationException();
}

default String successfulResult() {
throw new UnsupportedOperationException();
}

boolean isFatal();

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,9 @@
public record InputGuardrailParams(UserMessage userMessage, ChatMemory memory,
AugmentationResult augmentationResult, String userMessageTemplate,
Map<String, Object> variables) implements GuardrailParams {

@Override
public InputGuardrailParams withText(String text) {
throw new UnsupportedOperationException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ default OutputGuardrailResult success() {
return OutputGuardrailResult.success();
}

/**
* @return The result of a successful output guardrail validation with a specific result.
* @param successfulResult The successful result.
*/
default OutputGuardrailResult successWith(String successfulResult) {
return OutputGuardrailResult.successWith(successfulResult);
}

/**
* @param message A message describing the failure.
* @return The result of a failed output guardrail validation.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package io.quarkiverse.langchain4j.guardrails;

import java.util.List;
import java.util.Map;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.rag.AugmentationResult;
Expand All @@ -18,4 +20,11 @@
public record OutputGuardrailParams(AiMessage responseFromLLM, ChatMemory memory,
AugmentationResult augmentationResult, String userMessageTemplate,
Map<String, Object> variables) implements GuardrailParams {

@Override
public OutputGuardrailParams withText(String text) {
List<ToolExecutionRequest> tools = responseFromLLM.toolExecutionRequests();
AiMessage aiMessage = tools != null && !tools.isEmpty() ? new AiMessage(text, tools) : new AiMessage(text);
return new OutputGuardrailParams(aiMessage, memory, augmentationResult, userMessageTemplate, variables);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,54 @@
* @param result The result of the output guardrail validation.
* @param failures The list of failures, empty if the validation succeeded.
*/
public record OutputGuardrailResult(Result result, List<Failure> failures) implements GuardrailResult<OutputGuardrailResult> {
public record OutputGuardrailResult(Result result, String successfulResult,
List<Failure> failures) implements GuardrailResult<OutputGuardrailResult> {

private static final OutputGuardrailResult SUCCESS = new OutputGuardrailResult();

private OutputGuardrailResult() {
this(Result.SUCCESS, Collections.emptyList());
this(Result.SUCCESS, null, Collections.emptyList());
}

private OutputGuardrailResult(String successfulResult) {
this(Result.SUCCESS_WITH_RESULT, successfulResult, Collections.emptyList());
}

OutputGuardrailResult(List<Failure> failures, boolean fatal) {
this(fatal ? Result.FATAL : Result.FAILURE, failures);
this(fatal ? Result.FATAL : Result.FAILURE, null, failures);
}

public static OutputGuardrailResult success() {
return SUCCESS;
}

public static OutputGuardrailResult successWith(String successfulResult) {
return new OutputGuardrailResult(successfulResult);
}

public static OutputGuardrailResult failure(List<? extends GuardrailResult.Failure> failures) {
return new OutputGuardrailResult((List<Failure>) failures, false);
}

@Override
public boolean isSuccess() {
return result == Result.SUCCESS;
return result == Result.SUCCESS || result == Result.SUCCESS_WITH_RESULT;
}

@Override
public boolean isRewrittenResult() {
return result == Result.SUCCESS_WITH_RESULT;
}

public boolean isRetry() {
return !isSuccess() && failures.stream().anyMatch(Failure::retry);
}

public OutputGuardrailResult blockRetry() {
failures().set(0, failures().get(0).blockRetry());
return this;
}

public String getReprompt() {
if (!isSuccess()) {
for (Failure failure : failures) {
Expand Down Expand Up @@ -97,6 +116,13 @@ public Failure withGuardrailClass(Class<? extends Guardrail> guardrailClass) {
return new Failure(message(), cause(), guardrailClass, retry, reprompt);
}

public Failure blockRetry() {
return retry
? new Failure("Retry or reprompt is not allowed after a rewritten output", cause(), guardrailClass, false,
reprompt)
: this;
}

@Override
public String toString() {
return "The guardrail " + guardrailClass.getName() + " failed with this message: " + message;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ private List<ChatMessage> messagesToSend(ChatMessage augmentedUserMessage,
throw new GuardrailsSupport.GuardrailRetryException();
}
} else {
if (result.isRewrittenResult()) {
throw new GuardrailException(
"Attempting to rewrite the LLM output while streaming is not allowed");
}
return chunk;
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import jakarta.enterprise.inject.spi.CDI;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.UserMessage;
Expand Down Expand Up @@ -57,8 +58,9 @@ public static Response<AiMessage> invokeOutputGuardrails(AiServiceMethodCreateIn
if (max <= 0) {
max = 1;
}

OutputGuardrailResult result = null;
while (attempt < max) {
OutputGuardrailResult result;
try {
result = invokeOutputGuardRails(methodCreateInfo, output);
} catch (Exception e) {
Expand Down Expand Up @@ -97,9 +99,20 @@ public static Response<AiMessage> invokeOutputGuardrails(AiServiceMethodCreateIn
if (attempt == max) {
throw new GuardrailException("Output validation failed. The guardrails have reached the maximum number of retries");
}

if (result.isRewrittenResult()) {
response = rewriteResponseWithText(response, result.successfulResult());
}

return response;
}

public static Response<AiMessage> rewriteResponseWithText(Response<AiMessage> response, String text) {
List<ToolExecutionRequest> tools = response.content().toolExecutionRequests();
AiMessage content = tools != null && !tools.isEmpty() ? new AiMessage(text, tools) : new AiMessage(text);
return new Response<>(content, response.tokenUsage(), response.finishReason(), response.metadata());
}

@SuppressWarnings("unchecked")
private static OutputGuardrailResult invokeOutputGuardRails(AiServiceMethodCreateInfo methodCreateInfo,
OutputGuardrailParams params) {
Expand Down Expand Up @@ -160,25 +173,28 @@ private static <GR extends GuardrailResult> GR guardrailResult(GuardrailParams p
for (Class<? extends Guardrail> bean : classes) {
GR result = (GR) CDI.current().select(bean).get().validate(params).validatedBy(bean);
if (result.isFatal()) {
return result;
return accumulatedResults.isRewrittenResult() ? (GR) result.blockRetry() : result;
}
if (result.isRewrittenResult()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't remember if this method is invoked when using streamed responses. Streams make things slightly more convoluted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this method is used only for streamed response. I'm keeping this rewriting here regardless, but now if I find that this rewriting happened while streaming I throw an exception as discussed.

params = params.withText(result.successfulResult());
}
accumulatedResults = compose(accumulatedResults, result, producer);
}

return accumulatedResults;
}

private static <GR extends GuardrailResult> GR compose(GR first, GR second,
private static <GR extends GuardrailResult> GR compose(GR oldResult, GR newResult,
Function<List<? extends GuardrailResult.Failure>, GR> producer) {
if (first.isSuccess()) {
return second;
if (oldResult.isSuccess()) {
return newResult;
}
if (second.isSuccess()) {
return first;
if (newResult.isSuccess()) {
return oldResult;
}
List<? extends GuardrailResult.Failure> failures = new ArrayList<>();
failures.addAll(first.failures());
failures.addAll(second.failures());
failures.addAll(oldResult.failures());
failures.addAll(newResult.failures());
return producer.apply(failures);
}

Expand Down