Skip to content

Commit 4195f8e

Browse files
committed
Update neuron interface
Signed-off-by: wangxiyuan <[email protected]>
1 parent bb83429 commit 4195f8e

File tree

3 files changed

+119
-55
lines changed

3 files changed

+119
-55
lines changed

tests/kernels/test_attention_selector.py

+32-35
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55

66
from tests.kernels.utils import override_backend_env_variable
7-
from vllm.attention.selector import which_attn_to_use
7+
from vllm.attention.selector import get_attn_backend
88
from vllm.platforms.cpu import CpuPlatform
99
from vllm.platforms.cuda import CudaPlatform
1010
from vllm.platforms.openvino import OpenVinoPlatform
@@ -16,78 +16,75 @@
1616
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
1717
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
1818
def test_env(name: str, device: str, monkeypatch):
19-
"""Test that the attention selector can be set via environment variable."""
19+
"""Test that the attention selector can be set via environment variable.
20+
Note that we do not test FlashAttn because it is the default backend.
21+
"""
2022

2123
override_backend_env_variable(monkeypatch, name)
2224

2325
if device == "cpu":
2426
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
25-
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
26-
False)
27-
assert backend == "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
27+
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
28+
False)
29+
assert backend.name == "TORCH_SDPA"
2830
elif device == "hip":
2931
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
30-
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
31-
False)
32-
assert backend == "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501
32+
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
33+
False)
34+
assert backend.name == "ROCM_FLASH"
3335
elif device == "openvino":
3436
with patch("vllm.attention.selector.current_platform",
3537
OpenVinoPlatform()):
36-
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
37-
False)
38-
assert backend == "vllm.attention.backends.openvino.OpenVINOAttentionBackend" # noqa: E501
38+
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
39+
False)
40+
assert backend.name == "OPENVINO"
3941
else:
4042
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
41-
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
42-
False)
43-
if name == "FLASHINFER":
44-
assert backend == "vllm.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
45-
if name == "XFORMERS":
46-
assert backend == "vllm.attention.backends.xformers.XFormersBackend"
47-
else:
48-
assert backend == "vllm.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
43+
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
44+
False)
45+
assert backend.name == name
4946

5047

5148
def test_flash_attn(monkeypatch):
5249
"""Test FlashAttn validation."""
5350
# TODO: When testing for v1, pipe in `use_v1` as an argument to
54-
# which_attn_to_use
51+
# get_attn_backend
5552

5653
override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)
5754

5855
# Unsupported CUDA arch
5956
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
60-
backend = which_attn_to_use(16, torch.float16, None, 16, False)
61-
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
57+
backend = get_attn_backend(16, torch.float16, None, 16, False)
58+
assert backend.name != STR_FLASH_ATTN_VAL
6259

6360
# Unsupported data type
64-
backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False)
65-
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
61+
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
62+
assert backend.name != STR_FLASH_ATTN_VAL
6663

6764
# Unsupported kv cache data type
68-
backend = which_attn_to_use(16, torch.float16, "fp8", 16, False)
69-
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
65+
backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
66+
assert backend.name != STR_FLASH_ATTN_VAL
7067

7168
# Unsupported block size
72-
backend = which_attn_to_use(16, torch.float16, None, 8, False)
73-
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
69+
backend = get_attn_backend(16, torch.float16, None, 8, False)
70+
assert backend.name != STR_FLASH_ATTN_VAL
7471

7572
# flash-attn is not installed
7673
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
77-
backend = which_attn_to_use(16, torch.float16, None, 16, False)
78-
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
74+
backend = get_attn_backend(16, torch.float16, None, 16, False)
75+
assert backend.name != STR_FLASH_ATTN_VAL
7976

8077
# Unsupported head size
81-
backend = which_attn_to_use(17, torch.float16, None, 16, False)
82-
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
78+
backend = get_attn_backend(17, torch.float16, None, 16, False)
79+
assert backend.name != STR_FLASH_ATTN_VAL
8380

8481
# Attention-free models should bypass env and use PlaceholderAttention
85-
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True)
86-
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
82+
backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
83+
assert backend.name != STR_FLASH_ATTN_VAL
8784

8885

8986
def test_invalid_env(monkeypatch):
9087
"""Throw an exception if the backend name is invalid."""
9188
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
9289
with pytest.raises(ValueError):
93-
which_attn_to_use(16, torch.float16, None, 16, False)
90+
get_attn_backend(16, torch.float16, None, 16, False)

vllm/attention/selector.py

+8-19
Original file line numberDiff line numberDiff line change
@@ -114,25 +114,12 @@ def _cached_get_attn_backend(
114114
BlocksparseFlashAttentionBackend)
115115
return BlocksparseFlashAttentionBackend
116116

117-
attention_cls = which_attn_to_use(head_size, dtype, kv_cache_dtype,
118-
block_size, is_attention_free, use_v1)
119-
assert attention_cls != "", (
120-
f"Invalid attention backend for {current_platform.device_name}")
121-
122-
return resolve_obj_by_qualname(attention_cls)
123-
124-
125-
def which_attn_to_use(head_size: int,
126-
dtype: torch.dtype,
127-
kv_cache_dtype: Optional[str],
128-
block_size: int,
129-
is_attention_free: bool,
130-
use_v1: bool = False) -> str:
131-
"""Returns which flash attention backend to use."""
132117
# If there are no attention layers (e.g. we are running Mamba),
133118
# use the placeholder NO_ATTENTION
134119
if is_attention_free:
135-
return "vllm.attention.backends.placeholder_attn.PlaceholderAttentionBackend" # noqa: E501
120+
from vllm.attention.backends.placeholder_attn import (
121+
PlaceholderAttentionBackend)
122+
return PlaceholderAttentionBackend
136123

137124
# Check whether a particular choice of backend was
138125
# previously forced.
@@ -151,9 +138,11 @@ def which_attn_to_use(head_size: int,
151138
selected_backend = backend_name_to_enum(backend_by_env_var)
152139

153140
# get device-specific attn_backend
154-
return current_platform.get_attn_backend_cls(selected_backend, head_size,
155-
dtype, kv_cache_dtype,
156-
block_size, use_v1)
141+
attention_cls = current_platform.get_attn_backend_cls(
142+
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1)
143+
assert attention_cls != "", (
144+
f"Invalid attention backend for {current_platform.device_name}")
145+
return resolve_obj_by_qualname(attention_cls)
157146

158147

159148
@contextmanager

vllm/platforms/neuron.py

+79-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from typing import TYPE_CHECKING, Optional
22

3+
import torch
4+
35
from vllm.logger import init_logger
46

5-
from .interface import Platform, PlatformEnum
7+
from .interface import Platform, PlatformEnum, _Backend
68

79
if TYPE_CHECKING:
810
from vllm.config import VllmConfig
@@ -43,3 +45,79 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
4345
def is_pin_memory_available(cls) -> bool:
4446
logger.warning("Pin memory is not supported on Neuron.")
4547
return False
48+
49+
@classmethod
50+
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
51+
kv_cache_dtype, block_size, use_v1) -> str:
52+
if use_v1:
53+
logger.info("Using Flash Attention backend on V1 engine.")
54+
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
55+
if selected_backend == _Backend.FLASHINFER:
56+
logger.info("Using FlashInfer backend.")
57+
return "vllm.attention.backends.flashinfer.FlashInferBackend"
58+
elif selected_backend == _Backend.XFORMERS:
59+
logger.info("Using XFormers backend.")
60+
return "vllm.attention.backends.xformers.XFormersBackend"
61+
elif selected_backend == _Backend.FLASH_ATTN:
62+
logger.info("Using FlashAttention backend.")
63+
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
64+
elif selected_backend:
65+
raise ValueError(
66+
f"Invalid attention backend for {cls.device_name}")
67+
68+
target_backend = _Backend.FLASH_ATTN
69+
if not cls.has_device_capability(80):
70+
# Volta and Turing NVIDIA GPUs.
71+
logger.info(
72+
"Cannot use FlashAttention-2 backend for Volta and Turing "
73+
"GPUs.")
74+
target_backend = _Backend.XFORMERS
75+
elif dtype not in (torch.float16, torch.bfloat16):
76+
logger.info(
77+
"Cannot use FlashAttention-2 backend for dtype other than "
78+
"torch.float16 or torch.bfloat16.")
79+
target_backend = _Backend.XFORMERS
80+
elif kv_cache_dtype is not None and \
81+
kv_cache_dtype.startswith("fp8"):
82+
logger.info(
83+
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
84+
logger.warning(
85+
"Please use FlashInfer backend with FP8 KV Cache for "
86+
"better performance by setting environment variable "
87+
"VLLM_ATTENTION_BACKEND=FLASHINFER")
88+
target_backend = _Backend.XFORMERS
89+
elif block_size % 16 != 0:
90+
logger.info(
91+
"Cannot use FlashAttention-2 backend for block size not "
92+
"divisible by 16.")
93+
target_backend = _Backend.XFORMERS
94+
95+
# FlashAttn is valid for the model, checking if the package is
96+
# installed.
97+
if target_backend == _Backend.FLASH_ATTN:
98+
try:
99+
import vllm.vllm_flash_attn # noqa: F401
100+
from vllm.attention.backends.flash_attn import ( # noqa: F401
101+
FlashAttentionBackend)
102+
103+
supported_sizes = \
104+
FlashAttentionBackend.get_supported_head_sizes()
105+
if head_size not in supported_sizes:
106+
logger.info(
107+
"Cannot use FlashAttention-2 backend for head size %d.",
108+
head_size)
109+
target_backend = _Backend.XFORMERS
110+
except ImportError:
111+
logger.info(
112+
"Cannot use FlashAttention-2 backend because the "
113+
"vllm.vllm_flash_attn package is not found. "
114+
"Make sure that vllm_flash_attn was built and installed "
115+
"(on by default).")
116+
target_backend = _Backend.XFORMERS
117+
118+
if target_backend == _Backend.XFORMERS:
119+
logger.info("Using XFormers backend.")
120+
return "vllm.attention.backends.xformers.XFormersBackend"
121+
122+
logger.info("Using Flash Attention backend.")
123+
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"

0 commit comments

Comments
 (0)