From 85fce11e1ec339450c11dd8a41b0edd962940747 Mon Sep 17 00:00:00 2001 From: pyx9913 Date: Mon, 28 Oct 2024 22:04:51 +0800 Subject: [PATCH] add minicpm tool parser Signed-off-by: pyx9913 --- examples/tool_chat_template_minicpm3.jinja | 135 +++++++++ .../openai/tool_parsers/__init__.py | 3 +- .../tool_parsers/minicpm_tool_parser.py | 283 ++++++++++++++++++ 3 files changed, 420 insertions(+), 1 deletion(-) create mode 100644 examples/tool_chat_template_minicpm3.jinja create mode 100644 vllm/entrypoints/openai/tool_parsers/minicpm_tool_parser.py diff --git a/examples/tool_chat_template_minicpm3.jinja b/examples/tool_chat_template_minicpm3.jinja new file mode 100644 index 0000000000000..14d014d133b5e --- /dev/null +++ b/examples/tool_chat_template_minicpm3.jinja @@ -0,0 +1,135 @@ +{%- macro json_to_python_type(param_name, json_spec) %} +{%- set basic_type_map = { + 'string': 'str', + 'number': 'float', + 'integer': 'int', + 'boolean': 'bool', + 'null': 'None' +} %} + +{%- if json_spec.enum %} + {{- param_name|title }} +{%- elif basic_type_map[json_spec.type] is defined %} + {{- basic_type_map[json_spec.type] }} +{%- elif json_spec.type == 'array' %} + {{- 'List[' + json_to_python_type(param_name, json_spec['items']) + ']' }} +{%- elif json_spec.type == 'object' %} + {{- 'Dict[str, ' + json_to_python_type(param_name, json_spec.additionalProperties if json_spec.additionalProperties else 'Any') + ']' if not json_spec.properties else param_name|title }} +{%- elif json_spec.type is iterable %} + {{- 'Union[' }} + {%- for t in json_spec.type %} + {{- json_to_python_type(param_name, {'type': t}) }} + {{- ', ' if not loop.last }} + {%- endfor %} + {{- ']' }} +{%- else %} + {{- 'Any' }} +{%- endif %} +{%- endmacro %} + +{%- macro object_to_fields(json_spec, field_indent) %} + {%- set o_ns = namespace(f = caller()) %} + {%- for param_name, param_fields in json_spec.properties|items %} + {%- if param_fields.enum %} + {{- '\n\nclass ' + param_name|title + '(Enum):\n' }} + {%- for enum_option in param_fields.enum %} + {{- ' enum_' + loop.index0|string + ' = ' + enum_option|tojson + '\n' }} + {%- endfor %} + {%- elif param_fields.type == 'object' and param_fields.properties %} + {%- call object_to_fields(param_fields, ' ') %} + {{- '\n\nclass ' + param_name|title + '(BaseModel):\n' }} + {%- endcall %} + {%- elif param_fields.type == 'array' and param_fields['items'] and param_fields['items'].type == 'object' and param_fields['items'].properties %} + {%- call object_to_fields(param_fields['items'], ' ') %} + {{- '\n\nclass ' + param_name|title + '(BaseModel):\n' }} + {%- endcall %} + {%- endif %} + {%- set param_default = param_fields.default|tojson if param_fields.default is string else param_fields.default|string if param_fields.default is defined else 'None' %} + {%- set o_ns.f = o_ns.f + field_indent + param_name + ': ' %} + {%- set o_ns.f = o_ns.f + ('Optional[' + json_to_python_type(param_name, param_fields) + ']' if param_name not in json_spec.required else json_to_python_type(param_name, param_fields)) %} + {%- if not param_fields.title and not param_fields.description and not param_fields.pattern %} + {%- set o_ns.f = o_ns.f + (' = ' + param_default if param_name not in json_spec.required else '') %} + {%- else %} + {%- set o_ns.f = o_ns.f + (' = Field(...' if param_name in json_spec.required else ' = Field(' + param_default) %} + {%- set o_ns.f = o_ns.f + (', description=' + param_fields.description|tojson if param_fields.description else '') %} + {%- set o_ns.f = o_ns.f + (', regex=' + param_fields.pattern|tojson if param_fields.pattern else '') %} + {%- set o_ns.f = o_ns.f + (', title=' + param_fields.title|tojson if param_fields.title else '') %} + {%- set o_ns.f = o_ns.f + ')' %} + {%- endif %} + {%- set o_ns.f = o_ns.f + '\n' %} + {%- endfor %} + {{- o_ns.f }} +{%- endmacro %} + +{%- macro tool_parser(tools) %} +{%- for tool in tools %} + {%- if tool.type is not defined or tool.type == 'function' %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {%- set tool_params = tool.parameters if tool.parameters is defined else none %} + {%- call object_to_fields(tool_params, ' ') %} + {{- '\n\ndef ' + tool.name + '(' }} + {%- if tool_params %} + {%- for param_name, param_fields in tool_params.properties|items %} + {%- set param_default = param_fields.default|tojson if param_fields.default is string else param_fields.default|string if param_fields.default is defined else 'None' %} + {{- ', ' if loop.index0 != 0 }} + {{- param_name }} + {{- '=' + param_default if param_name not in tool_params.required }} + {%- endfor %} + {%- endif %} + {{- '):\n """' }} + {{- tool.description }} + {{- '\n\n Args:\n' if tool_params else '\n' }} + {%- endcall %} + {{- ' """\n' }} + {%- endif %} +{%- endfor %} +{%- endmacro %} + +{%- if messages[0]['role'] == 'system' %} + {%- set loop_messages = messages[1:] %} + {%- set system_message = messages[0]['content'] %} +{%- else %} + {%- set loop_messages = messages %} + {%- set system_message = '' %} +{%- endif %} +{{- '<|im_start|>system\n' + system_message if system_message or tools }} +{%- if tools %} + {{- '\n# Functions\nHere is a list of functions that you can invoke:\n```python\nfrom enum import Enum\nfrom typing import List, Dict, Optional\nfrom pydantic import BaseModel, Field\n\n' }} + {{- tool_parser(tools) }} + {{- "\n```\n\n# Function Call Rule and Output Format\n- If the user's question can be answered without calling any function, please answer the user's question directly. In this situation, you should return your thought and answer the user's question directly.\n- If the user cannot be answered without calling any function, and the user does not provide enough information to call functions, please ask the user for more information. In this situation, you should return your thought and ask the user for more information.\n- If the user's question cannot be answered without calling any function, and the user has provided enough information to call functions to solve it, you should call the functions. In this situation, the assistant should return your thought and call the functions.\n- Use default parameters unless the user has specified otherwise.\n- You should answer in the following format:\n\n<|thought_start|>\n{explain why the user's question can be answered without calling a function or why you should ask the user for more information or why you should call one or more functions and your plan to solve the user's question.}\n<|thought_end|>\n<|tool_call_start|>\n```python\nfunc1(params_name=params_value, params_name2=params_value2...)\nfunc2(params)\n```\n<|tool_call_end|>\n{answer the user's question directly or ask the user for more information}" }} +{%- endif %} +{{- '<|im_end|>\n' if system_message or tools }} +{%- for message in loop_messages %} + {%- set content = message.content %} + {%- if message.role == 'assistant' and message.tool_calls %} + {{- '<|im_start|>' + message.role + '\n' }} + {{- '<|thought_start|>\n' + message.thought + '\n<|thought_end|>\n' if message.thought }} + {{- '<|tool_call_start|>\n```python\n' }} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- tool_call.name + '(' }} + {%- if tool_call.arguments is defined and tool_call.arguments|length > 0 %} + {%- for param_name, param_value in tool_call.arguments|items %} + {{- param_name + '=' + param_value|tojson }} + {{- ',' if not loop.last }} + {%- endfor %} + {%- endif %} + {{- ')\n' }} + {%- endfor %} + {{- '```\n<|tool_call_end|>\n' }} + {{- content if content and not content.startswith('<|tool_call_start|>') }} + {{- '<|im_end|>\n' }} + {%- elif message.role == 'assistant' and message.thought %} + {{- '<|im_start|>' + message.role + '\n' + '<|thought_start|>\n' + message.thought + '\n<|thought_end|>\n' + content + '<|im_end|>\n' }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>\n' }} + {%- endif %} +{%- endfor %} + +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} \ No newline at end of file diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 0e88bb21ca75f..a8a8679e3c5c8 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -3,10 +3,11 @@ from .internlm2_tool_parser import Internlm2ToolParser from .jamba_tool_parser import JambaToolParser from .llama_tool_parser import Llama3JsonToolParser +from .minicpm_tool_parser import MiniCPMToolParser from .mistral_tool_parser import MistralToolParser __all__ = [ "ToolParser", "ToolParserManager", "Hermes2ProToolParser", "MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser", - "JambaToolParser" + "JambaToolParser", "MiniCPMToolParser" ] diff --git a/vllm/entrypoints/openai/tool_parsers/minicpm_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/minicpm_tool_parser.py new file mode 100644 index 0000000000000..683d049e71eba --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/minicpm_tool_parser.py @@ -0,0 +1,283 @@ +import ast +import json +import keyword +import re +import traceback +from typing import Dict, List, Sequence, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("minicpm") +class MiniCPMToolParser(ToolParser): + """ + Tool call parser for MiniCPM3 4B models intended for use with the + examples/tool_chat_template_minicpm3.jinja template. + + Used when --enable-auto-tool-choice --tool-call-parser minicpm are all set + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + self.thought_start_token = "<|thought_start|>" + self.thought_end_token = "<|thought_end|>" + self.tool_call_start_token = "<|tool_call_start|>" + self.tool_call_end_token = "<|tool_call_end|>" + self.stop_token_ids = [2, 73440] + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. + """ + msg = fc2dict(model_output) + if ("tool_calls" in msg and msg["tool_calls"] is not None + and len(msg["tool_calls"]) > 0): + tool_calls: List[ToolCall] = [ + ToolCall( + type="function", + function=FunctionCall( + name=raw_function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(raw_function_call["arguments"], + ensure_ascii=False), + ), + ) for raw_function_call in msg["tool_calls"] + ] + + # get any content before the tool call + ret = ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=msg.get("content", None), + ) + return ret + else: + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=[], + content=msg.get("content", None), + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + # if no tools are provided, we don't need to parse tool calls + if not request.tools: + return DeltaMessage(content=delta_text) + if self.thought_end_token not in current_text: + return None + useful_text = current_text.split(self.thought_end_token)[1] + if (current_token_ids[-1] + in self.stop_token_ids): # case 3: stream generation ended + msg = fc2dict(current_text) + if ("tool_calls" in msg and msg["tool_calls"] is not None + and len(msg["tool_calls"]) > 0): + self.prev_tool_call_arr = msg["tool_calls"] + self.streamed_args_for_tool = ["" for tc in msg["tool_calls"]] + delta_message = DeltaMessage( + role="assistant", + content=msg.get("content", None), + ) + return delta_message + else: + return DeltaMessage(content=msg.get("content", None)) + elif (self.tool_call_start_token in useful_text + and self.tool_call_end_token + in useful_text): # case 2: tool call ended + return None + elif (self.tool_call_start_token + in useful_text): # case 1: tool call started + # Extract function name and arguments, handling nested parentheses + pattern = r"(\w+)\(((?:[^()]*|\([^()]*\))*)\)" + matches = re.finditer(pattern, useful_text) + tool_calls: List[Dict] = [] + delta = None + for idx, match in enumerate(matches): + if self.current_tool_id < idx: + self.current_tool_id = idx + func_name = match.group(1) + func_args = match.group(2) + tool_call_string = f"{func_name}({func_args})\n" + + parsed = ast.parse(tool_call_string) + for elem in parsed.body: + assert isinstance(elem.value, ast.Call) # type: ignore + calls = resolve_ast_call(elem.value) # type: ignore + + for func_name, func_args in calls.items(): + this_call = { + "name": + func_name, + "arguments": + json.dumps(func_args, ensure_ascii=False), + } + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + **this_call).model_dump(exclude_none=True), + ) + ]) + self.prev_tool_call_arr = tool_calls + self.streamed_args_for_tool = ["" for x in tool_calls] + self.current_tool_name_sent = True + return delta + else: + return None + + +def fc2dict( + sequence: str, + tool_call_start="<|tool_call_start|>", + tool_call_end="<|tool_call_end|>", + thought_start="<|thought_start|>", + thought_end="<|thought_end|>", +): + if thought_end in sequence and thought_start in sequence: + thought_string, sequence = sequence.rsplit(thought_end, 1) + thought_string = thought_string.split(thought_start, 1)[1] + else: + thought_string = "" + if tool_call_start in sequence and tool_call_end in sequence: + tool_call_string, content = sequence.rsplit(tool_call_end, 1) + tool_call_string = tool_call_string.split(tool_call_start, 1)[1] + try: + tool_calls = [] + tool_call_string = tool_call_string.strip() + if tool_call_string.startswith("```"): + tool_call_string = tool_call_string[3:].strip() + if tool_call_string.startswith("python"): + tool_call_string = tool_call_string.lstrip( + "python").strip() + if tool_call_string.endswith("```"): + tool_call_string = tool_call_string[:-3].strip() + for kw in keyword.kwlist: + tool_call_string = tool_call_string.replace( + "," + kw + "=", "," + kw + "_=") + tool_call_string = tool_call_string.replace( + " " + kw + "=", " " + kw + "_=") + tool_call_string = tool_call_string.replace( + "(" + kw + "=", "(" + kw + "_=") + + parsed: ast.Module = ast.parse(tool_call_string) + + for elem in parsed.body: + assert isinstance(elem.value, ast.Call) # type: ignore + calls = resolve_ast_call(elem.value) # type: ignore + + for func_name, func_args in calls.items(): + new_args = {} + for k, v in func_args.items(): + for kw in keyword.kwlist: + if k == kw + "_": + k = kw + new_args[k] = v + + this_one = {"name": func_name, "arguments": new_args} + tool_calls.append(this_one) + + return { + "content": content.strip(), + "tool_calls": tool_calls, + "role": "assistant", + } + except Exception as e: + logger.error("Error parsing tool call: %s", str(e)) + logger.error(traceback.format_exc()) + return { + "content": content.strip(), + "role": "assistant", + "thought": thought_string, + } + else: + return { + "content": sequence.strip(), + "role": "assistant", + "thought": thought_string, + } + + +# from ShishirPatil/gorilla +def resolve_ast_call(elem): + # Handle nested attributes for deeply nested module paths + func_parts = [] + func_part = elem.func + while isinstance(func_part, ast.Attribute): + func_parts.append(func_part.attr) + func_part = func_part.value + if isinstance(func_part, ast.Name): + func_parts.append(func_part.id) + func_name = ".".join(reversed(func_parts)) + args_dict = {} + for arg in elem.keywords: + output = resolve_ast_by_type(arg.value) + args_dict[arg.arg] = output + return {func_name: args_dict} + + +def resolve_ast_by_type(value): + if isinstance(value, ast.Constant): + output = "..." if value.value is Ellipsis else value.value + elif isinstance(value, ast.UnaryOp): + output = -value.operand.value # type: ignore + elif isinstance(value, ast.List): + output = [resolve_ast_by_type(v) for v in value.elts] + elif isinstance(value, ast.Dict): + output = { + resolve_ast_by_type(k): resolve_ast_by_type(v) + for k, v in zip(value.keys, value.values) + } + elif isinstance( + value, + ast.NameConstant): # Added this condition to handle boolean values + output = value.value + elif isinstance( + value, ast.BinOp + ): # Added this condition to handle function calls as arguments + output = ast.literal_eval(ast.unparse(value)) # type: ignore + elif isinstance(value, ast.Name): + output = value.id + elif isinstance(value, ast.Call): + if len(value.keywords) == 0: + output = ast.unparse(value) # type: ignore + else: + output = resolve_ast_call(value) + elif isinstance(value, ast.Tuple): + output = tuple(resolve_ast_by_type(v) for v in value.elts) + elif isinstance(value, ast.Lambda): + output = ast.literal_eval( + ast.unparse( # type: ignore + value.body[0].value)) # type: ignore + elif isinstance(value, ast.Ellipsis): + output = "..." + elif isinstance(value, ast.Subscript): + try: + output = ast.unparse(value.body[0].value) # type: ignore + except Exception as e: + logger.error("Error parsing tool call: %s", str(e)) + output = ( + ast.unparse(value.value) + "[" + # type: ignore + ast.unparse(value.slice) + "]") # type: ignore + else: + raise Exception(f"Unsupported AST type: {type(value)}") + return output