Skip to content

Commit

Permalink
Merge pull request #1021 from mariofusco/out_guard_with_result
Browse files Browse the repository at this point in the history
  • Loading branch information
cescoffier authored Nov 3, 2024
2 parents 5faca9b + 4ac0106 commit ed12693
Show file tree
Hide file tree
Showing 10 changed files with 184 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.quarkiverse.langchain4j.test.guardrails;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;

import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -28,6 +29,7 @@
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrails;
import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException;
import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory;
import io.quarkus.test.QuarkusUnitTest;

Expand Down Expand Up @@ -78,6 +80,20 @@ void testThatRetryRestartTheChain() {
assertThat(firstGuardrail.lastAccess()).isLessThan(secondGuardrail.lastAccess());
}

@Test
@ActivateRequestContext
void testThatRewritesTheOutputTwiceInTheChain() {
assertThat(aiService.rewritingSuccess("1", "foo")).isEqualTo("Hi!,1,2");
}

@Test
@ActivateRequestContext
void testThatRepromptAfterRewriteIsNotAllowed() {
assertThatExceptionOfType(GuardrailException.class)
.isThrownBy(() -> aiService.repromptAfterRewrite("1", "foo"))
.withMessageContaining("Retry or reprompt is not allowed after a rewritten output");
}

@RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
public interface MyAiService {

Expand All @@ -90,6 +106,12 @@ public interface MyAiService {
@OutputGuardrails({ FirstGuardrail.class, FailingGuardrail.class, SecondGuardrail.class })
String failingFirstTwo(@MemoryId String mem, @UserMessage String message);

@OutputGuardrails({ FirstRewritingGuardrail.class, SecondRewritingGuardrail.class })
String rewritingSuccess(@MemoryId String mem, @UserMessage String message);

@OutputGuardrails({ FirstRewritingGuardrail.class, RepromptingGuardrail.class })
String repromptAfterRewrite(@MemoryId String mem, @UserMessage String message);

}

@RequestScoped
Expand Down Expand Up @@ -164,6 +186,42 @@ public int spy() {
}
}

@RequestScoped
public static class FirstRewritingGuardrail implements OutputGuardrail {

@Override
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
String text = responseFromLLM.text();
return successWith(text + ",1");
}
}

@RequestScoped
public static class SecondRewritingGuardrail implements OutputGuardrail {

@Override
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
String text = responseFromLLM.text();
return successWith(text + ",2");
}
}

@RequestScoped
public static class RepromptingGuardrail implements OutputGuardrail {

private boolean firstCall = true;

@Override
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
if (firstCall) {
firstCall = false;
String text = responseFromLLM.text();
return reprompt("Wrong message", text + ", " + text);
}
return success();
}
}

public static class MyChatModelSupplier implements Supplier<ChatLanguageModel> {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ void testFatalExceptionWithPassThroughAccumulator() {
assertThat(fatal.spy()).isEqualTo(1);
}

@Test
@ActivateRequestContext
void testRewritingWhileStreamingIsNotAllowed() {
assertThatThrownBy(() -> aiService.rewriting("1").collect().asList().await().indefinitely())
.isInstanceOf(GuardrailException.class)
.hasMessageContaining("Attempting to rewrite the LLM output while streaming is not allowed");
}

@RegisterAiService(streamingChatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
public interface MyAiService {

Expand Down Expand Up @@ -187,6 +195,9 @@ public interface MyAiService {
@OutputGuardrailAccumulator(PassThroughAccumulator.class)
Multi<String> fatalWithPassThroughAccumulator(@MemoryId String mem);

@UserMessage("Say Hi!")
@OutputGuardrails({ RewritingGuardrail.class })
Multi<String> rewriting(@MemoryId String mem);
}

@RequestScoped
Expand Down Expand Up @@ -272,6 +283,16 @@ public int spy() {
}
}

@RequestScoped
public static class RewritingGuardrail implements OutputGuardrail {

@Override
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
String text = responseFromLLM.text();
return successWith(text + ",1");
}
}

public static class MyChatModelSupplier implements Supplier<StreamingChatLanguageModel> {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,12 @@ public interface GuardrailParams {
* @return the augmentation result, can be {@code null}
*/
AugmentationResult augmentationResult();

/**
* Recreate this guardrail param with the given input or output text.
*
* @param text The text of the rewritten param.
* @return A clone of this guardrail params with the given input or output text.
*/
GuardrailParams withText(String text);
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ enum Result {
* A successful validation.
*/
SUCCESS,
/**
* A successful validation with a specific result.
*/
SUCCESS_WITH_RESULT,
/**
* A failed validation not preventing the subsequent validations eventually registered to be evaluated.
*/
Expand All @@ -27,6 +31,18 @@ enum Result {

boolean isSuccess();

default boolean isRewrittenResult() {
return false;
}

default GuardrailResult<GR> blockRetry() {
throw new UnsupportedOperationException();
}

default String successfulResult() {
throw new UnsupportedOperationException();
}

boolean isFatal();

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,9 @@
public record InputGuardrailParams(UserMessage userMessage, ChatMemory memory,
AugmentationResult augmentationResult, String userMessageTemplate,
Map<String, Object> variables) implements GuardrailParams {

@Override
public InputGuardrailParams withText(String text) {
throw new UnsupportedOperationException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ default OutputGuardrailResult success() {
return OutputGuardrailResult.success();
}

/**
* @return The result of a successful output guardrail validation with a specific result.
* @param successfulResult The successful result.
*/
default OutputGuardrailResult successWith(String successfulResult) {
return OutputGuardrailResult.successWith(successfulResult);
}

/**
* @param message A message describing the failure.
* @return The result of a failed output guardrail validation.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package io.quarkiverse.langchain4j.guardrails;

import java.util.List;
import java.util.Map;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.rag.AugmentationResult;
Expand All @@ -18,4 +20,11 @@
public record OutputGuardrailParams(AiMessage responseFromLLM, ChatMemory memory,
AugmentationResult augmentationResult, String userMessageTemplate,
Map<String, Object> variables) implements GuardrailParams {

@Override
public OutputGuardrailParams withText(String text) {
List<ToolExecutionRequest> tools = responseFromLLM.toolExecutionRequests();
AiMessage aiMessage = tools != null && !tools.isEmpty() ? new AiMessage(text, tools) : new AiMessage(text);
return new OutputGuardrailParams(aiMessage, memory, augmentationResult, userMessageTemplate, variables);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,54 @@
* @param result The result of the output guardrail validation.
* @param failures The list of failures, empty if the validation succeeded.
*/
public record OutputGuardrailResult(Result result, List<Failure> failures) implements GuardrailResult<OutputGuardrailResult> {
public record OutputGuardrailResult(Result result, String successfulResult,
List<Failure> failures) implements GuardrailResult<OutputGuardrailResult> {

private static final OutputGuardrailResult SUCCESS = new OutputGuardrailResult();

private OutputGuardrailResult() {
this(Result.SUCCESS, Collections.emptyList());
this(Result.SUCCESS, null, Collections.emptyList());
}

private OutputGuardrailResult(String successfulResult) {
this(Result.SUCCESS_WITH_RESULT, successfulResult, Collections.emptyList());
}

OutputGuardrailResult(List<Failure> failures, boolean fatal) {
this(fatal ? Result.FATAL : Result.FAILURE, failures);
this(fatal ? Result.FATAL : Result.FAILURE, null, failures);
}

public static OutputGuardrailResult success() {
return SUCCESS;
}

public static OutputGuardrailResult successWith(String successfulResult) {
return new OutputGuardrailResult(successfulResult);
}

public static OutputGuardrailResult failure(List<? extends GuardrailResult.Failure> failures) {
return new OutputGuardrailResult((List<Failure>) failures, false);
}

@Override
public boolean isSuccess() {
return result == Result.SUCCESS;
return result == Result.SUCCESS || result == Result.SUCCESS_WITH_RESULT;
}

@Override
public boolean isRewrittenResult() {
return result == Result.SUCCESS_WITH_RESULT;
}

public boolean isRetry() {
return !isSuccess() && failures.stream().anyMatch(Failure::retry);
}

public OutputGuardrailResult blockRetry() {
failures().set(0, failures().get(0).blockRetry());
return this;
}

public String getReprompt() {
if (!isSuccess()) {
for (Failure failure : failures) {
Expand Down Expand Up @@ -97,6 +116,13 @@ public Failure withGuardrailClass(Class<? extends Guardrail> guardrailClass) {
return new Failure(message(), cause(), guardrailClass, retry, reprompt);
}

public Failure blockRetry() {
return retry
? new Failure("Retry or reprompt is not allowed after a rewritten output", cause(), guardrailClass, false,
reprompt)
: this;
}

@Override
public String toString() {
return "The guardrail " + guardrailClass.getName() + " failed with this message: " + message;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ private List<ChatMessage> messagesToSend(ChatMessage augmentedUserMessage,
throw new GuardrailsSupport.GuardrailRetryException();
}
} else {
if (result.isRewrittenResult()) {
throw new GuardrailException(
"Attempting to rewrite the LLM output while streaming is not allowed");
}
return chunk;
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import jakarta.enterprise.inject.spi.CDI;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.UserMessage;
Expand Down Expand Up @@ -57,8 +58,9 @@ public static Response<AiMessage> invokeOutputGuardrails(AiServiceMethodCreateIn
if (max <= 0) {
max = 1;
}

OutputGuardrailResult result = null;
while (attempt < max) {
OutputGuardrailResult result;
try {
result = invokeOutputGuardRails(methodCreateInfo, output);
} catch (Exception e) {
Expand Down Expand Up @@ -97,9 +99,20 @@ public static Response<AiMessage> invokeOutputGuardrails(AiServiceMethodCreateIn
if (attempt == max) {
throw new GuardrailException("Output validation failed. The guardrails have reached the maximum number of retries");
}

if (result.isRewrittenResult()) {
response = rewriteResponseWithText(response, result.successfulResult());
}

return response;
}

public static Response<AiMessage> rewriteResponseWithText(Response<AiMessage> response, String text) {
List<ToolExecutionRequest> tools = response.content().toolExecutionRequests();
AiMessage content = tools != null && !tools.isEmpty() ? new AiMessage(text, tools) : new AiMessage(text);
return new Response<>(content, response.tokenUsage(), response.finishReason(), response.metadata());
}

@SuppressWarnings("unchecked")
private static OutputGuardrailResult invokeOutputGuardRails(AiServiceMethodCreateInfo methodCreateInfo,
OutputGuardrailParams params) {
Expand Down Expand Up @@ -160,25 +173,28 @@ private static <GR extends GuardrailResult> GR guardrailResult(GuardrailParams p
for (Class<? extends Guardrail> bean : classes) {
GR result = (GR) CDI.current().select(bean).get().validate(params).validatedBy(bean);
if (result.isFatal()) {
return result;
return accumulatedResults.isRewrittenResult() ? (GR) result.blockRetry() : result;
}
if (result.isRewrittenResult()) {
params = params.withText(result.successfulResult());
}
accumulatedResults = compose(accumulatedResults, result, producer);
}

return accumulatedResults;
}

private static <GR extends GuardrailResult> GR compose(GR first, GR second,
private static <GR extends GuardrailResult> GR compose(GR oldResult, GR newResult,
Function<List<? extends GuardrailResult.Failure>, GR> producer) {
if (first.isSuccess()) {
return second;
if (oldResult.isSuccess()) {
return newResult;
}
if (second.isSuccess()) {
return first;
if (newResult.isSuccess()) {
return oldResult;
}
List<? extends GuardrailResult.Failure> failures = new ArrayList<>();
failures.addAll(first.failures());
failures.addAll(second.failures());
failures.addAll(oldResult.failures());
failures.addAll(newResult.failures());
return producer.apply(failures);
}

Expand Down

0 comments on commit ed12693

Please sign in to comment.