Skip to content

Commit

Permalink
[Misc] Add FA2 support to ViT MHA layer (vllm-project#12355)
Browse files Browse the repository at this point in the history
Signed-off-by: Isotr0py <[email protected]>
  • Loading branch information
Isotr0py authored Jan 25, 2025
1 parent bf21481 commit f1fc051
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 5 deletions.
126 changes: 126 additions & 0 deletions tests/kernels/test_mha_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""
Test:
* Tests for MultiHeadAttention layer
"""
from unittest.mock import patch

import pytest
import torch

from vllm.attention.layer import MultiHeadAttention
from vllm.attention.selector import _Backend, _cached_get_attn_backend
from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform


@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching.
"""
_cached_get_attn_backend.cache_clear()


@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
def test_mha_attn_platform(device: str):
"""
Test that the attention selector between different platform and device.
"""
torch.set_default_dtype(torch.float16)

if device == "cpu":
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA
elif device == "hip":
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA
else:
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.FLASH_ATTN

with patch("vllm.attention.selector.current_platform", CudaPlatform()):
attn = MultiHeadAttention(16, 72, scale=1)
assert attn.attn_backend == _Backend.XFORMERS


def ref_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
) -> torch.Tensor:
"""
Native implementation of scaled dot product attention without mask:
- query, key, value: [batch_size, seq_len, num_heads, head_size]
- attn_mask: [batch_size, seq_len, seq_len]
"""
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
attn_weights = scale * torch.matmul(query, key.transpose(2, 3))
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.matmul(attn_weights, value).transpose(1, 2)
return out


BATCH_SIZES = [1, 16]
SEQ_LENS = [1]
NUM_HEADS = [1, 16]
NUM_KV_HEADS = [1]
HEAD_SIZES = [64, 80]
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [
torch.half, torch.bfloat16, torch.float
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
CUDA_DEVICES = ["cuda"]


@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_mha_attn_forward(
batch_size: int,
seq_len: int,
num_heads: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
device: str,
):
current_platform.seed_everything(0)
torch.set_default_device(device)
torch.set_default_dtype(dtype)

q = torch.randn(batch_size, seq_len, num_heads * head_size)
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
scale = 1.0 / head_size**0.5
attn = MultiHeadAttention(num_heads,
head_size,
scale=scale,
num_kv_heads=num_kv_heads)
output = attn(q, k, v)

assert num_heads % num_kv_heads == 0
num_queries_per_kv = num_heads // num_kv_heads
q = q.reshape(batch_size, seq_len, num_heads, head_size)
k = k.reshape(batch_size, seq_len, num_kv_heads, head_size)
v = v.reshape(batch_size, seq_len, num_kv_heads, head_size)
if num_queries_per_kv > 1:
k = torch.repeat_interleave(k, num_queries_per_kv, dim=2)
v = torch.repeat_interleave(v, num_queries_per_kv, dim=2)

ref_output = ref_attention(
q,
k,
v,
scale=scale,
).reshape(batch_size, seq_len, num_heads * head_size)
torch.testing.assert_close(output, ref_output)
25 changes: 20 additions & 5 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,18 +210,22 @@ def __init__(
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype=None,
block_size=16,
is_attention_free=False)
backend = backend_name_to_enum(attn_backend.get_name())
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
backend = _Backend.XFORMERS

self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
_Backend.FLASH_ATTN,
_Backend.FLASH_ATTN_VLLM_V1,
} else _Backend.TORCH_SDPA

def forward(
Expand All @@ -231,15 +235,26 @@ def forward(
value: torch.Tensor,
) -> torch.Tensor:
"""Input shape: batch_size x seq_len x hidden_size"""
# TODO(Isotr0py): Use existing backend implementations and support FA2
bsz, q_len, _ = query.size()
kv_len = key.size(1)

query = query.view(bsz, q_len, self.num_heads, self.head_size)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

if self.attn_backend == _Backend.XFORMERS:
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)

if self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.FLASH_ATTN_VLLM_V1,
}:
from vllm.vllm_flash_attn import flash_attn_func

out = flash_attn_func(query, key, value, softmax_scale=self.scale)
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops

out = xops.memory_efficient_attention_forward(query,
Expand Down

0 comments on commit f1fc051

Please sign in to comment.