diff --git a/langport/tests/test_conversation.py b/langport/tests/test_conversation.py
index 8a73a1a..2f007c2 100644
--- a/langport/tests/test_conversation.py
+++ b/langport/tests/test_conversation.py
@@ -3,7 +3,7 @@
import unittest
from chatproto.conversation.history import ConversationHistory
-from chatproto.conversation.models.baichuan import baichuan
+from chatproto.conversation.models.baichuan import baichuan2
from chatproto.conversation.models.chatglm import chatglm
from chatproto.conversation.models.chatgpt import chatgpt
from chatproto.conversation.models.llama import llama
@@ -11,123 +11,6 @@
from chatproto.conversation.models.qwen import qwen
from chatproto.conversation.models.starchat import starchat
-class TestLlamaMethods(unittest.TestCase):
- B_INST, E_INST = "[INST]", "[/INST]"
- B_SYS, E_SYS = "<>\n", "\n<>\n\n"
-
- SPECIAL_TAGS = [B_INST, E_INST, "<>", "<>"]
- UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt."
-
- def get_llama_prompt(self, dialogs):
- unsafe_requests = []
- prompt_tokens = []
- for dialog in dialogs:
- unsafe_requests.append(
- any([tag in msg["content"] for tag in self.SPECIAL_TAGS for msg in dialog])
- )
- if dialog[0]["role"] == "system":
- dialog = [
- {
- "role": dialog[1]["role"],
- "content": self.B_SYS
- + dialog[0]["content"]
- + self.E_SYS
- + dialog[1]["content"],
- }
- ] + dialog[2:]
- assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
- [msg["role"] == "assistant" for msg in dialog[1::2]]
- ), (
- "model only supports 'system', 'user' and 'assistant' roles, "
- "starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
- )
- dialog_tokens: str = "".join([
- f"{self.B_INST} {(prompt['content']).strip()} {self.E_INST} {(answer['content']).strip()} "
- for prompt, answer in zip(
- dialog[::2],
- dialog[1::2],
- )
- ])
- assert (
- dialog[-1]["role"] == "user"
- ), f"Last message must be from user, got {dialog[-1]['role']}"
- dialog_tokens += f"{self.B_INST} {(dialog[-1]['content']).strip()} {self.E_INST}"
- prompt_tokens.append(dialog_tokens)
- return prompt_tokens
-
- def test_conv(self):
- history = ConversationHistory(
- "SYSTEM_MESSAGE",
- messages=[
- (llama.roles[0], "aaa"),
- (llama.roles[1], "bbb"),
- (llama.roles[0], "ccc"),
- ],
- offset=0,
- settings=llama
- )
- my_out = history.get_prompt()
- llama_out = self.get_llama_prompt([[
- {"role": "system", "content": "SYSTEM_MESSAGE"},
- {"role": llama.roles[0], "content": "aaa"},
- {"role": llama.roles[1], "content": "bbb"},
- {"role": llama.roles[0], "content": "ccc"},
- ]])[0]
- self.assertEqual(my_out, llama_out)
-
- def test_conv2(self):
- history = ConversationHistory(
- "SYSTEM_MESSAGE",
- messages=[
- (llama.roles[0], "aaa"),
- (llama.roles[1], "bbb"),
- (llama.roles[0], "ccc"),
- (llama.roles[1], None),
- ],
- offset=0,
- settings=llama
- )
- my_out = history.get_prompt()
- llama_out = self.get_llama_prompt([[
- {"role": "system", "content": "SYSTEM_MESSAGE"},
- {"role": llama.roles[0], "content": "aaa"},
- {"role": llama.roles[1], "content": "bbb"},
- {"role": llama.roles[0], "content": "ccc"},
- ]])[0]
- print(my_out)
- print(llama_out)
- self.assertEqual(my_out, llama_out)
-
-
-class TestBaiChuanMethods(unittest.TestCase):
-
- def test_conv(self):
- history = ConversationHistory(
- "SYSTEM_MESSAGE",
- messages=[
- (baichuan.roles[0], "aaa"),
- (baichuan.roles[1], "bbb"),
- ],
- offset=0,
- settings=baichuan
- )
- self.assertEqual(history.get_prompt(), "SYSTEM_MESSAGE aaa bbb")
-
-
-class TestChatGLMMethods(unittest.TestCase):
-
- def test_conv(self):
- history = ConversationHistory(
- "SYSTEM_MESSAGE",
- messages=[
- (chatglm.roles[0], "aaa"),
- (chatglm.roles[1], "bbb"),
- ],
- offset=0,
- settings=chatglm
- )
- self.assertEqual(history.get_prompt(), "SYSTEM_MESSAGE\n\n[Round 1]\n\n问:aaa\n\n答:bbb\n\n")
-
class TestChatGPTMethods(unittest.TestCase):
def test_conv(self):