From 7785142d7c13a21bc01c2e7c0bc10b82964371f1 Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Mon, 21 Aug 2023 21:20:39 +0800 Subject: [PATCH] Pass chat template args including meta_prompt to model (#225) * pass args like meta_prompt to model * update chatbot * update * rollback * update llama2 and qwen * refine --- lmdeploy/model.py | 163 +++++++++++++++++++--------- lmdeploy/serve/turbomind/chatbot.py | 5 +- 2 files changed, 116 insertions(+), 52 deletions(-) diff --git a/lmdeploy/model.py b/lmdeploy/model.py index 8dc52b7b7..8039ca12f 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -8,12 +8,18 @@ class BaseModel: """Base model.""" - def __init__(self): - self.session_len = 2048 - self.top_p = 0.8 - self.top_k = None - self.temperature = 0.8 - self.repetition_penalty = 1.0 + def __init__(self, + session_len=2048, + top_p=0.8, + top_k=None, + temperature=0.8, + repetition_penalty=1.0, + **kwargs): + self.session_len = session_len + self.top_p = top_p + self.top_k = top_k + self.temperature = temperature + self.repetition_penalty = repetition_penalty @staticmethod def get_prompt(prompt, sequence_start=True): @@ -39,11 +45,16 @@ def stop_words(self): class Vicuna(BaseModel): """Chat template of vicuna model.""" - def __init__(self): - super().__init__() - self.system = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """ # noqa: E501 - self.user = 'USER' - self.assistant = 'ASSISTANT' + def __init__( + self, + system="""A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """, # noqa: E501 + user='USER', + assistant='ASSISTANT', + **kwargs): + super().__init__(**kwargs) + self.system = system + self.user = user + self.assistant = assistant def get_prompt(self, prompt, sequence_start=True): """Return the prompt that is concatenated with other elements in the @@ -65,21 +76,27 @@ def get_prompt(self, prompt, sequence_start=True): @MODELS.register_module(name='internlm') class InternLM(BaseModel): - def __init__(self): - super().__init__() + def __init__(self, **kwargs): + super().__init__(**kwargs) @MODELS.register_module(name='internlm-chat-7b') class InternLMChat7B(BaseModel): """Chat template of InternLM model.""" - def __init__(self): - super().__init__() - self.system = '' - self.user = '<|User|>' - self.eoh = '' - self.eoa = '' - self.assistant = '<|Bot|>' + def __init__(self, + system='', + user='<|User|>', + eoh='', + eoa='', + assistant='<|Bot|>', + **kwargs): + super().__init__(**kwargs) + self.system = system + self.user = user + self.eoh = eoh + self.eoa = eoa + self.assistant = assistant def get_prompt(self, prompt, sequence_start=True): """Return the prompt that is concatenated with other elements in the @@ -108,39 +125,77 @@ def stop_words(self): @MODELS.register_module(name='internlm-chat-7b-8k') class InternLMChat7B8K(InternLMChat7B): - def __init__(self): - super(InternLMChat7B8K, self).__init__() - self.session_len = 8192 + def __init__(self, session_len=8192, **kwargs): + super(InternLMChat7B8K, self).__init__(**kwargs) + self.session_len = session_len @MODELS.register_module(name='baichuan-7b') class Baichuan7B(BaseModel): - def __init__(self): - super().__init__() - self.repetition_penalty = 1.1 + def __init__(self, repetition_penalty=1.1, **kwargs): + super().__init__(**kwargs) + self.repetition_penalty = repetition_penalty + + +@MODELS.register_module(name='puyu') +class Puyu(BaseModel): + """Chat template of puyu model.This is only for internal usage in Shanghai + AI Laboratory.""" + + def __init__(self, + meta_instruction='', + user='<|Human|>: ', + eoh='', + eosys='', + assistant='<|Assistant|>: ', + system='<|System|>: ', + **kwargs): + super().__init__(**kwargs) + self.meta_instruction = meta_instruction + self.user = user + self.eoh = eoh + self.eosys = eosys + self.assistant = assistant + self.system = system + + def get_prompt(self, prompt, sequence_start=True): + if sequence_start: + return f'{self.system}{self.meta_instruction}{self.eosys}\n' \ + f'{self.user}{prompt}{self.eoh}\n' \ + f'{self.assistant}' + else: + return f'\n{self.user}{prompt}{self.eoh}\n{self.assistant}' + + @property + def stop_words(self): + """Return the stop-words' token ids.""" + return [45623] @MODELS.register_module(name='llama2') class Llama2(BaseModel): """Chat template of LLaMA2 model.""" - def __init__(self): - super().__init__() - B_INST, E_INST = '[INST]', '[/INST]' - B_SYS, E_SYS = '<>\n', '\n<>\n\n' - - DEFAULT_SYSTEM_PROMPT = """\ + def __init__( + self, + b_inst='[INST]', + e_inst='[/INST]', + b_sys='<>\n', + e_sys='\n<>\n\n', + default_sys_prompt="""\ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. -If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" # noqa: E501 - - self.b_inst = B_INST - self.e_inst = E_INST - self.b_sys = B_SYS - self.e_sys = E_SYS - self.default_sys_prompt = DEFAULT_SYSTEM_PROMPT - self.session_len = 4096 +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", # noqa: E501 + session_len=4096, + **kwargs): + super().__init__(**kwargs) + self.b_inst = b_inst + self.e_inst = e_inst + self.b_sys = b_sys + self.e_sys = e_sys + self.default_sys_prompt = default_sys_prompt + self.session_len = session_len def get_prompt(self, prompt, sequence_start=True): """Return the prompt that is concatenated with other elements in the @@ -165,16 +220,24 @@ def get_prompt(self, prompt, sequence_start=True): class Qwen7BChat(BaseModel): """Chat template for Qwen-7B-Chat.""" - def __init__(self): - super().__init__() - self.session_len = 8192 - self.top_p = 0.5 - self.top_k = 40 - self.temperature = 1.0 - - self.im_start = '<|im_start|>' - self.im_end = '<|im_end|>' - self.system = 'You are a helpful assistant.' + def __init__(self, + session_len=8192, + top_p=0.5, + top_k=40, + temperature=1.0, + im_start='<|im_start|>', + im_end='<|im_end|>', + system='You are a helpful assistant.', + **kwargs): + super().__init__(**kwargs) + self.session_len = session_len + self.top_p = top_p + self.top_k = top_k + self.temperature = temperature + + self.im_start = im_start + self.im_end = im_end + self.system = system def get_prompt(self, prompt, sequence_start=True): if sequence_start: diff --git a/lmdeploy/serve/turbomind/chatbot.py b/lmdeploy/serve/turbomind/chatbot.py index c1aadac04..a8d32825a 100644 --- a/lmdeploy/serve/turbomind/chatbot.py +++ b/lmdeploy/serve/turbomind/chatbot.py @@ -76,7 +76,8 @@ def __init__(self, log_level: int = logging.INFO, display: bool = False, profile_generation: bool = False, - profile_serving: bool = False): + profile_serving: bool = False, + **model_kwargs): self.tritonserver_addr = tritonserver_addr self.model_name = model_name if self.model_name == '': @@ -84,7 +85,7 @@ def __init__(self, assert self.model_name in MODELS.module_dict.keys(), \ f"'{self.model_name}' is not supported. " \ f'The supported models are: {MODELS.module_dict.keys()}' - self.model = MODELS.get(self.model_name)() + self.model = MODELS.get(self.model_name)(**model_kwargs) self._session = None self.preprocess = Preprocessor(tritonserver_addr) self.postprocess = Postprocessor(tritonserver_addr)