diff --git a/lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py index 084cae1bfe..61cd5fee87 100644 --- a/lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py @@ -47,6 +47,10 @@ def update_step_context(cls, step_context): device = step_context.block_offsets.device is_unpaged_prefill = False + if not step_context.is_decoding: + is_unpaged_prefill = \ + all((step_context.q_seqlens == + step_context.kv_seqlens).tolist()) q_start_loc = torch.cat((torch.tensor([0], device=device), step_context.q_seqlens.cumsum(0))).int() q_seqlens = step_context.q_seqlens.int() @@ -54,43 +58,30 @@ def update_step_context(cls, step_context): max_q_seq_len = torch.max(q_seqlens).item() max_kv_seq_len = torch.max(kv_seqlens).item() - if not step_context.is_decoding: - is_unpaged_prefill = \ - all((step_context.q_seqlens == - step_context.kv_seqlens).tolist()) - if is_unpaged_prefill: - single_attention_mask = torch.logical_not( - torch.tril( - torch.ones(max_q_seq_len, - max_kv_seq_len, - dtype=torch.bool).cuda(), - diagonal=max_kv_seq_len - max_q_seq_len, - )) - attention_mask.append(single_attention_mask) - total_slots = torch.arange(block_num * block_size, - dtype=torch.long, - device=device) - total_slots = total_slots.view(block_num, block_size) - for i in range(step_context.q_start_loc.size(0)): - q_seq_len = int(step_context.q_seqlens[i]) - kv_seq_len = int(step_context.kv_seqlens[i]) - if not (step_context.is_decoding or is_unpaged_prefill): - single_attention_mask = torch.logical_not( - torch.tril( - torch.ones(step_context.q_seqlens[i], - step_context.block_offsets.shape[1] * - block_size, - dtype=torch.bool).cuda(), - diagonal=step_context.kv_seqlens[i] - - step_context.q_seqlens[i], - )) - attention_mask.append(single_attention_mask) - history_length = kv_seq_len - q_seq_len - slot_tables = total_slots[step_context.block_offsets[i]].flatten() - slot_indices = [p for p in range(history_length, kv_seq_len)] - slots = slot_tables[slot_indices].reshape((-1, 1)) - kv_start_indices.append(slots) - kv_start_indices = torch.cat(kv_start_indices) + if step_context.is_decoding: + # collect kv_start_indices without using a for-loop, + # (fill kv-cache for just ONE token during the decoding phase) + idx = (step_context.kv_seqlens - 1) % block_size + b_num = (step_context.kv_seqlens - 1) // block_size + last_block = step_context.block_offsets.gather( + 1, b_num.view(-1, 1)).view(-1) + kv_start_indices = (last_block * block_size + idx).reshape((-1, 1)) + else: + total_slots = torch.arange(block_num * block_size, + dtype=torch.long, + device=device) + total_slots = total_slots.view(block_num, block_size) + for i in range(step_context.q_start_loc.size(0)): + q_seq_len = int(step_context.q_seqlens[i]) + kv_seq_len = int(step_context.kv_seqlens[i]) + # collect kv start indices during the prefill phase. + history_length = kv_seq_len - q_seq_len + slot_tables = total_slots[ + step_context.block_offsets[i]].flatten() + slot_indices = [p for p in range(history_length, kv_seq_len)] + slots = slot_tables[slot_indices].reshape((-1, 1)) + kv_start_indices.append(slots) + kv_start_indices = torch.cat(kv_start_indices) attn_meta_cls = cls.get_attention_metadata_cls() attn_metadata = attn_meta_cls(