Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Fix input for flashinfer prefill wrapper. #7008

Merged
merged 9 commits into from
Aug 2, 2024
11 changes: 9 additions & 2 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,20 @@ def begin_forward(self):
return

assert self.prefill_wrapper is not None
assert self.query_start_loc is not None
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per flashinfer-ai/flashinfer#362 (comment), perhaps we should keep this?

self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
batch_size = self.query_start_loc.shape[0] - 1
assert batch_size >= 0
# The prefill stage does not read kv cache.
# Both paged_kv_indices and paged_kv_last_page_len are empty.
# paged_kv_indptr is a zero tensor with size batch_size + 1.
self.paged_kv_indptr = torch.zeros(batch_size + 1,
device=self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward(
self.query_start_loc, self.paged_kv_indptr,
Expand Down
Loading