From 70a5c63a843c15bc2c21dae5cb5c710913784a1d Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Thu, 19 Oct 2023 14:40:37 +0800 Subject: [PATCH] add solar chat template (#576) --- README.md | 1 + README_zh-CN.md | 1 + lmdeploy/model.py | 67 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+) diff --git a/README.md b/README.md index a2de4d6ac..b4da5eda5 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,7 @@ LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by | :----------: | :-------------: | :--: | :-----: | :---: | :--: | | Llama | Yes | Yes | Yes | Yes | No | | Llama2 | Yes | Yes | Yes | Yes | No | +| SOLAR | Yes | Yes | Yes | Yes | No | | InternLM-7B | Yes | Yes | Yes | Yes | No | | InternLM-20B | Yes | Yes | Yes | Yes | No | | QWen-7B | Yes | Yes | Yes | No | No | diff --git a/README_zh-CN.md b/README_zh-CN.md index 09c66c282..10c03bd1a 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -64,6 +64,7 @@ LMDeploy 由 [MMDeploy](https://github.com/open-mmlab/mmdeploy) 和 [MMRazor](ht | :----------: | :------: | :--: | :-----: | :---: | :--: | | Llama | Yes | Yes | Yes | Yes | No | | Llama2 | Yes | Yes | Yes | Yes | No | +| SOLAR | Yes | Yes | Yes | Yes | No | | InternLM-7B | Yes | Yes | Yes | Yes | No | | InternLM-20B | Yes | Yes | Yes | Yes | No | | QWen-7B | Yes | Yes | Yes | No | No | diff --git a/lmdeploy/model.py b/lmdeploy/model.py index 3bfc59aef..a588774c7 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -575,6 +575,73 @@ def messages2prompt(self, messages, sequence_start=True): return super().messages2prompt(messages, sequence_start) +@MODELS.register_module(name='solar') +class SOLAR(BaseModel): + """Chat template of SOLAR model. + + `https://huggingface.co/upstage/SOLAR-0-70b-16bit` + """ + + def __init__(self, + b_sys='### System:\n', + e_sys='\n\n', + boh='### User:\n', + eoh='\n\n', + boa='### Assistant:\n', + eoa='\n\n', + system='', + session_len=2048, + **kwargs): + super().__init__(**kwargs) + self.b_sys = b_sys + self.e_sys = e_sys + self.boh = boh + self.eoh = eoh + self.boa = boa + self.eoa = eoa + self.system = system + self.session_len = session_len + + def decorate_prompt(self, prompt, sequence_start=True): + """Return the prompt that is concatenated with other elements in the + chat template. + + Args: + prompt (str): user's input prompt + sequence_start (bool): indicator for the first round chat of a + session sequence + Returns: + str: the concatenated prompt + """ + assert self.capability == 'chat', \ + f'{type(self).__name__} has no capability of {self.capability}' + if sequence_start: + return f'{self.b_sys}{self.system}{self.e_sys}' \ + f'{self.boh}{prompt}{self.eoh}{self.boa}' + + return f'{self.boh}{prompt} {self.eoh}{self.boa}' + + def messages2prompt(self, messages, sequence_start=True): + """Return the prompt that is concatenated with other elements in the + chat template. + + Args: + messages (str | List): user's input prompt + Returns: + str: the concatenated prompt + """ + if isinstance(messages, str): + return self.get_prompt(messages, sequence_start) + system, users, assistants = self._translate_messages(messages) + system = self.system if not system else system + ret = f'{self.b_sys}{system}{self.e_sys}' + for i, (user, assistant) in enumerate(zip(users, assistants)): + ret += f'{self.boh}{user}{self.eoh}{self.boa}' + if assistant: + ret += f'{assistant}{self.eoa}' + return ret + + def main(model_name: str = 'test'): assert model_name in MODELS.module_dict.keys(), \ f"'{model_name}' is not supported. " \