diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/support/ToolDefinitions.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/support/ToolDefinitions.java index 68d4646333a..e8b0e55c283 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/support/ToolDefinitions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/support/ToolDefinitions.java @@ -17,6 +17,8 @@ package org.springframework.ai.tool.support; import java.lang.reflect.Method; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; @@ -33,10 +35,16 @@ *

* * @author Mark Pollack + * @author Seol-JY * @since 1.0.0 */ public final class ToolDefinitions { + /** + * Cache for tool definitions. Key is the Method instance. + */ + private static final Map toolDefinitionCache = new ConcurrentHashMap<>(256); + private ToolDefinitions() { // prevents instantiation. } @@ -56,7 +64,16 @@ public static DefaultToolDefinition.Builder builder(Method method) { * Create a default {@link ToolDefinition} instance from a {@link Method}. */ public static ToolDefinition from(Method method) { - return builder(method).build(); + Assert.notNull(method, "method cannot be null"); + return toolDefinitionCache.computeIfAbsent(method, ToolDefinitions::createToolDefinition); + } + + private static ToolDefinition createToolDefinition(Method method) { + return DefaultToolDefinition.builder() + .name(ToolUtils.getToolName(method)) + .description(ToolUtils.getToolDescription(method)) + .inputSchema(JsonSchemaGenerator.generateForMethodInput(method)) + .build(); } } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/util/json/schema/JsonSchemaGenerator.java b/spring-ai-model/src/main/java/org/springframework/ai/util/json/schema/JsonSchemaGenerator.java index c66b9af2733..e7e7b7dd59a 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/util/json/schema/JsonSchemaGenerator.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/util/json/schema/JsonSchemaGenerator.java @@ -20,7 +20,10 @@ import java.lang.reflect.Parameter; import java.lang.reflect.Type; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Stream; import com.fasterxml.jackson.annotation.JsonProperty; @@ -67,8 +70,11 @@ * If none of these annotations are present, the default behavior is to consider the * property as required and not to include a description. *

+ * This class provides caching for method input schema generation to improve performance + * when the same method signatures are processed multiple times. * * @author Thomas Vitale + * @author Seol-JY * @since 1.0.0 */ public final class JsonSchemaGenerator { @@ -81,6 +87,8 @@ public final class JsonSchemaGenerator { */ private static final boolean PROPERTY_REQUIRED_BY_DEFAULT = true; + private static final Map methodSchemaCache = new ConcurrentHashMap<>(256); + private static final SchemaGenerator TYPE_SCHEMA_GENERATOR; private static final SchemaGenerator SUBTYPE_SCHEMA_GENERATOR; @@ -116,8 +124,82 @@ private JsonSchemaGenerator() { /** * Generate a JSON Schema for a method's input parameters. + * + *

+ * This method uses caching to improve performance when the same method signature is + * processed multiple times. The cache key includes method signature and schema + * options to ensure correct cache hits. + * @param method the method to generate schema for + * @param schemaOptions options for schema generation + * @return JSON Schema as a string + * @throws IllegalArgumentException if method is null */ public static String generateForMethodInput(Method method, SchemaOption... schemaOptions) { + Assert.notNull(method, "method cannot be null"); + + String cacheKey = buildMethodCacheKey(method, schemaOptions); + return methodSchemaCache.computeIfAbsent(cacheKey, key -> generateMethodSchemaInternal(method, schemaOptions)); + } + + /** + * Generate a JSON Schema for a class type. + */ + public static String generateForType(Type type, SchemaOption... schemaOptions) { + Assert.notNull(type, "type cannot be null"); + ObjectNode schema = TYPE_SCHEMA_GENERATOR.generateSchema(type); + if ((type == Void.class) && !schema.has("properties")) { + schema.putObject("properties"); + } + processSchemaOptions(schemaOptions, schema); + return schema.toPrettyString(); + } + + /** + * Build cache key for method input schema generation. + * + *

+ * The cache key includes: + *

+ * @param method the method + * @param schemaOptions schema generation options + * @return unique cache key + */ + private static String buildMethodCacheKey(Method method, SchemaOption... schemaOptions) { + StringBuilder keyBuilder = new StringBuilder(256); + + // Class name + keyBuilder.append(method.getDeclaringClass().getName()); + keyBuilder.append('#'); + + // Method name + keyBuilder.append(method.getName()); + keyBuilder.append('('); + + // Parameter types (including generic information) + Type[] parameterTypes = method.getGenericParameterTypes(); + for (int i = 0; i < parameterTypes.length; i++) { + if (i > 0) { + keyBuilder.append(','); + } + keyBuilder.append(parameterTypes[i].getTypeName()); + } + keyBuilder.append(')'); + + // Schema options + if (schemaOptions.length > 0) { + keyBuilder.append(':'); + keyBuilder.append(Arrays.toString(schemaOptions)); + } + + return keyBuilder.toString(); + } + + private static String generateMethodSchemaInternal(Method method, SchemaOption... schemaOptions) { ObjectNode schema = JsonParser.getObjectMapper().createObjectNode(); schema.put("$schema", SchemaVersion.DRAFT_2020_12.getIdentifier()); schema.put("type", "object"); @@ -155,19 +237,6 @@ public static String generateForMethodInput(Method method, SchemaOption... schem return schema.toPrettyString(); } - /** - * Generate a JSON Schema for a class type. - */ - public static String generateForType(Type type, SchemaOption... schemaOptions) { - Assert.notNull(type, "type cannot be null"); - ObjectNode schema = TYPE_SCHEMA_GENERATOR.generateSchema(type); - if ((type == Void.class) && !schema.has("properties")) { - schema.putObject("properties"); - } - processSchemaOptions(schemaOptions, schema); - return schema.toPrettyString(); - } - private static void processSchemaOptions(SchemaOption[] schemaOptions, ObjectNode schema) { if (Stream.of(schemaOptions) .noneMatch(option -> option == SchemaOption.ALLOW_ADDITIONAL_PROPERTIES_BY_DEFAULT)) { diff --git a/spring-ai-model/src/test/java/org/springframework/ai/tool/support/ToolDefinitionsTests.java b/spring-ai-model/src/test/java/org/springframework/ai/tool/support/ToolDefinitionsTests.java new file mode 100644 index 00000000000..94c5e6657b7 --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/tool/support/ToolDefinitionsTests.java @@ -0,0 +1,276 @@ +/* + * Copyright 2024-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.support; + +import java.lang.reflect.Method; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import org.junit.jupiter.api.Test; + +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.annotation.ToolParam; +import org.springframework.ai.tool.definition.DefaultToolDefinition; +import org.springframework.ai.tool.definition.ToolDefinition; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link ToolDefinitions}. + * + * @author Seol-JY + */ +class ToolDefinitionsTests { + + static class TestToolClass { + + @Tool(name = "getCurrentWeather", description = "Get current weather information") + public String getCurrentWeather(@ToolParam(description = "The city name") String city, + @ToolParam(description = "Temperature unit") String unit) { + return "Weather data for " + city + " in " + unit; + } + + @Tool(description = "Calculate sum of two numbers") + public int calculateSum(int a, int b) { + return a + b; + } + + public String nonToolMethod(String input) { + return "Not a tool method"; + } + + @Tool(description = "Process person data") + public String processPerson(PersonData person) { + return "Processing: " + person.name(); + } + + } + + record PersonData(@JsonProperty("full_name") @JsonPropertyDescription("The person's full name") String name, + @JsonPropertyDescription("The person's age") int age) { + } + + @Test + void builderShouldCreateValidBuilderForToolMethod() throws Exception { + Method method = TestToolClass.class.getMethod("getCurrentWeather", String.class, String.class); + + DefaultToolDefinition.Builder builder = ToolDefinitions.builder(method); + ToolDefinition toolDefinition = builder.build(); + + assertThat(toolDefinition).isNotNull(); + assertThat(toolDefinition.name()).isEqualTo("getCurrentWeather"); + assertThat(toolDefinition.description()).isEqualTo("Get current weather information"); + assertThat(toolDefinition.inputSchema()).isNotNull(); + assertThat(toolDefinition.inputSchema().toString()).contains("city").contains("unit"); + } + + @Test + void builderShouldCreateValidBuilderForMethodWithoutNameAnnotation() throws Exception { + Method method = TestToolClass.class.getMethod("calculateSum", int.class, int.class); + + DefaultToolDefinition.Builder builder = ToolDefinitions.builder(method); + ToolDefinition toolDefinition = builder.build(); + + assertThat(toolDefinition).isNotNull(); + assertThat(toolDefinition.name()).isEqualTo("calculateSum"); + assertThat(toolDefinition.description()).isEqualTo("Calculate sum of two numbers"); + assertThat(toolDefinition.inputSchema()).isNotNull(); + } + + @Test + void builderShouldCreateValidBuilderForMethodWithComplexParameter() throws Exception { + Method method = TestToolClass.class.getMethod("processPerson", PersonData.class); + + DefaultToolDefinition.Builder builder = ToolDefinitions.builder(method); + ToolDefinition toolDefinition = builder.build(); + + assertThat(toolDefinition).isNotNull(); + assertThat(toolDefinition.name()).isEqualTo("processPerson"); + assertThat(toolDefinition.description()).isEqualTo("Process person data"); + assertThat(toolDefinition.inputSchema()).isNotNull(); + assertThat(toolDefinition.inputSchema().toString()).contains("full_name").contains("age"); + } + + @Test + void builderShouldThrowExceptionWhenMethodIsNull() { + assertThatThrownBy(() -> ToolDefinitions.builder(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("method cannot be null"); + } + + @Test + void fromShouldCreateValidToolDefinition() throws Exception { + Method method = TestToolClass.class.getMethod("getCurrentWeather", String.class, String.class); + + ToolDefinition toolDefinition = ToolDefinitions.from(method); + + assertThat(toolDefinition).isNotNull(); + assertThat(toolDefinition.name()).isEqualTo("getCurrentWeather"); + assertThat(toolDefinition.description()).isEqualTo("Get current weather information"); + assertThat(toolDefinition.inputSchema()).isNotNull(); + } + + @Test + void fromShouldCreateConsistentToolDefinitions() throws Exception { + Method method = TestToolClass.class.getMethod("getCurrentWeather", String.class, String.class); + + ToolDefinition toolDefinition1 = ToolDefinitions.from(method); + ToolDefinition toolDefinition2 = ToolDefinitions.from(method); + + assertThat(toolDefinition1).isEqualTo(toolDefinition2); + assertThat(toolDefinition1.name()).isEqualTo(toolDefinition2.name()); + assertThat(toolDefinition1.description()).isEqualTo(toolDefinition2.description()); + assertThat(toolDefinition1.inputSchema()).isEqualTo(toolDefinition2.inputSchema()); + } + + @Test + void fromShouldThrowExceptionWhenMethodIsNull() { + assertThatThrownBy(() -> ToolDefinitions.from(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("method cannot be null"); + } + + @Test + void fromShouldReturnSameInstanceForSameMethod() throws Exception { + Method method = TestToolClass.class.getMethod("getCurrentWeather", String.class, String.class); + + ToolDefinition toolDefinition1 = ToolDefinitions.from(method); + ToolDefinition toolDefinition2 = ToolDefinitions.from(method); + + assertThat(toolDefinition1).isSameAs(toolDefinition2); + } + + @Test + void fromShouldReturnDifferentInstancesForDifferentMethods() throws Exception { + Method method1 = TestToolClass.class.getMethod("getCurrentWeather", String.class, String.class); + Method method2 = TestToolClass.class.getMethod("calculateSum", int.class, int.class); + + ToolDefinition toolDefinition1 = ToolDefinitions.from(method1); + ToolDefinition toolDefinition2 = ToolDefinitions.from(method2); + + assertThat(toolDefinition1).isNotSameAs(toolDefinition2); + assertThat(toolDefinition1.name()).isNotEqualTo(toolDefinition2.name()); + } + + @Test + void cachingShouldBeThreadSafe() throws Exception { + Method method = TestToolClass.class.getMethod("getCurrentWeather", String.class, String.class); + ExecutorService executor = Executors.newFixedThreadPool(10); + int numberOfTasks = 100; + + CompletableFuture[] futures = new CompletableFuture[numberOfTasks]; + for (int i = 0; i < numberOfTasks; i++) { + futures[i] = CompletableFuture.supplyAsync(() -> ToolDefinitions.from(method), executor); + } + + CompletableFuture allFutures = CompletableFuture.allOf(futures); + allFutures.get(5, TimeUnit.SECONDS); + + ToolDefinition first = futures[0].get(); + for (int i = 1; i < numberOfTasks; i++) { + assertThat(futures[i].get()).isSameAs(first); + } + + executor.shutdown(); + } + + @Test + void fromShouldHandleMethodWithNoParameters() throws Exception { + class NoParamToolClass { + + @Tool(description = "Get system status") + public String getSystemStatus() { + return "OK"; + } + + } + Method method = NoParamToolClass.class.getMethod("getSystemStatus"); + + ToolDefinition toolDefinition = ToolDefinitions.from(method); + + assertThat(toolDefinition).isNotNull(); + assertThat(toolDefinition.name()).isEqualTo("getSystemStatus"); + assertThat(toolDefinition.description()).isEqualTo("Get system status"); + assertThat(toolDefinition.inputSchema()).isNotNull(); + } + + @Test + void fromShouldHandleMethodWithMultipleComplexParameters() throws Exception { + class ComplexParamToolClass { + + @Tool(description = "Process complex data") + public String processComplexData(PersonData person, String[] tags, int priority) { + return "Processed"; + } + + } + Method method = ComplexParamToolClass.class.getMethod("processComplexData", PersonData.class, String[].class, + int.class); + + ToolDefinition toolDefinition = ToolDefinitions.from(method); + + assertThat(toolDefinition).isNotNull(); + assertThat(toolDefinition.name()).isEqualTo("processComplexData"); + assertThat(toolDefinition.description()).isEqualTo("Process complex data"); + assertThat(toolDefinition.inputSchema()).isNotNull(); + } + + @Test + void fromShouldHandleOverloadedMethods() throws Exception { + class OverloadedToolClass { + + @Tool(description = "Process string") + public String process(String input) { + return "String: " + input; + } + + @Tool(description = "Process number") + public String process(int input) { + return "Number: " + input; + } + + } + + Method stringMethod = OverloadedToolClass.class.getMethod("process", String.class); + Method intMethod = OverloadedToolClass.class.getMethod("process", int.class); + + ToolDefinition stringToolDefinition = ToolDefinitions.from(stringMethod); + ToolDefinition intToolDefinition = ToolDefinitions.from(intMethod); + + assertThat(stringToolDefinition).isNotSameAs(intToolDefinition); + assertThat(stringToolDefinition.name()).isEqualTo("process"); + assertThat(intToolDefinition.name()).isEqualTo("process"); + assertThat(stringToolDefinition.description()).isEqualTo("Process string"); + assertThat(intToolDefinition.description()).isEqualTo("Process number"); + } + + @Test + void builderAndFromShouldProduceEquivalentResults() throws Exception { + Method method = TestToolClass.class.getMethod("getCurrentWeather", String.class, String.class); + + ToolDefinition fromBuilder = ToolDefinitions.builder(method).build(); + ToolDefinition fromMethod = ToolDefinitions.from(method); + + assertThat(fromBuilder.name()).isEqualTo(fromMethod.name()); + assertThat(fromBuilder.description()).isEqualTo(fromMethod.description()); + assertThat(fromBuilder.inputSchema()).isEqualTo(fromMethod.inputSchema()); + } + +} diff --git a/spring-ai-model/src/test/java/org/springframework/ai/util/json/JsonSchemaGeneratorTests.java b/spring-ai-model/src/test/java/org/springframework/ai/util/json/JsonSchemaGeneratorTests.java index 243ec73bbfb..7a21490fcc0 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/util/json/JsonSchemaGeneratorTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/util/json/JsonSchemaGeneratorTests.java @@ -402,6 +402,48 @@ void generateSchemaForMethodWithToolContext() throws Exception { assertThat(schema).isEqualToIgnoringWhitespace(expectedJsonSchema); } + // CACHING TESTS + + @Test + void cacheMethodSchemaGeneration() throws Exception { + Method method = TestMethods.class.getDeclaredMethod("simpleMethod", String.class, int.class); + + String firstResult = JsonSchemaGenerator.generateForMethodInput(method); + String secondResult = JsonSchemaGenerator.generateForMethodInput(method); + + assertThat(firstResult).isEqualTo(secondResult); + assertThat(firstResult).isSameAs(secondResult); + } + + @Test + void cacheKeyIncludesSchemaOptions() throws Exception { + Method method = TestMethods.class.getDeclaredMethod("simpleMethod", String.class, int.class); + + String defaultResult = JsonSchemaGenerator.generateForMethodInput(method); + String upperCaseResult = JsonSchemaGenerator.generateForMethodInput(method, + JsonSchemaGenerator.SchemaOption.UPPER_CASE_TYPE_VALUES); + + assertThat(defaultResult).isNotEqualTo(upperCaseResult); + } + + @Test + void cacheDistinguishesDifferentMethods() throws Exception { + Method simpleMethod = TestMethods.class.getDeclaredMethod("simpleMethod", String.class, int.class); + Method annotatedMethod = TestMethods.class.getDeclaredMethod("annotatedMethod", String.class, String.class); + + String simpleResult = JsonSchemaGenerator.generateForMethodInput(simpleMethod); + String annotatedResult = JsonSchemaGenerator.generateForMethodInput(annotatedMethod); + + assertThat(simpleResult).isNotEqualTo(annotatedResult); + } + + @Test + void throwExceptionWhenMethodIsNull() { + assertThatThrownBy(() -> JsonSchemaGenerator.generateForMethodInput(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("method cannot be null"); + } + // TYPES @Test