From 1aac0b8c416276e4f0914521aafe63d930202ada Mon Sep 17 00:00:00 2001 From: Mike Depinet Date: Wed, 30 Oct 2024 16:12:47 -0700 Subject: [PATCH 01/10] pythonic tool parser Signed-off-by: Mike Depinet --- .../openai/tool_parsers/__init__.py | 0 .../tool_parsers/test_pythonic_tool_parser.py | 160 +++++++++++ .../entrypoints/openai/tool_parsers/utils.py | 113 ++++++++ .../openai/tool_parsers/__init__.py | 3 +- .../tool_parsers/pythonic_tool_parser.py | 262 ++++++++++++++++++ 5 files changed, 537 insertions(+), 1 deletion(-) create mode 100644 tests/entrypoints/openai/tool_parsers/__init__.py create mode 100644 tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py create mode 100644 tests/entrypoints/openai/tool_parsers/utils.py create mode 100644 vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py diff --git a/tests/entrypoints/openai/tool_parsers/__init__.py b/tests/entrypoints/openai/tool_parsers/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py new file mode 100644 index 0000000000000..3ce7e9db35a9d --- /dev/null +++ b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py @@ -0,0 +1,160 @@ +from typing import List +from unittest.mock import MagicMock + +import pytest + +from tests.entrypoints.openai.tool_parsers.utils import ( + run_tool_extraction, run_tool_extraction_streaming) +from vllm.entrypoints.openai.protocol import FunctionCall +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager + +# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1 +SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')" +SIMPLE_FUNCTION_CALL = FunctionCall( + name="get_weather", + arguments='{"city": "San Francisco", "metric": "celsius"}', +) +MORE_TYPES_FUNCTION_OUTPUT = ( + "register_user(name='John Doe', " + "age=37, " + "address={'city': 'San Francisco', 'state': 'CA'}, " + "role=None, " + "passed_test=True, " + "aliases=['John', 'Johnny'])") +MORE_TYPES_FUNCTION_CALL = FunctionCall( + name="register_user", + arguments='{"name": "John Doe", ' + '"age": 37, ' + '"address": {"city": "San Francisco", "state": "CA"}, ' + '"role": null, ' + '"passed_test": true, ' + '"aliases": ["John", "Johnny"]}', +) +PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()" +PARAMETERLESS_FUNCTION_CALL = FunctionCall( + name="get_weather", + arguments='{}', +) +EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})" +EMPTY_DICT_FUNCTION_CALL = FunctionCall( + name="do_something_cool", + arguments='{"additional_data": {}}', +) +EMPTY_LIST_FUNCTION_OUTPUT = "do_something_cool(steps=[])" +EMPTY_LIST_FUNCTION_CALL = FunctionCall( + name="do_something_cool", + arguments='{"steps": []}', +) +ESCAPED_STRING_FUNCTION_OUTPUT = ( + r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')") +ESCAPED_STRING_FUNCTION_CALL = FunctionCall( + name="get_weather", + arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}', +) + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_no_tool_call(streaming: bool): + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( + mock_tokenizer) + model_output = "How can I help you today?" + + content, tool_calls = run_tool_extraction(tool_parser, + model_output, + streaming=streaming) + + assert content == model_output + assert len(tool_calls) == 0 + + +TEST_CASES = [ + pytest.param(True, + f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL], + id="simple_streaming"), + pytest.param(False, + f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL], + id="simple_nonstreaming"), + pytest.param(True, + f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL], + id="more_types_streaming"), + pytest.param(False, + f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL], + id="more_types_nonstreaming"), + pytest.param(True, + f"[{PARAMETERLESS_FUNCTION_OUTPUT}]", + [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_streaming"), + pytest.param(False, + f"[{PARAMETERLESS_FUNCTION_OUTPUT}]", + [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_nonstreaming"), + pytest.param(True, + f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_streaming"), + pytest.param(False, + f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_nonstreaming"), + pytest.param(True, + f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_streaming"), + pytest.param(False, + f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_nonstreaming"), + pytest.param(True, + f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]", + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_streaming"), + pytest.param(False, + f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]", + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_nonstreaming"), + pytest.param(True, + f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]", + [SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], + id="parallel_calls_streaming"), + pytest.param(False, + f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]", + [SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], + id="parallel_calls_nonstreaming"), +] + + +@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", + TEST_CASES) +def test_tool_call(streaming: bool, model_output: str, + expected_tool_calls: List[FunctionCall]): + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( + mock_tokenizer) + + content, tool_calls = run_tool_extraction(tool_parser, + model_output, + streaming=streaming) + + assert content is None + assert len(tool_calls) == len(expected_tool_calls) + for actual, expected in zip(tool_calls, expected_tool_calls): + assert actual.type == "function" + assert actual.function == expected + + +def test_streaming_tool_call_with_large_steps(): + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( + mock_tokenizer) + model_output_deltas = [ + "[get_weather(city='San", + " Francisco', metric='celsius'), " + f"{PARAMETERLESS_FUNCTION_OUTPUT}, " + f"{EMPTY_LIST_FUNCTION_OUTPUT}]", + ] + + reconstructor = run_tool_extraction_streaming(tool_parser, + model_output_deltas) + + assert reconstructor.other_content == "" + assert len(reconstructor.tool_calls) == 3 + assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL + assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL + assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL diff --git a/tests/entrypoints/openai/tool_parsers/utils.py b/tests/entrypoints/openai/tool_parsers/utils.py new file mode 100644 index 0000000000000..d282b74f97990 --- /dev/null +++ b/tests/entrypoints/openai/tool_parsers/utils.py @@ -0,0 +1,113 @@ +from typing import Iterable, List, Tuple, Union + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers import ToolParser + + +class StreamingToolReconstructor: + + def __init__(self): + self.tool_calls: List[ToolCall] = [] + self.other_content: str = "" + + def append_delta(self, delta: DeltaMessage): + if delta.content is not None: + self.other_content += delta.content + else: + assert delta.tool_calls, ( + "Streaming results should have either content or tool calls " + "(or both)") + for call_delta in delta.tool_calls: + assert call_delta.type == "function", ( + "Streaming tool calls should only emit function calls. Got " + f"{call_delta.type}") + current_tool_call = self.tool_calls[ + -1] if self.tool_calls and call_delta.id in { + None, self.tool_calls[-1].id + } else None + if current_tool_call: + assert ( + current_tool_call.function.name == call_delta.function.name + or not call_delta.function.name + ), ("Streaming tool calls should not emit partial function " + f"names. Got {call_delta.function.name}") + assert ( + current_tool_call.id == call_delta.id or not call_delta.id + ), ("Streaming tool calls must not change function ids. Got " + f"{call_delta.id}, expected {current_tool_call.id} or None" + ) + assert (call_delta.index == len(self.tool_calls) - 1), ( + f"Incorrect index for tool delta. Got {call_delta.index}, " + f"expected {len(self.tool_calls) - 1}") + current_tool_call.function.arguments += ( + call_delta.function.arguments) + else: + assert call_delta.id is not None, ( + "Streaming tool calls must have an id on first appearance") + assert call_delta.function.name is not None, ( + "Streaming tool calls must have a function name on first " + "appearance") + assert call_delta.index == len(self.tool_calls), ( + f"Incorrect index for tool delta. Got {call_delta.index}, " + f"expected {len(self.tool_calls)}") + self.tool_calls.append( + ToolCall(id=call_delta.id, + function=FunctionCall( + name=call_delta.function.name, + arguments=call_delta.function.arguments + or ""))) + + +def run_tool_extraction( + tool_parser: ToolParser, + model_output: str, + request: Union[ChatCompletionRequest, None] = None, + streaming: bool = False) -> Tuple[Union[str, None], List[ToolCall]]: + if streaming: + reconstructor = run_tool_extraction_streaming(tool_parser, + model_output, request) + return reconstructor.other_content or None, reconstructor.tool_calls + else: + extracted = run_tool_extraction_nonstreaming(tool_parser, model_output, + request) + assert extracted.tools_called == bool(extracted.tool_calls) + return extracted.content, extracted.tool_calls + + +def run_tool_extraction_nonstreaming( + tool_parser: ToolParser, + model_output: str, + request: Union[ChatCompletionRequest, None] = None +) -> ExtractedToolCallInformation: + request = request or ChatCompletionRequest(messages=[], model="test-model") + return tool_parser.extract_tool_calls(model_output, request) + + +def run_tool_extraction_streaming( + tool_parser: ToolParser, + model_deltas: Iterable[str], + request: Union[ChatCompletionRequest, None] = None +) -> StreamingToolReconstructor: + request = request or ChatCompletionRequest(messages=[], model="test-model") + reconstructor = StreamingToolReconstructor() + previous_text = "" + previous_tokens: List[int] = [] + for delta in model_deltas: + token_delta = [ + tool_parser.vocab.get(token) + for token in tool_parser.model_tokenizer.tokenize(delta) + if token in tool_parser.vocab + ] + current_text = previous_text + delta + current_tokens = previous_tokens + token_delta + delta_message = tool_parser.extract_tool_calls_streaming( + previous_text, current_text, delta, previous_tokens, + current_tokens, token_delta, request) + if delta_message is not None: + reconstructor.append_delta(delta_message) + previous_text = current_text + previous_tokens = current_tokens + return reconstructor diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 1b299ce655570..7a3f5a9081f46 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -5,9 +5,10 @@ from .jamba_tool_parser import JambaToolParser from .llama_tool_parser import Llama3JsonToolParser from .mistral_tool_parser import MistralToolParser +from .pythonic_tool_parser import PythonicToolParser __all__ = [ "ToolParser", "ToolParserManager", "Granite20bFCToolParser", "Hermes2ProToolParser", "MistralToolParser", "Internlm2ToolParser", - "Llama3JsonToolParser", "JambaToolParser" + "Llama3JsonToolParser", "JambaToolParser", "PythonicToolParser" ] diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py new file mode 100644 index 0000000000000..8ce527c2fdebd --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py @@ -0,0 +1,262 @@ +import ast +import json +import re +from collections import defaultdict +from typing import Any, Dict, Sequence, Tuple, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class _UnexpectedAstError(Exception): + pass + + +@ToolParserManager.register_module("pythonic") +class PythonicToolParser(ToolParser): + """ + Tool call parser for models that produce tool calls in a pythonic style, + such as Llama 3.2 models. + + Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set + """ + + TOOL_CALL_REGEX = re.compile( + r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]", + re.DOTALL) + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + self._completed_tools: int = 0 + self._current_tool_id: Union[str, None] = None + self._sent_args_by_tool_id: Dict[str, str] = defaultdict(str) + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. + """ + + if not (self.TOOL_CALL_REGEX.match(model_output)): + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + module = ast.parse(model_output) + parsed = getattr(module.body[0], "value", None) + if isinstance(parsed, ast.List) and all( + isinstance(e, ast.Call) for e in parsed.elts): + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=[ + _handle_single_tool(e) # type: ignore + for e in parsed.elts + ], + content=None) + else: + raise _UnexpectedAstError( + "Tool output must be a list of function calls") + except Exception: + logger.exception("Error in extracting tool call from response.") + # Treat as regular text + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + if not current_text.startswith("["): + return DeltaMessage(content=delta_text) + + try: + valid_and_added_text = _make_valid_python(current_text) + if valid_and_added_text is None: + return None + valid_text, added_text = valid_and_added_text + + module = ast.parse(valid_text) + parsed = getattr(module.body[0], "value", None) + if not isinstance(parsed, ast.List) or not all( + isinstance(e, ast.Call) for e in parsed.elts): + raise _UnexpectedAstError( + "Tool output must be a list of function calls") + tool_calls = [ + _handle_single_tool(e) # type: ignore + for e in parsed.elts + ] + + tool_deltas = [] + for index, new_call in enumerate(tool_calls): + if index < self._completed_tools: + continue + new_call_complete = index < len( + tool_calls) - 1 or ")]" not in added_text + if self._current_tool_id is not None: + new_call.id = self._current_tool_id + if new_call_complete: + self._current_tool_id = None + self._completed_tools += 1 + else: + self._current_tool_id = new_call.id + + withheld_suffix = (added_text[:-2] + if not new_call_complete else "") + if not new_call_complete and added_text[-2] == ")": + # Function call is incomplete. Withhold the closing bracket. + withheld_suffix = withheld_suffix + "}" + # Strings get single quotes in the model-produced string. + # JSON requires double quotes. + withheld_suffix = withheld_suffix.replace("'", '"') + delta = _compute_tool_delta( + self._sent_args_by_tool_id[new_call.id], new_call, index, + withheld_suffix) + + if delta is not None: + tool_deltas.append(delta) + if (delta.function is not None + and delta.function.arguments is not None): + self._sent_args_by_tool_id[ + delta.id] += delta.function.arguments + + return DeltaMessage( + tool_calls=tool_deltas) if tool_deltas else None + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None + + +def _get_parameter_value(val: ast.expr) -> Any: + if isinstance(val, ast.Constant): + return val.value + elif isinstance(val, ast.Dict): + if not all(isinstance(k, ast.Constant) for k in val.keys): + raise _UnexpectedAstError( + "Dict tool call arguments must have literal keys") + return { + k.value: _get_parameter_value(v) # type: ignore + for k, v in zip(val.keys, val.values) + } + elif isinstance(val, ast.List): + return [_get_parameter_value(v) for v in val.elts] + else: + raise _UnexpectedAstError("Tool call arguments must be literals") + + +def _handle_single_tool(call: ast.Call) -> ToolCall: + if not isinstance(call.func, ast.Name): + raise _UnexpectedAstError("Invalid tool call name") + function_name = call.func.id + arguments = {} + for keyword in call.keywords: + arguments[keyword.arg] = _get_parameter_value(keyword.value) + return ToolCall(type="function", + function=FunctionCall(name=function_name, + arguments=json.dumps(arguments))) + + +def _make_valid_python(text: str) -> Union[Tuple[str, str], None]: + bracket_stack = [] + for index, char in enumerate(text): + if char in {"[", "(", "{"}: + bracket_stack.append(char) + elif char == "]": + if not bracket_stack or bracket_stack.pop() != "[": + raise _UnexpectedAstError("Mismatched square brackets") + elif char == ")": + if not bracket_stack or bracket_stack.pop() != "(": + raise _UnexpectedAstError("Mismatched parentheses") + elif char == "}": + if not bracket_stack or bracket_stack.pop() != "{": + raise _UnexpectedAstError("Mismatched curly braces") + elif char in {"'", '"'}: + if bracket_stack and bracket_stack[-1] == char: + if index > 0 and text[index - 1] == "\\": + # Treat an escaped quote as a regular character + pass + else: + bracket_stack.pop() + elif bracket_stack and bracket_stack[-1] in {"'", '"'}: + # Double quote within a single quote string or vice versa. + pass + else: + bracket_stack.append(char) + + text = text.rstrip() + + if bracket_stack and bracket_stack[-1] == "(" and not text.endswith("("): + return None # Incomplete parameter name + if bracket_stack and ( + (bracket_stack[-1] == "{" and not text.endswith("{")) or + (len(bracket_stack) > 1 and bracket_stack[-1] in {"'", '"'} + and bracket_stack[-2] == "{")): + return None # Incomplete property name within parameter value + if bracket_stack and bracket_stack[-1] == "[" and not text.endswith("["): + return None # Incomplete function name + + if text.endswith(","): + text = text[:-1] + elif text.endswith("="): + # Since we have no type information for this property/parameter value, + # we can't fill in a valid value. + return None + + added_text = "" + for char in reversed(bracket_stack): + if char == "[": + added_text += "]" + elif char == "(": + added_text += ")" + elif char == "{": + added_text += "}" + elif char == "'": + added_text += "'" + elif char == '"': + added_text += '"' + + return text + added_text, added_text + + +def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall, + index: int, + withheld_suffix: str) -> Union[DeltaToolCall, None]: + new_call_args = new_call.function.arguments + if withheld_suffix: + assert new_call_args.endswith(withheld_suffix) + new_call_args = new_call_args[:-len(withheld_suffix)] + if not previously_sent_args: + return DeltaToolCall(id=new_call.id, + index=index, + function=DeltaFunctionCall( + name=new_call.function.name, + arguments=new_call_args, + )) + + arg_diff = new_call_args[len(previously_sent_args):] + return DeltaToolCall(id=new_call.id, + index=index, + function=DeltaFunctionCall( + arguments=arg_diff)) if arg_diff else None From 683eb27db1f74fd7eea9250c5758ddeb7c6e7199 Mon Sep 17 00:00:00 2001 From: Mike Depinet Date: Fri, 1 Nov 2024 16:07:27 -0700 Subject: [PATCH 02/10] Add an entry to openai_compatible_server.md Signed-off-by: Mike Depinet --- .../serving/openai_compatible_server.md | 69 ++++++++++++------- 1 file changed, 45 insertions(+), 24 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index a1f93a9a28578..53e7799937029 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -103,7 +103,7 @@ vllm serve --chat-template ./path-to-chat-template.jinja vLLM community provides a set of chat templates for popular models. You can find them in the examples directory [here](https://github.com/vllm-project/vllm/tree/main/examples/) -With the inclusion of multi-modal chat APIs, the OpenAI spec now accepts chat messages in a new format which specifies +With the inclusion of multi-modal chat APIs, the OpenAI spec now accepts chat messages in a new format which specifies both a `type` and a `text` field. An example is provided below: ```python completion = client.chat.completions.create( @@ -113,10 +113,10 @@ completion = client.chat.completions.create( ] ) ``` -Most chat templates for LLMs expect the `content` to be a `string` but there are some newer models like +Most chat templates for LLMs expect the `content` to be a `string` but there are some newer models like `meta-llama/Llama-Guard-3-1B` that expect the content to be parsed with the new OpenAI spec. In order to choose which format the content needs to be parsed in by vLLM, please use the `--chat-template-text-format` argument to specify -between `string` or `openai`. The default value is `string` and vLLM internally converts both spec formats to match +between `string` or `openai`. The default value is `string` and vLLM internally converts both spec formats to match this, unless explicitly specified. @@ -129,18 +129,18 @@ this, unless explicitly specified. ``` ## Tool Calling in the Chat Completion API ### Named Function Calling -vLLM supports only named function calling in the chat completion API by default. It does so using Outlines, so this is -enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a -high-quality one. +vLLM supports only named function calling in the chat completion API by default. It does so using Outlines, so this is +enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a +high-quality one. -To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and -specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request. +To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and +specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request. ### Config file The `serve` module can also accept arguments from a config file in -`yaml` format. The arguments in the yaml must be specified using the -long form of the argument outlined [here](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#command-line-arguments-for-the-server): +`yaml` format. The arguments in the yaml must be specified using the +long form of the argument outlined [here](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#command-line-arguments-for-the-server): For example: @@ -156,7 +156,7 @@ uvicorn-log-level: "info" $ vllm serve SOME_MODEL --config config.yaml ``` --- -**NOTE** +**NOTE** In case an argument is supplied simultaneously using command line and the config file, the value from the commandline will take precedence. The order of priorities is `command line > config file values > defaults`. @@ -172,18 +172,18 @@ vLLM will use guided decoding to ensure the response matches the tool parameter ### Automatic Function Calling To enable this feature, you should set the following flags: -* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it +* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it deems appropriate. -* `--tool-call-parser` -- select the tool parser to use (listed below). Additional tool parsers +* `--tool-call-parser` -- select the tool parser to use (listed below). Additional tool parsers will continue to be added in the future, and also can register your own tool parsers in the `--tool-parser-plugin`. * `--tool-parser-plugin` -- **optional** tool parser plugin used to register user defined tool parsers into vllm, the registered tool parser name can be specified in `--tool-call-parser`. -* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages -that contain previously generated tool calls. Hermes, Mistral and Llama models have tool-compatible chat templates in their -`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat +* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages +that contain previously generated tool calls. Hermes, Mistral and Llama models have tool-compatible chat templates in their +`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat template configured in the `tokenizer_config.json`. In this case, it will be used per the `transformers` specification. More on this [here](https://huggingface.co/docs/transformers/en/chat_templating#why-do-some-models-have-multiple-templates) from HuggingFace; and you can find an example of this in a `tokenizer_config.json` [here](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/tokenizer_config.json) -If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template! +If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template! #### Hermes Models (`hermes`) @@ -194,8 +194,8 @@ All Nous Research Hermes-series models newer than Hermes 2 Pro should be support * `NousResearch/Hermes-3-*` -_Note that the Hermes 2 **Theta** models are known to have degraded tool call quality & capabilities due to the merge -step in their creation_. +_Note that the Hermes 2 **Theta** models are known to have degraded tool call quality & capabilities due to the merge +step in their creation_. Flags: `--tool-call-parser hermes` @@ -207,9 +207,9 @@ Supported models: * Additional mistral function-calling models are compatible as well. Known issues: -1. Mistral 7B struggles to generate parallel tool calls correctly. -2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is -much shorter than what vLLM generates. Since an exception is thrown when this condition +1. Mistral 7B struggles to generate parallel tool calls correctly. +2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is +much shorter than what vLLM generates. Since an exception is thrown when this condition is not met, the following additional chat templates are provided: * `examples/tool_chat_template_mistral.jinja` - this is the "official" Mistral chat template, but tweaked so that @@ -229,11 +229,11 @@ Supported models: * `meta-llama/Meta-Llama-3.1-405B-Instruct` * `meta-llama/Meta-Llama-3.1-405B-Instruct-FP8` -The tool calling that is supported is the [JSON based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). +The tool calling that is supported is the [JSON based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) in Llama-3.2 models, see the `pythonic` tool parser below. Other tool calling formats like the built in python tool calling or custom tool calling are not supported. Known issues: -1. Parallel tool calls are not supported. +1. Parallel tool calls are not supported. 2. The model can generate parameters with a wrong format, such as generating an array serialized as string instead of an array. @@ -274,6 +274,27 @@ Flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_tem The example chat template deviates slightly from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported. +#### Models with Pythonic Tool Calls (`pythonic`) + +A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models. + +As a concrete example, these models may look up the weather in San Francisco and Seattle by generating: +```python +[get_weather(city='San Francisco', metric='celsius'), get_weather(city='Seattle', metric='celsius')] +``` + +Limitations: +* The model must not generate both text and tool calls in the same generation. This may not be hard to change for a specific model, but the community currently lacks consensus on which tokens to emit when starting and ending tool calls. (In particular, the Llama 3.2 models emit no such tokens.) +* Llama's smaller models struggle to use tools effectively. + +Example supported models: +* `meta-llama/Llama-3.2-1B-Instruct` +* `meta-llama/Llama-3.2-3B-Instruct` +* `Team-ACE/ToolACE-8B` + +Flags: `--tool-call-parser pythonic` + + ### How to write a tool parser plugin A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py. From b64ca59364db58cbc90d2f3aa61ebe9f32b56fea Mon Sep 17 00:00:00 2001 From: Mike Depinet Date: Mon, 4 Nov 2024 09:42:51 -0800 Subject: [PATCH 03/10] Integration test attempt (can't run on my poor 4070) Signed-off-by: Mike Depinet --- ...tool_chat_template_llama3.2_pythonic.jinja | 97 +++++++++++++++++++ tests/tool_use/utils.py | 13 ++- 2 files changed, 109 insertions(+), 1 deletion(-) create mode 100644 examples/tool_chat_template_llama3.2_pythonic.jinja diff --git a/examples/tool_chat_template_llama3.2_pythonic.jinja b/examples/tool_chat_template_llama3.2_pythonic.jinja new file mode 100644 index 0000000000000..791b9ff6a3479 --- /dev/null +++ b/examples/tool_chat_template_llama3.2_pythonic.jinja @@ -0,0 +1,97 @@ +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = false %} +{%- endif %} +{%- if not date_string is defined %} + {%- if strftime_now is defined %} + {%- set date_string = strftime_now("%d %b %Y") %} + {%- else %} + {%- set date_string = "26 Jul 2024" %} + {%- endif %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language. When you receive a tool call response, use the output to format an answer to the original user question." %} +{%- endif %} + +{#- System message #} +{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} +{%- if tools is not none %} + {{- "Environment: ipython\n" }} +{%- endif %} +{{- "Cutting Knowledge Date: December 2023\n" }} +{{- "Today Date: " + date_string + "\n\n" }} +{%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions. To call functions, please respond with a python list of the calls. " }} + {{- 'Respond in the format [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] ' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %} +{{- system_message }} +{{- "<|eot_id|>" }} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} + {%- endif %} + {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} + {{- "Given the following functions, please respond with a python list for function calls " }} + {{- "with their proper arguments to best answer the given prompt.\n\n" }} + {{- 'Respond in the format [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] ' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- first_user_message + "<|eot_id|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n[' -}} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- tool_call.name + '(' -}} + {%- for param in tool_call.arguments %} + {{- param.name + '=' + param.value -}} + {% if not loop.last %}, {% endif %} + {%- endfor %} + {{- ')' -}} + {% if not loop.last %}, {% endif %} + {%- endfor %} + {{- ']<|eot_id|>' -}} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} + {%- if message.content is mapping %} + {{- message.content | tojson }} + {%- else %} + {{- { "output": message.content } | tojson }} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index d9ee0b1d54b0a..a21b4e39fed5e 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -111,7 +111,18 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], ], "supports_parallel": False, - } + }, + "pythonic": { + "model": + "meta-llama/Llama-3.2-3B-Instruct", + "arguments": [ + "--tool-call-parser", "pythonic", "--chat-template", + str(VLLM_PATH / + "examples/tool_chat_template_llama3.2_pythonic.jinja") + ], + "supports_parallel": + True, + }, } WEATHER_TOOL: ChatCompletionToolParam = { From 9553df6cf249f11cd7461d49e8c4a97b78d88fa6 Mon Sep 17 00:00:00 2001 From: Mike Depinet Date: Wed, 6 Nov 2024 20:37:14 +0000 Subject: [PATCH 04/10] checkpoint: fix most tool_use tests Signed-off-by: Mike Depinet --- ...tool_chat_template_llama3.2_pythonic.jinja | 3 +- .../tool_parsers/test_pythonic_tool_parser.py | 4 +- .../entrypoints/openai/tool_parsers/utils.py | 54 ++++++++++-------- .../tool_parsers/pythonic_tool_parser.py | 56 +++++++++++-------- 4 files changed, 68 insertions(+), 49 deletions(-) diff --git a/examples/tool_chat_template_llama3.2_pythonic.jinja b/examples/tool_chat_template_llama3.2_pythonic.jinja index 791b9ff6a3479..8c38de6c6a907 100644 --- a/examples/tool_chat_template_llama3.2_pythonic.jinja +++ b/examples/tool_chat_template_llama3.2_pythonic.jinja @@ -75,7 +75,8 @@ {%- endif %} {{- tool_call.name + '(' -}} {%- for param in tool_call.arguments %} - {{- param.name + '=' + param.value -}} + {{- param + '=' -}} + {{- "%sr" | format(tool_call.arguments[param]) -}} {% if not loop.last %}, {% endif %} {%- endfor %} {{- ')' -}} diff --git a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py index 3ce7e9db35a9d..47b0b6bb80ffe 100644 --- a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py @@ -150,8 +150,8 @@ def test_streaming_tool_call_with_large_steps(): f"{EMPTY_LIST_FUNCTION_OUTPUT}]", ] - reconstructor = run_tool_extraction_streaming(tool_parser, - model_output_deltas) + reconstructor = run_tool_extraction_streaming( + tool_parser, model_output_deltas, assert_one_tool_per_delta=False) assert reconstructor.other_content == "" assert len(reconstructor.tool_calls) == 3 diff --git a/tests/entrypoints/openai/tool_parsers/utils.py b/tests/entrypoints/openai/tool_parsers/utils.py index d282b74f97990..f0a2a32c16786 100644 --- a/tests/entrypoints/openai/tool_parsers/utils.py +++ b/tests/entrypoints/openai/tool_parsers/utils.py @@ -9,9 +9,10 @@ class StreamingToolReconstructor: - def __init__(self): + def __init__(self, assert_one_tool_per_delta: bool = True): self.tool_calls: List[ToolCall] = [] self.other_content: str = "" + self._assert_one_tool_per_delta = assert_one_tool_per_delta def append_delta(self, delta: DeltaMessage): if delta.content is not None: @@ -20,25 +21,27 @@ def append_delta(self, delta: DeltaMessage): assert delta.tool_calls, ( "Streaming results should have either content or tool calls " "(or both)") + if self._assert_one_tool_per_delta: + # Note: This isn't strictly required by the API and may not be + # possible to adhere to depending on the token space and number of + # tokens per streamed response from the model, but it is required + # by tool_use tests, so we enforce it here by default also. + assert len(delta.tool_calls) < 2, ( + "Streaming should include only one tool call per update.") for call_delta in delta.tool_calls: assert call_delta.type == "function", ( "Streaming tool calls should only emit function calls. Got " f"{call_delta.type}") current_tool_call = self.tool_calls[ - -1] if self.tool_calls and call_delta.id in { - None, self.tool_calls[-1].id - } else None + call_delta.index] if call_delta.index < len( + self.tool_calls) else None if current_tool_call: - assert ( - current_tool_call.function.name == call_delta.function.name - or not call_delta.function.name - ), ("Streaming tool calls should not emit partial function " - f"names. Got {call_delta.function.name}") - assert ( - current_tool_call.id == call_delta.id or not call_delta.id - ), ("Streaming tool calls must not change function ids. Got " - f"{call_delta.id}, expected {current_tool_call.id} or None" - ) + assert (not call_delta.function.name), ( + "Streaming tool calls should emit the full function name " + f"exactly once. Got {call_delta.function.name}") + assert (not call_delta.id), ( + "Streaming tool calls must emit function id only once. Got " + f"{call_delta.id}") assert (call_delta.index == len(self.tool_calls) - 1), ( f"Incorrect index for tool delta. Got {call_delta.index}, " f"expected {len(self.tool_calls) - 1}") @@ -62,13 +65,18 @@ def append_delta(self, delta: DeltaMessage): def run_tool_extraction( - tool_parser: ToolParser, - model_output: str, - request: Union[ChatCompletionRequest, None] = None, - streaming: bool = False) -> Tuple[Union[str, None], List[ToolCall]]: + tool_parser: ToolParser, + model_output: str, + request: Union[ChatCompletionRequest, None] = None, + streaming: bool = False, + assert_one_tool_per_delta: bool = True, +) -> Tuple[Union[str, None], List[ToolCall]]: if streaming: - reconstructor = run_tool_extraction_streaming(tool_parser, - model_output, request) + reconstructor = run_tool_extraction_streaming( + tool_parser, + model_output, + request, + assert_one_tool_per_delta=assert_one_tool_per_delta) return reconstructor.other_content or None, reconstructor.tool_calls else: extracted = run_tool_extraction_nonstreaming(tool_parser, model_output, @@ -89,10 +97,12 @@ def run_tool_extraction_nonstreaming( def run_tool_extraction_streaming( tool_parser: ToolParser, model_deltas: Iterable[str], - request: Union[ChatCompletionRequest, None] = None + request: Union[ChatCompletionRequest, None] = None, + assert_one_tool_per_delta: bool = True, ) -> StreamingToolReconstructor: request = request or ChatCompletionRequest(messages=[], model="test-model") - reconstructor = StreamingToolReconstructor() + reconstructor = StreamingToolReconstructor( + assert_one_tool_per_delta=assert_one_tool_per_delta) previous_text = "" previous_tokens: List[int] = [] for delta in model_deltas: diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py index 8ce527c2fdebd..54e089cb5c22f 100644 --- a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py @@ -1,8 +1,7 @@ import ast import json import re -from collections import defaultdict -from typing import Any, Dict, Sequence, Tuple, Union +from typing import Any, Sequence, Tuple, Union from transformers import PreTrainedTokenizerBase @@ -37,9 +36,15 @@ class PythonicToolParser(ToolParser): def __init__(self, tokenizer: PreTrainedTokenizerBase): super().__init__(tokenizer) - self._completed_tools: int = 0 - self._current_tool_id: Union[str, None] = None - self._sent_args_by_tool_id: Dict[str, str] = defaultdict(str) + + # Rename for readability. This is NOT a tool id. + @property + def current_tool_index(self) -> int: + return self.current_tool_id + + @current_tool_index.setter + def current_tool_index(self, value: int) -> None: + self.current_tool_id = value def extract_tool_calls( self, model_output: str, @@ -108,17 +113,17 @@ def extract_tool_calls_streaming( tool_deltas = [] for index, new_call in enumerate(tool_calls): - if index < self._completed_tools: + if index < self.current_tool_index: continue + + self.current_tool_index = index + if len(self.streamed_args_for_tool) == index: + self.streamed_args_for_tool.append("") + new_call_complete = index < len( tool_calls) - 1 or ")]" not in added_text - if self._current_tool_id is not None: - new_call.id = self._current_tool_id if new_call_complete: - self._current_tool_id = None - self._completed_tools += 1 - else: - self._current_tool_id = new_call.id + self.current_tool_index += 1 withheld_suffix = (added_text[:-2] if not new_call_complete else "") @@ -128,16 +133,15 @@ def extract_tool_calls_streaming( # Strings get single quotes in the model-produced string. # JSON requires double quotes. withheld_suffix = withheld_suffix.replace("'", '"') - delta = _compute_tool_delta( - self._sent_args_by_tool_id[new_call.id], new_call, index, - withheld_suffix) + delta = _compute_tool_delta(self.streamed_args_for_tool[index], + new_call, index, withheld_suffix) if delta is not None: tool_deltas.append(delta) if (delta.function is not None and delta.function.arguments is not None): - self._sent_args_by_tool_id[ - delta.id] += delta.function.arguments + self.streamed_args_for_tool[ + index] += delta.function.arguments return DeltaMessage( tool_calls=tool_deltas) if tool_deltas else None @@ -207,14 +211,19 @@ def _make_valid_python(text: str) -> Union[Tuple[str, str], None]: text = text.rstrip() - if bracket_stack and bracket_stack[-1] == "(" and not text.endswith("("): - return None # Incomplete parameter name + if bracket_stack and bracket_stack[-1] == "(": + trailing_params_text = text[:text.rfind("(")] + num_full_param_names = trailing_params_text.count("=") + num_full_param_values = trailing_params_text.count(",") + if num_full_param_names <= num_full_param_values: + return None # Incomplete parameter name if bracket_stack and ( (bracket_stack[-1] == "{" and not text.endswith("{")) or (len(bracket_stack) > 1 and bracket_stack[-1] in {"'", '"'} and bracket_stack[-2] == "{")): return None # Incomplete property name within parameter value - if bracket_stack and bracket_stack[-1] == "[" and not text.endswith("["): + if bracket_stack and bracket_stack[-1] == "[" and not text.endswith( + "[") and not text.endswith(")"): return None # Incomplete function name if text.endswith(","): @@ -256,7 +265,6 @@ def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall, )) arg_diff = new_call_args[len(previously_sent_args):] - return DeltaToolCall(id=new_call.id, - index=index, - function=DeltaFunctionCall( - arguments=arg_diff)) if arg_diff else None + return DeltaToolCall( + id="", index=index, function=DeltaFunctionCall( + arguments=arg_diff)) if arg_diff else None From 4db5b39d950cfc13ee1aff4ea2a9f1d24b065d1b Mon Sep 17 00:00:00 2001 From: Mike Depinet Date: Thu, 7 Nov 2024 22:04:50 +0000 Subject: [PATCH 05/10] Get remaining tool_use tests passing Signed-off-by: Mike Depinet --- .../tool_parsers/pythonic_tool_parser.py | 43 ++++++++++++------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py index 54e089cb5c22f..2d9f72b20e014 100644 --- a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py @@ -143,8 +143,22 @@ def extract_tool_calls_streaming( self.streamed_args_for_tool[ index] += delta.function.arguments - return DeltaMessage( - tool_calls=tool_deltas) if tool_deltas else None + # HACK: serving_chat.py inspects the internal state of tool parsers + # when determining it's final streaming delta, automatically + # adding autocompleted JSON. + # These two lines avoid that nonsense while ensuring finish_reason + # is set to tool_calls when at least one tool is called. + if tool_deltas and not self.prev_tool_call_arr: + self.prev_tool_call_arr = [{"arguments": {}}] + + if tool_deltas: + return DeltaMessage(tool_calls=tool_deltas) + elif not added_text and self.current_tool_id > 0: + # Return an empty DeltaMessage once the tool calls are all done + # so that finish_reason gets set. + return DeltaMessage(content='') + else: + return None except Exception: logger.exception("Error trying to handle streaming tool call.") logger.debug( @@ -210,29 +224,28 @@ def _make_valid_python(text: str) -> Union[Tuple[str, str], None]: bracket_stack.append(char) text = text.rstrip() - + if text.endswith("=") or text.endswith(":"): + # Since we have no type information for this property/parameter value, + # we can't fill in a valid value. + return None + if bracket_stack and bracket_stack[-1] == "{": + trailing_dict_text = text[:text.rfind("{")] + num_keys = trailing_dict_text.count(":") + num_values = trailing_dict_text.count(",") + if num_keys <= num_values: + return None # Incomplete property name within parameter value if bracket_stack and bracket_stack[-1] == "(": trailing_params_text = text[:text.rfind("(")] num_full_param_names = trailing_params_text.count("=") num_full_param_values = trailing_params_text.count(",") if num_full_param_names <= num_full_param_values: return None # Incomplete parameter name - if bracket_stack and ( - (bracket_stack[-1] == "{" and not text.endswith("{")) or - (len(bracket_stack) > 1 and bracket_stack[-1] in {"'", '"'} - and bracket_stack[-2] == "{")): - return None # Incomplete property name within parameter value + if text.endswith(","): + text = text[:-1] if bracket_stack and bracket_stack[-1] == "[" and not text.endswith( "[") and not text.endswith(")"): return None # Incomplete function name - if text.endswith(","): - text = text[:-1] - elif text.endswith("="): - # Since we have no type information for this property/parameter value, - # we can't fill in a valid value. - return None - added_text = "" for char in reversed(bracket_stack): if char == "[": From 644f8bef4d433477ec878d0fd112568f73cf1f17 Mon Sep 17 00:00:00 2001 From: Mike Depinet Date: Thu, 7 Nov 2024 23:34:49 +0000 Subject: [PATCH 06/10] Add ToolACE template and tool_use test entry Signed-off-by: Mike Depinet --- examples/tool_chat_template_toolace.jinja | 65 +++++++++++++++++++++++ tests/tool_use/utils.py | 12 ++++- 2 files changed, 76 insertions(+), 1 deletion(-) create mode 100644 examples/tool_chat_template_toolace.jinja diff --git a/examples/tool_chat_template_toolace.jinja b/examples/tool_chat_template_toolace.jinja new file mode 100644 index 0000000000000..a9b3b7189dddf --- /dev/null +++ b/examples/tool_chat_template_toolace.jinja @@ -0,0 +1,65 @@ +{{- bos_token }} + +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language." %} +{%- endif %} + +{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} +{%- if tools is not none and not tools_in_user_message %} + {{- "You are an expert in composing functions. You are given a question and a set of possible functions. Based on the question, you will need to make one or more function/tool calls to achieve the purpose.\n" }} + {{- "If none of the function can be used, point it out. If the given question lacks the parameters required by the function, also point it out.\n" }} + {{- "You should only return the function call in tools call sections.\n\n" }} + {{- "If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]\n" }} + {{- "You SHOULD NOT include any other text in the response.\n" }} + {{- "Here is a list of functions in JSON format that you can invoke.\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- "\n" }} +{%- endif %} +{{- system_message }} +{{- "<|eot_id|>" }} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n[' -}} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- tool_call.name + '(' -}} + {%- for param in tool_call.arguments %} + {{- param + '=' -}} + {{- "%sr" | format(tool_call.arguments[param]) -}} + {% if not loop.last %}, {% endif %} + {%- endfor %} + {{- ')' -}} + {% if not loop.last %}, {% endif %} + {%- endfor %} + {{- ']<|eot_id|>' -}} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} + {%- if message.content is mapping %} + {{- message.content | tojson }} + {%- else %} + {{- { "output": message.content } | tojson }} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} +{%- endfor %} + +{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index efccc9fe4e731..80f0e14545065 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -123,7 +123,7 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], "supports_parallel": False, }, - "pythonic": { + "llama3.2_pythonic": { "model": "meta-llama/Llama-3.2-3B-Instruct", "arguments": [ @@ -134,6 +134,16 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], "supports_parallel": True, }, + "toolACE": { + "model": + "Team-ACE/ToolACE-8B", + "arguments": [ + "--tool-call-parser", "pythonic", "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_toolace.jinja") + ], + "supports_parallel": + True, + }, } WEATHER_TOOL: ChatCompletionToolParam = { From 29d62aca4f433b06f783f905a7a5ae516cd5d45e Mon Sep 17 00:00:00 2001 From: Mike Depinet Date: Fri, 8 Nov 2024 00:22:41 +0000 Subject: [PATCH 07/10] update docs Signed-off-by: Mike Depinet --- docs/source/serving/openai_compatible_server.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 174bce7603554..4d8d29f331b12 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -329,11 +329,12 @@ Limitations: * Llama's smaller models struggle to use tools effectively. Example supported models: -* `meta-llama/Llama-3.2-1B-Instruct` -* `meta-llama/Llama-3.2-3B-Instruct` -* `Team-ACE/ToolACE-8B` +* `meta-llama/Llama-3.2-1B-Instruct` (use with `examples/tool_chat_template_llama3.2_pythonic.jinja`) +* `meta-llama/Llama-3.2-3B-Instruct` (use with `examples/tool_chat_template_llama3.2_pythonic.jinja`) +* `Team-ACE/ToolACE-8B` (use with `examples/tool_chat_template_toolace.jinja`) +* `fixie-ai/ultravox-v0_4-ToolACE-8B` (use with `examples/tool_chat_template_toolace.jinja`) -Flags: `--tool-call-parser pythonic` +Flags: `--tool-call-parser pythonic --chat-template {see_above}` ### How to write a tool parser plugin From a5795693699a13cab88515f07d7752da02c96b28 Mon Sep 17 00:00:00 2001 From: Mike Depinet Date: Fri, 8 Nov 2024 17:26:59 +0000 Subject: [PATCH 08/10] Warn about Llama3.2 models, add TODO for future work Signed-off-by: Mike Depinet --- docs/source/serving/openai_compatible_server.md | 6 ++++-- tests/tool_use/utils.py | 11 ----------- vllm/entrypoints/openai/serving_chat.py | 5 +++++ .../openai/tool_parsers/pythonic_tool_parser.py | 6 ++++++ 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 4d8d29f331b12..20405cbfbe531 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -329,13 +329,15 @@ Limitations: * Llama's smaller models struggle to use tools effectively. Example supported models: -* `meta-llama/Llama-3.2-1B-Instruct` (use with `examples/tool_chat_template_llama3.2_pythonic.jinja`) -* `meta-llama/Llama-3.2-3B-Instruct` (use with `examples/tool_chat_template_llama3.2_pythonic.jinja`) +* `meta-llama/Llama-3.2-1B-Instruct`\* (use with `examples/tool_chat_template_llama3.2_pythonic.jinja`) +* `meta-llama/Llama-3.2-3B-Instruct`\* (use with `examples/tool_chat_template_llama3.2_pythonic.jinja`) * `Team-ACE/ToolACE-8B` (use with `examples/tool_chat_template_toolace.jinja`) * `fixie-ai/ultravox-v0_4-ToolACE-8B` (use with `examples/tool_chat_template_toolace.jinja`) Flags: `--tool-call-parser pythonic --chat-template {see_above}` +\* Llama's smaller models frequently fail to emit tool calls in the correct format. Your mileage may vary. + ### How to write a tool parser plugin diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index 80f0e14545065..6818ac44b2478 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -123,17 +123,6 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], "supports_parallel": False, }, - "llama3.2_pythonic": { - "model": - "meta-llama/Llama-3.2-3B-Instruct", - "arguments": [ - "--tool-call-parser", "pythonic", "--chat-template", - str(VLLM_PATH / - "examples/tool_chat_template_llama3.2_pythonic.jinja") - ], - "supports_parallel": - True, - }, "toolACE": { "model": "Team-ACE/ToolACE-8B", diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 9551b4f2091dd..04e10cbbd6766 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -75,6 +75,11 @@ def __init__(self, try: self.tool_parser = ToolParserManager.get_tool_parser( tool_parser) + if (self.tool_parser.__name__ == "PythonicToolParser" and + model_config.model.startswith("meta-llama/Llama-3.2")): + logger.warning( + "Llama3.2 models may struggle to emit valid pythonic" + " tool calls") except Exception as e: raise TypeError("Error: --enable-auto-tool-choice requires " f"tool_parser:'{tool_parser}' which has not " diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py index 2d9f72b20e014..26da4d689fb8b 100644 --- a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py @@ -29,6 +29,12 @@ class PythonicToolParser(ToolParser): Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set """ + # TODO(mdepinet): Possible future improvements: + # 1. Support text + tools separated by either <|python_tag|> or \n\n + # 2. Support tools outside of a list (or separated by a semicolon). + # This depends on item 1 for consistent streaming. + # Neither of these are necessary for e.g. ToolACE, but both would help make + # Llama3.2 models more reliable. TOOL_CALL_REGEX = re.compile( r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]", From 0eaf0a6550e305662bbb24b922188543d8c64ee9 Mon Sep 17 00:00:00 2001 From: Mike Depinet Date: Wed, 13 Nov 2024 11:52:15 -0800 Subject: [PATCH 09/10] PR comments Signed-off-by: Mike Depinet --- docs/source/serving/openai_compatible_server.md | 2 +- vllm/entrypoints/openai/serving_chat.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 20405cbfbe531..5827ecc83491a 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -336,7 +336,7 @@ Example supported models: Flags: `--tool-call-parser pythonic --chat-template {see_above}` -\* Llama's smaller models frequently fail to emit tool calls in the correct format. Your mileage may vary. +WARNING: Llama's smaller models frequently fail to emit tool calls in the correct format. Your mileage may vary. ### How to write a tool parser plugin diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 04e10cbbd6766..67573d32454ac 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -73,13 +73,13 @@ def __init__(self, self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None if self.enable_auto_tools: try: - self.tool_parser = ToolParserManager.get_tool_parser( - tool_parser) - if (self.tool_parser.__name__ == "PythonicToolParser" and + if (tool_parser == "pythonic" and model_config.model.startswith("meta-llama/Llama-3.2")): logger.warning( "Llama3.2 models may struggle to emit valid pythonic" " tool calls") + self.tool_parser = ToolParserManager.get_tool_parser( + tool_parser) except Exception as e: raise TypeError("Error: --enable-auto-tool-choice requires " f"tool_parser:'{tool_parser}' which has not " From 29a97042a12285b21e72a3079be99f739f3317db Mon Sep 17 00:00:00 2001 From: Mike Depinet Date: Wed, 13 Nov 2024 12:24:39 -0800 Subject: [PATCH 10/10] Alter warning block based on other "note" block in this file Signed-off-by: Mike Depinet --- docs/source/serving/openai_compatible_server.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 5827ecc83491a..22b1d774bac75 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -336,7 +336,11 @@ Example supported models: Flags: `--tool-call-parser pythonic --chat-template {see_above}` -WARNING: Llama's smaller models frequently fail to emit tool calls in the correct format. Your mileage may vary. +--- +**WARNING** +Llama's smaller models frequently fail to emit tool calls in the correct format. Your mileage may vary. + +--- ### How to write a tool parser plugin