Skip to content

Commit 1909074

Browse files
committed
Merge branch 'patch_3' into apply_plugin
vllm-project#11609
2 parents bdc342f + 6fb8478 commit 1909074

File tree

10 files changed

+142
-149
lines changed

10 files changed

+142
-149
lines changed

tests/kernels/test_attention_selector.py

+17-14
Original file line numberDiff line numberDiff line change
@@ -16,33 +16,36 @@
1616
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
1717
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
1818
def test_env(name: str, device: str, monkeypatch):
19-
"""Test that the attention selector can be set via environment variable.
20-
Note that we do not test FlashAttn because it is the default backend.
21-
"""
19+
"""Test that the attention selector can be set via environment variable."""
2220

2321
override_backend_env_variable(monkeypatch, name)
2422

2523
if device == "cpu":
2624
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
2725
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
2826
False)
29-
assert backend.name == "TORCH_SDPA"
27+
assert backend == "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
3028
elif device == "hip":
3129
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
3230
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
3331
False)
34-
assert backend.name == "ROCM_FLASH"
32+
assert backend == "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501
3533
elif device == "openvino":
3634
with patch("vllm.attention.selector.current_platform",
3735
OpenVinoPlatform()):
3836
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
3937
False)
40-
assert backend.name == "OPENVINO"
38+
assert backend == "vllm.attention.backends.openvino.OpenVINOAttentionBackend" # noqa: E501
4139
else:
4240
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
4341
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
4442
False)
45-
assert backend.name == name
43+
if name == "FLASHINFER":
44+
assert backend == "vllm.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
45+
if name == "XFORMERS":
46+
assert backend == "vllm.attention.backends.xformers.XFormersBackend"
47+
else:
48+
assert backend == "vllm.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
4649

4750

4851
def test_flash_attn(monkeypatch):
@@ -55,32 +58,32 @@ def test_flash_attn(monkeypatch):
5558
# Unsupported CUDA arch
5659
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
5760
backend = which_attn_to_use(16, torch.float16, None, 16, False)
58-
assert backend.name != STR_FLASH_ATTN_VAL
61+
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
5962

6063
# Unsupported data type
6164
backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False)
62-
assert backend.name != STR_FLASH_ATTN_VAL
65+
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
6366

6467
# Unsupported kv cache data type
6568
backend = which_attn_to_use(16, torch.float16, "fp8", 16, False)
66-
assert backend.name != STR_FLASH_ATTN_VAL
69+
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
6770

6871
# Unsupported block size
6972
backend = which_attn_to_use(16, torch.float16, None, 8, False)
70-
assert backend.name != STR_FLASH_ATTN_VAL
73+
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
7174

7275
# flash-attn is not installed
7376
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
7477
backend = which_attn_to_use(16, torch.float16, None, 16, False)
75-
assert backend.name != STR_FLASH_ATTN_VAL
78+
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
7679

7780
# Unsupported head size
7881
backend = which_attn_to_use(17, torch.float16, None, 16, False)
79-
assert backend.name != STR_FLASH_ATTN_VAL
82+
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
8083

8184
# Attention-free models should bypass env and use PlaceholderAttention
8285
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True)
83-
assert backend.name != STR_FLASH_ATTN_VAL
86+
assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend"
8487

8588

8689
def test_invalid_env(monkeypatch):

vllm/attention/selector.py

+14-119
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from vllm.attention.backends.abstract import AttentionBackend
1010
from vllm.logger import init_logger
1111
from vllm.platforms import _Backend, current_platform
12-
from vllm.utils import STR_BACKEND_ENV_VAR
12+
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
1313

1414
logger = init_logger(__name__)
1515

@@ -114,83 +114,32 @@ def _cached_get_attn_backend(
114114
BlocksparseFlashAttentionBackend)
115115
return BlocksparseFlashAttentionBackend
116116

117-
backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
118-
is_attention_free, use_v1)
119-
if backend == _Backend.FLASH_ATTN:
120-
logger.info("Using Flash Attention backend.")
121-
from vllm.attention.backends.flash_attn import ( # noqa: F401
122-
FlashAttentionBackend)
123-
return FlashAttentionBackend
124-
if backend == _Backend.FLASH_ATTN_VLLM_V1:
125-
from vllm.v1.attention.backends.flash_attn import ( # noqa: F401
126-
FlashAttentionBackend as FlashAttentionBackendV1)
127-
return FlashAttentionBackendV1
128-
if backend == _Backend.XFORMERS:
129-
logger.info("Using XFormers backend.")
130-
from vllm.attention.backends.xformers import ( # noqa: F401
131-
XFormersBackend)
132-
return XFormersBackend
133-
elif backend == _Backend.ROCM_FLASH:
134-
logger.info("Using ROCmFlashAttention backend.")
135-
from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401
136-
ROCmFlashAttentionBackend)
137-
return ROCmFlashAttentionBackend
138-
elif backend == _Backend.TORCH_SDPA:
139-
assert current_platform.is_cpu(), RuntimeError(
140-
"Torch SDPA backend is only used for the CPU device.")
141-
logger.info("Using Torch SDPA backend.")
142-
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
143-
return TorchSDPABackend
144-
elif backend == _Backend.OPENVINO:
145-
logger.info("Using OpenVINO Attention backend.")
146-
from vllm.attention.backends.openvino import OpenVINOAttentionBackend
147-
return OpenVINOAttentionBackend
148-
elif backend == _Backend.IPEX:
149-
assert current_platform.is_xpu(), RuntimeError(
150-
"IPEX attention backend is only used for the XPU device.")
151-
logger.info("Using IPEX attention backend.")
152-
from vllm.attention.backends.ipex_attn import IpexAttnBackend
153-
return IpexAttnBackend
154-
elif backend == _Backend.FLASHINFER:
155-
logger.info("Using Flashinfer backend.")
156-
from vllm.attention.backends.flashinfer import FlashInferBackend
157-
return FlashInferBackend
158-
elif backend == _Backend.HPU_ATTN:
159-
logger.info("Using HPUAttention backend.")
160-
from vllm.attention.backends.hpu_attn import HPUAttentionBackend
161-
return HPUAttentionBackend
162-
elif backend == _Backend.PALLAS:
163-
logger.info("Using Pallas backend.")
164-
from vllm.attention.backends.pallas import PallasAttentionBackend
165-
return PallasAttentionBackend
166-
elif backend == _Backend.NO_ATTENTION:
167-
from vllm.attention.backends.placeholder_attn import (
168-
PlaceholderAttentionBackend)
169-
return PlaceholderAttentionBackend
170-
else:
171-
raise ValueError("Invalid attention backend.")
117+
attention_cls = which_attn_to_use(head_size, dtype, kv_cache_dtype,
118+
block_size, is_attention_free, use_v1)
119+
assert attention_cls != "", (
120+
f"Invalid attention backend for {current_platform.device_name}")
121+
122+
return resolve_obj_by_qualname(attention_cls)
172123

173124

174125
def which_attn_to_use(head_size: int,
175126
dtype: torch.dtype,
176127
kv_cache_dtype: Optional[str],
177128
block_size: int,
178129
is_attention_free: bool,
179-
use_v1: bool = False) -> _Backend:
130+
use_v1: bool = False) -> str:
180131
"""Returns which flash attention backend to use."""
181-
# Default case.
182-
selected_backend = _Backend.FLASH_ATTN
183-
184132
# If there are no attention layers (e.g. we are running Mamba),
185133
# use the placeholder NO_ATTENTION
186134
if is_attention_free:
187-
return _Backend.NO_ATTENTION
135+
return "vllm.attention.backends.placeholder_attn.PlaceholderAttentionBackend" # noqa: E501
188136

189137
# Check whether a particular choice of backend was
190138
# previously forced.
191139
#
192140
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
193141
# ENVIRONMENT VARIABLE.
142+
selected_backend = None
194143
backend_by_global_setting: Optional[_Backend] = (
195144
get_global_forced_attn_backend())
196145
if backend_by_global_setting is not None:
@@ -201,64 +150,10 @@ def which_attn_to_use(head_size: int,
201150
if backend_by_env_var is not None:
202151
selected_backend = backend_name_to_enum(backend_by_env_var)
203152

204-
# get device-specific default attn_backend
205-
default_backend = current_platform.get_default_attn_backend(
206-
selected_backend)
207-
if default_backend is not None:
208-
return default_backend
209-
210-
if use_v1:
211-
return _Backend.FLASH_ATTN_VLLM_V1
212-
213-
# FlashAttn in NVIDIA GPUs.
214-
if selected_backend == _Backend.FLASH_ATTN:
215-
if not current_platform.has_device_capability(80):
216-
# Volta and Turing NVIDIA GPUs.
217-
logger.info(
218-
"Cannot use FlashAttention-2 backend for Volta and Turing "
219-
"GPUs.")
220-
selected_backend = _Backend.XFORMERS
221-
elif dtype not in (torch.float16, torch.bfloat16):
222-
logger.info(
223-
"Cannot use FlashAttention-2 backend for dtype other than "
224-
"torch.float16 or torch.bfloat16.")
225-
selected_backend = _Backend.XFORMERS
226-
elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
227-
logger.info(
228-
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
229-
logger.warning(
230-
"Please use FlashInfer backend with FP8 KV Cache for "
231-
"better performance by setting environment variable "
232-
"VLLM_ATTENTION_BACKEND=FLASHINFER")
233-
selected_backend = _Backend.XFORMERS
234-
elif block_size % 16 != 0:
235-
logger.info(
236-
"Cannot use FlashAttention-2 backend for block size not "
237-
"divisible by 16.")
238-
selected_backend = _Backend.XFORMERS
239-
240-
# FlashAttn is valid for the model, checking if the package is installed.
241-
if selected_backend == _Backend.FLASH_ATTN:
242-
try:
243-
import vllm.vllm_flash_attn # noqa: F401
244-
from vllm.attention.backends.flash_attn import ( # noqa: F401
245-
FlashAttentionBackend)
246-
247-
supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
248-
if head_size not in supported_sizes:
249-
logger.info(
250-
"Cannot use FlashAttention-2 backend for head size %d.",
251-
head_size)
252-
selected_backend = _Backend.XFORMERS
253-
except ImportError:
254-
logger.info(
255-
"Cannot use FlashAttention-2 backend because the "
256-
"vllm.vllm_flash_attn package is not found. "
257-
"Make sure that vllm_flash_attn was built and installed "
258-
"(on by default).")
259-
selected_backend = _Backend.XFORMERS
260-
261-
return selected_backend
153+
# get device-specific attn_backend
154+
return current_platform.get_attn_backend_cls(selected_backend, head_size,
155+
dtype, kv_cache_dtype,
156+
block_size, use_v1)
262157

263158

264159
@contextmanager

vllm/platforms/cpu.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,13 @@ def get_device_name(cls, device_id: int = 0) -> str:
2828
return "cpu"
2929

3030
@classmethod
31-
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
31+
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
32+
dtype: torch.dtype, kv_cache_dtype: Optional[str],
33+
block_size: int, use_v1: bool) -> str:
3234
if selected_backend != _Backend.TORCH_SDPA:
3335
logger.info("Cannot use %s backend on CPU.", selected_backend)
34-
return _Backend.TORCH_SDPA
36+
logger.info("Using Torch SDPA backend.")
37+
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
3538

3639
@classmethod
3740
def get_device_total_memory(cls, device_id: int = 0) -> int:

vllm/platforms/cuda.py

+77-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import vllm.envs as envs
1717
from vllm.logger import init_logger
1818

19-
from .interface import DeviceCapability, Platform, PlatformEnum
19+
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
2020

2121
if TYPE_CHECKING:
2222
from vllm.config import VllmConfig
@@ -141,6 +141,82 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
141141
if cache_config and cache_config.block_size is None:
142142
cache_config.block_size = 16
143143

144+
@classmethod
145+
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
146+
kv_cache_dtype, block_size, use_v1) -> str:
147+
if use_v1:
148+
logger.info("Using Flash Attention backend on V1 engine.")
149+
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
150+
if selected_backend == _Backend.FLASHINFER:
151+
logger.info("Using FlashInfer backend.")
152+
return "vllm.attention.backends.flashinfer.FlashInferBackend"
153+
elif selected_backend == _Backend.XFORMERS:
154+
logger.info("Using XFormers backend.")
155+
return "vllm.attention.backends.xformers.XFormersBackend"
156+
elif selected_backend == _Backend.FLASH_ATTN:
157+
logger.info("Using FlashAttention backend.")
158+
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
159+
elif selected_backend:
160+
raise ValueError(
161+
f"Invalid attention backend for {cls.device_name}")
162+
163+
target_backend = _Backend.FLASH_ATTN
164+
if not cls.has_device_capability(80):
165+
# Volta and Turing NVIDIA GPUs.
166+
logger.info(
167+
"Cannot use FlashAttention-2 backend for Volta and Turing "
168+
"GPUs.")
169+
target_backend = _Backend.XFORMERS
170+
elif dtype not in (torch.float16, torch.bfloat16):
171+
logger.info(
172+
"Cannot use FlashAttention-2 backend for dtype other than "
173+
"torch.float16 or torch.bfloat16.")
174+
target_backend = _Backend.XFORMERS
175+
elif kv_cache_dtype is not None and \
176+
kv_cache_dtype.startswith("fp8"):
177+
logger.info(
178+
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
179+
logger.warning(
180+
"Please use FlashInfer backend with FP8 KV Cache for "
181+
"better performance by setting environment variable "
182+
"VLLM_ATTENTION_BACKEND=FLASHINFER")
183+
target_backend = _Backend.XFORMERS
184+
elif block_size % 16 != 0:
185+
logger.info(
186+
"Cannot use FlashAttention-2 backend for block size not "
187+
"divisible by 16.")
188+
target_backend = _Backend.XFORMERS
189+
190+
# FlashAttn is valid for the model, checking if the package is
191+
# installed.
192+
if target_backend == _Backend.FLASH_ATTN:
193+
try:
194+
import vllm.vllm_flash_attn # noqa: F401
195+
from vllm.attention.backends.flash_attn import ( # noqa: F401
196+
FlashAttentionBackend)
197+
198+
supported_sizes = \
199+
FlashAttentionBackend.get_supported_head_sizes()
200+
if head_size not in supported_sizes:
201+
logger.info(
202+
"Cannot use FlashAttention-2 backend for head size %d.",
203+
head_size)
204+
target_backend = _Backend.XFORMERS
205+
except ImportError:
206+
logger.info(
207+
"Cannot use FlashAttention-2 backend because the "
208+
"vllm.vllm_flash_attn package is not found. "
209+
"Make sure that vllm_flash_attn was built and installed "
210+
"(on by default).")
211+
target_backend = _Backend.XFORMERS
212+
213+
if target_backend == _Backend.XFORMERS:
214+
logger.info("Using XFormers backend.")
215+
return "vllm.attention.backends.xformers.XFormersBackend"
216+
217+
logger.info("Using Flash Attention backend.")
218+
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
219+
144220

145221
# NVML utils
146222
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

vllm/platforms/hpu.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@ class HpuPlatform(Platform):
2121
dispatch_key: str = "HPU"
2222

2323
@classmethod
24-
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
25-
return _Backend.HPU_ATTN
24+
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
25+
dtype: torch.dtype, kv_cache_dtype: Optional[str],
26+
block_size: int, use_v1: bool) -> str:
27+
logger.info("Using HPUAttention backend.")
28+
return "vllm.attention.backends.hpu_attn.HPUAttentionBackend"
2629

2730
@classmethod
2831
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:

vllm/platforms/interface.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,11 @@ def is_cuda_alike(self) -> bool:
115115
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
116116

117117
@classmethod
118-
def get_default_attn_backend(cls, selected_backend: _Backend):
119-
"""Get the default attention backend of a device."""
120-
return None
118+
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
119+
dtype: torch.dtype, kv_cache_dtype: Optional[str],
120+
block_size: int, use_v1: bool) -> str:
121+
"""Get the attention backend class of a device."""
122+
return ""
121123

122124
@classmethod
123125
def get_device_capability(

0 commit comments

Comments
 (0)