Skip to content

Commit

Permalink
[Kernel] Fix input for flashinfer prefill wrapper. (#7008)
Browse files Browse the repository at this point in the history
  • Loading branch information
LiuXiaoxuanPKU authored Aug 2, 2024
1 parent 6ce01f3 commit 954f730
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,20 @@ def begin_forward(self):
return

assert self.prefill_wrapper is not None
assert self.query_start_loc is not None
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
batch_size = self.query_start_loc.shape[0] - 1
assert batch_size >= 0
# The prefill stage does not read kv cache.
# Both paged_kv_indices and paged_kv_last_page_len are empty.
# paged_kv_indptr is a zero tensor with size batch_size + 1.
self.paged_kv_indptr = torch.zeros(batch_size + 1,
device=self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward(
self.query_start_loc, self.paged_kv_indptr,
Expand Down

0 comments on commit 954f730

Please sign in to comment.