diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 6687929c0bebe..80037dda20015 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -12,4 +12,5 @@ torch py-cpuinfo transformers mistral_common >= 1.3.4 -openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args \ No newline at end of file +openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args +partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args \ No newline at end of file diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index e0eba7f09bd65..8bb7067faa97c 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -157,8 +157,9 @@ vLLM will use guided decoding to ensure the response matches the tool parameter To enable this feature, you should set the following flags: * `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it deems appropriate. -* `--tool-call-parser` -- select the tool parser to use - currently either `hermes`, `mistral` or `llama3_json`. Additional tool parsers -will continue to be added in the future. +* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral` or `llama3_json` or `internlm`. Additional tool parsers +will continue to be added in the future, and also can register your own tool parsers in the `--tool-parser-plugin`. +* `--tool-parser-plugin` -- **optional** tool parser plugin used to register user defined tool parsers into vllm, the registered tool parser name can be specified in `--tool-call-parser`. * `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages that contain previously generated tool calls. Hermes, Mistral and Llama models have tool-compatible chat templates in their `tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat @@ -218,4 +219,73 @@ it works better with vLLM. Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja` +#### Internlm Models +Supported models: +* `internlm/internlm2_5-7b-chat` (confirmed) +* Additional internlm2.5 function-calling models are compatible as well + +Known issues: +* Although this implementation also supports Internlm2, the tool call results are not stable when testing with the `internlm/internlm2-chat-7b` model. + +Recommended flags: `--tool-call-parser internlm --chat-template examples/tool_chat_template_internlm2_tool.jinja` + + +### How to write a tool parser plugin + +A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py. + +Here is a summary of a plugin file: + +```python + +# import the required packages + +# define a tool parser and register it to vllm +# the name list in register_module can be used +# in --tool-call-parser. you can define as many +# tool parsers as you want here. +@ToolParserManager.register_module(["example"]) +class ExampleToolParser(ToolParser): + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + # adjust request. e.g.: set skip special tokens + # to False for tool call output. + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + return request + + # implement the tool call parse for stream call + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + return delta + + # implement the tool parse for non-stream call + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=text) + + +``` + +Then you can use this plugin in the command line like this. +``` + --enable-auto-tool-choice \ + --tool-parser-plugin + --tool-call-parser example \ + --chat-template \ +``` diff --git a/examples/tool_chat_template_internlm2_tool.jinja b/examples/tool_chat_template_internlm2_tool.jinja new file mode 100644 index 0000000000000..ac99666e93bc4 --- /dev/null +++ b/examples/tool_chat_template_internlm2_tool.jinja @@ -0,0 +1,60 @@ +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{{- bos_token }} +{%- if system_message is defined %} +{{- "<|im_start|>system\n" + system_message + "<|im_end|>\n" }} +{%- endif %} + +{%- if tools is not none %} + {{- "<|im_start|>system name=<|plugin|>\n[" }} + {%- for tool in tools %} + {{- tool.function|tojson }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- "<|im_end|>\n" }} +{%- endif %} + +{%- for message in loop_messages %} + {%- if message["role"] == "user" %} + {{- "<|im_start|>user\n" + message["content"] + "<|im_end|>\n"}} + {%- elif message.tool_calls is defined and message.tool_calls is not none %} + {%- set content = message["content"] if message["content"] else "" %} + {{- "<|im_start|>assistant\n" + content }} + {%- for tool_call in message.tool_calls %} + {%- set function=tool_call.function %} + {{- "<|action_start|><|plugin|>\n" }} + {{- '{"name": "' + function.name + '", '}} + {{- '"arguments": ' + function.arguments|tojson + '}' }} + {{- "<|action_end|>" }} + {%- endfor %} + {{- "<|im_end|>\n" }} + {%- elif message["role"] == "assistant" %} + {{- "<|im_start|>assistant\n" + message["content"] + "<|im_end|>\n"}} + {%- elif message["role"] == "tool_results" or message["role"] == "tool" or message["role"] == "function" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {{- "<|im_start|>environment name=<|plugin|>\n" + content|string + "<|im_end|>\n" }} + {%- else %} + {{- raise_exception("Only user and assistant and tool_results and tool and function roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %} + +{%- if add_generation_prompt %} +{{- '<|im_start|>assistant\n' }} +{%- endif %} \ No newline at end of file diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index 1a840f8a51c9f..ce36515a2381c 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -87,6 +87,18 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], "call the tool. Otherwise, answer the user's query directly " "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " "to the user's question - just respond to it normally." + }, + "internlm": { + "model": + "internlm/internlm2_5-7b-chat", + "arguments": [ + "--tool-call-parser", "internlm", "--chat-template", + str(VLLM_PATH / + "examples/tool_chat_template_internlm2_tool.jinja"), + "--trust_remote_code" + ], + "supports_parallel": + False, } } @@ -109,7 +121,7 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], "type": "string", "description": - "the two-letter abbreviation for the state " + "must the two-letter abbreviation for the state " "that the city is in, e.g. 'CA' which would " "mean 'California'" }, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 5078a2654eb22..bf367482cd80c 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -53,6 +53,7 @@ from vllm.entrypoints.openai.serving_engine import BaseModelPath from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) +from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path @@ -526,6 +527,15 @@ async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + valide_tool_parses = ToolParserManager.tool_parsers.keys() + if args.enable_auto_tool_choice \ + and args.tool_call_parser not in valide_tool_parses: + raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " + f"(chose from {{ {','.join(valide_tool_parses)} }})") + temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) temp_socket.bind(("", args.port)) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 446769a277f58..f59ba4e30accd 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -12,6 +12,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, PromptAdapterPath) +from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.utils import FlexibleArgumentParser @@ -190,16 +191,27 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "Enable auto tool choice for supported models. Use --tool-call-parser" "to specify which parser to use") + valid_tool_parsers = ToolParserManager.tool_parsers.keys() parser.add_argument( "--tool-call-parser", type=str, - choices=["mistral", "hermes", "llama3_json"], + metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in " + "--tool-parser-plugin", default=None, help= "Select the tool call parser depending on the model that you're using." " This is used to parse the model-generated tool call into OpenAI API " "format. Required for --enable-auto-tool-choice.") + parser.add_argument( + "--tool-parser-plugin", + type=str, + default="", + help= + "Special the tool parser plugin write to parse the model-generated tool" + " into OpenAI API format, the name register in this plugin can be used " + "in --tool-call-parser.") + parser = AsyncEngineArgs.add_cli_args(parser) parser.add_argument('--max-log-len', diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 41f131f56b51f..ce529f6f0ff58 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -29,10 +29,7 @@ OpenAIServing, PromptAdapterPath, TextTokensPrompt) -from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser, - Llama3JsonToolParser, - MistralToolParser, - ToolParser) +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.inputs import TokensPrompt from vllm.logger import init_logger from vllm.outputs import CompletionOutput, RequestOutput @@ -82,15 +79,13 @@ def __init__(self, self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None if self.enable_auto_tools: - if tool_parser == "mistral": - self.tool_parser = MistralToolParser - elif tool_parser == "hermes": - self.tool_parser = Hermes2ProToolParser - elif tool_parser == "llama3_json": - self.tool_parser = Llama3JsonToolParser - else: + try: + self.tool_parser = ToolParserManager.get_tool_parser( + tool_parser) + except Exception as e: raise TypeError("Error: --enable-auto-tool-choice requires " - "--tool-call-parser") + f"tool_parser:'{tool_parser}' which has not " + "been registered") from e async def create_chat_completion( self, @@ -187,6 +182,10 @@ async def create_chat_completion( raw_request.state.request_metadata = request_metadata try: + if self.enable_auto_tools and self.tool_parser: + request = self.tool_parser(tokenizer).adjust_request( + request=request) + if isinstance(prompt, str): prompt_inputs = self._tokenize_prompt_input( request, @@ -282,11 +281,11 @@ async def chat_completion_stream_generator( num_choices = 1 if request.n is None else request.n previous_num_tokens = [0] * num_choices finish_reason_sent = [False] * num_choices - num_prompt_tokens = 0 - tool_parser: Optional[ToolParser] = self.tool_parser( - tokenizer) if self.tool_parser else None + tool_parsers: List[Optional[ToolParser]] = [ + self.tool_parser(tokenizer) if self.tool_parser else None + ] * num_choices if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): tool_choice_function_name = request.tool_choice.function.name @@ -324,7 +323,7 @@ async def chat_completion_stream_generator( # NOTE num_choices defaults to 1 so this usually executes # once per request for i in range(num_choices): - + tool_parser = tool_parsers[i] choice_data = ChatCompletionResponseStreamChoice( index=i, delta=DeltaMessage( @@ -399,6 +398,7 @@ async def chat_completion_stream_generator( for output in res.outputs: i = output.index + tool_parser = tool_parsers[i] if finish_reason_sent[i]: continue @@ -446,7 +446,8 @@ async def chat_completion_stream_generator( delta_text=delta_text, previous_token_ids=previous_token_ids, current_token_ids=current_token_ids, - delta_token_ids=output.token_ids)) + delta_token_ids=output.token_ids, + request=request)) # update the previous values for the next iteration previous_texts[i] = current_text @@ -685,7 +686,8 @@ async def chat_completion_full_generator( and self.tool_parser: tool_parser = self.tool_parser(tokenizer) - tool_call_info = tool_parser.extract_tool_calls(output.text) + tool_call_info = tool_parser.extract_tool_calls( + output.text, request=request) tools_called = tool_call_info.tools_called if tool_call_info.tools_called: message = ChatMessage(role=role, diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 0069a2b8044b7..309d9bede489b 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -1,9 +1,10 @@ -from .abstract_tool_parser import ToolParser +from .abstract_tool_parser import ToolParser, ToolParserManager from .hermes_tool_parser import Hermes2ProToolParser +from .internlm2_tool_parser import Internlm2ToolParser from .llama_tool_parser import Llama3JsonToolParser from .mistral_tool_parser import MistralToolParser __all__ = [ - "ToolParser", "Hermes2ProToolParser", "MistralToolParser", - "Llama3JsonToolParser" + "ToolParser", "ToolParserManager", "Hermes2ProToolParser", + "MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser" ] diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index 873f615d43257..7e55532bc7297 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -1,9 +1,14 @@ -from typing import Dict, List, Sequence, Union +import importlib +import importlib.util +import os +from typing import Callable, Dict, List, Optional, Sequence, Type, Union -from vllm.entrypoints.openai.protocol import (DeltaMessage, +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage, ExtractedToolCallInformation) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import is_list_of logger = init_logger(__name__) @@ -24,8 +29,16 @@ def __init__(self, tokenizer: AnyTokenizer): self.model_tokenizer = tokenizer - def extract_tool_calls(self, - model_output: str) -> ExtractedToolCallInformation: + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + """ + Static method that used to adjust the request parameters. + """ + return request + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: """ Static method that should be implemented for extracting tool calls from a complete model-generated string. @@ -44,6 +57,7 @@ def extract_tool_calls_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], + request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: """ Instance method that should be implemented for extracting tool calls @@ -55,3 +69,86 @@ def extract_tool_calls_streaming( raise NotImplementedError( "AbstractToolParser.extract_tool_calls_streaming has not been " "implemented!") + + +class ToolParserManager: + tool_parsers: Dict[str, Type] = {} + + @classmethod + def get_tool_parser(cls, name) -> Type: + """ + Get tool parser by name which is registered by `register_module`. + + Raise a KeyError exception if the name is not registered. + """ + if name in cls.tool_parsers: + return cls.tool_parsers[name] + + raise KeyError(f"tool helper: '{name}' not found in tool_parsers") + + @classmethod + def _register_module(cls, + module: Type, + module_name: Optional[Union[str, List[str]]] = None, + force: bool = True) -> None: + if not issubclass(module, ToolParser): + raise TypeError( + f'module must be subclass of ToolParser, but got {type(module)}' + ) + if module_name is None: + module_name = module.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in cls.tool_parsers: + existed_module = cls.tool_parsers[name] + raise KeyError(f'{name} is already registered ' + f'at {existed_module.__module__}') + cls.tool_parsers[name] = module + + @classmethod + def register_module( + cls, + name: Optional[Union[str, List[str]]] = None, + force: bool = True, + module: Union[Type, None] = None) -> Union[type, Callable]: + """ + Register module with the given name or name list. it can be used as a + decoder(with module as None) or normal function(with module as not + None). + """ + if not isinstance(force, bool): + raise TypeError(f'force must be a boolean, but got {type(force)}') + + # raise the error ahead of time + if not (name is None or isinstance(name, str) + or is_list_of(name, str)): + raise TypeError( + 'name must be None, an instance of str, or a sequence of str, ' + f'but got {type(name)}') + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + cls._register_module(module=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(module): + cls._register_module(module=module, module_name=name, force=force) + return module + + return _register + + @classmethod + def import_tool_parser(cls, plugin_path: str) -> None: + """ + Import a user defined tool parser by the path of the tool parser define + file. + """ + module_name = os.path.splitext(os.path.basename(plugin_path))[0] + spec = importlib.util.spec_from_file_location(module_name, plugin_path) + if spec is None or spec.loader is None: + logger.error("load %s from %s failed.", module_name, plugin_path) + return + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index ad6f536838a88..40f041767190b 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -5,12 +5,13 @@ import partial_json_parser from partial_json_parser.core.options import Allow -from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser) + ToolParser, ToolParserManager) from vllm.entrypoints.openai.tool_parsers.utils import ( extract_intermediate_diff) from vllm.logger import init_logger @@ -20,6 +21,7 @@ logger = init_logger(__name__) +@ToolParserManager.register_module("hermes") class Hermes2ProToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): @@ -57,8 +59,11 @@ def __init__(self, tokenizer: AnyTokenizer): "Hermes 2 Pro Tool parser could not locate tool call start/end " "tokens in the tokenizer!") - def extract_tool_calls(self, - model_output: str) -> ExtractedToolCallInformation: + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: # sanity check; avoid unnecessary processing if self.tool_call_start_token not in model_output: @@ -114,6 +119,7 @@ def extract_tool_calls_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], + request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: logger.debug("delta_text: %s", delta_text) diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py new file mode 100644 index 0000000000000..905ab7db3d04c --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -0,0 +1,208 @@ +import json +from typing import Dict, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module(["internlm"]) +class Internlm2ToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + self.position = 0 + + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != 'none': + # do not skip special tokens because internlm use the special + # tokens to indicated the start and end of the tool calls + # information. + request.skip_special_tokens = False + return request + + def get_argments(self, obj): + if "parameters" in obj: + return obj.get("parameters") + elif "arguments" in obj: + return obj.get("arguments") + return None + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + if '<|action_start|>' not in current_text: + self.position = len(current_text) + return DeltaMessage(content=delta_text) + # if the tool call is sended, return a empty delta message + # to make sure the finish_reason will be send correctly. + if self.current_tool_id > 0: + return DeltaMessage(content='') + + last_pos = self.position + if '<|action_start|><|plugin|>' not in current_text[last_pos:]: + return None + + new_delta = current_text[last_pos:] + text, action = new_delta.split('<|action_start|><|plugin|>') + + if len(text) > 0: + self.position = self.position + len(text) + return DeltaMessage(content=text) + + action = action.strip() + action = action.split('<|action_end|>'.strip())[0] + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + + try: + parsable_arr = action + + # tool calls are generated in an object in inernlm2 + # it's not support parallel tool calls + try: + tool_call_arr: Dict = partial_json_parser.loads( + parsable_arr, flags) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + if not self.current_tool_name_sent: + function_name = tool_call_arr.get("name") + if function_name: + self.current_tool_id = self.current_tool_id + 1 + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + self.streamed_args_for_tool.append("") + else: + delta = None + # now we know we're on the same tool call and we're streaming + # arguments + else: + prev_arguments = self.get_argments( + self.prev_tool_call_arr[self.current_tool_id]) + cur_arguments = self.get_argments(tool_call_arr) + + # not arguments generated + if not cur_arguments and not prev_arguments: + delta = None + # will never happen + elif not cur_arguments and prev_arguments: + logger.error( + "INVARIANT - impossible to have arguments reset " + "mid-arguments") + delta = None + # first time to get parameters + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments) + + arguments_delta = cur_arguments_json[:cur_arguments_json. + index(delta_text) + + len(delta_text)] + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + # both prev and cur parameters, send the increase parameters + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + # check to see if the name is defined and has been sent. if so, + # stream the name - otherwise keep waiting + # finish by setting old and returning None as base case + tool_call_arr["arguments"] = self.get_argments(tool_call_arr) + self.prev_tool_call_arr = [tool_call_arr] + return delta + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", e) + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + text = model_output + tools = request.tools + if '<|action_start|><|plugin|>' in text: + text, action = text.split('<|action_start|><|plugin|>') + action = action.split('<|action_end|>'.strip())[0] + action = action[action.find('{'):] + action_dict = json.loads(action) + name, parameters = action_dict['name'], json.dumps( + action_dict.get('parameters', action_dict.get('arguments', + {}))) + + if not tools or name not in [t.function.name for t in tools]: + ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=text) + + tool_calls = [ + ToolCall( + function=FunctionCall(name=name, arguments=parameters)) + ] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=text if len(text) > 0 else None) + + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=text) diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index f98dca16674d5..3cf34bc4928a5 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -7,12 +7,13 @@ from partial_json_parser.core.options import Allow from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser) + ToolParser, ToolParserManager) from vllm.entrypoints.openai.tool_parsers.utils import find_common_prefix from vllm.logger import init_logger from vllm.utils import random_uuid @@ -41,6 +42,7 @@ def is_complete_json(input_str): return False +@ToolParserManager.register_module("llama3_json") class Llama3JsonToolParser(ToolParser): """ Tool call parser for Llama 3.1 models intended for use with the @@ -64,8 +66,9 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase): add_special_tokens=False)[0] self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) - def extract_tool_calls(self, - model_output: str) -> ExtractedToolCallInformation: + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: """ Extract the tool calls from a complete model response. """ @@ -125,6 +128,7 @@ def extract_tool_calls_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], + request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: if not (current_text.startswith(self.bot_token) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index b61ad40a697e4..1db30797ac6fc 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -8,12 +8,13 @@ from partial_json_parser.core.options import Allow from pydantic import Field -from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser) + ToolParser, ToolParserManager) from vllm.entrypoints.openai.tool_parsers.utils import ( extract_intermediate_diff) from vllm.logger import init_logger @@ -36,6 +37,7 @@ def generate_random_id(): return "".join(choices(ALPHANUMERIC, k=9)) +@ToolParserManager.register_module("mistral") class MistralToolParser(ToolParser): """ Tool call parser for Mistral 7B Instruct v0.3, intended for use with the @@ -47,9 +49,7 @@ class MistralToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) - if isinstance(self.model_tokenizer, MistralTokenizer): - self.model_tokenizer = self.model_tokenizer.tokenizer - else: + if not isinstance(self.model_tokenizer, MistralTokenizer): logger.info("Non-Mistral tokenizer detected when using a Mistral " "model...") @@ -61,11 +61,14 @@ def __init__(self, tokenizer: AnyTokenizer): self.streamed_args_for_tool: List[str] = [ ] # map what has been streamed for each tool so far to a list self.bot_token = "[TOOL_CALLS]" - self.bot_token_id = self.model_tokenizer.vocab[self.bot_token] + self.bot_token_id = self.model_tokenizer.get_vocab()[self.bot_token] self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) - def extract_tool_calls(self, - model_output: str) -> ExtractedToolCallInformation: + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: """ Extract the tool calls from a complete model response. Requires find-and-replacing single quotes with double quotes for JSON parsing, @@ -119,6 +122,7 @@ def extract_tool_calls_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], + request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: # if the tool call token is not in the tokens generated so far, append