Skip to content

Commit

Permalink
Provide an abstract output guardrails for json data extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
mariofusco committed Nov 13, 2024
1 parent 6efb975 commit 78fcaf4
Show file tree
Hide file tree
Showing 10 changed files with 265 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ void testThatRepromptAfterRewriteIsNotAllowed() {
.withMessageContaining("Retry or reprompt is not allowed after a rewritten output");
}

@Test
@ActivateRequestContext
void testThatRewritesTheOutputWithAResult() {
assertThat(aiService.rewritingSuccessWithResult("1", "foo")).isSameAs(RewritingGuardrailWithResult.RESULT);
}

@RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
public interface MyAiService {

Expand All @@ -112,6 +118,9 @@ public interface MyAiService {
@OutputGuardrails({ FirstRewritingGuardrail.class, RepromptingGuardrail.class })
String repromptAfterRewrite(@MemoryId String mem, @UserMessage String message);

@OutputGuardrails({ FirstRewritingGuardrail.class, RewritingGuardrailWithResult.class })
Integer rewritingSuccessWithResult(@MemoryId String mem, @UserMessage String message);

}

@RequestScoped
Expand Down Expand Up @@ -206,6 +215,18 @@ public OutputGuardrailResult validate(AiMessage responseFromLLM) {
}
}

@RequestScoped
public static class RewritingGuardrailWithResult implements OutputGuardrail {

static final Integer RESULT = 1_000;

@Override
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
String text = responseFromLLM.text();
return successWith(text + ",2", RESULT);
}
}

@RequestScoped
public static class RepromptingGuardrail implements OutputGuardrail {

Expand Down
5 changes: 5 additions & 0 deletions core/runtime/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-junit5</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package io.quarkiverse.langchain4j.guardrails;

import jakarta.inject.Inject;

import org.jboss.logging.Logger;

import com.fasterxml.jackson.core.type.TypeReference;

import dev.langchain4j.data.message.AiMessage;

public abstract class AbstractJsonExtractorOutputGuardrail implements OutputGuardrail {

@Inject
Logger logger;

@Inject
JsonGuardrailsUtils jsonGuardrailsUtils;

protected AbstractJsonExtractorOutputGuardrail() {
if (getOutputClass() == null && getOutputType() == null) {
throw new IllegalArgumentException("Either getOutputClass() or getOutputType() must be implemented");
}
}

@Override
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
String llmResponse = responseFromLLM.text();
logger.debugf("LLM output: %s", llmResponse);

Object result = deserialize(llmResponse);
if (result != null) {
return successWith(llmResponse, result);
}

String json = jsonGuardrailsUtils.trimNonJson(llmResponse);
if (json != null) {
result = deserialize(json);
if (result != null) {
return successWith(json, result);
}
}

return reprompt("Invalid JSON",
"Make sure you return a valid JSON object following "
+ "the specified format");
}

protected Object deserialize(String llmResponse) {
return getOutputClass() != null ? jsonGuardrailsUtils.deserialize(llmResponse, getOutputClass())
: jsonGuardrailsUtils.deserialize(llmResponse, getOutputType());
}

protected Class<?> getOutputClass() {
return null;
}

protected TypeReference<?> getOutputType() {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,19 @@ enum Result {

boolean isSuccess();

default boolean isRewrittenResult() {
default boolean hasRewrittenResult() {
return false;
}

default GuardrailResult<GR> blockRetry() {
throw new UnsupportedOperationException();
}

default String successfulResult() {
default String successfulText() {
throw new UnsupportedOperationException();
}

default Object successfulResult() {
throw new UnsupportedOperationException();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package io.quarkiverse.langchain4j.guardrails;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;

@ApplicationScoped
class JsonGuardrailsUtils {

@Inject
ObjectMapper objectMapper;

private JsonGuardrailsUtils() {
}

String trimNonJson(String llmResponse) {
int jsonMapStart = llmResponse.indexOf('{');
int jsonListStart = llmResponse.indexOf('[');
if (jsonMapStart < 0 && jsonListStart < 0) {
return null;
}
boolean isJsonMap = jsonMapStart >= 0 && (jsonMapStart < jsonListStart || jsonListStart < 0);

int jsonStart = isJsonMap ? jsonMapStart : jsonListStart;
int jsonEnd = isJsonMap ? llmResponse.lastIndexOf('}') : llmResponse.lastIndexOf(']');
return jsonEnd >= 0 && jsonStart < jsonEnd ? llmResponse.substring(jsonStart, jsonEnd + 1) : null;
}

<T> T deserialize(String json, Class<T> expectedOutputClass) {
try {
return objectMapper.readValue(json, expectedOutputClass);
} catch (JsonProcessingException e) {
return null;
}
}

<T> T deserialize(String json, TypeReference<T> expectedOutputType) {
try {
return objectMapper.readValue(json, expectedOutputType);
} catch (JsonProcessingException e) {
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,20 @@ default OutputGuardrailResult success() {
}

/**
* @return The result of a successful output guardrail validation with a specific result.
* @param successfulResult The successful result.
* @return The result of a successful output guardrail validation with a specific text.
* @param successfulText The text of the successful result.
*/
default OutputGuardrailResult successWith(String successfulResult) {
return OutputGuardrailResult.successWith(successfulResult);
default OutputGuardrailResult successWith(String successfulText) {
return OutputGuardrailResult.successWith(successfulText);
}

/**
* @return The result of a successful output guardrail validation with a specific text.
* @param successfulText The text of the successful result.
* @param successfulResult The object generated by this successful result.
*/
default OutputGuardrailResult successWith(String successfulText, Object successfulResult) {
return OutputGuardrailResult.successWith(successfulText, successfulResult);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,37 @@
* @param result The result of the output guardrail validation.
* @param failures The list of failures, empty if the validation succeeded.
*/
public record OutputGuardrailResult(Result result, String successfulResult,
public record OutputGuardrailResult(Result result, String successfulText, Object successfulResult,
List<Failure> failures) implements GuardrailResult<OutputGuardrailResult> {

private static final OutputGuardrailResult SUCCESS = new OutputGuardrailResult();

private OutputGuardrailResult() {
this(Result.SUCCESS, null, Collections.emptyList());
this(Result.SUCCESS, null, null, Collections.emptyList());
}

private OutputGuardrailResult(String successfulResult) {
this(Result.SUCCESS_WITH_RESULT, successfulResult, Collections.emptyList());
private OutputGuardrailResult(String successfulText) {
this(Result.SUCCESS_WITH_RESULT, successfulText, null, Collections.emptyList());
}

private OutputGuardrailResult(String successfulText, Object successfulResult) {
this(Result.SUCCESS_WITH_RESULT, successfulText, successfulResult, Collections.emptyList());
}

OutputGuardrailResult(List<Failure> failures, boolean fatal) {
this(fatal ? Result.FATAL : Result.FAILURE, null, failures);
this(fatal ? Result.FATAL : Result.FAILURE, null, null, failures);
}

public static OutputGuardrailResult success() {
return SUCCESS;
}

public static OutputGuardrailResult successWith(String successfulResult) {
return new OutputGuardrailResult(successfulResult);
public static OutputGuardrailResult successWith(String successfulText) {
return new OutputGuardrailResult(successfulText);
}

public static OutputGuardrailResult successWith(String successfulText, Object successfulResult) {
return new OutputGuardrailResult(successfulText, successfulResult);
}

public static OutputGuardrailResult failure(List<? extends GuardrailResult.Failure> failures) {
Expand All @@ -45,7 +53,7 @@ public boolean isSuccess() {
}

@Override
public boolean isRewrittenResult() {
public boolean hasRewrittenResult() {
return result == Result.SUCCESS_WITH_RESULT;
}

Expand Down Expand Up @@ -88,7 +96,7 @@ public OutputGuardrailResult validatedBy(Class<? extends Guardrail> guardrailCla
@Override
public String toString() {
if (isSuccess()) {
return "success";
return hasRewrittenResult() ? "Success with '" + successfulText + "'" : "Success";
}
return failures.stream().map(Failure::toString).collect(Collectors.joining(", "));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.internal.Exceptions.runtime;
import static dev.langchain4j.model.output.TokenUsage.sum;
import static dev.langchain4j.service.AiServices.removeToolMessages;
import static dev.langchain4j.service.AiServices.verifyModerationIfNeeded;
import static io.quarkiverse.langchain4j.runtime.ResponseSchemaUtil.hasResponseSchema;
Expand Down Expand Up @@ -295,7 +296,7 @@ private List<ChatMessage> messagesToSend(ChatMessage augmentedUserMessage,
throw new GuardrailsSupport.GuardrailRetryException();
}
} else {
if (result.isRewrittenResult()) {
if (result.hasRewrittenResult()) {
throw new GuardrailException(
"Attempting to rewrite the LLM output while streaming is not allowed");
}
Expand Down Expand Up @@ -367,7 +368,7 @@ private List<ChatMessage> messagesToSend(ChatMessage augmentedUserMessage,
audit.addLLMToApplicationMessage(response);
}

tokenUsageAccumulator = tokenUsageAccumulator.add(response.tokenUsage());
tokenUsageAccumulator = sum(tokenUsageAccumulator, response.tokenUsage());
}

String userMessageTemplate = methodCreateInfo.getUserMessageTemplate();
Expand All @@ -380,7 +381,13 @@ private List<ChatMessage> messagesToSend(ChatMessage augmentedUserMessage,
// everything worked as expected so let's commit the messages
chatMemory.commit();

response = Response.from(response.content(), tokenUsageAccumulator, response.finishReason());
Object guardrailResult = response.metadata().get(OutputGuardrailResult.class.getName());
if (guardrailResult != null && isTypeOf(returnType, guardrailResult.getClass())) {
return guardrailResult;
}

response = Response.from(response.content(), tokenUsageAccumulator, response.finishReason(), response.metadata());

if (isResult(returnType)) {
var parsedResponse = SERVICE_OUTPUT_PARSER.parse(response, resultTypeParam((ParameterizedType) returnType));
return Result.builder()
Expand All @@ -389,9 +396,9 @@ private List<ChatMessage> messagesToSend(ChatMessage augmentedUserMessage,
.sources(augmentationResult == null ? null : augmentationResult.contents())
.finishReason(response.finishReason())
.build();
} else {
return SERVICE_OUTPUT_PARSER.parse(response, returnType);
}

return SERVICE_OUTPUT_PARSER.parse(response, returnType);
}

private static Object doImplementGenerateImage(AiServiceMethodCreateInfo methodCreateInfo, QuarkusAiServiceContext context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

import static dev.langchain4j.data.message.UserMessage.userMessage;

import java.util.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import jakarta.enterprise.inject.spi.CDI;
Expand Down Expand Up @@ -100,17 +103,22 @@ public static Response<AiMessage> invokeOutputGuardrails(AiServiceMethodCreateIn
throw new GuardrailException("Output validation failed. The guardrails have reached the maximum number of retries");
}

if (result.isRewrittenResult()) {
response = rewriteResponseWithText(response, result.successfulResult());
if (result.hasRewrittenResult()) {
response = rewriteResponse(response, result);
}

return response;
}

public static Response<AiMessage> rewriteResponseWithText(Response<AiMessage> response, String text) {
public static Response<AiMessage> rewriteResponse(Response<AiMessage> response, OutputGuardrailResult result) {
List<ToolExecutionRequest> tools = response.content().toolExecutionRequests();
AiMessage content = tools != null && !tools.isEmpty() ? new AiMessage(text, tools) : new AiMessage(text);
return new Response<>(content, response.tokenUsage(), response.finishReason(), response.metadata());
AiMessage content = tools != null && !tools.isEmpty() ? new AiMessage(result.successfulText(), tools)
: new AiMessage(result.successfulText());
Map<String, Object> metadata = response.metadata();
if (result.successfulResult() != null) {
metadata.put(OutputGuardrailResult.class.getName(), result.successfulResult());
}
return new Response<>(content, response.tokenUsage(), response.finishReason(), metadata);
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -173,10 +181,10 @@ private static <GR extends GuardrailResult> GR guardrailResult(GuardrailParams p
for (Class<? extends Guardrail> bean : classes) {
GR result = (GR) CDI.current().select(bean).get().validate(params).validatedBy(bean);
if (result.isFatal()) {
return accumulatedResults.isRewrittenResult() ? (GR) result.blockRetry() : result;
return accumulatedResults.hasRewrittenResult() ? (GR) result.blockRetry() : result;
}
if (result.isRewrittenResult()) {
params = params.withText(result.successfulResult());
if (result.hasRewrittenResult()) {
params = params.withText(result.successfulText());
}
accumulatedResults = compose(accumulatedResults, result, producer);
}
Expand Down
Loading

0 comments on commit 78fcaf4

Please sign in to comment.