From 76a2b7a98b5ff568ce77fde48539911c2f37b614 Mon Sep 17 00:00:00 2001 From: Guillaume Calmettes Date: Thu, 14 Nov 2024 05:48:16 +0100 Subject: [PATCH] [BugFix]: properly deserialize `tool_calls` iterator before processing by mistral-common when MistralTokenizer is used (#9951) Signed-off-by: Guillaume Calmettes Signed-off-by: OmerD --- vllm/entrypoints/openai/serving_chat.py | 36 +++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 07cc9e94bdd03..5178481c737b4 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -127,6 +127,42 @@ async def create_chat_completion( return self.create_error_response( "tool_choice = \"required\" is not supported!") + # NOTE: There is currently a bug in pydantic where attributes + # declared as iterables are replaced in in the instances by + # pydantic-core ValidatorIterator instance. In particular, this + # affects tool_calls defined in ChatCompletionAssistantMessageParam + # model: + # see: + # - https://github.com/pydantic/pydantic/issues/9467 + # As a result, tool_calls from assistant messages are never + # deserialized in the request object if the tool_calls iterator is + # not consumed. This affect messages passed to the MistralTokenizer + # since no chat template is applied and therefore the tools_calls + # iterator is not directly consumed. + # Issue is tracked on Pydantic side, with resolution planned for + # v2.11 release. In the meantime, the official workaround is to + # consume the iterator so the tool_calls are correctly deserialized + # in the OpenAI ChatCompletionAssistantMessageParam object + # https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501 + # Official Pydantic Issues: + # - https://github.com/pydantic/pydantic/issues/9541 + # TODO: remove when pydantic v2.11 is released + if isinstance(tokenizer, MistralTokenizer): + for i, message in enumerate(request.messages): + if message.get("role") == 'assistant': + tool_calls_validator = message.get( + "tool_calls", ().__iter__()) + validated_tool_calls = [] + while True: + try: + tool_call = next( + tool_calls_validator) # type: ignore + validated_tool_calls.append(tool_call) + except StopIteration: + break + request.messages[i][ + "tool_calls"] = validated_tool_calls + if (request.tool_choice == "auto" and not (self.enable_auto_tools and tool_parser is not None) and not isinstance(tokenizer, MistralTokenizer)):