Skip to content

GH-2518: Remove requestOptions from observation context objects #2896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(AnthropicApi.PROVIDER_NAME)
.requestOptions(prompt.getOptions())
.build();

ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
Expand Down Expand Up @@ -240,7 +239,6 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(AnthropicApi.PROVIDER_NAME)
.requestOptions(prompt.getOptions())
.build();

Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(AiProvider.AZURE_OPENAI.value())
.requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)
.build();

ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
Expand Down Expand Up @@ -300,7 +299,6 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(AiProvider.AZURE_OPENAI.value())
.requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)
.build();

Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -51,6 +51,7 @@
* @author Mark Pollack
* @author Christian Tzolov
* @author Thomas Vitale
* @author Soby Chacko
* @since 1.0.0
*/
public class AzureOpenAiEmbeddingModel extends AbstractEmbeddingModel {
Expand Down Expand Up @@ -124,12 +125,15 @@ public EmbeddingResponse call(EmbeddingRequest embeddingRequest) {
.from(this.defaultOptions)
.merge(embeddingRequest.getOptions())
.build();
EmbeddingsOptions azureOptions = options.toAzureOptions(embeddingRequest.getInstructions());

EmbeddingRequest embeddingRequestWithMergedOptions = new EmbeddingRequest(embeddingRequest.getInstructions(),
options);

EmbeddingsOptions azureOptions = options.toAzureOptions(embeddingRequestWithMergedOptions.getInstructions());

var observationContext = EmbeddingModelObservationContext.builder()
.embeddingRequest(embeddingRequest)
.embeddingRequest(embeddingRequestWithMergedOptions)
.provider(AiProvider.AZURE_OPENAI.value())
.requestOptions(options)
.build();

return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatRespon
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(AiProvider.BEDROCK_CONVERSE.value())
.requestOptions(prompt.getOptions())
.build();

ChatResponse chatResponse = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
Expand Down Expand Up @@ -647,7 +646,6 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse perviousCh
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(AiProvider.BEDROCK_CONVERSE.value())
.requestOptions(prompt.getOptions())
.build();

Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,6 @@ public ChatResponse call(Prompt prompt) {
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(requestPrompt)
.provider(MiniMaxApiConstants.PROVIDER_NAME)
.requestOptions(requestPrompt.getOptions())
.build();

ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
Expand Down Expand Up @@ -334,7 +333,6 @@ public Flux<ChatResponse> stream(Prompt prompt) {
final ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(requestPrompt)
.provider(MiniMaxApiConstants.PROVIDER_NAME)
.requestOptions(requestPrompt.getOptions())
.build();

Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -43,13 +43,15 @@
import org.springframework.lang.Nullable;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
* MiniMax Embedding Model implementation.
*
* @author Geng Rong
* @author Thomas Vitale
* @since 1.0.0 M1
* @author Soby Chacko
* @since 1.0.0
*/
public class MiniMaxEmbeddingModel extends AbstractEmbeddingModel {

Expand Down Expand Up @@ -149,14 +151,15 @@ public float[] embed(Document document) {

@Override
public EmbeddingResponse call(EmbeddingRequest request) {
MiniMaxEmbeddingOptions requestOptions = mergeOptions(request.getOptions(), this.defaultOptions);

EmbeddingRequest embeddingRequest = buildEmbeddingRequest(request);

MiniMaxApi.EmbeddingRequest apiRequest = new MiniMaxApi.EmbeddingRequest(request.getInstructions(),
requestOptions.getModel());
embeddingRequest.getOptions().getModel());

var observationContext = EmbeddingModelObservationContext.builder()
.embeddingRequest(request)
.provider(MiniMaxApiConstants.PROVIDER_NAME)
.requestOptions(requestOptions)
.build();

return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
Expand Down Expand Up @@ -188,26 +191,24 @@ private DefaultUsage getDefaultUsage(MiniMaxApi.EmbeddingList apiEmbeddingList)
return new DefaultUsage(0, 0, apiEmbeddingList.totalTokens());
}

/**
* Merge runtime and default {@link EmbeddingOptions} to compute the final options to
* use in the request.
*/
private MiniMaxEmbeddingOptions mergeOptions(@Nullable EmbeddingOptions runtimeOptions,
MiniMaxEmbeddingOptions defaultOptions) {
var runtimeOptionsForProvider = ModelOptionsUtils.copyToTarget(runtimeOptions, EmbeddingOptions.class,
EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) {
// Process runtime options
MiniMaxEmbeddingOptions runtimeOptions = null;
if (embeddingRequest.getOptions() != null) {
runtimeOptions = ModelOptionsUtils.copyToTarget(embeddingRequest.getOptions(), EmbeddingOptions.class,
MiniMaxEmbeddingOptions.class);
}

// Define request options by merging runtime options and default options
MiniMaxEmbeddingOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
MiniMaxEmbeddingOptions.class);

var optionBuilder = MiniMaxEmbeddingOptions.builder();
if (runtimeOptionsForProvider != null && runtimeOptionsForProvider.getModel() != null) {
optionBuilder.model(runtimeOptionsForProvider.getModel());
}
else if (defaultOptions.getModel() != null) {
optionBuilder.model(defaultOptions.getModel());
// Validate request options
if (!StringUtils.hasText(requestOptions.getModel())) {
throw new IllegalArgumentException("model cannot be null or empty");
}
else {
optionBuilder.model(MiniMaxApi.DEFAULT_EMBEDDING_MODEL);
}
return optionBuilder.build();

return new EmbeddingRequest(embeddingRequest.getInstructions(), requestOptions);
}

public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,6 +29,7 @@
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.minimax.MiniMaxChatModel;
import org.springframework.ai.minimax.MiniMaxChatOptions;
import org.springframework.ai.minimax.MiniMaxEmbeddingModel;
Expand Down Expand Up @@ -57,6 +58,7 @@

/**
* @author Geng Rong
* @author Soby Chacko
*/
@SuppressWarnings("unchecked")
@ExtendWith(MockitoExtension.class)
Expand Down Expand Up @@ -150,8 +152,9 @@ public void miniMaxEmbeddingTransientError() {
.willThrow(new TransientAiException("Transient Error 2"))
.willReturn(ResponseEntity.of(Optional.of(expectedEmbeddings)));

EmbeddingOptions options = MiniMaxEmbeddingOptions.builder().model("model").build();
var result = this.embeddingModel
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null));
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), options));

assertThat(result).isNotNull();
assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f });
Expand All @@ -163,8 +166,9 @@ public void miniMaxEmbeddingTransientError() {
public void miniMaxEmbeddingNonTransientError() {
given(this.miniMaxApi.embeddings(isA(EmbeddingRequest.class)))
.willThrow(new RuntimeException("Non Transient Error"));
EmbeddingOptions options = MiniMaxEmbeddingOptions.builder().model("model").build();
assertThrows(RuntimeException.class, () -> this.embeddingModel
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)));
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), options)));
}

private class TestRetryListener implements RetryListener {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(MistralAiApi.PROVIDER_NAME)
.requestOptions(prompt.getOptions())
.build();

ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
Expand Down Expand Up @@ -260,7 +259,6 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(MistralAiApi.PROVIDER_NAME)
.requestOptions(prompt.getOptions())
.build();

Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,8 @@ public EmbeddingResponse call(EmbeddingRequest request) {
var apiRequest = createRequest(embeddingRequest);

var observationContext = EmbeddingModelObservationContext.builder()
.embeddingRequest(request)
.embeddingRequest(embeddingRequest)
.provider(MistralAiApi.PROVIDER_NAME)
.requestOptions(embeddingRequest.getOptions())
.build();

return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -43,6 +43,7 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
* {@link org.springframework.ai.embedding.EmbeddingModel} implementation that uses the
Expand Down Expand Up @@ -83,13 +84,15 @@ public OCIEmbeddingModel(GenerativeAiInference genAi, OCIEmbeddingOptions option
@Override
public EmbeddingResponse call(EmbeddingRequest request) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
OCIEmbeddingOptions runtimeOptions = mergeOptions(request.getOptions(), this.options);
List<EmbedTextRequest> embedTextRequests = createRequests(request.getInstructions(), runtimeOptions);

EmbeddingRequest embeddingRequest = buildEmbeddingRequest(request);

List<EmbedTextRequest> embedTextRequests = createRequests(embeddingRequest.getInstructions(),
(OCIEmbeddingOptions) embeddingRequest.getOptions());

EmbeddingModelObservationContext context = EmbeddingModelObservationContext.builder()
.embeddingRequest(request)
.embeddingRequest(embeddingRequest)
.provider(AiProvider.OCI_GENAI.value())
.requestOptions(runtimeOptions)
.build();

return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
Expand Down Expand Up @@ -158,6 +161,26 @@ private OCIEmbeddingOptions mergeOptions(EmbeddingOptions embeddingOptions, OCIE
return defaultOptions;
}

EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) {
// Process runtime options
OCIEmbeddingOptions runtimeOptions = null;
if (embeddingRequest.getOptions() != null) {
runtimeOptions = ModelOptionsUtils.copyToTarget(embeddingRequest.getOptions(), EmbeddingOptions.class,
OCIEmbeddingOptions.class);
}

// Define request options by merging runtime options and default options
OCIEmbeddingOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.options,
OCIEmbeddingOptions.class);

// Validate request options
if (!StringUtils.hasText(requestOptions.getModel())) {
throw new IllegalArgumentException("model cannot be null or empty");
}

return new EmbeddingRequest(embeddingRequest.getInstructions(), requestOptions);
}

private float[] toFloats(List<Float> embedding) {
float[] floats = new float[embedding.size()];
for (int i = 0; i < embedding.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.oci.ServingModeHelper;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -104,10 +105,10 @@ public OCICohereChatModel(GenerativeAiInference genAi, OCICohereChatOptions opti

@Override
public ChatResponse call(Prompt prompt) {
Prompt requestPrompt = this.buildRequestPrompt(prompt);
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.prompt(requestPrompt)
.provider(AiProvider.OCI_GENAI.value())
.requestOptions(prompt.getOptions() != null ? prompt.getOptions() : this.defaultOptions)
.build();

return ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
Expand All @@ -120,6 +121,21 @@ public ChatResponse call(Prompt prompt) {
});
}

Prompt buildRequestPrompt(Prompt prompt) {
// Process runtime options
OCICohereChatOptions runtimeOptions = null;
if (prompt.getOptions() != null) {
runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
OCICohereChatOptions.class);
}

// Define request options by merging runtime options and default options
OCICohereChatOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
OCICohereChatOptions.class);

return new Prompt(prompt.getInstructions(), requestOptions);
}

@Override
public ChatOptions getDefaultOptions() {
return OCICohereChatOptions.fromOptions(this.defaultOptions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(OllamaApi.PROVIDER_NAME)
.requestOptions(prompt.getOptions())
.build();

ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
Expand Down Expand Up @@ -296,7 +295,6 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCh
final ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(OllamaApi.PROVIDER_NAME)
.requestOptions(prompt.getOptions())
.build();

Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ public EmbeddingResponse call(EmbeddingRequest request) {
var observationContext = EmbeddingModelObservationContext.builder()
.embeddingRequest(request)
.provider(OllamaApi.PROVIDER_NAME)
.requestOptions(embeddingRequest.getOptions())
.build();

return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(OpenAiApiConstants.PROVIDER_NAME)
.requestOptions(prompt.getOptions())
.build();

ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
Expand Down Expand Up @@ -289,7 +288,6 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
final ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
.provider(OpenAiApiConstants.PROVIDER_NAME)
.requestOptions(prompt.getOptions())
.build();

Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
Expand Down
Loading