-
Notifications
You must be signed in to change notification settings - Fork 98
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Provide an abstract output guardrails for json data extraction
- Loading branch information
1 parent
6efb975
commit 78fcaf4
Showing
10 changed files
with
265 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
60 changes: 60 additions & 0 deletions
60
...main/java/io/quarkiverse/langchain4j/guardrails/AbstractJsonExtractorOutputGuardrail.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
package io.quarkiverse.langchain4j.guardrails; | ||
|
||
import jakarta.inject.Inject; | ||
|
||
import org.jboss.logging.Logger; | ||
|
||
import com.fasterxml.jackson.core.type.TypeReference; | ||
|
||
import dev.langchain4j.data.message.AiMessage; | ||
|
||
public abstract class AbstractJsonExtractorOutputGuardrail implements OutputGuardrail { | ||
|
||
@Inject | ||
Logger logger; | ||
|
||
@Inject | ||
JsonGuardrailsUtils jsonGuardrailsUtils; | ||
|
||
protected AbstractJsonExtractorOutputGuardrail() { | ||
if (getOutputClass() == null && getOutputType() == null) { | ||
throw new IllegalArgumentException("Either getOutputClass() or getOutputType() must be implemented"); | ||
} | ||
} | ||
|
||
@Override | ||
public OutputGuardrailResult validate(AiMessage responseFromLLM) { | ||
String llmResponse = responseFromLLM.text(); | ||
logger.debugf("LLM output: %s", llmResponse); | ||
|
||
Object result = deserialize(llmResponse); | ||
if (result != null) { | ||
return successWith(llmResponse, result); | ||
} | ||
|
||
String json = jsonGuardrailsUtils.trimNonJson(llmResponse); | ||
if (json != null) { | ||
result = deserialize(json); | ||
if (result != null) { | ||
return successWith(json, result); | ||
} | ||
} | ||
|
||
return reprompt("Invalid JSON", | ||
"Make sure you return a valid JSON object following " | ||
+ "the specified format"); | ||
} | ||
|
||
protected Object deserialize(String llmResponse) { | ||
return getOutputClass() != null ? jsonGuardrailsUtils.deserialize(llmResponse, getOutputClass()) | ||
: jsonGuardrailsUtils.deserialize(llmResponse, getOutputType()); | ||
} | ||
|
||
protected Class<?> getOutputClass() { | ||
return null; | ||
} | ||
|
||
protected TypeReference<?> getOutputType() { | ||
return null; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
47 changes: 47 additions & 0 deletions
47
core/runtime/src/main/java/io/quarkiverse/langchain4j/guardrails/JsonGuardrailsUtils.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
package io.quarkiverse.langchain4j.guardrails; | ||
|
||
import jakarta.enterprise.context.ApplicationScoped; | ||
import jakarta.inject.Inject; | ||
|
||
import com.fasterxml.jackson.core.JsonProcessingException; | ||
import com.fasterxml.jackson.core.type.TypeReference; | ||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
|
||
@ApplicationScoped | ||
class JsonGuardrailsUtils { | ||
|
||
@Inject | ||
ObjectMapper objectMapper; | ||
|
||
private JsonGuardrailsUtils() { | ||
} | ||
|
||
String trimNonJson(String llmResponse) { | ||
int jsonMapStart = llmResponse.indexOf('{'); | ||
int jsonListStart = llmResponse.indexOf('['); | ||
if (jsonMapStart < 0 && jsonListStart < 0) { | ||
return null; | ||
} | ||
boolean isJsonMap = jsonMapStart >= 0 && (jsonMapStart < jsonListStart || jsonListStart < 0); | ||
|
||
int jsonStart = isJsonMap ? jsonMapStart : jsonListStart; | ||
int jsonEnd = isJsonMap ? llmResponse.lastIndexOf('}') : llmResponse.lastIndexOf(']'); | ||
return jsonEnd >= 0 && jsonStart < jsonEnd ? llmResponse.substring(jsonStart, jsonEnd + 1) : null; | ||
} | ||
|
||
<T> T deserialize(String json, Class<T> expectedOutputClass) { | ||
try { | ||
return objectMapper.readValue(json, expectedOutputClass); | ||
} catch (JsonProcessingException e) { | ||
return null; | ||
} | ||
} | ||
|
||
<T> T deserialize(String json, TypeReference<T> expectedOutputType) { | ||
try { | ||
return objectMapper.readValue(json, expectedOutputType); | ||
} catch (JsonProcessingException e) { | ||
return null; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.