Skip to content

Commit

Permalink
Merge branch 'infer_ext' into daoxin/support-cogvlm
Browse files Browse the repository at this point in the history
  • Loading branch information
pdx1989 committed Aug 22, 2024
2 parents 70fd41c + 51ec61c commit e528255
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 320 deletions.
7 changes: 5 additions & 2 deletions lmdeploy/pytorch/engine/devices/ascend.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from .dipu import DIPUDeviceUtils
from .base_device_utils import BaseDeviceUtils


class ASCENDDeviceUtils(DIPUDeviceUtils):
class ASCENDDeviceUtils(BaseDeviceUtils):

device = 'ascend'

Expand Down Expand Up @@ -38,4 +38,7 @@ def update_step_context(cls, step_context):
kv_start_indices, device=step_context.block_offsets.device)
setattr(step_context, 'kv_start_indices', kv_start_indices)
setattr(step_context, 'attention_mask', attention_mask)
is_unpaged_prefill = (not step_context.is_decoding) and all(
(step_context.q_seq_length == step_context.kv_seq_length).tolist())
setattr(step_context, 'is_unpaged_prefill', is_unpaged_prefill)
return step_context
13 changes: 0 additions & 13 deletions lmdeploy/pytorch/engine/devices/dipu.py

This file was deleted.

13 changes: 10 additions & 3 deletions lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,22 @@ def apply_rotary_pos_emb(
query_states_reshaped = query_states.reshape(1, bs, head, dim)
key_states_reshaped = key_states.reshape(1, bs, num_kv_heads, dim)
if not (hasattr(context, 'cos') or hasattr(context, 'sin')):
cos = cos[position_ids_1d].view(1, bs, 1, -1)
sin = sin[position_ids_1d].view(1, bs, 1, -1)
if len(cos.shape) == 3 and len(sin.shape) == 3:
cos = cos[:, position_ids_1d].view(1, bs, 1, -1)
sin = sin[:, position_ids_1d].view(1, bs, 1, -1)
elif len(cos.shape) == 2 and len(sin.shape) == 2:
cos = cos[position_ids_1d].view(1, bs, 1, -1)
sin = sin[position_ids_1d].view(1, bs, 1, -1)
else:
raise RuntimeError("Cannot handle cos/sin shape dims!")

if context:
setattr(context, 'cos', cos)
setattr(context, 'sin', sin)
cached_cos = context.cos if context else cos
cached_sin = context.sin if context else sin
ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped,
cached_cos, cached_sin, None, None, None)
cached_cos, cached_sin, None, None)
if q_embed is None:
q_embed = query_states
else:
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def fused_rotary_emb(
cached_cos = context.cos if context else cos
cached_sin = context.sin if context else sin
ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped,
cached_cos, cached_sin, None, None, None)
cached_cos, cached_sin, None, None)
if out_q is None:
out_q = query_states
else:
Expand Down
66 changes: 32 additions & 34 deletions lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,44 +23,40 @@ def flash_context_attention(
):
num_q_heads, dim = query_states.shape[1:3]
num_kv_heads = value_states.shape[1]
batch = q_start_loc.shape[0]

qkv_eq = query_states.shape[0] == key_states.shape[0]
for i in range(batch):
if qkv_eq:
ext_ops.context_attention(
query=query_states,
key=key_states,
value=value_states,
q_start_loc=q_start_loc[i:i + 1],
seq_len_list=q_seq_len_list[i:i + 1],
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
attn_mask=context.attention_mask[i:i + 1],
attn_output=attn_output,
)
else:
key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
value_cache = value_cache.reshape(1, kv_cache_len,
num_kv_heads * dim)
ext_ops.paged_prefill_attention(
query_states,
key_cache,
value_cache,
block_offsets,
block_size,
q_start_loc[i:i + 1],
q_seq_len_list[i:i + 1],
kv_seq_len[i:i + 1],
num_q_heads,
num_kv_heads,
attn_mask=context.attention_mask[i:i + 1],
attn_output=attn_output,
)
if context.is_unpaged_prefill:
ext_ops.context_attention(
query=query_states,
key=key_states,
value=value_states,
q_start_loc=q_start_loc[i:i + 1],
seq_len_list=q_seq_len_list[i:i + 1],
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
attn_mask=context.attention_mask[i:i + 1],
attn_output=attn_output,
)
else:
key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
value_cache = value_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
ext_ops.paged_prefill_attention(
query_states,
key_cache,
value_cache,
block_offsets,
block_size,
q_start_loc[i:i + 1],
q_seq_len_list[i:i + 1],
kv_seq_len[i:i + 1],
num_q_heads,
num_kv_heads,
attn_mask=context.attention_mask[i:i + 1],
attn_output=attn_output,
)


def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
block_offsets, block_size):
max_kv_seq_len, block_offsets, block_size):
num_kv_heads, num_q_heads = k_cache.shape[1], q.shape[1]
ext_ops.paged_decode_attention(
query=q,
Expand All @@ -69,6 +65,7 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
block_table=block_offsets,
block_size=block_size,
kv_seq_len=kv_seq_len,
max_kv_seq_len=max_kv_seq_len,
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
attn_output=attn_output.view(q.shape),
Expand Down Expand Up @@ -120,6 +117,7 @@ def paged_attention_fwd(
v,
attn_output,
kv_seqlens,
context.max_kv_seq_length,
block_offsets,
block_size,
)
16 changes: 0 additions & 16 deletions lmdeploy/pytorch/kernels/dipu/__init__.py

This file was deleted.

29 changes: 0 additions & 29 deletions lmdeploy/pytorch/kernels/dipu/apply_rotary_pos_emb.py

This file was deleted.

27 changes: 0 additions & 27 deletions lmdeploy/pytorch/kernels/dipu/fill_kv_cache.py

This file was deleted.

36 changes: 0 additions & 36 deletions lmdeploy/pytorch/kernels/dipu/fused_rotary_emb.py

This file was deleted.

Loading

0 comments on commit e528255

Please sign in to comment.