diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java index a57eab0af..48f9f5196 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/AiServicesProcessor.java @@ -14,6 +14,8 @@ import static io.quarkiverse.langchain4j.deployment.MethodParameterAsTemplateVariableAllowance.FORCE_ALLOW; import static io.quarkiverse.langchain4j.deployment.MethodParameterAsTemplateVariableAllowance.IGNORE; import static io.quarkiverse.langchain4j.deployment.MethodParameterAsTemplateVariableAllowance.OPTIONAL_DENY; +import static io.quarkiverse.langchain4j.deployment.ObjectSubstitutionUtil.registerJsonSchema; +import static io.quarkiverse.langchain4j.runtime.types.TypeUtil.isMulti; import static io.quarkus.arc.processor.DotNames.NAMED; import java.io.IOException; @@ -61,7 +63,9 @@ import org.objectweb.asm.tree.analysis.AnalyzerException; import dev.langchain4j.exception.IllegalConfigurationException; +import dev.langchain4j.model.chat.request.json.JsonSchema; import dev.langchain4j.service.Moderate; +import dev.langchain4j.service.output.JsonSchemas; import dev.langchain4j.service.output.ServiceOutputParser; import io.quarkiverse.langchain4j.ModelName; import io.quarkiverse.langchain4j.RegisterAiService; @@ -117,6 +121,7 @@ import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem; import io.quarkus.deployment.builditem.nativeimage.ServiceProviderBuildItem; import io.quarkus.deployment.metrics.MetricsCapabilityBuildItem; +import io.quarkus.deployment.recording.RecorderContext; import io.quarkus.gizmo.ClassCreator; import io.quarkus.gizmo.ClassOutput; import io.quarkus.gizmo.FieldDescriptor; @@ -922,6 +927,7 @@ public void markIgnoredAnnotations(BuildProducer declarativeAiServiceItems, List methodParameterAllowedAnnotationsItems, @@ -1178,6 +1184,7 @@ public void handleAiServices( } + registerJsonSchema(recorderContext); recorder.setMetadata(perClassMetadata); } @@ -1246,8 +1253,10 @@ private AiServiceMethodCreateInfo gatherMethodMetadata( // TODO give user ability to provide custom OutputParser String outputFormatInstructions = ""; - if (generateResponseSchema && !returnType.equals(Multi.class)) + Optional structuredOutputSchema = Optional.empty(); + if (!returnType.equals(Multi.class)) { outputFormatInstructions = SERVICE_OUTPUT_PARSER.outputFormatInstructions(returnType); + } List templateParams = gatherTemplateParamInfo(params, allowedPredicates, ignoredPredicates); Optional systemMessageInfo = gatherSystemMessageInfo(method, templateParams); @@ -1255,7 +1264,7 @@ private AiServiceMethodCreateInfo gatherMethodMetadata( AiServiceMethodCreateInfo.ResponseSchemaInfo responseSchemaInfo = ResponseSchemaInfo.of(generateResponseSchema, systemMessageInfo, - userMessageInfo.template(), outputFormatInstructions); + userMessageInfo.template(), outputFormatInstructions, jsonSchemaFrom(returnType)); if (!generateResponseSchema && responseSchemaInfo.isInSystemMessage()) throw new RuntimeException( @@ -1293,6 +1302,13 @@ private AiServiceMethodCreateInfo gatherMethodMetadata( inputGuardrails, outputGuardrails, accumulatorClassName, responseAugmenterClassName); } + private Optional jsonSchemaFrom(java.lang.reflect.Type returnType) { + if (isMulti(returnType)) { + return Optional.empty(); + } + return JsonSchemas.jsonSchemaFrom(returnType); + } + private boolean detectIfToolExecutionRequiresAWorkerThread(MethodInfo method, List tools, List methodToolClassNames) { List allTools = new ArrayList<>(methodToolClassNames); diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ObjectSubstitutionUtil.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ObjectSubstitutionUtil.java new file mode 100644 index 000000000..e0ed5ef2c --- /dev/null +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ObjectSubstitutionUtil.java @@ -0,0 +1,49 @@ +package io.quarkiverse.langchain4j.deployment; + +import dev.langchain4j.model.chat.request.json.JsonArraySchema; +import dev.langchain4j.model.chat.request.json.JsonBooleanSchema; +import dev.langchain4j.model.chat.request.json.JsonEnumSchema; +import dev.langchain4j.model.chat.request.json.JsonIntegerSchema; +import dev.langchain4j.model.chat.request.json.JsonNumberSchema; +import dev.langchain4j.model.chat.request.json.JsonObjectSchema; +import dev.langchain4j.model.chat.request.json.JsonReferenceSchema; +import dev.langchain4j.model.chat.request.json.JsonSchema; +import dev.langchain4j.model.chat.request.json.JsonStringSchema; +import io.quarkiverse.langchain4j.runtime.substitution.JsonArraySchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.substitution.JsonBooleanSchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.substitution.JsonEnumSchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.substitution.JsonIntegerSchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.substitution.JsonNumberSchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.substitution.JsonObjectSchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.substitution.JsonReferenceSchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.substitution.JsonSchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.substitution.JsonStringSchemaObjectSubstitution; +import io.quarkus.deployment.recording.RecorderContext; + +final class ObjectSubstitutionUtil { + + private ObjectSubstitutionUtil() { + } + + static void registerJsonSchema(RecorderContext recorderContext) { + recorderContext.registerSubstitution(JsonSchema.class, JsonSchemaObjectSubstitution.Serialized.class, + JsonSchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonArraySchema.class, JsonArraySchemaObjectSubstitution.Serialized.class, + JsonArraySchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonBooleanSchema.class, JsonBooleanSchemaObjectSubstitution.Serialized.class, + JsonBooleanSchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonEnumSchema.class, JsonEnumSchemaObjectSubstitution.Serialized.class, + JsonEnumSchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonIntegerSchema.class, JsonIntegerSchemaObjectSubstitution.Serialized.class, + JsonIntegerSchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonNumberSchema.class, JsonNumberSchemaObjectSubstitution.Serialized.class, + JsonNumberSchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonObjectSchema.class, JsonObjectSchemaObjectSubstitution.Serialized.class, + JsonObjectSchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonReferenceSchema.class, + JsonReferenceSchemaObjectSubstitution.Serialized.class, + JsonReferenceSchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonStringSchema.class, JsonStringSchemaObjectSubstitution.Serialized.class, + JsonStringSchemaObjectSubstitution.class); + } +} diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java index d50160c32..dcc9217d2 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java @@ -6,6 +6,7 @@ import static io.quarkiverse.langchain4j.deployment.DotNames.NON_BLOCKING; import static io.quarkiverse.langchain4j.deployment.DotNames.RUN_ON_VIRTUAL_THREAD; import static io.quarkiverse.langchain4j.deployment.DotNames.UNI; +import static io.quarkiverse.langchain4j.deployment.ObjectSubstitutionUtil.registerJsonSchema; import java.lang.reflect.Modifier; import java.util.ArrayList; @@ -46,21 +47,12 @@ import dev.langchain4j.model.chat.request.json.JsonIntegerSchema; import dev.langchain4j.model.chat.request.json.JsonNumberSchema; import dev.langchain4j.model.chat.request.json.JsonObjectSchema; -import dev.langchain4j.model.chat.request.json.JsonReferenceSchema; import dev.langchain4j.model.chat.request.json.JsonSchemaElement; import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.output.structured.Description; import io.quarkiverse.langchain4j.deployment.items.ToolMethodBuildItem; import io.quarkiverse.langchain4j.runtime.ToolsRecorder; import io.quarkiverse.langchain4j.runtime.prompt.Mappable; -import io.quarkiverse.langchain4j.runtime.tool.JsonArraySchemaObjectSubstitution; -import io.quarkiverse.langchain4j.runtime.tool.JsonBooleanSchemaObjectSubstitution; -import io.quarkiverse.langchain4j.runtime.tool.JsonEnumSchemaObjectSubstitution; -import io.quarkiverse.langchain4j.runtime.tool.JsonIntegerSchemaObjectSubstitution; -import io.quarkiverse.langchain4j.runtime.tool.JsonNumberSchemaObjectSubstitution; -import io.quarkiverse.langchain4j.runtime.tool.JsonObjectSchemaObjectSubstitution; -import io.quarkiverse.langchain4j.runtime.tool.JsonReferenceSchemaObjectSubstitution; -import io.quarkiverse.langchain4j.runtime.tool.JsonStringSchemaObjectSubstitution; import io.quarkiverse.langchain4j.runtime.tool.ToolInvoker; import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo; import io.quarkiverse.langchain4j.runtime.tool.ToolSpanWrapper; @@ -342,23 +334,7 @@ public ToolsMetadataBuildItem filterOutRemovedTools( if (beforeRemoval != null) { recorderContext.registerSubstitution(ToolSpecification.class, ToolSpecificationObjectSubstitution.Serialized.class, ToolSpecificationObjectSubstitution.class); - recorderContext.registerSubstitution(JsonArraySchema.class, JsonArraySchemaObjectSubstitution.Serialized.class, - JsonArraySchemaObjectSubstitution.class); - recorderContext.registerSubstitution(JsonBooleanSchema.class, JsonBooleanSchemaObjectSubstitution.Serialized.class, - JsonBooleanSchemaObjectSubstitution.class); - recorderContext.registerSubstitution(JsonEnumSchema.class, JsonEnumSchemaObjectSubstitution.Serialized.class, - JsonEnumSchemaObjectSubstitution.class); - recorderContext.registerSubstitution(JsonIntegerSchema.class, JsonIntegerSchemaObjectSubstitution.Serialized.class, - JsonIntegerSchemaObjectSubstitution.class); - recorderContext.registerSubstitution(JsonNumberSchema.class, JsonNumberSchemaObjectSubstitution.Serialized.class, - JsonNumberSchemaObjectSubstitution.class); - recorderContext.registerSubstitution(JsonObjectSchema.class, JsonObjectSchemaObjectSubstitution.Serialized.class, - JsonObjectSchemaObjectSubstitution.class); - recorderContext.registerSubstitution(JsonReferenceSchema.class, - JsonReferenceSchemaObjectSubstitution.Serialized.class, - JsonReferenceSchemaObjectSubstitution.class); - recorderContext.registerSubstitution(JsonStringSchema.class, JsonStringSchemaObjectSubstitution.Serialized.class, - JsonStringSchemaObjectSubstitution.class); + registerJsonSchema(recorderContext); Map> metadataWithoutRemovedBeans = beforeRemoval.getMetadata().entrySet() .stream() .filter(entry -> validationPhase.getContext().removedBeans().stream() diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java index dca8cadf2..4321739e3 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodCreateInfo.java @@ -11,6 +11,7 @@ import org.eclipse.microprofile.config.ConfigProvider; import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.model.chat.request.json.JsonSchema; import dev.langchain4j.service.tool.ToolExecutor; import io.quarkiverse.langchain4j.guardrails.InputGuardrail; import io.quarkiverse.langchain4j.guardrails.OutputGuardrail; @@ -370,11 +371,12 @@ public record SpanInfo(String name) { } public record ResponseSchemaInfo(boolean enabled, boolean isInSystemMessage, Optional isInUserMessage, - String outputFormatInstructions) { + String outputFormatInstructions, Optional structuredOutputSchema) { public static ResponseSchemaInfo of(boolean enabled, Optional systemMessageInfo, Optional userMessageInfo, - String outputFormatInstructions) { + String outputFormatInstructions, + Optional structuredOutputSchema) { boolean systemMessage = systemMessageInfo.flatMap(TemplateInfo::text) .map(text -> text.contains(ResponseSchemaUtil.placeholder())) @@ -385,7 +387,8 @@ public static ResponseSchemaInfo of(boolean enabled, Optional syst userMessage = Optional.of(userMessageInfo.get().text.get().contains(ResponseSchemaUtil.placeholder())); } - return new ResponseSchemaInfo(enabled, systemMessage, userMessage, outputFormatInstructions); + return new ResponseSchemaInfo(enabled, systemMessage, userMessage, outputFormatInstructions, + structuredOutputSchema); } } } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java index 1d9ebfabb..1c4d31e1c 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.java @@ -2,6 +2,8 @@ import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.internal.Exceptions.runtime; +import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA; +import static dev.langchain4j.model.chat.request.ResponseFormatType.JSON; import static dev.langchain4j.model.output.TokenUsage.sum; import static dev.langchain4j.service.AiServices.removeToolMessages; import static dev.langchain4j.service.AiServices.verifyModerationIfNeeded; @@ -42,6 +44,10 @@ import dev.langchain4j.data.message.ToolExecutionResultMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.memory.ChatMemory; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.request.ResponseFormat; +import dev.langchain4j.model.chat.request.json.JsonSchema; +import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.input.structured.StructuredPrompt; @@ -70,6 +76,7 @@ import io.quarkiverse.langchain4j.runtime.ContextLocals; import io.quarkiverse.langchain4j.runtime.QuarkusServiceOutputParser; import io.quarkiverse.langchain4j.runtime.ResponseSchemaUtil; +import io.quarkiverse.langchain4j.runtime.types.TypeUtil; import io.quarkiverse.langchain4j.spi.DefaultMemoryIdProvider; import io.smallrye.common.vertx.VertxContext; import io.smallrye.mutiny.Multi; @@ -148,11 +155,14 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob Object memoryId = memoryId(methodCreateInfo, methodArgs, context.chatMemoryProvider != null); Optional systemMessage = prepareSystemMessage(methodCreateInfo, methodArgs, context.hasChatMemory() ? context.chatMemory(memoryId).messages() : Collections.emptyList()); - UserMessage userMessage = prepareUserMessage(context, methodCreateInfo, methodArgs); + + boolean supportsJsonSchema = supportsJsonSchema(context); + + UserMessage userMessage = prepareUserMessage(context, methodCreateInfo, methodArgs, supportsJsonSchema); Map templateVariables = getTemplateVariables(methodArgs, methodCreateInfo.getUserMessageInfo()); Type returnType = methodCreateInfo.getReturnType(); - if (isImage(returnType) || isResultImage(returnType)) { + if (TypeUtil.isImage(returnType) || TypeUtil.isResultImage(returnType)) { return doImplementGenerateImage(methodCreateInfo, context, audit, systemMessage, userMessage, memoryId, returnType, templateVariables); } @@ -191,7 +201,7 @@ private static Object doImplement(AiServiceMethodCreateInfo methodCreateInfo, Ob Metadata metadata = Metadata.from(userMessage, memoryId, chatMemory); AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata); - if (!isMulti(returnType)) { + if (!TypeUtil.isMulti(returnType)) { augmentationResult = context.retrievalAugmentor.augment(augmentationRequest); userMessage = (UserMessage) augmentationResult.chatMessage(); } else { @@ -254,7 +264,7 @@ private List messagesToSend(UserMessage augmentedUserMessage, methodCreateInfo); } - if (isTokenStream(returnType)) { + if (TypeUtil.isTokenStream(returnType)) { // TODO Indicate the output guardrails cannot be used when using token stream. chatMemory.commit(); // for streaming cases, we really have to commit because all alternatives are worse return new AiServiceTokenStream(messagesToSend, toolSpecifications, toolExecutors, @@ -263,7 +273,7 @@ private List messagesToSend(UserMessage augmentedUserMessage, var actualAugmentationResult = augmentationResult; var actualUserMessage = userMessage; - if (isMulti(returnType)) { + if (TypeUtil.isMulti(returnType)) { chatMemory.commit(); // for streaming cases, we really have to commit because all alternatives are worse if (methodCreateInfo.getOutputGuardrailsClassNames().isEmpty()) { var stream = new TokenStreamMulti(messagesToSend, toolSpecifications, toolExecutors, @@ -326,9 +336,33 @@ private List messagesToSend(UserMessage augmentedUserMessage, log.debug("Attempting to obtain AI response"); - Response response = toolSpecifications == null - ? context.chatModel.generate(messagesToSend) - : context.chatModel.generate(messagesToSend, toolSpecifications); + Optional jsonSchema = Optional.empty(); + if (supportsJsonSchema) { + jsonSchema = methodCreateInfo.getResponseSchemaInfo().structuredOutputSchema(); + } + + Response response; + if (jsonSchema.isPresent()) { + ChatRequest chatRequest = ChatRequest.builder() + .messages(messagesToSend) + .toolSpecifications(toolSpecifications) + .responseFormat(ResponseFormat.builder() + .type(JSON) + .jsonSchema(jsonSchema.get()) + .build()) + .build(); + + ChatResponse chatResponse = context.chatModel.chat(chatRequest); + response = new Response<>( + chatResponse.aiMessage(), + chatResponse.tokenUsage(), + chatResponse.finishReason()); + } else { + response = toolSpecifications == null + ? context.chatModel.generate(messagesToSend) + : context.chatModel.generate(messagesToSend, toolSpecifications); + } + log.debug("AI response obtained"); if (audit != null) { audit.addLLMToApplicationMessage(response); @@ -391,7 +425,7 @@ private List messagesToSend(UserMessage augmentedUserMessage, chatMemory.commit(); Object guardrailResult = response.metadata().get(OutputGuardrailResult.class.getName()); - if (guardrailResult != null && isTypeOf(returnType, guardrailResult.getClass())) { + if (guardrailResult != null && TypeUtil.isTypeOf(returnType, guardrailResult.getClass())) { return guardrailResult; } @@ -399,8 +433,9 @@ private List messagesToSend(UserMessage augmentedUserMessage, var responseAugmenterParam = new ResponseAugmenterParams(userMessage, chatMemory, augmentationResult, userMessageTemplate, templateVariables); - if (isResult(returnType)) { - var parsedResponse = SERVICE_OUTPUT_PARSER.parse(response, resultTypeParam((ParameterizedType) returnType)); + if (TypeUtil.isResult(returnType)) { + var parsedResponse = SERVICE_OUTPUT_PARSER.parse(response, + TypeUtil.resultTypeParam((ParameterizedType) returnType)); parsedResponse = ResponseAugmenterSupport.invoke(parsedResponse, methodCreateInfo, responseAugmenterParam); return Result.builder() .content(parsedResponse) @@ -441,9 +476,9 @@ private static Object doImplementGenerateImage(AiServiceMethodCreateInfo methodC audit.onCompletion(imageResponse.content()); } - if (isImage(returnType)) { + if (TypeUtil.isImage(returnType)) { return imageResponse.content(); - } else if (isResultImage(returnType)) { + } else if (TypeUtil.isResultImage(returnType)) { return Result.builder() .content(imageResponse) .tokenUsage(imageResponse.tokenUsage()) @@ -511,44 +546,9 @@ private static List createMessagesToSendForNoMemory(Optional type"); - } - return returnType.getActualTypeArguments()[0]; - } - - private static boolean isImage(Type returnType) { - return isTypeOf(returnType, Image.class); - } - - private static boolean isResultImage(Type returnType) { - if (!isImage(returnType)) { - return false; - } - return isImage(resultTypeParam((ParameterizedType) returnType)); - } - - private static boolean isTypeOf(Type type, Class clazz) { - if (type instanceof Class) { - return type.equals(clazz); - } - if (type instanceof ParameterizedType pt) { - return isTypeOf(pt.getRawType(), clazz); - } - throw new IllegalStateException("Unsupported return type " + type); + private static boolean supportsJsonSchema(AiServiceContext context) { + return context.chatModel != null + && context.chatModel.supportedCapabilities().contains(RESPONSE_FORMAT_JSON_SCHEMA); } private static Future triggerModerationIfNeeded(AiServiceContext context, @@ -594,7 +594,7 @@ private static Optional prepareSystemMessage(AiServiceMethodCreat } private static UserMessage prepareUserMessage(AiServiceContext context, AiServiceMethodCreateInfo createInfo, - Object[] methodArgs) { + Object[] methodArgs, boolean supportsJsonSchema) { AiServiceMethodCreateInfo.UserMessageInfo userMessageInfo = createInfo.getUserMessageInfo(); String userName = null; @@ -644,7 +644,7 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic } // No response schema placeholder found in the @SystemMessage and @UserMessage, concat it to the UserMessage. - if (!createInfo.getResponseSchemaInfo().isInSystemMessage() && !hasResponseSchema) { + if (!createInfo.getResponseSchemaInfo().isInSystemMessage() && !hasResponseSchema && !supportsJsonSchema) { templateText = templateText.concat(ResponseSchemaUtil.placeholder()); } @@ -664,10 +664,9 @@ private static UserMessage prepareUserMessage(AiServiceContext context, AiServic + paramIndex + " is null"); } - // TODO: Understand how to enable the {response_schema} for the @StructuredPrompt. String text = toString(argValue); return createUserMessage(userName, imageContent, - text.concat(createInfo.getResponseSchemaInfo().outputFormatInstructions())); + text.concat(supportsJsonSchema ? "" : createInfo.getResponseSchemaInfo().outputFormatInstructions())); } else { throw new IllegalStateException("Unable to construct UserMessage for class '" + context.aiServiceClass.getName() + "'. Please contact the maintainers"); diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonArraySchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonArraySchemaObjectSubstitution.java similarity index 94% rename from core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonArraySchemaObjectSubstitution.java rename to core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonArraySchemaObjectSubstitution.java index b4e5a69f0..7271ba14f 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonArraySchemaObjectSubstitution.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonArraySchemaObjectSubstitution.java @@ -1,4 +1,4 @@ -package io.quarkiverse.langchain4j.runtime.tool; +package io.quarkiverse.langchain4j.runtime.substitution; import dev.langchain4j.model.chat.request.json.JsonArraySchema; import dev.langchain4j.model.chat.request.json.JsonSchemaElement; diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonBooleanSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonBooleanSchemaObjectSubstitution.java similarity index 93% rename from core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonBooleanSchemaObjectSubstitution.java rename to core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonBooleanSchemaObjectSubstitution.java index c69eb773f..6298dbc90 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonBooleanSchemaObjectSubstitution.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonBooleanSchemaObjectSubstitution.java @@ -1,4 +1,4 @@ -package io.quarkiverse.langchain4j.runtime.tool; +package io.quarkiverse.langchain4j.runtime.substitution; import dev.langchain4j.model.chat.request.json.JsonBooleanSchema; import io.quarkus.runtime.ObjectSubstitution; diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonEnumSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonEnumSchemaObjectSubstitution.java similarity index 93% rename from core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonEnumSchemaObjectSubstitution.java rename to core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonEnumSchemaObjectSubstitution.java index cff0ef32d..2e629b70e 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonEnumSchemaObjectSubstitution.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonEnumSchemaObjectSubstitution.java @@ -1,4 +1,4 @@ -package io.quarkiverse.langchain4j.runtime.tool; +package io.quarkiverse.langchain4j.runtime.substitution; import java.util.List; diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonIntegerSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonIntegerSchemaObjectSubstitution.java similarity index 93% rename from core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonIntegerSchemaObjectSubstitution.java rename to core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonIntegerSchemaObjectSubstitution.java index 34a363397..392c88c7e 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonIntegerSchemaObjectSubstitution.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonIntegerSchemaObjectSubstitution.java @@ -1,4 +1,4 @@ -package io.quarkiverse.langchain4j.runtime.tool; +package io.quarkiverse.langchain4j.runtime.substitution; import dev.langchain4j.model.chat.request.json.JsonIntegerSchema; import io.quarkus.runtime.ObjectSubstitution; diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonNumberSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonNumberSchemaObjectSubstitution.java similarity index 93% rename from core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonNumberSchemaObjectSubstitution.java rename to core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonNumberSchemaObjectSubstitution.java index 3c7bf1295..d2dfd7995 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonNumberSchemaObjectSubstitution.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonNumberSchemaObjectSubstitution.java @@ -1,4 +1,4 @@ -package io.quarkiverse.langchain4j.runtime.tool; +package io.quarkiverse.langchain4j.runtime.substitution; import dev.langchain4j.model.chat.request.json.JsonNumberSchema; import io.quarkus.runtime.ObjectSubstitution; diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonObjectSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonObjectSchemaObjectSubstitution.java similarity index 95% rename from core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonObjectSchemaObjectSubstitution.java rename to core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonObjectSchemaObjectSubstitution.java index 43057f5fe..3bee609bd 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonObjectSchemaObjectSubstitution.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonObjectSchemaObjectSubstitution.java @@ -1,4 +1,4 @@ -package io.quarkiverse.langchain4j.runtime.tool; +package io.quarkiverse.langchain4j.runtime.substitution; import java.util.List; import java.util.Map; diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonReferenceSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonReferenceSchemaObjectSubstitution.java similarity index 93% rename from core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonReferenceSchemaObjectSubstitution.java rename to core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonReferenceSchemaObjectSubstitution.java index 8b47f85eb..ae3121546 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonReferenceSchemaObjectSubstitution.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonReferenceSchemaObjectSubstitution.java @@ -1,4 +1,4 @@ -package io.quarkiverse.langchain4j.runtime.tool; +package io.quarkiverse.langchain4j.runtime.substitution; import dev.langchain4j.model.chat.request.json.JsonReferenceSchema; import io.quarkus.runtime.ObjectSubstitution; diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonSchemaObjectSubstitution.java new file mode 100644 index 000000000..31264a19b --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonSchemaObjectSubstitution.java @@ -0,0 +1,28 @@ +package io.quarkiverse.langchain4j.runtime.substitution; + +import dev.langchain4j.model.chat.request.json.JsonSchema; +import dev.langchain4j.model.chat.request.json.JsonSchemaElement; +import io.quarkus.runtime.ObjectSubstitution; +import io.quarkus.runtime.annotations.RecordableConstructor; + +public class JsonSchemaObjectSubstitution + implements ObjectSubstitution { + @Override + public Serialized serialize(JsonSchema obj) { + return new Serialized(obj.name(), obj.rootElement()); + } + + @Override + public JsonSchema deserialize(Serialized obj) { + return JsonSchema.builder() + .name(obj.name) + .rootElement(obj.rootElement) + .build(); + } + + public record Serialized(String name, JsonSchemaElement rootElement) { + @RecordableConstructor + public Serialized { + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonStringSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonStringSchemaObjectSubstitution.java similarity index 93% rename from core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonStringSchemaObjectSubstitution.java rename to core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonStringSchemaObjectSubstitution.java index bbaa7a7df..6f7455375 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonStringSchemaObjectSubstitution.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/substitution/JsonStringSchemaObjectSubstitution.java @@ -1,4 +1,4 @@ -package io.quarkiverse.langchain4j.runtime.tool; +package io.quarkiverse.langchain4j.runtime.substitution; import dev.langchain4j.model.chat.request.json.JsonStringSchema; import io.quarkus.runtime.ObjectSubstitution; diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/types/TypeUtil.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/types/TypeUtil.java new file mode 100644 index 000000000..400d704af --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/types/TypeUtil.java @@ -0,0 +1,55 @@ +package io.quarkiverse.langchain4j.runtime.types; + +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; + +import dev.langchain4j.data.image.Image; +import dev.langchain4j.service.Result; +import dev.langchain4j.service.TokenStream; +import io.smallrye.mutiny.Multi; + +public final class TypeUtil { + + private TypeUtil() { + } + + public static boolean isTokenStream(Type returnType) { + return isTypeOf(returnType, TokenStream.class); + } + + public static boolean isMulti(Type returnType) { + return isTypeOf(returnType, Multi.class); + } + + public static boolean isResult(Type returnType) { + return isTypeOf(returnType, Result.class); + } + + public static Type resultTypeParam(ParameterizedType returnType) { + if (!isTypeOf(returnType, Result.class)) { + throw new IllegalStateException("Can only be called with Result type"); + } + return returnType.getActualTypeArguments()[0]; + } + + public static boolean isImage(Type returnType) { + return isTypeOf(returnType, Image.class); + } + + public static boolean isResultImage(Type returnType) { + if (!isImage(returnType)) { + return false; + } + return isImage(resultTypeParam((ParameterizedType) returnType)); + } + + public static boolean isTypeOf(Type type, Class clazz) { + if (type instanceof Class) { + return type.equals(clazz); + } + if (type instanceof ParameterizedType pt) { + return isTypeOf(pt.getRawType(), clazz); + } + throw new IllegalStateException("Unsupported return type " + type); + } +} diff --git a/model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/StructuredOutputResponseTest.java b/model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/StructuredOutputResponseTest.java new file mode 100644 index 000000000..22b427c28 --- /dev/null +++ b/model-providers/openai/openai-vanilla/deployment/src/test/java/org/acme/examples/aiservices/StructuredOutputResponseTest.java @@ -0,0 +1,84 @@ +package org.acme.examples.aiservices; + +import static java.time.Month.JULY; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.InstanceOfAssertFactories.map; + +import java.io.IOException; +import java.time.LocalDate; +import java.util.Map; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.service.UserMessage; +import io.quarkiverse.langchain4j.RegisterAiService; +import io.quarkiverse.langchain4j.openai.testing.internal.OpenAiBaseTest; +import io.quarkiverse.langchain4j.testing.internal.WiremockAware; +import io.quarkus.test.QuarkusUnitTest; + +public class StructuredOutputResponseTest extends OpenAiBaseTest { + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer( + () -> ShrinkWrap.create(JavaArchive.class)) + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.api-key", "whatever") + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.chat-model.response-format", "json_schema") + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.chat-model.strict-json-schema", "true") + .overrideRuntimeConfigKey("quarkus.langchain4j.openai.base-url", + WiremockAware.wiremockUrlForConfig("/v1")); + + record Person(String firstName, String lastName, LocalDate birthDate) { + } + + @RegisterAiService + @ApplicationScoped + interface PersonExtractor { + + @UserMessage("Extract information about a person from {{it}}") + Person extractPersonFrom(String text); + } + + @Inject + PersonExtractor personExtractor; + + @Test + public void testPojo() throws IOException { + setChatCompletionMessageContent( + // this is supposed to be a string inside a json string hence all the escaping... + "{\\n\\\"firstName\\\": \\\"John\\\",\\n\\\"lastName\\\": \\\"Doe\\\",\\n\\\"birthDate\\\": \\\"1968-07-04\\\"\\n}"); + + String text = "In 1968, amidst the fading echoes of Independence Day, " + + "a child named John arrived under the calm evening sky. " + + "This newborn, bearing the surname Doe, marked the start of a new journey."; + + Person result = personExtractor.extractPersonFrom(text); + + assertThat(result.firstName).isEqualTo("John"); + assertThat(result.lastName).isEqualTo("Doe"); + assertThat(result.birthDate).isEqualTo(LocalDate.of(1968, JULY, 4)); + + Map requestAsMap = getRequestAsMap(); + assertSingleRequestMessage(requestAsMap, + "Extract information about a person from In 1968, amidst the fading echoes of Independence Day, " + + "a child named John arrived under the calm evening sky. This newborn, bearing the surname Doe, " + + "marked the start of a new journey."); + assertThat(requestAsMap).hasEntrySatisfying("response_format", (v) -> { + assertThat(v).asInstanceOf(map(String.class, Object.class)).satisfies(responseFormatMap -> { + assertThat(responseFormatMap).containsEntry("type", "json_schema"); + assertThat(responseFormatMap).extracting("json_schema").satisfies(js -> { + assertThat(js).asInstanceOf(map(String.class, Object.class)).satisfies(jsonSchemaMap -> { + assertThat(jsonSchemaMap).containsEntry("name", "Person").containsKey("schema"); + }); + }); + }); + }); + + } +} diff --git a/model-providers/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java b/model-providers/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java index 625315a7f..e7309948b 100644 --- a/model-providers/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java +++ b/model-providers/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/OpenAiRecorder.java @@ -85,6 +85,7 @@ public Function, ChatLanguageModel .presencePenalty(chatModelConfig.presencePenalty()) .frequencyPenalty(chatModelConfig.frequencyPenalty()) .responseFormat(chatModelConfig.responseFormat().orElse(null)) + .strictJsonSchema(chatModelConfig.strictJsonSchema().orElse(null)) .stop(chatModelConfig.stop().orElse(null)); openAiConfig.organizationId().ifPresent(builder::organizationId); diff --git a/model-providers/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/ChatModelConfig.java b/model-providers/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/ChatModelConfig.java index 292aad127..c109d9fbc 100644 --- a/model-providers/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/ChatModelConfig.java +++ b/model-providers/openai/openai-vanilla/runtime/src/main/java/io/quarkiverse/langchain4j/openai/runtime/config/ChatModelConfig.java @@ -76,6 +76,11 @@ public interface ChatModelConfig { */ Optional responseFormat(); + /** + * Whether responses follow JSON Schema for Structured Outputs + */ + Optional strictJsonSchema(); + /** * The list of stop words to use. *