Skip to content

Commit

Permalink
Support baichuan2-chat chat template (#378)
Browse files Browse the repository at this point in the history
* support baichuan2-chat

* update args from generation config

* update deploy.py

* update readme

* tested with tp

* step-1 when last id is eos

* add news

---------

Co-authored-by: chenxin <[email protected]>
  • Loading branch information
wangruohui and irexyc authored Sep 8, 2023
1 parent ce21a31 commit 55764e0
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 12 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ ______________________________________________________________________

## News 🎉

- \[2023/09\] TurboMind supports Baichuan2-7B
- \[2023/08\] TurboMind supports flash-attention2.
- \[2023/08\] TurboMind supports Qwen-7B, dynamic NTK-RoPE scaling and dynamic logN scaling
- \[2023/08\] TurboMind supports Windows (tp=1)
Expand Down Expand Up @@ -55,11 +56,12 @@ LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by
> **Note**<br />
> W4A16 inference requires Nvidia GPU with Ampere architecture or above.
| Models | Tensor Parallel | FP16 | KV INT8 | W4A16 | W8A8 |
| :------: | :-------------: | :--: | :-----: | :---: | :--: |
| Llama | Yes | Yes | Yes | Yes | No |
| Llama2 | Yes | Yes | Yes | Yes | No |
| InternLM | Yes | Yes | Yes | Yes | No |
| Models | Tensor Parallel | FP16 | KV INT8 | W4A16 | W8A8 |
| :-------: | :-------------: | :--: | :-----: | :---: | :--: |
| Llama | Yes | Yes | Yes | Yes | No |
| Llama2 | Yes | Yes | Yes | Yes | No |
| InternLM | Yes | Yes | Yes | Yes | No |
| Baichuan2 | Yes | Yes | No | No | No |

### Pytorch

Expand Down
12 changes: 7 additions & 5 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ ______________________________________________________________________

## 更新 🎉

- \[2023/09\] TurboMind 支持 Baichuan2-7B
- \[2023/08\] TurboMind 支持 flash-attention2
- \[2023/08\] TurboMind 支持 Qwen-7B,动态NTK-RoPE缩放,动态logN缩放
- \[2023/08\] TurboMind 支持 Windows (tp=1)
Expand Down Expand Up @@ -56,11 +57,12 @@ LMDeploy 由 [MMDeploy](https://github.com/open-mmlab/mmdeploy) 和 [MMRazor](ht
> **Note**<br />
> W4A16 推理需要 Ampere 及以上架构的 Nvidia GPU
| 模型 | 模型并行 | FP16 | KV INT8 | W4A16 | W8A8 |
| :------: | :------: | :--: | :-----: | :---: | :--: |
| Llama | Yes | Yes | Yes | Yes | No |
| Llama2 | Yes | Yes | Yes | Yes | No |
| InternLM | Yes | Yes | Yes | Yes | No |
| 模型 | 模型并行 | FP16 | KV INT8 | W4A16 | W8A8 |
| :-------: | :------: | :--: | :-----: | :---: | :--: |
| Llama | Yes | Yes | Yes | Yes | No |
| Llama2 | Yes | Yes | Yes | Yes | No |
| InternLM | Yes | Yes | Yes | Yes | No |
| Baichuan2 | Yes | Yes | No | No | No |

### Pytorch

Expand Down
50 changes: 50 additions & 0 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,56 @@ def __init__(self, repetition_penalty=1.1, **kwargs):
self.repetition_penalty = repetition_penalty


@MODELS.register_module(name='baichuan2-7b-chat')
class Baichuan2_7BChat(BaseModel):

def __init__(self,
temperature=0.3,
top_k=5,
top_p=0.85,
repetition_penalty=1.05,
**kwargs):
super().__init__(temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
**kwargs)
self.user_token = '<reserved_106>' # id = 195
self.assistant_token = '<reserved_107>' # id = 196

def get_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template.
Args:
prompt (str): user's input prompt
sequence_start (bool): indicator for the first round chat of a
session sequence
Returns:
str: the concatenated prompt
"""
return f'{self.user_token}{prompt}{self.assistant_token}'

def messages2prompt(self, messages, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template.
Args:
messages (str | List): user's input prompt
Returns:
str: the concatenated prompt
"""
if isinstance(messages, str):
return self.get_prompt(messages, sequence_start)
system, users, assistants = self._translate_messages(messages)
ret = ''
for user, assistant in zip(users, assistants):
ret += f'{self.user_token}{user}{self.assistant_token}'
if assistant:
ret += f'{assistant}'
return ret


@MODELS.register_module(name='puyu')
class Puyu(BaseModel):
"""Chat template of puyu model.This is only for internal usage in Shanghai
Expand Down
5 changes: 5 additions & 0 deletions lmdeploy/serve/turbomind/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,11 @@ def get_tensor_transposed(name: str):
for ft, hf in other:
model_params[ft] = get_tensor(hf)

if model_name == 'baichuan2-7b-chat':
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/main/modeling_baichuan.py#L507
model_params['output.weight'] = torch.nn.functional.normalize(
model_params['output.weight'])

return export(model_name, num_layer, norm_eps, kv_head_num, model_params,
tokenizer_path, triton_models_path, tp)

Expand Down
12 changes: 10 additions & 2 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,16 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
output_ids, seq_start, sequence_length)
]
sequence_length -= seq_start.to(sequence_length.device)
yield [(output, l.item())
for output, l in zip(output_ids, sequence_length)]

outputs = []
for output, len_ in zip(output_ids, sequence_length):
output, len_ = output, len_.item()
if output[-1].item() == self.eos_id:
outputs.append((output[:-1], len_ - 1))
else:
outputs.append((output, len_))

yield outputs

if finish:
for t in self.threads:
Expand Down

0 comments on commit 55764e0

Please sign in to comment.