diff --git a/lmdeploy/turbomind/deploy/module.py b/lmdeploy/turbomind/deploy/module.py index 52497175e..1754161ff 100644 --- a/lmdeploy/turbomind/deploy/module.py +++ b/lmdeploy/turbomind/deploy/module.py @@ -191,7 +191,7 @@ def __init__(self, model: BaseOutputModel): self.attn_bias = model.model_config.attn_bias def _reorder_and_merge(self, qkvo): - q, k, v, o = map(transpose, qkvo) + q, k, v, o = qkvo # reorder output dim for tm's rotary embedding layout if self.model.permute_qk: q = permute_v2(q, self.head_dim) @@ -202,6 +202,27 @@ def _reorder_and_merge(self, qkvo): o = torch.zeros_like(q) return qkv, o + def _repeat_kv(self, qkvo, kind: str): + """replicate kv.""" + q, k, v, o = qkvo + head_dim = self.model.model_config.size_per_head + hidden_dim = self.model.model_config.hidden_units + + def _repeat(x): + dim = hidden_dim if kind != 'bias' else 1 + x = x.reshape(dim, -1, head_dim) + x = x.repeat(1, 1, self.model.repeat_kv) + x = x.reshape(dim, -1) + return x + + k, v = map(_repeat, (k, v)) + if kind == 'bias': + if o is None: + o = torch.zeros(hidden_dim, dtype=q.dtype, device=q.device) + q, k, v, o = map(torch.squeeze, (q, k, v, o)) + + return (q, k, v, o) + def _export(self, idx: int, qkvo, kind: str, pack_fn, **kwargs): if all(x is None for x in qkvo): return @@ -209,6 +230,9 @@ def _export(self, idx: int, qkvo, kind: str, pack_fn, **kwargs): if is_lora_a: qkv, o = map(transpose, qkvo) else: + qkvo = tuple(map(transpose, qkvo)) + if self.model.repeat_kv: + qkvo = self._repeat_kv(qkvo, kind) qkv, o = self._reorder_and_merge(qkvo) self.model.save_split(pack_fn(qkv), self._attn.format(idx, 'w_qkv', kind), diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py index f2c981bb2..7ea1a84f3 100644 --- a/lmdeploy/turbomind/deploy/target_model/base.py +++ b/lmdeploy/turbomind/deploy/target_model/base.py @@ -78,6 +78,17 @@ def __init__(self, self.model_config.expert_inter_size = _pad_inter_size( self.model_config.expert_inter_size, self.model_config.group_size, self.tensor_para_size) + + # head_num is divisble by tp but kv_head_num is not + # and tp is divisble by kv_head_num + assert self.model_config.head_num % self.tensor_para_size == 0 + self.repeat_kv = 0 + if (self.tensor_para_size > self.model_config.kv_head_num and + self.tensor_para_size % self.model_config.kv_head_num == 0): + self.repeat_kv = (self.tensor_para_size // + self.model_config.kv_head_num) + self.model_config.kv_head_num = self.tensor_para_size + self.model_config.verify() assert self.model_config.kv_head_num % self.tensor_para_size == 0