Skip to content

Commit

Permalink
Add UltraCM and WizardLM chat templates (#599)
Browse files Browse the repository at this point in the history
* add ultracm eval chat template

* add WizardLM chat template

* use ultrachat template instead of ultracm usecase
  • Loading branch information
AllentDan authored Nov 9, 2023
1 parent 18170ee commit 7749128
Showing 1 changed file with 69 additions and 0 deletions.
69 changes: 69 additions & 0 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def sampling_param(self):
repetition_penalty=self.repetition_penalty)


@MODELS.register_module(name='wizardlM')
@MODELS.register_module(name='vicuna')
class Vicuna(BaseModel):
"""Chat template of vicuna model."""
Expand Down Expand Up @@ -647,6 +648,74 @@ def messages2prompt(self, messages, sequence_start=True):
return ret


@MODELS.register_module(name='ultracm')
@MODELS.register_module(name='ultralm')
class UltraChat(BaseModel):
"""Template of UltraCM and UltraLM models.
`https://huggingface.co/openbmb/UltraCM-13b`
`https://huggingface.co/openbmb/UltraLM-13b`
"""

def __init__(
self,
system="""User: A one-turn chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, very detailed, and polite answers to the user's questions.</s>""", # noqa: E501
eos='</s>',
user='User: ',
assistant='Assistant: ',
session_len=2048,
**kwargs):
super().__init__(**kwargs)
self.system = system
self.eos = eos
self.session_len = session_len
self.user = user
self.assistant = assistant

def decorate_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template.
Args:
prompt (str): the input prompt
sequence_start (bool): indicator for the first round chat of a
session sequence
Returns:
str: the concatenated prompt
"""
assert self.capability == 'chat', \
f'{type(self).__name__} has no capability of {self.capability}'
if sequence_start:
return f'{self.system}\n{self.user}{prompt}{self.eos}' \
f'\n{self.assistant}'

return f'\n{self.user}{prompt}{self.eos}' \
f'\n{self.assistant}'

def messages2prompt(self, messages, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template. Only evaluate the last instruction completion pair.
Args:
messages (str | List): user's input prompt
Returns:
str: the concatenated prompt
"""
if isinstance(messages, str):
return self.get_prompt(messages, sequence_start)
system, users, assistants = self._translate_messages(messages)
system = self.system if not system else system
ret = f'{system}'
for user, assistant in zip(users, assistants):
if assistant:
ret += f'\n{self.user}{user}{self.eos}' \
f'\n{self.assistant}{assistant}{self.eos}'
else:
ret += f'\n{self.user}{user}{self.eos}' \
f'\n{self.assistant}'
return ret


def main(model_name: str = 'test'):
assert model_name in MODELS.module_dict.keys(), \
f"'{model_name}' is not supported. " \
Expand Down

0 comments on commit 7749128

Please sign in to comment.