Skip to content

Commit

Permalink
add flag and allow lazy loading for graph
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh committed Dec 16, 2024
1 parent 2b271d1 commit 12e530a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
13 changes: 13 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,10 @@ struct Args {
#[clap(default_value = "64", long, env)]
compile_max_rank: usize,

// The initial batch size for model CUDA compilations
#[clap(default_value = "32", long, env)]
compile_batch_size: usize,

/// The number of speculative tokens to generate in the model per step.
/// Defaults to 0, meaning no speculative decoding.
#[clap(long, env)]
Expand Down Expand Up @@ -654,6 +658,7 @@ fn shard_manager(
compile: bool,
compile_max_batch_size: usize,
compile_max_rank: usize,
compile_batch_size: usize,
speculative_tokens: Option<usize>,
speculation_max_batch_size: usize,
preloaded_adapter_ids: Vec<String>,
Expand Down Expand Up @@ -832,6 +837,12 @@ fn shard_manager(
compile_max_rank.to_string().into(),
));

// Compile initial batch size
envs.push((
"LORAX_COMPILE_BATCH_SIZE".into(),
compile_batch_size.to_string().into(),
));

// Speculative decoding max batch size
envs.push((
"LORAX_SPECULATION_MAX_BATCH_SIZE".into(),
Expand Down Expand Up @@ -1294,6 +1305,7 @@ fn spawn_shards(
let compile = args.compile;
let compile_max_batch_size = args.compile_max_batch_size;
let compile_max_rank = args.compile_max_rank;
let compile_batch_size = args.compile_batch_size;
let speculative_tokens = args.speculative_tokens;
let speculation_max_batch_size = args.speculation_max_batch_size;
let preloaded_adapter_ids = args.preloaded_adapter_ids.clone();
Expand Down Expand Up @@ -1325,6 +1337,7 @@ fn spawn_shards(
compile,
compile_max_batch_size,
compile_max_rank,
compile_batch_size,
speculative_tokens,
speculation_max_batch_size,
preloaded_adapter_ids,
Expand Down
13 changes: 11 additions & 2 deletions server/lorax_server/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from lorax_server.models.model import Model


MAX_BATCH_SIZE = int(os.environ.get("LORAX_COMPILE_MAX_BATCH_SIZE", 128))
MAX_BATCH_SIZE = int(os.environ.get("LORAX_COMPILE_MAX_BATCH_SIZE", 256))
COMPILE_BATCH_SIZE = int(os.environ.get("LORAX_COMPILE_BATCH_SIZE", 32))
MAX_RANK = int(os.environ.get("LORAX_COMPILE_MAX_RANK", 64))

SLOT_PAD_VALUE = -1
Expand Down Expand Up @@ -472,6 +473,7 @@ def __init__(
self.sliding_window_blocks = sliding_window_blocks
self.layer_to_lora_weights = layer_to_lora_weights
self.punica_wrapper = punica_wrapper
self.batch_size = COMPILE_BATCH_SIZE

def can_use_graph(
self,
Expand Down Expand Up @@ -603,7 +605,13 @@ def forward(

key = (batch_size, max_rank)
graph = self.cache.get(key)
if graph is None or not graph.input_state.traced_adapter_layer_names.issuperset(adapter_data.layer_names()):
if (
graph is None
or not graph.input_state.traced_adapter_layer_names.issuperset(adapter_data.layer_names())
# This is the case where COMPILE_BATCH_SIZE < batch_size <= MAX_BATCH_SIZE so
# we just retrace the graph for that new size
or batch_size > self.batch_size
):
current_traced_adapter_layer_names = (
graph.input_state.traced_adapter_layer_names if graph is not None else set()
)
Expand Down Expand Up @@ -631,6 +639,7 @@ def forward(
self.punica_wrapper,
)
self.cache[key] = graph
self.batch_size = batch_size

output_states = graph.forward(
input_ids=input_ids,
Expand Down

0 comments on commit 12e530a

Please sign in to comment.