From 7749128467a5b96fa41c846d174ae2e939a307e5 Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Thu, 9 Nov 2023 15:09:47 +0800 Subject: [PATCH] Add UltraCM and WizardLM chat templates (#599) * add ultracm eval chat template * add WizardLM chat template * use ultrachat template instead of ultracm usecase --- lmdeploy/model.py | 69 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/lmdeploy/model.py b/lmdeploy/model.py index c5cddcde7..7284f3490 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -111,6 +111,7 @@ def sampling_param(self): repetition_penalty=self.repetition_penalty) +@MODELS.register_module(name='wizardlM') @MODELS.register_module(name='vicuna') class Vicuna(BaseModel): """Chat template of vicuna model.""" @@ -647,6 +648,74 @@ def messages2prompt(self, messages, sequence_start=True): return ret +@MODELS.register_module(name='ultracm') +@MODELS.register_module(name='ultralm') +class UltraChat(BaseModel): + """Template of UltraCM and UltraLM models. + + `https://huggingface.co/openbmb/UltraCM-13b` + `https://huggingface.co/openbmb/UltraLM-13b` + """ + + def __init__( + self, + system="""User: A one-turn chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, very detailed, and polite answers to the user's questions.""", # noqa: E501 + eos='', + user='User: ', + assistant='Assistant: ', + session_len=2048, + **kwargs): + super().__init__(**kwargs) + self.system = system + self.eos = eos + self.session_len = session_len + self.user = user + self.assistant = assistant + + def decorate_prompt(self, prompt, sequence_start=True): + """Return the prompt that is concatenated with other elements in the + chat template. + + Args: + prompt (str): the input prompt + sequence_start (bool): indicator for the first round chat of a + session sequence + Returns: + str: the concatenated prompt + """ + assert self.capability == 'chat', \ + f'{type(self).__name__} has no capability of {self.capability}' + if sequence_start: + return f'{self.system}\n{self.user}{prompt}{self.eos}' \ + f'\n{self.assistant}' + + return f'\n{self.user}{prompt}{self.eos}' \ + f'\n{self.assistant}' + + def messages2prompt(self, messages, sequence_start=True): + """Return the prompt that is concatenated with other elements in the + chat template. Only evaluate the last instruction completion pair. + + Args: + messages (str | List): user's input prompt + Returns: + str: the concatenated prompt + """ + if isinstance(messages, str): + return self.get_prompt(messages, sequence_start) + system, users, assistants = self._translate_messages(messages) + system = self.system if not system else system + ret = f'{system}' + for user, assistant in zip(users, assistants): + if assistant: + ret += f'\n{self.user}{user}{self.eos}' \ + f'\n{self.assistant}{assistant}{self.eos}' + else: + ret += f'\n{self.user}{user}{self.eos}' \ + f'\n{self.assistant}' + return ret + + def main(model_name: str = 'test'): assert model_name in MODELS.module_dict.keys(), \ f"'{model_name}' is not supported. " \