From bfcc2c4f3df208b4a72f942de3017ba892a354ad Mon Sep 17 00:00:00 2001 From: pyx9913 Date: Tue, 29 Oct 2024 14:10:31 +0800 Subject: [PATCH] code format --- .../openai/tool_parsers/__init__.py | 4 +- .../tool_parsers/minicpm_tool_parser.py | 141 ++++++++---------- 2 files changed, 61 insertions(+), 84 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 5be6a9dfd3ad1..a8a8679e3c5c8 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -3,11 +3,11 @@ from .internlm2_tool_parser import Internlm2ToolParser from .jamba_tool_parser import JambaToolParser from .llama_tool_parser import Llama3JsonToolParser +from .minicpm_tool_parser import MiniCPMToolParser from .mistral_tool_parser import MistralToolParser -from .minicpm_tool_parser import MiniCPMJsonToolParser __all__ = [ "ToolParser", "ToolParserManager", "Hermes2ProToolParser", "MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser", - "JambaToolParser", "MiniCPMJsonToolParser" + "JambaToolParser", "MiniCPMToolParser" ] diff --git a/vllm/entrypoints/openai/tool_parsers/minicpm_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/minicpm_tool_parser.py index 77599b9061b85..683d049e71eba 100644 --- a/vllm/entrypoints/openai/tool_parsers/minicpm_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/minicpm_tool_parser.py @@ -3,23 +3,17 @@ import keyword import re import traceback -from typing import List, Sequence, Union, Dict +from typing import Dict, List, Sequence, Union from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.openai.protocol import ( - ChatCompletionRequest, - DeltaFunctionCall, - DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, - ToolCall, -) +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, - ToolParserManager, -) + ToolParser, ToolParserManager) from vllm.logger import init_logger logger = init_logger(__name__) @@ -43,29 +37,24 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase): self.stop_token_ids = [2, 73440] def extract_tool_calls( - self, model_output: str, request: ChatCompletionRequest - ) -> ExtractedToolCallInformation: + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: """ Extract the tool calls from a complete model response. """ msg = fc2dict(model_output) - if ( - "tool_calls" in msg - and msg["tool_calls"] is not None - and len(msg["tool_calls"]) > 0 - ): + if ("tool_calls" in msg and msg["tool_calls"] is not None + and len(msg["tool_calls"]) > 0): tool_calls: List[ToolCall] = [ ToolCall( type="function", function=FunctionCall( name=raw_function_call["name"], # function call args are JSON but as a string - arguments=json.dumps( - raw_function_call["arguments"], ensure_ascii=False - ), + arguments=json.dumps(raw_function_call["arguments"], + ensure_ascii=False), ), - ) - for raw_function_call in msg["tool_calls"] + ) for raw_function_call in msg["tool_calls"] ] # get any content before the tool call @@ -98,15 +87,11 @@ def extract_tool_calls_streaming( if self.thought_end_token not in current_text: return None useful_text = current_text.split(self.thought_end_token)[1] - if ( - current_token_ids[-1] in self.stop_token_ids - ): # case 3: stream generation ended + if (current_token_ids[-1] + in self.stop_token_ids): # case 3: stream generation ended msg = fc2dict(current_text) - if ( - "tool_calls" in msg - and msg["tool_calls"] is not None - and len(msg["tool_calls"]) > 0 - ): + if ("tool_calls" in msg and msg["tool_calls"] is not None + and len(msg["tool_calls"]) > 0): self.prev_tool_call_arr = msg["tool_calls"] self.streamed_args_for_tool = ["" for tc in msg["tool_calls"]] delta_message = DeltaMessage( @@ -116,14 +101,12 @@ def extract_tool_calls_streaming( return delta_message else: return DeltaMessage(content=msg.get("content", None)) - elif ( - self.tool_call_start_token in useful_text - and self.tool_call_end_token in useful_text - ): # case 2: tool call ended + elif (self.tool_call_start_token in useful_text + and self.tool_call_end_token + in useful_text): # case 2: tool call ended return None - elif ( - self.tool_call_start_token in useful_text - ): # case 1: tool call started + elif (self.tool_call_start_token + in useful_text): # case 1: tool call started # Extract function name and arguments, handling nested parentheses pattern = r"(\w+)\(((?:[^()]*|\([^()]*\))*)\)" matches = re.finditer(pattern, useful_text) @@ -138,26 +121,23 @@ def extract_tool_calls_streaming( parsed = ast.parse(tool_call_string) for elem in parsed.body: - assert isinstance(elem.value, ast.Call) - calls = resolve_ast_call(elem.value) + assert isinstance(elem.value, ast.Call) # type: ignore + calls = resolve_ast_call(elem.value) # type: ignore for func_name, func_args in calls.items(): this_call = { - "name": func_name, - "arguments": json.dumps( - func_args, ensure_ascii=False - ), + "name": + func_name, + "arguments": + json.dumps(func_args, ensure_ascii=False), } - delta = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - **this_call - ).model_dump(exclude_none=True), - ) - ] - ) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + **this_call).model_dump(exclude_none=True), + ) + ]) self.prev_tool_call_arr = tool_calls self.streamed_args_for_tool = ["" for x in tool_calls] self.current_tool_name_sent = True @@ -187,25 +167,23 @@ def fc2dict( if tool_call_string.startswith("```"): tool_call_string = tool_call_string[3:].strip() if tool_call_string.startswith("python"): - tool_call_string = tool_call_string.lstrip("python").strip() + tool_call_string = tool_call_string.lstrip( + "python").strip() if tool_call_string.endswith("```"): tool_call_string = tool_call_string[:-3].strip() for kw in keyword.kwlist: tool_call_string = tool_call_string.replace( - "," + kw + "=", "," + kw + "_=" - ) + "," + kw + "=", "," + kw + "_=") tool_call_string = tool_call_string.replace( - " " + kw + "=", " " + kw + "_=" - ) + " " + kw + "=", " " + kw + "_=") tool_call_string = tool_call_string.replace( - "(" + kw + "=", "(" + kw + "_=" - ) + "(" + kw + "=", "(" + kw + "_=") - parsed = ast.parse(tool_call_string) + parsed: ast.Module = ast.parse(tool_call_string) for elem in parsed.body: - assert isinstance(elem.value, ast.Call) - calls = resolve_ast_call(elem.value) + assert isinstance(elem.value, ast.Call) # type: ignore + calls = resolve_ast_call(elem.value) # type: ignore for func_name, func_args in calls.items(): new_args = {} @@ -224,7 +202,7 @@ def fc2dict( "role": "assistant", } except Exception as e: - logger.error(f"Error parsing tool call: {e}") + logger.error("Error parsing tool call: %s", str(e)) logger.error(traceback.format_exc()) return { "content": content.strip(), @@ -259,12 +237,9 @@ def resolve_ast_call(elem): def resolve_ast_by_type(value): if isinstance(value, ast.Constant): - if value.value is Ellipsis: - output = "..." - else: - output = value.value + output = "..." if value.value is Ellipsis else value.value elif isinstance(value, ast.UnaryOp): - output = -value.operand.value + output = -value.operand.value # type: ignore elif isinstance(value, ast.List): output = [resolve_ast_by_type(v) for v in value.elts] elif isinstance(value, ast.Dict): @@ -273,34 +248,36 @@ def resolve_ast_by_type(value): for k, v in zip(value.keys, value.values) } elif isinstance( - value, ast.NameConstant - ): # Added this condition to handle boolean values + value, + ast.NameConstant): # Added this condition to handle boolean values output = value.value elif isinstance( - value, ast.BinOp + value, ast.BinOp ): # Added this condition to handle function calls as arguments - output = ast.literal_eval(ast.unparse(value)) + output = ast.literal_eval(ast.unparse(value)) # type: ignore elif isinstance(value, ast.Name): output = value.id elif isinstance(value, ast.Call): if len(value.keywords) == 0: - output = ast.unparse(value) + output = ast.unparse(value) # type: ignore else: output = resolve_ast_call(value) elif isinstance(value, ast.Tuple): output = tuple(resolve_ast_by_type(v) for v in value.elts) elif isinstance(value, ast.Lambda): - output = ast.literal_eval(ast.unparse(value.body[0].value)) + output = ast.literal_eval( + ast.unparse( # type: ignore + value.body[0].value)) # type: ignore elif isinstance(value, ast.Ellipsis): output = "..." elif isinstance(value, ast.Subscript): try: - output = ast.unparse(value.body[0].value) + output = ast.unparse(value.body[0].value) # type: ignore except Exception as e: - logger.error(f"Error parsing tool call: {e}") + logger.error("Error parsing tool call: %s", str(e)) output = ( - ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]" - ) + ast.unparse(value.value) + "[" + # type: ignore + ast.unparse(value.slice) + "]") # type: ignore else: raise Exception(f"Unsupported AST type: {type(value)}") return output