From 55764e0b33d8b9298f68b77484bab3832696c010 Mon Sep 17 00:00:00 2001 From: WRH <12756472+wangruohui@users.noreply.github.com> Date: Fri, 8 Sep 2023 17:03:02 +0800 Subject: [PATCH] Support baichuan2-chat chat template (#378) * support baichuan2-chat * update args from generation config * update deploy.py * update readme * tested with tp * step-1 when last id is eos * add news --------- Co-authored-by: chenxin --- README.md | 12 ++++--- README_zh-CN.md | 12 ++++--- lmdeploy/model.py | 50 ++++++++++++++++++++++++++++++ lmdeploy/serve/turbomind/deploy.py | 5 +++ lmdeploy/turbomind/turbomind.py | 12 +++++-- 5 files changed, 79 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 820b26df9..3ed91d925 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ ______________________________________________________________________ ## News 🎉 +- \[2023/09\] TurboMind supports Baichuan2-7B - \[2023/08\] TurboMind supports flash-attention2. - \[2023/08\] TurboMind supports Qwen-7B, dynamic NTK-RoPE scaling and dynamic logN scaling - \[2023/08\] TurboMind supports Windows (tp=1) @@ -55,11 +56,12 @@ LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by > **Note**
> W4A16 inference requires Nvidia GPU with Ampere architecture or above. -| Models | Tensor Parallel | FP16 | KV INT8 | W4A16 | W8A8 | -| :------: | :-------------: | :--: | :-----: | :---: | :--: | -| Llama | Yes | Yes | Yes | Yes | No | -| Llama2 | Yes | Yes | Yes | Yes | No | -| InternLM | Yes | Yes | Yes | Yes | No | +| Models | Tensor Parallel | FP16 | KV INT8 | W4A16 | W8A8 | +| :-------: | :-------------: | :--: | :-----: | :---: | :--: | +| Llama | Yes | Yes | Yes | Yes | No | +| Llama2 | Yes | Yes | Yes | Yes | No | +| InternLM | Yes | Yes | Yes | Yes | No | +| Baichuan2 | Yes | Yes | No | No | No | ### Pytorch diff --git a/README_zh-CN.md b/README_zh-CN.md index 9e9649f7d..1e3b101d3 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -20,6 +20,7 @@ ______________________________________________________________________ ## 更新 🎉 +- \[2023/09\] TurboMind 支持 Baichuan2-7B - \[2023/08\] TurboMind 支持 flash-attention2 - \[2023/08\] TurboMind 支持 Qwen-7B,动态NTK-RoPE缩放,动态logN缩放 - \[2023/08\] TurboMind 支持 Windows (tp=1) @@ -56,11 +57,12 @@ LMDeploy 由 [MMDeploy](https://github.com/open-mmlab/mmdeploy) 和 [MMRazor](ht > **Note**
> W4A16 推理需要 Ampere 及以上架构的 Nvidia GPU -| 模型 | 模型并行 | FP16 | KV INT8 | W4A16 | W8A8 | -| :------: | :------: | :--: | :-----: | :---: | :--: | -| Llama | Yes | Yes | Yes | Yes | No | -| Llama2 | Yes | Yes | Yes | Yes | No | -| InternLM | Yes | Yes | Yes | Yes | No | +| 模型 | 模型并行 | FP16 | KV INT8 | W4A16 | W8A8 | +| :-------: | :------: | :--: | :-----: | :---: | :--: | +| Llama | Yes | Yes | Yes | Yes | No | +| Llama2 | Yes | Yes | Yes | Yes | No | +| InternLM | Yes | Yes | Yes | Yes | No | +| Baichuan2 | Yes | Yes | No | No | No | ### Pytorch diff --git a/lmdeploy/model.py b/lmdeploy/model.py index 8fa03b7df..b3706a59d 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -227,6 +227,56 @@ def __init__(self, repetition_penalty=1.1, **kwargs): self.repetition_penalty = repetition_penalty +@MODELS.register_module(name='baichuan2-7b-chat') +class Baichuan2_7BChat(BaseModel): + + def __init__(self, + temperature=0.3, + top_k=5, + top_p=0.85, + repetition_penalty=1.05, + **kwargs): + super().__init__(temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + **kwargs) + self.user_token = '' # id = 195 + self.assistant_token = '' # id = 196 + + def get_prompt(self, prompt, sequence_start=True): + """Return the prompt that is concatenated with other elements in the + chat template. + + Args: + prompt (str): user's input prompt + sequence_start (bool): indicator for the first round chat of a + session sequence + Returns: + str: the concatenated prompt + """ + return f'{self.user_token}{prompt}{self.assistant_token}' + + def messages2prompt(self, messages, sequence_start=True): + """Return the prompt that is concatenated with other elements in the + chat template. + + Args: + messages (str | List): user's input prompt + Returns: + str: the concatenated prompt + """ + if isinstance(messages, str): + return self.get_prompt(messages, sequence_start) + system, users, assistants = self._translate_messages(messages) + ret = '' + for user, assistant in zip(users, assistants): + ret += f'{self.user_token}{user}{self.assistant_token}' + if assistant: + ret += f'{assistant}' + return ret + + @MODELS.register_module(name='puyu') class Puyu(BaseModel): """Chat template of puyu model.This is only for internal usage in Shanghai diff --git a/lmdeploy/serve/turbomind/deploy.py b/lmdeploy/serve/turbomind/deploy.py index 1aa88bb19..516afd793 100644 --- a/lmdeploy/serve/turbomind/deploy.py +++ b/lmdeploy/serve/turbomind/deploy.py @@ -525,6 +525,11 @@ def get_tensor_transposed(name: str): for ft, hf in other: model_params[ft] = get_tensor(hf) + if model_name == 'baichuan2-7b-chat': + # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/main/modeling_baichuan.py#L507 + model_params['output.weight'] = torch.nn.functional.normalize( + model_params['output.weight']) + return export(model_name, num_layer, norm_eps, kv_head_num, model_params, tokenizer_path, triton_models_path, tp) diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 807bd55c8..c39110b71 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -340,8 +340,16 @@ def _broadcast_np(data, dtype, shape=(batch_size, )): output_ids, seq_start, sequence_length) ] sequence_length -= seq_start.to(sequence_length.device) - yield [(output, l.item()) - for output, l in zip(output_ids, sequence_length)] + + outputs = [] + for output, len_ in zip(output_ids, sequence_length): + output, len_ = output, len_.item() + if output[-1].item() == self.eos_id: + outputs.append((output[:-1], len_ - 1)) + else: + outputs.append((output, len_)) + + yield outputs if finish: for t in self.threads: