diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java index 6493730f9..3421e4c3e 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java @@ -1,30 +1,22 @@ package io.quarkiverse.langchain4j.deployment; -import static dev.langchain4j.agent.tool.JsonSchemaProperty.ARRAY; -import static dev.langchain4j.agent.tool.JsonSchemaProperty.BOOLEAN; -import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER; -import static dev.langchain4j.agent.tool.JsonSchemaProperty.NUMBER; -import static dev.langchain4j.agent.tool.JsonSchemaProperty.OBJECT; -import static dev.langchain4j.agent.tool.JsonSchemaProperty.STRING; -import static dev.langchain4j.agent.tool.JsonSchemaProperty.description; -import static dev.langchain4j.agent.tool.JsonSchemaProperty.enums; import static io.quarkiverse.langchain4j.deployment.DotNames.BLOCKING; import static io.quarkiverse.langchain4j.deployment.DotNames.COMPLETION_STAGE; import static io.quarkiverse.langchain4j.deployment.DotNames.MULTI; import static io.quarkiverse.langchain4j.deployment.DotNames.NON_BLOCKING; import static io.quarkiverse.langchain4j.deployment.DotNames.RUN_ON_VIRTUAL_THREAD; import static io.quarkiverse.langchain4j.deployment.DotNames.UNI; -import static java.util.Arrays.stream; -import static java.util.stream.Collectors.toList; import java.lang.reflect.Modifier; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.function.BiFunction; import java.util.function.Predicate; @@ -45,17 +37,33 @@ import org.objectweb.asm.ClassVisitor; import org.objectweb.asm.Opcodes; -import dev.langchain4j.agent.tool.JsonSchemaProperty; import dev.langchain4j.agent.tool.Tool; import dev.langchain4j.agent.tool.ToolMemoryId; -import dev.langchain4j.agent.tool.ToolParameters; import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.model.chat.request.json.JsonArraySchema; +import dev.langchain4j.model.chat.request.json.JsonBooleanSchema; +import dev.langchain4j.model.chat.request.json.JsonEnumSchema; +import dev.langchain4j.model.chat.request.json.JsonIntegerSchema; +import dev.langchain4j.model.chat.request.json.JsonNumberSchema; +import dev.langchain4j.model.chat.request.json.JsonObjectSchema; +import dev.langchain4j.model.chat.request.json.JsonReferenceSchema; +import dev.langchain4j.model.chat.request.json.JsonSchemaElement; +import dev.langchain4j.model.chat.request.json.JsonStringSchema; +import dev.langchain4j.model.output.structured.Description; import io.quarkiverse.langchain4j.deployment.items.ToolMethodBuildItem; import io.quarkiverse.langchain4j.runtime.ToolsRecorder; import io.quarkiverse.langchain4j.runtime.prompt.Mappable; +import io.quarkiverse.langchain4j.runtime.tool.JsonArraySchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.tool.JsonBooleanSchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.tool.JsonEnumSchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.tool.JsonIntegerSchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.tool.JsonNumberSchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.tool.JsonObjectSchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.tool.JsonReferenceSchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.tool.JsonSchemaElementObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.tool.JsonStringSchemaObjectSubstitution; import io.quarkiverse.langchain4j.runtime.tool.ToolInvoker; import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo; -import io.quarkiverse.langchain4j.runtime.tool.ToolParametersObjectSubstitution; import io.quarkiverse.langchain4j.runtime.tool.ToolSpanWrapper; import io.quarkiverse.langchain4j.runtime.tool.ToolSpecificationObjectSubstitution; import io.quarkus.arc.deployment.AdditionalBeanBuildItem; @@ -87,6 +95,7 @@ public class ToolProcessor { private static final DotName TOOL_MEMORY_ID = DotName.createSimple(ToolMemoryId.class); private static final DotName P = DotName.createSimple(dev.langchain4j.agent.tool.P.class); + private static final DotName DESCRIPTION = DotName.createSimple(Description.class); private static final MethodDescriptor METHOD_METADATA_CTOR = MethodDescriptor .ofConstructor(ToolInvoker.MethodMetadata.class, boolean.class, Map.class, Integer.class); private static final MethodDescriptor HASHMAP_CTOR = MethodDescriptor.ofConstructor(HashMap.class); @@ -209,6 +218,9 @@ public void handleTools( .name(toolName) .description(toolDescription); + var properties = new LinkedHashMap(toolMethod.parametersCount()); + var required = new ArrayList(toolMethod.parametersCount()); + MethodParameterInfo memoryIdParameter = null; for (MethodParameterInfo parameter : toolMethod.parameters()) { if (parameter.hasAnnotation(TOOL_MEMORY_ID)) { @@ -216,15 +228,22 @@ public void handleTools( continue; } - AnnotationInstance pInstance = parameter.annotation(P); - if (pInstance != null && pInstance.value("required") != null - && !pInstance.value("required").asBoolean()) { - builder.addOptionalParameter(parameter.name(), toJsonSchemaProperties(parameter, index)); - } else { - builder.addParameter(parameter.name(), toJsonSchemaProperties(parameter, index)); + var pInstance = parameter.annotation(P); + var jsonSchemaElement = toJsonSchemaElement(parameter, index); + properties.put(parameter.name(), jsonSchemaElement); + + if ((pInstance == null) + || ((pInstance.value("required") != null) && pInstance.value("required").asBoolean())) { + required.add(parameter.name()); } } + builder.parameters( + JsonObjectSchema.builder() + .properties(properties) + .required(required) + .build()); + Map nameToParamPosition = toolMethod.parameters().stream().collect( Collectors.toMap(MethodParameterInfo::name, i -> Integer.valueOf(i.position()))); @@ -324,8 +343,25 @@ public ToolsMetadataBuildItem filterOutRemovedTools( if (beforeRemoval != null) { recorderContext.registerSubstitution(ToolSpecification.class, ToolSpecificationObjectSubstitution.Serialized.class, ToolSpecificationObjectSubstitution.class); - recorderContext.registerSubstitution(ToolParameters.class, ToolParametersObjectSubstitution.Serialized.class, - ToolParametersObjectSubstitution.class); + recorderContext.registerSubstitution(JsonArraySchema.class, JsonArraySchemaObjectSubstitution.Serialized.class, + JsonArraySchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonBooleanSchema.class, JsonBooleanSchemaObjectSubstitution.Serialized.class, + JsonBooleanSchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonEnumSchema.class, JsonEnumSchemaObjectSubstitution.Serialized.class, + JsonEnumSchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonIntegerSchema.class, JsonIntegerSchemaObjectSubstitution.Serialized.class, + JsonIntegerSchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonNumberSchema.class, JsonNumberSchemaObjectSubstitution.Serialized.class, + JsonNumberSchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonObjectSchema.class, JsonObjectSchemaObjectSubstitution.Serialized.class, + JsonObjectSchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonReferenceSchema.class, + JsonReferenceSchemaObjectSubstitution.Serialized.class, + JsonReferenceSchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonSchemaElement.class, JsonSchemaElementObjectSubstitution.Serialized.class, + JsonSchemaElementObjectSubstitution.class); + recorderContext.registerSubstitution(JsonStringSchema.class, JsonStringSchemaObjectSubstitution.Serialized.class, + JsonStringSchemaObjectSubstitution.class); Map> metadataWithoutRemovedBeans = beforeRemoval.getMetadata().entrySet() .stream() .filter(entry -> validationPhase.getContext().removedBeans().stream() @@ -469,16 +505,14 @@ private String generateArgumentMapper(MethodInfo methodInfo, ClassOutput classOu return implClassName; } - private Iterable toJsonSchemaProperties(MethodParameterInfo parameter, IndexView index) { + private JsonSchemaElement toJsonSchemaElement(MethodParameterInfo parameter, IndexView index) { Type type = parameter.type(); - AnnotationInstance pInstance = parameter.annotation(P); + String description = descriptionFrom(parameter); - JsonSchemaProperty description = pInstance == null ? null : description(pInstance.value().asString()); - - return toJsonSchemaProperties(type, index, description); + return toJsonSchemaElement(type, index, description); } - private Iterable toJsonSchemaProperties(Type type, IndexView index, JsonSchemaProperty description) { + private JsonSchemaElement toJsonSchemaElement(Type type, IndexView index, String description) { DotName typeName = type.name(); if (type.kind() == Type.Kind.WILDCARD_TYPE) { @@ -487,18 +521,18 @@ private Iterable toJsonSchemaProperties(Type type, IndexView boundType = type.asWildcardType().superBound(); } if (boundType != null) { - return toJsonSchemaProperties(boundType, index, description); + return toJsonSchemaElement(boundType, index, description); } else { throw new IllegalArgumentException("Unsupported wildcard type with no bounds: " + type); } } if (DotNames.STRING.equals(typeName) || DotNames.CHARACTER.equals(typeName) || DotNames.PRIMITIVE_CHAR.equals(typeName)) { - return removeNulls(STRING, description); + return JsonStringSchema.builder().description(description).build(); } if (DotNames.BOOLEAN.equals(typeName) || DotNames.PRIMITIVE_BOOLEAN.equals(typeName)) { - return removeNulls(BOOLEAN, description); + return JsonBooleanSchema.builder().description(description).build(); } if (DotNames.BYTE.equals(typeName) || DotNames.PRIMITIVE_BYTE.equals(typeName) @@ -506,14 +540,14 @@ private Iterable toJsonSchemaProperties(Type type, IndexView || DotNames.INTEGER.equals(typeName) || DotNames.PRIMITIVE_INT.equals(typeName) || DotNames.LONG.equals(typeName) || DotNames.PRIMITIVE_LONG.equals(typeName) || DotNames.BIG_INTEGER.equals(typeName)) { - return removeNulls(INTEGER, description); + return JsonIntegerSchema.builder().description(description).build(); } // TODO put constraints on min and max? if (DotNames.FLOAT.equals(typeName) || DotNames.PRIMITIVE_FLOAT.equals(typeName) || DotNames.DOUBLE.equals(typeName) || DotNames.PRIMITIVE_DOUBLE.equals(typeName) || DotNames.BIG_DECIMAL.equals(typeName)) { - return removeNulls(NUMBER, description); + return JsonNumberSchema.builder().description(description).build(); } // TODO something else? @@ -524,49 +558,40 @@ private Iterable toJsonSchemaProperties(Type type, IndexView Type elementType = parameterizedType != null ? parameterizedType.arguments().get(0) : type.asArrayType().component(); - Iterable elementProperties = toJsonSchemaProperties(elementType, index, null); - - JsonSchemaProperty itemsSchema; - if (isComplexType(elementType)) { - Map fieldDescription = new HashMap<>(); - - for (JsonSchemaProperty fieldProperty : elementProperties) { - fieldDescription.put(fieldProperty.key(), fieldProperty.value()); - } - itemsSchema = JsonSchemaProperty.from("items", fieldDescription); - } else { - itemsSchema = JsonSchemaProperty.items(elementProperties.iterator().next()); - } + JsonSchemaElement element = toJsonSchemaElement(elementType, index, null); - return removeNulls(ARRAY, itemsSchema, description); + return JsonArraySchema.builder().description(description).items(element).build(); } if (isEnum(type, index)) { - return removeNulls(STRING, enums(enumConstants(type)), description); + var enums = Arrays.stream(enumConstants(type)) + .filter(e -> e.getClass().isEnum()) + .map(e -> ((Enum) e).name()) + .toList(); + + return JsonEnumSchema.builder() + .enumValues(enums) + .description(Optional.ofNullable(description).orElseGet(() -> descriptionFrom(type))) + .build(); } - if (type.kind() == Type.Kind.CLASS) { - Map properties = new HashMap<>(); - ClassInfo classInfo = index.getClassByName(type.name()); + if (isComplexType(type)) { + var builder = JsonObjectSchema.builder() + .description(Optional.ofNullable(description).orElseGet(() -> descriptionFrom(type))); - List required = new ArrayList<>(); - if (classInfo != null) { - for (FieldInfo field : classInfo.fields()) { - String fieldName = field.name(); - - Iterable fieldSchema = toJsonSchemaProperties(field.type(), index, null); - Map fieldDescription = new HashMap<>(); - - for (JsonSchemaProperty fieldProperty : fieldSchema) { - fieldDescription.put(fieldProperty.key(), fieldProperty.value()); - } + Optional.ofNullable(index.getClassByName(type.name())) + .map(ClassInfo::fields) + .orElseGet(List::of) + .forEach(field -> { + var fieldName = field.name(); + var fieldType = field.type(); + var fieldDescription = descriptionFrom(field); + var fieldSchema = toJsonSchemaElement(fieldType, index, fieldDescription); - properties.put(fieldName, fieldDescription); - } - } + builder.addProperty(fieldName, fieldSchema); + }); - JsonSchemaProperty objectSchema = JsonSchemaProperty.from("properties", properties); - return removeNulls(OBJECT, objectSchema, JsonSchemaProperty.from("required", required), description); + return builder.build(); } throw new IllegalArgumentException("Unsupported type: " + type); @@ -576,12 +601,6 @@ private boolean isComplexType(Type type) { return type.kind() == Type.Kind.CLASS || type.kind() == Type.Kind.PARAMETERIZED_TYPE; } - private Iterable removeNulls(JsonSchemaProperty... properties) { - return stream(properties) - .filter(Objects::nonNull) - .collect(toList()); - } - private boolean isEnum(Type returnType, IndexView index) { if (returnType.kind() != Type.Kind.CLASS) { return false; @@ -590,6 +609,28 @@ private boolean isEnum(Type returnType, IndexView index) { return maybeEnum != null && maybeEnum.isEnum(); } + private static String descriptionFrom(String[] description) { + return (description != null) ? String.join(" ", description) : null; + } + + private static String descriptionFrom(Type type) { + return Optional.ofNullable(type.annotation(DESCRIPTION)) + .map(annotationInstance -> descriptionFrom(annotationInstance.value().asStringArray())) + .orElse(null); + } + + private static String descriptionFrom(FieldInfo field) { + return Optional.ofNullable(field.annotation(DESCRIPTION)) + .map(annotationInstance -> descriptionFrom(annotationInstance.value().asStringArray())) + .orElse(null); + } + + private static String descriptionFrom(MethodParameterInfo parameter) { + return Optional.ofNullable(parameter.annotation(P)) + .map(p -> p.value().asString()) + .orElse(null); + } + private static Object[] enumConstants(Type type) { return JandexUtil.load(type, Thread.currentThread().getContextClassLoader()).getEnumConstants(); } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonArraySchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonArraySchemaObjectSubstitution.java new file mode 100644 index 000000000..1a12644ae --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonArraySchemaObjectSubstitution.java @@ -0,0 +1,41 @@ +package io.quarkiverse.langchain4j.runtime.tool; + +import dev.langchain4j.model.chat.request.json.JsonArraySchema; +import dev.langchain4j.model.chat.request.json.JsonSchemaElement; +import io.quarkus.runtime.annotations.RecordableConstructor; + +public final class JsonArraySchemaObjectSubstitution extends JsonSchemaElementObjectSubstitution { + @Override + public JsonArraySchemaObjectSubstitution.Serialized serialize(JsonSchemaElement obj) { + var a = (JsonArraySchema) obj; + return new Serialized(a.description(), a.items()); + } + + @Override + public JsonArraySchema deserialize(JsonSchemaElementObjectSubstitution.Serialized obj) { + var a = (JsonArraySchemaObjectSubstitution.Serialized) obj; + return JsonArraySchema.builder() + .description(a.description) + .items(a.items) + .build(); + } + + public static final class Serialized extends JsonSchemaElementObjectSubstitution.Serialized { + private final String description; + private final JsonSchemaElement items; + + @RecordableConstructor + public Serialized(String description, JsonSchemaElement items) { + this.description = description; + this.items = items; + } + + public String getDescription() { + return description; + } + + public JsonSchemaElement getItems() { + return items; + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonBooleanSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonBooleanSchemaObjectSubstitution.java new file mode 100644 index 000000000..4b50af96d --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonBooleanSchemaObjectSubstitution.java @@ -0,0 +1,34 @@ +package io.quarkiverse.langchain4j.runtime.tool; + +import dev.langchain4j.model.chat.request.json.JsonBooleanSchema; +import dev.langchain4j.model.chat.request.json.JsonSchemaElement; +import io.quarkus.runtime.annotations.RecordableConstructor; + +public final class JsonBooleanSchemaObjectSubstitution extends JsonSchemaElementObjectSubstitution { + @Override + public JsonBooleanSchemaObjectSubstitution.Serialized serialize(JsonSchemaElement obj) { + var b = (JsonBooleanSchema) obj; + return new Serialized(b.description()); + } + + @Override + public JsonBooleanSchema deserialize(JsonSchemaElementObjectSubstitution.Serialized obj) { + var b = (JsonBooleanSchemaObjectSubstitution.Serialized) obj; + return JsonBooleanSchema.builder() + .description(b.description) + .build(); + } + + public static final class Serialized extends JsonSchemaElementObjectSubstitution.Serialized { + private final String description; + + @RecordableConstructor + public Serialized(String description) { + this.description = description; + } + + public String getDescription() { + return description; + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonEnumSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonEnumSchemaObjectSubstitution.java new file mode 100644 index 000000000..ffdaac87e --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonEnumSchemaObjectSubstitution.java @@ -0,0 +1,43 @@ +package io.quarkiverse.langchain4j.runtime.tool; + +import java.util.List; + +import dev.langchain4j.model.chat.request.json.JsonEnumSchema; +import dev.langchain4j.model.chat.request.json.JsonSchemaElement; +import io.quarkus.runtime.annotations.RecordableConstructor; + +public final class JsonEnumSchemaObjectSubstitution extends JsonSchemaElementObjectSubstitution { + @Override + public JsonEnumSchemaObjectSubstitution.Serialized serialize(JsonSchemaElement obj) { + var e = (JsonEnumSchema) obj; + return new Serialized(e.description(), e.enumValues()); + } + + @Override + public JsonEnumSchema deserialize(JsonSchemaElementObjectSubstitution.Serialized obj) { + var e = (JsonEnumSchemaObjectSubstitution.Serialized) obj; + return JsonEnumSchema.builder() + .description(e.description) + .enumValues(e.enumValues) + .build(); + } + + public static final class Serialized extends JsonSchemaElementObjectSubstitution.Serialized { + private final String description; + private final List enumValues; + + @RecordableConstructor + public Serialized(String description, List enumValues) { + this.description = description; + this.enumValues = enumValues; + } + + public String getDescription() { + return description; + } + + public List getEnumValues() { + return enumValues; + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonIntegerSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonIntegerSchemaObjectSubstitution.java new file mode 100644 index 000000000..d48223f54 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonIntegerSchemaObjectSubstitution.java @@ -0,0 +1,34 @@ +package io.quarkiverse.langchain4j.runtime.tool; + +import dev.langchain4j.model.chat.request.json.JsonIntegerSchema; +import dev.langchain4j.model.chat.request.json.JsonSchemaElement; +import io.quarkus.runtime.annotations.RecordableConstructor; + +public final class JsonIntegerSchemaObjectSubstitution extends JsonSchemaElementObjectSubstitution { + @Override + public JsonIntegerSchemaObjectSubstitution.Serialized serialize(JsonSchemaElement obj) { + var i = (JsonIntegerSchema) obj; + return new Serialized(i.description()); + } + + @Override + public JsonIntegerSchema deserialize(JsonSchemaElementObjectSubstitution.Serialized obj) { + var i = (JsonIntegerSchemaObjectSubstitution.Serialized) obj; + return JsonIntegerSchema.builder() + .description(i.description) + .build(); + } + + public static final class Serialized extends JsonSchemaElementObjectSubstitution.Serialized { + private final String description; + + @RecordableConstructor + public Serialized(String description) { + this.description = description; + } + + public String getDescription() { + return description; + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonNumberSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonNumberSchemaObjectSubstitution.java new file mode 100644 index 000000000..ca9f45816 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonNumberSchemaObjectSubstitution.java @@ -0,0 +1,34 @@ +package io.quarkiverse.langchain4j.runtime.tool; + +import dev.langchain4j.model.chat.request.json.JsonNumberSchema; +import dev.langchain4j.model.chat.request.json.JsonSchemaElement; +import io.quarkus.runtime.annotations.RecordableConstructor; + +public final class JsonNumberSchemaObjectSubstitution extends JsonSchemaElementObjectSubstitution { + @Override + public JsonNumberSchemaObjectSubstitution.Serialized serialize(JsonSchemaElement obj) { + var n = (JsonNumberSchema) obj; + return new Serialized(n.description()); + } + + @Override + public JsonNumberSchema deserialize(JsonSchemaElementObjectSubstitution.Serialized obj) { + var n = (JsonNumberSchemaObjectSubstitution.Serialized) obj; + return JsonNumberSchema.builder() + .description(n.description) + .build(); + } + + public static final class Serialized extends JsonSchemaElementObjectSubstitution.Serialized { + private final String description; + + @RecordableConstructor + public Serialized(String description) { + this.description = description; + } + + public String getDescription() { + return description; + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonObjectSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonObjectSchemaObjectSubstitution.java new file mode 100644 index 000000000..810686be5 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonObjectSchemaObjectSubstitution.java @@ -0,0 +1,66 @@ +package io.quarkiverse.langchain4j.runtime.tool; + +import java.util.List; +import java.util.Map; + +import dev.langchain4j.model.chat.request.json.JsonObjectSchema; +import dev.langchain4j.model.chat.request.json.JsonSchemaElement; +import io.quarkus.runtime.annotations.RecordableConstructor; + +public final class JsonObjectSchemaObjectSubstitution extends JsonSchemaElementObjectSubstitution { + @Override + public JsonObjectSchemaObjectSubstitution.Serialized serialize(JsonSchemaElement obj) { + var o = (JsonObjectSchema) obj; + return new Serialized(o.description(), o.properties(), o.required(), o.additionalProperties(), o.definitions()); + } + + @Override + public JsonObjectSchema deserialize(JsonSchemaElementObjectSubstitution.Serialized obj) { + var o = (JsonObjectSchemaObjectSubstitution.Serialized) obj; + return JsonObjectSchema.builder() + .description(o.description) + .properties(o.properties) + .required(o.required) + .additionalProperties(o.additionalProperties) + .definitions(o.definitions) + .build(); + } + + public static final class Serialized extends JsonSchemaElementObjectSubstitution.Serialized { + private final String description; + private final Map properties; + private final List required; + private final Boolean additionalProperties; + private final Map definitions; + + @RecordableConstructor + public Serialized(String description, Map properties, List required, + Boolean additionalProperties, Map definitions) { + this.description = description; + this.properties = properties; + this.required = required; + this.additionalProperties = additionalProperties; + this.definitions = definitions; + } + + public String getDescription() { + return description; + } + + public Map getProperties() { + return properties; + } + + public List getRequired() { + return required; + } + + public Boolean getAdditionalProperties() { + return additionalProperties; + } + + public Map getDefinitions() { + return definitions; + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonReferenceSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonReferenceSchemaObjectSubstitution.java new file mode 100644 index 000000000..2182214c9 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonReferenceSchemaObjectSubstitution.java @@ -0,0 +1,32 @@ +package io.quarkiverse.langchain4j.runtime.tool; + +import dev.langchain4j.model.chat.request.json.JsonReferenceSchema; +import dev.langchain4j.model.chat.request.json.JsonSchemaElement; +import io.quarkus.runtime.annotations.RecordableConstructor; + +public final class JsonReferenceSchemaObjectSubstitution extends JsonSchemaElementObjectSubstitution { + public JsonReferenceSchemaObjectSubstitution.Serialized serialize(JsonSchemaElement obj) { + var r = (JsonReferenceSchema) obj; + return new Serialized(r.reference()); + } + + public JsonReferenceSchema deserialize(JsonSchemaElementObjectSubstitution.Serialized obj) { + var r = (JsonReferenceSchemaObjectSubstitution.Serialized) obj; + return JsonReferenceSchema.builder() + .reference(r.reference) + .build(); + } + + public static final class Serialized extends JsonSchemaElementObjectSubstitution.Serialized { + private final String reference; + + @RecordableConstructor + public Serialized(String reference) { + this.reference = reference; + } + + public String getReference() { + return reference; + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonSchemaElementObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonSchemaElementObjectSubstitution.java new file mode 100644 index 000000000..82c9d3d2a --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonSchemaElementObjectSubstitution.java @@ -0,0 +1,86 @@ +package io.quarkiverse.langchain4j.runtime.tool; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import dev.langchain4j.model.chat.request.json.JsonArraySchema; +import dev.langchain4j.model.chat.request.json.JsonBooleanSchema; +import dev.langchain4j.model.chat.request.json.JsonEnumSchema; +import dev.langchain4j.model.chat.request.json.JsonIntegerSchema; +import dev.langchain4j.model.chat.request.json.JsonNumberSchema; +import dev.langchain4j.model.chat.request.json.JsonObjectSchema; +import dev.langchain4j.model.chat.request.json.JsonReferenceSchema; +import dev.langchain4j.model.chat.request.json.JsonSchemaElement; +import dev.langchain4j.model.chat.request.json.JsonStringSchema; +import io.quarkus.runtime.ObjectSubstitution; + +public sealed class JsonSchemaElementObjectSubstitution + implements ObjectSubstitution + permits JsonArraySchemaObjectSubstitution, + JsonBooleanSchemaObjectSubstitution, + JsonEnumSchemaObjectSubstitution, + JsonIntegerSchemaObjectSubstitution, + JsonNumberSchemaObjectSubstitution, + JsonObjectSchemaObjectSubstitution, + JsonReferenceSchemaObjectSubstitution, + JsonStringSchemaObjectSubstitution { + + // Using ConcurrentHashMap in case multiple threads are using this class at the same time + // Not sure if this will ever happen + private final Map, JsonSchemaElementObjectSubstitution> substitutions = new ConcurrentHashMap<>(8); + + @Override + public JsonSchemaElementObjectSubstitution.Serialized serialize(JsonSchemaElement obj) { + return getSubstitution(obj.getClass()).serialize(obj); + } + + @Override + public JsonSchemaElement deserialize(JsonSchemaElementObjectSubstitution.Serialized obj) { + return getSubstitution(obj.getClass()).deserialize(obj); + } + + private JsonSchemaElementObjectSubstitution getSubstitution(Class clazz) { + return this.substitutions.computeIfAbsent(clazz, c -> { + if (JsonArraySchema.class.isAssignableFrom(c) + || JsonArraySchemaObjectSubstitution.Serialized.class.isAssignableFrom(c)) { + return new JsonArraySchemaObjectSubstitution(); + } else if (JsonBooleanSchema.class.isAssignableFrom(c) + || JsonBooleanSchemaObjectSubstitution.Serialized.class.isAssignableFrom(c)) { + return new JsonBooleanSchemaObjectSubstitution(); + } else if (JsonEnumSchema.class.isAssignableFrom(c) + || JsonEnumSchemaObjectSubstitution.Serialized.class.isAssignableFrom(c)) { + return new JsonEnumSchemaObjectSubstitution(); + } else if (JsonIntegerSchema.class.isAssignableFrom(c) + || JsonIntegerSchemaObjectSubstitution.Serialized.class.isAssignableFrom(c)) { + return new JsonIntegerSchemaObjectSubstitution(); + } else if (JsonNumberSchema.class.isAssignableFrom(c) + || JsonNumberSchemaObjectSubstitution.Serialized.class.isAssignableFrom(c)) { + return new JsonNumberSchemaObjectSubstitution(); + } else if (JsonObjectSchema.class.isAssignableFrom(c) + || JsonObjectSchemaObjectSubstitution.Serialized.class.isAssignableFrom(c)) { + return new JsonObjectSchemaObjectSubstitution(); + } else if (JsonReferenceSchema.class.isAssignableFrom(c) + || JsonReferenceSchemaObjectSubstitution.Serialized.class.isAssignableFrom(c)) { + return new JsonReferenceSchemaObjectSubstitution(); + } else if (JsonStringSchema.class.isAssignableFrom(c) + || JsonStringSchemaObjectSubstitution.Serialized.class.isAssignableFrom(c)) { + return new JsonStringSchemaObjectSubstitution(); + } + + // Handle unsupported types + throw new IllegalArgumentException("Unsupported type: %s".formatted(c.getName())); + }); + } + + public static sealed class Serialized + permits JsonArraySchemaObjectSubstitution.Serialized, + JsonBooleanSchemaObjectSubstitution.Serialized, + JsonEnumSchemaObjectSubstitution.Serialized, + JsonIntegerSchemaObjectSubstitution.Serialized, + JsonNumberSchemaObjectSubstitution.Serialized, + JsonObjectSchemaObjectSubstitution.Serialized, + JsonReferenceSchemaObjectSubstitution.Serialized, + JsonStringSchemaObjectSubstitution.Serialized { + + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonStringSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonStringSchemaObjectSubstitution.java new file mode 100644 index 000000000..f67140600 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonStringSchemaObjectSubstitution.java @@ -0,0 +1,34 @@ +package io.quarkiverse.langchain4j.runtime.tool; + +import dev.langchain4j.model.chat.request.json.JsonSchemaElement; +import dev.langchain4j.model.chat.request.json.JsonStringSchema; +import io.quarkus.runtime.annotations.RecordableConstructor; + +public final class JsonStringSchemaObjectSubstitution extends JsonSchemaElementObjectSubstitution { + @Override + public JsonStringSchemaObjectSubstitution.Serialized serialize(JsonSchemaElement obj) { + var s = (JsonStringSchema) obj; + return new Serialized(s.description()); + } + + @Override + public JsonStringSchema deserialize(JsonSchemaElementObjectSubstitution.Serialized obj) { + var s = (JsonStringSchemaObjectSubstitution.Serialized) obj; + return JsonStringSchema.builder() + .description(s.description) + .build(); + } + + public static final class Serialized extends JsonSchemaElementObjectSubstitution.Serialized { + private final String description; + + @RecordableConstructor + public Serialized(String description) { + this.description = description; + } + + public String getDescription() { + return description; + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/ToolParametersObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/ToolParametersObjectSubstitution.java index 9f6838207..d46a8be89 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/ToolParametersObjectSubstitution.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/ToolParametersObjectSubstitution.java @@ -7,6 +7,11 @@ import io.quarkus.runtime.ObjectSubstitution; import io.quarkus.runtime.annotations.RecordableConstructor; +/** + * @deprecated + * @see JsonSchemaElementObjectSubstitution + */ +@Deprecated(forRemoval = true) public class ToolParametersObjectSubstitution implements ObjectSubstitution { diff --git a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java index cdec3a989..260f44032 100644 --- a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java +++ b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java @@ -12,10 +12,12 @@ import org.jboss.resteasy.reactive.RestQuery; +import dev.langchain4j.agent.tool.P; import dev.langchain4j.agent.tool.Tool; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.output.structured.Description; import io.quarkiverse.langchain4j.RegisterAiService; @Path("assistant-with-tool") @@ -27,8 +29,12 @@ public AssistantWithToolsResource(Assistant assistant) { this.assistant = assistant; } + @Description("Some test data") public static class TestData { + @Description("The foo field") String foo; + + @Description("The bar field") Integer bar; Double baz; @@ -54,8 +60,8 @@ public interface Assistant { public static class Calculator { @Tool("Calculates the length of a string") - int stringLength(String s) { - return s.length(); + int stringLength(@P(value = "The string to compute the length of", required = false) String s) { + return (s == null) ? 0 : s.length(); } @Tool("Calculates the sum of two numbers") @@ -80,7 +86,7 @@ public TestData evaluateTestObject(List data) { } @Tool("Calculates all factors of the provided integer.") - List getFactors(int x) { + List getFactors(@P("The integer to get factor") int x) { return java.util.stream.IntStream.rangeClosed(1, x) .filter(i -> x % i == 0) .boxed()