Skip to content

Commit

Permalink
update test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jan 31, 2024
1 parent 746b878 commit 197a672
Showing 1 changed file with 1 addition and 118 deletions.
119 changes: 1 addition & 118 deletions langport/tests/test_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,131 +3,14 @@
import unittest

from chatproto.conversation.history import ConversationHistory
from chatproto.conversation.models.baichuan import baichuan
from chatproto.conversation.models.baichuan import baichuan2
from chatproto.conversation.models.chatglm import chatglm
from chatproto.conversation.models.chatgpt import chatgpt
from chatproto.conversation.models.llama import llama
from chatproto.conversation.models.openbuddy import openbuddy
from chatproto.conversation.models.qwen import qwen
from chatproto.conversation.models.starchat import starchat

class TestLlamaMethods(unittest.TestCase):
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

SPECIAL_TAGS = [B_INST, E_INST, "<<SYS>>", "<</SYS>>"]
UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt."

def get_llama_prompt(self, dialogs):
unsafe_requests = []
prompt_tokens = []
for dialog in dialogs:
unsafe_requests.append(
any([tag in msg["content"] for tag in self.SPECIAL_TAGS for msg in dialog])
)
if dialog[0]["role"] == "system":
dialog = [
{
"role": dialog[1]["role"],
"content": self.B_SYS
+ dialog[0]["content"]
+ self.E_SYS
+ dialog[1]["content"],
}
] + dialog[2:]
assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
[msg["role"] == "assistant" for msg in dialog[1::2]]
), (
"model only supports 'system', 'user' and 'assistant' roles, "
"starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
)
dialog_tokens: str = "".join([
f"{self.B_INST} {(prompt['content']).strip()} {self.E_INST} {(answer['content']).strip()} "
for prompt, answer in zip(
dialog[::2],
dialog[1::2],
)
])
assert (
dialog[-1]["role"] == "user"
), f"Last message must be from user, got {dialog[-1]['role']}"
dialog_tokens += f"{self.B_INST} {(dialog[-1]['content']).strip()} {self.E_INST}"
prompt_tokens.append(dialog_tokens)
return prompt_tokens

def test_conv(self):
history = ConversationHistory(
"SYSTEM_MESSAGE",
messages=[
(llama.roles[0], "aaa"),
(llama.roles[1], "bbb"),
(llama.roles[0], "ccc"),
],
offset=0,
settings=llama
)
my_out = history.get_prompt()
llama_out = self.get_llama_prompt([[
{"role": "system", "content": "SYSTEM_MESSAGE"},
{"role": llama.roles[0], "content": "aaa"},
{"role": llama.roles[1], "content": "bbb"},
{"role": llama.roles[0], "content": "ccc"},
]])[0]
self.assertEqual(my_out, llama_out)

def test_conv2(self):
history = ConversationHistory(
"SYSTEM_MESSAGE",
messages=[
(llama.roles[0], "aaa"),
(llama.roles[1], "bbb"),
(llama.roles[0], "ccc"),
(llama.roles[1], None),
],
offset=0,
settings=llama
)
my_out = history.get_prompt()
llama_out = self.get_llama_prompt([[
{"role": "system", "content": "SYSTEM_MESSAGE"},
{"role": llama.roles[0], "content": "aaa"},
{"role": llama.roles[1], "content": "bbb"},
{"role": llama.roles[0], "content": "ccc"},
]])[0]
print(my_out)
print(llama_out)
self.assertEqual(my_out, llama_out)


class TestBaiChuanMethods(unittest.TestCase):

def test_conv(self):
history = ConversationHistory(
"SYSTEM_MESSAGE",
messages=[
(baichuan.roles[0], "aaa"),
(baichuan.roles[1], "bbb"),
],
offset=0,
settings=baichuan
)
self.assertEqual(history.get_prompt(), "SYSTEM_MESSAGE <reserved_102> aaa <reserved_103> bbb</s>")


class TestChatGLMMethods(unittest.TestCase):

def test_conv(self):
history = ConversationHistory(
"SYSTEM_MESSAGE",
messages=[
(chatglm.roles[0], "aaa"),
(chatglm.roles[1], "bbb"),
],
offset=0,
settings=chatglm
)
self.assertEqual(history.get_prompt(), "SYSTEM_MESSAGE\n\n[Round 1]\n\n问:aaa\n\n答:bbb\n\n")

class TestChatGPTMethods(unittest.TestCase):

def test_conv(self):
Expand Down

0 comments on commit 197a672

Please sign in to comment.