From 9f6b857ac5d15f89cf5d1af27df3142921c78bf8 Mon Sep 17 00:00:00 2001 From: mariofusco Date: Thu, 14 Nov 2024 18:52:34 +0100 Subject: [PATCH] Add AbstractJsonExtractorOutputGuardrail to guardrails docs --- docs/modules/ROOT/pages/guardrails.adoc | 110 +++++++++++++++++------- 1 file changed, 78 insertions(+), 32 deletions(-) diff --git a/docs/modules/ROOT/pages/guardrails.adoc b/docs/modules/ROOT/pages/guardrails.adoc index dbd997407..1972e18a1 100644 --- a/docs/modules/ROOT/pages/guardrails.adoc +++ b/docs/modules/ROOT/pages/guardrails.adoc @@ -518,62 +518,108 @@ It may happen that the output generated by the LLM is not completely satisfying, {"name":"Alex", age:18} Alex is 18 since he became an adult a few days ago. ---- -In this situation it is better to try to programmatically trim the json part of the response and check if we can deserialize a valid Person object out of it, before trying to reprompt the LLM again. If the programmatic extraction of the json string from the partially hallucinated LLM output succeeds, it is possible to propagate the rewritten output through the `successWith` method as in the following example. +In this situation it is better to try to programmatically trim the json part of the response and check if we can deserialize a valid Person object out of it, before trying to reprompt the LLM again. If the programmatic extraction of the json string from the partially hallucinated LLM output succeeds, it is possible to propagate the rewritten output through the `successWith` method. + +This scenario is so common that it is already provided an abstract class implementing the `OutputGuardrail` interface and performing this programmatic json sanitization out-of-the-box. [source,java] ---- -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import dev.langchain4j.data.message.AiMessage; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; -import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult; -import jakarta.enterprise.context.ApplicationScoped; import jakarta.inject.Inject; import org.jboss.logging.Logger; +import com.fasterxml.jackson.core.type.TypeReference; +import dev.langchain4j.data.message.AiMessage; -@ApplicationScoped -public class ValidJsonOutputGuardrail implements OutputGuardrail { - - private static final ObjectMapper MAPPER = new ObjectMapper(); +public abstract class AbstractJsonExtractorOutputGuardrail implements OutputGuardrail { @Inject Logger logger; + @Inject + JsonGuardrailsUtils jsonGuardrailsUtils; + + protected AbstractJsonExtractorOutputGuardrail() { + if (getOutputClass() == null && getOutputType() == null) { + throw new IllegalArgumentException("Either getOutputClass() or getOutputType() must be implemented"); + } + } + @Override public OutputGuardrailResult validate(AiMessage responseFromLLM) { String llmResponse = responseFromLLM.text(); - logger.infof("LLM output: %s", llmResponse); + logger.debugf("LLM output: %s", llmResponse); - if (validateJson(llmResponse, Person.class)) { - return success(); + Object result = deserialize(llmResponse); + if (result != null) { + return successWith(llmResponse, result); } - String json = trimNonJson(llmResponse); - if (json != null && validateJson(json, Person.class)) { - return successWith(json); + String json = jsonGuardrailsUtils.trimNonJson(llmResponse); + if (json != null) { + result = deserialize(json); + if (result != null) { + return successWith(json, result); + } } return reprompt("Invalid JSON", - "Make sure you return a valid JSON object following " - + "the specified format"); + "Make sure you return a valid JSON object following " + + "the specified format"); } - private static String trimNonJson(String llmResponse) { - int jsonStart = llmResponse.indexOf("{"); - int jsonEnd = llmResponse.indexOf("}"); - if (jsonStart >= 0 && jsonEnd >= 0 && jsonStart < jsonEnd) { - return llmResponse.substring(jsonStart + 1, jsonEnd); - } + protected Object deserialize(String llmResponse) { + return getOutputClass() != null ? jsonGuardrailsUtils.deserialize(llmResponse, getOutputClass()) + : jsonGuardrailsUtils.deserialize(llmResponse, getOutputType()); + } + + protected Class getOutputClass() { return null; } - private static boolean validateJson(String json, Class expectedOutputClass) { - try { - MAPPER.readValue(json, expectedOutputClass); - return true; - } catch (JsonProcessingException e) { - return false; - } + protected TypeReference getOutputType() { + return null; + } +} +---- + +This implementation, first tries to deserialize the LLM response into the expected class to be returned by the data extraction. If this doesn't succeed it tries to trim away the non-json part of the response and perform the deserialization again. Note that in both case together with the json response, either the original LLM one or the one programmatically trimmed, the `successWith` method also returns the resulting deserialized object, so that it could be used directly as the final response of the data extraction, instead of uselessly having to execute a second deserialization. In case that both these attempts of deserialization fail then the `OutputGuardrail` perform a reprompt, hoping that the LLM will finally produce a valid json string. + +In this way if for example there is an AI service trying to extract the data of a customer from the user prompts like the following + +[source,java] +---- +@RegisterAiService +public interface CustomerExtractor { + + @UserMessage("Extract information about a customer from this text '{text}'. The response must contain only the JSON with customer's data and without any other sentence.") + @OutputGuardrails(CustomerExtractionOutputGuardrail.class) + Customer extractData(String text); +} +---- + +it is possible to use with it an `OutputGuardrail` that sanitizes the json LLM response by simply extending the former abstract class and declaring which is the expected output class of the data extraction. + +[source,java] +---- +@ApplicationScoped +public class CustomerExtractionOutputGuardrail extends AbstractJsonExtractorOutputGuardrail { + + @Override + protected Class getOutputClass() { + return Customer.class; + } +} +---- + +Note that if the data extraction requires a generified Java type, like a `List`, it is conversely necessary to extend the `getOutputType` and return a Jackson's `TypeReference` as it follows: + +[source,java] +---- +@ApplicationScoped +public class CustomersExtractionOutputGuardrail extends AbstractJsonExtractorOutputGuardrail { + + @Override + protected TypeReference getOutputType() { + return new TypeReference>() {}; } } ----