Skip to content

Commit 72d86c1

Browse files
committed
优化glm4的 tools
1 parent 5001a58 commit 72d86c1

File tree

2 files changed

+47
-23
lines changed

2 files changed

+47
-23
lines changed

gpt_server/model_handler/chatglm_react.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,37 +4,56 @@
44

55
GLM4_TOOL_SUFFIX_PROMPT = "在调用上述函数时,请使用 Json 格式表示调用的参数。"
66

7-
GLM4_TOOL_PROMPT = (
8-
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持,"
9-
"{tool_text}"
10-
)
7+
GLM4_TOOL_PROMPT = """"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。
8+
9+
# 可用工具
10+
{tool_text}
11+
Use the following format:
12+
13+
Question: the input question you must answer
14+
Thought: you should always think about what to do
15+
Action: the action to take, should be one of [{tool_names}]
16+
Action Input: the input to the action
17+
Observation: the result of the action
18+
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
19+
Thought: I now know the final answer
20+
Final Answer: the final answer to the original input question
21+
22+
Begin!
23+
24+
Question:
25+
"""
1126

1227

1328
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
14-
tool_text = ""
29+
tool_text = "\n"
30+
tool_names = []
1531
for tool in tools:
1632
tool = tool["function"]
1733
tool_name = tool["name"]
18-
tool_text += f"\n\n## {tool_name}\n\n{json.dumps(tool, ensure_ascii=False, indent=4)}\n{GLM4_TOOL_SUFFIX_PROMPT}"
19-
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
34+
tool_text += f"## {tool_name}\n\n{json.dumps(tool, ensure_ascii=False, indent=4)}\n{GLM4_TOOL_SUFFIX_PROMPT}\n\n"
35+
tool_names.append(tool_name)
36+
return GLM4_TOOL_PROMPT.format(
37+
tool_text=tool_text, tool_names=", ".join(tool_names)
38+
).strip()
2039

2140

2241
def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
23-
lines = content.strip().split("\n")
24-
if len(lines) != 2:
25-
return content
26-
tool_name = lines[0].strip()
27-
tool_input = lines[1].strip()
42+
i = content.rfind("Action:")
43+
j = content.rfind("Action Input:")
44+
tool_name = content[i + len("Action:") : j].strip().strip(".")
45+
tool_input = content[j + len("Action Input:") :].strip()
2846
try:
29-
json.loads(tool_input)
47+
tool_input_obj = json.loads(tool_input)
3048
except json.JSONDecodeError:
3149
return content
32-
tool_calls = [
33-
{
34-
"id": "call_{}".format(uuid.uuid4().hex),
35-
"function": {"name": tool_name, "arguments": tool_input},
36-
}
37-
]
50+
tool_calls = []
51+
tool_call = {
52+
"index": 0,
53+
"id": "call_{}".format(uuid.uuid4().hex),
54+
"function": {"name": tool_name, "arguments": tool_input},
55+
}
56+
tool_calls.append(tool_call)
3857

3958
return tool_calls
4059

gpt_server/model_worker/chatglm.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,14 @@ def __init__(
3131
)
3232

3333
self.stop = ["<|user|>", "<|observation|>", "<|endoftext|>"]
34-
self.stop_words_ids = [
35-
self.tokenizer.convert_tokens_to_ids(i) for i in self.stop
36-
]
34+
# 拓展额外的stop
35+
self.stop.extend(["Observation:"])
36+
self.stop_words_ids = []
37+
for i in self.stop:
38+
try:
39+
self.stop_words_ids.append(self.tokenizer.convert_tokens_to_ids(i))
40+
except Exception as e:
41+
pass
3742

3843
logger.info(f"chatglm停用词: {self.stop}")
3944

@@ -71,7 +76,7 @@ async def generate_stream_gate(self, params):
7176
if isinstance(messages, list):
7277
task = "chat"
7378
for msg in messages:
74-
if msg["role"] == "function":
79+
if msg["role"] == "function" or msg["role"] == "tool":
7580
msg["role"] = "observation"
7681

7782
if messages[-1]["role"] == "user":

0 commit comments

Comments
 (0)