Skip to content

Commit

Permalink
Launcher args for compile max batch size and rank
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Nov 20, 2024
1 parent d6caec4 commit 0b489b9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
25 changes: 25 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,14 @@ struct Args {
#[clap(long, env, value_enum)]
compile: bool,

// The maximum batch size past which CUDA graphs are disabled.
#[clap(default_value = "128", long, env)]
compile_max_batch_size: usize,

// The maximum adapter rank (LoRA) past which CUDA graphs are disabled.
#[clap(default_value = "64", long, env)]
compile_max_rank: usize,

/// The number of speculative tokens to generate in the model per step.
/// Defaults to 0, meaning no speculative decoding.
#[clap(long, env)]
Expand Down Expand Up @@ -641,6 +649,8 @@ fn shard_manager(
adapter_source: String,
quantize: Option<Quantization>,
compile: bool,
compile_max_batch_size: usize,
compile_max_rank: usize,
speculative_tokens: Option<usize>,
speculation_max_batch_size: usize,
preloaded_adapter_ids: Vec<String>,
Expand Down Expand Up @@ -807,6 +817,17 @@ fn shard_manager(
envs.push(("CHUNKED_PREFILL".into(), chunked_prefill.to_string().into()));
}

// Compile max batch size and rank
envs.push((
"LORAX_COMPILE_MAX_BATCH_SIZE".into(),
compile_max_batch_size.to_string().into(),
));

envs.push((
"LORAX_COMPILE_MAX_RANK".into(),
compile_max_rank.to_string().into(),
));

// Speculative decoding max batch size
envs.push((
"LORAX_SPECULATION_MAX_BATCH_SIZE".into(),
Expand Down Expand Up @@ -1261,6 +1282,8 @@ fn spawn_shards(
let otlp_endpoint = args.otlp_endpoint.clone();
let quantize = args.quantize;
let compile = args.compile;
let compile_max_batch_size = args.compile_max_batch_size;
let compile_max_rank = args.compile_max_rank;
let speculative_tokens = args.speculative_tokens;
let speculation_max_batch_size = args.speculation_max_batch_size;
let preloaded_adapter_ids = args.preloaded_adapter_ids.clone();
Expand Down Expand Up @@ -1289,6 +1312,8 @@ fn spawn_shards(
adapter_source,
quantize,
compile,
compile_max_batch_size,
compile_max_rank,
speculative_tokens,
speculation_max_batch_size,
preloaded_adapter_ids,
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from lorax_server.models.model import Model


MAX_BATCH_SIZE = int(os.environ.get("LORAX_COMPILE_MAX_BATCH_SIZE", 96))
MAX_BATCH_SIZE = int(os.environ.get("LORAX_COMPILE_MAX_BATCH_SIZE", 128))
MAX_RANK = int(os.environ.get("LORAX_COMPILE_MAX_RANK", 64))

SLOT_PAD_VALUE = -1
Expand Down

0 comments on commit 0b489b9

Please sign in to comment.