9
9
from vllm .attention .backends .abstract import AttentionBackend
10
10
from vllm .logger import init_logger
11
11
from vllm .platforms import _Backend , current_platform
12
- from vllm .utils import STR_BACKEND_ENV_VAR
12
+ from vllm .utils import STR_BACKEND_ENV_VAR , resolve_obj_by_qualname
13
13
14
14
logger = init_logger (__name__ )
15
15
@@ -114,83 +114,32 @@ def _cached_get_attn_backend(
114
114
BlocksparseFlashAttentionBackend )
115
115
return BlocksparseFlashAttentionBackend
116
116
117
- backend = which_attn_to_use (head_size , dtype , kv_cache_dtype , block_size ,
118
- is_attention_free , use_v1 )
119
- if backend == _Backend .FLASH_ATTN :
120
- logger .info ("Using Flash Attention backend." )
121
- from vllm .attention .backends .flash_attn import ( # noqa: F401
122
- FlashAttentionBackend )
123
- return FlashAttentionBackend
124
- if backend == _Backend .FLASH_ATTN_VLLM_V1 :
125
- from vllm .v1 .attention .backends .flash_attn import ( # noqa: F401
126
- FlashAttentionBackend as FlashAttentionBackendV1 )
127
- return FlashAttentionBackendV1
128
- if backend == _Backend .XFORMERS :
129
- logger .info ("Using XFormers backend." )
130
- from vllm .attention .backends .xformers import ( # noqa: F401
131
- XFormersBackend )
132
- return XFormersBackend
133
- elif backend == _Backend .ROCM_FLASH :
134
- logger .info ("Using ROCmFlashAttention backend." )
135
- from vllm .attention .backends .rocm_flash_attn import ( # noqa: F401
136
- ROCmFlashAttentionBackend )
137
- return ROCmFlashAttentionBackend
138
- elif backend == _Backend .TORCH_SDPA :
139
- assert current_platform .is_cpu (), RuntimeError (
140
- "Torch SDPA backend is only used for the CPU device." )
141
- logger .info ("Using Torch SDPA backend." )
142
- from vllm .attention .backends .torch_sdpa import TorchSDPABackend
143
- return TorchSDPABackend
144
- elif backend == _Backend .OPENVINO :
145
- logger .info ("Using OpenVINO Attention backend." )
146
- from vllm .attention .backends .openvino import OpenVINOAttentionBackend
147
- return OpenVINOAttentionBackend
148
- elif backend == _Backend .IPEX :
149
- assert current_platform .is_xpu (), RuntimeError (
150
- "IPEX attention backend is only used for the XPU device." )
151
- logger .info ("Using IPEX attention backend." )
152
- from vllm .attention .backends .ipex_attn import IpexAttnBackend
153
- return IpexAttnBackend
154
- elif backend == _Backend .FLASHINFER :
155
- logger .info ("Using Flashinfer backend." )
156
- from vllm .attention .backends .flashinfer import FlashInferBackend
157
- return FlashInferBackend
158
- elif backend == _Backend .HPU_ATTN :
159
- logger .info ("Using HPUAttention backend." )
160
- from vllm .attention .backends .hpu_attn import HPUAttentionBackend
161
- return HPUAttentionBackend
162
- elif backend == _Backend .PALLAS :
163
- logger .info ("Using Pallas backend." )
164
- from vllm .attention .backends .pallas import PallasAttentionBackend
165
- return PallasAttentionBackend
166
- elif backend == _Backend .NO_ATTENTION :
167
- from vllm .attention .backends .placeholder_attn import (
168
- PlaceholderAttentionBackend )
169
- return PlaceholderAttentionBackend
170
- else :
171
- raise ValueError ("Invalid attention backend." )
117
+ attention_cls = which_attn_to_use (head_size , dtype , kv_cache_dtype ,
118
+ block_size , is_attention_free , use_v1 )
119
+ assert attention_cls != "" , (
120
+ f"Invalid attention backend for { current_platform .device_name } " )
121
+
122
+ return resolve_obj_by_qualname (attention_cls )
172
123
173
124
174
125
def which_attn_to_use (head_size : int ,
175
126
dtype : torch .dtype ,
176
127
kv_cache_dtype : Optional [str ],
177
128
block_size : int ,
178
129
is_attention_free : bool ,
179
- use_v1 : bool = False ) -> _Backend :
130
+ use_v1 : bool = False ) -> str :
180
131
"""Returns which flash attention backend to use."""
181
- # Default case.
182
- selected_backend = _Backend .FLASH_ATTN
183
-
184
132
# If there are no attention layers (e.g. we are running Mamba),
185
133
# use the placeholder NO_ATTENTION
186
134
if is_attention_free :
187
- return _Backend . NO_ATTENTION
135
+ return "vllm.attention.backends.placeholder_attn.PlaceholderAttentionBackend" # noqa: E501
188
136
189
137
# Check whether a particular choice of backend was
190
138
# previously forced.
191
139
#
192
140
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
193
141
# ENVIRONMENT VARIABLE.
142
+ selected_backend = None
194
143
backend_by_global_setting : Optional [_Backend ] = (
195
144
get_global_forced_attn_backend ())
196
145
if backend_by_global_setting is not None :
@@ -201,64 +150,10 @@ def which_attn_to_use(head_size: int,
201
150
if backend_by_env_var is not None :
202
151
selected_backend = backend_name_to_enum (backend_by_env_var )
203
152
204
- # get device-specific default attn_backend
205
- default_backend = current_platform .get_default_attn_backend (
206
- selected_backend )
207
- if default_backend is not None :
208
- return default_backend
209
-
210
- if use_v1 :
211
- return _Backend .FLASH_ATTN_VLLM_V1
212
-
213
- # FlashAttn in NVIDIA GPUs.
214
- if selected_backend == _Backend .FLASH_ATTN :
215
- if not current_platform .has_device_capability (80 ):
216
- # Volta and Turing NVIDIA GPUs.
217
- logger .info (
218
- "Cannot use FlashAttention-2 backend for Volta and Turing "
219
- "GPUs." )
220
- selected_backend = _Backend .XFORMERS
221
- elif dtype not in (torch .float16 , torch .bfloat16 ):
222
- logger .info (
223
- "Cannot use FlashAttention-2 backend for dtype other than "
224
- "torch.float16 or torch.bfloat16." )
225
- selected_backend = _Backend .XFORMERS
226
- elif kv_cache_dtype is not None and kv_cache_dtype .startswith ("fp8" ):
227
- logger .info (
228
- "Cannot use FlashAttention-2 backend for FP8 KV cache." )
229
- logger .warning (
230
- "Please use FlashInfer backend with FP8 KV Cache for "
231
- "better performance by setting environment variable "
232
- "VLLM_ATTENTION_BACKEND=FLASHINFER" )
233
- selected_backend = _Backend .XFORMERS
234
- elif block_size % 16 != 0 :
235
- logger .info (
236
- "Cannot use FlashAttention-2 backend for block size not "
237
- "divisible by 16." )
238
- selected_backend = _Backend .XFORMERS
239
-
240
- # FlashAttn is valid for the model, checking if the package is installed.
241
- if selected_backend == _Backend .FLASH_ATTN :
242
- try :
243
- import vllm .vllm_flash_attn # noqa: F401
244
- from vllm .attention .backends .flash_attn import ( # noqa: F401
245
- FlashAttentionBackend )
246
-
247
- supported_sizes = FlashAttentionBackend .get_supported_head_sizes ()
248
- if head_size not in supported_sizes :
249
- logger .info (
250
- "Cannot use FlashAttention-2 backend for head size %d." ,
251
- head_size )
252
- selected_backend = _Backend .XFORMERS
253
- except ImportError :
254
- logger .info (
255
- "Cannot use FlashAttention-2 backend because the "
256
- "vllm.vllm_flash_attn package is not found. "
257
- "Make sure that vllm_flash_attn was built and installed "
258
- "(on by default)." )
259
- selected_backend = _Backend .XFORMERS
260
-
261
- return selected_backend
153
+ # get device-specific attn_backend
154
+ return current_platform .get_attn_backend_cls (selected_backend , head_size ,
155
+ dtype , kv_cache_dtype ,
156
+ block_size , use_v1 )
262
157
263
158
264
159
@contextmanager
0 commit comments