Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh committed Dec 13, 2024
1 parent c7c56dd commit 9e288b4
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 22 deletions.
20 changes: 12 additions & 8 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -501,10 +501,12 @@ impl Infer {
err
})?;

let embed_params = request.parameters.unwrap_or_default();

let (adapter_source, adapter_parameters) = extract_adapter_params(
request.parameters.adapter_id.clone(),
request.parameters.adapter_source.clone(),
request.parameters.adapter_parameters.clone(),
embed_params.adapter_id.clone(),
embed_params.adapter_source.clone(),
embed_params.adapter_parameters.clone(),
);

let adapter_idx;
Expand All @@ -520,7 +522,7 @@ impl Infer {
}
}

let api_token = request.parameters.api_token.clone();
let api_token = embed_params.api_token.clone();
let adapter = Adapter::new(
adapter_parameters,
adapter_source.unwrap(),
Expand Down Expand Up @@ -875,10 +877,12 @@ impl Infer {
err
})?;

let embed_params = request.parameters.clone().unwrap_or_default();

let (adapter_source, adapter_parameters) = extract_adapter_params(
request.parameters.adapter_id.clone(),
request.parameters.adapter_source.clone(),
request.parameters.adapter_parameters.clone(),
embed_params.adapter_id.clone(),
embed_params.adapter_source.clone(),
embed_params.adapter_parameters.clone(),
);

let adapter_idx;
Expand All @@ -894,7 +898,7 @@ impl Infer {
}
}

let api_token = request.parameters.api_token.clone();
let api_token = embed_params.api_token.clone();
let adapter = Adapter::new(
adapter_parameters,
adapter_source.unwrap(),
Expand Down
26 changes: 14 additions & 12 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1144,20 +1144,22 @@ pub(crate) struct EmbedParameters {
pub api_token: Option<String>,
}

fn default_embed_parameters() -> EmbedParameters {
EmbedParameters {
adapter_id: None,
adapter_source: None,
adapter_parameters: None,
api_token: None,
impl Default for EmbedParameters {
fn default() -> Self {
Self {
adapter_id: None,
adapter_source: None,
adapter_parameters: None,
api_token: None,
}
}
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
struct EmbedRequest {
inputs: String,
#[serde(default = "default_embed_parameters")]
pub parameters: EmbedParameters,
#[serde(default)]
pub parameters: Option<EmbedParameters>,
}

#[derive(Serialize, ToSchema)]
Expand Down Expand Up @@ -1192,8 +1194,8 @@ struct CompatEmbedRequest {
dimensions: Option<i32>,
#[allow(dead_code)]
user: Option<String>,
#[serde(default = "default_embed_parameters")]
parameters: EmbedParameters,
#[serde(default)]
parameters: Option<EmbedParameters>,
}

#[derive(Serialize, ToSchema)]
Expand Down Expand Up @@ -1221,8 +1223,8 @@ struct BatchClassifyRequest {
#[derive(Clone, Debug, Deserialize, ToSchema)]
struct BatchEmbedRequest {
inputs: Vec<String>,
#[serde(default = "default_embed_parameters")]
parameters: EmbedParameters,
#[serde(default)]
parameters: Option<EmbedParameters>,
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down
4 changes: 2 additions & 2 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,12 +486,12 @@ async fn health(
if health.shard_info().supports_embeddings {
let embed_request = EmbedRequest {
inputs: "San Francisco".to_string(),
parameters: EmbedParameters {
parameters: Some(EmbedParameters {
adapter_id: None,
adapter_source: None,
adapter_parameters: None,
api_token: None,
},
}),
};
match infer.embed(embed_request).await {
Ok(_) => {}
Expand Down

0 comments on commit 9e288b4

Please sign in to comment.