diff --git a/launcher/src/main.rs b/launcher/src/main.rs index d36ec0690..b1420f49e 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -210,6 +210,11 @@ struct Args { #[clap(long, env)] preloaded_adapter_source: Option, + /// The API token to use when fetching adapters from pbase. + /// If specified, will set the environment variable PREDIBASE_API_TOKEN. + #[clap(long, env)] + predibase_api_token: Option, + /// The dtype to be forced upon the model. This option cannot be used with `--quantize`. #[clap(long, env, value_enum)] dtype: Option, @@ -461,6 +466,7 @@ fn shard_manager( speculative_tokens: Option, preloaded_adapter_ids: Vec, preloaded_adapter_source: Option, + predibase_api_token: Option, dtype: Option, trust_remote_code: bool, uds_path: String, @@ -493,6 +499,9 @@ fn shard_manager( fs::remove_file(uds).unwrap(); } + // Copy current process env + let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); + // Process args let mut shard_args = vec![ "serve".to_string(), @@ -552,6 +561,13 @@ fn shard_manager( shard_args.push(preloaded_adapter_source); } + if let Some(predibase_api_token) = predibase_api_token { + envs.push(( + "PREDIBASE_API_TOKEN".into(), + predibase_api_token.to_string().into(), + )); + } + if let Some(dtype) = dtype { shard_args.push("--dtype".to_string()); shard_args.push(dtype.to_string()) @@ -569,9 +585,6 @@ fn shard_manager( shard_args.push(otlp_endpoint); } - // Copy current process env - let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); - // Torch Distributed Env vars envs.push(("RANK".into(), rank.to_string().into())); envs.push(("WORLD_SIZE".into(), world_size.to_string().into())); @@ -1030,6 +1043,7 @@ fn spawn_shards( let speculative_tokens = args.speculative_tokens; let preloaded_adapter_ids = args.preloaded_adapter_ids.clone(); let preloaded_adapter_source = args.preloaded_adapter_source.clone(); + let predibase_api_token = args.predibase_api_token.clone(); let dtype = args.dtype; let trust_remote_code = args.trust_remote_code; let master_port = args.master_port; @@ -1052,6 +1066,7 @@ fn spawn_shards( speculative_tokens, preloaded_adapter_ids, preloaded_adapter_source, + predibase_api_token, dtype, trust_remote_code, uds_path,