Skip to content

Commit

Permalink
Add AbstractJsonExtractorOutputGuardrail to guardrails docs
Browse files Browse the repository at this point in the history
  • Loading branch information
mariofusco committed Nov 14, 2024
1 parent 74b7c12 commit 9f6b857
Showing 1 changed file with 78 additions and 32 deletions.
110 changes: 78 additions & 32 deletions docs/modules/ROOT/pages/guardrails.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -518,62 +518,108 @@ It may happen that the output generated by the LLM is not completely satisfying,
{"name":"Alex", age:18} Alex is 18 since he became an adult a few days ago.
----

In this situation it is better to try to programmatically trim the json part of the response and check if we can deserialize a valid Person object out of it, before trying to reprompt the LLM again. If the programmatic extraction of the json string from the partially hallucinated LLM output succeeds, it is possible to propagate the rewritten output through the `successWith` method as in the following example.
In this situation it is better to try to programmatically trim the json part of the response and check if we can deserialize a valid Person object out of it, before trying to reprompt the LLM again. If the programmatic extraction of the json string from the partially hallucinated LLM output succeeds, it is possible to propagate the rewritten output through the `successWith` method.

This scenario is so common that it is already provided an abstract class implementing the `OutputGuardrail` interface and performing this programmatic json sanitization out-of-the-box.

[source,java]
----
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.message.AiMessage;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import org.jboss.logging.Logger;
import com.fasterxml.jackson.core.type.TypeReference;
import dev.langchain4j.data.message.AiMessage;
@ApplicationScoped
public class ValidJsonOutputGuardrail implements OutputGuardrail {
private static final ObjectMapper MAPPER = new ObjectMapper();
public abstract class AbstractJsonExtractorOutputGuardrail implements OutputGuardrail {
@Inject
Logger logger;
@Inject
JsonGuardrailsUtils jsonGuardrailsUtils;
protected AbstractJsonExtractorOutputGuardrail() {
if (getOutputClass() == null && getOutputType() == null) {
throw new IllegalArgumentException("Either getOutputClass() or getOutputType() must be implemented");
}
}
@Override
public OutputGuardrailResult validate(AiMessage responseFromLLM) {
String llmResponse = responseFromLLM.text();
logger.infof("LLM output: %s", llmResponse);
logger.debugf("LLM output: %s", llmResponse);
if (validateJson(llmResponse, Person.class)) {
return success();
Object result = deserialize(llmResponse);
if (result != null) {
return successWith(llmResponse, result);
}
String json = trimNonJson(llmResponse);
if (json != null && validateJson(json, Person.class)) {
return successWith(json);
String json = jsonGuardrailsUtils.trimNonJson(llmResponse);
if (json != null) {
result = deserialize(json);
if (result != null) {
return successWith(json, result);
}
}
return reprompt("Invalid JSON",
"Make sure you return a valid JSON object following "
+ "the specified format");
"Make sure you return a valid JSON object following "
+ "the specified format");
}
private static String trimNonJson(String llmResponse) {
int jsonStart = llmResponse.indexOf("{");
int jsonEnd = llmResponse.indexOf("}");
if (jsonStart >= 0 && jsonEnd >= 0 && jsonStart < jsonEnd) {
return llmResponse.substring(jsonStart + 1, jsonEnd);
}
protected Object deserialize(String llmResponse) {
return getOutputClass() != null ? jsonGuardrailsUtils.deserialize(llmResponse, getOutputClass())
: jsonGuardrailsUtils.deserialize(llmResponse, getOutputType());
}
protected Class<?> getOutputClass() {
return null;
}
private static boolean validateJson(String json, Class<?> expectedOutputClass) {
try {
MAPPER.readValue(json, expectedOutputClass);
return true;
} catch (JsonProcessingException e) {
return false;
}
protected TypeReference<?> getOutputType() {
return null;
}
}
----

This implementation, first tries to deserialize the LLM response into the expected class to be returned by the data extraction. If this doesn't succeed it tries to trim away the non-json part of the response and perform the deserialization again. Note that in both case together with the json response, either the original LLM one or the one programmatically trimmed, the `successWith` method also returns the resulting deserialized object, so that it could be used directly as the final response of the data extraction, instead of uselessly having to execute a second deserialization. In case that both these attempts of deserialization fail then the `OutputGuardrail` perform a reprompt, hoping that the LLM will finally produce a valid json string.

In this way if for example there is an AI service trying to extract the data of a customer from the user prompts like the following

[source,java]
----
@RegisterAiService
public interface CustomerExtractor {
@UserMessage("Extract information about a customer from this text '{text}'. The response must contain only the JSON with customer's data and without any other sentence.")
@OutputGuardrails(CustomerExtractionOutputGuardrail.class)
Customer extractData(String text);
}
----

it is possible to use with it an `OutputGuardrail` that sanitizes the json LLM response by simply extending the former abstract class and declaring which is the expected output class of the data extraction.

[source,java]
----
@ApplicationScoped
public class CustomerExtractionOutputGuardrail extends AbstractJsonExtractorOutputGuardrail {
@Override
protected Class<?> getOutputClass() {
return Customer.class;
}
}
----

Note that if the data extraction requires a generified Java type, like a `List<Customer>`, it is conversely necessary to extend the `getOutputType` and return a Jackson's `TypeReference` as it follows:

[source,java]
----
@ApplicationScoped
public class CustomersExtractionOutputGuardrail extends AbstractJsonExtractorOutputGuardrail {
@Override
protected TypeReference<?> getOutputType() {
return new TypeReference<List<Customer>>() {};
}
}
----

0 comments on commit 9f6b857

Please sign in to comment.