Skip to content

Commit

Permalink
Merge pull request #24 from vtuber-plan/development
Browse files Browse the repository at this point in the history
Llama2 prompt template bug fix
  • Loading branch information
jstzwj authored Nov 6, 2023
2 parents ee32f73 + de95d96 commit 2c5b9e9
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 13 deletions.
30 changes: 20 additions & 10 deletions langport/data/conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,20 +178,30 @@ def get_prompt(self) -> str:
return ret
elif self.settings.sep_style == SeparatorStyle.LLAMA:
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
system_q = B_SYS + self.system + E_SYS
system_a = ""
ret = f"{B_INST} {system_q} {E_INST} {system_a}"
if system_prompt:
ret = system_prompt + self.settings.sep
else:
ret = ""

if self.messages[0][0] == "system":
self.messages.pop(0)
for i, (role, message) in enumerate(self.messages):
if role == self.settings.roles[0]:
if not(i != 0 and self.messages[i - 1][0] == self.settings.roles[0]):
inst = B_INST
if i % 2 == 0:
inst = B_INST + " "
else:
inst = E_INST
inst = E_INST + " "
if i == 0:
inst = ""
if message:
ret += inst + " " + message.strip() + " "
if i % 2 == 0:
ret += inst + message.strip() + " "
else:
ret += inst + message.strip() + " "
if i == len(self.messages) - 1:
ret += E_INST
else:
ret += inst + " "
ret += E_INST

return ret
elif self.settings.sep_style == SeparatorStyle.CHATLM:
im_start, im_end = "<|im_start|>", "<|im_end|>"
Expand Down
1 change: 1 addition & 0 deletions langport/data/conversation/settings/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# Llama default template
llama = ConversationSettings(
name="llama",
system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
roles=("user", "assistant"),
sep_style=SeparatorStyle.LLAMA,
sep="",
Expand Down
2 changes: 1 addition & 1 deletion langport/data/conversation/settings/starchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# StarChat default template
starchat = ConversationSettings(
name="starchat",
system_template="<system>\n{system_message}",
system_template="<|system|>\n{system_message}",
roles=("<|user|>", "<|assistant|>"),
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
sep="<|end|>\n",
Expand Down
Empty file.
118 changes: 118 additions & 0 deletions langport/tests/test_conversation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,103 @@

from typing import List
import unittest

from langport.data.conversation.conversation_settings import ConversationHistory
from langport.data.conversation.settings.baichuan import baichuan
from langport.data.conversation.settings.chatglm import chatglm
from langport.data.conversation.settings.chatgpt import chatgpt
from langport.data.conversation.settings.llama import llama
from langport.data.conversation.settings.openbuddy import openbuddy
from langport.data.conversation.settings.qwen import qwen
from langport.data.conversation.settings.starchat import starchat

class TestLlamaMethods(unittest.TestCase):
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

SPECIAL_TAGS = [B_INST, E_INST, "<<SYS>>", "<</SYS>>"]
UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt."

def get_llama_prompt(self, dialogs):
unsafe_requests = []
prompt_tokens = []
for dialog in dialogs:
unsafe_requests.append(
any([tag in msg["content"] for tag in self.SPECIAL_TAGS for msg in dialog])
)
if dialog[0]["role"] == "system":
dialog = [
{
"role": dialog[1]["role"],
"content": self.B_SYS
+ dialog[0]["content"]
+ self.E_SYS
+ dialog[1]["content"],
}
] + dialog[2:]
assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
[msg["role"] == "assistant" for msg in dialog[1::2]]
), (
"model only supports 'system', 'user' and 'assistant' roles, "
"starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
)
dialog_tokens: str = "".join([
f"{self.B_INST} {(prompt['content']).strip()} {self.E_INST} {(answer['content']).strip()} "
for prompt, answer in zip(
dialog[::2],
dialog[1::2],
)
])
assert (
dialog[-1]["role"] == "user"
), f"Last message must be from user, got {dialog[-1]['role']}"
dialog_tokens += f"{self.B_INST} {(dialog[-1]['content']).strip()} {self.E_INST}"
prompt_tokens.append(dialog_tokens)
return prompt_tokens

def test_conv(self):
history = ConversationHistory(
"SYSTEM_MESSAGE",
messages=[
(llama.roles[0], "aaa"),
(llama.roles[1], "bbb"),
(llama.roles[0], "ccc"),
],
offset=0,
settings=llama
)
my_out = history.get_prompt()
llama_out = self.get_llama_prompt([[
{"role": "system", "content": "SYSTEM_MESSAGE"},
{"role": llama.roles[0], "content": "aaa"},
{"role": llama.roles[1], "content": "bbb"},
{"role": llama.roles[0], "content": "ccc"},
]])[0]
self.assertEqual(my_out, llama_out)

def test_conv2(self):
history = ConversationHistory(
"SYSTEM_MESSAGE",
messages=[
(llama.roles[0], "aaa"),
(llama.roles[1], "bbb"),
(llama.roles[0], "ccc"),
(llama.roles[1], None),
],
offset=0,
settings=llama
)
my_out = history.get_prompt()
llama_out = self.get_llama_prompt([[
{"role": "system", "content": "SYSTEM_MESSAGE"},
{"role": llama.roles[0], "content": "aaa"},
{"role": llama.roles[1], "content": "bbb"},
{"role": llama.roles[0], "content": "ccc"},
]])[0]
print(my_out)
print(llama_out)
self.assertEqual(my_out, llama_out)


class TestBaiChuanMethods(unittest.TestCase):

Expand Down Expand Up @@ -90,5 +181,32 @@ def test_conv(self):
bbb<|im_end|>
""")


class TestStarChatMethods(unittest.TestCase):

def test_conv(self):
history = ConversationHistory(
"SYSTEM_MESSAGE",
messages=[
(starchat.roles[0], "aaa"),
(starchat.roles[1], "bbb"),
],
offset=0,
settings=starchat
)
self.assertEqual(history.get_prompt(), "<|system|>\nSYSTEM_MESSAGE<|end|>\n<|user|>\naaa<|end|>\n<|assistant|>\nbbb<|end|>\n")

def test_conv_question(self):
history = ConversationHistory(
"SYSTEM_MESSAGE",
messages=[
(starchat.roles[0], "aaa"),
(starchat.roles[1], None),
],
offset=0,
settings=starchat
)
self.assertEqual(history.get_prompt(), "<|system|>\nSYSTEM_MESSAGE<|end|>\n<|user|>\naaa<|end|>\n<|assistant|>\n")

if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion langport/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
LANGPORT_VERSION = "0.3.7"
LANGPORT_VERSION = "0.3.8"
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "langport"
version = "0.3.7"
version = "0.3.8"
description = "A large language model serving platform."
readme = "README.md"
requires-python = ">=3.8"
Expand Down

0 comments on commit 2c5b9e9

Please sign in to comment.