From 8dde898e64fd95c54b72c04a62bd1cc0bba80116 Mon Sep 17 00:00:00 2001 From: grimoire Date: Sat, 7 Oct 2023 16:50:22 +0800 Subject: [PATCH 1/7] optimize fill kv cache --- lmdeploy/pytorch_poc/kernels/__init__.py | 7 +- lmdeploy/pytorch_poc/kernels/fill_kv_cache.py | 165 ++++++++++++++++++ lmdeploy/pytorch_poc/patch/baichuan.py | 2 + lmdeploy/pytorch_poc/patch/functional.py | 92 ++-------- lmdeploy/pytorch_poc/patch/llama.py | 1 + 5 files changed, 184 insertions(+), 83 deletions(-) create mode 100644 lmdeploy/pytorch_poc/kernels/fill_kv_cache.py diff --git a/lmdeploy/pytorch_poc/kernels/__init__.py b/lmdeploy/pytorch_poc/kernels/__init__.py index 5b24fd7e0e..afa093d12e 100644 --- a/lmdeploy/pytorch_poc/kernels/__init__.py +++ b/lmdeploy/pytorch_poc/kernels/__init__.py @@ -3,10 +3,9 @@ from .context_biased_pagedattention import biased_paged_attention_fwd from .context_flashattention_nopad import context_attention_fwd from .context_pagedattention import paged_attention_fwd +from .fill_kv_cache import fill_kv_cache __all__ = [ - 'context_attention_fwd', - 'paged_attention_fwd', - 'biased_paged_attention_fwd', - 'alibi_paged_attention_fwd', + 'context_attention_fwd', 'paged_attention_fwd', + 'biased_paged_attention_fwd', 'alibi_paged_attention_fwd', 'fill_kv_cache' ] diff --git a/lmdeploy/pytorch_poc/kernels/fill_kv_cache.py b/lmdeploy/pytorch_poc/kernels/fill_kv_cache.py new file mode 100644 index 0000000000..d1ab5f958d --- /dev/null +++ b/lmdeploy/pytorch_poc/kernels/fill_kv_cache.py @@ -0,0 +1,165 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Sequence + +import torch +import triton +import triton.language as tl +from torch import Tensor + + +@triton.jit +def _fill_kv_cache_kernel( + k_states, + v_states, + k_caches, + v_caches, + state_start, + state_len, + cache_start, + block_offsets1d, + stride_kss, # stride of key state token + stride_vss, # stride of value state token + stride_kcs: tl.constexpr, # stride of key cache token + stride_vcs: tl.constexpr, # stride of value cache token + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + prog_id = tl.program_id(0) + + stride_kb = stride_kcs * BLOCK_M + stride_vb = stride_vcs * BLOCK_M + + sstart = tl.load(state_start + prog_id) + slen = tl.load(state_len + prog_id) + cstart = tl.load(cache_start + prog_id) + boffset = tl.load(block_offsets1d + prog_id) + + off_m = tl.arange(0, BLOCK_M) + off_n = tl.arange(0, BLOCK_N) + + ks_ptrs = k_states + (sstart + + off_m[:, None]) * stride_kss + off_n[None, :] + vs_ptrs = v_states + (sstart + + off_m[:, None]) * stride_vss + off_n[None, :] + kc_ptrs = k_caches + boffset * stride_kb + ( + cstart + off_m[:, None]) * stride_kcs + off_n[None, :] + vc_ptrs = v_caches + boffset * stride_vb + ( + cstart + off_m[:, None]) * stride_vcs + off_n[None, :] + + mask = off_m[:, None] < slen + + for idx in range(0, stride_kcs, BLOCK_N): + ks = tl.load(ks_ptrs + idx, mask=mask) + tl.store(kc_ptrs + idx, ks, mask=mask) + + for idx in range(0, stride_vcs, BLOCK_N): + vs = tl.load(vs_ptrs + idx, mask=mask) + tl.store(vc_ptrs + idx, vs, mask=mask) + + +def fill_kv_cache(k_states: Tensor, + v_states: Tensor, + k_caches: Tensor, + v_caches: Tensor, + start_loc: Tensor, + seq_length: Tensor, + block_offsets: Tensor, + history_lengths: Sequence, + context: Any = None): + """fill kv cache for paged attention.""" + fill_cache_info = getattr(context, 'fill_cache_info', None) + + if fill_cache_info is None: + batch_size = block_offsets.size(0) + block_size = k_caches.size(1) + + if not isinstance(history_lengths, Tensor): + history_lengths = torch.tensor(history_lengths, + device=k_states.device) + + batch_ids = torch.arange(batch_size, device=k_states.device) + + first_block_ids = history_lengths // block_size + block_offsets1d = block_offsets[batch_ids, first_block_ids] + + token_ids_start = history_lengths % block_size + first_seq_len = torch.minimum(seq_length, block_size - token_ids_start) + + state_start = start_loc[:batch_size] + state_len = first_seq_len + cache_start = token_ids_start + + # middle + last = remain + remain_seq_len = torch.maximum(seq_length.new_zeros(1), + seq_length - first_seq_len) + last_seq_len = remain_seq_len % block_size + middle_seq_len = remain_seq_len - last_seq_len + middle_block_nums = middle_seq_len // block_size + remain_block_nums = (remain_seq_len / block_size).ceil().long() + + remain_state_start = [ + ss + slen + + torch.arange(0, rlen, block_size, device=k_states.device) + for ss, slen, rlen in zip(state_start, first_seq_len, + remain_seq_len) + ] + remain_seq_lens = [ + torch.full((mid, ), block_size, device=k_states.device) + for mid in middle_block_nums + ] + remain_seq_lens = [ + torch.cat([slen, last]) + for slen, last in zip(remain_seq_lens, last_seq_len.unsqueeze(-1)) + ] + remain_block_offsets1d = [ + block_offsets[bid, ids:ids + ids_len] + for bid, ids, ids_len in zip(range(batch_size), first_block_ids + + 1, remain_block_nums) + ] + + # state_start store the state index of the block + # state_len store the length to write in the block + # cache_start store the first index the write in block + # block_offsets1d store the index of block in caches + state_start = torch.cat([state_start] + remain_state_start) + state_len = torch.cat([state_len] + remain_seq_lens) + cache_start = torch.cat( + [cache_start] + + [state_start.new_zeros(state_start.size(0) - batch_size)]) + block_offsets1d = torch.cat([block_offsets1d] + remain_block_offsets1d) + + if context is not None: + fill_cache_info = dict() + fill_cache_info['state_start'] = state_start + fill_cache_info['state_len'] = state_len + fill_cache_info['cache_start'] = cache_start + fill_cache_info['block_offsets1d'] = block_offsets1d + context.fill_cache_info = fill_cache_info + else: + state_start = fill_cache_info['state_start'] + state_len = fill_cache_info['state_len'] + cache_start = fill_cache_info['cache_start'] + block_offsets1d = fill_cache_info['block_offsets1d'] + + grid = (state_start.size(0), ) + BLOCK_M = k_caches.size(-3) + BLOCK_N = min(128, k_caches.stride(-3), v_caches.stride(-3)) + + _fill_kv_cache_kernel[grid]( + k_states, + v_states, + k_caches, + v_caches, + state_start=state_start, + state_len=state_len, + cache_start=cache_start, + block_offsets1d=block_offsets1d, + stride_kss=k_states.stride(-3), + stride_vss=v_states.stride(-3), + stride_kcs=k_caches.stride(-3), + stride_vcs=v_caches.stride(-3), + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_warps=4, + num_stages=1, + ) diff --git a/lmdeploy/pytorch_poc/patch/baichuan.py b/lmdeploy/pytorch_poc/patch/baichuan.py index 8471e5f3c7..a5efaafd8b 100644 --- a/lmdeploy/pytorch_poc/patch/baichuan.py +++ b/lmdeploy/pytorch_poc/patch/baichuan.py @@ -120,6 +120,7 @@ def _rotary_emb_fn(query_states, key_states, value_states): head_dim=self.head_dim, position_ids=position_ids, past_key_value=past_key_value, + context=context, qkv_proj=_qkv_proj, o_proj=self.o_proj, rotary_emb_fn=_rotary_emb_fn, @@ -205,6 +206,7 @@ def _qkv_proj(hidden_states): head_dim=self.head_dim, position_ids=position_ids, past_key_value=past_key_value, + context=context, qkv_proj=_qkv_proj, o_proj=self.o_proj, rotary_emb_fn=_rotary_emb_fn, diff --git a/lmdeploy/pytorch_poc/patch/functional.py b/lmdeploy/pytorch_poc/patch/functional.py index 5d6bba8323..7b411810af 100644 --- a/lmdeploy/pytorch_poc/patch/functional.py +++ b/lmdeploy/pytorch_poc/patch/functional.py @@ -1,13 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. import math -from typing import Callable, Optional, Sequence, Tuple +from typing import Any, Callable, Optional, Sequence, Tuple import torch from torch import Tensor from torch import distributed as dist from lmdeploy.pytorch_poc.kernels import (alibi_paged_attention_fwd, - paged_attention_fwd) + fill_kv_cache, paged_attention_fwd) def rotate_half(x: Tensor): @@ -50,73 +50,6 @@ def apply_rotary_pos_emb(q: Tensor, k: Tensor, cos: Tensor, sin: Tensor, return q_embed, k_embed -def fill_kv_cache( - k_states: Tensor, - v_states: Tensor, - k_caches: Tensor, - v_caches: Tensor, - start_loc: Tensor, - seq_length: Tensor, - block_offsets: Tensor, - history_lengths: Sequence, -): - """Fill key/value cache with current key value states. - - Paged attention choose cache block by block tables. New key/value should be - filled into the cache blocks indicated by block tables. - - Args: - k_states (Tensor): key states - v_states (Tensor): value states - k_caches (Tensor): key caches - v_caches (Tensor): value caches - start_loc (Tensor): state location of each data in batch - seq_length (Tensor): sequence length of each data in batch - block_offsets (Tensor): block table of blocks in key/value caches. - history_lengths (Sequence): Cache length in k_caches/v_caches. - Does not include data in k_states/v_states - """ - block_size = k_caches.size(1) - - history_lengths = torch.tensor(history_lengths) - first_free_block_offsets = history_lengths // block_size - first_token_offsets = history_lengths % block_size - - for bid in range(len(history_lengths)): - loc = start_loc[bid] - seq_len = seq_length[bid] - b_offsets = block_offsets[bid] - free_offset = first_free_block_offsets[bid] - token_offset = first_token_offsets[bid] - - k_state = k_states[loc:loc + seq_len] - v_state = v_states[loc:loc + seq_len] - - # fill remain(last non-full block) - block_id = b_offsets[free_offset] - fill_token_num = min(block_size - token_offset, seq_len) - k_caches[block_id][token_offset:token_offset + - fill_token_num] = k_state[:fill_token_num] - v_caches[block_id][token_offset:token_offset + - fill_token_num] = v_state[:fill_token_num] - - # update offset - seq_len = seq_len - fill_token_num - free_offset += 1 - k_state = k_state[fill_token_num:] - v_state = v_state[fill_token_num:] - - for seq_offset in range(0, seq_len, block_size): - token_num = min(seq_len - seq_offset, block_size) - block_id = b_offsets[free_offset] - k_caches[block_id][:token_num] = k_state[:token_num] - v_caches[block_id][:token_num] = v_state[:token_num] - - free_offset += 1 - k_state = k_state[token_num:] - v_state = v_state[token_num:] - - def generate_batched_mask(q_lens, k_lens, max_q_len: int = None, @@ -177,6 +110,7 @@ def attention_forward_with_paged_attention( head_dim: int, position_ids: torch.LongTensor, past_key_value: Tuple[Tensor], + context: Any = None, q_proj: Optional[Callable] = None, k_proj: Optional[Callable] = None, v_proj: Optional[Callable] = None, @@ -233,16 +167,16 @@ def attention_forward_with_paged_attention( q_seq_length = kv_seq_length - kv_seq_length.new_tensor(history_lengths) q_start_loc = q_seq_length.cumsum(0) q_start_loc = torch.cat([q_start_loc.new_zeros(1), q_start_loc[:-1]]) - fill_kv_cache( - key_states, - value_states, - past_key_value[0], - past_key_value[1], - q_start_loc, - q_seq_length, - block_offsets=block_offsets, - history_lengths=history_lengths, - ) + + fill_kv_cache(key_states, + value_states, + past_key_value[0], + past_key_value[1], + q_start_loc, + q_seq_length, + block_offsets=block_offsets, + history_lengths=history_lengths, + context=context) attn_output = torch.empty_like(query_states) block_size = past_key_value[0].size(1) diff --git a/lmdeploy/pytorch_poc/patch/llama.py b/lmdeploy/pytorch_poc/patch/llama.py index b4cd93e0ff..6c701fe5a7 100644 --- a/lmdeploy/pytorch_poc/patch/llama.py +++ b/lmdeploy/pytorch_poc/patch/llama.py @@ -70,6 +70,7 @@ def _rotary_emb_fn(query_states, key_states, value_states): head_dim=self.head_dim, position_ids=position_ids, past_key_value=past_key_value, + context=context, q_proj=self.q_proj, k_proj=self.k_proj, v_proj=self.v_proj, From 6edc78cf5596878b3ae080c2b886fb141535925b Mon Sep 17 00:00:00 2001 From: grimoire Date: Sat, 7 Oct 2023 19:20:53 +0800 Subject: [PATCH 2/7] update internlm --- lmdeploy/pytorch_poc/patch/internlm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lmdeploy/pytorch_poc/patch/internlm.py b/lmdeploy/pytorch_poc/patch/internlm.py index f84fb7d212..3bc976215e 100644 --- a/lmdeploy/pytorch_poc/patch/internlm.py +++ b/lmdeploy/pytorch_poc/patch/internlm.py @@ -68,6 +68,7 @@ def _rotary_emb_fn(query_states, key_states, value_states): head_dim=self.head_dim, position_ids=position_ids, past_key_value=past_key_value, + context=context, q_proj=self.q_proj, k_proj=self.k_proj, v_proj=self.v_proj, From 996ac70a3ed076ec298c289a621958fe66b4d4c5 Mon Sep 17 00:00:00 2001 From: grimoire Date: Sat, 7 Oct 2023 20:33:33 +0800 Subject: [PATCH 3/7] faster embedding --- lmdeploy/pytorch_poc/engine/engine.py | 8 ++++++++ lmdeploy/pytorch_poc/patch/functional.py | 23 ++++++++++++++--------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/lmdeploy/pytorch_poc/engine/engine.py b/lmdeploy/pytorch_poc/engine/engine.py index 0ae4b51ec1..a175355ad9 100644 --- a/lmdeploy/pytorch_poc/engine/engine.py +++ b/lmdeploy/pytorch_poc/engine/engine.py @@ -106,12 +106,16 @@ def __init__( block_offsets: List[List[int]], history_lengths: List[int], position_ids: torch.Tensor, + q_start_loc: torch.Tensor, + seq_length: torch.Tensor, world_size: int = 1, device='cuda', ): self.block_offsets_list = block_offsets self.history_lengths = history_lengths self.position_ids = position_ids + self.q_start_loc = q_start_loc + self.seq_length = seq_length self.world_size = world_size # padding zero @@ -374,6 +378,8 @@ def _tp_model_loop( block_offsets=inputs['block_offsets'], history_lengths=inputs['history_lengths'], position_ids=inputs['position_ids'], + q_start_loc=inputs['q_start_loc'], + seq_length=inputs['seq_length'], world_size=world_size, ), q_seq_info=(inputs['q_start_loc'], inputs['seq_length']), @@ -717,6 +723,8 @@ def _model_forward(self, inputs: Dict, swap_in_map: Dict[int, int], block_offsets=inputs['block_offsets'], history_lengths=inputs['history_lengths'], position_ids=inputs['position_ids'], + q_start_loc=inputs['q_start_loc'], + seq_length=inputs['seq_length'], ), q_seq_info=(inputs['q_start_loc'], inputs['seq_length']), ) diff --git a/lmdeploy/pytorch_poc/patch/functional.py b/lmdeploy/pytorch_poc/patch/functional.py index 7b411810af..357862cc62 100644 --- a/lmdeploy/pytorch_poc/patch/functional.py +++ b/lmdeploy/pytorch_poc/patch/functional.py @@ -37,13 +37,11 @@ def apply_rotary_pos_emb(q: Tensor, k: Tensor, cos: Tensor, sin: Tensor, sin = sin.to(q.device) cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids] # [bs, 1, seq_len, dim] - sin = sin[position_ids] # [bs, 1, seq_len, dim] seq_length = position_ids[..., -1] + 1 - cos = [s[:l] for s, l in zip(cos, seq_length)] - sin = [s[:l] for s, l in zip(sin, seq_length)] - cos = torch.cat(cos, 0).unsqueeze(1) - sin = torch.cat(sin, 0).unsqueeze(1) + position_ids_1d = [ids[:l] for ids, l in zip(position_ids, seq_length)] + position_ids_1d = torch.cat(position_ids_1d) + cos = cos[position_ids_1d].unsqueeze(1) + sin = sin[position_ids_1d].unsqueeze(1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) @@ -164,9 +162,16 @@ def attention_forward_with_paged_attention( query_states, key_states, value_states) kv_seq_length = position_ids[..., -1] + 1 - q_seq_length = kv_seq_length - kv_seq_length.new_tensor(history_lengths) - q_start_loc = q_seq_length.cumsum(0) - q_start_loc = torch.cat([q_start_loc.new_zeros(1), q_start_loc[:-1]]) + + q_seq_length = getattr(context, 'seq_length', None) + if q_seq_length is None: + q_seq_length = kv_seq_length - kv_seq_length.new_tensor( + history_lengths) + + q_start_loc = getattr(context, 'q_start_loc', None) + if q_start_loc is None: + q_start_loc = q_seq_length.cumsum(0) + q_start_loc = torch.cat([q_start_loc.new_zeros(1), q_start_loc[:-1]]) fill_kv_cache(key_states, value_states, From 3271d317b5d6489b8d567c705034ac201a330d3d Mon Sep 17 00:00:00 2001 From: grimoire Date: Sun, 8 Oct 2023 15:58:20 +0800 Subject: [PATCH 4/7] fix bias tp --- lmdeploy/pytorch_poc/dist_utils.py | 3 +++ lmdeploy/pytorch_poc/kernels/fill_kv_cache.py | 2 +- lmdeploy/pytorch_poc/patch/functional.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/lmdeploy/pytorch_poc/dist_utils.py b/lmdeploy/pytorch_poc/dist_utils.py index 38e37bc52b..a1b381f666 100644 --- a/lmdeploy/pytorch_poc/dist_utils.py +++ b/lmdeploy/pytorch_poc/dist_utils.py @@ -58,6 +58,9 @@ def rowwise_parallelize_linear_fn(module: nn.Module, dist_tensor = distribute_tensor(param, device_mesh, dist_spec) if to_local: dist_tensor = try_to_local(dist_tensor) + if name == 'bias': + # rowwise linear would add bias more than ones. + dist_tensor /= device_mesh.size() dist_param = torch.nn.Parameter(dist_tensor) module.register_parameter(name, dist_param) diff --git a/lmdeploy/pytorch_poc/kernels/fill_kv_cache.py b/lmdeploy/pytorch_poc/kernels/fill_kv_cache.py index d1ab5f958d..83ad90678d 100644 --- a/lmdeploy/pytorch_poc/kernels/fill_kv_cache.py +++ b/lmdeploy/pytorch_poc/kernels/fill_kv_cache.py @@ -108,7 +108,7 @@ def fill_kv_cache(k_states: Tensor, for mid in middle_block_nums ] remain_seq_lens = [ - torch.cat([slen, last]) + (torch.cat([slen, last]) if last != 0 else slen) for slen, last in zip(remain_seq_lens, last_seq_len.unsqueeze(-1)) ] remain_block_offsets1d = [ diff --git a/lmdeploy/pytorch_poc/patch/functional.py b/lmdeploy/pytorch_poc/patch/functional.py index 357862cc62..0f697669df 100644 --- a/lmdeploy/pytorch_poc/patch/functional.py +++ b/lmdeploy/pytorch_poc/patch/functional.py @@ -226,7 +226,7 @@ def attention_forward_with_paged_attention( else: raise ValueError(f'Unknown bias type: {bias_type}') hidden_size = num_heads * head_dim - attn_output = attn_output.reshape(-1, hidden_size) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], hidden_size) if o_proj is not None: attn_output = o_proj(attn_output) From c27aa34092b28b40699332a39062b5a691fdf1b5 Mon Sep 17 00:00:00 2001 From: grimoire Date: Sun, 8 Oct 2023 17:30:38 +0800 Subject: [PATCH 5/7] fix baichuan2 --- lmdeploy/pytorch_poc/engine/engine.py | 20 +++++++++++++------- lmdeploy/pytorch_poc/patch/functional.py | 4 ++-- lmdeploy/pytorch_poc/patch/patch.py | 4 ++++ 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/lmdeploy/pytorch_poc/engine/engine.py b/lmdeploy/pytorch_poc/engine/engine.py index a175355ad9..0dddd3d38a 100644 --- a/lmdeploy/pytorch_poc/engine/engine.py +++ b/lmdeploy/pytorch_poc/engine/engine.py @@ -18,7 +18,7 @@ TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper) -from transformers.utils import WEIGHTS_INDEX_NAME, cached_file +from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, cached_file from lmdeploy.pytorch.accel import LoadNoInit from lmdeploy.pytorch_poc.config import (CacheConfig, ModelConfig, @@ -261,14 +261,20 @@ def _tp_model_loop( torch_dtype=torch_dtype, trust_remote_code=True) - torch_model_json_path = cached_file(model_path, WEIGHTS_INDEX_NAME) - with open(torch_model_json_path, mode='r') as f: - torch_model_json = json.load(f) + try: + torch_model_json_path = cached_file(model_path, WEIGHTS_INDEX_NAME) + with open(torch_model_json_path, mode='r') as f: + torch_model_json = json.load(f) - weight_map = torch_model_json['weight_map'] + weight_map = torch_model_json['weight_map'] - checkpoints = list(set(weight_map.values())) - checkpoints = [cached_file(model_path, ckpt) for ckpt in checkpoints] + checkpoints = list(set(weight_map.values())) + checkpoints = [ + cached_file(model_path, ckpt) for ckpt in checkpoints + ] + except Exception: + logger.warning(f'load failed, try load from {WEIGHTS_NAME}.') + checkpoints = [cached_file(model_path, WEIGHTS_NAME)] patched_model = patch( model, extra_args=extra_args, diff --git a/lmdeploy/pytorch_poc/patch/functional.py b/lmdeploy/pytorch_poc/patch/functional.py index 0f697669df..6ef955c30a 100644 --- a/lmdeploy/pytorch_poc/patch/functional.py +++ b/lmdeploy/pytorch_poc/patch/functional.py @@ -33,8 +33,8 @@ def apply_rotary_pos_emb(q: Tensor, k: Tensor, cos: Tensor, sin: Tensor, """ # The first two dimensions of cos and sin are always 1, # so we can `squeeze` them. - cos = cos.to(q.device) - sin = sin.to(q.device) + cos = cos.to(device=q.device, dtype=q.dtype) + sin = sin.to(device=q.device, dtype=q.dtype) cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] seq_length = position_ids[..., -1] + 1 diff --git a/lmdeploy/pytorch_poc/patch/patch.py b/lmdeploy/pytorch_poc/patch/patch.py index a7862136b2..76c31af34d 100644 --- a/lmdeploy/pytorch_poc/patch/patch.py +++ b/lmdeploy/pytorch_poc/patch/patch.py @@ -9,6 +9,7 @@ import torch.distributed as dist from addict import Addict from torch.distributed._tensor import DeviceMesh +from transformers.utils import TRANSFORMERS_DYNAMIC_MODULE_NAME from lmdeploy.pytorch_poc.dist_utils import partition_module, replicate_module from lmdeploy.utils import get_logger @@ -27,6 +28,9 @@ MODULE_MAP.update({ 'modeling_baichuan.Model': 'lmdeploy.pytorch_poc.patch.llama.LlamaModel', # noqa + (f'{TRANSFORMERS_DYNAMIC_MODULE_NAME}.Baichuan2-7B-Chat' + '.modeling_baichuan.BaichuanModel'): + 'lmdeploy.pytorch_poc.patch.llama.LlamaModel', # noqa 'modeling_baichuan.BaichuanModel': 'lmdeploy.pytorch_poc.patch.baichuan.BaichuanModel', # noqa 'modeling_baichuan.Attention': From 46e71d4b50e26f69b4ea9c2df0fdcc3f9d57da66 Mon Sep 17 00:00:00 2001 From: grimoire Date: Sun, 8 Oct 2023 18:00:00 +0800 Subject: [PATCH 6/7] fix fill kv cache --- lmdeploy/pytorch_poc/kernels/fill_kv_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/pytorch_poc/kernels/fill_kv_cache.py b/lmdeploy/pytorch_poc/kernels/fill_kv_cache.py index 83ad90678d..660a7403d0 100644 --- a/lmdeploy/pytorch_poc/kernels/fill_kv_cache.py +++ b/lmdeploy/pytorch_poc/kernels/fill_kv_cache.py @@ -134,7 +134,7 @@ def fill_kv_cache(k_states: Tensor, fill_cache_info['state_len'] = state_len fill_cache_info['cache_start'] = cache_start fill_cache_info['block_offsets1d'] = block_offsets1d - context.fill_cache_info = fill_cache_info + context.fill_cache_info = fill_cache_info else: state_start = fill_cache_info['state_start'] state_len = fill_cache_info['state_len'] From dcc47196ed435882b5da55d17d9f40977c6684e0 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 9 Oct 2023 11:34:10 +0800 Subject: [PATCH 7/7] fix lint --- lmdeploy/pytorch_poc/kernels/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/pytorch_poc/kernels/__init__.py b/lmdeploy/pytorch_poc/kernels/__init__.py index 7824eadb32..7c780471e7 100644 --- a/lmdeploy/pytorch_poc/kernels/__init__.py +++ b/lmdeploy/pytorch_poc/kernels/__init__.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .alibi_pagedattention import alibi_paged_attention_fwd from .biased_pagedattention import biased_paged_attention_fwd +from .fill_kv_cache import fill_kv_cache from .flashattention_nopad import context_attention_fwd from .pagedattention import paged_attention_fwd -from .fill_kv_cache import fill_kv_cache __all__ = [ 'context_attention_fwd', 'paged_attention_fwd',