diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 0e7a63e47..9f6bccad0 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -345,6 +345,14 @@ struct Args { #[clap(long, env, value_enum)] compile: bool, + // The maximum batch size past which CUDA graphs are disabled. + #[clap(default_value = "128", long, env)] + compile_max_batch_size: usize, + + // The maximum adapter rank (LoRA) past which CUDA graphs are disabled. + #[clap(default_value = "64", long, env)] + compile_max_rank: usize, + /// The number of speculative tokens to generate in the model per step. /// Defaults to 0, meaning no speculative decoding. #[clap(long, env)] @@ -641,6 +649,8 @@ fn shard_manager( adapter_source: String, quantize: Option, compile: bool, + compile_max_batch_size: usize, + compile_max_rank: usize, speculative_tokens: Option, speculation_max_batch_size: usize, preloaded_adapter_ids: Vec, @@ -807,6 +817,17 @@ fn shard_manager( envs.push(("CHUNKED_PREFILL".into(), chunked_prefill.to_string().into())); } + // Compile max batch size and rank + envs.push(( + "LORAX_COMPILE_MAX_BATCH_SIZE".into(), + compile_max_batch_size.to_string().into(), + )); + + envs.push(( + "LORAX_COMPILE_MAX_RANK".into(), + compile_max_rank.to_string().into(), + )); + // Speculative decoding max batch size envs.push(( "LORAX_SPECULATION_MAX_BATCH_SIZE".into(), @@ -1261,6 +1282,8 @@ fn spawn_shards( let otlp_endpoint = args.otlp_endpoint.clone(); let quantize = args.quantize; let compile = args.compile; + let compile_max_batch_size = args.compile_max_batch_size; + let compile_max_rank = args.compile_max_rank; let speculative_tokens = args.speculative_tokens; let speculation_max_batch_size = args.speculation_max_batch_size; let preloaded_adapter_ids = args.preloaded_adapter_ids.clone(); @@ -1289,6 +1312,8 @@ fn spawn_shards( adapter_source, quantize, compile, + compile_max_batch_size, + compile_max_rank, speculative_tokens, speculation_max_batch_size, preloaded_adapter_ids, diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index 8baaed757..9a73163bf 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -26,7 +26,7 @@ from lorax_server.models.model import Model -MAX_BATCH_SIZE = int(os.environ.get("LORAX_COMPILE_MAX_BATCH_SIZE", 96)) +MAX_BATCH_SIZE = int(os.environ.get("LORAX_COMPILE_MAX_BATCH_SIZE", 128)) MAX_RANK = int(os.environ.get("LORAX_COMPILE_MAX_RANK", 64)) SLOT_PAD_VALUE = -1