diff --git a/router/src/infer.rs b/router/src/infer.rs index 816eee56..a142793d 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -501,10 +501,12 @@ impl Infer { err })?; + let embed_params = request.parameters.unwrap_or_default(); + let (adapter_source, adapter_parameters) = extract_adapter_params( - request.parameters.adapter_id.clone(), - request.parameters.adapter_source.clone(), - request.parameters.adapter_parameters.clone(), + embed_params.adapter_id.clone(), + embed_params.adapter_source.clone(), + embed_params.adapter_parameters.clone(), ); let adapter_idx; @@ -520,7 +522,7 @@ impl Infer { } } - let api_token = request.parameters.api_token.clone(); + let api_token = embed_params.api_token.clone(); let adapter = Adapter::new( adapter_parameters, adapter_source.unwrap(), @@ -875,10 +877,12 @@ impl Infer { err })?; + let embed_params = request.parameters.clone().unwrap_or_default(); + let (adapter_source, adapter_parameters) = extract_adapter_params( - request.parameters.adapter_id.clone(), - request.parameters.adapter_source.clone(), - request.parameters.adapter_parameters.clone(), + embed_params.adapter_id.clone(), + embed_params.adapter_source.clone(), + embed_params.adapter_parameters.clone(), ); let adapter_idx; @@ -894,7 +898,7 @@ impl Infer { } } - let api_token = request.parameters.api_token.clone(); + let api_token = embed_params.api_token.clone(); let adapter = Adapter::new( adapter_parameters, adapter_source.unwrap(), diff --git a/router/src/lib.rs b/router/src/lib.rs index fd5b1049..c3cf2ced 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1144,20 +1144,22 @@ pub(crate) struct EmbedParameters { pub api_token: Option, } -fn default_embed_parameters() -> EmbedParameters { - EmbedParameters { - adapter_id: None, - adapter_source: None, - adapter_parameters: None, - api_token: None, +impl Default for EmbedParameters { + fn default() -> Self { + Self { + adapter_id: None, + adapter_source: None, + adapter_parameters: None, + api_token: None, + } } } #[derive(Clone, Debug, Deserialize, ToSchema)] struct EmbedRequest { inputs: String, - #[serde(default = "default_embed_parameters")] - pub parameters: EmbedParameters, + #[serde(default)] + pub parameters: Option, } #[derive(Serialize, ToSchema)] @@ -1192,8 +1194,8 @@ struct CompatEmbedRequest { dimensions: Option, #[allow(dead_code)] user: Option, - #[serde(default = "default_embed_parameters")] - parameters: EmbedParameters, + #[serde(default)] + parameters: Option, } #[derive(Serialize, ToSchema)] @@ -1221,8 +1223,8 @@ struct BatchClassifyRequest { #[derive(Clone, Debug, Deserialize, ToSchema)] struct BatchEmbedRequest { inputs: Vec, - #[serde(default = "default_embed_parameters")] - parameters: EmbedParameters, + #[serde(default)] + parameters: Option, } #[derive(Debug, Serialize, Deserialize)] diff --git a/router/src/server.rs b/router/src/server.rs index e66b1d11..405bcff8 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -486,12 +486,12 @@ async fn health( if health.shard_info().supports_embeddings { let embed_request = EmbedRequest { inputs: "San Francisco".to_string(), - parameters: EmbedParameters { + parameters: Some(EmbedParameters { adapter_id: None, adapter_source: None, adapter_parameters: None, api_token: None, - }, + }), }; match infer.embed(embed_request).await { Ok(_) => {}