Skip to content

Commit

Permalink
🐛 fix v1 attention caching issues
Browse files Browse the repository at this point in the history
Signed-off-by: Joe Runde <[email protected]>
  • Loading branch information
joerunde committed Nov 1, 2024
1 parent fe8f209 commit 41fa3f1
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 16 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,5 @@ markers = [
"skip_global_cleanup",
"core_model: run this model test in each PR instead of just daily",
"distributed_2_gpus: run this test only in distributed tests for 2 GPUs",
"skip_v1: do not run this test with v1",
]
39 changes: 29 additions & 10 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
return forced_attn_backend


@lru_cache(maxsize=None)
def get_attn_backend(
head_size: int,
dtype: torch.dtype,
Expand All @@ -98,15 +97,39 @@ def get_attn_backend(
is_blocksparse: bool = False,
) -> Type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
# value to be returned from the cache if the value changes between calls.
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
# private function.
return _cached_get_attn_backend(
head_size=head_size,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
block_size=block_size,
is_attention_free=is_attention_free,
is_blocksparse=is_blocksparse,
use_v1=envs.VLLM_USE_V1,
)


@lru_cache(maxsize=None)
def _cached_get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
is_blocksparse: bool = False,
use_v1: bool = False,
) -> Type[AttentionBackend]:
if is_blocksparse:
logger.info("Using BlocksparseFlashAttention backend.")
from vllm.attention.backends.blocksparse_attn import (
BlocksparseFlashAttentionBackend)
return BlocksparseFlashAttentionBackend

backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
is_attention_free)
is_attention_free, use_v1)
if backend == _Backend.FLASH_ATTN:
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
Expand Down Expand Up @@ -157,13 +180,9 @@ def get_attn_backend(
raise ValueError("Invalid attention backend.")


def which_attn_to_use(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
) -> _Backend:
def which_attn_to_use(head_size: int, dtype: torch.dtype,
kv_cache_dtype: Optional[str], block_size: int,
is_attention_free: bool, use_v1: bool) -> _Backend:
"""Returns which flash attention backend to use."""
# Default case.
selected_backend = _Backend.FLASH_ATTN
Expand Down Expand Up @@ -220,7 +239,7 @@ def which_attn_to_use(
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH

if envs.VLLM_USE_V1:
if use_v1:
return _Backend.FLASH_ATTN_VLLM_V1

# FlashAttn in NVIDIA GPUs.
Expand Down
12 changes: 6 additions & 6 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def forward(
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")

output = torch.ops.vllm.unified_flash_attention(
output = torch.ops.vllm.unified_v1_flash_attention(
query,
key,
value,
Expand All @@ -153,7 +153,7 @@ def forward(
return output


def unified_flash_attention(
def unified_v1_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
Expand Down Expand Up @@ -216,7 +216,7 @@ def unified_flash_attention(
return output.view(num_tokens, hidden_size)


def unified_flash_attention_fake(
def unified_v1_flash_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
Expand All @@ -236,8 +236,8 @@ def unified_flash_attention_fake(


direct_register_custom_op(
op_name="unified_flash_attention",
op_func=unified_flash_attention,
op_name="unified_v1_flash_attention",
op_func=unified_v1_flash_attention,
mutates_args=["kv_cache"],
fake_impl=unified_flash_attention_fake,
fake_impl=unified_v1_flash_attention_fake,
)

0 comments on commit 41fa3f1

Please sign in to comment.