From ffa9cbe22efbff58ae9f375b38a4a5db3da536d6 Mon Sep 17 00:00:00 2001 From: Soby Chacko Date: Fri, 25 Apr 2025 21:56:09 -0400 Subject: [PATCH] GH-2518: Remove requestOptions from observation context objects Fixes: #2518 Issue: https://github.com/spring-projects/spring-ai/issues/2518 This commit removes the deprecated requestOptions field from ChatModelObservationContext and EmbeddingModelObservationContext classes. Instead of passing options separately, the code now retrieves them directly from the request objects (prompt.getOptions() or embeddingRequest.getOptions()). Key changes: - Removed requestOptions parameter from observation context builders - Updated all model implementations to stop passing options separately - Fixed EmbeddingRequest handling in several model implementations - Added buildEmbeddingRequest method in models to properly merge options This change simplifies the API and removes duplication, as options are already available in the request objects themselves. Signed-off-by: Soby Chacko --- .../ai/anthropic/AnthropicChatModel.java | 2 - .../ai/azure/openai/AzureOpenAiChatModel.java | 2 - .../openai/AzureOpenAiEmbeddingModel.java | 12 +++-- .../converse/BedrockProxyChatModel.java | 2 - .../ai/minimax/MiniMaxChatModel.java | 2 - .../ai/minimax/MiniMaxEmbeddingModel.java | 45 ++++++++-------- .../ai/minimax/api/MiniMaxRetryTests.java | 10 ++-- .../ai/mistralai/MistralAiChatModel.java | 2 - .../ai/mistralai/MistralAiEmbeddingModel.java | 3 +- .../ai/oci/OCIEmbeddingModel.java | 33 ++++++++++-- .../ai/oci/cohere/OCICohereChatModel.java | 20 ++++++- .../ai/ollama/OllamaChatModel.java | 2 - .../ai/ollama/OllamaEmbeddingModel.java | 1 - .../ai/openai/OpenAiChatModel.java | 2 - .../ai/openai/OpenAiEmbeddingModel.java | 3 +- .../TransformersEmbeddingModel.java | 1 - .../text/VertexAiTextEmbeddingModel.java | 35 +++++++----- .../text/VertexAiTextEmbeddingRetryTests.java | 7 ++- .../gemini/VertexAiGeminiChatModel.java | 2 - .../ai/zhipuai/ZhiPuAiChatModel.java | 2 - .../ai/zhipuai/ZhiPuAiEmbeddingModel.java | 43 +++++++-------- .../ai/zhipuai/api/ZhiPuAiRetryTests.java | 8 +-- .../ChatModelObservationContext.java | 28 +--------- ...DefaultChatModelObservationConvention.java | 54 +++++++++++-------- ...ltEmbeddingModelObservationConvention.java | 15 +++--- .../EmbeddingModelObservationContext.java | 33 ++---------- ...ModelCompletionObservationFilterTests.java | 15 +++--- ...odelCompletionObservationHandlerTests.java | 6 +-- ...ChatModelMeterObservationHandlerTests.java | 9 ++-- .../ChatModelObservationContextTests.java | 19 ++----- ...elPromptContentObservationFilterTests.java | 15 +++--- ...lPromptContentObservationHandlerTests.java | 6 +-- ...ltChatModelObservationConventionTests.java | 31 +++++------ ...eddingModelObservationConventionTests.java | 28 +++++----- ...dingModelMeterObservationHandlerTests.java | 10 ++-- ...EmbeddingModelObservationContextTests.java | 21 +++----- 36 files changed, 248 insertions(+), 281 deletions(-) diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 72e327ed623..f36a0eb11db 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -177,7 +177,6 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(AnthropicApi.PROVIDER_NAME) - .requestOptions(prompt.getOptions()) .build(); ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION @@ -240,7 +239,6 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(AnthropicApi.PROVIDER_NAME) - .requestOptions(prompt.getOptions()) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index e85f9e03342..bed72e982ae 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -245,7 +245,6 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(AiProvider.AZURE_OPENAI.value()) - .requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions) .build(); ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION @@ -300,7 +299,6 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(AiProvider.AZURE_OPENAI.value()) - .requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java index a5f5b335781..c63ed598ba8 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,6 +51,7 @@ * @author Mark Pollack * @author Christian Tzolov * @author Thomas Vitale + * @author Soby Chacko * @since 1.0.0 */ public class AzureOpenAiEmbeddingModel extends AbstractEmbeddingModel { @@ -124,12 +125,15 @@ public EmbeddingResponse call(EmbeddingRequest embeddingRequest) { .from(this.defaultOptions) .merge(embeddingRequest.getOptions()) .build(); - EmbeddingsOptions azureOptions = options.toAzureOptions(embeddingRequest.getInstructions()); + + EmbeddingRequest embeddingRequestWithMergedOptions = new EmbeddingRequest(embeddingRequest.getInstructions(), + options); + + EmbeddingsOptions azureOptions = options.toAzureOptions(embeddingRequestWithMergedOptions.getInstructions()); var observationContext = EmbeddingModelObservationContext.builder() - .embeddingRequest(embeddingRequest) + .embeddingRequest(embeddingRequestWithMergedOptions) .provider(AiProvider.AZURE_OPENAI.value()) - .requestOptions(options) .build(); return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index 1cb4b2e547f..0f4d136a140 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -220,7 +220,6 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatRespon ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(AiProvider.BEDROCK_CONVERSE.value()) - .requestOptions(prompt.getOptions()) .build(); ChatResponse chatResponse = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION @@ -647,7 +646,6 @@ private Flux internalStream(Prompt prompt, ChatResponse perviousCh ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(AiProvider.BEDROCK_CONVERSE.value()) - .requestOptions(prompt.getOptions()) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java index 740666389c0..e5a774cacf9 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java @@ -241,7 +241,6 @@ public ChatResponse call(Prompt prompt) { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(requestPrompt) .provider(MiniMaxApiConstants.PROVIDER_NAME) - .requestOptions(requestPrompt.getOptions()) .build(); ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION @@ -334,7 +333,6 @@ public Flux stream(Prompt prompt) { final ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(requestPrompt) .provider(MiniMaxApiConstants.PROVIDER_NAME) - .requestOptions(requestPrompt.getOptions()) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java index fec3b0c310b..9a5983785f7 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,13 +43,15 @@ import org.springframework.lang.Nullable; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; /** * MiniMax Embedding Model implementation. * * @author Geng Rong * @author Thomas Vitale - * @since 1.0.0 M1 + * @author Soby Chacko + * @since 1.0.0 */ public class MiniMaxEmbeddingModel extends AbstractEmbeddingModel { @@ -149,14 +151,15 @@ public float[] embed(Document document) { @Override public EmbeddingResponse call(EmbeddingRequest request) { - MiniMaxEmbeddingOptions requestOptions = mergeOptions(request.getOptions(), this.defaultOptions); + + EmbeddingRequest embeddingRequest = buildEmbeddingRequest(request); + MiniMaxApi.EmbeddingRequest apiRequest = new MiniMaxApi.EmbeddingRequest(request.getInstructions(), - requestOptions.getModel()); + embeddingRequest.getOptions().getModel()); var observationContext = EmbeddingModelObservationContext.builder() .embeddingRequest(request) .provider(MiniMaxApiConstants.PROVIDER_NAME) - .requestOptions(requestOptions) .build(); return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION @@ -188,26 +191,24 @@ private DefaultUsage getDefaultUsage(MiniMaxApi.EmbeddingList apiEmbeddingList) return new DefaultUsage(0, 0, apiEmbeddingList.totalTokens()); } - /** - * Merge runtime and default {@link EmbeddingOptions} to compute the final options to - * use in the request. - */ - private MiniMaxEmbeddingOptions mergeOptions(@Nullable EmbeddingOptions runtimeOptions, - MiniMaxEmbeddingOptions defaultOptions) { - var runtimeOptionsForProvider = ModelOptionsUtils.copyToTarget(runtimeOptions, EmbeddingOptions.class, + EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) { + // Process runtime options + MiniMaxEmbeddingOptions runtimeOptions = null; + if (embeddingRequest.getOptions() != null) { + runtimeOptions = ModelOptionsUtils.copyToTarget(embeddingRequest.getOptions(), EmbeddingOptions.class, + MiniMaxEmbeddingOptions.class); + } + + // Define request options by merging runtime options and default options + MiniMaxEmbeddingOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, MiniMaxEmbeddingOptions.class); - var optionBuilder = MiniMaxEmbeddingOptions.builder(); - if (runtimeOptionsForProvider != null && runtimeOptionsForProvider.getModel() != null) { - optionBuilder.model(runtimeOptionsForProvider.getModel()); - } - else if (defaultOptions.getModel() != null) { - optionBuilder.model(defaultOptions.getModel()); + // Validate request options + if (!StringUtils.hasText(requestOptions.getModel())) { + throw new IllegalArgumentException("model cannot be null or empty"); } - else { - optionBuilder.model(MiniMaxApi.DEFAULT_EMBEDDING_MODEL); - } - return optionBuilder.build(); + + return new EmbeddingRequest(embeddingRequest.getInstructions(), requestOptions); } public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) { diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java index 5e860f33a86..9f165e27c0d 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/api/MiniMaxRetryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,6 +29,7 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.embedding.EmbeddingOptions; import org.springframework.ai.minimax.MiniMaxChatModel; import org.springframework.ai.minimax.MiniMaxChatOptions; import org.springframework.ai.minimax.MiniMaxEmbeddingModel; @@ -57,6 +58,7 @@ /** * @author Geng Rong + * @author Soby Chacko */ @SuppressWarnings("unchecked") @ExtendWith(MockitoExtension.class) @@ -150,8 +152,9 @@ public void miniMaxEmbeddingTransientError() { .willThrow(new TransientAiException("Transient Error 2")) .willReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); + EmbeddingOptions options = MiniMaxEmbeddingOptions.builder().model("model").build(); var result = this.embeddingModel - .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); + .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), options)); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); @@ -163,8 +166,9 @@ public void miniMaxEmbeddingTransientError() { public void miniMaxEmbeddingNonTransientError() { given(this.miniMaxApi.embeddings(isA(EmbeddingRequest.class))) .willThrow(new RuntimeException("Non Transient Error")); + EmbeddingOptions options = MiniMaxEmbeddingOptions.builder().model("model").build(); assertThrows(RuntimeException.class, () -> this.embeddingModel - .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); + .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), options))); } private class TestRetryListener implements RetryListener { diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index 7b8a3ee9136..738e2640dcf 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -187,7 +187,6 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(MistralAiApi.PROVIDER_NAME) - .requestOptions(prompt.getOptions()) .build(); ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION @@ -260,7 +259,6 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(MistralAiApi.PROVIDER_NAME) - .requestOptions(prompt.getOptions()) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java index 0908f57c534..347128f0ea0 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java @@ -116,9 +116,8 @@ public EmbeddingResponse call(EmbeddingRequest request) { var apiRequest = createRequest(embeddingRequest); var observationContext = EmbeddingModelObservationContext.builder() - .embeddingRequest(request) + .embeddingRequest(embeddingRequest) .provider(MistralAiApi.PROVIDER_NAME) - .requestOptions(embeddingRequest.getOptions()) .build(); return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION diff --git a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java index 81b02107c79..8ed9299a7f7 100644 --- a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java +++ b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -43,6 +43,7 @@ import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; /** * {@link org.springframework.ai.embedding.EmbeddingModel} implementation that uses the @@ -83,13 +84,15 @@ public OCIEmbeddingModel(GenerativeAiInference genAi, OCIEmbeddingOptions option @Override public EmbeddingResponse call(EmbeddingRequest request) { Assert.notEmpty(request.getInstructions(), "At least one text is required!"); - OCIEmbeddingOptions runtimeOptions = mergeOptions(request.getOptions(), this.options); - List embedTextRequests = createRequests(request.getInstructions(), runtimeOptions); + + EmbeddingRequest embeddingRequest = buildEmbeddingRequest(request); + + List embedTextRequests = createRequests(embeddingRequest.getInstructions(), + (OCIEmbeddingOptions) embeddingRequest.getOptions()); EmbeddingModelObservationContext context = EmbeddingModelObservationContext.builder() - .embeddingRequest(request) + .embeddingRequest(embeddingRequest) .provider(AiProvider.OCI_GENAI.value()) - .requestOptions(runtimeOptions) .build(); return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION @@ -158,6 +161,26 @@ private OCIEmbeddingOptions mergeOptions(EmbeddingOptions embeddingOptions, OCIE return defaultOptions; } + EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) { + // Process runtime options + OCIEmbeddingOptions runtimeOptions = null; + if (embeddingRequest.getOptions() != null) { + runtimeOptions = ModelOptionsUtils.copyToTarget(embeddingRequest.getOptions(), EmbeddingOptions.class, + OCIEmbeddingOptions.class); + } + + // Define request options by merging runtime options and default options + OCIEmbeddingOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.options, + OCIEmbeddingOptions.class); + + // Validate request options + if (!StringUtils.hasText(requestOptions.getModel())) { + throw new IllegalArgumentException("model cannot be null or empty"); + } + + return new EmbeddingRequest(embeddingRequest.getInstructions(), requestOptions); + } + private float[] toFloats(List embedding) { float[] floats = new float[embedding.size()]; for (int i = 0; i < embedding.size(); i++) { diff --git a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java index d3a23b5d9d5..462a9773ab4 100644 --- a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java +++ b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatModel.java @@ -53,6 +53,7 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.oci.ServingModeHelper; import org.springframework.util.Assert; @@ -104,10 +105,10 @@ public OCICohereChatModel(GenerativeAiInference genAi, OCICohereChatOptions opti @Override public ChatResponse call(Prompt prompt) { + Prompt requestPrompt = this.buildRequestPrompt(prompt); ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(prompt) + .prompt(requestPrompt) .provider(AiProvider.OCI_GENAI.value()) - .requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions) .build(); return ChatModelObservationDocumentation.CHAT_MODEL_OPERATION @@ -120,6 +121,21 @@ public ChatResponse call(Prompt prompt) { }); } + Prompt buildRequestPrompt(Prompt prompt) { + // Process runtime options + OCICohereChatOptions runtimeOptions = null; + if (prompt.getOptions() != null) { + runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, + OCICohereChatOptions.class); + } + + // Define request options by merging runtime options and default options + OCICohereChatOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, + OCICohereChatOptions.class); + + return new Prompt(prompt.getInstructions(), requestOptions); + } + @Override public ChatOptions getDefaultOptions() { return OCICohereChatOptions.fromOptions(this.defaultOptions); diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index bc886108fc8..051dabde9b0 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -225,7 +225,6 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(OllamaApi.PROVIDER_NAME) - .requestOptions(prompt.getOptions()) .build(); ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION @@ -296,7 +295,6 @@ private Flux internalStream(Prompt prompt, ChatResponse previousCh final ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(OllamaApi.PROVIDER_NAME) - .requestOptions(prompt.getOptions()) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java index da0408782e6..4a5710c9aed 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java @@ -113,7 +113,6 @@ public EmbeddingResponse call(EmbeddingRequest request) { var observationContext = EmbeddingModelObservationContext.builder() .embeddingRequest(request) .provider(OllamaApi.PROVIDER_NAME) - .requestOptions(embeddingRequest.getOptions()) .build(); return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 87636a18af2..5fbb8a283be 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -186,7 +186,6 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(OpenAiApiConstants.PROVIDER_NAME) - .requestOptions(prompt.getOptions()) .build(); ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION @@ -289,7 +288,6 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha final ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(OpenAiApiConstants.PROVIDER_NAME) - .requestOptions(prompt.getOptions()) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java index 2ac56916fc4..47c06ac5a72 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java @@ -156,9 +156,8 @@ public EmbeddingResponse call(EmbeddingRequest request) { OpenAiApi.EmbeddingRequest> apiRequest = createRequest(embeddingRequest); var observationContext = EmbeddingModelObservationContext.builder() - .embeddingRequest(request) + .embeddingRequest(embeddingRequest) .provider(OpenAiApiConstants.PROVIDER_NAME) - .requestOptions(embeddingRequest.getOptions()) .build(); return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION diff --git a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java index 5fd6f4a8bf9..a7324cad72b 100644 --- a/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java +++ b/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java @@ -289,7 +289,6 @@ public EmbeddingResponse call(EmbeddingRequest request) { var observationContext = EmbeddingModelObservationContext.builder() .embeddingRequest(request) .provider(AiProvider.ONNX.value()) - .requestOptions(request.getOptions()) .build(); return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java index 836137095c4..4bef9d1145b 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java @@ -35,6 +35,7 @@ import org.springframework.ai.document.Document; import org.springframework.ai.embedding.AbstractEmbeddingModel; import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingOptions; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResponseMetadata; @@ -59,6 +60,7 @@ * @author Christian Tzolov * @author Mark Pollack * @author Rodrigo Malara + * @author Soby Chacko * @since 1.0.0 */ public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel { @@ -117,12 +119,11 @@ public float[] embed(Document document) { @Override public EmbeddingResponse call(EmbeddingRequest request) { - final VertexAiTextEmbeddingOptions finalOptions = mergedOptions(request); + EmbeddingRequest embeddingRequest = buildEmbeddingRequest(request); var observationContext = EmbeddingModelObservationContext.builder() - .embeddingRequest(request) + .embeddingRequest(embeddingRequest) .provider(AiProvider.VERTEX_AI.value()) - .requestOptions(finalOptions) .build(); return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION @@ -131,10 +132,11 @@ public EmbeddingResponse call(EmbeddingRequest request) { .observe(() -> { try (PredictionServiceClient client = createPredictionServiceClient()) { - EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel()); + EmbeddingOptions options = embeddingRequest.getOptions(); + EndpointName endpointName = this.connectionDetails.getEndpointName(options.getModel()); PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName, - finalOptions); + (VertexAiTextEmbeddingOptions) options); PredictResponse embeddingResponse = this.retryTemplate .execute(context -> getPredictResponse(client, predictRequestBuilder)); @@ -155,7 +157,7 @@ public EmbeddingResponse call(EmbeddingRequest request) { embeddingList.add(new Embedding(vectorValues, index++)); } EmbeddingResponse response = new EmbeddingResponse(embeddingList, - generateResponseMetadata(finalOptions.getModel(), totalTokenCount)); + generateResponseMetadata(options.getModel(), totalTokenCount)); observationContext.setResponse(response); @@ -164,17 +166,24 @@ public EmbeddingResponse call(EmbeddingRequest request) { }); } - private VertexAiTextEmbeddingOptions mergedOptions(EmbeddingRequest request) { + EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) { + // Process runtime options + VertexAiTextEmbeddingOptions runtimeOptions = null; + if (embeddingRequest.getOptions() != null) { + runtimeOptions = ModelOptionsUtils.copyToTarget(embeddingRequest.getOptions(), EmbeddingOptions.class, + VertexAiTextEmbeddingOptions.class); + } - VertexAiTextEmbeddingOptions mergedOptions = this.defaultOptions; + // Define request options by merging runtime options and default options + VertexAiTextEmbeddingOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, + VertexAiTextEmbeddingOptions.class); - if (request.getOptions() != null) { - var defaultOptionsCopy = VertexAiTextEmbeddingOptions.builder().from(this.defaultOptions).build(); - mergedOptions = ModelOptionsUtils.merge(request.getOptions(), defaultOptionsCopy, - VertexAiTextEmbeddingOptions.class); + // Validate request options + if (!StringUtils.hasText(requestOptions.getModel())) { + throw new IllegalArgumentException("model cannot be null or empty"); } - return mergedOptions; + return new EmbeddingRequest(embeddingRequest.getInstructions(), requestOptions); } protected PredictRequest.Builder getPredictRequestBuilder(EmbeddingRequest request, EndpointName endpointName, diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java index 3430791d52a..088c87bc75a 100644 --- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingRetryTests.java @@ -30,6 +30,7 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.ai.embedding.EmbeddingOptions; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.retry.RetryUtils; @@ -116,7 +117,8 @@ public void vertexAiEmbeddingTransientError() { .willThrow(new TransientAiException("Transient Error 2")) .willReturn(mockResponse); - EmbeddingResponse result = this.embeddingModel.call(new EmbeddingRequest(List.of("text1", "text2"), null)); + EmbeddingOptions options = VertexAiTextEmbeddingOptions.builder().model("model").build(); + EmbeddingResponse result = this.embeddingModel.call(new EmbeddingRequest(List.of("text1", "text2"), options)); assertThat(result).isNotNull(); assertThat(result.getResults()).hasSize(1); @@ -132,8 +134,9 @@ public void vertexAiEmbeddingNonTransientError() { // Setup the mock PredictionServiceClient to throw a non-transient error given(this.mockPredictionServiceClient.predict(any())).willThrow(new RuntimeException("Non Transient Error")); + EmbeddingOptions options = VertexAiTextEmbeddingOptions.builder().model("model").build(); // Assert that a RuntimeException is thrown and not retried - assertThatThrownBy(() -> this.embeddingModel.call(new EmbeddingRequest(List.of("text1", "text2"), null))) + assertThatThrownBy(() -> this.embeddingModel.call(new EmbeddingRequest(List.of("text1", "text2"), options))) .isInstanceOf(RuntimeException.class); // Verify that predict was called only once (no retries for non-transient errors) diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 0ecd838480b..36fcfd649e6 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -366,7 +366,6 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(VertexAiGeminiConstants.PROVIDER_NAME) - .requestOptions(prompt.getOptions()) .build(); ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION @@ -478,7 +477,6 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) .provider(VertexAiGeminiConstants.PROVIDER_NAME) - .requestOptions(prompt.getOptions()) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java index 57a9527ca6f..408666fdc34 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java @@ -242,7 +242,6 @@ public ChatResponse call(Prompt prompt) { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(requestPrompt) .provider(ZhiPuApiConstants.PROVIDER_NAME) - .requestOptions(prompt.getOptions()) .build(); ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION @@ -319,7 +318,6 @@ public Flux stream(Prompt prompt) { final ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(requestPrompt) .provider(ZhiPuApiConstants.PROVIDER_NAME) - .requestOptions(buildRequestOptions(request)) .build(); Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation( diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java index f310f65cc97..f8c3f620529 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,15 +41,16 @@ import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.zhipuai.api.ZhiPuAiApi; import org.springframework.ai.zhipuai.api.ZhiPuApiConstants; -import org.springframework.lang.Nullable; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; /** * ZhiPuAI Embedding Model implementation. * * @author Geng Rong - * @since 1.0.0 M1 + * @author Soby Chacko + * @since 1.0.0 */ public class ZhiPuAiEmbeddingModel extends AbstractEmbeddingModel { @@ -153,12 +154,12 @@ public EmbeddingResponse call(EmbeddingRequest request) { logger.warn( "ZhiPu Embedding does not support batch embedding. Will make multiple API calls to embed(Document)"); } - ZhiPuAiEmbeddingOptions requestOptions = mergeOptions(request.getOptions(), this.defaultOptions); + + EmbeddingRequest embeddingRequest = buildEmbeddingRequest(request); var observationContext = EmbeddingModelObservationContext.builder() - .embeddingRequest(request) + .embeddingRequest(embeddingRequest) .provider(ZhiPuApiConstants.PROVIDER_NAME) - .requestOptions(requestOptions) .build(); return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION @@ -170,7 +171,7 @@ public EmbeddingResponse call(EmbeddingRequest request) { var totalUsage = new ZhiPuAiApi.Usage(0, 0, 0); for (String inputContent : request.getInstructions()) { - var apiRequest = createEmbeddingRequest(inputContent, requestOptions); + var apiRequest = createEmbeddingRequest(inputContent, embeddingRequest.getOptions()); ZhiPuAiApi.EmbeddingList response = this.retryTemplate .execute(ctx -> this.zhiPuAiApi.embeddings(apiRequest).getBody()); @@ -210,24 +211,24 @@ private DefaultUsage getDefaultUsage(ZhiPuAiApi.Usage usage) { return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage); } - /** - * Merge runtime and default {@link EmbeddingOptions} to compute the final options to - * use in the request. - */ - private ZhiPuAiEmbeddingOptions mergeOptions(@Nullable EmbeddingOptions runtimeOptions, - ZhiPuAiEmbeddingOptions defaultOptions) { - var runtimeOptionsForProvider = ModelOptionsUtils.copyToTarget(runtimeOptions, EmbeddingOptions.class, + EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) { + // Process runtime options + ZhiPuAiEmbeddingOptions runtimeOptions = null; + if (embeddingRequest.getOptions() != null) { + runtimeOptions = ModelOptionsUtils.copyToTarget(embeddingRequest.getOptions(), EmbeddingOptions.class, + ZhiPuAiEmbeddingOptions.class); + } + + // Define request options by merging runtime options and default options + ZhiPuAiEmbeddingOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, ZhiPuAiEmbeddingOptions.class); - if (runtimeOptionsForProvider == null) { - return defaultOptions; + // Validate request options + if (!StringUtils.hasText(requestOptions.getModel())) { + throw new IllegalArgumentException("model cannot be null or empty"); } - return ZhiPuAiEmbeddingOptions.builder() - .model(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getModel(), defaultOptions.getModel())) - .dimensions(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getDimensions(), - defaultOptions.getDimensions())) - .build(); + return new EmbeddingRequest(embeddingRequest.getInstructions(), requestOptions); } private ZhiPuAiApi.EmbeddingRequest createEmbeddingRequest(String text, EmbeddingOptions requestOptions) { diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java index 3ef3225e69d..b78db162096 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiRetryTests.java @@ -29,6 +29,7 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.embedding.EmbeddingOptions; import org.springframework.ai.image.ImageMessage; import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.retry.RetryUtils; @@ -164,9 +165,9 @@ public void zhiPuAiEmbeddingTransientError() { .willThrow(new TransientAiException("Transient Error 1")) .willThrow(new TransientAiException("Transient Error 2")) .willReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); - + EmbeddingOptions options = ZhiPuAiEmbeddingOptions.builder().model("model").build(); var result = this.embeddingModel - .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); + .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), options)); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f }); @@ -178,8 +179,9 @@ public void zhiPuAiEmbeddingTransientError() { public void zhiPuAiEmbeddingNonTransientError() { given(this.zhiPuAiApi.embeddings(isA(EmbeddingRequest.class))) .willThrow(new RuntimeException("Non Transient Error")); + EmbeddingOptions options = ZhiPuAiEmbeddingOptions.builder().model("model").build(); assertThrows(RuntimeException.class, () -> this.embeddingModel - .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); + .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), options))); } @Test diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContext.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContext.java index 64689201abb..819edec419f 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContext.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/observation/ChatModelObservationContext.java @@ -32,35 +32,21 @@ */ public class ChatModelObservationContext extends ModelObservationContext { - private final ChatOptions requestOptions; - - ChatModelObservationContext(Prompt prompt, String provider, ChatOptions requestOptions) { + ChatModelObservationContext(Prompt prompt, String provider) { super(prompt, AiOperationMetadata.builder().operationType(AiOperationType.CHAT.value()).provider(provider).build()); - Assert.notNull(requestOptions, "requestOptions cannot be null"); - this.requestOptions = requestOptions; } public static Builder builder() { return new Builder(); } - /** - * @deprecated Use {@link #getRequest().getOptions()} instead. - */ - @Deprecated(forRemoval = true) - public ChatOptions getRequestOptions() { - return this.requestOptions; - } - public static final class Builder { private Prompt prompt; private String provider; - private ChatOptions requestOptions; - private Builder() { } @@ -74,18 +60,8 @@ public Builder provider(String provider) { return this; } - /** - * @deprecated ChatOptions are passed in the Prompt object and should not be set - * separately anymore. - */ - @Deprecated(forRemoval = true) - public Builder requestOptions(ChatOptions requestOptions) { - this.requestOptions = requestOptions; - return this; - } - public ChatModelObservationContext build() { - return new ChatModelObservationContext(this.prompt, this.provider, this.requestOptions); + return new ChatModelObservationContext(this.prompt, this.provider); } } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConvention.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConvention.java index 0a2b16bb19e..8ebb8c2b0ab 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConvention.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConvention.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ import io.micrometer.common.KeyValue; import io.micrometer.common.KeyValues; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -29,6 +30,7 @@ * Default conventions to populate observations for chat model operations. * * @author Thomas Vitale + * @author Soby Chacko * @since 1.0.0 */ public class DefaultChatModelObservationConvention implements ChatModelObservationConvention { @@ -48,9 +50,9 @@ public String getName() { @Override public String getContextualName(ChatModelObservationContext context) { - if (StringUtils.hasText(context.getRequestOptions().getModel())) { - return "%s %s".formatted(context.getOperationMetadata().operationType(), - context.getRequestOptions().getModel()); + ChatOptions options = context.getRequest().getOptions(); + if (StringUtils.hasText(options.getModel())) { + return "%s %s".formatted(context.getOperationMetadata().operationType(), options.getModel()); } return context.getOperationMetadata().operationType(); } @@ -72,9 +74,10 @@ protected KeyValue aiProvider(ChatModelObservationContext context) { } protected KeyValue requestModel(ChatModelObservationContext context) { - if (StringUtils.hasText(context.getRequestOptions().getModel())) { + ChatOptions options = context.getRequest().getOptions(); + if (StringUtils.hasText(options.getModel())) { return KeyValue.of(ChatModelObservationDocumentation.LowCardinalityKeyNames.REQUEST_MODEL, - context.getRequestOptions().getModel()); + options.getModel()); } return REQUEST_MODEL_NONE; } @@ -111,40 +114,42 @@ public KeyValues getHighCardinalityKeyValues(ChatModelObservationContext context // Request protected KeyValues requestFrequencyPenalty(KeyValues keyValues, ChatModelObservationContext context) { - if (context.getRequestOptions().getFrequencyPenalty() != null) { + ChatOptions options = context.getRequest().getOptions(); + if (options.getFrequencyPenalty() != null) { return keyValues.and( ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_FREQUENCY_PENALTY.asString(), - String.valueOf(context.getRequestOptions().getFrequencyPenalty())); + String.valueOf(options.getFrequencyPenalty())); } return keyValues; } protected KeyValues requestMaxTokens(KeyValues keyValues, ChatModelObservationContext context) { - if (context.getRequestOptions().getMaxTokens() != null) { + ChatOptions options = context.getRequest().getOptions(); + if (options.getMaxTokens() != null) { return keyValues.and( ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_MAX_TOKENS.asString(), - String.valueOf(context.getRequestOptions().getMaxTokens())); + String.valueOf(options.getMaxTokens())); } return keyValues; } protected KeyValues requestPresencePenalty(KeyValues keyValues, ChatModelObservationContext context) { - if (context.getRequestOptions().getPresencePenalty() != null) { + ChatOptions options = context.getRequest().getOptions(); + if (options.getPresencePenalty() != null) { return keyValues.and( ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_PRESENCE_PENALTY.asString(), - String.valueOf(context.getRequestOptions().getPresencePenalty())); + String.valueOf(options.getPresencePenalty())); } return keyValues; } protected KeyValues requestStopSequences(KeyValues keyValues, ChatModelObservationContext context) { - if (!CollectionUtils.isEmpty(context.getRequestOptions().getStopSequences())) { + ChatOptions options = context.getRequest().getOptions(); + if (!CollectionUtils.isEmpty(options.getStopSequences())) { StringJoiner stopSequencesJoiner = new StringJoiner(", ", "[", "]"); - context.getRequestOptions() - .getStopSequences() - .forEach(value -> stopSequencesJoiner.add("\"" + value + "\"")); + options.getStopSequences().forEach(value -> stopSequencesJoiner.add("\"" + value + "\"")); KeyValue.of(ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_STOP_SEQUENCES, - context.getRequestOptions().getStopSequences(), Objects::nonNull); + options.getStopSequences(), Objects::nonNull); return keyValues.and( ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_STOP_SEQUENCES.asString(), stopSequencesJoiner.toString()); @@ -153,26 +158,29 @@ protected KeyValues requestStopSequences(KeyValues keyValues, ChatModelObservati } protected KeyValues requestTemperature(KeyValues keyValues, ChatModelObservationContext context) { - if (context.getRequestOptions().getTemperature() != null) { + ChatOptions options = context.getRequest().getOptions(); + if (options.getTemperature() != null) { return keyValues.and( ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_TEMPERATURE.asString(), - String.valueOf(context.getRequestOptions().getTemperature())); + String.valueOf(options.getTemperature())); } return keyValues; } protected KeyValues requestTopK(KeyValues keyValues, ChatModelObservationContext context) { - if (context.getRequestOptions().getTopK() != null) { + ChatOptions options = context.getRequest().getOptions(); + if (options.getTopK() != null) { return keyValues.and(ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_TOP_K.asString(), - String.valueOf(context.getRequestOptions().getTopK())); + String.valueOf(options.getTopK())); } return keyValues; } protected KeyValues requestTopP(KeyValues keyValues, ChatModelObservationContext context) { - if (context.getRequestOptions().getTopP() != null) { + ChatOptions options = context.getRequest().getOptions(); + if (options.getTopP() != null) { return keyValues.and(ChatModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_TOP_P.asString(), - String.valueOf(context.getRequestOptions().getTopP())); + String.valueOf(options.getTopP())); } return keyValues; } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConvention.java b/spring-ai-model/src/main/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConvention.java index 6949f0e000d..97d5146a66a 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConvention.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConvention.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ * Default conventions to populate observations for embedding model operations. * * @author Thomas Vitale + * @author Soby Chacko * @since 1.0.0 */ public class DefaultEmbeddingModelObservationConvention implements EmbeddingModelObservationConvention { @@ -44,9 +45,9 @@ public String getName() { @Override public String getContextualName(EmbeddingModelObservationContext context) { - if (StringUtils.hasText(context.getRequestOptions().getModel())) { + if (StringUtils.hasText(context.getRequest().getOptions().getModel())) { return "%s %s".formatted(context.getOperationMetadata().operationType(), - context.getRequestOptions().getModel()); + context.getRequest().getOptions().getModel()); } return context.getOperationMetadata().operationType(); } @@ -68,9 +69,9 @@ protected KeyValue aiProvider(EmbeddingModelObservationContext context) { } protected KeyValue requestModel(EmbeddingModelObservationContext context) { - if (StringUtils.hasText(context.getRequestOptions().getModel())) { + if (StringUtils.hasText(context.getRequest().getOptions().getModel())) { return KeyValue.of(EmbeddingModelObservationDocumentation.LowCardinalityKeyNames.REQUEST_MODEL, - context.getRequestOptions().getModel()); + context.getRequest().getOptions().getModel()); } return REQUEST_MODEL_NONE; } @@ -98,10 +99,10 @@ public KeyValues getHighCardinalityKeyValues(EmbeddingModelObservationContext co // Request protected KeyValues requestEmbeddingDimension(KeyValues keyValues, EmbeddingModelObservationContext context) { - if (context.getRequestOptions().getDimensions() != null) { + if (context.getRequest().getOptions().getDimensions() != null) { return keyValues .and(EmbeddingModelObservationDocumentation.HighCardinalityKeyNames.REQUEST_EMBEDDING_DIMENSIONS - .asString(), String.valueOf(context.getRequestOptions().getDimensions())); + .asString(), String.valueOf(context.getRequest().getOptions().getDimensions())); } return keyValues; } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContext.java b/spring-ai-model/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContext.java index 07bc7b0edcd..bc35c72533b 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContext.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContext.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,49 +22,34 @@ import org.springframework.ai.model.observation.ModelObservationContext; import org.springframework.ai.observation.AiOperationMetadata; import org.springframework.ai.observation.conventions.AiOperationType; -import org.springframework.util.Assert; /** * Context used to store metadata for embedding model exchanges. * * @author Thomas Vitale + * @author Soby Chacko * @since 1.0.0 */ public class EmbeddingModelObservationContext extends ModelObservationContext { - private final EmbeddingOptions requestOptions; - - EmbeddingModelObservationContext(EmbeddingRequest embeddingRequest, String provider, - EmbeddingOptions requestOptions) { + EmbeddingModelObservationContext(EmbeddingRequest embeddingRequest, String provider) { super(embeddingRequest, AiOperationMetadata.builder() .operationType(AiOperationType.EMBEDDING.value()) .provider(provider) .build()); - Assert.notNull(requestOptions, "requestOptions cannot be null"); - this.requestOptions = requestOptions; } public static Builder builder() { return new Builder(); } - /** - * @deprecated Use {@link #getRequest().getOptions()} instead. - */ - @Deprecated(forRemoval = true) - public EmbeddingOptions getRequestOptions() { - return this.requestOptions; - } - public static final class Builder { private EmbeddingRequest embeddingRequest; private String provider; - private EmbeddingOptions requestOptions; - private Builder() { } @@ -78,18 +63,8 @@ public Builder provider(String provider) { return this; } - /** - * @deprecated EmbeddingOptions are passed in the EmbeddingRequest object and - * should not be set separately anymore. - */ - @Deprecated(forRemoval = true) - public Builder requestOptions(EmbeddingOptions requestOptions) { - this.requestOptions = requestOptions; - return this; - } - public EmbeddingModelObservationContext build() { - return new EmbeddingModelObservationContext(this.embeddingRequest, this.provider, this.requestOptions); + return new EmbeddingModelObservationContext(this.embeddingRequest, this.provider); } } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilterTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilterTests.java index a1cbda26f3f..8471b9e7309 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilterTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,9 +51,8 @@ void whenNotSupportedObservationContextThenReturnOriginalContext() { @Test void whenEmptyResponseThenReturnOriginalContext() { var expectedContext = ChatModelObservationContext.builder() - .prompt(generatePrompt()) + .prompt(generatePrompt(ChatOptions.builder().model("mistral").build())) .provider("superprovider") - .requestOptions(ChatOptions.builder().model("mistral").build()) .build(); var actualContext = this.observationFilter.map(expectedContext); @@ -63,9 +62,8 @@ void whenEmptyResponseThenReturnOriginalContext() { @Test void whenEmptyCompletionThenReturnOriginalContext() { var expectedContext = ChatModelObservationContext.builder() - .prompt(generatePrompt()) + .prompt(generatePrompt(ChatOptions.builder().model("mistral").build())) .provider("superprovider") - .requestOptions(ChatOptions.builder().model("mistral").build()) .build(); expectedContext.setResponse(new ChatResponse(List.of(new Generation(new AssistantMessage(""))))); var actualContext = this.observationFilter.map(expectedContext); @@ -76,9 +74,8 @@ void whenEmptyCompletionThenReturnOriginalContext() { @Test void whenCompletionWithTextThenAugmentContext() { var originalContext = ChatModelObservationContext.builder() - .prompt(generatePrompt()) + .prompt(generatePrompt(ChatOptions.builder().model("mistral").build())) .provider("superprovider") - .requestOptions(ChatOptions.builder().model("mistral").build()) .build(); originalContext.setResponse(new ChatResponse(List.of(new Generation(new AssistantMessage("say please")), new Generation(new AssistantMessage("seriously, say please"))))); @@ -88,8 +85,8 @@ void whenCompletionWithTextThenAugmentContext() { .of(HighCardinalityKeyNames.COMPLETION.asString(), "[\"say please\", \"seriously, say please\"]")); } - private Prompt generatePrompt() { - return new Prompt("supercalifragilisticexpialidocious"); + private Prompt generatePrompt(ChatOptions chatOptions) { + return new Prompt("supercalifragilisticexpialidocious", chatOptions); } } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandlerTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandlerTests.java index 8e8218acf78..768f2d8d38a 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandlerTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelCompletionObservationHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,9 +47,9 @@ class ChatModelCompletionObservationHandlerTests { @Test void whenCompletionWithTextThenSpanEvent() { var observationContext = ChatModelObservationContext.builder() - .prompt(new Prompt("supercalifragilisticexpialidocious")) + .prompt(new Prompt("supercalifragilisticexpialidocious", + ChatOptions.builder().model("spoonful-of-sugar").build())) .provider("mary-poppins") - .requestOptions(ChatOptions.builder().model("spoonful-of-sugar").build()) .build(); observationContext.setResponse(new ChatResponse(List.of(new Generation(new AssistantMessage("say please")), new Generation(new AssistantMessage("seriously, say please"))))); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java index 06acfb0d218..fa8e415d1af 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelMeterObservationHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -95,14 +95,13 @@ void shouldCreateAllMetersDuringAnObservation() { private ChatModelObservationContext generateObservationContext() { return ChatModelObservationContext.builder() - .prompt(generatePrompt()) + .prompt(generatePrompt(ChatOptions.builder().model("mistral").build())) .provider("superprovider") - .requestOptions(ChatOptions.builder().model("mistral").build()) .build(); } - private Prompt generatePrompt() { - return new Prompt("hello"); + private Prompt generatePrompt(ChatOptions chatOptions) { + return new Prompt("hello", chatOptions); } static class TestUsage implements Usage { diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContextTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContextTests.java index 3e52f6a9fa9..37f3dc85836 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContextTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelObservationContextTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,26 +34,15 @@ class ChatModelObservationContextTests { @Test void whenMandatoryRequestOptionsThenReturn() { var observationContext = ChatModelObservationContext.builder() - .prompt(generatePrompt()) + .prompt(generatePrompt(ChatOptions.builder().model("supermodel").build())) .provider("superprovider") - .requestOptions(ChatOptions.builder().model("supermodel").build()) .build(); assertThat(observationContext).isNotNull(); } - @Test - void whenRequestOptionsIsNullThenThrow() { - assertThatThrownBy(() -> ChatModelObservationContext.builder() - .prompt(generatePrompt()) - .provider("superprovider") - .requestOptions(null) - .build()).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("requestOptions cannot be null"); - } - - private Prompt generatePrompt() { - return new Prompt("hello"); + private Prompt generatePrompt(ChatOptions chatOptions) { + return new Prompt("hello", chatOptions); } } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilterTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilterTests.java index c05dd3ef9aa..638f1320b00 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilterTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,9 +50,8 @@ void whenNotSupportedObservationContextThenReturnOriginalContext() { @Test void whenEmptyPromptThenReturnOriginalContext() { var expectedContext = ChatModelObservationContext.builder() - .prompt(new Prompt(List.of())) + .prompt(new Prompt(List.of(), ChatOptions.builder().model("mistral").build())) .provider("superprovider") - .requestOptions(ChatOptions.builder().model("mistral").build()) .build(); var actualContext = this.observationFilter.map(expectedContext); @@ -62,9 +61,8 @@ void whenEmptyPromptThenReturnOriginalContext() { @Test void whenPromptWithTextThenAugmentContext() { var originalContext = ChatModelObservationContext.builder() - .prompt(new Prompt("supercalifragilisticexpialidocious")) + .prompt(new Prompt("supercalifragilisticexpialidocious", ChatOptions.builder().model("mistral").build())) .provider("superprovider") - .requestOptions(ChatOptions.builder().model("mistral").build()) .build(); var augmentedContext = this.observationFilter.map(originalContext); @@ -75,10 +73,11 @@ void whenPromptWithTextThenAugmentContext() { @Test void whenPromptWithMessagesThenAugmentContext() { var originalContext = ChatModelObservationContext.builder() - .prompt(new Prompt(List.of(new SystemMessage("you're a chimney sweep"), - new UserMessage("supercalifragilisticexpialidocious")))) + .prompt(new Prompt( + List.of(new SystemMessage("you're a chimney sweep"), + new UserMessage("supercalifragilisticexpialidocious")), + ChatOptions.builder().model("mistral").build())) .provider("superprovider") - .requestOptions(ChatOptions.builder().model("mistral").build()) .build(); var augmentedContext = this.observationFilter.map(originalContext); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationHandlerTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationHandlerTests.java index ab90a855100..c76b524705a 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationHandlerTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/ChatModelPromptContentObservationHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,9 +42,9 @@ class ChatModelPromptContentObservationHandlerTests { @Test void whenPromptWithTextThenSpanEvent() { var observationContext = ChatModelObservationContext.builder() - .prompt(new Prompt("supercalifragilisticexpialidocious")) + .prompt(new Prompt("supercalifragilisticexpialidocious", + ChatOptions.builder().model("spoonful-of-sugar").build())) .provider("mary-poppins") - .requestOptions(ChatOptions.builder().model("spoonful-of-sugar").build()) .build(); var sdkTracer = SdkTracerProvider.builder().build().get("test"); var otelTracer = new OtelTracer(sdkTracer, new OtelCurrentTraceContext(), null); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java index c70e8b8a0cf..5629a1de463 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/observation/DefaultChatModelObservationConventionTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -55,9 +55,8 @@ void shouldHaveName() { @Test void contextualNameWhenModelIsDefined() { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(generatePrompt()) + .prompt(generatePrompt(ChatOptions.builder().model("mistral").build())) .provider("superprovider") - .requestOptions(ChatOptions.builder().model("mistral").build()) .build(); assertThat(this.observationConvention.getContextualName(observationContext)).isEqualTo("chat mistral"); } @@ -65,9 +64,8 @@ void contextualNameWhenModelIsDefined() { @Test void contextualNameWhenModelIsNotDefined() { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(generatePrompt()) + .prompt(generatePrompt(ChatOptions.builder().build())) .provider("superprovider") - .requestOptions(ChatOptions.builder().build()) .build(); assertThat(this.observationConvention.getContextualName(observationContext)).isEqualTo("chat"); } @@ -75,9 +73,8 @@ void contextualNameWhenModelIsNotDefined() { @Test void supportsOnlyChatModelObservationContext() { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(generatePrompt()) + .prompt(generatePrompt(ChatOptions.builder().model("mistral").build())) .provider("superprovider") - .requestOptions(ChatOptions.builder().model("mistral").build()) .build(); assertThat(this.observationConvention.supportsContext(observationContext)).isTrue(); assertThat(this.observationConvention.supportsContext(new Observation.Context())).isFalse(); @@ -86,9 +83,8 @@ void supportsOnlyChatModelObservationContext() { @Test void shouldHaveLowCardinalityKeyValuesWhenDefined() { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(generatePrompt()) + .prompt(generatePrompt(ChatOptions.builder().model("mistral").build())) .provider("superprovider") - .requestOptions(ChatOptions.builder().model("mistral").build()) .build(); assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains( KeyValue.of(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), "chat"), @@ -99,9 +95,7 @@ void shouldHaveLowCardinalityKeyValuesWhenDefined() { @Test void shouldHaveKeyValuesWhenDefinedAndResponse() { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(generatePrompt()) - .provider("superprovider") - .requestOptions(ChatOptions.builder() + .prompt(generatePrompt(ChatOptions.builder() .model("mistral") .frequencyPenalty(0.8) .maxTokens(200) @@ -110,7 +104,8 @@ void shouldHaveKeyValuesWhenDefinedAndResponse() { .temperature(0.5) .topK(1) .topP(0.9) - .build()) + .build())) + .provider("superprovider") .build(); observationContext.setResponse(new ChatResponse( List.of(new Generation(new AssistantMessage("response"), @@ -136,9 +131,8 @@ void shouldHaveKeyValuesWhenDefinedAndResponse() { @Test void shouldNotHaveKeyValuesWhenMissing() { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(generatePrompt()) + .prompt(generatePrompt(ChatOptions.builder().build())) .provider("superprovider") - .requestOptions(ChatOptions.builder().build()) .build(); assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)) .contains(KeyValue.of(LowCardinalityKeyNames.REQUEST_MODEL.asString(), KeyValue.NONE_VALUE)) @@ -162,9 +156,8 @@ void shouldNotHaveKeyValuesWhenMissing() { @Test void shouldNotHaveKeyValuesWhenEmptyValues() { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() - .prompt(generatePrompt()) + .prompt(generatePrompt(ChatOptions.builder().stopSequences(List.of()).build())) .provider("superprovider") - .requestOptions(ChatOptions.builder().stopSequences(List.of()).build()) .build(); observationContext.setResponse(new ChatResponse( List.of(new Generation(new AssistantMessage("response"), @@ -178,8 +171,8 @@ void shouldNotHaveKeyValuesWhenEmptyValues() { HighCardinalityKeyNames.RESPONSE_ID.asString()); } - private Prompt generatePrompt() { - return new Prompt("Who let the dogs out?"); + private Prompt generatePrompt(ChatOptions chatOptions) { + return new Prompt("Who let the dogs out?", chatOptions); } static class TestUsage implements Usage { diff --git a/spring-ai-model/src/test/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConventionTests.java b/spring-ai-model/src/test/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConventionTests.java index 7ed12ac161d..ba5c6467da6 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConventionTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/embedding/observation/DefaultEmbeddingModelObservationConventionTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,9 +22,11 @@ import io.micrometer.common.KeyValue; import io.micrometer.observation.Observation; +import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.Test; import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.embedding.EmbeddingOptions; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; @@ -52,9 +54,8 @@ void shouldHaveName() { @Test void contextualNameWhenModelIsDefined() { EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder() - .embeddingRequest(generateEmbeddingRequest()) + .embeddingRequest(generateEmbeddingRequest(EmbeddingOptionsBuilder.builder().withModel("mistral").build())) .provider("superprovider") - .requestOptions(EmbeddingOptionsBuilder.builder().withModel("mistral").build()) .build(); assertThat(this.observationConvention.getContextualName(observationContext)).isEqualTo("embedding mistral"); } @@ -62,9 +63,8 @@ void contextualNameWhenModelIsDefined() { @Test void contextualNameWhenModelIsNotDefined() { EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder() - .embeddingRequest(generateEmbeddingRequest()) + .embeddingRequest(generateEmbeddingRequest(EmbeddingOptionsBuilder.builder().build())) .provider("superprovider") - .requestOptions(EmbeddingOptionsBuilder.builder().build()) .build(); assertThat(this.observationConvention.getContextualName(observationContext)).isEqualTo("embedding"); } @@ -72,9 +72,9 @@ void contextualNameWhenModelIsNotDefined() { @Test void supportsOnlyEmbeddingModelObservationContext() { EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder() - .embeddingRequest(generateEmbeddingRequest()) + .embeddingRequest( + generateEmbeddingRequest(EmbeddingOptionsBuilder.builder().withModel("supermodel").build())) .provider("superprovider") - .requestOptions(EmbeddingOptionsBuilder.builder().withModel("supermodel").build()) .build(); assertThat(this.observationConvention.supportsContext(observationContext)).isTrue(); assertThat(this.observationConvention.supportsContext(new Observation.Context())).isFalse(); @@ -83,9 +83,8 @@ void supportsOnlyEmbeddingModelObservationContext() { @Test void shouldHaveLowCardinalityKeyValuesWhenDefined() { EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder() - .embeddingRequest(generateEmbeddingRequest()) + .embeddingRequest(generateEmbeddingRequest(EmbeddingOptionsBuilder.builder().withModel("mistral").build())) .provider("superprovider") - .requestOptions(EmbeddingOptionsBuilder.builder().withModel("mistral").build()) .build(); assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains( KeyValue.of(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), "embedding"), @@ -96,9 +95,9 @@ void shouldHaveLowCardinalityKeyValuesWhenDefined() { @Test void shouldHaveLowCardinalityKeyValuesWhenDefinedAndResponse() { EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder() - .embeddingRequest(generateEmbeddingRequest()) + .embeddingRequest(generateEmbeddingRequest( + EmbeddingOptionsBuilder.builder().withModel("mistral").withDimensions(1492).build())) .provider("superprovider") - .requestOptions(EmbeddingOptionsBuilder.builder().withModel("mistral").withDimensions(1492).build()) .build(); observationContext.setResponse(new EmbeddingResponse(List.of(), new EmbeddingResponseMetadata("mistral-42", new TestUsage(), Map.of()))); @@ -113,9 +112,8 @@ void shouldHaveLowCardinalityKeyValuesWhenDefinedAndResponse() { @Test void shouldNotHaveKeyValuesWhenMissing() { EmbeddingModelObservationContext observationContext = EmbeddingModelObservationContext.builder() - .embeddingRequest(generateEmbeddingRequest()) + .embeddingRequest(generateEmbeddingRequest(EmbeddingOptionsBuilder.builder().build())) .provider("superprovider") - .requestOptions(EmbeddingOptionsBuilder.builder().build()) .build(); assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)) .contains(KeyValue.of(LowCardinalityKeyNames.REQUEST_MODEL.asString(), KeyValue.NONE_VALUE)) @@ -128,8 +126,8 @@ void shouldNotHaveKeyValuesWhenMissing() { HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString()); } - private EmbeddingRequest generateEmbeddingRequest() { - return new EmbeddingRequest(List.of(), EmbeddingOptionsBuilder.builder().build()); + private EmbeddingRequest generateEmbeddingRequest(EmbeddingOptions embeddingOptions) { + return new EmbeddingRequest(List.of(), embeddingOptions); } static class TestUsage implements Usage { diff --git a/spring-ai-model/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandlerTests.java b/spring-ai-model/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandlerTests.java index b7b1c65f345..dada880a1a9 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandlerTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelMeterObservationHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.embedding.EmbeddingOptions; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; @@ -92,14 +93,13 @@ void shouldCreateAllMetersDuringAnObservation() { private EmbeddingModelObservationContext generateObservationContext() { return EmbeddingModelObservationContext.builder() - .embeddingRequest(generateEmbeddingRequest()) + .embeddingRequest(generateEmbeddingRequest(EmbeddingOptionsBuilder.builder().withModel("mistral").build())) .provider("superprovider") - .requestOptions(EmbeddingOptionsBuilder.builder().withModel("mistral").build()) .build(); } - private EmbeddingRequest generateEmbeddingRequest() { - return new EmbeddingRequest(List.of(), EmbeddingOptionsBuilder.builder().build()); + private EmbeddingRequest generateEmbeddingRequest(EmbeddingOptions embeddingOptions) { + return new EmbeddingRequest(List.of(), embeddingOptions); } static class TestUsage implements Usage { diff --git a/spring-ai-model/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContextTests.java b/spring-ai-model/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContextTests.java index 0678fe26ad4..780e881a43b 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContextTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/embedding/observation/EmbeddingModelObservationContextTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import org.junit.jupiter.api.Test; +import org.springframework.ai.embedding.EmbeddingOptions; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.embedding.EmbeddingRequest; @@ -36,26 +37,16 @@ class EmbeddingModelObservationContextTests { @Test void whenMandatoryRequestOptionsThenReturn() { var observationContext = EmbeddingModelObservationContext.builder() - .embeddingRequest(generateEmbeddingRequest()) + .embeddingRequest( + generateEmbeddingRequest(EmbeddingOptionsBuilder.builder().withModel("supermodel").build())) .provider("superprovider") - .requestOptions(EmbeddingOptionsBuilder.builder().withModel("supermodel").build()) .build(); assertThat(observationContext).isNotNull(); } - @Test - void whenRequestOptionsIsNullThenThrow() { - assertThatThrownBy(() -> EmbeddingModelObservationContext.builder() - .embeddingRequest(generateEmbeddingRequest()) - .provider("superprovider") - .requestOptions(null) - .build()).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("requestOptions cannot be null"); - } - - private EmbeddingRequest generateEmbeddingRequest() { - return new EmbeddingRequest(List.of(), EmbeddingOptionsBuilder.builder().build()); + private EmbeddingRequest generateEmbeddingRequest(EmbeddingOptions embeddingOptions) { + return new EmbeddingRequest(List.of(), embeddingOptions); } }