From 6b762e5d7d1cf71593c391305e78ecf0de3d0f3a Mon Sep 17 00:00:00 2001 From: constellate Date: Sun, 16 Jun 2024 01:21:28 -0500 Subject: [PATCH 001/222] add hermes 2 pro function calling template --- examples/tool_template_hermes_2_pro.jinja | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 examples/tool_template_hermes_2_pro.jinja diff --git a/examples/tool_template_hermes_2_pro.jinja b/examples/tool_template_hermes_2_pro.jinja new file mode 100644 index 0000000000000..21ac11505eb7c --- /dev/null +++ b/examples/tool_template_hermes_2_pro.jinja @@ -0,0 +1,7 @@ +You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: +{{tools}} +Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} +For each function call return a json object with function name and arguments within XML tags as follows: + +{"arguments": , "name": } + \ No newline at end of file From 606ec64bd1732fa0a1ffb13ec33d0c56f2714aab Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sun, 16 Jun 2024 16:39:01 -0500 Subject: [PATCH 002/222] feat(example): add example chat completion with tool usage --- ...penai_chat_completion_client_with_tools.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 examples/openai_chat_completion_client_with_tools.py diff --git a/examples/openai_chat_completion_client_with_tools.py b/examples/openai_chat_completion_client_with_tools.py new file mode 100644 index 0000000000000..0414a4be0e2e9 --- /dev/null +++ b/examples/openai_chat_completion_client_with_tools.py @@ -0,0 +1,61 @@ +from openai import OpenAI + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, +) + +models = client.models.list() +model = models.data[0].id +tools = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'San Francisco'" + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that the city is in, e.g. 'CA' which would mean 'California'" + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"] + } + } + } + } +}] + +chat_completion = client.chat.completions.create( + messages=[ + { + "role": "user", + "content": "Hi! How are you doing today?" + }, + { + "role": "assistant", + "content": "I'm doing well! How can I help you?" + }, + { + "role": "user", + "content": "Can you tell me what the weather will be in Dallas Texas?" + } + ], + model=model, + tools=tools +) + +print("Chat completion results:") +print(chat_completion) From d27446f26a77b51616fbd9364dc29a3004c088f2 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sun, 16 Jun 2024 23:26:32 -0500 Subject: [PATCH 003/222] feat: add CLI argument for OpenAI API-style tool use system prompt jinja template --- vllm/engine/arg_utils.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ba53b5c86fa72..f0add4244df8d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -152,6 +152,28 @@ def add_cli_args_for_vlm( return parser + @staticmethod + def add_cli_args_for_tools(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """ + CLI arguments to configure tool use for the OpenAI API-style endpoint. + At this point, only a template for taking the provided tools and formatting them + into a model-specific system prompt format is supported, but others may be added + in the future, e.g. for decoding the tool call generated by the model into the + OpenAI API style. + """ + parser.add_argument( + '--tool-use-prompt-template', + type=str, + default=None, + help="The path to the jinja template that should be used to format " + "any provided OpenAI API-style function definitions into a system prompt " + "that instructs the model how to use tools, and which tools are " + "available. If not provided, tools will be ignored. An example is " + "provided at 'examples/tool_template_hermes_2_pro.jinja'." + ) + + return parser + @staticmethod def add_cli_args( parser: argparse.ArgumentParser) -> argparse.ArgumentParser: @@ -510,6 +532,9 @@ def add_cli_args( # Related to Vision-language models such as llava parser = EngineArgs.add_cli_args_for_vlm(parser) + # Related to OpenAI API-style entrypoint + parser = EngineArgs.add_cli_args_for_tools(parser) + parser.add_argument( '--scheduler-delay-factor', type=float, @@ -598,6 +623,7 @@ def add_cli_args( type=str, default=None, help='Name or path of the QLoRA adapter.') + return parser @classmethod From bdf48a16eba1eda1133695a73c3bef6e9dd501dd Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sun, 16 Jun 2024 23:50:36 -0500 Subject: [PATCH 004/222] feat: add better validation for tool_choice, support for tool_choice = "auto"; set default `tool_choice` to "auto" --- vllm/entrypoints/openai/protocol.py | 45 +++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index b57d79859aec5..9258a1a417c92 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -305,13 +305,48 @@ def check_guided_decoding_count(cls, data): @model_validator(mode="before") @classmethod - def check_tool_choice(cls, data): - if "tool_choice" in data and data["tool_choice"] != "none": - if not isinstance(data["tool_choice"], dict): - raise ValueError("Currently only named tools are supported.") + def check_tool_usage(cls, data): + print("DATA", data) + + if "tool_choice" in data: + + # ensure that if "tool choice" is specified, tools are present if "tools" not in data or data["tools"] is None: + raise ValueError("When using `tool_choice`, `tools` must be set.") + + + # make sure that tool choice is either a named tool OR that it's set to "auto" + if data["tool_choice"] != "auto" and not isinstance(data["tool_choice"], dict): raise ValueError( - "When using `tool_choice`, `tools` must be set.") + "`tool_choice` must either be a named tool or \"auto\". `tool_choice=\"none\" is not supported.") + + # ensure that if "tool_choice" is specified as an object, it matches a valid tool + if isinstance(data["tool_choice"], dict): + valid_tool = False + specified_function = data["tool_choice"]["function"] + if not specified_function: + return ValueError( + 'Incorrectly formatted `tool_choice`. Should be like ' + + '`{"type": "function", "function": {"name": "my_function"}}`' + ) + specified_function_name = specified_function["name"] + if not specified_function_name: + return ValueError( + 'Incorrectly formatted `tool_choice`. Should be like ' + + '`{"type": "function", "function": {"name": "my_function"}}`' + ) + for tool in data['tools']: + if tool["function"]["name"] == specified_function_name: + valid_tool = True + break + if not valid_tool: + return ValueError("The tool specified in `tool_choice` does not match any of the specified `tools`") + + # per OpenAI spec, make sure that tool_choice defaults to "auto" when tools are specified + elif "tools" in data and "tool_choice" not in data: + data["tool_choice"] = "auto" + + # TODO validate tools return data @model_validator(mode="before") From 9493c6763dbd1db1283d5562d9944d4c30d1473e Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 17 Jun 2024 22:40:32 -0500 Subject: [PATCH 005/222] fix(cli): set OpenAI tool args in vllm/entrypoints/openai/cli_args.py instead of vllm/engine/arg_utils.py --- vllm/engine/arg_utils.py | 24 ------------------------ vllm/entrypoints/openai/cli_args.py | 16 ++++++++++++++++ vllm/entrypoints/openai/protocol.py | 2 +- 3 files changed, 17 insertions(+), 25 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f0add4244df8d..fd97ab4fb4960 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -152,28 +152,6 @@ def add_cli_args_for_vlm( return parser - @staticmethod - def add_cli_args_for_tools(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: - """ - CLI arguments to configure tool use for the OpenAI API-style endpoint. - At this point, only a template for taking the provided tools and formatting them - into a model-specific system prompt format is supported, but others may be added - in the future, e.g. for decoding the tool call generated by the model into the - OpenAI API style. - """ - parser.add_argument( - '--tool-use-prompt-template', - type=str, - default=None, - help="The path to the jinja template that should be used to format " - "any provided OpenAI API-style function definitions into a system prompt " - "that instructs the model how to use tools, and which tools are " - "available. If not provided, tools will be ignored. An example is " - "provided at 'examples/tool_template_hermes_2_pro.jinja'." - ) - - return parser - @staticmethod def add_cli_args( parser: argparse.ArgumentParser) -> argparse.ArgumentParser: @@ -532,8 +510,6 @@ def add_cli_args( # Related to Vision-language models such as llava parser = EngineArgs.add_cli_args_for_vlm(parser) - # Related to OpenAI API-style entrypoint - parser = EngineArgs.add_cli_args_for_tools(parser) parser.add_argument( '--scheduler-delay-factor', diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 4c0cb1e4f3e49..3b8970fb8c206 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -111,5 +111,21 @@ def make_arg_parser(): "If a class is provided, vLLM will add it to the server " "using app.add_middleware(). ") + parser.add_argument("--enable-api-tools", + action="store_true", + help="Enable OpenAI-like tools API " + "(only function calls are currently supported)") + + parser.add_argument( + '--tool-use-prompt-template', + type=str, + default=None, + help="The path to the jinja template that should be used to format " + "any provided OpenAI API-style function definitions into a system prompt " + "that instructs the model how to use tools, and which tools are " + "available. If not provided, tools will be ignored. An example is " + "provided at 'examples/tool_template_hermes_2_pro.jinja'." + ) + parser = AsyncEngineArgs.add_cli_args(parser) return parser diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 9258a1a417c92..9bf5bb0940d9d 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -341,7 +341,7 @@ def check_tool_usage(cls, data): break if not valid_tool: return ValueError("The tool specified in `tool_choice` does not match any of the specified `tools`") - + # per OpenAI spec, make sure that tool_choice defaults to "auto" when tools are specified elif "tools" in data and "tool_choice" not in data: data["tool_choice"] = "auto" From 35c9aa7f24530bf2ae2bd53ac2b91933bbc43b9b Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 17 Jun 2024 22:44:27 -0500 Subject: [PATCH 006/222] fix(types): add "auto" as an option for tool_choice in pydantic models --- vllm/entrypoints/openai/protocol.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 9bf5bb0940d9d..8bae87e1a45ec 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -148,8 +148,15 @@ class ChatCompletionRequest(OpenAIBaseModel): temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 tools: Optional[List[ChatCompletionToolsParam]] = None - tool_choice: Optional[Union[Literal["none"], - ChatCompletionNamedToolChoiceParam]] = "none" + tool_choice: Optional[ + Union[ + Union[ + Literal["none"], + Literal["auto"] + ], + ChatCompletionNamedToolChoiceParam + ] + ] = "none" user: Optional[str] = None # doc: begin-chat-completion-sampling-params From 8ff80fb5cdb4720742acea46c9c8ad5a43e85359 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 17 Jun 2024 22:49:35 -0500 Subject: [PATCH 007/222] fix: validation - guided decoding not valid with tool_choice = auto --- vllm/entrypoints/openai/protocol.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 8bae87e1a45ec..391776c7651de 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -304,8 +304,8 @@ def check_guided_decoding_count(cls, data): "You can only use one kind of guided decoding " "('guided_json', 'guided_regex' or 'guided_choice').") # you can only either use guided decoding or tools, not both - if guide_count > 1 and "tool_choice" in data and data[ - "tool_choice"] != "none": + if (guide_count > 1 and "tool_choice" in data and data[ + "tool_choice"] != "none" and data["tool_choice"] != "auto"): raise ValueError( "You can only either use guided decoding or tools, not both.") return data From 3c4acb1f5a857ebbf856ee35b8846a280556f38f Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 17 Jun 2024 22:55:41 -0500 Subject: [PATCH 008/222] feat: ensure guided deoding is only applied when tool_choice is NOT "none" AND NOT "auto" --- .../model_executor/guided_decoding/__init__.py | 4 ++-- .../guided_decoding/outlines_decoding.py | 18 +++++++++++++++--- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 50aa3ec379f4a..af70e00727e20 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -34,8 +34,8 @@ def _adapt_request_for_tool_use(request: Union[CompletionRequest, if type(request) is CompletionRequest: return request - # user has chosen to not use any tool - if request.tool_choice == "none": + # user has chosen to not use any tool, OR is allowing the model to choose a tool. + if request.tool_choice == "none" or request.tool_choice == "auto": return request # user has chosen to use a named tool diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index 721f7e0530cb7..9a5e82e29f146 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -9,11 +9,11 @@ from transformers import PreTrainedTokenizerBase from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - CompletionRequest) + CompletionRequest, + ChatCompletionNamedToolChoiceParam) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) - class GuidedDecodingMode(Enum): JSON = "json" REGEX = "regex" @@ -81,7 +81,19 @@ def _get_guide_and_mode( request: Union[CompletionRequest, ChatCompletionRequest] ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: - if request.guided_json: + # if the request is a chat completion request, AND the tool choice is a named tool choice, do guided decoding + # using that tool as the JSON schema + if isinstance(request, ChatCompletionRequest) and isinstance( + request.tool_choice, ChatCompletionNamedToolChoiceParam): + # Guided generation for tools/functions parameters + if request.tool_choice.type == "function": + for tool in request.tools: + if tool.type == "function" and tool.function.name == request.tool_choice.function.name: + json = json_dumps(tool.function.parameters, sort_keys=True) + return json, GuidedDecodingMode.JSON + return None, None + + elif request.guided_json: json = request.guided_json if isinstance(json, dict): # turn dict into hashable string From 9c5ef667f59951f6db7cff9f6e2e2c5c9de66438 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 17 Jun 2024 23:30:56 -0500 Subject: [PATCH 009/222] feat(cli): update CLI args for auto tool choice and OpenAIServingChat to receive the arguments & validate them --- vllm/entrypoints/openai/api_server.py | 6 ++++- vllm/entrypoints/openai/cli_args.py | 17 ++++++++++++++ vllm/entrypoints/openai/serving_chat.py | 30 ++++++++++++++++++++++++- 3 files changed, 51 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index ea6275920c79d..e5e4f86a31b40 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -214,7 +214,11 @@ async def authentication(request: Request, call_next): served_model_names, args.response_role, args.lora_modules, - args.chat_template) + args.chat_template, + args.enable_auto_tool_choice, + args.tool_use_prompt_template, + args.tool_use_prompt_role + ) openai_serving_completion = OpenAIServingCompletion( engine, model_config, served_model_names, args.lora_modules) openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 3b8970fb8c206..5133f0010a01e 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -116,6 +116,12 @@ def make_arg_parser(): help="Enable OpenAI-like tools API " "(only function calls are currently supported)") + parser.add_argument("--enable-auto-tool-choice", + action="store_true", + help='Enable auto tool choice for models that support it. ' + 'Requires specifying --tool-use-prompt-template.' + ) + parser.add_argument( '--tool-use-prompt-template', type=str, @@ -127,5 +133,16 @@ def make_arg_parser(): "provided at 'examples/tool_template_hermes_2_pro.jinja'." ) + parser.add_argument( + '--tool-use-prompt-role', + type=str, + default='system', + help='The chat role to use for the system prompt that instructs the model what tools are ' + 'available and how to use them. The default is "system". If the "system" role is used for the tool ' + 'use system prompt (default) _and_ the client specifies a system prompt, then the client-' + 'specified system prompt will be appended to the tool use system prompt. If a non-"system" role is ' + 'specified, it will be placed as the first non-system message in the conversation.' + ) + parser = AsyncEngineArgs.add_cli_args(parser) return parser diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 76940612496a0..d79ba2ce7c762 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -57,7 +57,11 @@ def __init__(self, served_model_names: List[str], response_role: str, lora_modules: Optional[List[LoRAModulePath]] = None, - chat_template: Optional[str] = None): + chat_template: Optional[str] = None, + enable_auto_tools: Optional[bool] = False, + tool_prompt_jinja_template_path: Optional[str] = None, + tool_prompt_role: Optional[str] = 'system' + ): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, @@ -66,6 +70,30 @@ def __init__(self, self.response_role = response_role self._load_chat_template(chat_template) + + print('CONSTELLATE enable_auto_tools', enable_auto_tools) + print("CONSTELLATE tool_prompt_jinja_template_path", tool_prompt_jinja_template_path) + print("CONSTELLATE tool_prompt_role", tool_prompt_role) + # set up tool use + self.enable_auto_tools = enable_auto_tools + self.tool_prompt_role = tool_prompt_role + if self.enable_auto_tools and tool_prompt_jinja_template_path: + self.tool_use_prompt_template = self._load_tool_prompt_template(tool_prompt_jinja_template_path) + elif self.enable_auto_tools and not tool_prompt_jinja_template_path: + raise ValueError( + 'Argument --enable-auto-tool-choice requires --tool-use-prompt-path to set the prompt for instructing ' + 'the model on which tools are available and how to use them.' + ) + + # TODO set the system prompt for tools and system prompt role for tools if applicable + def _load_tool_prompt_template(self, tool_prompt_jinja_template_path: Optional[str] = None) -> None: + """ + Load the Jinja template for the tool prompt + """ + print("Loading tool prompt template!", tool_prompt_jinja_template_path) + return None + + def _load_chat_template(self, chat_template: Optional[str]): tokenizer = self.tokenizer From f1a1e7b6e5fb681d1fb3c9de58db6557e7521201 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 17 Jun 2024 23:41:40 -0500 Subject: [PATCH 010/222] feat: add loading in the system prompt jinja template if specified; along with validation --- vllm/entrypoints/openai/serving_chat.py | 26 +++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index d79ba2ce7c762..530174297ae9e 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -6,6 +6,7 @@ from typing import Sequence as GenericSequence from typing import TypedDict, Union, cast, final +import jinja2 from fastapi import Request from openai.types.chat import (ChatCompletionContentPartImageParam, ChatCompletionContentPartTextParam) @@ -33,6 +34,12 @@ from vllm.sequence import Logprob from vllm.utils import random_uuid +from jinja2 import Environment, FileSystemLoader, select_autoescape +env = Environment( + loader=FileSystemLoader('./'), + autoescape=select_autoescape() +) + logger = init_logger(__name__) @@ -75,8 +82,10 @@ def __init__(self, print("CONSTELLATE tool_prompt_jinja_template_path", tool_prompt_jinja_template_path) print("CONSTELLATE tool_prompt_role", tool_prompt_role) # set up tool use - self.enable_auto_tools = enable_auto_tools - self.tool_prompt_role = tool_prompt_role + self.enable_auto_tools: bool = enable_auto_tools + self.tool_prompt_role: str = tool_prompt_role + self.tool_use_prompt_template: Optional[jinja2.Template] = None + if self.enable_auto_tools and tool_prompt_jinja_template_path: self.tool_use_prompt_template = self._load_tool_prompt_template(tool_prompt_jinja_template_path) elif self.enable_auto_tools and not tool_prompt_jinja_template_path: @@ -85,13 +94,22 @@ def __init__(self, 'the model on which tools are available and how to use them.' ) + # TODO set the system prompt for tools and system prompt role for tools if applicable - def _load_tool_prompt_template(self, tool_prompt_jinja_template_path: Optional[str] = None) -> None: + def _load_tool_prompt_template(self, tool_prompt_jinja_template_path: str) -> jinja2.Template: """ Load the Jinja template for the tool prompt """ print("Loading tool prompt template!", tool_prompt_jinja_template_path) - return None + + template = env.get_template(tool_prompt_jinja_template_path) + if not template: + raise FileNotFoundError( + f'The specified tool use prompt template {tool_prompt_jinja_template_path} was not found' + ) + + # Load the JINJA template + return template def _load_chat_template(self, chat_template: Optional[str]): From 07a67d212fa3240cede59343fc5a7d06fbe4aa1d Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 18 Jun 2024 19:14:52 -0500 Subject: [PATCH 011/222] wip: add a case for tool choice = auto when handling chat completion requests --- vllm/entrypoints/openai/serving_chat.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 530174297ae9e..d7fa6734cc425 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -550,8 +550,10 @@ async def chat_completion_full_generator( else: logprobs = None + # if the reqeust uses tools and specified a tool choice if request.tool_choice and type( request.tool_choice) is ChatCompletionNamedToolChoiceParam: + print("CONSTELLATE handling named tool choice") message = ChatMessage( role=role, content="", @@ -560,7 +562,17 @@ async def chat_completion_full_generator( name=request.tool_choice.function.name, arguments=output.text)) ]) + + # if the request doesn't use tool choice OR specifies to not use a tool elif not request.tool_choice or request.tool_choice == "none": + print("CONSTELLATE handling no tool choice or tool_choice = none") + message = ChatMessage(role=role, content=output.text) + + # handle when there are tools and tool choice is auto + elif request.tools and (request.tool_choice == "auto" or request.tool_choice is None): + print("CONSTELLATE handling tool choice = auto") + print(output) + # FOR NOW make it a chat message; we will have to detect the type to make it later. message = ChatMessage(role=role, content=output.text) choice_data = ChatCompletionResponseChoice( From 590a5590f146f9fd0a57e79299f584995be258aa Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 18 Jun 2024 21:12:28 -0500 Subject: [PATCH 012/222] fix: hermes 2 pro prompt template to prevent newlines --- examples/tool_template_hermes_2_pro.jinja | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/tool_template_hermes_2_pro.jinja b/examples/tool_template_hermes_2_pro.jinja index 21ac11505eb7c..c03e5dc9e66de 100644 --- a/examples/tool_template_hermes_2_pro.jinja +++ b/examples/tool_template_hermes_2_pro.jinja @@ -2,6 +2,4 @@ You are a function calling AI model. You are provided with function signatures w {{tools}} Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} For each function call return a json object with function name and arguments within XML tags as follows: - -{"arguments": , "name": } - \ No newline at end of file +{"arguments": , "name": } \ No newline at end of file From 6919acf05027f6868a52964074473fdad6574175 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 18 Jun 2024 21:12:58 -0500 Subject: [PATCH 013/222] feat: handle building the system prompt via template with --enable-auto-tool-choice and --tool-use-prompt-template are specified --- vllm/entrypoints/openai/serving_chat.py | 26 +++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index d7fa6734cc425..e4e8d8fb92184 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -262,15 +262,34 @@ async def create_chat_completion( conversation.extend(chat_parsed_result.messages) image_futures.extend(chat_parsed_result.image_futures) + # if specified, add the system prompt template + print('CONSTELLATE request tools', request.tools) + if self.enable_auto_tools and self.tool_use_prompt_template: + print('CONSTELLATE configuring tools') + # create the system prompt from the template + templated_prompt_with_tools: str = self.tool_use_prompt_template.render(tools=request.tools) + + # if there is already a system prompt + if conversation[0]['role'] == 'system': + print('CONSTELLATE modifying existing system prompt with tool template') + conversation[0]['content'] = f'{templated_prompt_with_tools}\n\n{conversation[0]["content"]}' + + # if there isn't a system prompt already + else: + conversation.insert(0, ConversationMessage(role='system', content=templated_prompt_with_tools)) + print('CONSTELLATE conversation:', conversation) + prompt = self.tokenizer.apply_chat_template( conversation=conversation, tokenize=False, add_generation_prompt=request.add_generation_prompt, ) + print('CONSTELLATE prompt', prompt) except Exception as e: logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) + # Fetch image data image_data: Optional[ImagePixelData] = None try: @@ -516,8 +535,11 @@ async def chat_completion_stream_generator( yield "data: [DONE]\n\n" async def chat_completion_full_generator( - self, request: ChatCompletionRequest, raw_request: Optional[Request], - result_generator: AsyncIterator[RequestOutput], request_id: str, + self, + request: ChatCompletionRequest, + raw_request: Optional[Request], + result_generator: AsyncIterator[RequestOutput], + request_id: str, conversation: List[ConversationMessage] ) -> Union[ErrorResponse, ChatCompletionResponse]: From db9c29aa319f1a559b5a228ec9f6e94f7171cfc5 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 18 Jun 2024 22:27:08 -0500 Subject: [PATCH 014/222] fix: remove debugging log lines --- vllm/entrypoints/openai/serving_chat.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index e4e8d8fb92184..a3d0c6d5702da 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -77,10 +77,6 @@ def __init__(self, self.response_role = response_role self._load_chat_template(chat_template) - - print('CONSTELLATE enable_auto_tools', enable_auto_tools) - print("CONSTELLATE tool_prompt_jinja_template_path", tool_prompt_jinja_template_path) - print("CONSTELLATE tool_prompt_role", tool_prompt_role) # set up tool use self.enable_auto_tools: bool = enable_auto_tools self.tool_prompt_role: str = tool_prompt_role @@ -263,28 +259,23 @@ async def create_chat_completion( image_futures.extend(chat_parsed_result.image_futures) # if specified, add the system prompt template - print('CONSTELLATE request tools', request.tools) if self.enable_auto_tools and self.tool_use_prompt_template: - print('CONSTELLATE configuring tools') # create the system prompt from the template templated_prompt_with_tools: str = self.tool_use_prompt_template.render(tools=request.tools) # if there is already a system prompt if conversation[0]['role'] == 'system': - print('CONSTELLATE modifying existing system prompt with tool template') conversation[0]['content'] = f'{templated_prompt_with_tools}\n\n{conversation[0]["content"]}' # if there isn't a system prompt already else: conversation.insert(0, ConversationMessage(role='system', content=templated_prompt_with_tools)) - print('CONSTELLATE conversation:', conversation) prompt = self.tokenizer.apply_chat_template( conversation=conversation, tokenize=False, add_generation_prompt=request.add_generation_prompt, ) - print('CONSTELLATE prompt', prompt) except Exception as e: logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) @@ -575,7 +566,7 @@ async def chat_completion_full_generator( # if the reqeust uses tools and specified a tool choice if request.tool_choice and type( request.tool_choice) is ChatCompletionNamedToolChoiceParam: - print("CONSTELLATE handling named tool choice") + message = ChatMessage( role=role, content="", @@ -587,13 +578,12 @@ async def chat_completion_full_generator( # if the request doesn't use tool choice OR specifies to not use a tool elif not request.tool_choice or request.tool_choice == "none": - print("CONSTELLATE handling no tool choice or tool_choice = none") + message = ChatMessage(role=role, content=output.text) # handle when there are tools and tool choice is auto elif request.tools and (request.tool_choice == "auto" or request.tool_choice is None): - print("CONSTELLATE handling tool choice = auto") - print(output) + # FOR NOW make it a chat message; we will have to detect the type to make it later. message = ChatMessage(role=role, content=output.text) From c16aa9a153f16d343c3223facd0220bc2fcff990 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 20 Jun 2024 10:18:19 -0500 Subject: [PATCH 015/222] fix(template): update hermes 2 pro template with newlines to get newlines as expected --- examples/tool_template_hermes_2_pro.jinja | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/tool_template_hermes_2_pro.jinja b/examples/tool_template_hermes_2_pro.jinja index c03e5dc9e66de..21ac11505eb7c 100644 --- a/examples/tool_template_hermes_2_pro.jinja +++ b/examples/tool_template_hermes_2_pro.jinja @@ -2,4 +2,6 @@ You are a function calling AI model. You are provided with function signatures w {{tools}} Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} For each function call return a json object with function name and arguments within XML tags as follows: -{"arguments": , "name": } \ No newline at end of file + +{"arguments": , "name": } + \ No newline at end of file From 9b9d86121fed74d49493eb38f8f7c5eb85f485ac Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 1 Jul 2024 19:32:03 -0500 Subject: [PATCH 016/222] feat: add hermes 2 pro and mistral full tool use chat templates --- .../tool_chat_template_hermes_2_pro.jinja | 120 ++++++++++++++++++ examples/tool_chat_template_mistral.jinja | 49 +++++++ 2 files changed, 169 insertions(+) create mode 100644 examples/tool_chat_template_hermes_2_pro.jinja create mode 100644 examples/tool_chat_template_mistral.jinja diff --git a/examples/tool_chat_template_hermes_2_pro.jinja b/examples/tool_chat_template_hermes_2_pro.jinja new file mode 100644 index 0000000000000..61192c02866a0 --- /dev/null +++ b/examples/tool_chat_template_hermes_2_pro.jinja @@ -0,0 +1,120 @@ +{%- macro json_to_python_type(json_spec) %} +{%- set basic_type_map = { + "string": "str", + "number": "float", + "integer": "int", + "boolean": "bool" +} %} + +{%- if basic_type_map[json_spec.type] is defined %} + {{- basic_type_map[json_spec.type] }} +{%- elif json_spec.type == "array" %} + {{- "list[" + json_to_python_type(json_spec|items) + "]"}} +{%- elif json_spec.type == "object" %} + {%- if json_spec.additionalProperties is defined %} + {{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']'}} + {%- else %} + {{- "dict" }} + {%- endif %} +{%- elif json_spec.type is iterable %} + {{- "Union[" }} + {%- for t in json_spec.type %} + {{- json_to_python_type({"type": t}) }} + {%- if not loop.last %} + {{- "," }} + {%- endif %} + {%- endfor %} + {{- "]" }} +{%- else %} + {{- "Any" }} +{%- endif %} +{%- endmacro %} + + +{{- bos_token }} +{{- "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }} +{%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- '{"type": "function", "function": ' }} + {{- '{"name": ' + tool.name + '", ' }} + {{- '"description": "' + tool.name + '(' }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- param_name + ": " + json_to_python_type(param_fields) }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- if tool.return is defined %} + {{- " -> " + json_to_python_type(tool.return) }} + {%- endif %} + {{- " - " + tool.description + "\n\n" }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {%- if loop.first %} + {{- " Args:\n" }} + {%- endif %} + {{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }} + {%- endfor %} + {%- if tool.return is defined and tool.return.description is defined %} + {{- "\n Returns:\n " + tool.return.description }} + {%- endif %} + {{- '"' }} + {{- ', "parameters": ' }} + {%- if tool.parameters.properties | length == 0 %} + {{- "{}" }} + {%- else %} + {{- tool.parameters|tojson }} + {%- endif %} + {{- "}" }} + {%- if not loop.last %} + {{- "\n" }} + {%- endif %} +{%- endfor %} +{{- " " }} +{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} +' }} +{{- "For each function call return a json object with function name and arguments within XML tags as follows: +" }} +{{- " +" }} +{{- '{"arguments": , "name": } +' }} +{{- '<|im_end|>' }} +{%- for message in messages %} + {%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n\n' }} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '{ ' }} + {%- if tool_call.arguments is defined %} + {{- '"arguments": ' }} + {{- tool_call.arguments|tojson }} + {{- ', '}} + {%- endif %} + {{- '"name": "' }} + {{- tool_call.name }} + {{- '"}' }} + {{- '\n ' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if not message.name is defined %} + {{- raise_exception("Tool response dicts require a 'name' key indicating the name of the called function!") }} + {%- endif %} + {{- '<|im_start|>' + message.role + '\n\n' }} + {{- '{"name": "' }} + {{- message.name }} + {{- '", "content": ' }} + {{- message.content|tojson + '}' }} + {{- '\n <|im_end|>\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/examples/tool_chat_template_mistral.jinja b/examples/tool_chat_template_mistral.jinja new file mode 100644 index 0000000000000..4eac99e20adaf --- /dev/null +++ b/examples/tool_chat_template_mistral.jinja @@ -0,0 +1,49 @@ +{{- bos_token }} +{%- set user_messages = messages | selectattr('role', 'equalto', 'user') | list %} +{%- for message in messages %} + {%- if message['role'] == 'user' %} + {%- if tools and (message == user_messages[-1]) %} + {{- ' [AVAILABLE_TOOLS] [' }} + {%- for tool in tools %} + {%- set tool = tool.function %} + {{- '{"type": "function", "function": {' }} + {%- for key, val in tool|items if key != "return" %} + {%- if val is string %} + {{- '"' + key + '": "' + val + '"' }} + {%- else %} + {{- '"' + key + '": ' + val|tojson }} + {%- endif %} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- "}}" }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- ' [/AVAILABLE_TOOLS]' }} + {%- endif %} + {{- ' [INST] ' + message['content'] + ' [/INST]' }} + {%- elif message['role'] == 'assistant' %} + {%- if message.tool_calls is defined and message.tool_calls|length > 0 %} + {{- ' [TOOL_CALLS] [' }} + {%- for tool_call in message.tool_calls %} + {{- {"name": tool_call.function.name, "arguments": tool_call.function.arguments, "id": tool_call.id}|tojson }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- '] ' }} + {{- eos_token }} + {%- elif message.content is defined %} + {{- ' ' + message.content + ' ' + eos_token}} + {%- endif %} + {%- elif message['role'] == 'tool' %} + {{- ' [TOOL_RESULTS] ' }} + {{- '{"call_id": "' + message.tool_call_id + '", "content": ' + message.content|string + '}' }} + {{- ' [/TOOL_RESULTS] ' }} + {%- endif %} +{%- endfor %} From d221364437451c5e809fefb1fa0c7de8d81e0b70 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 1 Jul 2024 19:32:21 -0500 Subject: [PATCH 017/222] chore: delete old template --- examples/tool_template_hermes_2_pro.jinja | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 examples/tool_template_hermes_2_pro.jinja diff --git a/examples/tool_template_hermes_2_pro.jinja b/examples/tool_template_hermes_2_pro.jinja deleted file mode 100644 index 21ac11505eb7c..0000000000000 --- a/examples/tool_template_hermes_2_pro.jinja +++ /dev/null @@ -1,7 +0,0 @@ -You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: -{{tools}} -Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} -For each function call return a json object with function name and arguments within XML tags as follows: - -{"arguments": , "name": } - \ No newline at end of file From 33c669b0a7d9516982b32a0e632e0999f0f7cde1 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 1 Jul 2024 20:25:56 -0500 Subject: [PATCH 018/222] feat: tool calls are now returned in the chat completion response --- vllm/entrypoints/openai/api_server.py | 7 ++++ vllm/entrypoints/openai/serving_chat.py | 45 ++++--------------------- 2 files changed, 14 insertions(+), 38 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index c4f63f94c498b..971f9ed583943 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -131,12 +131,19 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): generator = await openai_serving_chat.create_chat_completion( request, raw_request) + + # if there's an error, return it if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) + + # if streaming is requested, handle streaming + # TODO implement for streaming later if request.stream: return StreamingResponse(content=generator, media_type="text/event-stream") + + # handle non-streaming requests else: assert isinstance(generator, ChatCompletionResponse) return JSONResponse(content=generator.model_dump()) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 8f2e85fed4486..a83935897b33b 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -84,32 +84,6 @@ def __init__(self, self.tool_prompt_role: str = tool_prompt_role self.tool_use_prompt_template: Optional[jinja2.Template] = None - if self.enable_auto_tools and tool_prompt_jinja_template_path: - self.tool_use_prompt_template = self._load_tool_prompt_template(tool_prompt_jinja_template_path) - elif self.enable_auto_tools and not tool_prompt_jinja_template_path: - raise ValueError( - 'Argument --enable-auto-tool-choice requires --tool-use-prompt-path to set the prompt for instructing ' - 'the model on which tools are available and how to use them.' - ) - - - # TODO set the system prompt for tools and system prompt role for tools if applicable - def _load_tool_prompt_template(self, tool_prompt_jinja_template_path: str) -> jinja2.Template: - """ - Load the Jinja template for the tool prompt - """ - print("Loading tool prompt template!", tool_prompt_jinja_template_path) - - template = env.get_template(tool_prompt_jinja_template_path) - if not template: - raise FileNotFoundError( - f'The specified tool use prompt template {tool_prompt_jinja_template_path} was not found' - ) - - # Load the JINJA template - return template - - def _load_chat_template(self, chat_template: Optional[str]): tokenizer = self.tokenizer @@ -260,24 +234,19 @@ async def create_chat_completion( conversation.extend(chat_parsed_result.messages) image_futures.extend(chat_parsed_result.image_futures) - # if specified, add the system prompt template - if self.enable_auto_tools and self.tool_use_prompt_template: - # create the system prompt from the template - templated_prompt_with_tools: str = self.tool_use_prompt_template.render(tools=request.tools) - - # if there is already a system prompt - if conversation[0]['role'] == 'system': - conversation[0]['content'] = f'{templated_prompt_with_tools}\n\n{conversation[0]["content"]}' - - # if there isn't a system prompt already - else: - conversation.insert(0, ConversationMessage(role='system', content=templated_prompt_with_tools)) + tools = None + if self.enable_auto_tools and request.tools: + tools = [tool.model_dump() for tool in request.tools] + print('using tools', tools) prompt = self.tokenizer.apply_chat_template( conversation=conversation, tokenize=False, add_generation_prompt=request.add_generation_prompt, + tools=tools ) + + print('fully tokenized prompt:', prompt) except Exception as e: logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) From 7214e7043cdc9858297973b27f92ae93adae1f80 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 1 Jul 2024 21:06:10 -0500 Subject: [PATCH 019/222] fix: mistral chat template. replace huggingface-suggested one with the mixtral tool_use chat template --- examples/tool_chat_template_mistral.jinja | 50 +---------------------- 1 file changed, 1 insertion(+), 49 deletions(-) diff --git a/examples/tool_chat_template_mistral.jinja b/examples/tool_chat_template_mistral.jinja index 4eac99e20adaf..249cae0f258a5 100644 --- a/examples/tool_chat_template_mistral.jinja +++ b/examples/tool_chat_template_mistral.jinja @@ -1,49 +1 @@ -{{- bos_token }} -{%- set user_messages = messages | selectattr('role', 'equalto', 'user') | list %} -{%- for message in messages %} - {%- if message['role'] == 'user' %} - {%- if tools and (message == user_messages[-1]) %} - {{- ' [AVAILABLE_TOOLS] [' }} - {%- for tool in tools %} - {%- set tool = tool.function %} - {{- '{"type": "function", "function": {' }} - {%- for key, val in tool|items if key != "return" %} - {%- if val is string %} - {{- '"' + key + '": "' + val + '"' }} - {%- else %} - {{- '"' + key + '": ' + val|tojson }} - {%- endif %} - {%- if not loop.last %} - {{- ", " }} - {%- endif %} - {%- endfor %} - {{- "}}" }} - {%- if not loop.last %} - {{- ", " }} - {%- else %} - {{- "]" }} - {%- endif %} - {%- endfor %} - {{- ' [/AVAILABLE_TOOLS]' }} - {%- endif %} - {{- ' [INST] ' + message['content'] + ' [/INST]' }} - {%- elif message['role'] == 'assistant' %} - {%- if message.tool_calls is defined and message.tool_calls|length > 0 %} - {{- ' [TOOL_CALLS] [' }} - {%- for tool_call in message.tool_calls %} - {{- {"name": tool_call.function.name, "arguments": tool_call.function.arguments, "id": tool_call.id}|tojson }} - {%- if not loop.last %} - {{- ", " }} - {%- endif %} - {%- endfor %} - {{- '] ' }} - {{- eos_token }} - {%- elif message.content is defined %} - {{- ' ' + message.content + ' ' + eos_token}} - {%- endif %} - {%- elif message['role'] == 'tool' %} - {{- ' [TOOL_RESULTS] ' }} - {{- '{"call_id": "' + message.tool_call_id + '", "content": ' + message.content|string + '}' }} - {{- ' [/TOOL_RESULTS] ' }} - {%- endif %} -{%- endfor %} +{{bos_token}}{% set user_messages = messages | selectattr('role', 'equalto', 'user') | list %}{% for message in messages %}{% if message['role'] == 'user' %}{% if message == user_messages[-1] %}{% if tools %}{{'[AVAILABLE_TOOLS]'+ tools|string + '[/AVAILABLE_TOOLS]'}}{% endif %}{{ '[INST]' + message['content'] + '[/INST]' }}{% else %}{{ '[INST]' + message['content'] + '[/INST]' }}{% endif %}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% elif message['role'] == 'tool_results' %}{{'[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]'}}{% elif message['role'] == 'tool_calls' %}{{'[TOOL_CALLS]' + message['content']|string + eos_token}}{% endif %}{% endfor %} \ No newline at end of file From 6a3c61edcc8b72e09b642d355e42107fede846d3 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 1 Jul 2024 23:46:04 -0500 Subject: [PATCH 020/222] feat: add mistral tool parser, and empty hermes tool parser. non-streaming ONLY --- vllm/entrypoints/openai/tool_parsers.py | 61 +++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 vllm/entrypoints/openai/tool_parsers.py diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py new file mode 100644 index 0000000000000..847d0ee1067a5 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -0,0 +1,61 @@ +from vllm.entrypoints.openai.protocol import ToolCall, FunctionCall, ChatCompletionResponse +from vllm.logger import init_logger +from typing import List +from transformers import (AutoTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerFast) +import json +from pydantic import BaseModel + +logger = init_logger(__name__) + + +class ToolParser: + + def __init__(self): + pass + + def extract_tool_calls(self, model_response: ChatCompletionResponse) -> List[ToolCall]: + """ + Abstract method intended to be used for extracting tool calls for use in a NON-STREAMING response + """ + raise NotImplementedError('AbstractToolParser.extract_tool_calls has not been implemented!') + + def extract_tool_calls_streaming(self, generator): + raise NotImplementedError('AbstractToolParser.extract_tool_calls_streaming has not been implemented!') + + +class MistralToolParser(ToolParser): + bot_token: str = '[TOOL_CALLS]' + + def extract_tool_calls(self, model_response: ChatCompletionResponse) -> List[ToolCall]: + + # Get the tool call token from the tokenizer + if self.bot_token not in model_response.choices[0].message.content: + return [] + else: + try: + # extract the token so we hopefully have a JSON string + raw_tool_call = (model_response.choices[0].message.content + .replace(MistralToolParser.bot_token, '') # remove BOT token + .replace("'", '"')) # ... hack to parse broken mistral JSON + tool_call_arr = json.loads(raw_tool_call) + print('tool call array', tool_call_arr) + function_calls: List[FunctionCall] = [FunctionCall.parse_obj(obj) for obj in tool_call_arr] + print('got mistral tool calls', function_calls) + tool_calls = [ToolCall(type='function', function=function_call) for function_call in function_calls] + return tool_calls + + except Exception as e: + logger.error("Error in extracting tool call from response: %s", e) + return [] + + def extract_tool_calls_streaming(self, generator): + raise NotImplementedError('MistralToolParser.extract_tool_calls_streaming has not been implemented!') + + +class Hermes2ProToolParser(ToolParser): + def extract_tool_calls_streaming(self, generator): + raise NotImplementedError('Hermes2ProToolParser.extract_tool_calls_streaming has not been implemented!') + + def extract_tool_calls(self, model_response: ChatCompletionResponse) -> List[ToolCall]: + raise NotImplementedError('Hermes2ProToolParser.extract_tool_calls has not been implemented!') \ No newline at end of file From 73046e60f8c02b231d34e953ecb17adaebf13e43 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 1 Jul 2024 23:46:11 -0500 Subject: [PATCH 021/222] feat: update example --- examples/openai_chat_completion_client_with_tools.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/openai_chat_completion_client_with_tools.py b/examples/openai_chat_completion_client_with_tools.py index 0414a4be0e2e9..a3220fff7fe2e 100644 --- a/examples/openai_chat_completion_client_with_tools.py +++ b/examples/openai_chat_completion_client_with_tools.py @@ -12,6 +12,8 @@ models = client.models.list() model = models.data[0].id + + tools = [{ "type": "function", "function": { @@ -38,8 +40,7 @@ } }] -chat_completion = client.chat.completions.create( - messages=[ +messages = [ { "role": "user", "content": "Hi! How are you doing today?" @@ -52,7 +53,10 @@ "role": "user", "content": "Can you tell me what the weather will be in Dallas Texas?" } - ], + ] + +chat_completion = client.chat.completions.create( + messages=messages, model=model, tools=tools ) From d7311e66cb48cd2cf88250aa4a7eafae3ab6e1b4 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 1 Jul 2024 23:47:01 -0500 Subject: [PATCH 022/222] chore: update FunctionCall type to allow arguments as a Dict. non-auto tool-choice uses string, auto uses Dict --- vllm/entrypoints/openai/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index d38489f4806bf..77bcfb7369a16 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -636,7 +636,7 @@ class EmbeddingResponse(BaseModel): class FunctionCall(OpenAIBaseModel): name: str - arguments: str + arguments: str | Dict class ToolCall(OpenAIBaseModel): From 344b241c52886ce5e8fe9249354d339d83de6f6c Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 1 Jul 2024 23:47:58 -0500 Subject: [PATCH 023/222] feat: clean up CLI arguments, engine. implement tool parser selection; tool call parsing for NON-STREAMING responses --- vllm/entrypoints/openai/api_server.py | 29 ++++++++++++++++++++++--- vllm/entrypoints/openai/cli_args.py | 28 ++++++------------------ vllm/entrypoints/openai/serving_chat.py | 19 ++++++++++++---- 3 files changed, 48 insertions(+), 28 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 971f9ed583943..8ca3db2b08c74 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -146,7 +146,31 @@ async def create_chat_completion(request: ChatCompletionRequest, # handle non-streaming requests else: assert isinstance(generator, ChatCompletionResponse) - return JSONResponse(content=generator.model_dump()) + print('enable auto tools?', openai_serving_chat.enable_auto_tools) + print('tool parser?', openai_serving_chat.tool_parser) + if openai_serving_chat.enable_auto_tools and openai_serving_chat.tool_parser: + + print('returning tool call response') + response = generator.model_dump() + print('Handling response with auto tools and a configured parser!') + tool_calls = openai_serving_chat.tool_parser.extract_tool_calls(generator) + if tool_calls and len(tool_calls): + response['choices'][0]['message']['content'] = None + response['choices'][0]['message']['tool_calls'] = [ + { + "id": tool_call.id, + "type": tool_call.type, + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments + } + } for tool_call in tool_calls + ] + return JSONResponse(content=response) + + else: + print('Returning regular response') + return JSONResponse(content=generator.model_dump()) @app.post("/v1/completions") @@ -252,8 +276,7 @@ async def authentication(request: Request, call_next): args.lora_modules, args.chat_template, args.enable_auto_tool_choice, - args.tool_use_prompt_template, - args.tool_use_prompt_role + args.tool_call_parser ) openai_serving_completion = OpenAIServingCompletion( engine, model_config, served_model_names, args.lora_modules) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 870a1403c3b2c..5900650983777 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -123,27 +123,13 @@ def make_arg_parser(): 'Requires specifying --tool-use-prompt-template.' ) - parser.add_argument( - '--tool-use-prompt-template', - type=str, - default=None, - help="The path to the jinja template that should be used to format " - "any provided OpenAI API-style function definitions into a system prompt " - "that instructs the model how to use tools, and which tools are " - "available. If not provided, tools will be ignored. An example is " - "provided at 'examples/tool_template_hermes_2_pro.jinja'." - ) - - parser.add_argument( - '--tool-use-prompt-role', - type=str, - default='system', - help='The chat role to use for the system prompt that instructs the model what tools are ' - 'available and how to use them. The default is "system". If the "system" role is used for the tool ' - 'use system prompt (default) _and_ the client specifies a system prompt, then the client-' - 'specified system prompt will be appended to the tool use system prompt. If a non-"system" role is ' - 'specified, it will be placed as the first non-system message in the conversation.' - ) + parser.add_argument("--tool-call-parser", + type=str, + choices=['mistral', 'hermes'], + help='Select the tool call parser depending on the model that you\'re using. ' + 'This is used to parse the model-generated tool call into OpenAI API format. ' + 'Required for --enable-auto-tool-choice. Options: "mistral", "hermes"' + ) parser = AsyncEngineArgs.add_cli_args(parser) return parser diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a83935897b33b..6957116a7f7b2 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -36,6 +36,8 @@ log_tracing_disabled_warning) from vllm.utils import random_uuid +from vllm.entrypoints.openai.tool_parsers import ToolParser, MistralToolParser, Hermes2ProToolParser + from jinja2 import Environment, FileSystemLoader, select_autoescape env = Environment( loader=FileSystemLoader('./'), @@ -68,8 +70,7 @@ def __init__(self, lora_modules: Optional[List[LoRAModulePath]] = None, chat_template: Optional[str] = None, enable_auto_tools: Optional[bool] = False, - tool_prompt_jinja_template_path: Optional[str] = None, - tool_prompt_role: Optional[str] = 'system' + tool_parser: Optional[str] = None ): super().__init__(engine=engine, model_config=model_config, @@ -81,8 +82,16 @@ def __init__(self, # set up tool use self.enable_auto_tools: bool = enable_auto_tools - self.tool_prompt_role: str = tool_prompt_role - self.tool_use_prompt_template: Optional[jinja2.Template] = None + + if self.enable_auto_tools and not tool_parser: + raise TypeError('Error: --enable-auto-tool-choice requires --tool-choice-parser') + + if tool_parser == 'mistral': + self.tool_parser: ToolParser = MistralToolParser() + elif tool_parser == 'hermes': + self.tool_parser: ToolParser = Hermes2ProToolParser() + else: + raise ValueError(f'Invalid tool parser value {tool_parser}!') def _load_chat_template(self, chat_template: Optional[str]): tokenizer = self.tokenizer @@ -238,7 +247,9 @@ async def create_chat_completion( if self.enable_auto_tools and request.tools: tools = [tool.model_dump() for tool in request.tools] + print() print('using tools', tools) + print('add generation prompt? ', request.add_generation_prompt) prompt = self.tokenizer.apply_chat_template( conversation=conversation, tokenize=False, From 8740525cb95f4b32f65d7d8e8108c92de8d86bad Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 2 Jul 2024 16:54:25 -0500 Subject: [PATCH 024/222] feat: add methods to FunctionCall and ToolCall in protocol to make it easier to JSON-serialize --- vllm/entrypoints/openai/protocol.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 77bcfb7369a16..59cd49eeb7d34 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -638,12 +638,24 @@ class FunctionCall(OpenAIBaseModel): name: str arguments: str | Dict + def to_dict(self): + return { + "name": self.name, + "arguments": self.arguments + } class ToolCall(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") type: Literal["function"] = "function" function: FunctionCall + def to_dict(self): + return { + "id": self.id, + "type": self.type, + "function": self.function.to_dict() + } + class ChatMessage(OpenAIBaseModel): role: str From 7265be5c3ee824cb8d15df8b033fc9eaf0416e4b Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 2 Jul 2024 16:54:54 -0500 Subject: [PATCH 025/222] fix: ensure finish_reason = "tool_calls" when a tool call is generated --- vllm/entrypoints/openai/api_server.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 8ca3db2b08c74..216331b6c7748 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -146,26 +146,14 @@ async def create_chat_completion(request: ChatCompletionRequest, # handle non-streaming requests else: assert isinstance(generator, ChatCompletionResponse) - print('enable auto tools?', openai_serving_chat.enable_auto_tools) - print('tool parser?', openai_serving_chat.tool_parser) if openai_serving_chat.enable_auto_tools and openai_serving_chat.tool_parser: - - print('returning tool call response') response = generator.model_dump() - print('Handling response with auto tools and a configured parser!') tool_calls = openai_serving_chat.tool_parser.extract_tool_calls(generator) if tool_calls and len(tool_calls): + logger.info("Model chat completion response contains tool calls! Formatting...") response['choices'][0]['message']['content'] = None - response['choices'][0]['message']['tool_calls'] = [ - { - "id": tool_call.id, - "type": tool_call.type, - "function": { - "name": tool_call.function.name, - "arguments": tool_call.function.arguments - } - } for tool_call in tool_calls - ] + response['choices'][0]['message']['tool_calls'] = [tool_call.to_dict() for tool_call in tool_calls] + response['choices'][0]['finish_reason'] = 'tool_calls' return JSONResponse(content=response) else: From a1207f2ae9579cdddf6c5f169f46f16ee274eb92 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 2 Jul 2024 17:23:43 -0500 Subject: [PATCH 026/222] unfinished: work on hermes 2 tool call parser --- vllm/entrypoints/openai/tool_parsers.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index 847d0ee1067a5..8a068f15a7e6a 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -5,6 +5,7 @@ PreTrainedTokenizerFast) import json from pydantic import BaseModel +import re logger = init_logger(__name__) @@ -14,11 +15,25 @@ class ToolParser: def __init__(self): pass + tool_call_start: str = '' + tool_call_end: str = '' + + # regex to match between and OR between and EOS (happens sometimes :)) + tool_call_regex = re.compile(r'(.*?)|(.*)', re.DOTALL) + def extract_tool_calls(self, model_response: ChatCompletionResponse) -> List[ToolCall]: """ Abstract method intended to be used for extracting tool calls for use in a NON-STREAMING response """ - raise NotImplementedError('AbstractToolParser.extract_tool_calls has not been implemented!') + + # sanity check; avoid unnecessary processing + if self.tool_call_start not in model_response.choices[0].message.content: + return [] + + tool_call_tuples = self.tool_call_regex.findall(model_response.choices[0].message.content) + tool_calls = [match[0] if match[0] else match[1] for match in tool_call_tuples] + print('got tool calls for hermes 2 pro!', tool_calls) + return [] def extract_tool_calls_streaming(self, generator): raise NotImplementedError('AbstractToolParser.extract_tool_calls_streaming has not been implemented!') From 30ffa168188938a4e41f703a362a9de9787f077c Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 2 Jul 2024 19:34:39 -0500 Subject: [PATCH 027/222] partial: hermes 2 pro tool parsing --- vllm/entrypoints/openai/tool_parsers.py | 44 +++++++++++++------------ 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index 8a068f15a7e6a..450b16f6f4d9b 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -15,25 +15,8 @@ class ToolParser: def __init__(self): pass - tool_call_start: str = '' - tool_call_end: str = '' - - # regex to match between and OR between and EOS (happens sometimes :)) - tool_call_regex = re.compile(r'(.*?)|(.*)', re.DOTALL) - def extract_tool_calls(self, model_response: ChatCompletionResponse) -> List[ToolCall]: - """ - Abstract method intended to be used for extracting tool calls for use in a NON-STREAMING response - """ - - # sanity check; avoid unnecessary processing - if self.tool_call_start not in model_response.choices[0].message.content: - return [] - - tool_call_tuples = self.tool_call_regex.findall(model_response.choices[0].message.content) - tool_calls = [match[0] if match[0] else match[1] for match in tool_call_tuples] - print('got tool calls for hermes 2 pro!', tool_calls) - return [] + raise NotImplementedError('AbstractToolParser.extract_tool_calls has not been implemented!') def extract_tool_calls_streaming(self, generator): raise NotImplementedError('AbstractToolParser.extract_tool_calls_streaming has not been implemented!') @@ -69,8 +52,27 @@ def extract_tool_calls_streaming(self, generator): class Hermes2ProToolParser(ToolParser): - def extract_tool_calls_streaming(self, generator): - raise NotImplementedError('Hermes2ProToolParser.extract_tool_calls_streaming has not been implemented!') + + tool_call_start: str = '' + tool_call_end: str = '' + + # regex to match between and OR between and EOS (happens sometimes :)) + tool_call_regex = re.compile(r'\n(.*?)\n|\n(.*)', re.DOTALL) def extract_tool_calls(self, model_response: ChatCompletionResponse) -> List[ToolCall]: - raise NotImplementedError('Hermes2ProToolParser.extract_tool_calls has not been implemented!') \ No newline at end of file + """ + Abstract method intended to be used for extracting tool calls for use in a NON-STREAMING response + """ + + # sanity check; avoid unnecessary processing + if self.tool_call_start not in model_response.choices[0].message.content: + return [] + + tool_call_tuples = self.tool_call_regex.findall(model_response.choices[0].message.content) + function_calls = [FunctionCall.parse_obj(json.loads(match[0] if match[0] else match[1])) for match in tool_call_tuples] + tool_calls = [ToolCall(type='function', function=function_call) for function_call in function_calls] + print('got tool calls for hermes 2 pro!', tool_calls) + return tool_calls + + def extract_tool_calls_streaming(self, generator): + raise NotImplementedError('Hermes2ProToolParser.extract_tool_calls_streaming has not been implemented!') From 294c99e983bedc5ac8a48cf12216525ae1f282a2 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 3 Jul 2024 11:34:59 -0500 Subject: [PATCH 028/222] fix: hermes tool call parser, work on example --- ...penai_chat_completion_client_with_tools.py | 15 +++++++++++ vllm/entrypoints/openai/api_server.py | 4 ++- vllm/entrypoints/openai/tool_parsers.py | 26 ++++++++++++------- 3 files changed, 35 insertions(+), 10 deletions(-) diff --git a/examples/openai_chat_completion_client_with_tools.py b/examples/openai_chat_completion_client_with_tools.py index a3220fff7fe2e..addae145a26a5 100644 --- a/examples/openai_chat_completion_client_with_tools.py +++ b/examples/openai_chat_completion_client_with_tools.py @@ -63,3 +63,18 @@ print("Chat completion results:") print(chat_completion) + +# Now, simulate a tool call +def get_current_weather(city: str, state: str, unit: 'str'): + return "The weather in Dallas, Texas is 85 degrees fahrenheit. It is partly cloudly, with highs in the 90's." + +available_tools = { + "get_current_weather": get_current_weather +} + +completion_tool_calls = chat_completion.choices[0].message.tool_calls +for call in completion_tool_calls: + tool_to_call = available_tools[call.function.name] + args = call.function.arguments + print(tool_to_call, args) + diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 216331b6c7748..dd213eade0c46 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -150,10 +150,12 @@ async def create_chat_completion(request: ChatCompletionRequest, response = generator.model_dump() tool_calls = openai_serving_chat.tool_parser.extract_tool_calls(generator) if tool_calls and len(tool_calls): - logger.info("Model chat completion response contains tool calls! Formatting...") + print('TOOL CALLS', tool_calls) response['choices'][0]['message']['content'] = None response['choices'][0]['message']['tool_calls'] = [tool_call.to_dict() for tool_call in tool_calls] response['choices'][0]['finish_reason'] = 'tool_calls' + else: + print('TOOL: no tool calls detected') return JSONResponse(content=response) else: diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index 450b16f6f4d9b..a18e50774b4e2 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -36,11 +36,14 @@ def extract_tool_calls(self, model_response: ChatCompletionResponse) -> List[Too raw_tool_call = (model_response.choices[0].message.content .replace(MistralToolParser.bot_token, '') # remove BOT token .replace("'", '"')) # ... hack to parse broken mistral JSON - tool_call_arr = json.loads(raw_tool_call) - print('tool call array', tool_call_arr) - function_calls: List[FunctionCall] = [FunctionCall.parse_obj(obj) for obj in tool_call_arr] - print('got mistral tool calls', function_calls) - tool_calls = [ToolCall(type='function', function=function_call) for function_call in function_calls] + function_call_arr = json.loads(raw_tool_call) + tool_calls: List[ToolCall] = [ + ToolCall( + type='function', + function=FunctionCall.parse_obj(raw_function_call) + ) + for raw_function_call in function_call_arr + ] return tool_calls except Exception as e: @@ -57,7 +60,7 @@ class Hermes2ProToolParser(ToolParser): tool_call_end: str = '' # regex to match between and OR between and EOS (happens sometimes :)) - tool_call_regex = re.compile(r'\n(.*?)\n|\n(.*)', re.DOTALL) + tool_call_regex = re.compile(r'(.*?)|(.*)', re.DOTALL) def extract_tool_calls(self, model_response: ChatCompletionResponse) -> List[ToolCall]: """ @@ -66,12 +69,17 @@ def extract_tool_calls(self, model_response: ChatCompletionResponse) -> List[Too # sanity check; avoid unnecessary processing if self.tool_call_start not in model_response.choices[0].message.content: + print('TOOL tool_call_start is not in the response') return [] - tool_call_tuples = self.tool_call_regex.findall(model_response.choices[0].message.content) - function_calls = [FunctionCall.parse_obj(json.loads(match[0] if match[0] else match[1])) for match in tool_call_tuples] + # there are two possible captures - between tags, or between a tag and end-of-string so the result of findall + # is an array of tuples where one is a function call and the other is None + function_call_tuples = self.tool_call_regex.findall(model_response.choices[0].message.content) + + + # filter out & enforce the schema + function_calls = [FunctionCall.parse_obj(json.loads(match[0] if match[0] else match[1])) for match in function_call_tuples] tool_calls = [ToolCall(type='function', function=function_call) for function_call in function_calls] - print('got tool calls for hermes 2 pro!', tool_calls) return tool_calls def extract_tool_calls_streaming(self, generator): From 00be988c554b174b7483078b21dc4f9af85ede00 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 3 Jul 2024 11:56:24 -0500 Subject: [PATCH 029/222] fix: tool call arguments should be returned as JSON string not as a literal dict/object --- .../openai_chat_completion_client_with_tools.py | 9 +++++++-- vllm/entrypoints/openai/protocol.py | 2 +- vllm/entrypoints/openai/tool_parsers.py | 17 ++++++++++++++--- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/examples/openai_chat_completion_client_with_tools.py b/examples/openai_chat_completion_client_with_tools.py index addae145a26a5..7bfa6fb0dd0f0 100644 --- a/examples/openai_chat_completion_client_with_tools.py +++ b/examples/openai_chat_completion_client_with_tools.py @@ -1,4 +1,5 @@ from openai import OpenAI +import json # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" @@ -75,6 +76,10 @@ def get_current_weather(city: str, state: str, unit: 'str'): completion_tool_calls = chat_completion.choices[0].message.tool_calls for call in completion_tool_calls: tool_to_call = available_tools[call.function.name] - args = call.function.arguments - print(tool_to_call, args) + args = json.loads(call.function.arguments) + result = tool_to_call(**args) + print(result) + + + diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 59cd49eeb7d34..597b5625fb454 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -636,7 +636,7 @@ class EmbeddingResponse(BaseModel): class FunctionCall(OpenAIBaseModel): name: str - arguments: str | Dict + arguments: str def to_dict(self): return { diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index a18e50774b4e2..c96c59c592a74 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -40,7 +40,10 @@ def extract_tool_calls(self, model_response: ChatCompletionResponse) -> List[Too tool_calls: List[ToolCall] = [ ToolCall( type='function', - function=FunctionCall.parse_obj(raw_function_call) + function=FunctionCall( + name=raw_function_call['name'], + arguments=json.dumps(raw_function_call['arguments']) + ) ) for raw_function_call in function_call_arr ] @@ -78,8 +81,16 @@ def extract_tool_calls(self, model_response: ChatCompletionResponse) -> List[Too # filter out & enforce the schema - function_calls = [FunctionCall.parse_obj(json.loads(match[0] if match[0] else match[1])) for match in function_call_tuples] - tool_calls = [ToolCall(type='function', function=function_call) for function_call in function_calls] + raw_function_calls = [json.loads(match[0] if match[0] else match[1]) for match in function_call_tuples] + tool_calls = [ + ToolCall( + type='function', + function=FunctionCall( + name=function_call['name'], + arguments=json.dumps(function_call['arguments']) + ) + ) for function_call in raw_function_calls + ] return tool_calls def extract_tool_calls_streaming(self, generator): From b70e7d7456918f9d8294032441aa2f645429c0fa Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 3 Jul 2024 12:55:52 -0500 Subject: [PATCH 030/222] feat: enable both content and tool_calls if the model allows --- vllm/entrypoints/openai/api_server.py | 9 ++- vllm/entrypoints/openai/protocol.py | 12 ++++ vllm/entrypoints/openai/tool_parsers.py | 88 ++++++++++++++++++------- 3 files changed, 79 insertions(+), 30 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index dd213eade0c46..6756b0192509e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -148,11 +148,10 @@ async def create_chat_completion(request: ChatCompletionRequest, assert isinstance(generator, ChatCompletionResponse) if openai_serving_chat.enable_auto_tools and openai_serving_chat.tool_parser: response = generator.model_dump() - tool_calls = openai_serving_chat.tool_parser.extract_tool_calls(generator) - if tool_calls and len(tool_calls): - print('TOOL CALLS', tool_calls) - response['choices'][0]['message']['content'] = None - response['choices'][0]['message']['tool_calls'] = [tool_call.to_dict() for tool_call in tool_calls] + tool_call_info = openai_serving_chat.tool_parser.extract_tool_calls(generator) + if tool_call_info.tools_called: + response['choices'][0]['message']['content'] = tool_call_info.content + response['choices'][0]['message']['tool_calls'] = [tool_call.to_dict() for tool_call in tool_call_info.tool_calls] response['choices'][0]['finish_reason'] = 'tool_calls' else: print('TOOL: no tool calls detected') diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 597b5625fb454..ee49691f83523 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -657,6 +657,18 @@ def to_dict(self): } +class ExtractedToolCallInformation(BaseModel): + # indicate if tools were called + tools_called: bool + + # extracted tool calls + tool_calls: List[ToolCall] + + # content - per OpenAI spec, content AND tool calls can be returned ALTHOUGH THIS IS VERY RARE + # But some models will do this intentionally + content: str | None + + class ChatMessage(OpenAIBaseModel): role: str content: str diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index c96c59c592a74..247e625012802 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -1,4 +1,4 @@ -from vllm.entrypoints.openai.protocol import ToolCall, FunctionCall, ChatCompletionResponse +from vllm.entrypoints.openai.protocol import ToolCall, FunctionCall, ChatCompletionResponse, ExtractedToolCallInformation from vllm.logger import init_logger from typing import List from transformers import (AutoTokenizer, PreTrainedTokenizer, @@ -15,7 +15,7 @@ class ToolParser: def __init__(self): pass - def extract_tool_calls(self, model_response: ChatCompletionResponse) -> List[ToolCall]: + def extract_tool_calls(self, model_response: ChatCompletionResponse) -> ExtractedToolCallInformation: raise NotImplementedError('AbstractToolParser.extract_tool_calls has not been implemented!') def extract_tool_calls_streaming(self, generator): @@ -25,33 +25,49 @@ def extract_tool_calls_streaming(self, generator): class MistralToolParser(ToolParser): bot_token: str = '[TOOL_CALLS]' - def extract_tool_calls(self, model_response: ChatCompletionResponse) -> List[ToolCall]: + def extract_tool_calls(self, model_response: ChatCompletionResponse) -> ExtractedToolCallInformation: # Get the tool call token from the tokenizer if self.bot_token not in model_response.choices[0].message.content: - return [] + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_response.choices[0].message.content + ) else: try: # extract the token so we hopefully have a JSON string raw_tool_call = (model_response.choices[0].message.content .replace(MistralToolParser.bot_token, '') # remove BOT token .replace("'", '"')) # ... hack to parse broken mistral JSON + # load the JSON, and then use it to build the Function and Tool Call function_call_arr = json.loads(raw_tool_call) tool_calls: List[ToolCall] = [ ToolCall( type='function', function=FunctionCall( name=raw_function_call['name'], + # function call args are JSON but as a string arguments=json.dumps(raw_function_call['arguments']) ) ) for raw_function_call in function_call_arr ] - return tool_calls + content = model_response.choices[0].message.content.split(self.bot_token)[0] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if len(content) > 0 else None + ) except Exception as e: + # TODO discussion on how to best handle invalidly-generated tool calls logger.error("Error in extracting tool call from response: %s", e) - return [] + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_response.choices[0].message.content + ) def extract_tool_calls_streaming(self, generator): raise NotImplementedError('MistralToolParser.extract_tool_calls_streaming has not been implemented!') @@ -64,34 +80,56 @@ class Hermes2ProToolParser(ToolParser): # regex to match between and OR between and EOS (happens sometimes :)) tool_call_regex = re.compile(r'(.*?)|(.*)', re.DOTALL) + scratch_pad_regex = re.compile(r'(.*?)', re.DOTALL) - def extract_tool_calls(self, model_response: ChatCompletionResponse) -> List[ToolCall]: - """ - Abstract method intended to be used for extracting tool calls for use in a NON-STREAMING response - """ + def extract_tool_calls(self, model_response: ChatCompletionResponse) -> ExtractedToolCallInformation: # sanity check; avoid unnecessary processing if self.tool_call_start not in model_response.choices[0].message.content: print('TOOL tool_call_start is not in the response') - return [] + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_response.choices[0].message.content + ) + + else: - # there are two possible captures - between tags, or between a tag and end-of-string so the result of findall - # is an array of tuples where one is a function call and the other is None - function_call_tuples = self.tool_call_regex.findall(model_response.choices[0].message.content) + try: + # there are two possible captures - between tags, or between a tag and end-of-string so the result of findall + # is an array of tuples where one is a function call and the other is None + function_call_tuples = self.tool_call_regex.findall(model_response.choices[0].message.content) + # load the JSON, and then use it to build the Function and Tool Call + raw_function_calls = [json.loads(match[0] if match[0] else match[1]) for match in function_call_tuples] + tool_calls = [ + ToolCall( + type='function', + function=FunctionCall( + name=function_call['name'], + # function call args are JSON but as a string + arguments=json.dumps(function_call['arguments']) + ) + ) for function_call in raw_function_calls + ] + content_match = self.scratch_pad_regex.search(model_response.choices[0].message.content) + print("CONTENT MATCH", content_match) + content = content_match.group(1) if content_match else None + print("CONTENT", content) + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None + ) - # filter out & enforce the schema - raw_function_calls = [json.loads(match[0] if match[0] else match[1]) for match in function_call_tuples] - tool_calls = [ - ToolCall( - type='function', - function=FunctionCall( - name=function_call['name'], - arguments=json.dumps(function_call['arguments']) + except Exception as e: + logger.error("Error in extracting tool call from response %s", e) + # TODO discussion on how to best handle invalidly-generated tool calls + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_response.choices[0].message.content ) - ) for function_call in raw_function_calls - ] - return tool_calls def extract_tool_calls_streaming(self, generator): raise NotImplementedError('Hermes2ProToolParser.extract_tool_calls_streaming has not been implemented!') From ece61826c7f59e908d6416113fd10f6571346650 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 3 Jul 2024 13:42:09 -0500 Subject: [PATCH 031/222] feat: update example with tool call --- .../openai_chat_completion_client_with_tools.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/examples/openai_chat_completion_client_with_tools.py b/examples/openai_chat_completion_client_with_tools.py index 7bfa6fb0dd0f0..283740a21174d 100644 --- a/examples/openai_chat_completion_client_with_tools.py +++ b/examples/openai_chat_completion_client_with_tools.py @@ -64,6 +64,10 @@ print("Chat completion results:") print(chat_completion) +messages.append({ + "role": "assistant", + "tool_calls": chat_completion.choices[0].message.tool_calls +}) # Now, simulate a tool call def get_current_weather(city: str, state: str, unit: 'str'): @@ -79,7 +83,20 @@ def get_current_weather(city: str, state: str, unit: 'str'): args = json.loads(call.function.arguments) result = tool_to_call(**args) print(result) + messages.append({ + "role": "tool", + "content": result, + "tool_call_id": call.id, + "name": call.function.name + }) +print("Sending new chat with messages", messages) +chat_completion_2 = client.chat.completions.create( + messages=messages, + model=model, + tools=tools, +) +print(chat_completion_2) From 705ca625b43350e6dc4a1c202f6aa30025a618d9 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 3 Jul 2024 14:45:45 -0500 Subject: [PATCH 032/222] fix: typing-related issues for chat messages --- vllm/entrypoints/openai/protocol.py | 7 +- vllm/entrypoints/openai/serving_chat.py | 94 +++++++++++++++++-------- 2 files changed, 68 insertions(+), 33 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index ee49691f83523..1375f79c71f73 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -34,12 +34,16 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): content: Union[str, List[ChatCompletionContentPartParam]] """The contents of the message.""" - name: str + name: Optional[str] """An optional name for the participant. Provides the model information to differentiate between participants of the same role. """ + tool_call_id: Optional[str] + + tool_calls: Optional[List[dict]] + ChatCompletionMessageParam = Union[ @@ -320,7 +324,6 @@ def check_guided_decoding_count(cls, data): @model_validator(mode="before") @classmethod def check_tool_usage(cls, data): - print("DATA", data) if "tool_choice" in data: diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 6957116a7f7b2..9a38f894438bd 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -39,6 +39,7 @@ from vllm.entrypoints.openai.tool_parsers import ToolParser, MistralToolParser, Hermes2ProToolParser from jinja2 import Environment, FileSystemLoader, select_autoescape + env = Environment( loader=FileSystemLoader('./'), autoescape=select_autoescape() @@ -50,7 +51,10 @@ @final # So that it should be compatible with Dict[str, str] class ConversationMessage(TypedDict): role: str - content: str + content: Optional[str] # optional IFF tool_calls is specified + tool_call_id: Optional[str] + name: str | None + tool_calls: Optional[List] @dataclass(frozen=True) @@ -122,11 +126,15 @@ def _load_chat_template(self, chat_template: Optional[str]): logger.warning( "No chat template provided. Chat API will not work.") - def _parse_chat_message_content_parts( - self, - role: str, - parts: Iterable[ChatCompletionContentPartParam], + def _parse_chat_message_content_parts_for_image( + self, + role: str, + parts: Iterable[ChatCompletionContentPartParam], ) -> ChatMessageParseResult: + + """ + Handle parsing out the image data for image chat completions + """ texts: List[str] = [] image_futures: List[Awaitable[ImagePixelData]] = [] @@ -200,26 +208,46 @@ def _parse_chat_message_content_parts( image_futures=image_futures) def _parse_chat_message_content( - self, - message: ChatCompletionMessageParam, + self, + message: ChatCompletionMessageParam, ) -> ChatMessageParseResult: - role = message["role"] + role = message.get('role') content = message.get("content") + tool_call_id = message.get('tool_call_id') + tool_calls = message.get('tool_calls') + name = message.get('tool_calls') - if content is None: + # invariant + if content is None and tool_calls is None: + print('Parsing message as empty:', message) return ChatMessageParseResult(messages=[], image_futures=[]) - if isinstance(content, str): - messages = [ConversationMessage(role=role, content=content)] + + # if content is a string OR if there's tool calls + if isinstance(content, str) or isinstance(tool_calls, list): + print('parsing message as content', message) + + messages: List[ConversationMessage] = [] + if role == 'tool': + messages = [ConversationMessage(role=role, name=name, content=content, tool_call_id=tool_call_id)] + elif role == 'assistant': + if tool_calls: + messages = [ConversationMessage(role=role, content=content, tool_calls=tool_calls)] + else: + messages = [ConversationMessage(role=role, content=content)] + else: # user and system messages can be handled the same way + messages = [ConversationMessage(role=role, content=content)] return ChatMessageParseResult(messages=messages, image_futures=[]) - return self._parse_chat_message_content_parts(role, content) + elif isinstance(content, list): + print('parsing message as image stuff') + return self._parse_chat_message_content_parts_for_image(role, content) async def create_chat_completion( - self, - request: ChatCompletionRequest, - raw_request: Optional[Request] = None + self, + request: ChatCompletionRequest, + raw_request: Optional[Request] = None ) -> Union[ErrorResponse, AsyncGenerator[str, None], - ChatCompletionResponse]: + ChatCompletionResponse]: """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/chat/create @@ -229,17 +257,21 @@ async def create_chat_completion( NOTE: Currently we do not support the following feature: - function_call (Users should implement this by themselves) """ + print('checking model') error_check_ret = await self._check_model(request) if error_check_ret is not None: + print('Error with model') return error_check_ret try: + print('trying to parse messages') conversation: List[ConversationMessage] = [] image_futures: List[Awaitable[ImagePixelData]] = [] for msg in request.messages: + print('parsing messages...') chat_parsed_result = self._parse_chat_message_content(msg) - + print('messages parsed...') conversation.extend(chat_parsed_result.messages) image_futures.extend(chat_parsed_result.image_futures) @@ -250,6 +282,7 @@ async def create_chat_completion( print() print('using tools', tools) print('add generation prompt? ', request.add_generation_prompt) + print('conversation', conversation) prompt = self.tokenizer.apply_chat_template( conversation=conversation, tokenize=False, @@ -262,7 +295,6 @@ async def create_chat_completion( logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) - # Fetch image data image_data: Optional[ImagePixelData] = None try: @@ -285,7 +317,7 @@ async def create_chat_completion( lora_request = self._maybe_get_lora(request) decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ - or decoding_config.guided_decoding_backend + or decoding_config.guided_decoding_backend guided_decode_logits_processor = ( await get_guided_decoding_logits_processor( guided_decoding_backend, request, await @@ -387,7 +419,7 @@ async def chat_completion_stream_generator( last_msg_content = "" if conversation and conversation[-1].get( "content") and conversation[-1].get( - "role") == role: + "role") == role: last_msg_content = conversation[-1]["content"] if last_msg_content: @@ -421,7 +453,7 @@ async def chat_completion_stream_generator( delta_token_ids = output.token_ids[previous_num_tokens[i]:] out_logprobs = output.logprobs[ - previous_num_tokens[i]:] if output.logprobs else None + previous_num_tokens[i]:] if output.logprobs else None if request.logprobs and request.top_logprobs is not None: assert out_logprobs is not None, ( @@ -517,12 +549,12 @@ async def chat_completion_stream_generator( yield "data: [DONE]\n\n" async def chat_completion_full_generator( - self, - request: ChatCompletionRequest, - raw_request: Optional[Request], - result_generator: AsyncIterator[RequestOutput], - request_id: str, - conversation: List[ConversationMessage] + self, + request: ChatCompletionRequest, + raw_request: Optional[Request], + result_generator: AsyncIterator[RequestOutput], + request_id: str, + conversation: List[ConversationMessage] ) -> Union[ErrorResponse, ChatCompletionResponse]: model_name = self.served_model_names[0] @@ -630,10 +662,10 @@ def _get_top_logprobs( ] def _create_chat_logprobs( - self, - token_ids: GenericSequence[int], - top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], - num_output_top_logprobs: Optional[int] = None, + self, + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], + num_output_top_logprobs: Optional[int] = None, ) -> ChatCompletionLogProbs: """Create OpenAI-style logprobs.""" From b3a62e90e0aa955f9c3f1c7db6cd9b8c143f6aeb Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 4 Jul 2024 13:09:43 -0500 Subject: [PATCH 033/222] feat: fix lots of parsing & extraction issues to ensure tool calls & results are parsed properly --- vllm/entrypoints/openai/api_server.py | 1 - vllm/entrypoints/openai/protocol.py | 2 +- vllm/entrypoints/openai/serving_chat.py | 38 ++++++++++++------------- vllm/entrypoints/openai/tool_parsers.py | 3 -- 4 files changed, 20 insertions(+), 24 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 6756b0192509e..70c9a93276e6e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -158,7 +158,6 @@ async def create_chat_completion(request: ChatCompletionRequest, return JSONResponse(content=response) else: - print('Returning regular response') return JSONResponse(content=generator.model_dump()) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 1375f79c71f73..a153c9bb969d2 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -41,7 +41,7 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): same role. """ tool_call_id: Optional[str] - + tool_calls: Optional[List[dict]] diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 9a38f894438bd..80049faf54b7d 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -219,27 +219,27 @@ def _parse_chat_message_content( # invariant if content is None and tool_calls is None: - print('Parsing message as empty:', message) return ChatMessageParseResult(messages=[], image_futures=[]) # if content is a string OR if there's tool calls - if isinstance(content, str) or isinstance(tool_calls, list): - print('parsing message as content', message) + if isinstance(content, str) or tool_calls: - messages: List[ConversationMessage] = [] if role == 'tool': messages = [ConversationMessage(role=role, name=name, content=content, tool_call_id=tool_call_id)] elif role == 'assistant': if tool_calls: - messages = [ConversationMessage(role=role, content=content, tool_calls=tool_calls)] + # tool_calls is a ValidatorIterator and should be flattened into a list + # (although it doesn't have to be) + messages = [ConversationMessage(role=role, content=content, tool_calls=list(tool_calls))] else: messages = [ConversationMessage(role=role, content=content)] - else: # user and system messages can be handled the same way + else: + # user and system messages can be handled the same way messages = [ConversationMessage(role=role, content=content)] + return ChatMessageParseResult(messages=messages, image_futures=[]) elif isinstance(content, list): - print('parsing message as image stuff') return self._parse_chat_message_content_parts_for_image(role, content) async def create_chat_completion( @@ -257,21 +257,17 @@ async def create_chat_completion( NOTE: Currently we do not support the following feature: - function_call (Users should implement this by themselves) """ - print('checking model') error_check_ret = await self._check_model(request) if error_check_ret is not None: - print('Error with model') + print('Error with model', error_check_ret) return error_check_ret try: - print('trying to parse messages') conversation: List[ConversationMessage] = [] image_futures: List[Awaitable[ImagePixelData]] = [] for msg in request.messages: - print('parsing messages...') chat_parsed_result = self._parse_chat_message_content(msg) - print('messages parsed...') conversation.extend(chat_parsed_result.messages) image_futures.extend(chat_parsed_result.image_futures) @@ -279,10 +275,6 @@ async def create_chat_completion( if self.enable_auto_tools and request.tools: tools = [tool.model_dump() for tool in request.tools] - print() - print('using tools', tools) - print('add generation prompt? ', request.add_generation_prompt) - print('conversation', conversation) prompt = self.tokenizer.apply_chat_template( conversation=conversation, tokenize=False, @@ -355,12 +347,20 @@ async def create_chat_completion( # Streaming response if request.stream: return self.chat_completion_stream_generator( - request, result_generator, request_id, conversation) + request, + result_generator, + request_id, + conversation + ) else: try: return await self.chat_completion_full_generator( - request, raw_request, result_generator, request_id, - conversation) + request, + raw_request, + result_generator, + request_id, + conversation + ) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index 247e625012802..f573f4ed1b894 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -86,7 +86,6 @@ def extract_tool_calls(self, model_response: ChatCompletionResponse) -> Extracte # sanity check; avoid unnecessary processing if self.tool_call_start not in model_response.choices[0].message.content: - print('TOOL tool_call_start is not in the response') return ExtractedToolCallInformation( tools_called=False, tool_calls=[], @@ -113,9 +112,7 @@ def extract_tool_calls(self, model_response: ChatCompletionResponse) -> Extracte ) for function_call in raw_function_calls ] content_match = self.scratch_pad_regex.search(model_response.choices[0].message.content) - print("CONTENT MATCH", content_match) content = content_match.group(1) if content_match else None - print("CONTENT", content) return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, From c2d1afc68a20a05e15a18d491a2b6e9f17072132 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 4 Jul 2024 14:43:04 -0500 Subject: [PATCH 034/222] chore: refactor tool extraction for non-streaming responses to be in serving_chat.py --- vllm/entrypoints/openai/api_server.py | 24 ++++++++++-------------- vllm/entrypoints/openai/serving_chat.py | 17 ++++++++++++++++- vllm/entrypoints/openai/tool_parsers.py | 3 ++- 3 files changed, 28 insertions(+), 16 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 70c9a93276e6e..2ed680e1df925 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -140,25 +140,21 @@ async def create_chat_completion(request: ChatCompletionRequest, # if streaming is requested, handle streaming # TODO implement for streaming later if request.stream: - return StreamingResponse(content=generator, + + if openai_serving_chat.enable_auto_tools and openai_serving_chat.tool_parser: + print('handling streaming response') + + return StreamingResponse(content=generator, + media_type="text/event-stream") + + else: + return StreamingResponse(content=generator, media_type="text/event-stream") # handle non-streaming requests else: assert isinstance(generator, ChatCompletionResponse) - if openai_serving_chat.enable_auto_tools and openai_serving_chat.tool_parser: - response = generator.model_dump() - tool_call_info = openai_serving_chat.tool_parser.extract_tool_calls(generator) - if tool_call_info.tools_called: - response['choices'][0]['message']['content'] = tool_call_info.content - response['choices'][0]['message']['tool_calls'] = [tool_call.to_dict() for tool_call in tool_call_info.tool_calls] - response['choices'][0]['finish_reason'] = 'tool_calls' - else: - print('TOOL: no tool calls detected') - return JSONResponse(content=response) - - else: - return JSONResponse(content=generator.model_dump()) + return JSONResponse(content=generator.model_dump()) @app.post("/v1/completions") diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 80049faf54b7d..dd84eabd77992 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -354,13 +354,28 @@ async def create_chat_completion( ) else: try: - return await self.chat_completion_full_generator( + generator = await self.chat_completion_full_generator( request, raw_request, result_generator, request_id, conversation ) + + assert isinstance(generator, ChatCompletionResponse) + print('generator', generator) + + # handle tool extraction + if self.enable_auto_tools and self.tool_parser: + tool_call_info = self.tool_parser.extract_tool_calls(generator) + if tool_call_info.tools_called: + generator.choices[0].message.content = tool_call_info.content + generator.choices[0].message.tool_calls = [tool_call.to_dict() for tool_call in + tool_call_info.tool_calls] + generator.choices[0].finish_reason = 'tool_calls' + else: + print('no tool calls detected') + return generator except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index f573f4ed1b894..459ede4a10b2a 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -70,7 +70,8 @@ def extract_tool_calls(self, model_response: ChatCompletionResponse) -> Extracte ) def extract_tool_calls_streaming(self, generator): - raise NotImplementedError('MistralToolParser.extract_tool_calls_streaming has not been implemented!') + for chunk in generator: + print('CHUNK', chunk) class Hermes2ProToolParser(ToolParser): From 2d4b302af306ff9b71e60135530e75ccf29b15f7 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 5 Jul 2024 11:48:05 -0500 Subject: [PATCH 035/222] fix: mistral tool calling chat template --- examples/tool_chat_template_mistral.jinja | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/tool_chat_template_mistral.jinja b/examples/tool_chat_template_mistral.jinja index 249cae0f258a5..6a99d80802067 100644 --- a/examples/tool_chat_template_mistral.jinja +++ b/examples/tool_chat_template_mistral.jinja @@ -1 +1 @@ -{{bos_token}}{% set user_messages = messages | selectattr('role', 'equalto', 'user') | list %}{% for message in messages %}{% if message['role'] == 'user' %}{% if message == user_messages[-1] %}{% if tools %}{{'[AVAILABLE_TOOLS]'+ tools|string + '[/AVAILABLE_TOOLS]'}}{% endif %}{{ '[INST]' + message['content'] + '[/INST]' }}{% else %}{{ '[INST]' + message['content'] + '[/INST]' }}{% endif %}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% elif message['role'] == 'tool_results' %}{{'[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]'}}{% elif message['role'] == 'tool_calls' %}{{'[TOOL_CALLS]' + message['content']|string + eos_token}}{% endif %}{% endfor %} \ No newline at end of file +{{ bos_token }}{% set user_messages = messages | selectattr('role', 'equalto', 'user') | list %}{% for message in messages %}{% if message['role'] == 'user' %}{% if message == user_messages[-1] %}{% if tools %}{{ '[AVAILABLE_TOOLS]'+ tools|string + '[/AVAILABLE_TOOLS]' }}{% endif %}{{ '[INST]' + message['content'] + '[/INST]' }}{% else %}{{ '[INST]' + message['content'] + '[/INST]' }}{% endif %}{% elif message['role'] == 'assistant' and message['tool_calls'] and message['tool_calls']|length > 0 %}{{ '[TOOL_CALLS]' + message['content']|string + eos_token }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token }}{% elif message['role'] == 'tool' %}{{ '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }}{% endif %}{% endfor %} \ No newline at end of file From 1fcd4f5387b6c4197170b2ab6c58c6e49b838192 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 5 Jul 2024 11:48:48 -0500 Subject: [PATCH 036/222] fix: make ChatMessage content Optional since it could be an assistant message with tool_calls and no content --- vllm/entrypoints/openai/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index a153c9bb969d2..844b2a3756603 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -674,7 +674,7 @@ class ExtractedToolCallInformation(BaseModel): class ChatMessage(OpenAIBaseModel): role: str - content: str + content: Optional[str | None] tool_calls: List[ToolCall] = Field(default_factory=list) From 2926c3e903d2fa628b3545117010169abc84820a Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 5 Jul 2024 11:52:10 -0500 Subject: [PATCH 037/222] refactor: move tool parsing into the right place in serving_chat, and update tool parsers accordingly --- vllm/entrypoints/openai/serving_chat.py | 30 +++++++++++-------------- vllm/entrypoints/openai/tool_parsers.py | 27 +++++++++++----------- 2 files changed, 27 insertions(+), 30 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index dd84eabd77992..4792cf514499a 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -282,7 +282,6 @@ async def create_chat_completion( tools=tools ) - print('fully tokenized prompt:', prompt) except Exception as e: logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) @@ -363,18 +362,6 @@ async def create_chat_completion( ) assert isinstance(generator, ChatCompletionResponse) - print('generator', generator) - - # handle tool extraction - if self.enable_auto_tools and self.tool_parser: - tool_call_info = self.tool_parser.extract_tool_calls(generator) - if tool_call_info.tools_called: - generator.choices[0].message.content = tool_call_info.content - generator.choices[0].message.tool_calls = [tool_call.to_dict() for tool_call in - tool_call_info.tool_calls] - generator.choices[0].finish_reason = 'tool_calls' - else: - print('no tool calls detected') return generator except ValueError as e: # TODO: Use a vllm-specific Validation Error @@ -601,6 +588,7 @@ async def chat_completion_full_generator( else: logprobs = None + tools_called = False # if the reqeust uses tools and specified a tool choice if request.tool_choice and type( request.tool_choice) is ChatCompletionNamedToolChoiceParam: @@ -613,6 +601,7 @@ async def chat_completion_full_generator( name=request.tool_choice.function.name, arguments=output.text)) ]) + tools_called = True # if the request doesn't use tool choice OR specifies to not use a tool elif not request.tool_choice or request.tool_choice == "none": @@ -620,17 +609,24 @@ async def chat_completion_full_generator( message = ChatMessage(role=role, content=output.text) # handle when there are tools and tool choice is auto - elif request.tools and (request.tool_choice == "auto" or request.tool_choice is None): + elif request.tools and (request.tool_choice == "auto" or request.tool_choice is None) and self.enable_auto_tools and self.tool_parser: + + tool_call_info = self.tool_parser.extract_tool_calls(output.text) + tools_called = tool_call_info.tools_called + if tool_call_info.tools_called: + message = ChatMessage(role=role, content=tool_call_info.content, tool_calls=tool_call_info.tool_calls) + else: # FOR NOW make it a chat message; we will have to detect the type to make it later. - message = ChatMessage(role=role, content=output.text) + message = ChatMessage(role=role, content=output.text) choice_data = ChatCompletionResponseChoice( index=output.index, message=message, logprobs=logprobs, - finish_reason=output.finish_reason, - stop_reason=output.stop_reason) + finish_reason='tool_calls' if tools_called else output.stop_reason, + stop_reason=output.stop_reason + ) choices.append(choice_data) if request.echo: diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index 459ede4a10b2a..88e505859c946 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -15,7 +15,7 @@ class ToolParser: def __init__(self): pass - def extract_tool_calls(self, model_response: ChatCompletionResponse) -> ExtractedToolCallInformation: + def extract_tool_calls(self, model_output: str) -> ExtractedToolCallInformation: raise NotImplementedError('AbstractToolParser.extract_tool_calls has not been implemented!') def extract_tool_calls_streaming(self, generator): @@ -25,19 +25,19 @@ def extract_tool_calls_streaming(self, generator): class MistralToolParser(ToolParser): bot_token: str = '[TOOL_CALLS]' - def extract_tool_calls(self, model_response: ChatCompletionResponse) -> ExtractedToolCallInformation: + def extract_tool_calls(self, model_output: str) -> ExtractedToolCallInformation: # Get the tool call token from the tokenizer - if self.bot_token not in model_response.choices[0].message.content: + if self.bot_token not in model_output: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], - content=model_response.choices[0].message.content + content=model_output ) else: try: # extract the token so we hopefully have a JSON string - raw_tool_call = (model_response.choices[0].message.content + raw_tool_call = (model_output .replace(MistralToolParser.bot_token, '') # remove BOT token .replace("'", '"')) # ... hack to parse broken mistral JSON # load the JSON, and then use it to build the Function and Tool Call @@ -53,7 +53,7 @@ def extract_tool_calls(self, model_response: ChatCompletionResponse) -> Extracte ) for raw_function_call in function_call_arr ] - content = model_response.choices[0].message.content.split(self.bot_token)[0] + content = model_output.split(self.bot_token)[0] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, @@ -63,10 +63,11 @@ def extract_tool_calls(self, model_response: ChatCompletionResponse) -> Extracte except Exception as e: # TODO discussion on how to best handle invalidly-generated tool calls logger.error("Error in extracting tool call from response: %s", e) + print('ERROR', e) return ExtractedToolCallInformation( tools_called=False, tool_calls=[], - content=model_response.choices[0].message.content + content=model_output ) def extract_tool_calls_streaming(self, generator): @@ -83,14 +84,14 @@ class Hermes2ProToolParser(ToolParser): tool_call_regex = re.compile(r'(.*?)|(.*)', re.DOTALL) scratch_pad_regex = re.compile(r'(.*?)', re.DOTALL) - def extract_tool_calls(self, model_response: ChatCompletionResponse) -> ExtractedToolCallInformation: + def extract_tool_calls(self, model_output: str) -> ExtractedToolCallInformation: # sanity check; avoid unnecessary processing - if self.tool_call_start not in model_response.choices[0].message.content: + if self.tool_call_start not in model_output: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], - content=model_response.choices[0].message.content + content=model_output ) else: @@ -98,7 +99,7 @@ def extract_tool_calls(self, model_response: ChatCompletionResponse) -> Extracte try: # there are two possible captures - between tags, or between a tag and end-of-string so the result of findall # is an array of tuples where one is a function call and the other is None - function_call_tuples = self.tool_call_regex.findall(model_response.choices[0].message.content) + function_call_tuples = self.tool_call_regex.findall(model_output) # load the JSON, and then use it to build the Function and Tool Call raw_function_calls = [json.loads(match[0] if match[0] else match[1]) for match in function_call_tuples] @@ -112,7 +113,7 @@ def extract_tool_calls(self, model_response: ChatCompletionResponse) -> Extracte ) ) for function_call in raw_function_calls ] - content_match = self.scratch_pad_regex.search(model_response.choices[0].message.content) + content_match = self.scratch_pad_regex.search(model_output) content = content_match.group(1) if content_match else None return ExtractedToolCallInformation( tools_called=True, @@ -126,7 +127,7 @@ def extract_tool_calls(self, model_response: ChatCompletionResponse) -> Extracte return ExtractedToolCallInformation( tools_called=False, tool_calls=[], - content=model_response.choices[0].message.content + content=model_output ) def extract_tool_calls_streaming(self, generator): From fa082e0b213f3f510a8c6085c38a9573223375e9 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sun, 7 Jul 2024 17:13:12 -0500 Subject: [PATCH 038/222] fix: finish_reason should NEVER be None; OpenAI defualt is "stop" --- vllm/entrypoints/openai/protocol.py | 4 ++-- vllm/entrypoints/openai/serving_chat.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 844b2a3756603..f02626fdd8fd3 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -696,8 +696,8 @@ class ChatCompletionResponseChoice(OpenAIBaseModel): index: int message: ChatMessage logprobs: Optional[ChatCompletionLogProbs] = None - finish_reason: Optional[str] = None - stop_reason: Optional[Union[int, str]] = None + finish_reason: Optional[str] = Field(default='stop') # per OpenAI spec this is the default + stop_reason: Optional[Union[int, str]] = None # ??? Not part of the OpenAI spec class ChatCompletionResponse(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 4792cf514499a..395621919f2db 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -246,8 +246,7 @@ async def create_chat_completion( self, request: ChatCompletionRequest, raw_request: Optional[Request] = None - ) -> Union[ErrorResponse, AsyncGenerator[str, None], - ChatCompletionResponse]: + ) -> Union[ErrorResponse, AsyncGenerator[str, None], ChatCompletionResponse]: """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/chat/create @@ -375,7 +374,7 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str: async def chat_completion_stream_generator( self, request: ChatCompletionRequest, - result_generator: AsyncIterator[RequestOutput], request_id: str, + result_generator: AsyncIterator[RequestOfiutput], request_id: str, conversation: List[ConversationMessage] ) -> AsyncGenerator[str, None]: model_name = self.served_model_names[0] @@ -620,11 +619,12 @@ async def chat_completion_full_generator( # FOR NOW make it a chat message; we will have to detect the type to make it later. message = ChatMessage(role=role, content=output.text) + choice_data = ChatCompletionResponseChoice( index=output.index, message=message, logprobs=logprobs, - finish_reason='tool_calls' if tools_called else output.stop_reason, + finish_reason='tool_calls' if tools_called else output.stop_reason if output.stop_reason else 'stop', stop_reason=output.stop_reason ) choices.append(choice_data) From 845578699633bd46041b89cc2becab58a26c248a Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sun, 7 Jul 2024 17:17:47 -0500 Subject: [PATCH 039/222] fix: typo introduced in earlier commit --- vllm/entrypoints/openai/serving_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 395621919f2db..c0c845b2c2a74 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -374,7 +374,7 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str: async def chat_completion_stream_generator( self, request: ChatCompletionRequest, - result_generator: AsyncIterator[RequestOfiutput], request_id: str, + result_generator: AsyncIterator[RequestOutput], request_id: str, conversation: List[ConversationMessage] ) -> AsyncGenerator[str, None]: model_name = self.served_model_names[0] From a97ccc75668b032c934e0a642668ba25057dba88 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sun, 7 Jul 2024 19:11:28 -0500 Subject: [PATCH 040/222] feat: signature updates and refactoring to tool parser streaming; prepare for streaming tools --- vllm/entrypoints/openai/serving_chat.py | 31 +++++++++++++++++++++---- vllm/entrypoints/openai/tool_parsers.py | 20 +++++++++++++--- 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index c0c845b2c2a74..6aeafe27a4028 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -373,8 +373,10 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str: return request.messages[-1]["role"] async def chat_completion_stream_generator( - self, request: ChatCompletionRequest, - result_generator: AsyncIterator[RequestOutput], request_id: str, + self, + request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], + request_id: str, conversation: List[ConversationMessage] ) -> AsyncGenerator[str, None]: model_name = self.served_model_names[0] @@ -386,6 +388,7 @@ async def chat_completion_stream_generator( assert request.n is not None previous_texts = [""] * request.n previous_num_tokens = [0] * request.n + previous_token_ids = [[]] * request.n finish_reason_sent = [False] * request.n try: async for res in result_generator: @@ -447,7 +450,9 @@ async def chat_completion_stream_generator( first_iteration = False for output in res.outputs: + i = output.index + print(f'[{i}]:', output) if finish_reason_sent[i]: continue @@ -467,21 +472,38 @@ async def chat_completion_stream_generator( else: logprobs = None + delta_text = output.text[len(previous_texts[i]):] - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) + # handle streaming deltas for tools with tool_choice if request.tool_choice and type( request.tool_choice ) is ChatCompletionNamedToolChoiceParam: + print('handling streaming for tools with tool choice!') delta_message = DeltaMessage(tool_calls=[ ToolCall(function=FunctionCall( name=request.tool_choice.function.name, arguments=delta_text)) ]) + + # handle streaming deltas for tools with tool_choice + elif request.tools and (request.tool_choice is None or request.tool_choice == 'auto'): + print('handling streaming for tools with no tool choice!') + delta_message = DeltaMessage(content=delta_text) else: + print('handling streaming for normal message') delta_message = DeltaMessage(content=delta_text) + # handle setting the previous values for the next iteration + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + previous_token_ids[i] = output.token_ids + print('previous texts:', previous_texts) + print('delta_text:', delta_text) + print('previous_num_tokens:', previous_num_tokens) + print('previous token IDs', previous_token_ids) + print('delta token IDs: ', delta_token_ids) + if output.finish_reason is None: # Send token-by-token response for each request.n @@ -523,6 +545,7 @@ async def chat_completion_stream_generator( yield f"data: {data}\n\n" finish_reason_sent[i] = True + # once the final token is handled, if stream_options.include_usage is sent, send the usage if (request.stream_options and request.stream_options.include_usage): final_usage = UsageInfo( diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index 88e505859c946..4a362806f58ee 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -6,7 +6,7 @@ import json from pydantic import BaseModel import re - +from vllm.entrypoints.openai.protocol import DeltaMessage logger = init_logger(__name__) @@ -18,7 +18,14 @@ def __init__(self): def extract_tool_calls(self, model_output: str) -> ExtractedToolCallInformation: raise NotImplementedError('AbstractToolParser.extract_tool_calls has not been implemented!') - def extract_tool_calls_streaming(self, generator): + def extract_tool_calls_streaming(self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: List[int], + current_token_ids: List[int], + delta_token_ids: List[int], + ) -> DeltaMessage: raise NotImplementedError('AbstractToolParser.extract_tool_calls_streaming has not been implemented!') @@ -130,5 +137,12 @@ def extract_tool_calls(self, model_output: str) -> ExtractedToolCallInformation: content=model_output ) - def extract_tool_calls_streaming(self, generator): + def extract_tool_calls_streaming(self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: List[int], + current_token_ids: List[int], + delta_token_ids: List[int] + ) -> DeltaMessage: raise NotImplementedError('Hermes2ProToolParser.extract_tool_calls_streaming has not been implemented!') From c697e9f279bfd33520f6df0750a679c80cc67439 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sun, 7 Jul 2024 23:18:44 -0500 Subject: [PATCH 041/222] fix: kind of fixed mistral chat template --- examples/tool_chat_template_mistral.jinja | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/tool_chat_template_mistral.jinja b/examples/tool_chat_template_mistral.jinja index 6a99d80802067..29efc61c2fd50 100644 --- a/examples/tool_chat_template_mistral.jinja +++ b/examples/tool_chat_template_mistral.jinja @@ -1 +1 @@ -{{ bos_token }}{% set user_messages = messages | selectattr('role', 'equalto', 'user') | list %}{% for message in messages %}{% if message['role'] == 'user' %}{% if message == user_messages[-1] %}{% if tools %}{{ '[AVAILABLE_TOOLS]'+ tools|string + '[/AVAILABLE_TOOLS]' }}{% endif %}{{ '[INST]' + message['content'] + '[/INST]' }}{% else %}{{ '[INST]' + message['content'] + '[/INST]' }}{% endif %}{% elif message['role'] == 'assistant' and message['tool_calls'] and message['tool_calls']|length > 0 %}{{ '[TOOL_CALLS]' + message['content']|string + eos_token }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token }}{% elif message['role'] == 'tool' %}{{ '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }}{% endif %}{% endfor %} \ No newline at end of file +{{ bos_token }}{% set user_messages = messages | selectattr('role', 'equalto', 'user') | list %}{% for message in messages %}{% if message['role'] == 'user' %}{% if message == user_messages[-1] %}{% if tools %}{{ '[AVAILABLE_TOOLS]'+ tools|string + '[/AVAILABLE_TOOLS]' }}{% endif %}{{ '[INST]' + message['content'] + '[/INST]' }}{% else %}{{ '[INST]' + message['content'] + '[/INST]' }}{% endif %}{% elif message['role'] == 'assistant' and message['tool_calls'] and message['tool_calls']|length > 0 %}{{ '[TOOL_CALLS]' + message['tool_calls']|string + eos_token }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token }}{% elif message['role'] == 'tool' %}{{ '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }}{% endif %}{% endfor %} \ No newline at end of file From df877f6449d464f2e9c6e523582ffa25a3026133 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sun, 7 Jul 2024 23:33:01 -0500 Subject: [PATCH 042/222] feat: make non-streaming tool parsing a static method so that streaming can have state --- vllm/entrypoints/openai/serving_chat.py | 34 +++++++++++++++-------- vllm/entrypoints/openai/tool_parsers.py | 36 ++++++++++++++++--------- 2 files changed, 47 insertions(+), 23 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 6aeafe27a4028..7fc802d1acb37 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -2,7 +2,7 @@ import time from dataclasses import dataclass, field from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, - List, Optional) + List, Optional, Type) from typing import Sequence as GenericSequence from typing import TypedDict, Union, cast, final @@ -91,9 +91,9 @@ def __init__(self, raise TypeError('Error: --enable-auto-tool-choice requires --tool-choice-parser') if tool_parser == 'mistral': - self.tool_parser: ToolParser = MistralToolParser() + self.tool_parser: Type[ToolParser] = MistralToolParser elif tool_parser == 'hermes': - self.tool_parser: ToolParser = Hermes2ProToolParser() + self.tool_parser: Type[ToolParser] = Hermes2ProToolParser else: raise ValueError(f'Invalid tool parser value {tool_parser}!') @@ -390,6 +390,8 @@ async def chat_completion_stream_generator( previous_num_tokens = [0] * request.n previous_token_ids = [[]] * request.n finish_reason_sent = [False] * request.n + + tool_parser: ToolParser = self.tool_parser() try: async for res in result_generator: # We need to do it here, because if there are exceptions in @@ -475,6 +477,7 @@ async def chat_completion_stream_generator( delta_text = output.text[len(previous_texts[i]):] + # handle streaming deltas for tools with tool_choice if request.tool_choice and type( request.tool_choice @@ -487,9 +490,18 @@ async def chat_completion_stream_generator( ]) # handle streaming deltas for tools with tool_choice - elif request.tools and (request.tool_choice is None or request.tool_choice == 'auto'): + elif (request.tools and (request.tool_choice is None or request.tool_choice == 'auto') + and self.enable_auto_tools): + print('handling streaming for tools with no tool choice!') - delta_message = DeltaMessage(content=delta_text) + delta_message = tool_parser.extract_tool_calls_streaming( + previous_text=previous_texts[i], + current_text=output.text, + delta_text=delta_text, + previous_token_ids=previous_token_ids[i], + current_token_ids=output.token_ids, + delta_token_ids=delta_token_ids + ) else: print('handling streaming for normal message') delta_message = DeltaMessage(content=delta_text) @@ -498,11 +510,11 @@ async def chat_completion_stream_generator( previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) previous_token_ids[i] = output.token_ids - print('previous texts:', previous_texts) - print('delta_text:', delta_text) - print('previous_num_tokens:', previous_num_tokens) - print('previous token IDs', previous_token_ids) - print('delta token IDs: ', delta_token_ids) + + # if the message delta is None (e.g. because it was a "control token" for tool calls, then + # get the next token without streaming a chunk + if delta_message is None: + continue if output.finish_reason is None: # Send token-by-token response for each request.n @@ -631,7 +643,7 @@ async def chat_completion_full_generator( message = ChatMessage(role=role, content=output.text) # handle when there are tools and tool choice is auto - elif request.tools and (request.tool_choice == "auto" or request.tool_choice is None) and self.enable_auto_tools and self.tool_parser: + elif request.tools and (request.tool_choice == "auto" or request.tool_choice is None) and self.enable_auto_tools: tool_call_info = self.tool_parser.extract_tool_calls(output.text) tools_called = tool_call_info.tools_called diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index 4a362806f58ee..8a7c44c3238ef 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -15,7 +15,8 @@ class ToolParser: def __init__(self): pass - def extract_tool_calls(self, model_output: str) -> ExtractedToolCallInformation: + @staticmethod + def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: raise NotImplementedError('AbstractToolParser.extract_tool_calls has not been implemented!') def extract_tool_calls_streaming(self, @@ -25,17 +26,19 @@ def extract_tool_calls_streaming(self, previous_token_ids: List[int], current_token_ids: List[int], delta_token_ids: List[int], - ) -> DeltaMessage: + ) -> DeltaMessage | None: raise NotImplementedError('AbstractToolParser.extract_tool_calls_streaming has not been implemented!') class MistralToolParser(ToolParser): bot_token: str = '[TOOL_CALLS]' + bot_token_id: int = 5 - def extract_tool_calls(self, model_output: str) -> ExtractedToolCallInformation: + @staticmethod + def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: # Get the tool call token from the tokenizer - if self.bot_token not in model_output: + if MistralToolParser.bot_token not in model_output: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], @@ -60,7 +63,7 @@ def extract_tool_calls(self, model_output: str) -> ExtractedToolCallInformation: ) for raw_function_call in function_call_arr ] - content = model_output.split(self.bot_token)[0] + content = model_output.split(MistralToolParser.bot_token)[0] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, @@ -77,9 +80,17 @@ def extract_tool_calls(self, model_output: str) -> ExtractedToolCallInformation: content=model_output ) - def extract_tool_calls_streaming(self, generator): - for chunk in generator: - print('CHUNK', chunk) + def extract_tool_calls_streaming(self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: List[int], + current_token_ids: List[int], + delta_token_ids: List[int], + ) -> DeltaMessage | None: + + + return DeltaMessage(content=delta_text) class Hermes2ProToolParser(ToolParser): @@ -91,10 +102,11 @@ class Hermes2ProToolParser(ToolParser): tool_call_regex = re.compile(r'(.*?)|(.*)', re.DOTALL) scratch_pad_regex = re.compile(r'(.*?)', re.DOTALL) - def extract_tool_calls(self, model_output: str) -> ExtractedToolCallInformation: + @staticmethod + def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: # sanity check; avoid unnecessary processing - if self.tool_call_start not in model_output: + if Hermes2ProToolParser.tool_call_start not in model_output: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], @@ -106,7 +118,7 @@ def extract_tool_calls(self, model_output: str) -> ExtractedToolCallInformation: try: # there are two possible captures - between tags, or between a tag and end-of-string so the result of findall # is an array of tuples where one is a function call and the other is None - function_call_tuples = self.tool_call_regex.findall(model_output) + function_call_tuples = Hermes2ProToolParser.tool_call_regex.findall(model_output) # load the JSON, and then use it to build the Function and Tool Call raw_function_calls = [json.loads(match[0] if match[0] else match[1]) for match in function_call_tuples] @@ -120,7 +132,7 @@ def extract_tool_calls(self, model_output: str) -> ExtractedToolCallInformation: ) ) for function_call in raw_function_calls ] - content_match = self.scratch_pad_regex.search(model_output) + content_match = Hermes2ProToolParser.scratch_pad_regex.search(model_output) content = content_match.group(1) if content_match else None return ExtractedToolCallInformation( tools_called=True, From 8fa57aed2a10eb0000f78e58852211e741865573 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sun, 7 Jul 2024 23:57:47 -0500 Subject: [PATCH 043/222] partial: work on streaming tool call parser for mistral --- vllm/entrypoints/openai/tool_parsers.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index 8a7c44c3238ef..f32b87d908a07 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -89,8 +89,23 @@ def extract_tool_calls_streaming(self, delta_token_ids: List[int], ) -> DeltaMessage | None: + # if the tool call token ID is not in the tokens generated so far, append output to contents + if self.bot_token_id not in current_token_ids: + return DeltaMessage(content=delta_text) + else: + + # if the bot token is the only token in the delta, return None so we don't ship a delta to the client + if len(delta_token_ids) == 1 and delta_token_ids[0] == self.bot_token_id: + return None + + # for mistral, everything after the BOT token is tool call, not content. If there's content + # which I have yet to see, it would HAVE to come BEFORE the BOT token + else: + # Now we get into partial JSON parsing + # TODO IMPLEMENT THIS + return DeltaMessage(content=delta_text) + - return DeltaMessage(content=delta_text) class Hermes2ProToolParser(ToolParser): @@ -149,6 +164,11 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: content=model_output ) + + def __init__(self): + self.current_tool_count: int = 0 + self.current_tool_name_sent: bool = False # reset each time we encounter a new tool in the array + def extract_tool_calls_streaming(self, previous_text: str, current_text: str, From 301b02e0c163f0f93fab21e40fda4a0bd8c2f858 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 9 Jul 2024 19:20:45 -0500 Subject: [PATCH 044/222] deps: add partial-json-parser for parsing streaming JSON --- requirements-common.txt | 1 + vllm/entrypoints/openai/tool_parsers.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index 636f85343e1f2..036c440e5cbc3 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -21,3 +21,4 @@ lm-format-enforcer == 0.10.1 outlines >= 0.0.43 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 +partial-json-parser # used for parsing partial JSON outputs \ No newline at end of file diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index f32b87d908a07..9e246968174db 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -4,7 +4,7 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) import json -from pydantic import BaseModel +import partial_json_parser import re from vllm.entrypoints.openai.protocol import DeltaMessage logger = init_logger(__name__) From 1364bc104eb28e06d39c2c67a7f5ef26a6329940 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 9 Jul 2024 23:19:42 -0500 Subject: [PATCH 045/222] fix: protocol stuff, work on mistral streaming --- vllm/entrypoints/openai/protocol.py | 21 +++- vllm/entrypoints/openai/serving_chat.py | 6 +- vllm/entrypoints/openai/tool_parsers.py | 139 +++++++++++++++++++++--- 3 files changed, 147 insertions(+), 19 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index f02626fdd8fd3..b6d532bf5e423 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -647,6 +647,7 @@ def to_dict(self): "arguments": self.arguments } + class ToolCall(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") type: Literal["function"] = "function" @@ -660,6 +661,24 @@ def to_dict(self): } +class DeltaFunctionCall(FunctionCall): + name: Optional[str] = None + arguments: Optional[str] = None + + +# a tool call delta where everything is optional +class DeltaToolCall(ToolCall): + index: int # this is always required, the index of the tool call in the tool_calls array. + function: Optional[DeltaFunctionCall] = None + + +# the initial delta that gets sent once a new tool call is started; differs in that it includes an auto-set id and type +class InitialDeltaToolCall(DeltaToolCall): + id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") + type: Literal["function"] = "function" + index: int + + class ExtractedToolCallInformation(BaseModel): # indicate if tools were called tools_called: bool @@ -712,7 +731,7 @@ class ChatCompletionResponse(OpenAIBaseModel): class DeltaMessage(OpenAIBaseModel): role: Optional[str] = None content: Optional[str] = None - tool_calls: List[ToolCall] = Field(default_factory=list) + tool_calls: List[DeltaToolCall] = Field(default_factory=list) class ChatCompletionResponseStreamChoice(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 7fc802d1acb37..613e071272735 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -454,7 +454,8 @@ async def chat_completion_stream_generator( for output in res.outputs: i = output.index - print(f'[{i}]:', output) + # prints the full completion so far including text and tokens + #print(f'[{i}]:', output) if finish_reason_sent[i]: continue @@ -482,7 +483,6 @@ async def chat_completion_stream_generator( if request.tool_choice and type( request.tool_choice ) is ChatCompletionNamedToolChoiceParam: - print('handling streaming for tools with tool choice!') delta_message = DeltaMessage(tool_calls=[ ToolCall(function=FunctionCall( name=request.tool_choice.function.name, @@ -493,7 +493,6 @@ async def chat_completion_stream_generator( elif (request.tools and (request.tool_choice is None or request.tool_choice == 'auto') and self.enable_auto_tools): - print('handling streaming for tools with no tool choice!') delta_message = tool_parser.extract_tool_calls_streaming( previous_text=previous_texts[i], current_text=output.text, @@ -503,7 +502,6 @@ async def chat_completion_stream_generator( delta_token_ids=delta_token_ids ) else: - print('handling streaming for normal message') delta_message = DeltaMessage(content=delta_text) # handle setting the previous values for the next iteration diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index 9e246968174db..4240f936f48d3 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -1,12 +1,15 @@ -from vllm.entrypoints.openai.protocol import ToolCall, FunctionCall, ChatCompletionResponse, ExtractedToolCallInformation +from vllm.entrypoints.openai.protocol import ToolCall, FunctionCall, ChatCompletionResponse, \ + ExtractedToolCallInformation, DeltaToolCall, InitialDeltaToolCall, DeltaFunctionCall from vllm.logger import init_logger -from typing import List +from typing import List, Dict from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) import json import partial_json_parser +from partial_json_parser import Allow import re from vllm.entrypoints.openai.protocol import DeltaMessage + logger = init_logger(__name__) @@ -48,8 +51,8 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: try: # extract the token so we hopefully have a JSON string raw_tool_call = (model_output - .replace(MistralToolParser.bot_token, '') # remove BOT token - .replace("'", '"')) # ... hack to parse broken mistral JSON + .replace(MistralToolParser.bot_token, '') # remove BOT token + .replace("'", '"')) # ... hack to parse broken mistral JSON # load the JSON, and then use it to build the Function and Tool Call function_call_arr = json.loads(raw_tool_call) tool_calls: List[ToolCall] = [ @@ -80,6 +83,13 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: content=model_output ) + def __init__(self): + super().__init__() + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.current_tool_initial_sent: bool = False + def extract_tool_calls_streaming(self, previous_text: str, current_text: str, @@ -92,24 +102,125 @@ def extract_tool_calls_streaming(self, # if the tool call token ID is not in the tokens generated so far, append output to contents if self.bot_token_id not in current_token_ids: return DeltaMessage(content=delta_text) + + # if the tool call token ID IS in the tokens generated so far, that means we're parsing as tool calls now else: - # if the bot token is the only token in the delta, return None so we don't ship a delta to the client - if len(delta_token_ids) == 1 and delta_token_ids[0] == self.bot_token_id: - return None + # handle if we detected the BOT token which means the start of tool calling + if self.bot_token_id in delta_token_ids: + logger.info('Found bot_token!') + + # if it's the only token, return None, so we don't send a chat completion + if len(delta_token_ids) == 1: + return None + # for mistral, everything after the BOT token is tool call, not content. If there's content # which I have yet to see, it would HAVE to come BEFORE the BOT token - else: - # Now we get into partial JSON parsing - # TODO IMPLEMENT THIS - return DeltaMessage(content=delta_text) + # flags for partial JSON parsing (lib uses bit mask) + # if the tool name has been sent then allow any incomplete field ELSE allow everything BUT strings + # to avoid sending the partial tool name incorrectly + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR + try: + # note this is basically only the way to do this - just make sure your tool arguments will + # never be something containing an apostrophe + parsable_arr = (current_text + .replace(self.bot_token, '') # remove BOT token to get valid json + .replace('\'', '"') # replace mistral single quotes with double for JSON parsing + ) + logger.info('parsing: %s', parsable_arr) + tool_call_arr: List[Dict] = partial_json_parser.loads(parsable_arr, flags) + #print('parsed ', tool_call_arr) + # case: we are starting a new tool in the array + # -> array has nonzero length AND length has moved past cursor + if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1: + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.current_tool_initial_sent = False + logger.info('starting on new tool %d', self.current_tool_id) -class Hermes2ProToolParser(ToolParser): + # case: there is no tool in the array + elif len(tool_call_arr) - 1 == self.current_tool_id and self.current_tool_id >= 0: + logger.info('update to tool %d', self.current_tool_id) + + # if there is NOTHING in the array + else: + logger.info('No tool call detected yet!') + return None # TODO FIX + + # handle parsing + current_tool_call: Dict = tool_call_arr[self.current_tool_id] + + # if the current tool initial data incl. the id, type=function and idx not sent, send that + if not self.current_tool_initial_sent: + logger.info('Sending InitialDeltaToolCall') + self.current_tool_initial_sent = True + delta = DeltaMessage( + tool_calls=[ + InitialDeltaToolCall(index=self.current_tool_id).model_dump(exclude_none=True)] + ) + + # if the current tool name hasn't been sent, send if available - otherwise no chunks + elif not self.current_tool_name_sent: + function_name = current_tool_call.get('name') + logger.info('Sending DeltaToolCall with function name!') + if function_name: + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + # now we know we're on the same tool call and we're streaming arguments + else: + # TODO be more clever about this - I think we can grab the raw string for the current tool ONCE + # the arguments key is defined and stream everything generated after it? Be careful of end of arr tho... + # diff arguments from previous generation against arguments from current generation + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get('arguments') + cur_arguments = current_tool_call.get('arguments') + + # if arguments is not defined for the current function yet, skip the chunk + if not cur_arguments and not prev_arguments: + logger.info('Skipping chunk - no argument characters received yet!') + delta = None + + # INVARIANT - we have previously-defined arguments, but now they're undefined + elif prev_arguments and not cur_arguments: + logger.error('INVARIANT - we have current arguments for the function, but not previous ones!') + delta = None + + # we have first values for arguments: + elif not prev_arguments and cur_arguments: + logger.info('We have arguments for the function!') + delta = None # TODO replace + + # if the arguments are the same it's prob a structural/control char difference; don't stream a chunk + elif prev_arguments and cur_arguments and prev_arguments == cur_arguments: + # TODO can we be clever and figure out how to stream this by processing the string of + # only the current tool? + logger.info(f'Skipping - control/structure character received in arguments: {prev_arguments} vs. {cur_arguments}') + delta = None + + elif prev_arguments and cur_arguments and prev_arguments != cur_arguments: + logger.info('We have new values for arguments for the function!') + + # check to see if the name is defined and has been sent. if so, stream the name - otherwise keep waiting + # finish by setting old and returning None as base case + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception as e: + logger.error(f'Error trying to handle streaming tool call: {e}') + logger.info('skipping returning a chunk here - maybe we just need more?') + return None + + +class Hermes2ProToolParser(ToolParser): tool_call_start: str = '' tool_call_end: str = '' @@ -164,10 +275,10 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: content=model_output ) - def __init__(self): + super().__init__() self.current_tool_count: int = 0 - self.current_tool_name_sent: bool = False # reset each time we encounter a new tool in the array + self.current_tool_name_sent: bool = False # reset each time we encounter a new tool in the array def extract_tool_calls_streaming(self, previous_text: str, From cbd8919bef34fa7313f207d81bd4ad738093b622 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 10 Jul 2024 01:48:22 -0500 Subject: [PATCH 046/222] feat: progress on mistral streaming parser --- vllm/entrypoints/openai/tool_parsers.py | 111 ++++++++++++++++++------ 1 file changed, 83 insertions(+), 28 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index 4240f936f48d3..8d13f5882eb1b 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -13,6 +13,54 @@ logger = init_logger(__name__) +def find_common_prefix(s1: str, s2: str) -> str: + prefix = '' + min_length = min(len(s1), len(s2)) + for i in range(0, min_length): + if s1[i] == s2[i]: + prefix += s1[i] + else: + break + return prefix + + +def find_common_suffix(s1: str, s2: str) -> str: + suffix = '' + min_length = min(len(s1), len(s2)) + for i in range(1, min_length + 1): + if s1[-i] == s2[-i]: + suffix = s1[-i] + suffix + else: + break + return suffix + + +def extract_intermediate_diff(s1: str, s2: str) -> str: + """ + Extract the difference in the middle between two strings that are KNOWN to have a common prefix and OPTIONALLY + also a common suffix + """ + prefix = find_common_prefix(s1, s2) + suffix = find_common_suffix(s1, s2) + diff = s1 + if len(prefix): + diff = diff.replace(prefix, '', 1) # replace the prefix only once in case it's mirrored + if len(suffix): + diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1] + return diff + + +def find_all_indices(string, substring): + indices = [] + index = -1 + while True: + index = string.find(substring, index + 1) + if index == -1: + break + indices.append(index) + return indices + + class ToolParser: def __init__(self): @@ -127,8 +175,8 @@ def extract_tool_calls_streaming(self, # note this is basically only the way to do this - just make sure your tool arguments will # never be something containing an apostrophe parsable_arr = (current_text - .replace(self.bot_token, '') # remove BOT token to get valid json - .replace('\'', '"') # replace mistral single quotes with double for JSON parsing + .replace(self.bot_token, '') # remove BOT token to get valid json + .replace('\'', '"') # replace mistral single quotes with double for JSON parsing ) logger.info('parsing: %s', parsable_arr) tool_call_arr: List[Dict] = partial_json_parser.loads(parsable_arr, flags) @@ -143,14 +191,15 @@ def extract_tool_calls_streaming(self, self.current_tool_initial_sent = False logger.info('starting on new tool %d', self.current_tool_id) - # case: there is no tool in the array + # case: update an existing tool elif len(tool_call_arr) - 1 == self.current_tool_id and self.current_tool_id >= 0: - logger.info('update to tool %d', self.current_tool_id) + # logger.info('update to tool %d', self.current_tool_id) + pass # if there is NOTHING in the array else: logger.info('No tool call detected yet!') - return None # TODO FIX + return None # handle parsing current_tool_call: Dict = tool_call_arr[self.current_tool_id] @@ -167,8 +216,8 @@ def extract_tool_calls_streaming(self, # if the current tool name hasn't been sent, send if available - otherwise no chunks elif not self.current_tool_name_sent: function_name = current_tool_call.get('name') - logger.info('Sending DeltaToolCall with function name!') if function_name: + logger.info(f'Sending DeltaToolCall with function name {function_name}!') delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True)) ]) @@ -178,37 +227,43 @@ def extract_tool_calls_streaming(self, # now we know we're on the same tool call and we're streaming arguments else: - # TODO be more clever about this - I think we can grab the raw string for the current tool ONCE - # the arguments key is defined and stream everything generated after it? Be careful of end of arr tho... - # diff arguments from previous generation against arguments from current generation + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get('arguments') cur_arguments = current_tool_call.get('arguments') - # if arguments is not defined for the current function yet, skip the chunk if not cur_arguments and not prev_arguments: - logger.info('Skipping chunk - no argument characters received yet!') + logger.info(f'Skipping text {delta_text} (tokens {delta_token_ids}) - no arguments yet') delta = None - - # INVARIANT - we have previously-defined arguments, but now they're undefined - elif prev_arguments and not cur_arguments: - logger.error('INVARIANT - we have current arguments for the function, but not previous ones!') + elif not cur_arguments and prev_arguments: + logger.error('INVARIANT - impossible to have arguments reset mid-arguments') delta = None + elif cur_arguments and not prev_arguments: + logger.info('First tokens in arguments received') + cur_arguments_json = json.dumps(cur_arguments) + logger.info(f'Finding {delta_text} in |{cur_arguments_json}') + arguments_delta = cur_arguments_json[:cur_arguments_json.index(delta_text) + len(delta_text)] + logger.info(f'First tokens in arguments received: {arguments_delta}') + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( + arguments=arguments_delta + ).model_dump(exclude_none=True)) + ]) - # we have first values for arguments: - elif not prev_arguments and cur_arguments: - logger.info('We have arguments for the function!') - delta = None # TODO replace - - # if the arguments are the same it's prob a structural/control char difference; don't stream a chunk - elif prev_arguments and cur_arguments and prev_arguments == cur_arguments: - # TODO can we be clever and figure out how to stream this by processing the string of - # only the current tool? - logger.info(f'Skipping - control/structure character received in arguments: {prev_arguments} vs. {cur_arguments}') + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + shared_prefix = find_common_prefix(cur_args_json, prev_args_json) + cur_args_json = cur_args_json.replace(shared_prefix, '', 1) + argument_diff = cur_args_json[:cur_args_json.index(delta_text) + len(delta_text)] + logger.info(f'got arguments diff: {argument_diff}') + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( + arguments=argument_diff + ).model_dump(exclude_none=True)) + ]) + else: delta = None - elif prev_arguments and cur_arguments and prev_arguments != cur_arguments: - logger.info('We have new values for arguments for the function!') - # check to see if the name is defined and has been sent. if so, stream the name - otherwise keep waiting # finish by setting old and returning None as base case self.prev_tool_call_arr = tool_call_arr From d480db60cfa15c4d2ba6c626a0ba7d2c0f039194 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 10 Jul 2024 02:14:29 -0500 Subject: [PATCH 047/222] fix: some tool parser stuff. best its working yet --- vllm/entrypoints/openai/tool_parsers.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index 8d13f5882eb1b..b5f5a1026b15d 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -231,17 +231,18 @@ def extract_tool_calls_streaming(self, prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get('arguments') cur_arguments = current_tool_call.get('arguments') + new_text = delta_text.replace('\'', '"') + if not cur_arguments and not prev_arguments: - logger.info(f'Skipping text {delta_text} (tokens {delta_token_ids}) - no arguments yet') + logger.info(f'Skipping text {new_text} (tokens {delta_token_ids}) - no arguments yet') delta = None elif not cur_arguments and prev_arguments: logger.error('INVARIANT - impossible to have arguments reset mid-arguments') delta = None elif cur_arguments and not prev_arguments: - logger.info('First tokens in arguments received') cur_arguments_json = json.dumps(cur_arguments) - logger.info(f'Finding {delta_text} in |{cur_arguments_json}') - arguments_delta = cur_arguments_json[:cur_arguments_json.index(delta_text) + len(delta_text)] + logger.info(f'Finding {new_text} in |{cur_arguments_json}') + arguments_delta = cur_arguments_json[:cur_arguments_json.index(new_text) + len(new_text)] logger.info(f'First tokens in arguments received: {arguments_delta}') delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( @@ -252,9 +253,13 @@ def extract_tool_calls_streaming(self, elif cur_arguments and prev_arguments: cur_args_json = json.dumps(cur_arguments) prev_args_json = json.dumps(prev_arguments) + logger.info(f'Searching for diff between \n{cur_args_json}\n{prev_args_json}') shared_prefix = find_common_prefix(cur_args_json, prev_args_json) + logger.info(f'Shared prefix: |{shared_prefix}|', ) cur_args_json = cur_args_json.replace(shared_prefix, '', 1) - argument_diff = cur_args_json[:cur_args_json.index(delta_text) + len(delta_text)] + logger.info(f'Cur args JSON: {cur_args_json}') + logger.info(f'new text: {new_text}') + argument_diff = cur_args_json[:cur_args_json.index(new_text) + len(new_text)] logger.info(f'got arguments diff: {argument_diff}') delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( From 305685e26676790397c57e2a50c861afd7893f36 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 10 Jul 2024 02:31:06 -0500 Subject: [PATCH 048/222] fix: major parsing logic issue when overlapping prefix & suffix due to json autocompletion --- vllm/entrypoints/openai/tool_parsers.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index b5f5a1026b15d..1de479802224c 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -40,13 +40,23 @@ def extract_intermediate_diff(s1: str, s2: str) -> str: Extract the difference in the middle between two strings that are KNOWN to have a common prefix and OPTIONALLY also a common suffix """ - prefix = find_common_prefix(s1, s2) suffix = find_common_suffix(s1, s2) + logger.info(f'Found suffix {suffix}') + + # prevent double-counting + s2_old = s2 + s2 = s2[::-1].replace(suffix[::-1], '', 1)[::-1] + logger.info(f'Updated search term s2 from {s2_old} to {s2}') + prefix = find_common_prefix(s1, s2) diff = s1 - if len(prefix): - diff = diff.replace(prefix, '', 1) # replace the prefix only once in case it's mirrored if len(suffix): + logger.info(f'Nuking suffix {suffix}') diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1] + + if len(prefix): + logger.info(f'Nuking prefix {prefix}') + diff = diff.replace(prefix, '', 1) # replace the prefix only once in case it's mirrored + return diff @@ -254,12 +264,7 @@ def extract_tool_calls_streaming(self, cur_args_json = json.dumps(cur_arguments) prev_args_json = json.dumps(prev_arguments) logger.info(f'Searching for diff between \n{cur_args_json}\n{prev_args_json}') - shared_prefix = find_common_prefix(cur_args_json, prev_args_json) - logger.info(f'Shared prefix: |{shared_prefix}|', ) - cur_args_json = cur_args_json.replace(shared_prefix, '', 1) - logger.info(f'Cur args JSON: {cur_args_json}') - logger.info(f'new text: {new_text}') - argument_diff = cur_args_json[:cur_args_json.index(new_text) + len(new_text)] + argument_diff = extract_intermediate_diff(cur_args_json, prev_args_json ) logger.info(f'got arguments diff: {argument_diff}') delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( From e47f70f32f6ade2457dfdc330f2b30d2c3997123 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 10 Jul 2024 03:25:03 -0500 Subject: [PATCH 049/222] feat: implement mistral tool calling streaming for ONE TOOL ONLY RIGHT NOW --- vllm/entrypoints/openai/serving_chat.py | 22 +++++++++++++++- vllm/entrypoints/openai/tool_parsers.py | 34 +++++++++++++++++++++---- 2 files changed, 50 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 613e071272735..047362697f02a 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,5 +1,6 @@ import codecs import time +import json from dataclasses import dataclass, field from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, List, Optional, Type) @@ -20,7 +21,7 @@ ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, - FunctionCall, ToolCall, UsageInfo) + FunctionCall, ToolCall, UsageInfo, DeltaToolCall, DeltaFunctionCall) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) from vllm.inputs import PromptInputs @@ -534,6 +535,25 @@ async def chat_completion_stream_generator( data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" else: + logger.info(f'Sending FINISH message with delta {delta_message.model_dump()}') + # check to make sure we haven't missed something on the last function call + if ( + delta_message.tool_calls[0].function.arguments == '' + or delta_message.tool_calls[0].function.arguments + and (output.finish_reason == 'stop' or output.finish_reason == 'tool_calls') + ): + expected_call = json.dumps( + tool_parser.prev_tool_call_arr[len(tool_parser.prev_tool_call_arr) - 1].get('arguments', {}) + ) + logger.info(f'Expected tool call {expected_call}') + actual_call = tool_parser.streamed_args_for_tool[len(tool_parser.prev_tool_call_arr) - 1] + logger.info(f'Actual tool call {actual_call}') + remaining_call = expected_call.replace(actual_call, '', 1) + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall(index=len(tool_parser.prev_tool_call_arr) - 1, function=DeltaFunctionCall( + arguments=remaining_call + ).model_dump(exclude_none=True)) + ]) # Send the finish response for each request.n only once prompt_tokens = len(res.prompt_token_ids) choice_data = ChatCompletionResponseStreamChoice( diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index 1de479802224c..cb5199456d5d3 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -74,7 +74,11 @@ def find_all_indices(string, substring): class ToolParser: def __init__(self): - pass + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.current_tool_initial_sent: bool = False + self.streamed_args_for_tool: List[str] = [] # map what has been streamed for each tool so far to a list @staticmethod def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: @@ -147,6 +151,7 @@ def __init__(self): self.current_tool_id: int = -1 self.current_tool_name_sent: bool = False self.current_tool_initial_sent: bool = False + self.streamed_args_for_tool: List[str] = [] # map what has been streamed for each tool so far to a list def extract_tool_calls_streaming(self, previous_text: str, @@ -190,15 +195,28 @@ def extract_tool_calls_streaming(self, ) logger.info('parsing: %s', parsable_arr) tool_call_arr: List[Dict] = partial_json_parser.loads(parsable_arr, flags) + current_tool_call: Dict = tool_call_arr[self.current_tool_id] + #print('parsed ', tool_call_arr) # case: we are starting a new tool in the array # -> array has nonzero length AND length has moved past cursor if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1: + logger.info('Checking for completeness of previous tool before moving on to next tool') + # if we're moving on to a new call, first make sure we haven't missed anything due to JSON completions + diff: str | None = current_tool_call.get('arguments') + if diff: + diff = diff.replace(self.streamed_args_for_tool[self.current_tool_id], '') + logger.info(f'Found diff between tools: {diff}') + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall(arguments=diff).model_dump(exclude_none=True)) + ]) + # re-set stuff pertaining to progress in the current tool self.current_tool_id = len(tool_call_arr) - 1 self.current_tool_name_sent = False self.current_tool_initial_sent = False + self.streamed_args_for_tool.append('') logger.info('starting on new tool %d', self.current_tool_id) # case: update an existing tool @@ -211,9 +229,6 @@ def extract_tool_calls_streaming(self, logger.info('No tool call detected yet!') return None - # handle parsing - current_tool_call: Dict = tool_call_arr[self.current_tool_id] - # if the current tool initial data incl. the id, type=function and idx not sent, send that if not self.current_tool_initial_sent: logger.info('Sending InitialDeltaToolCall') @@ -259,19 +274,23 @@ def extract_tool_calls_streaming(self, arguments=arguments_delta ).model_dump(exclude_none=True)) ]) + self.streamed_args_for_tool[self.current_tool_id] += arguments_delta elif cur_arguments and prev_arguments: cur_args_json = json.dumps(cur_arguments) prev_args_json = json.dumps(prev_arguments) logger.info(f'Searching for diff between \n{cur_args_json}\n{prev_args_json}') - argument_diff = extract_intermediate_diff(cur_args_json, prev_args_json ) + argument_diff = extract_intermediate_diff(cur_args_json, prev_args_json) logger.info(f'got arguments diff: {argument_diff}') delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( arguments=argument_diff ).model_dump(exclude_none=True)) ]) + self.streamed_args_for_tool[self.current_tool_id] += argument_diff else: + # try parsing it with regular JSON - if it works we're at the end, and we need to send the + # difference between tokens streamed so far and the valid JSON delta = None # check to see if the name is defined and has been sent. if so, stream the name - otherwise keep waiting @@ -344,6 +363,11 @@ def __init__(self): super().__init__() self.current_tool_count: int = 0 self.current_tool_name_sent: bool = False # reset each time we encounter a new tool in the array + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.current_tool_initial_sent: bool = False + self.streamed_args_for_tool: List[str] = [] # map what has been streamed for each tool so far to a list def extract_tool_calls_streaming(self, previous_text: str, From d8f4487483fe3949b6c2ef05aef118cff91d2c0f Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 10 Jul 2024 03:51:27 -0500 Subject: [PATCH 050/222] feat: update openai client to showcase streaming --- ...penai_chat_completion_client_with_tools.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/examples/openai_chat_completion_client_with_tools.py b/examples/openai_chat_completion_client_with_tools.py index 283740a21174d..202422eb94242 100644 --- a/examples/openai_chat_completion_client_with_tools.py +++ b/examples/openai_chat_completion_client_with_tools.py @@ -64,6 +64,30 @@ print("Chat completion results:") print(chat_completion) +print('\n\n') + +tool_calls_stream = client.chat.completions.create( + messages=messages, + model=model, + tools=tools, + stream=True +) + +chunks = [] +for chunk in tool_calls_stream: + chunks.append(chunk) +arguments = '' +for chunk in chunks: + if chunk.choices[0].delta.tool_calls: + if chunk.choices[0].delta.tool_calls[0].id: + print(f'streamed tool call id: {chunk.choices[0].delta.tool_calls[0].id}') + if chunk.choices[0].delta.tool_calls[0].function: + if chunk.choices[0].delta.tool_calls[0].function.name: + print(f'streamed tool call name: {chunk.choices[0].delta.tool_calls[0].function.name}') + if chunk.choices[0].delta.tool_calls[0].function.arguments: + arguments += chunk.choices[0].delta.tool_calls[0].function.arguments +print(f'streamed tool call arguments: {arguments}\n\n') + messages.append({ "role": "assistant", "tool_calls": chat_completion.choices[0].message.tool_calls @@ -95,8 +119,10 @@ def get_current_weather(city: str, state: str, unit: 'str'): messages=messages, model=model, tools=tools, + stream=False ) print(chat_completion_2) +print('\n\n') From cfa6d039970f2cf6c63783c1a87e93681deaf960 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 10 Jul 2024 03:52:04 -0500 Subject: [PATCH 051/222] fix: finish reason & debug logging --- vllm/entrypoints/openai/serving_chat.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 047362697f02a..58a809f09d2e9 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -535,7 +535,6 @@ async def chat_completion_stream_generator( data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" else: - logger.info(f'Sending FINISH message with delta {delta_message.model_dump()}') # check to make sure we haven't missed something on the last function call if ( delta_message.tool_calls[0].function.arguments == '' @@ -547,7 +546,7 @@ async def chat_completion_stream_generator( ) logger.info(f'Expected tool call {expected_call}') actual_call = tool_parser.streamed_args_for_tool[len(tool_parser.prev_tool_call_arr) - 1] - logger.info(f'Actual tool call {actual_call}') + logger.info(f'Actual tool call {actual_call}, correcting.') remaining_call = expected_call.replace(actual_call, '', 1) delta_message = DeltaMessage(tool_calls=[ DeltaToolCall(index=len(tool_parser.prev_tool_call_arr) - 1, function=DeltaFunctionCall( @@ -560,7 +559,7 @@ async def chat_completion_stream_generator( index=i, delta=delta_message, logprobs=logprobs, - finish_reason=output.finish_reason, + finish_reason=output.finish_reason if not len(tool_parser.prev_tool_call_arr) else 'tool_calls', stop_reason=output.stop_reason) chunk = ChatCompletionStreamResponse( id=request_id, From b2cb8fb177b1bb9eedd712c156a109c1e5e0c547 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 10 Jul 2024 03:53:44 -0500 Subject: [PATCH 052/222] fix(docs): CLI argument description was bad --- vllm/entrypoints/openai/cli_args.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 5900650983777..3c76c133a9806 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -119,8 +119,7 @@ def make_arg_parser(): parser.add_argument("--enable-auto-tool-choice", action="store_true", - help='Enable auto tool choice for models that support it. ' - 'Requires specifying --tool-use-prompt-template.' + help='Enable auto tool choice for models that support it. Requires --tool-call-parser' ) parser.add_argument("--tool-call-parser", From 625584a2b3ad2f36d9b6a8b2df7f96dcf552334b Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 13 Jul 2024 12:42:44 -0500 Subject: [PATCH 053/222] fix(parser): mistral tool parser issue that was giving me a stroke. Also, code cleanup & documentation --- vllm/entrypoints/openai/tool_parsers.py | 159 +++++++++++++++++------- 1 file changed, 117 insertions(+), 42 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index cb5199456d5d3..c31c5e7587fb6 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -14,6 +14,15 @@ def find_common_prefix(s1: str, s2: str) -> str: + """ + Finds a common prefix that is shared between two strings, if there is one. Order of arguments is NOT important. + + This function is provided as a UTILITY for extracting information from JSON generated by partial_json_parser, + to help in ensuring that the right tokens are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. + + e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '{"fruit": "ap' + """ prefix = '' min_length = min(len(s1), len(s2)) for i in range(0, min_length): @@ -25,42 +34,61 @@ def find_common_prefix(s1: str, s2: str) -> str: def find_common_suffix(s1: str, s2: str) -> str: + """ + Finds a common suffix shared between two strings, if there is one. Order of arguments is NOT important. + + e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}' + """ suffix = '' min_length = min(len(s1), len(s2)) for i in range(1, min_length + 1): - if s1[-i] == s2[-i]: + if s1[-i] == s2[-i] and not s1[-i].isalnum(): suffix = s1[-i] + suffix else: break return suffix -def extract_intermediate_diff(s1: str, s2: str) -> str: +def extract_intermediate_diff(curr: str, old: str) -> str: """ - Extract the difference in the middle between two strings that are KNOWN to have a common prefix and OPTIONALLY - also a common suffix + Given two strings, extract the difference in the middle between two strings that are known to have a common + prefix and/or suffix. + + This function is provided as a UTILITY for extracting information from JSON generated by partial_json_parser, + to help in ensuring that the right tokens are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. The order of arguments IS important - the new version of the + partially-parsed JSON must be the first argument, and the secnod argument must be from the previous generation. + + What it returns, is tokens that should be streamed to the client. + + e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}') -> 'ple' + e.g. extract_intermediate_diff('{"name": "get_current_weather", "city": "D"}', '{"name": "get_current_weather"}' -> + '", "city": "D' """ - suffix = find_common_suffix(s1, s2) + suffix = find_common_suffix(curr, old) logger.info(f'Found suffix {suffix}') # prevent double-counting - s2_old = s2 - s2 = s2[::-1].replace(suffix[::-1], '', 1)[::-1] - logger.info(f'Updated search term s2 from {s2_old} to {s2}') - prefix = find_common_prefix(s1, s2) - diff = s1 + s2_old = old + old = old[::-1].replace(suffix[::-1], '', 1)[::-1] + logger.info(f'Updated search term s2 from {s2_old} to {old}') + prefix = find_common_prefix(curr, old) + diff = curr if len(suffix): logger.info(f'Nuking suffix {suffix}') diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1] if len(prefix): logger.info(f'Nuking prefix {prefix}') - diff = diff.replace(prefix, '', 1) # replace the prefix only once in case it's mirrored + diff = diff.replace(prefix, '', 1) # replace the prefix only once in case it's mirrored return diff def find_all_indices(string, substring): + """ + Find all (starting) indices of a substring in a given string. Useful for tool call extraction + """ indices = [] index = -1 while True: @@ -72,16 +100,32 @@ def find_all_indices(string, substring): class ToolParser: + """ + Abstract ToolParser class that should not be used directly. Provided properties and methods should be used in + derived classes. + """ def __init__(self): + # the tool call array derived from partial JSON parsing from the previous execution of the function self.prev_tool_call_arr: List[Dict] = [] + # the index of the tool call that is currently being parsed self.current_tool_id: int = -1 + # indicates whether the name of the tool call that is currently being parsed has been sent. I have only seen + # OpenAI send the entire tool call name in a single chunk, so we wait until it has finished parsing. self.current_tool_name_sent: bool = False + # indicates if the initial tool call chunk with index, tool call ID etc has been sent. happens BEFORE the name + # is sent. self.current_tool_initial_sent: bool = False + # array of the argument strings (one for each tool) that have been streamed to the client. self.streamed_args_for_tool: List[str] = [] # map what has been streamed for each tool so far to a list @staticmethod def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: + """ + Static method that should be implemented for extracting tool calls from a complete model-generated string. + Used for non-streaming responses where we have the entire model response available before sending to the client. + Static because it's stateless. + """ raise NotImplementedError('AbstractToolParser.extract_tool_calls has not been implemented!') def extract_tool_calls_streaming(self, @@ -92,15 +136,40 @@ def extract_tool_calls_streaming(self, current_token_ids: List[int], delta_token_ids: List[int], ) -> DeltaMessage | None: + """ + Instance method that should be implemented for extracting tool calls from an incomplete response; for use when + handling tool calls and streaming. Has to be an instance method because it requires state - the current text/ + tokens/diffs, but also the information about what has previously been parsed and extracted (see constructor) + """ raise NotImplementedError('AbstractToolParser.extract_tool_calls_streaming has not been implemented!') class MistralToolParser(ToolParser): - bot_token: str = '[TOOL_CALLS]' - bot_token_id: int = 5 + """ + Tool call parser for Mistral 7B Instruct v0.3, intended for use with the examples/tool_chat_template_mistral.jinja + template. There are server IMPORTANT CAVEATS for this parser: + - The chat template is NOT official and does not work well if you try to get the model to call 2+ tools at once. + Stick to only one tool call per generation, as the chat template is not reliable with > 1 and the model + Will lose coherence. + - Mistral's tool call format, that this translates into an OpenAI format, uses SINGLE QUOTES which cannot be + parsed to JSON. To enable JSON parsing and serialization, we find-and-replace these with DOUBLE QUOTES. To + prevent tool call corruption / deserialization failure, ensure that your tool calls and in particular your + ARGUMENTS never contain single or double quotes except as JSON control characters. + + Used when --enable-api-tools --enable-auto-tool-choice --tool-call-parser mistral are all set + """ + + # the bot_token is the token indicating tool call(s) follow. Tokens before this token will be parsed as content; and + # if not present, the entire response will be parsed as text content. + bot_token: str = '[TOOL_CALLS]' # string literal + bot_token_id: int = 5 # token ID thereof from the models' tokenizer @staticmethod def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. Requires find-and-replacing single quotes with double + quotes for JSON parsing, make sure your tool call arguments don't ever include quotes! + """ # Get the tool call token from the tokenizer if MistralToolParser.bot_token not in model_output: @@ -136,9 +205,9 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: ) except Exception as e: - # TODO discussion on how to best handle invalidly-generated tool calls logger.error("Error in extracting tool call from response: %s", e) print('ERROR', e) + # return information to just treat the tool call as regular JSON return ExtractedToolCallInformation( tools_called=False, tool_calls=[], @@ -147,6 +216,8 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: def __init__(self): super().__init__() + + # initialize properties used for state when parsing tool calls in streaming mode self.prev_tool_call_arr: List[Dict] = [] self.current_tool_id: int = -1 self.current_tool_name_sent: bool = False @@ -162,7 +233,7 @@ def extract_tool_calls_streaming(self, delta_token_ids: List[int], ) -> DeltaMessage | None: - # if the tool call token ID is not in the tokens generated so far, append output to contents + # if the tool call token is not in the tokens generated so far, append output to contents since it's not a tool if self.bot_token_id not in current_token_ids: return DeltaMessage(content=delta_text) @@ -173,44 +244,47 @@ def extract_tool_calls_streaming(self, if self.bot_token_id in delta_token_ids: logger.info('Found bot_token!') - # if it's the only token, return None, so we don't send a chat completion + # if it's the only token, return None, so we don't send a chat completion any don't send a control token if len(delta_token_ids) == 1: return None - - # for mistral, everything after the BOT token is tool call, not content. If there's content - # which I have yet to see, it would HAVE to come BEFORE the BOT token - - # flags for partial JSON parsing (lib uses bit mask) - # if the tool name has been sent then allow any incomplete field ELSE allow everything BUT strings - # to avoid sending the partial tool name incorrectly + # bit mask flags for partial JSON parsing. If the name hasn't been sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have seen) allows sending the entire tool/ + # function name at once. flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR try: - # note this is basically only the way to do this - just make sure your tool arguments will - # never be something containing an apostrophe - parsable_arr = (current_text - .replace(self.bot_token, '') # remove BOT token to get valid json - .replace('\'', '"') # replace mistral single quotes with double for JSON parsing - ) + # replace BOT token with empty string, and convert single quotes to double to allow parsing as JSON + # since mistral uses single quotes instead of double for tool calls + tool_call_message_portion = current_text.split(self.bot_token)[1] + parsable_arr = tool_call_message_portion.replace('\'', '"') + logger.info('parsing: %s', parsable_arr) + + # tool calls are generated in an array, so do partial JSON parsing on the entire array tool_call_arr: List[Dict] = partial_json_parser.loads(parsable_arr, flags) + + # select as the current tool call the one we're on the state at + logger.info(f'Current tool call ID: {self.current_tool_id}') current_tool_call: Dict = tool_call_arr[self.current_tool_id] - #print('parsed ', tool_call_arr) + # print('parsed ', tool_call_arr) # case: we are starting a new tool in the array - # -> array has nonzero length AND length has moved past cursor + # -> array has nonzero length AND length has moved past curscor if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1: logger.info('Checking for completeness of previous tool before moving on to next tool') + # if we're moving on to a new call, first make sure we haven't missed anything due to JSON completions - diff: str | None = current_tool_call.get('arguments') - if diff: - diff = diff.replace(self.streamed_args_for_tool[self.current_tool_id], '') - logger.info(f'Found diff between tools: {diff}') - return DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall(arguments=diff).model_dump(exclude_none=True)) - ]) + if self.current_tool_id >= 0: + diff: str | None = current_tool_call.get('arguments') + if diff: + diff = diff.replace(self.streamed_args_for_tool[self.current_tool_id], '') + logger.info(f'Found diff between tools: {diff}') + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall(arguments=diff).model_dump(exclude_none=True)) + ]) # re-set stuff pertaining to progress in the current tool self.current_tool_id = len(tool_call_arr) - 1 @@ -219,12 +293,12 @@ def extract_tool_calls_streaming(self, self.streamed_args_for_tool.append('') logger.info('starting on new tool %d', self.current_tool_id) - # case: update an existing tool + # case: update an existing tool - this is handled below elif len(tool_call_arr) - 1 == self.current_tool_id and self.current_tool_id >= 0: # logger.info('update to tool %d', self.current_tool_id) pass - # if there is NOTHING in the array + # if there is NOTHING in the array, e.g. if only the open bracket was streamed yet else: logger.info('No tool call detected yet!') return None @@ -244,7 +318,8 @@ def extract_tool_calls_streaming(self, if function_name: logger.info(f'Sending DeltaToolCall with function name {function_name}!') delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True)) + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True)) ]) self.current_tool_name_sent = True else: @@ -266,7 +341,7 @@ def extract_tool_calls_streaming(self, delta = None elif cur_arguments and not prev_arguments: cur_arguments_json = json.dumps(cur_arguments) - logger.info(f'Finding {new_text} in |{cur_arguments_json}') + logger.info(f'Finding {new_text} in |{cur_arguments_json}|') arguments_delta = cur_arguments_json[:cur_arguments_json.index(new_text) + len(new_text)] logger.info(f'First tokens in arguments received: {arguments_delta}') delta = DeltaMessage(tool_calls=[ From 7a6f6ac7a4ab57db32bfe8781f73bb60d42f5041 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 13 Jul 2024 12:43:25 -0500 Subject: [PATCH 054/222] chore: update examples & logging --- examples/openai_chat_completion_client_with_tools.py | 9 ++++++++- vllm/entrypoints/openai/serving_chat.py | 1 + 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/openai_chat_completion_client_with_tools.py b/examples/openai_chat_completion_client_with_tools.py index 202422eb94242..2d9fce716ff85 100644 --- a/examples/openai_chat_completion_client_with_tools.py +++ b/examples/openai_chat_completion_client_with_tools.py @@ -36,7 +36,8 @@ "description": "The unit to fetch the temperature in", "enum": ["celsius", "fahrenheit"] } - } + }, + "required": ["city", "state", "unit"] } } }] @@ -88,6 +89,12 @@ arguments += chunk.choices[0].delta.tool_calls[0].function.arguments print(f'streamed tool call arguments: {arguments}\n\n') +for chunk in chunks: + if chunk.choices[0].delta.tool_calls: + print(chunk.choices[0].delta.tool_calls[0]) + +print('\n\n') + messages.append({ "role": "assistant", "tool_calls": chat_completion.choices[0].message.tool_calls diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 58a809f09d2e9..45b661c7645de 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -548,6 +548,7 @@ async def chat_completion_stream_generator( actual_call = tool_parser.streamed_args_for_tool[len(tool_parser.prev_tool_call_arr) - 1] logger.info(f'Actual tool call {actual_call}, correcting.') remaining_call = expected_call.replace(actual_call, '', 1) + logger.info(f'Remaining call: {remaining_call}') delta_message = DeltaMessage(tool_calls=[ DeltaToolCall(index=len(tool_parser.prev_tool_call_arr) - 1, function=DeltaFunctionCall( arguments=remaining_call From 9eb24520b4b1d27a09d18eee1a5d2ff7b6c974d8 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 13 Jul 2024 13:26:29 -0500 Subject: [PATCH 055/222] fix: accidentally broke non-tool streaming earlier; this fixes it --- vllm/entrypoints/openai/serving_chat.py | 9 ++++++--- vllm/entrypoints/openai/tool_parsers.py | 22 ++++++++++++++++++---- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 45b661c7645de..aa347562dd3b2 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -537,9 +537,12 @@ async def chat_completion_stream_generator( else: # check to make sure we haven't missed something on the last function call if ( - delta_message.tool_calls[0].function.arguments == '' - or delta_message.tool_calls[0].function.arguments - and (output.finish_reason == 'stop' or output.finish_reason == 'tool_calls') + delta_message.tool_calls + and ( + delta_message.tool_calls[0].function.arguments == '' + or delta_message.tool_calls[0].function.arguments + and (output.finish_reason == 'stop' or output.finish_reason == 'tool_calls') + ) ): expected_call = json.dumps( tool_parser.prev_tool_call_arr[len(tool_parser.prev_tool_call_arr) - 1].get('arguments', {}) diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index c31c5e7587fb6..7c4ca876d82ed 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -36,6 +36,7 @@ def find_common_prefix(s1: str, s2: str) -> str: def find_common_suffix(s1: str, s2: str) -> str: """ Finds a common suffix shared between two strings, if there is one. Order of arguments is NOT important. + Stops when the suffix ends OR it hits an alphanumeric character e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}' """ @@ -380,8 +381,11 @@ def extract_tool_calls_streaming(self, class Hermes2ProToolParser(ToolParser): - tool_call_start: str = '' - tool_call_end: str = '' + tool_call_start_token: str = '' + tool_call_end_token: str = '' + tool_call_start_token_id: int = 128004 + tool_call_start_token_id: int = 128011 + # regex to match between and OR between and EOS (happens sometimes :)) tool_call_regex = re.compile(r'(.*?)|(.*)', re.DOTALL) @@ -391,7 +395,7 @@ class Hermes2ProToolParser(ToolParser): def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: # sanity check; avoid unnecessary processing - if Hermes2ProToolParser.tool_call_start not in model_output: + if Hermes2ProToolParser.tool_call_start_token not in model_output: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], @@ -444,6 +448,7 @@ def __init__(self): self.current_tool_initial_sent: bool = False self.streamed_args_for_tool: List[str] = [] # map what has been streamed for each tool so far to a list + def extract_tool_calls_streaming(self, previous_text: str, current_text: str, @@ -452,4 +457,13 @@ def extract_tool_calls_streaming(self, current_token_ids: List[int], delta_token_ids: List[int] ) -> DeltaMessage: - raise NotImplementedError('Hermes2ProToolParser.extract_tool_calls_streaming has not been implemented!') + + # check to see if we should be streaming a tool call - is there a + if self.tool_call_start_token not in current_token_ids: + logger.info(f'No tool call tokens found!') + return DeltaMessage(content=delta_text) + + else: + # TODO check if we are in the middle of a tool call OR if it has passed + + return DeltaMessage(content=delta_text) \ No newline at end of file From 08bd8d08e35c92fcc1662df4a126c1a8906012e9 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 13 Jul 2024 14:13:02 -0500 Subject: [PATCH 056/222] fix: some stuff in the example, and some mistral stuff --- ...penai_chat_completion_client_with_tools.py | 26 ++++++++++++++----- vllm/entrypoints/openai/tool_parsers.py | 13 +++++++--- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/examples/openai_chat_completion_client_with_tools.py b/examples/openai_chat_completion_client_with_tools.py index 2d9fce716ff85..3d29f622e8028 100644 --- a/examples/openai_chat_completion_client_with_tools.py +++ b/examples/openai_chat_completion_client_with_tools.py @@ -53,7 +53,7 @@ }, { "role": "user", - "content": "Can you tell me what the weather will be in Dallas Texas?" + "content": "Can you tell me what the weather will be in Dallas and San Francisco?" } ] @@ -77,21 +77,33 @@ chunks = [] for chunk in tool_calls_stream: chunks.append(chunk) -arguments = '' + if chunk.choices[0].delta.tool_calls: + print(chunk.choices[0].delta.tool_calls[0]) + else: + print(chunk.choices[0].delta) + + +arguments = [] +tool_call_idx = -1 for chunk in chunks: + if chunk.choices[0].delta.tool_calls: + if chunk.choices[0].delta.tool_calls[0].index != tool_call_idx: + if tool_call_idx >= 0: + print(f'streamed tool call arguments: {arguments[tool_call_idx]}\n\n') + tool_call_idx = chunk.choices[0].delta.tool_calls[0].index + arguments.append('') if chunk.choices[0].delta.tool_calls[0].id: print(f'streamed tool call id: {chunk.choices[0].delta.tool_calls[0].id}') if chunk.choices[0].delta.tool_calls[0].function: if chunk.choices[0].delta.tool_calls[0].function.name: print(f'streamed tool call name: {chunk.choices[0].delta.tool_calls[0].function.name}') if chunk.choices[0].delta.tool_calls[0].function.arguments: - arguments += chunk.choices[0].delta.tool_calls[0].function.arguments -print(f'streamed tool call arguments: {arguments}\n\n') + arguments[tool_call_idx] += chunk.choices[0].delta.tool_calls[0].function.arguments + +if len(arguments): + print(f'streamed tool call arguments: {arguments[-1]}') -for chunk in chunks: - if chunk.choices[0].delta.tool_calls: - print(chunk.choices[0].delta.tool_calls[0]) print('\n\n') diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index 7c4ca876d82ed..ba960b4c9a829 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -280,19 +280,24 @@ def extract_tool_calls_streaming(self, if self.current_tool_id >= 0: diff: str | None = current_tool_call.get('arguments') if diff: - diff = diff.replace(self.streamed_args_for_tool[self.current_tool_id], '') + diff = json.dumps(diff).replace(self.streamed_args_for_tool[self.current_tool_id], '') logger.info(f'Found diff between tools: {diff}') - return DeltaMessage(tool_calls=[ + delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall(arguments=diff).model_dump(exclude_none=True)) ]) - + self.streamed_args_for_tool[self.current_tool_id] += diff + else: + delta = None + else: + delta = None # re-set stuff pertaining to progress in the current tool self.current_tool_id = len(tool_call_arr) - 1 self.current_tool_name_sent = False self.current_tool_initial_sent = False self.streamed_args_for_tool.append('') logger.info('starting on new tool %d', self.current_tool_id) + return delta # case: update an existing tool - this is handled below elif len(tool_call_arr) - 1 == self.current_tool_id and self.current_tool_id >= 0: @@ -465,5 +470,5 @@ def extract_tool_calls_streaming(self, else: # TODO check if we are in the middle of a tool call OR if it has passed - + return DeltaMessage(content=delta_text) \ No newline at end of file From 62b9ad49d3da0c3364e38531bde74f45e01f7f84 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 13 Jul 2024 15:03:08 -0500 Subject: [PATCH 057/222] feat: work on hermes tool parser --- ...penai_chat_completion_client_with_tools.py | 2 +- vllm/entrypoints/openai/tool_parsers.py | 30 +++++++++++++++++-- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/examples/openai_chat_completion_client_with_tools.py b/examples/openai_chat_completion_client_with_tools.py index 3d29f622e8028..085f6f5d838c0 100644 --- a/examples/openai_chat_completion_client_with_tools.py +++ b/examples/openai_chat_completion_client_with_tools.py @@ -53,7 +53,7 @@ }, { "role": "user", - "content": "Can you tell me what the weather will be in Dallas and San Francisco?" + "content": "Can you tell me what the temperate will be in Dallas and San Francisco, in fahrenheit?" } ] diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index ba960b4c9a829..3540f36386a6f 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -389,7 +389,7 @@ class Hermes2ProToolParser(ToolParser): tool_call_start_token: str = '' tool_call_end_token: str = '' tool_call_start_token_id: int = 128004 - tool_call_start_token_id: int = 128011 + tool_call_end_token_id: int = 128011 # regex to match between and OR between and EOS (happens sometimes :)) @@ -426,6 +426,8 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: ) ) for function_call in raw_function_calls ] + + # TODO extract including the scratch pad into content content_match = Hermes2ProToolParser.scratch_pad_regex.search(model_output) content = content_match.group(1) if content_match else None return ExtractedToolCallInformation( @@ -436,7 +438,6 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: except Exception as e: logger.error("Error in extracting tool call from response %s", e) - # TODO discussion on how to best handle invalidly-generated tool calls return ExtractedToolCallInformation( tools_called=False, tool_calls=[], @@ -463,12 +464,35 @@ def extract_tool_calls_streaming(self, delta_token_ids: List[int] ) -> DeltaMessage: + logger.info(f'delta_text: {delta_text}') + logger.info(f'delta_token_ids: {delta_token_ids}') # check to see if we should be streaming a tool call - is there a - if self.tool_call_start_token not in current_token_ids: + if self.tool_call_start_token_id not in current_token_ids: logger.info(f'No tool call tokens found!') return DeltaMessage(content=delta_text) else: # TODO check if we are in the middle of a tool call OR if it has passed + prev_tool_start_count = previous_token_ids.count(self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count(self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count(self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count(self.tool_call_end_token_id) + + if cur_tool_start_count > cur_tool_end_count and cur_tool_start_count > prev_tool_start_count: + logger.info('Starting a new tool call!') + + elif cur_tool_start_count > cur_tool_end_count and cur_tool_start_count == prev_tool_start_count: + logger.info('Working on an existing tool call!') + + + # TODO These are not working + elif cur_tool_start_count == cur_tool_end_count and cur_tool_end_count > prev_tool_end_count: + logger.info('Closing the current tool call!') + + elif cur_tool_start_count == cur_tool_end_count and prev_tool_end_count == cur_tool_end_count: + logger.info('Generating text content!') + + else: + logger.info('INVARIANT') return DeltaMessage(content=delta_text) \ No newline at end of file From 6e537870b336a814be28eb20531182fd2ec7b6b5 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 17 Jul 2024 00:09:06 -0500 Subject: [PATCH 058/222] fix(serving_chat): issue with deep vs. shallow copy caused bug where current_token_ids == previous_token_ids; fixed it. --- vllm/entrypoints/openai/serving_chat.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index aa347562dd3b2..1ff9029039cb7 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -389,7 +389,6 @@ async def chat_completion_stream_generator( assert request.n is not None previous_texts = [""] * request.n previous_num_tokens = [0] * request.n - previous_token_ids = [[]] * request.n finish_reason_sent = [False] * request.n tool_parser: ToolParser = self.tool_parser() @@ -498,7 +497,7 @@ async def chat_completion_stream_generator( previous_text=previous_texts[i], current_text=output.text, delta_text=delta_text, - previous_token_ids=previous_token_ids[i], + previous_token_ids=output.token_ids[:-1 * len(delta_token_ids)], current_token_ids=output.token_ids, delta_token_ids=delta_token_ids ) @@ -508,7 +507,6 @@ async def chat_completion_stream_generator( # handle setting the previous values for the next iteration previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) - previous_token_ids[i] = output.token_ids # if the message delta is None (e.g. because it was a "control token" for tool calls, then # get the next token without streaming a chunk From 26b97dcd3630fc97ecae41713cbf8f0de282465b Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 17 Jul 2024 01:06:31 -0500 Subject: [PATCH 059/222] feat: change ordering in hermes chat template so that function name is generated before args so that function name can be streamed first --- examples/tool_chat_template_hermes_2_pro.jinja | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/tool_chat_template_hermes_2_pro.jinja b/examples/tool_chat_template_hermes_2_pro.jinja index 61192c02866a0..205f1fc55f7d8 100644 --- a/examples/tool_chat_template_hermes_2_pro.jinja +++ b/examples/tool_chat_template_hermes_2_pro.jinja @@ -79,7 +79,7 @@ " }} {{- " " }} -{{- '{"arguments": , "name": } +{{- '{"name": , "arguments": } ' }} {{- '<|im_end|>' }} {%- for message in messages %} From bfd10396b6a66b89988a8fe9e0f75145f6cc38b3 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 17 Jul 2024 02:28:38 -0500 Subject: [PATCH 060/222] feat(tool_parsers): hermes 2 pro streaming parser --- vllm/entrypoints/openai/tool_parsers.py | 174 ++++++++++++++++++++---- 1 file changed, 150 insertions(+), 24 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index 3540f36386a6f..dc807a8d1fafa 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -1,5 +1,5 @@ from vllm.entrypoints.openai.protocol import ToolCall, FunctionCall, ChatCompletionResponse, \ - ExtractedToolCallInformation, DeltaToolCall, InitialDeltaToolCall, DeltaFunctionCall + ExtractedToolCallInformation, DeltaToolCall, InitialDeltaToolCall, DeltaFunctionCall, DeltaMessage from vllm.logger import init_logger from typing import List, Dict from transformers import (AutoTokenizer, PreTrainedTokenizer, @@ -163,7 +163,7 @@ class MistralToolParser(ToolParser): # the bot_token is the token indicating tool call(s) follow. Tokens before this token will be parsed as content; and # if not present, the entire response will be parsed as text content. bot_token: str = '[TOOL_CALLS]' # string literal - bot_token_id: int = 5 # token ID thereof from the models' tokenizer + bot_token_id: int = 5 # token ID thereof from the models' tokenizer @staticmethod def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: @@ -381,7 +381,7 @@ def extract_tool_calls_streaming(self, except Exception as e: logger.error(f'Error trying to handle streaming tool call: {e}') - logger.info('skipping returning a chunk here - maybe we just need more?') + logger.info(f'Skipping chunk as a result of tool streaming extraction error') return None @@ -391,7 +391,6 @@ class Hermes2ProToolParser(ToolParser): tool_call_start_token_id: int = 128004 tool_call_end_token_id: int = 128011 - # regex to match between and OR between and EOS (happens sometimes :)) tool_call_regex = re.compile(r'(.*?)|(.*)', re.DOTALL) scratch_pad_regex = re.compile(r'(.*?)', re.DOTALL) @@ -446,7 +445,6 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: def __init__(self): super().__init__() - self.current_tool_count: int = 0 self.current_tool_name_sent: bool = False # reset each time we encounter a new tool in the array self.prev_tool_call_arr: List[Dict] = [] self.current_tool_id: int = -1 @@ -454,7 +452,6 @@ def __init__(self): self.current_tool_initial_sent: bool = False self.streamed_args_for_tool: List[str] = [] # map what has been streamed for each tool so far to a list - def extract_tool_calls_streaming(self, previous_text: str, current_text: str, @@ -462,7 +459,7 @@ def extract_tool_calls_streaming(self, previous_token_ids: List[int], current_token_ids: List[int], delta_token_ids: List[int] - ) -> DeltaMessage: + ) -> DeltaMessage | None: logger.info(f'delta_text: {delta_text}') logger.info(f'delta_token_ids: {delta_token_ids}') @@ -472,27 +469,156 @@ def extract_tool_calls_streaming(self, return DeltaMessage(content=delta_text) else: - # TODO check if we are in the middle of a tool call OR if it has passed - prev_tool_start_count = previous_token_ids.count(self.tool_call_start_token_id) - prev_tool_end_count = previous_token_ids.count(self.tool_call_end_token_id) - cur_tool_start_count = current_token_ids.count(self.tool_call_start_token_id) - cur_tool_end_count = current_token_ids.count(self.tool_call_end_token_id) + try: - if cur_tool_start_count > cur_tool_end_count and cur_tool_start_count > prev_tool_start_count: - logger.info('Starting a new tool call!') + # figure out where we are in the parsing by counting tool call start & end tags + prev_tool_start_count = previous_token_ids.count(self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count(self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count(self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count(self.tool_call_end_token_id) - elif cur_tool_start_count > cur_tool_end_count and cur_tool_start_count == prev_tool_start_count: - logger.info('Working on an existing tool call!') + # a cheap case - we're generating text, NOT tool calls. + if cur_tool_start_count == cur_tool_end_count and prev_tool_end_count == cur_tool_end_count: + logger.info('Generating text content! skipping tool parsing.') + return DeltaMessage(content=delta_text) + # most of the time, we're going in here - we need to do partial JSON parsing and build stuff. + else: + # flags for partial JSON parting. exported constants from "Allow" are handled via BIT MASK + # generally, we don't allow sending an incomplete function name. so we don't allow + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR + + # if a new tool call is being started. unusual since normally the first "cheap case" will be hit. + if cur_tool_start_count > cur_tool_end_count and cur_tool_start_count > prev_tool_start_count: + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split(self.tool_call_start_token)[-1] + text_portion = None + else: + tool_call_portion = None + text_portion = None + delta = None - # TODO These are not working - elif cur_tool_start_count == cur_tool_end_count and cur_tool_end_count > prev_tool_end_count: - logger.info('Closing the current tool call!') + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.current_tool_initial_sent = False + self.streamed_args_for_tool.append('') + logger.info(f'Starting on a new tool {self.current_tool_id}') + + # if an existing tool call is being updated - the most common case! + elif cur_tool_start_count > cur_tool_end_count and cur_tool_start_count == prev_tool_start_count: + tool_call_portion = current_text.split(self.tool_call_start_token)[-1] + text_portion = None + + # if the current tool call is being closed + elif cur_tool_start_count == cur_tool_end_count and cur_tool_end_count > prev_tool_end_count: + logger.info('Closing the current tool call!') + diff = self.prev_tool_call_arr[self.current_tool_id].get('arguments') + if diff: + diff = json.dumps(diff).replace(self.streamed_args_for_tool[self.current_tool_id], '') + logger.info(f'Finishing tool and found diff that wasn\'t streamed yet: {diff}') + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( + arguments=diff + ).model_dump(exclude_none=True)) + ]) - elif cur_tool_start_count == cur_tool_end_count and prev_tool_end_count == cur_tool_end_count: - logger.info('Generating text content!') + else: + logger.error('INVARIANT - invalid state trying to parse tool calls (wtf?)') + delta = None + return delta + + logger.info(f'Tool call portion: {tool_call_portion}') + current_tool_call = partial_json_parser.loads(tool_call_portion, flags) if tool_call_portion else None + logger.info(f'Parsed tool call {current_tool_call}') + + # make sure to send the initial message first if we haven't already - with the tool ID + if not self.current_tool_initial_sent: + logger.info('Sending InitialDeltaToolCall') + self.current_tool_initial_sent = True + return DeltaMessage( + tool_calls=[ + InitialDeltaToolCall(index=self.current_tool_id).model_dump(exclude_none=True) + ] + ) - else: - logger.info('INVARIANT') + # after that, make sure we send the function name before any arguments + elif not self.current_tool_name_sent: + function_name: str | None = current_tool_call.get('name') + if function_name: + logger.info(f'Sending DeltaToolCall with function name {function_name}!') + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True) + )]) + else: + return None + else: + # if there is no tool calls + if tool_call_portion is None: + # if there's text but not tool calls, send that - otherwise None to skip chunk + delta = DeltaMessage(content=delta_text) if text_portion is not None else None + # now, the nitty-gritty of tool calls + else: + # now we have the portion to parse as tool call. + if text_portion is not None: + logger.info(f'Also, will send text portion {text_portion}') + + logger.info(f'Trying to parse current tool call with ID {self.current_tool_id}') + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + logger.info('Pushed dummy value into tool call arr') + # main logic for tool parsing here + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get('arguments') # previous arguments for this tool + cur_arguments = current_tool_call.get('arguments') # arguments, if any, in current dict + + logger.info(f'Diffing old arguments {prev_arguments} against new ones {cur_arguments}') + if not cur_arguments and not prev_arguments: + logger.info(f'Skipping text {delta_text} - no arguments!') + delta = None + elif not cur_arguments and prev_arguments: + logger.error('INVARIANT - impossible to have arguments reset mid-call') + delta = None + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments) + logger.info(f'Finding {delta_text} in {cur_arguments_json}') + arguments_delta = cur_arguments_json[:cur_arguments_json.index(delta_text) + len(delta_text)] + logger.info(f'First tokens in arguments received: {arguments_delta}') + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( + arguments=arguments_delta + ).model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[self.current_tool_id] += arguments_delta + + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + logger.info(f"Searching for diff between \n{cur_args_json}\n{prev_args_json}") + argument_diff = extract_intermediate_diff(cur_args_json, prev_args_json) + logger.info(f'Got argument diff: {argument_diff}') + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( + arguments=argument_diff + ).model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[self.current_tool_id] += argument_diff + else: + delta = None + + # handle saving the state for the current tool into the "prev" list for use in diffing for + # the next iteration + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[self.current_tool_id] = current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) + + # TODO REPLACE ME WITH TOOL CALL + #delta = DeltaMessage(content=delta_text) + return delta - return DeltaMessage(content=delta_text) \ No newline at end of file + except Exception as e: + logger.error(f'Error trying to handle streaming tool call: {e}') + logger.info(f'Skipping chunk as a result of tool streaming extraction error') + return None # do not stream a delta. skip this token ID. From 2ccb893aacb4d8986048a450c4b4ca9a3ea55638 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 17 Jul 2024 20:21:18 -0500 Subject: [PATCH 061/222] fix(docs): some type issue that the doc CI check did not like --- vllm/entrypoints/openai/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index b6d532bf5e423..597f7e1afd155 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -688,7 +688,7 @@ class ExtractedToolCallInformation(BaseModel): # content - per OpenAI spec, content AND tool calls can be returned ALTHOUGH THIS IS VERY RARE # But some models will do this intentionally - content: str | None + content: Union[str, None] class ChatMessage(OpenAIBaseModel): From 45962f9766ecf7102f4ab33ea18b375c51d5da8a Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 17 Jul 2024 20:32:37 -0500 Subject: [PATCH 062/222] fix(docs): some type issue that the doc CI check did not like --- vllm/entrypoints/openai/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 597f7e1afd155..6809572051745 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -688,7 +688,7 @@ class ExtractedToolCallInformation(BaseModel): # content - per OpenAI spec, content AND tool calls can be returned ALTHOUGH THIS IS VERY RARE # But some models will do this intentionally - content: Union[str, None] + content: Optional[str | None] class ChatMessage(OpenAIBaseModel): From de27e655b011596f3b727b9df182eca535ac54f2 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 17 Jul 2024 20:59:56 -0500 Subject: [PATCH 063/222] fix(types): try Optional[str] = None --- vllm/entrypoints/openai/protocol.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 6809572051745..8ef45e8e07391 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -688,12 +688,12 @@ class ExtractedToolCallInformation(BaseModel): # content - per OpenAI spec, content AND tool calls can be returned ALTHOUGH THIS IS VERY RARE # But some models will do this intentionally - content: Optional[str | None] + content: Optional[str] = None class ChatMessage(OpenAIBaseModel): role: str - content: Optional[str | None] + content: Optional[str] = None tool_calls: List[ToolCall] = Field(default_factory=list) From 862078b31340f58ab050171de9a05dca07aa22c8 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 17 Jul 2024 22:41:20 -0500 Subject: [PATCH 064/222] fix: refactor tool chat template and add docs --- .../serving/openai_compatible_server.md | 59 +++++++++++++++++-- ....jinja => tool_chat_template_hermes.jinja} | 0 vllm/entrypoints/openai/serving_chat.py | 35 ++++++++--- 3 files changed, 80 insertions(+), 14 deletions(-) rename examples/{tool_chat_template_hermes_2_pro.jinja => tool_chat_template_hermes.jinja} (100%) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 6248d84683753..33decd5a91140 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -112,14 +112,61 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/) :func: make_arg_parser :prog: -m vllm.entrypoints.openai.api_server ``` +## Tool Calling in the Chat Completion API +### Named Function Calling +vLLM supports only named function calling in the chat completion API by default. It does so using Outlines, so this is +enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a +high-quality one. -## Tool calling in the chat completion API -vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap. +To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and +specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request. -To use a named function you need to define the function in the `tools` parameter and call it in the `tool_choice` parameter. - -It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. **This may change in the future.** +It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter. -Please refer to the OpenAI API reference documentation for more information. + +### Automatic Function Calling +_This feature is in **beta**. It has limited model support, is not guaranteed to be stable, and does not have +well-defined failure modes._ As such, it must be explicitly enabled when desired. + +To enable this feature, you must set the following flags: +* `--enable-api-tools` -- **mandatory** for Auto tool choice. tells vLLM that you want to enable tool templating and extraction. +* `--enable-auto-toolchoice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its' own tool scalls when it +deems appropriate. +* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages +that contain previously generated tool calls.This argument can be set to `tool_use` if your model has a tool use chat +template configured in the `tokenizer_config.json`. In this case, it will be used per the `transformers` specification. More on this [here](https://huggingface.co/docs/transformers/en/chat_templating#why-do-some-models-have-multiple-templates) +from HuggingFace; and you can find an example of this in a `tokenizer_config.json` [here]() +* `--tool-parser` -- select the tool parser to use - currently either `hermes` or `mistral`. + +If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template! + +#### Hermes Models +Supported models in this series: +* `NousResearch/Hermes-2-Pro-Llama-3-8B` +* `NousResearch/Hermes-2-Theta-Llama-3-70B` +* `NousResearch/Hermes-2-Pro-Llama-3-70B` +* `NousResearch/Hermes-2-Theta-Llama-3-8B` +* `NousResearch/Hermes-2-Pro-Mistral-7B` + +_Note that the Hermes 2 **Theta** models are known to have degraded tool call quality & capabilities due to the merge +step in their creation_. It is recommended to use the Hermes 2 **Pro** models. + +Recommended flags: `--tool-parser hermes --chat-template examples/tool_chat_template_hermes.jinja` + +#### Mistral Models +Supported models: +* `mistralai/Mistral-7B-Instruct-v0.3` + +There are several known issues with tool-calling in Mistral models: +* Attempting to generate > 1 tool call at a time usually results in a parser failure, since the model generates the calls +in an unpredictable format due to the aforementioned chat template issue. +* Mistral function-calling / tool use generates calls with _single_ quotes `'` instead of double quotes `"`. As a +result, tool call generations can't be handled as JSON by the parser automatically without using `eval`, which would +present security issues for vLLM users. As a result, to support Mistral tool calls, we find-and-replace single-quotes +with double-quotes in mistral-generated tool calls. Therefore, **it is important to ensure that your tool call +arguments do not contain single quotes.** Escaped double quotes may be handled properly, but otherwise you should +expect parser issues. + +Recommended flags: `--tool-parser mistral --chat-template examples/tool_chat_template_mistral.jinja` diff --git a/examples/tool_chat_template_hermes_2_pro.jinja b/examples/tool_chat_template_hermes.jinja similarity index 100% rename from examples/tool_chat_template_hermes_2_pro.jinja rename to examples/tool_chat_template_hermes.jinja diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 1ff9029039cb7..238a593d1bcbb 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -83,6 +83,7 @@ def __init__(self, lora_modules=lora_modules) self.response_role = response_role + self.use_tool_use_model_template = False self._load_chat_template(chat_template) # set up tool use @@ -102,6 +103,12 @@ def _load_chat_template(self, chat_template: Optional[str]): tokenizer = self.tokenizer if chat_template is not None: + + # if `tool_use` is supplied; try to use it as the tool use name for the template + if chat_template == 'tool_use': + self.use_tool_use_model_template = True + logger.info('The "tool_use" chat template was specified. This will be loaded from tokenizer_config.json. Expect runtime errors if this is not present!') + return try: with open(chat_template, "r") as f: tokenizer.chat_template = f.read() @@ -121,8 +128,10 @@ def _load_chat_template(self, chat_template: Optional[str]): logger.info("Using supplied chat template:\n%s", tokenizer.chat_template) elif tokenizer.chat_template is not None: - logger.info("Using default chat template:\n%s", - tokenizer.chat_template) + if self.enable_auto_tools: + logger.info('Trying to find a tool_use chat template in tokenizer_config.json! Will use default template otherwise.') + else: + logger.info("Using default chat template:\n%s",tokenizer.chat_template) else: logger.warning( "No chat template provided. Chat API will not work.") @@ -275,12 +284,22 @@ async def create_chat_completion( if self.enable_auto_tools and request.tools: tools = [tool.model_dump() for tool in request.tools] - prompt = self.tokenizer.apply_chat_template( - conversation=conversation, - tokenize=False, - add_generation_prompt=request.add_generation_prompt, - tools=tools - ) + # default. use the pre-loaded template not the "tool_use" template option introduced by huggingface + if not self.use_tool_use_model_template: + prompt = self.tokenizer.apply_chat_template( + conversation=conversation, + tokenize=False, + add_generation_prompt=request.add_generation_prompt, + tools=tools + ) + else: + prompt = self.tokenizer.apply_chat_template( + conversation=conversation, + tokenize=False, + add_generation_prompt=request.add_generation_prompt, + tools=tools, + chat_template='tool_use' + ) except Exception as e: logger.error("Error in applying chat template from request: %s", e) From f0569272ca2591760e87cf58825947df6a8e932a Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 31 Jul 2024 23:45:31 -0500 Subject: [PATCH 065/222] chore: annotate that parallel_tool_calls will be ignored --- vllm/entrypoints/openai/protocol.py | 1 + vllm/entrypoints/openai/serving_chat.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 8ef45e8e07391..66d31b82c39ac 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -161,6 +161,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ChatCompletionNamedToolChoiceParam ] ] = "none" + parallel_tool_calls: Optional[bool] = False # NOTE this will be ignored by VLLM as the behavior is determined by the model user: Optional[str] = None # doc: begin-chat-completion-sampling-params diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 238a593d1bcbb..df86acee528a0 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -88,6 +88,9 @@ def __init__(self, # set up tool use self.enable_auto_tools: bool = enable_auto_tools + if self.enable_auto_tools: + logger.info('"Auto" tool choice has been enabled please note that while the parallel_tool_calls client ' + 'option is preset for compatibility reasons, it will be ignored.') if self.enable_auto_tools and not tool_parser: raise TypeError('Error: --enable-auto-tool-choice requires --tool-choice-parser') From 9560591291da958051f848176bbf397c3f3fbbb2 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 31 Jul 2024 23:51:47 -0500 Subject: [PATCH 066/222] fix: handle un-handled "theoretically unreachable" case because such things are important --- vllm/entrypoints/openai/serving_chat.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index df86acee528a0..2672e78221c97 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -663,7 +663,9 @@ async def chat_completion_full_generator( else: logprobs = None + # by default, tools are not used. tools_called = False + # if the reqeust uses tools and specified a tool choice if request.tool_choice and type( request.tool_choice) is ChatCompletionNamedToolChoiceParam: @@ -695,6 +697,11 @@ async def chat_completion_full_generator( # FOR NOW make it a chat message; we will have to detect the type to make it later. message = ChatMessage(role=role, content=output.text) + # undetermined case that is still important to handle + else: + logger.error('Error in chat_completion_full_generator - cannot determine if tools shouuld ' + 'be extracted. Returning a standard chat completion.') + message = ChatMessage(role=role, content=output.text) choice_data = ChatCompletionResponseChoice( index=output.index, From b2ceb71f0d9e4c2c6eea748196c8f3c15153d40f Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 31 Jul 2024 23:55:27 -0500 Subject: [PATCH 067/222] fix: hermes tool call template to omit tool-use system prompt if tools are not specified --- examples/tool_chat_template_hermes.jinja | 134 ++++++++++++----------- 1 file changed, 68 insertions(+), 66 deletions(-) diff --git a/examples/tool_chat_template_hermes.jinja b/examples/tool_chat_template_hermes.jinja index 205f1fc55f7d8..d807c12dbba16 100644 --- a/examples/tool_chat_template_hermes.jinja +++ b/examples/tool_chat_template_hermes.jinja @@ -1,87 +1,89 @@ {%- macro json_to_python_type(json_spec) %} -{%- set basic_type_map = { + {%- set basic_type_map = { "string": "str", "number": "float", "integer": "int", "boolean": "bool" } %} -{%- if basic_type_map[json_spec.type] is defined %} - {{- basic_type_map[json_spec.type] }} -{%- elif json_spec.type == "array" %} - {{- "list[" + json_to_python_type(json_spec|items) + "]"}} -{%- elif json_spec.type == "object" %} - {%- if json_spec.additionalProperties is defined %} - {{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']'}} + {%- if basic_type_map[json_spec.type] is defined %} + {{- basic_type_map[json_spec.type] }} + {%- elif json_spec.type == "array" %} + {{- "list[" + json_to_python_type(json_spec|items) + "]" }} + {%- elif json_spec.type == "object" %} + {%- if json_spec.additionalProperties is defined %} + {{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']' }} + {%- else %} + {{- "dict" }} + {%- endif %} + {%- elif json_spec.type is iterable %} + {{- "Union[" }} + {%- for t in json_spec.type %} + {{- json_to_python_type({"type": t}) }} + {%- if not loop.last %} + {{- "," }} + {%- endif %} + {%- endfor %} + {{- "]" }} {%- else %} - {{- "dict" }} + {{- "Any" }} {%- endif %} -{%- elif json_spec.type is iterable %} - {{- "Union[" }} - {%- for t in json_spec.type %} - {{- json_to_python_type({"type": t}) }} - {%- if not loop.last %} - {{- "," }} - {%- endif %} - {%- endfor %} - {{- "]" }} -{%- else %} - {{- "Any" }} -{%- endif %} {%- endmacro %} {{- bos_token }} -{{- "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }} -{%- for tool in tools %} - {%- if tool.function is defined %} - {%- set tool = tool.function %} - {%- endif %} - {{- '{"type": "function", "function": ' }} - {{- '{"name": ' + tool.name + '", ' }} - {{- '"description": "' + tool.name + '(' }} - {%- for param_name, param_fields in tool.parameters.properties|items %} - {{- param_name + ": " + json_to_python_type(param_fields) }} - {%- if not loop.last %} - {{- ", " }} +{%- if tools is iterable and tools | length > 0 %} + {{- "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }} + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} {%- endif %} - {%- endfor %} - {{- ")" }} - {%- if tool.return is defined %} - {{- " -> " + json_to_python_type(tool.return) }} - {%- endif %} - {{- " - " + tool.description + "\n\n" }} - {%- for param_name, param_fields in tool.parameters.properties|items %} - {%- if loop.first %} - {{- " Args:\n" }} + {{- '{"type": "function", "function": ' }} + {{- '{"name": ' + tool.name + '", ' }} + {{- '"description": "' + tool.name + '(' }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- param_name + ": " + json_to_python_type(param_fields) }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- if tool.return is defined %} + {{- " -> " + json_to_python_type(tool.return) }} + {%- endif %} + {{- " - " + tool.description + "\n\n" }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {%- if loop.first %} + {{- " Args:\n" }} + {%- endif %} + {{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }} + {%- endfor %} + {%- if tool.return is defined and tool.return.description is defined %} + {{- "\n Returns:\n " + tool.return.description }} + {%- endif %} + {{- '"' }} + {{- ', "parameters": ' }} + {%- if tool.parameters.properties | length == 0 %} + {{- "{}" }} + {%- else %} + {{- tool.parameters|tojson }} + {%- endif %} + {{- "}" }} + {%- if not loop.last %} + {{- "\n" }} {%- endif %} - {{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }} {%- endfor %} - {%- if tool.return is defined and tool.return.description is defined %} - {{- "\n Returns:\n " + tool.return.description }} - {%- endif %} - {{- '"' }} - {{- ', "parameters": ' }} - {%- if tool.parameters.properties | length == 0 %} - {{- "{}" }} - {%- else %} - {{- tool.parameters|tojson }} - {%- endif %} - {{- "}" }} - {%- if not loop.last %} - {{- "\n" }} - {%- endif %} -{%- endfor %} -{{- " " }} -{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} + {{- " " }} + {{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} ' }} -{{- "For each function call return a json object with function name and arguments within XML tags as follows: + {{- "For each function call return a json object with function name and arguments within XML tags as follows: " }} -{{- " + {{- " " }} -{{- '{"name": , "arguments": } + {{- '{"name": , "arguments": } ' }} -{{- '<|im_end|>' }} + {{- '<|im_end|>' }} +{%- endif %} {%- for message in messages %} {%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %} {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} @@ -95,7 +97,7 @@ {%- if tool_call.arguments is defined %} {{- '"arguments": ' }} {{- tool_call.arguments|tojson }} - {{- ', '}} + {{- ', ' }} {%- endif %} {{- '"name": "' }} {{- tool_call.name }} @@ -112,7 +114,7 @@ {{- message.name }} {{- '", "content": ' }} {{- message.content|tojson + '}' }} - {{- '\n <|im_end|>\n' }} + {{- '\n <|im_end|>\n' }} {%- endif %} {%- endfor %} {%- if add_generation_prompt %} From 558461a74b7b4792e471e28438828b27eedb73a4 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 1 Aug 2024 00:19:57 -0500 Subject: [PATCH 068/222] fix: implement access to tool call token IDs via tokenizer vocab in tool call parser non-static methods --- vllm/entrypoints/openai/serving_chat.py | 2 +- vllm/entrypoints/openai/tool_parsers.py | 35 +++++++++++++++++++------ 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 2672e78221c97..ff221e1b09957 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -413,7 +413,7 @@ async def chat_completion_stream_generator( previous_num_tokens = [0] * request.n finish_reason_sent = [False] * request.n - tool_parser: ToolParser = self.tool_parser() + tool_parser: ToolParser = self.tool_parser(self.tokenizer) try: async for res in result_generator: # We need to do it here, because if there are exceptions in diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index dc807a8d1fafa..298f2539e39d7 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -1,7 +1,7 @@ from vllm.entrypoints.openai.protocol import ToolCall, FunctionCall, ChatCompletionResponse, \ ExtractedToolCallInformation, DeltaToolCall, InitialDeltaToolCall, DeltaFunctionCall, DeltaMessage from vllm.logger import init_logger -from typing import List, Dict +from typing import List, Dict, Optional from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) import json @@ -106,7 +106,10 @@ class ToolParser: derived classes. """ - def __init__(self): + def __init__( + self, + tokenizer: Optional[PreTrainedTokenizer | PreTrainedTokenizerFast | PreTrainedTokenizerFast | AutoTokenizer]=None + ): # the tool call array derived from partial JSON parsing from the previous execution of the function self.prev_tool_call_arr: List[Dict] = [] # the index of the tool call that is currently being parsed @@ -120,6 +123,8 @@ def __init__(self): # array of the argument strings (one for each tool) that have been streamed to the client. self.streamed_args_for_tool: List[str] = [] # map what has been streamed for each tool so far to a list + self.model_tokenizer = tokenizer + @staticmethod def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: """ @@ -215,8 +220,12 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: content=model_output ) - def __init__(self): - super().__init__() + def __init__( + self, + tokenizer: Optional[ + PreTrainedTokenizer | PreTrainedTokenizerFast | PreTrainedTokenizerFast | AutoTokenizer] = None + ): + super().__init__(tokenizer) # initialize properties used for state when parsing tool calls in streaming mode self.prev_tool_call_arr: List[Dict] = [] @@ -388,8 +397,7 @@ def extract_tool_calls_streaming(self, class Hermes2ProToolParser(ToolParser): tool_call_start_token: str = '' tool_call_end_token: str = '' - tool_call_start_token_id: int = 128004 - tool_call_end_token_id: int = 128011 + # regex to match between and OR between and EOS (happens sometimes :)) tool_call_regex = re.compile(r'(.*?)|(.*)', re.DOTALL) @@ -443,8 +451,12 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: content=model_output ) - def __init__(self): - super().__init__() + def __init__( + self, + tokenizer: Optional[ + PreTrainedTokenizer | PreTrainedTokenizerFast | PreTrainedTokenizerFast | AutoTokenizer] = None + ): + super().__init__(tokenizer) self.current_tool_name_sent: bool = False # reset each time we encounter a new tool in the array self.prev_tool_call_arr: List[Dict] = [] self.current_tool_id: int = -1 @@ -452,6 +464,13 @@ def __init__(self): self.current_tool_initial_sent: bool = False self.streamed_args_for_tool: List[str] = [] # map what has been streamed for each tool so far to a list + if not self.model_tokenizer: + raise ValueError('The model tokenizer must be passed to the ToolParser constructor during construction.') + self.tool_call_start_token_id: int = self.model_tokenizer.vocab[''] + self.tool_call_end_token_id: int = self.model_tokenizer.vocab[''] + if not self.tool_call_start_token_id or not self.tool_call_end_token_id: + raise RuntimeError('Hermes 2 Pro Tool parser could not locate tool call start/end tokens in the tokenizer!') + def extract_tool_calls_streaming(self, previous_text: str, current_text: str, From f7f15faede1dcccbbed709efbddaf4dc659c84e0 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 1 Aug 2024 00:23:54 -0500 Subject: [PATCH 069/222] fix: hermes tool parser does not extract non-tool-call content the same way it does in streaming as in non-streaming --- vllm/entrypoints/openai/tool_parsers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index 298f2539e39d7..f3c54e4f18732 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -434,9 +434,7 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: ) for function_call in raw_function_calls ] - # TODO extract including the scratch pad into content - content_match = Hermes2ProToolParser.scratch_pad_regex.search(model_output) - content = content_match.group(1) if content_match else None + content = model_output[:model_output.find(Hermes2ProToolParser.tool_call_start_token)] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, From 4356ec4dde3f544020a674f1f65a91710feb319f Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 1 Aug 2024 00:27:30 -0500 Subject: [PATCH 070/222] fix: grab the name properly from chat completions and fall back to empty string since its optional --- vllm/entrypoints/openai/serving_chat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index ff221e1b09957..aefdf41089a84 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -228,7 +228,8 @@ def _parse_chat_message_content( content = message.get("content") tool_call_id = message.get('tool_call_id') tool_calls = message.get('tool_calls') - name = message.get('tool_calls') + # "name" is optional now per OAI spec, so clients using models that need it, should make sure to pass it! + name = message.get('name', '') # invariant if content is None and tool_calls is None: From 0abeb536060db1ce0b47653517b77c257cdff3d9 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 1 Aug 2024 00:28:48 -0500 Subject: [PATCH 071/222] fix: type annotation --- vllm/entrypoints/openai/serving_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index aefdf41089a84..97619bf620a53 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -54,7 +54,7 @@ class ConversationMessage(TypedDict): role: str content: Optional[str] # optional IFF tool_calls is specified tool_call_id: Optional[str] - name: str | None + name: Optional[str] tool_calls: Optional[List] From 9c0e6d8a67f7d05b19463e249596c29652ac1805 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 1 Aug 2024 00:58:09 -0500 Subject: [PATCH 072/222] fix: mistral tool extraction when dealing with poor precision --- vllm/entrypoints/openai/tool_parsers.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index f3c54e4f18732..ea3724453265f 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -106,6 +106,7 @@ class ToolParser: derived classes. """ + def __init__( self, tokenizer: Optional[PreTrainedTokenizer | PreTrainedTokenizerFast | PreTrainedTokenizerFast | AutoTokenizer]=None @@ -169,6 +170,8 @@ class MistralToolParser(ToolParser): # if not present, the entire response will be parsed as text content. bot_token: str = '[TOOL_CALLS]' # string literal bot_token_id: int = 5 # token ID thereof from the models' tokenizer + tool_call_regex = re.compile(r'\[{.*?}\]', re.DOTALL) + @staticmethod def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: @@ -177,6 +180,8 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: quotes for JSON parsing, make sure your tool call arguments don't ever include quotes! """ + logger.info('Trying to extract mistral tool calls from the following:') + logger.info(model_output) # Get the tool call token from the tokenizer if MistralToolParser.bot_token not in model_output: return ExtractedToolCallInformation( @@ -186,10 +191,14 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: ) else: try: - # extract the token so we hopefully have a JSON string - raw_tool_call = (model_output - .replace(MistralToolParser.bot_token, '') # remove BOT token - .replace("'", '"')) # ... hack to parse broken mistral JSON + + # this will throw an exception if we can't find the tool call properly + raw_tool_call = MistralToolParser.tool_call_regex.findall( + model_output + .replace(MistralToolParser.bot_token, '') # remove BOT token + .replace("'", '"') # replace string quotes + )[0] + # load the JSON, and then use it to build the Function and Tool Call function_call_arr = json.loads(raw_tool_call) tool_calls: List[ToolCall] = [ From 45ecd68f463da3c437109c2433980156a580427b Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 1 Aug 2024 01:26:01 -0500 Subject: [PATCH 073/222] doc: indicate that temperature should be set to 0 when doing mistral tool calls --- docs/source/serving/openai_compatible_server.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 33decd5a91140..0270abec36b3e 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -161,7 +161,9 @@ Supported models: There are several known issues with tool-calling in Mistral models: * Attempting to generate > 1 tool call at a time usually results in a parser failure, since the model generates the calls -in an unpredictable format due to the aforementioned chat template issue. +in an unpredictable format due to the aforementioned chat template issue. **This can be mitigated by setting the +`temperature` to `0` in the OpenAI-style API call** - do this, and tool calls (including parallel ones) are **far** more +consistent * Mistral function-calling / tool use generates calls with _single_ quotes `'` instead of double quotes `"`. As a result, tool call generations can't be handled as JSON by the parser automatically without using `eval`, which would present security issues for vLLM users. As a result, to support Mistral tool calls, we find-and-replace single-quotes From f63efe8ea6db963203d7204401df11b4ef15b12b Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 1 Aug 2024 23:36:37 -0500 Subject: [PATCH 074/222] fix: log levels --- vllm/entrypoints/openai/serving_chat.py | 2 - vllm/entrypoints/openai/tool_parsers.py | 81 +++++++++++-------------- 2 files changed, 35 insertions(+), 48 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index d186b4677cd3d..36d7f4537cee6 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -105,8 +105,6 @@ async def create_chat_completion( for the API specification. This API mimics the OpenAI ChatCompletion API. - NOTE: Currently we do not support the following feature: - - function_call (Users should implement this by themselves) """ error_check_ret = await self._check_model(request) if error_check_ret is not None: diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index ea3724453265f..8735a69a9269f 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -67,20 +67,16 @@ def extract_intermediate_diff(curr: str, old: str) -> str: '", "city": "D' """ suffix = find_common_suffix(curr, old) - logger.info(f'Found suffix {suffix}') # prevent double-counting s2_old = old old = old[::-1].replace(suffix[::-1], '', 1)[::-1] - logger.info(f'Updated search term s2 from {s2_old} to {old}') prefix = find_common_prefix(curr, old) diff = curr if len(suffix): - logger.info(f'Nuking suffix {suffix}') diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1] if len(prefix): - logger.info(f'Nuking prefix {prefix}') diff = diff.replace(prefix, '', 1) # replace the prefix only once in case it's mirrored return diff @@ -180,8 +176,8 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: quotes for JSON parsing, make sure your tool call arguments don't ever include quotes! """ - logger.info('Trying to extract mistral tool calls from the following:') - logger.info(model_output) + logger.debug('Trying to extract mistral tool calls from the following:') + logger.debug(model_output) # Get the tool call token from the tokenizer if MistralToolParser.bot_token not in model_output: return ExtractedToolCallInformation( @@ -261,8 +257,6 @@ def extract_tool_calls_streaming(self, # handle if we detected the BOT token which means the start of tool calling if self.bot_token_id in delta_token_ids: - logger.info('Found bot_token!') - # if it's the only token, return None, so we don't send a chat completion any don't send a control token if len(delta_token_ids) == 1: return None @@ -278,28 +272,23 @@ def extract_tool_calls_streaming(self, tool_call_message_portion = current_text.split(self.bot_token)[1] parsable_arr = tool_call_message_portion.replace('\'', '"') - logger.info('parsing: %s', parsable_arr) + logger.debug('parsing: %s', parsable_arr) # tool calls are generated in an array, so do partial JSON parsing on the entire array tool_call_arr: List[Dict] = partial_json_parser.loads(parsable_arr, flags) # select as the current tool call the one we're on the state at - logger.info(f'Current tool call ID: {self.current_tool_id}') current_tool_call: Dict = tool_call_arr[self.current_tool_id] - # print('parsed ', tool_call_arr) - # case: we are starting a new tool in the array # -> array has nonzero length AND length has moved past curscor if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1: - logger.info('Checking for completeness of previous tool before moving on to next tool') # if we're moving on to a new call, first make sure we haven't missed anything due to JSON completions if self.current_tool_id >= 0: diff: str | None = current_tool_call.get('arguments') if diff: diff = json.dumps(diff).replace(self.streamed_args_for_tool[self.current_tool_id], '') - logger.info(f'Found diff between tools: {diff}') delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall(arguments=diff).model_dump(exclude_none=True)) @@ -314,22 +303,22 @@ def extract_tool_calls_streaming(self, self.current_tool_name_sent = False self.current_tool_initial_sent = False self.streamed_args_for_tool.append('') - logger.info('starting on new tool %d', self.current_tool_id) + logger.debug('starting on new tool %d', self.current_tool_id) return delta # case: update an existing tool - this is handled below elif len(tool_call_arr) - 1 == self.current_tool_id and self.current_tool_id >= 0: - # logger.info('update to tool %d', self.current_tool_id) + # logger.debug('update to tool %d', self.current_tool_id) pass # if there is NOTHING in the array, e.g. if only the open bracket was streamed yet else: - logger.info('No tool call detected yet!') + logger.debug('No tool call detected yet!') return None # if the current tool initial data incl. the id, type=function and idx not sent, send that if not self.current_tool_initial_sent: - logger.info('Sending InitialDeltaToolCall') + logger.debug('Sending InitialDeltaToolCall') self.current_tool_initial_sent = True delta = DeltaMessage( tool_calls=[ @@ -340,7 +329,7 @@ def extract_tool_calls_streaming(self, elif not self.current_tool_name_sent: function_name = current_tool_call.get('name') if function_name: - logger.info(f'Sending DeltaToolCall with function name {function_name}!') + logger.debug(f'Sending DeltaToolCall with function name {function_name}!') delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True)) @@ -358,16 +347,16 @@ def extract_tool_calls_streaming(self, new_text = delta_text.replace('\'', '"') if not cur_arguments and not prev_arguments: - logger.info(f'Skipping text {new_text} (tokens {delta_token_ids}) - no arguments yet') + logger.debug(f'Skipping text {new_text} (tokens {delta_token_ids}) - no arguments yet') delta = None elif not cur_arguments and prev_arguments: logger.error('INVARIANT - impossible to have arguments reset mid-arguments') delta = None elif cur_arguments and not prev_arguments: cur_arguments_json = json.dumps(cur_arguments) - logger.info(f'Finding {new_text} in |{cur_arguments_json}|') + logger.debug(f'Finding {new_text} in |{cur_arguments_json}|') arguments_delta = cur_arguments_json[:cur_arguments_json.index(new_text) + len(new_text)] - logger.info(f'First tokens in arguments received: {arguments_delta}') + logger.debug(f'First tokens in arguments received: {arguments_delta}') delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( arguments=arguments_delta @@ -378,9 +367,9 @@ def extract_tool_calls_streaming(self, elif cur_arguments and prev_arguments: cur_args_json = json.dumps(cur_arguments) prev_args_json = json.dumps(prev_arguments) - logger.info(f'Searching for diff between \n{cur_args_json}\n{prev_args_json}') + logger.debug(f'Searching for diff between \n{cur_args_json}\n{prev_args_json}') argument_diff = extract_intermediate_diff(cur_args_json, prev_args_json) - logger.info(f'got arguments diff: {argument_diff}') + logger.debug(f'got arguments diff: {argument_diff}') delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( arguments=argument_diff @@ -399,7 +388,7 @@ def extract_tool_calls_streaming(self, except Exception as e: logger.error(f'Error trying to handle streaming tool call: {e}') - logger.info(f'Skipping chunk as a result of tool streaming extraction error') + logger.debug(f'Skipping chunk as a result of tool streaming extraction error') return None @@ -487,11 +476,11 @@ def extract_tool_calls_streaming(self, delta_token_ids: List[int] ) -> DeltaMessage | None: - logger.info(f'delta_text: {delta_text}') - logger.info(f'delta_token_ids: {delta_token_ids}') + logger.debug(f'delta_text: {delta_text}') + logger.debug(f'delta_token_ids: {delta_token_ids}') # check to see if we should be streaming a tool call - is there a if self.tool_call_start_token_id not in current_token_ids: - logger.info(f'No tool call tokens found!') + logger.debug(f'No tool call tokens found!') return DeltaMessage(content=delta_text) else: @@ -505,7 +494,7 @@ def extract_tool_calls_streaming(self, # a cheap case - we're generating text, NOT tool calls. if cur_tool_start_count == cur_tool_end_count and prev_tool_end_count == cur_tool_end_count: - logger.info('Generating text content! skipping tool parsing.') + logger.debug('Generating text content! skipping tool parsing.') return DeltaMessage(content=delta_text) # most of the time, we're going in here - we need to do partial JSON parsing and build stuff. @@ -529,7 +518,7 @@ def extract_tool_calls_streaming(self, self.current_tool_name_sent = False self.current_tool_initial_sent = False self.streamed_args_for_tool.append('') - logger.info(f'Starting on a new tool {self.current_tool_id}') + logger.debug(f'Starting on a new tool {self.current_tool_id}') # if an existing tool call is being updated - the most common case! elif cur_tool_start_count > cur_tool_end_count and cur_tool_start_count == prev_tool_start_count: @@ -538,11 +527,11 @@ def extract_tool_calls_streaming(self, # if the current tool call is being closed elif cur_tool_start_count == cur_tool_end_count and cur_tool_end_count > prev_tool_end_count: - logger.info('Closing the current tool call!') + logger.debug('Closing the current tool call!') diff = self.prev_tool_call_arr[self.current_tool_id].get('arguments') if diff: diff = json.dumps(diff).replace(self.streamed_args_for_tool[self.current_tool_id], '') - logger.info(f'Finishing tool and found diff that wasn\'t streamed yet: {diff}') + logger.debug(f'Finishing tool and found diff that wasn\'t streamed yet: {diff}') return DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( arguments=diff @@ -554,13 +543,13 @@ def extract_tool_calls_streaming(self, delta = None return delta - logger.info(f'Tool call portion: {tool_call_portion}') + logger.debug(f'Tool call portion: {tool_call_portion}') current_tool_call = partial_json_parser.loads(tool_call_portion, flags) if tool_call_portion else None - logger.info(f'Parsed tool call {current_tool_call}') + logger.debug(f'Parsed tool call {current_tool_call}') # make sure to send the initial message first if we haven't already - with the tool ID if not self.current_tool_initial_sent: - logger.info('Sending InitialDeltaToolCall') + logger.debug('Sending InitialDeltaToolCall') self.current_tool_initial_sent = True return DeltaMessage( tool_calls=[ @@ -572,7 +561,7 @@ def extract_tool_calls_streaming(self, elif not self.current_tool_name_sent: function_name: str | None = current_tool_call.get('name') if function_name: - logger.info(f'Sending DeltaToolCall with function name {function_name}!') + logger.debug(f'Sending DeltaToolCall with function name {function_name}!') self.current_tool_name_sent = True return DeltaMessage(tool_calls=[DeltaToolCall( index=self.current_tool_id, @@ -589,28 +578,28 @@ def extract_tool_calls_streaming(self, else: # now we have the portion to parse as tool call. if text_portion is not None: - logger.info(f'Also, will send text portion {text_portion}') + logger.debug(f'Also, will send text portion {text_portion}') - logger.info(f'Trying to parse current tool call with ID {self.current_tool_id}') + logger.debug(f'Trying to parse current tool call with ID {self.current_tool_id}') if len(self.prev_tool_call_arr) <= self.current_tool_id: self.prev_tool_call_arr.append({}) - logger.info('Pushed dummy value into tool call arr') + logger.debug('Pushed dummy value into tool call arr') # main logic for tool parsing here prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get('arguments') # previous arguments for this tool cur_arguments = current_tool_call.get('arguments') # arguments, if any, in current dict - logger.info(f'Diffing old arguments {prev_arguments} against new ones {cur_arguments}') + logger.debug(f'Diffing old arguments {prev_arguments} against new ones {cur_arguments}') if not cur_arguments and not prev_arguments: - logger.info(f'Skipping text {delta_text} - no arguments!') + logger.debug(f'Skipping text {delta_text} - no arguments!') delta = None elif not cur_arguments and prev_arguments: logger.error('INVARIANT - impossible to have arguments reset mid-call') delta = None elif cur_arguments and not prev_arguments: cur_arguments_json = json.dumps(cur_arguments) - logger.info(f'Finding {delta_text} in {cur_arguments_json}') + logger.debug(f'Finding {delta_text} in {cur_arguments_json}') arguments_delta = cur_arguments_json[:cur_arguments_json.index(delta_text) + len(delta_text)] - logger.info(f'First tokens in arguments received: {arguments_delta}') + logger.debug(f'First tokens in arguments received: {arguments_delta}') delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( arguments=arguments_delta @@ -621,9 +610,9 @@ def extract_tool_calls_streaming(self, elif cur_arguments and prev_arguments: cur_args_json = json.dumps(cur_arguments) prev_args_json = json.dumps(prev_arguments) - logger.info(f"Searching for diff between \n{cur_args_json}\n{prev_args_json}") + logger.debug(f"Searching for diff between \n{cur_args_json}\n{prev_args_json}") argument_diff = extract_intermediate_diff(cur_args_json, prev_args_json) - logger.info(f'Got argument diff: {argument_diff}') + logger.debug(f'Got argument diff: {argument_diff}') delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( arguments=argument_diff @@ -646,5 +635,5 @@ def extract_tool_calls_streaming(self, except Exception as e: logger.error(f'Error trying to handle streaming tool call: {e}') - logger.info(f'Skipping chunk as a result of tool streaming extraction error') + logger.debug(f'Skipping chunk as a result of tool streaming extraction error') return None # do not stream a delta. skip this token ID. From dc27bec4e0d35018bb7547342e3ce34fe760a253 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 1 Aug 2024 23:37:16 -0500 Subject: [PATCH 075/222] fix: more logging changes --- vllm/entrypoints/openai/serving_chat.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 36d7f4537cee6..f22c2cc5a1c7c 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -453,11 +453,8 @@ async def chat_completion_stream_generator( expected_call = json.dumps( tool_parser.prev_tool_call_arr[len(tool_parser.prev_tool_call_arr) - 1].get('arguments', {}) ) - logger.info(f'Expected tool call {expected_call}') actual_call = tool_parser.streamed_args_for_tool[len(tool_parser.prev_tool_call_arr) - 1] - logger.info(f'Actual tool call {actual_call}, correcting.') remaining_call = expected_call.replace(actual_call, '', 1) - logger.info(f'Remaining call: {remaining_call}') delta_message = DeltaMessage(tool_calls=[ DeltaToolCall(index=len(tool_parser.prev_tool_call_arr) - 1, function=DeltaFunctionCall( arguments=remaining_call From 85515c04f1dfe5786e2bbf50bfccfb7405add994 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 2 Aug 2024 00:10:21 -0500 Subject: [PATCH 076/222] fix(ci): formatting --- ...penai_chat_completion_client_with_tools.py | 92 ++-- vllm/entrypoints/chat_utils.py | 39 +- vllm/entrypoints/openai/api_server.py | 5 +- vllm/entrypoints/openai/cli_args.py | 29 +- vllm/entrypoints/openai/protocol.py | 59 +- vllm/entrypoints/openai/serving_chat.py | 164 +++--- vllm/entrypoints/openai/tool_parsers.py | 504 +++++++++++------- .../guided_decoding/outlines_decoding.py | 7 +- 8 files changed, 511 insertions(+), 388 deletions(-) diff --git a/examples/openai_chat_completion_client_with_tools.py b/examples/openai_chat_completion_client_with_tools.py index 085f6f5d838c0..417e15bce118c 100644 --- a/examples/openai_chat_completion_client_with_tools.py +++ b/examples/openai_chat_completion_client_with_tools.py @@ -14,7 +14,6 @@ models = client.models.list() model = models.data[0].id - tools = [{ "type": "function", "function": { @@ -24,12 +23,16 @@ "type": "object", "properties": { "city": { - "type": "string", - "description": "The city to find the weather for, e.g. 'San Francisco'" + "type": + "string", + "description": + "The city to find the weather for, e.g. 'San Francisco'" }, "state": { - "type": "string", - "description": "the two-letter abbreviation for the state that the city is in, e.g. 'CA' which would mean 'California'" + "type": + "string", + "description": + "the two-letter abbreviation for the state that the city is in, e.g. 'CA' which would mean 'California'" }, "unit": { "type": "string", @@ -42,37 +45,31 @@ } }] -messages = [ - { - "role": "user", - "content": "Hi! How are you doing today?" - }, - { - "role": "assistant", - "content": "I'm doing well! How can I help you?" - }, - { - "role": "user", - "content": "Can you tell me what the temperate will be in Dallas and San Francisco, in fahrenheit?" - } - ] +messages = [{ + "role": "user", + "content": "Hi! How are you doing today?" +}, { + "role": "assistant", + "content": "I'm doing well! How can I help you?" +}, { + "role": + "user", + "content": + "Can you tell me what the temperate will be in Dallas and San Francisco, in fahrenheit?" +}] -chat_completion = client.chat.completions.create( - messages=messages, - model=model, - tools=tools -) +chat_completion = client.chat.completions.create(messages=messages, + model=model, + tools=tools) print("Chat completion results:") print(chat_completion) print('\n\n') -tool_calls_stream = client.chat.completions.create( - messages=messages, - model=model, - tools=tools, - stream=True -) +tool_calls_stream = client.chat.completions.create(messages=messages, + model=model, + tools=tools, + stream=True) chunks = [] for chunk in tool_calls_stream: @@ -82,7 +79,6 @@ else: print(chunk.choices[0].delta) - arguments = [] tool_call_idx = -1 for chunk in chunks: @@ -90,21 +86,27 @@ if chunk.choices[0].delta.tool_calls: if chunk.choices[0].delta.tool_calls[0].index != tool_call_idx: if tool_call_idx >= 0: - print(f'streamed tool call arguments: {arguments[tool_call_idx]}\n\n') + print( + f'streamed tool call arguments: {arguments[tool_call_idx]}\n\n' + ) tool_call_idx = chunk.choices[0].delta.tool_calls[0].index arguments.append('') if chunk.choices[0].delta.tool_calls[0].id: - print(f'streamed tool call id: {chunk.choices[0].delta.tool_calls[0].id}') + print( + f'streamed tool call id: {chunk.choices[0].delta.tool_calls[0].id}' + ) if chunk.choices[0].delta.tool_calls[0].function: if chunk.choices[0].delta.tool_calls[0].function.name: - print(f'streamed tool call name: {chunk.choices[0].delta.tool_calls[0].function.name}') + print( + f'streamed tool call name: {chunk.choices[0].delta.tool_calls[0].function.name}' + ) if chunk.choices[0].delta.tool_calls[0].function.arguments: - arguments[tool_call_idx] += chunk.choices[0].delta.tool_calls[0].function.arguments + arguments[tool_call_idx] += chunk.choices[0].delta.tool_calls[ + 0].function.arguments if len(arguments): print(f'streamed tool call arguments: {arguments[-1]}') - print('\n\n') messages.append({ @@ -112,13 +114,13 @@ "tool_calls": chat_completion.choices[0].message.tool_calls }) + # Now, simulate a tool call def get_current_weather(city: str, state: str, unit: 'str'): return "The weather in Dallas, Texas is 85 degrees fahrenheit. It is partly cloudly, with highs in the 90's." -available_tools = { - "get_current_weather": get_current_weather -} + +available_tools = {"get_current_weather": get_current_weather} completion_tool_calls = chat_completion.choices[0].message.tool_calls for call in completion_tool_calls: @@ -134,14 +136,10 @@ def get_current_weather(city: str, state: str, unit: 'str'): }) print("Sending new chat with messages", messages) -chat_completion_2 = client.chat.completions.create( - messages=messages, - model=model, - tools=tools, - stream=False -) +chat_completion_2 = client.chat.completions.create(messages=messages, + model=model, + tools=tools, + stream=False) print(chat_completion_2) print('\n\n') - - diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 3c80304c2bc83..6be74ea7ae512 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -33,7 +33,7 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False): ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam, -CustomChatCompletionContentPartParam] + CustomChatCompletionContentPartParam] class CustomChatCompletionMessageParam(TypedDict, total=False): @@ -53,11 +53,11 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, -CustomChatCompletionMessageParam] + CustomChatCompletionMessageParam] @final # So that it should be compatible with Dict[str, str] -class ConversationMessage(TypedDict): +class ConversationMessage(TypedDict, total=False): role: str content: Optional[str] tool_call_id: Optional[str] @@ -126,10 +126,10 @@ def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str: def _parse_chat_message_content_parts( - role: str, - parts: Iterable[ChatCompletionContentPartParam], - model_config: ModelConfig, - tokenizer: PreTrainedTokenizer, + role: str, + parts: Iterable[ChatCompletionContentPartParam], + model_config: ModelConfig, + tokenizer: PreTrainedTokenizer, ) -> ChatMessageParseResult: texts: List[str] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = [] @@ -178,15 +178,17 @@ def _parse_chat_message_content_parts( def parse_chat_message_content( - message: ChatCompletionMessageParam, - model_config: ModelConfig, - tokenizer: PreTrainedTokenizer, + message: ChatCompletionMessageParam, + model_config: ModelConfig, + tokenizer: PreTrainedTokenizer, ) -> ChatMessageParseResult: role = message["role"] content = message.get("content") tool_call_id = message.get('content') tool_calls = message.get('tool_calls') - name = message.get('name', '') # no longer used by OpenAI, was formerly. used for tool calls by some models still + name = message.get( + 'name', '' + ) # no longer used by OpenAI, was formerly. used for tool calls by some models still # empty case if content is None and tool_calls is None: @@ -194,13 +196,22 @@ def parse_chat_message_content( # special case - assistant message where tool calls are provided. if role == 'assistant' and tool_calls is not None and len(tool_calls): - messages = [ConversationMessage(role=role, content=content, tool_calls=list(tool_calls))] + messages = [ + ConversationMessage(role=role, + content=content, + tool_calls=list(tool_calls)) + ] return ChatMessageParseResult(messages=messages, mm_futures=[]) # special case - tool call result message elif role == 'tool': - messages = [ConversationMessage(role=role, name=name, content=content, tool_call_id=tool_call_id)] - return ChatMessageParseResult(messages=messages,mm_futures=[]) + messages = [ + ConversationMessage(role=role, + name=name, + content=content, + tool_call_id=tool_call_id) + ] + return ChatMessageParseResult(messages=messages, mm_futures=[]) # other cases - normal assistant response, user message or system message elif isinstance(content, str): diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index bfdc74044184e..afe80e042e763 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -147,7 +147,7 @@ async def create_chat_completion(request: ChatCompletionRequest, else: return StreamingResponse(content=generator, - media_type="text/event-stream") + media_type="text/event-stream") # handle non-streaming requests else: @@ -271,8 +271,7 @@ async def build_server( chat_template=args.chat_template, return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, - tool_parser=args.tool_call_parser - ) + tool_parser=args.tool_call_parser) openai_serving_completion = OpenAIServingCompletion( engine, model_config, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 474599031164f..3dbc209180193 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -138,20 +138,23 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument("--enable-api-tools", action="store_true", help="Enable OpenAI-like tools API " - "(only function calls are currently supported)") + "(only function calls are currently supported)") - parser.add_argument("--enable-auto-tool-choice", - action="store_true", - help='Enable auto tool choice for models that support it. Requires --tool-call-parser' - ) - - parser.add_argument("--tool-call-parser", - type=str, - choices=['mistral', 'hermes'], - help='Select the tool call parser depending on the model that you\'re using. ' - 'This is used to parse the model-generated tool call into OpenAI API format. ' - 'Required for --enable-auto-tool-choice. Options: "mistral", "hermes"' - ) + parser.add_argument( + "--enable-auto-tool-choice", + action="store_true", + help= + 'Enable auto tool choice for models that support it. Requires --tool-call-parser' + ) + + parser.add_argument( + "--tool-call-parser", + type=str, + choices=['mistral', 'hermes'], + help= + 'Select the tool call parser depending on the model that you\'re using. ' + 'This is used to parse the model-generated tool call into OpenAI API format. ' + 'Required for --enable-auto-tool-choice. Options: "mistral", "hermes"') parser = AsyncEngineArgs.add_cli_args(parser) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index e8528ae15ef78..36e65b5ea3b35 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1,7 +1,7 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import time -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union, Type import openai import torch @@ -47,10 +47,9 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): tool_calls: Optional[List[dict]] - -ChatCompletionMessageParam = Union[ - openai.types.chat.ChatCompletionMessageParam, - CustomChatCompletionMessageParam] +ChatCompletionMessageParam = Type[ + Union[openai.types.chat.ChatCompletionMessageParam, + CustomChatCompletionMessageParam]] class OpenAIBaseModel(BaseModel): @@ -155,16 +154,10 @@ class ChatCompletionRequest(OpenAIBaseModel): temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 tools: Optional[List[ChatCompletionToolsParam]] = None - tool_choice: Optional[ - Union[ - Union[ - Literal["none"], - Literal["auto"] - ], - ChatCompletionNamedToolChoiceParam - ] - ] = "none" - parallel_tool_calls: Optional[bool] = False # NOTE this will be ignored by VLLM as the behavior is determined by the model + tool_choice: Optional[Union[Union[Literal["none"], Literal["auto"]], + ChatCompletionNamedToolChoiceParam]] = "none" + parallel_tool_calls: Optional[ + bool] = False # NOTE this will be ignored by VLLM as the behavior is determined by the model user: Optional[str] = None # doc: begin-chat-completion-sampling-params @@ -329,8 +322,9 @@ def check_guided_decoding_count(cls, data): "You can only use one kind of guided decoding " "('guided_json', 'guided_regex' or 'guided_choice').") # you can only either use guided decoding or tools, not both - if (guide_count > 1 and "tool_choice" in data and data[ - "tool_choice"] != "none" and data["tool_choice"] != "auto"): + if (guide_count > 1 and "tool_choice" in data + and data["tool_choice"] != "none" + and data["tool_choice"] != "auto"): raise ValueError( "You can only either use guided decoding or tools, not both.") return data @@ -343,13 +337,15 @@ def check_tool_usage(cls, data): # ensure that if "tool choice" is specified, tools are present if "tools" not in data or data["tools"] is None: - raise ValueError("When using `tool_choice`, `tools` must be set.") - + raise ValueError( + "When using `tool_choice`, `tools` must be set.") # make sure that tool choice is either a named tool OR that it's set to "auto" - if data["tool_choice"] != "auto" and not isinstance(data["tool_choice"], dict): + if data["tool_choice"] != "auto" and not isinstance( + data["tool_choice"], dict): raise ValueError( - "`tool_choice` must either be a named tool or \"auto\". `tool_choice=\"none\" is not supported.") + "`tool_choice` must either be a named tool or \"auto\". `tool_choice=\"none\" is not supported." + ) # ensure that if "tool_choice" is specified as an object, it matches a valid tool if isinstance(data["tool_choice"], dict): @@ -357,13 +353,15 @@ def check_tool_usage(cls, data): specified_function = data["tool_choice"]["function"] if not specified_function: return ValueError( - 'Incorrectly formatted `tool_choice`. Should be like ' + + 'Incorrectly formatted `tool_choice`. Should be like ' + + '`{"type": "function", "function": {"name": "my_function"}}`' ) specified_function_name = specified_function["name"] if not specified_function_name: return ValueError( - 'Incorrectly formatted `tool_choice`. Should be like ' + + 'Incorrectly formatted `tool_choice`. Should be like ' + + '`{"type": "function", "function": {"name": "my_function"}}`' ) for tool in data['tools']: @@ -371,7 +369,9 @@ def check_tool_usage(cls, data): valid_tool = True break if not valid_tool: - return ValueError("The tool specified in `tool_choice` does not match any of the specified `tools`") + return ValueError( + "The tool specified in `tool_choice` does not match any of the specified `tools`" + ) # per OpenAI spec, make sure that tool_choice defaults to "auto" when tools are specified elif "tools" in data and "tool_choice" not in data: @@ -651,10 +651,7 @@ class FunctionCall(OpenAIBaseModel): arguments: str def to_dict(self): - return { - "name": self.name, - "arguments": self.arguments - } + return {"name": self.name, "arguments": self.arguments} class ToolCall(OpenAIBaseModel): @@ -724,8 +721,10 @@ class ChatCompletionResponseChoice(OpenAIBaseModel): index: int message: ChatMessage logprobs: Optional[ChatCompletionLogProbs] = None - finish_reason: Optional[str] = Field(default='stop') # per OpenAI spec this is the default - stop_reason: Optional[Union[int, str]] = None # ??? Not part of the OpenAI spec + finish_reason: Optional[str] = Field( + default='stop') # per OpenAI spec this is the default + stop_reason: Optional[Union[int, + str]] = None # ??? Not part of the OpenAI spec class ChatCompletionResponse(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index f22c2cc5a1c7c..73a8747fa4e44 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,8 +1,7 @@ import time import json -from dataclasses import dataclass, field -from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, - List, Optional, Type, final) +from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List, + Optional, Type) from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List, Optional, Union, Sequence as GenericSequence) from typing import Union @@ -14,8 +13,6 @@ from vllm.entrypoints.chat_utils import (ConversationMessage, load_chat_template, parse_chat_message_content, - ChatCompletionMessageParam, - ChatMessageParseResult, ConversationMessage) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( @@ -41,31 +38,27 @@ from jinja2 import Environment, FileSystemLoader, select_autoescape -env = Environment( - loader=FileSystemLoader('./'), - autoescape=select_autoescape() -) +env = Environment(loader=FileSystemLoader('./'), + autoescape=select_autoescape()) logger = init_logger(__name__) class OpenAIServingChat(OpenAIServing): - def __init__( - self, - engine: AsyncLLMEngine, - model_config: ModelConfig, - served_model_names: List[str], - response_role: str, - *, - lora_modules: Optional[List[LoRAModulePath]], - prompt_adapters: Optional[List[PromptAdapterPath]], - request_logger: Optional[RequestLogger], - chat_template: Optional[str], - return_tokens_as_token_ids: bool = False, - enable_auto_tools: Optional[bool] = False, - tool_parser: Optional[str] = None - ): + def __init__(self, + engine: AsyncLLMEngine, + model_config: ModelConfig, + served_model_names: List[str], + response_role: str, + *, + lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]], + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + return_tokens_as_token_ids: bool = False, + enable_auto_tools: Optional[bool] = False, + tool_parser: Optional[str] = None): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, @@ -81,11 +74,15 @@ def __init__( # set up tool use self.enable_auto_tools: bool = enable_auto_tools if self.enable_auto_tools: - logger.info('"Auto" tool choice has been enabled please note that while the parallel_tool_calls client ' - 'option is preset for compatibility reasons, it will be ignored.') + logger.info( + '"Auto" tool choice has been enabled please note that while the parallel_tool_calls client ' + 'option is preset for compatibility reasons, it will be ignored.' + ) if self.enable_auto_tools and not tool_parser: - raise TypeError('Error: --enable-auto-tool-choice requires --tool-choice-parser') + raise TypeError( + 'Error: --enable-auto-tool-choice requires --tool-choice-parser' + ) if tool_parser == 'mistral': self.tool_parser: Type[ToolParser] = MistralToolParser @@ -95,10 +92,11 @@ def __init__( raise ValueError(f'Invalid tool parser value {tool_parser}!') async def create_chat_completion( - self, - request: ChatCompletionRequest, - raw_request: Optional[Request] = None - ) -> Union[ErrorResponse, AsyncGenerator[str, None], ChatCompletionResponse]: + self, + request: ChatCompletionRequest, + raw_request: Optional[Request] = None + ) -> Union[ErrorResponse, AsyncGenerator[str, None], + ChatCompletionResponse]: """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/chat/create @@ -215,22 +213,12 @@ async def create_chat_completion( # Streaming response if request.stream: return self.chat_completion_stream_generator( - request, - result_generator, - request_id, - conversation, - tokenizer - ) + request, result_generator, request_id, conversation, tokenizer) else: try: generator = await self.chat_completion_full_generator( - request, - raw_request, - result_generator, - request_id, - conversation, - tokenizer - ) + request, raw_request, result_generator, request_id, + conversation, tokenizer) assert isinstance(generator, ChatCompletionResponse) return generator @@ -307,7 +295,7 @@ async def chat_completion_stream_generator( last_msg_content = "" if conversation and conversation[-1].get( "content") and conversation[-1].get( - "role") == role: + "role") == role: last_msg_content = conversation[-1]["content"] if last_msg_content: @@ -355,7 +343,7 @@ async def chat_completion_stream_generator( delta_token_ids = output.token_ids[previous_num_tokens[i]:] out_logprobs = output.logprobs[ - previous_num_tokens[i]:] if output.logprobs else None + previous_num_tokens[i]:] if output.logprobs else None if request.logprobs and request.top_logprobs is not None: assert out_logprobs is not None, ( @@ -369,10 +357,8 @@ async def chat_completion_stream_generator( else: logprobs = None - delta_text = output.text[len(previous_texts[i]):] - # handle streaming deltas for tools with tool_choice if request.tool_choice and type( request.tool_choice @@ -384,7 +370,8 @@ async def chat_completion_stream_generator( ]) # handle streaming deltas for tools with tool_choice - elif (request.tools and (request.tool_choice is None or request.tool_choice == 'auto') + elif (request.tools and (request.tool_choice is None + or request.tool_choice == 'auto') and self.enable_auto_tools): print('output token IDs', output.token_ids) @@ -393,10 +380,10 @@ async def chat_completion_stream_generator( previous_text=previous_texts[i], current_text=output.text, delta_text=delta_text, - previous_token_ids=output.token_ids[:-1 * len(delta_token_ids)], + previous_token_ids=output. + token_ids[:-1 * len(delta_token_ids)], current_token_ids=output.token_ids, - delta_token_ids=delta_token_ids - ) + delta_token_ids=delta_token_ids) else: delta_message = DeltaMessage(content=delta_text) @@ -442,31 +429,37 @@ async def chat_completion_stream_generator( yield f"data: {data}\n\n" else: # check to make sure we haven't missed something on the last function call - if ( - delta_message.tool_calls - and ( - delta_message.tool_calls[0].function.arguments == '' - or delta_message.tool_calls[0].function.arguments - and (output.finish_reason == 'stop' or output.finish_reason == 'tool_calls') - ) - ): + if (delta_message.tool_calls and + (delta_message.tool_calls[0].function.arguments + == '' or + delta_message.tool_calls[0].function.arguments and + (output.finish_reason == 'stop' + or output.finish_reason == 'tool_calls'))): expected_call = json.dumps( - tool_parser.prev_tool_call_arr[len(tool_parser.prev_tool_call_arr) - 1].get('arguments', {}) - ) - actual_call = tool_parser.streamed_args_for_tool[len(tool_parser.prev_tool_call_arr) - 1] - remaining_call = expected_call.replace(actual_call, '', 1) + tool_parser.prev_tool_call_arr[ + len(tool_parser.prev_tool_call_arr) - + 1].get('arguments', {})) + actual_call = tool_parser.streamed_args_for_tool[ + len(tool_parser.prev_tool_call_arr) - 1] + remaining_call = expected_call.replace( + actual_call, '', 1) delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall(index=len(tool_parser.prev_tool_call_arr) - 1, function=DeltaFunctionCall( - arguments=remaining_call - ).model_dump(exclude_none=True)) - ]) + DeltaToolCall( + index=len(tool_parser.prev_tool_call_arr) - + 1, + function=DeltaFunctionCall( + arguments=remaining_call).model_dump( + exclude_none=True)) + ]) # Send the finish response for each request.n only once prompt_tokens = len(res.prompt_token_ids) choice_data = ChatCompletionResponseStreamChoice( index=i, delta=delta_message, logprobs=logprobs, - finish_reason=output.finish_reason if not len(tool_parser.prev_tool_call_arr) else 'tool_calls', + finish_reason=output.finish_reason + if not len(tool_parser.prev_tool_call_arr) else + 'tool_calls', stop_reason=output.stop_reason) chunk = ChatCompletionStreamResponse( id=request_id, @@ -582,30 +575,36 @@ async def chat_completion_full_generator( message = ChatMessage(role=role, content=output.text) # handle when there are tools and tool choice is auto - elif request.tools and (request.tool_choice == "auto" or request.tool_choice is None) and self.enable_auto_tools: + elif request.tools and ( + request.tool_choice == "auto" + or request.tool_choice is None) and self.enable_auto_tools: - tool_call_info = self.tool_parser.extract_tool_calls(output.text) + tool_call_info = self.tool_parser.extract_tool_calls( + output.text) tools_called = tool_call_info.tools_called if tool_call_info.tools_called: - message = ChatMessage(role=role, content=tool_call_info.content, tool_calls=tool_call_info.tool_calls) + message = ChatMessage(role=role, + content=tool_call_info.content, + tool_calls=tool_call_info.tool_calls) else: - # FOR NOW make it a chat message; we will have to detect the type to make it later. + # FOR NOW make it a chat message; we will have to detect the type to make it later. message = ChatMessage(role=role, content=output.text) # undetermined case that is still important to handle else: - logger.error('Error in chat_completion_full_generator - cannot determine if tools should ' - 'be extracted. Returning a standard chat completion.') + logger.error( + 'Error in chat_completion_full_generator - cannot determine if tools should ' + 'be extracted. Returning a standard chat completion.') message = ChatMessage(role=role, content=output.text) choice_data = ChatCompletionResponseChoice( index=output.index, message=message, logprobs=logprobs, - finish_reason='tool_calls' if tools_called else output.stop_reason if output.stop_reason else 'stop', - stop_reason=output.stop_reason - ) + finish_reason='tool_calls' if tools_called else + output.stop_reason if output.stop_reason else 'stop', + stop_reason=output.stop_reason) choices.append(choice_data) if request.echo: @@ -615,7 +614,7 @@ async def chat_completion_full_generator( last_msg_content = conversation[-1]["content"] for choice in choices: - full_message = last_msg_content + choice.message.content + full_message = last_msg_content + choice.message.content if choice.message.content else '' choice.message.content = full_message num_prompt_tokens = len(final_res.prompt_token_ids) @@ -637,11 +636,8 @@ async def chat_completion_full_generator( return response def _get_top_logprobs( - self, - logprobs: Dict[int, Logprob], - top_logprobs: Optional[int], - tokenizer: PreTrainedTokenizer - ) -> List[ChatCompletionLogProb]: + self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int], + tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]: return [ ChatCompletionLogProb(token=(token := self._get_decoded_token( p[1], diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index 8735a69a9269f..28d939fc93d92 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -1,7 +1,10 @@ -from vllm.entrypoints.openai.protocol import ToolCall, FunctionCall, ChatCompletionResponse, \ - ExtractedToolCallInformation, DeltaToolCall, InitialDeltaToolCall, DeltaFunctionCall, DeltaMessage +from vllm.entrypoints.openai.protocol import (ToolCall, FunctionCall, + ExtractedToolCallInformation, + DeltaToolCall, + InitialDeltaToolCall, + DeltaFunctionCall, DeltaMessage) from vllm.logger import init_logger -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Union from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) import json @@ -77,7 +80,9 @@ def extract_intermediate_diff(curr: str, old: str) -> str: diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1] if len(prefix): - diff = diff.replace(prefix, '', 1) # replace the prefix only once in case it's mirrored + diff = diff.replace( + prefix, '', + 1) # replace the prefix only once in case it's mirrored return diff @@ -102,11 +107,11 @@ class ToolParser: derived classes. """ - - def __init__( - self, - tokenizer: Optional[PreTrainedTokenizer | PreTrainedTokenizerFast | PreTrainedTokenizerFast | AutoTokenizer]=None - ): + def __init__(self, + tokenizer: Optional[Union[PreTrainedTokenizer, + PreTrainedTokenizerFast, + PreTrainedTokenizerFast, + AutoTokenizer]] = None): # the tool call array derived from partial JSON parsing from the previous execution of the function self.prev_tool_call_arr: List[Dict] = [] # the index of the tool call that is currently being parsed @@ -118,7 +123,8 @@ def __init__( # is sent. self.current_tool_initial_sent: bool = False # array of the argument strings (one for each tool) that have been streamed to the client. - self.streamed_args_for_tool: List[str] = [] # map what has been streamed for each tool so far to a list + self.streamed_args_for_tool: List[str] = [ + ] # map what has been streamed for each tool so far to a list self.model_tokenizer = tokenizer @@ -129,22 +135,26 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: Used for non-streaming responses where we have the entire model response available before sending to the client. Static because it's stateless. """ - raise NotImplementedError('AbstractToolParser.extract_tool_calls has not been implemented!') - - def extract_tool_calls_streaming(self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: List[int], - current_token_ids: List[int], - delta_token_ids: List[int], - ) -> DeltaMessage | None: + raise NotImplementedError( + 'AbstractToolParser.extract_tool_calls has not been implemented!') + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: List[int], + current_token_ids: List[int], + delta_token_ids: List[int], + ) -> Union[DeltaMessage, None]: """ Instance method that should be implemented for extracting tool calls from an incomplete response; for use when handling tool calls and streaming. Has to be an instance method because it requires state - the current text/ tokens/diffs, but also the information about what has previously been parsed and extracted (see constructor) """ - raise NotImplementedError('AbstractToolParser.extract_tool_calls_streaming has not been implemented!') + raise NotImplementedError( + 'AbstractToolParser.extract_tool_calls_streaming has not been implemented!' + ) class MistralToolParser(ToolParser): @@ -168,7 +178,6 @@ class MistralToolParser(ToolParser): bot_token_id: int = 5 # token ID thereof from the models' tokenizer tool_call_regex = re.compile(r'\[{.*?}\]', re.DOTALL) - @staticmethod def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: """ @@ -176,22 +185,21 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: quotes for JSON parsing, make sure your tool call arguments don't ever include quotes! """ - logger.debug('Trying to extract mistral tool calls from the following:') + logger.debug( + 'Trying to extract mistral tool calls from the following:') logger.debug(model_output) # Get the tool call token from the tokenizer if MistralToolParser.bot_token not in model_output: - return ExtractedToolCallInformation( - tools_called=False, - tool_calls=[], - content=model_output - ) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) else: try: # this will throw an exception if we can't find the tool call properly raw_tool_call = MistralToolParser.tool_call_regex.findall( - model_output - .replace(MistralToolParser.bot_token, '') # remove BOT token + model_output.replace(MistralToolParser.bot_token, + '') # remove BOT token .replace("'", '"') # replace string quotes )[0] @@ -203,33 +211,30 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: function=FunctionCall( name=raw_function_call['name'], # function call args are JSON but as a string - arguments=json.dumps(raw_function_call['arguments']) - ) - ) + arguments=json.dumps( + raw_function_call['arguments']))) for raw_function_call in function_call_arr ] content = model_output.split(MistralToolParser.bot_token)[0] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, - content=content if len(content) > 0 else None - ) + content=content if len(content) > 0 else None) except Exception as e: - logger.error("Error in extracting tool call from response: %s", e) + logger.error("Error in extracting tool call from response: %s", + e) print('ERROR', e) # return information to just treat the tool call as regular JSON - return ExtractedToolCallInformation( - tools_called=False, - tool_calls=[], - content=model_output - ) - - def __init__( - self, - tokenizer: Optional[ - PreTrainedTokenizer | PreTrainedTokenizerFast | PreTrainedTokenizerFast | AutoTokenizer] = None - ): + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def __init__(self, + tokenizer: Optional[Union[PreTrainedTokenizer, + PreTrainedTokenizerFast, + PreTrainedTokenizerFast, + AutoTokenizer]] = None): super().__init__(tokenizer) # initialize properties used for state when parsing tool calls in streaming mode @@ -237,16 +242,18 @@ def __init__( self.current_tool_id: int = -1 self.current_tool_name_sent: bool = False self.current_tool_initial_sent: bool = False - self.streamed_args_for_tool: List[str] = [] # map what has been streamed for each tool so far to a list - - def extract_tool_calls_streaming(self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: List[int], - current_token_ids: List[int], - delta_token_ids: List[int], - ) -> DeltaMessage | None: + self.streamed_args_for_tool: List[str] = [ + ] # map what has been streamed for each tool so far to a list + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: List[int], + current_token_ids: List[int], + delta_token_ids: List[int], + ) -> Union[DeltaMessage, None]: # if the tool call token is not in the tokens generated so far, append output to contents since it's not a tool if self.bot_token_id not in current_token_ids: @@ -269,31 +276,41 @@ def extract_tool_calls_streaming(self, # replace BOT token with empty string, and convert single quotes to double to allow parsing as JSON # since mistral uses single quotes instead of double for tool calls - tool_call_message_portion = current_text.split(self.bot_token)[1] + tool_call_message_portion = current_text.split( + self.bot_token)[1] parsable_arr = tool_call_message_portion.replace('\'', '"') logger.debug('parsing: %s', parsable_arr) # tool calls are generated in an array, so do partial JSON parsing on the entire array - tool_call_arr: List[Dict] = partial_json_parser.loads(parsable_arr, flags) + tool_call_arr: List[Dict] = partial_json_parser.loads( + parsable_arr, flags) # select as the current tool call the one we're on the state at current_tool_call: Dict = tool_call_arr[self.current_tool_id] # case: we are starting a new tool in the array # -> array has nonzero length AND length has moved past curscor - if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1: + if len(tool_call_arr) > 0 and len( + tool_call_arr) > self.current_tool_id + 1: - # if we're moving on to a new call, first make sure we haven't missed anything due to JSON completions + # if we're moving on to a new call, first make sure we haven't missed anything in the previous + # one that was auto-generated due to JSON completions, but wasn't streamed to the client yet. if self.current_tool_id >= 0: - diff: str | None = current_tool_call.get('arguments') + diff: Union[str, + None] = current_tool_call.get('arguments') if diff: - diff = json.dumps(diff).replace(self.streamed_args_for_tool[self.current_tool_id], '') + diff = json.dumps(diff).replace( + self.streamed_args_for_tool[ + self.current_tool_id], '') delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall(arguments=diff).model_dump(exclude_none=True)) + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) ]) - self.streamed_args_for_tool[self.current_tool_id] += diff + self.streamed_args_for_tool[ + self.current_tool_id] += diff else: delta = None else: @@ -303,11 +320,14 @@ def extract_tool_calls_streaming(self, self.current_tool_name_sent = False self.current_tool_initial_sent = False self.streamed_args_for_tool.append('') - logger.debug('starting on new tool %d', self.current_tool_id) + logger.debug('starting on new tool %d', + self.current_tool_id) return delta # case: update an existing tool - this is handled below - elif len(tool_call_arr) - 1 == self.current_tool_id and self.current_tool_id >= 0: + elif len( + tool_call_arr + ) - 1 == self.current_tool_id and self.current_tool_id >= 0: # logger.debug('update to tool %d', self.current_tool_id) pass @@ -320,19 +340,24 @@ def extract_tool_calls_streaming(self, if not self.current_tool_initial_sent: logger.debug('Sending InitialDeltaToolCall') self.current_tool_initial_sent = True - delta = DeltaMessage( - tool_calls=[ - InitialDeltaToolCall(index=self.current_tool_id).model_dump(exclude_none=True)] - ) + delta = DeltaMessage(tool_calls=[ + InitialDeltaToolCall( + index=self.current_tool_id).model_dump( + exclude_none=True) + ]) # if the current tool name hasn't been sent, send if available - otherwise no chunks elif not self.current_tool_name_sent: function_name = current_tool_call.get('name') if function_name: - logger.debug(f'Sending DeltaToolCall with function name {function_name}!') + logger.debug( + f'Sending DeltaToolCall with function name {function_name}!' + ) delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True)) + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) ]) self.current_tool_name_sent = True else: @@ -341,41 +366,59 @@ def extract_tool_calls_streaming(self, # now we know we're on the same tool call and we're streaming arguments else: - prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get('arguments') + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get('arguments') cur_arguments = current_tool_call.get('arguments') new_text = delta_text.replace('\'', '"') if not cur_arguments and not prev_arguments: - logger.debug(f'Skipping text {new_text} (tokens {delta_token_ids}) - no arguments yet') + logger.debug( + f'Skipping text {new_text} (tokens {delta_token_ids}) - no arguments yet' + ) delta = None elif not cur_arguments and prev_arguments: - logger.error('INVARIANT - impossible to have arguments reset mid-arguments') + logger.error( + 'INVARIANT - impossible to have arguments reset mid-arguments' + ) delta = None elif cur_arguments and not prev_arguments: cur_arguments_json = json.dumps(cur_arguments) - logger.debug(f'Finding {new_text} in |{cur_arguments_json}|') - arguments_delta = cur_arguments_json[:cur_arguments_json.index(new_text) + len(new_text)] - logger.debug(f'First tokens in arguments received: {arguments_delta}') + logger.debug( + f'Finding {new_text} in |{cur_arguments_json}|') + arguments_delta = cur_arguments_json[: + cur_arguments_json + .index(new_text) + + len(new_text)] + logger.debug( + f'First tokens in arguments received: {arguments_delta}' + ) delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( - arguments=arguments_delta - ).model_dump(exclude_none=True)) + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta). + model_dump(exclude_none=True)) ]) - self.streamed_args_for_tool[self.current_tool_id] += arguments_delta + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta elif cur_arguments and prev_arguments: cur_args_json = json.dumps(cur_arguments) prev_args_json = json.dumps(prev_arguments) - logger.debug(f'Searching for diff between \n{cur_args_json}\n{prev_args_json}') - argument_diff = extract_intermediate_diff(cur_args_json, prev_args_json) + logger.debug( + f'Searching for diff between \n{cur_args_json}\n{prev_args_json}' + ) + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) logger.debug(f'got arguments diff: {argument_diff}') delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( - arguments=argument_diff - ).model_dump(exclude_none=True)) + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) ]) - self.streamed_args_for_tool[self.current_tool_id] += argument_diff + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff else: # try parsing it with regular JSON - if it works we're at the end, and we need to send the # difference between tokens streamed so far and the valid JSON @@ -387,8 +430,11 @@ def extract_tool_calls_streaming(self, return delta except Exception as e: - logger.error(f'Error trying to handle streaming tool call: {e}') - logger.debug(f'Skipping chunk as a result of tool streaming extraction error') + logger.error( + f'Error trying to handle streaming tool call: {e}') + logger.debug( + 'Skipping chunk as a result of tool streaming extraction error' + ) return None @@ -396,105 +442,114 @@ class Hermes2ProToolParser(ToolParser): tool_call_start_token: str = '' tool_call_end_token: str = '' - # regex to match between and OR between and EOS (happens sometimes :)) - tool_call_regex = re.compile(r'(.*?)|(.*)', re.DOTALL) - scratch_pad_regex = re.compile(r'(.*?)', re.DOTALL) + tool_call_regex = re.compile( + r'(.*?)|(.*)', re.DOTALL) + scratch_pad_regex = re.compile(r'(.*?)', + re.DOTALL) @staticmethod def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: # sanity check; avoid unnecessary processing if Hermes2ProToolParser.tool_call_start_token not in model_output: - return ExtractedToolCallInformation( - tools_called=False, - tool_calls=[], - content=model_output - ) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) else: try: - # there are two possible captures - between tags, or between a tag and end-of-string so the result of findall - # is an array of tuples where one is a function call and the other is None - function_call_tuples = Hermes2ProToolParser.tool_call_regex.findall(model_output) + # there are two possible captures - between tags, or between a tag and end-of-string so the result of + # findall is an array of tuples where one is a function call and the other is None + function_call_tuples = Hermes2ProToolParser.tool_call_regex.findall( + model_output) # load the JSON, and then use it to build the Function and Tool Call - raw_function_calls = [json.loads(match[0] if match[0] else match[1]) for match in function_call_tuples] + raw_function_calls = [ + json.loads(match[0] if match[0] else match[1]) + for match in function_call_tuples + ] tool_calls = [ ToolCall( type='function', function=FunctionCall( name=function_call['name'], # function call args are JSON but as a string - arguments=json.dumps(function_call['arguments']) - ) - ) for function_call in raw_function_calls + arguments=json.dumps(function_call['arguments']))) + for function_call in raw_function_calls ] - content = model_output[:model_output.find(Hermes2ProToolParser.tool_call_start_token)] + content = model_output[:model_output.find( + Hermes2ProToolParser.tool_call_start_token)] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, - content=content if content else None - ) + content=content if content else None) except Exception as e: - logger.error("Error in extracting tool call from response %s", e) - return ExtractedToolCallInformation( - tools_called=False, - tool_calls=[], - content=model_output - ) - - def __init__( - self, - tokenizer: Optional[ - PreTrainedTokenizer | PreTrainedTokenizerFast | PreTrainedTokenizerFast | AutoTokenizer] = None - ): + logger.error("Error in extracting tool call from response %s", + e) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def __init__(self, + tokenizer: Optional[Union[PreTrainedTokenizer, + PreTrainedTokenizerFast, + PreTrainedTokenizerFast, + AutoTokenizer]] = None): super().__init__(tokenizer) self.current_tool_name_sent: bool = False # reset each time we encounter a new tool in the array self.prev_tool_call_arr: List[Dict] = [] self.current_tool_id: int = -1 self.current_tool_name_sent: bool = False self.current_tool_initial_sent: bool = False - self.streamed_args_for_tool: List[str] = [] # map what has been streamed for each tool so far to a list + self.streamed_args_for_tool: List[str] = [ + ] # map what has been streamed for each tool so far to a list if not self.model_tokenizer: - raise ValueError('The model tokenizer must be passed to the ToolParser constructor during construction.') - self.tool_call_start_token_id: int = self.model_tokenizer.vocab[''] - self.tool_call_end_token_id: int = self.model_tokenizer.vocab[''] + raise ValueError( + 'The model tokenizer must be passed to the ToolParser constructor during construction.' + ) + self.tool_call_start_token_id: int = self.model_tokenizer.vocab[ + ''] + self.tool_call_end_token_id: int = self.model_tokenizer.vocab[ + ''] if not self.tool_call_start_token_id or not self.tool_call_end_token_id: - raise RuntimeError('Hermes 2 Pro Tool parser could not locate tool call start/end tokens in the tokenizer!') + raise RuntimeError( + 'Hermes 2 Pro Tool parser could not locate tool call start/end tokens in the tokenizer!' + ) - def extract_tool_calls_streaming(self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: List[int], - current_token_ids: List[int], - delta_token_ids: List[int] - ) -> DeltaMessage | None: + def extract_tool_calls_streaming( + self, previous_text: str, current_text: str, delta_text: str, + previous_token_ids: List[int], current_token_ids: List[int], + delta_token_ids: List[int]) -> Union[DeltaMessage, None]: logger.debug(f'delta_text: {delta_text}') logger.debug(f'delta_token_ids: {delta_token_ids}') # check to see if we should be streaming a tool call - is there a if self.tool_call_start_token_id not in current_token_ids: - logger.debug(f'No tool call tokens found!') + logger.debug('No tool call tokens found!') return DeltaMessage(content=delta_text) else: try: # figure out where we are in the parsing by counting tool call start & end tags - prev_tool_start_count = previous_token_ids.count(self.tool_call_start_token_id) - prev_tool_end_count = previous_token_ids.count(self.tool_call_end_token_id) - cur_tool_start_count = current_token_ids.count(self.tool_call_start_token_id) - cur_tool_end_count = current_token_ids.count(self.tool_call_end_token_id) + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) # a cheap case - we're generating text, NOT tool calls. if cur_tool_start_count == cur_tool_end_count and prev_tool_end_count == cur_tool_end_count: - logger.debug('Generating text content! skipping tool parsing.') + logger.debug( + 'Generating text content! skipping tool parsing.') return DeltaMessage(content=delta_text) # most of the time, we're going in here - we need to do partial JSON parsing and build stuff. @@ -506,7 +561,8 @@ def extract_tool_calls_streaming(self, # if a new tool call is being started. unusual since normally the first "cheap case" will be hit. if cur_tool_start_count > cur_tool_end_count and cur_tool_start_count > prev_tool_start_count: if len(delta_token_ids) > 1: - tool_call_portion = current_text.split(self.tool_call_start_token)[-1] + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] text_portion = None else: tool_call_portion = None @@ -518,122 +574,182 @@ def extract_tool_calls_streaming(self, self.current_tool_name_sent = False self.current_tool_initial_sent = False self.streamed_args_for_tool.append('') - logger.debug(f'Starting on a new tool {self.current_tool_id}') + logger.debug( + f'Starting on a new tool {self.current_tool_id}') # if an existing tool call is being updated - the most common case! elif cur_tool_start_count > cur_tool_end_count and cur_tool_start_count == prev_tool_start_count: - tool_call_portion = current_text.split(self.tool_call_start_token)[-1] + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] text_portion = None # if the current tool call is being closed elif cur_tool_start_count == cur_tool_end_count and cur_tool_end_count > prev_tool_end_count: logger.debug('Closing the current tool call!') - diff = self.prev_tool_call_arr[self.current_tool_id].get('arguments') + diff = self.prev_tool_call_arr[ + self.current_tool_id].get('arguments') if diff: - diff = json.dumps(diff).replace(self.streamed_args_for_tool[self.current_tool_id], '') - logger.debug(f'Finishing tool and found diff that wasn\'t streamed yet: {diff}') + diff = json.dumps(diff).replace( + self.streamed_args_for_tool[ + self.current_tool_id], '') + logger.debug( + f'Finishing tool and found diff that wasn\'t streamed yet: {diff}' + ) return DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( - arguments=diff - ).model_dump(exclude_none=True)) - ]) + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) else: - logger.error('INVARIANT - invalid state trying to parse tool calls (wtf?)') + logger.error( + 'INVARIANT - invalid state trying to parse tool calls (wtf?)' + ) delta = None return delta logger.debug(f'Tool call portion: {tool_call_portion}') - current_tool_call = partial_json_parser.loads(tool_call_portion, flags) if tool_call_portion else None + current_tool_call = partial_json_parser.loads( + tool_call_portion, + flags) if tool_call_portion else None logger.debug(f'Parsed tool call {current_tool_call}') # make sure to send the initial message first if we haven't already - with the tool ID if not self.current_tool_initial_sent: logger.debug('Sending InitialDeltaToolCall') self.current_tool_initial_sent = True - return DeltaMessage( - tool_calls=[ - InitialDeltaToolCall(index=self.current_tool_id).model_dump(exclude_none=True) - ] - ) + return DeltaMessage(tool_calls=[ + InitialDeltaToolCall( + index=self.current_tool_id).model_dump( + exclude_none=True) + ]) # after that, make sure we send the function name before any arguments elif not self.current_tool_name_sent: - function_name: str | None = current_tool_call.get('name') + function_name: Union[ + str, None] = current_tool_call.get('name') if function_name: - logger.debug(f'Sending DeltaToolCall with function name {function_name}!') + logger.debug( + f'Sending DeltaToolCall with function name {function_name}!' + ) self.current_tool_name_sent = True - return DeltaMessage(tool_calls=[DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True) - )]) + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + name=function_name). + model_dump(exclude_none=True)) + ]) else: return None else: # if there is no tool calls if tool_call_portion is None: # if there's text but not tool calls, send that - otherwise None to skip chunk - delta = DeltaMessage(content=delta_text) if text_portion is not None else None + delta = DeltaMessage( + content=delta_text + ) if text_portion is not None else None # now, the nitty-gritty of tool calls else: # now we have the portion to parse as tool call. if text_portion is not None: - logger.debug(f'Also, will send text portion {text_portion}') - - logger.debug(f'Trying to parse current tool call with ID {self.current_tool_id}') - if len(self.prev_tool_call_arr) <= self.current_tool_id: + logger.debug( + f'Also, will send text portion {text_portion}' + ) + + logger.debug( + f'Trying to parse current tool call with ID {self.current_tool_id}' + ) + if len(self.prev_tool_call_arr + ) <= self.current_tool_id: self.prev_tool_call_arr.append({}) - logger.debug('Pushed dummy value into tool call arr') + logger.debug( + 'Pushed dummy value into tool call arr') # main logic for tool parsing here - prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get('arguments') # previous arguments for this tool - cur_arguments = current_tool_call.get('arguments') # arguments, if any, in current dict - - logger.debug(f'Diffing old arguments {prev_arguments} against new ones {cur_arguments}') + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get('arguments') + cur_arguments = current_tool_call.get( + 'arguments' + ) # arguments, if any, in current dict + + logger.debug( + f'Diffing old arguments {prev_arguments} against new ones {cur_arguments}' + ) if not cur_arguments and not prev_arguments: - logger.debug(f'Skipping text {delta_text} - no arguments!') + logger.debug( + f'Skipping text {delta_text} - no arguments!' + ) delta = None elif not cur_arguments and prev_arguments: - logger.error('INVARIANT - impossible to have arguments reset mid-call') + logger.error( + 'INVARIANT - impossible to have arguments reset mid-call' + ) delta = None elif cur_arguments and not prev_arguments: cur_arguments_json = json.dumps(cur_arguments) - logger.debug(f'Finding {delta_text} in {cur_arguments_json}') - arguments_delta = cur_arguments_json[:cur_arguments_json.index(delta_text) + len(delta_text)] - logger.debug(f'First tokens in arguments received: {arguments_delta}') + logger.debug( + f'Finding {delta_text} in {cur_arguments_json}' + ) + arguments_delta = cur_arguments_json[: + cur_arguments_json + .index( + delta_text + ) + + len(delta_text + )] + logger.debug( + f'First tokens in arguments received: {arguments_delta}' + ) delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( - arguments=arguments_delta - ).model_dump(exclude_none=True)) + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta + ).model_dump( + exclude_none=True)) ]) - self.streamed_args_for_tool[self.current_tool_id] += arguments_delta + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta elif cur_arguments and prev_arguments: cur_args_json = json.dumps(cur_arguments) prev_args_json = json.dumps(prev_arguments) - logger.debug(f"Searching for diff between \n{cur_args_json}\n{prev_args_json}") - argument_diff = extract_intermediate_diff(cur_args_json, prev_args_json) - logger.debug(f'Got argument diff: {argument_diff}') + logger.debug( + f"Searching for diff between \n{cur_args_json}\n{prev_args_json}" + ) + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + logger.debug( + f'Got argument diff: {argument_diff}') delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( - arguments=argument_diff - ).model_dump(exclude_none=True)) + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump( + exclude_none=True)) ]) - self.streamed_args_for_tool[self.current_tool_id] += argument_diff + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff else: delta = None # handle saving the state for the current tool into the "prev" list for use in diffing for # the next iteration - if self.current_tool_id == len(self.prev_tool_call_arr) - 1: - self.prev_tool_call_arr[self.current_tool_id] = current_tool_call + if self.current_tool_id == len( + self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[ + self.current_tool_id] = current_tool_call else: - self.prev_tool_call_arr.append(current_tool_call) + self.prev_tool_call_arr.append( + current_tool_call) # TODO REPLACE ME WITH TOOL CALL #delta = DeltaMessage(content=delta_text) return delta except Exception as e: - logger.error(f'Error trying to handle streaming tool call: {e}') - logger.debug(f'Skipping chunk as a result of tool streaming extraction error') + logger.error( + f'Error trying to handle streaming tool call: {e}') + logger.debug( + 'Skipping chunk as a result of tool streaming extraction error' + ) return None # do not stream a delta. skip this token ID. diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index 9a5e82e29f146..ccb46d9537aae 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -8,12 +8,13 @@ from pydantic import BaseModel from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - CompletionRequest, - ChatCompletionNamedToolChoiceParam) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, CompletionRequest, + ChatCompletionNamedToolChoiceParam) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) + class GuidedDecodingMode(Enum): JSON = "json" REGEX = "regex" From 15aa9b43ef7c046df91f5c0dba56a1ffcbcaf36c Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 2 Aug 2024 00:35:43 -0500 Subject: [PATCH 077/222] fix: formatting --- vllm/entrypoints/openai/api_server.py | 3 +- vllm/entrypoints/openai/cli_args.py | 15 ++++-- vllm/entrypoints/openai/protocol.py | 6 +-- vllm/entrypoints/openai/serving_chat.py | 3 +- vllm/entrypoints/openai/tool_parsers.py | 66 ++++++++++++++----------- 5 files changed, 53 insertions(+), 40 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index afe80e042e763..9c37cd8c08261 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -139,7 +139,8 @@ async def create_chat_completion(request: ChatCompletionRequest, # TODO implement for streaming later if request.stream: - if openai_serving_chat.enable_auto_tools and openai_serving_chat.tool_parser: + if (openai_serving_chat.enable_auto_tools and + openai_serving_chat.tool_parser): print('handling streaming response') return StreamingResponse(content=generator, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 3dbc209180193..3322ae3fb2dff 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -39,7 +39,10 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=nullable_str, default=None, help="host name") - parser.add_argument("--port", type=int, default=8000, help="port number") + parser.add_argument("--port", + type=int, + default=8000, + help="port number") parser.add_argument( "--uvicorn-log-level", type=str, @@ -144,7 +147,8 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--enable-auto-tool-choice", action="store_true", help= - 'Enable auto tool choice for models that support it. Requires --tool-call-parser' + 'Enable auto tool choice for supported models. Use --tool-call-parser' + 'to specify which parser to use' ) parser.add_argument( @@ -152,9 +156,10 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=str, choices=['mistral', 'hermes'], help= - 'Select the tool call parser depending on the model that you\'re using. ' - 'This is used to parse the model-generated tool call into OpenAI API format. ' - 'Required for --enable-auto-tool-choice. Options: "mistral", "hermes"') + 'Select the tool call parser depending on the model that you\'re using.' + ' This is used to parse the model-generated tool call into OpenAI API ' + 'format. Required for --enable-auto-tool-choice. Options: "mistral", ' + '"hermes"') parser = AsyncEngineArgs.add_cli_args(parser) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 36e65b5ea3b35..c82839f100676 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -674,11 +674,11 @@ class DeltaFunctionCall(FunctionCall): # a tool call delta where everything is optional class DeltaToolCall(ToolCall): - index: int # this is always required, the index of the tool call in the tool_calls array. + index: int # this is always required, the index of the tool call in the arr function: Optional[DeltaFunctionCall] = None -# the initial delta that gets sent once a new tool call is started; differs in that it includes an auto-set id and type +# the initial delta that gets sent once a new tool call is started; class InitialDeltaToolCall(DeltaToolCall): id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") type: Literal["function"] = "function" @@ -692,7 +692,7 @@ class ExtractedToolCallInformation(BaseModel): # extracted tool calls tool_calls: List[ToolCall] - # content - per OpenAI spec, content AND tool calls can be returned ALTHOUGH THIS IS VERY RARE + # content - per OpenAI spec, content AND tool calls can be returned rarely # But some models will do this intentionally content: Optional[str] = None diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 73a8747fa4e44..be30fdbe2fc6c 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -34,7 +34,8 @@ log_tracing_disabled_warning) from vllm.utils import random_uuid -from vllm.entrypoints.openai.tool_parsers import ToolParser, MistralToolParser, Hermes2ProToolParser +from vllm.entrypoints.openai.tool_parsers import ( + ToolParser, MistralToolParser, Hermes2ProToolParser) from jinja2 import Environment, FileSystemLoader, select_autoescape diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index 28d939fc93d92..2fb12d6bec7b7 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -18,13 +18,16 @@ def find_common_prefix(s1: str, s2: str) -> str: """ - Finds a common prefix that is shared between two strings, if there is one. Order of arguments is NOT important. + Finds a common prefix that is shared between two strings, if there is one. + Order of arguments is NOT important. - This function is provided as a UTILITY for extracting information from JSON generated by partial_json_parser, - to help in ensuring that the right tokens are returned in streaming, so that close-quotes, close-brackets and + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and close-braces are not returned prematurely. - e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '{"fruit": "ap' + e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') -> + '{"fruit": "ap' """ prefix = '' min_length = min(len(s1), len(s2)) @@ -38,7 +41,8 @@ def find_common_prefix(s1: str, s2: str) -> str: def find_common_suffix(s1: str, s2: str) -> str: """ - Finds a common suffix shared between two strings, if there is one. Order of arguments is NOT important. + Finds a common suffix shared between two strings, if there is one. Order of + arguments is NOT important. Stops when the suffix ends OR it hits an alphanumeric character e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}' @@ -55,19 +59,21 @@ def find_common_suffix(s1: str, s2: str) -> str: def extract_intermediate_diff(curr: str, old: str) -> str: """ - Given two strings, extract the difference in the middle between two strings that are known to have a common - prefix and/or suffix. + Given two strings, extract the difference in the middle between two strings + that are known to have a common prefix and/or suffix. - This function is provided as a UTILITY for extracting information from JSON generated by partial_json_parser, - to help in ensuring that the right tokens are returned in streaming, so that close-quotes, close-brackets and - close-braces are not returned prematurely. The order of arguments IS important - the new version of the - partially-parsed JSON must be the first argument, and the secnod argument must be from the previous generation. + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. The order of arguments IS + important - the new version of the partially-parsed JSON must be the first + argument, and the secnod argument must be from the previous generation. What it returns, is tokens that should be streamed to the client. - e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}') -> 'ple' - e.g. extract_intermediate_diff('{"name": "get_current_weather", "city": "D"}', '{"name": "get_current_weather"}' -> - '", "city": "D' + e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}') + -> 'ple' + """ suffix = find_common_suffix(curr, old) @@ -89,7 +95,8 @@ def extract_intermediate_diff(curr: str, old: str) -> str: def find_all_indices(string, substring): """ - Find all (starting) indices of a substring in a given string. Useful for tool call extraction + Find all (starting) indices of a substring in a given string. Useful for + tool call extraction """ indices = [] index = -1 @@ -103,7 +110,8 @@ def find_all_indices(string, substring): class ToolParser: """ - Abstract ToolParser class that should not be used directly. Provided properties and methods should be used in + Abstract ToolParser class that should not be used directly. Provided + properties and methods should be used in derived classes. """ @@ -112,27 +120,22 @@ def __init__(self, PreTrainedTokenizerFast, PreTrainedTokenizerFast, AutoTokenizer]] = None): - # the tool call array derived from partial JSON parsing from the previous execution of the function self.prev_tool_call_arr: List[Dict] = [] # the index of the tool call that is currently being parsed self.current_tool_id: int = -1 - # indicates whether the name of the tool call that is currently being parsed has been sent. I have only seen - # OpenAI send the entire tool call name in a single chunk, so we wait until it has finished parsing. self.current_tool_name_sent: bool = False - # indicates if the initial tool call chunk with index, tool call ID etc has been sent. happens BEFORE the name - # is sent. self.current_tool_initial_sent: bool = False - # array of the argument strings (one for each tool) that have been streamed to the client. - self.streamed_args_for_tool: List[str] = [ - ] # map what has been streamed for each tool so far to a list + self.streamed_args_for_tool: List[str] = [] self.model_tokenizer = tokenizer @staticmethod def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: """ - Static method that should be implemented for extracting tool calls from a complete model-generated string. - Used for non-streaming responses where we have the entire model response available before sending to the client. + Static method that should be implemented for extracting tool calls from + a complete model-generated string. + Used for non-streaming responses where we have the entire model response + available before sending to the client. Static because it's stateless. """ raise NotImplementedError( @@ -148,12 +151,15 @@ def extract_tool_calls_streaming( delta_token_ids: List[int], ) -> Union[DeltaMessage, None]: """ - Instance method that should be implemented for extracting tool calls from an incomplete response; for use when - handling tool calls and streaming. Has to be an instance method because it requires state - the current text/ - tokens/diffs, but also the information about what has previously been parsed and extracted (see constructor) + Instance method that should be implemented for extracting tool calls + from an incomplete response; for use when handling tool calls and + streaming. Has to be an instance method because it requires state - + the current text/ tokens/diffs, but also the information about what has + previously been parsed and extracted (see constructor) """ raise NotImplementedError( - 'AbstractToolParser.extract_tool_calls_streaming has not been implemented!' + 'AbstractToolParser.extract_tool_calls_streaming has not been ' + 'implemented!' ) From 9380ad74548012e038c91d5979d132e96ebd7884 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 2 Aug 2024 21:16:41 -0500 Subject: [PATCH 078/222] fix: remove unnecessary case that was artifact from previous approach --- vllm/entrypoints/openai/api_server.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e9649d3a756ac..2c37b3323214f 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -187,6 +187,7 @@ async def show_version(): @router.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): + generator = await openai_serving_chat.create_chat_completion( request, raw_request) @@ -196,18 +197,9 @@ async def create_chat_completion(request: ChatCompletionRequest, status_code=generator.code) # if streaming is requested, handle streaming - # TODO implement for streaming later - if request.stream: - - if (openai_serving_chat.enable_auto_tools and - openai_serving_chat.tool_parser): - print('handling streaming response') - - return StreamingResponse(content=generator, - media_type="text/event-stream") - else: - return StreamingResponse(content=generator, + if request.stream: + return StreamingResponse(content=generator, media_type="text/event-stream") # handle non-streaming requests From 0390f8cc4ffc5a7a52e813f6e3278154dc8849e3 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 2 Aug 2024 22:02:56 -0500 Subject: [PATCH 079/222] fix: validation errors --- ...penai_chat_completion_client_with_tools.py | 5 +- vllm/entrypoints/chat_utils.py | 55 +++---------------- vllm/entrypoints/openai/protocol.py | 55 ++++++++++++++----- 3 files changed, 53 insertions(+), 62 deletions(-) diff --git a/examples/openai_chat_completion_client_with_tools.py b/examples/openai_chat_completion_client_with_tools.py index 417e15bce118c..bfe3c7c019671 100644 --- a/examples/openai_chat_completion_client_with_tools.py +++ b/examples/openai_chat_completion_client_with_tools.py @@ -135,11 +135,10 @@ def get_current_weather(city: str, state: str, unit: 'str'): "name": call.function.name }) -print("Sending new chat with messages", messages) + chat_completion_2 = client.chat.completions.create(messages=messages, model=model, tools=tools, stream=False) - -print(chat_completion_2) print('\n\n') +print(chat_completion_2) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 6be74ea7ae512..08e914e2b7874 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1,16 +1,14 @@ import codecs from dataclasses import dataclass, field from functools import lru_cache -from typing import Awaitable, Iterable, List, Optional, Union, cast, final +from typing import Awaitable, Iterable, List, Optional, cast # yapf conflicts with isort for this block # yapf: disable from openai.types.chat import ChatCompletionContentPartImageParam -from openai.types.chat import ( - ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) + from openai.types.chat import ChatCompletionContentPartTextParam -from openai.types.chat import ( - ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) + # yapf: enable # pydantic needs the TypedDict from typing_extensions from pydantic import ConfigDict @@ -21,50 +19,15 @@ from vllm.logger import init_logger from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import async_get_and_parse_image +from vllm.entrypoints.openai.protocol import ( + ChatCompletionMessageParam, + ChatCompletionContentPartParam, + ConversationMessage +) logger = init_logger(__name__) -class CustomChatCompletionContentPartParam(TypedDict, total=False): - __pydantic_config__ = ConfigDict(extra="allow") # type: ignore - - type: Required[str] - """The type of the content part.""" - - -ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam, - CustomChatCompletionContentPartParam] - - -class CustomChatCompletionMessageParam(TypedDict, total=False): - """Enables custom roles in the Chat Completion API.""" - role: Required[str] - """The role of the message's author.""" - - content: Union[str, List[ChatCompletionContentPartParam]] - """The contents of the message.""" - - name: str - """An optional name for the participant. - - Provides the model information to differentiate between participants of the - same role. - """ - - -ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, - CustomChatCompletionMessageParam] - - -@final # So that it should be compatible with Dict[str, str] -class ConversationMessage(TypedDict, total=False): - role: str - content: Optional[str] - tool_call_id: Optional[str] - name: Optional[str] - tool_calls: Optional[List] - - @dataclass(frozen=True) class ChatMessageParseResult: messages: List[ConversationMessage] @@ -195,7 +158,7 @@ def parse_chat_message_content( return ChatMessageParseResult(messages=[], mm_futures=[]) # special case - assistant message where tool calls are provided. - if role == 'assistant' and tool_calls is not None and len(tool_calls): + if role == 'assistant' and tool_calls is not None: messages = [ ConversationMessage(role=role, content=content, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index c82839f100676..092e751df5a95 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1,31 +1,55 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import time -from typing import Any, Dict, List, Literal, Optional, Union, Type -import openai +from typing import Any, Dict, List, Literal, Optional, Union, Type, final import torch from pydantic import BaseModel, ConfigDict, Field, model_validator from transformers import PreTrainedTokenizer from typing_extensions import Annotated, TypedDict, Required -from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.pooling_params import PoolingParams from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.utils import random_uuid +from openai.types.chat import ( + ChatCompletionContentPartParam, + ChatCompletionMessageParam as OpenAIChatCompletionMessageParam, + ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam +) -class CustomChatCompletionContentPartParam(TypedDict, total=False): - __pydantic_config__ = ConfigDict(extra="allow") # type: ignore +class CustomChatCompletionMessageParam(TypedDict, total=False): + """Enables custom roles in the Chat Completion API.""" + role: Required[str] + """The role of the message's author.""" + + content: Union[str, List[ChatCompletionContentPartParam]] + """The contents of the message.""" + + name: str + """An optional name for the participant. + + Provides the model information to differentiate between participants of the + same role. + """ + + tool_call_id: Optional[str] + + tool_calls: Optional[List[dict]] + +@final # So that it should be compatible with Dict[str, str] +class ConversationMessage(TypedDict, total=False): + role: str + content: Optional[str] + tool_call_id: Optional[str] + name: Optional[str] + tool_calls: Optional[List] - type: Required[str] - """The type of the content part.""" -ChatCompletionContentPartParam = Union[ - openai.types.chat.ChatCompletionContentPartParam, - CustomChatCompletionContentPartParam] +ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, + CustomChatCompletionMessageParam] class CustomChatCompletionMessageParam(TypedDict, total=False): @@ -46,10 +70,15 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): tool_calls: Optional[List[dict]] +class CustomChatCompletionContentPartParam(TypedDict, total=False): + __pydantic_config__ = ConfigDict(extra="allow") # type: ignore + + type: Required[str] + """The type of the content part.""" + -ChatCompletionMessageParam = Type[ - Union[openai.types.chat.ChatCompletionMessageParam, - CustomChatCompletionMessageParam]] +ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam, + CustomChatCompletionContentPartParam] class OpenAIBaseModel(BaseModel): From e29a62aefe4b4ff536ddf13c9bf63e769ad12d46 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 2 Aug 2024 22:43:56 -0500 Subject: [PATCH 080/222] fix: hermes prompt template issue that occurred when passing multiple tool calls in --- examples/tool_chat_template_hermes.jinja | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/tool_chat_template_hermes.jinja b/examples/tool_chat_template_hermes.jinja index d807c12dbba16..3cc07d9ad0525 100644 --- a/examples/tool_chat_template_hermes.jinja +++ b/examples/tool_chat_template_hermes.jinja @@ -88,8 +88,9 @@ {%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %} {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} {%- elif message.role == "assistant" %} - {{- '<|im_start|>' + message.role + '\n\n' }} + {{- '<|im_start|>' + message.role }} {%- for tool_call in message.tool_calls %} + {{- '\n\n' }} {%- if tool_call.function is defined %} {%- set tool_call = tool_call.function %} {%- endif %} From e393c66c751b37c897fe26e9d0862e9203d05691 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 3 Aug 2024 00:17:18 -0500 Subject: [PATCH 081/222] fix: formatting & mypy fixes --- ...penai_chat_completion_client_with_tools.py | 1 - vllm/entrypoints/chat_utils.py | 8 +- vllm/entrypoints/openai/api_server.py | 2 +- vllm/entrypoints/openai/cli_args.py | 8 +- vllm/entrypoints/openai/protocol.py | 13 +- vllm/entrypoints/openai/serving_chat.py | 149 +++++++++--------- vllm/entrypoints/openai/tool_parsers.py | 3 +- 7 files changed, 90 insertions(+), 94 deletions(-) diff --git a/examples/openai_chat_completion_client_with_tools.py b/examples/openai_chat_completion_client_with_tools.py index bfe3c7c019671..bddf97869f59e 100644 --- a/examples/openai_chat_completion_client_with_tools.py +++ b/examples/openai_chat_completion_client_with_tools.py @@ -135,7 +135,6 @@ def get_current_weather(city: str, state: str, unit: 'str'): "name": call.function.name }) - chat_completion_2 = client.chat.completions.create(messages=messages, model=model, tools=tools, diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 08e914e2b7874..98b19c37fb68e 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -19,11 +19,9 @@ from vllm.logger import init_logger from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import async_get_and_parse_image -from vllm.entrypoints.openai.protocol import ( - ChatCompletionMessageParam, - ChatCompletionContentPartParam, - ConversationMessage -) +from vllm.entrypoints.openai.protocol import (ChatCompletionMessageParam, + ChatCompletionContentPartParam, + ConversationMessage) logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 2c37b3323214f..0c69c59b3cd96 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -200,7 +200,7 @@ async def create_chat_completion(request: ChatCompletionRequest, if request.stream: return StreamingResponse(content=generator, - media_type="text/event-stream") + media_type="text/event-stream") # handle non-streaming requests else: diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 96f276a5289f0..4b6f76f40913e 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -39,10 +39,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=nullable_str, default=None, help="host name") - parser.add_argument("--port", - type=int, - default=8000, - help="port number") + parser.add_argument("--port", type=int, default=8000, help="port number") parser.add_argument( "--uvicorn-log-level", type=str, @@ -154,8 +151,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=False, help= 'Enable auto tool choice for supported models. Use --tool-call-parser' - 'to specify which parser to use' - ) + 'to specify which parser to use') parser.add_argument( "--tool-call-parser", diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 092e751df5a95..301442ecbad6b 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -12,11 +12,11 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.utils import random_uuid -from openai.types.chat import ( - ChatCompletionContentPartParam, - ChatCompletionMessageParam as OpenAIChatCompletionMessageParam, - ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam -) +from openai.types.chat import (ChatCompletionContentPartParam, + ChatCompletionMessageParam as + OpenAIChatCompletionMessageParam, + ChatCompletionContentPartParam as + OpenAIChatCompletionContentPartParam) class CustomChatCompletionMessageParam(TypedDict, total=False): @@ -38,6 +38,7 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): tool_calls: Optional[List[dict]] + @final # So that it should be compatible with Dict[str, str] class ConversationMessage(TypedDict, total=False): role: str @@ -47,7 +48,6 @@ class ConversationMessage(TypedDict, total=False): tool_calls: Optional[List] - ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, CustomChatCompletionMessageParam] @@ -70,6 +70,7 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): tool_calls: Optional[List[dict]] + class CustomChatCompletionContentPartParam(TypedDict, total=False): __pydantic_config__ = ConfigDict(extra="allow") # type: ignore diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 0e5ca05ba7f79..9f5c9b669d418 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -34,8 +34,9 @@ log_tracing_disabled_warning) from vllm.utils import random_uuid -from vllm.entrypoints.openai.tool_parsers import ( - ToolParser, MistralToolParser, Hermes2ProToolParser) +from vllm.entrypoints.openai.tool_parsers import (ToolParser, + MistralToolParser, + Hermes2ProToolParser) from jinja2 import Environment, FileSystemLoader, select_autoescape @@ -47,21 +48,19 @@ class OpenAIServingChat(OpenAIServing): - def __init__( - self, - async_engine_client: AsyncEngineClient, - model_config: ModelConfig, - served_model_names: List[str], - response_role: str, - *, - lora_modules: Optional[List[LoRAModulePath]], - prompt_adapters: Optional[List[PromptAdapterPath]], - request_logger: Optional[RequestLogger], - chat_template: Optional[str], - return_tokens_as_token_ids: bool = False, - enable_auto_tools: Optional[bool] = False, - tool_parser: Optional[str] = None - ): + def __init__(self, + async_engine_client: AsyncEngineClient, + model_config: ModelConfig, + served_model_names: List[str], + response_role: str, + *, + lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]], + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + return_tokens_as_token_ids: bool = False, + enable_auto_tools: Optional[bool] = False, + tool_parser: Optional[str] = None): super().__init__(async_engine_client=async_engine_client, model_config=model_config, served_model_names=served_model_names, @@ -80,27 +79,27 @@ def __init__( logger.info( '"Auto" tool choice has been enabled please note that while the ' 'parallel_tool_calls client option is preset for compatibility ' - 'reasons, it will be ignored.' - ) + 'reasons, it will be ignored.') if self.enable_auto_tools: if tool_parser == 'mistral': - self.tool_parser: Optional[Type[ToolParser]] = MistralToolParser + self.tool_parser: Optional[ + Type[ToolParser]] = MistralToolParser elif tool_parser == 'hermes': - self.tool_parser: Optional[Type[ToolParser]] = Hermes2ProToolParser + self.tool_parser: Optional[ + Type[ToolParser]] = Hermes2ProToolParser else: raise TypeError( - 'Error: --enable-auto-tool-choice requires --tool-parser' - ) + 'Error: --enable-auto-tool-choice requires --tool-parser') else: self.tool_parser: Optional[Type[ToolParser]] = None async def create_chat_completion( - self, - request: ChatCompletionRequest, - raw_request: Optional[Request] = None + self, + request: ChatCompletionRequest, + raw_request: Optional[Request] = None ) -> Union[ErrorResponse, AsyncGenerator[str, None], - ChatCompletionResponse]: + ChatCompletionResponse]: """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/chat/create @@ -180,7 +179,7 @@ async def create_chat_completion( tokenizer, guided_decode_logits_processor, default_max_tokens=self.max_model_len - - len(prompt_inputs["prompt_token_ids"])) + len(prompt_inputs["prompt_token_ids"])) self._log_inputs(request_id, prompt_inputs, @@ -239,12 +238,12 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str: return request.messages[-1]["role"] async def chat_completion_stream_generator( - self, - request: ChatCompletionRequest, - result_generator: AsyncIterator[RequestOutput], - request_id: str, - conversation: List[ConversationMessage], - tokenizer: PreTrainedTokenizer, + self, + request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], + request_id: str, + conversation: List[ConversationMessage], + tokenizer: PreTrainedTokenizer, ) -> AsyncGenerator[str, None]: model_name = self.served_model_names[0] created_time = int(time.time()) @@ -257,7 +256,8 @@ async def chat_completion_stream_generator( previous_num_tokens = [0] * num_choices finish_reason_sent = [False] * num_choices - tool_parser: ToolParser = self.tool_parser(tokenizer) + tool_parser: ToolParser = self.tool_parser( + tokenizer) if self.tool_parser else None try: async for res in result_generator: @@ -300,8 +300,8 @@ async def chat_completion_stream_generator( last_msg_content = "" if conversation and conversation[-1].get( "content") and conversation[-1].get( - "role") == role: - last_msg_content = conversation[-1]["content"] + "role") == role: + last_msg_content = conversation[-1]["content"] or '' if last_msg_content: for i in range(num_choices): @@ -340,16 +340,13 @@ async def chat_completion_stream_generator( for output in res.outputs: i = output.index - # prints the full completion so far including text and tokens - # print(f'[{i}]:', output) if finish_reason_sent[i]: continue delta_token_ids = output.token_ids[previous_num_tokens[i]:] out_logprobs = output.logprobs[ - previous_num_tokens[ - i]:] if output.logprobs else None + previous_num_tokens[i]:] if output.logprobs else None if request.logprobs and request.top_logprobs is not None: assert out_logprobs is not None, ( @@ -364,6 +361,7 @@ async def chat_completion_stream_generator( logprobs = None delta_text = output.text[len(previous_texts[i]):] + delta_message: Optional[DeltaMessage] = None # handle streaming deltas for tools with tool_choice if request.tool_choice and type( @@ -380,14 +378,12 @@ async def chat_completion_stream_generator( or request.tool_choice == 'auto') and self.enable_auto_tools): - print('output token IDs', output.token_ids) - print('delta token IDs', delta_token_ids) delta_message = tool_parser.extract_tool_calls_streaming( previous_text=previous_texts[i], current_text=output.text, delta_text=delta_text, - previous_token_ids=output.token_ids[ - :-1 * len(delta_token_ids)], + previous_token_ids=output. + token_ids[:-1 * len(delta_token_ids)], current_token_ids=output.token_ids, delta_token_ids=delta_token_ids) else: @@ -425,7 +421,7 @@ async def chat_completion_stream_generator( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + - completion_tokens, + completion_tokens, ) chunk.usage = usage else: @@ -434,14 +430,18 @@ async def chat_completion_stream_generator( data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" else: - # check to make sure we haven't missed something on the last function call - if (delta_message.tool_calls and - (delta_message.tool_calls[0].function.arguments - == '' or - delta_message.tool_calls[ - 0].function.arguments and - (output.finish_reason == 'stop' - or output.finish_reason == 'tool_calls'))): + # check to make sure we haven't "forgotten" to stream + # any tokens that were generated but previously + # matched by partial json parsing + if (delta_message.tool_calls + and delta_message.tool_calls[0] + and delta_message.tool_calls[0].function and + (delta_message.tool_calls[0].function.arguments + == '' or + delta_message.tool_calls[0].function.arguments and + (output.finish_reason == 'stop' + or output.finish_reason == 'tool_calls'))): + expected_call = json.dumps( tool_parser.prev_tool_call_arr[ len(tool_parser.prev_tool_call_arr) - @@ -453,10 +453,10 @@ async def chat_completion_stream_generator( delta_message = DeltaMessage(tool_calls=[ DeltaToolCall( index=len(tool_parser.prev_tool_call_arr) - - 1, + 1, function=DeltaFunctionCall( arguments=remaining_call).model_dump( - exclude_none=True)) + exclude_none=True)) ]) # Send the finish response for each request.n only once prompt_tokens = len(res.prompt_token_ids) @@ -483,7 +483,7 @@ async def chat_completion_stream_generator( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + - completion_tokens, + completion_tokens, ) chunk.usage = usage else: @@ -520,13 +520,13 @@ async def chat_completion_stream_generator( yield "data: [DONE]\n\n" async def chat_completion_full_generator( - self, - request: ChatCompletionRequest, - raw_request: Optional[Request], - result_generator: AsyncIterator[RequestOutput], - request_id: str, - conversation: List[ConversationMessage], - tokenizer: PreTrainedTokenizer, + self, + request: ChatCompletionRequest, + raw_request: Optional[Request], + result_generator: AsyncIterator[RequestOutput], + request_id: str, + conversation: List[ConversationMessage], + tokenizer: PreTrainedTokenizer, ) -> Union[ErrorResponse, ChatCompletionResponse]: model_name = self.served_model_names[0] @@ -564,8 +564,10 @@ async def chat_completion_full_generator( # if auto tools are not enabled, and a named tool choice using # outlines is not being used - if not self.enable_auto_tools and not isinstance( - request.tool_choice, ChatCompletionNamedToolChoiceParam): + if not (self.enable_auto_tools + or not self.tool_parser) and not isinstance( + request.tool_choice, + ChatCompletionNamedToolChoiceParam): message = ChatMessage(role=role, content=output.text) # if the reqeust uses tools and specified a tool choice @@ -590,7 +592,8 @@ async def chat_completion_full_generator( # handle when there are tools and tool choice is auto elif request.tools and ( request.tool_choice == "auto" - or request.tool_choice is None) and self.enable_auto_tools: + or request.tool_choice is None) and self.enable_auto_tools \ + and self.tool_parser: tool_call_info = self.tool_parser.extract_tool_calls( output.text) @@ -624,7 +627,7 @@ async def chat_completion_full_generator( last_msg_content = "" if conversation and conversation[-1].get( "content") and conversation[-1].get("role") == role: - last_msg_content = conversation[-1]["content"] + last_msg_content = conversation[-1]["content"] or '' for choice in choices: full_message = last_msg_content + choice.message.content if choice.message.content else '' @@ -665,11 +668,11 @@ def _get_top_logprobs( ] def _create_chat_logprobs( - self, - token_ids: GenericSequence[int], - top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], - tokenizer: PreTrainedTokenizer, - num_output_top_logprobs: Optional[int] = None, + self, + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], + tokenizer: PreTrainedTokenizer, + num_output_top_logprobs: Optional[int] = None, ) -> ChatCompletionLogProbs: """Create OpenAI-style logprobs.""" diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index 2fb12d6bec7b7..1e345a69badd5 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -159,8 +159,7 @@ def extract_tool_calls_streaming( """ raise NotImplementedError( 'AbstractToolParser.extract_tool_calls_streaming has not been ' - 'implemented!' - ) + 'implemented!') class MistralToolParser(ToolParser): From 8d1eac198a0a0b5e227f452a885421280955ab72 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 3 Aug 2024 00:25:24 -0500 Subject: [PATCH 082/222] fix: more types --- vllm/entrypoints/openai/serving_chat.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 9f5c9b669d418..6517bb3302dbb 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -81,18 +81,15 @@ def __init__(self, 'parallel_tool_calls client option is preset for compatibility ' 'reasons, it will be ignored.') + self.tool_parser: Optional[Type[ToolParser]] = None if self.enable_auto_tools: if tool_parser == 'mistral': - self.tool_parser: Optional[ - Type[ToolParser]] = MistralToolParser + self.tool_parser = MistralToolParser elif tool_parser == 'hermes': - self.tool_parser: Optional[ - Type[ToolParser]] = Hermes2ProToolParser + self.tool_parser = Hermes2ProToolParser else: raise TypeError( 'Error: --enable-auto-tool-choice requires --tool-parser') - else: - self.tool_parser: Optional[Type[ToolParser]] = None async def create_chat_completion( self, @@ -256,7 +253,7 @@ async def chat_completion_stream_generator( previous_num_tokens = [0] * num_choices finish_reason_sent = [False] * num_choices - tool_parser: ToolParser = self.tool_parser( + tool_parser: Optional[ToolParser] = self.tool_parser( tokenizer) if self.tool_parser else None try: @@ -374,8 +371,9 @@ async def chat_completion_stream_generator( ]) # handle streaming deltas for tools with tool_choice - elif (request.tools and (request.tool_choice is None - or request.tool_choice == 'auto') + elif (request.tools and tool_parser + and (request.tool_choice is None + or request.tool_choice == 'auto') and self.enable_auto_tools): delta_message = tool_parser.extract_tool_calls_streaming( @@ -440,8 +438,8 @@ async def chat_completion_stream_generator( == '' or delta_message.tool_calls[0].function.arguments and (output.finish_reason == 'stop' - or output.finish_reason == 'tool_calls'))): - + or output.finish_reason == 'tool_calls')) + and tool_parser): expected_call = json.dumps( tool_parser.prev_tool_call_arr[ len(tool_parser.prev_tool_call_arr) - From 813c3c554c55c801d32177b84bfbdc9de666ec9d Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 3 Aug 2024 00:27:29 -0500 Subject: [PATCH 083/222] fix: more mypy fixes --- vllm/entrypoints/openai/serving_chat.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 6517bb3302dbb..5ae9a37d62e9b 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -74,7 +74,7 @@ def __init__(self, self.chat_template = load_chat_template(chat_template) # set up tool use - self.enable_auto_tools: bool = enable_auto_tools + self.enable_auto_tools: bool = enable_auto_tools or False if self.enable_auto_tools: logger.info( '"Auto" tool choice has been enabled please note that while the ' @@ -463,8 +463,9 @@ async def chat_completion_stream_generator( delta=delta_message, logprobs=logprobs, finish_reason=output.finish_reason - if not len(tool_parser.prev_tool_call_arr) else - 'tool_calls', + if not (tool_parser + and len(tool_parser.prev_tool_call_arr)) + else 'tool_calls', stop_reason=output.stop_reason) chunk = ChatCompletionStreamResponse( id=request_id, From f9da83219972cae8283eb13df0727760a87cccfd Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 3 Aug 2024 00:55:48 -0500 Subject: [PATCH 084/222] fix: formatting --- vllm/entrypoints/chat_utils.py | 8 ++++---- vllm/entrypoints/openai/tool_parsers.py | 10 ++++------ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 98b19c37fb68e..cbedbb09764cf 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1,7 +1,7 @@ import codecs from dataclasses import dataclass, field from functools import lru_cache -from typing import Awaitable, Iterable, List, Optional, cast +from typing import Awaitable, Iterable, List, Optional, cast, Any # yapf conflicts with isort for this block # yapf: disable @@ -11,7 +11,6 @@ # yapf: enable # pydantic needs the TypedDict from typing_extensions -from pydantic import ConfigDict from transformers import PreTrainedTokenizer from typing_extensions import Required, TypedDict @@ -179,5 +178,6 @@ def parse_chat_message_content( messages = [ConversationMessage(role=role, content=content)] return ChatMessageParseResult(messages=messages, mm_futures=[]) - return _parse_chat_message_content_parts(role, content, model_config, - tokenizer) + return _parse_chat_message_content_parts(role, + cast(Iterable[Any], content), + model_config, tokenizer) diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index 1e345a69badd5..978e7b78534e9 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -285,7 +285,7 @@ def extract_tool_calls_streaming( self.bot_token)[1] parsable_arr = tool_call_message_portion.replace('\'', '"') - logger.debug('parsing: %s', parsable_arr) + #logger.debug('parsing: %s', parsable_arr) # tool calls are generated in an array, so do partial JSON parsing on the entire array tool_call_arr: List[Dict] = partial_json_parser.loads( @@ -338,7 +338,7 @@ def extract_tool_calls_streaming( # if there is NOTHING in the array, e.g. if only the open bracket was streamed yet else: - logger.debug('No tool call detected yet!') + #logger.debug('No tool call detected yet!') return None # if the current tool initial data incl. the id, type=function and idx not sent, send that @@ -378,9 +378,7 @@ def extract_tool_calls_streaming( new_text = delta_text.replace('\'', '"') if not cur_arguments and not prev_arguments: - logger.debug( - f'Skipping text {new_text} (tokens {delta_token_ids}) - no arguments yet' - ) + delta = None elif not cur_arguments and prev_arguments: logger.error( @@ -508,7 +506,7 @@ def __init__(self, self.current_tool_name_sent: bool = False # reset each time we encounter a new tool in the array self.prev_tool_call_arr: List[Dict] = [] self.current_tool_id: int = -1 - self.current_tool_name_sent: bool = False + self.current_tool_name_sent = False self.current_tool_initial_sent: bool = False self.streamed_args_for_tool: List[str] = [ ] # map what has been streamed for each tool so far to a list From 08d54b1025346a34b5a90ac711e169894e582b81 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 3 Aug 2024 01:01:31 -0500 Subject: [PATCH 085/222] fix: more mypy fixes --- vllm/entrypoints/chat_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index cbedbb09764cf..39d3206d3e098 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1,7 +1,7 @@ import codecs from dataclasses import dataclass, field from functools import lru_cache -from typing import Awaitable, Iterable, List, Optional, cast, Any +from typing import Awaitable, Iterable, List, Optional, cast, Any, Union # yapf conflicts with isort for this block # yapf: disable @@ -158,7 +158,7 @@ def parse_chat_message_content( if role == 'assistant' and tool_calls is not None: messages = [ ConversationMessage(role=role, - content=content, + content=cast(Union[str, None], content), tool_calls=list(tool_calls)) ] return ChatMessageParseResult(messages=messages, mm_futures=[]) @@ -168,8 +168,9 @@ def parse_chat_message_content( messages = [ ConversationMessage(role=role, name=name, - content=content, - tool_call_id=tool_call_id) + content=cast(Union[str, None], content), + tool_call_id=cast(Union[str, None], + tool_call_id)) ] return ChatMessageParseResult(messages=messages, mm_futures=[]) From c87a6eceea44a003243847075dd77ec337a0c75b Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 3 Aug 2024 01:07:21 -0500 Subject: [PATCH 086/222] fix: final mypy fixes --- vllm/entrypoints/openai/protocol.py | 34 +++++++++---------------- vllm/entrypoints/openai/serving_chat.py | 2 +- 2 files changed, 13 insertions(+), 23 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 301442ecbad6b..42a976817c196 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -52,25 +52,6 @@ class ConversationMessage(TypedDict, total=False): CustomChatCompletionMessageParam] -class CustomChatCompletionMessageParam(TypedDict, total=False): - """Enables custom roles in the Chat Completion API.""" - role: Required[str] - """The role of the message's author.""" - - content: Union[str, List[ChatCompletionContentPartParam]] - """The contents of the message.""" - - name: Optional[str] - """An optional name for the participant. - - Provides the model information to differentiate between participants of the - same role. - """ - tool_call_id: Optional[str] - - tool_calls: Optional[List[dict]] - - class CustomChatCompletionContentPartParam(TypedDict, total=False): __pydantic_config__ = ConfigDict(extra="allow") # type: ignore @@ -697,16 +678,25 @@ def to_dict(self): } -class DeltaFunctionCall(FunctionCall): +class DeltaFunctionCall(BaseModel): name: Optional[str] = None arguments: Optional[str] = None # a tool call delta where everything is optional -class DeltaToolCall(ToolCall): - index: int # this is always required, the index of the tool call in the arr +class DeltaToolCall(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") + type: Literal["function"] = "function" + index: int function: Optional[DeltaFunctionCall] = None + def to_dict(self): + return { + "id": self.id, + "type": self.type, + "function": self.function.to_dict() if self.function else None + } + # the initial delta that gets sent once a new tool call is started; class InitialDeltaToolCall(DeltaToolCall): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 5ae9a37d62e9b..1ac4077cf49e6 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -569,7 +569,7 @@ async def chat_completion_full_generator( ChatCompletionNamedToolChoiceParam): message = ChatMessage(role=role, content=output.text) - # if the reqeust uses tools and specified a tool choice + # if the request uses tools and specified a tool choice elif request.tool_choice and type( request.tool_choice) is ChatCompletionNamedToolChoiceParam: From 19eab7a23c43156ad7be1010945db5e4847d85df Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 3 Aug 2024 18:09:50 -0500 Subject: [PATCH 087/222] fix: finish_reason behavior was broken for non-streaming calls --- vllm/entrypoints/openai/serving_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 1ac4077cf49e6..3a29570b4107d 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -618,7 +618,7 @@ async def chat_completion_full_generator( message=message, logprobs=logprobs, finish_reason='tool_calls' if tools_called else - output.stop_reason if output.stop_reason else 'stop', + output.finish_reason if output.finish_reason else 'stop', stop_reason=output.stop_reason) choices.append(choice_data) From 0c72dc6b9da07e8cacae702444a8aca53ac887ce Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 5 Aug 2024 09:55:39 -0500 Subject: [PATCH 088/222] fix(test): ensure tool_choice="required" throws a BadRequestError, and that "auto" tool choice throws a BadRequestError without the proper CLI args --- vllm/entrypoints/openai/serving_chat.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 3a29570b4107d..713232e8532be 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -158,6 +158,21 @@ async def create_chat_completion( logger.error("Error in loading multi-modal data: %s", e) return self.create_error_response(str(e)) + # validation for OpenAI tools + try: + # tool_choice = "required" is not supported + assert(request.tool_choice != 'required', 'tool_choice="required" is not supported.') + + # "auto" tools requires --enable-api-tools --enable-auto-tool-choice and --tool-parser + if request.tool_choice == 'auto': + assert(self.enable_auto_tools and self.tool_parser is not None, + '"auto" tool choice requires --enable-auto-tool-choice and --tool-parser to be set') + + except Exception as e: + logger.error('Error validating OpenAI tool configuration: %s', e) + return self.create_error_response(str(e)) + + request_id = f"chat-{random_uuid()}" try: @@ -380,8 +395,7 @@ async def chat_completion_stream_generator( previous_text=previous_texts[i], current_text=output.text, delta_text=delta_text, - previous_token_ids=output. - token_ids[:-1 * len(delta_token_ids)], + previous_token_ids=output.token_ids[:-1 * len(delta_token_ids)], current_token_ids=output.token_ids, delta_token_ids=delta_token_ids) else: From ee3b6adb56f27bc09cc80d02d95fa76654a923fe Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 5 Aug 2024 09:56:25 -0500 Subject: [PATCH 089/222] fix: remove deprecated CLI argument --- vllm/entrypoints/openai/cli_args.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 4b6f76f40913e..7272208feddd5 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -140,11 +140,6 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="If specified, will run the OpenAI frontend server in the same " "process as the model serving engine.") - parser.add_argument("--enable-api-tools", - action="store_true", - help="Enable OpenAI-like tools API " - "(only function calls are currently supported)") - parser.add_argument( "--enable-auto-tool-choice", action="store_true", From 04ba399a2751449843fd930cd49c8db78302233a Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 5 Aug 2024 17:27:59 -0500 Subject: [PATCH 090/222] fix: type --- vllm/entrypoints/chat_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 39d3206d3e098..6aa7a9d8b7e94 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -180,5 +180,5 @@ def parse_chat_message_content( return ChatMessageParseResult(messages=messages, mm_futures=[]) return _parse_chat_message_content_parts(role, - cast(Iterable[Any], content), + cast(Iterable, content), model_config, tokenizer) From f2c1254cec9e0e671fa63d0e968b7ad34b6081cb Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 5 Aug 2024 17:33:43 -0500 Subject: [PATCH 091/222] fix: remoev another assertion and replace with if/exception --- vllm/entrypoints/openai/serving_chat.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 713232e8532be..17717e0868bc7 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -147,31 +147,26 @@ async def create_chat_completion( return self.create_error_response(str(e)) mm_data: Optional[MultiModalDataDict] = None + try: if len(mm_futures): # since we support only single mm data currently - assert len( - mm_futures - ) == 1, "Multiple 'image_url' input is currently not supported." + if len(mm_futures) != 1: + return self.create_error_response("Multiple 'image_url' input is currently not supported.") mm_data = await mm_futures[0] except Exception as e: logger.error("Error in loading multi-modal data: %s", e) return self.create_error_response(str(e)) # validation for OpenAI tools - try: # tool_choice = "required" is not supported - assert(request.tool_choice != 'required', 'tool_choice="required" is not supported.') + if request.tool_choice != 'required': + return self.create_error_response('tool_choice = "required" is not supported!') # "auto" tools requires --enable-api-tools --enable-auto-tool-choice and --tool-parser - if request.tool_choice == 'auto': - assert(self.enable_auto_tools and self.tool_parser is not None, - '"auto" tool choice requires --enable-auto-tool-choice and --tool-parser to be set') - - except Exception as e: - logger.error('Error validating OpenAI tool configuration: %s', e) - return self.create_error_response(str(e)) - + if request.tool_choice == 'auto' and not (self.enable_auto_tools and self.tool_parser is not None): + return self.create_error_response( + '"auto" tool choice requires --enable-auto-tool-choice and --tool-parser to be set') request_id = f"chat-{random_uuid()}" try: @@ -236,7 +231,8 @@ async def create_chat_completion( request, raw_request, result_generator, request_id, conversation, tokenizer) - assert isinstance(generator, ChatCompletionResponse) + if not isinstance(generator, ChatCompletionResponse): + raise ValueError('Expected generator to be instance of ChatCompletionResponse') return generator except ValueError as e: From eefbee5f90ba797a202df294027f788537e663bb Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 5 Aug 2024 18:01:20 -0500 Subject: [PATCH 092/222] fix: bad condition --- vllm/entrypoints/chat_utils.py | 2 +- vllm/entrypoints/openai/serving_chat.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 6aa7a9d8b7e94..70db6078b1158 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -158,7 +158,7 @@ def parse_chat_message_content( if role == 'assistant' and tool_calls is not None: messages = [ ConversationMessage(role=role, - content=cast(Union[str, None], content), + content=cast(Optional[str], content), tool_calls=list(tool_calls)) ] return ChatMessageParseResult(messages=messages, mm_futures=[]) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 17717e0868bc7..8cfdfeba80fed 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -160,7 +160,7 @@ async def create_chat_completion( # validation for OpenAI tools # tool_choice = "required" is not supported - if request.tool_choice != 'required': + if request.tool_choice == 'required': return self.create_error_response('tool_choice = "required" is not supported!') # "auto" tools requires --enable-api-tools --enable-auto-tool-choice and --tool-parser From d18e9c363790458f3d7d0d6f862163a0f8d25197 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 5 Aug 2024 18:02:38 -0500 Subject: [PATCH 093/222] fix: cleaner concat --- vllm/entrypoints/openai/serving_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 8cfdfeba80fed..08bf483d67889 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -639,7 +639,7 @@ async def chat_completion_full_generator( last_msg_content = conversation[-1]["content"] or '' for choice in choices: - full_message = last_msg_content + choice.message.content if choice.message.content else '' + full_message = last_msg_content + (choice.message.content or '') choice.message.content = full_message num_prompt_tokens = len(final_res.prompt_token_ids) From 5d43a00735c5bfa4f973c1fc322b583c6bd8b691 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 5 Aug 2024 18:05:59 -0500 Subject: [PATCH 094/222] fix: clean up vode --- vllm/entrypoints/openai/protocol.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 42a976817c196..c715600fe284b 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -333,9 +333,7 @@ def check_guided_decoding_count(cls, data): "You can only use one kind of guided decoding " "('guided_json', 'guided_regex' or 'guided_choice').") # you can only either use guided decoding or tools, not both - if (guide_count > 1 and "tool_choice" in data - and data["tool_choice"] != "none" - and data["tool_choice"] != "auto"): + if guide_count > 1 and data.get('tool_choice', 'none') not in ("none", "auto"): raise ValueError( "You can only either use guided decoding or tools, not both.") return data From 092224c83684bdd5967f5030ae4288a20d57d209 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 5 Aug 2024 18:08:23 -0500 Subject: [PATCH 095/222] chore: more cleanup --- vllm/entrypoints/openai/serving_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 08bf483d67889..5266995449d4c 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -106,7 +106,7 @@ async def create_chat_completion( """ error_check_ret = await self._check_model(request) if error_check_ret is not None: - print('Error with model', error_check_ret) + logger.error('Error with model %s', error_check_ret) return error_check_ret try: From 8f6029f574ee8ba1d18593ae3f62fabee5dc5c31 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 5 Aug 2024 18:13:23 -0500 Subject: [PATCH 096/222] chore: clean up conditional and document better --- vllm/entrypoints/openai/protocol.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index c715600fe284b..ec48a2652cc2e 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -342,6 +342,11 @@ def check_guided_decoding_count(cls, data): @classmethod def check_tool_usage(cls, data): + # if "tool_choice" is not specified but tools are provided, default to "auto" tool_choice + if "tool_choice" not in data and "tools" in data: + data["tool_choice"] = "auto" + + # if "tool_choice" is specified -- validation if "tool_choice" in data: # ensure that if "tool choice" is specified, tools are present @@ -382,10 +387,6 @@ def check_tool_usage(cls, data): "The tool specified in `tool_choice` does not match any of the specified `tools`" ) - # per OpenAI spec, make sure that tool_choice defaults to "auto" when tools are specified - elif "tools" in data and "tool_choice" not in data: - data["tool_choice"] = "auto" - # TODO validate tools return data From 1d856c7c9a5a314ece3ca3d003fa60a3a663edd0 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 5 Aug 2024 20:11:07 -0500 Subject: [PATCH 097/222] fix(ci): broken tool streaming when using guided decoding --- vllm/entrypoints/openai/protocol.py | 2 +- vllm/entrypoints/openai/serving_chat.py | 106 +++++++++++++----------- 2 files changed, 57 insertions(+), 51 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index ec48a2652cc2e..4a7b5d1efe9d6 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -758,7 +758,7 @@ class ChatCompletionResponse(OpenAIBaseModel): class DeltaMessage(OpenAIBaseModel): role: Optional[str] = None content: Optional[str] = None - tool_calls: List[DeltaToolCall] = Field(default_factory=list) + tool_calls: Optional[List[DeltaToolCall]] = Field(default_factory=list) class ChatCompletionResponseStreamChoice(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 5266995449d4c..68c6cdcb577e6 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -92,11 +92,11 @@ def __init__(self, 'Error: --enable-auto-tool-choice requires --tool-parser') async def create_chat_completion( - self, - request: ChatCompletionRequest, - raw_request: Optional[Request] = None + self, + request: ChatCompletionRequest, + raw_request: Optional[Request] = None ) -> Union[ErrorResponse, AsyncGenerator[str, None], - ChatCompletionResponse]: + ChatCompletionResponse]: """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/chat/create @@ -159,14 +159,14 @@ async def create_chat_completion( return self.create_error_response(str(e)) # validation for OpenAI tools - # tool_choice = "required" is not supported + # tool_choice = "required" is not supported if request.tool_choice == 'required': return self.create_error_response('tool_choice = "required" is not supported!') # "auto" tools requires --enable-api-tools --enable-auto-tool-choice and --tool-parser if request.tool_choice == 'auto' and not (self.enable_auto_tools and self.tool_parser is not None): return self.create_error_response( - '"auto" tool choice requires --enable-auto-tool-choice and --tool-parser to be set') + '"auto" tool choice requires --enable-auto-tool-choice and --tool-parser to be set') request_id = f"chat-{random_uuid()}" try: @@ -186,7 +186,7 @@ async def create_chat_completion( tokenizer, guided_decode_logits_processor, default_max_tokens=self.max_model_len - - len(prompt_inputs["prompt_token_ids"])) + len(prompt_inputs["prompt_token_ids"])) self._log_inputs(request_id, prompt_inputs, @@ -246,12 +246,12 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str: return request.messages[-1]["role"] async def chat_completion_stream_generator( - self, - request: ChatCompletionRequest, - result_generator: AsyncIterator[RequestOutput], - request_id: str, - conversation: List[ConversationMessage], - tokenizer: PreTrainedTokenizer, + self, + request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], + request_id: str, + conversation: List[ConversationMessage], + tokenizer: PreTrainedTokenizer, ) -> AsyncGenerator[str, None]: model_name = self.served_model_names[0] created_time = int(time.time()) @@ -308,7 +308,7 @@ async def chat_completion_stream_generator( last_msg_content = "" if conversation and conversation[-1].get( "content") and conversation[-1].get( - "role") == role: + "role") == role: last_msg_content = conversation[-1]["content"] or '' if last_msg_content: @@ -354,7 +354,7 @@ async def chat_completion_stream_generator( delta_token_ids = output.token_ids[previous_num_tokens[i]:] out_logprobs = output.logprobs[ - previous_num_tokens[i]:] if output.logprobs else None + previous_num_tokens[i]:] if output.logprobs else None if request.logprobs and request.top_logprobs is not None: assert out_logprobs is not None, ( @@ -376,9 +376,11 @@ async def chat_completion_stream_generator( request.tool_choice ) is ChatCompletionNamedToolChoiceParam: delta_message = DeltaMessage(tool_calls=[ - ToolCall(function=FunctionCall( + DeltaToolCall(function=DeltaFunctionCall( name=request.tool_choice.function.name, - arguments=delta_text)) + arguments=delta_text), + index=i # note: ok to hard-code to 0 since named tool calling doesn't support arrays + ) ]) # handle streaming deltas for tools with tool_choice @@ -429,7 +431,7 @@ async def chat_completion_stream_generator( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + - completion_tokens, + completion_tokens, ) chunk.usage = usage else: @@ -441,30 +443,33 @@ async def chat_completion_stream_generator( # check to make sure we haven't "forgotten" to stream # any tokens that were generated but previously # matched by partial json parsing - if (delta_message.tool_calls - and delta_message.tool_calls[0] - and delta_message.tool_calls[0].function and - (delta_message.tool_calls[0].function.arguments - == '' or - delta_message.tool_calls[0].function.arguments and - (output.finish_reason == 'stop' - or output.finish_reason == 'tool_calls')) - and tool_parser): + # only happens if we are NOT using guided decoding + index = len(tool_parser.prev_tool_call_arr) - 1 if len( + tool_parser.prev_tool_call_arr) > 0 else 0 + if ( + delta_message.tool_calls and + delta_message.tool_calls[0] and + delta_message.tool_calls[0].function and + ( + delta_message.tool_calls[0].function.arguments == '' or + delta_message.tool_calls[0].function.arguments and ( + output.finish_reason == 'stop' or + output.finish_reason == 'tool_calls' + ) + ) and + tool_parser and request.tool_choice == 'auto' + ): expected_call = json.dumps( - tool_parser.prev_tool_call_arr[ - len(tool_parser.prev_tool_call_arr) - - 1].get('arguments', {})) - actual_call = tool_parser.streamed_args_for_tool[ - len(tool_parser.prev_tool_call_arr) - 1] + tool_parser.prev_tool_call_arr[index].get('arguments', {})) + actual_call = tool_parser.streamed_args_for_tool[index] remaining_call = expected_call.replace( actual_call, '', 1) delta_message = DeltaMessage(tool_calls=[ DeltaToolCall( - index=len(tool_parser.prev_tool_call_arr) - - 1, + index=index, function=DeltaFunctionCall( arguments=remaining_call).model_dump( - exclude_none=True)) + exclude_none=True)) ]) # Send the finish response for each request.n only once prompt_tokens = len(res.prompt_token_ids) @@ -492,7 +497,7 @@ async def chat_completion_stream_generator( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + - completion_tokens, + completion_tokens, ) chunk.usage = usage else: @@ -523,19 +528,20 @@ async def chat_completion_stream_generator( except ValueError as e: # TODO: Use a vllm-specific Validation Error + logger.error('error in chat completion stream generator: %s', e) data = self.create_streaming_error_response(str(e)) yield f"data: {data}\n\n" # Send the final done message after all response.n are finished yield "data: [DONE]\n\n" async def chat_completion_full_generator( - self, - request: ChatCompletionRequest, - raw_request: Optional[Request], - result_generator: AsyncIterator[RequestOutput], - request_id: str, - conversation: List[ConversationMessage], - tokenizer: PreTrainedTokenizer, + self, + request: ChatCompletionRequest, + raw_request: Optional[Request], + result_generator: AsyncIterator[RequestOutput], + request_id: str, + conversation: List[ConversationMessage], + tokenizer: PreTrainedTokenizer, ) -> Union[ErrorResponse, ChatCompletionResponse]: model_name = self.served_model_names[0] @@ -575,8 +581,8 @@ async def chat_completion_full_generator( # outlines is not being used if not (self.enable_auto_tools or not self.tool_parser) and not isinstance( - request.tool_choice, - ChatCompletionNamedToolChoiceParam): + request.tool_choice, + ChatCompletionNamedToolChoiceParam): message = ChatMessage(role=role, content=output.text) # if the request uses tools and specified a tool choice @@ -677,11 +683,11 @@ def _get_top_logprobs( ] def _create_chat_logprobs( - self, - token_ids: GenericSequence[int], - top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], - tokenizer: PreTrainedTokenizer, - num_output_top_logprobs: Optional[int] = None, + self, + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], + tokenizer: PreTrainedTokenizer, + num_output_top_logprobs: Optional[int] = None, ) -> ChatCompletionLogProbs: """Create OpenAI-style logprobs.""" From 38635ad1c09c934494b1931958b12b43fa1e065c Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 5 Aug 2024 20:23:18 -0500 Subject: [PATCH 098/222] fix: formatting --- vllm/entrypoints/chat_utils.py | 3 +- vllm/entrypoints/openai/protocol.py | 5 +- vllm/entrypoints/openai/serving_chat.py | 137 ++++++++------ vllm/entrypoints/openai/tool_parsers.py | 242 ++++++++++++++---------- 4 files changed, 220 insertions(+), 167 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 70db6078b1158..055092ca55f52 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -179,6 +179,5 @@ def parse_chat_message_content( messages = [ConversationMessage(role=role, content=content)] return ChatMessageParseResult(messages=messages, mm_futures=[]) - return _parse_chat_message_content_parts(role, - cast(Iterable, content), + return _parse_chat_message_content_parts(role, cast(Iterable, content), model_config, tokenizer) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 4a7b5d1efe9d6..013b2ceb184a1 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -333,7 +333,8 @@ def check_guided_decoding_count(cls, data): "You can only use one kind of guided decoding " "('guided_json', 'guided_regex' or 'guided_choice').") # you can only either use guided decoding or tools, not both - if guide_count > 1 and data.get('tool_choice', 'none') not in ("none", "auto"): + if guide_count > 1 and data.get('tool_choice', + 'none') not in ("none", "auto"): raise ValueError( "You can only either use guided decoding or tools, not both.") return data @@ -758,7 +759,7 @@ class ChatCompletionResponse(OpenAIBaseModel): class DeltaMessage(OpenAIBaseModel): role: Optional[str] = None content: Optional[str] = None - tool_calls: Optional[List[DeltaToolCall]] = Field(default_factory=list) + tool_calls: List[DeltaToolCall] = Field(default_factory=list) class ChatCompletionResponseStreamChoice(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 68c6cdcb577e6..fabaf9e8042dd 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -92,11 +92,11 @@ def __init__(self, 'Error: --enable-auto-tool-choice requires --tool-parser') async def create_chat_completion( - self, - request: ChatCompletionRequest, - raw_request: Optional[Request] = None + self, + request: ChatCompletionRequest, + raw_request: Optional[Request] = None ) -> Union[ErrorResponse, AsyncGenerator[str, None], - ChatCompletionResponse]: + ChatCompletionResponse]: """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/chat/create @@ -152,7 +152,9 @@ async def create_chat_completion( if len(mm_futures): # since we support only single mm data currently if len(mm_futures) != 1: - return self.create_error_response("Multiple 'image_url' input is currently not supported.") + return self.create_error_response( + "Multiple 'image_url' input is currently not supported." + ) mm_data = await mm_futures[0] except Exception as e: logger.error("Error in loading multi-modal data: %s", e) @@ -161,12 +163,15 @@ async def create_chat_completion( # validation for OpenAI tools # tool_choice = "required" is not supported if request.tool_choice == 'required': - return self.create_error_response('tool_choice = "required" is not supported!') + return self.create_error_response( + 'tool_choice = "required" is not supported!') # "auto" tools requires --enable-api-tools --enable-auto-tool-choice and --tool-parser - if request.tool_choice == 'auto' and not (self.enable_auto_tools and self.tool_parser is not None): + if request.tool_choice == 'auto' and not ( + self.enable_auto_tools and self.tool_parser is not None): return self.create_error_response( - '"auto" tool choice requires --enable-auto-tool-choice and --tool-parser to be set') + '"auto" tool choice requires --enable-auto-tool-choice and --tool-parser to be set' + ) request_id = f"chat-{random_uuid()}" try: @@ -186,7 +191,7 @@ async def create_chat_completion( tokenizer, guided_decode_logits_processor, default_max_tokens=self.max_model_len - - len(prompt_inputs["prompt_token_ids"])) + len(prompt_inputs["prompt_token_ids"])) self._log_inputs(request_id, prompt_inputs, @@ -232,7 +237,9 @@ async def create_chat_completion( conversation, tokenizer) if not isinstance(generator, ChatCompletionResponse): - raise ValueError('Expected generator to be instance of ChatCompletionResponse') + raise ValueError( + 'Expected generator to be instance of ChatCompletionResponse' + ) return generator except ValueError as e: @@ -246,12 +253,12 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str: return request.messages[-1]["role"] async def chat_completion_stream_generator( - self, - request: ChatCompletionRequest, - result_generator: AsyncIterator[RequestOutput], - request_id: str, - conversation: List[ConversationMessage], - tokenizer: PreTrainedTokenizer, + self, + request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], + request_id: str, + conversation: List[ConversationMessage], + tokenizer: PreTrainedTokenizer, ) -> AsyncGenerator[str, None]: model_name = self.served_model_names[0] created_time = int(time.time()) @@ -308,7 +315,7 @@ async def chat_completion_stream_generator( last_msg_content = "" if conversation and conversation[-1].get( "content") and conversation[-1].get( - "role") == role: + "role") == role: last_msg_content = conversation[-1]["content"] or '' if last_msg_content: @@ -354,7 +361,7 @@ async def chat_completion_stream_generator( delta_token_ids = output.token_ids[previous_num_tokens[i]:] out_logprobs = output.logprobs[ - previous_num_tokens[i]:] if output.logprobs else None + previous_num_tokens[i]:] if output.logprobs else None if request.logprobs and request.top_logprobs is not None: assert out_logprobs is not None, ( @@ -376,10 +383,12 @@ async def chat_completion_stream_generator( request.tool_choice ) is ChatCompletionNamedToolChoiceParam: delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall(function=DeltaFunctionCall( - name=request.tool_choice.function.name, - arguments=delta_text), - index=i # note: ok to hard-code to 0 since named tool calling doesn't support arrays + DeltaToolCall( + function=DeltaFunctionCall( + name=request.tool_choice.function.name, + arguments=delta_text), + index= + i # note: ok to hard-code to 0 since named tool calling doesn't support arrays ) ]) @@ -393,7 +402,8 @@ async def chat_completion_stream_generator( previous_text=previous_texts[i], current_text=output.text, delta_text=delta_text, - previous_token_ids=output.token_ids[:-1 * len(delta_token_ids)], + previous_token_ids=output. + token_ids[:-1 * len(delta_token_ids)], current_token_ids=output.token_ids, delta_token_ids=delta_token_ids) else: @@ -431,7 +441,7 @@ async def chat_completion_stream_generator( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + - completion_tokens, + completion_tokens, ) chunk.usage = usage else: @@ -444,32 +454,34 @@ async def chat_completion_stream_generator( # any tokens that were generated but previously # matched by partial json parsing # only happens if we are NOT using guided decoding - index = len(tool_parser.prev_tool_call_arr) - 1 if len( - tool_parser.prev_tool_call_arr) > 0 else 0 - if ( - delta_message.tool_calls and - delta_message.tool_calls[0] and - delta_message.tool_calls[0].function and - ( - delta_message.tool_calls[0].function.arguments == '' or - delta_message.tool_calls[0].function.arguments and ( - output.finish_reason == 'stop' or - output.finish_reason == 'tool_calls' - ) - ) and - tool_parser and request.tool_choice == 'auto' - ): + if tool_parser: + index = len( + tool_parser.prev_tool_call_arr) - 1 if len( + tool_parser.prev_tool_call_arr) > 0 else 0 + else: + index = 0 + if (delta_message.tool_calls + and delta_message.tool_calls[0] + and delta_message.tool_calls[0].function and + (delta_message.tool_calls[0].function.arguments + == '' or + delta_message.tool_calls[0].function.arguments and + (output.finish_reason == 'stop' + or output.finish_reason == 'tool_calls')) + and tool_parser + and request.tool_choice == 'auto'): expected_call = json.dumps( - tool_parser.prev_tool_call_arr[index].get('arguments', {})) - actual_call = tool_parser.streamed_args_for_tool[index] + tool_parser.prev_tool_call_arr[index].get( + 'arguments', {})) + actual_call = tool_parser.streamed_args_for_tool[ + index] remaining_call = expected_call.replace( actual_call, '', 1) delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=index, - function=DeltaFunctionCall( - arguments=remaining_call).model_dump( - exclude_none=True)) + DeltaToolCall(index=index, + function=DeltaFunctionCall( + arguments=remaining_call). + model_dump(exclude_none=True)) ]) # Send the finish response for each request.n only once prompt_tokens = len(res.prompt_token_ids) @@ -497,7 +509,7 @@ async def chat_completion_stream_generator( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + - completion_tokens, + completion_tokens, ) chunk.usage = usage else: @@ -535,13 +547,13 @@ async def chat_completion_stream_generator( yield "data: [DONE]\n\n" async def chat_completion_full_generator( - self, - request: ChatCompletionRequest, - raw_request: Optional[Request], - result_generator: AsyncIterator[RequestOutput], - request_id: str, - conversation: List[ConversationMessage], - tokenizer: PreTrainedTokenizer, + self, + request: ChatCompletionRequest, + raw_request: Optional[Request], + result_generator: AsyncIterator[RequestOutput], + request_id: str, + conversation: List[ConversationMessage], + tokenizer: PreTrainedTokenizer, ) -> Union[ErrorResponse, ChatCompletionResponse]: model_name = self.served_model_names[0] @@ -581,8 +593,8 @@ async def chat_completion_full_generator( # outlines is not being used if not (self.enable_auto_tools or not self.tool_parser) and not isinstance( - request.tool_choice, - ChatCompletionNamedToolChoiceParam): + request.tool_choice, + ChatCompletionNamedToolChoiceParam): message = ChatMessage(role=role, content=output.text) # if the request uses tools and specified a tool choice @@ -645,7 +657,8 @@ async def chat_completion_full_generator( last_msg_content = conversation[-1]["content"] or '' for choice in choices: - full_message = last_msg_content + (choice.message.content or '') + full_message = last_msg_content + (choice.message.content + or '') choice.message.content = full_message num_prompt_tokens = len(final_res.prompt_token_ids) @@ -683,11 +696,11 @@ def _get_top_logprobs( ] def _create_chat_logprobs( - self, - token_ids: GenericSequence[int], - top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], - tokenizer: PreTrainedTokenizer, - num_output_top_logprobs: Optional[int] = None, + self, + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], + tokenizer: PreTrainedTokenizer, + num_output_top_logprobs: Optional[int] = None, ) -> ChatCompletionLogProbs: """Create OpenAI-style logprobs.""" diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index 978e7b78534e9..864dff3d31c7f 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -164,20 +164,28 @@ def extract_tool_calls_streaming( class MistralToolParser(ToolParser): """ - Tool call parser for Mistral 7B Instruct v0.3, intended for use with the examples/tool_chat_template_mistral.jinja - template. There are server IMPORTANT CAVEATS for this parser: - - The chat template is NOT official and does not work well if you try to get the model to call 2+ tools at once. - Stick to only one tool call per generation, as the chat template is not reliable with > 1 and the model + Tool call parser for Mistral 7B Instruct v0.3, intended for use with the + examples/tool_chat_template_mistral.jinja template. There are several + IMPORTANT CAVEATS for this parser: + - The chat template is NOT official and does not work well if you try to + get the model to call 2+ tools at once without temperature=0. + Stick to only one tool call per generation, or set temp to 0 + as the chat template is not reliable with > 1 and the model Will lose coherence. - - Mistral's tool call format, that this translates into an OpenAI format, uses SINGLE QUOTES which cannot be - parsed to JSON. To enable JSON parsing and serialization, we find-and-replace these with DOUBLE QUOTES. To - prevent tool call corruption / deserialization failure, ensure that your tool calls and in particular your - ARGUMENTS never contain single or double quotes except as JSON control characters. - - Used when --enable-api-tools --enable-auto-tool-choice --tool-call-parser mistral are all set + - Mistral's tool call format, that this translates into an OpenAI + format, uses SINGLE QUOTES which cannot be parsed to JSON. To enable + JSON parsing and serialization, we find-and-replace these with + DOUBLE QUOTES. To prevent tool call corruption / deserialization + failure, ensure that your tool calls and in particular your + ARGUMENTS never contain single or double quotes except as JSON + control characters. + + Used when --enable-api-tools --enable-auto-tool-choice --tool-call-parser + mistral are all set """ - # the bot_token is the token indicating tool call(s) follow. Tokens before this token will be parsed as content; and + # the bot_token is the token indicating tool call(s) follow. Tokens before + # this token will be parsed as content; and # if not present, the entire response will be parsed as text content. bot_token: str = '[TOOL_CALLS]' # string literal bot_token_id: int = 5 # token ID thereof from the models' tokenizer @@ -186,8 +194,9 @@ class MistralToolParser(ToolParser): @staticmethod def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: """ - Extract the tool calls from a complete model response. Requires find-and-replacing single quotes with double - quotes for JSON parsing, make sure your tool call arguments don't ever include quotes! + Extract the tool calls from a complete model response. Requires + find-and-replacing single quotes with double quotes for JSON parsing, + make sure your tool call arguments don't ever include quotes! """ logger.debug( @@ -201,14 +210,16 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: else: try: - # this will throw an exception if we can't find the tool call properly + # this will throw an exception if we can't find the tool call + # properly raw_tool_call = MistralToolParser.tool_call_regex.findall( model_output.replace(MistralToolParser.bot_token, '') # remove BOT token .replace("'", '"') # replace string quotes )[0] - # load the JSON, and then use it to build the Function and Tool Call + # load the JSON, and then use it to build the Function and + # Tool Call function_call_arr = json.loads(raw_tool_call) tool_calls: List[ToolCall] = [ ToolCall( @@ -242,7 +253,8 @@ def __init__(self, AutoTokenizer]] = None): super().__init__(tokenizer) - # initialize properties used for state when parsing tool calls in streaming mode + # initialize properties used for state when parsing tool calls in + # streaming mode self.prev_tool_call_arr: List[Dict] = [] self.current_tool_id: int = -1 self.current_tool_name_sent: bool = False @@ -260,34 +272,42 @@ def extract_tool_calls_streaming( delta_token_ids: List[int], ) -> Union[DeltaMessage, None]: - # if the tool call token is not in the tokens generated so far, append output to contents since it's not a tool + # if the tool call token is not in the tokens generated so far, append + # output to contents since it's not a tool if self.bot_token_id not in current_token_ids: return DeltaMessage(content=delta_text) - # if the tool call token ID IS in the tokens generated so far, that means we're parsing as tool calls now + # if the tool call token ID IS in the tokens generated so far, that + # means we're parsing as tool calls now else: - # handle if we detected the BOT token which means the start of tool calling + # handle if we detected the BOT token which means the start of tool + # calling if self.bot_token_id in delta_token_ids: - # if it's the only token, return None, so we don't send a chat completion any don't send a control token + # if it's the only token, return None, so we don't send a chat + # completion any don't send a control token if len(delta_token_ids) == 1: return None - # bit mask flags for partial JSON parsing. If the name hasn't been sent yet, don't allow sending - # an incomplete string since OpenAI only ever (as far as I have seen) allows sending the entire tool/ - # function name at once. - flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR try: - # replace BOT token with empty string, and convert single quotes to double to allow parsing as JSON - # since mistral uses single quotes instead of double for tool calls + # replace BOT token with empty string, and convert single quotes + # to double to allow parsing as JSON since mistral uses single + # quotes instead of double for tool calls tool_call_message_portion = current_text.split( self.bot_token)[1] parsable_arr = tool_call_message_portion.replace('\'', '"') - #logger.debug('parsing: %s', parsable_arr) + # logger.debug('parsing: %s', parsable_arr) - # tool calls are generated in an array, so do partial JSON parsing on the entire array + # tool calls are generated in an array, so do partial JSON + # parsing on the entire array tool_call_arr: List[Dict] = partial_json_parser.loads( parsable_arr, flags) @@ -336,12 +356,14 @@ def extract_tool_calls_streaming( # logger.debug('update to tool %d', self.current_tool_id) pass - # if there is NOTHING in the array, e.g. if only the open bracket was streamed yet + # if there is NOTHING in the array, e.g. if only the open + # bracket was streamed yet else: - #logger.debug('No tool call detected yet!') + # logger.debug('No tool call detected yet!') return None - # if the current tool initial data incl. the id, type=function and idx not sent, send that + # if the current tool initial data incl. the id, type=function + # and idx not sent, send that if not self.current_tool_initial_sent: logger.debug('Sending InitialDeltaToolCall') self.current_tool_initial_sent = True @@ -351,13 +373,14 @@ def extract_tool_calls_streaming( exclude_none=True) ]) - # if the current tool name hasn't been sent, send if available - otherwise no chunks + # if the current tool name hasn't been sent, send if available + # - otherwise no chunks elif not self.current_tool_name_sent: function_name = current_tool_call.get('name') if function_name: logger.debug( - f'Sending DeltaToolCall with function name {function_name}!' - ) + f'Sending DeltaToolCall with function name ' + f'{function_name}!') delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( @@ -368,7 +391,8 @@ def extract_tool_calls_streaming( else: delta = None - # now we know we're on the same tool call and we're streaming arguments + # now we know we're on the same tool call and we're streaming + # arguments else: prev_arguments = self.prev_tool_call_arr[ @@ -382,8 +406,8 @@ def extract_tool_calls_streaming( delta = None elif not cur_arguments and prev_arguments: logger.error( - 'INVARIANT - impossible to have arguments reset mid-arguments' - ) + 'INVARIANT - impossible to have arguments reset ' + 'mid-arguments') delta = None elif cur_arguments and not prev_arguments: cur_arguments_json = json.dumps(cur_arguments) @@ -393,9 +417,8 @@ def extract_tool_calls_streaming( cur_arguments_json .index(new_text) + len(new_text)] - logger.debug( - f'First tokens in arguments received: {arguments_delta}' - ) + logger.debug(f'First tokens in arguments received: ' + f'{arguments_delta}') delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( @@ -409,8 +432,8 @@ def extract_tool_calls_streaming( cur_args_json = json.dumps(cur_arguments) prev_args_json = json.dumps(prev_arguments) logger.debug( - f'Searching for diff between \n{cur_args_json}\n{prev_args_json}' - ) + f'Searching for diff between \n{cur_args_json}\n' + f'{prev_args_json}') argument_diff = extract_intermediate_diff( cur_args_json, prev_args_json) logger.debug(f'got arguments diff: {argument_diff}') @@ -423,11 +446,13 @@ def extract_tool_calls_streaming( self.streamed_args_for_tool[ self.current_tool_id] += argument_diff else: - # try parsing it with regular JSON - if it works we're at the end, and we need to send the - # difference between tokens streamed so far and the valid JSON + # try parsing it with regular JSON - if it works we're + # at the end, and we need to send the difference between + # tokens streamed so far and the valid JSON delta = None - # check to see if the name is defined and has been sent. if so, stream the name - otherwise keep waiting + # check to see if the name is defined and has been sent. if so, + # stream the name - otherwise keep waiting # finish by setting old and returning None as base case self.prev_tool_call_arr = tool_call_arr return delta @@ -436,8 +461,8 @@ def extract_tool_calls_streaming( logger.error( f'Error trying to handle streaming tool call: {e}') logger.debug( - 'Skipping chunk as a result of tool streaming extraction error' - ) + 'Skipping chunk as a result of tool streaming extraction ' + 'error') return None @@ -445,7 +470,8 @@ class Hermes2ProToolParser(ToolParser): tool_call_start_token: str = '' tool_call_end_token: str = '' - # regex to match between and OR between and EOS (happens sometimes :)) + # regex to match between and OR between + # and EOS (happens sometimes :)) tool_call_regex = re.compile( r'(.*?)|(.*)', re.DOTALL) scratch_pad_regex = re.compile(r'(.*?)', @@ -463,12 +489,15 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: else: try: - # there are two possible captures - between tags, or between a tag and end-of-string so the result of - # findall is an array of tuples where one is a function call and the other is None - function_call_tuples = Hermes2ProToolParser.tool_call_regex.findall( - model_output) - - # load the JSON, and then use it to build the Function and Tool Call + # there are two possible captures - between tags, or between a + # tag and end-of-string so the result of + # findall is an array of tuples where one is a function call and + # the other is None + function_call_tuples = ( + Hermes2ProToolParser.tool_call_regex.findall(model_output)) + + # load the JSON, and then use it to build the Function and + # Tool Call raw_function_calls = [ json.loads(match[0] if match[0] else match[1]) for match in function_call_tuples @@ -503,7 +532,7 @@ def __init__(self, PreTrainedTokenizerFast, AutoTokenizer]] = None): super().__init__(tokenizer) - self.current_tool_name_sent: bool = False # reset each time we encounter a new tool in the array + self.current_tool_name_sent: bool = False self.prev_tool_call_arr: List[Dict] = [] self.current_tool_id: int = -1 self.current_tool_name_sent = False @@ -513,16 +542,16 @@ def __init__(self, if not self.model_tokenizer: raise ValueError( - 'The model tokenizer must be passed to the ToolParser constructor during construction.' - ) + 'The model tokenizer must be passed to the ToolParser ' + 'constructor during construction.') self.tool_call_start_token_id: int = self.model_tokenizer.vocab[ ''] self.tool_call_end_token_id: int = self.model_tokenizer.vocab[ ''] if not self.tool_call_start_token_id or not self.tool_call_end_token_id: raise RuntimeError( - 'Hermes 2 Pro Tool parser could not locate tool call start/end tokens in the tokenizer!' - ) + 'Hermes 2 Pro Tool parser could not locate tool call start/end ' + 'tokens in the tokenizer!') def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, @@ -539,7 +568,8 @@ def extract_tool_calls_streaming( else: try: - # figure out where we are in the parsing by counting tool call start & end tags + # figure out where we are in the parsing by counting tool call + # start & end tags prev_tool_start_count = previous_token_ids.count( self.tool_call_start_token_id) prev_tool_end_count = previous_token_ids.count( @@ -550,19 +580,24 @@ def extract_tool_calls_streaming( self.tool_call_end_token_id) # a cheap case - we're generating text, NOT tool calls. - if cur_tool_start_count == cur_tool_end_count and prev_tool_end_count == cur_tool_end_count: + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count): logger.debug( 'Generating text content! skipping tool parsing.') return DeltaMessage(content=delta_text) - # most of the time, we're going in here - we need to do partial JSON parsing and build stuff. + # most of the time, we're going in here - we need to do partial + # JSON parsing and build stuff. else: - # flags for partial JSON parting. exported constants from "Allow" are handled via BIT MASK - # generally, we don't allow sending an incomplete function name. so we don't allow - flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR - - # if a new tool call is being started. unusual since normally the first "cheap case" will be hit. - if cur_tool_start_count > cur_tool_end_count and cur_tool_start_count > prev_tool_start_count: + # flags for partial JSON parting. exported constants from + # "Allow" are handled via BIT MASK + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + + # if a new tool call is being started. unusual since + # normally the first "cheap case" will be hit. + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): if len(delta_token_ids) > 1: tool_call_portion = current_text.split( self.tool_call_start_token)[-1] @@ -580,14 +615,17 @@ def extract_tool_calls_streaming( logger.debug( f'Starting on a new tool {self.current_tool_id}') - # if an existing tool call is being updated - the most common case! - elif cur_tool_start_count > cur_tool_end_count and cur_tool_start_count == prev_tool_start_count: + # if an existing tool call is being updated - the most + # common case! + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): tool_call_portion = current_text.split( self.tool_call_start_token)[-1] text_portion = None # if the current tool call is being closed - elif cur_tool_start_count == cur_tool_end_count and cur_tool_end_count > prev_tool_end_count: + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count > prev_tool_end_count): logger.debug('Closing the current tool call!') diff = self.prev_tool_call_arr[ self.current_tool_id].get('arguments') @@ -596,8 +634,8 @@ def extract_tool_calls_streaming( self.streamed_args_for_tool[ self.current_tool_id], '') logger.debug( - f'Finishing tool and found diff that wasn\'t streamed yet: {diff}' - ) + f'Finishing tool and found diff that had not ' + f'been streamed yet: {diff}') return DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( @@ -607,8 +645,8 @@ def extract_tool_calls_streaming( else: logger.error( - 'INVARIANT - invalid state trying to parse tool calls (wtf?)' - ) + 'INVARIANT - invalid state trying to parse tool ' + 'calls (wtf?)') delta = None return delta @@ -618,7 +656,8 @@ def extract_tool_calls_streaming( flags) if tool_call_portion else None logger.debug(f'Parsed tool call {current_tool_call}') - # make sure to send the initial message first if we haven't already - with the tool ID + # make sure to send the initial message first if we haven't + # already - with the tool ID if not self.current_tool_initial_sent: logger.debug('Sending InitialDeltaToolCall') self.current_tool_initial_sent = True @@ -628,14 +667,15 @@ def extract_tool_calls_streaming( exclude_none=True) ]) - # after that, make sure we send the function name before any arguments + # after that, make sure we send the function name before + # any arguments elif not self.current_tool_name_sent: function_name: Union[ str, None] = current_tool_call.get('name') if function_name: logger.debug( - f'Sending DeltaToolCall with function name {function_name}!' - ) + f'Sending DeltaToolCall with function name ' + f'{function_name}!') self.current_tool_name_sent = True return DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, @@ -648,7 +688,8 @@ def extract_tool_calls_streaming( else: # if there is no tool calls if tool_call_portion is None: - # if there's text but not tool calls, send that - otherwise None to skip chunk + # if there's text but not tool calls, send that - + # otherwise None to skip chunk delta = DeltaMessage( content=delta_text ) if text_portion is not None else None @@ -656,13 +697,12 @@ def extract_tool_calls_streaming( else: # now we have the portion to parse as tool call. if text_portion is not None: - logger.debug( - f'Also, will send text portion {text_portion}' - ) + logger.debug(f'Also, will send text portion ' + f'{text_portion}') logger.debug( - f'Trying to parse current tool call with ID {self.current_tool_id}' - ) + f'Trying to parse current tool call with ID ' + f'{self.current_tool_id}') if len(self.prev_tool_call_arr ) <= self.current_tool_id: self.prev_tool_call_arr.append({}) @@ -676,23 +716,22 @@ def extract_tool_calls_streaming( ) # arguments, if any, in current dict logger.debug( - f'Diffing old arguments {prev_arguments} against new ones {cur_arguments}' - ) + f'Diffing old arguments {prev_arguments} ' + f'against new ones {cur_arguments}') if not cur_arguments and not prev_arguments: logger.debug( - f'Skipping text {delta_text} - no arguments!' + f'Skipping text {delta_text} - no arguments' ) delta = None elif not cur_arguments and prev_arguments: logger.error( - 'INVARIANT - impossible to have arguments reset mid-call' - ) + 'INVARIANT - impossible to have arguments ' + 'reset mid-call') delta = None elif cur_arguments and not prev_arguments: cur_arguments_json = json.dumps(cur_arguments) - logger.debug( - f'Finding {delta_text} in {cur_arguments_json}' - ) + logger.debug(f'Finding {delta_text} in ' + f'{cur_arguments_json}') arguments_delta = cur_arguments_json[: cur_arguments_json .index( @@ -701,8 +740,8 @@ def extract_tool_calls_streaming( len(delta_text )] logger.debug( - f'First tokens in arguments received: {arguments_delta}' - ) + f'First tokens in arguments received:' + f' {arguments_delta}') delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( @@ -717,8 +756,8 @@ def extract_tool_calls_streaming( cur_args_json = json.dumps(cur_arguments) prev_args_json = json.dumps(prev_arguments) logger.debug( - f"Searching for diff between \n{cur_args_json}\n{prev_args_json}" - ) + f"Searching for diff between " + f"\n{cur_args_json}\n{prev_args_json}") argument_diff = extract_intermediate_diff( cur_args_json, prev_args_json) logger.debug( @@ -735,7 +774,8 @@ def extract_tool_calls_streaming( else: delta = None - # handle saving the state for the current tool into the "prev" list for use in diffing for + # handle saving the state for the current tool into + # the "prev" list for use in diffing for # the next iteration if self.current_tool_id == len( self.prev_tool_call_arr) - 1: @@ -746,13 +786,13 @@ def extract_tool_calls_streaming( current_tool_call) # TODO REPLACE ME WITH TOOL CALL - #delta = DeltaMessage(content=delta_text) + # delta = DeltaMessage(content=delta_text) return delta except Exception as e: logger.error( f'Error trying to handle streaming tool call: {e}') logger.debug( - 'Skipping chunk as a result of tool streaming extraction error' - ) + 'Skipping chunk as a result of tool streaming extraction ' + 'error') return None # do not stream a delta. skip this token ID. From 28da76c396ec926b12dc05834227313992540f26 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 5 Aug 2024 20:54:53 -0500 Subject: [PATCH 099/222] chore: refactor tool parsers structure to make it more maintainable --- vllm/entrypoints/openai/tool_parsers.py | 795 ------------------ .../openai/tool_parsers/__init__.py | 5 + .../tool_parsers/abstract_tool_parser.py | 64 ++ .../openai/tool_parsers/hermes_tool_parser.py | 349 ++++++++ .../tool_parsers/mistral_tool_parser.py | 321 +++++++ vllm/entrypoints/openai/tool_parsers/utils.py | 90 ++ 6 files changed, 829 insertions(+), 795 deletions(-) create mode 100644 vllm/entrypoints/openai/tool_parsers/__init__.py create mode 100644 vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py create mode 100644 vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py create mode 100644 vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py create mode 100644 vllm/entrypoints/openai/tool_parsers/utils.py diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py index 864dff3d31c7f..b86b92ce9c057 100644 --- a/vllm/entrypoints/openai/tool_parsers.py +++ b/vllm/entrypoints/openai/tool_parsers.py @@ -1,798 +1,3 @@ -from vllm.entrypoints.openai.protocol import (ToolCall, FunctionCall, - ExtractedToolCallInformation, - DeltaToolCall, - InitialDeltaToolCall, - DeltaFunctionCall, DeltaMessage) from vllm.logger import init_logger -from typing import List, Dict, Optional, Union -from transformers import (AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerFast) -import json -import partial_json_parser -from partial_json_parser import Allow -import re -from vllm.entrypoints.openai.protocol import DeltaMessage logger = init_logger(__name__) - - -def find_common_prefix(s1: str, s2: str) -> str: - """ - Finds a common prefix that is shared between two strings, if there is one. - Order of arguments is NOT important. - - This function is provided as a UTILITY for extracting information from JSON - generated by partial_json_parser, to help in ensuring that the right tokens - are returned in streaming, so that close-quotes, close-brackets and - close-braces are not returned prematurely. - - e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') -> - '{"fruit": "ap' - """ - prefix = '' - min_length = min(len(s1), len(s2)) - for i in range(0, min_length): - if s1[i] == s2[i]: - prefix += s1[i] - else: - break - return prefix - - -def find_common_suffix(s1: str, s2: str) -> str: - """ - Finds a common suffix shared between two strings, if there is one. Order of - arguments is NOT important. - Stops when the suffix ends OR it hits an alphanumeric character - - e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}' - """ - suffix = '' - min_length = min(len(s1), len(s2)) - for i in range(1, min_length + 1): - if s1[-i] == s2[-i] and not s1[-i].isalnum(): - suffix = s1[-i] + suffix - else: - break - return suffix - - -def extract_intermediate_diff(curr: str, old: str) -> str: - """ - Given two strings, extract the difference in the middle between two strings - that are known to have a common prefix and/or suffix. - - This function is provided as a UTILITY for extracting information from JSON - generated by partial_json_parser, to help in ensuring that the right tokens - are returned in streaming, so that close-quotes, close-brackets and - close-braces are not returned prematurely. The order of arguments IS - important - the new version of the partially-parsed JSON must be the first - argument, and the secnod argument must be from the previous generation. - - What it returns, is tokens that should be streamed to the client. - - e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}') - -> 'ple' - - """ - suffix = find_common_suffix(curr, old) - - # prevent double-counting - s2_old = old - old = old[::-1].replace(suffix[::-1], '', 1)[::-1] - prefix = find_common_prefix(curr, old) - diff = curr - if len(suffix): - diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1] - - if len(prefix): - diff = diff.replace( - prefix, '', - 1) # replace the prefix only once in case it's mirrored - - return diff - - -def find_all_indices(string, substring): - """ - Find all (starting) indices of a substring in a given string. Useful for - tool call extraction - """ - indices = [] - index = -1 - while True: - index = string.find(substring, index + 1) - if index == -1: - break - indices.append(index) - return indices - - -class ToolParser: - """ - Abstract ToolParser class that should not be used directly. Provided - properties and methods should be used in - derived classes. - """ - - def __init__(self, - tokenizer: Optional[Union[PreTrainedTokenizer, - PreTrainedTokenizerFast, - PreTrainedTokenizerFast, - AutoTokenizer]] = None): - self.prev_tool_call_arr: List[Dict] = [] - # the index of the tool call that is currently being parsed - self.current_tool_id: int = -1 - self.current_tool_name_sent: bool = False - self.current_tool_initial_sent: bool = False - self.streamed_args_for_tool: List[str] = [] - - self.model_tokenizer = tokenizer - - @staticmethod - def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: - """ - Static method that should be implemented for extracting tool calls from - a complete model-generated string. - Used for non-streaming responses where we have the entire model response - available before sending to the client. - Static because it's stateless. - """ - raise NotImplementedError( - 'AbstractToolParser.extract_tool_calls has not been implemented!') - - def extract_tool_calls_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: List[int], - current_token_ids: List[int], - delta_token_ids: List[int], - ) -> Union[DeltaMessage, None]: - """ - Instance method that should be implemented for extracting tool calls - from an incomplete response; for use when handling tool calls and - streaming. Has to be an instance method because it requires state - - the current text/ tokens/diffs, but also the information about what has - previously been parsed and extracted (see constructor) - """ - raise NotImplementedError( - 'AbstractToolParser.extract_tool_calls_streaming has not been ' - 'implemented!') - - -class MistralToolParser(ToolParser): - """ - Tool call parser for Mistral 7B Instruct v0.3, intended for use with the - examples/tool_chat_template_mistral.jinja template. There are several - IMPORTANT CAVEATS for this parser: - - The chat template is NOT official and does not work well if you try to - get the model to call 2+ tools at once without temperature=0. - Stick to only one tool call per generation, or set temp to 0 - as the chat template is not reliable with > 1 and the model - Will lose coherence. - - Mistral's tool call format, that this translates into an OpenAI - format, uses SINGLE QUOTES which cannot be parsed to JSON. To enable - JSON parsing and serialization, we find-and-replace these with - DOUBLE QUOTES. To prevent tool call corruption / deserialization - failure, ensure that your tool calls and in particular your - ARGUMENTS never contain single or double quotes except as JSON - control characters. - - Used when --enable-api-tools --enable-auto-tool-choice --tool-call-parser - mistral are all set - """ - - # the bot_token is the token indicating tool call(s) follow. Tokens before - # this token will be parsed as content; and - # if not present, the entire response will be parsed as text content. - bot_token: str = '[TOOL_CALLS]' # string literal - bot_token_id: int = 5 # token ID thereof from the models' tokenizer - tool_call_regex = re.compile(r'\[{.*?}\]', re.DOTALL) - - @staticmethod - def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: - """ - Extract the tool calls from a complete model response. Requires - find-and-replacing single quotes with double quotes for JSON parsing, - make sure your tool call arguments don't ever include quotes! - """ - - logger.debug( - 'Trying to extract mistral tool calls from the following:') - logger.debug(model_output) - # Get the tool call token from the tokenizer - if MistralToolParser.bot_token not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) - else: - try: - - # this will throw an exception if we can't find the tool call - # properly - raw_tool_call = MistralToolParser.tool_call_regex.findall( - model_output.replace(MistralToolParser.bot_token, - '') # remove BOT token - .replace("'", '"') # replace string quotes - )[0] - - # load the JSON, and then use it to build the Function and - # Tool Call - function_call_arr = json.loads(raw_tool_call) - tool_calls: List[ToolCall] = [ - ToolCall( - type='function', - function=FunctionCall( - name=raw_function_call['name'], - # function call args are JSON but as a string - arguments=json.dumps( - raw_function_call['arguments']))) - for raw_function_call in function_call_arr - ] - content = model_output.split(MistralToolParser.bot_token)[0] - return ExtractedToolCallInformation( - tools_called=True, - tool_calls=tool_calls, - content=content if len(content) > 0 else None) - - except Exception as e: - logger.error("Error in extracting tool call from response: %s", - e) - print('ERROR', e) - # return information to just treat the tool call as regular JSON - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) - - def __init__(self, - tokenizer: Optional[Union[PreTrainedTokenizer, - PreTrainedTokenizerFast, - PreTrainedTokenizerFast, - AutoTokenizer]] = None): - super().__init__(tokenizer) - - # initialize properties used for state when parsing tool calls in - # streaming mode - self.prev_tool_call_arr: List[Dict] = [] - self.current_tool_id: int = -1 - self.current_tool_name_sent: bool = False - self.current_tool_initial_sent: bool = False - self.streamed_args_for_tool: List[str] = [ - ] # map what has been streamed for each tool so far to a list - - def extract_tool_calls_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: List[int], - current_token_ids: List[int], - delta_token_ids: List[int], - ) -> Union[DeltaMessage, None]: - - # if the tool call token is not in the tokens generated so far, append - # output to contents since it's not a tool - if self.bot_token_id not in current_token_ids: - return DeltaMessage(content=delta_text) - - # if the tool call token ID IS in the tokens generated so far, that - # means we're parsing as tool calls now - else: - - # handle if we detected the BOT token which means the start of tool - # calling - if self.bot_token_id in delta_token_ids: - # if it's the only token, return None, so we don't send a chat - # completion any don't send a control token - if len(delta_token_ids) == 1: - return None - - # bit mask flags for partial JSON parsing. If the name hasn't been - # sent yet, don't allow sending - # an incomplete string since OpenAI only ever (as far as I have - # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR - try: - - # replace BOT token with empty string, and convert single quotes - # to double to allow parsing as JSON since mistral uses single - # quotes instead of double for tool calls - tool_call_message_portion = current_text.split( - self.bot_token)[1] - parsable_arr = tool_call_message_portion.replace('\'', '"') - - # logger.debug('parsing: %s', parsable_arr) - - # tool calls are generated in an array, so do partial JSON - # parsing on the entire array - tool_call_arr: List[Dict] = partial_json_parser.loads( - parsable_arr, flags) - - # select as the current tool call the one we're on the state at - current_tool_call: Dict = tool_call_arr[self.current_tool_id] - - # case: we are starting a new tool in the array - # -> array has nonzero length AND length has moved past curscor - if len(tool_call_arr) > 0 and len( - tool_call_arr) > self.current_tool_id + 1: - - # if we're moving on to a new call, first make sure we haven't missed anything in the previous - # one that was auto-generated due to JSON completions, but wasn't streamed to the client yet. - if self.current_tool_id >= 0: - diff: Union[str, - None] = current_tool_call.get('arguments') - if diff: - diff = json.dumps(diff).replace( - self.streamed_args_for_tool[ - self.current_tool_id], '') - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += diff - else: - delta = None - else: - delta = None - # re-set stuff pertaining to progress in the current tool - self.current_tool_id = len(tool_call_arr) - 1 - self.current_tool_name_sent = False - self.current_tool_initial_sent = False - self.streamed_args_for_tool.append('') - logger.debug('starting on new tool %d', - self.current_tool_id) - return delta - - # case: update an existing tool - this is handled below - elif len( - tool_call_arr - ) - 1 == self.current_tool_id and self.current_tool_id >= 0: - # logger.debug('update to tool %d', self.current_tool_id) - pass - - # if there is NOTHING in the array, e.g. if only the open - # bracket was streamed yet - else: - # logger.debug('No tool call detected yet!') - return None - - # if the current tool initial data incl. the id, type=function - # and idx not sent, send that - if not self.current_tool_initial_sent: - logger.debug('Sending InitialDeltaToolCall') - self.current_tool_initial_sent = True - delta = DeltaMessage(tool_calls=[ - InitialDeltaToolCall( - index=self.current_tool_id).model_dump( - exclude_none=True) - ]) - - # if the current tool name hasn't been sent, send if available - # - otherwise no chunks - elif not self.current_tool_name_sent: - function_name = current_tool_call.get('name') - if function_name: - logger.debug( - f'Sending DeltaToolCall with function name ' - f'{function_name}!') - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True)) - ]) - self.current_tool_name_sent = True - else: - delta = None - - # now we know we're on the same tool call and we're streaming - # arguments - else: - - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get('arguments') - cur_arguments = current_tool_call.get('arguments') - - new_text = delta_text.replace('\'', '"') - - if not cur_arguments and not prev_arguments: - - delta = None - elif not cur_arguments and prev_arguments: - logger.error( - 'INVARIANT - impossible to have arguments reset ' - 'mid-arguments') - delta = None - elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments) - logger.debug( - f'Finding {new_text} in |{cur_arguments_json}|') - arguments_delta = cur_arguments_json[: - cur_arguments_json - .index(new_text) + - len(new_text)] - logger.debug(f'First tokens in arguments received: ' - f'{arguments_delta}') - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += arguments_delta - - elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments) - prev_args_json = json.dumps(prev_arguments) - logger.debug( - f'Searching for diff between \n{cur_args_json}\n' - f'{prev_args_json}') - argument_diff = extract_intermediate_diff( - cur_args_json, prev_args_json) - logger.debug(f'got arguments diff: {argument_diff}') - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff - else: - # try parsing it with regular JSON - if it works we're - # at the end, and we need to send the difference between - # tokens streamed so far and the valid JSON - delta = None - - # check to see if the name is defined and has been sent. if so, - # stream the name - otherwise keep waiting - # finish by setting old and returning None as base case - self.prev_tool_call_arr = tool_call_arr - return delta - - except Exception as e: - logger.error( - f'Error trying to handle streaming tool call: {e}') - logger.debug( - 'Skipping chunk as a result of tool streaming extraction ' - 'error') - return None - - -class Hermes2ProToolParser(ToolParser): - tool_call_start_token: str = '' - tool_call_end_token: str = '' - - # regex to match between and OR between - # and EOS (happens sometimes :)) - tool_call_regex = re.compile( - r'(.*?)|(.*)', re.DOTALL) - scratch_pad_regex = re.compile(r'(.*?)', - re.DOTALL) - - @staticmethod - def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: - - # sanity check; avoid unnecessary processing - if Hermes2ProToolParser.tool_call_start_token not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) - - else: - - try: - # there are two possible captures - between tags, or between a - # tag and end-of-string so the result of - # findall is an array of tuples where one is a function call and - # the other is None - function_call_tuples = ( - Hermes2ProToolParser.tool_call_regex.findall(model_output)) - - # load the JSON, and then use it to build the Function and - # Tool Call - raw_function_calls = [ - json.loads(match[0] if match[0] else match[1]) - for match in function_call_tuples - ] - tool_calls = [ - ToolCall( - type='function', - function=FunctionCall( - name=function_call['name'], - # function call args are JSON but as a string - arguments=json.dumps(function_call['arguments']))) - for function_call in raw_function_calls - ] - - content = model_output[:model_output.find( - Hermes2ProToolParser.tool_call_start_token)] - return ExtractedToolCallInformation( - tools_called=True, - tool_calls=tool_calls, - content=content if content else None) - - except Exception as e: - logger.error("Error in extracting tool call from response %s", - e) - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) - - def __init__(self, - tokenizer: Optional[Union[PreTrainedTokenizer, - PreTrainedTokenizerFast, - PreTrainedTokenizerFast, - AutoTokenizer]] = None): - super().__init__(tokenizer) - self.current_tool_name_sent: bool = False - self.prev_tool_call_arr: List[Dict] = [] - self.current_tool_id: int = -1 - self.current_tool_name_sent = False - self.current_tool_initial_sent: bool = False - self.streamed_args_for_tool: List[str] = [ - ] # map what has been streamed for each tool so far to a list - - if not self.model_tokenizer: - raise ValueError( - 'The model tokenizer must be passed to the ToolParser ' - 'constructor during construction.') - self.tool_call_start_token_id: int = self.model_tokenizer.vocab[ - ''] - self.tool_call_end_token_id: int = self.model_tokenizer.vocab[ - ''] - if not self.tool_call_start_token_id or not self.tool_call_end_token_id: - raise RuntimeError( - 'Hermes 2 Pro Tool parser could not locate tool call start/end ' - 'tokens in the tokenizer!') - - def extract_tool_calls_streaming( - self, previous_text: str, current_text: str, delta_text: str, - previous_token_ids: List[int], current_token_ids: List[int], - delta_token_ids: List[int]) -> Union[DeltaMessage, None]: - - logger.debug(f'delta_text: {delta_text}') - logger.debug(f'delta_token_ids: {delta_token_ids}') - # check to see if we should be streaming a tool call - is there a - if self.tool_call_start_token_id not in current_token_ids: - logger.debug('No tool call tokens found!') - return DeltaMessage(content=delta_text) - - else: - try: - - # figure out where we are in the parsing by counting tool call - # start & end tags - prev_tool_start_count = previous_token_ids.count( - self.tool_call_start_token_id) - prev_tool_end_count = previous_token_ids.count( - self.tool_call_end_token_id) - cur_tool_start_count = current_token_ids.count( - self.tool_call_start_token_id) - cur_tool_end_count = current_token_ids.count( - self.tool_call_end_token_id) - - # a cheap case - we're generating text, NOT tool calls. - if (cur_tool_start_count == cur_tool_end_count - and prev_tool_end_count == cur_tool_end_count): - logger.debug( - 'Generating text content! skipping tool parsing.') - return DeltaMessage(content=delta_text) - - # most of the time, we're going in here - we need to do partial - # JSON parsing and build stuff. - else: - # flags for partial JSON parting. exported constants from - # "Allow" are handled via BIT MASK - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR - - # if a new tool call is being started. unusual since - # normally the first "cheap case" will be hit. - if (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count > prev_tool_start_count): - if len(delta_token_ids) > 1: - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] - text_portion = None - else: - tool_call_portion = None - text_portion = None - delta = None - - # set cursors and state appropriately - self.current_tool_id += 1 - self.current_tool_name_sent = False - self.current_tool_initial_sent = False - self.streamed_args_for_tool.append('') - logger.debug( - f'Starting on a new tool {self.current_tool_id}') - - # if an existing tool call is being updated - the most - # common case! - elif (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count == prev_tool_start_count): - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] - text_portion = None - - # if the current tool call is being closed - elif (cur_tool_start_count == cur_tool_end_count - and cur_tool_end_count > prev_tool_end_count): - logger.debug('Closing the current tool call!') - diff = self.prev_tool_call_arr[ - self.current_tool_id].get('arguments') - if diff: - diff = json.dumps(diff).replace( - self.streamed_args_for_tool[ - self.current_tool_id], '') - logger.debug( - f'Finishing tool and found diff that had not ' - f'been streamed yet: {diff}') - return DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump( - exclude_none=True)) - ]) - - else: - logger.error( - 'INVARIANT - invalid state trying to parse tool ' - 'calls (wtf?)') - delta = None - return delta - - logger.debug(f'Tool call portion: {tool_call_portion}') - current_tool_call = partial_json_parser.loads( - tool_call_portion, - flags) if tool_call_portion else None - logger.debug(f'Parsed tool call {current_tool_call}') - - # make sure to send the initial message first if we haven't - # already - with the tool ID - if not self.current_tool_initial_sent: - logger.debug('Sending InitialDeltaToolCall') - self.current_tool_initial_sent = True - return DeltaMessage(tool_calls=[ - InitialDeltaToolCall( - index=self.current_tool_id).model_dump( - exclude_none=True) - ]) - - # after that, make sure we send the function name before - # any arguments - elif not self.current_tool_name_sent: - function_name: Union[ - str, None] = current_tool_call.get('name') - if function_name: - logger.debug( - f'Sending DeltaToolCall with function name ' - f'{function_name}!') - self.current_tool_name_sent = True - return DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - name=function_name). - model_dump(exclude_none=True)) - ]) - else: - return None - else: - # if there is no tool calls - if tool_call_portion is None: - # if there's text but not tool calls, send that - - # otherwise None to skip chunk - delta = DeltaMessage( - content=delta_text - ) if text_portion is not None else None - # now, the nitty-gritty of tool calls - else: - # now we have the portion to parse as tool call. - if text_portion is not None: - logger.debug(f'Also, will send text portion ' - f'{text_portion}') - - logger.debug( - f'Trying to parse current tool call with ID ' - f'{self.current_tool_id}') - if len(self.prev_tool_call_arr - ) <= self.current_tool_id: - self.prev_tool_call_arr.append({}) - logger.debug( - 'Pushed dummy value into tool call arr') - # main logic for tool parsing here - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get('arguments') - cur_arguments = current_tool_call.get( - 'arguments' - ) # arguments, if any, in current dict - - logger.debug( - f'Diffing old arguments {prev_arguments} ' - f'against new ones {cur_arguments}') - if not cur_arguments and not prev_arguments: - logger.debug( - f'Skipping text {delta_text} - no arguments' - ) - delta = None - elif not cur_arguments and prev_arguments: - logger.error( - 'INVARIANT - impossible to have arguments ' - 'reset mid-call') - delta = None - elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments) - logger.debug(f'Finding {delta_text} in ' - f'{cur_arguments_json}') - arguments_delta = cur_arguments_json[: - cur_arguments_json - .index( - delta_text - ) + - len(delta_text - )] - logger.debug( - f'First tokens in arguments received:' - f' {arguments_delta}') - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta - ).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += arguments_delta - - elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments) - prev_args_json = json.dumps(prev_arguments) - logger.debug( - f"Searching for diff between " - f"\n{cur_args_json}\n{prev_args_json}") - argument_diff = extract_intermediate_diff( - cur_args_json, prev_args_json) - logger.debug( - f'Got argument diff: {argument_diff}') - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff). - model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff - else: - delta = None - - # handle saving the state for the current tool into - # the "prev" list for use in diffing for - # the next iteration - if self.current_tool_id == len( - self.prev_tool_call_arr) - 1: - self.prev_tool_call_arr[ - self.current_tool_id] = current_tool_call - else: - self.prev_tool_call_arr.append( - current_tool_call) - - # TODO REPLACE ME WITH TOOL CALL - # delta = DeltaMessage(content=delta_text) - return delta - - except Exception as e: - logger.error( - f'Error trying to handle streaming tool call: {e}') - logger.debug( - 'Skipping chunk as a result of tool streaming extraction ' - 'error') - return None # do not stream a delta. skip this token ID. diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py new file mode 100644 index 0000000000000..64a33a6d4eded --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -0,0 +1,5 @@ +from .abstract_tool_parser import ToolParser +from .hermes_tool_parser import Hermes2ProToolParser +from .mistral_tool_parser import MistralToolParser + +__all__ = ['ToolParser', 'Hermes2ProToolParser', 'MistralToolParser'] \ No newline at end of file diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py new file mode 100644 index 0000000000000..e6c1a2809f2e6 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -0,0 +1,64 @@ +from typing import Optional, Union, List, Dict + +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, \ + AutoTokenizer + +from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation, \ + DeltaMessage +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class ToolParser: + """ + Abstract ToolParser class that should not be used directly. Provided + properties and methods should be used in + derived classes. + """ + + def __init__(self, + tokenizer: Optional[Union[PreTrainedTokenizer, + PreTrainedTokenizerFast, + PreTrainedTokenizerFast, + AutoTokenizer]] = None): + self.prev_tool_call_arr: List[Dict] = [] + # the index of the tool call that is currently being parsed + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.current_tool_initial_sent: bool = False + self.streamed_args_for_tool: List[str] = [] + + self.model_tokenizer = tokenizer + + @staticmethod + def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: + """ + Static method that should be implemented for extracting tool calls from + a complete model-generated string. + Used for non-streaming responses where we have the entire model response + available before sending to the client. + Static because it's stateless. + """ + raise NotImplementedError( + 'AbstractToolParser.extract_tool_calls has not been implemented!') + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: List[int], + current_token_ids: List[int], + delta_token_ids: List[int], + ) -> Union[DeltaMessage, None]: + """ + Instance method that should be implemented for extracting tool calls + from an incomplete response; for use when handling tool calls and + streaming. Has to be an instance method because it requires state - + the current text/ tokens/diffs, but also the information about what has + previously been parsed and extracted (see constructor) + """ + raise NotImplementedError( + 'AbstractToolParser.extract_tool_calls_streaming has not been ' + 'implemented!') diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py new file mode 100644 index 0000000000000..72a8a96b9fc73 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -0,0 +1,349 @@ +import json +import re +from typing import Optional, Union, List, Dict + +import partial_json_parser +from partial_json_parser.core.options import Allow +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, \ + AutoTokenizer + +from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation, \ + ToolCall, FunctionCall, DeltaMessage, DeltaToolCall, DeltaFunctionCall, \ + InitialDeltaToolCall +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser +from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class Hermes2ProToolParser(ToolParser): + tool_call_start_token: str = '' + tool_call_end_token: str = '' + + # regex to match between and OR between + # and EOS (happens sometimes :)) + tool_call_regex = re.compile( + r'(.*?)|(.*)', re.DOTALL) + scratch_pad_regex = re.compile(r'(.*?)', + re.DOTALL) + + @staticmethod + def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if Hermes2ProToolParser.tool_call_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + else: + + try: + # there are two possible captures - between tags, or between a + # tag and end-of-string so the result of + # findall is an array of tuples where one is a function call and + # the other is None + function_call_tuples = ( + Hermes2ProToolParser.tool_call_regex.findall(model_output)) + + # load the JSON, and then use it to build the Function and + # Tool Call + raw_function_calls = [ + json.loads(match[0] if match[0] else match[1]) + for match in function_call_tuples + ] + tool_calls = [ + ToolCall( + type='function', + function=FunctionCall( + name=function_call['name'], + # function call args are JSON but as a string + arguments=json.dumps(function_call['arguments']))) + for function_call in raw_function_calls + ] + + content = model_output[:model_output.find( + Hermes2ProToolParser.tool_call_start_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None) + + except Exception as e: + logger.error("Error in extracting tool call from response %s", + e) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def __init__(self, + tokenizer: Optional[Union[PreTrainedTokenizer, + PreTrainedTokenizerFast, + PreTrainedTokenizerFast, + AutoTokenizer]] = None): + super().__init__(tokenizer) + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent = False + self.current_tool_initial_sent: bool = False + self.streamed_args_for_tool: List[str] = [ + ] # map what has been streamed for each tool so far to a list + + if not self.model_tokenizer: + raise ValueError( + 'The model tokenizer must be passed to the ToolParser ' + 'constructor during construction.') + self.tool_call_start_token_id: int = self.model_tokenizer.vocab[ + ''] + self.tool_call_end_token_id: int = self.model_tokenizer.vocab[ + ''] + if not self.tool_call_start_token_id or not self.tool_call_end_token_id: + raise RuntimeError( + 'Hermes 2 Pro Tool parser could not locate tool call start/end ' + 'tokens in the tokenizer!') + + def extract_tool_calls_streaming( + self, previous_text: str, current_text: str, delta_text: str, + previous_token_ids: List[int], current_token_ids: List[int], + delta_token_ids: List[int]) -> Union[DeltaMessage, None]: + + logger.debug(f'delta_text: {delta_text}') + logger.debug(f'delta_token_ids: {delta_token_ids}') + # check to see if we should be streaming a tool call - is there a + if self.tool_call_start_token_id not in current_token_ids: + logger.debug('No tool call tokens found!') + return DeltaMessage(content=delta_text) + + else: + try: + + # figure out where we are in the parsing by counting tool call + # start & end tags + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) + + # a cheap case - we're generating text, NOT tool calls. + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count): + logger.debug( + 'Generating text content! skipping tool parsing.') + return DeltaMessage(content=delta_text) + + # most of the time, we're going in here - we need to do partial + # JSON parsing and build stuff. + else: + # flags for partial JSON parting. exported constants from + # "Allow" are handled via BIT MASK + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + + # if a new tool call is being started. unusual since + # normally the first "cheap case" will be hit. + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + else: + tool_call_portion = None + text_portion = None + delta = None + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.current_tool_initial_sent = False + self.streamed_args_for_tool.append('') + logger.debug( + f'Starting on a new tool {self.current_tool_id}') + + # if an existing tool call is being updated - the most + # common case! + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # if the current tool call is being closed + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count > prev_tool_end_count): + logger.debug('Closing the current tool call!') + diff = self.prev_tool_call_arr[ + self.current_tool_id].get('arguments') + if diff: + diff = json.dumps(diff).replace( + self.streamed_args_for_tool[ + self.current_tool_id], '') + logger.debug( + f'Finishing tool and found diff that had not ' + f'been streamed yet: {diff}') + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) + + else: + logger.error( + 'INVARIANT - invalid state trying to parse tool ' + 'calls (wtf?)') + delta = None + return delta + + logger.debug(f'Tool call portion: {tool_call_portion}') + current_tool_call = partial_json_parser.loads( + tool_call_portion, + flags) if tool_call_portion else None + logger.debug(f'Parsed tool call {current_tool_call}') + + # make sure to send the initial message first if we haven't + # already - with the tool ID + if not self.current_tool_initial_sent: + logger.debug('Sending InitialDeltaToolCall') + self.current_tool_initial_sent = True + return DeltaMessage(tool_calls=[ + InitialDeltaToolCall( + index=self.current_tool_id).model_dump( + exclude_none=True) + ]) + + # after that, make sure we send the function name before + # any arguments + elif not self.current_tool_name_sent: + function_name: Union[ + str, None] = current_tool_call.get('name') + if function_name: + logger.debug( + f'Sending DeltaToolCall with function name ' + f'{function_name}!') + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + name=function_name). + model_dump(exclude_none=True)) + ]) + else: + return None + else: + # if there is no tool calls + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = DeltaMessage( + content=delta_text + ) if text_portion is not None else None + # now, the nitty-gritty of tool calls + else: + # now we have the portion to parse as tool call. + if text_portion is not None: + logger.debug(f'Also, will send text portion ' + f'{text_portion}') + + logger.debug( + f'Trying to parse current tool call with ID ' + f'{self.current_tool_id}') + if len(self.prev_tool_call_arr + ) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + logger.debug( + 'Pushed dummy value into tool call arr') + # main logic for tool parsing here + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get('arguments') + cur_arguments = current_tool_call.get( + 'arguments' + ) # arguments, if any, in current dict + + logger.debug( + f'Diffing old arguments {prev_arguments} ' + f'against new ones {cur_arguments}') + if not cur_arguments and not prev_arguments: + logger.debug( + f'Skipping text {delta_text} - no arguments' + ) + delta = None + elif not cur_arguments and prev_arguments: + logger.error( + 'INVARIANT - impossible to have arguments ' + 'reset mid-call') + delta = None + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments) + logger.debug(f'Finding {delta_text} in ' + f'{cur_arguments_json}') + arguments_delta = cur_arguments_json[: + cur_arguments_json + .index( + delta_text + ) + + len(delta_text + )] + logger.debug( + f'First tokens in arguments received:' + f' {arguments_delta}') + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta + ).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + logger.debug( + f"Searching for diff between " + f"\n{cur_args_json}\n{prev_args_json}") + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + logger.debug( + f'Got argument diff: {argument_diff}') + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + delta = None + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for + # the next iteration + if self.current_tool_id == len( + self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[ + self.current_tool_id] = current_tool_call + else: + self.prev_tool_call_arr.append( + current_tool_call) + + # TODO REPLACE ME WITH TOOL CALL + # delta = DeltaMessage(content=delta_text) + return delta + + except Exception as e: + logger.error( + f'Error trying to handle streaming tool call: {e}') + logger.debug( + 'Skipping chunk as a result of tool streaming extraction ' + 'error') + return None # do not stream a delta. skip this token ID. diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py new file mode 100644 index 0000000000000..5fd475a7e5f30 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -0,0 +1,321 @@ +import json +import re +from typing import List, Optional, Union, Dict + +import partial_json_parser +from partial_json_parser.core.options import Allow +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, \ + AutoTokenizer + +from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation, \ + ToolCall, FunctionCall, DeltaMessage, DeltaToolCall, DeltaFunctionCall, \ + InitialDeltaToolCall +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser +from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class MistralToolParser(ToolParser): + """ + Tool call parser for Mistral 7B Instruct v0.3, intended for use with the + examples/tool_chat_template_mistral.jinja template. There are several + IMPORTANT CAVEATS for this parser: + - The chat template is NOT official and does not work well if you try to + get the model to call 2+ tools at once without temperature=0. + Stick to only one tool call per generation, or set temp to 0 + as the chat template is not reliable with > 1 and the model + Will lose coherence. + - Mistral's tool call format, that this translates into an OpenAI + format, uses SINGLE QUOTES which cannot be parsed to JSON. To enable + JSON parsing and serialization, we find-and-replace these with + DOUBLE QUOTES. To prevent tool call corruption / deserialization + failure, ensure that your tool calls and in particular your + ARGUMENTS never contain single or double quotes except as JSON + control characters. + + Used when --enable-api-tools --enable-auto-tool-choice --tool-call-parser + mistral are all set + """ + + # the bot_token is the token indicating tool call(s) follow. Tokens before + # this token will be parsed as content; and + # if not present, the entire response will be parsed as text content. + bot_token: str = '[TOOL_CALLS]' # string literal + bot_token_id: int = 5 # token ID thereof from the models' tokenizer + tool_call_regex = re.compile(r'\[{.*?}\]', re.DOTALL) + + @staticmethod + def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. Requires + find-and-replacing single quotes with double quotes for JSON parsing, + make sure your tool call arguments don't ever include quotes! + """ + + logger.debug( + 'Trying to extract mistral tool calls from the following:') + logger.debug(model_output) + # Get the tool call token from the tokenizer + if MistralToolParser.bot_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + else: + try: + + # this will throw an exception if we can't find the tool call + # properly + raw_tool_call = MistralToolParser.tool_call_regex.findall( + model_output.replace(MistralToolParser.bot_token, + '') # remove BOT token + .replace("'", '"') # replace string quotes + )[0] + + # load the JSON, and then use it to build the Function and + # Tool Call + function_call_arr = json.loads(raw_tool_call) + tool_calls: List[ToolCall] = [ + ToolCall( + type='function', + function=FunctionCall( + name=raw_function_call['name'], + # function call args are JSON but as a string + arguments=json.dumps( + raw_function_call['arguments']))) + for raw_function_call in function_call_arr + ] + content = model_output.split(MistralToolParser.bot_token)[0] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if len(content) > 0 else None) + + except Exception as e: + logger.error("Error in extracting tool call from response: %s", + e) + print('ERROR', e) + # return information to just treat the tool call as regular JSON + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def __init__(self, + tokenizer: Optional[Union[PreTrainedTokenizer, + PreTrainedTokenizerFast, + PreTrainedTokenizerFast, + AutoTokenizer]] = None): + super().__init__(tokenizer) + + # initialize properties used for state when parsing tool calls in + # streaming mode + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.current_tool_initial_sent: bool = False + self.streamed_args_for_tool: List[str] = [ + ] # map what has been streamed for each tool so far to a list + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: List[int], + current_token_ids: List[int], + delta_token_ids: List[int], + ) -> Union[DeltaMessage, None]: + + # if the tool call token is not in the tokens generated so far, append + # output to contents since it's not a tool + if self.bot_token_id not in current_token_ids: + return DeltaMessage(content=delta_text) + + # if the tool call token ID IS in the tokens generated so far, that + # means we're parsing as tool calls now + else: + + # handle if we detected the BOT token which means the start of tool + # calling + if self.bot_token_id in delta_token_ids: + # if it's the only token, return None, so we don't send a chat + # completion any don't send a control token + if len(delta_token_ids) == 1: + return None + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + + # replace BOT token with empty string, and convert single quotes + # to double to allow parsing as JSON since mistral uses single + # quotes instead of double for tool calls + tool_call_message_portion = current_text.split( + self.bot_token)[1] + parsable_arr = tool_call_message_portion.replace('\'', '"') + + # logger.debug('parsing: %s', parsable_arr) + + # tool calls are generated in an array, so do partial JSON + # parsing on the entire array + tool_call_arr: List[Dict] = partial_json_parser.loads( + parsable_arr, flags) + + # select as the current tool call the one we're on the state at + current_tool_call: Dict = tool_call_arr[self.current_tool_id] + + # case: we are starting a new tool in the array + # -> array has nonzero length AND length has moved past curscor + if len(tool_call_arr) > 0 and len( + tool_call_arr) > self.current_tool_id + 1: + + # if we're moving on to a new call, first make sure we haven't missed anything in the previous + # one that was auto-generated due to JSON completions, but wasn't streamed to the client yet. + if self.current_tool_id >= 0: + diff: Union[str, + None] = current_tool_call.get('arguments') + if diff: + diff = json.dumps(diff).replace( + self.streamed_args_for_tool[ + self.current_tool_id], '') + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += diff + else: + delta = None + else: + delta = None + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.current_tool_initial_sent = False + self.streamed_args_for_tool.append('') + logger.debug('starting on new tool %d', + self.current_tool_id) + return delta + + # case: update an existing tool - this is handled below + elif len( + tool_call_arr + ) - 1 == self.current_tool_id and self.current_tool_id >= 0: + # logger.debug('update to tool %d', self.current_tool_id) + pass + + # if there is NOTHING in the array, e.g. if only the open + # bracket was streamed yet + else: + # logger.debug('No tool call detected yet!') + return None + + # if the current tool initial data incl. the id, type=function + # and idx not sent, send that + if not self.current_tool_initial_sent: + logger.debug('Sending InitialDeltaToolCall') + self.current_tool_initial_sent = True + delta = DeltaMessage(tool_calls=[ + InitialDeltaToolCall( + index=self.current_tool_id).model_dump( + exclude_none=True) + ]) + + # if the current tool name hasn't been sent, send if available + # - otherwise no chunks + elif not self.current_tool_name_sent: + function_name = current_tool_call.get('name') + if function_name: + logger.debug( + f'Sending DeltaToolCall with function name ' + f'{function_name}!') + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + + # now we know we're on the same tool call and we're streaming + # arguments + else: + + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get('arguments') + cur_arguments = current_tool_call.get('arguments') + + new_text = delta_text.replace('\'', '"') + + if not cur_arguments and not prev_arguments: + + delta = None + elif not cur_arguments and prev_arguments: + logger.error( + 'INVARIANT - impossible to have arguments reset ' + 'mid-arguments') + delta = None + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments) + logger.debug( + f'Finding {new_text} in |{cur_arguments_json}|') + arguments_delta = cur_arguments_json[: + cur_arguments_json + .index(new_text) + + len(new_text)] + logger.debug(f'First tokens in arguments received: ' + f'{arguments_delta}') + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + logger.debug( + f'Searching for diff between \n{cur_args_json}\n' + f'{prev_args_json}') + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + logger.debug(f'got arguments diff: {argument_diff}') + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + # try parsing it with regular JSON - if it works we're + # at the end, and we need to send the difference between + # tokens streamed so far and the valid JSON + delta = None + + # check to see if the name is defined and has been sent. if so, + # stream the name - otherwise keep waiting + # finish by setting old and returning None as base case + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception as e: + logger.error( + f'Error trying to handle streaming tool call: {e}') + logger.debug( + 'Skipping chunk as a result of tool streaming extraction ' + 'error') + return None diff --git a/vllm/entrypoints/openai/tool_parsers/utils.py b/vllm/entrypoints/openai/tool_parsers/utils.py new file mode 100644 index 0000000000000..50eaa906377bb --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/utils.py @@ -0,0 +1,90 @@ +def find_common_prefix(s1: str, s2: str) -> str: + """ + Finds a common prefix that is shared between two strings, if there is one. + Order of arguments is NOT important. + + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. + + e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') -> + '{"fruit": "ap' + """ + prefix = '' + min_length = min(len(s1), len(s2)) + for i in range(0, min_length): + if s1[i] == s2[i]: + prefix += s1[i] + else: + break + return prefix + + +def find_common_suffix(s1: str, s2: str) -> str: + """ + Finds a common suffix shared between two strings, if there is one. Order of + arguments is NOT important. + Stops when the suffix ends OR it hits an alphanumeric character + + e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}' + """ + suffix = '' + min_length = min(len(s1), len(s2)) + for i in range(1, min_length + 1): + if s1[-i] == s2[-i] and not s1[-i].isalnum(): + suffix = s1[-i] + suffix + else: + break + return suffix + + +def extract_intermediate_diff(curr: str, old: str) -> str: + """ + Given two strings, extract the difference in the middle between two strings + that are known to have a common prefix and/or suffix. + + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. The order of arguments IS + important - the new version of the partially-parsed JSON must be the first + argument, and the secnod argument must be from the previous generation. + + What it returns, is tokens that should be streamed to the client. + + e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}') + -> 'ple' + + """ + suffix = find_common_suffix(curr, old) + + # prevent double-counting + s2_old = old + old = old[::-1].replace(suffix[::-1], '', 1)[::-1] + prefix = find_common_prefix(curr, old) + diff = curr + if len(suffix): + diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1] + + if len(prefix): + diff = diff.replace( + prefix, '', + 1) # replace the prefix only once in case it's mirrored + + return diff + + +def find_all_indices(string, substring): + """ + Find all (starting) indices of a substring in a given string. Useful for + tool call extraction + """ + indices = [] + index = -1 + while True: + index = string.find(substring, index + 1) + if index == -1: + break + indices.append(index) + return indices From 76a27bd4de1b8074db800778ccd20c9afe9f1559 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 5 Aug 2024 21:32:52 -0500 Subject: [PATCH 100/222] fix(tests): raise a valueError that was being passed instead of raised --- vllm/entrypoints/openai/protocol.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 013b2ceb184a1..b7538b39603a7 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -322,6 +322,9 @@ def validate_stream_options(cls, values): @model_validator(mode="before") @classmethod def check_guided_decoding_count(cls, data): + if isinstance(data, ValueError): + raise data + guide_count = sum([ "guided_json" in data and data["guided_json"] is not None, "guided_regex" in data and data["guided_regex"] is not None, From 990a0e58c6ab5440464ab8c026bb60f2049962fd Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 6 Aug 2024 18:29:15 -0500 Subject: [PATCH 101/222] fix(PEP8): openai chat completion client with tools --- ...penai_chat_completion_client_with_tools.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/openai_chat_completion_client_with_tools.py b/examples/openai_chat_completion_client_with_tools.py index bddf97869f59e..909bf7fe681eb 100644 --- a/examples/openai_chat_completion_client_with_tools.py +++ b/examples/openai_chat_completion_client_with_tools.py @@ -32,7 +32,8 @@ "type": "string", "description": - "the two-letter abbreviation for the state that the city is in, e.g. 'CA' which would mean 'California'" + "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'" }, "unit": { "type": "string", @@ -55,7 +56,7 @@ "role": "user", "content": - "Can you tell me what the temperate will be in Dallas and San Francisco, in fahrenheit?" + "Can you tell me what the temperate will be in Dallas, in fahrenheit?" }] chat_completion = client.chat.completions.create(messages=messages, @@ -87,19 +88,17 @@ if chunk.choices[0].delta.tool_calls[0].index != tool_call_idx: if tool_call_idx >= 0: print( - f'streamed tool call arguments: {arguments[tool_call_idx]}\n\n' + f'streamed tool call arguments: {arguments[tool_call_idx]}' ) tool_call_idx = chunk.choices[0].delta.tool_calls[0].index arguments.append('') if chunk.choices[0].delta.tool_calls[0].id: - print( - f'streamed tool call id: {chunk.choices[0].delta.tool_calls[0].id}' - ) + print(f'streamed tool call id: ' + f'{chunk.choices[0].delta.tool_calls[0].id}') if chunk.choices[0].delta.tool_calls[0].function: if chunk.choices[0].delta.tool_calls[0].function.name: - print( - f'streamed tool call name: {chunk.choices[0].delta.tool_calls[0].function.name}' - ) + print(f'streamed tool call name: ' + f'{chunk.choices[0].delta.tool_calls[0].function.name}') if chunk.choices[0].delta.tool_calls[0].function.arguments: arguments[tool_call_idx] += chunk.choices[0].delta.tool_calls[ 0].function.arguments @@ -117,7 +116,8 @@ # Now, simulate a tool call def get_current_weather(city: str, state: str, unit: 'str'): - return "The weather in Dallas, Texas is 85 degrees fahrenheit. It is partly cloudly, with highs in the 90's." + return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is " + "partly cloudly, with highs in the 90's.") available_tools = {"get_current_weather": get_current_weather} From 751d5a868f9a6a3d80f30cd1ee3145e799da0f11 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 6 Aug 2024 18:31:46 -0500 Subject: [PATCH 102/222] fix(PEP8): chat_utils --- vllm/entrypoints/chat_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 055092ca55f52..61dbb224d9ba4 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1,7 +1,7 @@ import codecs from dataclasses import dataclass, field from functools import lru_cache -from typing import Awaitable, Iterable, List, Optional, cast, Any, Union +from typing import Awaitable, Iterable, List, Optional, cast, Union # yapf conflicts with isort for this block # yapf: disable @@ -12,7 +12,6 @@ # yapf: enable # pydantic needs the TypedDict from typing_extensions from transformers import PreTrainedTokenizer -from typing_extensions import Required, TypedDict from vllm.config import ModelConfig from vllm.logger import init_logger @@ -146,9 +145,8 @@ def parse_chat_message_content( content = message.get("content") tool_call_id = message.get('content') tool_calls = message.get('tool_calls') - name = message.get( - 'name', '' - ) # no longer used by OpenAI, was formerly. used for tool calls by some models still + # no longer used by OpenAI, but some models still use it for tool calls. + name = message.get('name', '') # empty case if content is None and tool_calls is None: From 834969fb537e35fa11d41acb10594bc0e12a3f74 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 6 Aug 2024 18:34:28 -0500 Subject: [PATCH 103/222] fix(PEP8): protocol.py --- vllm/entrypoints/openai/protocol.py | 37 ++++++++++++++++------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index b7538b39603a7..a257478c187ca 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1,7 +1,7 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import time -from typing import Any, Dict, List, Literal, Optional, Union, Type, final +from typing import Any, Dict, List, Literal, Optional, Union, final import torch from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -167,8 +167,9 @@ class ChatCompletionRequest(OpenAIBaseModel): tools: Optional[List[ChatCompletionToolsParam]] = None tool_choice: Optional[Union[Union[Literal["none"], Literal["auto"]], ChatCompletionNamedToolChoiceParam]] = "none" - parallel_tool_calls: Optional[ - bool] = False # NOTE this will be ignored by VLLM as the behavior is determined by the model + + # NOTE this will be ignored by VLLM -- the model determines the behavior + parallel_tool_calls: Optional[bool] = False user: Optional[str] = None # doc: begin-chat-completion-sampling-params @@ -346,7 +347,8 @@ def check_guided_decoding_count(cls, data): @classmethod def check_tool_usage(cls, data): - # if "tool_choice" is not specified but tools are provided, default to "auto" tool_choice + # if "tool_choice" is not specified but tools are provided, + # default to "auto" tool_choice if "tool_choice" not in data and "tools" in data: data["tool_choice"] = "auto" @@ -358,38 +360,38 @@ def check_tool_usage(cls, data): raise ValueError( "When using `tool_choice`, `tools` must be set.") - # make sure that tool choice is either a named tool OR that it's set to "auto" + # make sure that tool choice is either a named tool + # OR that it's set to "auto" if data["tool_choice"] != "auto" and not isinstance( data["tool_choice"], dict): raise ValueError( - "`tool_choice` must either be a named tool or \"auto\". `tool_choice=\"none\" is not supported." - ) + "`tool_choice` must either be a named tool or \"auto\". " + "`tool_choice=\"none\" is not supported.") - # ensure that if "tool_choice" is specified as an object, it matches a valid tool + # ensure that if "tool_choice" is specified as an object, + # it matches a valid tool if isinstance(data["tool_choice"], dict): valid_tool = False specified_function = data["tool_choice"]["function"] if not specified_function: return ValueError( 'Incorrectly formatted `tool_choice`. Should be like ' - + - '`{"type": "function", "function": {"name": "my_function"}}`' - ) + '`{"type": "function",' + ' "function": {"name": "my_function"}}`') specified_function_name = specified_function["name"] if not specified_function_name: return ValueError( 'Incorrectly formatted `tool_choice`. Should be like ' - + - '`{"type": "function", "function": {"name": "my_function"}}`' - ) + '`{"type": "function", ' + '"function": {"name": "my_function"}}`') for tool in data['tools']: if tool["function"]["name"] == specified_function_name: valid_tool = True break if not valid_tool: return ValueError( - "The tool specified in `tool_choice` does not match any of the specified `tools`" - ) + "The tool specified in `tool_choice` does not match any" + " of the specified `tools`") # TODO validate tools return data @@ -465,7 +467,7 @@ class CompletionRequest(OpenAIBaseModel): ) guided_json: Optional[Union[str, dict, BaseModel]] = Field( default=None, - description=("If specified, the output will follow the JSON schema."), + description="If specified, the output will follow the JSON schema.", ) guided_regex: Optional[str] = Field( default=None, @@ -865,3 +867,4 @@ class DetokenizeRequest(OpenAIBaseModel): class DetokenizeResponse(OpenAIBaseModel): prompt: str +ser \ No newline at end of file From fcd69d753d7f89a860943806285130c9ae6eb617 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 6 Aug 2024 19:18:24 -0500 Subject: [PATCH 104/222] fix(PEP8): serving_chat & fix typo in protocol --- vllm/entrypoints/openai/protocol.py | 1 - vllm/entrypoints/openai/serving_chat.py | 78 ++++++++++++------------- 2 files changed, 39 insertions(+), 40 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index a257478c187ca..76655a5f1b5ef 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -867,4 +867,3 @@ class DetokenizeRequest(OpenAIBaseModel): class DetokenizeResponse(OpenAIBaseModel): prompt: str -ser \ No newline at end of file diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index fabaf9e8042dd..d4e42e5ca21dd 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,17 +1,15 @@ import time import json +from typing import Type from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List, - Optional, Type) -from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List, - Optional, Union, Sequence as GenericSequence) + Optional, Sequence as GenericSequence) from typing import Union from fastapi import Request from transformers import PreTrainedTokenizer from vllm.config import ModelConfig from vllm.engine.protocol import AsyncEngineClient -from vllm.entrypoints.chat_utils import (ConversationMessage, - load_chat_template, +from vllm.entrypoints.chat_utils import (load_chat_template, parse_chat_message_content, ConversationMessage) from vllm.entrypoints.logger import RequestLogger @@ -77,9 +75,9 @@ def __init__(self, self.enable_auto_tools: bool = enable_auto_tools or False if self.enable_auto_tools: logger.info( - '"Auto" tool choice has been enabled please note that while the ' - 'parallel_tool_calls client option is preset for compatibility ' - 'reasons, it will be ignored.') + '"Auto" tool choice has been enabled please note that while' + ' the parallel_tool_calls client option is preset for ' + 'compatibility reasons, it will be ignored.') self.tool_parser: Optional[Type[ToolParser]] = None if self.enable_auto_tools: @@ -166,12 +164,12 @@ async def create_chat_completion( return self.create_error_response( 'tool_choice = "required" is not supported!') - # "auto" tools requires --enable-api-tools --enable-auto-tool-choice and --tool-parser + # "auto" tools requires --enable-auto-tool-choice and --tool-parser if request.tool_choice == 'auto' and not ( self.enable_auto_tools and self.tool_parser is not None): return self.create_error_response( - '"auto" tool choice requires --enable-auto-tool-choice and --tool-parser to be set' - ) + '"auto" tool choice requires ' + '--enable-auto-tool-choice and --tool-parser to be set') request_id = f"chat-{random_uuid()}" try: @@ -237,9 +235,8 @@ async def create_chat_completion( conversation, tokenizer) if not isinstance(generator, ChatCompletionResponse): - raise ValueError( - 'Expected generator to be instance of ChatCompletionResponse' - ) + raise ValueError('Expected generator to be instance of ' + 'ChatCompletionResponse') return generator except ValueError as e: @@ -297,7 +294,7 @@ async def chat_completion_stream_generator( model=model_name) if (request.stream_options and request.stream_options.include_usage): - if (request.stream_options.continuous_usage_stats): + if request.stream_options.continuous_usage_stats: prompt_tokens = len(res.prompt_token_ids) usage = UsageInfo(prompt_tokens=prompt_tokens, completion_tokens=0, @@ -379,17 +376,13 @@ async def chat_completion_stream_generator( delta_message: Optional[DeltaMessage] = None # handle streaming deltas for tools with tool_choice - if request.tool_choice and type( - request.tool_choice - ) is ChatCompletionNamedToolChoiceParam: + if (request.tool_choice and type(request.tool_choice) is + ChatCompletionNamedToolChoiceParam): delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall( - function=DeltaFunctionCall( - name=request.tool_choice.function.name, - arguments=delta_text), - index= - i # note: ok to hard-code to 0 since named tool calling doesn't support arrays - ) + DeltaToolCall(function=DeltaFunctionCall( + name=request.tool_choice.function.name, + arguments=delta_text), + index=i) ]) # handle streaming deltas for tools with tool_choice @@ -398,14 +391,15 @@ async def chat_completion_stream_generator( or request.tool_choice == 'auto') and self.enable_auto_tools): - delta_message = tool_parser.extract_tool_calls_streaming( - previous_text=previous_texts[i], - current_text=output.text, - delta_text=delta_text, - previous_token_ids=output. - token_ids[:-1 * len(delta_token_ids)], - current_token_ids=output.token_ids, - delta_token_ids=delta_token_ids) + delta_message = ( + tool_parser.extract_tool_calls_streaming( + previous_text=previous_texts[i], + current_text=output.text, + delta_text=delta_text, + previous_token_ids=output. + token_ids[:-1 * len(delta_token_ids)], + current_token_ids=output.token_ids, + delta_token_ids=delta_token_ids)) else: delta_message = DeltaMessage(content=delta_text) @@ -413,7 +407,9 @@ async def chat_completion_stream_generator( previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) - # if the message delta is None (e.g. because it was a "control token" for tool calls, then + # if the message delta is None (e.g. because it was a + # "control token" for tool calls or the parser otherwise + # wasn't ready to send a token, then # get the next token without streaming a chunk if delta_message is None: continue @@ -518,7 +514,8 @@ async def chat_completion_stream_generator( yield f"data: {data}\n\n" finish_reason_sent[i] = True - # once the final token is handled, if stream_options.include_usage is sent, send the usage + # once the final token is handled, if stream_options.include_usage + # is sent, send the usage if (request.stream_options and request.stream_options.include_usage): final_usage = UsageInfo( @@ -611,7 +608,8 @@ async def chat_completion_full_generator( ]) tools_called = True - # if the request doesn't use tool choice OR specifies to not use a tool + # if the request doesn't use tool choice + # OR specifies to not use a tool elif not request.tool_choice or request.tool_choice == "none": message = ChatMessage(role=role, content=output.text) @@ -631,14 +629,16 @@ async def chat_completion_full_generator( tool_calls=tool_call_info.tool_calls) else: - # FOR NOW make it a chat message; we will have to detect the type to make it later. + # FOR NOW make it a chat message; we will have to detect + # the type to make it later. message = ChatMessage(role=role, content=output.text) # undetermined case that is still important to handle else: logger.error( - 'Error in chat_completion_full_generator - cannot determine if tools should ' - 'be extracted. Returning a standard chat completion.') + 'Error in chat_completion_full_generator - cannot determine' + ' if tools should be extracted. Returning a standard chat ' + 'completion.') message = ChatMessage(role=role, content=output.text) choice_data = ChatCompletionResponseChoice( From c4486376b1d5f6a281a1e792032a658cdbcc1631 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 6 Aug 2024 19:33:03 -0500 Subject: [PATCH 105/222] fix(PEP8): Hermes Tool parser --- .../openai/tool_parsers/hermes_tool_parser.py | 66 ++++++++++--------- 1 file changed, 34 insertions(+), 32 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 72a8a96b9fc73..02ce5e8af2400 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -109,8 +109,8 @@ def extract_tool_calls_streaming( previous_token_ids: List[int], current_token_ids: List[int], delta_token_ids: List[int]) -> Union[DeltaMessage, None]: - logger.debug(f'delta_text: {delta_text}') - logger.debug(f'delta_token_ids: {delta_token_ids}') + logger.debug('delta_text: %s', delta_text) + logger.debug('delta_token_ids: %s', delta_token_ids) # check to see if we should be streaming a tool call - is there a if self.tool_call_start_token_id not in current_token_ids: logger.debug('No tool call tokens found!') @@ -163,8 +163,8 @@ def extract_tool_calls_streaming( self.current_tool_name_sent = False self.current_tool_initial_sent = False self.streamed_args_for_tool.append('') - logger.debug( - f'Starting on a new tool {self.current_tool_id}') + logger.debug('Starting on a new tool %s', + self.current_tool_id) # if an existing tool call is being updated - the most # common case! @@ -185,8 +185,8 @@ def extract_tool_calls_streaming( self.streamed_args_for_tool[ self.current_tool_id], '') logger.debug( - f'Finishing tool and found diff that had not ' - f'been streamed yet: {diff}') + 'Finishing tool and found diff that had not ' + 'been streamed yet: %s', diff) return DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( @@ -201,11 +201,12 @@ def extract_tool_calls_streaming( delta = None return delta - logger.debug(f'Tool call portion: {tool_call_portion}') + logger.debug('Tool call portion: %s', tool_call_portion + or '') current_tool_call = partial_json_parser.loads( - tool_call_portion, + tool_call_portion or '{}', flags) if tool_call_portion else None - logger.debug(f'Parsed tool call {current_tool_call}') + logger.debug('Parsed tool call %s', current_tool_call) # make sure to send the initial message first if we haven't # already - with the tool ID @@ -225,8 +226,8 @@ def extract_tool_calls_streaming( str, None] = current_tool_call.get('name') if function_name: logger.debug( - f'Sending DeltaToolCall with function name ' - f'{function_name}!') + 'Sending DeltaToolCall with function name %s', + function_name) self.current_tool_name_sent = True return DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, @@ -248,12 +249,13 @@ def extract_tool_calls_streaming( else: # now we have the portion to parse as tool call. if text_portion is not None: - logger.debug(f'Also, will send text portion ' - f'{text_portion}') + logger.debug( + 'Also, will send text portion: %s', + text_portion) logger.debug( - f'Trying to parse current tool call with ID ' - f'{self.current_tool_id}') + 'Trying to parse current tool call with ID %s', + self.current_tool_id) if len(self.prev_tool_call_arr ) <= self.current_tool_id: self.prev_tool_call_arr.append({}) @@ -266,13 +268,13 @@ def extract_tool_calls_streaming( 'arguments' ) # arguments, if any, in current dict - logger.debug( - f'Diffing old arguments {prev_arguments} ' - f'against new ones {cur_arguments}') + logger.debug('diffing old arguments: %s', + prev_arguments) + logger.debug('against new ones: %s', cur_arguments) + if not cur_arguments and not prev_arguments: - logger.debug( - f'Skipping text {delta_text} - no arguments' - ) + logger.debug('Skipping text %s - no arguments', + delta_text) delta = None elif not cur_arguments and prev_arguments: logger.error( @@ -281,8 +283,8 @@ def extract_tool_calls_streaming( delta = None elif cur_arguments and not prev_arguments: cur_arguments_json = json.dumps(cur_arguments) - logger.debug(f'Finding {delta_text} in ' - f'{cur_arguments_json}') + logger.debug('finding %s in %s', delta_text, + cur_arguments_json) arguments_delta = cur_arguments_json[: cur_arguments_json .index( @@ -291,8 +293,8 @@ def extract_tool_calls_streaming( len(delta_text )] logger.debug( - f'First tokens in arguments received:' - f' {arguments_delta}') + 'First tokens in arguments received: %s', + arguments_delta) delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( @@ -306,13 +308,13 @@ def extract_tool_calls_streaming( elif cur_arguments and prev_arguments: cur_args_json = json.dumps(cur_arguments) prev_args_json = json.dumps(prev_arguments) - logger.debug( - f"Searching for diff between " - f"\n{cur_args_json}\n{prev_args_json}") + logger.debug('Searching for dif between\n%s', + cur_args_json) + logger.debug('and\n%s', prev_args_json) argument_diff = extract_intermediate_diff( cur_args_json, prev_args_json) - logger.debug( - f'Got argument diff: {argument_diff}') + logger.debug('got argument diff %s', + argument_diff) delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( @@ -341,8 +343,8 @@ def extract_tool_calls_streaming( return delta except Exception as e: - logger.error( - f'Error trying to handle streaming tool call: {e}') + logger.error('Error trying to handle streaming tool call: %s', + e) logger.debug( 'Skipping chunk as a result of tool streaming extraction ' 'error') From bd0b3a72fbadcc24588d3531e33395d3c098c55c Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 6 Aug 2024 19:38:25 -0500 Subject: [PATCH 106/222] fix(PEP8): format files with ./format --fix --- ...penai_chat_completion_client_with_tools.py | 3 +- vllm/entrypoints/chat_utils.py | 14 ++--- vllm/entrypoints/openai/protocol.py | 12 ++-- vllm/entrypoints/openai/serving_chat.py | 29 +++++---- .../tool_parsers/abstract_tool_parser.py | 10 +-- .../openai/tool_parsers/hermes_tool_parser.py | 22 ++++--- .../tool_parsers/mistral_tool_parser.py | 61 +++++++++++-------- vllm/entrypoints/openai/tool_parsers/utils.py | 2 +- .../guided_decoding/__init__.py | 3 +- .../guided_decoding/outlines_decoding.py | 10 +-- 10 files changed, 89 insertions(+), 77 deletions(-) diff --git a/examples/openai_chat_completion_client_with_tools.py b/examples/openai_chat_completion_client_with_tools.py index 909bf7fe681eb..4c177247986bc 100644 --- a/examples/openai_chat_completion_client_with_tools.py +++ b/examples/openai_chat_completion_client_with_tools.py @@ -1,6 +1,7 @@ -from openai import OpenAI import json +from openai import OpenAI + # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" openai_api_base = "http://localhost:8000/v1" diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 61dbb224d9ba4..1e671c586d8db 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1,25 +1,23 @@ import codecs from dataclasses import dataclass, field from functools import lru_cache -from typing import Awaitable, Iterable, List, Optional, cast, Union +from typing import Awaitable, Iterable, List, Optional, Union, cast # yapf conflicts with isort for this block # yapf: disable -from openai.types.chat import ChatCompletionContentPartImageParam - -from openai.types.chat import ChatCompletionContentPartTextParam - +from openai.types.chat import (ChatCompletionContentPartImageParam, + ChatCompletionContentPartTextParam) # yapf: enable # pydantic needs the TypedDict from typing_extensions from transformers import PreTrainedTokenizer from vllm.config import ModelConfig +from vllm.entrypoints.openai.protocol import (ChatCompletionContentPartParam, + ChatCompletionMessageParam, + ConversationMessage) from vllm.logger import init_logger from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import async_get_and_parse_image -from vllm.entrypoints.openai.protocol import (ChatCompletionMessageParam, - ChatCompletionContentPartParam, - ConversationMessage) logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 76655a5f1b5ef..14fbd9645a54c 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -4,19 +4,19 @@ from typing import Any, Dict, List, Literal, Optional, Union, final import torch +from openai.types.chat import ChatCompletionContentPartParam +from openai.types.chat import ( + ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) +from openai.types.chat import ( + ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) from pydantic import BaseModel, ConfigDict, Field, model_validator from transformers import PreTrainedTokenizer -from typing_extensions import Annotated, TypedDict, Required +from typing_extensions import Annotated, Required, TypedDict from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.pooling_params import PoolingParams from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.utils import random_uuid -from openai.types.chat import (ChatCompletionContentPartParam, - ChatCompletionMessageParam as - OpenAIChatCompletionMessageParam, - ChatCompletionContentPartParam as - OpenAIChatCompletionContentPartParam) class CustomChatCompletionMessageParam(TypedDict, total=False): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index d4e42e5ca21dd..06cbf4a17abb4 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,28 +1,33 @@ -import time import json -from typing import Type +import time from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List, - Optional, Sequence as GenericSequence) -from typing import Union + Optional) +from typing import Sequence as GenericSequence +from typing import Type, Union + from fastapi import Request +from jinja2 import Environment, FileSystemLoader, select_autoescape from transformers import PreTrainedTokenizer from vllm.config import ModelConfig from vllm.engine.protocol import AsyncEngineClient -from vllm.entrypoints.chat_utils import (load_chat_template, - parse_chat_message_content, - ConversationMessage) +from vllm.entrypoints.chat_utils import (ConversationMessage, + load_chat_template, + parse_chat_message_content) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, ChatCompletionLogProbs, ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, - FunctionCall, ToolCall, UsageInfo, DeltaToolCall, DeltaFunctionCall) + ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, + DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing, PromptAdapterPath) +from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser, + MistralToolParser, + ToolParser) from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.multimodal import MultiModalDataDict @@ -32,12 +37,6 @@ log_tracing_disabled_warning) from vllm.utils import random_uuid -from vllm.entrypoints.openai.tool_parsers import (ToolParser, - MistralToolParser, - Hermes2ProToolParser) - -from jinja2 import Environment, FileSystemLoader, select_autoescape - env = Environment(loader=FileSystemLoader('./'), autoescape=select_autoescape()) diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index e6c1a2809f2e6..6892657f9de50 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -1,10 +1,10 @@ -from typing import Optional, Union, List, Dict +from typing import Dict, List, Optional, Union -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, \ - AutoTokenizer +from transformers import (AutoTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerFast) -from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation, \ - DeltaMessage +from vllm.entrypoints.openai.protocol import (DeltaMessage, + ExtractedToolCallInformation) from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 02ce5e8af2400..2859b84c613b6 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -1,17 +1,21 @@ import json import re -from typing import Optional, Union, List, Dict +from typing import Dict, List, Optional, Union import partial_json_parser from partial_json_parser.core.options import Allow -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, \ - AutoTokenizer - -from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation, \ - ToolCall, FunctionCall, DeltaMessage, DeltaToolCall, DeltaFunctionCall, \ - InitialDeltaToolCall -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser -from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff +from transformers import (AutoTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerFast) + +from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + InitialDeltaToolCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 5fd475a7e5f30..433db4b902d5f 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -1,17 +1,21 @@ import json import re -from typing import List, Optional, Union, Dict +from typing import Dict, List, Optional, Union import partial_json_parser from partial_json_parser.core.options import Allow -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, \ - AutoTokenizer - -from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation, \ - ToolCall, FunctionCall, DeltaMessage, DeltaToolCall, DeltaFunctionCall, \ - InitialDeltaToolCall -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser -from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff +from transformers import (AutoTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerFast) + +from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + InitialDeltaToolCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) from vllm.logger import init_logger logger = init_logger(__name__) @@ -138,11 +142,11 @@ def extract_tool_calls_streaming( # handle if we detected the BOT token which means the start of tool # calling - if self.bot_token_id in delta_token_ids: + if (self.bot_token_id in delta_token_ids + and len(delta_token_ids) == 1): # if it's the only token, return None, so we don't send a chat # completion any don't send a control token - if len(delta_token_ids) == 1: - return None + return None # bit mask flags for partial JSON parsing. If the name hasn't been # sent yet, don't allow sending @@ -170,12 +174,14 @@ def extract_tool_calls_streaming( current_tool_call: Dict = tool_call_arr[self.current_tool_id] # case: we are starting a new tool in the array - # -> array has nonzero length AND length has moved past curscor + # -> array has > 0 length AND length has moved past cursor if len(tool_call_arr) > 0 and len( tool_call_arr) > self.current_tool_id + 1: - # if we're moving on to a new call, first make sure we haven't missed anything in the previous - # one that was auto-generated due to JSON completions, but wasn't streamed to the client yet. + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. if self.current_tool_id >= 0: diff: Union[str, None] = current_tool_call.get('arguments') @@ -234,8 +240,8 @@ def extract_tool_calls_streaming( function_name = current_tool_call.get('name') if function_name: logger.debug( - f'Sending DeltaToolCall with function name ' - f'{function_name}!') + 'Sending DeltaToolCall with function name %s', + function_name) delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( @@ -266,14 +272,15 @@ def extract_tool_calls_streaming( delta = None elif cur_arguments and not prev_arguments: cur_arguments_json = json.dumps(cur_arguments) - logger.debug( - f'Finding {new_text} in |{cur_arguments_json}|') + logger.debug('finding %s in |%s|', new_text, + cur_arguments_json) + arguments_delta = cur_arguments_json[: cur_arguments_json .index(new_text) + len(new_text)] - logger.debug(f'First tokens in arguments received: ' - f'{arguments_delta}') + logger.debug('First tokens in arguments received: %s', + arguments_delta) delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( @@ -286,12 +293,12 @@ def extract_tool_calls_streaming( elif cur_arguments and prev_arguments: cur_args_json = json.dumps(cur_arguments) prev_args_json = json.dumps(prev_arguments) - logger.debug( - f'Searching for diff between \n{cur_args_json}\n' - f'{prev_args_json}') + logger.debug('Searching for diff between \n%s\n%s', + cur_args_json, prev_args_json) + argument_diff = extract_intermediate_diff( cur_args_json, prev_args_json) - logger.debug(f'got arguments diff: {argument_diff}') + logger.debug('got arguments diff: %s', argument_diff) delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( @@ -313,8 +320,8 @@ def extract_tool_calls_streaming( return delta except Exception as e: - logger.error( - f'Error trying to handle streaming tool call: {e}') + logger.error('Error trying to handle streaming tool call: %s', + e) logger.debug( 'Skipping chunk as a result of tool streaming extraction ' 'error') diff --git a/vllm/entrypoints/openai/tool_parsers/utils.py b/vllm/entrypoints/openai/tool_parsers/utils.py index 50eaa906377bb..52d2e6ed985fd 100644 --- a/vllm/entrypoints/openai/tool_parsers/utils.py +++ b/vllm/entrypoints/openai/tool_parsers/utils.py @@ -60,7 +60,7 @@ def extract_intermediate_diff(curr: str, old: str) -> str: suffix = find_common_suffix(curr, old) # prevent double-counting - s2_old = old + #s2_old = old old = old[::-1].replace(suffix[::-1], '', 1)[::-1] prefix = find_common_prefix(curr, old) diff = curr diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index af70e00727e20..4e8b312b04e49 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -34,7 +34,8 @@ def _adapt_request_for_tool_use(request: Union[CompletionRequest, if type(request) is CompletionRequest: return request - # user has chosen to not use any tool, OR is allowing the model to choose a tool. + # user has chosen to not use any tool, + # OR is allowing the model to choose a tool. if request.tool_choice == "none" or request.tool_choice == "auto": return request diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index ccb46d9537aae..1429df2c1d793 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -9,8 +9,8 @@ from transformers import PreTrainedTokenizerBase from vllm.entrypoints.openai.protocol import ( - ChatCompletionRequest, CompletionRequest, - ChatCompletionNamedToolChoiceParam) + ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, + CompletionRequest) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) @@ -82,14 +82,16 @@ def _get_guide_and_mode( request: Union[CompletionRequest, ChatCompletionRequest] ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: - # if the request is a chat completion request, AND the tool choice is a named tool choice, do guided decoding + # if the request is a chat completion request, AND the tool choice is a + # named tool choice, do guided decoding # using that tool as the JSON schema if isinstance(request, ChatCompletionRequest) and isinstance( request.tool_choice, ChatCompletionNamedToolChoiceParam): # Guided generation for tools/functions parameters if request.tool_choice.type == "function": for tool in request.tools: - if tool.type == "function" and tool.function.name == request.tool_choice.function.name: + if (tool.type == "function" and tool.function.name + == request.tool_choice.function.name): json = json_dumps(tool.function.parameters, sort_keys=True) return json, GuidedDecodingMode.JSON return None, None From 66049d8d4dab445d4be73ce8d99cca6d81d6cc4e Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 6 Aug 2024 20:45:25 -0500 Subject: [PATCH 107/222] fix: docs; allow specifying the tool_use huggingface template --- docs/source/serving/openai_compatible_server.md | 4 ++-- vllm/entrypoints/openai/serving_tokenization.py | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index f06414e7f4091..6df0136b5afd5 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -130,12 +130,12 @@ well-defined failure modes._ As such, it must be explicitly enabled when desired To enable this feature, you must set the following flags: * `--enable-api-tools` -- **mandatory** for Auto tool choice. tells vLLM that you want to enable tool templating and extraction. -* `--enable-auto-toolchoice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its' own tool scalls when it +* `--enable-auto-toolchoice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it deems appropriate. * `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages that contain previously generated tool calls.This argument can be set to `tool_use` if your model has a tool use chat template configured in the `tokenizer_config.json`. In this case, it will be used per the `transformers` specification. More on this [here](https://huggingface.co/docs/transformers/en/chat_templating#why-do-some-models-have-multiple-templates) -from HuggingFace; and you can find an example of this in a `tokenizer_config.json` [here]() +from HuggingFace; and you can find an example of this in a `tokenizer_config.json` [here](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/tokenizer_config.json) * `--tool-parser` -- select the tool parser to use - currently either `hermes` or `mistral`. If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template! diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index c4350881a27a6..6a252670e9afa 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -40,7 +40,11 @@ def __init__( request_logger=request_logger) # If this is None we use the tokenizer's default chat template - self.chat_template = load_chat_template(chat_template) + # the list of commonly-used chat template names for HF named templates + hf_chat_templates: List[str] = ['default', 'tool_use'] + self.chat_template = chat_template \ + if chat_template in hf_chat_templates \ + else load_chat_template(chat_template) async def create_tokenize( self, From c10611130c69fbecee363ca00e0b78a46d613779 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 6 Aug 2024 20:54:25 -0500 Subject: [PATCH 108/222] chore: formatting --- vllm/entrypoints/openai/api_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 0c69c59b3cd96..7f5ded3e962d5 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -197,7 +197,6 @@ async def create_chat_completion(request: ChatCompletionRequest, status_code=generator.code) # if streaming is requested, handle streaming - if request.stream: return StreamingResponse(content=generator, media_type="text/event-stream") From 643c7927cc84204d483e001608dc479cbeb0ed08 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 6 Aug 2024 21:11:14 -0500 Subject: [PATCH 109/222] fix: mistral chat template formatting --- examples/tool_chat_template_mistral.jinja | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/examples/tool_chat_template_mistral.jinja b/examples/tool_chat_template_mistral.jinja index 29efc61c2fd50..d2b1380745b85 100644 --- a/examples/tool_chat_template_mistral.jinja +++ b/examples/tool_chat_template_mistral.jinja @@ -1 +1,20 @@ -{{ bos_token }}{% set user_messages = messages | selectattr('role', 'equalto', 'user') | list %}{% for message in messages %}{% if message['role'] == 'user' %}{% if message == user_messages[-1] %}{% if tools %}{{ '[AVAILABLE_TOOLS]'+ tools|string + '[/AVAILABLE_TOOLS]' }}{% endif %}{{ '[INST]' + message['content'] + '[/INST]' }}{% else %}{{ '[INST]' + message['content'] + '[/INST]' }}{% endif %}{% elif message['role'] == 'assistant' and message['tool_calls'] and message['tool_calls']|length > 0 %}{{ '[TOOL_CALLS]' + message['tool_calls']|string + eos_token }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token }}{% elif message['role'] == 'tool' %}{{ '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }}{% endif %}{% endfor %} \ No newline at end of file +{{- bos_token }} +{%- set user_messages = messages | selectattr('role', 'equalto', 'user') | list %} +{%- for message in messages %} + {%- if message['role'] == 'user' %} + {%- if message == user_messages[-1] %} + {%- if tools %} + {{- '[AVAILABLE_TOOLS]'+ tools|string + '[/AVAILABLE_TOOLS]' }} + {%- endif %} + {{- '[INST]' + message['content'] + '[/INST]' }} + {%- else %} + {{- '[INST]' + message['content'] + '[/INST]' }} + {%- endif %} + {%- elif message['role'] == 'assistant' and message['tool_calls'] and message['tool_calls']|length > 0 %} + {{- '[TOOL_CALLS]' + message['tool_calls']|string + eos_token }} + {%- elif message['role'] == 'assistant' %} + {{- ' ' + message['content'] + ' ' + eos_token }} + {%- elif message['role'] == 'tool' %} + {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }} + {%- endif %} +{%- endfor %} \ No newline at end of file From 55ece00ba40b4e3f60c5925a2680f15f5779a6da Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 6 Aug 2024 21:20:00 -0500 Subject: [PATCH 110/222] feat: add official mistral 7B instruct v0.3 chat template --- .../tool_chat_template_mistral_official.jinja | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 examples/tool_chat_template_mistral_official.jinja diff --git a/examples/tool_chat_template_mistral_official.jinja b/examples/tool_chat_template_mistral_official.jinja new file mode 100644 index 0000000000000..3716dceb2d9bb --- /dev/null +++ b/examples/tool_chat_template_mistral_official.jinja @@ -0,0 +1,86 @@ +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + {%- if not tools is defined %} + {%- set tools = none %} + {%- endif %} +{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %} + +{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %} + {%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %} + {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif %} +{%- endfor %} + +{{- bos_token }} +{%- for message in loop_messages %} + {%- if message["role"] == "user" %} + {%- if tools is not none and (message == user_messages[-1]) %} + {{- "[AVAILABLE_TOOLS] [" }} + {%- for tool in tools %} + {%- set tool = tool.function %} + {{- '{"type": "function", "function": {' }} + {%- for key, val in tool.items() if key != "return" %} + {%- if val is string %} + {{- '"' + key + '": "' + val + '"' }} + {%- else %} + {{- '"' + key + '": ' + val|tojson }} + {%- endif %} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- "}}" }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- "[/AVAILABLE_TOOLS]" }} + {%- endif %} + {%- if loop.last and system_message is defined %} + {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }} + {%- else %} + {{- "[INST] " + message["content"] + "[/INST]" }} + {%- endif %} + {%- elif message["role"] == "tool_calls" or message.tool_calls is defined %} + {%- if message.tool_calls is defined %} + {%- set tool_calls = message.tool_calls %} + {%- else %} + {%- set tool_calls = message.content %} + {%- endif %} + {{- "[TOOL_CALLS] [" }} + {%- for tool_call in tool_calls %} + {%- set out = tool_call.function|tojson %} + {{- out[:-1] }} + {%- if not tool_call.id is defined or tool_call.id|length != 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }} + {%- endif %} + {{- ', "id": "' + tool_call.id + '"}' }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" + eos_token }} + {%- endif %} + {%- endfor %} + {%- elif message["role"] == "assistant" %} + {{- " " + message["content"] + eos_token}} + {%- elif message["role"] == "tool_results" or message["role"] == "tool" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }} + {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }} + {%- endif %} + {{- '"call_id": "' + message.tool_call_id + '"}[/TOOL_RESULTS]' }} + {%- else %} + {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %} From 722501a4bf968c1c1089d24698978bd8ad55129e Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 6 Aug 2024 21:39:10 -0500 Subject: [PATCH 111/222] fix: patch official mistral template to handle vLLM-generated tool call IDs in an appropriate way --- examples/tool_chat_template_mistral_official.jinja | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/tool_chat_template_mistral_official.jinja b/examples/tool_chat_template_mistral_official.jinja index 3716dceb2d9bb..49855b6506f9f 100644 --- a/examples/tool_chat_template_mistral_official.jinja +++ b/examples/tool_chat_template_mistral_official.jinja @@ -57,10 +57,10 @@ {%- for tool_call in tool_calls %} {%- set out = tool_call.function|tojson %} {{- out[:-1] }} - {%- if not tool_call.id is defined or tool_call.id|length != 9 %} - {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }} + {%- if not tool_call.id is defined or tool_call.id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }} {%- endif %} - {{- ', "id": "' + tool_call.id + '"}' }} + {{- ', "id": "' + tool_call.id[-9:] + '"}' }} {%- if not loop.last %} {{- ", " }} {%- else %} @@ -76,10 +76,10 @@ {%- set content = message.content %} {%- endif %} {{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }} - {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %} - {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }} + {%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }} {%- endif %} - {{- '"call_id": "' + message.tool_call_id + '"}[/TOOL_RESULTS]' }} + {{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }} {%- else %} {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} {%- endif %} From a70b013b48d11a97ab57b1b8c8eb06113782ccb2 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 6 Aug 2024 21:40:04 -0500 Subject: [PATCH 112/222] fix: replace unofficial mistral chat template with official one --- examples/tool_chat_template_mistral.jinja | 96 ++++++++++++++++--- .../tool_chat_template_mistral_official.jinja | 86 ----------------- 2 files changed, 81 insertions(+), 101 deletions(-) delete mode 100644 examples/tool_chat_template_mistral_official.jinja diff --git a/examples/tool_chat_template_mistral.jinja b/examples/tool_chat_template_mistral.jinja index d2b1380745b85..49855b6506f9f 100644 --- a/examples/tool_chat_template_mistral.jinja +++ b/examples/tool_chat_template_mistral.jinja @@ -1,20 +1,86 @@ +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + {%- if not tools is defined %} + {%- set tools = none %} + {%- endif %} +{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %} + +{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %} + {%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %} + {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif %} +{%- endfor %} + {{- bos_token }} -{%- set user_messages = messages | selectattr('role', 'equalto', 'user') | list %} -{%- for message in messages %} - {%- if message['role'] == 'user' %} - {%- if message == user_messages[-1] %} - {%- if tools %} - {{- '[AVAILABLE_TOOLS]'+ tools|string + '[/AVAILABLE_TOOLS]' }} +{%- for message in loop_messages %} + {%- if message["role"] == "user" %} + {%- if tools is not none and (message == user_messages[-1]) %} + {{- "[AVAILABLE_TOOLS] [" }} + {%- for tool in tools %} + {%- set tool = tool.function %} + {{- '{"type": "function", "function": {' }} + {%- for key, val in tool.items() if key != "return" %} + {%- if val is string %} + {{- '"' + key + '": "' + val + '"' }} + {%- else %} + {{- '"' + key + '": ' + val|tojson }} + {%- endif %} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- "}}" }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- "[/AVAILABLE_TOOLS]" }} + {%- endif %} + {%- if loop.last and system_message is defined %} + {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }} + {%- else %} + {{- "[INST] " + message["content"] + "[/INST]" }} + {%- endif %} + {%- elif message["role"] == "tool_calls" or message.tool_calls is defined %} + {%- if message.tool_calls is defined %} + {%- set tool_calls = message.tool_calls %} + {%- else %} + {%- set tool_calls = message.content %} + {%- endif %} + {{- "[TOOL_CALLS] [" }} + {%- for tool_call in tool_calls %} + {%- set out = tool_call.function|tojson %} + {{- out[:-1] }} + {%- if not tool_call.id is defined or tool_call.id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }} {%- endif %} - {{- '[INST]' + message['content'] + '[/INST]' }} + {{- ', "id": "' + tool_call.id[-9:] + '"}' }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" + eos_token }} + {%- endif %} + {%- endfor %} + {%- elif message["role"] == "assistant" %} + {{- " " + message["content"] + eos_token}} + {%- elif message["role"] == "tool_results" or message["role"] == "tool" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} {%- else %} - {{- '[INST]' + message['content'] + '[/INST]' }} + {%- set content = message.content %} + {%- endif %} + {{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }} + {%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }} {%- endif %} - {%- elif message['role'] == 'assistant' and message['tool_calls'] and message['tool_calls']|length > 0 %} - {{- '[TOOL_CALLS]' + message['tool_calls']|string + eos_token }} - {%- elif message['role'] == 'assistant' %} - {{- ' ' + message['content'] + ' ' + eos_token }} - {%- elif message['role'] == 'tool' %} - {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }} + {{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }} + {%- else %} + {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} {%- endif %} -{%- endfor %} \ No newline at end of file +{%- endfor %} diff --git a/examples/tool_chat_template_mistral_official.jinja b/examples/tool_chat_template_mistral_official.jinja deleted file mode 100644 index 49855b6506f9f..0000000000000 --- a/examples/tool_chat_template_mistral_official.jinja +++ /dev/null @@ -1,86 +0,0 @@ -{%- if messages[0]["role"] == "system" %} - {%- set system_message = messages[0]["content"] %} - {%- set loop_messages = messages[1:] %} -{%- else %} - {%- set loop_messages = messages %} -{%- endif %} - {%- if not tools is defined %} - {%- set tools = none %} - {%- endif %} -{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %} - -{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %} - {%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %} - {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }} - {%- endif %} -{%- endfor %} - -{{- bos_token }} -{%- for message in loop_messages %} - {%- if message["role"] == "user" %} - {%- if tools is not none and (message == user_messages[-1]) %} - {{- "[AVAILABLE_TOOLS] [" }} - {%- for tool in tools %} - {%- set tool = tool.function %} - {{- '{"type": "function", "function": {' }} - {%- for key, val in tool.items() if key != "return" %} - {%- if val is string %} - {{- '"' + key + '": "' + val + '"' }} - {%- else %} - {{- '"' + key + '": ' + val|tojson }} - {%- endif %} - {%- if not loop.last %} - {{- ", " }} - {%- endif %} - {%- endfor %} - {{- "}}" }} - {%- if not loop.last %} - {{- ", " }} - {%- else %} - {{- "]" }} - {%- endif %} - {%- endfor %} - {{- "[/AVAILABLE_TOOLS]" }} - {%- endif %} - {%- if loop.last and system_message is defined %} - {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }} - {%- else %} - {{- "[INST] " + message["content"] + "[/INST]" }} - {%- endif %} - {%- elif message["role"] == "tool_calls" or message.tool_calls is defined %} - {%- if message.tool_calls is defined %} - {%- set tool_calls = message.tool_calls %} - {%- else %} - {%- set tool_calls = message.content %} - {%- endif %} - {{- "[TOOL_CALLS] [" }} - {%- for tool_call in tool_calls %} - {%- set out = tool_call.function|tojson %} - {{- out[:-1] }} - {%- if not tool_call.id is defined or tool_call.id|length < 9 %} - {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }} - {%- endif %} - {{- ', "id": "' + tool_call.id[-9:] + '"}' }} - {%- if not loop.last %} - {{- ", " }} - {%- else %} - {{- "]" + eos_token }} - {%- endif %} - {%- endfor %} - {%- elif message["role"] == "assistant" %} - {{- " " + message["content"] + eos_token}} - {%- elif message["role"] == "tool_results" or message["role"] == "tool" %} - {%- if message.content is defined and message.content.content is defined %} - {%- set content = message.content.content %} - {%- else %} - {%- set content = message.content %} - {%- endif %} - {{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }} - {%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %} - {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }} - {%- endif %} - {{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }} - {%- else %} - {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} - {%- endif %} -{%- endfor %} From 941bd038a0c172d632cf79ea0e2e619f39c42077 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 6 Aug 2024 21:43:39 -0500 Subject: [PATCH 113/222] chore(docs): update mistral tool calling docs to remove the notes about prompt limitations, since we now have a high-quality mistral prompt template --- docs/source/serving/openai_compatible_server.md | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 6df0136b5afd5..83ddd74ee8a51 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -156,12 +156,9 @@ Recommended flags: `--tool-parser hermes --chat-template examples/tool_chat_temp #### Mistral Models Supported models: * `mistralai/Mistral-7B-Instruct-v0.3` +* Possibly mistral-large and mixtral? These have not been tested at the time of this writing. -There are several known issues with tool-calling in Mistral models: -* Attempting to generate > 1 tool call at a time usually results in a parser failure, since the model generates the calls -in an unpredictable format due to the aforementioned chat template issue. **This can be mitigated by setting the -`temperature` to `0` in the OpenAI-style API call** - do this, and tool calls (including parallel ones) are **far** more -consistent +There is a several known with tool-calling in Mistral models: * Mistral function-calling / tool use generates calls with _single_ quotes `'` instead of double quotes `"`. As a result, tool call generations can't be handled as JSON by the parser automatically without using `eval`, which would present security issues for vLLM users. As a result, to support Mistral tool calls, we find-and-replace single-quotes From 0d0b556055948d9fcaf6bac6da4e9bb5d12463cd Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 6 Aug 2024 22:00:39 -0500 Subject: [PATCH 114/222] fix(test): transformers no longer supports using a default chat template when the model doesnt provide one I upgraded the transformers version in this PR to support their addition of "tool_use" prompt template in tokenizer_config.json as well as other named prompts. However, in this new version of transformers, falling back to a default prompt template is not supported when the model does not specify one (facebook/opt-125m does not specify one), so I removed tests for when there is no specified prompt template and no default in the tokenizer_config.json file --- tests/async_engine/test_chat_template.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/async_engine/test_chat_template.py b/tests/async_engine/test_chat_template.py index aea8a7fed6e33..ff65764648c54 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/async_engine/test_chat_template.py @@ -13,10 +13,6 @@ # Define models, templates, and their corresponding expected outputs MODEL_TEMPLATE_GENERATON_OUTPUT = [ - ("facebook/opt-125m", None, True, - "HelloHi there!What is the capital of"), - ("facebook/opt-125m", None, False, - "HelloHi there!What is the capital of"), ("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user Hello<|im_end|> <|im_start|>assistant @@ -80,8 +76,12 @@ def test_no_load_chat_template_literallike(): @pytest.mark.parametrize( "model,template,add_generation_prompt,expected_output", MODEL_TEMPLATE_GENERATON_OUTPUT) -def test_get_gen_prompt(model, template, add_generation_prompt, - expected_output): +def test_get_gen_prompt( + model, + template, + add_generation_prompt, + expected_output +): # Initialize the tokenizer tokenizer = get_tokenizer(tokenizer_name=model) template_content = load_chat_template(chat_template=template) From 8ec75882a27fb42bd307155cd04bd09bd196f2e4 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 6 Aug 2024 22:04:18 -0500 Subject: [PATCH 115/222] fix: add chat template path for opt-125m since not specifying this is not supported when the model does not have a template in the new version of transformers --- tests/async_engine/test_openapi_server_ray.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index 5ecd770ede836..353c3cfc4466c 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -1,11 +1,15 @@ import openai # use the official client for correctness check import pytest +import pathlib +import os from ..utils import RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "facebook/opt-125m" +chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath( + __file__))).parent.parent / "examples/template_chatml.jinja" @pytest.fixture(scope="module") def server(): @@ -16,7 +20,9 @@ def server(): "--max-model-len", "2048", "--enforce-eager", - "--engine-use-ray" + "--engine-use-ray", + "--chat-template", + chatml_jinja_path ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: From cd1c095c6ade2f682fcff8928803745bdbd48dd2 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 6 Aug 2024 22:50:51 -0500 Subject: [PATCH 116/222] fix(test): cast posix path to string --- tests/async_engine/test_openapi_server_ray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index 353c3cfc4466c..b16e1d8a4dffd 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -22,7 +22,7 @@ def server(): "--enforce-eager", "--engine-use-ray", "--chat-template", - chatml_jinja_path + str(chatml_jinja_path) ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: From eb8a1ea4ea47f0fd2211404f0908a240a4858b13 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 6 Aug 2024 23:00:16 -0500 Subject: [PATCH 117/222] fix(test): updated expected token count because of applying chatml template --- tests/async_engine/test_openapi_server_ray.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index b16e1d8a4dffd..1470304e21dcd 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -88,8 +88,9 @@ async def test_single_chat_session(client: openai.AsyncOpenAI): choice = chat_completion.choices[0] assert choice.finish_reason == "length" + print('USAGE', chat_completion.usage) assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=13, total_tokens=23) + completion_tokens=10, prompt_tokens=55, total_tokens=65) message = choice.message assert message.content is not None and len(message.content) >= 10 From 7fc67e5bac2e6c36cad5f7b392a4da0ed79a7e6d Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 6 Aug 2024 23:00:33 -0500 Subject: [PATCH 118/222] chore: remove print --- tests/async_engine/test_openapi_server_ray.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index 1470304e21dcd..e6a12101c9b7b 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -88,7 +88,6 @@ async def test_single_chat_session(client: openai.AsyncOpenAI): choice = chat_completion.choices[0] assert choice.finish_reason == "length" - print('USAGE', chat_completion.usage) assert chat_completion.usage == openai.types.CompletionUsage( completion_tokens=10, prompt_tokens=55, total_tokens=65) From 05b366ff3c176104e51bd43cf9e8ebd6f5e97ea1 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 08:50:59 -0500 Subject: [PATCH 119/222] fix: add chat template due to bumped transformers version --- tests/entrypoints/openai/test_oot_registration.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/entrypoints/openai/test_oot_registration.py b/tests/entrypoints/openai/test_oot_registration.py index 5272ac4065f1d..0f82614d77786 100644 --- a/tests/entrypoints/openai/test_oot_registration.py +++ b/tests/entrypoints/openai/test_oot_registration.py @@ -3,12 +3,15 @@ import torch from openai import OpenAI, OpenAIError - +import os +import pathlib from vllm import ModelRegistry from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.utils import get_open_port +chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath( + __file__))).parent.parent / "examples/template_chatml.jinja" class MyOPTForCausalLM(OPTForCausalLM): @@ -26,6 +29,7 @@ def server_function(port): ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) sys.argv = ["placeholder.py"] + \ ("--model facebook/opt-125m --gpu-memory-utilization 0.10 " + f"--chat-template {str(chatml_jinja_path)}" f"--dtype float32 --api-key token-abc123 --port {port}").split() import runpy runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') From b417e2b1686370e8cf64f07fb215aa6bb0aa0ec4 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 09:12:00 -0500 Subject: [PATCH 120/222] fix: tests --- tests/entrypoints/openai/test_oot_registration.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/openai/test_oot_registration.py b/tests/entrypoints/openai/test_oot_registration.py index 0f82614d77786..43f9f399b42cd 100644 --- a/tests/entrypoints/openai/test_oot_registration.py +++ b/tests/entrypoints/openai/test_oot_registration.py @@ -1,10 +1,11 @@ +import os +import pathlib import sys import time import torch from openai import OpenAI, OpenAIError -import os -import pathlib + from vllm import ModelRegistry from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -13,6 +14,7 @@ chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath( __file__))).parent.parent / "examples/template_chatml.jinja" + class MyOPTForCausalLM(OPTForCausalLM): def compute_logits(self, hidden_states: torch.Tensor, From 1bf96f7b68eeedea117afea19c34fdc695c9e404 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 09:20:58 -0500 Subject: [PATCH 121/222] fix: yapf --- vllm/entrypoints/openai/protocol.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0cde80168fa92..e03c3318b4ae1 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -6,10 +6,10 @@ import torch from openai.types.chat import ChatCompletionContentPartParam -from openai.types.chat import ( - ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) -from openai.types.chat import ( - ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) +from openai.types.chat import (ChatCompletionContentPartParam as + OpenAIChatCompletionContentPartParam) +from openai.types.chat import (ChatCompletionMessageParam as + OpenAIChatCompletionMessageParam) from pydantic import BaseModel, ConfigDict, Field, model_validator from transformers import PreTrainedTokenizer from typing_extensions import Annotated, Required, TypedDict From 11dbdd7c4701ce04653370e57b3a30b0a816876b Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 11:39:21 -0500 Subject: [PATCH 122/222] fix: disable yapf for block conflicting with isort --- vllm/entrypoints/openai/protocol.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index e03c3318b4ae1..86e39d5cdb710 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -5,11 +5,13 @@ from typing import Any, Dict, List, Literal, Optional, Union, final import torch +# yapf conflicts with isort for this block +# yapf: disable from openai.types.chat import ChatCompletionContentPartParam -from openai.types.chat import (ChatCompletionContentPartParam as - OpenAIChatCompletionContentPartParam) -from openai.types.chat import (ChatCompletionMessageParam as - OpenAIChatCompletionMessageParam) +from openai.types.chat import ( + ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) +from openai.types.chat import ( + ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) from pydantic import BaseModel, ConfigDict, Field, model_validator from transformers import PreTrainedTokenizer from typing_extensions import Annotated, Required, TypedDict @@ -19,6 +21,8 @@ from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.utils import random_uuid +# yapf: enable + # torch is mocked during docs generation, # so we have to provide the values as literals _MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807) From de33564c2b1ab6708f5abb9d88f04fa3bf81ddd0 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 16:05:53 -0500 Subject: [PATCH 123/222] fix: tool_call_id was accidentally message content --- vllm/entrypoints/chat_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index a6ab4626de482..c75f50c720d3b 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -146,7 +146,7 @@ def _parse_chat_message_content( ) -> ChatMessageParseResult: role = message["role"] content = message.get("content") - tool_call_id = message.get('content') + tool_call_id = message.get('tool_call_id') tool_calls = message.get('tool_calls') # no longer used by OpenAI, but some models still use it for tool calls. name = message.get('name', '') From 122fdc3847543b2ea10f6a988c2bf41f7fc9755b Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 16:40:44 -0500 Subject: [PATCH 124/222] fix: use double quotes in example --- ...penai_chat_completion_client_with_tools.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/openai_chat_completion_client_with_tools.py b/examples/openai_chat_completion_client_with_tools.py index 4c177247986bc..adf0632f15b7a 100644 --- a/examples/openai_chat_completion_client_with_tools.py +++ b/examples/openai_chat_completion_client_with_tools.py @@ -66,7 +66,7 @@ print("Chat completion results:") print(chat_completion) -print('\n\n') +print("\n\n") tool_calls_stream = client.chat.completions.create(messages=messages, model=model, @@ -89,25 +89,25 @@ if chunk.choices[0].delta.tool_calls[0].index != tool_call_idx: if tool_call_idx >= 0: print( - f'streamed tool call arguments: {arguments[tool_call_idx]}' + f"streamed tool call arguments: {arguments[tool_call_idx]}" ) tool_call_idx = chunk.choices[0].delta.tool_calls[0].index - arguments.append('') + arguments.append("") if chunk.choices[0].delta.tool_calls[0].id: - print(f'streamed tool call id: ' - f'{chunk.choices[0].delta.tool_calls[0].id}') + print(f"streamed tool call id: " + f"{chunk.choices[0].delta.tool_calls[0].id}") if chunk.choices[0].delta.tool_calls[0].function: if chunk.choices[0].delta.tool_calls[0].function.name: - print(f'streamed tool call name: ' - f'{chunk.choices[0].delta.tool_calls[0].function.name}') + print(f"streamed tool call name: " + f"{chunk.choices[0].delta.tool_calls[0].function.name}") if chunk.choices[0].delta.tool_calls[0].function.arguments: arguments[tool_call_idx] += chunk.choices[0].delta.tool_calls[ 0].function.arguments if len(arguments): - print(f'streamed tool call arguments: {arguments[-1]}') + print(f"streamed tool call arguments: {arguments[-1]}") -print('\n\n') +print("\n\n") messages.append({ "role": "assistant", @@ -140,5 +140,5 @@ def get_current_weather(city: str, state: str, unit: 'str'): model=model, tools=tools, stream=False) -print('\n\n') +print("\n\n") print(chat_completion_2) From 3b0589d16a1be2afc828f3ada25c80a4eb08526d Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 16:40:55 -0500 Subject: [PATCH 125/222] fix: use double quotes in chat_utils --- vllm/entrypoints/chat_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index c75f50c720d3b..d120d0d72c6ad 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -146,17 +146,17 @@ def _parse_chat_message_content( ) -> ChatMessageParseResult: role = message["role"] content = message.get("content") - tool_call_id = message.get('tool_call_id') - tool_calls = message.get('tool_calls') + tool_call_id = message.get("tool_call_id") + tool_calls = message.get("tool_calls") # no longer used by OpenAI, but some models still use it for tool calls. - name = message.get('name', '') + name = message.get("name", "") # empty case if content is None and tool_calls is None: return ChatMessageParseResult(messages=[], mm_futures=[]) # special case - assistant message where tool calls are provided. - if role == 'assistant' and tool_calls is not None: + if role == "assistant" and tool_calls is not None: messages = [ ConversationMessage(role=role, content=cast(Optional[str], content), @@ -165,7 +165,7 @@ def _parse_chat_message_content( return ChatMessageParseResult(messages=messages, mm_futures=[]) # special case - tool call result message - elif role == 'tool': + elif role == "tool": messages = [ ConversationMessage(role=role, name=name, From 8634184e83e9f509cb88c2f20a1b661f070dc23c Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 16:41:45 -0500 Subject: [PATCH 126/222] fix: use double quotes in api_server --- vllm/entrypoints/openai/api_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index fabd8c10b8beb..9ab97577db9b3 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -55,7 +55,7 @@ openai_serving_embedding: OpenAIServingEmbedding openai_serving_tokenization: OpenAIServingTokenization -logger = init_logger('vllm.entrypoints.openai.api_server') +logger = init_logger("vllm.entrypoints.openai.api_server") _running_tasks: Set[asyncio.Task] = set() @@ -142,7 +142,7 @@ def mount_metrics(app: FastAPI): # Add prometheus asgi middleware to route /metrics requests metrics_route = Mount("/metrics", make_asgi_app()) # Workaround for 307 Redirect for /metrics - metrics_route.path_regex = re.compile('^/metrics(?P.*)$') + metrics_route.path_regex = re.compile("^/metrics(?P.*)$") app.routes.append(metrics_route) From a952c15d70169d0e2a199fa5c75c301bd3b272f6 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 16:43:42 -0500 Subject: [PATCH 127/222] fix: single quotes (that I added) in cli_args are now double quotes --- vllm/entrypoints/openai/cli_args.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 7272208feddd5..19f8fc93e38c9 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -145,19 +145,19 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action="store_true", default=False, help= - 'Enable auto tool choice for supported models. Use --tool-call-parser' - 'to specify which parser to use') + "Enable auto tool choice for supported models. Use --tool-call-parser" + "to specify which parser to use") parser.add_argument( "--tool-call-parser", type=str, - choices=['mistral', 'hermes'], + choices=["mistral", "hermes"], default=None, help= - 'Select the tool call parser depending on the model that you\'re using.' - ' This is used to parse the model-generated tool call into OpenAI API ' - 'format. Required for --enable-auto-tool-choice. Options: "mistral", ' - '"hermes"') + "Select the tool call parser depending on the model that you\'re using." + " This is used to parse the model-generated tool call into OpenAI API " + "format. Required for --enable-auto-tool-choice. Options: 'hermes', " + "'mistral'") parser = AsyncEngineArgs.add_cli_args(parser) From cf85b1c6582fb6a854800058e5eaf6334a9b0ba0 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 16:47:09 -0500 Subject: [PATCH 128/222] fix: double quotes in protocol.py --- vllm/entrypoints/openai/protocol.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 39cf66abfd79a..6d1d67ac05624 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -358,8 +358,8 @@ def check_guided_decoding_count(cls, data): "You can only use one kind of guided decoding " "('guided_json', 'guided_regex' or 'guided_choice').") # you can only either use guided decoding or tools, not both - if guide_count > 1 and data.get('tool_choice', - 'none') not in ("none", "auto"): + if guide_count > 1 and data.get("tool_choice", + "none") not in ("none", "auto"): raise ValueError( "You can only either use guided decoding or tools, not both.") return data @@ -396,16 +396,16 @@ def check_tool_usage(cls, data): specified_function = data["tool_choice"]["function"] if not specified_function: return ValueError( - 'Incorrectly formatted `tool_choice`. Should be like ' - '`{"type": "function",' - ' "function": {"name": "my_function"}}`') + "Incorrectly formatted `tool_choice`. Should be like " + "`{\"type\": \"function\"," + " \"function\": {\"name\": \"my_function\"}}`") specified_function_name = specified_function["name"] if not specified_function_name: return ValueError( - 'Incorrectly formatted `tool_choice`. Should be like ' - '`{"type": "function", ' - '"function": {"name": "my_function"}}`') - for tool in data['tools']: + "Incorrectly formatted `tool_choice`. Should be like " + "`{\"type\": \"function\", " + "\"function\": {\"name\": \"my_function\"}}`") + for tool in data["tools"]: if tool["function"]["name"] == specified_function_name: valid_tool = True break @@ -766,7 +766,7 @@ class ChatCompletionResponseChoice(OpenAIBaseModel): message: ChatMessage logprobs: Optional[ChatCompletionLogProbs] = None finish_reason: Optional[str] = Field( - default='stop') # per OpenAI spec this is the default + default="stop") # per OpenAI spec this is the default stop_reason: Optional[Union[int, str]] = None # ??? Not part of the OpenAI spec From 1f8ea1a8ebc19c21fdad4516fd0e32fea1fad987 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 16:49:21 -0500 Subject: [PATCH 129/222] fix: double quotes in serving_chat --- vllm/entrypoints/openai/serving_chat.py | 60 ++++++++++++------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 52a00f425d8ca..ee813ea4029fa 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -38,7 +38,7 @@ log_tracing_disabled_warning) from vllm.utils import iterate_with_cancellation, random_uuid -env = Environment(loader=FileSystemLoader('./'), +env = Environment(loader=FileSystemLoader("./"), autoescape=select_autoescape()) logger = init_logger(__name__) @@ -75,19 +75,19 @@ def __init__(self, self.enable_auto_tools: bool = enable_auto_tools or False if self.enable_auto_tools: logger.info( - '"Auto" tool choice has been enabled please note that while' - ' the parallel_tool_calls client option is preset for ' - 'compatibility reasons, it will be ignored.') + "\"auto\" tool choice has been enabled please note that while" + " the parallel_tool_calls client option is preset for " + "compatibility reasons, it will be ignored.") self.tool_parser: Optional[Type[ToolParser]] = None if self.enable_auto_tools: - if tool_parser == 'mistral': + if tool_parser == "mistral": self.tool_parser = MistralToolParser - elif tool_parser == 'hermes': + elif tool_parser == "hermes": self.tool_parser = Hermes2ProToolParser else: raise TypeError( - 'Error: --enable-auto-tool-choice requires --tool-parser') + "Error: --enable-auto-tool-choice requires --tool-parser") async def create_chat_completion( self, @@ -104,7 +104,7 @@ async def create_chat_completion( """ error_check_ret = await self._check_model(request) if error_check_ret is not None: - logger.error('Error with model %s', error_check_ret) + logger.error("Error with model %s", error_check_ret) return error_check_ret try: @@ -153,16 +153,16 @@ async def create_chat_completion( # validation for OpenAI tools # tool_choice = "required" is not supported - if request.tool_choice == 'required': + if request.tool_choice == "required": return self.create_error_response( - 'tool_choice = "required" is not supported!') + "tool_choice = \"required\" is not supported!") # "auto" tools requires --enable-auto-tool-choice and --tool-parser - if request.tool_choice == 'auto' and not ( + if request.tool_choice == "auto" and not ( self.enable_auto_tools and self.tool_parser is not None): return self.create_error_response( - '"auto" tool choice requires ' - '--enable-auto-tool-choice and --tool-parser to be set') + "\"auto\" tool choice requires " + "--enable-auto-tool-choice and --tool-parser to be set") request_id = f"chat-{random_uuid()}" try: @@ -303,7 +303,7 @@ async def chat_completion_stream_generator( if conversation and conversation[-1].get( "content") and conversation[-1].get( "role") == role: - last_msg_content = conversation[-1]["content"] or '' + last_msg_content = conversation[-1]["content"] or "" if last_msg_content: for i in range(num_choices): @@ -378,7 +378,7 @@ async def chat_completion_stream_generator( # handle streaming deltas for tools with tool_choice elif (request.tools and tool_parser and (request.tool_choice is None - or request.tool_choice == 'auto') + or request.tool_choice == "auto") and self.enable_auto_tools): delta_message = ( @@ -450,19 +450,19 @@ async def chat_completion_stream_generator( and delta_message.tool_calls[0] and delta_message.tool_calls[0].function and (delta_message.tool_calls[0].function.arguments - == '' or + == "" or delta_message.tool_calls[0].function.arguments and - (output.finish_reason == 'stop' - or output.finish_reason == 'tool_calls')) + (output.finish_reason == "stop" + or output.finish_reason == "tool_calls")) and tool_parser - and request.tool_choice == 'auto'): + and request.tool_choice == "auto"): expected_call = json.dumps( tool_parser.prev_tool_call_arr[index].get( - 'arguments', {})) + "arguments", {})) actual_call = tool_parser.streamed_args_for_tool[ index] remaining_call = expected_call.replace( - actual_call, '', 1) + actual_call, "", 1) delta_message = DeltaMessage(tool_calls=[ DeltaToolCall(index=index, function=DeltaFunctionCall( @@ -478,7 +478,7 @@ async def chat_completion_stream_generator( finish_reason=output.finish_reason if not (tool_parser and len(tool_parser.prev_tool_call_arr)) - else 'tool_calls', + else "tool_calls", stop_reason=output.stop_reason) chunk = ChatCompletionStreamResponse( id=request_id, @@ -527,7 +527,7 @@ async def chat_completion_stream_generator( except ValueError as e: # TODO: Use a vllm-specific Validation Error - logger.error('error in chat completion stream generator: %s', e) + logger.error("error in chat completion stream generator: %s", e) data = self.create_streaming_error_response(str(e)) yield f"data: {data}\n\n" # Send the final done message after all response.n are finished @@ -625,17 +625,17 @@ async def chat_completion_full_generator( # undetermined case that is still important to handle else: logger.error( - 'Error in chat_completion_full_generator - cannot determine' - ' if tools should be extracted. Returning a standard chat ' - 'completion.') + "Error in chat_completion_full_generator - cannot determine" + " if tools should be extracted. Returning a standard chat " + "completion.") message = ChatMessage(role=role, content=output.text) choice_data = ChatCompletionResponseChoice( index=output.index, message=message, logprobs=logprobs, - finish_reason='tool_calls' if tools_called else - output.finish_reason if output.finish_reason else 'stop', + finish_reason="tool_calls" if tools_called else + output.finish_reason if output.finish_reason else "stop", stop_reason=output.stop_reason) choices.append(choice_data) @@ -643,11 +643,11 @@ async def chat_completion_full_generator( last_msg_content = "" if conversation and conversation[-1].get( "content") and conversation[-1].get("role") == role: - last_msg_content = conversation[-1]["content"] or '' + last_msg_content = conversation[-1]["content"] or "" for choice in choices: full_message = last_msg_content + (choice.message.content - or '') + or "") choice.message.content = full_message num_prompt_tokens = len(final_res.prompt_token_ids) From 49fb3ae5def6136313767571fb826a1d08a4fe35 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 16:50:26 -0500 Subject: [PATCH 130/222] fix: double quotes in abstracttoolparser --- vllm/entrypoints/openai/tool_parsers/__init__.py | 2 +- .../entrypoints/openai/tool_parsers/abstract_tool_parser.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 64a33a6d4eded..5d5d53784fedf 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -2,4 +2,4 @@ from .hermes_tool_parser import Hermes2ProToolParser from .mistral_tool_parser import MistralToolParser -__all__ = ['ToolParser', 'Hermes2ProToolParser', 'MistralToolParser'] \ No newline at end of file +__all__ = ["ToolParser", "Hermes2ProToolParser", "MistralToolParser"] \ No newline at end of file diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index 6892657f9de50..e0870396a69ee 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -41,7 +41,7 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: Static because it's stateless. """ raise NotImplementedError( - 'AbstractToolParser.extract_tool_calls has not been implemented!') + "AbstractToolParser.extract_tool_calls has not been implemented!") def extract_tool_calls_streaming( self, @@ -60,5 +60,5 @@ def extract_tool_calls_streaming( previously been parsed and extracted (see constructor) """ raise NotImplementedError( - 'AbstractToolParser.extract_tool_calls_streaming has not been ' - 'implemented!') + "AbstractToolParser.extract_tool_calls_streaming has not been " + "implemented!") From c45f82488586e876e01c05a86d1f434bfab6c24b Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 16:52:20 -0500 Subject: [PATCH 131/222] fix: double quotes in hermes tool parser --- .../openai/tool_parsers/hermes_tool_parser.py | 104 +++++++++--------- 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 2859b84c613b6..1d06d806a8c97 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -22,14 +22,14 @@ class Hermes2ProToolParser(ToolParser): - tool_call_start_token: str = '' - tool_call_end_token: str = '' + tool_call_start_token: str = "" + tool_call_end_token: str = "" # regex to match between and OR between # and EOS (happens sometimes :)) tool_call_regex = re.compile( - r'(.*?)|(.*)', re.DOTALL) - scratch_pad_regex = re.compile(r'(.*?)', + r"(.*?)|(.*)", re.DOTALL) + scratch_pad_regex = re.compile(r"(.*?)", re.DOTALL) @staticmethod @@ -59,11 +59,11 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: ] tool_calls = [ ToolCall( - type='function', + type="function", function=FunctionCall( - name=function_call['name'], + name=function_call["name"], # function call args are JSON but as a string - arguments=json.dumps(function_call['arguments']))) + arguments=json.dumps(function_call["arguments"]))) for function_call in raw_function_calls ] @@ -97,27 +97,27 @@ def __init__(self, if not self.model_tokenizer: raise ValueError( - 'The model tokenizer must be passed to the ToolParser ' - 'constructor during construction.') + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") self.tool_call_start_token_id: int = self.model_tokenizer.vocab[ - ''] + ""] self.tool_call_end_token_id: int = self.model_tokenizer.vocab[ - ''] + ""] if not self.tool_call_start_token_id or not self.tool_call_end_token_id: raise RuntimeError( - 'Hermes 2 Pro Tool parser could not locate tool call start/end ' - 'tokens in the tokenizer!') + "Hermes 2 Pro Tool parser could not locate tool call start/end " + "tokens in the tokenizer!") def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: List[int], current_token_ids: List[int], delta_token_ids: List[int]) -> Union[DeltaMessage, None]: - logger.debug('delta_text: %s', delta_text) - logger.debug('delta_token_ids: %s', delta_token_ids) + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) # check to see if we should be streaming a tool call - is there a if self.tool_call_start_token_id not in current_token_ids: - logger.debug('No tool call tokens found!') + logger.debug("No tool call tokens found!") return DeltaMessage(content=delta_text) else: @@ -138,7 +138,7 @@ def extract_tool_calls_streaming( if (cur_tool_start_count == cur_tool_end_count and prev_tool_end_count == cur_tool_end_count): logger.debug( - 'Generating text content! skipping tool parsing.') + "Generating text content! skipping tool parsing.") return DeltaMessage(content=delta_text) # most of the time, we're going in here - we need to do partial @@ -166,8 +166,8 @@ def extract_tool_calls_streaming( self.current_tool_id += 1 self.current_tool_name_sent = False self.current_tool_initial_sent = False - self.streamed_args_for_tool.append('') - logger.debug('Starting on a new tool %s', + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", self.current_tool_id) # if an existing tool call is being updated - the most @@ -181,16 +181,16 @@ def extract_tool_calls_streaming( # if the current tool call is being closed elif (cur_tool_start_count == cur_tool_end_count and cur_tool_end_count > prev_tool_end_count): - logger.debug('Closing the current tool call!') + logger.debug("Closing the current tool call!") diff = self.prev_tool_call_arr[ - self.current_tool_id].get('arguments') + self.current_tool_id].get("arguments") if diff: diff = json.dumps(diff).replace( self.streamed_args_for_tool[ - self.current_tool_id], '') + self.current_tool_id], "") logger.debug( - 'Finishing tool and found diff that had not ' - 'been streamed yet: %s', diff) + "Finishing tool and found diff that had not " + "been streamed yet: %s", diff) return DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( @@ -200,22 +200,22 @@ def extract_tool_calls_streaming( else: logger.error( - 'INVARIANT - invalid state trying to parse tool ' - 'calls (wtf?)') + "INVARIANT - invalid state trying to parse tool " + "calls (wtf?)") delta = None return delta - logger.debug('Tool call portion: %s', tool_call_portion - or '') + logger.debug("Tool call portion: %s", tool_call_portion + or "") current_tool_call = partial_json_parser.loads( - tool_call_portion or '{}', + tool_call_portion or "{}", flags) if tool_call_portion else None - logger.debug('Parsed tool call %s', current_tool_call) + logger.debug("Parsed tool call %s", current_tool_call) - # make sure to send the initial message first if we haven't + # make sure to send the initial message first if we haven"t # already - with the tool ID if not self.current_tool_initial_sent: - logger.debug('Sending InitialDeltaToolCall') + logger.debug("Sending InitialDeltaToolCall") self.current_tool_initial_sent = True return DeltaMessage(tool_calls=[ InitialDeltaToolCall( @@ -227,10 +227,10 @@ def extract_tool_calls_streaming( # any arguments elif not self.current_tool_name_sent: function_name: Union[ - str, None] = current_tool_call.get('name') + str, None] = current_tool_call.get("name") if function_name: logger.debug( - 'Sending DeltaToolCall with function name %s', + "Sending DeltaToolCall with function name %s", function_name) self.current_tool_name_sent = True return DeltaMessage(tool_calls=[ @@ -254,40 +254,40 @@ def extract_tool_calls_streaming( # now we have the portion to parse as tool call. if text_portion is not None: logger.debug( - 'Also, will send text portion: %s', + "Also, will send text portion: %s", text_portion) logger.debug( - 'Trying to parse current tool call with ID %s', + "Trying to parse current tool call with ID %s", self.current_tool_id) if len(self.prev_tool_call_arr ) <= self.current_tool_id: self.prev_tool_call_arr.append({}) logger.debug( - 'Pushed dummy value into tool call arr') + "Pushed dummy value into tool call arr") # main logic for tool parsing here prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get('arguments') + self.current_tool_id].get("arguments") cur_arguments = current_tool_call.get( - 'arguments' + "arguments" ) # arguments, if any, in current dict - logger.debug('diffing old arguments: %s', + logger.debug("diffing old arguments: %s", prev_arguments) - logger.debug('against new ones: %s', cur_arguments) + logger.debug("against new ones: %s", cur_arguments) if not cur_arguments and not prev_arguments: - logger.debug('Skipping text %s - no arguments', + logger.debug("Skipping text %s - no arguments", delta_text) delta = None elif not cur_arguments and prev_arguments: logger.error( - 'INVARIANT - impossible to have arguments ' - 'reset mid-call') + "INVARIANT - impossible to have arguments " + "reset mid-call") delta = None elif cur_arguments and not prev_arguments: cur_arguments_json = json.dumps(cur_arguments) - logger.debug('finding %s in %s', delta_text, + logger.debug("finding %s in %s", delta_text, cur_arguments_json) arguments_delta = cur_arguments_json[: cur_arguments_json @@ -297,7 +297,7 @@ def extract_tool_calls_streaming( len(delta_text )] logger.debug( - 'First tokens in arguments received: %s', + "First tokens in arguments received: %s", arguments_delta) delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, @@ -312,12 +312,12 @@ def extract_tool_calls_streaming( elif cur_arguments and prev_arguments: cur_args_json = json.dumps(cur_arguments) prev_args_json = json.dumps(prev_arguments) - logger.debug('Searching for dif between\n%s', + logger.debug("Searching for dif between\n%s", cur_args_json) - logger.debug('and\n%s', prev_args_json) + logger.debug("and\n%s", prev_args_json) argument_diff = extract_intermediate_diff( cur_args_json, prev_args_json) - logger.debug('got argument diff %s', + logger.debug("got argument diff %s", argument_diff) delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, @@ -347,9 +347,9 @@ def extract_tool_calls_streaming( return delta except Exception as e: - logger.error('Error trying to handle streaming tool call: %s', + logger.error("Error trying to handle streaming tool call: %s", e) logger.debug( - 'Skipping chunk as a result of tool streaming extraction ' - 'error') + "Skipping chunk as a result of tool streaming extraction " + "error") return None # do not stream a delta. skip this token ID. From 9b7cbabdafbc0148c4b4466af405647102900e68 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 16:54:56 -0500 Subject: [PATCH 132/222] fix: double quotes in mistral tool parser --- .../tool_parsers/mistral_tool_parser.py | 62 +++++++++---------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 433db4b902d5f..58f2b18231f26 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -46,9 +46,9 @@ class MistralToolParser(ToolParser): # the bot_token is the token indicating tool call(s) follow. Tokens before # this token will be parsed as content; and # if not present, the entire response will be parsed as text content. - bot_token: str = '[TOOL_CALLS]' # string literal - bot_token_id: int = 5 # token ID thereof from the models' tokenizer - tool_call_regex = re.compile(r'\[{.*?}\]', re.DOTALL) + bot_token: str = "[TOOL_CALLS]" # string literal + bot_token_id: int = 5 # token ID thereof from the models" tokenizer + tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) @staticmethod def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: @@ -59,7 +59,7 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: """ logger.debug( - 'Trying to extract mistral tool calls from the following:') + "Trying to extract mistral tool calls from the following:") logger.debug(model_output) # Get the tool call token from the tokenizer if MistralToolParser.bot_token not in model_output: @@ -73,8 +73,8 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: # properly raw_tool_call = MistralToolParser.tool_call_regex.findall( model_output.replace(MistralToolParser.bot_token, - '') # remove BOT token - .replace("'", '"') # replace string quotes + "") # remove BOT token + .replace("'", "\"") # replace string quotes )[0] # load the JSON, and then use it to build the Function and @@ -82,12 +82,12 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: function_call_arr = json.loads(raw_tool_call) tool_calls: List[ToolCall] = [ ToolCall( - type='function', + type="function", function=FunctionCall( - name=raw_function_call['name'], + name=raw_function_call["name"], # function call args are JSON but as a string arguments=json.dumps( - raw_function_call['arguments']))) + raw_function_call["arguments"]))) for raw_function_call in function_call_arr ] content = model_output.split(MistralToolParser.bot_token)[0] @@ -99,7 +99,7 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: except Exception as e: logger.error("Error in extracting tool call from response: %s", e) - print('ERROR', e) + print("ERROR", e) # return information to just treat the tool call as regular JSON return ExtractedToolCallInformation(tools_called=False, tool_calls=[], @@ -161,7 +161,7 @@ def extract_tool_calls_streaming( # quotes instead of double for tool calls tool_call_message_portion = current_text.split( self.bot_token)[1] - parsable_arr = tool_call_message_portion.replace('\'', '"') + parsable_arr = tool_call_message_portion.replace("\'", "\"") # logger.debug('parsing: %s', parsable_arr) @@ -184,11 +184,11 @@ def extract_tool_calls_streaming( # streamed to the client yet. if self.current_tool_id >= 0: diff: Union[str, - None] = current_tool_call.get('arguments') + None] = current_tool_call.get("arguments") if diff: diff = json.dumps(diff).replace( self.streamed_args_for_tool[ - self.current_tool_id], '') + self.current_tool_id], "") delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( @@ -205,8 +205,8 @@ def extract_tool_calls_streaming( self.current_tool_id = len(tool_call_arr) - 1 self.current_tool_name_sent = False self.current_tool_initial_sent = False - self.streamed_args_for_tool.append('') - logger.debug('starting on new tool %d', + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) return delta @@ -214,19 +214,17 @@ def extract_tool_calls_streaming( elif len( tool_call_arr ) - 1 == self.current_tool_id and self.current_tool_id >= 0: - # logger.debug('update to tool %d', self.current_tool_id) pass # if there is NOTHING in the array, e.g. if only the open # bracket was streamed yet else: - # logger.debug('No tool call detected yet!') return None # if the current tool initial data incl. the id, type=function # and idx not sent, send that if not self.current_tool_initial_sent: - logger.debug('Sending InitialDeltaToolCall') + logger.debug("Sending InitialDeltaToolCall") self.current_tool_initial_sent = True delta = DeltaMessage(tool_calls=[ InitialDeltaToolCall( @@ -237,10 +235,10 @@ def extract_tool_calls_streaming( # if the current tool name hasn't been sent, send if available # - otherwise no chunks elif not self.current_tool_name_sent: - function_name = current_tool_call.get('name') + function_name = current_tool_call.get("name") if function_name: logger.debug( - 'Sending DeltaToolCall with function name %s', + "Sending DeltaToolCall with function name %s", function_name) delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, @@ -257,29 +255,29 @@ def extract_tool_calls_streaming( else: prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get('arguments') - cur_arguments = current_tool_call.get('arguments') + self.current_tool_id].get("arguments") + cur_arguments = current_tool_call.get("arguments") - new_text = delta_text.replace('\'', '"') + new_text = delta_text.replace("\'", "\"") if not cur_arguments and not prev_arguments: delta = None elif not cur_arguments and prev_arguments: logger.error( - 'INVARIANT - impossible to have arguments reset ' - 'mid-arguments') + "INVARIANT - impossible to have arguments reset " + "mid-arguments") delta = None elif cur_arguments and not prev_arguments: cur_arguments_json = json.dumps(cur_arguments) - logger.debug('finding %s in |%s|', new_text, + logger.debug("finding %s in |%s|", new_text, cur_arguments_json) arguments_delta = cur_arguments_json[: cur_arguments_json .index(new_text) + len(new_text)] - logger.debug('First tokens in arguments received: %s', + logger.debug("First tokens in arguments received: %s", arguments_delta) delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, @@ -293,12 +291,12 @@ def extract_tool_calls_streaming( elif cur_arguments and prev_arguments: cur_args_json = json.dumps(cur_arguments) prev_args_json = json.dumps(prev_arguments) - logger.debug('Searching for diff between \n%s\n%s', + logger.debug("Searching for diff between \n%s\n%s", cur_args_json, prev_args_json) argument_diff = extract_intermediate_diff( cur_args_json, prev_args_json) - logger.debug('got arguments diff: %s', argument_diff) + logger.debug("got arguments diff: %s", argument_diff) delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( @@ -320,9 +318,9 @@ def extract_tool_calls_streaming( return delta except Exception as e: - logger.error('Error trying to handle streaming tool call: %s', + logger.error("Error trying to handle streaming tool call: %s", e) logger.debug( - 'Skipping chunk as a result of tool streaming extraction ' - 'error') + "Skipping chunk as a result of tool streaming extraction " + "error") return None From f63908ffbebb12f9ca3947999426649222970346 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 16:57:39 -0500 Subject: [PATCH 133/222] fix: remove todo --- vllm/entrypoints/openai/protocol.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 6d1d67ac05624..782746a8b21bb 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -413,8 +413,6 @@ def check_tool_usage(cls, data): return ValueError( "The tool specified in `tool_choice` does not match any" " of the specified `tools`") - - # TODO validate tools return data @model_validator(mode="before") From 8db2a0d2cec6f3b7c4f3f830ef566f9135f674de Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 17:05:15 -0500 Subject: [PATCH 134/222] fix: remove deprecated to_dict method --- vllm/entrypoints/openai/protocol.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 782746a8b21bb..b26a2e0dcfeb0 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -683,22 +683,12 @@ class FunctionCall(OpenAIBaseModel): name: str arguments: str - def to_dict(self): - return {"name": self.name, "arguments": self.arguments} - class ToolCall(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") type: Literal["function"] = "function" function: FunctionCall - def to_dict(self): - return { - "id": self.id, - "type": self.type, - "function": self.function.to_dict() - } - class DeltaFunctionCall(BaseModel): name: Optional[str] = None @@ -712,13 +702,6 @@ class DeltaToolCall(OpenAIBaseModel): index: int function: Optional[DeltaFunctionCall] = None - def to_dict(self): - return { - "id": self.id, - "type": self.type, - "function": self.function.to_dict() if self.function else None - } - # the initial delta that gets sent once a new tool call is started; class InitialDeltaToolCall(DeltaToolCall): From 3895fd91a508104f49353c9c6ed47184409b6c1e Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 17:06:41 -0500 Subject: [PATCH 135/222] fix: give comments their own line --- vllm/entrypoints/openai/protocol.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index b26a2e0dcfeb0..3e690238841d6 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -746,10 +746,10 @@ class ChatCompletionResponseChoice(OpenAIBaseModel): index: int message: ChatMessage logprobs: Optional[ChatCompletionLogProbs] = None - finish_reason: Optional[str] = Field( - default="stop") # per OpenAI spec this is the default - stop_reason: Optional[Union[int, - str]] = None # ??? Not part of the OpenAI spec + # per OpenAI spec this is the default + finish_reason: Optional[str] = Field(default="stop") + # not part of the OpenAI spec but included in vLLM for legacy reasons + stop_reason: Optional[Union[int,str]] = None class ChatCompletionResponse(OpenAIBaseModel): From d9615197ed3a19e2ad00858d9d584b4bcb88b422 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 17:07:35 -0500 Subject: [PATCH 136/222] fix: remove unused loader --- vllm/entrypoints/openai/serving_chat.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index ee813ea4029fa..4345de90846ff 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -38,9 +38,6 @@ log_tracing_disabled_warning) from vllm.utils import iterate_with_cancellation, random_uuid -env = Environment(loader=FileSystemLoader("./"), - autoescape=select_autoescape()) - logger = init_logger(__name__) From 43a63183d66504486d6520b2db4740e60325ef56 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 17:10:06 -0500 Subject: [PATCH 137/222] fix: formatting --- vllm/entrypoints/openai/protocol.py | 2 +- vllm/entrypoints/openai/serving_chat.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 3e690238841d6..a778ea24e7e1d 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -749,7 +749,7 @@ class ChatCompletionResponseChoice(OpenAIBaseModel): # per OpenAI spec this is the default finish_reason: Optional[str] = Field(default="stop") # not part of the OpenAI spec but included in vLLM for legacy reasons - stop_reason: Optional[Union[int,str]] = None + stop_reason: Optional[Union[int, str]] = None class ChatCompletionResponse(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 4345de90846ff..73738f9e395b2 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -6,7 +6,6 @@ from typing import Type, Union from fastapi import Request -from jinja2 import Environment, FileSystemLoader, select_autoescape from transformers import PreTrainedTokenizer from vllm.config import ModelConfig From 7e90682401f989131ee2e1fec476c74380725406 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 23:20:47 -0500 Subject: [PATCH 138/222] fix: indents in hermes tool parser by making cases better --- .../openai/tool_parsers/hermes_tool_parser.py | 436 +++++++++--------- 1 file changed, 218 insertions(+), 218 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 1d06d806a8c97..0ca1f06181b08 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -120,236 +120,236 @@ def extract_tool_calls_streaming( logger.debug("No tool call tokens found!") return DeltaMessage(content=delta_text) - else: - try: - - # figure out where we are in the parsing by counting tool call - # start & end tags - prev_tool_start_count = previous_token_ids.count( - self.tool_call_start_token_id) - prev_tool_end_count = previous_token_ids.count( - self.tool_call_end_token_id) - cur_tool_start_count = current_token_ids.count( - self.tool_call_start_token_id) - cur_tool_end_count = current_token_ids.count( - self.tool_call_end_token_id) - - # a cheap case - we're generating text, NOT tool calls. - if (cur_tool_start_count == cur_tool_end_count - and prev_tool_end_count == cur_tool_end_count): - logger.debug( - "Generating text content! skipping tool parsing.") - return DeltaMessage(content=delta_text) - - # most of the time, we're going in here - we need to do partial - # JSON parsing and build stuff. - else: - # flags for partial JSON parting. exported constants from - # "Allow" are handled via BIT MASK - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR - - # if a new tool call is being started. unusual since - # normally the first "cheap case" will be hit. - if (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count > prev_tool_start_count): - if len(delta_token_ids) > 1: - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] - text_portion = None - else: - tool_call_portion = None - text_portion = None - delta = None - # set cursors and state appropriately - self.current_tool_id += 1 - self.current_tool_name_sent = False - self.current_tool_initial_sent = False - self.streamed_args_for_tool.append("") - logger.debug("Starting on a new tool %s", - self.current_tool_id) - - # if an existing tool call is being updated - the most - # common case! - elif (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count == prev_tool_start_count): + try: + + # figure out where we are in the parsing by counting tool call + # start & end tags + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) + + # case: if we're generating text, NOT tools, return a text delta + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count): + logger.debug( + "Generating text content! skipping tool parsing.") + return DeltaMessage(content=delta_text) + + # case: if tool open & close tag counts don't match, we're doing + # something with tools with this diff. + else: + # flags for partial JSON parting. exported constants from + # "Allow" are handled via BIT MASK + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + + # if a new tool call is being started. unusual since + # normally the first "cheap case" will be hit. + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: tool_call_portion = current_text.split( self.tool_call_start_token)[-1] text_portion = None - - # if the current tool call is being closed - elif (cur_tool_start_count == cur_tool_end_count - and cur_tool_end_count > prev_tool_end_count): - logger.debug("Closing the current tool call!") - diff = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") - if diff: - diff = json.dumps(diff).replace( - self.streamed_args_for_tool[ - self.current_tool_id], "") - logger.debug( - "Finishing tool and found diff that had not " - "been streamed yet: %s", diff) - return DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump( - exclude_none=True)) - ]) - else: - logger.error( - "INVARIANT - invalid state trying to parse tool " - "calls (wtf?)") + tool_call_portion = None + text_portion = None delta = None - return delta - - logger.debug("Tool call portion: %s", tool_call_portion - or "") - current_tool_call = partial_json_parser.loads( - tool_call_portion or "{}", - flags) if tool_call_portion else None - logger.debug("Parsed tool call %s", current_tool_call) - - # make sure to send the initial message first if we haven"t - # already - with the tool ID - if not self.current_tool_initial_sent: - logger.debug("Sending InitialDeltaToolCall") - self.current_tool_initial_sent = True + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.current_tool_initial_sent = False + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", + self.current_tool_id) + + # if an existing tool call is being updated - the most + # common case! + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # if the current tool call is being closed + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count > prev_tool_end_count): + logger.debug("Closing the current tool call!") + diff = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + if diff: + diff = json.dumps(diff).replace( + self.streamed_args_for_tool[ + self.current_tool_id], "") + logger.debug( + "Finishing tool and found diff that had not " + "been streamed yet: %s", diff) return DeltaMessage(tool_calls=[ - InitialDeltaToolCall( - index=self.current_tool_id).model_dump( - exclude_none=True) + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) ]) - # after that, make sure we send the function name before - # any arguments - elif not self.current_tool_name_sent: - function_name: Union[ - str, None] = current_tool_call.get("name") - if function_name: + else: + logger.error( + "INVARIANT - invalid state trying to parse tool " + "calls (wtf?)") + delta = None + return delta + + logger.debug("Tool call portion: %s", tool_call_portion + or "") + current_tool_call = partial_json_parser.loads( + tool_call_portion or "{}", + flags) if tool_call_portion else None + logger.debug("Parsed tool call %s", current_tool_call) + + # make sure to send the initial message first if we haven"t + # already - with the tool ID + if not self.current_tool_initial_sent: + logger.debug("Sending InitialDeltaToolCall") + self.current_tool_initial_sent = True + return DeltaMessage(tool_calls=[ + InitialDeltaToolCall( + index=self.current_tool_id).model_dump( + exclude_none=True) + ]) + + # after that, make sure we send the function name before + # any arguments + elif not self.current_tool_name_sent: + function_name: Union[ + str, None] = current_tool_call.get("name") + if function_name: + logger.debug( + "Sending DeltaToolCall with function name %s", + function_name) + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + name=function_name). + model_dump(exclude_none=True)) + ]) + else: + return None + else: + # if there is no tool calls + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = DeltaMessage( + content=delta_text + ) if text_portion is not None else None + # now, the nitty-gritty of tool calls + else: + # now we have the portion to parse as tool call. + if text_portion is not None: + logger.debug( + "Also, will send text portion: %s", + text_portion) + + logger.debug( + "Trying to parse current tool call with ID %s", + self.current_tool_id) + if len(self.prev_tool_call_arr + ) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + logger.debug( + "Pushed dummy value into tool call arr") + # main logic for tool parsing here + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + cur_arguments = current_tool_call.get( + "arguments" + ) # arguments, if any, in current dict + + logger.debug("diffing old arguments: %s", + prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", + delta_text) + delta = None + elif not cur_arguments and prev_arguments: + logger.error( + "INVARIANT - impossible to have arguments " + "reset mid-call") + delta = None + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments) + logger.debug("finding %s in %s", delta_text, + cur_arguments_json) + arguments_delta = cur_arguments_json[: + cur_arguments_json + .index( + delta_text + ) + + len(delta_text + )] logger.debug( - "Sending DeltaToolCall with function name %s", - function_name) - self.current_tool_name_sent = True - return DeltaMessage(tool_calls=[ + "First tokens in arguments received: %s", + arguments_delta) + delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( - name=function_name). - model_dump(exclude_none=True)) + arguments=arguments_delta + ).model_dump( + exclude_none=True)) ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + logger.debug("Searching for dif between\n%s", + cur_args_json) + logger.debug("and\n%s", prev_args_json) + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + logger.debug("got argument diff %s", + argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff else: - return None - else: - # if there is no tool calls - if tool_call_portion is None: - # if there's text but not tool calls, send that - - # otherwise None to skip chunk - delta = DeltaMessage( - content=delta_text - ) if text_portion is not None else None - # now, the nitty-gritty of tool calls - else: - # now we have the portion to parse as tool call. - if text_portion is not None: - logger.debug( - "Also, will send text portion: %s", - text_portion) - - logger.debug( - "Trying to parse current tool call with ID %s", - self.current_tool_id) - if len(self.prev_tool_call_arr - ) <= self.current_tool_id: - self.prev_tool_call_arr.append({}) - logger.debug( - "Pushed dummy value into tool call arr") - # main logic for tool parsing here - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") - cur_arguments = current_tool_call.get( - "arguments" - ) # arguments, if any, in current dict - - logger.debug("diffing old arguments: %s", - prev_arguments) - logger.debug("against new ones: %s", cur_arguments) - - if not cur_arguments and not prev_arguments: - logger.debug("Skipping text %s - no arguments", - delta_text) - delta = None - elif not cur_arguments and prev_arguments: - logger.error( - "INVARIANT - impossible to have arguments " - "reset mid-call") - delta = None - elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments) - logger.debug("finding %s in %s", delta_text, - cur_arguments_json) - arguments_delta = cur_arguments_json[: - cur_arguments_json - .index( - delta_text - ) + - len(delta_text - )] - logger.debug( - "First tokens in arguments received: %s", - arguments_delta) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta - ).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += arguments_delta - - elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments) - prev_args_json = json.dumps(prev_arguments) - logger.debug("Searching for dif between\n%s", - cur_args_json) - logger.debug("and\n%s", prev_args_json) - argument_diff = extract_intermediate_diff( - cur_args_json, prev_args_json) - logger.debug("got argument diff %s", - argument_diff) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff). - model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff - else: - delta = None - - # handle saving the state for the current tool into - # the "prev" list for use in diffing for - # the next iteration - if self.current_tool_id == len( - self.prev_tool_call_arr) - 1: - self.prev_tool_call_arr[ - self.current_tool_id] = current_tool_call - else: - self.prev_tool_call_arr.append( - current_tool_call) - - # TODO REPLACE ME WITH TOOL CALL - # delta = DeltaMessage(content=delta_text) - return delta + delta = None - except Exception as e: - logger.error("Error trying to handle streaming tool call: %s", - e) - logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") - return None # do not stream a delta. skip this token ID. + # handle saving the state for the current tool into + # the "prev" list for use in diffing for + # the next iteration + if self.current_tool_id == len( + self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[ + self.current_tool_id] = current_tool_call + else: + self.prev_tool_call_arr.append( + current_tool_call) + + # TODO REPLACE ME WITH TOOL CALL + # delta = DeltaMessage(content=delta_text) + return delta + + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", + e) + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None # do not stream a delta. skip this token ID. From c4c480c237cf52809b56a6842c12d9885e95a937 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 23:30:40 -0500 Subject: [PATCH 139/222] fix: readability for hermes tool call parser --- .../openai/tool_parsers/hermes_tool_parser.py | 380 +++++++++--------- 1 file changed, 188 insertions(+), 192 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 0ca1f06181b08..f5a1636cd2fb0 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -120,7 +120,6 @@ def extract_tool_calls_streaming( logger.debug("No tool call tokens found!") return DeltaMessage(content=delta_text) - try: # figure out where we are in the parsing by counting tool call @@ -137,214 +136,211 @@ def extract_tool_calls_streaming( # case: if we're generating text, NOT tools, return a text delta if (cur_tool_start_count == cur_tool_end_count and prev_tool_end_count == cur_tool_end_count): - logger.debug( - "Generating text content! skipping tool parsing.") + logger.debug("Generating text content! skipping tool parsing.") return DeltaMessage(content=delta_text) # case: if tool open & close tag counts don't match, we're doing - # something with tools with this diff. - else: - # flags for partial JSON parting. exported constants from - # "Allow" are handled via BIT MASK - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR - - # if a new tool call is being started. unusual since - # normally the first "cheap case" will be hit. - if (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count > prev_tool_start_count): - if len(delta_token_ids) > 1: - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] - text_portion = None - else: - tool_call_portion = None - text_portion = None - delta = None - - # set cursors and state appropriately - self.current_tool_id += 1 - self.current_tool_name_sent = False - self.current_tool_initial_sent = False - self.streamed_args_for_tool.append("") - logger.debug("Starting on a new tool %s", - self.current_tool_id) - - # if an existing tool call is being updated - the most - # common case! - elif (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count == prev_tool_start_count): + # imaginary "else" block here + # something with tools with this diff. + # flags for partial JSON parting. exported constants from + # "Allow" are handled via BIT MASK + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + + # case -- we're starting a new tool call + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: tool_call_portion = current_text.split( self.tool_call_start_token)[-1] - text_portion = None - - # if the current tool call is being closed - elif (cur_tool_start_count == cur_tool_end_count - and cur_tool_end_count > prev_tool_end_count): - logger.debug("Closing the current tool call!") - diff = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") - if diff: - diff = json.dumps(diff).replace( - self.streamed_args_for_tool[ - self.current_tool_id], "") - logger.debug( - "Finishing tool and found diff that had not " - "been streamed yet: %s", diff) - return DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump( - exclude_none=True)) - ]) - else: - logger.error( - "INVARIANT - invalid state trying to parse tool " - "calls (wtf?)") + tool_call_portion = None delta = None - return delta - - logger.debug("Tool call portion: %s", tool_call_portion - or "") - current_tool_call = partial_json_parser.loads( - tool_call_portion or "{}", - flags) if tool_call_portion else None - logger.debug("Parsed tool call %s", current_tool_call) - - # make sure to send the initial message first if we haven"t - # already - with the tool ID - if not self.current_tool_initial_sent: - logger.debug("Sending InitialDeltaToolCall") - self.current_tool_initial_sent = True + + text_portion = None + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.current_tool_initial_sent = False + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", + self.current_tool_id) + + # case -- we're updating an existing tool call + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + + # get the portion of the text that's the tool call + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # case -- the current tool call is being closed. + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count > prev_tool_end_count): + logger.debug("Closing the current tool call!") + diff = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + if diff: + diff = json.dumps(diff).replace( + self.streamed_args_for_tool[ + self.current_tool_id], "") + logger.debug( + "Finishing tool and found diff that had not " + "been streamed yet: %s", diff) return DeltaMessage(tool_calls=[ - InitialDeltaToolCall( - index=self.current_tool_id).model_dump( - exclude_none=True) + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) ]) - # after that, make sure we send the function name before - # any arguments - elif not self.current_tool_name_sent: - function_name: Union[ - str, None] = current_tool_call.get("name") - if function_name: + # case -- otherwise we're just generating text + else: + delta = DeltaMessage(tool_calls=[], content=delta_text) + return delta + + logger.debug("Tool call portion: %s", tool_call_portion + or "") + current_tool_call = partial_json_parser.loads( + tool_call_portion or "{}", + flags) if tool_call_portion else None + logger.debug("Parsed tool call %s", current_tool_call) + + # case - we haven't sent the initial delta with the tool call ID + # (it will be sent) + if not self.current_tool_initial_sent: + logger.debug("Sending InitialDeltaToolCall") + self.current_tool_initial_sent = True + return DeltaMessage(tool_calls=[ + InitialDeltaToolCall( + index=self.current_tool_id).model_dump( + exclude_none=True) + ]) + + # case - we haven't sent the tool name yet. If it's available, send + # it. otherwise, wait until it's available. + elif not self.current_tool_name_sent: + function_name: Union[ + str, None] = current_tool_call.get("name") + if function_name: + logger.debug( + "Sending DeltaToolCall with function name %s", + function_name) + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + name=function_name). + model_dump(exclude_none=True)) + ]) + else: + return None + else: + # if there is no tool calls + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = DeltaMessage( + content=delta_text + ) if text_portion is not None else None + # now, the nitty-gritty of tool calls + else: + # now we have the portion to parse as tool call. + if text_portion is not None: + logger.debug( + "Also, will send text portion: %s", + text_portion) + + logger.debug( + "Trying to parse current tool call with ID %s", + self.current_tool_id) + if len(self.prev_tool_call_arr + ) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + logger.debug( + "Pushed dummy value into tool call arr") + # main logic for tool parsing here + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + cur_arguments = current_tool_call.get( + "arguments" + ) # arguments, if any, in current dict + + logger.debug("diffing old arguments: %s", + prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", + delta_text) + delta = None + elif not cur_arguments and prev_arguments: + logger.error( + "INVARIANT - impossible to have arguments " + "reset mid-call") + delta = None + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments) + logger.debug("finding %s in %s", delta_text, + cur_arguments_json) + arguments_delta = cur_arguments_json[: + cur_arguments_json + .index( + delta_text + ) + + len(delta_text + )] logger.debug( - "Sending DeltaToolCall with function name %s", - function_name) - self.current_tool_name_sent = True - return DeltaMessage(tool_calls=[ + "First tokens in arguments received: %s", + arguments_delta) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta + ).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + logger.debug("Searching for dif between\n%s", + cur_args_json) + logger.debug("and\n%s", prev_args_json) + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + logger.debug("got argument diff %s", + argument_diff) + delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( - name=function_name). - model_dump(exclude_none=True)) + arguments=argument_diff). + model_dump( + exclude_none=True)) ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff else: - return None - else: - # if there is no tool calls - if tool_call_portion is None: - # if there's text but not tool calls, send that - - # otherwise None to skip chunk - delta = DeltaMessage( - content=delta_text - ) if text_portion is not None else None - # now, the nitty-gritty of tool calls + delta = None + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for + # the next iteration + if self.current_tool_id == len( + self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[ + self.current_tool_id] = current_tool_call else: - # now we have the portion to parse as tool call. - if text_portion is not None: - logger.debug( - "Also, will send text portion: %s", - text_portion) + self.prev_tool_call_arr.append( + current_tool_call) - logger.debug( - "Trying to parse current tool call with ID %s", - self.current_tool_id) - if len(self.prev_tool_call_arr - ) <= self.current_tool_id: - self.prev_tool_call_arr.append({}) - logger.debug( - "Pushed dummy value into tool call arr") - # main logic for tool parsing here - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") - cur_arguments = current_tool_call.get( - "arguments" - ) # arguments, if any, in current dict - - logger.debug("diffing old arguments: %s", - prev_arguments) - logger.debug("against new ones: %s", cur_arguments) - - if not cur_arguments and not prev_arguments: - logger.debug("Skipping text %s - no arguments", - delta_text) - delta = None - elif not cur_arguments and prev_arguments: - logger.error( - "INVARIANT - impossible to have arguments " - "reset mid-call") - delta = None - elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments) - logger.debug("finding %s in %s", delta_text, - cur_arguments_json) - arguments_delta = cur_arguments_json[: - cur_arguments_json - .index( - delta_text - ) + - len(delta_text - )] - logger.debug( - "First tokens in arguments received: %s", - arguments_delta) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta - ).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += arguments_delta - - elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments) - prev_args_json = json.dumps(prev_arguments) - logger.debug("Searching for dif between\n%s", - cur_args_json) - logger.debug("and\n%s", prev_args_json) - argument_diff = extract_intermediate_diff( - cur_args_json, prev_args_json) - logger.debug("got argument diff %s", - argument_diff) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff). - model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff - else: - delta = None - - # handle saving the state for the current tool into - # the "prev" list for use in diffing for - # the next iteration - if self.current_tool_id == len( - self.prev_tool_call_arr) - 1: - self.prev_tool_call_arr[ - self.current_tool_id] = current_tool_call - else: - self.prev_tool_call_arr.append( - current_tool_call) - - # TODO REPLACE ME WITH TOOL CALL - # delta = DeltaMessage(content=delta_text) - return delta + # TODO REPLACE ME WITH TOOL CALL + # delta = DeltaMessage(content=delta_text) + return delta except Exception as e: logger.error("Error trying to handle streaming tool call: %s", From f7a0e76dbbb54111f7719f081a4905ebfbefd9f3 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 23:35:25 -0500 Subject: [PATCH 140/222] fix: more hermes tool parser readability --- .../openai/tool_parsers/hermes_tool_parser.py | 215 +++++++++--------- 1 file changed, 108 insertions(+), 107 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index f5a1636cd2fb0..ff495f3042e20 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -83,9 +83,9 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: def __init__(self, tokenizer: Optional[Union[PreTrainedTokenizer, - PreTrainedTokenizerFast, - PreTrainedTokenizerFast, - AutoTokenizer]] = None): + PreTrainedTokenizerFast, + PreTrainedTokenizerFast, + AutoTokenizer]] = None): super().__init__(tokenizer) self.current_tool_name_sent: bool = False self.prev_tool_call_arr: List[Dict] = [] @@ -193,7 +193,7 @@ def extract_tool_calls_streaming( DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( arguments=diff).model_dump( - exclude_none=True)) + exclude_none=True)) ]) # case -- otherwise we're just generating text @@ -216,7 +216,7 @@ def extract_tool_calls_streaming( return DeltaMessage(tool_calls=[ InitialDeltaToolCall( index=self.current_tool_id).model_dump( - exclude_none=True) + exclude_none=True) ]) # case - we haven't sent the tool name yet. If it's available, send @@ -237,110 +237,111 @@ def extract_tool_calls_streaming( ]) else: return None + # case -- otherwise, send the tool call delta + + # if the tool call portion is None, send the delta as text + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = DeltaMessage(content=delta_text) \ + if text_portion is not None else None + return delta + + # now, the nitty-gritty of tool calls + # now we have the portion to parse as tool call. + if text_portion is not None: + logger.debug( + "Also, will send text portion: %s", + text_portion) + + logger.debug( + "Trying to parse current tool call with ID %s", + self.current_tool_id) + if len(self.prev_tool_call_arr + ) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + logger.debug( + "Pushed dummy value into tool call arr") + # main logic for tool parsing here + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + cur_arguments = current_tool_call.get( + "arguments" + ) # arguments, if any, in current dict + + logger.debug("diffing old arguments: %s", + prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", + delta_text) + delta = None + elif not cur_arguments and prev_arguments: + logger.error( + "INVARIANT - impossible to have arguments " + "reset mid-call") + delta = None + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments) + logger.debug("finding %s in %s", delta_text, + cur_arguments_json) + arguments_delta = cur_arguments_json[: + cur_arguments_json + .index( + delta_text + ) + + len(delta_text + )] + logger.debug( + "First tokens in arguments received: %s", + arguments_delta) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta + ).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + logger.debug("Searching for dif between\n%s", + cur_args_json) + logger.debug("and\n%s", prev_args_json) + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + logger.debug("got argument diff %s", + argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff else: - # if there is no tool calls - if tool_call_portion is None: - # if there's text but not tool calls, send that - - # otherwise None to skip chunk - delta = DeltaMessage( - content=delta_text - ) if text_portion is not None else None - # now, the nitty-gritty of tool calls - else: - # now we have the portion to parse as tool call. - if text_portion is not None: - logger.debug( - "Also, will send text portion: %s", - text_portion) + delta = None + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for + # the next iteration + if self.current_tool_id == len( + self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[ + self.current_tool_id] = current_tool_call + else: + self.prev_tool_call_arr.append( + current_tool_call) - logger.debug( - "Trying to parse current tool call with ID %s", - self.current_tool_id) - if len(self.prev_tool_call_arr - ) <= self.current_tool_id: - self.prev_tool_call_arr.append({}) - logger.debug( - "Pushed dummy value into tool call arr") - # main logic for tool parsing here - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") - cur_arguments = current_tool_call.get( - "arguments" - ) # arguments, if any, in current dict - - logger.debug("diffing old arguments: %s", - prev_arguments) - logger.debug("against new ones: %s", cur_arguments) - - if not cur_arguments and not prev_arguments: - logger.debug("Skipping text %s - no arguments", - delta_text) - delta = None - elif not cur_arguments and prev_arguments: - logger.error( - "INVARIANT - impossible to have arguments " - "reset mid-call") - delta = None - elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments) - logger.debug("finding %s in %s", delta_text, - cur_arguments_json) - arguments_delta = cur_arguments_json[: - cur_arguments_json - .index( - delta_text - ) + - len(delta_text - )] - logger.debug( - "First tokens in arguments received: %s", - arguments_delta) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta - ).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += arguments_delta - - elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments) - prev_args_json = json.dumps(prev_arguments) - logger.debug("Searching for dif between\n%s", - cur_args_json) - logger.debug("and\n%s", prev_args_json) - argument_diff = extract_intermediate_diff( - cur_args_json, prev_args_json) - logger.debug("got argument diff %s", - argument_diff) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff). - model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff - else: - delta = None - - # handle saving the state for the current tool into - # the "prev" list for use in diffing for - # the next iteration - if self.current_tool_id == len( - self.prev_tool_call_arr) - 1: - self.prev_tool_call_arr[ - self.current_tool_id] = current_tool_call - else: - self.prev_tool_call_arr.append( - current_tool_call) - - # TODO REPLACE ME WITH TOOL CALL - # delta = DeltaMessage(content=delta_text) - return delta + # TODO REPLACE ME WITH TOOL CALL + # delta = DeltaMessage(content=delta_text) + return delta except Exception as e: logger.error("Error trying to handle streaming tool call: %s", From 34746fd1875603b66656ace7d0f0e78924576571 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 23:42:23 -0500 Subject: [PATCH 141/222] fix: more clarify updates to hermes parser --- .../openai/tool_parsers/hermes_tool_parser.py | 55 +++++++++++-------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index ff495f3042e20..b6f5580049bed 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -254,45 +254,52 @@ def extract_tool_calls_streaming( "Also, will send text portion: %s", text_portion) - logger.debug( - "Trying to parse current tool call with ID %s", - self.current_tool_id) - if len(self.prev_tool_call_arr - ) <= self.current_tool_id: + logger.debug("Trying to parse current tool call with ID %s", + self.current_tool_id) + + # if we're starting a new tool call, push an empty object in as + # a placeholder for the arguments + if len(self.prev_tool_call_arr) <= self.current_tool_id: self.prev_tool_call_arr.append({}) logger.debug( "Pushed dummy value into tool call arr") - # main logic for tool parsing here - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") - cur_arguments = current_tool_call.get( - "arguments" - ) # arguments, if any, in current dict - - logger.debug("diffing old arguments: %s", - prev_arguments) + + # main logic for tool parsing here - compare prev. partially-parsed + # JSON to the current partially-parsed JSON + prev_arguments = (self.prev_tool_call_arr[self.current_tool_id] + .get("arguments")) + cur_arguments = current_tool_call.get("arguments") + + logger.debug("diffing old arguments: %s",prev_arguments) logger.debug("against new ones: %s", cur_arguments) + # case -- no arguments have been created yet. skip sending a delta. if not cur_arguments and not prev_arguments: logger.debug("Skipping text %s - no arguments", delta_text) delta = None + + # case -- prev arguments are defined, but non are now. + # probably impossible, but not a fatal error - just keep going elif not cur_arguments and prev_arguments: - logger.error( - "INVARIANT - impossible to have arguments " - "reset mid-call") + logger.error("should be impossible to have arguments reset " + "mid-call. skipping streaming anything.") delta = None + + # case -- we now have the first info about arguments available from + # autocompleting the JSON elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments) logger.debug("finding %s in %s", delta_text, cur_arguments_json) - arguments_delta = cur_arguments_json[: - cur_arguments_json - .index( - delta_text - ) + - len(delta_text - )] + + # get the location where previous args differ from current + args_delta_start_loc = cur_arguments_json.index(delta_text) \ + + len(delta_text) + + # use that to find the actual delta + arguments_delta = cur_arguments_json[:args_delta_start_loc] logger.debug( "First tokens in arguments received: %s", arguments_delta) From 40dab7978fbb3bc4729caf8ce3607f65b8da6bb0 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 23:47:02 -0500 Subject: [PATCH 142/222] fix: last hermes tool parser formatting & logic tweaks --- .../openai/tool_parsers/hermes_tool_parser.py | 105 +++++++----------- 1 file changed, 43 insertions(+), 62 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index b6f5580049bed..9de0b5a3471b2 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -83,9 +83,9 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: def __init__(self, tokenizer: Optional[Union[PreTrainedTokenizer, - PreTrainedTokenizerFast, - PreTrainedTokenizerFast, - AutoTokenizer]] = None): + PreTrainedTokenizerFast, + PreTrainedTokenizerFast, + AutoTokenizer]] = None): super().__init__(tokenizer) self.current_tool_name_sent: bool = False self.prev_tool_call_arr: List[Dict] = [] @@ -164,8 +164,7 @@ def extract_tool_calls_streaming( self.current_tool_name_sent = False self.current_tool_initial_sent = False self.streamed_args_for_tool.append("") - logger.debug("Starting on a new tool %s", - self.current_tool_id) + logger.debug("Starting on a new tool %s", self.current_tool_id) # case -- we're updating an existing tool call elif (cur_tool_start_count > cur_tool_end_count @@ -180,12 +179,11 @@ def extract_tool_calls_streaming( elif (cur_tool_start_count == cur_tool_end_count and cur_tool_end_count > prev_tool_end_count): logger.debug("Closing the current tool call!") - diff = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") + diff = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") if diff: diff = json.dumps(diff).replace( - self.streamed_args_for_tool[ - self.current_tool_id], "") + self.streamed_args_for_tool[self.current_tool_id], "") logger.debug( "Finishing tool and found diff that had not " "been streamed yet: %s", diff) @@ -193,7 +191,7 @@ def extract_tool_calls_streaming( DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( arguments=diff).model_dump( - exclude_none=True)) + exclude_none=True)) ]) # case -- otherwise we're just generating text @@ -201,8 +199,7 @@ def extract_tool_calls_streaming( delta = DeltaMessage(tool_calls=[], content=delta_text) return delta - logger.debug("Tool call portion: %s", tool_call_portion - or "") + logger.debug("Tool call portion: %s", tool_call_portion or "") current_tool_call = partial_json_parser.loads( tool_call_portion or "{}", flags) if tool_call_portion else None @@ -216,24 +213,22 @@ def extract_tool_calls_streaming( return DeltaMessage(tool_calls=[ InitialDeltaToolCall( index=self.current_tool_id).model_dump( - exclude_none=True) + exclude_none=True) ]) # case - we haven't sent the tool name yet. If it's available, send # it. otherwise, wait until it's available. elif not self.current_tool_name_sent: - function_name: Union[ - str, None] = current_tool_call.get("name") + function_name: Union[str, None] = current_tool_call.get("name") if function_name: - logger.debug( - "Sending DeltaToolCall with function name %s", - function_name) + logger.debug("Sending DeltaToolCall with function name %s", + function_name) self.current_tool_name_sent = True return DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( - name=function_name). - model_dump(exclude_none=True)) + name=function_name).model_dump( + exclude_none=True)) ]) else: return None @@ -250,9 +245,7 @@ def extract_tool_calls_streaming( # now, the nitty-gritty of tool calls # now we have the portion to parse as tool call. if text_portion is not None: - logger.debug( - "Also, will send text portion: %s", - text_portion) + logger.debug("Also, will send text portion: %s", text_portion) logger.debug("Trying to parse current tool call with ID %s", self.current_tool_id) @@ -261,22 +254,20 @@ def extract_tool_calls_streaming( # a placeholder for the arguments if len(self.prev_tool_call_arr) <= self.current_tool_id: self.prev_tool_call_arr.append({}) - logger.debug( - "Pushed dummy value into tool call arr") + logger.debug("Pushed dummy value into tool call arr") # main logic for tool parsing here - compare prev. partially-parsed # JSON to the current partially-parsed JSON - prev_arguments = (self.prev_tool_call_arr[self.current_tool_id] - .get("arguments")) + prev_arguments = ( + self.prev_tool_call_arr[self.current_tool_id].get("arguments")) cur_arguments = current_tool_call.get("arguments") - logger.debug("diffing old arguments: %s",prev_arguments) + logger.debug("diffing old arguments: %s", prev_arguments) logger.debug("against new ones: %s", cur_arguments) # case -- no arguments have been created yet. skip sending a delta. if not cur_arguments and not prev_arguments: - logger.debug("Skipping text %s - no arguments", - delta_text) + logger.debug("Skipping text %s - no arguments", delta_text) delta = None # case -- prev arguments are defined, but non are now. @@ -296,63 +287,53 @@ def extract_tool_calls_streaming( # get the location where previous args differ from current args_delta_start_loc = cur_arguments_json.index(delta_text) \ - + len(delta_text) + + len(delta_text) # use that to find the actual delta arguments_delta = cur_arguments_json[:args_delta_start_loc] - logger.debug( - "First tokens in arguments received: %s", - arguments_delta) + logger.debug("First tokens in arguments received: %s", + arguments_delta) + delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( - arguments=arguments_delta - ).model_dump( - exclude_none=True)) + arguments=arguments_delta).model_dump( + exclude_none=True)) ]) - self.streamed_args_for_tool[ - self.current_tool_id] += arguments_delta + self.streamed_args_for_tool[self.current_tool_id] \ + += arguments_delta + # last case -- we have an update to existing arguments. elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments) prev_args_json = json.dumps(prev_arguments) - logger.debug("Searching for dif between\n%s", - cur_args_json) + logger.debug("Searching for dif between\n%s", cur_args_json) logger.debug("and\n%s", prev_args_json) argument_diff = extract_intermediate_diff( cur_args_json, prev_args_json) - logger.debug("got argument diff %s", - argument_diff) + logger.debug("got argument diff %s", argument_diff) delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( - arguments=argument_diff). - model_dump( - exclude_none=True)) + arguments=argument_diff).model_dump( + exclude_none=True)) ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff - else: - delta = None + self.streamed_args_for_tool[self.current_tool_id] \ + += argument_diff # handle saving the state for the current tool into - # the "prev" list for use in diffing for - # the next iteration - if self.current_tool_id == len( - self.prev_tool_call_arr) - 1: - self.prev_tool_call_arr[ - self.current_tool_id] = current_tool_call + # the "prev" list for use in diffing for the next iteration + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[self.current_tool_id] = \ + current_tool_call else: - self.prev_tool_call_arr.append( - current_tool_call) + self.prev_tool_call_arr.append(current_tool_call) - # TODO REPLACE ME WITH TOOL CALL - # delta = DeltaMessage(content=delta_text) return delta except Exception as e: - logger.error("Error trying to handle streaming tool call: %s", - e) + logger.error("Error trying to handle streaming tool call: %s", e) logger.debug( "Skipping chunk as a result of tool streaming extraction " "error") From ba26f22662ee566c23a0d17b0936a9949295ae62 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 23:55:18 -0500 Subject: [PATCH 143/222] fix: mistral tool call parser formatting --- .../tool_parsers/mistral_tool_parser.py | 140 +++++++++--------- 1 file changed, 69 insertions(+), 71 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 58f2b18231f26..f491346411024 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -24,22 +24,18 @@ class MistralToolParser(ToolParser): """ Tool call parser for Mistral 7B Instruct v0.3, intended for use with the - examples/tool_chat_template_mistral.jinja template. There are several - IMPORTANT CAVEATS for this parser: - - The chat template is NOT official and does not work well if you try to - get the model to call 2+ tools at once without temperature=0. - Stick to only one tool call per generation, or set temp to 0 - as the chat template is not reliable with > 1 and the model - Will lose coherence. - - Mistral's tool call format, that this translates into an OpenAI - format, uses SINGLE QUOTES which cannot be parsed to JSON. To enable - JSON parsing and serialization, we find-and-replace these with - DOUBLE QUOTES. To prevent tool call corruption / deserialization - failure, ensure that your tool calls and in particular your - ARGUMENTS never contain single or double quotes except as JSON - control characters. - - Used when --enable-api-tools --enable-auto-tool-choice --tool-call-parser + examples/tool_chat_template_mistral.jinja template. There is an + IMPORTANT CAVEAT for this parser: + + NOTE: Mistral's tool call format, that this translates into an OpenAI + format, uses SINGLE QUOTES which cannot be parsed to JSON. To enable + JSON parsing and serialization, we find-and-replace these with + DOUBLE QUOTES. To prevent tool call corruption / deserialization + failure, ensure that your tool calls and in particular your + ARGUMENTS never contain single or double quotes except as JSON + control characters. + + Used when --enable-auto-tool-choice --tool-call-parser mistral are all set """ @@ -58,58 +54,60 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: make sure your tool call arguments don't ever include quotes! """ - logger.debug( - "Trying to extract mistral tool calls from the following:") + logger.debug("Trying to extract mistral tool calls from the following:") logger.debug(model_output) - # Get the tool call token from the tokenizer + + # case -- if a tool call token is not present, return a text response if MistralToolParser.bot_token not in model_output: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output + ) + try: + + # use a regex to find the tool call. remove the BOT token + # and make sure to replace single quotes with double quotes + raw_tool_call = MistralToolParser.tool_call_regex.findall( + model_output.replace(MistralToolParser.bot_token, "") + .replace("'", "\"") + )[0] + + # load the JSON, and then use it to build the Function and + # Tool Call + function_call_arr = json.loads(raw_tool_call) + tool_calls: List[ToolCall] = [ + ToolCall( + type="function", + function=FunctionCall( + name=raw_function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps( + raw_function_call["arguments"]))) + for raw_function_call in function_call_arr + ] + + # get any content before the tool call + content = model_output.split(MistralToolParser.bot_token)[0] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if len(content) > 0 else None) + + except Exception as e: + logger.error("Error in extracting tool call from response: %s", + e) + print("ERROR", e) + # return information to just treat the tool call as regular JSON return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) - else: - try: - - # this will throw an exception if we can't find the tool call - # properly - raw_tool_call = MistralToolParser.tool_call_regex.findall( - model_output.replace(MistralToolParser.bot_token, - "") # remove BOT token - .replace("'", "\"") # replace string quotes - )[0] - - # load the JSON, and then use it to build the Function and - # Tool Call - function_call_arr = json.loads(raw_tool_call) - tool_calls: List[ToolCall] = [ - ToolCall( - type="function", - function=FunctionCall( - name=raw_function_call["name"], - # function call args are JSON but as a string - arguments=json.dumps( - raw_function_call["arguments"]))) - for raw_function_call in function_call_arr - ] - content = model_output.split(MistralToolParser.bot_token)[0] - return ExtractedToolCallInformation( - tools_called=True, - tool_calls=tool_calls, - content=content if len(content) > 0 else None) - - except Exception as e: - logger.error("Error in extracting tool call from response: %s", - e) - print("ERROR", e) - # return information to just treat the tool call as regular JSON - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) def __init__(self, tokenizer: Optional[Union[PreTrainedTokenizer, - PreTrainedTokenizerFast, - PreTrainedTokenizerFast, - AutoTokenizer]] = None): + PreTrainedTokenizerFast, + PreTrainedTokenizerFast, + AutoTokenizer]] = None): super().__init__(tokenizer) # initialize properties used for state when parsing tool calls in @@ -122,13 +120,13 @@ def __init__(self, ] # map what has been streamed for each tool so far to a list def extract_tool_calls_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: List[int], - current_token_ids: List[int], - delta_token_ids: List[int], + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: List[int], + current_token_ids: List[int], + delta_token_ids: List[int], ) -> Union[DeltaMessage, None]: # if the tool call token is not in the tokens generated so far, append @@ -184,7 +182,7 @@ def extract_tool_calls_streaming( # streamed to the client yet. if self.current_tool_id >= 0: diff: Union[str, - None] = current_tool_call.get("arguments") + None] = current_tool_call.get("arguments") if diff: diff = json.dumps(diff).replace( self.streamed_args_for_tool[ @@ -193,7 +191,7 @@ def extract_tool_calls_streaming( DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( arguments=diff).model_dump( - exclude_none=True)) + exclude_none=True)) ]) self.streamed_args_for_tool[ self.current_tool_id] += diff @@ -229,7 +227,7 @@ def extract_tool_calls_streaming( delta = DeltaMessage(tool_calls=[ InitialDeltaToolCall( index=self.current_tool_id).model_dump( - exclude_none=True) + exclude_none=True) ]) # if the current tool name hasn't been sent, send if available @@ -244,7 +242,7 @@ def extract_tool_calls_streaming( DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( name=function_name).model_dump( - exclude_none=True)) + exclude_none=True)) ]) self.current_tool_name_sent = True else: From 79e8bb310d9018be47554cda14c148ab667affaf Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 7 Aug 2024 23:55:49 -0500 Subject: [PATCH 144/222] fix: remove unnecessary else block in mistral tool call parser --- .../tool_parsers/mistral_tool_parser.py | 341 +++++++++--------- 1 file changed, 170 insertions(+), 171 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index f491346411024..89865431e8d22 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -136,189 +136,188 @@ def extract_tool_calls_streaming( # if the tool call token ID IS in the tokens generated so far, that # means we're parsing as tool calls now - else: - - # handle if we detected the BOT token which means the start of tool - # calling - if (self.bot_token_id in delta_token_ids - and len(delta_token_ids) == 1): - # if it's the only token, return None, so we don't send a chat - # completion any don't send a control token - return None - - # bit mask flags for partial JSON parsing. If the name hasn't been - # sent yet, don't allow sending - # an incomplete string since OpenAI only ever (as far as I have - # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR - try: - - # replace BOT token with empty string, and convert single quotes - # to double to allow parsing as JSON since mistral uses single - # quotes instead of double for tool calls - tool_call_message_portion = current_text.split( - self.bot_token)[1] - parsable_arr = tool_call_message_portion.replace("\'", "\"") - - # logger.debug('parsing: %s', parsable_arr) - - # tool calls are generated in an array, so do partial JSON - # parsing on the entire array - tool_call_arr: List[Dict] = partial_json_parser.loads( - parsable_arr, flags) - - # select as the current tool call the one we're on the state at - current_tool_call: Dict = tool_call_arr[self.current_tool_id] - - # case: we are starting a new tool in the array - # -> array has > 0 length AND length has moved past cursor - if len(tool_call_arr) > 0 and len( - tool_call_arr) > self.current_tool_id + 1: - - # if we're moving on to a new call, first make sure we - # haven't missed anything in the previous one that was - # auto-generated due to JSON completions, but wasn't - # streamed to the client yet. - if self.current_tool_id >= 0: - diff: Union[str, - None] = current_tool_call.get("arguments") - if diff: - diff = json.dumps(diff).replace( - self.streamed_args_for_tool[ - self.current_tool_id], "") - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += diff - else: - delta = None - else: - delta = None - # re-set stuff pertaining to progress in the current tool - self.current_tool_id = len(tool_call_arr) - 1 - self.current_tool_name_sent = False - self.current_tool_initial_sent = False - self.streamed_args_for_tool.append("") - logger.debug("starting on new tool %d", - self.current_tool_id) - return delta - - # case: update an existing tool - this is handled below - elif len( - tool_call_arr - ) - 1 == self.current_tool_id and self.current_tool_id >= 0: - pass - - # if there is NOTHING in the array, e.g. if only the open - # bracket was streamed yet - else: - return None - # if the current tool initial data incl. the id, type=function - # and idx not sent, send that - if not self.current_tool_initial_sent: - logger.debug("Sending InitialDeltaToolCall") - self.current_tool_initial_sent = True - delta = DeltaMessage(tool_calls=[ - InitialDeltaToolCall( - index=self.current_tool_id).model_dump( - exclude_none=True) - ]) + # handle if we detected the BOT token which means the start of tool + # calling + if (self.bot_token_id in delta_token_ids + and len(delta_token_ids) == 1): + # if it's the only token, return None, so we don't send a chat + # completion any don't send a control token + return None + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: - # if the current tool name hasn't been sent, send if available - # - otherwise no chunks - elif not self.current_tool_name_sent: - function_name = current_tool_call.get("name") - if function_name: - logger.debug( - "Sending DeltaToolCall with function name %s", - function_name) + # replace BOT token with empty string, and convert single quotes + # to double to allow parsing as JSON since mistral uses single + # quotes instead of double for tool calls + tool_call_message_portion = current_text.split( + self.bot_token)[1] + parsable_arr = tool_call_message_portion.replace("\'", "\"") + + # logger.debug('parsing: %s', parsable_arr) + + # tool calls are generated in an array, so do partial JSON + # parsing on the entire array + tool_call_arr: List[Dict] = partial_json_parser.loads( + parsable_arr, flags) + + # select as the current tool call the one we're on the state at + current_tool_call: Dict = tool_call_arr[self.current_tool_id] + + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + if len(tool_call_arr) > 0 and len( + tool_call_arr) > self.current_tool_id + 1: + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + diff: Union[str, + None] = current_tool_call.get("arguments") + if diff: + diff = json.dumps(diff).replace( + self.streamed_args_for_tool[ + self.current_tool_id], "") delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( - name=function_name).model_dump( + arguments=diff).model_dump( exclude_none=True)) ]) - self.current_tool_name_sent = True + self.streamed_args_for_tool[ + self.current_tool_id] += diff else: delta = None - - # now we know we're on the same tool call and we're streaming - # arguments else: + delta = None + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.current_tool_initial_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", + self.current_tool_id) + return delta - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") - cur_arguments = current_tool_call.get("arguments") - - new_text = delta_text.replace("\'", "\"") - - if not cur_arguments and not prev_arguments: + # case: update an existing tool - this is handled below + elif len( + tool_call_arr + ) - 1 == self.current_tool_id and self.current_tool_id >= 0: + pass - delta = None - elif not cur_arguments and prev_arguments: - logger.error( - "INVARIANT - impossible to have arguments reset " - "mid-arguments") - delta = None - elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments) - logger.debug("finding %s in |%s|", new_text, - cur_arguments_json) - - arguments_delta = cur_arguments_json[: - cur_arguments_json - .index(new_text) + - len(new_text)] - logger.debug("First tokens in arguments received: %s", - arguments_delta) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += arguments_delta - - elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments) - prev_args_json = json.dumps(prev_arguments) - logger.debug("Searching for diff between \n%s\n%s", - cur_args_json, prev_args_json) + # if there is NOTHING in the array, e.g. if only the open + # bracket was streamed yet + else: + return None - argument_diff = extract_intermediate_diff( - cur_args_json, prev_args_json) - logger.debug("got arguments diff: %s", argument_diff) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff - else: - # try parsing it with regular JSON - if it works we're - # at the end, and we need to send the difference between - # tokens streamed so far and the valid JSON - delta = None + # if the current tool initial data incl. the id, type=function + # and idx not sent, send that + if not self.current_tool_initial_sent: + logger.debug("Sending InitialDeltaToolCall") + self.current_tool_initial_sent = True + delta = DeltaMessage(tool_calls=[ + InitialDeltaToolCall( + index=self.current_tool_id).model_dump( + exclude_none=True) + ]) + + # if the current tool name hasn't been sent, send if available + # - otherwise no chunks + elif not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + logger.debug( + "Sending DeltaToolCall with function name %s", + function_name) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + + # now we know we're on the same tool call and we're streaming + # arguments + else: + + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + cur_arguments = current_tool_call.get("arguments") + + new_text = delta_text.replace("\'", "\"") + + if not cur_arguments and not prev_arguments: + + delta = None + elif not cur_arguments and prev_arguments: + logger.error( + "INVARIANT - impossible to have arguments reset " + "mid-arguments") + delta = None + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments) + logger.debug("finding %s in |%s|", new_text, + cur_arguments_json) + + arguments_delta = cur_arguments_json[: + cur_arguments_json + .index(new_text) + + len(new_text)] + logger.debug("First tokens in arguments received: %s", + arguments_delta) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + logger.debug("Searching for diff between \n%s\n%s", + cur_args_json, prev_args_json) + + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + # try parsing it with regular JSON - if it works we're + # at the end, and we need to send the difference between + # tokens streamed so far and the valid JSON + delta = None - # check to see if the name is defined and has been sent. if so, - # stream the name - otherwise keep waiting - # finish by setting old and returning None as base case - self.prev_tool_call_arr = tool_call_arr - return delta + # check to see if the name is defined and has been sent. if so, + # stream the name - otherwise keep waiting + # finish by setting old and returning None as base case + self.prev_tool_call_arr = tool_call_arr + return delta - except Exception as e: - logger.error("Error trying to handle streaming tool call: %s", - e) - logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") - return None + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", + e) + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None From 844d26595b7f61ad5ebaca0d1d6c5ae8143710c4 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 8 Aug 2024 00:02:46 -0500 Subject: [PATCH 145/222] fix: formatting and control flow for mistral tool parser --- .../tool_parsers/mistral_tool_parser.py | 32 +++++++++---------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 89865431e8d22..914cab857f783 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -160,28 +160,33 @@ def extract_tool_calls_streaming( self.bot_token)[1] parsable_arr = tool_call_message_portion.replace("\'", "\"") - # logger.debug('parsing: %s', parsable_arr) - # tool calls are generated in an array, so do partial JSON # parsing on the entire array - tool_call_arr: List[Dict] = partial_json_parser.loads( - parsable_arr, flags) + tool_call_arr: List[Dict] = partial_json_parser.loads(parsable_arr, + flags) # select as the current tool call the one we're on the state at current_tool_call: Dict = tool_call_arr[self.current_tool_id] + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if len(tool_call_arr) == 0: + return None + # case: we are starting a new tool in the array # -> array has > 0 length AND length has moved past cursor - if len(tool_call_arr) > 0 and len( - tool_call_arr) > self.current_tool_id + 1: + elif ( + len(tool_call_arr) > 0 and + len(tool_call_arr) > self.current_tool_id + 1 + ): # if we're moving on to a new call, first make sure we # haven't missed anything in the previous one that was # auto-generated due to JSON completions, but wasn't # streamed to the client yet. if self.current_tool_id >= 0: - diff: Union[str, - None] = current_tool_call.get("arguments") + diff: Union[str,None] = current_tool_call.get("arguments") + if diff: diff = json.dumps(diff).replace( self.streamed_args_for_tool[ @@ -208,15 +213,8 @@ def extract_tool_calls_streaming( return delta # case: update an existing tool - this is handled below - elif len( - tool_call_arr - ) - 1 == self.current_tool_id and self.current_tool_id >= 0: - pass - - # if there is NOTHING in the array, e.g. if only the open - # bracket was streamed yet - else: - return None + #elif len(tool_call_arr) - 1 == self.current_tool_id and self.current_tool_id >= 0: + # pass # if the current tool initial data incl. the id, type=function # and idx not sent, send that From 1fa202759740d0babb9dc5abd077ba902e731f98 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 8 Aug 2024 00:07:22 -0500 Subject: [PATCH 146/222] fix: formatting updates to mistral tool call parser --- .../tool_parsers/mistral_tool_parser.py | 91 ++++++++----------- 1 file changed, 39 insertions(+), 52 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 914cab857f783..2a3b70f6b1cc4 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -54,24 +54,22 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: make sure your tool call arguments don't ever include quotes! """ - logger.debug("Trying to extract mistral tool calls from the following:") + logger.debug( + "Trying to extract mistral tool calls from the following:") logger.debug(model_output) # case -- if a tool call token is not present, return a text response if MistralToolParser.bot_token not in model_output: - return ExtractedToolCallInformation( - tools_called=False, - tool_calls=[], - content=model_output - ) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) try: # use a regex to find the tool call. remove the BOT token # and make sure to replace single quotes with double quotes raw_tool_call = MistralToolParser.tool_call_regex.findall( - model_output.replace(MistralToolParser.bot_token, "") - .replace("'", "\"") - )[0] + model_output.replace(MistralToolParser.bot_token, + "").replace("'", "\""))[0] # load the JSON, and then use it to build the Function and # Tool Call @@ -82,8 +80,7 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: function=FunctionCall( name=raw_function_call["name"], # function call args are JSON but as a string - arguments=json.dumps( - raw_function_call["arguments"]))) + arguments=json.dumps(raw_function_call["arguments"]))) for raw_function_call in function_call_arr ] @@ -95,8 +92,7 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: content=content if len(content) > 0 else None) except Exception as e: - logger.error("Error in extracting tool call from response: %s", - e) + logger.error("Error in extracting tool call from response: %s", e) print("ERROR", e) # return information to just treat the tool call as regular JSON return ExtractedToolCallInformation(tools_called=False, @@ -105,9 +101,9 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: def __init__(self, tokenizer: Optional[Union[PreTrainedTokenizer, - PreTrainedTokenizerFast, - PreTrainedTokenizerFast, - AutoTokenizer]] = None): + PreTrainedTokenizerFast, + PreTrainedTokenizerFast, + AutoTokenizer]] = None): super().__init__(tokenizer) # initialize properties used for state when parsing tool calls in @@ -120,13 +116,13 @@ def __init__(self, ] # map what has been streamed for each tool so far to a list def extract_tool_calls_streaming( - self, - previous_text: str, - current_text: str, - delta_text: str, - previous_token_ids: List[int], - current_token_ids: List[int], - delta_token_ids: List[int], + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: List[int], + current_token_ids: List[int], + delta_token_ids: List[int], ) -> Union[DeltaMessage, None]: # if the tool call token is not in the tokens generated so far, append @@ -156,14 +152,13 @@ def extract_tool_calls_streaming( # replace BOT token with empty string, and convert single quotes # to double to allow parsing as JSON since mistral uses single # quotes instead of double for tool calls - tool_call_message_portion = current_text.split( - self.bot_token)[1] + tool_call_message_portion = current_text.split(self.bot_token)[1] parsable_arr = tool_call_message_portion.replace("\'", "\"") # tool calls are generated in an array, so do partial JSON # parsing on the entire array - tool_call_arr: List[Dict] = partial_json_parser.loads(parsable_arr, - flags) + tool_call_arr: List[Dict] = partial_json_parser.loads( + parsable_arr, flags) # select as the current tool call the one we're on the state at current_tool_call: Dict = tool_call_arr[self.current_tool_id] @@ -175,27 +170,25 @@ def extract_tool_calls_streaming( # case: we are starting a new tool in the array # -> array has > 0 length AND length has moved past cursor - elif ( - len(tool_call_arr) > 0 and - len(tool_call_arr) > self.current_tool_id + 1 - ): + elif (len(tool_call_arr) > 0 + and len(tool_call_arr) > self.current_tool_id + 1): # if we're moving on to a new call, first make sure we # haven't missed anything in the previous one that was # auto-generated due to JSON completions, but wasn't # streamed to the client yet. if self.current_tool_id >= 0: - diff: Union[str,None] = current_tool_call.get("arguments") + diff: Union[str, None] = current_tool_call.get("arguments") if diff: diff = json.dumps(diff).replace( - self.streamed_args_for_tool[ - self.current_tool_id], "") + self.streamed_args_for_tool[self.current_tool_id], + "") delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( arguments=diff).model_dump( - exclude_none=True)) + exclude_none=True)) ]) self.streamed_args_for_tool[ self.current_tool_id] += diff @@ -208,13 +201,10 @@ def extract_tool_calls_streaming( self.current_tool_name_sent = False self.current_tool_initial_sent = False self.streamed_args_for_tool.append("") - logger.debug("starting on new tool %d", - self.current_tool_id) + logger.debug("starting on new tool %d", self.current_tool_id) return delta # case: update an existing tool - this is handled below - #elif len(tool_call_arr) - 1 == self.current_tool_id and self.current_tool_id >= 0: - # pass # if the current tool initial data incl. the id, type=function # and idx not sent, send that @@ -224,22 +214,21 @@ def extract_tool_calls_streaming( delta = DeltaMessage(tool_calls=[ InitialDeltaToolCall( index=self.current_tool_id).model_dump( - exclude_none=True) + exclude_none=True) ]) # if the current tool name hasn't been sent, send if available - # - otherwise no chunks + # - otherwise send nothing elif not self.current_tool_name_sent: function_name = current_tool_call.get("name") if function_name: - logger.debug( - "Sending DeltaToolCall with function name %s", - function_name) + logger.debug("Sending DeltaToolCall with function name %s", + function_name) delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( name=function_name).model_dump( - exclude_none=True)) + exclude_none=True)) ]) self.current_tool_name_sent = True else: @@ -268,9 +257,8 @@ def extract_tool_calls_streaming( logger.debug("finding %s in |%s|", new_text, cur_arguments_json) - arguments_delta = cur_arguments_json[: - cur_arguments_json - .index(new_text) + + arguments_delta = cur_arguments_json[:cur_arguments_json. + index(new_text) + len(new_text)] logger.debug("First tokens in arguments received: %s", arguments_delta) @@ -295,8 +283,8 @@ def extract_tool_calls_streaming( delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( - arguments=argument_diff). - model_dump(exclude_none=True)) + arguments=argument_diff).model_dump( + exclude_none=True)) ]) self.streamed_args_for_tool[ self.current_tool_id] += argument_diff @@ -313,8 +301,7 @@ def extract_tool_calls_streaming( return delta except Exception as e: - logger.error("Error trying to handle streaming tool call: %s", - e) + logger.error("Error trying to handle streaming tool call: %s", e) logger.debug( "Skipping chunk as a result of tool streaming extraction " "error") From 6fc1f757c7ac74b1d4d04af203394f0c250e519c Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 8 Aug 2024 17:22:29 -0500 Subject: [PATCH 147/222] fix: catch a silent exception in hermes tool parser that was generating error lines but not causing problems --- .../openai/tool_parsers/hermes_tool_parser.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 9de0b5a3471b2..fb361e6a01dbc 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -200,10 +200,15 @@ def extract_tool_calls_streaming( return delta logger.debug("Tool call portion: %s", tool_call_portion or "") - current_tool_call = partial_json_parser.loads( - tool_call_portion or "{}", - flags) if tool_call_portion else None - logger.debug("Parsed tool call %s", current_tool_call) + try: + + current_tool_call = partial_json_parser.loads( + tool_call_portion or "{}", + flags) if tool_call_portion else None + logger.debug("Parsed tool call %s", current_tool_call) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None # case - we haven't sent the initial delta with the tool call ID # (it will be sent) @@ -334,7 +339,4 @@ def extract_tool_calls_streaming( except Exception as e: logger.error("Error trying to handle streaming tool call: %s", e) - logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") return None # do not stream a delta. skip this token ID. From a89e565b37a8aca4512fef72d4ea82907051b50b Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 8 Aug 2024 17:50:15 -0500 Subject: [PATCH 148/222] fix: refactoring & cleanup serving_chat --- vllm/entrypoints/openai/serving_chat.py | 89 +++++++++++++++++++------ 1 file changed, 67 insertions(+), 22 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 73738f9e395b2..e7053ab4f8f7c 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -31,7 +31,7 @@ from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.multimodal import MultiModalDataDict -from vllm.outputs import RequestOutput +from vllm.outputs import CompletionOutput, RequestOutput from vllm.sequence import Logprob from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) @@ -361,7 +361,7 @@ async def chat_completion_stream_generator( delta_text = output.text[len(previous_texts[i]):] delta_message: Optional[DeltaMessage] = None - # handle streaming deltas for tools with tool_choice + # handle streaming deltas for tools with named tool_choice if (request.tool_choice and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam): delta_message = DeltaMessage(tool_calls=[ @@ -371,25 +371,28 @@ async def chat_completion_stream_generator( index=i) ]) - # handle streaming deltas for tools with tool_choice - elif (request.tools and tool_parser - and (request.tool_choice is None - or request.tool_choice == "auto") - and self.enable_auto_tools): - + # handle streaming deltas for tools with "auto" tool choice + elif (self._should_stream_with_auto_tool_parsing(request) + and tool_parser): delta_message = ( tool_parser.extract_tool_calls_streaming( previous_text=previous_texts[i], current_text=output.text, delta_text=delta_text, - previous_token_ids=output. - token_ids[:-1 * len(delta_token_ids)], + previous_token_ids= \ + output.token_ids[ + :-1 * len(delta_token_ids) + ], current_token_ids=output.token_ids, - delta_token_ids=delta_token_ids)) + delta_token_ids=delta_token_ids + ) + ) + + # handle streaming just a content delta else: delta_message = DeltaMessage(content=delta_text) - # handle setting the previous values for the next iteration + # set the previous values for the next iteration previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) @@ -431,6 +434,8 @@ async def chat_completion_stream_generator( data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" + + # if the model is finished generating else: # check to make sure we haven't "forgotten" to stream # any tokens that were generated but previously @@ -442,29 +447,33 @@ async def chat_completion_stream_generator( tool_parser.prev_tool_call_arr) > 0 else 0 else: index = 0 - if (delta_message.tool_calls - and delta_message.tool_calls[0] - and delta_message.tool_calls[0].function and - (delta_message.tool_calls[0].function.arguments - == "" or - delta_message.tool_calls[0].function.arguments and - (output.finish_reason == "stop" - or output.finish_reason == "tool_calls")) - and tool_parser - and request.tool_choice == "auto"): + + if self._should_check_for_unstreamed_tool_arg_tokens( + delta_message, output) and tool_parser: + + # get the expected call based on partial JSON + # parsing which "autocompletes" the JSON expected_call = json.dumps( tool_parser.prev_tool_call_arr[index].get( "arguments", {})) + + # get what we've streamed so for for arguments + # for the current tool actual_call = tool_parser.streamed_args_for_tool[ index] + + # check to see if there's anything left to stream remaining_call = expected_call.replace( actual_call, "", 1) + + # set that as a delta message delta_message = DeltaMessage(tool_calls=[ DeltaToolCall(index=index, function=DeltaFunctionCall( arguments=remaining_call). model_dump(exclude_none=True)) ]) + # Send the finish response for each request.n only once prompt_tokens = len(res.prompt_token_ids) choice_data = ChatCompletionResponseStreamChoice( @@ -524,6 +533,7 @@ async def chat_completion_stream_generator( except ValueError as e: # TODO: Use a vllm-specific Validation Error logger.error("error in chat completion stream generator: %s", e) + print(e) data = self.create_streaming_error_response(str(e)) yield f"data: {data}\n\n" # Send the final done message after all response.n are finished @@ -717,3 +727,38 @@ def _create_chat_logprobs( tokenizer))) return ChatCompletionLogProbs(content=logprobs_content) + + def _should_stream_with_auto_tool_parsing(self, + request: ChatCompletionRequest): + """ + Utility function to check if streamed tokens should go through the tool + call parser that was configured. + + We only want to do this IF user-provided tools are set, a tool parser + is configured, "auto" tool choice is enabled, and the request's tool + choice field indicates that "auto" tool choice should be used. + """ + return (request.tools and self.tool_parser and self.enable_auto_tools + and request.tool_choice in ['auto', None]) + + def _should_check_for_unstreamed_tool_arg_tokens( + self, + delta_message: Optional[DeltaMessage], + output: CompletionOutput, + ) -> bool: + """ + Check to see if we should check for unstreamed tool arguments tokens. + This is only applicable when auto tool parsing is enabled, the delta + is a tool call with arguments. + """ + + # yapf: disable + return bool( + # if there is a delta message that includes tool calls which + # include a function that has arguments + self.enable_auto_tools and self.tool_parser and delta_message + and delta_message.tool_calls and delta_message.tool_calls[0] + and delta_message.tool_calls[0].function + and delta_message.tool_calls[0].function.arguments is not None + and output.finish_reason is not None + ) From 9f0a8039adf85a08913423cd651119043251cd0c Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 8 Aug 2024 18:02:00 -0500 Subject: [PATCH 149/222] fix: two silent errors in mistral tool parser (not causing problems) and remove the replaceAll since its not necessary any more --- .../tool_parsers/mistral_tool_parser.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 2a3b70f6b1cc4..5b9bf90e765a8 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -68,8 +68,7 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: # use a regex to find the tool call. remove the BOT token # and make sure to replace single quotes with double quotes raw_tool_call = MistralToolParser.tool_call_regex.findall( - model_output.replace(MistralToolParser.bot_token, - "").replace("'", "\""))[0] + model_output.replace(MistralToolParser.bot_token, ""))[0] # load the JSON, and then use it to build the Function and # Tool Call @@ -152,16 +151,21 @@ def extract_tool_calls_streaming( # replace BOT token with empty string, and convert single quotes # to double to allow parsing as JSON since mistral uses single # quotes instead of double for tool calls - tool_call_message_portion = current_text.split(self.bot_token)[1] - parsable_arr = tool_call_message_portion.replace("\'", "\"") + parsable_arr = current_text.split(self.bot_token)[1] # tool calls are generated in an array, so do partial JSON # parsing on the entire array - tool_call_arr: List[Dict] = partial_json_parser.loads( - parsable_arr, flags) + try: + tool_call_arr: List[Dict] = partial_json_parser.loads( + parsable_arr, flags) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None # select as the current tool call the one we're on the state at - current_tool_call: Dict = tool_call_arr[self.current_tool_id] + + current_tool_call: Dict = tool_call_arr[self.current_tool_id] \ + if len(tool_call_arr) > 0 else {} # case -- if no tokens have been streamed for the tool, e.g. # only the array brackets, stream nothing @@ -242,6 +246,8 @@ def extract_tool_calls_streaming( self.current_tool_id].get("arguments") cur_arguments = current_tool_call.get("arguments") + logger.debug("new text: %s", current_text) + new_text = delta_text.replace("\'", "\"") if not cur_arguments and not prev_arguments: From b2a08841f4c42c730d01791ed3c7bea257ca54c5 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 9 Aug 2024 10:07:57 -0500 Subject: [PATCH 150/222] fix: CLI args in docs & comments --- docs/source/serving/openai_compatible_server.md | 7 +++---- vllm/entrypoints/openai/serving_chat.py | 7 ++++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 83ddd74ee8a51..f867f5023c76b 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -129,8 +129,7 @@ _This feature is in **beta**. It has limited model support, is not guaranteed to well-defined failure modes._ As such, it must be explicitly enabled when desired. To enable this feature, you must set the following flags: -* `--enable-api-tools` -- **mandatory** for Auto tool choice. tells vLLM that you want to enable tool templating and extraction. -* `--enable-auto-toolchoice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it +* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it deems appropriate. * `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages that contain previously generated tool calls.This argument can be set to `tool_use` if your model has a tool use chat @@ -151,7 +150,7 @@ Supported models in this series: _Note that the Hermes 2 **Theta** models are known to have degraded tool call quality & capabilities due to the merge step in their creation_. It is recommended to use the Hermes 2 **Pro** models. -Recommended flags: `--tool-parser hermes --chat-template examples/tool_chat_template_hermes.jinja` +Recommended flags: `--tool-call-parser hermes --chat-template examples/tool_chat_template_hermes.jinja` #### Mistral Models Supported models: @@ -166,4 +165,4 @@ with double-quotes in mistral-generated tool calls. Therefore, **it is important arguments do not contain single quotes.** Escaped double quotes may be handled properly, but otherwise you should expect parser issues. -Recommended flags: `--tool-parser mistral --chat-template examples/tool_chat_template_mistral.jinja` +Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral.jinja` diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index e7053ab4f8f7c..7320e4d32dd2f 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -83,7 +83,7 @@ def __init__(self, self.tool_parser = Hermes2ProToolParser else: raise TypeError( - "Error: --enable-auto-tool-choice requires --tool-parser") + "Error: --enable-auto-tool-choice requires --tool-call-parser") async def create_chat_completion( self, @@ -153,12 +153,13 @@ async def create_chat_completion( return self.create_error_response( "tool_choice = \"required\" is not supported!") - # "auto" tools requires --enable-auto-tool-choice and --tool-parser + # "auto" tools requires --enable-auto-tool-choice + # and --tool-call-parser if request.tool_choice == "auto" and not ( self.enable_auto_tools and self.tool_parser is not None): return self.create_error_response( "\"auto\" tool choice requires " - "--enable-auto-tool-choice and --tool-parser to be set") + "--enable-auto-tool-choice and --tool-call-parser to be set") request_id = f"chat-{random_uuid()}" try: From 698d11291bb046e55949d69de13296c15929e3e1 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 9 Aug 2024 10:08:40 -0500 Subject: [PATCH 151/222] fix: remove line about singlq quotes for mistral --- docs/source/serving/openai_compatible_server.md | 7 ------- 1 file changed, 7 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index f867f5023c76b..1fe111478c852 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -157,12 +157,5 @@ Supported models: * `mistralai/Mistral-7B-Instruct-v0.3` * Possibly mistral-large and mixtral? These have not been tested at the time of this writing. -There is a several known with tool-calling in Mistral models: -* Mistral function-calling / tool use generates calls with _single_ quotes `'` instead of double quotes `"`. As a -result, tool call generations can't be handled as JSON by the parser automatically without using `eval`, which would -present security issues for vLLM users. As a result, to support Mistral tool calls, we find-and-replace single-quotes -with double-quotes in mistral-generated tool calls. Therefore, **it is important to ensure that your tool call -arguments do not contain single quotes.** Escaped double quotes may be handled properly, but otherwise you should -expect parser issues. Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral.jinja` From 2e6a48a068d29bdca2634eb61b3f48ef9a19fb42 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 9 Aug 2024 10:09:04 -0500 Subject: [PATCH 152/222] fix: update doc in mistral tool parser as well --- .../openai/tool_parsers/mistral_tool_parser.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 5b9bf90e765a8..2bc9bfc6792ae 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -24,19 +24,9 @@ class MistralToolParser(ToolParser): """ Tool call parser for Mistral 7B Instruct v0.3, intended for use with the - examples/tool_chat_template_mistral.jinja template. There is an - IMPORTANT CAVEAT for this parser: - - NOTE: Mistral's tool call format, that this translates into an OpenAI - format, uses SINGLE QUOTES which cannot be parsed to JSON. To enable - JSON parsing and serialization, we find-and-replace these with - DOUBLE QUOTES. To prevent tool call corruption / deserialization - failure, ensure that your tool calls and in particular your - ARGUMENTS never contain single or double quotes except as JSON - control characters. - - Used when --enable-auto-tool-choice --tool-call-parser - mistral are all set + examples/tool_chat_template_mistral.jinja template. + + Used when --enable-auto-tool-choice --tool-call-parser gmistral are all set """ # the bot_token is the token indicating tool call(s) follow. Tokens before From 9078126d07eb65e4d411c75cf9c545ab54f4cdb1 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 9 Aug 2024 12:21:05 -0500 Subject: [PATCH 153/222] fix: update mistral chat templates and docs for mistral tool calling --- .../serving/openai_compatible_server.md | 20 ++++ examples/tool_chat_template_mistral.jinja | 10 +- .../tool_chat_template_mistral_parallel.jinja | 94 +++++++++++++++++++ 3 files changed, 119 insertions(+), 5 deletions(-) create mode 100644 examples/tool_chat_template_mistral_parallel.jinja diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 1fe111478c852..8211410e96d87 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -157,5 +157,25 @@ Supported models: * `mistralai/Mistral-7B-Instruct-v0.3` * Possibly mistral-large and mixtral? These have not been tested at the time of this writing. +Known issues: +1. Mistral 7B struggles to generate parallel tool calls correctly. +2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is +much shorter than what vLLM generates. + +To address this, the following additional chat templates are provided: + +* `examples/tool_chat_template_mistral.jinja` - this is the "official" Mistral chat template, but tweaked so that +it works with vLLM's tool call IDs (provided `tool_call_id` fields are truncated to the last 9 digits) +* `examples/tool_chat_template_mistral_parallel.jinja` - this is a "better" version that adds a tool-use system prompt +when tools are provided, that results in much better reliability when working with parallel tool calling. + +**Please note** that the model's default chat template in `tokenizer_config.json` will not work with vLLM, as it expects +tool_call_id fields to be exactly 9 digits, which is shorter than vLLM's format. You **must** do one of the following +to get tool calling to work with mistral: +1. use one of the 2 provided tool chat templates +2. provide your own tool chat template that corrects for this +3. in your client code, ignore the vLLM-generated `tool_call_id`, and manually generate and pass in your own 9-digit +`tool_call_id`s for `assistant`-role messages containing tool calls, and `tool`-role messages containing tool call +results. Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral.jinja` diff --git a/examples/tool_chat_template_mistral.jinja b/examples/tool_chat_template_mistral.jinja index 49855b6506f9f..49691f59c2f2c 100644 --- a/examples/tool_chat_template_mistral.jinja +++ b/examples/tool_chat_template_mistral.jinja @@ -4,9 +4,9 @@ {%- else %} {%- set loop_messages = messages %} {%- endif %} - {%- if not tools is defined %} - {%- set tools = none %} - {%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} {%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %} {%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %} @@ -43,7 +43,7 @@ {{- "[/AVAILABLE_TOOLS]" }} {%- endif %} {%- if loop.last and system_message is defined %} - {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }} + {{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }} {%- else %} {{- "[INST] " + message["content"] + "[/INST]" }} {%- endif %} @@ -68,7 +68,7 @@ {%- endif %} {%- endfor %} {%- elif message["role"] == "assistant" %} - {{- " " + message["content"] + eos_token}} + {{- " " + message["content"] + eos_token }} {%- elif message["role"] == "tool_results" or message["role"] == "tool" %} {%- if message.content is defined and message.content.content is defined %} {%- set content = message.content.content %} diff --git a/examples/tool_chat_template_mistral_parallel.jinja b/examples/tool_chat_template_mistral_parallel.jinja new file mode 100644 index 0000000000000..a294cbfd026be --- /dev/null +++ b/examples/tool_chat_template_mistral_parallel.jinja @@ -0,0 +1,94 @@ +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} +{%- if tools is defined %} + {%- set parallel_tool_prompt = "You are a helpful assistant that can call tools. If you call one or more tools, format them in a single JSON array or objects, where each object is a tool call, not as separate objects outside of an array or multiple arrays. Use the format [{\"name\": tool call name, \"arguments\": tool call arguments}, additional tool calls] if you call more than one tool. If you call tools, do not attempt to interpret them or otherwise provide a response until you receive a tool call result that you can interpret for the user." %} + {%- if system_message is defined %} + {%- set system_message = parallel_tool_prompt + "\n\n" + system_message %} + {%- else %} + {%- set system_message = parallel_tool_prompt %} + {%- endif %} +{%- endif %} +{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %} + +{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %} + {%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %} + {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif %} +{%- endfor %} + +{{- bos_token }} +{%- for message in loop_messages %} + {%- if message["role"] == "user" %} + {%- if tools is not none and (message == user_messages[-1]) %} + {{- "[AVAILABLE_TOOLS] [" }} + {%- for tool in tools %} + {%- set tool = tool.function %} + {{- '{"type": "function", "function": {' }} + {%- for key, val in tool.items() if key != "return" %} + {%- if val is string %} + {{- '"' + key + '": "' + val + '"' }} + {%- else %} + {{- '"' + key + '": ' + val|tojson }} + {%- endif %} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- "}}" }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- "[/AVAILABLE_TOOLS]" }} + {%- endif %} + {%- if loop.last and system_message is defined %} + {{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }} + {%- else %} + {{- "[INST] " + message["content"] + "[/INST]" }} + {%- endif %} + {%- elif message["role"] == "tool_calls" or message.tool_calls is defined %} + {%- if message.tool_calls is defined %} + {%- set tool_calls = message.tool_calls %} + {%- else %} + {%- set tool_calls = message.content %} + {%- endif %} + {{- "[TOOL_CALLS] [" }} + {%- for tool_call in tool_calls %} + {%- set out = tool_call.function|tojson %} + {{- out[:-1] }} + {%- if not tool_call.id is defined or tool_call.id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }} + {%- endif %} + {{- ', "id": "' + tool_call.id[-9:] + '"}' }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" + eos_token }} + {%- endif %} + {%- endfor %} + {%- elif message["role"] == "assistant" %} + {{- " " + message["content"] + eos_token }} + {%- elif message["role"] == "tool_results" or message["role"] == "tool" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }} + {%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }} + {%- endif %} + {{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }} + {%- else %} + {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %} From f627b4806a9a70010f4af7db411521d5b762c981 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 9 Aug 2024 12:21:53 -0500 Subject: [PATCH 154/222] format: serving_chat --- vllm/entrypoints/openai/serving_chat.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 7320e4d32dd2f..ca44f00485151 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -82,8 +82,8 @@ def __init__(self, elif tool_parser == "hermes": self.tool_parser = Hermes2ProToolParser else: - raise TypeError( - "Error: --enable-auto-tool-choice requires --tool-call-parser") + raise TypeError("Error: --enable-auto-tool-choice requires " + "--tool-call-parser") async def create_chat_completion( self, @@ -451,7 +451,6 @@ async def chat_completion_stream_generator( if self._should_check_for_unstreamed_tool_arg_tokens( delta_message, output) and tool_parser: - # get the expected call based on partial JSON # parsing which "autocompletes" the JSON expected_call = json.dumps( @@ -755,11 +754,11 @@ def _should_check_for_unstreamed_tool_arg_tokens( # yapf: disable return bool( - # if there is a delta message that includes tool calls which - # include a function that has arguments - self.enable_auto_tools and self.tool_parser and delta_message - and delta_message.tool_calls and delta_message.tool_calls[0] - and delta_message.tool_calls[0].function - and delta_message.tool_calls[0].function.arguments is not None - and output.finish_reason is not None + # if there is a delta message that includes tool calls which + # include a function that has arguments + self.enable_auto_tools and self.tool_parser and delta_message + and delta_message.tool_calls and delta_message.tool_calls[0] + and delta_message.tool_calls[0].function + and delta_message.tool_calls[0].function.arguments is not None + and output.finish_reason is not None ) From ce9afeb8a9552e2062d573273c6a1a4b17f81319 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 9 Aug 2024 14:20:03 -0500 Subject: [PATCH 155/222] doc: in example, explain how to start server for usage --- .../openai_chat_completion_client_with_tools.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/examples/openai_chat_completion_client_with_tools.py b/examples/openai_chat_completion_client_with_tools.py index adf0632f15b7a..5d423a8e65aa0 100644 --- a/examples/openai_chat_completion_client_with_tools.py +++ b/examples/openai_chat_completion_client_with_tools.py @@ -1,3 +1,20 @@ +""" +Set up this example by starting a vLLM OpenAI-compatible server with tool call +options enabled. For example: + +IMPORTANT: for mistral, you must use one of the provided mistral tool call +templates, or your own - the model default doesn't work for tool calls with vLLM +See the vLLM docs on OpenAI server & tool calling for more details. + +vllm serve --model mistralai/Mistral-7B-Instruct-v0.3 \ + --chat-template examples/tool_chat_template_mistral.jinja \ + --enable-auto-tool-choice --tool-call-parser mistral + +OR +vllm serve --model NousResearch/Hermes-2-Pro-Llama-3-8B \ + --chat-template examples/tool_chat_template_hermes.jinja \ + --enable-auto-tool-choice --tool-call-parser hermes +""" import json from openai import OpenAI From 6a2757bc4edb81ea98a0cc2a2b748ab54afc614e Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 9 Aug 2024 14:31:06 -0500 Subject: [PATCH 156/222] chore: docs in serving_chat and swap `or ""` for the correct type annotation on the left-hand operand of assignment --- vllm/entrypoints/openai/serving_chat.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index ca44f00485151..e7d4fc336a3fb 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -267,6 +267,9 @@ async def chat_completion_stream_generator( # Send first response for each request.n (index) with # the role role = self.get_chat_request_role(request) + + # NOTE num_choices defaults to 1 so this usually executes + # once per request for i in range(num_choices): choice_data = ChatCompletionResponseStreamChoice( index=i, @@ -279,14 +282,18 @@ async def chat_completion_stream_generator( created=created_time, choices=[choice_data], model=model_name) + + # if usage should be included if (request.stream_options and request.stream_options.include_usage): + # if continuous usage stats are requested, add it if request.stream_options.continuous_usage_stats: prompt_tokens = len(res.prompt_token_ids) usage = UsageInfo(prompt_tokens=prompt_tokens, completion_tokens=0, total_tokens=prompt_tokens) chunk.usage = usage + # otherwise don't else: chunk.usage = None @@ -296,11 +303,11 @@ async def chat_completion_stream_generator( # Send response to echo the input portion of the # last message if request.echo: - last_msg_content = "" + last_msg_content: Optional[str] = "" if conversation and conversation[-1].get( "content") and conversation[-1].get( "role") == role: - last_msg_content = conversation[-1]["content"] or "" + last_msg_content = conversation[-1]["content"] if last_msg_content: for i in range(num_choices): @@ -418,6 +425,8 @@ async def chat_completion_stream_generator( created=created_time, choices=[choice_data], model=model_name) + + # handle usage stats if requested & if continuous if (request.stream_options and request.stream_options.include_usage): if (request.stream_options.continuous_usage_stats): From 66aa580a8fe3c6d73fdecc9abb387f5dcf33a321 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 9 Aug 2024 19:33:42 -0500 Subject: [PATCH 157/222] fix: patch the hermes chat template which was missing a quote (@Nous bls fix in huggingface) --- examples/tool_chat_template_hermes.jinja | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/tool_chat_template_hermes.jinja b/examples/tool_chat_template_hermes.jinja index 3cc07d9ad0525..7da8db52de891 100644 --- a/examples/tool_chat_template_hermes.jinja +++ b/examples/tool_chat_template_hermes.jinja @@ -39,7 +39,7 @@ {%- set tool = tool.function %} {%- endif %} {{- '{"type": "function", "function": ' }} - {{- '{"name": ' + tool.name + '", ' }} + {{- '{"name": "' + tool.name + '", ' }} {{- '"description": "' + tool.name + '(' }} {%- for param_name, param_fields in tool.parameters.properties|items %} {{- param_name + ": " + json_to_python_type(param_fields) }} From b43b8a96d2d44af5cb2d2ed0a00a11b52182724d Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 9 Aug 2024 19:43:53 -0500 Subject: [PATCH 158/222] tests: add parametrized pytest fixture, and begin adding tests --- tests/entrypoints/openai/test_tools.py | 220 +++++++++++++++++++++++++ 1 file changed, 220 insertions(+) create mode 100644 tests/entrypoints/openai/test_tools.py diff --git a/tests/entrypoints/openai/test_tools.py b/tests/entrypoints/openai/test_tools.py new file mode 100644 index 0000000000000..5f36aa893237f --- /dev/null +++ b/tests/entrypoints/openai/test_tools.py @@ -0,0 +1,220 @@ +import pytest +from ...utils import VLLM_PATH, RemoteOpenAIServer +from typing import List, TypedDict, Dict +import openai +from openai.types.chat import ChatCompletionMessageParam + + +class ServerConfig(TypedDict): + model: str + arguments: List[str] + + +class TestConfig(TypedDict): + client: openai.AsyncOpenAI + model: str + + +ARGS: List[str] = [ + "--dtype", "half", # TODO change to BF16 + "--kv-cache-dtype", "fp8", + "--enable-auto-tool-choice" +] + +CONFIGS: Dict[str, ServerConfig] = { + "hermes": { + "model": "NousResearch/Hermes-2-Pro-Llama-3-8B", + "arguments": [ + "--tool-call-parser", + "hermes", + "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja") + ] + }, + "mistral": { + "model": "mistralai/Mistral-7B-Instruct-v0.3", + "arguments": [ + "--tool-call-parser", + "mistral", + "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja") + ] + + } +} + +MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [ + { + "role": "user", + "content": "Hi! How are you?" + }, + { + "role": "assistant", + "content": "I'm doing great! How can I assist you?" + }, + { + "role": "user", + "content": "Can you write a simple 'hello world' program in python?" + } +] + +WEATHER_TOOL = { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, " + "e.g. 'San Francisco'" + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state " + "that the city is in, e.g. 'CA' which would " + "mean 'California'" + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": [ + "celsius", + "fahrenheit" + ] + } + } + } + } +} + +SEARCH_TOOL = { + "type": "function", + "function": { + "name": "web_search", + "description": "Search the internet and get a summary of the top " + "10 webpages. Should only be used if you don't know " + "the answer to a user query, and the results are likely" + "to be able to be found with a web search", + "parameters": { + "type": "object", + "properties": { + "search_term": { + "type": "string", + "description": "The term to use in the search. This should" + "ideally be keywords to search for, not a" + "natural-language question" + } + }, + "required": ["search_term"] + } + } +} + + +# Parameterize with the keys in the configs dict, instead of the items, or using +# a list of configs, so that if tests fail, we can easily/nicely see which +# model/config was the param that caused the failure. +@pytest.fixture(params=CONFIGS.keys(), scope="module") +def config(request) -> ServerConfig: + server_config: ServerConfig = CONFIGS[request.param] + with (RemoteOpenAIServer(server_config["model"], + ARGS + server_config["arguments"]) as server): + yield TestConfig( + client=server.get_async_client(), + model=server_config["model"] + ) + + +@pytest.mark.asyncio +async def test_get_models(config: TestConfig): + client = config["client"] + assert client is not None + assert isinstance(client, openai.AsyncOpenAI) + + models = await client.models.list() + assert len(models.data) == 1 + + +# test: make sure chat completions without tools provided work even when tools +# are enabled. This makes sure tool call chat templates work, AND that the tool +# parser stream processing doesn't change the output of the model. +@pytest.mark.asyncio +async def test_chat_completion_without_tools(config: TestConfig): + chat_completion = await config["client"].chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=16, + model=config["model"], + logprobs=False + ) + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + output_text = chat_completion.choices[0].message.content + + # check to make sure we got text + assert output_text is not None + assert len(output_text) > 0 + + # check to make sure no tool calls were returned + assert (choice.message.tool_calls is None or + len(choice.message.tool_calls) == 0) + + # make the same request, streaming + stream = await config["client"].chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=16, + model=config["model"], + logprobs=False, + stream=True, + ) + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + # assemble streamed chunks + async for chunk in stream: + delta = chunk.choices[0].delta + + # make sure the role is assistant + if delta.role: + assert delta.role == 'assistant' + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + + # make sure tool call chunks aren't being streamed + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + # make sure the role was sent, only 1 finish reason was sent, that chunks + # were in fact sent, and that the chunks match non-streaming + assert role_sent + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == stop_reason + assert len(chunks) + assert "".join(chunks) == output_text + +# test: conversation with tools enabled and provided that should not invoke +# tools, to make sure we can still get normal chat completion responses +# and that they won't be parsed as tools +@pytest.mark.asyncio +async def test_chat_completion_with_tools_expecting_chat(config: TestConfig): + pass + +# test: request a chat completion that should return tool calls, so we know they +# are parsable +@pytest.mark.asyncio +async def test_chat_completion_with_tools_expecting_tools(config: TestConfig): + pass +# test: providing tools and results back to model to get a non-tool response (streaming/not) + +# test: getting the model to generate parallel tool calls (streaming/not) + +# test: providing parallel tool calls back to the model to get a response (streaming/not) From 01d528c8bf943be09aa5e5722401dbe7ad16eb0d Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 10 Aug 2024 12:32:42 -0500 Subject: [PATCH 159/222] fix: formatting --- tests/entrypoints/openai/test_tools.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/entrypoints/openai/test_tools.py b/tests/entrypoints/openai/test_tools.py index 5f36aa893237f..00bdcc89d6178 100644 --- a/tests/entrypoints/openai/test_tools.py +++ b/tests/entrypoints/openai/test_tools.py @@ -204,17 +204,16 @@ async def test_chat_completion_without_tools(config: TestConfig): # test: conversation with tools enabled and provided that should not invoke # tools, to make sure we can still get normal chat completion responses # and that they won't be parsed as tools -@pytest.mark.asyncio -async def test_chat_completion_with_tools_expecting_chat(config: TestConfig): - pass + # test: request a chat completion that should return tool calls, so we know they # are parsable -@pytest.mark.asyncio -async def test_chat_completion_with_tools_expecting_tools(config: TestConfig): - pass -# test: providing tools and results back to model to get a non-tool response (streaming/not) + + +# test: providing tools and results back to model to get a non-tool response +# (streaming/not) # test: getting the model to generate parallel tool calls (streaming/not) -# test: providing parallel tool calls back to the model to get a response (streaming/not) +# test: providing parallel tool calls back to the model to get a response +# (streaming/not) From b9397e3af34dd1aebe691f906e5177b19b45ebcf Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 10 Aug 2024 13:06:40 -0500 Subject: [PATCH 160/222] fix: test formatting that was causing format.sh to crash --- tests/entrypoints/openai/test_tools.py | 153 +++++++++++++------------ 1 file changed, 80 insertions(+), 73 deletions(-) diff --git a/tests/entrypoints/openai/test_tools.py b/tests/entrypoints/openai/test_tools.py index 00bdcc89d6178..f7ee5c183bf54 100644 --- a/tests/entrypoints/openai/test_tools.py +++ b/tests/entrypoints/openai/test_tools.py @@ -1,9 +1,11 @@ -import pytest -from ...utils import VLLM_PATH, RemoteOpenAIServer -from typing import List, TypedDict, Dict +from typing import Dict, List, TypedDict + import openai +import pytest from openai.types.chat import ChatCompletionMessageParam +from ...utils import VLLM_PATH, RemoteOpenAIServer + class ServerConfig(TypedDict): model: str @@ -16,47 +18,48 @@ class TestConfig(TypedDict): ARGS: List[str] = [ - "--dtype", "half", # TODO change to BF16 - "--kv-cache-dtype", "fp8", + "--dtype", + "half", # TODO change to BF16 + "--kv-cache-dtype", + "fp8", "--enable-auto-tool-choice" ] CONFIGS: Dict[str, ServerConfig] = { "hermes": { - "model": "NousResearch/Hermes-2-Pro-Llama-3-8B", + "model": + "NousResearch/Hermes-2-Pro-Llama-3-8B", "arguments": [ - "--tool-call-parser", - "hermes", - "--chat-template", + "--tool-call-parser", "hermes", "--chat-template", str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja") ] }, "mistral": { - "model": "mistralai/Mistral-7B-Instruct-v0.3", + "model": + "mistralai/Mistral-7B-Instruct-v0.3", "arguments": [ - "--tool-call-parser", - "mistral", - "--chat-template", + "--tool-call-parser", "mistral", "--chat-template", str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja") ] - } } -MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [ - { - "role": "user", - "content": "Hi! How are you?" - }, - { - "role": "assistant", - "content": "I'm doing great! How can I assist you?" - }, - { - "role": "user", - "content": "Can you write a simple 'hello world' program in python?" - } -] +MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "Hi! How are you?" +}, { + "role": + "assistant", + "content": + "I'm doing great! How can I assist you?" +}, { + "role": + "user", + "content": + "Can you write a simple 'hello world' program in python?" +}] WEATHER_TOOL = { "type": "function", @@ -67,23 +70,24 @@ class TestConfig(TypedDict): "type": "object", "properties": { "city": { - "type": "string", - "description": "The city to find the weather for, " - "e.g. 'San Francisco'" + "type": + "string", + "description": + "The city to find the weather for, " + "e.g. 'San Francisco'" }, "state": { - "type": "string", - "description": "the two-letter abbreviation for the state " - "that the city is in, e.g. 'CA' which would " - "mean 'California'" + "type": + "string", + "description": + "the two-letter abbreviation for the state " + "that the city is in, e.g. 'CA' which would " + "mean 'California'" }, "unit": { "type": "string", "description": "The unit to fetch the temperature in", - "enum": [ - "celsius", - "fahrenheit" - ] + "enum": ["celsius", "fahrenheit"] } } } @@ -93,19 +97,23 @@ class TestConfig(TypedDict): SEARCH_TOOL = { "type": "function", "function": { - "name": "web_search", - "description": "Search the internet and get a summary of the top " - "10 webpages. Should only be used if you don't know " - "the answer to a user query, and the results are likely" - "to be able to be found with a web search", + "name": + "web_search", + "description": + "Search the internet and get a summary of the top " + "10 webpages. Should only be used if you don't know " + "the answer to a user query, and the results are likely" + "to be able to be found with a web search", "parameters": { "type": "object", "properties": { "search_term": { - "type": "string", - "description": "The term to use in the search. This should" - "ideally be keywords to search for, not a" - "natural-language question" + "type": + "string", + "description": + "The term to use in the search. This should" + "ideally be keywords to search for, not a" + "natural-language question" } }, "required": ["search_term"] @@ -113,24 +121,25 @@ class TestConfig(TypedDict): } } +configKeys = CONFIGS.keys() + -# Parameterize with the keys in the configs dict, instead of the items, or using -# a list of configs, so that if tests fail, we can easily/nicely see which -# model/config was the param that caused the failure. -@pytest.fixture(params=CONFIGS.keys(), scope="module") -def config(request) -> ServerConfig: - server_config: ServerConfig = CONFIGS[request.param] - with (RemoteOpenAIServer(server_config["model"], - ARGS + server_config["arguments"]) as server): - yield TestConfig( - client=server.get_async_client(), - model=server_config["model"] - ) +@pytest.fixture(scope="module", params=configKeys) +def client_config(request): + print('param', request.param) + server_config: ServerConfig = CONFIGS["hermes"] + model = server_config["model"] + args_for_model = server_config["arguments"] + with RemoteOpenAIServer(model, ARGS + args_for_model) as server: + client = server.get_async_client() + yield TestConfig(client=client, model=model) @pytest.mark.asyncio -async def test_get_models(config: TestConfig): - client = config["client"] +async def test_get_models(client_config: TestConfig): + client = client_config["client"] + model = client_config["model"] + print('Running test_get_models for ', model) assert client is not None assert isinstance(client, openai.AsyncOpenAI) @@ -142,14 +151,13 @@ async def test_get_models(config: TestConfig): # are enabled. This makes sure tool call chat templates work, AND that the tool # parser stream processing doesn't change the output of the model. @pytest.mark.asyncio -async def test_chat_completion_without_tools(config: TestConfig): - chat_completion = await config["client"].chat.completions.create( +async def test_chat_completion_without_tools(client_config: TestConfig): + chat_completion = await client_config["client"].chat.completions.create( messages=MESSAGES_WITHOUT_TOOLS, temperature=0, max_tokens=16, - model=config["model"], - logprobs=False - ) + model=client_config["model"], + logprobs=False) choice = chat_completion.choices[0] stop_reason = chat_completion.choices[0].finish_reason output_text = chat_completion.choices[0].message.content @@ -159,15 +167,15 @@ async def test_chat_completion_without_tools(config: TestConfig): assert len(output_text) > 0 # check to make sure no tool calls were returned - assert (choice.message.tool_calls is None or - len(choice.message.tool_calls) == 0) + assert (choice.message.tool_calls is None + or len(choice.message.tool_calls) == 0) # make the same request, streaming - stream = await config["client"].chat.completions.create( + stream = await client_config["client"].chat.completions.create( messages=MESSAGES_WITHOUT_TOOLS, temperature=0, max_tokens=16, - model=config["model"], + model=client_config["model"], logprobs=False, stream=True, ) @@ -201,15 +209,14 @@ async def test_chat_completion_without_tools(config: TestConfig): assert len(chunks) assert "".join(chunks) == output_text + # test: conversation with tools enabled and provided that should not invoke # tools, to make sure we can still get normal chat completion responses # and that they won't be parsed as tools - # test: request a chat completion that should return tool calls, so we know they # are parsable - # test: providing tools and results back to model to get a non-tool response # (streaming/not) From d80ac42123b786d1ebf9fb7350b1563ad4975b99 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 10 Aug 2024 14:20:04 -0500 Subject: [PATCH 161/222] test: add test to do tool choice --- tests/entrypoints/openai/test_tools.py | 176 ++++++++++++++++++++++--- 1 file changed, 158 insertions(+), 18 deletions(-) diff --git a/tests/entrypoints/openai/test_tools.py b/tests/entrypoints/openai/test_tools.py index f7ee5c183bf54..b629ca247baca 100644 --- a/tests/entrypoints/openai/test_tools.py +++ b/tests/entrypoints/openai/test_tools.py @@ -1,11 +1,40 @@ -from typing import Dict, List, TypedDict +import json +from typing import Dict, List, Literal, Optional import openai import pytest from openai.types.chat import ChatCompletionMessageParam +from typing_extensions import NotRequired, TypedDict from ...utils import VLLM_PATH, RemoteOpenAIServer +# we need this because this is more precise than the existing definition in +# vll.entrypoints.openai.protocol which inherits BaseModel. for literals, I need +# a dict to check against + + +class OaiToolFunctionParamProperties(TypedDict): + type: str + description: Optional[str] + enum: NotRequired[List[str]] + + +class OAiToolFunctionParams(TypedDict): + type: Literal["object"] + properties: Dict[str, OaiToolFunctionParamProperties] + required: NotRequired[List[str]] + + +class OAiFunctionDefinition(TypedDict): + name: str + description: str + parameters: OAiToolFunctionParams + + +class OpenAICompatibleToolDefinition(TypedDict): + type: Literal["function"] + function: OAiFunctionDefinition + class ServerConfig(TypedDict): model: str @@ -58,10 +87,17 @@ class TestConfig(TypedDict): "role": "user", "content": - "Can you write a simple 'hello world' program in python?" + "Can you tell me a joke?" +}] + +MESSAGES_ASKING_FOR_TOOLS: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas in Fahrenheit?" }] -WEATHER_TOOL = { +WEATHER_TOOL: OpenAICompatibleToolDefinition = { "type": "function", "function": { "name": "get_current_weather", @@ -94,7 +130,7 @@ class TestConfig(TypedDict): } } -SEARCH_TOOL = { +SEARCH_TOOL: OpenAICompatibleToolDefinition = { "type": "function", "function": { "name": @@ -135,18 +171,6 @@ def client_config(request): yield TestConfig(client=client, model=model) -@pytest.mark.asyncio -async def test_get_models(client_config: TestConfig): - client = client_config["client"] - model = client_config["model"] - print('Running test_get_models for ', model) - assert client is not None - assert isinstance(client, openai.AsyncOpenAI) - - models = await client.models.list() - assert len(models.data) == 1 - - # test: make sure chat completions without tools provided work even when tools # are enabled. This makes sure tool call chat templates work, AND that the tool # parser stream processing doesn't change the output of the model. @@ -155,7 +179,7 @@ async def test_chat_completion_without_tools(client_config: TestConfig): chat_completion = await client_config["client"].chat.completions.create( messages=MESSAGES_WITHOUT_TOOLS, temperature=0, - max_tokens=16, + max_tokens=128, model=client_config["model"], logprobs=False) choice = chat_completion.choices[0] @@ -174,7 +198,7 @@ async def test_chat_completion_without_tools(client_config: TestConfig): stream = await client_config["client"].chat.completions.create( messages=MESSAGES_WITHOUT_TOOLS, temperature=0, - max_tokens=16, + max_tokens=128, model=client_config["model"], logprobs=False, stream=True, @@ -213,9 +237,125 @@ async def test_chat_completion_without_tools(client_config: TestConfig): # test: conversation with tools enabled and provided that should not invoke # tools, to make sure we can still get normal chat completion responses # and that they won't be parsed as tools +@pytest.mark.asyncio +async def test_chat_completion_with_tools(client_config: TestConfig): + chat_completion = await client_config["client"].chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=128, + model=client_config["model"], + tools=[WEATHER_TOOL], + logprobs=False) + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + output_text = chat_completion.choices[0].message.content + + # check to make sure we got text + assert output_text is not None + assert stop_reason != 'tool_calls' + assert len(output_text) > 0 + + # check to make sure no tool calls were returned + assert (choice.message.tool_calls is None + or len(choice.message.tool_calls) == 0) + + # make the same request, streaming + stream = await client_config["client"].chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=128, + model=client_config["model"], + logprobs=False, + tools=[WEATHER_TOOL], + stream=True, + ) + + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + # assemble streamed chunks + async for chunk in stream: + delta = chunk.choices[0].delta + + # make sure the role is assistant + if delta.role: + assert delta.role == 'assistant' + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + + # make sure tool call chunks aren't being streamed + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + # make sure the role was sent, only 1 finish reason was sent, that chunks + # were in fact sent, and that the chunks match non-streaming + assert role_sent + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == stop_reason + assert chunk.choices[0].finish_reason != 'tool_calls' + assert len(chunks) + assert "".join(chunks) == output_text + # test: request a chat completion that should return tool calls, so we know they # are parsable +@pytest.mark.asyncio +async def test_tool_call(client_config: TestConfig): + chat_completion = await client_config["client"].chat.completions.create( + messages=MESSAGES_ASKING_FOR_TOOLS, + temperature=0, + max_tokens=500, + model=client_config["model"], + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + tool_calls = chat_completion.choices[0].message.tool_calls + + # make sure a tool call is present + assert choice.message.role == 'assistant' + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0].type == 'function' + assert tool_calls[0].function is not None + assert isinstance(tool_calls[0].id, str) + assert len(tool_calls[0].id) > 16 + + # make sure the weather tool was called (classic example) with arguments + assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"] + assert tool_calls[0].function.arguments is not None + assert isinstance(tool_calls[0].function.arguments, str) + + # make sure the arguments parse properly + parsed_arguments = json.loads(tool_calls[0].function.arguments) + assert isinstance(parsed_arguments, Dict) + assert isinstance(parsed_arguments.get("city"), str) + assert isinstance(parsed_arguments.get("state"), str) + assert parsed_arguments.get("city") == "Dallas" + assert parsed_arguments.get("state") == "TX" + + assert stop_reason == "tool_calls" + """ + # make the same request, streaming + stream = await client_config["client"].chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=128, + model=client_config["model"], + logprobs=False, + tools=[WEATHER_TOOL], + stream=True, + ) + """ + + pass + # test: providing tools and results back to model to get a non-tool response # (streaming/not) From e9a857f350ec4ee7b0ae8e0502e3b74b4202b163 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sun, 11 Aug 2024 21:06:30 -0500 Subject: [PATCH 162/222] test: add test for tool calling with tool choice; very non-streaming and streaming outputs match --- tests/entrypoints/openai/test_tools.py | 88 +++++++++++++++++++++++--- 1 file changed, 78 insertions(+), 10 deletions(-) diff --git a/tests/entrypoints/openai/test_tools.py b/tests/entrypoints/openai/test_tools.py index b629ca247baca..47ee33dc97c1d 100644 --- a/tests/entrypoints/openai/test_tools.py +++ b/tests/entrypoints/openai/test_tools.py @@ -305,7 +305,7 @@ async def test_chat_completion_with_tools(client_config: TestConfig): # test: request a chat completion that should return tool calls, so we know they # are parsable @pytest.mark.asyncio -async def test_tool_call(client_config: TestConfig): +async def test_tool_call_and_choice(client_config: TestConfig): chat_completion = await client_config["client"].chat.completions.create( messages=MESSAGES_ASKING_FOR_TOOLS, temperature=0, @@ -341,20 +341,88 @@ async def test_tool_call(client_config: TestConfig): assert parsed_arguments.get("state") == "TX" assert stop_reason == "tool_calls" - """ + + function_name: Optional[str] = None + function_args_str: str = '' + tool_call_id: Optional[str] = None + role_name: Optional[str] = None + finish_reason_count: int = 0 + # make the same request, streaming stream = await client_config["client"].chat.completions.create( - messages=MESSAGES_WITHOUT_TOOLS, - temperature=0, - max_tokens=128, model=client_config["model"], + messages=MESSAGES_ASKING_FOR_TOOLS, + temperature=0, + max_tokens=500, + tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, - tools=[WEATHER_TOOL], - stream=True, - ) - """ + stream=True) + + async for chunk in stream: + assert chunk.choices[0].index == 0 + + if chunk.choices[0].finish_reason: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == 'tool_calls' + + # if a role is being streamed make sure it wasn't already set to + # something else + if chunk.choices[0].delta.role: + assert not role_name or role_name == 'assistant' + role_name = 'assistant' + + # if a tool call is streamed make sure there's exactly one + # (based on the request parameters + streamed_tool_calls = chunk.choices[0].delta.tool_calls + + if streamed_tool_calls and len(streamed_tool_calls) > 0: + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + # if a tool call ID is streamed, make sure one hasn't been already + if tool_call.id: + assert not tool_call_id + tool_call_id = tool_call.id + + # if parts of the function start being streamed + if tool_call.function: + # if the function name is defined, set it. it should be streamed + # IN ENTIRETY, exactly one time. + if tool_call.function.name: + assert function_name is None + assert isinstance(tool_call.function.name, str) + function_name = tool_call.function.name + if tool_call.function.arguments: + assert isinstance(tool_call.function.arguments, str) + function_args_str += tool_call.function.arguments - pass + assert finish_reason_count == 1 + assert role_name == 'assistant' + assert isinstance(tool_call_id, str) and (len(tool_call_id) > 16) + + # validate the name and arguments + assert function_name == WEATHER_TOOL["function"]["name"] + assert function_name == tool_calls[0].function.name + assert isinstance(function_args_str, str) + + # validate arguments + streamed_args = json.loads(function_args_str) + assert isinstance(streamed_args, Dict) + assert isinstance(streamed_args.get("city"), str) + assert isinstance(streamed_args.get("state"), str) + assert streamed_args.get("city") == "Dallas" + assert streamed_args.get("state") == "TX" + + # make sure everything matches non-streaming except for ID + assert function_name == tool_calls[0].function.name + assert choice.message.role == role_name + assert choice.message.tool_calls[0].function.name == function_name + + # compare streamed with non-streamed args Dict-wise, not string-wise + # because character-to-character comparison might not work e.g. the tool + # call parser adding extra spaces or something like that. we care about the + # dicts matching not byte-wise match + assert parsed_arguments == streamed_args # test: providing tools and results back to model to get a non-tool response From 1a2f8b264c15753cba645712d6e2558430e4cc55 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 12 Aug 2024 19:14:55 -0500 Subject: [PATCH 163/222] fix(tests): download models before starting openai server to prevent timeout --- tests/entrypoints/openai/test_tools.py | 72 ++++++++++++-------------- 1 file changed, 34 insertions(+), 38 deletions(-) diff --git a/tests/entrypoints/openai/test_tools.py b/tests/entrypoints/openai/test_tools.py index 47ee33dc97c1d..6acd79cadca66 100644 --- a/tests/entrypoints/openai/test_tools.py +++ b/tests/entrypoints/openai/test_tools.py @@ -1,40 +1,15 @@ import json -from typing import Dict, List, Literal, Optional +from typing import Dict, List, Optional import openai import pytest -from openai.types.chat import ChatCompletionMessageParam -from typing_extensions import NotRequired, TypedDict +from openai.types.chat import (ChatCompletionMessageParam, + ChatCompletionToolParam) +from transformers import AutoModelForCausalLM, AutoTokenizer +from typing_extensions import TypedDict from ...utils import VLLM_PATH, RemoteOpenAIServer -# we need this because this is more precise than the existing definition in -# vll.entrypoints.openai.protocol which inherits BaseModel. for literals, I need -# a dict to check against - - -class OaiToolFunctionParamProperties(TypedDict): - type: str - description: Optional[str] - enum: NotRequired[List[str]] - - -class OAiToolFunctionParams(TypedDict): - type: Literal["object"] - properties: Dict[str, OaiToolFunctionParamProperties] - required: NotRequired[List[str]] - - -class OAiFunctionDefinition(TypedDict): - name: str - description: str - parameters: OAiToolFunctionParams - - -class OpenAICompatibleToolDefinition(TypedDict): - type: Literal["function"] - function: OAiFunctionDefinition - class ServerConfig(TypedDict): model: str @@ -74,6 +49,15 @@ class TestConfig(TypedDict): } MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{ + "role": + "system", + "content": + "You are a helpful assistant with access to tools. If a tool" + " that you have would be helpful to answer a user query, " + "call the tool. Otherwise, answer the user's query directly " + "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " + "to the user's question - just respond to it normally." +}, { "role": "user", "content": @@ -87,7 +71,7 @@ class TestConfig(TypedDict): "role": "user", "content": - "Can you tell me a joke?" + "Can you tell me a joke please?" }] MESSAGES_ASKING_FOR_TOOLS: List[ChatCompletionMessageParam] = [{ @@ -97,7 +81,7 @@ class TestConfig(TypedDict): "What is the weather in Dallas, Texas in Fahrenheit?" }] -WEATHER_TOOL: OpenAICompatibleToolDefinition = { +WEATHER_TOOL: ChatCompletionToolParam = { "type": "function", "function": { "name": "get_current_weather", @@ -130,7 +114,7 @@ class TestConfig(TypedDict): } } -SEARCH_TOOL: OpenAICompatibleToolDefinition = { +SEARCH_TOOL: ChatCompletionToolParam = { "type": "function", "function": { "name": @@ -157,13 +141,23 @@ class TestConfig(TypedDict): } } -configKeys = CONFIGS.keys() + +# for each server config, download the model and return the config +@pytest.fixture(scope="module", params=CONFIGS.keys()) +def server_config(request): + config = CONFIGS[request.param] + + print(f'downloading model for {config["model"]}') + + # download model and tokenizer using transformers + AutoTokenizer.from_pretrained(config["model"]) + AutoModelForCausalLM.from_pretrained(config["model"]) + yield CONFIGS[request.param] -@pytest.fixture(scope="module", params=configKeys) -def client_config(request): - print('param', request.param) - server_config: ServerConfig = CONFIGS["hermes"] +# run this for each server config +@pytest.fixture(scope="module") +def client_config(request, server_config: ServerConfig): model = server_config["model"] args_for_model = server_config["arguments"] with RemoteOpenAIServer(model, ARGS + args_for_model) as server: @@ -239,6 +233,7 @@ async def test_chat_completion_without_tools(client_config: TestConfig): # and that they won't be parsed as tools @pytest.mark.asyncio async def test_chat_completion_with_tools(client_config: TestConfig): + print(f'sending prompt {MESSAGES_WITHOUT_TOOLS}') chat_completion = await client_config["client"].chat.completions.create( messages=MESSAGES_WITHOUT_TOOLS, temperature=0, @@ -249,6 +244,7 @@ async def test_chat_completion_with_tools(client_config: TestConfig): choice = chat_completion.choices[0] stop_reason = chat_completion.choices[0].finish_reason output_text = chat_completion.choices[0].message.content + print(chat_completion.choices[0]) # check to make sure we got text assert output_text is not None From d24ae670188cdcdad44c0d5762c99e1a33f09fec Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 12 Aug 2024 21:45:26 -0500 Subject: [PATCH 164/222] try restructuring tests to make it work --- tests/entrypoints/openai/test_tools.py | 232 +++++++++++++++++-------- 1 file changed, 160 insertions(+), 72 deletions(-) diff --git a/tests/entrypoints/openai/test_tools.py b/tests/entrypoints/openai/test_tools.py index 6acd79cadca66..b75f2144ee7e6 100644 --- a/tests/entrypoints/openai/test_tools.py +++ b/tests/entrypoints/openai/test_tools.py @@ -32,7 +32,7 @@ class TestConfig(TypedDict): CONFIGS: Dict[str, ServerConfig] = { "hermes": { "model": - "NousResearch/Hermes-2-Pro-Llama-3-8B", + "NousResearch/Hermes-2-Pro-Llama-3-8B", "arguments": [ "--tool-call-parser", "hermes", "--chat-template", str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja") @@ -40,7 +40,7 @@ class TestConfig(TypedDict): }, "mistral": { "model": - "mistralai/Mistral-7B-Instruct-v0.3", + "mistralai/Mistral-7B-Instruct-v0.3", "arguments": [ "--tool-call-parser", "mistral", "--chat-template", str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja") @@ -48,39 +48,6 @@ class TestConfig(TypedDict): } } -MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{ - "role": - "system", - "content": - "You are a helpful assistant with access to tools. If a tool" - " that you have would be helpful to answer a user query, " - "call the tool. Otherwise, answer the user's query directly " - "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " - "to the user's question - just respond to it normally." -}, { - "role": - "user", - "content": - "Hi! How are you?" -}, { - "role": - "assistant", - "content": - "I'm doing great! How can I assist you?" -}, { - "role": - "user", - "content": - "Can you tell me a joke please?" -}] - -MESSAGES_ASKING_FOR_TOOLS: List[ChatCompletionMessageParam] = [{ - "role": - "user", - "content": - "What is the weather in Dallas, Texas in Fahrenheit?" -}] - WEATHER_TOOL: ChatCompletionToolParam = { "type": "function", "function": { @@ -91,18 +58,18 @@ class TestConfig(TypedDict): "properties": { "city": { "type": - "string", + "string", "description": - "The city to find the weather for, " - "e.g. 'San Francisco'" + "The city to find the weather for, " + "e.g. 'San Francisco'" }, "state": { "type": - "string", + "string", "description": - "the two-letter abbreviation for the state " - "that the city is in, e.g. 'CA' which would " - "mean 'California'" + "the two-letter abbreviation for the state " + "that the city is in, e.g. 'CA' which would " + "mean 'California'" }, "unit": { "type": "string", @@ -118,22 +85,22 @@ class TestConfig(TypedDict): "type": "function", "function": { "name": - "web_search", + "web_search", "description": - "Search the internet and get a summary of the top " - "10 webpages. Should only be used if you don't know " - "the answer to a user query, and the results are likely" - "to be able to be found with a web search", + "Search the internet and get a summary of the top " + "10 webpages. Should only be used if you don't know " + "the answer to a user query, and the results are likely" + "to be able to be found with a web search", "parameters": { "type": "object", "properties": { "search_term": { "type": - "string", + "string", "description": - "The term to use in the search. This should" - "ideally be keywords to search for, not a" - "natural-language question" + "The term to use in the search. This should" + "ideally be keywords to search for, not a" + "natural-language question" } }, "required": ["search_term"] @@ -141,6 +108,60 @@ class TestConfig(TypedDict): } } +MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{ + "role": + "system", + "content": + "You are a helpful assistant with access to tools. If a tool" + " that you have would be helpful to answer a user query, " + "call the tool. Otherwise, answer the user's query directly " + "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " + "to the user's question - just respond to it normally." +}, { + "role": + "user", + "content": + "Hi! How are you?" +}, { + "role": + "assistant", + "content": + "I'm doing great! How can I assist you?" +}, { + "role": + "user", + "content": + "Can you tell me a joke please?" +}] + +MESSAGES_ASKING_FOR_TOOLS: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas in Fahrenheit?" +}] + +MESSAGES_WITH_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{ + "role": "user", + "content": "What is the weather in Dallas, Texas in Fahrenheit?" + },{ + "role": "assistant", + "tool_calls": [{ + "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "type": "function", + "function": { + "name": WEATHER_TOOL["function"]["name"], + "arguments": '{"city": "Dallas", "state": "TX", ' + '"unit": "fahrenheit"}' + } + }] + },{ + "role": "tool", + "tool_call_id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "content": "The weather in Dallas is 98 degrees fahrenheit, with partly" + "cloudy skies and a low chance of rain." + }] + # for each server config, download the model and return the config @pytest.fixture(scope="module", params=CONFIGS.keys()) @@ -157,24 +178,29 @@ def server_config(request): # run this for each server config @pytest.fixture(scope="module") -def client_config(request, server_config: ServerConfig): +def server(request, server_config: ServerConfig): model = server_config["model"] args_for_model = server_config["arguments"] with RemoteOpenAIServer(model, ARGS + args_for_model) as server: - client = server.get_async_client() - yield TestConfig(client=client, model=model) + yield server + +@pytest.fixture(scope="module") +def client(server: RemoteOpenAIServer): + return server.get_async_client() # test: make sure chat completions without tools provided work even when tools # are enabled. This makes sure tool call chat templates work, AND that the tool # parser stream processing doesn't change the output of the model. @pytest.mark.asyncio -async def test_chat_completion_without_tools(client_config: TestConfig): - chat_completion = await client_config["client"].chat.completions.create( +async def test_chat_completion_without_tools(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( messages=MESSAGES_WITHOUT_TOOLS, temperature=0, max_tokens=128, - model=client_config["model"], + model=model_name, logprobs=False) choice = chat_completion.choices[0] stop_reason = chat_completion.choices[0].finish_reason @@ -189,11 +215,11 @@ async def test_chat_completion_without_tools(client_config: TestConfig): or len(choice.message.tool_calls) == 0) # make the same request, streaming - stream = await client_config["client"].chat.completions.create( + stream = await client.chat.completions.create( messages=MESSAGES_WITHOUT_TOOLS, temperature=0, max_tokens=128, - model=client_config["model"], + model=model_name, logprobs=False, stream=True, ) @@ -207,6 +233,7 @@ async def test_chat_completion_without_tools(client_config: TestConfig): # make sure the role is assistant if delta.role: + assert not role_sent assert delta.role == 'assistant' role_sent = True @@ -215,6 +242,7 @@ async def test_chat_completion_without_tools(client_config: TestConfig): if chunk.choices[0].finish_reason is not None: finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason # make sure tool call chunks aren't being streamed assert not delta.tool_calls or len(delta.tool_calls) == 0 @@ -223,7 +251,6 @@ async def test_chat_completion_without_tools(client_config: TestConfig): # were in fact sent, and that the chunks match non-streaming assert role_sent assert finish_reason_count == 1 - assert chunk.choices[0].finish_reason == stop_reason assert len(chunks) assert "".join(chunks) == output_text @@ -232,19 +259,19 @@ async def test_chat_completion_without_tools(client_config: TestConfig): # tools, to make sure we can still get normal chat completion responses # and that they won't be parsed as tools @pytest.mark.asyncio -async def test_chat_completion_with_tools(client_config: TestConfig): - print(f'sending prompt {MESSAGES_WITHOUT_TOOLS}') - chat_completion = await client_config["client"].chat.completions.create( +async def test_chat_completion_with_tools(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( messages=MESSAGES_WITHOUT_TOOLS, temperature=0, max_tokens=128, - model=client_config["model"], + model=model_name, tools=[WEATHER_TOOL], logprobs=False) choice = chat_completion.choices[0] stop_reason = chat_completion.choices[0].finish_reason output_text = chat_completion.choices[0].message.content - print(chat_completion.choices[0]) # check to make sure we got text assert output_text is not None @@ -256,11 +283,11 @@ async def test_chat_completion_with_tools(client_config: TestConfig): or len(choice.message.tool_calls) == 0) # make the same request, streaming - stream = await client_config["client"].chat.completions.create( + stream = await client.chat.completions.create( messages=MESSAGES_WITHOUT_TOOLS, temperature=0, max_tokens=128, - model=client_config["model"], + model=model_name, logprobs=False, tools=[WEATHER_TOOL], stream=True, @@ -301,12 +328,14 @@ async def test_chat_completion_with_tools(client_config: TestConfig): # test: request a chat completion that should return tool calls, so we know they # are parsable @pytest.mark.asyncio -async def test_tool_call_and_choice(client_config: TestConfig): - chat_completion = await client_config["client"].chat.completions.create( +async def test_tool_call_and_choice(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( messages=MESSAGES_ASKING_FOR_TOOLS, temperature=0, max_tokens=500, - model=client_config["model"], + model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False) @@ -345,8 +374,8 @@ async def test_tool_call_and_choice(client_config: TestConfig): finish_reason_count: int = 0 # make the same request, streaming - stream = await client_config["client"].chat.completions.create( - model=client_config["model"], + stream = await client.chat.completions.create( + model=model_name, messages=MESSAGES_ASKING_FOR_TOOLS, temperature=0, max_tokens=500, @@ -423,6 +452,65 @@ async def test_tool_call_and_choice(client_config: TestConfig): # test: providing tools and results back to model to get a non-tool response # (streaming/not) +@pytest.mark.asyncio +async def test_tool_call_with_results(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITH_TOOL_RESPONSE, + temperature=0, + max_tokens=500, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False + ) + + choice = chat_completion.choices[0] + + assert choice.finish_reason != "tool_calls" # "stop" or "length" + assert choice.message.role == "assistant" + assert choice.message.tool_calls is None \ + or len(choice.message.tool_calls) == 0 + assert choice.message.content is not None + assert "98" in choice.message.content # the temperature from the response + + stream = await client.chat.completions.create( + messages=MESSAGES_WITH_TOOL_RESPONSE, + temperature=0, + max_tokens=500, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True + ) + + chunks: List[str] = [] + finish_reason: Optional[str] == None + finish_reason_count = 0 + role_sent: bool = False + + async for chunk in stream: + delta = chunk.choices[0].delta + + if delta.role: + assert not role_sent + assert delta.role == "assistant" + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason + + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + assert role_sent + assert finish_reason_count == 1 + assert len(chunks) + assert "".join(chunks) == choice.message.content + # test: getting the model to generate parallel tool calls (streaming/not) From 2c8e82f13d892ab42ba0d76b048a2bfa3b176926 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 12 Aug 2024 21:51:52 -0500 Subject: [PATCH 165/222] format: test_tools.py --- tests/entrypoints/openai/test_tools.py | 117 +++++++++++++------------ 1 file changed, 62 insertions(+), 55 deletions(-) diff --git a/tests/entrypoints/openai/test_tools.py b/tests/entrypoints/openai/test_tools.py index b75f2144ee7e6..13bb8cb33d34e 100644 --- a/tests/entrypoints/openai/test_tools.py +++ b/tests/entrypoints/openai/test_tools.py @@ -32,7 +32,7 @@ class TestConfig(TypedDict): CONFIGS: Dict[str, ServerConfig] = { "hermes": { "model": - "NousResearch/Hermes-2-Pro-Llama-3-8B", + "NousResearch/Hermes-2-Pro-Llama-3-8B", "arguments": [ "--tool-call-parser", "hermes", "--chat-template", str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja") @@ -40,7 +40,7 @@ class TestConfig(TypedDict): }, "mistral": { "model": - "mistralai/Mistral-7B-Instruct-v0.3", + "mistralai/Mistral-7B-Instruct-v0.3", "arguments": [ "--tool-call-parser", "mistral", "--chat-template", str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja") @@ -58,18 +58,18 @@ class TestConfig(TypedDict): "properties": { "city": { "type": - "string", + "string", "description": - "The city to find the weather for, " - "e.g. 'San Francisco'" + "The city to find the weather for, " + "e.g. 'San Francisco'" }, "state": { "type": - "string", + "string", "description": - "the two-letter abbreviation for the state " - "that the city is in, e.g. 'CA' which would " - "mean 'California'" + "the two-letter abbreviation for the state " + "that the city is in, e.g. 'CA' which would " + "mean 'California'" }, "unit": { "type": "string", @@ -85,22 +85,22 @@ class TestConfig(TypedDict): "type": "function", "function": { "name": - "web_search", + "web_search", "description": - "Search the internet and get a summary of the top " - "10 webpages. Should only be used if you don't know " - "the answer to a user query, and the results are likely" - "to be able to be found with a web search", + "Search the internet and get a summary of the top " + "10 webpages. Should only be used if you don't know " + "the answer to a user query, and the results are likely" + "to be able to be found with a web search", "parameters": { "type": "object", "properties": { "search_term": { "type": - "string", + "string", "description": - "The term to use in the search. This should" - "ideally be keywords to search for, not a" - "natural-language question" + "The term to use in the search. This should" + "ideally be keywords to search for, not a" + "natural-language question" } }, "required": ["search_term"] @@ -110,57 +110,65 @@ class TestConfig(TypedDict): MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{ "role": - "system", + "system", "content": - "You are a helpful assistant with access to tools. If a tool" - " that you have would be helpful to answer a user query, " - "call the tool. Otherwise, answer the user's query directly " - "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " - "to the user's question - just respond to it normally." + "You are a helpful assistant with access to tools. If a tool" + " that you have would be helpful to answer a user query, " + "call the tool. Otherwise, answer the user's query directly " + "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " + "to the user's question - just respond to it normally." }, { "role": - "user", + "user", "content": - "Hi! How are you?" + "Hi! How are you?" }, { "role": - "assistant", + "assistant", "content": - "I'm doing great! How can I assist you?" + "I'm doing great! How can I assist you?" }, { "role": - "user", + "user", "content": - "Can you tell me a joke please?" + "Can you tell me a joke please?" }] MESSAGES_ASKING_FOR_TOOLS: List[ChatCompletionMessageParam] = [{ "role": - "user", + "user", "content": - "What is the weather in Dallas, Texas in Fahrenheit?" + "What is the weather in Dallas, Texas in Fahrenheit?" }] MESSAGES_WITH_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{ - "role": "user", - "content": "What is the weather in Dallas, Texas in Fahrenheit?" - },{ - "role": "assistant", - "tool_calls": [{ - "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", - "type": "function", - "function": { - "name": WEATHER_TOOL["function"]["name"], - "arguments": '{"city": "Dallas", "state": "TX", ' - '"unit": "fahrenheit"}' - } - }] - },{ - "role": "tool", - "tool_call_id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", - "content": "The weather in Dallas is 98 degrees fahrenheit, with partly" - "cloudy skies and a low chance of rain." + "role": + "user", + "content": + "What is the weather in Dallas, Texas in Fahrenheit?" +}, { + "role": + "assistant", + "tool_calls": [{ + "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "type": "function", + "function": { + "name": + WEATHER_TOOL["function"]["name"], + "arguments": + '{"city": "Dallas", "state": "TX", ' + '"unit": "fahrenheit"}' + } }] +}, { + "role": + "tool", + "tool_call_id": + "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "content": + "The weather in Dallas is 98 degrees fahrenheit, with partly" + "cloudy skies and a low chance of rain." +}] # for each server config, download the model and return the config @@ -184,6 +192,7 @@ def server(request, server_config: ServerConfig): with RemoteOpenAIServer(model, ARGS + args_for_model) as server: yield server + @pytest.fixture(scope="module") def client(server: RemoteOpenAIServer): return server.get_async_client() @@ -209,6 +218,7 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI): # check to make sure we got text assert output_text is not None assert len(output_text) > 0 + assert stop_reason != "tool_calls" # check to make sure no tool calls were returned assert (choice.message.tool_calls is None @@ -462,8 +472,7 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI): max_tokens=500, model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False - ) + logprobs=False) choice = chat_completion.choices[0] @@ -481,11 +490,9 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI): model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, - stream=True - ) + stream=True) chunks: List[str] = [] - finish_reason: Optional[str] == None finish_reason_count = 0 role_sent: bool = False From 8e61cb30f8c1a807f53334cef2cfbc611c469884 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 12 Aug 2024 23:10:09 -0500 Subject: [PATCH 166/222] try fixing tests by removing my specific dtype=half and kv-cache-dtype=fp8 options that I use when testing --- tests/entrypoints/openai/test_tools.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/openai/test_tools.py b/tests/entrypoints/openai/test_tools.py index 13bb8cb33d34e..e8fe1a3a7fd41 100644 --- a/tests/entrypoints/openai/test_tools.py +++ b/tests/entrypoints/openai/test_tools.py @@ -22,10 +22,10 @@ class TestConfig(TypedDict): ARGS: List[str] = [ - "--dtype", - "half", # TODO change to BF16 - "--kv-cache-dtype", - "fp8", + # "--dtype", + # "half", # TODO change to BF16 + # "--kv-cache-dtype", + # "fp8", "--enable-auto-tool-choice" ] From 7605d9f16872ba54e4e70a13128effb09af7f7ce Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 13 Aug 2024 12:48:33 -0500 Subject: [PATCH 167/222] fix: download model with hf_hub to prevent mistral timeout --- tests/entrypoints/openai/test_tools.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/entrypoints/openai/test_tools.py b/tests/entrypoints/openai/test_tools.py index e8fe1a3a7fd41..34fc091f70827 100644 --- a/tests/entrypoints/openai/test_tools.py +++ b/tests/entrypoints/openai/test_tools.py @@ -3,9 +3,9 @@ import openai import pytest +from huggingface_hub import snapshot_download from openai.types.chat import (ChatCompletionMessageParam, ChatCompletionToolParam) -from transformers import AutoModelForCausalLM, AutoTokenizer from typing_extensions import TypedDict from ...utils import VLLM_PATH, RemoteOpenAIServer @@ -22,10 +22,10 @@ class TestConfig(TypedDict): ARGS: List[str] = [ - # "--dtype", - # "half", # TODO change to BF16 - # "--kv-cache-dtype", - # "fp8", + "--dtype", + "half", # TODO change to BF16 + "--kv-cache-dtype", + "fp8", "--enable-auto-tool-choice" ] @@ -179,8 +179,7 @@ def server_config(request): print(f'downloading model for {config["model"]}') # download model and tokenizer using transformers - AutoTokenizer.from_pretrained(config["model"]) - AutoModelForCausalLM.from_pretrained(config["model"]) + snapshot_download(config["model"]) yield CONFIGS[request.param] From f39ae8f2a0232ea1c944efe174e7a39b62b73395 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 13 Aug 2024 17:14:21 -0500 Subject: [PATCH 168/222] test: parallel tool calls, re-trigger CI that was interrupted by huggingface outage --- tests/entrypoints/openai/test_tools.py | 135 +++++++++++++++++++++++++ 1 file changed, 135 insertions(+) diff --git a/tests/entrypoints/openai/test_tools.py b/tests/entrypoints/openai/test_tools.py index 34fc091f70827..b330fd77a87b8 100644 --- a/tests/entrypoints/openai/test_tools.py +++ b/tests/entrypoints/openai/test_tools.py @@ -170,6 +170,14 @@ class TestConfig(TypedDict): "cloudy skies and a low chance of rain." }] +MESSAGES_ASKING_FOR_PARALLEL_TOOLS: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas and Orlando, Florida in " + "Fahrenheit?" +}] + # for each server config, download the model and return the config @pytest.fixture(scope="module", params=CONFIGS.keys()) @@ -519,6 +527,133 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI): # test: getting the model to generate parallel tool calls (streaming/not) +# when requested. NOTE that not all models may support this, so some exclusions +# may be added in the future. e.g. llama 3.1 models are not designed to support +# parallel tool calls. +@pytest.mark.asyncio +async def test_parallel_tool_calls(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + temperature=0, + max_tokens=800, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + print('completion choice: ', choice) + stop_reason = chat_completion.choices[0].finish_reason + non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls + + # make sure 2 tool calls are present + assert choice.message.role == "assistant" + assert non_streamed_tool_calls is not None + assert len(non_streamed_tool_calls) == 2 + + for tool_call in non_streamed_tool_calls: + # make sure the tool includes a function and ID + assert tool_call.type == "function" + assert tool_call.function is not None + assert isinstance(tool_call.id, str) + assert len(tool_call.id) > 16 + + # make sure the weather tool was called correctly + assert tool_call.function.name == WEATHER_TOOL["function"]["name"] + assert isinstance(tool_call.function.arguments, str) + + parsed_arguments = json.loads(tool_call.function.arguments) + assert isinstance(parsed_arguments, Dict) + assert isinstance(parsed_arguments.get("city"), str) + assert isinstance(parsed_arguments.get("state"), str) + + assert stop_reason == "tool_calls" + + # make the same request, streaming + stream = await client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + temperature=0, + max_tokens=800, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + role_name: Optional[str] = None + finish_reason_count: int = 0 + + tool_call_names: List[str] = [] + tool_call_args: List[str] = [] + tool_call_idx: int = -1 + tool_call_id_count: int = 0 + + async for chunk in stream: + + print('got chunk', chunk.choices[0]) + + # if there's a finish reason make sure it's tools + if chunk.choices[0].finish_reason: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == 'tool_calls' + + # if a role is being streamed make sure it wasn't already set to + # something else + if chunk.choices[0].delta.role: + assert not role_name or role_name == 'assistant' + role_name = 'assistant' + + # if a tool call is streamed make sure there's exactly one + # (based on the request parameters + streamed_tool_calls = chunk.choices[0].delta.tool_calls + + if streamed_tool_calls and len(streamed_tool_calls) > 0: + + # make sure only one diff is present - correct even for parallel + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + # if a new tool is being called, set up empty arguments + if tool_call.index != tool_call_idx: + tool_call_idx = tool_call.index + tool_call_args.append("") + + # if a tool call ID is streamed, make sure one hasn't been already + if tool_call.id: + tool_call_id_count += 1 + assert (isinstance(tool_call.id, str) + and (len(tool_call.id) > 16)) + + # if parts of the function start being streamed + if tool_call.function: + # if the function name is defined, set it. it should be streamed + # IN ENTIRETY, exactly one time. + if tool_call.function.name: + assert isinstance(tool_call.function.name, str) + tool_call_names.append(tool_call.function.name) + + if tool_call.function.arguments: + # make sure they're a string and then add them to the list + assert isinstance(tool_call.function.arguments, str) + + tool_call_args[ + tool_call.index] += tool_call.function.arguments + + print('tool call names', tool_call_names) + print('tool_call args', tool_call_args) + assert finish_reason_count == 1 + assert role_name == 'assistant' + + assert (len(non_streamed_tool_calls) == len(tool_call_names) == + len(tool_call_args)) + + for i in range(0, 2): + assert non_streamed_tool_calls[i].function.name == tool_call_names[i] + streamed_args = json.loads(tool_call_args[i]) + non_streamed_args = json.loads( + non_streamed_tool_calls[i].function.arguments) + assert streamed_args == non_streamed_args + # test: providing parallel tool calls back to the model to get a response # (streaming/not) From 1f50cef34f37870c59b5add065614c5339986d20 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 13 Aug 2024 20:50:24 -0500 Subject: [PATCH 169/222] fix: print statements --- tests/entrypoints/openai/test_tools.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/entrypoints/openai/test_tools.py b/tests/entrypoints/openai/test_tools.py index b330fd77a87b8..dfeaab5af5de8 100644 --- a/tests/entrypoints/openai/test_tools.py +++ b/tests/entrypoints/openai/test_tools.py @@ -188,7 +188,9 @@ def server_config(request): # download model and tokenizer using transformers snapshot_download(config["model"]) + print(f'downloaded model {config["model"]}') yield CONFIGS[request.param] + print(f'Cleaning up vLLM server for {config["model"]} ') # run this for each server config @@ -196,8 +198,10 @@ def server_config(request): def server(request, server_config: ServerConfig): model = server_config["model"] args_for_model = server_config["arguments"] + print(f"Starting server for {model}") with RemoteOpenAIServer(model, ARGS + args_for_model) as server: yield server + print(f'shutting down server for {model}') @pytest.fixture(scope="module") @@ -543,7 +547,6 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI): logprobs=False) choice = chat_completion.choices[0] - print('completion choice: ', choice) stop_reason = chat_completion.choices[0].finish_reason non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls @@ -590,8 +593,6 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI): async for chunk in stream: - print('got chunk', chunk.choices[0]) - # if there's a finish reason make sure it's tools if chunk.choices[0].finish_reason: finish_reason_count += 1 @@ -639,8 +640,6 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI): tool_call_args[ tool_call.index] += tool_call.function.arguments - print('tool call names', tool_call_names) - print('tool_call args', tool_call_args) assert finish_reason_count == 1 assert role_name == 'assistant' From b7de9dec2674b80b0c5c10925f6fe4256ef44834 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 14 Aug 2024 09:57:01 -0500 Subject: [PATCH 170/222] fix(tests): allow passing in wait time to remote Open AI server and increase wait time --- tests/entrypoints/openai/test_tools.py | 3 ++- tests/utils.py | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/entrypoints/openai/test_tools.py b/tests/entrypoints/openai/test_tools.py index dfeaab5af5de8..38dc8a641127f 100644 --- a/tests/entrypoints/openai/test_tools.py +++ b/tests/entrypoints/openai/test_tools.py @@ -199,7 +199,8 @@ def server(request, server_config: ServerConfig): model = server_config["model"] args_for_model = server_config["arguments"] print(f"Starting server for {model}") - with RemoteOpenAIServer(model, ARGS + args_for_model) as server: + with RemoteOpenAIServer(model, ARGS + args_for_model, + max_start_wait_s=240) as server: yield server print(f'shutting down server for {model}') diff --git a/tests/utils.py b/tests/utils.py index 697bf7d93c36e..500fb4c662a30 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -53,14 +53,13 @@ class RemoteOpenAIServer: DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key MAX_START_WAIT_S = 120 # wait for server to start for 120 seconds - def __init__( - self, - model: str, - cli_args: List[str], - *, - env_dict: Optional[Dict[str, str]] = None, - auto_port: bool = True, - ) -> None: + def __init__(self, + model: str, + cli_args: List[str], + *, + env_dict: Optional[Dict[str, str]] = None, + auto_port: bool = True, + max_start_wait_s: Optional[int] = None) -> None: if auto_port: if "-p" in cli_args or "--port" in cli_args: raise ValueError("You have manually specified the port" @@ -75,6 +74,9 @@ def __init__( self.host = str(args.host or 'localhost') self.port = int(args.port) + if max_start_wait_s: + self.MAX_START_WAIT_S = max_start_wait_s + env = os.environ.copy() # the current process might initialize cuda, # to be safe, we should use spawn method From bff5e18efc32b759b934599b78360a75da2d4a84 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 14 Aug 2024 15:19:25 -0500 Subject: [PATCH 171/222] test: add final tests for providing tool call responses with parallel tool calling --- tests/entrypoints/openai/test_tools.py | 111 +++++++++++++++++++++++-- 1 file changed, 104 insertions(+), 7 deletions(-) diff --git a/tests/entrypoints/openai/test_tools.py b/tests/entrypoints/openai/test_tools.py index 38dc8a641127f..0e285bb791529 100644 --- a/tests/entrypoints/openai/test_tools.py +++ b/tests/entrypoints/openai/test_tools.py @@ -178,19 +178,62 @@ class TestConfig(TypedDict): "Fahrenheit?" }] +MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas and Orlando, Florida in " + "Fahrenheit?" +}, { + "role": + "assistant", + "tool_calls": [{ + "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "type": "function", + "function": { + "name": + WEATHER_TOOL["function"]["name"], + "arguments": + '{"city": "Dallas", "state": "TX", ' + '"unit": "fahrenheit"}' + } + }, { + "id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", + "type": "function", + "function": { + "name": + WEATHER_TOOL["function"]["name"], + "arguments": + '{"city": "Orlando", "state": "Fl", ' + '"unit": "fahrenheit"}' + } + }] +}, { + "role": + "tool", + "tool_call_id": + "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "content": + "The weather in Dallas TX is 98 degrees fahrenheit with mostly " + "cloudy skies and a chance of rain in the evening." +}, { + "role": + "tool", + "tool_call_id": + "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", + "content": + "The weather in Orlando FL is 78 degrees fahrenheit with clear" + "skies." +}] + # for each server config, download the model and return the config @pytest.fixture(scope="module", params=CONFIGS.keys()) def server_config(request): config = CONFIGS[request.param] - - print(f'downloading model for {config["model"]}') - # download model and tokenizer using transformers snapshot_download(config["model"]) - print(f'downloaded model {config["model"]}') yield CONFIGS[request.param] - print(f'Cleaning up vLLM server for {config["model"]} ') # run this for each server config @@ -198,11 +241,9 @@ def server_config(request): def server(request, server_config: ServerConfig): model = server_config["model"] args_for_model = server_config["arguments"] - print(f"Starting server for {model}") with RemoteOpenAIServer(model, ARGS + args_for_model, max_start_wait_s=240) as server: yield server - print(f'shutting down server for {model}') @pytest.fixture(scope="module") @@ -657,3 +698,59 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI): # test: providing parallel tool calls back to the model to get a response # (streaming/not) +@pytest.mark.asyncio +async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, + temperature=0, + max_tokens=500, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + + assert choice.finish_reason != "tool_calls" # "stop" or "length" + assert choice.message.role == "assistant" + assert choice.message.tool_calls is None \ + or len(choice.message.tool_calls) == 0 + assert choice.message.content is not None + assert "98" in choice.message.content # Dallas temp in tool response + assert "78" in choice.message.content # Orlando temp in tool response + + stream = await client.chat.completions.create( + messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, + temperature=0, + max_tokens=500, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + async for chunk in stream: + delta = chunk.choices[0].delta + + if delta.role: + assert not role_sent + assert delta.role == "assistant" + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason + + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + assert role_sent + assert finish_reason_count == 1 + assert len(chunks) + assert "".join(chunks) == choice.message.content From 1b417ba3ab2169be062242bc1334f41d48e52628 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 14 Aug 2024 18:08:33 -0500 Subject: [PATCH 172/222] refactor(tests): break tests out into multiple files for readability & in preparation of moving them out of entrypoints --- tests/entrypoints/openai/test_tools.py | 756 ------------------ tests/entrypoints/openai/tool_use/__init__.py | 0 tests/entrypoints/openai/tool_use/conftest.py | 29 + .../openai/tool_use/test_chat_completions.py | 143 ++++ .../tool_use/test_parallel_tool_calls.py | 193 +++++ .../openai/tool_use/test_tool_calls.py | 192 +++++ tests/entrypoints/openai/tool_use/util.py | 218 +++++ 7 files changed, 775 insertions(+), 756 deletions(-) delete mode 100644 tests/entrypoints/openai/test_tools.py create mode 100644 tests/entrypoints/openai/tool_use/__init__.py create mode 100644 tests/entrypoints/openai/tool_use/conftest.py create mode 100644 tests/entrypoints/openai/tool_use/test_chat_completions.py create mode 100644 tests/entrypoints/openai/tool_use/test_parallel_tool_calls.py create mode 100644 tests/entrypoints/openai/tool_use/test_tool_calls.py create mode 100644 tests/entrypoints/openai/tool_use/util.py diff --git a/tests/entrypoints/openai/test_tools.py b/tests/entrypoints/openai/test_tools.py deleted file mode 100644 index 0e285bb791529..0000000000000 --- a/tests/entrypoints/openai/test_tools.py +++ /dev/null @@ -1,756 +0,0 @@ -import json -from typing import Dict, List, Optional - -import openai -import pytest -from huggingface_hub import snapshot_download -from openai.types.chat import (ChatCompletionMessageParam, - ChatCompletionToolParam) -from typing_extensions import TypedDict - -from ...utils import VLLM_PATH, RemoteOpenAIServer - - -class ServerConfig(TypedDict): - model: str - arguments: List[str] - - -class TestConfig(TypedDict): - client: openai.AsyncOpenAI - model: str - - -ARGS: List[str] = [ - "--dtype", - "half", # TODO change to BF16 - "--kv-cache-dtype", - "fp8", - "--enable-auto-tool-choice" -] - -CONFIGS: Dict[str, ServerConfig] = { - "hermes": { - "model": - "NousResearch/Hermes-2-Pro-Llama-3-8B", - "arguments": [ - "--tool-call-parser", "hermes", "--chat-template", - str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja") - ] - }, - "mistral": { - "model": - "mistralai/Mistral-7B-Instruct-v0.3", - "arguments": [ - "--tool-call-parser", "mistral", "--chat-template", - str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja") - ] - } -} - -WEATHER_TOOL: ChatCompletionToolParam = { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": - "string", - "description": - "The city to find the weather for, " - "e.g. 'San Francisco'" - }, - "state": { - "type": - "string", - "description": - "the two-letter abbreviation for the state " - "that the city is in, e.g. 'CA' which would " - "mean 'California'" - }, - "unit": { - "type": "string", - "description": "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"] - } - } - } - } -} - -SEARCH_TOOL: ChatCompletionToolParam = { - "type": "function", - "function": { - "name": - "web_search", - "description": - "Search the internet and get a summary of the top " - "10 webpages. Should only be used if you don't know " - "the answer to a user query, and the results are likely" - "to be able to be found with a web search", - "parameters": { - "type": "object", - "properties": { - "search_term": { - "type": - "string", - "description": - "The term to use in the search. This should" - "ideally be keywords to search for, not a" - "natural-language question" - } - }, - "required": ["search_term"] - } - } -} - -MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{ - "role": - "system", - "content": - "You are a helpful assistant with access to tools. If a tool" - " that you have would be helpful to answer a user query, " - "call the tool. Otherwise, answer the user's query directly " - "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " - "to the user's question - just respond to it normally." -}, { - "role": - "user", - "content": - "Hi! How are you?" -}, { - "role": - "assistant", - "content": - "I'm doing great! How can I assist you?" -}, { - "role": - "user", - "content": - "Can you tell me a joke please?" -}] - -MESSAGES_ASKING_FOR_TOOLS: List[ChatCompletionMessageParam] = [{ - "role": - "user", - "content": - "What is the weather in Dallas, Texas in Fahrenheit?" -}] - -MESSAGES_WITH_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{ - "role": - "user", - "content": - "What is the weather in Dallas, Texas in Fahrenheit?" -}, { - "role": - "assistant", - "tool_calls": [{ - "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", - "type": "function", - "function": { - "name": - WEATHER_TOOL["function"]["name"], - "arguments": - '{"city": "Dallas", "state": "TX", ' - '"unit": "fahrenheit"}' - } - }] -}, { - "role": - "tool", - "tool_call_id": - "chatcmpl-tool-03e6481b146e408e9523d9c956696295", - "content": - "The weather in Dallas is 98 degrees fahrenheit, with partly" - "cloudy skies and a low chance of rain." -}] - -MESSAGES_ASKING_FOR_PARALLEL_TOOLS: List[ChatCompletionMessageParam] = [{ - "role": - "user", - "content": - "What is the weather in Dallas, Texas and Orlando, Florida in " - "Fahrenheit?" -}] - -MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{ - "role": - "user", - "content": - "What is the weather in Dallas, Texas and Orlando, Florida in " - "Fahrenheit?" -}, { - "role": - "assistant", - "tool_calls": [{ - "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", - "type": "function", - "function": { - "name": - WEATHER_TOOL["function"]["name"], - "arguments": - '{"city": "Dallas", "state": "TX", ' - '"unit": "fahrenheit"}' - } - }, { - "id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", - "type": "function", - "function": { - "name": - WEATHER_TOOL["function"]["name"], - "arguments": - '{"city": "Orlando", "state": "Fl", ' - '"unit": "fahrenheit"}' - } - }] -}, { - "role": - "tool", - "tool_call_id": - "chatcmpl-tool-03e6481b146e408e9523d9c956696295", - "content": - "The weather in Dallas TX is 98 degrees fahrenheit with mostly " - "cloudy skies and a chance of rain in the evening." -}, { - "role": - "tool", - "tool_call_id": - "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", - "content": - "The weather in Orlando FL is 78 degrees fahrenheit with clear" - "skies." -}] - - -# for each server config, download the model and return the config -@pytest.fixture(scope="module", params=CONFIGS.keys()) -def server_config(request): - config = CONFIGS[request.param] - # download model and tokenizer using transformers - snapshot_download(config["model"]) - yield CONFIGS[request.param] - - -# run this for each server config -@pytest.fixture(scope="module") -def server(request, server_config: ServerConfig): - model = server_config["model"] - args_for_model = server_config["arguments"] - with RemoteOpenAIServer(model, ARGS + args_for_model, - max_start_wait_s=240) as server: - yield server - - -@pytest.fixture(scope="module") -def client(server: RemoteOpenAIServer): - return server.get_async_client() - - -# test: make sure chat completions without tools provided work even when tools -# are enabled. This makes sure tool call chat templates work, AND that the tool -# parser stream processing doesn't change the output of the model. -@pytest.mark.asyncio -async def test_chat_completion_without_tools(client: openai.AsyncOpenAI): - models = await client.models.list() - model_name: str = models.data[0].id - chat_completion = await client.chat.completions.create( - messages=MESSAGES_WITHOUT_TOOLS, - temperature=0, - max_tokens=128, - model=model_name, - logprobs=False) - choice = chat_completion.choices[0] - stop_reason = chat_completion.choices[0].finish_reason - output_text = chat_completion.choices[0].message.content - - # check to make sure we got text - assert output_text is not None - assert len(output_text) > 0 - assert stop_reason != "tool_calls" - - # check to make sure no tool calls were returned - assert (choice.message.tool_calls is None - or len(choice.message.tool_calls) == 0) - - # make the same request, streaming - stream = await client.chat.completions.create( - messages=MESSAGES_WITHOUT_TOOLS, - temperature=0, - max_tokens=128, - model=model_name, - logprobs=False, - stream=True, - ) - chunks: List[str] = [] - finish_reason_count = 0 - role_sent: bool = False - - # assemble streamed chunks - async for chunk in stream: - delta = chunk.choices[0].delta - - # make sure the role is assistant - if delta.role: - assert not role_sent - assert delta.role == 'assistant' - role_sent = True - - if delta.content: - chunks.append(delta.content) - - if chunk.choices[0].finish_reason is not None: - finish_reason_count += 1 - assert chunk.choices[0].finish_reason == choice.finish_reason - - # make sure tool call chunks aren't being streamed - assert not delta.tool_calls or len(delta.tool_calls) == 0 - - # make sure the role was sent, only 1 finish reason was sent, that chunks - # were in fact sent, and that the chunks match non-streaming - assert role_sent - assert finish_reason_count == 1 - assert len(chunks) - assert "".join(chunks) == output_text - - -# test: conversation with tools enabled and provided that should not invoke -# tools, to make sure we can still get normal chat completion responses -# and that they won't be parsed as tools -@pytest.mark.asyncio -async def test_chat_completion_with_tools(client: openai.AsyncOpenAI): - models = await client.models.list() - model_name: str = models.data[0].id - chat_completion = await client.chat.completions.create( - messages=MESSAGES_WITHOUT_TOOLS, - temperature=0, - max_tokens=128, - model=model_name, - tools=[WEATHER_TOOL], - logprobs=False) - choice = chat_completion.choices[0] - stop_reason = chat_completion.choices[0].finish_reason - output_text = chat_completion.choices[0].message.content - - # check to make sure we got text - assert output_text is not None - assert stop_reason != 'tool_calls' - assert len(output_text) > 0 - - # check to make sure no tool calls were returned - assert (choice.message.tool_calls is None - or len(choice.message.tool_calls) == 0) - - # make the same request, streaming - stream = await client.chat.completions.create( - messages=MESSAGES_WITHOUT_TOOLS, - temperature=0, - max_tokens=128, - model=model_name, - logprobs=False, - tools=[WEATHER_TOOL], - stream=True, - ) - - chunks: List[str] = [] - finish_reason_count = 0 - role_sent: bool = False - - # assemble streamed chunks - async for chunk in stream: - delta = chunk.choices[0].delta - - # make sure the role is assistant - if delta.role: - assert delta.role == 'assistant' - role_sent = True - - if delta.content: - chunks.append(delta.content) - - if chunk.choices[0].finish_reason is not None: - finish_reason_count += 1 - - # make sure tool call chunks aren't being streamed - assert not delta.tool_calls or len(delta.tool_calls) == 0 - - # make sure the role was sent, only 1 finish reason was sent, that chunks - # were in fact sent, and that the chunks match non-streaming - assert role_sent - assert finish_reason_count == 1 - assert chunk.choices[0].finish_reason == stop_reason - assert chunk.choices[0].finish_reason != 'tool_calls' - assert len(chunks) - assert "".join(chunks) == output_text - - -# test: request a chat completion that should return tool calls, so we know they -# are parsable -@pytest.mark.asyncio -async def test_tool_call_and_choice(client: openai.AsyncOpenAI): - models = await client.models.list() - model_name: str = models.data[0].id - chat_completion = await client.chat.completions.create( - messages=MESSAGES_ASKING_FOR_TOOLS, - temperature=0, - max_tokens=500, - model=model_name, - tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False) - - choice = chat_completion.choices[0] - stop_reason = chat_completion.choices[0].finish_reason - tool_calls = chat_completion.choices[0].message.tool_calls - - # make sure a tool call is present - assert choice.message.role == 'assistant' - assert tool_calls is not None - assert len(tool_calls) == 1 - assert tool_calls[0].type == 'function' - assert tool_calls[0].function is not None - assert isinstance(tool_calls[0].id, str) - assert len(tool_calls[0].id) > 16 - - # make sure the weather tool was called (classic example) with arguments - assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"] - assert tool_calls[0].function.arguments is not None - assert isinstance(tool_calls[0].function.arguments, str) - - # make sure the arguments parse properly - parsed_arguments = json.loads(tool_calls[0].function.arguments) - assert isinstance(parsed_arguments, Dict) - assert isinstance(parsed_arguments.get("city"), str) - assert isinstance(parsed_arguments.get("state"), str) - assert parsed_arguments.get("city") == "Dallas" - assert parsed_arguments.get("state") == "TX" - - assert stop_reason == "tool_calls" - - function_name: Optional[str] = None - function_args_str: str = '' - tool_call_id: Optional[str] = None - role_name: Optional[str] = None - finish_reason_count: int = 0 - - # make the same request, streaming - stream = await client.chat.completions.create( - model=model_name, - messages=MESSAGES_ASKING_FOR_TOOLS, - temperature=0, - max_tokens=500, - tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False, - stream=True) - - async for chunk in stream: - assert chunk.choices[0].index == 0 - - if chunk.choices[0].finish_reason: - finish_reason_count += 1 - assert chunk.choices[0].finish_reason == 'tool_calls' - - # if a role is being streamed make sure it wasn't already set to - # something else - if chunk.choices[0].delta.role: - assert not role_name or role_name == 'assistant' - role_name = 'assistant' - - # if a tool call is streamed make sure there's exactly one - # (based on the request parameters - streamed_tool_calls = chunk.choices[0].delta.tool_calls - - if streamed_tool_calls and len(streamed_tool_calls) > 0: - assert len(streamed_tool_calls) == 1 - tool_call = streamed_tool_calls[0] - - # if a tool call ID is streamed, make sure one hasn't been already - if tool_call.id: - assert not tool_call_id - tool_call_id = tool_call.id - - # if parts of the function start being streamed - if tool_call.function: - # if the function name is defined, set it. it should be streamed - # IN ENTIRETY, exactly one time. - if tool_call.function.name: - assert function_name is None - assert isinstance(tool_call.function.name, str) - function_name = tool_call.function.name - if tool_call.function.arguments: - assert isinstance(tool_call.function.arguments, str) - function_args_str += tool_call.function.arguments - - assert finish_reason_count == 1 - assert role_name == 'assistant' - assert isinstance(tool_call_id, str) and (len(tool_call_id) > 16) - - # validate the name and arguments - assert function_name == WEATHER_TOOL["function"]["name"] - assert function_name == tool_calls[0].function.name - assert isinstance(function_args_str, str) - - # validate arguments - streamed_args = json.loads(function_args_str) - assert isinstance(streamed_args, Dict) - assert isinstance(streamed_args.get("city"), str) - assert isinstance(streamed_args.get("state"), str) - assert streamed_args.get("city") == "Dallas" - assert streamed_args.get("state") == "TX" - - # make sure everything matches non-streaming except for ID - assert function_name == tool_calls[0].function.name - assert choice.message.role == role_name - assert choice.message.tool_calls[0].function.name == function_name - - # compare streamed with non-streamed args Dict-wise, not string-wise - # because character-to-character comparison might not work e.g. the tool - # call parser adding extra spaces or something like that. we care about the - # dicts matching not byte-wise match - assert parsed_arguments == streamed_args - - -# test: providing tools and results back to model to get a non-tool response -# (streaming/not) -@pytest.mark.asyncio -async def test_tool_call_with_results(client: openai.AsyncOpenAI): - models = await client.models.list() - model_name: str = models.data[0].id - chat_completion = await client.chat.completions.create( - messages=MESSAGES_WITH_TOOL_RESPONSE, - temperature=0, - max_tokens=500, - model=model_name, - tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False) - - choice = chat_completion.choices[0] - - assert choice.finish_reason != "tool_calls" # "stop" or "length" - assert choice.message.role == "assistant" - assert choice.message.tool_calls is None \ - or len(choice.message.tool_calls) == 0 - assert choice.message.content is not None - assert "98" in choice.message.content # the temperature from the response - - stream = await client.chat.completions.create( - messages=MESSAGES_WITH_TOOL_RESPONSE, - temperature=0, - max_tokens=500, - model=model_name, - tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False, - stream=True) - - chunks: List[str] = [] - finish_reason_count = 0 - role_sent: bool = False - - async for chunk in stream: - delta = chunk.choices[0].delta - - if delta.role: - assert not role_sent - assert delta.role == "assistant" - role_sent = True - - if delta.content: - chunks.append(delta.content) - - if chunk.choices[0].finish_reason is not None: - finish_reason_count += 1 - assert chunk.choices[0].finish_reason == choice.finish_reason - - assert not delta.tool_calls or len(delta.tool_calls) == 0 - - assert role_sent - assert finish_reason_count == 1 - assert len(chunks) - assert "".join(chunks) == choice.message.content - - -# test: getting the model to generate parallel tool calls (streaming/not) -# when requested. NOTE that not all models may support this, so some exclusions -# may be added in the future. e.g. llama 3.1 models are not designed to support -# parallel tool calls. -@pytest.mark.asyncio -async def test_parallel_tool_calls(client: openai.AsyncOpenAI): - models = await client.models.list() - model_name: str = models.data[0].id - chat_completion = await client.chat.completions.create( - messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, - temperature=0, - max_tokens=800, - model=model_name, - tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False) - - choice = chat_completion.choices[0] - stop_reason = chat_completion.choices[0].finish_reason - non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls - - # make sure 2 tool calls are present - assert choice.message.role == "assistant" - assert non_streamed_tool_calls is not None - assert len(non_streamed_tool_calls) == 2 - - for tool_call in non_streamed_tool_calls: - # make sure the tool includes a function and ID - assert tool_call.type == "function" - assert tool_call.function is not None - assert isinstance(tool_call.id, str) - assert len(tool_call.id) > 16 - - # make sure the weather tool was called correctly - assert tool_call.function.name == WEATHER_TOOL["function"]["name"] - assert isinstance(tool_call.function.arguments, str) - - parsed_arguments = json.loads(tool_call.function.arguments) - assert isinstance(parsed_arguments, Dict) - assert isinstance(parsed_arguments.get("city"), str) - assert isinstance(parsed_arguments.get("state"), str) - - assert stop_reason == "tool_calls" - - # make the same request, streaming - stream = await client.chat.completions.create( - model=model_name, - messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, - temperature=0, - max_tokens=800, - tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False, - stream=True) - - role_name: Optional[str] = None - finish_reason_count: int = 0 - - tool_call_names: List[str] = [] - tool_call_args: List[str] = [] - tool_call_idx: int = -1 - tool_call_id_count: int = 0 - - async for chunk in stream: - - # if there's a finish reason make sure it's tools - if chunk.choices[0].finish_reason: - finish_reason_count += 1 - assert chunk.choices[0].finish_reason == 'tool_calls' - - # if a role is being streamed make sure it wasn't already set to - # something else - if chunk.choices[0].delta.role: - assert not role_name or role_name == 'assistant' - role_name = 'assistant' - - # if a tool call is streamed make sure there's exactly one - # (based on the request parameters - streamed_tool_calls = chunk.choices[0].delta.tool_calls - - if streamed_tool_calls and len(streamed_tool_calls) > 0: - - # make sure only one diff is present - correct even for parallel - assert len(streamed_tool_calls) == 1 - tool_call = streamed_tool_calls[0] - - # if a new tool is being called, set up empty arguments - if tool_call.index != tool_call_idx: - tool_call_idx = tool_call.index - tool_call_args.append("") - - # if a tool call ID is streamed, make sure one hasn't been already - if tool_call.id: - tool_call_id_count += 1 - assert (isinstance(tool_call.id, str) - and (len(tool_call.id) > 16)) - - # if parts of the function start being streamed - if tool_call.function: - # if the function name is defined, set it. it should be streamed - # IN ENTIRETY, exactly one time. - if tool_call.function.name: - assert isinstance(tool_call.function.name, str) - tool_call_names.append(tool_call.function.name) - - if tool_call.function.arguments: - # make sure they're a string and then add them to the list - assert isinstance(tool_call.function.arguments, str) - - tool_call_args[ - tool_call.index] += tool_call.function.arguments - - assert finish_reason_count == 1 - assert role_name == 'assistant' - - assert (len(non_streamed_tool_calls) == len(tool_call_names) == - len(tool_call_args)) - - for i in range(0, 2): - assert non_streamed_tool_calls[i].function.name == tool_call_names[i] - streamed_args = json.loads(tool_call_args[i]) - non_streamed_args = json.loads( - non_streamed_tool_calls[i].function.arguments) - assert streamed_args == non_streamed_args - - -# test: providing parallel tool calls back to the model to get a response -# (streaming/not) -@pytest.mark.asyncio -async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI): - models = await client.models.list() - model_name: str = models.data[0].id - chat_completion = await client.chat.completions.create( - messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, - temperature=0, - max_tokens=500, - model=model_name, - tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False) - - choice = chat_completion.choices[0] - - assert choice.finish_reason != "tool_calls" # "stop" or "length" - assert choice.message.role == "assistant" - assert choice.message.tool_calls is None \ - or len(choice.message.tool_calls) == 0 - assert choice.message.content is not None - assert "98" in choice.message.content # Dallas temp in tool response - assert "78" in choice.message.content # Orlando temp in tool response - - stream = await client.chat.completions.create( - messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, - temperature=0, - max_tokens=500, - model=model_name, - tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False, - stream=True) - - chunks: List[str] = [] - finish_reason_count = 0 - role_sent: bool = False - - async for chunk in stream: - delta = chunk.choices[0].delta - - if delta.role: - assert not role_sent - assert delta.role == "assistant" - role_sent = True - - if delta.content: - chunks.append(delta.content) - - if chunk.choices[0].finish_reason is not None: - finish_reason_count += 1 - assert chunk.choices[0].finish_reason == choice.finish_reason - - assert not delta.tool_calls or len(delta.tool_calls) == 0 - - assert role_sent - assert finish_reason_count == 1 - assert len(chunks) - assert "".join(chunks) == choice.message.content diff --git a/tests/entrypoints/openai/tool_use/__init__.py b/tests/entrypoints/openai/tool_use/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/entrypoints/openai/tool_use/conftest.py b/tests/entrypoints/openai/tool_use/conftest.py new file mode 100644 index 0000000000000..3d221af999fc5 --- /dev/null +++ b/tests/entrypoints/openai/tool_use/conftest.py @@ -0,0 +1,29 @@ +import pytest +from huggingface_hub import snapshot_download + +from ....utils import RemoteOpenAIServer +from .util import ARGS, CONFIGS, ServerConfig + + +# for each server config, download the model and return the config +@pytest.fixture(scope="module", params=CONFIGS.keys()) +def server_config(request): + config = CONFIGS[request.param] + # download model and tokenizer using transformers + snapshot_download(config["model"]) + yield CONFIGS[request.param] + + +# run this for each server config +@pytest.fixture(scope="module") +def server(request, server_config: ServerConfig): + model = server_config["model"] + args_for_model = server_config["arguments"] + with RemoteOpenAIServer(model, ARGS + args_for_model, + max_start_wait_s=240) as server: + yield server + + +@pytest.fixture(scope="module") +def client(server: RemoteOpenAIServer): + return server.get_async_client() diff --git a/tests/entrypoints/openai/tool_use/test_chat_completions.py b/tests/entrypoints/openai/tool_use/test_chat_completions.py new file mode 100644 index 0000000000000..a5adb04252c12 --- /dev/null +++ b/tests/entrypoints/openai/tool_use/test_chat_completions.py @@ -0,0 +1,143 @@ +from typing import List + +import openai +import pytest + +from .util import MESSAGES_WITHOUT_TOOLS, WEATHER_TOOL + + +# test: make sure chat completions without tools provided work even when tools +# are enabled. This makes sure tool call chat templates work, AND that the tool +# parser stream processing doesn't change the output of the model. +@pytest.mark.asyncio +async def test_chat_completion_without_tools(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=128, + model=model_name, + logprobs=False) + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + output_text = chat_completion.choices[0].message.content + + # check to make sure we got text + assert output_text is not None + assert len(output_text) > 0 + assert stop_reason != "tool_calls" + + # check to make sure no tool calls were returned + assert (choice.message.tool_calls is None + or len(choice.message.tool_calls) == 0) + + # make the same request, streaming + stream = await client.chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=128, + model=model_name, + logprobs=False, + stream=True, + ) + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + # assemble streamed chunks + async for chunk in stream: + delta = chunk.choices[0].delta + + # make sure the role is assistant + if delta.role: + assert not role_sent + assert delta.role == 'assistant' + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason + + # make sure tool call chunks aren't being streamed + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + # make sure the role was sent, only 1 finish reason was sent, that chunks + # were in fact sent, and that the chunks match non-streaming + assert role_sent + assert finish_reason_count == 1 + assert len(chunks) + assert "".join(chunks) == output_text + + +# test: conversation with tools enabled and provided that should not invoke +# tools, to make sure we can still get normal chat completion responses +# and that they won't be parsed as tools +@pytest.mark.asyncio +async def test_chat_completion_with_tools(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=128, + model=model_name, + tools=[WEATHER_TOOL], + logprobs=False) + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + output_text = chat_completion.choices[0].message.content + + # check to make sure we got text + assert output_text is not None + assert stop_reason != 'tool_calls' + assert len(output_text) > 0 + + # check to make sure no tool calls were returned + assert (choice.message.tool_calls is None + or len(choice.message.tool_calls) == 0) + + # make the same request, streaming + stream = await client.chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=128, + model=model_name, + logprobs=False, + tools=[WEATHER_TOOL], + stream=True, + ) + + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + # assemble streamed chunks + async for chunk in stream: + delta = chunk.choices[0].delta + + # make sure the role is assistant + if delta.role: + assert delta.role == 'assistant' + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + + # make sure tool call chunks aren't being streamed + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + # make sure the role was sent, only 1 finish reason was sent, that chunks + # were in fact sent, and that the chunks match non-streaming + assert role_sent + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == stop_reason + assert chunk.choices[0].finish_reason != 'tool_calls' + assert len(chunks) + assert "".join(chunks) == output_text diff --git a/tests/entrypoints/openai/tool_use/test_parallel_tool_calls.py b/tests/entrypoints/openai/tool_use/test_parallel_tool_calls.py new file mode 100644 index 0000000000000..5083cf394a6cf --- /dev/null +++ b/tests/entrypoints/openai/tool_use/test_parallel_tool_calls.py @@ -0,0 +1,193 @@ +import json +from typing import Dict, List, Optional + +import openai +import pytest + +from .util import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL, + WEATHER_TOOL) + + +# test: getting the model to generate parallel tool calls (streaming/not) +# when requested. NOTE that not all models may support this, so some exclusions +# may be added in the future. e.g. llama 3.1 models are not designed to support +# parallel tool calls. +@pytest.mark.asyncio +async def test_parallel_tool_calls(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + temperature=0, + max_tokens=800, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls + + # make sure 2 tool calls are present + assert choice.message.role == "assistant" + assert non_streamed_tool_calls is not None + assert len(non_streamed_tool_calls) == 2 + + for tool_call in non_streamed_tool_calls: + # make sure the tool includes a function and ID + assert tool_call.type == "function" + assert tool_call.function is not None + assert isinstance(tool_call.id, str) + assert len(tool_call.id) > 16 + + # make sure the weather tool was called correctly + assert tool_call.function.name == WEATHER_TOOL["function"]["name"] + assert isinstance(tool_call.function.arguments, str) + + parsed_arguments = json.loads(tool_call.function.arguments) + assert isinstance(parsed_arguments, Dict) + assert isinstance(parsed_arguments.get("city"), str) + assert isinstance(parsed_arguments.get("state"), str) + + assert stop_reason == "tool_calls" + + # make the same request, streaming + stream = await client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + temperature=0, + max_tokens=800, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + role_name: Optional[str] = None + finish_reason_count: int = 0 + + tool_call_names: List[str] = [] + tool_call_args: List[str] = [] + tool_call_idx: int = -1 + tool_call_id_count: int = 0 + + async for chunk in stream: + + # if there's a finish reason make sure it's tools + if chunk.choices[0].finish_reason: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == 'tool_calls' + + # if a role is being streamed make sure it wasn't already set to + # something else + if chunk.choices[0].delta.role: + assert not role_name or role_name == 'assistant' + role_name = 'assistant' + + # if a tool call is streamed make sure there's exactly one + # (based on the request parameters + streamed_tool_calls = chunk.choices[0].delta.tool_calls + + if streamed_tool_calls and len(streamed_tool_calls) > 0: + + # make sure only one diff is present - correct even for parallel + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + # if a new tool is being called, set up empty arguments + if tool_call.index != tool_call_idx: + tool_call_idx = tool_call.index + tool_call_args.append("") + + # if a tool call ID is streamed, make sure one hasn't been already + if tool_call.id: + tool_call_id_count += 1 + assert (isinstance(tool_call.id, str) + and (len(tool_call.id) > 16)) + + # if parts of the function start being streamed + if tool_call.function: + # if the function name is defined, set it. it should be streamed + # IN ENTIRETY, exactly one time. + if tool_call.function.name: + assert isinstance(tool_call.function.name, str) + tool_call_names.append(tool_call.function.name) + + if tool_call.function.arguments: + # make sure they're a string and then add them to the list + assert isinstance(tool_call.function.arguments, str) + + tool_call_args[ + tool_call.index] += tool_call.function.arguments + + assert finish_reason_count == 1 + assert role_name == 'assistant' + + assert (len(non_streamed_tool_calls) == len(tool_call_names) == + len(tool_call_args)) + + for i in range(0, 2): + assert non_streamed_tool_calls[i].function.name == tool_call_names[i] + streamed_args = json.loads(tool_call_args[i]) + non_streamed_args = json.loads( + non_streamed_tool_calls[i].function.arguments) + assert streamed_args == non_streamed_args + + +# test: providing parallel tool calls back to the model to get a response +# (streaming/not) +@pytest.mark.asyncio +async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, + temperature=0, + max_tokens=500, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + + assert choice.finish_reason != "tool_calls" # "stop" or "length" + assert choice.message.role == "assistant" + assert choice.message.tool_calls is None \ + or len(choice.message.tool_calls) == 0 + assert choice.message.content is not None + assert "98" in choice.message.content # Dallas temp in tool response + assert "78" in choice.message.content # Orlando temp in tool response + + stream = await client.chat.completions.create( + messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, + temperature=0, + max_tokens=500, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + async for chunk in stream: + delta = chunk.choices[0].delta + + if delta.role: + assert not role_sent + assert delta.role == "assistant" + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason + + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + assert role_sent + assert finish_reason_count == 1 + assert len(chunks) + assert "".join(chunks) == choice.message.content diff --git a/tests/entrypoints/openai/tool_use/test_tool_calls.py b/tests/entrypoints/openai/tool_use/test_tool_calls.py new file mode 100644 index 0000000000000..5f1a8dfff1c33 --- /dev/null +++ b/tests/entrypoints/openai/tool_use/test_tool_calls.py @@ -0,0 +1,192 @@ +import json +from typing import Dict, List, Optional + +import openai +import pytest + +from .util import (MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE, + SEARCH_TOOL, WEATHER_TOOL) + + +# test: request a chat completion that should return tool calls, so we know they +# are parsable +@pytest.mark.asyncio +async def test_tool_call_and_choice(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_ASKING_FOR_TOOLS, + temperature=0, + max_tokens=500, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + tool_calls = chat_completion.choices[0].message.tool_calls + + # make sure a tool call is present + assert choice.message.role == 'assistant' + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0].type == 'function' + assert tool_calls[0].function is not None + assert isinstance(tool_calls[0].id, str) + assert len(tool_calls[0].id) > 16 + + # make sure the weather tool was called (classic example) with arguments + assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"] + assert tool_calls[0].function.arguments is not None + assert isinstance(tool_calls[0].function.arguments, str) + + # make sure the arguments parse properly + parsed_arguments = json.loads(tool_calls[0].function.arguments) + assert isinstance(parsed_arguments, Dict) + assert isinstance(parsed_arguments.get("city"), str) + assert isinstance(parsed_arguments.get("state"), str) + assert parsed_arguments.get("city") == "Dallas" + assert parsed_arguments.get("state") == "TX" + + assert stop_reason == "tool_calls" + + function_name: Optional[str] = None + function_args_str: str = '' + tool_call_id: Optional[str] = None + role_name: Optional[str] = None + finish_reason_count: int = 0 + + # make the same request, streaming + stream = await client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_TOOLS, + temperature=0, + max_tokens=500, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + async for chunk in stream: + assert chunk.choices[0].index == 0 + + if chunk.choices[0].finish_reason: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == 'tool_calls' + + # if a role is being streamed make sure it wasn't already set to + # something else + if chunk.choices[0].delta.role: + assert not role_name or role_name == 'assistant' + role_name = 'assistant' + + # if a tool call is streamed make sure there's exactly one + # (based on the request parameters + streamed_tool_calls = chunk.choices[0].delta.tool_calls + + if streamed_tool_calls and len(streamed_tool_calls) > 0: + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + # if a tool call ID is streamed, make sure one hasn't been already + if tool_call.id: + assert not tool_call_id + tool_call_id = tool_call.id + + # if parts of the function start being streamed + if tool_call.function: + # if the function name is defined, set it. it should be streamed + # IN ENTIRETY, exactly one time. + if tool_call.function.name: + assert function_name is None + assert isinstance(tool_call.function.name, str) + function_name = tool_call.function.name + if tool_call.function.arguments: + assert isinstance(tool_call.function.arguments, str) + function_args_str += tool_call.function.arguments + + assert finish_reason_count == 1 + assert role_name == 'assistant' + assert isinstance(tool_call_id, str) and (len(tool_call_id) > 16) + + # validate the name and arguments + assert function_name == WEATHER_TOOL["function"]["name"] + assert function_name == tool_calls[0].function.name + assert isinstance(function_args_str, str) + + # validate arguments + streamed_args = json.loads(function_args_str) + assert isinstance(streamed_args, Dict) + assert isinstance(streamed_args.get("city"), str) + assert isinstance(streamed_args.get("state"), str) + assert streamed_args.get("city") == "Dallas" + assert streamed_args.get("state") == "TX" + + # make sure everything matches non-streaming except for ID + assert function_name == tool_calls[0].function.name + assert choice.message.role == role_name + assert choice.message.tool_calls[0].function.name == function_name + + # compare streamed with non-streamed args Dict-wise, not string-wise + # because character-to-character comparison might not work e.g. the tool + # call parser adding extra spaces or something like that. we care about the + # dicts matching not byte-wise match + assert parsed_arguments == streamed_args + + +# test: providing tools and results back to model to get a non-tool response +# (streaming/not) +@pytest.mark.asyncio +async def test_tool_call_with_results(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITH_TOOL_RESPONSE, + temperature=0, + max_tokens=500, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + + assert choice.finish_reason != "tool_calls" # "stop" or "length" + assert choice.message.role == "assistant" + assert choice.message.tool_calls is None \ + or len(choice.message.tool_calls) == 0 + assert choice.message.content is not None + assert "98" in choice.message.content # the temperature from the response + + stream = await client.chat.completions.create( + messages=MESSAGES_WITH_TOOL_RESPONSE, + temperature=0, + max_tokens=500, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + async for chunk in stream: + delta = chunk.choices[0].delta + + if delta.role: + assert not role_sent + assert delta.role == "assistant" + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason + + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + assert role_sent + assert finish_reason_count == 1 + assert len(chunks) + assert "".join(chunks) == choice.message.content diff --git a/tests/entrypoints/openai/tool_use/util.py b/tests/entrypoints/openai/tool_use/util.py new file mode 100644 index 0000000000000..a075d05583b03 --- /dev/null +++ b/tests/entrypoints/openai/tool_use/util.py @@ -0,0 +1,218 @@ +from typing import Dict, List + +from openai.types.chat import (ChatCompletionMessageParam, + ChatCompletionToolParam) +from typing_extensions import TypedDict + +from ....utils import VLLM_PATH + + +class ServerConfig(TypedDict): + model: str + arguments: List[str] + + +ARGS: List[str] = [ + "--dtype", + "half", # TODO change to BF16 + "--kv-cache-dtype", + "fp8", + "--enable-auto-tool-choice" +] + +CONFIGS: Dict[str, ServerConfig] = { + "hermes": { + "model": + "NousResearch/Hermes-2-Pro-Llama-3-8B", + "arguments": [ + "--tool-call-parser", "hermes", "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja") + ] + }, + "mistral": { + "model": + "mistralai/Mistral-7B-Instruct-v0.3", + "arguments": [ + "--tool-call-parser", "mistral", "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja") + ] + } +} + +WEATHER_TOOL: ChatCompletionToolParam = { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": + "string", + "description": + "The city to find the weather for, " + "e.g. 'San Francisco'" + }, + "state": { + "type": + "string", + "description": + "the two-letter abbreviation for the state " + "that the city is in, e.g. 'CA' which would " + "mean 'California'" + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"] + } + } + } + } +} + +SEARCH_TOOL: ChatCompletionToolParam = { + "type": "function", + "function": { + "name": + "web_search", + "description": + "Search the internet and get a summary of the top " + "10 webpages. Should only be used if you don't know " + "the answer to a user query, and the results are likely" + "to be able to be found with a web search", + "parameters": { + "type": "object", + "properties": { + "search_term": { + "type": + "string", + "description": + "The term to use in the search. This should" + "ideally be keywords to search for, not a" + "natural-language question" + } + }, + "required": ["search_term"] + } + } +} + +MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{ + "role": + "system", + "content": + "You are a helpful assistant with access to tools. If a tool" + " that you have would be helpful to answer a user query, " + "call the tool. Otherwise, answer the user's query directly " + "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " + "to the user's question - just respond to it normally." +}, { + "role": + "user", + "content": + "Hi! How are you?" +}, { + "role": + "assistant", + "content": + "I'm doing great! How can I assist you?" +}, { + "role": + "user", + "content": + "Can you tell me a joke please?" +}] + +MESSAGES_ASKING_FOR_TOOLS: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas in Fahrenheit?" +}] + +MESSAGES_WITH_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas in Fahrenheit?" +}, { + "role": + "assistant", + "tool_calls": [{ + "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "type": "function", + "function": { + "name": + WEATHER_TOOL["function"]["name"], + "arguments": + '{"city": "Dallas", "state": "TX", ' + '"unit": "fahrenheit"}' + } + }] +}, { + "role": + "tool", + "tool_call_id": + "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "content": + "The weather in Dallas is 98 degrees fahrenheit, with partly" + "cloudy skies and a low chance of rain." +}] + +MESSAGES_ASKING_FOR_PARALLEL_TOOLS: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas and Orlando, Florida in " + "Fahrenheit?" +}] + +MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas and Orlando, Florida in " + "Fahrenheit?" +}, { + "role": + "assistant", + "tool_calls": [{ + "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "type": "function", + "function": { + "name": + WEATHER_TOOL["function"]["name"], + "arguments": + '{"city": "Dallas", "state": "TX", ' + '"unit": "fahrenheit"}' + } + }, { + "id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", + "type": "function", + "function": { + "name": + WEATHER_TOOL["function"]["name"], + "arguments": + '{"city": "Orlando", "state": "Fl", ' + '"unit": "fahrenheit"}' + } + }] +}, { + "role": + "tool", + "tool_call_id": + "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "content": + "The weather in Dallas TX is 98 degrees fahrenheit with mostly " + "cloudy skies and a chance of rain in the evening." +}, { + "role": + "tool", + "tool_call_id": + "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", + "content": + "The weather in Orlando FL is 78 degrees fahrenheit with clear" + "skies." +}] From 8365a256726018f064b3bf778a688b3a4f85b2ef Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 15 Aug 2024 10:25:41 -0500 Subject: [PATCH 173/222] fix: add consolidated.safetensor to ignore list for mistral in tool entrypoint tests --- tests/entrypoints/openai/tool_use/util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/entrypoints/openai/tool_use/util.py b/tests/entrypoints/openai/tool_use/util.py index a075d05583b03..a9c3d65686fa3 100644 --- a/tests/entrypoints/openai/tool_use/util.py +++ b/tests/entrypoints/openai/tool_use/util.py @@ -34,7 +34,8 @@ class ServerConfig(TypedDict): "mistralai/Mistral-7B-Instruct-v0.3", "arguments": [ "--tool-call-parser", "mistral", "--chat-template", - str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja") + str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja"), + "--ignore-patterns=\"consolidated.safetensors\"" ] } } From 445cf59b0ec9bd43317de115082a332e4d3ee506 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 15 Aug 2024 11:29:43 -0500 Subject: [PATCH 174/222] refactor(ci): move tool tests out of entrypoints fastcheck, and create new non-fastcheck stage for them --- .buildkite/test-pipeline.yaml | 9 + tests/entrypoints_extended/__init__.py | 0 tests/entrypoints_extended/openai/__init__.py | 0 .../openai/tool_use/__init__.py | 0 .../openai/tool_use/conftest.py | 29 +++ .../openai/tool_use/test_chat_completions.py | 143 ++++++++++++ .../tool_use/test_parallel_tool_calls.py | 193 +++++++++++++++ .../openai/tool_use/test_tool_calls.py | 192 +++++++++++++++ .../openai/tool_use/util.py | 219 ++++++++++++++++++ 9 files changed, 785 insertions(+) create mode 100644 tests/entrypoints_extended/__init__.py create mode 100644 tests/entrypoints_extended/openai/__init__.py create mode 100644 tests/entrypoints_extended/openai/tool_use/__init__.py create mode 100644 tests/entrypoints_extended/openai/tool_use/conftest.py create mode 100644 tests/entrypoints_extended/openai/tool_use/test_chat_completions.py create mode 100644 tests/entrypoints_extended/openai/tool_use/test_parallel_tool_calls.py create mode 100644 tests/entrypoints_extended/openai/tool_use/test_tool_calls.py create mode 100644 tests/entrypoints_extended/openai/tool_use/util.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e21ae6b0502f4..bbbe55a7f435b 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -85,6 +85,15 @@ steps: - pytest -v -s entrypoints/llm - pytest -v -s entrypoints/openai +- label: Entrypoints Test (Tools, Extensions) # 20 min + fast_check: false + mirror_hardwares: [ amd ] + source_file_dependencies: + - vllm/ + - tests/entrypoints_extended + commands: + - pytest -v -s entrypoints_extended/openai + - label: Distributed Tests (4 GPUs) # 10min working_dir: "/vllm-workspace/tests" num_gpus: 4 diff --git a/tests/entrypoints_extended/__init__.py b/tests/entrypoints_extended/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/entrypoints_extended/openai/__init__.py b/tests/entrypoints_extended/openai/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/entrypoints_extended/openai/tool_use/__init__.py b/tests/entrypoints_extended/openai/tool_use/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/entrypoints_extended/openai/tool_use/conftest.py b/tests/entrypoints_extended/openai/tool_use/conftest.py new file mode 100644 index 0000000000000..3d221af999fc5 --- /dev/null +++ b/tests/entrypoints_extended/openai/tool_use/conftest.py @@ -0,0 +1,29 @@ +import pytest +from huggingface_hub import snapshot_download + +from ....utils import RemoteOpenAIServer +from .util import ARGS, CONFIGS, ServerConfig + + +# for each server config, download the model and return the config +@pytest.fixture(scope="module", params=CONFIGS.keys()) +def server_config(request): + config = CONFIGS[request.param] + # download model and tokenizer using transformers + snapshot_download(config["model"]) + yield CONFIGS[request.param] + + +# run this for each server config +@pytest.fixture(scope="module") +def server(request, server_config: ServerConfig): + model = server_config["model"] + args_for_model = server_config["arguments"] + with RemoteOpenAIServer(model, ARGS + args_for_model, + max_start_wait_s=240) as server: + yield server + + +@pytest.fixture(scope="module") +def client(server: RemoteOpenAIServer): + return server.get_async_client() diff --git a/tests/entrypoints_extended/openai/tool_use/test_chat_completions.py b/tests/entrypoints_extended/openai/tool_use/test_chat_completions.py new file mode 100644 index 0000000000000..a5adb04252c12 --- /dev/null +++ b/tests/entrypoints_extended/openai/tool_use/test_chat_completions.py @@ -0,0 +1,143 @@ +from typing import List + +import openai +import pytest + +from .util import MESSAGES_WITHOUT_TOOLS, WEATHER_TOOL + + +# test: make sure chat completions without tools provided work even when tools +# are enabled. This makes sure tool call chat templates work, AND that the tool +# parser stream processing doesn't change the output of the model. +@pytest.mark.asyncio +async def test_chat_completion_without_tools(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=128, + model=model_name, + logprobs=False) + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + output_text = chat_completion.choices[0].message.content + + # check to make sure we got text + assert output_text is not None + assert len(output_text) > 0 + assert stop_reason != "tool_calls" + + # check to make sure no tool calls were returned + assert (choice.message.tool_calls is None + or len(choice.message.tool_calls) == 0) + + # make the same request, streaming + stream = await client.chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=128, + model=model_name, + logprobs=False, + stream=True, + ) + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + # assemble streamed chunks + async for chunk in stream: + delta = chunk.choices[0].delta + + # make sure the role is assistant + if delta.role: + assert not role_sent + assert delta.role == 'assistant' + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason + + # make sure tool call chunks aren't being streamed + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + # make sure the role was sent, only 1 finish reason was sent, that chunks + # were in fact sent, and that the chunks match non-streaming + assert role_sent + assert finish_reason_count == 1 + assert len(chunks) + assert "".join(chunks) == output_text + + +# test: conversation with tools enabled and provided that should not invoke +# tools, to make sure we can still get normal chat completion responses +# and that they won't be parsed as tools +@pytest.mark.asyncio +async def test_chat_completion_with_tools(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=128, + model=model_name, + tools=[WEATHER_TOOL], + logprobs=False) + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + output_text = chat_completion.choices[0].message.content + + # check to make sure we got text + assert output_text is not None + assert stop_reason != 'tool_calls' + assert len(output_text) > 0 + + # check to make sure no tool calls were returned + assert (choice.message.tool_calls is None + or len(choice.message.tool_calls) == 0) + + # make the same request, streaming + stream = await client.chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=128, + model=model_name, + logprobs=False, + tools=[WEATHER_TOOL], + stream=True, + ) + + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + # assemble streamed chunks + async for chunk in stream: + delta = chunk.choices[0].delta + + # make sure the role is assistant + if delta.role: + assert delta.role == 'assistant' + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + + # make sure tool call chunks aren't being streamed + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + # make sure the role was sent, only 1 finish reason was sent, that chunks + # were in fact sent, and that the chunks match non-streaming + assert role_sent + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == stop_reason + assert chunk.choices[0].finish_reason != 'tool_calls' + assert len(chunks) + assert "".join(chunks) == output_text diff --git a/tests/entrypoints_extended/openai/tool_use/test_parallel_tool_calls.py b/tests/entrypoints_extended/openai/tool_use/test_parallel_tool_calls.py new file mode 100644 index 0000000000000..5083cf394a6cf --- /dev/null +++ b/tests/entrypoints_extended/openai/tool_use/test_parallel_tool_calls.py @@ -0,0 +1,193 @@ +import json +from typing import Dict, List, Optional + +import openai +import pytest + +from .util import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL, + WEATHER_TOOL) + + +# test: getting the model to generate parallel tool calls (streaming/not) +# when requested. NOTE that not all models may support this, so some exclusions +# may be added in the future. e.g. llama 3.1 models are not designed to support +# parallel tool calls. +@pytest.mark.asyncio +async def test_parallel_tool_calls(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + temperature=0, + max_tokens=800, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls + + # make sure 2 tool calls are present + assert choice.message.role == "assistant" + assert non_streamed_tool_calls is not None + assert len(non_streamed_tool_calls) == 2 + + for tool_call in non_streamed_tool_calls: + # make sure the tool includes a function and ID + assert tool_call.type == "function" + assert tool_call.function is not None + assert isinstance(tool_call.id, str) + assert len(tool_call.id) > 16 + + # make sure the weather tool was called correctly + assert tool_call.function.name == WEATHER_TOOL["function"]["name"] + assert isinstance(tool_call.function.arguments, str) + + parsed_arguments = json.loads(tool_call.function.arguments) + assert isinstance(parsed_arguments, Dict) + assert isinstance(parsed_arguments.get("city"), str) + assert isinstance(parsed_arguments.get("state"), str) + + assert stop_reason == "tool_calls" + + # make the same request, streaming + stream = await client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + temperature=0, + max_tokens=800, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + role_name: Optional[str] = None + finish_reason_count: int = 0 + + tool_call_names: List[str] = [] + tool_call_args: List[str] = [] + tool_call_idx: int = -1 + tool_call_id_count: int = 0 + + async for chunk in stream: + + # if there's a finish reason make sure it's tools + if chunk.choices[0].finish_reason: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == 'tool_calls' + + # if a role is being streamed make sure it wasn't already set to + # something else + if chunk.choices[0].delta.role: + assert not role_name or role_name == 'assistant' + role_name = 'assistant' + + # if a tool call is streamed make sure there's exactly one + # (based on the request parameters + streamed_tool_calls = chunk.choices[0].delta.tool_calls + + if streamed_tool_calls and len(streamed_tool_calls) > 0: + + # make sure only one diff is present - correct even for parallel + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + # if a new tool is being called, set up empty arguments + if tool_call.index != tool_call_idx: + tool_call_idx = tool_call.index + tool_call_args.append("") + + # if a tool call ID is streamed, make sure one hasn't been already + if tool_call.id: + tool_call_id_count += 1 + assert (isinstance(tool_call.id, str) + and (len(tool_call.id) > 16)) + + # if parts of the function start being streamed + if tool_call.function: + # if the function name is defined, set it. it should be streamed + # IN ENTIRETY, exactly one time. + if tool_call.function.name: + assert isinstance(tool_call.function.name, str) + tool_call_names.append(tool_call.function.name) + + if tool_call.function.arguments: + # make sure they're a string and then add them to the list + assert isinstance(tool_call.function.arguments, str) + + tool_call_args[ + tool_call.index] += tool_call.function.arguments + + assert finish_reason_count == 1 + assert role_name == 'assistant' + + assert (len(non_streamed_tool_calls) == len(tool_call_names) == + len(tool_call_args)) + + for i in range(0, 2): + assert non_streamed_tool_calls[i].function.name == tool_call_names[i] + streamed_args = json.loads(tool_call_args[i]) + non_streamed_args = json.loads( + non_streamed_tool_calls[i].function.arguments) + assert streamed_args == non_streamed_args + + +# test: providing parallel tool calls back to the model to get a response +# (streaming/not) +@pytest.mark.asyncio +async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, + temperature=0, + max_tokens=500, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + + assert choice.finish_reason != "tool_calls" # "stop" or "length" + assert choice.message.role == "assistant" + assert choice.message.tool_calls is None \ + or len(choice.message.tool_calls) == 0 + assert choice.message.content is not None + assert "98" in choice.message.content # Dallas temp in tool response + assert "78" in choice.message.content # Orlando temp in tool response + + stream = await client.chat.completions.create( + messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, + temperature=0, + max_tokens=500, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + async for chunk in stream: + delta = chunk.choices[0].delta + + if delta.role: + assert not role_sent + assert delta.role == "assistant" + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason + + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + assert role_sent + assert finish_reason_count == 1 + assert len(chunks) + assert "".join(chunks) == choice.message.content diff --git a/tests/entrypoints_extended/openai/tool_use/test_tool_calls.py b/tests/entrypoints_extended/openai/tool_use/test_tool_calls.py new file mode 100644 index 0000000000000..5f1a8dfff1c33 --- /dev/null +++ b/tests/entrypoints_extended/openai/tool_use/test_tool_calls.py @@ -0,0 +1,192 @@ +import json +from typing import Dict, List, Optional + +import openai +import pytest + +from .util import (MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE, + SEARCH_TOOL, WEATHER_TOOL) + + +# test: request a chat completion that should return tool calls, so we know they +# are parsable +@pytest.mark.asyncio +async def test_tool_call_and_choice(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_ASKING_FOR_TOOLS, + temperature=0, + max_tokens=500, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + tool_calls = chat_completion.choices[0].message.tool_calls + + # make sure a tool call is present + assert choice.message.role == 'assistant' + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0].type == 'function' + assert tool_calls[0].function is not None + assert isinstance(tool_calls[0].id, str) + assert len(tool_calls[0].id) > 16 + + # make sure the weather tool was called (classic example) with arguments + assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"] + assert tool_calls[0].function.arguments is not None + assert isinstance(tool_calls[0].function.arguments, str) + + # make sure the arguments parse properly + parsed_arguments = json.loads(tool_calls[0].function.arguments) + assert isinstance(parsed_arguments, Dict) + assert isinstance(parsed_arguments.get("city"), str) + assert isinstance(parsed_arguments.get("state"), str) + assert parsed_arguments.get("city") == "Dallas" + assert parsed_arguments.get("state") == "TX" + + assert stop_reason == "tool_calls" + + function_name: Optional[str] = None + function_args_str: str = '' + tool_call_id: Optional[str] = None + role_name: Optional[str] = None + finish_reason_count: int = 0 + + # make the same request, streaming + stream = await client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_TOOLS, + temperature=0, + max_tokens=500, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + async for chunk in stream: + assert chunk.choices[0].index == 0 + + if chunk.choices[0].finish_reason: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == 'tool_calls' + + # if a role is being streamed make sure it wasn't already set to + # something else + if chunk.choices[0].delta.role: + assert not role_name or role_name == 'assistant' + role_name = 'assistant' + + # if a tool call is streamed make sure there's exactly one + # (based on the request parameters + streamed_tool_calls = chunk.choices[0].delta.tool_calls + + if streamed_tool_calls and len(streamed_tool_calls) > 0: + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + # if a tool call ID is streamed, make sure one hasn't been already + if tool_call.id: + assert not tool_call_id + tool_call_id = tool_call.id + + # if parts of the function start being streamed + if tool_call.function: + # if the function name is defined, set it. it should be streamed + # IN ENTIRETY, exactly one time. + if tool_call.function.name: + assert function_name is None + assert isinstance(tool_call.function.name, str) + function_name = tool_call.function.name + if tool_call.function.arguments: + assert isinstance(tool_call.function.arguments, str) + function_args_str += tool_call.function.arguments + + assert finish_reason_count == 1 + assert role_name == 'assistant' + assert isinstance(tool_call_id, str) and (len(tool_call_id) > 16) + + # validate the name and arguments + assert function_name == WEATHER_TOOL["function"]["name"] + assert function_name == tool_calls[0].function.name + assert isinstance(function_args_str, str) + + # validate arguments + streamed_args = json.loads(function_args_str) + assert isinstance(streamed_args, Dict) + assert isinstance(streamed_args.get("city"), str) + assert isinstance(streamed_args.get("state"), str) + assert streamed_args.get("city") == "Dallas" + assert streamed_args.get("state") == "TX" + + # make sure everything matches non-streaming except for ID + assert function_name == tool_calls[0].function.name + assert choice.message.role == role_name + assert choice.message.tool_calls[0].function.name == function_name + + # compare streamed with non-streamed args Dict-wise, not string-wise + # because character-to-character comparison might not work e.g. the tool + # call parser adding extra spaces or something like that. we care about the + # dicts matching not byte-wise match + assert parsed_arguments == streamed_args + + +# test: providing tools and results back to model to get a non-tool response +# (streaming/not) +@pytest.mark.asyncio +async def test_tool_call_with_results(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITH_TOOL_RESPONSE, + temperature=0, + max_tokens=500, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + + assert choice.finish_reason != "tool_calls" # "stop" or "length" + assert choice.message.role == "assistant" + assert choice.message.tool_calls is None \ + or len(choice.message.tool_calls) == 0 + assert choice.message.content is not None + assert "98" in choice.message.content # the temperature from the response + + stream = await client.chat.completions.create( + messages=MESSAGES_WITH_TOOL_RESPONSE, + temperature=0, + max_tokens=500, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + async for chunk in stream: + delta = chunk.choices[0].delta + + if delta.role: + assert not role_sent + assert delta.role == "assistant" + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason + + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + assert role_sent + assert finish_reason_count == 1 + assert len(chunks) + assert "".join(chunks) == choice.message.content diff --git a/tests/entrypoints_extended/openai/tool_use/util.py b/tests/entrypoints_extended/openai/tool_use/util.py new file mode 100644 index 0000000000000..a9c3d65686fa3 --- /dev/null +++ b/tests/entrypoints_extended/openai/tool_use/util.py @@ -0,0 +1,219 @@ +from typing import Dict, List + +from openai.types.chat import (ChatCompletionMessageParam, + ChatCompletionToolParam) +from typing_extensions import TypedDict + +from ....utils import VLLM_PATH + + +class ServerConfig(TypedDict): + model: str + arguments: List[str] + + +ARGS: List[str] = [ + "--dtype", + "half", # TODO change to BF16 + "--kv-cache-dtype", + "fp8", + "--enable-auto-tool-choice" +] + +CONFIGS: Dict[str, ServerConfig] = { + "hermes": { + "model": + "NousResearch/Hermes-2-Pro-Llama-3-8B", + "arguments": [ + "--tool-call-parser", "hermes", "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja") + ] + }, + "mistral": { + "model": + "mistralai/Mistral-7B-Instruct-v0.3", + "arguments": [ + "--tool-call-parser", "mistral", "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja"), + "--ignore-patterns=\"consolidated.safetensors\"" + ] + } +} + +WEATHER_TOOL: ChatCompletionToolParam = { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": + "string", + "description": + "The city to find the weather for, " + "e.g. 'San Francisco'" + }, + "state": { + "type": + "string", + "description": + "the two-letter abbreviation for the state " + "that the city is in, e.g. 'CA' which would " + "mean 'California'" + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"] + } + } + } + } +} + +SEARCH_TOOL: ChatCompletionToolParam = { + "type": "function", + "function": { + "name": + "web_search", + "description": + "Search the internet and get a summary of the top " + "10 webpages. Should only be used if you don't know " + "the answer to a user query, and the results are likely" + "to be able to be found with a web search", + "parameters": { + "type": "object", + "properties": { + "search_term": { + "type": + "string", + "description": + "The term to use in the search. This should" + "ideally be keywords to search for, not a" + "natural-language question" + } + }, + "required": ["search_term"] + } + } +} + +MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{ + "role": + "system", + "content": + "You are a helpful assistant with access to tools. If a tool" + " that you have would be helpful to answer a user query, " + "call the tool. Otherwise, answer the user's query directly " + "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " + "to the user's question - just respond to it normally." +}, { + "role": + "user", + "content": + "Hi! How are you?" +}, { + "role": + "assistant", + "content": + "I'm doing great! How can I assist you?" +}, { + "role": + "user", + "content": + "Can you tell me a joke please?" +}] + +MESSAGES_ASKING_FOR_TOOLS: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas in Fahrenheit?" +}] + +MESSAGES_WITH_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas in Fahrenheit?" +}, { + "role": + "assistant", + "tool_calls": [{ + "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "type": "function", + "function": { + "name": + WEATHER_TOOL["function"]["name"], + "arguments": + '{"city": "Dallas", "state": "TX", ' + '"unit": "fahrenheit"}' + } + }] +}, { + "role": + "tool", + "tool_call_id": + "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "content": + "The weather in Dallas is 98 degrees fahrenheit, with partly" + "cloudy skies and a low chance of rain." +}] + +MESSAGES_ASKING_FOR_PARALLEL_TOOLS: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas and Orlando, Florida in " + "Fahrenheit?" +}] + +MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas and Orlando, Florida in " + "Fahrenheit?" +}, { + "role": + "assistant", + "tool_calls": [{ + "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "type": "function", + "function": { + "name": + WEATHER_TOOL["function"]["name"], + "arguments": + '{"city": "Dallas", "state": "TX", ' + '"unit": "fahrenheit"}' + } + }, { + "id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", + "type": "function", + "function": { + "name": + WEATHER_TOOL["function"]["name"], + "arguments": + '{"city": "Orlando", "state": "Fl", ' + '"unit": "fahrenheit"}' + } + }] +}, { + "role": + "tool", + "tool_call_id": + "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "content": + "The weather in Dallas TX is 98 degrees fahrenheit with mostly " + "cloudy skies and a chance of rain in the evening." +}, { + "role": + "tool", + "tool_call_id": + "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", + "content": + "The weather in Orlando FL is 78 degrees fahrenheit with clear" + "skies." +}] From afc41d017433221419176daa85b0890fcd617fa3 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 15 Aug 2024 11:55:38 -0500 Subject: [PATCH 175/222] fix: refactor tool use with CI changes --- .buildkite/test-pipeline.yaml | 6 +++--- tests/entrypoints_extended/openai/__init__.py | 0 tests/entrypoints_extended/openai/tool_use/__init__.py | 0 tests/{entrypoints_extended => tool_use}/__init__.py | 0 .../{entrypoints_extended/openai => }/tool_use/conftest.py | 2 +- .../openai => }/tool_use/test_chat_completions.py | 0 .../openai => }/tool_use/test_parallel_tool_calls.py | 0 .../openai => }/tool_use/test_tool_calls.py | 0 tests/{entrypoints_extended/openai => }/tool_use/util.py | 2 +- 9 files changed, 5 insertions(+), 5 deletions(-) delete mode 100644 tests/entrypoints_extended/openai/__init__.py delete mode 100644 tests/entrypoints_extended/openai/tool_use/__init__.py rename tests/{entrypoints_extended => tool_use}/__init__.py (100%) rename tests/{entrypoints_extended/openai => }/tool_use/conftest.py (95%) rename tests/{entrypoints_extended/openai => }/tool_use/test_chat_completions.py (100%) rename tests/{entrypoints_extended/openai => }/tool_use/test_parallel_tool_calls.py (100%) rename tests/{entrypoints_extended/openai => }/tool_use/test_tool_calls.py (100%) rename tests/{entrypoints_extended/openai => }/tool_use/util.py (99%) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index bbbe55a7f435b..cb30e7493bdad 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -85,14 +85,14 @@ steps: - pytest -v -s entrypoints/llm - pytest -v -s entrypoints/openai -- label: Entrypoints Test (Tools, Extensions) # 20 min +- label: OpenAI-Compatible Tool Use # 20 min fast_check: false mirror_hardwares: [ amd ] source_file_dependencies: - vllm/ - - tests/entrypoints_extended + - tests/tool_use commands: - - pytest -v -s entrypoints_extended/openai + - pytest -v -s tool_use - label: Distributed Tests (4 GPUs) # 10min working_dir: "/vllm-workspace/tests" diff --git a/tests/entrypoints_extended/openai/__init__.py b/tests/entrypoints_extended/openai/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/tests/entrypoints_extended/openai/tool_use/__init__.py b/tests/entrypoints_extended/openai/tool_use/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/tests/entrypoints_extended/__init__.py b/tests/tool_use/__init__.py similarity index 100% rename from tests/entrypoints_extended/__init__.py rename to tests/tool_use/__init__.py diff --git a/tests/entrypoints_extended/openai/tool_use/conftest.py b/tests/tool_use/conftest.py similarity index 95% rename from tests/entrypoints_extended/openai/tool_use/conftest.py rename to tests/tool_use/conftest.py index 3d221af999fc5..3b52d226d42ff 100644 --- a/tests/entrypoints_extended/openai/tool_use/conftest.py +++ b/tests/tool_use/conftest.py @@ -1,7 +1,7 @@ import pytest from huggingface_hub import snapshot_download -from ....utils import RemoteOpenAIServer +from tests.utils import RemoteOpenAIServer from .util import ARGS, CONFIGS, ServerConfig diff --git a/tests/entrypoints_extended/openai/tool_use/test_chat_completions.py b/tests/tool_use/test_chat_completions.py similarity index 100% rename from tests/entrypoints_extended/openai/tool_use/test_chat_completions.py rename to tests/tool_use/test_chat_completions.py diff --git a/tests/entrypoints_extended/openai/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py similarity index 100% rename from tests/entrypoints_extended/openai/tool_use/test_parallel_tool_calls.py rename to tests/tool_use/test_parallel_tool_calls.py diff --git a/tests/entrypoints_extended/openai/tool_use/test_tool_calls.py b/tests/tool_use/test_tool_calls.py similarity index 100% rename from tests/entrypoints_extended/openai/tool_use/test_tool_calls.py rename to tests/tool_use/test_tool_calls.py diff --git a/tests/entrypoints_extended/openai/tool_use/util.py b/tests/tool_use/util.py similarity index 99% rename from tests/entrypoints_extended/openai/tool_use/util.py rename to tests/tool_use/util.py index a9c3d65686fa3..47e8b9c87aefd 100644 --- a/tests/entrypoints_extended/openai/tool_use/util.py +++ b/tests/tool_use/util.py @@ -4,7 +4,7 @@ ChatCompletionToolParam) from typing_extensions import TypedDict -from ....utils import VLLM_PATH +from tests.utils import VLLM_PATH class ServerConfig(TypedDict): From 1fd4648dba759a9874369d142cc4b4f125efb809 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 15 Aug 2024 15:05:47 -0500 Subject: [PATCH 176/222] doc: update docs to clarify recommended CLI options & available chat templates --- .../serving/openai_compatible_server.md | 47 +++++++------------ 1 file changed, 18 insertions(+), 29 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 8211410e96d87..ca60ea2f8f9cd 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -125,57 +125,46 @@ vLLM will use guided decoding to ensure the response matches the tool parameter ### Automatic Function Calling -_This feature is in **beta**. It has limited model support, is not guaranteed to be stable, and does not have -well-defined failure modes._ As such, it must be explicitly enabled when desired. - -To enable this feature, you must set the following flags: +To enable this feature, you should set the following flags: * `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it -deems appropriate. +deems appropriate. +* `--tool-parser` -- select the tool parser to use - currently either `hermes` or `mistral`. Additional tool parsers +will continue to be added in the future. * `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages -that contain previously generated tool calls.This argument can be set to `tool_use` if your model has a tool use chat +that contain previously generated tool calls. Hermes and Mistral models have tool-compatible chat templates in their +`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat template configured in the `tokenizer_config.json`. In this case, it will be used per the `transformers` specification. More on this [here](https://huggingface.co/docs/transformers/en/chat_templating#why-do-some-models-have-multiple-templates) from HuggingFace; and you can find an example of this in a `tokenizer_config.json` [here](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/tokenizer_config.json) -* `--tool-parser` -- select the tool parser to use - currently either `hermes` or `mistral`. If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template! #### Hermes Models -Supported models in this series: -* `NousResearch/Hermes-2-Pro-Llama-3-8B` -* `NousResearch/Hermes-2-Theta-Llama-3-70B` -* `NousResearch/Hermes-2-Pro-Llama-3-70B` -* `NousResearch/Hermes-2-Theta-Llama-3-8B` -* `NousResearch/Hermes-2-Pro-Mistral-7B` +All Nous Research Hermes-series models newer than Hermes 2 Pro should be supported. +* `NousResearch/Hermes-2-Pro-*` +* `NousResearch/Hermes-2-Theta-*` +* `NousResearch/Hermes-3-*` + _Note that the Hermes 2 **Theta** models are known to have degraded tool call quality & capabilities due to the merge -step in their creation_. It is recommended to use the Hermes 2 **Pro** models. +step in their creation_. -Recommended flags: `--tool-call-parser hermes --chat-template examples/tool_chat_template_hermes.jinja` +Flags: `--tool-call-parser hermes` #### Mistral Models Supported models: -* `mistralai/Mistral-7B-Instruct-v0.3` -* Possibly mistral-large and mixtral? These have not been tested at the time of this writing. +* `mistralai/Mistral-7B-Instruct-v0.3` (confirmed) +* Additional mistral function-calling models are compatible as well. Known issues: 1. Mistral 7B struggles to generate parallel tool calls correctly. 2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is -much shorter than what vLLM generates. - -To address this, the following additional chat templates are provided: +much shorter than what vLLM generates. Since an exception is thrown when this condition +is not met, the following additional chat templates are provided: * `examples/tool_chat_template_mistral.jinja` - this is the "official" Mistral chat template, but tweaked so that it works with vLLM's tool call IDs (provided `tool_call_id` fields are truncated to the last 9 digits) * `examples/tool_chat_template_mistral_parallel.jinja` - this is a "better" version that adds a tool-use system prompt when tools are provided, that results in much better reliability when working with parallel tool calling. -**Please note** that the model's default chat template in `tokenizer_config.json` will not work with vLLM, as it expects -tool_call_id fields to be exactly 9 digits, which is shorter than vLLM's format. You **must** do one of the following -to get tool calling to work with mistral: -1. use one of the 2 provided tool chat templates -2. provide your own tool chat template that corrects for this -3. in your client code, ignore the vLLM-generated `tool_call_id`, and manually generate and pass in your own 9-digit -`tool_call_id`s for `assistant`-role messages containing tool calls, and `tool`-role messages containing tool call -results. -Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral.jinja` +Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja` From e2a1b798c3597277effb59b2dea463cab871926e Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 15 Aug 2024 16:06:15 -0500 Subject: [PATCH 177/222] fix: hermes tool template and set fixtures to be session-scoped --- examples/tool_chat_template_hermes.jinja | 71 +++--- tests/entrypoints/openai/tool_use/__init__.py | 0 tests/entrypoints/openai/tool_use/conftest.py | 29 --- .../openai/tool_use/test_chat_completions.py | 143 ------------ .../tool_use/test_parallel_tool_calls.py | 193 --------------- .../openai/tool_use/test_tool_calls.py | 192 --------------- tests/entrypoints/openai/tool_use/util.py | 219 ------------------ tests/tool_use/conftest.py | 6 +- tests/tool_use/util.py | 7 +- 9 files changed, 45 insertions(+), 815 deletions(-) delete mode 100644 tests/entrypoints/openai/tool_use/__init__.py delete mode 100644 tests/entrypoints/openai/tool_use/conftest.py delete mode 100644 tests/entrypoints/openai/tool_use/test_chat_completions.py delete mode 100644 tests/entrypoints/openai/tool_use/test_parallel_tool_calls.py delete mode 100644 tests/entrypoints/openai/tool_use/test_tool_calls.py delete mode 100644 tests/entrypoints/openai/tool_use/util.py diff --git a/examples/tool_chat_template_hermes.jinja b/examples/tool_chat_template_hermes.jinja index 7da8db52de891..35247917998a9 100644 --- a/examples/tool_chat_template_hermes.jinja +++ b/examples/tool_chat_template_hermes.jinja @@ -32,8 +32,8 @@ {{- bos_token }} +{{- "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }} {%- if tools is iterable and tools | length > 0 %} - {{- "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }} {%- for tool in tools %} {%- if tool.function is defined %} {%- set tool = tool.function %} @@ -73,49 +73,58 @@ {{- "\n" }} {%- endif %} {%- endfor %} - {{- " " }} - {{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"} +{%- endif %} +{{- " " }} +{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}} ' }} - {{- "For each function call return a json object with function name and arguments within XML tags as follows: +{{- "For each function call return a json object with function name and arguments within XML tags as follows: " }} - {{- " +{{- " " }} - {{- '{"name": , "arguments": } +{{- '{"name": , "arguments": } ' }} - {{- '<|im_end|>' }} -{%- endif %} +{{- '<|im_end|>' }} {%- for message in messages %} {%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %} {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} - {%- elif message.role == "assistant" %} - {{- '<|im_start|>' + message.role }} - {%- for tool_call in message.tool_calls %} - {{- '\n\n' }} - {%- if tool_call.function is defined %} - {%- set tool_call = tool_call.function %} - {%- endif %} - {{- '{ ' }} - {%- if tool_call.arguments is defined %} - {{- '"arguments": ' }} - {{- tool_call.arguments|tojson }} + {%- elif message.role == "assistant" and message.tool_calls is defined %} + {{- '<|im_start|>' + message.role }} + {%- for tool_call in message.tool_calls %} + {{- '\n\n' }} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '{' }} + {{- '"name": "' }} + {{- tool_call.name }} + {{- '"}' }} {{- ', ' }} - {%- endif %} - {{- '"name": "' }} - {{- tool_call.name }} - {{- '"}' }} - {{- '\n ' }} - {%- endfor %} + {%- if tool_call.arguments is defined %} + {{- '"arguments": ' }} + {{- tool_call.arguments|tojson }} + {%- endif %} + {{- '\n' }} + {%- endfor %} {{- '<|im_end|>\n' }} {%- elif message.role == "tool" %} {%- if not message.name is defined %} {{- raise_exception("Tool response dicts require a 'name' key indicating the name of the called function!") }} {%- endif %} - {{- '<|im_start|>' + message.role + '\n\n' }} - {{- '{"name": "' }} - {{- message.name }} - {{- '", "content": ' }} - {{- message.content|tojson + '}' }} - {{- '\n <|im_end|>\n' }} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>tool\n' }} + {%- endif %} + {{- '\n' }} + {{- message.content }} + {%- if not loop.last %} + {{- '\n\n' }} + {%- else %} + {{- '\n' }} + {%- endif %} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>' }} + {%- elif loop.last %} + {{- '<|im_end|>' }} + {%- endif %} {%- endif %} {%- endfor %} {%- if add_generation_prompt %} diff --git a/tests/entrypoints/openai/tool_use/__init__.py b/tests/entrypoints/openai/tool_use/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/tests/entrypoints/openai/tool_use/conftest.py b/tests/entrypoints/openai/tool_use/conftest.py deleted file mode 100644 index 3d221af999fc5..0000000000000 --- a/tests/entrypoints/openai/tool_use/conftest.py +++ /dev/null @@ -1,29 +0,0 @@ -import pytest -from huggingface_hub import snapshot_download - -from ....utils import RemoteOpenAIServer -from .util import ARGS, CONFIGS, ServerConfig - - -# for each server config, download the model and return the config -@pytest.fixture(scope="module", params=CONFIGS.keys()) -def server_config(request): - config = CONFIGS[request.param] - # download model and tokenizer using transformers - snapshot_download(config["model"]) - yield CONFIGS[request.param] - - -# run this for each server config -@pytest.fixture(scope="module") -def server(request, server_config: ServerConfig): - model = server_config["model"] - args_for_model = server_config["arguments"] - with RemoteOpenAIServer(model, ARGS + args_for_model, - max_start_wait_s=240) as server: - yield server - - -@pytest.fixture(scope="module") -def client(server: RemoteOpenAIServer): - return server.get_async_client() diff --git a/tests/entrypoints/openai/tool_use/test_chat_completions.py b/tests/entrypoints/openai/tool_use/test_chat_completions.py deleted file mode 100644 index a5adb04252c12..0000000000000 --- a/tests/entrypoints/openai/tool_use/test_chat_completions.py +++ /dev/null @@ -1,143 +0,0 @@ -from typing import List - -import openai -import pytest - -from .util import MESSAGES_WITHOUT_TOOLS, WEATHER_TOOL - - -# test: make sure chat completions without tools provided work even when tools -# are enabled. This makes sure tool call chat templates work, AND that the tool -# parser stream processing doesn't change the output of the model. -@pytest.mark.asyncio -async def test_chat_completion_without_tools(client: openai.AsyncOpenAI): - models = await client.models.list() - model_name: str = models.data[0].id - chat_completion = await client.chat.completions.create( - messages=MESSAGES_WITHOUT_TOOLS, - temperature=0, - max_tokens=128, - model=model_name, - logprobs=False) - choice = chat_completion.choices[0] - stop_reason = chat_completion.choices[0].finish_reason - output_text = chat_completion.choices[0].message.content - - # check to make sure we got text - assert output_text is not None - assert len(output_text) > 0 - assert stop_reason != "tool_calls" - - # check to make sure no tool calls were returned - assert (choice.message.tool_calls is None - or len(choice.message.tool_calls) == 0) - - # make the same request, streaming - stream = await client.chat.completions.create( - messages=MESSAGES_WITHOUT_TOOLS, - temperature=0, - max_tokens=128, - model=model_name, - logprobs=False, - stream=True, - ) - chunks: List[str] = [] - finish_reason_count = 0 - role_sent: bool = False - - # assemble streamed chunks - async for chunk in stream: - delta = chunk.choices[0].delta - - # make sure the role is assistant - if delta.role: - assert not role_sent - assert delta.role == 'assistant' - role_sent = True - - if delta.content: - chunks.append(delta.content) - - if chunk.choices[0].finish_reason is not None: - finish_reason_count += 1 - assert chunk.choices[0].finish_reason == choice.finish_reason - - # make sure tool call chunks aren't being streamed - assert not delta.tool_calls or len(delta.tool_calls) == 0 - - # make sure the role was sent, only 1 finish reason was sent, that chunks - # were in fact sent, and that the chunks match non-streaming - assert role_sent - assert finish_reason_count == 1 - assert len(chunks) - assert "".join(chunks) == output_text - - -# test: conversation with tools enabled and provided that should not invoke -# tools, to make sure we can still get normal chat completion responses -# and that they won't be parsed as tools -@pytest.mark.asyncio -async def test_chat_completion_with_tools(client: openai.AsyncOpenAI): - models = await client.models.list() - model_name: str = models.data[0].id - chat_completion = await client.chat.completions.create( - messages=MESSAGES_WITHOUT_TOOLS, - temperature=0, - max_tokens=128, - model=model_name, - tools=[WEATHER_TOOL], - logprobs=False) - choice = chat_completion.choices[0] - stop_reason = chat_completion.choices[0].finish_reason - output_text = chat_completion.choices[0].message.content - - # check to make sure we got text - assert output_text is not None - assert stop_reason != 'tool_calls' - assert len(output_text) > 0 - - # check to make sure no tool calls were returned - assert (choice.message.tool_calls is None - or len(choice.message.tool_calls) == 0) - - # make the same request, streaming - stream = await client.chat.completions.create( - messages=MESSAGES_WITHOUT_TOOLS, - temperature=0, - max_tokens=128, - model=model_name, - logprobs=False, - tools=[WEATHER_TOOL], - stream=True, - ) - - chunks: List[str] = [] - finish_reason_count = 0 - role_sent: bool = False - - # assemble streamed chunks - async for chunk in stream: - delta = chunk.choices[0].delta - - # make sure the role is assistant - if delta.role: - assert delta.role == 'assistant' - role_sent = True - - if delta.content: - chunks.append(delta.content) - - if chunk.choices[0].finish_reason is not None: - finish_reason_count += 1 - - # make sure tool call chunks aren't being streamed - assert not delta.tool_calls or len(delta.tool_calls) == 0 - - # make sure the role was sent, only 1 finish reason was sent, that chunks - # were in fact sent, and that the chunks match non-streaming - assert role_sent - assert finish_reason_count == 1 - assert chunk.choices[0].finish_reason == stop_reason - assert chunk.choices[0].finish_reason != 'tool_calls' - assert len(chunks) - assert "".join(chunks) == output_text diff --git a/tests/entrypoints/openai/tool_use/test_parallel_tool_calls.py b/tests/entrypoints/openai/tool_use/test_parallel_tool_calls.py deleted file mode 100644 index 5083cf394a6cf..0000000000000 --- a/tests/entrypoints/openai/tool_use/test_parallel_tool_calls.py +++ /dev/null @@ -1,193 +0,0 @@ -import json -from typing import Dict, List, Optional - -import openai -import pytest - -from .util import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS, - MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL, - WEATHER_TOOL) - - -# test: getting the model to generate parallel tool calls (streaming/not) -# when requested. NOTE that not all models may support this, so some exclusions -# may be added in the future. e.g. llama 3.1 models are not designed to support -# parallel tool calls. -@pytest.mark.asyncio -async def test_parallel_tool_calls(client: openai.AsyncOpenAI): - models = await client.models.list() - model_name: str = models.data[0].id - chat_completion = await client.chat.completions.create( - messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, - temperature=0, - max_tokens=800, - model=model_name, - tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False) - - choice = chat_completion.choices[0] - stop_reason = chat_completion.choices[0].finish_reason - non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls - - # make sure 2 tool calls are present - assert choice.message.role == "assistant" - assert non_streamed_tool_calls is not None - assert len(non_streamed_tool_calls) == 2 - - for tool_call in non_streamed_tool_calls: - # make sure the tool includes a function and ID - assert tool_call.type == "function" - assert tool_call.function is not None - assert isinstance(tool_call.id, str) - assert len(tool_call.id) > 16 - - # make sure the weather tool was called correctly - assert tool_call.function.name == WEATHER_TOOL["function"]["name"] - assert isinstance(tool_call.function.arguments, str) - - parsed_arguments = json.loads(tool_call.function.arguments) - assert isinstance(parsed_arguments, Dict) - assert isinstance(parsed_arguments.get("city"), str) - assert isinstance(parsed_arguments.get("state"), str) - - assert stop_reason == "tool_calls" - - # make the same request, streaming - stream = await client.chat.completions.create( - model=model_name, - messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, - temperature=0, - max_tokens=800, - tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False, - stream=True) - - role_name: Optional[str] = None - finish_reason_count: int = 0 - - tool_call_names: List[str] = [] - tool_call_args: List[str] = [] - tool_call_idx: int = -1 - tool_call_id_count: int = 0 - - async for chunk in stream: - - # if there's a finish reason make sure it's tools - if chunk.choices[0].finish_reason: - finish_reason_count += 1 - assert chunk.choices[0].finish_reason == 'tool_calls' - - # if a role is being streamed make sure it wasn't already set to - # something else - if chunk.choices[0].delta.role: - assert not role_name or role_name == 'assistant' - role_name = 'assistant' - - # if a tool call is streamed make sure there's exactly one - # (based on the request parameters - streamed_tool_calls = chunk.choices[0].delta.tool_calls - - if streamed_tool_calls and len(streamed_tool_calls) > 0: - - # make sure only one diff is present - correct even for parallel - assert len(streamed_tool_calls) == 1 - tool_call = streamed_tool_calls[0] - - # if a new tool is being called, set up empty arguments - if tool_call.index != tool_call_idx: - tool_call_idx = tool_call.index - tool_call_args.append("") - - # if a tool call ID is streamed, make sure one hasn't been already - if tool_call.id: - tool_call_id_count += 1 - assert (isinstance(tool_call.id, str) - and (len(tool_call.id) > 16)) - - # if parts of the function start being streamed - if tool_call.function: - # if the function name is defined, set it. it should be streamed - # IN ENTIRETY, exactly one time. - if tool_call.function.name: - assert isinstance(tool_call.function.name, str) - tool_call_names.append(tool_call.function.name) - - if tool_call.function.arguments: - # make sure they're a string and then add them to the list - assert isinstance(tool_call.function.arguments, str) - - tool_call_args[ - tool_call.index] += tool_call.function.arguments - - assert finish_reason_count == 1 - assert role_name == 'assistant' - - assert (len(non_streamed_tool_calls) == len(tool_call_names) == - len(tool_call_args)) - - for i in range(0, 2): - assert non_streamed_tool_calls[i].function.name == tool_call_names[i] - streamed_args = json.loads(tool_call_args[i]) - non_streamed_args = json.loads( - non_streamed_tool_calls[i].function.arguments) - assert streamed_args == non_streamed_args - - -# test: providing parallel tool calls back to the model to get a response -# (streaming/not) -@pytest.mark.asyncio -async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI): - models = await client.models.list() - model_name: str = models.data[0].id - chat_completion = await client.chat.completions.create( - messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, - temperature=0, - max_tokens=500, - model=model_name, - tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False) - - choice = chat_completion.choices[0] - - assert choice.finish_reason != "tool_calls" # "stop" or "length" - assert choice.message.role == "assistant" - assert choice.message.tool_calls is None \ - or len(choice.message.tool_calls) == 0 - assert choice.message.content is not None - assert "98" in choice.message.content # Dallas temp in tool response - assert "78" in choice.message.content # Orlando temp in tool response - - stream = await client.chat.completions.create( - messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, - temperature=0, - max_tokens=500, - model=model_name, - tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False, - stream=True) - - chunks: List[str] = [] - finish_reason_count = 0 - role_sent: bool = False - - async for chunk in stream: - delta = chunk.choices[0].delta - - if delta.role: - assert not role_sent - assert delta.role == "assistant" - role_sent = True - - if delta.content: - chunks.append(delta.content) - - if chunk.choices[0].finish_reason is not None: - finish_reason_count += 1 - assert chunk.choices[0].finish_reason == choice.finish_reason - - assert not delta.tool_calls or len(delta.tool_calls) == 0 - - assert role_sent - assert finish_reason_count == 1 - assert len(chunks) - assert "".join(chunks) == choice.message.content diff --git a/tests/entrypoints/openai/tool_use/test_tool_calls.py b/tests/entrypoints/openai/tool_use/test_tool_calls.py deleted file mode 100644 index 5f1a8dfff1c33..0000000000000 --- a/tests/entrypoints/openai/tool_use/test_tool_calls.py +++ /dev/null @@ -1,192 +0,0 @@ -import json -from typing import Dict, List, Optional - -import openai -import pytest - -from .util import (MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE, - SEARCH_TOOL, WEATHER_TOOL) - - -# test: request a chat completion that should return tool calls, so we know they -# are parsable -@pytest.mark.asyncio -async def test_tool_call_and_choice(client: openai.AsyncOpenAI): - models = await client.models.list() - model_name: str = models.data[0].id - chat_completion = await client.chat.completions.create( - messages=MESSAGES_ASKING_FOR_TOOLS, - temperature=0, - max_tokens=500, - model=model_name, - tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False) - - choice = chat_completion.choices[0] - stop_reason = chat_completion.choices[0].finish_reason - tool_calls = chat_completion.choices[0].message.tool_calls - - # make sure a tool call is present - assert choice.message.role == 'assistant' - assert tool_calls is not None - assert len(tool_calls) == 1 - assert tool_calls[0].type == 'function' - assert tool_calls[0].function is not None - assert isinstance(tool_calls[0].id, str) - assert len(tool_calls[0].id) > 16 - - # make sure the weather tool was called (classic example) with arguments - assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"] - assert tool_calls[0].function.arguments is not None - assert isinstance(tool_calls[0].function.arguments, str) - - # make sure the arguments parse properly - parsed_arguments = json.loads(tool_calls[0].function.arguments) - assert isinstance(parsed_arguments, Dict) - assert isinstance(parsed_arguments.get("city"), str) - assert isinstance(parsed_arguments.get("state"), str) - assert parsed_arguments.get("city") == "Dallas" - assert parsed_arguments.get("state") == "TX" - - assert stop_reason == "tool_calls" - - function_name: Optional[str] = None - function_args_str: str = '' - tool_call_id: Optional[str] = None - role_name: Optional[str] = None - finish_reason_count: int = 0 - - # make the same request, streaming - stream = await client.chat.completions.create( - model=model_name, - messages=MESSAGES_ASKING_FOR_TOOLS, - temperature=0, - max_tokens=500, - tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False, - stream=True) - - async for chunk in stream: - assert chunk.choices[0].index == 0 - - if chunk.choices[0].finish_reason: - finish_reason_count += 1 - assert chunk.choices[0].finish_reason == 'tool_calls' - - # if a role is being streamed make sure it wasn't already set to - # something else - if chunk.choices[0].delta.role: - assert not role_name or role_name == 'assistant' - role_name = 'assistant' - - # if a tool call is streamed make sure there's exactly one - # (based on the request parameters - streamed_tool_calls = chunk.choices[0].delta.tool_calls - - if streamed_tool_calls and len(streamed_tool_calls) > 0: - assert len(streamed_tool_calls) == 1 - tool_call = streamed_tool_calls[0] - - # if a tool call ID is streamed, make sure one hasn't been already - if tool_call.id: - assert not tool_call_id - tool_call_id = tool_call.id - - # if parts of the function start being streamed - if tool_call.function: - # if the function name is defined, set it. it should be streamed - # IN ENTIRETY, exactly one time. - if tool_call.function.name: - assert function_name is None - assert isinstance(tool_call.function.name, str) - function_name = tool_call.function.name - if tool_call.function.arguments: - assert isinstance(tool_call.function.arguments, str) - function_args_str += tool_call.function.arguments - - assert finish_reason_count == 1 - assert role_name == 'assistant' - assert isinstance(tool_call_id, str) and (len(tool_call_id) > 16) - - # validate the name and arguments - assert function_name == WEATHER_TOOL["function"]["name"] - assert function_name == tool_calls[0].function.name - assert isinstance(function_args_str, str) - - # validate arguments - streamed_args = json.loads(function_args_str) - assert isinstance(streamed_args, Dict) - assert isinstance(streamed_args.get("city"), str) - assert isinstance(streamed_args.get("state"), str) - assert streamed_args.get("city") == "Dallas" - assert streamed_args.get("state") == "TX" - - # make sure everything matches non-streaming except for ID - assert function_name == tool_calls[0].function.name - assert choice.message.role == role_name - assert choice.message.tool_calls[0].function.name == function_name - - # compare streamed with non-streamed args Dict-wise, not string-wise - # because character-to-character comparison might not work e.g. the tool - # call parser adding extra spaces or something like that. we care about the - # dicts matching not byte-wise match - assert parsed_arguments == streamed_args - - -# test: providing tools and results back to model to get a non-tool response -# (streaming/not) -@pytest.mark.asyncio -async def test_tool_call_with_results(client: openai.AsyncOpenAI): - models = await client.models.list() - model_name: str = models.data[0].id - chat_completion = await client.chat.completions.create( - messages=MESSAGES_WITH_TOOL_RESPONSE, - temperature=0, - max_tokens=500, - model=model_name, - tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False) - - choice = chat_completion.choices[0] - - assert choice.finish_reason != "tool_calls" # "stop" or "length" - assert choice.message.role == "assistant" - assert choice.message.tool_calls is None \ - or len(choice.message.tool_calls) == 0 - assert choice.message.content is not None - assert "98" in choice.message.content # the temperature from the response - - stream = await client.chat.completions.create( - messages=MESSAGES_WITH_TOOL_RESPONSE, - temperature=0, - max_tokens=500, - model=model_name, - tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False, - stream=True) - - chunks: List[str] = [] - finish_reason_count = 0 - role_sent: bool = False - - async for chunk in stream: - delta = chunk.choices[0].delta - - if delta.role: - assert not role_sent - assert delta.role == "assistant" - role_sent = True - - if delta.content: - chunks.append(delta.content) - - if chunk.choices[0].finish_reason is not None: - finish_reason_count += 1 - assert chunk.choices[0].finish_reason == choice.finish_reason - - assert not delta.tool_calls or len(delta.tool_calls) == 0 - - assert role_sent - assert finish_reason_count == 1 - assert len(chunks) - assert "".join(chunks) == choice.message.content diff --git a/tests/entrypoints/openai/tool_use/util.py b/tests/entrypoints/openai/tool_use/util.py deleted file mode 100644 index a9c3d65686fa3..0000000000000 --- a/tests/entrypoints/openai/tool_use/util.py +++ /dev/null @@ -1,219 +0,0 @@ -from typing import Dict, List - -from openai.types.chat import (ChatCompletionMessageParam, - ChatCompletionToolParam) -from typing_extensions import TypedDict - -from ....utils import VLLM_PATH - - -class ServerConfig(TypedDict): - model: str - arguments: List[str] - - -ARGS: List[str] = [ - "--dtype", - "half", # TODO change to BF16 - "--kv-cache-dtype", - "fp8", - "--enable-auto-tool-choice" -] - -CONFIGS: Dict[str, ServerConfig] = { - "hermes": { - "model": - "NousResearch/Hermes-2-Pro-Llama-3-8B", - "arguments": [ - "--tool-call-parser", "hermes", "--chat-template", - str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja") - ] - }, - "mistral": { - "model": - "mistralai/Mistral-7B-Instruct-v0.3", - "arguments": [ - "--tool-call-parser", "mistral", "--chat-template", - str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja"), - "--ignore-patterns=\"consolidated.safetensors\"" - ] - } -} - -WEATHER_TOOL: ChatCompletionToolParam = { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": - "string", - "description": - "The city to find the weather for, " - "e.g. 'San Francisco'" - }, - "state": { - "type": - "string", - "description": - "the two-letter abbreviation for the state " - "that the city is in, e.g. 'CA' which would " - "mean 'California'" - }, - "unit": { - "type": "string", - "description": "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"] - } - } - } - } -} - -SEARCH_TOOL: ChatCompletionToolParam = { - "type": "function", - "function": { - "name": - "web_search", - "description": - "Search the internet and get a summary of the top " - "10 webpages. Should only be used if you don't know " - "the answer to a user query, and the results are likely" - "to be able to be found with a web search", - "parameters": { - "type": "object", - "properties": { - "search_term": { - "type": - "string", - "description": - "The term to use in the search. This should" - "ideally be keywords to search for, not a" - "natural-language question" - } - }, - "required": ["search_term"] - } - } -} - -MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{ - "role": - "system", - "content": - "You are a helpful assistant with access to tools. If a tool" - " that you have would be helpful to answer a user query, " - "call the tool. Otherwise, answer the user's query directly " - "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " - "to the user's question - just respond to it normally." -}, { - "role": - "user", - "content": - "Hi! How are you?" -}, { - "role": - "assistant", - "content": - "I'm doing great! How can I assist you?" -}, { - "role": - "user", - "content": - "Can you tell me a joke please?" -}] - -MESSAGES_ASKING_FOR_TOOLS: List[ChatCompletionMessageParam] = [{ - "role": - "user", - "content": - "What is the weather in Dallas, Texas in Fahrenheit?" -}] - -MESSAGES_WITH_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{ - "role": - "user", - "content": - "What is the weather in Dallas, Texas in Fahrenheit?" -}, { - "role": - "assistant", - "tool_calls": [{ - "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", - "type": "function", - "function": { - "name": - WEATHER_TOOL["function"]["name"], - "arguments": - '{"city": "Dallas", "state": "TX", ' - '"unit": "fahrenheit"}' - } - }] -}, { - "role": - "tool", - "tool_call_id": - "chatcmpl-tool-03e6481b146e408e9523d9c956696295", - "content": - "The weather in Dallas is 98 degrees fahrenheit, with partly" - "cloudy skies and a low chance of rain." -}] - -MESSAGES_ASKING_FOR_PARALLEL_TOOLS: List[ChatCompletionMessageParam] = [{ - "role": - "user", - "content": - "What is the weather in Dallas, Texas and Orlando, Florida in " - "Fahrenheit?" -}] - -MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{ - "role": - "user", - "content": - "What is the weather in Dallas, Texas and Orlando, Florida in " - "Fahrenheit?" -}, { - "role": - "assistant", - "tool_calls": [{ - "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", - "type": "function", - "function": { - "name": - WEATHER_TOOL["function"]["name"], - "arguments": - '{"city": "Dallas", "state": "TX", ' - '"unit": "fahrenheit"}' - } - }, { - "id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", - "type": "function", - "function": { - "name": - WEATHER_TOOL["function"]["name"], - "arguments": - '{"city": "Orlando", "state": "Fl", ' - '"unit": "fahrenheit"}' - } - }] -}, { - "role": - "tool", - "tool_call_id": - "chatcmpl-tool-03e6481b146e408e9523d9c956696295", - "content": - "The weather in Dallas TX is 98 degrees fahrenheit with mostly " - "cloudy skies and a chance of rain in the evening." -}, { - "role": - "tool", - "tool_call_id": - "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", - "content": - "The weather in Orlando FL is 78 degrees fahrenheit with clear" - "skies." -}] diff --git a/tests/tool_use/conftest.py b/tests/tool_use/conftest.py index 3b52d226d42ff..c7dbaf848663b 100644 --- a/tests/tool_use/conftest.py +++ b/tests/tool_use/conftest.py @@ -6,7 +6,7 @@ # for each server config, download the model and return the config -@pytest.fixture(scope="module", params=CONFIGS.keys()) +@pytest.fixture(scope="session", params=CONFIGS.keys()) def server_config(request): config = CONFIGS[request.param] # download model and tokenizer using transformers @@ -15,7 +15,7 @@ def server_config(request): # run this for each server config -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def server(request, server_config: ServerConfig): model = server_config["model"] args_for_model = server_config["arguments"] @@ -24,6 +24,6 @@ def server(request, server_config: ServerConfig): yield server -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def client(server: RemoteOpenAIServer): return server.get_async_client() diff --git a/tests/tool_use/util.py b/tests/tool_use/util.py index 47e8b9c87aefd..35ebc267f6bfa 100644 --- a/tests/tool_use/util.py +++ b/tests/tool_use/util.py @@ -11,12 +11,9 @@ class ServerConfig(TypedDict): model: str arguments: List[str] - +# universal args for all models go here. also good if you need to test locally +# and change type or KV cache quantization or something. ARGS: List[str] = [ - "--dtype", - "half", # TODO change to BF16 - "--kv-cache-dtype", - "fp8", "--enable-auto-tool-choice" ] From f7f8b923b467e44c1664c8a9b06f3839290ab12e Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 15 Aug 2024 16:08:49 -0500 Subject: [PATCH 178/222] fix: formatting --- tests/tool_use/conftest.py | 1 + tests/tool_use/util.py | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/tool_use/conftest.py b/tests/tool_use/conftest.py index c7dbaf848663b..b93eecf47cc2d 100644 --- a/tests/tool_use/conftest.py +++ b/tests/tool_use/conftest.py @@ -2,6 +2,7 @@ from huggingface_hub import snapshot_download from tests.utils import RemoteOpenAIServer + from .util import ARGS, CONFIGS, ServerConfig diff --git a/tests/tool_use/util.py b/tests/tool_use/util.py index 35ebc267f6bfa..06d1e305a373d 100644 --- a/tests/tool_use/util.py +++ b/tests/tool_use/util.py @@ -11,11 +11,10 @@ class ServerConfig(TypedDict): model: str arguments: List[str] + # universal args for all models go here. also good if you need to test locally # and change type or KV cache quantization or something. -ARGS: List[str] = [ - "--enable-auto-tool-choice" -] +ARGS: List[str] = ["--enable-auto-tool-choice"] CONFIGS: Dict[str, ServerConfig] = { "hermes": { From 94047c7629d0cdab2a58d701365e976c765c9d2e Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 15 Aug 2024 18:04:04 -0500 Subject: [PATCH 179/222] chore(tests): move tool tests out of the fastcheck section --- .buildkite/test-pipeline.yaml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index cb30e7493bdad..81828330debe1 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -85,14 +85,6 @@ steps: - pytest -v -s entrypoints/llm - pytest -v -s entrypoints/openai -- label: OpenAI-Compatible Tool Use # 20 min - fast_check: false - mirror_hardwares: [ amd ] - source_file_dependencies: - - vllm/ - - tests/tool_use - commands: - - pytest -v -s tool_use - label: Distributed Tests (4 GPUs) # 10min working_dir: "/vllm-workspace/tests" @@ -256,6 +248,14 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - bash ./run-tests.sh -c configs/models-small.txt -t 1 +- label: OpenAI-Compatible Tool Use # 20 min + fast_check: false + mirror_hardwares: [ amd ] + source_file_dependencies: + - vllm/ + - tests/tool_use + commands: + - pytest -v -s tool_use ##### 1 GPU test ##### ##### multi gpus test ##### From 44b2e072d420bbd67a393ed29a69da5335da1363 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 15 Aug 2024 18:13:46 -0500 Subject: [PATCH 180/222] cleanup: range statement --- tests/tool_use/test_parallel_tool_calls.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py index 5083cf394a6cf..31d531e29ae80 100644 --- a/tests/tool_use/test_parallel_tool_calls.py +++ b/tests/tool_use/test_parallel_tool_calls.py @@ -125,7 +125,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI): assert (len(non_streamed_tool_calls) == len(tool_call_names) == len(tool_call_args)) - for i in range(0, 2): + for i in range(2): assert non_streamed_tool_calls[i].function.name == tool_call_names[i] streamed_args = json.loads(tool_call_args[i]) non_streamed_args = json.loads( From 6d3650904002f0093fb28f2c3255fa191564bf8d Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 15 Aug 2024 18:15:31 -0500 Subject: [PATCH 181/222] cleanup: unnecessary backslash --- vllm/entrypoints/openai/cli_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 19f8fc93e38c9..d81ab9401372a 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -154,7 +154,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: choices=["mistral", "hermes"], default=None, help= - "Select the tool call parser depending on the model that you\'re using." + "Select the tool call parser depending on the model that you're using." " This is used to parse the model-generated tool call into OpenAI API " "format. Required for --enable-auto-tool-choice. Options: 'hermes', " "'mistral'") From fb40e5fa906bf85214e0754fefcf9d867d69a36e Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 15 Aug 2024 18:16:07 -0500 Subject: [PATCH 182/222] cleanup: cli args --- vllm/entrypoints/openai/cli_args.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index d81ab9401372a..87d830f1bdf30 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -156,8 +156,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help= "Select the tool call parser depending on the model that you're using." " This is used to parse the model-generated tool call into OpenAI API " - "format. Required for --enable-auto-tool-choice. Options: 'hermes', " - "'mistral'") + "format. Required for --enable-auto-tool-choice.") parser = AsyncEngineArgs.add_cli_args(parser) From c1d3110db869ab7263a928907ac45bbb8be0be7c Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 16 Aug 2024 18:16:32 -0500 Subject: [PATCH 183/222] fix: examples --- ...penai_chat_completion_client_with_tools.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/openai_chat_completion_client_with_tools.py b/examples/openai_chat_completion_client_with_tools.py index 5d423a8e65aa0..2bbe42b6bd2ef 100644 --- a/examples/openai_chat_completion_client_with_tools.py +++ b/examples/openai_chat_completion_client_with_tools.py @@ -103,23 +103,24 @@ for chunk in chunks: if chunk.choices[0].delta.tool_calls: - if chunk.choices[0].delta.tool_calls[0].index != tool_call_idx: + tool_call = chunk.choices[0].delta.tool_calls[0] + + if tool_call.index != tool_call_idx: if tool_call_idx >= 0: print( f"streamed tool call arguments: {arguments[tool_call_idx]}" ) tool_call_idx = chunk.choices[0].delta.tool_calls[0].index arguments.append("") - if chunk.choices[0].delta.tool_calls[0].id: - print(f"streamed tool call id: " - f"{chunk.choices[0].delta.tool_calls[0].id}") - if chunk.choices[0].delta.tool_calls[0].function: - if chunk.choices[0].delta.tool_calls[0].function.name: - print(f"streamed tool call name: " - f"{chunk.choices[0].delta.tool_calls[0].function.name}") - if chunk.choices[0].delta.tool_calls[0].function.arguments: - arguments[tool_call_idx] += chunk.choices[0].delta.tool_calls[ - 0].function.arguments + if tool_call.id: + print(f"streamed tool call id: {tool_call.id} ") + + if tool_call.function: + if tool_call.function.name: + print(f"streamed tool call name: {tool_call.function.name}") + + if tool_call.function.arguments: + arguments[tool_call_idx] += tool_call.function.arguments if len(arguments): print(f"streamed tool call arguments: {arguments[-1]}") From 5fb1a414ae2f237996bc73269c98c2aeabf8a920 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 16 Aug 2024 21:59:01 -0500 Subject: [PATCH 184/222] fix(tests): set max tokens to a lower number that is about 30-50% above what the test should require --- tests/tool_use/test_chat_completions.py | 8 ++++---- tests/tool_use/test_parallel_tool_calls.py | 8 ++++---- tests/tool_use/test_tool_calls.py | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/tool_use/test_chat_completions.py b/tests/tool_use/test_chat_completions.py index a5adb04252c12..8484dde1a75fd 100644 --- a/tests/tool_use/test_chat_completions.py +++ b/tests/tool_use/test_chat_completions.py @@ -16,7 +16,7 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI): chat_completion = await client.chat.completions.create( messages=MESSAGES_WITHOUT_TOOLS, temperature=0, - max_tokens=128, + max_tokens=150, model=model_name, logprobs=False) choice = chat_completion.choices[0] @@ -36,7 +36,7 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI): stream = await client.chat.completions.create( messages=MESSAGES_WITHOUT_TOOLS, temperature=0, - max_tokens=128, + max_tokens=150, model=model_name, logprobs=False, stream=True, @@ -83,7 +83,7 @@ async def test_chat_completion_with_tools(client: openai.AsyncOpenAI): chat_completion = await client.chat.completions.create( messages=MESSAGES_WITHOUT_TOOLS, temperature=0, - max_tokens=128, + max_tokens=150, model=model_name, tools=[WEATHER_TOOL], logprobs=False) @@ -104,7 +104,7 @@ async def test_chat_completion_with_tools(client: openai.AsyncOpenAI): stream = await client.chat.completions.create( messages=MESSAGES_WITHOUT_TOOLS, temperature=0, - max_tokens=128, + max_tokens=150, model=model_name, logprobs=False, tools=[WEATHER_TOOL], diff --git a/tests/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py index 31d531e29ae80..f4d1cba457452 100644 --- a/tests/tool_use/test_parallel_tool_calls.py +++ b/tests/tool_use/test_parallel_tool_calls.py @@ -20,7 +20,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI): chat_completion = await client.chat.completions.create( messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, temperature=0, - max_tokens=800, + max_tokens=200, model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False) @@ -57,7 +57,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI): model=model_name, messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, temperature=0, - max_tokens=800, + max_tokens=200, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, stream=True) @@ -142,7 +142,7 @@ async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI): chat_completion = await client.chat.completions.create( messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, temperature=0, - max_tokens=500, + max_tokens=200, model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False) @@ -160,7 +160,7 @@ async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI): stream = await client.chat.completions.create( messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, temperature=0, - max_tokens=500, + max_tokens=200, model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, diff --git a/tests/tool_use/test_tool_calls.py b/tests/tool_use/test_tool_calls.py index 5f1a8dfff1c33..d182f32bd8966 100644 --- a/tests/tool_use/test_tool_calls.py +++ b/tests/tool_use/test_tool_calls.py @@ -17,7 +17,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): chat_completion = await client.chat.completions.create( messages=MESSAGES_ASKING_FOR_TOOLS, temperature=0, - max_tokens=500, + max_tokens=100, model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False) @@ -61,7 +61,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): model=model_name, messages=MESSAGES_ASKING_FOR_TOOLS, temperature=0, - max_tokens=500, + max_tokens=100, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, stream=True) @@ -142,7 +142,7 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI): chat_completion = await client.chat.completions.create( messages=MESSAGES_WITH_TOOL_RESPONSE, temperature=0, - max_tokens=500, + max_tokens=100, model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False) @@ -159,7 +159,7 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI): stream = await client.chat.completions.create( messages=MESSAGES_WITH_TOOL_RESPONSE, temperature=0, - max_tokens=500, + max_tokens=100, model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, From bbb5b2716b058ab57a2e86253a9ffc27324ea9d9 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 17 Aug 2024 15:52:00 -0500 Subject: [PATCH 185/222] fix: exceptions in tool validation should be raised, not returned, so that error messages are returned to the client --- vllm/entrypoints/openai/protocol.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 83cfd8ca92439..88a865add46cc 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -409,13 +409,13 @@ def check_tool_usage(cls, data): valid_tool = False specified_function = data["tool_choice"]["function"] if not specified_function: - return ValueError( + raise ValueError( "Incorrectly formatted `tool_choice`. Should be like " "`{\"type\": \"function\"," " \"function\": {\"name\": \"my_function\"}}`") specified_function_name = specified_function["name"] if not specified_function_name: - return ValueError( + raise ValueError( "Incorrectly formatted `tool_choice`. Should be like " "`{\"type\": \"function\", " "\"function\": {\"name\": \"my_function\"}}`") @@ -424,7 +424,7 @@ def check_tool_usage(cls, data): valid_tool = True break if not valid_tool: - return ValueError( + raise ValueError( "The tool specified in `tool_choice` does not match any" " of the specified `tools`") return data From 17292b02842c0693418dfa6f9377adcf4e31ae2c Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 17 Aug 2024 16:02:59 -0500 Subject: [PATCH 186/222] chore: cleanup debug lines in tool parsers to remove less-relevant ones --- .../openai/tool_parsers/hermes_tool_parser.py | 10 +--------- .../openai/tool_parsers/mistral_tool_parser.py | 12 ++---------- 2 files changed, 3 insertions(+), 19 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index fb361e6a01dbc..410fbff6c45ff 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -178,7 +178,6 @@ def extract_tool_calls_streaming( # case -- the current tool call is being closed. elif (cur_tool_start_count == cur_tool_end_count and cur_tool_end_count > prev_tool_end_count): - logger.debug("Closing the current tool call!") diff = self.prev_tool_call_arr[self.current_tool_id].get( "arguments") if diff: @@ -199,7 +198,6 @@ def extract_tool_calls_streaming( delta = DeltaMessage(tool_calls=[], content=delta_text) return delta - logger.debug("Tool call portion: %s", tool_call_portion or "") try: current_tool_call = partial_json_parser.loads( @@ -213,7 +211,6 @@ def extract_tool_calls_streaming( # case - we haven't sent the initial delta with the tool call ID # (it will be sent) if not self.current_tool_initial_sent: - logger.debug("Sending InitialDeltaToolCall") self.current_tool_initial_sent = True return DeltaMessage(tool_calls=[ InitialDeltaToolCall( @@ -226,8 +223,6 @@ def extract_tool_calls_streaming( elif not self.current_tool_name_sent: function_name: Union[str, None] = current_tool_call.get("name") if function_name: - logger.debug("Sending DeltaToolCall with function name %s", - function_name) self.current_tool_name_sent = True return DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, @@ -249,8 +244,6 @@ def extract_tool_calls_streaming( # now, the nitty-gritty of tool calls # now we have the portion to parse as tool call. - if text_portion is not None: - logger.debug("Also, will send text portion: %s", text_portion) logger.debug("Trying to parse current tool call with ID %s", self.current_tool_id) @@ -259,7 +252,6 @@ def extract_tool_calls_streaming( # a placeholder for the arguments if len(self.prev_tool_call_arr) <= self.current_tool_id: self.prev_tool_call_arr.append({}) - logger.debug("Pushed dummy value into tool call arr") # main logic for tool parsing here - compare prev. partially-parsed # JSON to the current partially-parsed JSON @@ -313,7 +305,7 @@ def extract_tool_calls_streaming( cur_args_json = json.dumps(cur_arguments) prev_args_json = json.dumps(prev_arguments) - logger.debug("Searching for dif between\n%s", cur_args_json) + logger.debug("Searching for diff between\n%s", cur_args_json) logger.debug("and\n%s", prev_args_json) argument_diff = extract_intermediate_diff( cur_args_json, prev_args_json) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 2bc9bfc6792ae..67dc6ef3e5f29 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -44,10 +44,6 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: make sure your tool call arguments don't ever include quotes! """ - logger.debug( - "Trying to extract mistral tool calls from the following:") - logger.debug(model_output) - # case -- if a tool call token is not present, return a text response if MistralToolParser.bot_token not in model_output: return ExtractedToolCallInformation(tools_called=False, @@ -203,7 +199,6 @@ def extract_tool_calls_streaming( # if the current tool initial data incl. the id, type=function # and idx not sent, send that if not self.current_tool_initial_sent: - logger.debug("Sending InitialDeltaToolCall") self.current_tool_initial_sent = True delta = DeltaMessage(tool_calls=[ InitialDeltaToolCall( @@ -216,8 +211,7 @@ def extract_tool_calls_streaming( elif not self.current_tool_name_sent: function_name = current_tool_call.get("name") if function_name: - logger.debug("Sending DeltaToolCall with function name %s", - function_name) + delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( @@ -236,8 +230,6 @@ def extract_tool_calls_streaming( self.current_tool_id].get("arguments") cur_arguments = current_tool_call.get("arguments") - logger.debug("new text: %s", current_text) - new_text = delta_text.replace("\'", "\"") if not cur_arguments and not prev_arguments: @@ -250,7 +242,7 @@ def extract_tool_calls_streaming( delta = None elif cur_arguments and not prev_arguments: cur_arguments_json = json.dumps(cur_arguments) - logger.debug("finding %s in |%s|", new_text, + logger.debug("finding %s in %s", new_text, cur_arguments_json) arguments_delta = cur_arguments_json[:cur_arguments_json. From 6830f904cdfa5d075e92186db0f43be2fd14e4b6 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 17 Aug 2024 16:05:26 -0500 Subject: [PATCH 187/222] fix: remove comments --- vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 67dc6ef3e5f29..8483086081413 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -32,8 +32,8 @@ class MistralToolParser(ToolParser): # the bot_token is the token indicating tool call(s) follow. Tokens before # this token will be parsed as content; and # if not present, the entire response will be parsed as text content. - bot_token: str = "[TOOL_CALLS]" # string literal - bot_token_id: int = 5 # token ID thereof from the models" tokenizer + bot_token: str = "[TOOL_CALLS]" + bot_token_id: int = 5 tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) @staticmethod From c87a81f21b79f27d2e99de3804f1055f2ba3959f Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 17 Aug 2024 16:23:53 -0500 Subject: [PATCH 188/222] refactor: tool call parsers no longer use static methods --- vllm/entrypoints/openai/serving_chat.py | 4 +- .../openai/tool_parsers/hermes_tool_parser.py | 82 ++++++++++--------- .../tool_parsers/mistral_tool_parser.py | 54 ++++++------ 3 files changed, 69 insertions(+), 71 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index e7d4fc336a3fb..1ab4d2cd3198b 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -624,8 +624,8 @@ async def chat_completion_full_generator( or request.tool_choice is None) and self.enable_auto_tools \ and self.tool_parser: - tool_call_info = self.tool_parser.extract_tool_calls( - output.text) + tool_parser = self.tool_parser(tokenizer) + tool_call_info = tool_parser.extract_tool_calls(output.text) tools_called = tool_call_info.tools_called if tool_call_info.tools_called: message = ChatMessage(role=role, diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 410fbff6c45ff..12398cb173a7c 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -22,21 +22,50 @@ class Hermes2ProToolParser(ToolParser): - tool_call_start_token: str = "" - tool_call_end_token: str = "" - # regex to match between and OR between - # and EOS (happens sometimes :)) - tool_call_regex = re.compile( - r"(.*?)|(.*)", re.DOTALL) - scratch_pad_regex = re.compile(r"(.*?)", - re.DOTALL) + def __init__(self, + tokenizer: Optional[Union[PreTrainedTokenizer, + PreTrainedTokenizerFast, + PreTrainedTokenizerFast, + AutoTokenizer]] = None): + super().__init__(tokenizer) + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent = False + self.current_tool_initial_sent: bool = False + self.streamed_args_for_tool: List[str] = [ + ] # map what has been streamed for each tool so far to a list + + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + + self.tool_call_regex = re.compile( + r"(.*?)|(.*)", re.DOTALL) + self.scratch_pad_regex = re.compile( + r"(.*?)", re.DOTALL) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + self.tool_call_start_token_id: int = self.model_tokenizer.vocab[ + self.tool_call_start_token] + self.tool_call_end_token_id: int = self.model_tokenizer.vocab[ + self.tool_call_end_token] + if not self.tool_call_start_token_id or not self.tool_call_end_token_id: + raise RuntimeError( + "Hermes 2 Pro Tool parser could not locate tool call start/end " + "tokens in the tokenizer!") - @staticmethod - def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: + + def extract_tool_calls( + self, + model_output: str + ) -> ExtractedToolCallInformation: # sanity check; avoid unnecessary processing - if Hermes2ProToolParser.tool_call_start_token not in model_output: + if self.tool_call_start_token not in model_output: return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) @@ -49,7 +78,7 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: # findall is an array of tuples where one is a function call and # the other is None function_call_tuples = ( - Hermes2ProToolParser.tool_call_regex.findall(model_output)) + self.tool_call_regex.findall(model_output)) # load the JSON, and then use it to build the Function and # Tool Call @@ -68,7 +97,7 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: ] content = model_output[:model_output.find( - Hermes2ProToolParser.tool_call_start_token)] + self.tool_call_start_token)] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, @@ -81,33 +110,6 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: tool_calls=[], content=model_output) - def __init__(self, - tokenizer: Optional[Union[PreTrainedTokenizer, - PreTrainedTokenizerFast, - PreTrainedTokenizerFast, - AutoTokenizer]] = None): - super().__init__(tokenizer) - self.current_tool_name_sent: bool = False - self.prev_tool_call_arr: List[Dict] = [] - self.current_tool_id: int = -1 - self.current_tool_name_sent = False - self.current_tool_initial_sent: bool = False - self.streamed_args_for_tool: List[str] = [ - ] # map what has been streamed for each tool so far to a list - - if not self.model_tokenizer: - raise ValueError( - "The model tokenizer must be passed to the ToolParser " - "constructor during construction.") - self.tool_call_start_token_id: int = self.model_tokenizer.vocab[ - ""] - self.tool_call_end_token_id: int = self.model_tokenizer.vocab[ - ""] - if not self.tool_call_start_token_id or not self.tool_call_end_token_id: - raise RuntimeError( - "Hermes 2 Pro Tool parser could not locate tool call start/end " - "tokens in the tokenizer!") - def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: List[int], current_token_ids: List[int], diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 8483086081413..93db78f7f52ab 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -29,15 +29,27 @@ class MistralToolParser(ToolParser): Used when --enable-auto-tool-choice --tool-call-parser gmistral are all set """ - # the bot_token is the token indicating tool call(s) follow. Tokens before - # this token will be parsed as content; and - # if not present, the entire response will be parsed as text content. - bot_token: str = "[TOOL_CALLS]" - bot_token_id: int = 5 - tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) - - @staticmethod - def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: + def __init__(self, + tokenizer: Optional[Union[PreTrainedTokenizer, + PreTrainedTokenizerFast, + PreTrainedTokenizerFast, + AutoTokenizer]] = None): + super().__init__(tokenizer) + + # initialize properties used for state when parsing tool calls in + # streaming mode + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.current_tool_initial_sent: bool = False + self.streamed_args_for_tool: List[str] = [ + ] # map what has been streamed for each tool so far to a list + self.bot_token = "[TOOL_CALLS]" + self.bot_token_id = self.model_tokenizer.vocab[self.bot_token] + self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) + + def extract_tool_calls(self, + model_output: str) -> ExtractedToolCallInformation: """ Extract the tool calls from a complete model response. Requires find-and-replacing single quotes with double quotes for JSON parsing, @@ -45,7 +57,7 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: """ # case -- if a tool call token is not present, return a text response - if MistralToolParser.bot_token not in model_output: + if self.bot_token not in model_output: return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) @@ -53,8 +65,8 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: # use a regex to find the tool call. remove the BOT token # and make sure to replace single quotes with double quotes - raw_tool_call = MistralToolParser.tool_call_regex.findall( - model_output.replace(MistralToolParser.bot_token, ""))[0] + raw_tool_call = self.tool_call_regex.findall( + model_output.replace(self.bot_token, ""))[0] # load the JSON, and then use it to build the Function and # Tool Call @@ -70,7 +82,7 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: ] # get any content before the tool call - content = model_output.split(MistralToolParser.bot_token)[0] + content = model_output.split(self.bot_token)[0] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, @@ -84,22 +96,6 @@ def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: tool_calls=[], content=model_output) - def __init__(self, - tokenizer: Optional[Union[PreTrainedTokenizer, - PreTrainedTokenizerFast, - PreTrainedTokenizerFast, - AutoTokenizer]] = None): - super().__init__(tokenizer) - - # initialize properties used for state when parsing tool calls in - # streaming mode - self.prev_tool_call_arr: List[Dict] = [] - self.current_tool_id: int = -1 - self.current_tool_name_sent: bool = False - self.current_tool_initial_sent: bool = False - self.streamed_args_for_tool: List[str] = [ - ] # map what has been streamed for each tool so far to a list - def extract_tool_calls_streaming( self, previous_text: str, From 6f5e585da4653168aece842a8383dab7a706c22a Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 17 Aug 2024 16:28:09 -0500 Subject: [PATCH 189/222] fix: remove cruft --- vllm/entrypoints/openai/tool_parsers/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/utils.py b/vllm/entrypoints/openai/tool_parsers/utils.py index 52d2e6ed985fd..c3af0f1efee11 100644 --- a/vllm/entrypoints/openai/tool_parsers/utils.py +++ b/vllm/entrypoints/openai/tool_parsers/utils.py @@ -59,8 +59,6 @@ def extract_intermediate_diff(curr: str, old: str) -> str: """ suffix = find_common_suffix(curr, old) - # prevent double-counting - #s2_old = old old = old[::-1].replace(suffix[::-1], '', 1)[::-1] prefix = find_common_prefix(curr, old) diff = curr From f2b0ee861ee6a8594d7764b4dcbacac7d512b320 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 17 Aug 2024 16:30:38 -0500 Subject: [PATCH 190/222] fix: comment on its own line --- vllm/entrypoints/openai/tool_parsers/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/tool_parsers/utils.py b/vllm/entrypoints/openai/tool_parsers/utils.py index c3af0f1efee11..20a08bb14e490 100644 --- a/vllm/entrypoints/openai/tool_parsers/utils.py +++ b/vllm/entrypoints/openai/tool_parsers/utils.py @@ -66,9 +66,10 @@ def extract_intermediate_diff(curr: str, old: str) -> str: diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1] if len(prefix): + # replace the prefix only once in case it's mirrored diff = diff.replace( prefix, '', - 1) # replace the prefix only once in case it's mirrored + 1) return diff From fb2db3ab4bfd9d66fec365ce7c9d82f19ede7825 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 17 Aug 2024 16:33:09 -0500 Subject: [PATCH 191/222] fix: remove print() --- vllm/entrypoints/openai/serving_chat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 1ab4d2cd3198b..6d410e2c03563 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -542,7 +542,6 @@ async def chat_completion_stream_generator( except ValueError as e: # TODO: Use a vllm-specific Validation Error logger.error("error in chat completion stream generator: %s", e) - print(e) data = self.create_streaming_error_response(str(e)) yield f"data: {data}\n\n" # Send the final done message after all response.n are finished From 1f4eeff8a0b400583f8058d27db6fa8fe7b8fae9 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 17 Aug 2024 16:34:50 -0500 Subject: [PATCH 192/222] refactor: types based on earlier refactor of hermes and mistral tool parsers --- .../openai/tool_parsers/abstract_tool_parser.py | 12 +++++------- .../openai/tool_parsers/hermes_tool_parser.py | 11 ++++------- vllm/entrypoints/openai/tool_parsers/utils.py | 4 +--- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index e0870396a69ee..5545b9b450189 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Union +from typing import Dict, List, Union from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) @@ -18,10 +18,8 @@ class ToolParser: """ def __init__(self, - tokenizer: Optional[Union[PreTrainedTokenizer, - PreTrainedTokenizerFast, - PreTrainedTokenizerFast, - AutoTokenizer]] = None): + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, + PreTrainedTokenizerFast, AutoTokenizer]): self.prev_tool_call_arr: List[Dict] = [] # the index of the tool call that is currently being parsed self.current_tool_id: int = -1 @@ -31,8 +29,8 @@ def __init__(self, self.model_tokenizer = tokenizer - @staticmethod - def extract_tool_calls(model_output: str) -> ExtractedToolCallInformation: + def extract_tool_calls(self, + model_output: str) -> ExtractedToolCallInformation: """ Static method that should be implemented for extracting tool calls from a complete model-generated string. diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 12398cb173a7c..d62209f3015e7 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -58,11 +58,8 @@ def __init__(self, "Hermes 2 Pro Tool parser could not locate tool call start/end " "tokens in the tokenizer!") - - def extract_tool_calls( - self, - model_output: str - ) -> ExtractedToolCallInformation: + def extract_tool_calls(self, + model_output: str) -> ExtractedToolCallInformation: # sanity check; avoid unnecessary processing if self.tool_call_start_token not in model_output: @@ -96,8 +93,8 @@ def extract_tool_calls( for function_call in raw_function_calls ] - content = model_output[:model_output.find( - self.tool_call_start_token)] + content = model_output[:model_output. + find(self.tool_call_start_token)] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, diff --git a/vllm/entrypoints/openai/tool_parsers/utils.py b/vllm/entrypoints/openai/tool_parsers/utils.py index 20a08bb14e490..db7fc5259fc4e 100644 --- a/vllm/entrypoints/openai/tool_parsers/utils.py +++ b/vllm/entrypoints/openai/tool_parsers/utils.py @@ -67,9 +67,7 @@ def extract_intermediate_diff(curr: str, old: str) -> str: if len(prefix): # replace the prefix only once in case it's mirrored - diff = diff.replace( - prefix, '', - 1) + diff = diff.replace(prefix, '', 1) return diff From f14e3e5fe9970637b1198d0a741caff98d666a5a Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 17 Aug 2024 22:10:20 -0500 Subject: [PATCH 193/222] fix(tests): make max wait timeout for RemoteOpenAIServer an instance variable only --- tests/utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 500fb4c662a30..e8c1218ad7db8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -51,7 +51,6 @@ def _nvml(): class RemoteOpenAIServer: DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key - MAX_START_WAIT_S = 120 # wait for server to start for 120 seconds def __init__(self, model: str, @@ -59,7 +58,7 @@ def __init__(self, *, env_dict: Optional[Dict[str, str]] = None, auto_port: bool = True, - max_start_wait_s: Optional[int] = None) -> None: + max_start_wait_s: Optional[int] = 120) -> None: if auto_port: if "-p" in cli_args or "--port" in cli_args: raise ValueError("You have manually specified the port" @@ -75,7 +74,7 @@ def __init__(self, self.port = int(args.port) if max_start_wait_s: - self.MAX_START_WAIT_S = max_start_wait_s + self.max_start_wait_s = max_start_wait_s env = os.environ.copy() # the current process might initialize cuda, @@ -88,7 +87,7 @@ def __init__(self, stdout=sys.stdout, stderr=sys.stderr) self._wait_for_server(url=self.url_for("health"), - timeout=self.MAX_START_WAIT_S) + timeout=self.max_start_wait_s) def __enter__(self): return self From e21cfa858b1f01cffc0c4a3d35205a55846200a5 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sun, 18 Aug 2024 14:46:59 -0500 Subject: [PATCH 194/222] fix: refactor RemoteOpenAIServer to use a default class var to prevent breaking tests --- tests/entrypoints/openai/test_oot_registration.py | 3 ++- tests/utils.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/entrypoints/openai/test_oot_registration.py b/tests/entrypoints/openai/test_oot_registration.py index 3e1c7a1456697..bc979b86b4ebc 100644 --- a/tests/entrypoints/openai/test_oot_registration.py +++ b/tests/entrypoints/openai/test_oot_registration.py @@ -82,7 +82,8 @@ def test_oot_registration_for_api_server(): except OpenAIError as e: if "Connection error" in str(e): time.sleep(3) - if time.time() - now > RemoteOpenAIServer.MAX_START_WAIT_S: + if time.time( + ) - now > RemoteOpenAIServer.DEFAULT_MAX_START_WAIT_S: msg = "Server did not start in time" raise RuntimeError(msg) from e else: diff --git a/tests/utils.py b/tests/utils.py index e8c1218ad7db8..f44ccf21b0c19 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -51,6 +51,7 @@ def _nvml(): class RemoteOpenAIServer: DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key + DEFAULT_MAX_START_WAIT_S = 120 # needed as a class variable for a test def __init__(self, model: str, @@ -58,7 +59,7 @@ def __init__(self, *, env_dict: Optional[Dict[str, str]] = None, auto_port: bool = True, - max_start_wait_s: Optional[int] = 120) -> None: + max_start_wait_s: Optional[int] = None) -> None: if auto_port: if "-p" in cli_args or "--port" in cli_args: raise ValueError("You have manually specified the port" @@ -74,7 +75,8 @@ def __init__(self, self.port = int(args.port) if max_start_wait_s: - self.max_start_wait_s = max_start_wait_s + self.max_start_wait_s = max_start_wait_s or \ + RemoteOpenAIServer.DEFAULT_MAX_START_WAIT_S env = os.environ.copy() # the current process might initialize cuda, From f6fa6df3f19584ea794b67c6d23bb31620235e25 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 21 Aug 2024 20:25:17 -0500 Subject: [PATCH 195/222] fix: problems caused by resolution of merge conflicts --- vllm/entrypoints/chat_utils.py | 2 +- vllm/entrypoints/openai/serving_chat.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 97195354dca2f..8e4c2a08b7e61 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -231,7 +231,7 @@ def _parse_chat_message_content( role = cast(str, message["role"]) # can be iterable in some cases, so cast content = message.get("content") tool_call_id = message.get("tool_call_id") - tool_calls = message.get("tool_calls", []) + tool_calls = message.get("tool_calls", None) # no longer used by OpenAI, but some models still use it for tool calls. name = message.get("name", "") diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 3e1d61b1628c5..9d029bf9ec2bc 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -120,6 +120,7 @@ async def create_chat_completion( tool.model_dump() for tool in request.tools ] + print('CONVERSATION\n', conversation) prompt = apply_chat_template( tokenizer, conversation=conversation, From e4222a5227e2fd1765658cce36496e24faf56b34 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 22 Aug 2024 17:05:12 -0500 Subject: [PATCH 196/222] fix(tests): set max model len --- tests/tool_use/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tool_use/util.py b/tests/tool_use/util.py index 06d1e305a373d..8ec9b05b2c521 100644 --- a/tests/tool_use/util.py +++ b/tests/tool_use/util.py @@ -14,7 +14,7 @@ class ServerConfig(TypedDict): # universal args for all models go here. also good if you need to test locally # and change type or KV cache quantization or something. -ARGS: List[str] = ["--enable-auto-tool-choice"] +ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "8096"] CONFIGS: Dict[str, ServerConfig] = { "hermes": { From 8caf6f85bd6016bf89b986e2192032a50251dc08 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 22 Aug 2024 17:15:28 -0500 Subject: [PATCH 197/222] fix: hermes system prompt in chat template was missing <|im_start|>system\n --- examples/tool_chat_template_hermes.jinja | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/tool_chat_template_hermes.jinja b/examples/tool_chat_template_hermes.jinja index 35247917998a9..6947444f135a8 100644 --- a/examples/tool_chat_template_hermes.jinja +++ b/examples/tool_chat_template_hermes.jinja @@ -32,7 +32,7 @@ {{- bos_token }} -{{- "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }} +{{- "<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }} {%- if tools is iterable and tools | length > 0 %} {%- for tool in tools %} {%- if tool.function is defined %} From ebdcef940def49ab3e9078e98869d7e59f48b22c Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 22 Aug 2024 17:18:14 -0500 Subject: [PATCH 198/222] fix(tests): double RemoteOpenAIServer start timeout suggested by @mgoin --- tests/tool_use/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tool_use/conftest.py b/tests/tool_use/conftest.py index e494b9addaa64..c940624924f82 100644 --- a/tests/tool_use/conftest.py +++ b/tests/tool_use/conftest.py @@ -21,7 +21,7 @@ def server(request, server_config: ServerConfig): model = server_config["model"] args_for_model = server_config["arguments"] with RemoteOpenAIServer(model, ARGS + args_for_model, - max_wait_seconds=240) as server: + max_wait_seconds=480) as server: yield server From bdd01bbe2c743109900362babebcbf0bdc6039eb Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 22 Aug 2024 17:18:47 -0500 Subject: [PATCH 199/222] fix: remove cruft --- vllm/entrypoints/openai/serving_chat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 9d029bf9ec2bc..3e1d61b1628c5 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -120,7 +120,6 @@ async def create_chat_completion( tool.model_dump() for tool in request.tools ] - print('CONVERSATION\n', conversation) prompt = apply_chat_template( tokenizer, conversation=conversation, From 8f49d9e42a7fc5292e42ce8d2df9cac5f3fb4d1c Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 22 Aug 2024 17:19:35 -0500 Subject: [PATCH 200/222] fix: spacing in doc strings --- vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index f0cd608f8622a..de76bfca9132a 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -54,8 +54,8 @@ def extract_tool_calls_streaming( Instance method that should be implemented for extracting tool calls from an incomplete response; for use when handling tool calls and streaming. Has to be an instance method because it requires state - - the current text/ tokens/diffs, but also the information about what has - previously been parsed and extracted (see constructor) + the current tokens/diffs, but also the information about what has + previously been parsed and extracted (see constructor) """ raise NotImplementedError( "AbstractToolParser.extract_tool_calls_streaming has not been " From 11c751dcc76de7beff414cf3e89e7b157228250b Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 22 Aug 2024 17:21:41 -0500 Subject: [PATCH 201/222] fix: make enable_auto_tools a boolean not optional --- vllm/entrypoints/openai/serving_chat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 3e1d61b1628c5..8405239093157 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -53,7 +53,7 @@ def __init__(self, request_logger: Optional[RequestLogger], chat_template: Optional[str], return_tokens_as_token_ids: bool = False, - enable_auto_tools: Optional[bool] = False, + enable_auto_tools: bool = False, tool_parser: Optional[str] = None): super().__init__(async_engine_client=async_engine_client, model_config=model_config, @@ -68,7 +68,7 @@ def __init__(self, self.chat_template = load_chat_template(chat_template) # set up tool use - self.enable_auto_tools: bool = enable_auto_tools or False + self.enable_auto_tools: bool = enable_auto_tools if self.enable_auto_tools: logger.info( "\"auto\" tool choice has been enabled please note that while" From dc5db10049f7b14b89070b83f6c8fbbc992ef1c7 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 22 Aug 2024 17:22:12 -0500 Subject: [PATCH 202/222] fix: whitespace --- vllm/entrypoints/openai/serving_chat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 8405239093157..2fb7cc7cf03e4 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -163,7 +163,6 @@ async def create_chat_completion( request_id = f"chat-{random_uuid()}" try: - guided_decode_logits_processor = ( await self._guided_decode_logits_processor(request, tokenizer)) From 13a4fb1fa4e5d6d8d0cc346130142e1bfe545552 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 22 Aug 2024 19:41:41 -0500 Subject: [PATCH 203/222] fix: formatting --- vllm/entrypoints/openai/serving_chat.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 2fb7cc7cf03e4..034514b1764eb 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,9 +1,10 @@ import asyncio import json import time -from typing import AsyncGenerator, AsyncIterator, Dict, Final, List, Optional +from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List, + Optional) from typing import Sequence as GenericSequence -from typing import Type, Union +from typing import Union from fastapi import Request @@ -75,7 +76,7 @@ def __init__(self, " the parallel_tool_calls client option is preset for " "compatibility reasons, it will be ignored.") - self.tool_parser: Optional[Type[ToolParser]] = None + self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None if self.enable_auto_tools: if tool_parser == "mistral": self.tool_parser = MistralToolParser From 30238f2bb8bf595c9a1178b464bda8af1bc90564 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 22 Aug 2024 19:42:39 -0500 Subject: [PATCH 204/222] delete: old file --- vllm/entrypoints/openai/tool_parsers.py | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 vllm/entrypoints/openai/tool_parsers.py diff --git a/vllm/entrypoints/openai/tool_parsers.py b/vllm/entrypoints/openai/tool_parsers.py deleted file mode 100644 index b86b92ce9c057..0000000000000 --- a/vllm/entrypoints/openai/tool_parsers.py +++ /dev/null @@ -1,3 +0,0 @@ -from vllm.logger import init_logger - -logger = init_logger(__name__) From e548b2d59ffdca3a0532025018841a5a43d8b35f Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 22 Aug 2024 19:48:24 -0500 Subject: [PATCH 205/222] refactor: tool parsers to use AnyTokenizer --- .../openai/tool_parsers/abstract_tool_parser.py | 8 ++------ .../openai/tool_parsers/hermes_tool_parser.py | 11 +++-------- .../openai/tool_parsers/mistral_tool_parser.py | 11 +++-------- 3 files changed, 8 insertions(+), 22 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index de76bfca9132a..b0807e6f1e782 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -1,11 +1,9 @@ from typing import Dict, List, Sequence, Union -from transformers import (AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerFast) - from vllm.entrypoints.openai.protocol import (DeltaMessage, ExtractedToolCallInformation) from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) @@ -17,9 +15,7 @@ class ToolParser: derived classes. """ - def __init__(self, - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, - PreTrainedTokenizerFast, AutoTokenizer]): + def __init__(self, tokenizer: AnyTokenizer): self.prev_tool_call_arr: List[Dict] = [] # the index of the tool call that is currently being parsed self.current_tool_id: int = -1 diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 03ee11edb60ee..31551c6c9c92b 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -1,11 +1,9 @@ import json import re -from typing import Dict, List, Optional, Sequence, Union +from typing import Dict, List, Sequence, Union import partial_json_parser from partial_json_parser.core.options import Allow -from transformers import (AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerFast) from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -17,17 +15,14 @@ from vllm.entrypoints.openai.tool_parsers.utils import ( extract_intermediate_diff) from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) class Hermes2ProToolParser(ToolParser): - def __init__(self, - tokenizer: Optional[Union[PreTrainedTokenizer, - PreTrainedTokenizerFast, - PreTrainedTokenizerFast, - AutoTokenizer]] = None): + def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.current_tool_name_sent: bool = False self.prev_tool_call_arr: List[Dict] = [] diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index e7d87833a949b..c6d45235a9722 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -1,11 +1,9 @@ import json import re -from typing import Dict, List, Optional, Sequence, Union +from typing import Dict, List, Sequence, Union import partial_json_parser from partial_json_parser.core.options import Allow -from transformers import (AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerFast) from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, DeltaToolCall, @@ -17,6 +15,7 @@ from vllm.entrypoints.openai.tool_parsers.utils import ( extract_intermediate_diff) from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) @@ -29,11 +28,7 @@ class MistralToolParser(ToolParser): Used when --enable-auto-tool-choice --tool-call-parser gmistral are all set """ - def __init__(self, - tokenizer: Optional[Union[PreTrainedTokenizer, - PreTrainedTokenizerFast, - PreTrainedTokenizerFast, - AutoTokenizer]] = None): + def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) # initialize properties used for state when parsing tool calls in From ad9a8ff11282e20a34f60d6537cd5a7c0d0041d3 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 22 Aug 2024 19:49:17 -0500 Subject: [PATCH 206/222] refactor: util.py -> utils.py --- tests/tool_use/{util.py => utils.py} | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) rename tests/tool_use/{util.py => utils.py} (98%) diff --git a/tests/tool_use/util.py b/tests/tool_use/utils.py similarity index 98% rename from tests/tool_use/util.py rename to tests/tool_use/utils.py index 8ec9b05b2c521..ee5a83e5ca670 100644 --- a/tests/tool_use/util.py +++ b/tests/tool_use/utils.py @@ -14,7 +14,9 @@ class ServerConfig(TypedDict): # universal args for all models go here. also good if you need to test locally # and change type or KV cache quantization or something. -ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "8096"] +ARGS: List[str] = [ + "--enable-auto-tool-choice", "--max-model-len", "8096", "--dtype", "half" +] CONFIGS: Dict[str, ServerConfig] = { "hermes": { From 79ab9292416a69a84d588c236f327997edcf0c4e Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 22 Aug 2024 19:53:33 -0500 Subject: [PATCH 207/222] refactor: utils --- tests/tool_use/conftest.py | 2 +- tests/tool_use/test_chat_completions.py | 2 +- tests/tool_use/test_parallel_tool_calls.py | 6 +++--- tests/tool_use/test_tool_calls.py | 4 ++-- tests/tool_use/utils.py | 4 +--- 5 files changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/tool_use/conftest.py b/tests/tool_use/conftest.py index c940624924f82..9ef90d5063e78 100644 --- a/tests/tool_use/conftest.py +++ b/tests/tool_use/conftest.py @@ -3,7 +3,7 @@ from tests.utils import RemoteOpenAIServer -from .util import ARGS, CONFIGS, ServerConfig +from .utils import ARGS, CONFIGS, ServerConfig # for each server config, download the model and return the config diff --git a/tests/tool_use/test_chat_completions.py b/tests/tool_use/test_chat_completions.py index 8484dde1a75fd..038ff81d2b674 100644 --- a/tests/tool_use/test_chat_completions.py +++ b/tests/tool_use/test_chat_completions.py @@ -3,7 +3,7 @@ import openai import pytest -from .util import MESSAGES_WITHOUT_TOOLS, WEATHER_TOOL +from .utils import MESSAGES_WITHOUT_TOOLS, WEATHER_TOOL # test: make sure chat completions without tools provided work even when tools diff --git a/tests/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py index f4d1cba457452..b03b5a2075a6c 100644 --- a/tests/tool_use/test_parallel_tool_calls.py +++ b/tests/tool_use/test_parallel_tool_calls.py @@ -4,9 +4,9 @@ import openai import pytest -from .util import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS, - MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL, - WEATHER_TOOL) +from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL, + WEATHER_TOOL) # test: getting the model to generate parallel tool calls (streaming/not) diff --git a/tests/tool_use/test_tool_calls.py b/tests/tool_use/test_tool_calls.py index d182f32bd8966..c3abe9e1f5060 100644 --- a/tests/tool_use/test_tool_calls.py +++ b/tests/tool_use/test_tool_calls.py @@ -4,8 +4,8 @@ import openai import pytest -from .util import (MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE, - SEARCH_TOOL, WEATHER_TOOL) +from .utils import (MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE, + SEARCH_TOOL, WEATHER_TOOL) # test: request a chat completion that should return tool calls, so we know they diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index ee5a83e5ca670..8ec9b05b2c521 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -14,9 +14,7 @@ class ServerConfig(TypedDict): # universal args for all models go here. also good if you need to test locally # and change type or KV cache quantization or something. -ARGS: List[str] = [ - "--enable-auto-tool-choice", "--max-model-len", "8096", "--dtype", "half" -] +ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "8096"] CONFIGS: Dict[str, ServerConfig] = { "hermes": { From bba73949582db324132cdb8cbdc2747d191752ad Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 22 Aug 2024 20:05:05 -0500 Subject: [PATCH 208/222] fix: cruft --- vllm/entrypoints/openai/api_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d1c7fea12bf32..36791c90bda79 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -272,7 +272,6 @@ async def create_chat_completion(request: ChatCompletionRequest, generator = await openai_serving_chat.create_chat_completion( request, raw_request) - # if there's an error, return it if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) From e11536f44d01a38396670461a67c886a1b8831d2 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 22 Aug 2024 20:06:17 -0500 Subject: [PATCH 209/222] fix: unnecessary union in type --- vllm/entrypoints/openai/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 6bfa689e6bcd4..d18ba221a7da0 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -160,7 +160,7 @@ class ChatCompletionRequest(OpenAIBaseModel): temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 tools: Optional[List[ChatCompletionToolsParam]] = None - tool_choice: Optional[Union[Union[Literal["none"], Literal["auto"]], + tool_choice: Optional[Union[Literal["none"], Literal["auto"], ChatCompletionNamedToolChoiceParam]] = "none" # NOTE this will be ignored by VLLM -- the model determines the behavior From bea0f5613fae77e90efaf5fb1ee38ab4dc024161 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Thu, 22 Aug 2024 20:07:22 -0500 Subject: [PATCH 210/222] fix: type --- vllm/entrypoints/openai/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index d18ba221a7da0..3fc62ef7e055f 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -751,7 +751,7 @@ class ChatCompletionResponseChoice(OpenAIBaseModel): message: ChatMessage logprobs: Optional[ChatCompletionLogProbs] = None # per OpenAI spec this is the default - finish_reason: Optional[str] = Field(default="stop") + finish_reason: Optional[str] = "stop" # not part of the OpenAI spec but included in vLLM for legacy reasons stop_reason: Optional[Union[int, str]] = None From d5b169b406d806e46a7bfa0d840c1f3028e06436 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 23 Aug 2024 10:36:24 -0500 Subject: [PATCH 211/222] fix: yapf might not need to be disabled --- vllm/entrypoints/openai/protocol.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 3fc62ef7e055f..edd76fd2d0bf7 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -5,8 +5,6 @@ from typing import Any, Dict, List, Literal, Optional, Union import torch -# yapf conflicts with isort for this block -# yapf: disable from openai.types.chat import ChatCompletionContentPartParam from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Annotated, Required, TypedDict @@ -19,8 +17,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid -# yapf: enable - # torch is mocked during docs generation, # so we have to provide the values as literals _MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807) From f73e0891897f33af3c09ef402f998c3b2e0f7e07 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sun, 25 Aug 2024 01:32:25 -0500 Subject: [PATCH 212/222] fix: exception in hermes chat template thats unnecessary cc @interstellarninja --- examples/tool_chat_template_hermes.jinja | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/tool_chat_template_hermes.jinja b/examples/tool_chat_template_hermes.jinja index 6947444f135a8..b18b463032d4f 100644 --- a/examples/tool_chat_template_hermes.jinja +++ b/examples/tool_chat_template_hermes.jinja @@ -107,9 +107,6 @@ {%- endfor %} {{- '<|im_end|>\n' }} {%- elif message.role == "tool" %} - {%- if not message.name is defined %} - {{- raise_exception("Tool response dicts require a 'name' key indicating the name of the called function!") }} - {%- endif %} {%- if loop.previtem and loop.previtem.role != "tool" %} {{- '<|im_start|>tool\n' }} {%- endif %} From 477003c5e4f3b6256304115e02aa6f16946e19d7 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sun, 25 Aug 2024 01:37:15 -0500 Subject: [PATCH 213/222] fix: need to check if tool calls is present regardless of "content" type; perform narrowing --- vllm/entrypoints/chat_utils.py | 129 +++++++++++++++++++-------------- 1 file changed, 76 insertions(+), 53 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 8e4c2a08b7e61..40195d264221b 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1,21 +1,25 @@ import codecs from dataclasses import dataclass -from functools import lru_cache +from functools import lru_cache, partial from pathlib import Path from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple, Union, cast) # yapf conflicts with isort for this block # yapf: disable -from openai.types.chat import ChatCompletionContentPartImageParam +from openai.types.chat import (ChatCompletionAssistantMessageParam, + ChatCompletionContentPartImageParam) from openai.types.chat import ( ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) -from openai.types.chat import ChatCompletionContentPartTextParam +from openai.types.chat import (ChatCompletionContentPartRefusalParam, + ChatCompletionContentPartTextParam) from openai.types.chat import ( ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) +from openai.types.chat import (ChatCompletionMessageToolCallParam, + ChatCompletionToolMessageParam) # yapf: enable # pydantic needs the TypedDict from typing_extensions -from pydantic import ConfigDict, TypeAdapter +from pydantic import ConfigDict from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig @@ -51,7 +55,8 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False): ChatCompletionContentPartParam: TypeAlias = Union[ OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, - CustomChatCompletionContentPartParam, ] + ChatCompletionContentPartRefusalParam, + CustomChatCompletionContentPartParam] class CustomChatCompletionMessageParam(TypedDict, total=False): @@ -68,8 +73,12 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): Provides the model information to differentiate between participants of the same role. """ + tool_call_id: Optional[str] - tool_calls: Optional[List] + """Tool call that this message is responding to.""" + + tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + """The tool calls generated by the model, such as function calls.""" ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, @@ -78,11 +87,20 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): # TODO: Make fields ReadOnly once mypy supports it class ConversationMessage(TypedDict, total=False): - role: str + role: Required[str] + """The role of the message's author.""" + content: Optional[str] + """The contents of the message""" + tool_call_id: Optional[str] + """Tool call that this message is responding to.""" + name: Optional[str] - tool_calls: Optional[List] + """The name of the function to call""" + + tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + """The tool calls generated by the model, such as function calls.""" @dataclass(frozen=True) @@ -155,9 +173,11 @@ def _get_full_multimodal_text_prompt(placeholder_token_str: str, return f"{placeholder_token_str}\n{text_prompt}" -_TextParser = TypeAdapter(ChatCompletionContentPartTextParam) -_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam) -_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam) +# No need to validate using Pydantic again +_TextParser = partial(cast, ChatCompletionContentPartTextParam) +_ImageParser = partial(cast, ChatCompletionContentPartImageParam) +_AudioParser = partial(cast, ChatCompletionContentPartAudioParam) +_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) def _parse_chat_message_content_parts( @@ -173,7 +193,7 @@ def _parse_chat_message_content_parts( for part in parts: part_type = part["type"] if part_type == "text": - text = _TextParser.validate_python(part)["text"] + text = _TextParser(part)["text"] texts.append(text) elif part_type == "image_url": modality = "image" @@ -181,7 +201,7 @@ def _parse_chat_message_content_parts( raise NotImplementedError( "Multiple multimodal inputs is currently not supported.") - image_url = _ImageParser.validate_python(part)["image_url"] + image_url = _ImageParser(part)["image_url"] if image_url.get("detail", "auto") != "auto": logger.warning( @@ -196,9 +216,12 @@ def _parse_chat_message_content_parts( raise NotImplementedError( "Multiple multimodal inputs is currently not supported.") - audio_url = _AudioParser.validate_python(part)["audio_url"] + audio_url = _AudioParser(part)["audio_url"] audio_future = async_get_and_parse_audio(audio_url["url"]) mm_futures.append(audio_future) + elif part_type == "refusal": + text = _RefusalParser(part)["refusal"] + texts.append(text) else: raise NotImplementedError(f"Unknown part type: {part_type}") @@ -223,53 +246,53 @@ def _parse_chat_message_content_parts( return ChatMessageParseResult(messages=messages, mm_futures=mm_futures) +# No need to validate using Pydantic again +_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam) +_ToolParser = partial(cast, ChatCompletionToolMessageParam) + + def _parse_chat_message_content( message: ChatCompletionMessageParam, model_config: ModelConfig, tokenizer: AnyTokenizer, ) -> ChatMessageParseResult: - role = cast(str, message["role"]) # can be iterable in some cases, so cast + role = message["role"] content = message.get("content") - tool_call_id = message.get("tool_call_id") - tool_calls = message.get("tool_calls", None) - # no longer used by OpenAI, but some models still use it for tool calls. - name = message.get("name", "") - - # empty case - if content is None and tool_calls is None: - return ChatMessageParseResult(messages=[], mm_futures=[]) - - # special case - assistant message where tool calls are provided. - if role == "assistant" and tool_calls is not None: - messages = [ - ConversationMessage(role=role, - content=cast(Optional[str], content), - tool_calls=list(tool_calls)) - ] - return ChatMessageParseResult(messages=messages, mm_futures=[]) - - # special case - tool call result message - elif role == "tool": - messages = [ - ConversationMessage(role=role, - name=name, - content=cast(Union[str, None], content), - tool_call_id=cast(Union[str, None], - tool_call_id)) - ] - return ChatMessageParseResult(messages=messages, mm_futures=[]) - - # other cases - normal assistant response, user message or system message - elif isinstance(content, str): + tool_calls = message.get("tool_calls") + + if content is None and not tool_calls: + result = ChatMessageParseResult(messages=[], mm_futures=[]) + # assistant messages, or user messages, or system messages + elif (isinstance(content, str) + or (tool_calls is not None and + (isinstance(content, str) or content is None))): messages = [ConversationMessage(role=role, content=content)] - return ChatMessageParseResult(messages=messages, mm_futures=[]) + result = ChatMessageParseResult(messages=messages, mm_futures=[]) - return _parse_chat_message_content_parts( - role, - content, # type: ignore - model_config, - tokenizer, - ) + else: + result = _parse_chat_message_content_parts( + role, + content, # type: ignore + model_config, + tokenizer, + ) + + for result_msg in result.messages: + if role == 'assistant': + parsed_msg = _AssistantParser(message) + if (function_call := parsed_msg.get("function_call")) is not None: + result_msg["name"] = function_call["name"] + if "tool_calls" in parsed_msg: + result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) + elif role == "tool": + parsed_msg = _ToolParser(message) + if "tool_call_id" in parsed_msg: + result_msg["tool_call_id"] = parsed_msg["tool_call_id"] + + if "name" in message and isinstance(message["name"], str): + result_msg["name"] = message["name"] + + return result def parse_chat_messages( From df85e1272ef123d6ccc85412a3f84e289cd439dd Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 30 Aug 2024 18:17:25 -0500 Subject: [PATCH 214/222] fix: type narrowing & remove unused var --- vllm/entrypoints/chat_utils.py | 36 +++++++++++++++------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 40195d264221b..5d9cf8182a072 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -12,7 +12,8 @@ from openai.types.chat import ( ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) from openai.types.chat import (ChatCompletionContentPartRefusalParam, - ChatCompletionContentPartTextParam) + ChatCompletionContentPartTextParam, + ) from openai.types.chat import ( ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) from openai.types.chat import (ChatCompletionMessageToolCallParam, @@ -258,30 +259,25 @@ def _parse_chat_message_content( ) -> ChatMessageParseResult: role = message["role"] content = message.get("content") - tool_calls = message.get("tool_calls") - if content is None and not tool_calls: - result = ChatMessageParseResult(messages=[], mm_futures=[]) - # assistant messages, or user messages, or system messages - elif (isinstance(content, str) - or (tool_calls is not None and - (isinstance(content, str) or content is None))): - messages = [ConversationMessage(role=role, content=content)] - result = ChatMessageParseResult(messages=messages, mm_futures=[]) - - else: - result = _parse_chat_message_content_parts( - role, - content, # type: ignore - model_config, - tokenizer, - ) + if content is None: + content = [] + elif isinstance(content, str): + content = [ + ChatCompletionContentPartTextParam(type="text", text=content) + ] + + result = _parse_chat_message_content_parts( + role, + content, # type: ignore + model_config, + tokenizer, + ) for result_msg in result.messages: if role == 'assistant': parsed_msg = _AssistantParser(message) - if (function_call := parsed_msg.get("function_call")) is not None: - result_msg["name"] = function_call["name"] + if "tool_calls" in parsed_msg: result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) elif role == "tool": From c6d1bf14a482324a8e4378bf88d65e46ebb7a618 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 30 Aug 2024 18:19:31 -0500 Subject: [PATCH 215/222] fix: format --- vllm/entrypoints/chat_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 5d9cf8182a072..ff747771a520c 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -12,8 +12,7 @@ from openai.types.chat import ( ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) from openai.types.chat import (ChatCompletionContentPartRefusalParam, - ChatCompletionContentPartTextParam, - ) + ChatCompletionContentPartTextParam) from openai.types.chat import ( ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) from openai.types.chat import (ChatCompletionMessageToolCallParam, From e70160dfd9d44bc63f21cb354f5314538c57d708 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Fri, 30 Aug 2024 23:47:13 -0500 Subject: [PATCH 216/222] refactor: tool arguments in assistant-role messages with tool_calls should parsed from a string into a dict before being passed into the chat template --- vllm/entrypoints/chat_utils.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index ff747771a520c..1b194fb4bd3b4 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1,9 +1,10 @@ import codecs +import json from dataclasses import dataclass from functools import lru_cache, partial from pathlib import Path -from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple, - Union, cast) +from typing import (Any, Awaitable, Dict, Iterable, List, Literal, Optional, + Tuple, Union, cast) # yapf conflicts with isort for this block # yapf: disable @@ -322,6 +323,20 @@ def apply_chat_template( "allowed, so you must provide a chat template if the tokenizer " "does not define one.") + # per the Transformers docs & maintainers, tool call arguments in + # assistant-role messages with tool_calls need to be dicts not JSON str - + # this is how tool-use chat templates will expect them moving forwards + # so, for messages that have tool_calls, parse the string (which we get + # from openAI format) to dict + for message in conversation: + if (message["role"] == "assistant" and "tool_calls" in message + and isinstance(message["tool_calls"], list)): + + for i in range(len(message["tool_calls"])): + args: str = message["tool_calls"][i]["function"]["arguments"] + parsed_args: Dict = json.loads(args) + message["tool_calls"][i]["function"]["arguments"] = parsed_args + prompt = tokenizer.apply_chat_template( conversation=conversation, chat_template=chat_template, From 6db7f7b703e27702044027caccee2231530255b7 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 31 Aug 2024 01:02:42 -0500 Subject: [PATCH 217/222] fix: hermes tool parser issue --- .../openai/tool_parsers/hermes_tool_parser.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 31551c6c9c92b..a6429dc12aa74 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -132,11 +132,12 @@ def extract_tool_calls_streaming( cur_tool_end_count = current_token_ids.count( self.tool_call_end_token_id) - # case: if we're generating text, NOT tools, return a text delta + # case: if we're generating text, OR rounding out a tool call if (cur_tool_start_count == cur_tool_end_count and prev_tool_end_count == cur_tool_end_count): logger.debug("Generating text content! skipping tool parsing.") - return DeltaMessage(content=delta_text) + if delta_text != self.tool_call_end_token: + return DeltaMessage(content=delta_text) # case: if tool open & close tag counts don't match, we're doing # imaginary "else" block here @@ -185,6 +186,8 @@ def extract_tool_calls_streaming( logger.debug( "Finishing tool and found diff that had not " "been streamed yet: %s", diff) + self.streamed_args_for_tool[self.current_tool_id] \ + += diff return DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, function=DeltaFunctionCall( @@ -194,7 +197,10 @@ def extract_tool_calls_streaming( # case -- otherwise we're just generating text else: - delta = DeltaMessage(tool_calls=[], content=delta_text) + print("JUST GENERATING TEXT") + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + delta = DeltaMessage(tool_calls=[], content=text) return delta try: From cb55c081dd2eacc45c9d04f72fa58d89d0168921 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 31 Aug 2024 16:43:24 -0500 Subject: [PATCH 218/222] fix: mypy stuff in tool parsers --- .../entrypoints/openai/tool_parsers/hermes_tool_parser.py | 8 +++++++- .../openai/tool_parsers/mistral_tool_parser.py | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index a6429dc12aa74..6aa897f8932d2 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -15,7 +15,7 @@ from vllm.entrypoints.openai.tool_parsers.utils import ( extract_intermediate_diff) from vllm.logger import init_logger -from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer logger = init_logger(__name__) @@ -24,6 +24,12 @@ class Hermes2ProToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) + + if isinstance(self.model_tokenizer, MistralTokenizer): + logger.error( + "Detected Mistral tokenizer when using a Hermes model") + self.model_tokenizer = self.model_tokenizer.tokenizer + self.current_tool_name_sent: bool = False self.prev_tool_call_arr: List[Dict] = [] self.current_tool_id: int = -1 diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index c6d45235a9722..d48770c792e98 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -15,7 +15,7 @@ from vllm.entrypoints.openai.tool_parsers.utils import ( extract_intermediate_diff) from vllm.logger import init_logger -from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer logger = init_logger(__name__) @@ -31,6 +31,12 @@ class MistralToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) + if isinstance(self.model_tokenizer, MistralTokenizer): + self.model_tokenizer = self.model_tokenizer.tokenizer + else: + logger.info("Non-Mistral tokenizer detected when using a Mistral " + "model...") + # initialize properties used for state when parsing tool calls in # streaming mode self.prev_tool_call_arr: List[Dict] = [] From a70e8266f057cee4ae0ce6a17476a6c7cb539b3d Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Sat, 31 Aug 2024 16:45:26 -0500 Subject: [PATCH 219/222] fix: remove cruft --- vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 6aa897f8932d2..7afbca7162edf 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -203,7 +203,6 @@ def extract_tool_calls_streaming( # case -- otherwise we're just generating text else: - print("JUST GENERATING TEXT") text = delta_text.replace(self.tool_call_start_token, "") text = text.replace(self.tool_call_end_token, "") delta = DeltaMessage(tool_calls=[], content=text) From 165c02638cc7780478d35a629d8dfcad52b927cc Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Mon, 2 Sep 2024 21:12:02 -0500 Subject: [PATCH 220/222] fix: merge conflict + type issues --- vllm/entrypoints/chat_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 63f9e2ad1153a..1fa66da3fa148 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -280,7 +280,7 @@ def _parse_chat_message_content_parts( placeholder, 0) + 1 elif part_type == "audio_url": - audio_url = _AudioParser.validate_python(part)["audio_url"] + audio_url = _AudioParser(part)["audio_url"] audio_coro = async_get_and_parse_audio(audio_url["url"]) placeholder = mm_tracker.add("audio", audio_coro) if placeholder: From 4972a89b768053b74c7daa6e7ee929b345d6bfd5 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Tue, 3 Sep 2024 23:51:42 -0500 Subject: [PATCH 221/222] fix(tests): update pytest fixture for client based on #7565 --- tests/tool_use/conftest.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/tool_use/conftest.py b/tests/tool_use/conftest.py index 9ef90d5063e78..ab6a29eba1b3f 100644 --- a/tests/tool_use/conftest.py +++ b/tests/tool_use/conftest.py @@ -1,4 +1,5 @@ import pytest +import pytest_asyncio from huggingface_hub import snapshot_download from tests.utils import RemoteOpenAIServer @@ -25,6 +26,7 @@ def server(request, server_config: ServerConfig): yield server -@pytest.fixture(scope="session") -def client(server: RemoteOpenAIServer): - return server.get_async_client() +@pytest_asyncio.fixture +async def client(server: RemoteOpenAIServer): + async with server.get_async_client() as async_client: + yield async_client From d6728b2d4b54f6345862f7cd0158fc3344a20944 Mon Sep 17 00:00:00 2001 From: Kyle Mistele Date: Wed, 4 Sep 2024 06:13:39 -0500 Subject: [PATCH 222/222] fix: docs --- docs/source/serving/openai_compatible_server.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index b21096187562a..eb4ea0fb5655e 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -157,7 +157,7 @@ vLLM will use guided decoding to ensure the response matches the tool parameter To enable this feature, you should set the following flags: * `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it deems appropriate. -* `--tool-parser` -- select the tool parser to use - currently either `hermes` or `mistral`. Additional tool parsers +* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral`. Additional tool parsers will continue to be added in the future. * `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages that contain previously generated tool calls. Hermes and Mistral models have tool-compatible chat templates in their