Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support baichuan2-chat chat template #378

Merged
merged 7 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ ______________________________________________________________________

## News 🎉

- \[2023/09\] TurboMind supports Baichuan2-7B
- \[2023/08\] TurboMind supports flash-attention2.
- \[2023/08\] TurboMind supports Qwen-7B, dynamic NTK-RoPE scaling and dynamic logN scaling
- \[2023/08\] TurboMind supports Windows (tp=1)
Expand Down Expand Up @@ -55,11 +56,12 @@ LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by
> **Note**<br />
> W4A16 inference requires Nvidia GPU with Ampere architecture or above.

| Models | Tensor Parallel | FP16 | KV INT8 | W4A16 | W8A8 |
| :------: | :-------------: | :--: | :-----: | :---: | :--: |
| Llama | Yes | Yes | Yes | Yes | No |
| Llama2 | Yes | Yes | Yes | Yes | No |
| InternLM | Yes | Yes | Yes | Yes | No |
| Models | Tensor Parallel | FP16 | KV INT8 | W4A16 | W8A8 |
| :-------: | :-------------: | :--: | :-----: | :---: | :--: |
| Llama | Yes | Yes | Yes | Yes | No |
| Llama2 | Yes | Yes | Yes | Yes | No |
| InternLM | Yes | Yes | Yes | Yes | No |
| Baichuan2 | Yes | Yes | No | No | No |

### Pytorch

Expand Down
12 changes: 7 additions & 5 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ ______________________________________________________________________

## 更新 🎉

- \[2023/09\] TurboMind 支持 Baichuan2-7B
- \[2023/08\] TurboMind 支持 flash-attention2
- \[2023/08\] TurboMind 支持 Qwen-7B,动态NTK-RoPE缩放,动态logN缩放
- \[2023/08\] TurboMind 支持 Windows (tp=1)
Expand Down Expand Up @@ -56,11 +57,12 @@ LMDeploy 由 [MMDeploy](https://github.com/open-mmlab/mmdeploy) 和 [MMRazor](ht
> **Note**<br />
> W4A16 推理需要 Ampere 及以上架构的 Nvidia GPU

| 模型 | 模型并行 | FP16 | KV INT8 | W4A16 | W8A8 |
| :------: | :------: | :--: | :-----: | :---: | :--: |
| Llama | Yes | Yes | Yes | Yes | No |
| Llama2 | Yes | Yes | Yes | Yes | No |
| InternLM | Yes | Yes | Yes | Yes | No |
| 模型 | 模型并行 | FP16 | KV INT8 | W4A16 | W8A8 |
| :-------: | :------: | :--: | :-----: | :---: | :--: |
| Llama | Yes | Yes | Yes | Yes | No |
| Llama2 | Yes | Yes | Yes | Yes | No |
| InternLM | Yes | Yes | Yes | Yes | No |
| Baichuan2 | Yes | Yes | No | No | No |

### Pytorch

Expand Down
50 changes: 50 additions & 0 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,56 @@ def __init__(self, repetition_penalty=1.1, **kwargs):
self.repetition_penalty = repetition_penalty


@MODELS.register_module(name='baichuan2-7b-chat')
class Baichuan2_7BChat(BaseModel):

def __init__(self,
temperature=0.3,
top_k=5,
top_p=0.85,
repetition_penalty=1.05,
**kwargs):
super().__init__(temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
**kwargs)
self.user_token = '<reserved_106>' # id = 195
self.assistant_token = '<reserved_107>' # id = 196

def get_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template.

Args:
prompt (str): user's input prompt
sequence_start (bool): indicator for the first round chat of a
session sequence
Returns:
str: the concatenated prompt
"""
return f'{self.user_token}{prompt}{self.assistant_token}'

def messages2prompt(self, messages, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template.

Args:
messages (str | List): user's input prompt
Returns:
str: the concatenated prompt
"""
if isinstance(messages, str):
return self.get_prompt(messages, sequence_start)
system, users, assistants = self._translate_messages(messages)
ret = ''
for user, assistant in zip(users, assistants):
ret += f'{self.user_token}{user}{self.assistant_token}'
if assistant:
ret += f'{assistant}'
return ret


@MODELS.register_module(name='puyu')
class Puyu(BaseModel):
"""Chat template of puyu model.This is only for internal usage in Shanghai
Expand Down
5 changes: 5 additions & 0 deletions lmdeploy/serve/turbomind/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,11 @@ def get_tensor_transposed(name: str):
for ft, hf in other:
model_params[ft] = get_tensor(hf)

if model_name == 'baichuan2-7b-chat':
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/main/modeling_baichuan.py#L507
model_params['output.weight'] = torch.nn.functional.normalize(
model_params['output.weight'])

return export(model_name, num_layer, norm_eps, kv_head_num, model_params,
tokenizer_path, triton_models_path, tp)

Expand Down
12 changes: 10 additions & 2 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,16 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
output_ids, seq_start, sequence_length)
]
sequence_length -= seq_start.to(sequence_length.device)
yield [(output, l.item())
for output, l in zip(output_ids, sequence_length)]

outputs = []
for output, len_ in zip(output_ids, sequence_length):
output, len_ = output, len_.item()
if output[-1].item() == self.eos_id:
outputs.append((output[:-1], len_ - 1))
else:
outputs.append((output, len_))

yield outputs

if finish:
for t in self.threads:
Expand Down