From 7ca481b9e60b2d93243fff9c1b570fb1186c51d5 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 19 Nov 2024 09:57:44 -0800 Subject: [PATCH 1/2] Add cli arg --speculation-max-batch-size --- launcher/src/main.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 26a2e4aa9..df707c00f 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -350,6 +350,10 @@ struct Args { #[clap(long, env)] speculative_tokens: Option, + // The maximum batch size past which speculative decoding is disabled. + #[clap(long, env)] + speculation_max_batch_size: Option, + /// The list of adapter ids to preload during initialization (to avoid cold start times). #[clap(long, env)] preloaded_adapter_ids: Vec, @@ -638,6 +642,7 @@ fn shard_manager( quantize: Option, compile: bool, speculative_tokens: Option, + speculation_max_batch_size: Option, preloaded_adapter_ids: Vec, preloaded_adapter_source: Option, predibase_api_token: Option, @@ -802,6 +807,14 @@ fn shard_manager( envs.push(("CHUNKED_PREFILL".into(), chunked_prefill.to_string().into())); } + // Speculative decoding max batch size + if let Some(speculation_max_batch_size) = speculation_max_batch_size { + envs.push(( + "LORAX_SPECULATION_MAX_BATCH_SIZE".into(), + speculation_max_batch_size.to_string().into(), + )); + } + // Backend if backend == Backend::FlashInfer { envs.push(("FLASH_INFER".into(), "1".into())); @@ -1244,6 +1257,7 @@ fn spawn_shards( let quantize = args.quantize; let compile = args.compile; let speculative_tokens = args.speculative_tokens; + let speculation_max_batch_size = args.speculation_max_batch_size; let preloaded_adapter_ids = args.preloaded_adapter_ids.clone(); let preloaded_adapter_source = args.preloaded_adapter_source.clone(); let predibase_api_token = args.predibase_api_token.clone(); @@ -1271,6 +1285,7 @@ fn spawn_shards( quantize, compile, speculative_tokens, + speculation_max_batch_size, preloaded_adapter_ids, preloaded_adapter_source, predibase_api_token, From 0d3778df1d127442352cd248a7aabd54ff4591ba Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Tue, 19 Nov 2024 09:59:40 -0800 Subject: [PATCH 2/2] Add default --- launcher/src/main.rs | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index df707c00f..748680a6a 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -351,8 +351,8 @@ struct Args { speculative_tokens: Option, // The maximum batch size past which speculative decoding is disabled. - #[clap(long, env)] - speculation_max_batch_size: Option, + #[clap(default_value = "32", long, env)] + speculation_max_batch_size: usize, /// The list of adapter ids to preload during initialization (to avoid cold start times). #[clap(long, env)] @@ -642,7 +642,7 @@ fn shard_manager( quantize: Option, compile: bool, speculative_tokens: Option, - speculation_max_batch_size: Option, + speculation_max_batch_size: usize, preloaded_adapter_ids: Vec, preloaded_adapter_source: Option, predibase_api_token: Option, @@ -808,12 +808,10 @@ fn shard_manager( } // Speculative decoding max batch size - if let Some(speculation_max_batch_size) = speculation_max_batch_size { - envs.push(( - "LORAX_SPECULATION_MAX_BATCH_SIZE".into(), - speculation_max_batch_size.to_string().into(), - )); - } + envs.push(( + "LORAX_SPECULATION_MAX_BATCH_SIZE".into(), + speculation_max_batch_size.to_string().into(), + )); // Backend if backend == Backend::FlashInfer {