Skip to content

Commit

Permalink
[BugFix]: properly deserialize tool_calls iterator before processin…
Browse files Browse the repository at this point in the history
…g by mistral-common when MistralTokenizer is used (vllm-project#9951)

Signed-off-by: Guillaume Calmettes <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
  • Loading branch information
gcalmettes authored and tlrmchlsmth committed Nov 23, 2024
1 parent 77d1ab0 commit e4a392b
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,42 @@ async def create_chat_completion(
return self.create_error_response(
"tool_choice = \"required\" is not supported!")

# NOTE: There is currently a bug in pydantic where attributes
# declared as iterables are replaced in in the instances by
# pydantic-core ValidatorIterator instance. In particular, this
# affects tool_calls defined in ChatCompletionAssistantMessageParam
# model:
# see:
# - https://github.com/pydantic/pydantic/issues/9467
# As a result, tool_calls from assistant messages are never
# deserialized in the request object if the tool_calls iterator is
# not consumed. This affect messages passed to the MistralTokenizer
# since no chat template is applied and therefore the tools_calls
# iterator is not directly consumed.
# Issue is tracked on Pydantic side, with resolution planned for
# v2.11 release. In the meantime, the official workaround is to
# consume the iterator so the tool_calls are correctly deserialized
# in the OpenAI ChatCompletionAssistantMessageParam object
# https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
# Official Pydantic Issues:
# - https://github.com/pydantic/pydantic/issues/9541
# TODO: remove when pydantic v2.11 is released
if isinstance(tokenizer, MistralTokenizer):
for i, message in enumerate(request.messages):
if message.get("role") == 'assistant':
tool_calls_validator = message.get(
"tool_calls", ().__iter__())
validated_tool_calls = []
while True:
try:
tool_call = next(
tool_calls_validator) # type: ignore
validated_tool_calls.append(tool_call)
except StopIteration:
break
request.messages[i][
"tool_calls"] = validated_tool_calls

if (request.tool_choice == "auto" and
not (self.enable_auto_tools and tool_parser is not None)
and not isinstance(tokenizer, MistralTokenizer)):
Expand Down

0 comments on commit e4a392b

Please sign in to comment.