Skip to content

Commit

Permalink
code format
Browse files Browse the repository at this point in the history
  • Loading branch information
Cppowboy committed Oct 29, 2024
1 parent 2df396a commit bfcc2c4
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 84 deletions.
4 changes: 2 additions & 2 deletions vllm/entrypoints/openai/tool_parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from .internlm2_tool_parser import Internlm2ToolParser
from .jamba_tool_parser import JambaToolParser
from .llama_tool_parser import Llama3JsonToolParser
from .minicpm_tool_parser import MiniCPMToolParser
from .mistral_tool_parser import MistralToolParser
from .minicpm_tool_parser import MiniCPMJsonToolParser

__all__ = [
"ToolParser", "ToolParserManager", "Hermes2ProToolParser",
"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser",
"JambaToolParser", "MiniCPMJsonToolParser"
"JambaToolParser", "MiniCPMToolParser"
]
141 changes: 59 additions & 82 deletions vllm/entrypoints/openai/tool_parsers/minicpm_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,17 @@
import keyword
import re
import traceback
from typing import List, Sequence, Union, Dict
from typing import Dict, List, Sequence, Union

from transformers import PreTrainedTokenizerBase

from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
ToolParserManager,
)
ToolParser, ToolParserManager)
from vllm.logger import init_logger

logger = init_logger(__name__)
Expand All @@ -43,29 +37,24 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase):
self.stop_token_ids = [2, 73440]

def extract_tool_calls(
self, model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation:
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
"""
msg = fc2dict(model_output)
if (
"tool_calls" in msg
and msg["tool_calls"] is not None
and len(msg["tool_calls"]) > 0
):
if ("tool_calls" in msg and msg["tool_calls"] is not None
and len(msg["tool_calls"]) > 0):
tool_calls: List[ToolCall] = [
ToolCall(
type="function",
function=FunctionCall(
name=raw_function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(
raw_function_call["arguments"], ensure_ascii=False
),
arguments=json.dumps(raw_function_call["arguments"],
ensure_ascii=False),
),
)
for raw_function_call in msg["tool_calls"]
) for raw_function_call in msg["tool_calls"]
]

# get any content before the tool call
Expand Down Expand Up @@ -98,15 +87,11 @@ def extract_tool_calls_streaming(
if self.thought_end_token not in current_text:
return None
useful_text = current_text.split(self.thought_end_token)[1]
if (
current_token_ids[-1] in self.stop_token_ids
): # case 3: stream generation ended
if (current_token_ids[-1]
in self.stop_token_ids): # case 3: stream generation ended
msg = fc2dict(current_text)
if (
"tool_calls" in msg
and msg["tool_calls"] is not None
and len(msg["tool_calls"]) > 0
):
if ("tool_calls" in msg and msg["tool_calls"] is not None
and len(msg["tool_calls"]) > 0):
self.prev_tool_call_arr = msg["tool_calls"]
self.streamed_args_for_tool = ["" for tc in msg["tool_calls"]]
delta_message = DeltaMessage(
Expand All @@ -116,14 +101,12 @@ def extract_tool_calls_streaming(
return delta_message
else:
return DeltaMessage(content=msg.get("content", None))
elif (
self.tool_call_start_token in useful_text
and self.tool_call_end_token in useful_text
): # case 2: tool call ended
elif (self.tool_call_start_token in useful_text
and self.tool_call_end_token
in useful_text): # case 2: tool call ended
return None
elif (
self.tool_call_start_token in useful_text
): # case 1: tool call started
elif (self.tool_call_start_token
in useful_text): # case 1: tool call started
# Extract function name and arguments, handling nested parentheses
pattern = r"(\w+)\(((?:[^()]*|\([^()]*\))*)\)"
matches = re.finditer(pattern, useful_text)
Expand All @@ -138,26 +121,23 @@ def extract_tool_calls_streaming(

parsed = ast.parse(tool_call_string)
for elem in parsed.body:
assert isinstance(elem.value, ast.Call)
calls = resolve_ast_call(elem.value)
assert isinstance(elem.value, ast.Call) # type: ignore
calls = resolve_ast_call(elem.value) # type: ignore

for func_name, func_args in calls.items():
this_call = {
"name": func_name,
"arguments": json.dumps(
func_args, ensure_ascii=False
),
"name":
func_name,
"arguments":
json.dumps(func_args, ensure_ascii=False),
}
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
**this_call
).model_dump(exclude_none=True),
)
]
)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
**this_call).model_dump(exclude_none=True),
)
])
self.prev_tool_call_arr = tool_calls
self.streamed_args_for_tool = ["" for x in tool_calls]
self.current_tool_name_sent = True
Expand Down Expand Up @@ -187,25 +167,23 @@ def fc2dict(
if tool_call_string.startswith("```"):
tool_call_string = tool_call_string[3:].strip()
if tool_call_string.startswith("python"):
tool_call_string = tool_call_string.lstrip("python").strip()
tool_call_string = tool_call_string.lstrip(
"python").strip()
if tool_call_string.endswith("```"):
tool_call_string = tool_call_string[:-3].strip()
for kw in keyword.kwlist:
tool_call_string = tool_call_string.replace(
"," + kw + "=", "," + kw + "_="
)
"," + kw + "=", "," + kw + "_=")
tool_call_string = tool_call_string.replace(
" " + kw + "=", " " + kw + "_="
)
" " + kw + "=", " " + kw + "_=")
tool_call_string = tool_call_string.replace(
"(" + kw + "=", "(" + kw + "_="
)
"(" + kw + "=", "(" + kw + "_=")

parsed = ast.parse(tool_call_string)
parsed: ast.Module = ast.parse(tool_call_string)

for elem in parsed.body:
assert isinstance(elem.value, ast.Call)
calls = resolve_ast_call(elem.value)
assert isinstance(elem.value, ast.Call) # type: ignore
calls = resolve_ast_call(elem.value) # type: ignore

for func_name, func_args in calls.items():
new_args = {}
Expand All @@ -224,7 +202,7 @@ def fc2dict(
"role": "assistant",
}
except Exception as e:
logger.error(f"Error parsing tool call: {e}")
logger.error("Error parsing tool call: %s", str(e))
logger.error(traceback.format_exc())
return {
"content": content.strip(),
Expand Down Expand Up @@ -259,12 +237,9 @@ def resolve_ast_call(elem):

def resolve_ast_by_type(value):
if isinstance(value, ast.Constant):
if value.value is Ellipsis:
output = "..."
else:
output = value.value
output = "..." if value.value is Ellipsis else value.value
elif isinstance(value, ast.UnaryOp):
output = -value.operand.value
output = -value.operand.value # type: ignore
elif isinstance(value, ast.List):
output = [resolve_ast_by_type(v) for v in value.elts]
elif isinstance(value, ast.Dict):
Expand All @@ -273,34 +248,36 @@ def resolve_ast_by_type(value):
for k, v in zip(value.keys, value.values)
}
elif isinstance(
value, ast.NameConstant
): # Added this condition to handle boolean values
value,
ast.NameConstant): # Added this condition to handle boolean values
output = value.value
elif isinstance(
value, ast.BinOp
value, ast.BinOp
): # Added this condition to handle function calls as arguments
output = ast.literal_eval(ast.unparse(value))
output = ast.literal_eval(ast.unparse(value)) # type: ignore
elif isinstance(value, ast.Name):
output = value.id
elif isinstance(value, ast.Call):
if len(value.keywords) == 0:
output = ast.unparse(value)
output = ast.unparse(value) # type: ignore
else:
output = resolve_ast_call(value)
elif isinstance(value, ast.Tuple):
output = tuple(resolve_ast_by_type(v) for v in value.elts)
elif isinstance(value, ast.Lambda):
output = ast.literal_eval(ast.unparse(value.body[0].value))
output = ast.literal_eval(
ast.unparse( # type: ignore
value.body[0].value)) # type: ignore
elif isinstance(value, ast.Ellipsis):
output = "..."
elif isinstance(value, ast.Subscript):
try:
output = ast.unparse(value.body[0].value)
output = ast.unparse(value.body[0].value) # type: ignore
except Exception as e:
logger.error(f"Error parsing tool call: {e}")
logger.error("Error parsing tool call: %s", str(e))
output = (
ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]"
)
ast.unparse(value.value) + "[" + # type: ignore
ast.unparse(value.slice) + "]") # type: ignore
else:
raise Exception(f"Unsupported AST type: {type(value)}")
return output

0 comments on commit bfcc2c4

Please sign in to comment.