Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DeepSeek-V2 support #2763

Merged
merged 24 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions autotest/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ turbomind_chat_model:
- liuhaotian/llava-v1.6-vicuna-7b
- deepseek-ai/deepseek-vl-1.3b-chat
- deepseek-ai/deepseek-coder-1.3b-instruct
- deepseek-ai/DeepSeek-V2-Lite-Chat
- codellama/CodeLlama-7b-Instruct-hf
- THUDM/glm-4-9b-chat
- openbmb/MiniCPM-Llama3-V-2_5
Expand Down Expand Up @@ -167,6 +168,7 @@ turbomind_quatization:
- Qwen/Qwen2-VL-7B-Instruct
- mistralai/Mistral-7B-Instruct-v0.3
- deepseek-ai/deepseek-coder-1.3b-instruct
- deepseek-ai/DeepSeek-V2-Lite-Chat
- codellama/CodeLlama-7b-Instruct-hf
gptq:
- internlm/internlm2_5-7b-chat
Expand Down
4 changes: 2 additions & 2 deletions examples/cpp/llama/llama_triton_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,14 @@ broadCastRequest(const std::vector<int>& v_start_ids,
}
else {
// conditional case.
ft::deviceMalloc(&d_input_ids, size_1, false);
ft::deviceMalloc(&d_input_ids, size_1, nullptr, false);
// ft::deviceMalloc(&d_input_lengths, size_2, false);
ft::cudaH2Dcpy(d_input_ids, v_input_ids.data(), size_1);
// ft::cudaH2Dcpy(d_input_lengths, v_input_lengths.data(), size_2);
}

if (!v_input_bad_words.empty()) {
ft::deviceMalloc(&d_input_bad_words, size_bad_words, false);
ft::deviceMalloc(&d_input_bad_words, size_bad_words, nullptr, false);
ft::cudaH2Dcpy(d_input_bad_words, v_input_bad_words.data(), size_bad_words);
}
else {
Expand Down
20 changes: 16 additions & 4 deletions lmdeploy/turbomind/deploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import json
from dataclasses import asdict, fields
from typing import List

# use pydantic.dataclasses.dataclass to check data type
from pydantic.dataclasses import dataclass
Expand Down Expand Up @@ -43,22 +44,33 @@ class ModelConfig:
# of token_embedding
embedding_size: int = 0
num_layer: int = None
inter_size: int = None
inter_size: List[int] = None
norm_eps: float = None
attn_bias: int = 0
start_id: int = None
end_id: int = None
size_per_head: int = 128
group_size: int = 0
group_size: int = 64
weight_type: str = None
session_len: int = None
tp: int = 1
model_format: str = 'hf'
expert_num: int = 0
expert_num: List[int] = ()
expert_inter_size: int = 0
experts_per_token: int = 0
moe_shared_gate: int = False
moe_norm_topk: int = False
norm_topk_prob: int = False
routed_scale: float = 1.0
topk_group: int = 1
topk_method: str = 'greedy'
moe_group_num: int = 1
# MLA
q_lora_rank: int = 0
kv_lora_rank: int = 0
qk_rope_dim: int = 0
v_head_dim: int = 0
# tuning
tune_layer_num: int = 1

def verify(self):
invalid = {}
Expand Down
7 changes: 3 additions & 4 deletions lmdeploy/turbomind/deploy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,10 @@ def get_tm_model(model_path,
engine_config.model_format = quant_method
group_size = _group_size

# Compatible to awq models that are quantized by lmdeploy (<=v0.3.0)
if not group_size:
group_size = 128

if engine_config.model_format in ['awq', 'gptq']:
# Compatible to awq models that are quantized by lmdeploy (<=v0.3.0)
if not group_size:
group_size = 128
assert group_size == 128, \
f'model format is "{engine_config.model_format}" ' \
f'but group_size is {group_size}. Currently, only 128 ' \
Expand Down
21 changes: 21 additions & 0 deletions lmdeploy/turbomind/deploy/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,27 @@ def items(self):
yield (-1, {k: f.get_tensor(k) for k in misc})
assert not params

# def items(self):
# params = defaultdict(dict)
# for shard in self.shards:
# # with safe_open(shard, 'pt') as f:
# with open(shard, 'rb') as f:
# w = safetensors.torch.load(f.read())
# misc = []
# for k in w.keys():
# match = re.findall(self.pattern, k)
# if not match:
# misc.append(k)
# else:
# idx = int(match[0])
# param = params[idx]
# param[k] = w[k]
# if len(param) == self.item_count[idx]:
# yield (idx, params.pop(idx))
# if misc:
# yield (-1, {k: w[k] for k in misc})
# assert not params


class PytorchLoader(BaseLoader):

Expand Down
90 changes: 82 additions & 8 deletions lmdeploy/turbomind/deploy/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,13 @@ class Ffn(Module):
def __init__(self, model: BaseOutputModel):
self.model = model
self.tp = model.tensor_para_size
# inter_sizes in config are padded and may be different from what's
# in the weights
self.inter_size = model.model_config.inter_size
self.group_size = max(1, model.model_config.group_size)

def _export(self,
inter_size: int,
fmt: str,
idx: int,
w123,
Expand All @@ -110,11 +113,11 @@ def _export(self,
w1, w2, w3 = map(transpose, w123)

if not is_lora_a:
w1 = pad_out_dims(w1, self.inter_size)
w3 = pad_out_dims(w3, self.inter_size)
w1 = pad_out_dims(w1, inter_size)
w3 = pad_out_dims(w3, inter_size)
if not is_lora_b:
group_size = self.group_size if apply_gs else 1
w2 = pad_in_dims(w2, self.inter_size // group_size)
w2 = pad_in_dims(w2, inter_size // group_size)

w1, w2, w3 = map(pack_fn, (w1, w2, w3))
self.model.save_split(w1,
Expand All @@ -132,7 +135,8 @@ def _export(self,

def apply(self, i: int, r: BaseReader):
for e in get_params(r.ffn(i, None)):
e(partial(self._export, self._ffn), partial(r.ffn, i), i)
e(partial(self._export, self.inter_size[i], self._ffn),
partial(r.ffn, i), i)


class MoeFfn(Ffn):
Expand All @@ -154,11 +158,13 @@ def __init__(self, model: BaseOutputModel):
self.shared_gate = model.model_config.moe_shared_gate

def apply(self, i: int, r: BaseReader):
if self.expert_num[i] == 0:
return
for p in get_params(r.moe_ffn_expert()):
for e in range(self.expert_num):
for e in range(self.expert_num[i]):
fmt = self._moe_ffn_expert.replace('E', str(e))
p(partial(self._export, fmt), partial(r.moe_ffn_expert, e, i),
i)
p(partial(self._export, self.inter_size, fmt),
partial(r.moe_ffn_expert, e, i), i)

gate = transpose(r.moe_ffn_gate(i))
self.model.save_split(gate, self._moe_ffn_gate.format(i))
Expand Down Expand Up @@ -218,6 +224,70 @@ def apply(self, i: int, r: BaseReader):
e(self._export, partial(r.attn, i), i)


class MLA(Module):
"""
requires:
r.mla(i, kind)
r.mla_norm(i)
"""

_mla = 'layers.{0}.attention.{1}.{2}'

def __init__(self, model: BaseOutputModel):
self.model = model

def _export(self, idx: int, xs, kind: str, pack_fn, **kwargs):
if all(x is None for x in xs):
return
q_a, q_b, q, kv_a, kv_b, o = map(transpose, xs)

if q is not None:
q_b = q

cfg = self.model.model_config
qk_nope_dim = cfg.size_per_head - cfg.qk_rope_dim

q_b = q_b.reshape(-1, cfg.size_per_head)

# [nope_dim | rope_dim] -> [rope_dim | nope_dim]
q_nope, q_pe = torch.split(q_b, (qk_nope_dim, cfg.qk_rope_dim), dim=-1)
q_b = torch.cat((q_pe, q_nope),
dim=-1).view(-1, cfg.head_num * cfg.size_per_head)

o = o.reshape(cfg.head_num, cfg.v_head_dim, -1)
o = torch.nn.functional.pad(
o, (0, 0, 0, cfg.size_per_head - cfg.v_head_dim, 0, 0))
o = o.view(cfg.head_num * cfg.size_per_head, cfg.hidden_units)

if q_a is not None:
self.model.save_split(pack_fn(q_a),
self._mla.format(idx, 'q_a_proj', kind))
q_b_name = 'q_proj' if q_a is None else 'q_b_proj'
self.model.save_split(pack_fn(q_b),
self._mla.format(idx, q_b_name, kind),
split_dim=-1)
self.model.save_split(pack_fn(kv_a),
self._mla.format(idx, 'kv_a_proj', kind))
self.model.save_split(pack_fn(kv_b),
self._mla.format(idx, 'kv_b_proj', kind),
split_dim=-1)
self.model.save_split(pack_fn(o),
self._mla.format(idx, 'wo', kind),
split_dim=0)

_layernorm = 'layers.{0}.attention.{1}_a_layernorm'

def apply(self, i: int, r: BaseReader):

for f in get_params(r.attn(i, None), bias=False):
f(self._export, partial(r.mla, i), i)

q, k = r.mla_norm(i)
if q is not None:
self.model.save_split(q, self._layernorm.format(i, 'q'))
self.model.save_split(k, self._layernorm.format(i, 'kv'))


class Misc(Module):
"""
requires:
Expand Down Expand Up @@ -258,7 +328,11 @@ class Transformer:

def __init__(self, model: BaseOutputModel):
self.model = model
modules = [Attn, LayerNorm]
modules = [LayerNorm]
if model.model_config.kv_lora_rank:
modules.append(MLA)
else:
modules.append(Attn)
if model.model_config.inter_size:
modules.append(Ffn)
if model.model_config.expert_num:
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/turbomind/deploy/source_model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .baichuan import Baichuan2Model, BaichuanModel # noqa: F401
from .deepseek2 import DeepSeek2Model # noqa: F401
from .deepseek_vl import DeepSeekVLModel # noqa: F401
from .glm4 import Glm4Model # noqa: F401
from .internlm2 import InternLM2Model # noqa: F401
Expand Down
100 changes: 100 additions & 0 deletions lmdeploy/turbomind/deploy/source_model/deepseek2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base import INPUT_MODELS
from .llama import LlamaModel, LlamaReader


class DeepSeek2Reader(LlamaReader):

def moe_ffn_gate(self, i):
return self.params.get(f'model.layers.{i}.mlp.gate.weight')

def moe_ffn_expert(self, e=None, i=None, kind=None):
if not kind:
return self.filter(r'experts')
result = []
for key in ['gate', 'down', 'up']:
name = f'model.layers.{i}.mlp.experts.{e}.{key}_proj.{kind}'
tensor = self.params.get(name)
tensor = self.transform(tensor, kind)
result.append(tensor)
return (*result, )

def _ffn(self, i: int, kind: str):
"""Get ffn kind for layer i."""
if not kind:
return self.filter(r'mlp' if i == 0 else r'shared_expert\.')
result = []
for key in ['gate', 'down', 'up']:
name = f'model.layers.{i}.mlp.shared_experts.{key}_proj.{kind}'
if i == 0:
name = name.replace('shared_experts.', '')
tensor = self.params.get(name)
tensor = self.transform(tensor, kind)
result.append(tensor)
return (*result, )

def mla(self, i: int, kind: str):
if not kind:
return self.filter(r'self_attn.*proj')
result = []
for key in [
'q_a_proj', 'q_b_proj', 'q_proj', 'kv_a_proj_with_mqa',
'kv_b_proj', 'o_proj'
]:
tensor = self.params.get(
f'{self.attn_layer_prefix}.{i}.self_attn.{key}.{kind}')
tensor = self.transform(tensor, kind)
result.append(tensor)
return (*result, )

def mla_norm(self, i: int):
result = []
for k in ['q', 'kv']:
name = f'{self.attn_layer_prefix}.{i}.self_attn.{k}_a_layernorm.weight' # noqa: E501
result.append(self.params.get(name))
return (*result, )


@INPUT_MODELS.register_module(name='deepseek2')
class DeepSeek2Model(LlamaModel):

Reader = DeepSeek2Reader

def tokenizer_info(self):
n_words = self.model_config['vocab_size']
bos_id = self.model_config['bos_token_id']
eos_id = self.model_config['eos_token_id']
return n_words, bos_id, eos_id

def model_info(self):
cfg = self.model_config
info = super().model_info()
qk_nope_dim = cfg['qk_nope_head_dim']
qk_rope_dim = cfg['qk_rope_head_dim']
num_layer = cfg['num_hidden_layers']
expert_num = cfg['n_routed_experts']
expert_num = [expert_num] * num_layer
expert_num[0] = 0
n_shared_experts = cfg['n_shared_experts']
expert_inter_size = cfg['moe_intermediate_size']
experts_per_token = cfg['num_experts_per_tok']
inter_size = [n_shared_experts * expert_inter_size] * num_layer
inter_size[0] = cfg['intermediate_size']
norm_topk_prob = cfg['norm_topk_prob']
info.update(kv_lora_rank=cfg['kv_lora_rank'],
q_lora_rank=cfg['q_lora_rank'] or 0,
qk_rope_dim=qk_rope_dim,
v_head_dim=cfg['v_head_dim'],
size_per_head=qk_rope_dim + qk_nope_dim,
rotary_embedding=qk_rope_dim,
expert_num=expert_num,
expert_inter_size=expert_inter_size,
experts_per_token=experts_per_token,
inter_size=inter_size,
norm_topk_prob=norm_topk_prob,
routed_scale=cfg['routed_scaling_factor'],
topk_method=cfg['topk_method'],
topk_group=cfg['topk_group'],
moe_group_num=cfg['n_group'],
tune_layer_num=2)
return info
2 changes: 1 addition & 1 deletion lmdeploy/turbomind/deploy/source_model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ def model_info(self):
info['expert_num'] = cfg['num_local_experts']
info['expert_inter_size'] = cfg['intermediate_size']
info['experts_per_token'] = cfg['num_experts_per_tok']
info['moe_norm_topk'] = True
info['norm_topk_prob'] = True
info['inter_size'] = 0
return info
Loading
Loading