|
4 | 4 | import torch
|
5 | 5 |
|
6 | 6 | from tests.kernels.utils import override_backend_env_variable
|
7 |
| -from vllm.attention.selector import which_attn_to_use |
| 7 | +from vllm.attention.selector import get_attn_backend |
8 | 8 | from vllm.platforms.cpu import CpuPlatform
|
9 | 9 | from vllm.platforms.cuda import CudaPlatform
|
10 | 10 | from vllm.platforms.openvino import OpenVinoPlatform
|
|
16 | 16 | "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
|
17 | 17 | @pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
|
18 | 18 | def test_env(name: str, device: str, monkeypatch):
|
19 |
| - """Test that the attention selector can be set via environment variable.""" |
| 19 | + """Test that the attention selector can be set via environment variable. |
| 20 | + Note that we do not test FlashAttn because it is the default backend. |
| 21 | + """ |
20 | 22 |
|
21 | 23 | override_backend_env_variable(monkeypatch, name)
|
22 | 24 |
|
23 | 25 | if device == "cpu":
|
24 | 26 | with patch("vllm.attention.selector.current_platform", CpuPlatform()):
|
25 |
| - backend = which_attn_to_use(16, torch.float16, torch.float16, 16, |
26 |
| - False) |
27 |
| - assert backend == "vllm.attention.backends.torch_sdpa.TorchSDPABackend" |
| 27 | + backend = get_attn_backend(16, torch.float16, torch.float16, 16, |
| 28 | + False) |
| 29 | + assert backend.name == "TORCH_SDPA" |
28 | 30 | elif device == "hip":
|
29 | 31 | with patch("vllm.attention.selector.current_platform", RocmPlatform()):
|
30 |
| - backend = which_attn_to_use(16, torch.float16, torch.float16, 16, |
31 |
| - False) |
32 |
| - assert backend == "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501 |
| 32 | + backend = get_attn_backend(16, torch.float16, torch.float16, 16, |
| 33 | + False) |
| 34 | + assert backend.name == "ROCM_FLASH" |
33 | 35 | elif device == "openvino":
|
34 | 36 | with patch("vllm.attention.selector.current_platform",
|
35 | 37 | OpenVinoPlatform()):
|
36 |
| - backend = which_attn_to_use(16, torch.float16, torch.float16, 16, |
37 |
| - False) |
38 |
| - assert backend == "vllm.attention.backends.openvino.OpenVINOAttentionBackend" # noqa: E501 |
| 38 | + backend = get_attn_backend(16, torch.float16, torch.float16, 16, |
| 39 | + False) |
| 40 | + assert backend.name == "OPENVINO" |
39 | 41 | else:
|
40 | 42 | with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
41 |
| - backend = which_attn_to_use(16, torch.float16, torch.float16, 16, |
42 |
| - False) |
43 |
| - if name == "FLASHINFER": |
44 |
| - assert backend == "vllm.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 |
45 |
| - if name == "XFORMERS": |
46 |
| - assert backend == "vllm.attention.backends.xformers.XFormersBackend" |
47 |
| - else: |
48 |
| - assert backend == "vllm.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 |
| 43 | + backend = get_attn_backend(16, torch.float16, torch.float16, 16, |
| 44 | + False) |
| 45 | + assert backend.name == name |
49 | 46 |
|
50 | 47 |
|
51 | 48 | def test_flash_attn(monkeypatch):
|
52 | 49 | """Test FlashAttn validation."""
|
53 | 50 | # TODO: When testing for v1, pipe in `use_v1` as an argument to
|
54 |
| - # which_attn_to_use |
| 51 | + # get_attn_backend |
55 | 52 |
|
56 | 53 | override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)
|
57 | 54 |
|
58 | 55 | # Unsupported CUDA arch
|
59 | 56 | with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
|
60 |
| - backend = which_attn_to_use(16, torch.float16, None, 16, False) |
61 |
| - assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 |
| 57 | + backend = get_attn_backend(16, torch.float16, None, 16, False) |
| 58 | + assert backend.name != STR_FLASH_ATTN_VAL |
62 | 59 |
|
63 | 60 | # Unsupported data type
|
64 |
| - backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False) |
65 |
| - assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" |
| 61 | + backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False) |
| 62 | + assert backend.name != STR_FLASH_ATTN_VAL |
66 | 63 |
|
67 | 64 | # Unsupported kv cache data type
|
68 |
| - backend = which_attn_to_use(16, torch.float16, "fp8", 16, False) |
69 |
| - assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" |
| 65 | + backend = get_attn_backend(16, torch.float16, "fp8", 16, False) |
| 66 | + assert backend.name != STR_FLASH_ATTN_VAL |
70 | 67 |
|
71 | 68 | # Unsupported block size
|
72 |
| - backend = which_attn_to_use(16, torch.float16, None, 8, False) |
73 |
| - assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" |
| 69 | + backend = get_attn_backend(16, torch.float16, None, 8, False) |
| 70 | + assert backend.name != STR_FLASH_ATTN_VAL |
74 | 71 |
|
75 | 72 | # flash-attn is not installed
|
76 | 73 | with patch.dict('sys.modules', {'vllm_flash_attn': None}):
|
77 |
| - backend = which_attn_to_use(16, torch.float16, None, 16, False) |
78 |
| - assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 |
| 74 | + backend = get_attn_backend(16, torch.float16, None, 16, False) |
| 75 | + assert backend.name != STR_FLASH_ATTN_VAL |
79 | 76 |
|
80 | 77 | # Unsupported head size
|
81 |
| - backend = which_attn_to_use(17, torch.float16, None, 16, False) |
82 |
| - assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" |
| 78 | + backend = get_attn_backend(17, torch.float16, None, 16, False) |
| 79 | + assert backend.name != STR_FLASH_ATTN_VAL |
83 | 80 |
|
84 | 81 | # Attention-free models should bypass env and use PlaceholderAttention
|
85 |
| - backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True) |
86 |
| - assert backend != "vllm.attention.backends.flash_attn.FlashAttentionBackend" |
| 82 | + backend = get_attn_backend(16, torch.float16, torch.float16, 16, True) |
| 83 | + assert backend.name != STR_FLASH_ATTN_VAL |
87 | 84 |
|
88 | 85 |
|
89 | 86 | def test_invalid_env(monkeypatch):
|
90 | 87 | """Throw an exception if the backend name is invalid."""
|
91 | 88 | override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
|
92 | 89 | with pytest.raises(ValueError):
|
93 |
| - which_attn_to_use(16, torch.float16, None, 16, False) |
| 90 | + get_attn_backend(16, torch.float16, None, 16, False) |
0 commit comments