From 690ae5d247bf3185443ccf35e9ff0db68886fe3b Mon Sep 17 00:00:00 2001 From: Pavani Majety Date: Wed, 18 Dec 2024 16:37:06 -0800 Subject: [PATCH 1/2] [Core] Changes to support 0.2.0 flashinfer --- vllm/attention/backends/flashinfer.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index e367468d05d26..22720ff40b85a 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -356,14 +356,18 @@ def begin_forward(self): self.block_table_bound = self.block_table_bound.to(self.device) self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) self.paged_kv_indices = self.paged_kv_indices.to(self.device) - self.prefill_wrapper.end_forward() - self.prefill_wrapper.begin_forward( + self.prefill_wrapper.plan( self.query_start_loc, self.paged_kv_indptr[:self.num_prefills + 1], self.paged_kv_indices, self.paged_kv_last_page_len[:self.num_prefills], - self.num_qo_heads, self.num_kv_heads, self.head_dim, - self.page_size) + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + # pass query and kv data types. + q_data_type=self.q_data_type, + kv_data_type=self.data_type) if self.num_decode_tokens > 0: assert self.paged_kv_indices is not None assert self.paged_kv_indptr is not None @@ -379,8 +383,7 @@ def begin_forward(self): self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) assert self.decode_wrapper is not None - self.decode_wrapper.end_forward() - self.decode_wrapper.begin_forward( + self.decode_wrapper.plan( self.paged_kv_indptr[self.num_prefills:], self.paged_kv_indices, self.paged_kv_last_page_len[self.num_prefills:], @@ -391,7 +394,7 @@ def begin_forward(self): # Disable flashinfer's pos encoding and use vllm's rope. pos_encoding_mode="NONE", # kv-cache data type. - data_type=self.data_type, + kv_data_type=self.data_type, # query data type. q_data_type=self.q_data_type) @@ -863,7 +866,7 @@ def forward( else: assert prefill_meta is not None assert prefill_meta.prefill_wrapper is not None - prefill_output = prefill_meta.prefill_wrapper.forward( + prefill_output = prefill_meta.prefill_wrapper.run( query, kv_cache, logits_soft_cap=logits_soft_cap, @@ -874,7 +877,7 @@ def forward( if decode_meta := attn_metadata.decode_metadata: assert decode_meta is not None assert decode_meta.decode_wrapper is not None - decode_output = decode_meta.decode_wrapper.forward( + decode_output = decode_meta.decode_wrapper.run( decode_query, kv_cache, sm_scale=softmax_scale, From 5439e7d283459720cad63cb6b7e3ba1f328aa86c Mon Sep 17 00:00:00 2001 From: Pavani Majety Date: Wed, 18 Dec 2024 18:24:06 -0800 Subject: [PATCH 2/2] Revert to using begin_forward/forward because plan/run inputs have changed Signed-off-by: Pavani Majety --- vllm/attention/backends/flashinfer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 22720ff40b85a..a70bb09624b10 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -356,7 +356,7 @@ def begin_forward(self): self.block_table_bound = self.block_table_bound.to(self.device) self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) self.paged_kv_indices = self.paged_kv_indices.to(self.device) - self.prefill_wrapper.plan( + self.prefill_wrapper.begin_forward( self.query_start_loc, self.paged_kv_indptr[:self.num_prefills + 1], self.paged_kv_indices, @@ -383,7 +383,7 @@ def begin_forward(self): self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) assert self.decode_wrapper is not None - self.decode_wrapper.plan( + self.decode_wrapper.begin_forward( self.paged_kv_indptr[self.num_prefills:], self.paged_kv_indices, self.paged_kv_last_page_len[self.num_prefills:], @@ -866,7 +866,7 @@ def forward( else: assert prefill_meta is not None assert prefill_meta.prefill_wrapper is not None - prefill_output = prefill_meta.prefill_wrapper.run( + prefill_output = prefill_meta.prefill_wrapper.forward( query, kv_cache, logits_soft_cap=logits_soft_cap, @@ -877,7 +877,7 @@ def forward( if decode_meta := attn_metadata.decode_metadata: assert decode_meta is not None assert decode_meta.decode_wrapper is not None - decode_output = decode_meta.decode_wrapper.run( + decode_output = decode_meta.decode_wrapper.forward( decode_query, kv_cache, sm_scale=softmax_scale,