diff --git a/tests/kernels/test_mha_attn.py b/tests/kernels/test_mha_attn.py new file mode 100644 index 0000000000000..22d434f5e40ef --- /dev/null +++ b/tests/kernels/test_mha_attn.py @@ -0,0 +1,126 @@ +""" +Test: + +* Tests for MultiHeadAttention layer +""" +from unittest.mock import patch + +import pytest +import torch + +from vllm.attention.layer import MultiHeadAttention +from vllm.attention.selector import _Backend, _cached_get_attn_backend +from vllm.platforms import current_platform +from vllm.platforms.cpu import CpuPlatform +from vllm.platforms.cuda import CudaPlatform +from vllm.platforms.rocm import RocmPlatform + + +@pytest.fixture(autouse=True) +def clear_cache(): + """Clear lru cache to ensure each test case runs without caching. + """ + _cached_get_attn_backend.cache_clear() + + +@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) +def test_mha_attn_platform(device: str): + """ + Test that the attention selector between different platform and device. + """ + torch.set_default_dtype(torch.float16) + + if device == "cpu": + with patch("vllm.attention.selector.current_platform", CpuPlatform()): + attn = MultiHeadAttention(16, 64, scale=1) + assert attn.attn_backend == _Backend.TORCH_SDPA + elif device == "hip": + with patch("vllm.attention.selector.current_platform", RocmPlatform()): + attn = MultiHeadAttention(16, 64, scale=1) + assert attn.attn_backend == _Backend.TORCH_SDPA + else: + with patch("vllm.attention.selector.current_platform", CudaPlatform()): + attn = MultiHeadAttention(16, 64, scale=1) + assert attn.attn_backend == _Backend.FLASH_ATTN + + with patch("vllm.attention.selector.current_platform", CudaPlatform()): + attn = MultiHeadAttention(16, 72, scale=1) + assert attn.attn_backend == _Backend.XFORMERS + + +def ref_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, +) -> torch.Tensor: + """ + Native implementation of scaled dot product attention without mask: + - query, key, value: [batch_size, seq_len, num_heads, head_size] + - attn_mask: [batch_size, seq_len, seq_len] + """ + query, key, value = (x.transpose(1, 2) for x in (query, key, value)) + attn_weights = scale * torch.matmul(query, key.transpose(2, 3)) + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.matmul(attn_weights, value).transpose(1, 2) + return out + + +BATCH_SIZES = [1, 16] +SEQ_LENS = [1] +NUM_HEADS = [1, 16] +NUM_KV_HEADS = [1] +HEAD_SIZES = [64, 80] +# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} +DTYPES = [ + torch.half, torch.bfloat16, torch.float +] if not current_platform.is_rocm() else [torch.half, torch.bfloat16] +CUDA_DEVICES = ["cuda"] + + +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("seq_len", SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_mha_attn_forward( + batch_size: int, + seq_len: int, + num_heads: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: str, +): + current_platform.seed_everything(0) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + q = torch.randn(batch_size, seq_len, num_heads * head_size) + k = torch.randn(batch_size, seq_len, num_kv_heads * head_size) + v = torch.randn(batch_size, seq_len, num_kv_heads * head_size) + scale = 1.0 / head_size**0.5 + attn = MultiHeadAttention(num_heads, + head_size, + scale=scale, + num_kv_heads=num_kv_heads) + output = attn(q, k, v) + + assert num_heads % num_kv_heads == 0 + num_queries_per_kv = num_heads // num_kv_heads + q = q.reshape(batch_size, seq_len, num_heads, head_size) + k = k.reshape(batch_size, seq_len, num_kv_heads, head_size) + v = v.reshape(batch_size, seq_len, num_kv_heads, head_size) + if num_queries_per_kv > 1: + k = torch.repeat_interleave(k, num_queries_per_kv, dim=2) + v = torch.repeat_interleave(v, num_queries_per_kv, dim=2) + + ref_output = ref_attention( + q, + k, + v, + scale=scale, + ).reshape(batch_size, seq_len, num_heads * head_size) + torch.testing.assert_close(output, ref_output) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 79ea9b666c7e8..a90bb4fbf5ab3 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -210,6 +210,9 @@ def __init__( self.scale = scale self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + dtype = torch.get_default_dtype() attn_backend = get_attn_backend(head_size, dtype, @@ -217,11 +220,12 @@ def __init__( block_size=16, is_attention_free=False) backend = backend_name_to_enum(attn_backend.get_name()) - if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: - backend = _Backend.XFORMERS self.attn_backend = backend if backend in { - _Backend.TORCH_SDPA, _Backend.XFORMERS + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.FLASH_ATTN, + _Backend.FLASH_ATTN_VLLM_V1, } else _Backend.TORCH_SDPA def forward( @@ -231,7 +235,6 @@ def forward( value: torch.Tensor, ) -> torch.Tensor: """Input shape: batch_size x seq_len x hidden_size""" - # TODO(Isotr0py): Use existing backend implementations and support FA2 bsz, q_len, _ = query.size() kv_len = key.size(1) @@ -239,7 +242,19 @@ def forward( key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size) value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size) - if self.attn_backend == _Backend.XFORMERS: + if (num_repeat := self.num_queries_per_kv) > 1: + # Handle MQA and GQA + key = torch.repeat_interleave(key, num_repeat, dim=2) + value = torch.repeat_interleave(value, num_repeat, dim=2) + + if self.attn_backend in { + _Backend.FLASH_ATTN, + _Backend.FLASH_ATTN_VLLM_V1, + }: + from vllm.vllm_flash_attn import flash_attn_func + + out = flash_attn_func(query, key, value, softmax_scale=self.scale) + elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops out = xops.memory_efficient_attention_forward(query,