diff --git a/lmdeploy/pytorch/engine/devices/ascend.py b/lmdeploy/pytorch/engine/devices/ascend.py index 14017e05cb..8a0982c6eb 100644 --- a/lmdeploy/pytorch/engine/devices/ascend.py +++ b/lmdeploy/pytorch/engine/devices/ascend.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from .dipu import DIPUDeviceUtils +from .base_device_utils import BaseDeviceUtils -class ASCENDDeviceUtils(DIPUDeviceUtils): +class ASCENDDeviceUtils(BaseDeviceUtils): device = 'ascend' @@ -38,4 +38,7 @@ def update_step_context(cls, step_context): kv_start_indices, device=step_context.block_offsets.device) setattr(step_context, 'kv_start_indices', kv_start_indices) setattr(step_context, 'attention_mask', attention_mask) + is_unpaged_prefill = (not step_context.is_decoding) and all( + (step_context.q_seq_length == step_context.kv_seq_length).tolist()) + setattr(step_context, 'is_unpaged_prefill', is_unpaged_prefill) return step_context diff --git a/lmdeploy/pytorch/engine/devices/dipu.py b/lmdeploy/pytorch/engine/devices/dipu.py deleted file mode 100644 index d2cc9c4243..0000000000 --- a/lmdeploy/pytorch/engine/devices/dipu.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .base_device_utils import BaseDeviceUtils - - -class DIPUDeviceUtils(BaseDeviceUtils): - - device = 'dipu' - - @classmethod - def update_step_context(cls, step_context): - """update step context.""" - raise NotImplementedError('`update_step_context` of ' - f'<{cls}> not implemented.') diff --git a/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py index ec5f669feb..acfa44a42f 100644 --- a/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py +++ b/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py @@ -19,15 +19,22 @@ def apply_rotary_pos_emb( query_states_reshaped = query_states.reshape(1, bs, head, dim) key_states_reshaped = key_states.reshape(1, bs, num_kv_heads, dim) if not (hasattr(context, 'cos') or hasattr(context, 'sin')): - cos = cos[position_ids_1d].view(1, bs, 1, -1) - sin = sin[position_ids_1d].view(1, bs, 1, -1) + if len(cos.shape) == 3 and len(sin.shape) == 3: + cos = cos[:, position_ids_1d].view(1, bs, 1, -1) + sin = sin[:, position_ids_1d].view(1, bs, 1, -1) + elif len(cos.shape) == 2 and len(sin.shape) == 2: + cos = cos[position_ids_1d].view(1, bs, 1, -1) + sin = sin[position_ids_1d].view(1, bs, 1, -1) + else: + raise RuntimeError("Cannot handle cos/sin shape dims!") + if context: setattr(context, 'cos', cos) setattr(context, 'sin', sin) cached_cos = context.cos if context else cos cached_sin = context.sin if context else sin ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped, - cached_cos, cached_sin, None, None, None) + cached_cos, cached_sin, None, None) if q_embed is None: q_embed = query_states else: diff --git a/lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py b/lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py index 01346bfb58..ed5e833d7c 100644 --- a/lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py +++ b/lmdeploy/pytorch/kernels/ascend/fused_rotary_emb.py @@ -33,7 +33,7 @@ def fused_rotary_emb( cached_cos = context.cos if context else cos cached_sin = context.sin if context else sin ext_ops.apply_rotary_pos_emb(query_states_reshaped, key_states_reshaped, - cached_cos, cached_sin, None, None, None) + cached_cos, cached_sin, None, None) if out_q is None: out_q = query_states else: diff --git a/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py b/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py index 42cd24cd7f..0e25a7de4d 100644 --- a/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py +++ b/lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py @@ -23,44 +23,40 @@ def flash_context_attention( ): num_q_heads, dim = query_states.shape[1:3] num_kv_heads = value_states.shape[1] - batch = q_start_loc.shape[0] - qkv_eq = query_states.shape[0] == key_states.shape[0] - for i in range(batch): - if qkv_eq: - ext_ops.context_attention( - query=query_states, - key=key_states, - value=value_states, - q_start_loc=q_start_loc[i:i + 1], - seq_len_list=q_seq_len_list[i:i + 1], - num_q_heads=num_q_heads, - num_kv_heads=num_kv_heads, - attn_mask=context.attention_mask[i:i + 1], - attn_output=attn_output, - ) - else: - key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim) - value_cache = value_cache.reshape(1, kv_cache_len, - num_kv_heads * dim) - ext_ops.paged_prefill_attention( - query_states, - key_cache, - value_cache, - block_offsets, - block_size, - q_start_loc[i:i + 1], - q_seq_len_list[i:i + 1], - kv_seq_len[i:i + 1], - num_q_heads, - num_kv_heads, - attn_mask=context.attention_mask[i:i + 1], - attn_output=attn_output, - ) + if context.is_unpaged_prefill: + ext_ops.context_attention( + query=query_states, + key=key_states, + value=value_states, + q_start_loc=q_start_loc[i:i + 1], + seq_len_list=q_seq_len_list[i:i + 1], + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + attn_mask=context.attention_mask[i:i + 1], + attn_output=attn_output, + ) + else: + key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim) + value_cache = value_cache.reshape(1, kv_cache_len, num_kv_heads * dim) + ext_ops.paged_prefill_attention( + query_states, + key_cache, + value_cache, + block_offsets, + block_size, + q_start_loc[i:i + 1], + q_seq_len_list[i:i + 1], + kv_seq_len[i:i + 1], + num_q_heads, + num_kv_heads, + attn_mask=context.attention_mask[i:i + 1], + attn_output=attn_output, + ) def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, - block_offsets, block_size): + max_kv_seq_len, block_offsets, block_size): num_kv_heads, num_q_heads = k_cache.shape[1], q.shape[1] ext_ops.paged_decode_attention( query=q, @@ -69,6 +65,7 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len, block_table=block_offsets, block_size=block_size, kv_seq_len=kv_seq_len, + max_kv_seq_len=max_kv_seq_len, num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, attn_output=attn_output.view(q.shape), @@ -120,6 +117,7 @@ def paged_attention_fwd( v, attn_output, kv_seqlens, + context.max_kv_seq_length, block_offsets, block_size, ) diff --git a/lmdeploy/pytorch/kernels/dipu/__init__.py b/lmdeploy/pytorch/kernels/dipu/__init__.py deleted file mode 100644 index 65ebc8cec1..0000000000 --- a/lmdeploy/pytorch/kernels/dipu/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from ..default import multinomial_sampling -from .apply_rotary_pos_emb import apply_rotary_pos_emb -from .fill_kv_cache import fill_kv_cache -from .fused_rotary_emb import fused_rotary_emb -from .pagedattention import paged_attention_fwd -from .rms_norm import rms_norm - -__all__ = [ - 'rms_norm', - 'apply_rotary_pos_emb', - 'fused_rotary_emb', - 'fill_kv_cache', - 'paged_attention_fwd', - 'multinomial_sampling', -] diff --git a/lmdeploy/pytorch/kernels/dipu/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/dipu/apply_rotary_pos_emb.py deleted file mode 100644 index ab6a3e0cdc..0000000000 --- a/lmdeploy/pytorch/kernels/dipu/apply_rotary_pos_emb.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import deeplink_ext.cpp_extensions as ext -from torch import Tensor - - -def apply_rotary_pos_emb( - query_states: Tensor, - key_states: Tensor, - cos: Tensor, - sin: Tensor, - position_ids: Tensor, - position_ids_1d: Tensor, - q_embed=None, - k_embed=None, - context=None, -): - bs, head, dim = query_states.shape - numKeyValueHeads = key_states.shape[1] - query_states = query_states.reshape(bs, head * dim) - key_states = key_states.reshape(bs, numKeyValueHeads * dim) - if not (hasattr(context, 'cos') or hasattr(context, 'sin')): - cos = cos[position_ids_1d].view(1, bs, 1, -1) - sin = sin[position_ids_1d].view(1, bs, 1, -1) - setattr(context, 'cos', cos) - setattr(context, 'sin', sin) - ext.rotary_embedding_v2(query_states, key_states, context.cos, context.sin, - dim) - return query_states.view(bs, head, - dim), key_states.view(bs, numKeyValueHeads, dim) diff --git a/lmdeploy/pytorch/kernels/dipu/fill_kv_cache.py b/lmdeploy/pytorch/kernels/dipu/fill_kv_cache.py deleted file mode 100644 index f51b851185..0000000000 --- a/lmdeploy/pytorch/kernels/dipu/fill_kv_cache.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import deeplink_ext.cpp_extensions as ext -from torch import Tensor - - -def fill_kv_cache( - key_states: Tensor, - value_states: Tensor, - key_caches: Tensor, - value_caches: Tensor, - q_start_loc: Tensor, - q_seq_length: Tensor, - kv_seq_length: Tensor, - max_q_seq_length: int, - block_offsets: Tensor, - context: None, -): - """fill key/value state to cache for paged attention.""" - dest_index_copy_kv(key_states, context.kv_start_indices, key_caches) - dest_index_copy_kv(value_states, context.kv_start_indices, value_caches) - - -def dest_index_copy_kv(states, dest_loc, caches): - block_num, block_size, head, dim = caches.size() - caches_tmp = caches.view(block_num * block_size, head, dim) - ext.dest_index_copy_kv(states, dest_loc, caches_tmp) - caches[:] = caches_tmp.view(block_num, block_size, head, dim) diff --git a/lmdeploy/pytorch/kernels/dipu/fused_rotary_emb.py b/lmdeploy/pytorch/kernels/dipu/fused_rotary_emb.py deleted file mode 100644 index 2a67a24516..0000000000 --- a/lmdeploy/pytorch/kernels/dipu/fused_rotary_emb.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import deeplink_ext.cpp_extensions as ext -import torch -from torch import Tensor - - -def fused_rotary_emb( - query_states: Tensor, - key_states: Tensor, - position_ids: torch.LongTensor, - inv_freq: Tensor, - scaling_factor: float, - out_q: Tensor = None, - out_k: Tensor = None, - context=None, -): - _, bs, head, dim = query_states.shape - _, _, numKeyValueHeads, _ = key_states.shape - query_states = query_states.view(bs, head * dim) - key_states = key_states.view(bs, numKeyValueHeads * dim) - position_ids = position_ids.squeeze(0).unsqueeze(-1) - pos_freq = position_ids / scaling_factor * inv_freq - if not (hasattr(context, 'cos') or hasattr(context, 'sin')): - cos = (torch.cos(pos_freq).view(position_ids.shape[0], 1, - -1).repeat(1, 1, - 2).to(query_states.dtype)) - sin = (torch.sin(pos_freq).view(position_ids.shape[0], 1, - -1).repeat(1, 1, - 2).to(query_states.dtype)) - setattr(context, 'cos', cos) - setattr(context, 'sin', sin) - ext.rotary_embedding_v2(query_states, key_states, context.cos, context.sin, - dim) - query_states = query_states.view(1, bs, head, dim) - key_states = key_states.view(1, bs, numKeyValueHeads, dim) - return query_states, key_states diff --git a/lmdeploy/pytorch/kernels/dipu/pagedattention.py b/lmdeploy/pytorch/kernels/dipu/pagedattention.py deleted file mode 100644 index 9304ec0a35..0000000000 --- a/lmdeploy/pytorch/kernels/dipu/pagedattention.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import deeplink_ext.cpp_extensions as ext -import torch -from torch import Tensor - - -def flash_context_attention( - query_states: Tensor, - key_states: Tensor, - value_states: Tensor, - attn_output: Tensor, - key_cache: Tensor, - value_cache: Tensor, - block_offsets: Tensor, - q_start_loc: Tensor, - q_seqlens: list, - kv_seqlens: list, - block_size: int, - kv_cache_len: int, - context=None, -): - batch, head, dim = ( - q_start_loc.shape[0], - query_states.shape[1], - query_states.shape[2], - ) - numKeyValueHeads = value_states.shape[1] - assert key_states.shape[1] == value_states.shape[1] - for i in range(batch): - start = q_start_loc[i] - end = start + q_seqlens[i] - single_seqlen = int(end - start) - single_q = query_states[start:end].view(1, single_seqlen, -1) - single_k = key_states[start:end].reshape(1, single_seqlen, -1) - single_v = value_states[start:end].reshape(1, single_seqlen, -1) - single_out = attn_output[start:end, :].view(1, single_seqlen, -1) - mask = context.attention_mask[i] - if q_seqlens[i] == kv_seqlens[i]: - ext.prompt_flash_attention( - single_out, - single_q, - single_k, - single_v, - mask, - [kv_seqlens[i]], - kv_seqlens[i], - head, - numKeyValueHeads, - dim, - ) - else: - key_cache = key_cache.reshape(1, kv_cache_len, - numKeyValueHeads * dim) - value_cache = value_cache.reshape(1, kv_cache_len, - numKeyValueHeads * dim) - for j in range(q_seqlens[i]): - single_q = query_states[start + j:start + j + 1].view(1, 1, -1) - single_out = attn_output[start + j:start + j + 1].view( - 1, 1, -1) - ext.paged_attention( - single_out, - single_q, - key_cache, - value_cache, - mask[j:j + 1], - [kv_seqlens[i]], - head, - numKeyValueHeads, - dim, - block_offsets[i:i + 1], - block_size, - ) - - -def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seqlens, - block_table, block_size): - numKeyValueHeads = k_cache.shape[1] - assert k_cache.shape[1] == v_cache.shape[1] - bs, head, dim = q.shape - kv_cache_len = k_cache.shape[0] - q = q.reshape(bs, 1, head * dim) - k_cache = k_cache.reshape(1, kv_cache_len, numKeyValueHeads * dim) - v_cache = v_cache.reshape(1, kv_cache_len, numKeyValueHeads * dim) - ext.paged_attention( - attn_output.view(q.shape), - q, - k_cache, - v_cache, - None, - kv_seqlens, - head, - numKeyValueHeads, - dim, - block_table, - block_size, - ) - - -def paged_attention_fwd( - query_states: Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - key_cache: Tensor, - value_cache: Tensor, - attn_output: Tensor, - block_offsets: Tensor, - q_start_loc: Tensor, - q_seqlens: Tensor, - kv_seqlens: Tensor, - max_seqlen: int, - window_size: int = 1, - context=None, -): - is_decoding = query_states.shape[-3] == q_seqlens.size(0) - block_num, block_size, head, dim = key_cache.size() - kv_cache_len = block_num * block_size - k = key_cache.reshape(block_num * block_size, head, dim) - v = value_cache.reshape(block_num * block_size, head, dim) - if not is_decoding: - flash_context_attention( - query_states, - key_states, - value_states, - attn_output, - k, - v, - block_offsets.to(torch.int32), - q_start_loc, - q_seqlens.tolist(), - kv_seqlens.tolist(), - block_size, - kv_cache_len, - context=context, - ) - else: - paged_token_attention( - query_states, - k, - v, - attn_output, - kv_seqlens.tolist(), - block_offsets.to(torch.int32), - block_size, - ) diff --git a/lmdeploy/pytorch/kernels/dipu/rms_norm.py b/lmdeploy/pytorch/kernels/dipu/rms_norm.py deleted file mode 100644 index 8dbcf91ca2..0000000000 --- a/lmdeploy/pytorch/kernels/dipu/rms_norm.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import deeplink_ext.cpp_extensions as ext -import torch -from torch import Tensor - - -def rms_norm(hidden_states: Tensor, weight: Tensor, eps: float = 1e-6): - output = torch.empty_like(hidden_states) - inv_rms_shape = list(hidden_states.shape[:-1]) + [1] - inv_rms = torch.empty(inv_rms_shape, - dtype=torch.float32, - device=hidden_states.device) - ext.rms_norm(output, inv_rms, hidden_states, weight.shape, weight, None, - eps) - return output