diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java index b353f817e..d42dafe47 100644 --- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java +++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java @@ -94,6 +94,12 @@ void testThatRepromptAfterRewriteIsNotAllowed() { .withMessageContaining("Retry or reprompt is not allowed after a rewritten output"); } + @Test + @ActivateRequestContext + void testThatRewritesTheOutputWithAResult() { + assertThat(aiService.rewritingSuccessWithResult("1", "foo")).isSameAs(RewritingGuardrailWithResult.RESULT); + } + @RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class) public interface MyAiService { @@ -112,6 +118,9 @@ public interface MyAiService { @OutputGuardrails({ FirstRewritingGuardrail.class, RepromptingGuardrail.class }) String repromptAfterRewrite(@MemoryId String mem, @UserMessage String message); + @OutputGuardrails({ FirstRewritingGuardrail.class, RewritingGuardrailWithResult.class }) + Integer rewritingSuccessWithResult(@MemoryId String mem, @UserMessage String message); + } @RequestScoped @@ -206,6 +215,18 @@ public OutputGuardrailResult validate(AiMessage responseFromLLM) { } } + @RequestScoped + public static class RewritingGuardrailWithResult implements OutputGuardrail { + + static final Integer RESULT = 1_000; + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + String text = responseFromLLM.text(); + return successWith(text + ",2", RESULT); + } + } + @RequestScoped public static class RepromptingGuardrail implements OutputGuardrail { diff --git a/core/runtime/pom.xml b/core/runtime/pom.xml index 603472071..62b87b67a 100644 --- a/core/runtime/pom.xml +++ b/core/runtime/pom.xml @@ -140,6 +140,11 @@ junit-jupiter test + + io.quarkus + quarkus-junit5 + test + diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/AbstractJsonExtractorOutputGuardrail.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/AbstractJsonExtractorOutputGuardrail.java new file mode 100644 index 000000000..488c832ca --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/AbstractJsonExtractorOutputGuardrail.java @@ -0,0 +1,60 @@ +package io.quarkiverse.langchain4j.guardrails; + +import jakarta.inject.Inject; + +import org.jboss.logging.Logger; + +import com.fasterxml.jackson.core.type.TypeReference; + +import dev.langchain4j.data.message.AiMessage; + +public abstract class AbstractJsonExtractorOutputGuardrail implements OutputGuardrail { + + @Inject + Logger logger; + + @Inject + JsonGuardrailsUtils jsonGuardrailsUtils; + + protected AbstractJsonExtractorOutputGuardrail() { + if (getOutputClass() == null && getOutputType() == null) { + throw new IllegalArgumentException("Either getOutputClass() or getOutputType() must be implemented"); + } + } + + @Override + public OutputGuardrailResult validate(AiMessage responseFromLLM) { + String llmResponse = responseFromLLM.text(); + logger.debugf("LLM output: %s", llmResponse); + + Object result = deserialize(llmResponse); + if (result != null) { + return successWith(llmResponse, result); + } + + String json = jsonGuardrailsUtils.trimNonJson(llmResponse); + if (json != null) { + result = deserialize(json); + if (result != null) { + return successWith(json, result); + } + } + + return reprompt("Invalid JSON", + "Make sure you return a valid JSON object following " + + "the specified format"); + } + + protected Object deserialize(String llmResponse) { + return getOutputClass() != null ? jsonGuardrailsUtils.deserialize(llmResponse, getOutputClass()) + : jsonGuardrailsUtils.deserialize(llmResponse, getOutputType()); + } + + protected Class getOutputClass() { + return null; + } + + protected TypeReference getOutputType() { + return null; + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java index f731a17ea..965860f31 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java @@ -31,7 +31,7 @@ enum Result { boolean isSuccess(); - default boolean isRewrittenResult() { + default boolean hasRewrittenResult() { return false; } @@ -39,7 +39,11 @@ default GuardrailResult blockRetry() { throw new UnsupportedOperationException(); } - default String successfulResult() { + default String successfulText() { + throw new UnsupportedOperationException(); + } + + default Object successfulResult() { throw new UnsupportedOperationException(); } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/JsonGuardrailsUtils.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/JsonGuardrailsUtils.java new file mode 100644 index 000000000..4c09b535e --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/JsonGuardrailsUtils.java @@ -0,0 +1,47 @@ +package io.quarkiverse.langchain4j.guardrails; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +@ApplicationScoped +class JsonGuardrailsUtils { + + @Inject + ObjectMapper objectMapper; + + private JsonGuardrailsUtils() { + } + + String trimNonJson(String llmResponse) { + int jsonMapStart = llmResponse.indexOf('{'); + int jsonListStart = llmResponse.indexOf('['); + if (jsonMapStart < 0 && jsonListStart < 0) { + return null; + } + boolean isJsonMap = jsonMapStart >= 0 && (jsonMapStart < jsonListStart || jsonListStart < 0); + + int jsonStart = isJsonMap ? jsonMapStart : jsonListStart; + int jsonEnd = isJsonMap ? llmResponse.lastIndexOf('}') : llmResponse.lastIndexOf(']'); + return jsonEnd >= 0 && jsonStart < jsonEnd ? llmResponse.substring(jsonStart, jsonEnd + 1) : null; + } + + T deserialize(String json, Class expectedOutputClass) { + try { + return objectMapper.readValue(json, expectedOutputClass); + } catch (JsonProcessingException e) { + return null; + } + } + + T deserialize(String json, TypeReference expectedOutputType) { + try { + return objectMapper.readValue(json, expectedOutputType); + } catch (JsonProcessingException e) { + return null; + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java index 762b5478f..3d3489f3f 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java @@ -51,11 +51,20 @@ default OutputGuardrailResult success() { } /** - * @return The result of a successful output guardrail validation with a specific result. - * @param successfulResult The successful result. + * @return The result of a successful output guardrail validation with a specific text. + * @param successfulText The text of the successful result. */ - default OutputGuardrailResult successWith(String successfulResult) { - return OutputGuardrailResult.successWith(successfulResult); + default OutputGuardrailResult successWith(String successfulText) { + return OutputGuardrailResult.successWith(successfulText); + } + + /** + * @return The result of a successful output guardrail validation with a specific text. + * @param successfulText The text of the successful result. + * @param successfulResult The object generated by this successful result. + */ + default OutputGuardrailResult successWith(String successfulText, Object successfulResult) { + return OutputGuardrailResult.successWith(successfulText, successfulResult); } /** diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java index 2b139fa0d..28d85dc63 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java @@ -10,29 +10,37 @@ * @param result The result of the output guardrail validation. * @param failures The list of failures, empty if the validation succeeded. */ -public record OutputGuardrailResult(Result result, String successfulResult, +public record OutputGuardrailResult(Result result, String successfulText, Object successfulResult, List failures) implements GuardrailResult { private static final OutputGuardrailResult SUCCESS = new OutputGuardrailResult(); private OutputGuardrailResult() { - this(Result.SUCCESS, null, Collections.emptyList()); + this(Result.SUCCESS, null, null, Collections.emptyList()); } - private OutputGuardrailResult(String successfulResult) { - this(Result.SUCCESS_WITH_RESULT, successfulResult, Collections.emptyList()); + private OutputGuardrailResult(String successfulText) { + this(Result.SUCCESS_WITH_RESULT, successfulText, null, Collections.emptyList()); + } + + private OutputGuardrailResult(String successfulText, Object successfulResult) { + this(Result.SUCCESS_WITH_RESULT, successfulText, successfulResult, Collections.emptyList()); } OutputGuardrailResult(List failures, boolean fatal) { - this(fatal ? Result.FATAL : Result.FAILURE, null, failures); + this(fatal ? Result.FATAL : Result.FAILURE, null, null, failures); } public static OutputGuardrailResult success() { return SUCCESS; } - public static OutputGuardrailResult successWith(String successfulResult) { - return new OutputGuardrailResult(successfulResult); + public static OutputGuardrailResult successWith(String successfulText) { + return new OutputGuardrailResult(successfulText); + } + + public static OutputGuardrailResult successWith(String successfulText, Object successfulResult) { + return new OutputGuardrailResult(successfulText, successfulResult); } public static OutputGuardrailResult failure(List failures) { @@ -45,7 +53,7 @@ public boolean isSuccess() { } @Override - public boolean isRewrittenResult() { + public boolean hasRewrittenResult() { return result == Result.SUCCESS_WITH_RESULT; } @@ -88,7 +96,7 @@ public OutputGuardrailResult validatedBy(Class guardrailCla @Override public String toString() { if (isSuccess()) { - return "success"; + return hasRewrittenResult() ? "Success with '" + successfulText + "'" : "Success"; } return failures.stream().map(Failure::toString).collect(Collectors.joining(", ")); } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java index 111af826c..348d0591d 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java @@ -2,6 +2,7 @@ import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.internal.Exceptions.runtime; +import static dev.langchain4j.model.output.TokenUsage.sum; import static dev.langchain4j.service.AiServices.removeToolMessages; import static dev.langchain4j.service.AiServices.verifyModerationIfNeeded; import static io.quarkiverse.langchain4j.runtime.ResponseSchemaUtil.hasResponseSchema; @@ -295,7 +296,7 @@ private List messagesToSend(ChatMessage augmentedUserMessage, throw new GuardrailsSupport.GuardrailRetryException(); } } else { - if (result.isRewrittenResult()) { + if (result.hasRewrittenResult()) { throw new GuardrailException( "Attempting to rewrite the LLM output while streaming is not allowed"); } @@ -367,7 +368,7 @@ private List messagesToSend(ChatMessage augmentedUserMessage, audit.addLLMToApplicationMessage(response); } - tokenUsageAccumulator = tokenUsageAccumulator.add(response.tokenUsage()); + tokenUsageAccumulator = sum(tokenUsageAccumulator, response.tokenUsage()); } String userMessageTemplate = methodCreateInfo.getUserMessageTemplate(); @@ -380,7 +381,13 @@ private List messagesToSend(ChatMessage augmentedUserMessage, // everything worked as expected so let's commit the messages chatMemory.commit(); - response = Response.from(response.content(), tokenUsageAccumulator, response.finishReason()); + Object guardrailResult = response.metadata().get(OutputGuardrailResult.class.getName()); + if (guardrailResult != null && isTypeOf(returnType, guardrailResult.getClass())) { + return guardrailResult; + } + + response = Response.from(response.content(), tokenUsageAccumulator, response.finishReason(), response.metadata()); + if (isResult(returnType)) { var parsedResponse = SERVICE_OUTPUT_PARSER.parse(response, resultTypeParam((ParameterizedType) returnType)); return Result.builder() @@ -389,9 +396,9 @@ private List messagesToSend(ChatMessage augmentedUserMessage, .sources(augmentationResult == null ? null : augmentationResult.contents()) .finishReason(response.finishReason()) .build(); - } else { - return SERVICE_OUTPUT_PARSER.parse(response, returnType); } + + return SERVICE_OUTPUT_PARSER.parse(response, returnType); } private static Object doImplementGenerateImage(AiServiceMethodCreateInfo methodCreateInfo, QuarkusAiServiceContext context, diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java index c4bb7315b..0fc62a060 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java @@ -2,7 +2,10 @@ import static dev.langchain4j.data.message.UserMessage.userMessage; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; import java.util.function.Function; import jakarta.enterprise.inject.spi.CDI; @@ -100,17 +103,22 @@ public static Response invokeOutputGuardrails(AiServiceMethodCreateIn throw new GuardrailException("Output validation failed. The guardrails have reached the maximum number of retries"); } - if (result.isRewrittenResult()) { - response = rewriteResponseWithText(response, result.successfulResult()); + if (result.hasRewrittenResult()) { + response = rewriteResponse(response, result); } return response; } - public static Response rewriteResponseWithText(Response response, String text) { + public static Response rewriteResponse(Response response, OutputGuardrailResult result) { List tools = response.content().toolExecutionRequests(); - AiMessage content = tools != null && !tools.isEmpty() ? new AiMessage(text, tools) : new AiMessage(text); - return new Response<>(content, response.tokenUsage(), response.finishReason(), response.metadata()); + AiMessage content = tools != null && !tools.isEmpty() ? new AiMessage(result.successfulText(), tools) + : new AiMessage(result.successfulText()); + Map metadata = response.metadata(); + if (result.successfulResult() != null) { + metadata.put(OutputGuardrailResult.class.getName(), result.successfulResult()); + } + return new Response<>(content, response.tokenUsage(), response.finishReason(), metadata); } @SuppressWarnings("unchecked") @@ -173,10 +181,10 @@ private static GR guardrailResult(GuardrailParams p for (Class bean : classes) { GR result = (GR) CDI.current().select(bean).get().validate(params).validatedBy(bean); if (result.isFatal()) { - return accumulatedResults.isRewrittenResult() ? (GR) result.blockRetry() : result; + return accumulatedResults.hasRewrittenResult() ? (GR) result.blockRetry() : result; } - if (result.isRewrittenResult()) { - params = params.withText(result.successfulResult()); + if (result.hasRewrittenResult()) { + params = params.withText(result.successfulText()); } accumulatedResults = compose(accumulatedResults, result, producer); } diff --git a/core/runtime/src/test/java/io/quarkiverse/langchain4j/guardrails/JsonGuardrailsUtilsTest.java b/core/runtime/src/test/java/io/quarkiverse/langchain4j/guardrails/JsonGuardrailsUtilsTest.java new file mode 100644 index 000000000..e6cd37594 --- /dev/null +++ b/core/runtime/src/test/java/io/quarkiverse/langchain4j/guardrails/JsonGuardrailsUtilsTest.java @@ -0,0 +1,67 @@ +package io.quarkiverse.langchain4j.guardrails; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.List; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.core.type.TypeReference; + +import io.quarkus.test.junit.QuarkusTest; + +@QuarkusTest +class JsonGuardrailsUtilsTest { + + @Inject + JsonGuardrailsUtils jsonGuardrailsUtils; + + record Person(String firstName, String lastName, int age) { + } + + @Test + public void testJsonMapExtraction() { + String input = "Here is some text before the JSON part: {\"key\": \"value\", \"nested\": {\"innerKey\": 42}} and some text after."; + String json = jsonGuardrailsUtils.trimNonJson(input); + assertEquals("{\"key\": \"value\", \"nested\": {\"innerKey\": 42}}", json); + } + + @Test + public void testJsonListExtraction() { + String input = "Here is some text before the JSON part: [{\"key\": \"value\", \"nested\": {\"innerKey\": 42}}, {\"key\": \"value\", \"nested\": {\"innerKey\": 42}}] and some text after."; + String json = jsonGuardrailsUtils.trimNonJson(input); + assertEquals( + "[{\"key\": \"value\", \"nested\": {\"innerKey\": 42}}, {\"key\": \"value\", \"nested\": {\"innerKey\": 42}}]", + json); + } + + @Test + public void testJsonValidation() { + String input = "{\"firstName\": \"Mario\", \"lastName\": \"Fusco\", \"age\": 50} Mario turned 50 a few days ago."; + String json = jsonGuardrailsUtils.trimNonJson(input); + Person person = jsonGuardrailsUtils.deserialize(json, Person.class); + assertEquals("Mario", person.firstName); + assertEquals("Fusco", person.lastName); + assertEquals(50, person.age); + } + + @Test + public void testJsonListValidation() { + String input = "Let me introduce you [{\"firstName\": \"Mario\", \"lastName\": \"Fusco\", \"age\": 50}, {\"firstName\": \"Sofia\", \"lastName\": \"Fusco\", \"age\": 13}] Mario and his daughter."; + String json = jsonGuardrailsUtils.trimNonJson(input); + List family = jsonGuardrailsUtils.deserialize(json, new TypeReference<>() { + }); + + Person dad = family.get(0); + assertEquals("Mario", dad.firstName); + assertEquals("Fusco", dad.lastName); + assertEquals(50, dad.age); + + Person daughter = family.get(1); + assertEquals("Sofia", daughter.firstName); + assertEquals("Fusco", daughter.lastName); + assertEquals(13, daughter.age); + } +}