diff --git a/launcher/src/main.rs b/launcher/src/main.rs index df707c00f..748680a6a 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -351,8 +351,8 @@ struct Args { speculative_tokens: Option, // The maximum batch size past which speculative decoding is disabled. - #[clap(long, env)] - speculation_max_batch_size: Option, + #[clap(default_value = "32", long, env)] + speculation_max_batch_size: usize, /// The list of adapter ids to preload during initialization (to avoid cold start times). #[clap(long, env)] @@ -642,7 +642,7 @@ fn shard_manager( quantize: Option, compile: bool, speculative_tokens: Option, - speculation_max_batch_size: Option, + speculation_max_batch_size: usize, preloaded_adapter_ids: Vec, preloaded_adapter_source: Option, predibase_api_token: Option, @@ -808,12 +808,10 @@ fn shard_manager( } // Speculative decoding max batch size - if let Some(speculation_max_batch_size) = speculation_max_batch_size { - envs.push(( - "LORAX_SPECULATION_MAX_BATCH_SIZE".into(), - speculation_max_batch_size.to_string().into(), - )); - } + envs.push(( + "LORAX_SPECULATION_MAX_BATCH_SIZE".into(), + speculation_max_batch_size.to_string().into(), + )); // Backend if backend == Backend::FlashInfer {