diff --git a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java
index b353f817e..d42dafe47 100644
--- a/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java
+++ b/core/deployment/src/test/java/io/quarkiverse/langchain4j/test/guardrails/OutputGuardrailChainTest.java
@@ -94,6 +94,12 @@ void testThatRepromptAfterRewriteIsNotAllowed() {
.withMessageContaining("Retry or reprompt is not allowed after a rewritten output");
}
+ @Test
+ @ActivateRequestContext
+ void testThatRewritesTheOutputWithAResult() {
+ assertThat(aiService.rewritingSuccessWithResult("1", "foo")).isSameAs(RewritingGuardrailWithResult.RESULT);
+ }
+
@RegisterAiService(chatLanguageModelSupplier = MyChatModelSupplier.class, chatMemoryProviderSupplier = MyMemoryProviderSupplier.class)
public interface MyAiService {
@@ -112,6 +118,9 @@ public interface MyAiService {
@OutputGuardrails({ FirstRewritingGuardrail.class, RepromptingGuardrail.class })
String repromptAfterRewrite(@MemoryId String mem, @UserMessage String message);
+ @OutputGuardrails({ FirstRewritingGuardrail.class, RewritingGuardrailWithResult.class })
+ Integer rewritingSuccessWithResult(@MemoryId String mem, @UserMessage String message);
+
}
@RequestScoped
@@ -206,6 +215,18 @@ public OutputGuardrailResult validate(AiMessage responseFromLLM) {
}
}
+ @RequestScoped
+ public static class RewritingGuardrailWithResult implements OutputGuardrail {
+
+ static final Integer RESULT = 1_000;
+
+ @Override
+ public OutputGuardrailResult validate(AiMessage responseFromLLM) {
+ String text = responseFromLLM.text();
+ return successWith(text + ",2", RESULT);
+ }
+ }
+
@RequestScoped
public static class RepromptingGuardrail implements OutputGuardrail {
diff --git a/core/runtime/pom.xml b/core/runtime/pom.xml
index 603472071..62b87b67a 100644
--- a/core/runtime/pom.xml
+++ b/core/runtime/pom.xml
@@ -140,6 +140,11 @@
junit-jupiter
test
+
+ io.quarkus
+ quarkus-junit5
+ test
+
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/AbstractJsonExtractorOutputGuardrail.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/AbstractJsonExtractorOutputGuardrail.java
new file mode 100644
index 000000000..488c832ca
--- /dev/null
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/AbstractJsonExtractorOutputGuardrail.java
@@ -0,0 +1,60 @@
+package io.quarkiverse.langchain4j.guardrails;
+
+import jakarta.inject.Inject;
+
+import org.jboss.logging.Logger;
+
+import com.fasterxml.jackson.core.type.TypeReference;
+
+import dev.langchain4j.data.message.AiMessage;
+
+public abstract class AbstractJsonExtractorOutputGuardrail implements OutputGuardrail {
+
+ @Inject
+ Logger logger;
+
+ @Inject
+ JsonGuardrailsUtils jsonGuardrailsUtils;
+
+ protected AbstractJsonExtractorOutputGuardrail() {
+ if (getOutputClass() == null && getOutputType() == null) {
+ throw new IllegalArgumentException("Either getOutputClass() or getOutputType() must be implemented");
+ }
+ }
+
+ @Override
+ public OutputGuardrailResult validate(AiMessage responseFromLLM) {
+ String llmResponse = responseFromLLM.text();
+ logger.debugf("LLM output: %s", llmResponse);
+
+ Object result = deserialize(llmResponse);
+ if (result != null) {
+ return successWith(llmResponse, result);
+ }
+
+ String json = jsonGuardrailsUtils.trimNonJson(llmResponse);
+ if (json != null) {
+ result = deserialize(json);
+ if (result != null) {
+ return successWith(json, result);
+ }
+ }
+
+ return reprompt("Invalid JSON",
+ "Make sure you return a valid JSON object following "
+ + "the specified format");
+ }
+
+ protected Object deserialize(String llmResponse) {
+ return getOutputClass() != null ? jsonGuardrailsUtils.deserialize(llmResponse, getOutputClass())
+ : jsonGuardrailsUtils.deserialize(llmResponse, getOutputType());
+ }
+
+ protected Class> getOutputClass() {
+ return null;
+ }
+
+ protected TypeReference> getOutputType() {
+ return null;
+ }
+}
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java
index f731a17ea..965860f31 100644
--- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/GuardrailResult.java
@@ -31,7 +31,7 @@ enum Result {
boolean isSuccess();
- default boolean isRewrittenResult() {
+ default boolean hasRewrittenResult() {
return false;
}
@@ -39,7 +39,11 @@ default GuardrailResult blockRetry() {
throw new UnsupportedOperationException();
}
- default String successfulResult() {
+ default String successfulText() {
+ throw new UnsupportedOperationException();
+ }
+
+ default Object successfulResult() {
throw new UnsupportedOperationException();
}
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/JsonGuardrailsUtils.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/JsonGuardrailsUtils.java
new file mode 100644
index 000000000..4c09b535e
--- /dev/null
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/JsonGuardrailsUtils.java
@@ -0,0 +1,47 @@
+package io.quarkiverse.langchain4j.guardrails;
+
+import jakarta.enterprise.context.ApplicationScoped;
+import jakarta.inject.Inject;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+@ApplicationScoped
+class JsonGuardrailsUtils {
+
+ @Inject
+ ObjectMapper objectMapper;
+
+ private JsonGuardrailsUtils() {
+ }
+
+ String trimNonJson(String llmResponse) {
+ int jsonMapStart = llmResponse.indexOf('{');
+ int jsonListStart = llmResponse.indexOf('[');
+ if (jsonMapStart < 0 && jsonListStart < 0) {
+ return null;
+ }
+ boolean isJsonMap = jsonMapStart >= 0 && (jsonMapStart < jsonListStart || jsonListStart < 0);
+
+ int jsonStart = isJsonMap ? jsonMapStart : jsonListStart;
+ int jsonEnd = isJsonMap ? llmResponse.lastIndexOf('}') : llmResponse.lastIndexOf(']');
+ return jsonEnd >= 0 && jsonStart < jsonEnd ? llmResponse.substring(jsonStart, jsonEnd + 1) : null;
+ }
+
+ T deserialize(String json, Class expectedOutputClass) {
+ try {
+ return objectMapper.readValue(json, expectedOutputClass);
+ } catch (JsonProcessingException e) {
+ return null;
+ }
+ }
+
+ T deserialize(String json, TypeReference expectedOutputType) {
+ try {
+ return objectMapper.readValue(json, expectedOutputType);
+ } catch (JsonProcessingException e) {
+ return null;
+ }
+ }
+}
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java
index 762b5478f..3d3489f3f 100644
--- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrail.java
@@ -51,11 +51,20 @@ default OutputGuardrailResult success() {
}
/**
- * @return The result of a successful output guardrail validation with a specific result.
- * @param successfulResult The successful result.
+ * @return The result of a successful output guardrail validation with a specific text.
+ * @param successfulText The text of the successful result.
*/
- default OutputGuardrailResult successWith(String successfulResult) {
- return OutputGuardrailResult.successWith(successfulResult);
+ default OutputGuardrailResult successWith(String successfulText) {
+ return OutputGuardrailResult.successWith(successfulText);
+ }
+
+ /**
+ * @return The result of a successful output guardrail validation with a specific text.
+ * @param successfulText The text of the successful result.
+ * @param successfulResult The object generated by this successful result.
+ */
+ default OutputGuardrailResult successWith(String successfulText, Object successfulResult) {
+ return OutputGuardrailResult.successWith(successfulText, successfulResult);
}
/**
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java
index 2b139fa0d..28d85dc63 100644
--- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/OutputGuardrailResult.java
@@ -10,29 +10,37 @@
* @param result The result of the output guardrail validation.
* @param failures The list of failures, empty if the validation succeeded.
*/
-public record OutputGuardrailResult(Result result, String successfulResult,
+public record OutputGuardrailResult(Result result, String successfulText, Object successfulResult,
List failures) implements GuardrailResult {
private static final OutputGuardrailResult SUCCESS = new OutputGuardrailResult();
private OutputGuardrailResult() {
- this(Result.SUCCESS, null, Collections.emptyList());
+ this(Result.SUCCESS, null, null, Collections.emptyList());
}
- private OutputGuardrailResult(String successfulResult) {
- this(Result.SUCCESS_WITH_RESULT, successfulResult, Collections.emptyList());
+ private OutputGuardrailResult(String successfulText) {
+ this(Result.SUCCESS_WITH_RESULT, successfulText, null, Collections.emptyList());
+ }
+
+ private OutputGuardrailResult(String successfulText, Object successfulResult) {
+ this(Result.SUCCESS_WITH_RESULT, successfulText, successfulResult, Collections.emptyList());
}
OutputGuardrailResult(List failures, boolean fatal) {
- this(fatal ? Result.FATAL : Result.FAILURE, null, failures);
+ this(fatal ? Result.FATAL : Result.FAILURE, null, null, failures);
}
public static OutputGuardrailResult success() {
return SUCCESS;
}
- public static OutputGuardrailResult successWith(String successfulResult) {
- return new OutputGuardrailResult(successfulResult);
+ public static OutputGuardrailResult successWith(String successfulText) {
+ return new OutputGuardrailResult(successfulText);
+ }
+
+ public static OutputGuardrailResult successWith(String successfulText, Object successfulResult) {
+ return new OutputGuardrailResult(successfulText, successfulResult);
}
public static OutputGuardrailResult failure(List extends GuardrailResult.Failure> failures) {
@@ -45,7 +53,7 @@ public boolean isSuccess() {
}
@Override
- public boolean isRewrittenResult() {
+ public boolean hasRewrittenResult() {
return result == Result.SUCCESS_WITH_RESULT;
}
@@ -88,7 +96,7 @@ public OutputGuardrailResult validatedBy(Class extends Guardrail> guardrailCla
@Override
public String toString() {
if (isSuccess()) {
- return "success";
+ return hasRewrittenResult() ? "Success with '" + successfulText + "'" : "Success";
}
return failures.stream().map(Failure::toString).collect(Collectors.joining(", "));
}
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java
index 111af826c..348d0591d 100644
--- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java
@@ -2,6 +2,7 @@
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.internal.Exceptions.runtime;
+import static dev.langchain4j.model.output.TokenUsage.sum;
import static dev.langchain4j.service.AiServices.removeToolMessages;
import static dev.langchain4j.service.AiServices.verifyModerationIfNeeded;
import static io.quarkiverse.langchain4j.runtime.ResponseSchemaUtil.hasResponseSchema;
@@ -295,7 +296,7 @@ private List messagesToSend(ChatMessage augmentedUserMessage,
throw new GuardrailsSupport.GuardrailRetryException();
}
} else {
- if (result.isRewrittenResult()) {
+ if (result.hasRewrittenResult()) {
throw new GuardrailException(
"Attempting to rewrite the LLM output while streaming is not allowed");
}
@@ -367,7 +368,7 @@ private List messagesToSend(ChatMessage augmentedUserMessage,
audit.addLLMToApplicationMessage(response);
}
- tokenUsageAccumulator = tokenUsageAccumulator.add(response.tokenUsage());
+ tokenUsageAccumulator = sum(tokenUsageAccumulator, response.tokenUsage());
}
String userMessageTemplate = methodCreateInfo.getUserMessageTemplate();
@@ -380,7 +381,13 @@ private List messagesToSend(ChatMessage augmentedUserMessage,
// everything worked as expected so let's commit the messages
chatMemory.commit();
- response = Response.from(response.content(), tokenUsageAccumulator, response.finishReason());
+ Object guardrailResult = response.metadata().get(OutputGuardrailResult.class.getName());
+ if (guardrailResult != null && isTypeOf(returnType, guardrailResult.getClass())) {
+ return guardrailResult;
+ }
+
+ response = Response.from(response.content(), tokenUsageAccumulator, response.finishReason(), response.metadata());
+
if (isResult(returnType)) {
var parsedResponse = SERVICE_OUTPUT_PARSER.parse(response, resultTypeParam((ParameterizedType) returnType));
return Result.builder()
@@ -389,9 +396,9 @@ private List messagesToSend(ChatMessage augmentedUserMessage,
.sources(augmentationResult == null ? null : augmentationResult.contents())
.finishReason(response.finishReason())
.build();
- } else {
- return SERVICE_OUTPUT_PARSER.parse(response, returnType);
}
+
+ return SERVICE_OUTPUT_PARSER.parse(response, returnType);
}
private static Object doImplementGenerateImage(AiServiceMethodCreateInfo methodCreateInfo, QuarkusAiServiceContext context,
diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java
index c4bb7315b..0fc62a060 100644
--- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java
+++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/GuardrailsSupport.java
@@ -2,7 +2,10 @@
import static dev.langchain4j.data.message.UserMessage.userMessage;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
import java.util.function.Function;
import jakarta.enterprise.inject.spi.CDI;
@@ -100,17 +103,22 @@ public static Response invokeOutputGuardrails(AiServiceMethodCreateIn
throw new GuardrailException("Output validation failed. The guardrails have reached the maximum number of retries");
}
- if (result.isRewrittenResult()) {
- response = rewriteResponseWithText(response, result.successfulResult());
+ if (result.hasRewrittenResult()) {
+ response = rewriteResponse(response, result);
}
return response;
}
- public static Response rewriteResponseWithText(Response response, String text) {
+ public static Response rewriteResponse(Response response, OutputGuardrailResult result) {
List tools = response.content().toolExecutionRequests();
- AiMessage content = tools != null && !tools.isEmpty() ? new AiMessage(text, tools) : new AiMessage(text);
- return new Response<>(content, response.tokenUsage(), response.finishReason(), response.metadata());
+ AiMessage content = tools != null && !tools.isEmpty() ? new AiMessage(result.successfulText(), tools)
+ : new AiMessage(result.successfulText());
+ Map metadata = response.metadata();
+ if (result.successfulResult() != null) {
+ metadata.put(OutputGuardrailResult.class.getName(), result.successfulResult());
+ }
+ return new Response<>(content, response.tokenUsage(), response.finishReason(), metadata);
}
@SuppressWarnings("unchecked")
@@ -173,10 +181,10 @@ private static GR guardrailResult(GuardrailParams p
for (Class extends Guardrail> bean : classes) {
GR result = (GR) CDI.current().select(bean).get().validate(params).validatedBy(bean);
if (result.isFatal()) {
- return accumulatedResults.isRewrittenResult() ? (GR) result.blockRetry() : result;
+ return accumulatedResults.hasRewrittenResult() ? (GR) result.blockRetry() : result;
}
- if (result.isRewrittenResult()) {
- params = params.withText(result.successfulResult());
+ if (result.hasRewrittenResult()) {
+ params = params.withText(result.successfulText());
}
accumulatedResults = compose(accumulatedResults, result, producer);
}
diff --git a/core/runtime/src/test/java/io/quarkiverse/langchain4j/guardrails/JsonGuardrailsUtilsTest.java b/core/runtime/src/test/java/io/quarkiverse/langchain4j/guardrails/JsonGuardrailsUtilsTest.java
new file mode 100644
index 000000000..e6cd37594
--- /dev/null
+++ b/core/runtime/src/test/java/io/quarkiverse/langchain4j/guardrails/JsonGuardrailsUtilsTest.java
@@ -0,0 +1,67 @@
+package io.quarkiverse.langchain4j.guardrails;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import java.util.List;
+
+import jakarta.inject.Inject;
+
+import org.junit.jupiter.api.Test;
+
+import com.fasterxml.jackson.core.type.TypeReference;
+
+import io.quarkus.test.junit.QuarkusTest;
+
+@QuarkusTest
+class JsonGuardrailsUtilsTest {
+
+ @Inject
+ JsonGuardrailsUtils jsonGuardrailsUtils;
+
+ record Person(String firstName, String lastName, int age) {
+ }
+
+ @Test
+ public void testJsonMapExtraction() {
+ String input = "Here is some text before the JSON part: {\"key\": \"value\", \"nested\": {\"innerKey\": 42}} and some text after.";
+ String json = jsonGuardrailsUtils.trimNonJson(input);
+ assertEquals("{\"key\": \"value\", \"nested\": {\"innerKey\": 42}}", json);
+ }
+
+ @Test
+ public void testJsonListExtraction() {
+ String input = "Here is some text before the JSON part: [{\"key\": \"value\", \"nested\": {\"innerKey\": 42}}, {\"key\": \"value\", \"nested\": {\"innerKey\": 42}}] and some text after.";
+ String json = jsonGuardrailsUtils.trimNonJson(input);
+ assertEquals(
+ "[{\"key\": \"value\", \"nested\": {\"innerKey\": 42}}, {\"key\": \"value\", \"nested\": {\"innerKey\": 42}}]",
+ json);
+ }
+
+ @Test
+ public void testJsonValidation() {
+ String input = "{\"firstName\": \"Mario\", \"lastName\": \"Fusco\", \"age\": 50} Mario turned 50 a few days ago.";
+ String json = jsonGuardrailsUtils.trimNonJson(input);
+ Person person = jsonGuardrailsUtils.deserialize(json, Person.class);
+ assertEquals("Mario", person.firstName);
+ assertEquals("Fusco", person.lastName);
+ assertEquals(50, person.age);
+ }
+
+ @Test
+ public void testJsonListValidation() {
+ String input = "Let me introduce you [{\"firstName\": \"Mario\", \"lastName\": \"Fusco\", \"age\": 50}, {\"firstName\": \"Sofia\", \"lastName\": \"Fusco\", \"age\": 13}] Mario and his daughter.";
+ String json = jsonGuardrailsUtils.trimNonJson(input);
+ List family = jsonGuardrailsUtils.deserialize(json, new TypeReference<>() {
+ });
+
+ Person dad = family.get(0);
+ assertEquals("Mario", dad.firstName);
+ assertEquals("Fusco", dad.lastName);
+ assertEquals(50, dad.age);
+
+ Person daughter = family.get(1);
+ assertEquals("Sofia", daughter.firstName);
+ assertEquals("Fusco", daughter.lastName);
+ assertEquals(13, daughter.age);
+ }
+}