Skip to content

Commit

Permalink
Fix CUDA compile when using long sequence lengths (#363)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Mar 28, 2024
1 parent 0b9117f commit 6cfc0f3
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def __init__(
):
super().__init__()
self.max_past = (
config.sliding_window if config.sliding_window is not None else 0
config.sliding_window if config.sliding_window is not None else -1
)
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
Expand Down Expand Up @@ -538,8 +538,7 @@ def __init__(self, config, weights):
), 0, LM_HEAD, process_group=weights.process_group)

self.max_past = config.sliding_window
if self.max_past is None:
raise ValueError("max_past cannot be None")


def forward(
self,
Expand All @@ -558,7 +557,7 @@ def forward(
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
else:
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
max_s = min(self.max_past, max_s)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __init__(
super().__init__()

self.max_past = (
config.sliding_window if config.sliding_window is not None else 0
config.sliding_window if config.sliding_window is not None else -1
)

self.num_heads = config.num_attention_heads
Expand Down Expand Up @@ -457,8 +457,6 @@ def __init__(self, config, weights):
), 0, LM_HEAD, process_group=weights.process_group)

self.max_past = config.sliding_window
if self.max_past is None:
raise ValueError("max_past cannot be None")

def forward(
self,
Expand All @@ -477,7 +475,7 @@ def forward(
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
else:
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
max_s = min(self.max_past, max_s)
Expand Down
17 changes: 13 additions & 4 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,10 @@ def __init__(

self.compile = compile
self.model_graph_wrapper: GraphCache = None

@property
def sliding_window_blocks(self) -> Optional[int]:
return SLIDING_WINDOW_BLOCKS

@property
def batch_type(self) -> Type[FlashCausalLMBatch]:
Expand All @@ -771,6 +775,8 @@ def adapter_memory_size(self) -> int:
return ADAPTER_MEMORY_FRACTION * total_gpu_memory

def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int):
max_total_tokens = batch.max_seqlen + max_new_tokens

torch.cuda.empty_cache()
try:
cache_manager = set_cache_manager(
Expand Down Expand Up @@ -809,7 +815,13 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int):
# Estimate the memory overhead from CUDA graphs so we can subtract it from the kv cache.
# Needs to be estimated here rather than fully initialized as the graph cache relies on the
# cache manager being set.
self.model_graph_wrapper = GraphCache(self.model, self.device, self.adapter_layers)
self.model_graph_wrapper = GraphCache(
self.model,
self.device,
self.adapter_layers,
max_total_tokens,
self.sliding_window_blocks
)
graph_cache_memory = self.model_graph_wrapper.get_estimated_cache_memory()
logger.info("Estimated graph cache memory: {} MB", graph_cache_memory / 1024 / 1024)
torch.cuda.synchronize(self.device)
Expand Down Expand Up @@ -900,9 +912,6 @@ def generate_token(
prefill_logprobs = batch.prefill_next_token_indices is not None
return_alternatives = any(req.parameters.return_k_alternatives > 0 for req in batch.requests)

# Debugging for LoRAX
# print("!!! adapter_indices", batch.adapter_indices)

if batch.needed_blocks_slots:
# Allocate blocks to this batch
block_tables, block_tables_tensor, slots = get_cache_manager().allocate(
Expand Down
3 changes: 0 additions & 3 deletions server/lorax_server/models/flash_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ def __init__(
)
config.quantize = quantize

if config.sliding_window is None:
config.sliding_window = config.max_position_embeddings

torch.distributed.barrier(group=self.process_group)

filenames = weight_files(model_id, revision=revision, extension=".safetensors")
Expand Down
2 changes: 0 additions & 2 deletions server/lorax_server/models/flash_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ def __init__(

model = FlashQwen2ForCausalLM(config, weights)

if config.sliding_window is None:
config.sliding_window = config.max_position_embeddings
self.config = config

torch.distributed.barrier(group=self.process_group)
Expand Down
10 changes: 7 additions & 3 deletions server/lorax_server/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def info(self) -> InfoResponse:
device_type=self.device.type,
window_size=self.sliding_window,
)

@property
def sliding_window_blocks(self) -> Optional[int]:
return None

@property
@abstractmethod
Expand Down Expand Up @@ -150,11 +154,11 @@ def load_adapter(
adapter_index: int,
api_token: str,
):
"""Physically loads the adapter weights into the model.
"""Loads adapter weights from disk / host memory on the GPU.
adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded
into model. Otherwise, the adapter weights are merged into the model
weights on the fly.
into model. Otherwise, the adapter weights are applied during the forward
pass and stored separately from the base model parameters.
"""
if adapter_index in self.loaded_adapters:
# Adapter already loaded
Expand Down
41 changes: 33 additions & 8 deletions server/lorax_server/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from dataclasses import dataclass
from functools import lru_cache
import math
from statistics import median
from typing import TYPE_CHECKING, List, Optional, Tuple
import numpy as np
Expand All @@ -23,7 +24,6 @@

# TODO(travis): make this configurable by model / user
MAX_BATCH_SIZE = 256
MAX_CONTEXT_LENGTH = 8192
MAX_RANK = 64

SLOT_PAD_VALUE = -1
Expand Down Expand Up @@ -77,8 +77,18 @@ class GraphState:


@lru_cache(maxsize=1)
def get_max_graph_state(device: torch.device, adapter_layers: Tuple[str]) -> GraphState:
max_num_blocks = (MAX_CONTEXT_LENGTH + BLOCK_SIZE - 1) // BLOCK_SIZE
def get_max_graph_state(
device: torch.device,
adapter_layers: Tuple[str],
max_total_tokens: int,
sliding_window_blocks: Optional[int] = None,
) -> GraphState:
# max_num_blocks = (max_total_tokens + BLOCK_SIZE - 1) // BLOCK_SIZE
max_num_blocks = math.ceil((max_total_tokens - 1) / BLOCK_SIZE)
if sliding_window_blocks is not None:
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
max_num_blocks = max(max_num_blocks, sliding_window_blocks)

block_tables_arr = np.zeros((MAX_BATCH_SIZE, max_num_blocks), dtype=np.int32)
block_tables = torch.from_numpy(block_tables_arr).to(device=device)

Expand Down Expand Up @@ -133,7 +143,7 @@ def __init__(
memory_pool: Tuple[int, int],
input_state: GraphState,
output_states: torch.Tensor,
model,
model: nn.Module,
):
self.graph = graph
self.memory_pool = memory_pool
Expand All @@ -149,8 +159,10 @@ def trace(
batch_size: int,
max_rank: int,
memory_pool: Tuple[int, int],
max_total_tokens: int,
sliding_window_blocks: Optional[int] = None,
) -> "GraphWrapper":
max_input_state = get_max_graph_state(device, adapter_layers)
max_input_state = get_max_graph_state(device, adapter_layers, max_total_tokens, sliding_window_blocks)

# WARNING: for some reason the SGMV kernel can hang if we don't use a power of 2
# as the segment size. This is a workaround until we can figure out why.
Expand Down Expand Up @@ -216,7 +228,7 @@ def trace(
block_tables=input_state.block_tables,
slots=input_state.slots,
input_lengths=input_state.input_lengths,
max_s=MAX_CONTEXT_LENGTH,
max_s=max_total_tokens,
adapter_data=input_state.adapter_data,
lm_head_indices=None,
)
Expand Down Expand Up @@ -277,12 +289,21 @@ def __call__(self, *args, **kwargs):


class GraphCache:
def __init__(self, model: nn.Module, device: torch.device, adapter_layers: List[str]):
def __init__(
self,
model: nn.Module,
device: torch.device,
adapter_layers: List[str],
max_total_tokens: int,
sliding_window_blocks: Optional[int] = None,
):
self.model = model
self.device = device
self.adapter_layers = tuple(adapter_layers)
self.memory_pool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
self.cache = {}
self.max_total_tokens = max_total_tokens
self.sliding_window_blocks = sliding_window_blocks

def can_use_graph(
self,
Expand All @@ -303,7 +324,7 @@ def can_use_graph(
return (
torch.cuda.is_available()
and batch_size <= MAX_BATCH_SIZE
and max_s <= MAX_CONTEXT_LENGTH
and max_s <= self.max_total_tokens
and max_rank <= MAX_RANK
and nranks <= 1
and max_rank in _allowed_ranks
Expand Down Expand Up @@ -331,6 +352,8 @@ def get_estimated_cache_memory(self) -> int:
batch_size,
max_rank,
pool,
self.max_total_tokens,
self.sliding_window_blocks,
)
tmp_cache[key] = graph
pool = graph.memory_pool
Expand Down Expand Up @@ -368,6 +391,8 @@ def warmup(self):
batch_size,
max_rank,
pool,
self.max_total_tokens,
self.sliding_window_blocks,
)
self.cache[key] = graph
pool = graph.memory_pool
Expand Down

0 comments on commit 6cfc0f3

Please sign in to comment.