From 7439a8b5fcbc4d77bd73496f27d4048c5b43cb22 Mon Sep 17 00:00:00 2001 From: Clayton <132770471+cedonley@users.noreply.github.com> Date: Wed, 11 Dec 2024 17:10:12 -0800 Subject: [PATCH] [Bugfix] Multiple fixes to tool streaming with hermes and mistral (#10979) Signed-off-by: cedonley --- vllm/entrypoints/openai/serving_chat.py | 16 +++++- .../openai/tool_parsers/hermes_tool_parser.py | 51 +++++++++++++++---- .../tool_parsers/mistral_tool_parser.py | 23 ++++++--- 3 files changed, 69 insertions(+), 21 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 0af7613a473a4..0738210e27cb6 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -496,21 +496,33 @@ async def chat_completion_stream_generator( if self._should_check_for_unstreamed_tool_arg_tokens( delta_message, output) and tool_parser: + latest_delta_len = 0 + if ((isinstance( + delta_message.tool_calls[0].function, + DeltaFunctionCall)) and isinstance( + delta_message.tool_calls[0].function. + arguments, str)): + latest_delta_len = len( + delta_message.tool_calls[0].function. + arguments) + # get the expected call based on partial JSON # parsing which "autocompletes" the JSON expected_call = json.dumps( tool_parser.prev_tool_call_arr[index].get( - "arguments", {})) + "arguments", {}), + ensure_ascii=False) # get what we've streamed so far for arguments # for the current tool actual_call = tool_parser.streamed_args_for_tool[ index] + if (latest_delta_len > 0): + actual_call = actual_call[:-latest_delta_len] # check to see if there's anything left to stream remaining_call = expected_call.replace( actual_call, "", 1) - # set that as a delta message delta_message = DeltaMessage(tool_calls=[ DeltaToolCall(index=index, diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 18816cd665b3e..869d15ac359ea 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -91,7 +91,8 @@ def extract_tool_calls( function=FunctionCall( name=function_call["name"], # function call args are JSON but as a string - arguments=json.dumps(function_call["arguments"]))) + arguments=json.dumps(function_call["arguments"], + ensure_ascii=False))) for function_call in raw_function_calls ] @@ -139,13 +140,26 @@ def extract_tool_calls_streaming( self.tool_call_start_token_id) cur_tool_end_count = current_token_ids.count( self.tool_call_end_token_id) + tool_call_portion = None + text_portion = None # case: if we're generating text, OR rounding out a tool call if (cur_tool_start_count == cur_tool_end_count - and prev_tool_end_count == cur_tool_end_count): + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text): logger.debug("Generating text content! skipping tool parsing.") - if delta_text != self.tool_call_end_token: - return DeltaMessage(content=delta_text) + return DeltaMessage(content=delta_text) + + if self.tool_call_end_token in delta_text: + logger.debug("tool_call_end_token in delta_text") + full_text = current_text + delta_text + tool_call_portion = full_text.split( + self.tool_call_start_token)[-1].split( + self.tool_call_end_token)[0].rstrip() + delta_text = delta_text.split( + self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split( + self.tool_call_end_token)[-1].lstrip() # case: if tool open & close tag counts don't match, we're doing # imaginary "else" block here @@ -184,15 +198,21 @@ def extract_tool_calls_streaming( # case -- the current tool call is being closed. elif (cur_tool_start_count == cur_tool_end_count - and cur_tool_end_count > prev_tool_end_count): + and cur_tool_end_count >= prev_tool_end_count): + if (self.prev_tool_call_arr is None + or len(self.prev_tool_call_arr) == 0): + logger.debug( + "attempting to close tool call, but no tool call") + return None diff = self.prev_tool_call_arr[self.current_tool_id].get( "arguments") if diff: diff = diff.encode('utf-8').decode( 'unicode_escape') if diff is str else diff - diff = json.dumps( - diff, ensure_ascii=False - )[len(self.streamed_args_for_tool[self.current_tool_id]):] + if ('"}' not in delta_text): + return None + end_loc = delta_text.rindex('"}') + diff = delta_text[:end_loc] + '"}' logger.debug( "Finishing tool and found diff that had not " "been streamed yet: %s", diff) @@ -221,10 +241,15 @@ def extract_tool_calls_streaming( except partial_json_parser.core.exceptions.MalformedJSON: logger.debug('not enough tokens to parse into JSON yet') return None + except json.decoder.JSONDecodeError: + logger.debug("unable to parse JSON") + return None # case - we haven't sent the tool name yet. If it's available, send # it. otherwise, wait until it's available. if not self.current_tool_name_sent: + if (current_tool_call is None): + return None function_name: Union[str, None] = current_tool_call.get("name") if function_name: self.current_tool_name_sent = True @@ -284,13 +309,17 @@ def extract_tool_calls_streaming( # autocompleting the JSON elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments) + cur_arguments_json = json.dumps(cur_arguments, + ensure_ascii=False) logger.debug("finding %s in %s", delta_text, cur_arguments_json) # get the location where previous args differ from current - args_delta_start_loc = cur_arguments_json.index(delta_text) \ - + len(delta_text) + if (delta_text not in cur_arguments_json[:-2]): + return None + args_delta_start_loc = cur_arguments_json[:-2]. \ + rindex(delta_text) + \ + len(delta_text) # use that to find the actual delta arguments_delta = cur_arguments_json[:args_delta_start_loc] diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 5caac84138e3b..bada805dd35b9 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -19,7 +19,6 @@ extract_intermediate_diff) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import random_uuid logger = init_logger(__name__) @@ -109,7 +108,8 @@ def extract_tool_calls( function=FunctionCall( name=raw_function_call["name"], # function call args are JSON but as a string - arguments=json.dumps(raw_function_call["arguments"]))) + arguments=json.dumps(raw_function_call["arguments"], + ensure_ascii=False))) for raw_function_call in function_call_arr ] @@ -199,7 +199,7 @@ def extract_tool_calls_streaming( diff: Union[str, None] = current_tool_call.get("arguments") if diff: - diff = json.dumps(diff).replace( + diff = json.dumps(diff, ensure_ascii=False).replace( self.streamed_args_for_tool[self.current_tool_id], "") delta = DeltaMessage(tool_calls=[ @@ -232,7 +232,7 @@ def extract_tool_calls_streaming( delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=f"chatcmpl-tool-{random_uuid()}", + id=MistralToolCall.generate_random_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) @@ -250,6 +250,8 @@ def extract_tool_calls_streaming( cur_arguments = current_tool_call.get("arguments") new_text = delta_text.replace("\'", "\"") + if ('"}' in new_text): + new_text = new_text[:new_text.rindex('"}')] if not cur_arguments and not prev_arguments: @@ -260,12 +262,15 @@ def extract_tool_calls_streaming( "mid-arguments") delta = None elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments) + cur_arguments_json = json.dumps(cur_arguments, + ensure_ascii=False)[:-2] logger.debug("finding %s in %s", new_text, cur_arguments_json) + if (new_text not in cur_arguments_json): + return None arguments_delta = cur_arguments_json[:cur_arguments_json. - index(new_text) + + rindex(new_text) + len(new_text)] logger.debug("First tokens in arguments received: %s", arguments_delta) @@ -279,8 +284,10 @@ def extract_tool_calls_streaming( self.current_tool_id] += arguments_delta elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments) - prev_args_json = json.dumps(prev_arguments) + cur_args_json = json.dumps(cur_arguments, + ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, + ensure_ascii=False) logger.debug("Searching for diff between \n%s\n%s", cur_args_json, prev_args_json)