From 23bf2a8b058603da5c378769ee5e8f5ca6f0cc68 Mon Sep 17 00:00:00 2001 From: Eric Deandrea Date: Fri, 22 Nov 2024 16:07:21 -0500 Subject: [PATCH] Migrate to the JsonSchemaElement API Closes #1054 --- .../langchain4j/deployment/ToolProcessor.java | 184 +++++++++++------- .../JsonArraySchemaObjectSubstitution.java | 28 +++ .../JsonBooleanSchemaObjectSubstitution.java | 26 +++ .../JsonEnumSchemaObjectSubstitution.java | 29 +++ .../JsonIntegerSchemaObjectSubstitution.java | 26 +++ .../JsonNumberSchemaObjectSubstitution.java | 26 +++ .../JsonObjectSchemaObjectSubstitution.java | 36 ++++ ...JsonReferenceSchemaObjectSubstitution.java | 24 +++ .../JsonStringSchemaObjectSubstitution.java | 26 +++ .../ToolParametersObjectSubstitution.java | 12 ++ .../AssistantWithToolsResource.java | 12 +- .../langchain4j/jlama/JlamaModel.java | 7 +- .../langchain4j/ollama/MessageMapper.java | 16 +- .../watsonx/bean/TextChatMessage.java | 14 +- 14 files changed, 364 insertions(+), 102 deletions(-) create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonArraySchemaObjectSubstitution.java create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonBooleanSchemaObjectSubstitution.java create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonEnumSchemaObjectSubstitution.java create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonIntegerSchemaObjectSubstitution.java create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonNumberSchemaObjectSubstitution.java create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonObjectSchemaObjectSubstitution.java create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonReferenceSchemaObjectSubstitution.java create mode 100644 core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonStringSchemaObjectSubstitution.java diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java index 6493730f9..d50160c32 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java @@ -1,30 +1,22 @@ package io.quarkiverse.langchain4j.deployment; -import static dev.langchain4j.agent.tool.JsonSchemaProperty.ARRAY; -import static dev.langchain4j.agent.tool.JsonSchemaProperty.BOOLEAN; -import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER; -import static dev.langchain4j.agent.tool.JsonSchemaProperty.NUMBER; -import static dev.langchain4j.agent.tool.JsonSchemaProperty.OBJECT; -import static dev.langchain4j.agent.tool.JsonSchemaProperty.STRING; -import static dev.langchain4j.agent.tool.JsonSchemaProperty.description; -import static dev.langchain4j.agent.tool.JsonSchemaProperty.enums; import static io.quarkiverse.langchain4j.deployment.DotNames.BLOCKING; import static io.quarkiverse.langchain4j.deployment.DotNames.COMPLETION_STAGE; import static io.quarkiverse.langchain4j.deployment.DotNames.MULTI; import static io.quarkiverse.langchain4j.deployment.DotNames.NON_BLOCKING; import static io.quarkiverse.langchain4j.deployment.DotNames.RUN_ON_VIRTUAL_THREAD; import static io.quarkiverse.langchain4j.deployment.DotNames.UNI; -import static java.util.Arrays.stream; -import static java.util.stream.Collectors.toList; import java.lang.reflect.Modifier; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.function.BiFunction; import java.util.function.Predicate; @@ -45,17 +37,32 @@ import org.objectweb.asm.ClassVisitor; import org.objectweb.asm.Opcodes; -import dev.langchain4j.agent.tool.JsonSchemaProperty; import dev.langchain4j.agent.tool.Tool; import dev.langchain4j.agent.tool.ToolMemoryId; -import dev.langchain4j.agent.tool.ToolParameters; import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.model.chat.request.json.JsonArraySchema; +import dev.langchain4j.model.chat.request.json.JsonBooleanSchema; +import dev.langchain4j.model.chat.request.json.JsonEnumSchema; +import dev.langchain4j.model.chat.request.json.JsonIntegerSchema; +import dev.langchain4j.model.chat.request.json.JsonNumberSchema; +import dev.langchain4j.model.chat.request.json.JsonObjectSchema; +import dev.langchain4j.model.chat.request.json.JsonReferenceSchema; +import dev.langchain4j.model.chat.request.json.JsonSchemaElement; +import dev.langchain4j.model.chat.request.json.JsonStringSchema; +import dev.langchain4j.model.output.structured.Description; import io.quarkiverse.langchain4j.deployment.items.ToolMethodBuildItem; import io.quarkiverse.langchain4j.runtime.ToolsRecorder; import io.quarkiverse.langchain4j.runtime.prompt.Mappable; +import io.quarkiverse.langchain4j.runtime.tool.JsonArraySchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.tool.JsonBooleanSchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.tool.JsonEnumSchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.tool.JsonIntegerSchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.tool.JsonNumberSchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.tool.JsonObjectSchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.tool.JsonReferenceSchemaObjectSubstitution; +import io.quarkiverse.langchain4j.runtime.tool.JsonStringSchemaObjectSubstitution; import io.quarkiverse.langchain4j.runtime.tool.ToolInvoker; import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo; -import io.quarkiverse.langchain4j.runtime.tool.ToolParametersObjectSubstitution; import io.quarkiverse.langchain4j.runtime.tool.ToolSpanWrapper; import io.quarkiverse.langchain4j.runtime.tool.ToolSpecificationObjectSubstitution; import io.quarkus.arc.deployment.AdditionalBeanBuildItem; @@ -87,6 +94,7 @@ public class ToolProcessor { private static final DotName TOOL_MEMORY_ID = DotName.createSimple(ToolMemoryId.class); private static final DotName P = DotName.createSimple(dev.langchain4j.agent.tool.P.class); + private static final DotName DESCRIPTION = DotName.createSimple(Description.class); private static final MethodDescriptor METHOD_METADATA_CTOR = MethodDescriptor .ofConstructor(ToolInvoker.MethodMetadata.class, boolean.class, Map.class, Integer.class); private static final MethodDescriptor HASHMAP_CTOR = MethodDescriptor.ofConstructor(HashMap.class); @@ -209,6 +217,9 @@ public void handleTools( .name(toolName) .description(toolDescription); + var properties = new LinkedHashMap(toolMethod.parametersCount()); + var required = new ArrayList(toolMethod.parametersCount()); + MethodParameterInfo memoryIdParameter = null; for (MethodParameterInfo parameter : toolMethod.parameters()) { if (parameter.hasAnnotation(TOOL_MEMORY_ID)) { @@ -216,15 +227,22 @@ public void handleTools( continue; } - AnnotationInstance pInstance = parameter.annotation(P); - if (pInstance != null && pInstance.value("required") != null - && !pInstance.value("required").asBoolean()) { - builder.addOptionalParameter(parameter.name(), toJsonSchemaProperties(parameter, index)); - } else { - builder.addParameter(parameter.name(), toJsonSchemaProperties(parameter, index)); + var pInstance = parameter.annotation(P); + var jsonSchemaElement = toJsonSchemaElement(parameter, index); + properties.put(parameter.name(), jsonSchemaElement); + + if ((pInstance == null) + || ((pInstance.value("required") != null) && pInstance.value("required").asBoolean())) { + required.add(parameter.name()); } } + builder.parameters( + JsonObjectSchema.builder() + .properties(properties) + .required(required) + .build()); + Map nameToParamPosition = toolMethod.parameters().stream().collect( Collectors.toMap(MethodParameterInfo::name, i -> Integer.valueOf(i.position()))); @@ -324,8 +342,23 @@ public ToolsMetadataBuildItem filterOutRemovedTools( if (beforeRemoval != null) { recorderContext.registerSubstitution(ToolSpecification.class, ToolSpecificationObjectSubstitution.Serialized.class, ToolSpecificationObjectSubstitution.class); - recorderContext.registerSubstitution(ToolParameters.class, ToolParametersObjectSubstitution.Serialized.class, - ToolParametersObjectSubstitution.class); + recorderContext.registerSubstitution(JsonArraySchema.class, JsonArraySchemaObjectSubstitution.Serialized.class, + JsonArraySchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonBooleanSchema.class, JsonBooleanSchemaObjectSubstitution.Serialized.class, + JsonBooleanSchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonEnumSchema.class, JsonEnumSchemaObjectSubstitution.Serialized.class, + JsonEnumSchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonIntegerSchema.class, JsonIntegerSchemaObjectSubstitution.Serialized.class, + JsonIntegerSchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonNumberSchema.class, JsonNumberSchemaObjectSubstitution.Serialized.class, + JsonNumberSchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonObjectSchema.class, JsonObjectSchemaObjectSubstitution.Serialized.class, + JsonObjectSchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonReferenceSchema.class, + JsonReferenceSchemaObjectSubstitution.Serialized.class, + JsonReferenceSchemaObjectSubstitution.class); + recorderContext.registerSubstitution(JsonStringSchema.class, JsonStringSchemaObjectSubstitution.Serialized.class, + JsonStringSchemaObjectSubstitution.class); Map> metadataWithoutRemovedBeans = beforeRemoval.getMetadata().entrySet() .stream() .filter(entry -> validationPhase.getContext().removedBeans().stream() @@ -469,16 +502,14 @@ private String generateArgumentMapper(MethodInfo methodInfo, ClassOutput classOu return implClassName; } - private Iterable toJsonSchemaProperties(MethodParameterInfo parameter, IndexView index) { + private JsonSchemaElement toJsonSchemaElement(MethodParameterInfo parameter, IndexView index) { Type type = parameter.type(); - AnnotationInstance pInstance = parameter.annotation(P); + String description = descriptionFrom(parameter); - JsonSchemaProperty description = pInstance == null ? null : description(pInstance.value().asString()); - - return toJsonSchemaProperties(type, index, description); + return toJsonSchemaElement(type, index, description); } - private Iterable toJsonSchemaProperties(Type type, IndexView index, JsonSchemaProperty description) { + private JsonSchemaElement toJsonSchemaElement(Type type, IndexView index, String description) { DotName typeName = type.name(); if (type.kind() == Type.Kind.WILDCARD_TYPE) { @@ -487,18 +518,18 @@ private Iterable toJsonSchemaProperties(Type type, IndexView boundType = type.asWildcardType().superBound(); } if (boundType != null) { - return toJsonSchemaProperties(boundType, index, description); + return toJsonSchemaElement(boundType, index, description); } else { throw new IllegalArgumentException("Unsupported wildcard type with no bounds: " + type); } } if (DotNames.STRING.equals(typeName) || DotNames.CHARACTER.equals(typeName) || DotNames.PRIMITIVE_CHAR.equals(typeName)) { - return removeNulls(STRING, description); + return JsonStringSchema.builder().description(description).build(); } if (DotNames.BOOLEAN.equals(typeName) || DotNames.PRIMITIVE_BOOLEAN.equals(typeName)) { - return removeNulls(BOOLEAN, description); + return JsonBooleanSchema.builder().description(description).build(); } if (DotNames.BYTE.equals(typeName) || DotNames.PRIMITIVE_BYTE.equals(typeName) @@ -506,14 +537,14 @@ private Iterable toJsonSchemaProperties(Type type, IndexView || DotNames.INTEGER.equals(typeName) || DotNames.PRIMITIVE_INT.equals(typeName) || DotNames.LONG.equals(typeName) || DotNames.PRIMITIVE_LONG.equals(typeName) || DotNames.BIG_INTEGER.equals(typeName)) { - return removeNulls(INTEGER, description); + return JsonIntegerSchema.builder().description(description).build(); } // TODO put constraints on min and max? if (DotNames.FLOAT.equals(typeName) || DotNames.PRIMITIVE_FLOAT.equals(typeName) || DotNames.DOUBLE.equals(typeName) || DotNames.PRIMITIVE_DOUBLE.equals(typeName) || DotNames.BIG_DECIMAL.equals(typeName)) { - return removeNulls(NUMBER, description); + return JsonNumberSchema.builder().description(description).build(); } // TODO something else? @@ -524,49 +555,40 @@ private Iterable toJsonSchemaProperties(Type type, IndexView Type elementType = parameterizedType != null ? parameterizedType.arguments().get(0) : type.asArrayType().component(); - Iterable elementProperties = toJsonSchemaProperties(elementType, index, null); - - JsonSchemaProperty itemsSchema; - if (isComplexType(elementType)) { - Map fieldDescription = new HashMap<>(); - - for (JsonSchemaProperty fieldProperty : elementProperties) { - fieldDescription.put(fieldProperty.key(), fieldProperty.value()); - } - itemsSchema = JsonSchemaProperty.from("items", fieldDescription); - } else { - itemsSchema = JsonSchemaProperty.items(elementProperties.iterator().next()); - } + JsonSchemaElement element = toJsonSchemaElement(elementType, index, null); - return removeNulls(ARRAY, itemsSchema, description); + return JsonArraySchema.builder().description(description).items(element).build(); } if (isEnum(type, index)) { - return removeNulls(STRING, enums(enumConstants(type)), description); + var enums = Arrays.stream(enumConstants(type)) + .filter(e -> e.getClass().isEnum()) + .map(e -> ((Enum) e).name()) + .toList(); + + return JsonEnumSchema.builder() + .enumValues(enums) + .description(Optional.ofNullable(description).orElseGet(() -> descriptionFrom(type))) + .build(); } - if (type.kind() == Type.Kind.CLASS) { - Map properties = new HashMap<>(); - ClassInfo classInfo = index.getClassByName(type.name()); + if (isComplexType(type)) { + var builder = JsonObjectSchema.builder() + .description(Optional.ofNullable(description).orElseGet(() -> descriptionFrom(type))); - List required = new ArrayList<>(); - if (classInfo != null) { - for (FieldInfo field : classInfo.fields()) { - String fieldName = field.name(); - - Iterable fieldSchema = toJsonSchemaProperties(field.type(), index, null); - Map fieldDescription = new HashMap<>(); - - for (JsonSchemaProperty fieldProperty : fieldSchema) { - fieldDescription.put(fieldProperty.key(), fieldProperty.value()); - } + Optional.ofNullable(index.getClassByName(type.name())) + .map(ClassInfo::fields) + .orElseGet(List::of) + .forEach(field -> { + var fieldName = field.name(); + var fieldType = field.type(); + var fieldDescription = descriptionFrom(field); + var fieldSchema = toJsonSchemaElement(fieldType, index, fieldDescription); - properties.put(fieldName, fieldDescription); - } - } + builder.addProperty(fieldName, fieldSchema); + }); - JsonSchemaProperty objectSchema = JsonSchemaProperty.from("properties", properties); - return removeNulls(OBJECT, objectSchema, JsonSchemaProperty.from("required", required), description); + return builder.build(); } throw new IllegalArgumentException("Unsupported type: " + type); @@ -576,12 +598,6 @@ private boolean isComplexType(Type type) { return type.kind() == Type.Kind.CLASS || type.kind() == Type.Kind.PARAMETERIZED_TYPE; } - private Iterable removeNulls(JsonSchemaProperty... properties) { - return stream(properties) - .filter(Objects::nonNull) - .collect(toList()); - } - private boolean isEnum(Type returnType, IndexView index) { if (returnType.kind() != Type.Kind.CLASS) { return false; @@ -590,6 +606,28 @@ private boolean isEnum(Type returnType, IndexView index) { return maybeEnum != null && maybeEnum.isEnum(); } + private static String descriptionFrom(String[] description) { + return (description != null) ? String.join(" ", description) : null; + } + + private static String descriptionFrom(Type type) { + return Optional.ofNullable(type.annotation(DESCRIPTION)) + .map(annotationInstance -> descriptionFrom(annotationInstance.value().asStringArray())) + .orElse(null); + } + + private static String descriptionFrom(FieldInfo field) { + return Optional.ofNullable(field.annotation(DESCRIPTION)) + .map(annotationInstance -> descriptionFrom(annotationInstance.value().asStringArray())) + .orElse(null); + } + + private static String descriptionFrom(MethodParameterInfo parameter) { + return Optional.ofNullable(parameter.annotation(P)) + .map(p -> p.value().asString()) + .orElse(null); + } + private static Object[] enumConstants(Type type) { return JandexUtil.load(type, Thread.currentThread().getContextClassLoader()).getEnumConstants(); } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonArraySchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonArraySchemaObjectSubstitution.java new file mode 100644 index 000000000..b4e5a69f0 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonArraySchemaObjectSubstitution.java @@ -0,0 +1,28 @@ +package io.quarkiverse.langchain4j.runtime.tool; + +import dev.langchain4j.model.chat.request.json.JsonArraySchema; +import dev.langchain4j.model.chat.request.json.JsonSchemaElement; +import io.quarkus.runtime.ObjectSubstitution; +import io.quarkus.runtime.annotations.RecordableConstructor; + +public class JsonArraySchemaObjectSubstitution + implements ObjectSubstitution { + @Override + public Serialized serialize(JsonArraySchema obj) { + return new Serialized(obj.description(), obj.items()); + } + + @Override + public JsonArraySchema deserialize(Serialized obj) { + return JsonArraySchema.builder() + .description(obj.description) + .items(obj.items) + .build(); + } + + public record Serialized(String description, JsonSchemaElement items) { + @RecordableConstructor + public Serialized { + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonBooleanSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonBooleanSchemaObjectSubstitution.java new file mode 100644 index 000000000..c69eb773f --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonBooleanSchemaObjectSubstitution.java @@ -0,0 +1,26 @@ +package io.quarkiverse.langchain4j.runtime.tool; + +import dev.langchain4j.model.chat.request.json.JsonBooleanSchema; +import io.quarkus.runtime.ObjectSubstitution; +import io.quarkus.runtime.annotations.RecordableConstructor; + +public class JsonBooleanSchemaObjectSubstitution + implements ObjectSubstitution { + @Override + public Serialized serialize(JsonBooleanSchema obj) { + return new Serialized(obj.description()); + } + + @Override + public JsonBooleanSchema deserialize(Serialized obj) { + return JsonBooleanSchema.builder() + .description(obj.description) + .build(); + } + + public record Serialized(String description) { + @RecordableConstructor + public Serialized { + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonEnumSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonEnumSchemaObjectSubstitution.java new file mode 100644 index 000000000..cff0ef32d --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonEnumSchemaObjectSubstitution.java @@ -0,0 +1,29 @@ +package io.quarkiverse.langchain4j.runtime.tool; + +import java.util.List; + +import dev.langchain4j.model.chat.request.json.JsonEnumSchema; +import io.quarkus.runtime.ObjectSubstitution; +import io.quarkus.runtime.annotations.RecordableConstructor; + +public class JsonEnumSchemaObjectSubstitution + implements ObjectSubstitution { + @Override + public Serialized serialize(JsonEnumSchema obj) { + return new Serialized(obj.description(), obj.enumValues()); + } + + @Override + public JsonEnumSchema deserialize(Serialized obj) { + return JsonEnumSchema.builder() + .description(obj.description) + .enumValues(obj.enumValues) + .build(); + } + + public record Serialized(String description, List enumValues) { + @RecordableConstructor + public Serialized { + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonIntegerSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonIntegerSchemaObjectSubstitution.java new file mode 100644 index 000000000..34a363397 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonIntegerSchemaObjectSubstitution.java @@ -0,0 +1,26 @@ +package io.quarkiverse.langchain4j.runtime.tool; + +import dev.langchain4j.model.chat.request.json.JsonIntegerSchema; +import io.quarkus.runtime.ObjectSubstitution; +import io.quarkus.runtime.annotations.RecordableConstructor; + +public final class JsonIntegerSchemaObjectSubstitution + implements ObjectSubstitution { + @Override + public Serialized serialize(JsonIntegerSchema obj) { + return new Serialized(obj.description()); + } + + @Override + public JsonIntegerSchema deserialize(Serialized obj) { + return JsonIntegerSchema.builder() + .description(obj.description) + .build(); + } + + public record Serialized(String description) { + @RecordableConstructor + public Serialized { + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonNumberSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonNumberSchemaObjectSubstitution.java new file mode 100644 index 000000000..3c7bf1295 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonNumberSchemaObjectSubstitution.java @@ -0,0 +1,26 @@ +package io.quarkiverse.langchain4j.runtime.tool; + +import dev.langchain4j.model.chat.request.json.JsonNumberSchema; +import io.quarkus.runtime.ObjectSubstitution; +import io.quarkus.runtime.annotations.RecordableConstructor; + +public class JsonNumberSchemaObjectSubstitution + implements ObjectSubstitution { + @Override + public Serialized serialize(JsonNumberSchema obj) { + return new Serialized(obj.description()); + } + + @Override + public JsonNumberSchema deserialize(Serialized obj) { + return JsonNumberSchema.builder() + .description(obj.description) + .build(); + } + + public record Serialized(String description) { + @RecordableConstructor + public Serialized { + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonObjectSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonObjectSchemaObjectSubstitution.java new file mode 100644 index 000000000..43057f5fe --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonObjectSchemaObjectSubstitution.java @@ -0,0 +1,36 @@ +package io.quarkiverse.langchain4j.runtime.tool; + +import java.util.List; +import java.util.Map; + +import dev.langchain4j.model.chat.request.json.JsonObjectSchema; +import dev.langchain4j.model.chat.request.json.JsonSchemaElement; +import io.quarkus.runtime.ObjectSubstitution; +import io.quarkus.runtime.annotations.RecordableConstructor; + +public class JsonObjectSchemaObjectSubstitution + implements ObjectSubstitution { + @Override + public Serialized serialize(JsonObjectSchema obj) { + return new Serialized(obj.description(), obj.properties(), obj.required(), obj.additionalProperties(), + obj.definitions()); + } + + @Override + public JsonObjectSchema deserialize(Serialized obj) { + return JsonObjectSchema.builder() + .description(obj.description) + .properties(obj.properties) + .required(obj.required) + .additionalProperties(obj.additionalProperties) + .definitions(obj.definitions) + .build(); + } + + public record Serialized(String description, Map properties, List required, + Boolean additionalProperties, Map definitions) { + @RecordableConstructor + public Serialized { + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonReferenceSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonReferenceSchemaObjectSubstitution.java new file mode 100644 index 000000000..8b47f85eb --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonReferenceSchemaObjectSubstitution.java @@ -0,0 +1,24 @@ +package io.quarkiverse.langchain4j.runtime.tool; + +import dev.langchain4j.model.chat.request.json.JsonReferenceSchema; +import io.quarkus.runtime.ObjectSubstitution; +import io.quarkus.runtime.annotations.RecordableConstructor; + +public class JsonReferenceSchemaObjectSubstitution + implements ObjectSubstitution { + public Serialized serialize(JsonReferenceSchema obj) { + return new Serialized(obj.reference()); + } + + public JsonReferenceSchema deserialize(Serialized obj) { + return JsonReferenceSchema.builder() + .reference(obj.reference) + .build(); + } + + public record Serialized(String reference) { + @RecordableConstructor + public Serialized { + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonStringSchemaObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonStringSchemaObjectSubstitution.java new file mode 100644 index 000000000..bbaa7a7df --- /dev/null +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/JsonStringSchemaObjectSubstitution.java @@ -0,0 +1,26 @@ +package io.quarkiverse.langchain4j.runtime.tool; + +import dev.langchain4j.model.chat.request.json.JsonStringSchema; +import io.quarkus.runtime.ObjectSubstitution; +import io.quarkus.runtime.annotations.RecordableConstructor; + +public final class JsonStringSchemaObjectSubstitution + implements ObjectSubstitution { + @Override + public Serialized serialize(JsonStringSchema obj) { + return new Serialized(obj.description()); + } + + @Override + public JsonStringSchema deserialize(Serialized obj) { + return JsonStringSchema.builder() + .description(obj.description) + .build(); + } + + public record Serialized(String description) { + @RecordableConstructor + public Serialized { + } + } +} diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/ToolParametersObjectSubstitution.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/ToolParametersObjectSubstitution.java index 9f6838207..8c12cd3dd 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/ToolParametersObjectSubstitution.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/tool/ToolParametersObjectSubstitution.java @@ -7,6 +7,18 @@ import io.quarkus.runtime.ObjectSubstitution; import io.quarkus.runtime.annotations.RecordableConstructor; +/** + * @deprecated + * @see JsonArraySchemaObjectSubstitution + * @see JsonBooleanSchemaObjectSubstitution + * @see JsonEnumSchemaObjectSubstitution + * @see JsonIntegerSchemaObjectSubstitution + * @see JsonNumberSchemaObjectSubstitution + * @see JsonObjectSchemaObjectSubstitution + * @see JsonReferenceSchemaObjectSubstitution + * @see JsonStringSchemaObjectSubstitution + */ +@Deprecated(forRemoval = true) public class ToolParametersObjectSubstitution implements ObjectSubstitution { diff --git a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java index cdec3a989..260f44032 100644 --- a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java +++ b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java @@ -12,10 +12,12 @@ import org.jboss.resteasy.reactive.RestQuery; +import dev.langchain4j.agent.tool.P; import dev.langchain4j.agent.tool.Tool; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.memory.chat.MessageWindowChatMemory; +import dev.langchain4j.model.output.structured.Description; import io.quarkiverse.langchain4j.RegisterAiService; @Path("assistant-with-tool") @@ -27,8 +29,12 @@ public AssistantWithToolsResource(Assistant assistant) { this.assistant = assistant; } + @Description("Some test data") public static class TestData { + @Description("The foo field") String foo; + + @Description("The bar field") Integer bar; Double baz; @@ -54,8 +60,8 @@ public interface Assistant { public static class Calculator { @Tool("Calculates the length of a string") - int stringLength(String s) { - return s.length(); + int stringLength(@P(value = "The string to compute the length of", required = false) String s) { + return (s == null) ? 0 : s.length(); } @Tool("Calculates the sum of two numbers") @@ -80,7 +86,7 @@ public TestData evaluateTestObject(List data) { } @Tool("Calculates all factors of the provided integer.") - List getFactors(int x) { + List getFactors(@P("The integer to get factor") int x) { return java.util.stream.IntStream.rangeClosed(1, x) .filter(i -> x % i == 0) .boxed() diff --git a/model-providers/jlama/runtime/src/main/java/io/quarkiverse/langchain4j/jlama/JlamaModel.java b/model-providers/jlama/runtime/src/main/java/io/quarkiverse/langchain4j/jlama/JlamaModel.java index e2a71b74d..e195edf3e 100644 --- a/model-providers/jlama/runtime/src/main/java/io/quarkiverse/langchain4j/jlama/JlamaModel.java +++ b/model-providers/jlama/runtime/src/main/java/io/quarkiverse/langchain4j/jlama/JlamaModel.java @@ -127,12 +127,7 @@ static Tool toTool(ToolSpecification toolSpecification) { .name(toolSpecification.name()) .description(toolSpecification.description()); - if (toolSpecification.toolParameters() != null) { - for (Map.Entry> p : toolSpecification.toolParameters().properties().entrySet()) { - builder.addParameter(p.getKey(), p.getValue(), - toolSpecification.toolParameters().required().contains(p.getKey())); - } - } else if (toolSpecification.parameters() != null) { + if (toolSpecification.parameters() != null) { for (Map.Entry p : toolSpecification.parameters().properties().entrySet()) { builder.addParameter(p.getKey(), JsonSchemaElementHelper.toMap(p.getValue()), toolSpecification.parameters().required().contains(p.getKey())); diff --git a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/MessageMapper.java b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/MessageMapper.java index aea73737c..a0b017a9e 100644 --- a/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/MessageMapper.java +++ b/model-providers/ollama/runtime/src/main/java/io/quarkiverse/langchain4j/ollama/MessageMapper.java @@ -16,7 +16,6 @@ import com.fasterxml.jackson.core.type.TypeReference; import dev.langchain4j.agent.tool.ToolExecutionRequest; -import dev.langchain4j.agent.tool.ToolParameters; import dev.langchain4j.agent.tool.ToolSpecification; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.ChatMessage; @@ -142,23 +141,12 @@ static List toTools(Collection toolSpecifications) { } private static Tool toTool(ToolSpecification toolSpecification) { - Tool.Function.Parameters functionParameters; - if (toolSpecification.toolParameters() != null) { - functionParameters = toFunctionParameters(toolSpecification.toolParameters()); - } else { - functionParameters = toFunctionParameters(toolSpecification.parameters()); - } + Tool.Function.Parameters functionParameters = toFunctionParameters(toolSpecification.parameters()); + return new Tool(Tool.Type.FUNCTION, new Tool.Function(toolSpecification.name(), toolSpecification.description(), functionParameters)); } - private static Tool.Function.Parameters toFunctionParameters(ToolParameters toolParameters) { - if (toolParameters == null) { - return Tool.Function.Parameters.empty(); - } - return Tool.Function.Parameters.objectType(toolParameters.properties(), toolParameters.required()); - } - private static Tool.Function.Parameters toFunctionParameters(JsonObjectSchema parameters) { if (parameters == null) { return Tool.Function.Parameters.empty(); diff --git a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatMessage.java b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatMessage.java index 55e365f1f..091d86f36 100644 --- a/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatMessage.java +++ b/model-providers/watsonx/runtime/src/main/java/io/quarkiverse/langchain4j/watsonx/bean/TextChatMessage.java @@ -16,6 +16,7 @@ import dev.langchain4j.data.message.TextContent; import dev.langchain4j.data.message.ToolExecutionResultMessage; import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.request.json.JsonSchemaElementHelper; import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatMessageAssistant; import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatMessageSystem; import io.quarkiverse.langchain4j.watsonx.bean.TextChatMessage.TextChatMessageTool; @@ -174,7 +175,7 @@ public static TextChatMessageTool of(ToolExecutionResultMessage toolExecutionRes /** * Creates a {@link TextChatMessageTool}. * - * @param message the content of the message tool. + * @param content the content of the message tool. * @param toolCallId the unique identifier of the message tool. * @return the created {@link TextChatMessageTool}. */ @@ -219,15 +220,16 @@ public record TextChatParameterFunction(String name, String description, Map