From 6cfc0f33dfdc7145fd38aa55e6f9ffbc59339c4b Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 28 Mar 2024 14:19:04 -0700 Subject: [PATCH] Fix CUDA compile when using long sequence lengths (#363) --- .../custom_modeling/flash_mistral_modeling.py | 7 ++-- .../custom_modeling/flash_qwen2_modeling.py | 6 +-- server/lorax_server/models/flash_causal_lm.py | 17 ++++++-- server/lorax_server/models/flash_mistral.py | 3 -- server/lorax_server/models/flash_qwen2.py | 2 - server/lorax_server/models/model.py | 10 +++-- server/lorax_server/utils/graph.py | 41 +++++++++++++++---- 7 files changed, 58 insertions(+), 28 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py index b43b6a7cd..b6e9813cb 100644 --- a/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py @@ -217,7 +217,7 @@ def __init__( ): super().__init__() self.max_past = ( - config.sliding_window if config.sliding_window is not None else 0 + config.sliding_window if config.sliding_window is not None else -1 ) self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size @@ -538,8 +538,7 @@ def __init__(self, config, weights): ), 0, LM_HEAD, process_group=weights.process_group) self.max_past = config.sliding_window - if self.max_past is None: - raise ValueError("max_past cannot be None") + def forward( self, @@ -558,7 +557,7 @@ def forward( if prefill_cache_indices is not None: # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] - else: + elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values max_s = min(self.max_past, max_s) diff --git a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py index 115b71dc8..ad587c92b 100644 --- a/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_qwen2_modeling.py @@ -155,7 +155,7 @@ def __init__( super().__init__() self.max_past = ( - config.sliding_window if config.sliding_window is not None else 0 + config.sliding_window if config.sliding_window is not None else -1 ) self.num_heads = config.num_attention_heads @@ -457,8 +457,6 @@ def __init__(self, config, weights): ), 0, LM_HEAD, process_group=weights.process_group) self.max_past = config.sliding_window - if self.max_past is None: - raise ValueError("max_past cannot be None") def forward( self, @@ -477,7 +475,7 @@ def forward( if prefill_cache_indices is not None: # Slots also need to be sliced as it has the same size as the whole kv tensor slots = slots[prefill_cache_indices] - else: + elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values max_s = min(self.max_past, max_s) diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 9b7112d16..2f53ffd6c 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -761,6 +761,10 @@ def __init__( self.compile = compile self.model_graph_wrapper: GraphCache = None + + @property + def sliding_window_blocks(self) -> Optional[int]: + return SLIDING_WINDOW_BLOCKS @property def batch_type(self) -> Type[FlashCausalLMBatch]: @@ -771,6 +775,8 @@ def adapter_memory_size(self) -> int: return ADAPTER_MEMORY_FRACTION * total_gpu_memory def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int): + max_total_tokens = batch.max_seqlen + max_new_tokens + torch.cuda.empty_cache() try: cache_manager = set_cache_manager( @@ -809,7 +815,13 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int): # Estimate the memory overhead from CUDA graphs so we can subtract it from the kv cache. # Needs to be estimated here rather than fully initialized as the graph cache relies on the # cache manager being set. - self.model_graph_wrapper = GraphCache(self.model, self.device, self.adapter_layers) + self.model_graph_wrapper = GraphCache( + self.model, + self.device, + self.adapter_layers, + max_total_tokens, + self.sliding_window_blocks + ) graph_cache_memory = self.model_graph_wrapper.get_estimated_cache_memory() logger.info("Estimated graph cache memory: {} MB", graph_cache_memory / 1024 / 1024) torch.cuda.synchronize(self.device) @@ -900,9 +912,6 @@ def generate_token( prefill_logprobs = batch.prefill_next_token_indices is not None return_alternatives = any(req.parameters.return_k_alternatives > 0 for req in batch.requests) - # Debugging for LoRAX - # print("!!! adapter_indices", batch.adapter_indices) - if batch.needed_blocks_slots: # Allocate blocks to this batch block_tables, block_tables_tensor, slots = get_cache_manager().allocate( diff --git a/server/lorax_server/models/flash_mistral.py b/server/lorax_server/models/flash_mistral.py index d480697e3..97df03804 100644 --- a/server/lorax_server/models/flash_mistral.py +++ b/server/lorax_server/models/flash_mistral.py @@ -58,9 +58,6 @@ def __init__( ) config.quantize = quantize - if config.sliding_window is None: - config.sliding_window = config.max_position_embeddings - torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") diff --git a/server/lorax_server/models/flash_qwen2.py b/server/lorax_server/models/flash_qwen2.py index 4c8a88e16..693b1e381 100644 --- a/server/lorax_server/models/flash_qwen2.py +++ b/server/lorax_server/models/flash_qwen2.py @@ -99,8 +99,6 @@ def __init__( model = FlashQwen2ForCausalLM(config, weights) - if config.sliding_window is None: - config.sliding_window = config.max_position_embeddings self.config = config torch.distributed.barrier(group=self.process_group) diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 48505a4ed..047bf8cd4 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -72,6 +72,10 @@ def info(self) -> InfoResponse: device_type=self.device.type, window_size=self.sliding_window, ) + + @property + def sliding_window_blocks(self) -> Optional[int]: + return None @property @abstractmethod @@ -150,11 +154,11 @@ def load_adapter( adapter_index: int, api_token: str, ): - """Physically loads the adapter weights into the model. + """Loads adapter weights from disk / host memory on the GPU. adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded - into model. Otherwise, the adapter weights are merged into the model - weights on the fly. + into model. Otherwise, the adapter weights are applied during the forward + pass and stored separately from the base model parameters. """ if adapter_index in self.loaded_adapters: # Adapter already loaded diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index 956346e39..4a1db6be1 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from functools import lru_cache +import math from statistics import median from typing import TYPE_CHECKING, List, Optional, Tuple import numpy as np @@ -23,7 +24,6 @@ # TODO(travis): make this configurable by model / user MAX_BATCH_SIZE = 256 -MAX_CONTEXT_LENGTH = 8192 MAX_RANK = 64 SLOT_PAD_VALUE = -1 @@ -77,8 +77,18 @@ class GraphState: @lru_cache(maxsize=1) -def get_max_graph_state(device: torch.device, adapter_layers: Tuple[str]) -> GraphState: - max_num_blocks = (MAX_CONTEXT_LENGTH + BLOCK_SIZE - 1) // BLOCK_SIZE +def get_max_graph_state( + device: torch.device, + adapter_layers: Tuple[str], + max_total_tokens: int, + sliding_window_blocks: Optional[int] = None, +) -> GraphState: + # max_num_blocks = (max_total_tokens + BLOCK_SIZE - 1) // BLOCK_SIZE + max_num_blocks = math.ceil((max_total_tokens - 1) / BLOCK_SIZE) + if sliding_window_blocks is not None: + # Needed blocks can not go over SLIDING_WINDOW_BLOCKS + max_num_blocks = max(max_num_blocks, sliding_window_blocks) + block_tables_arr = np.zeros((MAX_BATCH_SIZE, max_num_blocks), dtype=np.int32) block_tables = torch.from_numpy(block_tables_arr).to(device=device) @@ -133,7 +143,7 @@ def __init__( memory_pool: Tuple[int, int], input_state: GraphState, output_states: torch.Tensor, - model, + model: nn.Module, ): self.graph = graph self.memory_pool = memory_pool @@ -149,8 +159,10 @@ def trace( batch_size: int, max_rank: int, memory_pool: Tuple[int, int], + max_total_tokens: int, + sliding_window_blocks: Optional[int] = None, ) -> "GraphWrapper": - max_input_state = get_max_graph_state(device, adapter_layers) + max_input_state = get_max_graph_state(device, adapter_layers, max_total_tokens, sliding_window_blocks) # WARNING: for some reason the SGMV kernel can hang if we don't use a power of 2 # as the segment size. This is a workaround until we can figure out why. @@ -216,7 +228,7 @@ def trace( block_tables=input_state.block_tables, slots=input_state.slots, input_lengths=input_state.input_lengths, - max_s=MAX_CONTEXT_LENGTH, + max_s=max_total_tokens, adapter_data=input_state.adapter_data, lm_head_indices=None, ) @@ -277,12 +289,21 @@ def __call__(self, *args, **kwargs): class GraphCache: - def __init__(self, model: nn.Module, device: torch.device, adapter_layers: List[str]): + def __init__( + self, + model: nn.Module, + device: torch.device, + adapter_layers: List[str], + max_total_tokens: int, + sliding_window_blocks: Optional[int] = None, + ): self.model = model self.device = device self.adapter_layers = tuple(adapter_layers) self.memory_pool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None self.cache = {} + self.max_total_tokens = max_total_tokens + self.sliding_window_blocks = sliding_window_blocks def can_use_graph( self, @@ -303,7 +324,7 @@ def can_use_graph( return ( torch.cuda.is_available() and batch_size <= MAX_BATCH_SIZE - and max_s <= MAX_CONTEXT_LENGTH + and max_s <= self.max_total_tokens and max_rank <= MAX_RANK and nranks <= 1 and max_rank in _allowed_ranks @@ -331,6 +352,8 @@ def get_estimated_cache_memory(self) -> int: batch_size, max_rank, pool, + self.max_total_tokens, + self.sliding_window_blocks, ) tmp_cache[key] = graph pool = graph.memory_pool @@ -368,6 +391,8 @@ def warmup(self): batch_size, max_rank, pool, + self.max_total_tokens, + self.sliding_window_blocks, ) self.cache[key] = graph pool = graph.memory_pool