diff --git a/python/flashinfer/page.py b/python/flashinfer/page.py index be206b2d..8efe2168 100644 --- a/python/flashinfer/page.py +++ b/python/flashinfer/page.py @@ -197,7 +197,7 @@ def append_paged_kv_cache( append_value: torch.Tensor, batch_indices: torch.Tensor, positions: torch.Tensor, - paged_kv_cache: torch.Tensor, + paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], kv_indices: torch.Tensor, kv_indptr: torch.Tensor, kv_last_page_len: torch.Tensor,