Skip to content

Commit

Permalink
Memory estimate fixes (#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh authored Dec 17, 2024
1 parent ddd1c2d commit 6dfb215
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 25 deletions.
34 changes: 15 additions & 19 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.4.0 # Use the latest revision
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black
name: Format code
args:
- --line-length=120
- repo: https://github.com/pycqa/flake8
rev: 7.0.0
hooks:
- id: flake8
name: flake8
args: ['--max-line-length=120']
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.4.0 # Use the latest revision
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.8.3
hooks:
# Run the linter.
- id: ruff
args: [--fix]
# Run the formatter.
- id: ruff-format
15 changes: 14 additions & 1 deletion launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,10 @@ struct Args {
#[clap(default_value = "64", long, env)]
compile_max_rank: usize,

// The initial batch size for model CUDA compilations
#[clap(default_value = "32", long, env)]
compile_batch_size: usize,

/// The number of speculative tokens to generate in the model per step.
/// Defaults to 0, meaning no speculative decoding.
#[clap(long, env)]
Expand Down Expand Up @@ -633,7 +637,7 @@ struct Args {
#[clap(long, env)]
disable_sgmv: bool,

#[clap(default_value = "0.8", long, env)]
#[clap(default_value = "0.9", long, env)]
memory_wiggle_room: f32,
}

Expand All @@ -654,6 +658,7 @@ fn shard_manager(
compile: bool,
compile_max_batch_size: usize,
compile_max_rank: usize,
compile_batch_size: usize,
speculative_tokens: Option<usize>,
speculation_max_batch_size: usize,
preloaded_adapter_ids: Vec<String>,
Expand Down Expand Up @@ -832,6 +837,12 @@ fn shard_manager(
compile_max_rank.to_string().into(),
));

// Compile initial batch size
envs.push((
"LORAX_COMPILE_BATCH_SIZE".into(),
compile_batch_size.to_string().into(),
));

// Speculative decoding max batch size
envs.push((
"LORAX_SPECULATION_MAX_BATCH_SIZE".into(),
Expand Down Expand Up @@ -1294,6 +1305,7 @@ fn spawn_shards(
let compile = args.compile;
let compile_max_batch_size = args.compile_max_batch_size;
let compile_max_rank = args.compile_max_rank;
let compile_batch_size = args.compile_batch_size;
let speculative_tokens = args.speculative_tokens;
let speculation_max_batch_size = args.speculation_max_batch_size;
let preloaded_adapter_ids = args.preloaded_adapter_ids.clone();
Expand Down Expand Up @@ -1325,6 +1337,7 @@ fn spawn_shards(
compile,
compile_max_batch_size,
compile_max_rank,
compile_batch_size,
speculative_tokens,
speculation_max_batch_size,
preloaded_adapter_ids,
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,7 +1350,7 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

free_memory = get_cuda_free_memory(self.device, MEMORY_FRACTION - ADAPTER_MEMORY_FRACTION)
free_memory -= graph_cache_memory
free_memory = max(0, free_memory - graph_cache_memory)
logger.info("Memory remaining for kv cache: {} MB", free_memory / 1024 / 1024)

batch_num_blocks = batch.num_blocks if batch is not None else 0
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# CUDA memory fraction
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))

MEMORY_WIGGLE_ROOM = float(os.getenv("MEMORY_WIGGLE_ROOM", "0.8"))
MEMORY_WIGGLE_ROOM = float(os.getenv("MEMORY_WIGGLE_ROOM", "0.9"))


class FakeBarrier:
Expand Down
15 changes: 12 additions & 3 deletions server/lorax_server/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from lorax_server.models.model import Model


MAX_BATCH_SIZE = int(os.environ.get("LORAX_COMPILE_MAX_BATCH_SIZE", 128))
MAX_BATCH_SIZE = int(os.environ.get("LORAX_COMPILE_MAX_BATCH_SIZE", 256))
COMPILE_BATCH_SIZE = int(os.environ.get("LORAX_COMPILE_BATCH_SIZE", 32))
MAX_RANK = int(os.environ.get("LORAX_COMPILE_MAX_RANK", 64))

SLOT_PAD_VALUE = -1
Expand All @@ -40,7 +41,7 @@
CACHED_BATCH_SIZES = [1, 2, 3, 4, 8, 16] + [
BATCH_SIZE_INCREMENT * (i + 1) for i in range(MAX_BATCH_SIZE // BATCH_SIZE_INCREMENT)
]
CACHED_BATCH_SIZES = [b for b in CACHED_BATCH_SIZES if b <= MAX_BATCH_SIZE]
CACHED_BATCH_SIZES = [b for b in CACHED_BATCH_SIZES if b <= COMPILE_BATCH_SIZE]

# Include 0 to ensure we can use cuda graphs without adapters
# TODO(travis): use padding to allow for more ranks without increasing memory usage
Expand Down Expand Up @@ -472,6 +473,7 @@ def __init__(
self.sliding_window_blocks = sliding_window_blocks
self.layer_to_lora_weights = layer_to_lora_weights
self.punica_wrapper = punica_wrapper
self.batch_size = COMPILE_BATCH_SIZE

def can_use_graph(
self,
Expand Down Expand Up @@ -603,7 +605,13 @@ def forward(

key = (batch_size, max_rank)
graph = self.cache.get(key)
if graph is None or not graph.input_state.traced_adapter_layer_names.issuperset(adapter_data.layer_names()):
if (
graph is None
or not graph.input_state.traced_adapter_layer_names.issuperset(adapter_data.layer_names())
# This is the case where COMPILE_BATCH_SIZE < batch_size <= MAX_BATCH_SIZE so
# we just retrace the graph for that new size
or batch_size > self.batch_size
):
current_traced_adapter_layer_names = (
graph.input_state.traced_adapter_layer_names if graph is not None else set()
)
Expand Down Expand Up @@ -631,6 +639,7 @@ def forward(
self.punica_wrapper,
)
self.cache[key] = graph
self.batch_size = batch_size

output_states = graph.forward(
input_ids=input_ids,
Expand Down

0 comments on commit 6dfb215

Please sign in to comment.