Skip to content

Commit

Permalink
opt update_step_ctx on maca.
Browse files Browse the repository at this point in the history
  • Loading branch information
Reinerzhou committed Nov 29, 2024
1 parent 3913ead commit 62bbc72
Showing 1 changed file with 28 additions and 37 deletions.
65 changes: 28 additions & 37 deletions lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,50 +47,41 @@ def update_step_context(cls, step_context):
device = step_context.block_offsets.device

is_unpaged_prefill = False
if not step_context.is_decoding:
is_unpaged_prefill = \
all((step_context.q_seqlens ==
step_context.kv_seqlens).tolist())
q_start_loc = torch.cat((torch.tensor([0], device=device),
step_context.q_seqlens.cumsum(0))).int()
q_seqlens = step_context.q_seqlens.int()
kv_seqlens = step_context.kv_seqlens.int()
max_q_seq_len = torch.max(q_seqlens).item()
max_kv_seq_len = torch.max(kv_seqlens).item()

if not step_context.is_decoding:
is_unpaged_prefill = \
all((step_context.q_seqlens ==
step_context.kv_seqlens).tolist())
if is_unpaged_prefill:
single_attention_mask = torch.logical_not(
torch.tril(
torch.ones(max_q_seq_len,
max_kv_seq_len,
dtype=torch.bool).cuda(),
diagonal=max_kv_seq_len - max_q_seq_len,
))
attention_mask.append(single_attention_mask)
total_slots = torch.arange(block_num * block_size,
dtype=torch.long,
device=device)
total_slots = total_slots.view(block_num, block_size)
for i in range(step_context.q_start_loc.size(0)):
q_seq_len = int(step_context.q_seqlens[i])
kv_seq_len = int(step_context.kv_seqlens[i])
if not (step_context.is_decoding or is_unpaged_prefill):
single_attention_mask = torch.logical_not(
torch.tril(
torch.ones(step_context.q_seqlens[i],
step_context.block_offsets.shape[1] *
block_size,
dtype=torch.bool).cuda(),
diagonal=step_context.kv_seqlens[i] -
step_context.q_seqlens[i],
))
attention_mask.append(single_attention_mask)
history_length = kv_seq_len - q_seq_len
slot_tables = total_slots[step_context.block_offsets[i]].flatten()
slot_indices = [p for p in range(history_length, kv_seq_len)]
slots = slot_tables[slot_indices].reshape((-1, 1))
kv_start_indices.append(slots)
kv_start_indices = torch.cat(kv_start_indices)
if step_context.is_decoding:
# collect kv_start_indices without using a for-loop,
# (fill kv-cache for just ONE token during the decoding phase)
idx = (step_context.kv_seqlens - 1) % block_size
b_num = (step_context.kv_seqlens - 1) // block_size
last_block = step_context.block_offsets.gather(
1, b_num.view(-1, 1)).view(-1)
kv_start_indices = (last_block * block_size + idx).reshape((-1, 1))
else:
total_slots = torch.arange(block_num * block_size,
dtype=torch.long,
device=device)
total_slots = total_slots.view(block_num, block_size)
for i in range(step_context.q_start_loc.size(0)):
q_seq_len = int(step_context.q_seqlens[i])
kv_seq_len = int(step_context.kv_seqlens[i])
# collect kv start indices during the prefill phase.
history_length = kv_seq_len - q_seq_len
slot_tables = total_slots[
step_context.block_offsets[i]].flatten()
slot_indices = [p for p in range(history_length, kv_seq_len)]
slots = slot_tables[slot_indices].reshape((-1, 1))
kv_start_indices.append(slots)
kv_start_indices = torch.cat(kv_start_indices)

attn_meta_cls = cls.get_attention_metadata_cls()
attn_metadata = attn_meta_cls(
Expand Down

0 comments on commit 62bbc72

Please sign in to comment.