Skip to content

Commit

Permalink
restore accidental deletion
Browse files Browse the repository at this point in the history
  • Loading branch information
hliuca committed Dec 2, 2024
1 parent b6a8200 commit 9242621
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
# Batch may be composed of prefill|decodes, adjust query start indices
# to refer to the start of decodes when the two are split apart.
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
if self._cached_decode_metadata.query_start_loc is not None:
qs = self._cached_decode_metadata.query_start_loc
self._cached_decode_metadata.query_start_loc = qs - qs[0]
return self._cached_decode_metadata

def advance_step(self,
Expand Down

0 comments on commit 9242621

Please sign in to comment.