diff --git a/langport/data/conversation/__init__.py b/langport/data/conversation/__init__.py index 9dd9c21..f3993c5 100644 --- a/langport/data/conversation/__init__.py +++ b/langport/data/conversation/__init__.py @@ -178,20 +178,30 @@ def get_prompt(self) -> str: return ret elif self.settings.sep_style == SeparatorStyle.LLAMA: B_INST, E_INST = "[INST]", "[/INST]" - B_SYS, E_SYS = "<>\n", "\n<>\n\n" - system_q = B_SYS + self.system + E_SYS - system_a = "" - ret = f"{B_INST} {system_q} {E_INST} {system_a}" + if system_prompt: + ret = system_prompt + self.settings.sep + else: + ret = "" + + if self.messages[0][0] == "system": + self.messages.pop(0) for i, (role, message) in enumerate(self.messages): - if role == self.settings.roles[0]: - if not(i != 0 and self.messages[i - 1][0] == self.settings.roles[0]): - inst = B_INST + if i % 2 == 0: + inst = B_INST + " " else: - inst = E_INST + inst = E_INST + " " + if i == 0: + inst = "" if message: - ret += inst + " " + message.strip() + " " + if i % 2 == 0: + ret += inst + message.strip() + " " + else: + ret += inst + message.strip() + " " + if i == len(self.messages) - 1: + ret += E_INST else: - ret += inst + " " + ret += E_INST + return ret elif self.settings.sep_style == SeparatorStyle.CHATLM: im_start, im_end = "<|im_start|>", "<|im_end|>" diff --git a/langport/data/conversation/settings/llama.py b/langport/data/conversation/settings/llama.py index b6b9e57..ab6211a 100644 --- a/langport/data/conversation/settings/llama.py +++ b/langport/data/conversation/settings/llama.py @@ -7,6 +7,7 @@ # Llama default template llama = ConversationSettings( name="llama", + system_template="[INST] <>\n{system_message}\n<>\n\n", roles=("user", "assistant"), sep_style=SeparatorStyle.LLAMA, sep="", diff --git a/langport/data/conversation/settings/starchat.py b/langport/data/conversation/settings/starchat.py index dee7a5f..cb5def4 100644 --- a/langport/data/conversation/settings/starchat.py +++ b/langport/data/conversation/settings/starchat.py @@ -7,7 +7,7 @@ # StarChat default template starchat = ConversationSettings( name="starchat", - system_template="\n{system_message}", + system_template="<|system|>\n{system_message}", roles=("<|user|>", "<|assistant|>"), sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, sep="<|end|>\n", diff --git a/langport/service/server/tensorrt_worker.py b/langport/service/server/tensorrt_worker.py new file mode 100644 index 0000000..e69de29 diff --git a/langport/tests/test_conversation.py b/langport/tests/test_conversation.py index 4bf21d5..886b784 100644 --- a/langport/tests/test_conversation.py +++ b/langport/tests/test_conversation.py @@ -1,12 +1,103 @@ +from typing import List import unittest from langport.data.conversation.conversation_settings import ConversationHistory from langport.data.conversation.settings.baichuan import baichuan from langport.data.conversation.settings.chatglm import chatglm from langport.data.conversation.settings.chatgpt import chatgpt +from langport.data.conversation.settings.llama import llama from langport.data.conversation.settings.openbuddy import openbuddy from langport.data.conversation.settings.qwen import qwen +from langport.data.conversation.settings.starchat import starchat + +class TestLlamaMethods(unittest.TestCase): + B_INST, E_INST = "[INST]", "[/INST]" + B_SYS, E_SYS = "<>\n", "\n<>\n\n" + + SPECIAL_TAGS = [B_INST, E_INST, "<>", "<>"] + UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt." + + def get_llama_prompt(self, dialogs): + unsafe_requests = [] + prompt_tokens = [] + for dialog in dialogs: + unsafe_requests.append( + any([tag in msg["content"] for tag in self.SPECIAL_TAGS for msg in dialog]) + ) + if dialog[0]["role"] == "system": + dialog = [ + { + "role": dialog[1]["role"], + "content": self.B_SYS + + dialog[0]["content"] + + self.E_SYS + + dialog[1]["content"], + } + ] + dialog[2:] + assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( + [msg["role"] == "assistant" for msg in dialog[1::2]] + ), ( + "model only supports 'system', 'user' and 'assistant' roles, " + "starting with 'system', then 'user' and alternating (u/a/u/a/u...)" + ) + dialog_tokens: str = "".join([ + f"{self.B_INST} {(prompt['content']).strip()} {self.E_INST} {(answer['content']).strip()} " + for prompt, answer in zip( + dialog[::2], + dialog[1::2], + ) + ]) + assert ( + dialog[-1]["role"] == "user" + ), f"Last message must be from user, got {dialog[-1]['role']}" + dialog_tokens += f"{self.B_INST} {(dialog[-1]['content']).strip()} {self.E_INST}" + prompt_tokens.append(dialog_tokens) + return prompt_tokens + + def test_conv(self): + history = ConversationHistory( + "SYSTEM_MESSAGE", + messages=[ + (llama.roles[0], "aaa"), + (llama.roles[1], "bbb"), + (llama.roles[0], "ccc"), + ], + offset=0, + settings=llama + ) + my_out = history.get_prompt() + llama_out = self.get_llama_prompt([[ + {"role": "system", "content": "SYSTEM_MESSAGE"}, + {"role": llama.roles[0], "content": "aaa"}, + {"role": llama.roles[1], "content": "bbb"}, + {"role": llama.roles[0], "content": "ccc"}, + ]])[0] + self.assertEqual(my_out, llama_out) + + def test_conv2(self): + history = ConversationHistory( + "SYSTEM_MESSAGE", + messages=[ + (llama.roles[0], "aaa"), + (llama.roles[1], "bbb"), + (llama.roles[0], "ccc"), + (llama.roles[1], None), + ], + offset=0, + settings=llama + ) + my_out = history.get_prompt() + llama_out = self.get_llama_prompt([[ + {"role": "system", "content": "SYSTEM_MESSAGE"}, + {"role": llama.roles[0], "content": "aaa"}, + {"role": llama.roles[1], "content": "bbb"}, + {"role": llama.roles[0], "content": "ccc"}, + ]])[0] + print(my_out) + print(llama_out) + self.assertEqual(my_out, llama_out) + class TestBaiChuanMethods(unittest.TestCase): @@ -90,5 +181,32 @@ def test_conv(self): bbb<|im_end|> """) + +class TestStarChatMethods(unittest.TestCase): + + def test_conv(self): + history = ConversationHistory( + "SYSTEM_MESSAGE", + messages=[ + (starchat.roles[0], "aaa"), + (starchat.roles[1], "bbb"), + ], + offset=0, + settings=starchat + ) + self.assertEqual(history.get_prompt(), "<|system|>\nSYSTEM_MESSAGE<|end|>\n<|user|>\naaa<|end|>\n<|assistant|>\nbbb<|end|>\n") + + def test_conv_question(self): + history = ConversationHistory( + "SYSTEM_MESSAGE", + messages=[ + (starchat.roles[0], "aaa"), + (starchat.roles[1], None), + ], + offset=0, + settings=starchat + ) + self.assertEqual(history.get_prompt(), "<|system|>\nSYSTEM_MESSAGE<|end|>\n<|user|>\naaa<|end|>\n<|assistant|>\n") + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/langport/version.py b/langport/version.py index a16650d..47dbd54 100644 --- a/langport/version.py +++ b/langport/version.py @@ -1 +1 @@ -LANGPORT_VERSION = "0.3.7" \ No newline at end of file +LANGPORT_VERSION = "0.3.8" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 882f48f..f9a9ba2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "langport" -version = "0.3.7" +version = "0.3.8" description = "A large language model serving platform." readme = "README.md" requires-python = ">=3.8"