diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ac4cd2ce4..ee0fc0c5a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,20 +1,16 @@ repos: -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.4.0 # Use the latest revision - hooks: - - id: trailing-whitespace - - id: end-of-file-fixer - - id: check-yaml -- repo: https://github.com/psf/black - rev: 24.2.0 - hooks: - - id: black - name: Format code - args: - - --line-length=120 -- repo: https://github.com/pycqa/flake8 - rev: 7.0.0 - hooks: - - id: flake8 - name: flake8 - args: ['--max-line-length=120'] + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.4.0 # Use the latest revision + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.8.3 + hooks: + # Run the linter. + - id: ruff + args: [--fix] + # Run the formatter. + - id: ruff-format diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 33116d167..e45b83a32 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -353,6 +353,10 @@ struct Args { #[clap(default_value = "64", long, env)] compile_max_rank: usize, + // The initial batch size for model CUDA compilations + #[clap(default_value = "32", long, env)] + compile_batch_size: usize, + /// The number of speculative tokens to generate in the model per step. /// Defaults to 0, meaning no speculative decoding. #[clap(long, env)] @@ -633,7 +637,7 @@ struct Args { #[clap(long, env)] disable_sgmv: bool, - #[clap(default_value = "0.8", long, env)] + #[clap(default_value = "0.9", long, env)] memory_wiggle_room: f32, } @@ -654,6 +658,7 @@ fn shard_manager( compile: bool, compile_max_batch_size: usize, compile_max_rank: usize, + compile_batch_size: usize, speculative_tokens: Option, speculation_max_batch_size: usize, preloaded_adapter_ids: Vec, @@ -832,6 +837,12 @@ fn shard_manager( compile_max_rank.to_string().into(), )); + // Compile initial batch size + envs.push(( + "LORAX_COMPILE_BATCH_SIZE".into(), + compile_batch_size.to_string().into(), + )); + // Speculative decoding max batch size envs.push(( "LORAX_SPECULATION_MAX_BATCH_SIZE".into(), @@ -1294,6 +1305,7 @@ fn spawn_shards( let compile = args.compile; let compile_max_batch_size = args.compile_max_batch_size; let compile_max_rank = args.compile_max_rank; + let compile_batch_size = args.compile_batch_size; let speculative_tokens = args.speculative_tokens; let speculation_max_batch_size = args.speculation_max_batch_size; let preloaded_adapter_ids = args.preloaded_adapter_ids.clone(); @@ -1325,6 +1337,7 @@ fn spawn_shards( compile, compile_max_batch_size, compile_max_rank, + compile_batch_size, speculative_tokens, speculation_max_batch_size, preloaded_adapter_ids, diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 9fdd80d1b..92937511d 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1350,7 +1350,7 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size free_memory = get_cuda_free_memory(self.device, MEMORY_FRACTION - ADAPTER_MEMORY_FRACTION) - free_memory -= graph_cache_memory + free_memory = max(0, free_memory - graph_cache_memory) logger.info("Memory remaining for kv cache: {} MB", free_memory / 1024 / 1024) batch_num_blocks = batch.num_blocks if batch is not None else 0 diff --git a/server/lorax_server/utils/dist.py b/server/lorax_server/utils/dist.py index ca009a8e2..409720b94 100644 --- a/server/lorax_server/utils/dist.py +++ b/server/lorax_server/utils/dist.py @@ -11,7 +11,7 @@ # CUDA memory fraction MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0")) -MEMORY_WIGGLE_ROOM = float(os.getenv("MEMORY_WIGGLE_ROOM", "0.8")) +MEMORY_WIGGLE_ROOM = float(os.getenv("MEMORY_WIGGLE_ROOM", "0.9")) class FakeBarrier: diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index 9a73163bf..64755f79f 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -26,7 +26,8 @@ from lorax_server.models.model import Model -MAX_BATCH_SIZE = int(os.environ.get("LORAX_COMPILE_MAX_BATCH_SIZE", 128)) +MAX_BATCH_SIZE = int(os.environ.get("LORAX_COMPILE_MAX_BATCH_SIZE", 256)) +COMPILE_BATCH_SIZE = int(os.environ.get("LORAX_COMPILE_BATCH_SIZE", 32)) MAX_RANK = int(os.environ.get("LORAX_COMPILE_MAX_RANK", 64)) SLOT_PAD_VALUE = -1 @@ -40,7 +41,7 @@ CACHED_BATCH_SIZES = [1, 2, 3, 4, 8, 16] + [ BATCH_SIZE_INCREMENT * (i + 1) for i in range(MAX_BATCH_SIZE // BATCH_SIZE_INCREMENT) ] -CACHED_BATCH_SIZES = [b for b in CACHED_BATCH_SIZES if b <= MAX_BATCH_SIZE] +CACHED_BATCH_SIZES = [b for b in CACHED_BATCH_SIZES if b <= COMPILE_BATCH_SIZE] # Include 0 to ensure we can use cuda graphs without adapters # TODO(travis): use padding to allow for more ranks without increasing memory usage @@ -472,6 +473,7 @@ def __init__( self.sliding_window_blocks = sliding_window_blocks self.layer_to_lora_weights = layer_to_lora_weights self.punica_wrapper = punica_wrapper + self.batch_size = COMPILE_BATCH_SIZE def can_use_graph( self, @@ -603,7 +605,13 @@ def forward( key = (batch_size, max_rank) graph = self.cache.get(key) - if graph is None or not graph.input_state.traced_adapter_layer_names.issuperset(adapter_data.layer_names()): + if ( + graph is None + or not graph.input_state.traced_adapter_layer_names.issuperset(adapter_data.layer_names()) + # This is the case where COMPILE_BATCH_SIZE < batch_size <= MAX_BATCH_SIZE so + # we just retrace the graph for that new size + or batch_size > self.batch_size + ): current_traced_adapter_layer_names = ( graph.input_state.traced_adapter_layer_names if graph is not None else set() ) @@ -631,6 +639,7 @@ def forward( self.punica_wrapper, ) self.cache[key] = graph + self.batch_size = batch_size output_states = graph.forward( input_ids=input_ids,