Skip to content

Commit

Permalink
Output guardrails should support structured output
Browse files Browse the repository at this point in the history
Fixes #1200
  • Loading branch information
edeandrea committed Jan 7, 2025
1 parent 9314bcc commit 7973ead
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.input.structured.StructuredPrompt;
Expand Down Expand Up @@ -337,32 +337,7 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,

log.debug("Attempting to obtain AI response");

Optional<JsonSchema> jsonSchema = Optional.empty();
if (supportsJsonSchema) {
jsonSchema = methodCreateInfo.getResponseSchemaInfo().structuredOutputSchema();
}

Response<AiMessage> response;
if (jsonSchema.isPresent()) {
ChatRequest chatRequest = ChatRequest.builder()
.messages(messagesToSend)
.toolSpecifications(toolSpecifications)
.responseFormat(ResponseFormat.builder()
.type(JSON)
.jsonSchema(jsonSchema.get())
.build())
.build();

ChatResponse chatResponse = context.chatModel.chat(chatRequest);
response = new Response<>(
chatResponse.aiMessage(),
chatResponse.tokenUsage(),
chatResponse.finishReason());
} else {
response = toolSpecifications == null
? context.chatModel.generate(messagesToSend)
: context.chatModel.generate(messagesToSend, toolSpecifications);
}
var response = executeRequest(context, methodCreateInfo, messagesToSend, toolSpecifications);

log.debug("AI response obtained");
if (audit != null) {
Expand Down Expand Up @@ -450,6 +425,46 @@ private List<ChatMessage> messagesToSend(UserMessage augmentedUserMessage,
methodCreateInfo, responseAugmenterParam);
}

private static Response<AiMessage> executeRequest(JsonSchema jsonSchema, List<ChatMessage> messagesToSend,
ChatLanguageModel chatModel, List<ToolSpecification> toolSpecifications) {
var chatRequest = ChatRequest.builder()
.messages(messagesToSend)
.toolSpecifications(toolSpecifications)
.responseFormat(
ResponseFormat.builder()
.type(JSON)
.jsonSchema(jsonSchema)
.build())
.build();

var response = chatModel.chat(chatRequest);

return new Response<>(
response.aiMessage(),
response.tokenUsage(),
response.finishReason());
}

private static Response<AiMessage> executeRequest(List<ChatMessage> messagesToSend, ChatLanguageModel chatModel,
List<ToolSpecification> toolSpecifications) {
return (toolSpecifications == null) ? chatModel.generate(messagesToSend)
: chatModel.generate(messagesToSend, toolSpecifications);
}

static Response<AiMessage> executeRequest(AiServiceMethodCreateInfo methodCreateInfo, List<ChatMessage> messagesToSend,
ChatLanguageModel chatModel, List<ToolSpecification> toolSpecifications) {
var jsonSchema = supportsJsonSchema(chatModel) ? methodCreateInfo.getResponseSchemaInfo().structuredOutputSchema()
: Optional.<JsonSchema> empty();

return jsonSchema.isPresent() ? executeRequest(jsonSchema.get(), messagesToSend, chatModel, toolSpecifications)
: executeRequest(messagesToSend, chatModel, toolSpecifications);
}

static Response<AiMessage> executeRequest(QuarkusAiServiceContext context, AiServiceMethodCreateInfo methodCreateInfo,
List<ChatMessage> messagesToSend, List<ToolSpecification> toolSpecifications) {
return executeRequest(methodCreateInfo, messagesToSend, context.chatModel, toolSpecifications);
}

private static Object doImplementGenerateImage(AiServiceMethodCreateInfo methodCreateInfo, QuarkusAiServiceContext context,
Audit audit, Optional<SystemMessage> systemMessage, UserMessage userMessage,
Object memoryId, Type returnType, Map<String, Object> templateVariables) {
Expand Down Expand Up @@ -547,9 +562,12 @@ private static List<ChatMessage> createMessagesToSendForNoMemory(Optional<System
return result;
}

private static boolean supportsJsonSchema(ChatLanguageModel chatModel) {
return (chatModel != null) && chatModel.supportedCapabilities().contains(RESPONSE_FORMAT_JSON_SCHEMA);
}

private static boolean supportsJsonSchema(AiServiceContext context) {
return context.chatModel != null
&& context.chatModel.supportedCapabilities().contains(RESPONSE_FORMAT_JSON_SCHEMA);
return supportsJsonSchema(context.chatModel);
}

private static Future<Moderation> triggerModerationIfNeeded(AiServiceContext context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,24 +81,18 @@ public static Response<AiMessage> invokeOutputGuardrails(AiServiceMethodCreateIn
if (!result.isSuccess()) {
if (!result.isRetry()) {
throw new GuardrailException(result.toString(), result.getFirstFailureException());
} else if (result.getReprompt() != null) {
// Retry with re-prompting
chatMemory.add(userMessage(result.getReprompt()));
if (toolSpecifications == null) {
response = chatModel.generate(chatMemory.messages());
} else {
response = chatModel.generate(chatMemory.messages(), toolSpecifications);
}
chatMemory.add(response.content());
} else {
// Retry without re-prompting
if (toolSpecifications == null) {
response = chatModel.generate(chatMemory.messages());
} else {
response = chatModel.generate(chatMemory.messages(), toolSpecifications);
// Retry
if (result.getReprompt() != null) {
// Retry with reprompting
chatMemory.add(userMessage(result.getReprompt()));
}

response = AiServiceMethodImplementationSupport.executeRequest(methodCreateInfo, chatMemory.messages(),
chatModel, toolSpecifications);
chatMemory.add(response.content());
}

attempt++;
output = new OutputGuardrailParams(response.content(), output.memory(),
output.augmentationResult(), output.userMessageTemplate(), output.variables());
Expand Down

0 comments on commit 7973ead

Please sign in to comment.