From 54bd9a03c4b2da0fd0b0e17b0552bbb0d517a081 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 15 Aug 2024 22:38:56 -0700 Subject: [PATCH] register custom op for flash attn and use from torch.ops (#7536) --- .buildkite/test-pipeline.yaml | 7 ++ tests/compile/test_full_graph.py | 20 ++++ tests/kernels/test_flash_attn.py | 73 ++++++++++-- vllm/attention/backends/flash_attn.py | 155 +++++++++++++++++++++----- vllm/attention/backends/flashinfer.py | 6 +- 5 files changed, 220 insertions(+), 41 deletions(-) create mode 100644 tests/compile/test_full_graph.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 8c0fc8e05a33e..264b7f58ad1ac 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -163,6 +163,13 @@ steps: - pytest -v -s models/test_oot_registration.py # it needs a clean process - pytest -v -s models -m \"not vlm\" --ignore=models/test_oot_registration.py +- label: torch compile integration test + source_file_dependencies: + - vllm/ + commands: + - pytest -v -s ./compile/test_full_graph.py + + - label: Vision Language Models Test # 42min mirror_hardwares: [amd] source_file_dependencies: diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py new file mode 100644 index 0000000000000..d5b59db8c7887 --- /dev/null +++ b/tests/compile/test_full_graph.py @@ -0,0 +1,20 @@ +import os + +import pytest + + +@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) +def test_full_graph(model): + # make sure these models can be captured in full graph mode + os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" + + from vllm import LLM, SamplingParams + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0) + llm = LLM(model="meta-llama/Meta-Llama-3-8B") + llm.generate(prompts, sampling_params) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 779f340b6398e..870a8bf65eb92 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -2,13 +2,16 @@ import pytest import torch -from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache -NUM_HEADS = [(16, 16), (32, 8), (64, 8)] +import vllm.attention.backends.flash_attn # noqa: F401 + +NUM_HEADS = [(4, 4), (8, 2), (16, 2)] HEAD_SIZES = [128, 256] BLOCK_SIZES = [16, 32] DTYPES = [torch.float16, torch.bfloat16] -NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. +# one value large enough to test overflow in index calculation. +# one value small enough to test the schema op check +NUM_BLOCKS = [32768, 2048] def ref_paged_attn( @@ -72,6 +75,7 @@ def ref_paged_attn( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @torch.inference_mode() def test_flash_attn_with_paged_kv( kv_lens: List[int], @@ -80,6 +84,7 @@ def test_flash_attn_with_paged_kv( dtype: torch.dtype, block_size: int, soft_cap: Optional[float], + num_blocks: int, ) -> None: torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) @@ -91,7 +96,7 @@ def test_flash_attn_with_paged_kv( scale = head_size**-0.5 query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - key_cache = torch.randn(NUM_BLOCKS, + key_cache = torch.randn(num_blocks, block_size, num_kv_heads, head_size, @@ -101,14 +106,14 @@ def test_flash_attn_with_paged_kv( max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size block_tables = torch.randint(0, - NUM_BLOCKS, + num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) - output = flash_attn_with_kvcache( - q=query.unsqueeze(1), - k_cache=key_cache, - v_cache=value_cache, + output = torch.ops.vllm.flash_attn_with_kvcache( + decode_query=query.unsqueeze(1), + key_cache=key_cache, + value_cache=value_cache, softmax_scale=scale, causal=True, block_table=block_tables, @@ -116,6 +121,25 @@ def test_flash_attn_with_paged_kv( softcap=soft_cap if soft_cap is not None else 0, ).squeeze(1) + if num_blocks <= 2048: + test_utils = ["test_faketensor", "test_schema"] + else: + test_utils = ["test_faketensor"] + + torch.library.opcheck(torch.ops.vllm.flash_attn_with_kvcache, + args=tuple(), + kwargs=dict( + decode_query=query.unsqueeze(1), + key_cache=key_cache, + value_cache=value_cache, + softmax_scale=scale, + causal=True, + block_table=block_tables, + cache_seqlens=kv_lens_tensor, + softcap=soft_cap if soft_cap is not None else 0, + ), + test_utils=test_utils) + ref_output = ref_paged_attn( query=query, key_cache=key_cache, @@ -137,6 +161,7 @@ def test_flash_attn_with_paged_kv( @pytest.mark.parametrize("sliding_window", [None]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @torch.inference_mode() def test_varlen_with_paged_kv( seq_lens: List[Tuple[int, int]], @@ -146,6 +171,7 @@ def test_varlen_with_paged_kv( dtype: torch.dtype, block_size: int, soft_cap: Optional[float], + num_blocks: int, ) -> None: torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) @@ -166,7 +192,7 @@ def test_varlen_with_paged_kv( num_query_heads, head_size, dtype=dtype) - key_cache = torch.randn(NUM_BLOCKS, + key_cache = torch.randn(num_blocks, block_size, num_kv_heads, head_size, @@ -181,11 +207,11 @@ def test_varlen_with_paged_kv( max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size block_tables = torch.randint(0, - NUM_BLOCKS, + num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) - output = flash_attn_varlen_func( + output = torch.ops.vllm.flash_attn_varlen_func( q=query, k=key_cache, v=value_cache, @@ -200,6 +226,29 @@ def test_varlen_with_paged_kv( softcap=soft_cap if soft_cap is not None else 0, ) + if num_blocks <= 2048: + test_utils = ["test_faketensor", "test_schema"] + else: + test_utils = ["test_faketensor"] + + torch.library.opcheck(torch.ops.vllm.flash_attn_varlen_func, + args=tuple(), + kwargs=dict( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_query_lens, + cu_seqlens_k=cu_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=window_size, + block_table=block_tables, + softcap=soft_cap if soft_cap is not None else 0, + ), + test_utils=test_utils) + ref_output = ref_paged_attn( query=query, key_cache=key_cache, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 160bf2307fbf5..f230bb57e3177 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch -from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -18,6 +17,108 @@ if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUBuilder +from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func +from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache + + +@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[]) +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size: Optional[List[int]] = None, + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + block_table: Optional[torch.Tensor] = None, +) -> torch.Tensor: + # custom op does not support tuple input + real_window_size: Tuple[int, int] + if window_size is None: + real_window_size = (-1, -1) + else: + assert len(window_size) == 2 + real_window_size = (window_size[0], window_size[1]) + return _flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=softmax_scale, + causal=causal, + window_size=real_window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + block_table=block_table, + ) + + +@flash_attn_varlen_func.register_fake # type: ignore +def _( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size: Optional[List[int]] = None, + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + block_table: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return torch.empty_like(q) + + +@torch.library.custom_op("vllm::flash_attn_with_kvcache", mutates_args=[]) +def flash_attn_with_kvcache( + decode_query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + cache_seqlens: Optional[torch.Tensor] = None, + block_table: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + alibi_slopes: Optional[torch.Tensor] = None, + softcap: float = 0.0, +) -> torch.Tensor: + return _flash_attn_with_kvcache( + decode_query, + key_cache, + value_cache, + cache_seqlens=cache_seqlens, + block_table=block_table, + softmax_scale=softmax_scale, + causal=causal, + alibi_slopes=alibi_slopes, + softcap=softcap, + ) + + +@flash_attn_with_kvcache.register_fake # type: ignore +def _( + decode_query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + cache_seqlens: Optional[torch.Tensor] = None, + block_table: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + alibi_slopes: Optional[torch.Tensor] = None, + softcap: float = 0.0, +) -> torch.Tensor: + return torch.empty_like(decode_query) + class FlashAttentionBackend(AttentionBackend): @@ -517,7 +618,7 @@ def forward( # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - out = flash_attn_varlen_func( + out = torch.ops.vllm.flash_attn_varlen_func( q=query, k=key, v=value, @@ -537,34 +638,36 @@ def forward( # prefix-enabled attention assert prefill_meta.seq_lens is not None max_seq_len = max(prefill_meta.seq_lens) - output[:num_prefill_tokens] = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_k=max_seq_len, + output[: + num_prefill_tokens] = torch.ops.vllm.flash_attn_varlen_func( # noqa + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + block_table=prefill_meta.block_tables, + softcap=self.logits_soft_cap, + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + output[ + num_prefill_tokens:] = torch.ops.vllm.flash_attn_with_kvcache( + decode_query.unsqueeze(1), + key_cache, + value_cache, + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, softmax_scale=self.scale, causal=True, alibi_slopes=self.alibi_slopes, - block_table=prefill_meta.block_tables, softcap=self.logits_soft_cap, - ) - - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - output[num_prefill_tokens:] = flash_attn_with_kvcache( - decode_query.unsqueeze(1), - key_cache, - value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - ).squeeze(1) + ).squeeze(1) # Reshape the output tensor. return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index fad873b448a34..3022fa70e2ca7 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -4,9 +4,9 @@ try: from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper - from vllm_flash_attn import flash_attn_varlen_func + + import vllm.attention.backends.flash_attn # noqa except ImportError: - flash_attn_varlen_func = None BatchDecodeWithPagedKVCacheWrapper = None BatchPrefillWithPagedKVCacheWrapper = None @@ -520,7 +520,7 @@ def forward( # This happens when vllm runs the profiling to # determine the number of blocks. if kv_cache is None: - output = flash_attn_varlen_func( + output = torch.ops.vllm.flash_attn_varlen_func( q=query, k=key, v=value,