diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index e367468d05d26..22720ff40b85a 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -356,14 +356,18 @@ def begin_forward(self): self.block_table_bound = self.block_table_bound.to(self.device) self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) self.paged_kv_indices = self.paged_kv_indices.to(self.device) - self.prefill_wrapper.end_forward() - self.prefill_wrapper.begin_forward( + self.prefill_wrapper.plan( self.query_start_loc, self.paged_kv_indptr[:self.num_prefills + 1], self.paged_kv_indices, self.paged_kv_last_page_len[:self.num_prefills], - self.num_qo_heads, self.num_kv_heads, self.head_dim, - self.page_size) + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + # pass query and kv data types. + q_data_type=self.q_data_type, + kv_data_type=self.data_type) if self.num_decode_tokens > 0: assert self.paged_kv_indices is not None assert self.paged_kv_indptr is not None @@ -379,8 +383,7 @@ def begin_forward(self): self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) assert self.decode_wrapper is not None - self.decode_wrapper.end_forward() - self.decode_wrapper.begin_forward( + self.decode_wrapper.plan( self.paged_kv_indptr[self.num_prefills:], self.paged_kv_indices, self.paged_kv_last_page_len[self.num_prefills:], @@ -391,7 +394,7 @@ def begin_forward(self): # Disable flashinfer's pos encoding and use vllm's rope. pos_encoding_mode="NONE", # kv-cache data type. - data_type=self.data_type, + kv_data_type=self.data_type, # query data type. q_data_type=self.q_data_type) @@ -863,7 +866,7 @@ def forward( else: assert prefill_meta is not None assert prefill_meta.prefill_wrapper is not None - prefill_output = prefill_meta.prefill_wrapper.forward( + prefill_output = prefill_meta.prefill_wrapper.run( query, kv_cache, logits_soft_cap=logits_soft_cap, @@ -874,7 +877,7 @@ def forward( if decode_meta := attn_metadata.decode_metadata: assert decode_meta is not None assert decode_meta.decode_wrapper is not None - decode_output = decode_meta.decode_wrapper.forward( + decode_output = decode_meta.decode_wrapper.run( decode_query, kv_cache, sm_scale=softmax_scale,