Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor chat template and support accurate name matching. #1216

Merged
merged 33 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
b1767ab
remove sampling parameters in model.py
AllentDan Feb 28, 2024
37ffff3
Refactor chat template
AllentDan Feb 29, 2024
98e7d19
Merge branch 'main' into refactor-model
AllentDan Feb 29, 2024
267a8cc
gemma
AllentDan Feb 29, 2024
b233d4e
add match
AllentDan Feb 29, 2024
df8b4fd
fix UT
AllentDan Mar 1, 2024
0015384
remove fuzzywuzzy
AllentDan Mar 1, 2024
ff28a3c
fix llama2 UT after removing fuzzy matching
AllentDan Mar 1, 2024
1cb8c01
Merge branch 'main' into refactor-model
AllentDan Mar 1, 2024
6e1c88d
update chatglm2
AllentDan Mar 1, 2024
d1ea3cf
Merge branch 'main' into refactor-model
AllentDan Mar 1, 2024
fae1faf
remove fuzzywuzzy in readthedocs
AllentDan Mar 1, 2024
5eed3ee
rename stop_word_suffix -> seperator
AllentDan Mar 4, 2024
cba0e4a
fix UT
AllentDan Mar 4, 2024
a4f4897
resolve comments
AllentDan Mar 5, 2024
c2cd240
hide some model_names
AllentDan Mar 6, 2024
9478f83
fix lmdeploy list --engine turbomind
AllentDan Mar 6, 2024
0b291ea
fix
AllentDan Mar 6, 2024
498ebb2
Merge branch 'main' into refactor-model
AllentDan Mar 6, 2024
34a36ce
mv deprecate to front
AllentDan Mar 6, 2024
b459083
match is chat
AllentDan Mar 6, 2024
4be54ad
fix UT
AllentDan Mar 7, 2024
c293b8b
recover two names
AllentDan Mar 7, 2024
4d853b3
clean names
AllentDan Mar 7, 2024
c85fcf9
recover vicuna back
AllentDan Mar 7, 2024
bff5c43
put llama back
AllentDan Mar 7, 2024
c7d83fd
fix yi
AllentDan Mar 7, 2024
b5e727e
better match for deepseek
AllentDan Mar 11, 2024
96f4284
fix UT
AllentDan Mar 11, 2024
551ce22
remove eoa from solar model chat template
AllentDan Mar 11, 2024
340e5b4
use default if not passed in
AllentDan Mar 11, 2024
1a4cf76
better deprecate hint
AllentDan Mar 11, 2024
9ed42c1
update example
AllentDan Mar 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 20 additions & 25 deletions examples/vl/qwen_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,12 @@ def __init__(self,
top_p=0.3,
top_k=None,
temperature=1.0,
im_start='<|im_start|>',
AllentDan marked this conversation as resolved.
Show resolved Hide resolved
im_end='<|im_end|>',
system='You are a helpful assistant.',
stop_words=['<|im_end|>'],
**kwargs):
super().__init__(**kwargs)
self.session_len = session_len
self.top_p = top_p
self.top_k = top_k
self.temperature = temperature
self.im_start = im_start
self.im_end = im_end
self.system = system
self.stop_words = stop_words

def _concat_image_info(self, prompt):
"""Append image placeholder."""
Expand All @@ -45,27 +37,30 @@ def _concat_image_info(self, prompt):
prompt = res + prompt
return prompt

def decorate_prompt(self, prompt, sequence_start=True):
def get_prompt(self, prompt, sequence_start=True):
"""Apply chat template to prompt."""
prompt = self._concat_image_info(prompt)
return super().decorate_prompt(prompt, sequence_start)
return super().get_prompt(prompt, sequence_start)

def messages2prompt(self, messages, sequence_start=True):
"""Apply chat template to history."""
if isinstance(messages, str) or isinstance(messages[0], str):
return self.decorate_prompt(messages, sequence_start)
system, users, assistants = self._translate_messages(messages)
ret = f'{self.im_start}system\n{system}{self.im_end}'
for user, assistant in zip(users, assistants):
if not isinstance(user):
user = [user[0]['text'], len(user) - 1]
user = self._concat_image_info(user)
if assistant:
ret += f'\n{self.im_start}user\n{user}{self.im_end}' \
f'\n{self.im_start}assistant\n{assistant}'
else:
ret += f'\n{self.im_start}user\n{user}{self.im_end}' \
f'\n{self.im_start}assistant\n'
return self.get_prompt(messages, sequence_start)
box_map = dict(user=self.user,
assistant=self.assistant,
system=self.system)
eox_map = dict(user=self.eoh,
assistant=self.eoa + self.stop_word_suffix,
system=self.eosys)
ret = ''
for message in messages:
role = message['role']
content = message['content']
if role == 'user' and not isinstance(content, str):
content = [content[0]['text'], len(content) - 1]
content = self._concat_image_info(content)
ret += f'{box_map[role]}{content}{eox_map[role]}'
ret += f'{self.assistant}'
return ret


Expand Down Expand Up @@ -134,8 +129,8 @@ def prepare_query(self, query, sequence_start=True):
image_paths = []
if not isinstance(query, str):
query, image_paths = query[0], query[1:]
decorate_text = self.decorator.decorate_prompt(
(query, len(image_paths)), sequence_start)
decorate_text = self.decorator.get_prompt((query, len(image_paths)),
sequence_start)
return self._to_inputs(decorate_text, image_paths, sequence_start)

def prepare_message(self, messages):
Expand Down
54 changes: 27 additions & 27 deletions examples/vl/xcomposer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from lmdeploy.model import MODELS, BaseModel
from lmdeploy.model import MODELS, BaseChatTemplate

meta_instruction = """meta instruction
You are an AI assistant whose name is 浦语.
Expand All @@ -20,20 +20,20 @@


@MODELS.register_module(name='internlm-xcomposer-7b')
class InternLMXComposerTemplate(BaseModel):
class InternLMXComposerTemplate(BaseChatTemplate):
"""Internlm xcomposer chat template."""

def __init__(self,
system=meta_instruction,
user='<|User|>:',
assistant='<|Bot|>:',
meta_instruction=meta_instruction,
user=' <|User|>: ',
assistant=' <|Bot|>: ',
eoh='<TOKENS_UNUSED_0>',
eoa='<TOKENS_UNUSED_1>',
stop_words=['<TOKENS_UNUSED_0>', '<TOKENS_UNUSED_1>'],
image_placeholder='<Img><ImageHere></Img>',
**kwargs):
super().__init__(**kwargs)
self.system = system
self.meta_instruction = meta_instruction
self.user = user
self.assistant = assistant
self.eoh = eoh
Expand All @@ -51,31 +51,32 @@ def _concat_image_info(self, prompt):
prompt = f'{self.image_placeholder}{prompt}'
return prompt

def decorate_prompt(self, prompt, sequence_start=True):
def get_prompt(self, prompt, sequence_start=True):
"""Apply chat template to prompt."""
prompt = self._concat_image_info(prompt)
if sequence_start:
return f'{self.system} {self.user} {prompt}{self.eoh} {self.assistant}' # noqa
else:
return f' {self.user} {prompt}{self.eoh} {self.assistant}'
return super().get_prompt(prompt, sequence_start)

def messages2prompt(self, messages, sequence_start=True):
"""Apply chat template to history."""
if isinstance(messages, str) or isinstance(messages[0], str):
return self.decorate_prompt(messages, sequence_start)
system, users, assistants = self._translate_messages(messages)
system = self.system if not system else system
ret = system
for user, assistant in zip(users, assistants):
if not isinstance(user, str):
assert isinstance(user, Sequence)
assert all(isinstance(item, dict) for item in user)
user = [user[0]['text'], len(user) - 1]
user = self._concat_image_info(user)
if assistant:
ret += f' {self.user} {user}{self.eoh} {self.assistant} {assistant}{self.eoa}' # noqa
else:
ret += f' {self.user} {user}{self.eoh} {self.assistant}'
return self.get_prompt(messages, sequence_start)
box_map = dict(user=self.user,
assistant=self.assistant,
system=self.system)
eox_map = dict(user=self.eoh,
assistant=self.eoa + self.stop_word_suffix,
system=self.eosys)
ret = ''
for message in messages:
role = message['role']
content = message['content']
if role == 'user' and not isinstance(content, str):
assert isinstance(content, Sequence)
assert all(isinstance(item, dict) for item in content)
content = [content[0]['text'], len(content) - 1]
content = self._concat_image_info(content)
ret += f'{box_map[role]}{content}{eox_map[role]}'
ret += f'{self.assistant}'
return ret


Expand Down Expand Up @@ -155,8 +156,7 @@ def prepare_query(self, query, sequence_start=True):
if len(image_paths) > 1:
print('does not support multiple images, use last one.')
image_paths = image_paths[-1:]
decorate_text = self.decorator.decorate_prompt(
(query, len(image_paths)))
decorate_text = self.decorator.get_prompt((query, len(image_paths)))
return self._to_inputs(decorate_text, image_paths, sequence_start)

def prepare_message(self, messages):
Expand Down
31 changes: 19 additions & 12 deletions lmdeploy/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,27 @@ def convert(args):
@staticmethod
def list(args):
"""List the supported model names."""
engine = args.engine
assert engine in ['turbomind', 'pytorch']
if engine == 'pytorch':
model_names = [
'llama', 'llama2', 'internlm', 'internlm2', 'baichuan2',
'chatglm2', 'falcon', 'yi', 'mistral', 'mixtral', 'qwen1.5',
'gemma', 'deepseek'
]
elif engine == 'turbomind':
from lmdeploy.model import MODELS
model_names = list(MODELS.module_dict.keys())
model_names = [n for n in model_names if n.lower() not in ['base']]
from lmdeploy.model import MODELS
model_names = list(MODELS.module_dict.keys())
deprecate_names = [
'baichuan-7b', 'baichuan2-7b', 'chatglm2-6b', 'internlm',
'internlm-chat-20b', 'internlm-chat-7b', 'internlm-chat-7b-8k',
'internlm2', 'internlm2-1_8b', 'internlm2-20b', 'internlm2-7b',
'internlm2-chat-1_8b', 'internlm2-chat-20b', 'internlm2-chat-7b',
'llama-2-chat', 'llama2', 'qwen-14b', 'qwen-7b', 'solar-70b',
'yi-200k', 'yi-34b', 'yi-chat'
]
model_names = [
n for n in model_names
if n.lower() not in deprecate_names + ['base']
]
model_names.sort()
print('Supported model names:')
yellow = '\033[33m'
reset = '\033[0m'
max_name_width = max([len(name) for name in deprecate_names])
for name in deprecate_names:
print(f'{name:<{max_name_width}} {yellow}Deprecate soon{reset}')
print('\n'.join(model_names))

@staticmethod
Expand Down
Loading
Loading