diff --git a/langport/data/conversation/__init__.py b/langport/data/conversation/__init__.py
index 9dd9c21..f3993c5 100644
--- a/langport/data/conversation/__init__.py
+++ b/langport/data/conversation/__init__.py
@@ -178,20 +178,30 @@ def get_prompt(self) -> str:
return ret
elif self.settings.sep_style == SeparatorStyle.LLAMA:
B_INST, E_INST = "[INST]", "[/INST]"
- B_SYS, E_SYS = "<>\n", "\n<>\n\n"
- system_q = B_SYS + self.system + E_SYS
- system_a = ""
- ret = f"{B_INST} {system_q} {E_INST} {system_a}"
+ if system_prompt:
+ ret = system_prompt + self.settings.sep
+ else:
+ ret = ""
+
+ if self.messages[0][0] == "system":
+ self.messages.pop(0)
for i, (role, message) in enumerate(self.messages):
- if role == self.settings.roles[0]:
- if not(i != 0 and self.messages[i - 1][0] == self.settings.roles[0]):
- inst = B_INST
+ if i % 2 == 0:
+ inst = B_INST + " "
else:
- inst = E_INST
+ inst = E_INST + " "
+ if i == 0:
+ inst = ""
if message:
- ret += inst + " " + message.strip() + " "
+ if i % 2 == 0:
+ ret += inst + message.strip() + " "
+ else:
+ ret += inst + message.strip() + " "
+ if i == len(self.messages) - 1:
+ ret += E_INST
else:
- ret += inst + " "
+ ret += E_INST
+
return ret
elif self.settings.sep_style == SeparatorStyle.CHATLM:
im_start, im_end = "<|im_start|>", "<|im_end|>"
diff --git a/langport/data/conversation/settings/llama.py b/langport/data/conversation/settings/llama.py
index b6b9e57..ab6211a 100644
--- a/langport/data/conversation/settings/llama.py
+++ b/langport/data/conversation/settings/llama.py
@@ -7,6 +7,7 @@
# Llama default template
llama = ConversationSettings(
name="llama",
+ system_template="[INST] <>\n{system_message}\n<>\n\n",
roles=("user", "assistant"),
sep_style=SeparatorStyle.LLAMA,
sep="",
diff --git a/langport/data/conversation/settings/starchat.py b/langport/data/conversation/settings/starchat.py
index dee7a5f..cb5def4 100644
--- a/langport/data/conversation/settings/starchat.py
+++ b/langport/data/conversation/settings/starchat.py
@@ -7,7 +7,7 @@
# StarChat default template
starchat = ConversationSettings(
name="starchat",
- system_template="\n{system_message}",
+ system_template="<|system|>\n{system_message}",
roles=("<|user|>", "<|assistant|>"),
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
sep="<|end|>\n",
diff --git a/langport/service/server/tensorrt_worker.py b/langport/service/server/tensorrt_worker.py
new file mode 100644
index 0000000..e69de29
diff --git a/langport/tests/test_conversation.py b/langport/tests/test_conversation.py
index 4bf21d5..886b784 100644
--- a/langport/tests/test_conversation.py
+++ b/langport/tests/test_conversation.py
@@ -1,12 +1,103 @@
+from typing import List
import unittest
from langport.data.conversation.conversation_settings import ConversationHistory
from langport.data.conversation.settings.baichuan import baichuan
from langport.data.conversation.settings.chatglm import chatglm
from langport.data.conversation.settings.chatgpt import chatgpt
+from langport.data.conversation.settings.llama import llama
from langport.data.conversation.settings.openbuddy import openbuddy
from langport.data.conversation.settings.qwen import qwen
+from langport.data.conversation.settings.starchat import starchat
+
+class TestLlamaMethods(unittest.TestCase):
+ B_INST, E_INST = "[INST]", "[/INST]"
+ B_SYS, E_SYS = "<>\n", "\n<>\n\n"
+
+ SPECIAL_TAGS = [B_INST, E_INST, "<>", "<>"]
+ UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt."
+
+ def get_llama_prompt(self, dialogs):
+ unsafe_requests = []
+ prompt_tokens = []
+ for dialog in dialogs:
+ unsafe_requests.append(
+ any([tag in msg["content"] for tag in self.SPECIAL_TAGS for msg in dialog])
+ )
+ if dialog[0]["role"] == "system":
+ dialog = [
+ {
+ "role": dialog[1]["role"],
+ "content": self.B_SYS
+ + dialog[0]["content"]
+ + self.E_SYS
+ + dialog[1]["content"],
+ }
+ ] + dialog[2:]
+ assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
+ [msg["role"] == "assistant" for msg in dialog[1::2]]
+ ), (
+ "model only supports 'system', 'user' and 'assistant' roles, "
+ "starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
+ )
+ dialog_tokens: str = "".join([
+ f"{self.B_INST} {(prompt['content']).strip()} {self.E_INST} {(answer['content']).strip()} "
+ for prompt, answer in zip(
+ dialog[::2],
+ dialog[1::2],
+ )
+ ])
+ assert (
+ dialog[-1]["role"] == "user"
+ ), f"Last message must be from user, got {dialog[-1]['role']}"
+ dialog_tokens += f"{self.B_INST} {(dialog[-1]['content']).strip()} {self.E_INST}"
+ prompt_tokens.append(dialog_tokens)
+ return prompt_tokens
+
+ def test_conv(self):
+ history = ConversationHistory(
+ "SYSTEM_MESSAGE",
+ messages=[
+ (llama.roles[0], "aaa"),
+ (llama.roles[1], "bbb"),
+ (llama.roles[0], "ccc"),
+ ],
+ offset=0,
+ settings=llama
+ )
+ my_out = history.get_prompt()
+ llama_out = self.get_llama_prompt([[
+ {"role": "system", "content": "SYSTEM_MESSAGE"},
+ {"role": llama.roles[0], "content": "aaa"},
+ {"role": llama.roles[1], "content": "bbb"},
+ {"role": llama.roles[0], "content": "ccc"},
+ ]])[0]
+ self.assertEqual(my_out, llama_out)
+
+ def test_conv2(self):
+ history = ConversationHistory(
+ "SYSTEM_MESSAGE",
+ messages=[
+ (llama.roles[0], "aaa"),
+ (llama.roles[1], "bbb"),
+ (llama.roles[0], "ccc"),
+ (llama.roles[1], None),
+ ],
+ offset=0,
+ settings=llama
+ )
+ my_out = history.get_prompt()
+ llama_out = self.get_llama_prompt([[
+ {"role": "system", "content": "SYSTEM_MESSAGE"},
+ {"role": llama.roles[0], "content": "aaa"},
+ {"role": llama.roles[1], "content": "bbb"},
+ {"role": llama.roles[0], "content": "ccc"},
+ ]])[0]
+ print(my_out)
+ print(llama_out)
+ self.assertEqual(my_out, llama_out)
+
class TestBaiChuanMethods(unittest.TestCase):
@@ -90,5 +181,32 @@ def test_conv(self):
bbb<|im_end|>
""")
+
+class TestStarChatMethods(unittest.TestCase):
+
+ def test_conv(self):
+ history = ConversationHistory(
+ "SYSTEM_MESSAGE",
+ messages=[
+ (starchat.roles[0], "aaa"),
+ (starchat.roles[1], "bbb"),
+ ],
+ offset=0,
+ settings=starchat
+ )
+ self.assertEqual(history.get_prompt(), "<|system|>\nSYSTEM_MESSAGE<|end|>\n<|user|>\naaa<|end|>\n<|assistant|>\nbbb<|end|>\n")
+
+ def test_conv_question(self):
+ history = ConversationHistory(
+ "SYSTEM_MESSAGE",
+ messages=[
+ (starchat.roles[0], "aaa"),
+ (starchat.roles[1], None),
+ ],
+ offset=0,
+ settings=starchat
+ )
+ self.assertEqual(history.get_prompt(), "<|system|>\nSYSTEM_MESSAGE<|end|>\n<|user|>\naaa<|end|>\n<|assistant|>\n")
+
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
diff --git a/langport/version.py b/langport/version.py
index a16650d..47dbd54 100644
--- a/langport/version.py
+++ b/langport/version.py
@@ -1 +1 @@
-LANGPORT_VERSION = "0.3.7"
\ No newline at end of file
+LANGPORT_VERSION = "0.3.8"
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 882f48f..f9a9ba2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "langport"
-version = "0.3.7"
+version = "0.3.8"
description = "A large language model serving platform."
readme = "README.md"
requires-python = ">=3.8"