Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Core] Changes to support 0.2.0 flashinfer
Browse files Browse the repository at this point in the history
pavanimajety committed Dec 19, 2024

Verified

This commit was signed with the committer’s verified signature.
pavanimajety Pavani Majety
1 parent fdea8ec commit 690ae5d
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
@@ -356,14 +356,18 @@ def begin_forward(self):
self.block_table_bound = self.block_table_bound.to(self.device)
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward(
self.prefill_wrapper.plan(
self.query_start_loc,
self.paged_kv_indptr[:self.num_prefills + 1],
self.paged_kv_indices,
self.paged_kv_last_page_len[:self.num_prefills],
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
self.num_qo_heads,
self.num_kv_heads,
self.head_dim,
self.page_size,
# pass query and kv data types.
q_data_type=self.q_data_type,
kv_data_type=self.data_type)
if self.num_decode_tokens > 0:
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
@@ -379,8 +383,7 @@ def begin_forward(self):
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)

assert self.decode_wrapper is not None
self.decode_wrapper.end_forward()
self.decode_wrapper.begin_forward(
self.decode_wrapper.plan(
self.paged_kv_indptr[self.num_prefills:],
self.paged_kv_indices,
self.paged_kv_last_page_len[self.num_prefills:],
@@ -391,7 +394,7 @@ def begin_forward(self):
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
# kv-cache data type.
data_type=self.data_type,
kv_data_type=self.data_type,
# query data type.
q_data_type=self.q_data_type)

@@ -863,7 +866,7 @@ def forward(
else:
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
prefill_output = prefill_meta.prefill_wrapper.forward(
prefill_output = prefill_meta.prefill_wrapper.run(
query,
kv_cache,
logits_soft_cap=logits_soft_cap,
@@ -874,7 +877,7 @@ def forward(
if decode_meta := attn_metadata.decode_metadata:
assert decode_meta is not None
assert decode_meta.decode_wrapper is not None
decode_output = decode_meta.decode_wrapper.forward(
decode_output = decode_meta.decode_wrapper.run(
decode_query,
kv_cache,
sm_scale=softmax_scale,

0 comments on commit 690ae5d

Please sign in to comment.