From 11cd1ae6ad6fa7d35060fea35133e08c0a1c227c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 15 Nov 2024 01:42:49 +0100 Subject: [PATCH] [Tool parsing] Improve / correct mistral tool parsing (#10333) --- .../decoder_only/language/test_mistral.py | 93 ++++++++++++++++--- vllm/entrypoints/openai/serving_chat.py | 39 +------- .../tool_parsers/mistral_tool_parser.py | 25 +++-- .../transformers_utils/tokenizers/__init__.py | 4 +- vllm/transformers_utils/tokenizers/mistral.py | 70 +++++++++++++- 5 files changed, 172 insertions(+), 59 deletions(-) diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 6ec4b7e7e3f71..99b5d5694f9f7 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -2,9 +2,13 @@ Run `pytest tests/models/test_mistral.py`. """ +import copy + import pytest from vllm import SamplingParams +from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa + MistralToolParser) from ...utils import check_logprobs_close @@ -58,17 +62,69 @@ }, "required": ["city", "state", "unit"] } + }, +}, { + "type": "function", + "function": { + "name": "rewrite", + "description": "Rewrites text", + "parameters": { + "type": "object", + "required": [], + "properties": { + "text": { + "type": "string", + "description": "The input text to rewrite." + } + } + } } }] -MSGS = [{ - "role": - "user", - "content": ("Can you tell me what the temperate" - " will be in Dallas, in fahrenheit?") -}] -EXPECTED_FUNC_CALL = ( - '[{"name": "get_current_weather", "arguments": ' - '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]') +MSGS = [ + { + "role": "system", + "content": "You are an assistant." + }, + { + "role": + "user", + "content": + "Could you please rewrite the below article? \n\n My English needs improvving, maybe I make errors." # noqa + }, + { + "role": + "assistant", + "content": + "", + "tool_calls": [{ + "id": "bbc5b7ede", + "type": "function", + "function": { + "name": + "rewrite", + "arguments": + '{\"text\":\"My English needs improvving, maybe I make errors.\"}' # noqa + } + }] + }, + { + "role": "tool", + "content": + "{\"action\":\"rewrite\",\"outcome\":\"My English needs improving, maybe I make errors.\"}", # noqa + "tool_call_id": "bbc5b7ede", + "name": "rewrite" + }, + { + "role": "assistant", + "content": "---\n\nMy English needs improving, maybe I make errors" + }, + { + "role": + "user", + "content": ("Can you tell me what the temperate" + " will be in Dallas, in fahrenheit?") + } +] @pytest.mark.parametrize("model", MODELS) @@ -175,8 +231,23 @@ def test_mistral_function_calling( tokenizer_mode="mistral", config_format="mistral", load_format="mistral") as vllm_model: - outputs = vllm_model.model.chat(MSGS, + + msgs = copy.deepcopy(MSGS) + outputs = vllm_model.model.chat(msgs, tools=TOOLS, sampling_params=SAMPLING_PARAMS) - assert outputs[0].outputs[0].text.strip() == EXPECTED_FUNC_CALL + tokenizer = vllm_model.model.get_tokenizer() + tool_parser = MistralToolParser(tokenizer) + + model_output = outputs[0].outputs[0].text.strip() + assert model_output.startswith(tool_parser.bot_token), model_output + parsed_message = tool_parser.extract_tool_calls(model_output, None) + + assert parsed_message.tools_called + assert parsed_message.tool_calls[0].id == "0UAqFzWsD" + assert parsed_message.tool_calls[ + 0].function.name == "get_current_weather" + assert parsed_message.tool_calls[ + 0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa + assert parsed_message.content is None diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 5178481c737b4..77cae00ae827f 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -30,6 +30,7 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls from vllm.utils import iterate_with_cancellation logger = init_logger(__name__) @@ -127,41 +128,11 @@ async def create_chat_completion( return self.create_error_response( "tool_choice = \"required\" is not supported!") - # NOTE: There is currently a bug in pydantic where attributes - # declared as iterables are replaced in in the instances by - # pydantic-core ValidatorIterator instance. In particular, this - # affects tool_calls defined in ChatCompletionAssistantMessageParam - # model: - # see: - # - https://github.com/pydantic/pydantic/issues/9467 - # As a result, tool_calls from assistant messages are never - # deserialized in the request object if the tool_calls iterator is - # not consumed. This affect messages passed to the MistralTokenizer - # since no chat template is applied and therefore the tools_calls - # iterator is not directly consumed. - # Issue is tracked on Pydantic side, with resolution planned for - # v2.11 release. In the meantime, the official workaround is to - # consume the iterator so the tool_calls are correctly deserialized - # in the OpenAI ChatCompletionAssistantMessageParam object - # https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501 - # Official Pydantic Issues: - # - https://github.com/pydantic/pydantic/issues/9541 - # TODO: remove when pydantic v2.11 is released + # because of issues with pydantic we need to potentially + # re-serialize the tool_calls field of the request + # for more info: see comment in `maybe_serialize_tool_calls` if isinstance(tokenizer, MistralTokenizer): - for i, message in enumerate(request.messages): - if message.get("role") == 'assistant': - tool_calls_validator = message.get( - "tool_calls", ().__iter__()) - validated_tool_calls = [] - while True: - try: - tool_call = next( - tool_calls_validator) # type: ignore - validated_tool_calls.append(tool_call) - except StopIteration: - break - request.messages[i][ - "tool_calls"] = validated_tool_calls + maybe_serialize_tool_calls(request) if (request.tool_choice == "auto" and not (self.enable_auto_tools and tool_parser is not None) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index f5c0d92f3f9bd..5caac84138e3b 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -62,7 +62,7 @@ def __init__(self, tokenizer: AnyTokenizer): ] # map what has been streamed for each tool so far to a list self.bot_token = "[TOOL_CALLS]" self.bot_token_id = self.vocab.get(self.bot_token) - self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) + self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) if self.bot_token_id is None: raise RuntimeError( "Mistral Tool Parser could not locate the tool call token in " @@ -84,16 +84,25 @@ def extract_tool_calls( return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) + + # first remove the BOT token + tool_content = model_output.replace(self.bot_token, "").strip() + try: - # use a regex to find the tool call. remove the BOT token - # and make sure to replace single quotes with double quotes - raw_tool_call = self.tool_call_regex.findall( - model_output.replace(self.bot_token, ""))[0] + # we first try to directly load the json as parsing very nested + # jsons is difficult + try: + function_call_arr = json.loads(tool_content) + except json.JSONDecodeError: + # use a regex to find the part corresponding to the tool call. + # NOTE: This use case should not happen if the model is trained + # correctly. It's a easy possible fix so it's included, but + # can be brittle for very complex / highly nested tool calls + raw_tool_call = self.tool_call_regex.findall(tool_content)[0] + function_call_arr = json.loads(raw_tool_call) - # load the JSON, and then use it to build the Function and # Tool Call - function_call_arr = json.loads(raw_tool_call) tool_calls: List[MistralToolCall] = [ MistralToolCall( type="function", @@ -116,7 +125,7 @@ def extract_tool_calls( # return information to just treat the tool call as regular JSON return ExtractedToolCallInformation(tools_called=False, tool_calls=[], - content=model_output) + content=tool_content) def extract_tool_calls_streaming( self, diff --git a/vllm/transformers_utils/tokenizers/__init__.py b/vllm/transformers_utils/tokenizers/__init__.py index 5f437d414e181..e68ad79b296b8 100644 --- a/vllm/transformers_utils/tokenizers/__init__.py +++ b/vllm/transformers_utils/tokenizers/__init__.py @@ -1,3 +1,3 @@ -from .mistral import MistralTokenizer +from .mistral import MistralTokenizer, maybe_serialize_tool_calls -__all__ = ["MistralTokenizer"] +__all__ = ["MistralTokenizer", "maybe_serialize_tool_calls"] diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 1b273c6b120ea..b1cb9a15b943b 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -7,6 +7,7 @@ import huggingface_hub from huggingface_hub import HfApi, hf_hub_download from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.tokens.tokenizers.base import SpecialTokens # yapf: disable from mistral_common.tokens.tokenizers.mistral import ( MistralTokenizer as PublicMistralTokenizer) @@ -29,6 +30,43 @@ class Encoding: input_ids: List[int] +def maybe_serialize_tool_calls(request: ChatCompletionRequest): + # SEE: https://github.com/vllm-project/vllm/pull/9951 + # Credits go to: @gcalmettes + # NOTE: There is currently a bug in pydantic where attributes + # declared as iterables are replaced in in the instances by + # pydantic-core ValidatorIterator instance. In particular, this + # affects tool_calls defined in ChatCompletionAssistantMessageParam + # model: + # see: + # - https://github.com/pydantic/pydantic/issues/9467 + # As a result, tool_calls from assistant messages are never + # deserialized in the request object if the tool_calls iterator is + # not consumed. This affect messages passed to the MistralTokenizer + # since no chat template is applied and therefore the tools_calls + # iterator is not directly consumed. + # Issue is tracked on Pydantic side, with resolution planned for + # v2.11 release. In the meantime, the official workaround is to + # consume the iterator so the tool_calls are correctly deserialized + # in the OpenAI ChatCompletionAssistantMessageParam object + # https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501 + # Official Pydantic Issues: + # - https://github.com/pydantic/pydantic/issues/9541 + # TODO: remove when pydantic v2.11 is released + for i, message in enumerate(request.messages): + if message.get("role") == 'assistant': + tool_calls_validator = message.get("tool_calls", ().__iter__()) + validated_tool_calls = [] + while True: + try: + tool_call = next(tool_calls_validator) # type: ignore + validated_tool_calls.append(tool_call) + except StopIteration: + break + + request.messages[i]["tool_calls"] = validated_tool_calls + + def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]: repo_cache = os.path.join( huggingface_hub.constants.HF_HUB_CACHE, @@ -222,7 +260,8 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str: if self.is_tekken: tokens = [ t for t in tokens - if t not in self.tokenizer._all_special_tokens + if (t is SpecialTokens.tool_calls + or t not in self.tokenizer._all_special_tokens) ] if any(isinstance(t, bytes) for t in tokens): @@ -246,7 +285,27 @@ def _token_to_id(t: str): else: decoded = "".join(tokens) else: - decoded = self.tokenizer.decode(tokens) # type: ignore[arg-type] + # make sure certain special tokens like Tool calls are + # not decoded + special_tokens = {SpecialTokens.tool_calls} + regular_tokens: List[str] = [] + decoded_list = [] + + for token in tokens: + if token in special_tokens: + if regular_tokens: + decoded_list.append( + self.tokenizer.decode(regular_tokens)) + regular_tokens = [] + decoded_list.append(token) + else: + regular_tokens.append(token) + + if regular_tokens: + decoded_list.append( + self.decode(regular_tokens)) # type: ignore + + decoded = ''.join(decoded_list) return decoded @@ -274,8 +333,11 @@ def convert_ids_to_tokens( assert self.is_tekken or self.is_spm, type(self.tokenizer) if self.is_tekken: - # skip special tokens - ids = [i for i in ids if i > self.tokenizer.num_special_tokens] + # skip special tokens except tool call + ids = [ + i for i in ids if i > self.tokenizer.num_special_tokens or i == + self.tokenizer.get_control_token(SpecialTokens.tool_calls) + ] tokens = [self.tokenizer.id_to_piece(id) for id in ids]