Skip to content

Commit 07a0d14

Browse files
committed
优化qwen的tools react 提示词
1 parent b2008dc commit 07a0d14

File tree

7 files changed

+176
-156
lines changed

7 files changed

+176
-156
lines changed

gpt_server/model_handler/__init__.py

Whitespace-only changes.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from typing import Any, Dict, List, Tuple, Union
2+
import json
3+
import uuid
4+
5+
GLM4_TOOL_SUFFIX_PROMPT = "在调用上述函数时,请使用 Json 格式表示调用的参数。"
6+
7+
GLM4_TOOL_PROMPT = (
8+
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持,"
9+
"{tool_text}"
10+
)
11+
12+
13+
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
14+
tool_text = ""
15+
for tool in tools:
16+
tool = tool["function"]
17+
tool_name = tool["name"]
18+
tool_text += f"\n\n## {tool_name}\n\n{json.dumps(tool, ensure_ascii=False, indent=4)}\n{GLM4_TOOL_SUFFIX_PROMPT}"
19+
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
20+
21+
22+
def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
23+
lines = content.strip().split("\n")
24+
if len(lines) != 2:
25+
return content
26+
tool_name = lines[0].strip()
27+
tool_input = lines[1].strip()
28+
try:
29+
json.loads(tool_input)
30+
except json.JSONDecodeError:
31+
return content
32+
tool_calls = [
33+
{
34+
"id": "call_{}".format(uuid.uuid4().hex),
35+
"function": {"name": tool_name, "arguments": tool_input},
36+
}
37+
]
38+
39+
return tool_calls
40+
41+
42+
if __name__ == "__main__":
43+
import json
44+
45+
tools_str = """[{'type': 'function', 'function': {'name': 'track', 'description': '追踪指定股票的实时价格', 'parameters': {'type': 'object', 'properties': {'symbol': {'description': '需要追踪的股票代码', 'type': 'integer'}}, 'required': ['symbol']}}}, {'type': 'function', 'function': {'name': 'text-to-speech', 'description': '将文本转换为语音', 'parameters': {'type': 'object', 'properties': {'text': {'description': '需要转换成语音的文本', 'type': 'string'}, 'voice': {'description': '要使用的语音类型(男声、女声等', 'default': '男声', 'type': 'string'}, 'speed': {'description': '语音的速度(快、中等、慢等', 'default': '中等', 'type': 'string'}}, 'required': ['text']}}}]"""
46+
tools_str = tools_str.replace("'", '"')
47+
tools = json.loads(tools_str)
48+
49+
res = glm4_tool_formatter(tools=tools)
50+
print(res)
51+
print()
52+
out = 'multiply\n{"first_int": 8, "second_int": 9}'
53+
r = glm4_tool_extractor(out)
54+
print(r)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import re
2+
from typing import Any, Dict, List, Tuple, Union
3+
import json
4+
import uuid
5+
6+
# default
7+
TOOL_SYSTEM_PROMPT = """Answer the following questions as best you can. You have access to the following tools:
8+
9+
{tool_text}
10+
11+
Use the following format:
12+
13+
Question: the input question you must answer
14+
Thought: you should always think about what to do
15+
Action: the action to take, should be one of [{tool_names}]
16+
Action Input: the input to the action
17+
Observation: the result of the action
18+
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
19+
Thought: I now know the final answer
20+
Final Answer: the final answer to the original input question
21+
22+
Begin!
23+
24+
Question:"""
25+
26+
27+
def qwen_tool_formatter(tools: List[Dict[str, Any]]) -> str:
28+
tool_names = []
29+
param_text_list = []
30+
for tool in tools:
31+
tool = tool["function"]
32+
param_text = """{tool_name}: Call this tool to interact with the {tool_name} API. What is the {tool_name} API useful for? {description} Parameters: {parameters} Format the arguments as a JSON object."""
33+
parameters = []
34+
for name, param in tool["parameters"]["properties"].items():
35+
parameters.append(
36+
{
37+
"name": name,
38+
"description": param.get("description", ""),
39+
"required": (
40+
True if name in tool["parameters"]["required"] else False
41+
),
42+
"schema": {"type": param["type"]},
43+
}
44+
)
45+
param_text_str = param_text.format(
46+
tool_name=tool["name"],
47+
description=tool["description"],
48+
parameters=parameters,
49+
)
50+
param_text_list.append(param_text_str)
51+
52+
tool_names.append(tool["name"])
53+
54+
tool_text = "\n\n".join(param_text_list).strip()
55+
return TOOL_SYSTEM_PROMPT.format(
56+
tool_text=tool_text,
57+
tool_names=", ".join(tool_names),
58+
)
59+
60+
61+
def qwen_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
62+
63+
i = content.rfind("Action:")
64+
j = content.rfind("Action Input:")
65+
tool_name = content[i + len("Action:") : j].strip().strip(".")
66+
tool_input = content[j + len("Action Input:") :].strip()
67+
try:
68+
json.loads(tool_input)
69+
except json.JSONDecodeError:
70+
return content
71+
tool_calls = []
72+
tool_call = {
73+
"id": "call_{}".format(uuid.uuid4().hex),
74+
"function": {"name": tool_name, "arguments": tool_input},
75+
}
76+
tool_calls.append(tool_call)
77+
78+
return tool_calls
79+
80+
81+
if __name__ == "__main__":
82+
import json
83+
84+
tools_str = """[{'type': 'function', 'function': {'name': 'track', 'description': '追踪指定股票的实时价格', 'parameters': {'type': 'object', 'properties': {'symbol': {'description': '需要追踪的股票代码', 'type': 'integer'}}, 'required': ['symbol']}}}, {'type': 'function', 'function': {'name': 'text-to-speech', 'description': '将文本转换为语音', 'parameters': {'type': 'object', 'properties': {'text': {'description': '需要转换成语音的文本', 'type': 'string'}, 'voice': {'description': '要使用的语音类型(男声、女声等', 'default': '男声', 'type': 'string'}, 'speed': {'description': '语音的速度(快、中等、慢等', 'default': '中等', 'type': 'string'}}, 'required': ['text']}}}]"""
85+
tools_str = tools_str.replace("'", '"')
86+
tools = json.loads(tools_str)
87+
res = qwen_tool_formatter(tools=tools)
88+
print(res)
89+
out = 'Action: multiply.\nAction Input: {"first_int": 8, "second_int": 9}\n'
90+
r = qwen_tool_extractor(out)
91+
print("\n\n")
92+
print(r)

gpt_server/model_handler/tools.py

Lines changed: 0 additions & 152 deletions
This file was deleted.

gpt_server/model_handler/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from gpt_server.model_handler.qwen_react import qwen_tool_formatter
2+
from gpt_server.model_handler.chatglm_react import glm4_tool_formatter
3+
4+
5+
def add_tools2messages(params: dict, model_adapter: str = "default"):
6+
messages = params["messages"]
7+
if params.get("tools", None): # 如果传入tools
8+
if model_adapter == "qwen":
9+
system_content = qwen_tool_formatter(tools=params.get("tools"))
10+
11+
elif model_adapter == "chatglm4":
12+
system_content = glm4_tool_formatter(tools=params.get("tools"))
13+
14+
if messages[0]["role"] != "system":
15+
messages.insert(0, {"role": "system", "content": system_content})
16+
17+
elif messages[0]["role"] == "system":
18+
messages[0]["content"] = system_content
19+
20+
return messages

gpt_server/model_worker/chatglm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import torch
55
from loguru import logger
66
from gpt_server.model_worker.base import ModelWorkerBase
7-
from gpt_server.model_handler.tools import add_tools2messages, glm4_tool_extractor
7+
from gpt_server.model_handler.chatglm_react import glm4_tool_extractor
8+
from gpt_server.model_handler.utils import add_tools2messages
89

910

1011
class ChatGLMWorker(ModelWorkerBase):
@@ -109,6 +110,7 @@ async def generate_stream_gate(self, params):
109110
if params.get("tools", False) and isinstance(
110111
tool_calls, list
111112
): # 如果传入tools
113+
logger.debug(f"工具解析成功, tool_calls: {tool_calls}")
112114
ret["tool_calls"] = tool_calls
113115
yield json.dumps(ret).encode() + b"\0"
114116
except torch.cuda.OutOfMemoryError as e:

gpt_server/model_worker/qwen.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import torch
66

77
from gpt_server.model_worker.base import ModelWorkerBase
8-
from gpt_server.model_handler.tools import add_tools2messages, default_tool_extractor
8+
from gpt_server.model_handler.qwen_react import qwen_tool_extractor
9+
from gpt_server.model_handler.utils import add_tools2messages
910

1011

1112
class QwenWorker(ModelWorkerBase):
@@ -39,6 +40,8 @@ def __init__(
3940
self.stop = [
4041
self.tokenizer.decode(skip_word) for skip_word in self.stop_words_ids
4142
]
43+
# 拓展额外的stop
44+
self.stop.extend(["Observation:"])
4245
logger.info(f"qwen停用词: {self.stop}")
4346
self.other_config = {
4447
"chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}"
@@ -50,7 +53,7 @@ async def generate_stream_gate(self, params):
5053
logger.info(f"worker_id: {self.worker_id}")
5154
try:
5255
model_type = getattr(self.model_config, "model_type", "qwen")
53-
messages = add_tools2messages(params=params, model_adapter="default")
56+
messages = add_tools2messages(params=params, model_adapter="qwen")
5457

5558
if isinstance(messages, list):
5659
task = "chat"
@@ -94,10 +97,11 @@ async def generate_stream_gate(self, params):
9497

9598
yield json.dumps(ret).encode() + b"\0"
9699
# ------ add tool_calls ------
97-
tool_calls = default_tool_extractor(response)
100+
tool_calls = qwen_tool_extractor(response)
98101
if params.get("tools", False) and isinstance(
99102
tool_calls, list
100103
): # 如果传入tools
104+
logger.debug(f"工具解析成功, tool_calls: {tool_calls}")
101105
ret["tool_calls"] = tool_calls
102106
yield json.dumps(ret).encode() + b"\0"
103107
except torch.cuda.OutOfMemoryError as e:

0 commit comments

Comments
 (0)