Skip to content

Commit

Permalink
feat: CUDAGraph compatibility of multi-level cascade inference APIs (#…
Browse files Browse the repository at this point in the history
…586)

This PR add support for CUDAGraph compatibility for
`MultiLevelCascadeAttentionWrapper`.

cc @raywanb @pavanimajety @comaniac
  • Loading branch information
yzh119 authored Nov 6, 2024
1 parent 83e541d commit 2332e8a
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 13 deletions.
69 changes: 63 additions & 6 deletions python/flashinfer/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,22 @@ class MultiLevelCascadeAttentionWrapper:
...
>>> outputs[0].shape
torch.Size([7, 64, 128])
See Also
--------
BatchPrefillWithPagedKVCacheWrapper
"""

def __init__(
self, num_levels, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD"
self,
num_levels,
float_workspace_buffer: torch.Tensor,
kv_layout: str = "NHD",
use_cuda_graph: bool = False,
qo_indptr_buf_arr: Optional[list[torch.Tensor]] = None,
paged_kv_indptr_buf_arr: Optional[list[torch.Tensor]] = None,
paged_kv_indices_buf_arr: Optional[list[torch.Tensor]] = None,
paged_kv_last_page_len_buf_arr: Optional[list[torch.Tensor]] = None,
) -> None:
r"""Constructor of :class:`MultiLevelCascadeAttentionWrapper`.
Expand All @@ -298,14 +310,59 @@ def __init__(
buffer should be the same as the device of the input tensors.
kv_layout : str
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
use_cuda_graph : bool
Whether to use CUDA graph to capture the kernels, if enabled, the auxiliary data structures
will be stored in provided buffers.
qo_indptr_buf_arr : Optional[List[torch.Tensor]]
An array of qo indptr buffers for each level, the array length should be equal to
the number of levels.
The last element of each tensor should be the total number of queries/outputs.
paged_kv_indptr_buf_arr : Optional[List[torch.Tensor]]
An array of paged kv-cache indptr buffers for each level, the array length should be
equal to the number of levels.
paged_kv_indices_buf_arr : Optional[List[torch.Tensor]]
An array of paged kv-cache indices buffers for each level, the array length should be
equal to the number of levels.
paged_kv_last_page_len_buf_arr : Optional[List[torch.Tensor]]
An array of paged kv-cache last page length buffers for each level, the array length
should be equal to the number of levels.
"""
self._batch_prefill_wrappers = [
BatchPrefillWithPagedKVCacheWrapper(float_workspace_buffer, kv_layout)
for _ in range(num_levels)
]
self._use_cuda_graph = use_cuda_graph
if use_cuda_graph:
self._batch_prefill_wrappers = [
BatchPrefillWithPagedKVCacheWrapper(
float_workspace_buffer,
kv_layout,
use_cuda_graph=True,
qo_indptr_buf=qo_indptr_buf,
paged_kv_indptr_buf=paged_kv_indptr_buf,
paged_kv_indices_buf=paged_kv_indices_buf,
paged_kv_last_page_len_buf=paged_kv_last_page_len_buf,
)
for (
qo_indptr_buf,
paged_kv_indptr_buf,
paged_kv_indices_buf,
paged_kv_last_page_len_buf,
) in zip(
qo_indptr_buf_arr,
paged_kv_indptr_buf_arr,
paged_kv_indices_buf_arr,
paged_kv_last_page_len_buf_arr,
)
]
else:
self._batch_prefill_wrappers = [
BatchPrefillWithPagedKVCacheWrapper(float_workspace_buffer, kv_layout)
for _ in range(num_levels)
]
self._num_levels = num_levels
self._kv_layout = kv_layout

@property
def is_cuda_graph_enabled(self) -> bool:
return self._use_cuda_graph

def reset_workspace_buffer(
self,
float_workspace_buffer: torch.Tensor,
Expand Down Expand Up @@ -912,7 +969,7 @@ def forward(
k_shared: torch.Tensor,
v_shared: torch.Tensor,
unique_kv_cache: torch.Tensor,
causal: bool = True,
causal: bool = False,
allow_fp16_qk_reduction: bool = False,
sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None,
Expand Down
12 changes: 6 additions & 6 deletions python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ def __init__(
use_cuda_graph : bool
Whether to enable CUDA graph capture for the prefill kernels, if enabled, the
auxiliary data structures will be stored as provided buffers. The ``batch_size``
auxiliary data structures will be stored in provided buffers. The ``batch_size``
cannot change during the lifecycle of this wrapper when CUDAGraph is enabled.
qo_indptr_buf : Optional[torch.Tensor]
Expand Down Expand Up @@ -1095,7 +1095,7 @@ def forward(
self,
q: torch.Tensor,
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
causal: bool = True,
causal: bool = False,
pos_encoding_mode: str = "NONE",
allow_fp16_qk_reduction: bool = False,
k_scale: Optional[float] = None,
Expand Down Expand Up @@ -1240,7 +1240,7 @@ def forward_return_lse(
self,
q: torch.Tensor,
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
causal: bool = True,
causal: bool = False,
pos_encoding_mode: str = "NONE",
allow_fp16_qk_reduction: bool = False,
k_scale: Optional[float] = None,
Expand Down Expand Up @@ -1491,7 +1491,7 @@ def plan(
head_dim: int,
custom_mask: Optional[torch.Tensor] = None,
packed_custom_mask: Optional[torch.Tensor] = None,
causal: bool = True,
causal: bool = False,
pos_encoding_mode: str = "NONE",
allow_fp16_qk_reduction: bool = False,
window_left: int = -1,
Expand Down Expand Up @@ -1683,7 +1683,7 @@ def forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = True,
causal: bool = False,
pos_encoding_mode: str = "NONE",
allow_fp16_qk_reduction: bool = False,
window_left: int = -1,
Expand Down Expand Up @@ -1812,7 +1812,7 @@ def forward_return_lse(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = True,
causal: bool = False,
pos_encoding_mode: str = "NONE",
allow_fp16_qk_reduction: bool = False,
window_left: int = -1,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_shared_prefix_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def ceil_div(a, b):
@pytest.mark.parametrize("unique_kv_len", [37, 17])
@pytest.mark.parametrize("shared_kv_len", [128, 512, 2048])
@pytest.mark.parametrize("num_heads", [8, 16])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("causal", [False])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("page_size", [1, 16])
def test_batch_attention_with_shared_prefix_paged_kv_cache(
Expand Down

0 comments on commit 2332e8a

Please sign in to comment.