From a4f0e75b1812fe67390314ffe7053e4f5bb8d55a Mon Sep 17 00:00:00 2001 From: Magdy Saleh <17618143+magdyksaleh@users.noreply.github.com> Date: Mon, 5 Feb 2024 17:46:02 -0500 Subject: [PATCH] Set default adapter source (#223) --- launcher/src/main.rs | 15 +++++++++------ router/src/adapter.rs | 8 +++----- router/src/main.rs | 4 ++++ router/src/server.rs | 7 +++++++ 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index e50f4a120..0f3c673ff 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -105,8 +105,8 @@ struct Args { /// or it can be a local directory containing the necessary files /// as saved by `save_pretrained(...)` methods of transformers. /// Should be compatible with the model specified in `model_id`. - #[clap(default_value = "", long, env)] - adapter_id: String, + #[clap(long, env)] + adapter_id: Option, /// The source of the model to load. /// Can be `hub` or `s3`. @@ -115,7 +115,7 @@ struct Args { #[clap(default_value = "hub", long, env)] source: String, - /// The source of the model to load. + /// The source of the static adapter to load. /// Can be `hub` or `s3` or `pbase` /// `hub` will load the model from the huggingface hub. /// `s3` will load the model from the predibase S3 bucket. @@ -764,9 +764,10 @@ fn download_convert_model( download_args.push(revision.to_string()) } - if !args.adapter_id.is_empty() { + // check if option has a value + if let Some(adapter_id) = &args.adapter_id { download_args.push("--adapter-id".to_string()); - download_args.push(args.adapter_id.clone()); + download_args.push(adapter_id.to_string()); } // Copy current process env @@ -877,7 +878,7 @@ fn spawn_shards( // Start shard processes for rank in 0..num_shard { let model_id = args.model_id.clone(); - let adapter_id = args.adapter_id.clone(); + let adapter_id = args.adapter_id.clone().unwrap_or_default(); let revision = args.revision.clone(); let source: String = args.source.clone(); let adapter_source: String = args.adapter_source.clone(); @@ -996,6 +997,8 @@ fn spawn_webserver( format!("{}-0", args.shard_uds_path), "--tokenizer-name".to_string(), args.model_id, + "--adapter-source".to_string(), + args.adapter_source, ]; // Model optional max batch total tokens diff --git a/router/src/adapter.rs b/router/src/adapter.rs index 4f0e5d9a9..968e59dc2 100644 --- a/router/src/adapter.rs +++ b/router/src/adapter.rs @@ -2,15 +2,13 @@ use std::hash; use crate::AdapterParameters; +use crate::server::DEFAULT_ADAPTER_SOURCE; + /// "adapter ID" for the base model. The base model does not have an adapter ID, /// but we reason about it in the same way. This must match the base model ID /// used in the Python server. pub const BASE_MODEL_ADAPTER_ID: &str = "__base_model__"; -/// default adapter source. One TODO is to figure out how to do this -/// from within the proto definition, or lib.rs -pub const DEFAULT_ADAPTER_SOURCE: &str = "hub"; - #[derive(Debug, Clone)] pub(crate) struct Adapter { /// adapter parameters @@ -85,7 +83,7 @@ pub(crate) fn extract_adapter_params( } let mut adapter_source = adapter_source.clone(); if adapter_source.is_none() { - adapter_source = Some(DEFAULT_ADAPTER_SOURCE.to_string()); + adapter_source = Some(DEFAULT_ADAPTER_SOURCE.get().unwrap().to_string()); } let adapter_parameters = adapter_parameters.clone().unwrap_or(AdapterParameters { diff --git a/router/src/main.rs b/router/src/main.rs index 7bac183cf..158f67f95 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -75,6 +75,8 @@ struct Args { ngrok_authtoken: Option, #[clap(long, env)] ngrok_edge: Option, + #[clap(default_value = "hub", long, env)] + adapter_source: String, } fn main() -> Result<(), RouterError> { @@ -108,6 +110,7 @@ fn main() -> Result<(), RouterError> { ngrok, ngrok_authtoken, ngrok_edge, + adapter_source, } = args; // Validate args @@ -323,6 +326,7 @@ fn main() -> Result<(), RouterError> { ngrok, ngrok_authtoken, ngrok_edge, + adapter_source, ) .await?; Ok(()) diff --git a/router/src/server.rs b/router/src/server.rs index 88400fb89..66f3fadb6 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -34,6 +34,7 @@ use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; static MODEL_ID: OnceCell = OnceCell::new(); +pub static DEFAULT_ADAPTER_SOURCE: OnceCell = OnceCell::new(); /// Generate tokens if `stream == false` or a stream of token if `stream == true` #[utoipa::path( @@ -712,6 +713,7 @@ pub async fn run( ngrok: bool, ngrok_authtoken: Option, ngrok_edge: Option, + adapter_source: String, ) -> Result<(), axum::BoxError> { // OpenAPI documentation #[derive(OpenApi)] @@ -874,6 +876,11 @@ pub async fn run( MODEL_ID.set(model_id.clone()).unwrap_or_else(|_| { panic!("MODEL_ID was already set!"); }); + DEFAULT_ADAPTER_SOURCE + .set(adapter_source.clone()) + .unwrap_or_else(|_| { + panic!("DEFAULT_ADAPTER_SOURCE was already set!"); + }); // Create router let app = Router::new()