diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index ccf8ab03a621b..91abaab78dcb8 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -133,13 +133,20 @@ def begin_forward(self): return assert self.prefill_wrapper is not None + assert self.query_start_loc is not None assert self.paged_kv_indices is not None assert self.paged_kv_indptr is not None assert self.paged_kv_last_page_len is not None - self.paged_kv_indices = self.paged_kv_indices.to(self.device) - self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + batch_size = self.query_start_loc.shape[0] - 1 + assert batch_size >= 0 + # The prefill stage does not read kv cache. + # Both paged_kv_indices and paged_kv_last_page_len are empty. + # paged_kv_indptr is a zero tensor with size batch_size + 1. + self.paged_kv_indptr = torch.zeros(batch_size + 1, + device=self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.device) + self.paged_kv_indices = self.paged_kv_indices.to(self.device) self.prefill_wrapper.end_forward() self.prefill_wrapper.begin_forward( self.query_start_loc, self.paged_kv_indptr,