From d34c4a8351d301de0e531ced07e1cb25b28ddad0 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 4 Dec 2024 21:06:53 -0800 Subject: [PATCH 1/5] tmp Signed-off-by: Woosuk Kwon --- vllm/v1/attention/backends/flash_attn.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index d37989055c2e5..afec0a232ad9e 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -33,10 +33,10 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, - ) -> Tuple[int, ...]: + ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) + return ((num_blocks, block_size, num_kv_heads, head_size), ) * 2 @dataclass @@ -106,7 +106,7 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], attn_metadata: FlashAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, @@ -138,14 +138,24 @@ def forward( # Profiling run. return output - num_actual_tokens = attn_metadata.num_actual_tokens + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + num_actual_tokens = attn_metadata.num_actual_tokens # Reshape the input keys and values and store them in the cache. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] and + # value[:num_actual_tokens] because the reshape_and_cache_flash op uses + # the slot_mapping's shape to determine the number of actual tokens. key_cache = kv_cache[0] value_cache = kv_cache[1] torch.ops._C_cache_ops.reshape_and_cache_flash( - key[:num_actual_tokens], - value[:num_actual_tokens], + key, + value, key_cache, value_cache, attn_metadata.slot_mapping, From 14e2f777b8eb6448787f54fce732fd732e6ec74d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 4 Dec 2024 21:14:56 -0800 Subject: [PATCH 2/5] minor Signed-off-by: Woosuk Kwon --- vllm/v1/attention/backends/flash_attn.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index afec0a232ad9e..c85f43c6288f0 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -36,7 +36,7 @@ def get_kv_cache_shape( ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") - return ((num_blocks, block_size, num_kv_heads, head_size), ) * 2 + return (2, num_blocks, block_size, num_kv_heads, head_size) @dataclass @@ -106,7 +106,7 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: Tuple[torch.Tensor, torch.Tensor], + kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, @@ -151,8 +151,7 @@ def forward( # not padded. However, we don't need to do key[:num_actual_tokens] and # value[:num_actual_tokens] because the reshape_and_cache_flash op uses # the slot_mapping's shape to determine the number of actual tokens. - key_cache = kv_cache[0] - value_cache = kv_cache[1] + key_cache, value_cache = kv_cache.unbind(0) torch.ops._C_cache_ops.reshape_and_cache_flash( key, value, From fc025ec948790649c68a9bb795f854aa930639f1 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 8 Dec 2024 02:49:58 -0800 Subject: [PATCH 3/5] fix Signed-off-by: Woosuk Kwon --- csrc/cache_kernels.cu | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 1be806bbfa43c..8e435a339d681 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -307,10 +307,16 @@ void reshape_and_cache_flash( torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& slot_mapping, // [num_tokens] + torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] const std::string& kv_cache_dtype, const double k_scale, const double v_scale) { - int num_tokens = key.size(0); + // NOTE(woosuk): key.size(0) can be different from slot_mapping.size(0) + // because of padding. Specifically, key.size(0) is the number of tokens + // after padding (for CUDA graphs), while slot_mapping.size(0) can be + // the actual number of tokens without padding (in vLLM V1). + // For compatibility with both cases, we use slot_mapping.size(0) as the + // number of tokens. + int num_tokens = slot_mapping.size(0); int num_heads = key.size(1); int head_size = key.size(2); int block_size = key_cache.size(1); From 194fa9e564e16f8ef577c009980d4d88ab6b78ef Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 8 Dec 2024 02:52:25 -0800 Subject: [PATCH 4/5] minor Signed-off-by: Woosuk Kwon --- vllm/v1/attention/backends/flash_attn.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index c85f43c6288f0..251a103e60f06 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -33,7 +33,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, - ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + ) -> Tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) @@ -144,6 +144,8 @@ def forward( # in this method. For example, `view` and `slice` (or `[:n]`) operations # are surprisingly slow even in the case they do not invoke any GPU ops. # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. num_actual_tokens = attn_metadata.num_actual_tokens # Reshape the input keys and values and store them in the cache. From 269901d5498a6bdce7ea9aa90112408fad4c5c91 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 8 Dec 2024 02:59:44 -0800 Subject: [PATCH 5/5] comment Signed-off-by: Woosuk Kwon --- csrc/cache_kernels.cu | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 8e435a339d681..8a95279f9a25a 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -310,10 +310,14 @@ void reshape_and_cache_flash( torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] const std::string& kv_cache_dtype, const double k_scale, const double v_scale) { - // NOTE(woosuk): key.size(0) can be different from slot_mapping.size(0) - // because of padding. Specifically, key.size(0) is the number of tokens - // after padding (for CUDA graphs), while slot_mapping.size(0) can be - // the actual number of tokens without padding (in vLLM V1). + // NOTE(woosuk): In vLLM V1, key.size(0) can be different from + // slot_mapping.size(0) because of padding for CUDA graphs. + // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because + // both include padding. + // In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0) + // since key includes padding for CUDA graphs, while slot_mapping does not. + // In this case, slot_mapping.size(0) represents the actual number of tokens + // before padding. // For compatibility with both cases, we use slot_mapping.size(0) as the // number of tokens. int num_tokens = slot_mapping.size(0);