diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java index f98da332ba0..a99bbffbacb 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java @@ -389,7 +389,7 @@ void customTemplateRendererWithCall() { .text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator() + "") .param("format", outputConverter.getFormat())) - .templateRenderer(StTemplateRenderer.builder().startDelimiterToken('<').endDelimiterToken('>').build()) + .templateRenderer(StTemplateRenderer.builder().startDelimiterToken("<").endDelimiterToken(">").build()) .call() .content(); // @formatter:on @@ -413,7 +413,7 @@ void customTemplateRendererWithCallAndAdvisor() { .text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator() + "") .param("format", outputConverter.getFormat())) - .templateRenderer(StTemplateRenderer.builder().startDelimiterToken('<').endDelimiterToken('>').build()) + .templateRenderer(StTemplateRenderer.builder().startDelimiterToken("<").endDelimiterToken(">").build()) .call() .content(); // @formatter:on @@ -438,7 +438,7 @@ void customTemplateRendererWithStream() { .text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator() + "") .param("format", outputConverter.getFormat())) - .templateRenderer(StTemplateRenderer.builder().startDelimiterToken('<').endDelimiterToken('>').build()) + .templateRenderer(StTemplateRenderer.builder().startDelimiterToken("<").endDelimiterToken(">").build()) .stream() .chatResponse(); @@ -474,7 +474,7 @@ void customTemplateRendererWithStreamAndAdvisor() { .text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator() + "") .param("format", outputConverter.getFormat())) - .templateRenderer(StTemplateRenderer.builder().startDelimiterToken('<').endDelimiterToken('>').build()) + .templateRenderer(StTemplateRenderer.builder().startDelimiterToken("<").endDelimiterToken(">").build()) .stream() .chatResponse(); diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java index 9d4d4962069..eb5d2eba6a4 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientUtilsTests.java @@ -342,8 +342,8 @@ void whenCustomTemplateRendererIsProvidedThenItIsUsedForRendering() { String systemText = "Instructions "; Map systemParams = Map.of("name", "Spring AI"); TemplateRenderer customRenderer = StTemplateRenderer.builder() - .startDelimiterToken('<') - .endDelimiterToken('>') + .startDelimiterToken("<") + .endDelimiterToken(">") .build(); ChatModel chatModel = mock(ChatModel.class); DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient diff --git a/spring-ai-template-st/src/main/java/org/springframework/ai/template/st/StTemplateRenderer.java b/spring-ai-template-st/src/main/java/org/springframework/ai/template/st/StTemplateRenderer.java index 3780b948a09..08b2bd62c0a 100644 --- a/spring-ai-template-st/src/main/java/org/springframework/ai/template/st/StTemplateRenderer.java +++ b/spring-ai-template-st/src/main/java/org/springframework/ai/template/st/StTemplateRenderer.java @@ -1,24 +1,11 @@ -/* - * Copyright 2023-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - package org.springframework.ai.template.st; +import java.util.Collections; import java.util.HashSet; import java.util.Map; import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import org.antlr.runtime.Token; import org.antlr.runtime.TokenStream; @@ -26,6 +13,7 @@ import org.slf4j.LoggerFactory; import org.stringtemplate.v4.ST; import org.stringtemplate.v4.compiler.Compiler; +import org.stringtemplate.v4.compiler.STException; import org.stringtemplate.v4.compiler.STLexer; import org.springframework.ai.template.TemplateRenderer; @@ -57,106 +45,125 @@ public class StTemplateRenderer implements TemplateRenderer { private static final String VALIDATION_MESSAGE = "Not all variables were replaced in the template. Missing variable names are: %s."; - private static final char DEFAULT_START_DELIMITER_TOKEN = '{'; + private static final String DEFAULT_START_DELIMITER = "{"; - private static final char DEFAULT_END_DELIMITER_TOKEN = '}'; + private static final String DEFAULT_END_DELIMITER = "}"; private static final ValidationMode DEFAULT_VALIDATION_MODE = ValidationMode.THROW; private static final boolean DEFAULT_VALIDATE_ST_FUNCTIONS = false; - private final char startDelimiterToken; + private final String startDelimiterToken; - private final char endDelimiterToken; + private final String endDelimiterToken; private final ValidationMode validationMode; private final boolean validateStFunctions; /** - * Constructs a new {@code StTemplateRenderer} with the specified delimiter tokens, - * validation mode, and function validation flag. - * @param startDelimiterToken the character used to denote the start of a template - * variable (e.g., '{') - * @param endDelimiterToken the character used to denote the end of a template - * variable (e.g., '}') - * @param validationMode the mode to use for template variable validation; must not be - * null - * @param validateStFunctions whether to validate StringTemplate functions in the - * template + * Constructs a StTemplateRenderer with custom delimiters, validation mode, and + * function validation flag. + * @param startDelimiterToken Multi-character start delimiter (non-null/non-empty) + * @param endDelimiterToken Multi-character end delimiter (non-null/non-empty) + * @param validationMode Mode for handling missing variables (non-null) + * @param validateStFunctions Whether to treat ST built-in functions as variables */ - public StTemplateRenderer(char startDelimiterToken, char endDelimiterToken, ValidationMode validationMode, + public StTemplateRenderer(String startDelimiterToken, String endDelimiterToken, ValidationMode validationMode, boolean validateStFunctions) { - Assert.notNull(validationMode, "validationMode cannot be null"); + Assert.notNull(validationMode, "validationMode must not be null"); + Assert.hasText(startDelimiterToken, "startDelimiterToken must not be null or empty"); + Assert.hasText(endDelimiterToken, "endDelimiterToken must not be null or empty"); + this.startDelimiterToken = startDelimiterToken; this.endDelimiterToken = endDelimiterToken; this.validationMode = validationMode; this.validateStFunctions = validateStFunctions; } + /** + * Renders the template by first converting custom delimiters to ST's native format, + * then replacing variables. + * @param template Template string with variables (non-null/non-empty) + * @param variables Map of variable names to values (non-null, keys must not be null) + * @return Rendered string with variables replaced + */ @Override public String apply(String template, Map variables) { - Assert.hasText(template, "template cannot be null or empty"); - Assert.notNull(variables, "variables cannot be null"); - Assert.noNullElements(variables.keySet(), "variables keys cannot be null"); + Assert.hasText(template, "template must not be null or empty"); + Assert.notNull(variables, "variables must not be null"); + Assert.noNullElements(variables.keySet(), "variables keys must not contain null"); - ST st = createST(template); - for (Map.Entry entry : variables.entrySet()) { - st.add(entry.getKey(), entry.getValue()); + try { + String processedTemplate = preprocessTemplate(template); + ST st = new ST(processedTemplate, '{', '}'); + variables.forEach(st::add); + + if (validationMode != ValidationMode.NONE) { + validate(st, variables); + } + + return st.render(); } - if (this.validationMode != ValidationMode.NONE) { - validate(st, variables); + catch (STException e) { + throw new IllegalArgumentException("Failed to render template", e); } - return st.render(); } - private ST createST(String template) { - try { - return new ST(template, this.startDelimiterToken, this.endDelimiterToken); - } - catch (Exception ex) { - throw new IllegalArgumentException("The template string is not valid.", ex); + /** + * Converts custom delimiter-wrapped variables (e.g., ) to ST's native format + * ({name}). + */ + private String preprocessTemplate(String template) { + if ("{".equals(startDelimiterToken) && "}".equals(endDelimiterToken)) { + return template; } + String escapedStart = Pattern.quote(startDelimiterToken); + String escapedEnd = Pattern.quote(endDelimiterToken); + String variablePattern = escapedStart + "([a-zA-Z_][a-zA-Z0-9_]*)" + escapedEnd; + return template.replaceAll(variablePattern, "{$1}"); } /** - * Validates that all required template variables are provided in the model. Returns - * the set of missing variables for further handling or logging. - * @param st the StringTemplate instance - * @param templateVariables the provided variables - * @return set of missing variable names, or empty set if none are missing + * Validates that all template variables have been provided in the variables map. */ - private Set validate(ST st, Map templateVariables) { + private void validate(ST st, Map templateVariables) { Set templateTokens = getInputVariables(st); - Set modelKeys = templateVariables != null ? templateVariables.keySet() : new HashSet<>(); + Set modelKeys = templateVariables != null ? templateVariables.keySet() : Collections.emptySet(); Set missingVariables = new HashSet<>(templateTokens); missingVariables.removeAll(modelKeys); if (!missingVariables.isEmpty()) { - if (this.validationMode == ValidationMode.WARN) { - logger.warn(VALIDATION_MESSAGE.formatted(missingVariables)); + String message = VALIDATION_MESSAGE.formatted(missingVariables); + if (validationMode == ValidationMode.WARN) { + logger.warn(message); } - else if (this.validationMode == ValidationMode.THROW) { - throw new IllegalStateException(VALIDATION_MESSAGE.formatted(missingVariables)); + else if (validationMode == ValidationMode.THROW) { + throw new IllegalStateException(message); } } - return missingVariables; } + /** + * Extracts variable names from the template using ST's token stream and regex + * validation. + */ private Set getInputVariables(ST st) { - TokenStream tokens = st.impl.tokens; Set inputVariables = new HashSet<>(); + TokenStream tokens = st.impl.tokens; boolean isInsideList = false; + Set stKeywords = Set.of("if", "elseif", "else", "endif", "for", "endfor", "while", "endwhile", "switch", + "endswitch", "case", "default"); + for (int i = 0; i < tokens.size(); i++) { Token token = tokens.get(i); - // Handle list variables with option (e.g., {items; separator=", "}) if (token.getType() == STLexer.LDELIM && i + 1 < tokens.size() && tokens.get(i + 1).getType() == STLexer.ID) { if (i + 2 < tokens.size() && tokens.get(i + 2).getType() == STLexer.COLON) { String text = tokens.get(i + 1).getText(); - if (!Compiler.funcs.containsKey(text) || this.validateStFunctions) { + if ((!Compiler.funcs.containsKey(text) || validateStFunctions) && !stKeywords.contains(text)) { inputVariables.add(text); isInsideList = true; } @@ -165,34 +172,49 @@ private Set getInputVariables(ST st) { else if (token.getType() == STLexer.RDELIM) { isInsideList = false; } - // Only add IDs that are not function calls (i.e., not immediately followed by else if (!isInsideList && token.getType() == STLexer.ID) { boolean isFunctionCall = (i + 1 < tokens.size() && tokens.get(i + 1).getType() == STLexer.LPAREN); boolean isDotProperty = (i > 0 && tokens.get(i - 1).getType() == STLexer.DOT); - // Only add as variable if: - // - Not a function call - // - Not a built-in function used as property (unless validateStFunctions) - if (!isFunctionCall && (!Compiler.funcs.containsKey(token.getText()) || this.validateStFunctions - || !(isDotProperty && Compiler.funcs.containsKey(token.getText())))) { - inputVariables.add(token.getText()); + String tokenText = token.getText(); + if (!isFunctionCall && (!Compiler.funcs.containsKey(tokenText) || validateStFunctions + || !(isDotProperty && Compiler.funcs.containsKey(tokenText)))) { + if (!stKeywords.contains(tokenText)) { + inputVariables.add(tokenText); + } } } } + + Pattern varPattern = Pattern.compile(Pattern.quote("{") + "([a-zA-Z_][a-zA-Z0-9_]*)" + Pattern.quote("}")); + Matcher matcher = varPattern.matcher(st.impl.template); + while (matcher.find()) { + String var = matcher.group(1); + if (!stKeywords.contains(var)) { + inputVariables.add(var); + } + } + + Set localVariables = Set.of("it", "item", "index", "key", "value"); + inputVariables.removeAll(localVariables); + return inputVariables; } + /** + * Creates a builder for configuring StTemplateRenderer instances. + */ public static Builder builder() { return new Builder(); } /** - * Builder for configuring and creating {@link StTemplateRenderer} instances. + * Builder for fluent configuration of StTemplateRenderer. */ public static final class Builder { - private char startDelimiterToken = DEFAULT_START_DELIMITER_TOKEN; + private String startDelimiterToken = DEFAULT_START_DELIMITER; - private char endDelimiterToken = DEFAULT_END_DELIMITER_TOKEN; + private String endDelimiterToken = DEFAULT_END_DELIMITER; private ValidationMode validationMode = DEFAULT_VALIDATION_MODE; @@ -202,33 +224,23 @@ private Builder() { } /** - * Sets the character used as the start delimiter for template expressions. - * Default is '{'. - * @param startDelimiterToken The start delimiter character. - * @return This builder instance for chaining. + * Sets the multi-character start delimiter (e.g., "{{" or "<"). */ - public Builder startDelimiterToken(char startDelimiterToken) { + public Builder startDelimiterToken(String startDelimiterToken) { this.startDelimiterToken = startDelimiterToken; return this; } /** - * Sets the character used as the end delimiter for template expressions. Default - * is '}'. - * @param endDelimiterToken The end delimiter character. - * @return This builder instance for chaining. + * Sets the multi-character end delimiter (e.g., "}}" or ">"). */ - public Builder endDelimiterToken(char endDelimiterToken) { + public Builder endDelimiterToken(String endDelimiterToken) { this.endDelimiterToken = endDelimiterToken; return this; } /** - * Sets the validation mode to control behavior when the provided variables do not - * match the variables required by the template. Default is - * {@link ValidationMode#THROW}. - * @param validationMode The desired validation mode. - * @return This builder instance for chaining. + * Sets the validation mode for missing variables. */ public Builder validationMode(ValidationMode validationMode) { this.validationMode = validationMode; @@ -236,17 +248,7 @@ public Builder validationMode(ValidationMode validationMode) { } /** - * Configures the renderer to support StringTemplate's built-in functions during - * validation. - *

- * When enabled (set to true), identifiers in the template that match known ST - * function names (e.g., "first", "rest", "length") will not be treated as - * required input variables during validation. - *

- * When disabled (default, false), these identifiers are treated like regular - * variables and must be provided in the input map if validation is enabled - * ({@link ValidationMode#WARN} or {@link ValidationMode#THROW}). - * @return This builder instance for chaining. + * Enables validation of ST built-in functions (treats them as variables). */ public Builder validateStFunctions() { this.validateStFunctions = true; @@ -254,13 +256,10 @@ public Builder validateStFunctions() { } /** - * Builds and returns a new {@link StTemplateRenderer} instance with the - * configured settings. - * @return A configured {@link StTemplateRenderer}. + * Builds the configured StTemplateRenderer instance. */ public StTemplateRenderer build() { - return new StTemplateRenderer(this.startDelimiterToken, this.endDelimiterToken, this.validationMode, - this.validateStFunctions); + return new StTemplateRenderer(startDelimiterToken, endDelimiterToken, validationMode, validateStFunctions); } } diff --git a/spring-ai-template-st/src/test/java/org/springframework/ai/template/st/StTemplateRendererTests.java b/spring-ai-template-st/src/test/java/org/springframework/ai/template/st/StTemplateRendererTests.java index 1dd548c5c0e..818d1b6e4f2 100644 --- a/spring-ai-template-st/src/test/java/org/springframework/ai/template/st/StTemplateRendererTests.java +++ b/spring-ai-template-st/src/test/java/org/springframework/ai/template/st/StTemplateRendererTests.java @@ -38,15 +38,15 @@ class StTemplateRendererTests { void shouldNotAcceptNullValidationMode() { assertThatThrownBy(() -> StTemplateRenderer.builder().validationMode(null).build()) .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("validationMode cannot be null"); + .hasMessageContaining("validationMode must not be null"); } @Test void shouldUseDefaultValuesWhenUsingBuilder() { StTemplateRenderer renderer = StTemplateRenderer.builder().build(); - assertThat(ReflectionTestUtils.getField(renderer, "startDelimiterToken")).isEqualTo('{'); - assertThat(ReflectionTestUtils.getField(renderer, "endDelimiterToken")).isEqualTo('}'); + assertThat(ReflectionTestUtils.getField(renderer, "startDelimiterToken")).isEqualTo("{"); + assertThat(ReflectionTestUtils.getField(renderer, "endDelimiterToken")).isEqualTo("}"); assertThat(ReflectionTestUtils.getField(renderer, "validationMode")).isEqualTo(ValidationMode.THROW); } @@ -80,14 +80,14 @@ void shouldNotRenderEmptyTemplate() { Map variables = new HashMap<>(); assertThatThrownBy(() -> renderer.apply("", variables)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("template cannot be null or empty"); + .hasMessageContaining("template must not be null or empty"); } @Test void shouldNotAcceptNullVariables() { StTemplateRenderer renderer = StTemplateRenderer.builder().build(); assertThatThrownBy(() -> renderer.apply("Hello!", null)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("variables cannot be null"); + .hasMessageContaining("variables must not be null"); } @Test @@ -98,7 +98,7 @@ void shouldNotAcceptVariablesWithNullKeySet() { variables.put(null, "Spring AI"); assertThatThrownBy(() -> renderer.apply(template, variables)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("variables keys cannot be null"); + .hasMessageContaining("variables keys must not contain null"); } @Test @@ -108,7 +108,7 @@ void shouldThrowExceptionForInvalidTemplateSyntax() { variables.put("name", "Spring AI"); assertThatThrownBy(() -> renderer.apply("Hello {name!", variables)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("The template string is not valid."); + .hasMessageContaining("Failed to render template"); } @Test @@ -148,8 +148,8 @@ void shouldRenderWithoutValidationInNoneMode() { @Test void shouldRenderWithCustomDelimiters() { StTemplateRenderer renderer = StTemplateRenderer.builder() - .startDelimiterToken('<') - .endDelimiterToken('>') + .startDelimiterToken("<") + .endDelimiterToken(">") .build(); Map variables = new HashMap<>(); variables.put("name", "Spring AI"); @@ -159,11 +159,40 @@ void shouldRenderWithCustomDelimiters() { assertThat(result).isEqualTo("Hello Spring AI!"); } + @Test + void shouldRenderWithDoubleAngleBracketDelimiters() { + StTemplateRenderer renderer = StTemplateRenderer.builder() + .startDelimiterToken("<<") + .endDelimiterToken(">>") + .build(); + + Map variables = new HashMap<>(); + variables.put("name", "Spring AI"); + + String result = renderer.apply("Hello <>!", variables); + + assertThat(result).isEqualTo("Hello Spring AI!"); + } + + @Test + void shouldHandleDoubleCurlyBracesAsDelimiters() { + StTemplateRenderer renderer = StTemplateRenderer.builder() + .startDelimiterToken("{{") + .endDelimiterToken("}}") + .build(); + Map variables = new HashMap<>(); + variables.put("name", "Spring AI"); + + String result = renderer.apply("Hello {{name}}!", variables); + + assertThat(result).isEqualTo("Hello Spring AI!"); + } + @Test void shouldHandleSpecialCharactersAsDelimiters() { StTemplateRenderer renderer = StTemplateRenderer.builder() - .startDelimiterToken('$') - .endDelimiterToken('$') + .startDelimiterToken("$") + .endDelimiterToken("$") .build(); Map variables = new HashMap<>(); variables.put("name", "Spring AI"); @@ -287,9 +316,11 @@ void shouldHandleObjectVariables() { */ @Test void shouldRenderTemplateWithBuiltInFunctions() { - StTemplateRenderer renderer = StTemplateRenderer.builder().build(); + StTemplateRenderer renderer = StTemplateRenderer.builder().validationMode(ValidationMode.THROW).build(); + Map variables = new HashMap<>(); variables.put("memory", "you are a helpful assistant"); + String template = "{if(strlen(memory))}Hello!{endif}"; String result = renderer.apply(template, variables);