Skip to content

Commit

Permalink
refactor: optimize performance of ascend backend's update_step_contex…
Browse files Browse the repository at this point in the history
…t() (#2521)
  • Loading branch information
jiajie-yang authored Sep 26, 2024
1 parent 2c71f27 commit 0323103
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions lmdeploy/pytorch/backends/ascend/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_v_block_shape(
def update_step_context(cls, step_context):
"""update step context."""
kv_start_indices, attention_mask = [], []
_, block_size, _ = step_context.kv_caches[0][0].shape
block_num, block_size, _ = step_context.kv_caches[0][0].shape
device = step_context.block_offsets.device

is_unpaged_prefill = False
Expand All @@ -98,6 +98,10 @@ def update_step_context(cls, step_context):
diagonal=max_kv_seq_len - max_q_seq_len,
))
attention_mask.append(single_attention_mask)
total_slots = torch.arange(block_num * block_size,
dtype=torch.long,
device=device)
total_slots = total_slots.view(block_num, block_size)
for i in range(step_context.q_start_loc.size(0)):
q_seq_len = int(step_context.q_seqlens[i])
kv_seq_len = int(step_context.kv_seqlens[i])
Expand All @@ -113,17 +117,11 @@ def update_step_context(cls, step_context):
))
attention_mask.append(single_attention_mask)
history_length = kv_seq_len - q_seq_len
block_idx = history_length // block_size
block_loc = step_context.block_offsets[i][block_idx]
token_loc = history_length % block_size
for j in range(q_seq_len):
kv_start_indices.append([block_loc * block_size + token_loc])
if j == q_seq_len - 1:
break
token_loc = (token_loc + 1) % block_size
block_idx = block_idx if token_loc else block_idx + 1
block_loc = step_context.block_offsets[i][block_idx]
kv_start_indices = torch.tensor(kv_start_indices, device=device)
slot_tables = total_slots[step_context.block_offsets[i]].flatten()
slot_indices = [p for p in range(history_length, kv_seq_len)]
slots = slot_tables[slot_indices].reshape((-1, 1))
kv_start_indices.append(slots)
kv_start_indices = torch.cat(kv_start_indices)

attn_meta_cls = cls.get_attention_metadata_cls()
attn_metadata = attn_meta_cls(
Expand Down

0 comments on commit 0323103

Please sign in to comment.