Skip to content

Commit

Permalink
replicate kv for some models when tp is divisble by kv_head_num
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Dec 3, 2024
1 parent 4ede631 commit 00c435a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
20 changes: 20 additions & 0 deletions lmdeploy/turbomind/deploy/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,26 @@ def _reorder_and_merge(self, qkvo):
o = torch.zeros_like(q)
return qkv, o

def _repeat_kv(self, qkvo, kind: str):
"""replicate kv."""
q, k, v, o = qkvo
head_dim = self.model.model_config.size_per_head
hidden_dim = self.model.model_config.hidden_units

def _repeat(x):
dim = hidden_dim if kind != 'bias' else 1
x = x.view(-1, head_dim, dim).repeat(1, self.model.repeat_kv, 1)
x = x.reshape(-1, dim)
return x

k, v = map(_repeat, (k, v))
if kind == 'bias':
if o is None:
o = torch.zeros(hidden_dim, dtype=q.dtype, device=q.device)
q, k, v, o = map(torch.squeeze, (q, k, v, o))

return (q, k, v, o)

def _export(self, idx: int, qkvo, kind: str, pack_fn, **kwargs):
if all(x is None for x in qkvo):
return
Expand Down
11 changes: 11 additions & 0 deletions lmdeploy/turbomind/deploy/target_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,17 @@ def __init__(self,
self.model_config.expert_inter_size = _pad_inter_size(
self.model_config.expert_inter_size,
self.model_config.group_size, self.tensor_para_size)

# head_num is divisble by tp but kv_head_num is not
# and tp is divisble by kv_head_num
assert self.model_config.head_num % self.tensor_para_size == 0
self.repeat_kv = 0
if (self.tensor_para_size > self.model_config.kv_head_num and
self.tensor_para_size % self.model_config.kv_head_num == 0):
self.repeat_kv = (self.tensor_para_size //
self.model_config.kv_head_num)
self.model_config.kv_head_num = self.tensor_para_size

self.model_config.verify()
assert self.model_config.kv_head_num % self.tensor_para_size == 0

Expand Down

0 comments on commit 00c435a

Please sign in to comment.