Skip to content

Commit

Permalink
Merge pull request #1147 from quarkiverse/#1143
Browse files Browse the repository at this point in the history
Add support for structured output in OpenAI
  • Loading branch information
geoand authored Dec 10, 2024
2 parents b779f6b + 4751df3 commit cefbc8c
Show file tree
Hide file tree
Showing 18 changed files with 310 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import static io.quarkiverse.langchain4j.deployment.MethodParameterAsTemplateVariableAllowance.FORCE_ALLOW;
import static io.quarkiverse.langchain4j.deployment.MethodParameterAsTemplateVariableAllowance.IGNORE;
import static io.quarkiverse.langchain4j.deployment.MethodParameterAsTemplateVariableAllowance.OPTIONAL_DENY;
import static io.quarkiverse.langchain4j.deployment.ObjectSubstitutionUtil.registerJsonSchema;
import static io.quarkiverse.langchain4j.runtime.types.TypeUtil.isMulti;
import static io.quarkus.arc.processor.DotNames.NAMED;

import java.io.IOException;
Expand Down Expand Up @@ -61,7 +63,9 @@
import org.objectweb.asm.tree.analysis.AnalyzerException;

import dev.langchain4j.exception.IllegalConfigurationException;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.service.Moderate;
import dev.langchain4j.service.output.JsonSchemas;
import dev.langchain4j.service.output.ServiceOutputParser;
import io.quarkiverse.langchain4j.ModelName;
import io.quarkiverse.langchain4j.RegisterAiService;
Expand Down Expand Up @@ -117,6 +121,7 @@
import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem;
import io.quarkus.deployment.builditem.nativeimage.ServiceProviderBuildItem;
import io.quarkus.deployment.metrics.MetricsCapabilityBuildItem;
import io.quarkus.deployment.recording.RecorderContext;
import io.quarkus.gizmo.ClassCreator;
import io.quarkus.gizmo.ClassOutput;
import io.quarkus.gizmo.FieldDescriptor;
Expand Down Expand Up @@ -922,6 +927,7 @@ public void markIgnoredAnnotations(BuildProducer<MethodParameterIgnoredAnnotatio
public void handleAiServices(
LangChain4jBuildConfig config,
AiServicesRecorder recorder,
RecorderContext recorderContext,
CombinedIndexBuildItem indexBuildItem,
List<DeclarativeAiServiceBuildItem> declarativeAiServiceItems,
List<MethodParameterAllowedAnnotationsBuildItem> methodParameterAllowedAnnotationsItems,
Expand Down Expand Up @@ -1178,6 +1184,7 @@ public void handleAiServices(

}

registerJsonSchema(recorderContext);
recorder.setMetadata(perClassMetadata);
}

Expand Down Expand Up @@ -1246,16 +1253,18 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(

// TODO give user ability to provide custom OutputParser
String outputFormatInstructions = "";
if (generateResponseSchema && !returnType.equals(Multi.class))
Optional<JsonSchema> structuredOutputSchema = Optional.empty();
if (!returnType.equals(Multi.class)) {
outputFormatInstructions = SERVICE_OUTPUT_PARSER.outputFormatInstructions(returnType);
}

List<TemplateParameterInfo> templateParams = gatherTemplateParamInfo(params, allowedPredicates, ignoredPredicates);
Optional<AiServiceMethodCreateInfo.TemplateInfo> systemMessageInfo = gatherSystemMessageInfo(method, templateParams);
AiServiceMethodCreateInfo.UserMessageInfo userMessageInfo = gatherUserMessageInfo(method, templateParams);

AiServiceMethodCreateInfo.ResponseSchemaInfo responseSchemaInfo = ResponseSchemaInfo.of(generateResponseSchema,
systemMessageInfo,
userMessageInfo.template(), outputFormatInstructions);
userMessageInfo.template(), outputFormatInstructions, jsonSchemaFrom(returnType));

if (!generateResponseSchema && responseSchemaInfo.isInSystemMessage())
throw new RuntimeException(
Expand Down Expand Up @@ -1293,6 +1302,13 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(
inputGuardrails, outputGuardrails, accumulatorClassName, responseAugmenterClassName);
}

private Optional<JsonSchema> jsonSchemaFrom(java.lang.reflect.Type returnType) {
if (isMulti(returnType)) {
return Optional.empty();
}
return JsonSchemas.jsonSchemaFrom(returnType);
}

private boolean detectIfToolExecutionRequiresAWorkerThread(MethodInfo method, List<ToolMethodBuildItem> tools,
List<String> methodToolClassNames) {
List<String> allTools = new ArrayList<>(methodToolClassNames);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package io.quarkiverse.langchain4j.deployment;

import dev.langchain4j.model.chat.request.json.JsonArraySchema;
import dev.langchain4j.model.chat.request.json.JsonBooleanSchema;
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
import dev.langchain4j.model.chat.request.json.JsonNumberSchema;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonReferenceSchema;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
import io.quarkiverse.langchain4j.runtime.substitution.JsonArraySchemaObjectSubstitution;
import io.quarkiverse.langchain4j.runtime.substitution.JsonBooleanSchemaObjectSubstitution;
import io.quarkiverse.langchain4j.runtime.substitution.JsonEnumSchemaObjectSubstitution;
import io.quarkiverse.langchain4j.runtime.substitution.JsonIntegerSchemaObjectSubstitution;
import io.quarkiverse.langchain4j.runtime.substitution.JsonNumberSchemaObjectSubstitution;
import io.quarkiverse.langchain4j.runtime.substitution.JsonObjectSchemaObjectSubstitution;
import io.quarkiverse.langchain4j.runtime.substitution.JsonReferenceSchemaObjectSubstitution;
import io.quarkiverse.langchain4j.runtime.substitution.JsonSchemaObjectSubstitution;
import io.quarkiverse.langchain4j.runtime.substitution.JsonStringSchemaObjectSubstitution;
import io.quarkus.deployment.recording.RecorderContext;

final class ObjectSubstitutionUtil {

private ObjectSubstitutionUtil() {
}

static void registerJsonSchema(RecorderContext recorderContext) {
recorderContext.registerSubstitution(JsonSchema.class, JsonSchemaObjectSubstitution.Serialized.class,
JsonSchemaObjectSubstitution.class);
recorderContext.registerSubstitution(JsonArraySchema.class, JsonArraySchemaObjectSubstitution.Serialized.class,
JsonArraySchemaObjectSubstitution.class);
recorderContext.registerSubstitution(JsonBooleanSchema.class, JsonBooleanSchemaObjectSubstitution.Serialized.class,
JsonBooleanSchemaObjectSubstitution.class);
recorderContext.registerSubstitution(JsonEnumSchema.class, JsonEnumSchemaObjectSubstitution.Serialized.class,
JsonEnumSchemaObjectSubstitution.class);
recorderContext.registerSubstitution(JsonIntegerSchema.class, JsonIntegerSchemaObjectSubstitution.Serialized.class,
JsonIntegerSchemaObjectSubstitution.class);
recorderContext.registerSubstitution(JsonNumberSchema.class, JsonNumberSchemaObjectSubstitution.Serialized.class,
JsonNumberSchemaObjectSubstitution.class);
recorderContext.registerSubstitution(JsonObjectSchema.class, JsonObjectSchemaObjectSubstitution.Serialized.class,
JsonObjectSchemaObjectSubstitution.class);
recorderContext.registerSubstitution(JsonReferenceSchema.class,
JsonReferenceSchemaObjectSubstitution.Serialized.class,
JsonReferenceSchemaObjectSubstitution.class);
recorderContext.registerSubstitution(JsonStringSchema.class, JsonStringSchemaObjectSubstitution.Serialized.class,
JsonStringSchemaObjectSubstitution.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import static io.quarkiverse.langchain4j.deployment.DotNames.NON_BLOCKING;
import static io.quarkiverse.langchain4j.deployment.DotNames.RUN_ON_VIRTUAL_THREAD;
import static io.quarkiverse.langchain4j.deployment.DotNames.UNI;
import static io.quarkiverse.langchain4j.deployment.ObjectSubstitutionUtil.registerJsonSchema;

import java.lang.reflect.Modifier;
import java.util.ArrayList;
Expand Down Expand Up @@ -46,21 +47,12 @@
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
import dev.langchain4j.model.chat.request.json.JsonNumberSchema;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonReferenceSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
import dev.langchain4j.model.output.structured.Description;
import io.quarkiverse.langchain4j.deployment.items.ToolMethodBuildItem;
import io.quarkiverse.langchain4j.runtime.ToolsRecorder;
import io.quarkiverse.langchain4j.runtime.prompt.Mappable;
import io.quarkiverse.langchain4j.runtime.tool.JsonArraySchemaObjectSubstitution;
import io.quarkiverse.langchain4j.runtime.tool.JsonBooleanSchemaObjectSubstitution;
import io.quarkiverse.langchain4j.runtime.tool.JsonEnumSchemaObjectSubstitution;
import io.quarkiverse.langchain4j.runtime.tool.JsonIntegerSchemaObjectSubstitution;
import io.quarkiverse.langchain4j.runtime.tool.JsonNumberSchemaObjectSubstitution;
import io.quarkiverse.langchain4j.runtime.tool.JsonObjectSchemaObjectSubstitution;
import io.quarkiverse.langchain4j.runtime.tool.JsonReferenceSchemaObjectSubstitution;
import io.quarkiverse.langchain4j.runtime.tool.JsonStringSchemaObjectSubstitution;
import io.quarkiverse.langchain4j.runtime.tool.ToolInvoker;
import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo;
import io.quarkiverse.langchain4j.runtime.tool.ToolSpanWrapper;
Expand Down Expand Up @@ -342,23 +334,7 @@ public ToolsMetadataBuildItem filterOutRemovedTools(
if (beforeRemoval != null) {
recorderContext.registerSubstitution(ToolSpecification.class, ToolSpecificationObjectSubstitution.Serialized.class,
ToolSpecificationObjectSubstitution.class);
recorderContext.registerSubstitution(JsonArraySchema.class, JsonArraySchemaObjectSubstitution.Serialized.class,
JsonArraySchemaObjectSubstitution.class);
recorderContext.registerSubstitution(JsonBooleanSchema.class, JsonBooleanSchemaObjectSubstitution.Serialized.class,
JsonBooleanSchemaObjectSubstitution.class);
recorderContext.registerSubstitution(JsonEnumSchema.class, JsonEnumSchemaObjectSubstitution.Serialized.class,
JsonEnumSchemaObjectSubstitution.class);
recorderContext.registerSubstitution(JsonIntegerSchema.class, JsonIntegerSchemaObjectSubstitution.Serialized.class,
JsonIntegerSchemaObjectSubstitution.class);
recorderContext.registerSubstitution(JsonNumberSchema.class, JsonNumberSchemaObjectSubstitution.Serialized.class,
JsonNumberSchemaObjectSubstitution.class);
recorderContext.registerSubstitution(JsonObjectSchema.class, JsonObjectSchemaObjectSubstitution.Serialized.class,
JsonObjectSchemaObjectSubstitution.class);
recorderContext.registerSubstitution(JsonReferenceSchema.class,
JsonReferenceSchemaObjectSubstitution.Serialized.class,
JsonReferenceSchemaObjectSubstitution.class);
recorderContext.registerSubstitution(JsonStringSchema.class, JsonStringSchemaObjectSubstitution.Serialized.class,
JsonStringSchemaObjectSubstitution.class);
registerJsonSchema(recorderContext);
Map<String, List<ToolMethodCreateInfo>> metadataWithoutRemovedBeans = beforeRemoval.getMetadata().entrySet()
.stream()
.filter(entry -> validationPhase.getContext().removedBeans().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.eclipse.microprofile.config.ConfigProvider;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.model.chat.request.json.JsonSchema;
import dev.langchain4j.service.tool.ToolExecutor;
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
Expand Down Expand Up @@ -370,11 +371,12 @@ public record SpanInfo(String name) {
}

public record ResponseSchemaInfo(boolean enabled, boolean isInSystemMessage, Optional<Boolean> isInUserMessage,
String outputFormatInstructions) {
String outputFormatInstructions, Optional<JsonSchema> structuredOutputSchema) {

public static ResponseSchemaInfo of(boolean enabled, Optional<TemplateInfo> systemMessageInfo,
Optional<TemplateInfo> userMessageInfo,
String outputFormatInstructions) {
String outputFormatInstructions,
Optional<JsonSchema> structuredOutputSchema) {

boolean systemMessage = systemMessageInfo.flatMap(TemplateInfo::text)
.map(text -> text.contains(ResponseSchemaUtil.placeholder()))
Expand All @@ -385,7 +387,8 @@ public static ResponseSchemaInfo of(boolean enabled, Optional<TemplateInfo> syst
userMessage = Optional.of(userMessageInfo.get().text.get().contains(ResponseSchemaUtil.placeholder()));
}

return new ResponseSchemaInfo(enabled, systemMessage, userMessage, outputFormatInstructions);
return new ResponseSchemaInfo(enabled, systemMessage, userMessage, outputFormatInstructions,
structuredOutputSchema);
}
}
}
Loading

0 comments on commit cefbc8c

Please sign in to comment.