Skip to content

Commit

Permalink
Pass chat template args including meta_prompt to model (#225)
Browse files Browse the repository at this point in the history
* pass args like meta_prompt to model

* update chatbot

* update

* rollback

* update llama2 and qwen

* refine
  • Loading branch information
AllentDan authored Aug 21, 2023
1 parent f44ef17 commit 7785142
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 52 deletions.
163 changes: 113 additions & 50 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@
class BaseModel:
"""Base model."""

def __init__(self):
self.session_len = 2048
self.top_p = 0.8
self.top_k = None
self.temperature = 0.8
self.repetition_penalty = 1.0
def __init__(self,
session_len=2048,
top_p=0.8,
top_k=None,
temperature=0.8,
repetition_penalty=1.0,
**kwargs):
self.session_len = session_len
self.top_p = top_p
self.top_k = top_k
self.temperature = temperature
self.repetition_penalty = repetition_penalty

@staticmethod
def get_prompt(prompt, sequence_start=True):
Expand All @@ -39,11 +45,16 @@ def stop_words(self):
class Vicuna(BaseModel):
"""Chat template of vicuna model."""

def __init__(self):
super().__init__()
self.system = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """ # noqa: E501
self.user = 'USER'
self.assistant = 'ASSISTANT'
def __init__(
self,
system="""A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """, # noqa: E501
user='USER',
assistant='ASSISTANT',
**kwargs):
super().__init__(**kwargs)
self.system = system
self.user = user
self.assistant = assistant

def get_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
Expand All @@ -65,21 +76,27 @@ def get_prompt(self, prompt, sequence_start=True):
@MODELS.register_module(name='internlm')
class InternLM(BaseModel):

def __init__(self):
super().__init__()
def __init__(self, **kwargs):
super().__init__(**kwargs)


@MODELS.register_module(name='internlm-chat-7b')
class InternLMChat7B(BaseModel):
"""Chat template of InternLM model."""

def __init__(self):
super().__init__()
self.system = ''
self.user = '<|User|>'
self.eoh = '<eoh>'
self.eoa = '<eoa>'
self.assistant = '<|Bot|>'
def __init__(self,
system='',
user='<|User|>',
eoh='<eoh>',
eoa='<eoa>',
assistant='<|Bot|>',
**kwargs):
super().__init__(**kwargs)
self.system = system
self.user = user
self.eoh = eoh
self.eoa = eoa
self.assistant = assistant

def get_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
Expand Down Expand Up @@ -108,39 +125,77 @@ def stop_words(self):
@MODELS.register_module(name='internlm-chat-7b-8k')
class InternLMChat7B8K(InternLMChat7B):

def __init__(self):
super(InternLMChat7B8K, self).__init__()
self.session_len = 8192
def __init__(self, session_len=8192, **kwargs):
super(InternLMChat7B8K, self).__init__(**kwargs)
self.session_len = session_len


@MODELS.register_module(name='baichuan-7b')
class Baichuan7B(BaseModel):

def __init__(self):
super().__init__()
self.repetition_penalty = 1.1
def __init__(self, repetition_penalty=1.1, **kwargs):
super().__init__(**kwargs)
self.repetition_penalty = repetition_penalty


@MODELS.register_module(name='puyu')
class Puyu(BaseModel):
"""Chat template of puyu model.This is only for internal usage in Shanghai
AI Laboratory."""

def __init__(self,
meta_instruction='',
user='<|Human|>: ',
eoh='',
eosys='',
assistant='<|Assistant|>: ',
system='<|System|>: ',
**kwargs):
super().__init__(**kwargs)
self.meta_instruction = meta_instruction
self.user = user
self.eoh = eoh
self.eosys = eosys
self.assistant = assistant
self.system = system

def get_prompt(self, prompt, sequence_start=True):
if sequence_start:
return f'<BOS>{self.system}{self.meta_instruction}{self.eosys}\n' \
f'{self.user}{prompt}{self.eoh}\n' \
f'{self.assistant}'
else:
return f'\n{self.user}{prompt}{self.eoh}\n{self.assistant}'

@property
def stop_words(self):
"""Return the stop-words' token ids."""
return [45623]


@MODELS.register_module(name='llama2')
class Llama2(BaseModel):
"""Chat template of LLaMA2 model."""

def __init__(self):
super().__init__()
B_INST, E_INST = '[INST]', '[/INST]'
B_SYS, E_SYS = '<<SYS>>\n', '\n<</SYS>>\n\n'

DEFAULT_SYSTEM_PROMPT = """\
def __init__(
self,
b_inst='[INST]',
e_inst='[/INST]',
b_sys='<<SYS>>\n',
e_sys='\n<</SYS>>\n\n',
default_sys_prompt="""\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" # noqa: E501

self.b_inst = B_INST
self.e_inst = E_INST
self.b_sys = B_SYS
self.e_sys = E_SYS
self.default_sys_prompt = DEFAULT_SYSTEM_PROMPT
self.session_len = 4096
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", # noqa: E501
session_len=4096,
**kwargs):
super().__init__(**kwargs)
self.b_inst = b_inst
self.e_inst = e_inst
self.b_sys = b_sys
self.e_sys = e_sys
self.default_sys_prompt = default_sys_prompt
self.session_len = session_len

def get_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
Expand All @@ -165,16 +220,24 @@ def get_prompt(self, prompt, sequence_start=True):
class Qwen7BChat(BaseModel):
"""Chat template for Qwen-7B-Chat."""

def __init__(self):
super().__init__()
self.session_len = 8192
self.top_p = 0.5
self.top_k = 40
self.temperature = 1.0

self.im_start = '<|im_start|>'
self.im_end = '<|im_end|>'
self.system = 'You are a helpful assistant.'
def __init__(self,
session_len=8192,
top_p=0.5,
top_k=40,
temperature=1.0,
im_start='<|im_start|>',
im_end='<|im_end|>',
system='You are a helpful assistant.',
**kwargs):
super().__init__(**kwargs)
self.session_len = session_len
self.top_p = top_p
self.top_k = top_k
self.temperature = temperature

self.im_start = im_start
self.im_end = im_end
self.system = system

def get_prompt(self, prompt, sequence_start=True):
if sequence_start:
Expand Down
5 changes: 3 additions & 2 deletions lmdeploy/serve/turbomind/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,16 @@ def __init__(self,
log_level: int = logging.INFO,
display: bool = False,
profile_generation: bool = False,
profile_serving: bool = False):
profile_serving: bool = False,
**model_kwargs):
self.tritonserver_addr = tritonserver_addr
self.model_name = model_name
if self.model_name == '':
self.model_name = self._get_model_name()
assert self.model_name in MODELS.module_dict.keys(), \
f"'{self.model_name}' is not supported. " \
f'The supported models are: {MODELS.module_dict.keys()}'
self.model = MODELS.get(self.model_name)()
self.model = MODELS.get(self.model_name)(**model_kwargs)
self._session = None
self.preprocess = Preprocessor(tritonserver_addr)
self.postprocess = Postprocessor(tritonserver_addr)
Expand Down

0 comments on commit 7785142

Please sign in to comment.