Skip to content

Commit

Permalink
Revert to using begin_forward/forward because plan/run inputs have ch…
Browse files Browse the repository at this point in the history
…anged
  • Loading branch information
pavanimajety committed Dec 19, 2024
1 parent 690ae5d commit c1e4b21
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def begin_forward(self):
self.block_table_bound = self.block_table_bound.to(self.device)
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.plan(
self.prefill_wrapper.begin_forward(
self.query_start_loc,
self.paged_kv_indptr[:self.num_prefills + 1],
self.paged_kv_indices,
Expand All @@ -383,7 +383,7 @@ def begin_forward(self):
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)

assert self.decode_wrapper is not None
self.decode_wrapper.plan(
self.decode_wrapper.begin_forward(
self.paged_kv_indptr[self.num_prefills:],
self.paged_kv_indices,
self.paged_kv_last_page_len[self.num_prefills:],
Expand Down Expand Up @@ -866,7 +866,7 @@ def forward(
else:
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
prefill_output = prefill_meta.prefill_wrapper.run(
prefill_output = prefill_meta.prefill_wrapper.forward(
query,
kv_cache,
logits_soft_cap=logits_soft_cap,
Expand All @@ -877,7 +877,7 @@ def forward(
if decode_meta := attn_metadata.decode_metadata:
assert decode_meta is not None
assert decode_meta.decode_wrapper is not None
decode_output = decode_meta.decode_wrapper.run(
decode_output = decode_meta.decode_wrapper.forward(
decode_query,
kv_cache,
sm_scale=softmax_scale,
Expand Down

0 comments on commit c1e4b21

Please sign in to comment.