-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tool parsing] Improve / correct mistral tool parsing #10333
Changes from all commits
dd2df17
7a72ccc
0b551ba
800b376
fe39e84
f1d3cf2
ab8c7e2
5cbbff1
b694ba5
fac07af
7e9ae4c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,9 +2,13 @@ | |
Run `pytest tests/models/test_mistral.py`. | ||
""" | ||
import copy | ||
|
||
import pytest | ||
|
||
from vllm import SamplingParams | ||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa | ||
MistralToolParser) | ||
|
||
from ...utils import check_logprobs_close | ||
|
||
|
@@ -58,17 +62,69 @@ | |
}, | ||
"required": ["city", "state", "unit"] | ||
} | ||
}, | ||
}, { | ||
"type": "function", | ||
"function": { | ||
"name": "rewrite", | ||
"description": "Rewrites text", | ||
"parameters": { | ||
"type": "object", | ||
"required": [], | ||
"properties": { | ||
"text": { | ||
"type": "string", | ||
"description": "The input text to rewrite." | ||
} | ||
} | ||
} | ||
} | ||
}] | ||
MSGS = [{ | ||
"role": | ||
"user", | ||
"content": ("Can you tell me what the temperate" | ||
" will be in Dallas, in fahrenheit?") | ||
}] | ||
EXPECTED_FUNC_CALL = ( | ||
'[{"name": "get_current_weather", "arguments": ' | ||
'{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]') | ||
MSGS = [ | ||
{ | ||
"role": "system", | ||
"content": "You are an assistant." | ||
}, | ||
{ | ||
"role": | ||
"user", | ||
"content": | ||
"Could you please rewrite the below article? \n\n My English needs improvving, maybe I make errors." # noqa | ||
}, | ||
{ | ||
"role": | ||
"assistant", | ||
"content": | ||
"", | ||
"tool_calls": [{ | ||
"id": "bbc5b7ede", | ||
"type": "function", | ||
"function": { | ||
"name": | ||
"rewrite", | ||
"arguments": | ||
'{\"text\":\"My English needs improvving, maybe I make errors.\"}' # noqa | ||
} | ||
}] | ||
}, | ||
{ | ||
"role": "tool", | ||
"content": | ||
"{\"action\":\"rewrite\",\"outcome\":\"My English needs improving, maybe I make errors.\"}", # noqa | ||
"tool_call_id": "bbc5b7ede", | ||
"name": "rewrite" | ||
}, | ||
{ | ||
"role": "assistant", | ||
"content": "---\n\nMy English needs improving, maybe I make errors" | ||
}, | ||
{ | ||
"role": | ||
"user", | ||
"content": ("Can you tell me what the temperate" | ||
" will be in Dallas, in fahrenheit?") | ||
} | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
|
@@ -175,8 +231,23 @@ def test_mistral_function_calling( | |
tokenizer_mode="mistral", | ||
config_format="mistral", | ||
load_format="mistral") as vllm_model: | ||
outputs = vllm_model.model.chat(MSGS, | ||
|
||
msgs = copy.deepcopy(MSGS) | ||
outputs = vllm_model.model.chat(msgs, | ||
tools=TOOLS, | ||
sampling_params=SAMPLING_PARAMS) | ||
|
||
assert outputs[0].outputs[0].text.strip() == EXPECTED_FUNC_CALL | ||
tokenizer = vllm_model.model.get_tokenizer() | ||
tool_parser = MistralToolParser(tokenizer) | ||
|
||
model_output = outputs[0].outputs[0].text.strip() | ||
assert model_output.startswith(tool_parser.bot_token), model_output | ||
parsed_message = tool_parser.extract_tool_calls(model_output, None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cleaner to let the parser take care of correctly extracting the dict |
||
|
||
assert parsed_message.tools_called | ||
assert parsed_message.tool_calls[0].id == "0UAqFzWsD" | ||
assert parsed_message.tool_calls[ | ||
0].function.name == "get_current_weather" | ||
assert parsed_message.tool_calls[ | ||
0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa | ||
assert parsed_message.content is None |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
from vllm.sampling_params import BeamSearchParams, SamplingParams | ||
from vllm.sequence import Logprob | ||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer | ||
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls | ||
from vllm.utils import iterate_with_cancellation | ||
|
||
logger = init_logger(__name__) | ||
|
@@ -127,41 +128,11 @@ async def create_chat_completion( | |
return self.create_error_response( | ||
"tool_choice = \"required\" is not supported!") | ||
|
||
# NOTE: There is currently a bug in pydantic where attributes | ||
# declared as iterables are replaced in in the instances by | ||
# pydantic-core ValidatorIterator instance. In particular, this | ||
# affects tool_calls defined in ChatCompletionAssistantMessageParam | ||
# model: | ||
# see: | ||
# - https://github.com/pydantic/pydantic/issues/9467 | ||
# As a result, tool_calls from assistant messages are never | ||
# deserialized in the request object if the tool_calls iterator is | ||
# not consumed. This affect messages passed to the MistralTokenizer | ||
# since no chat template is applied and therefore the tools_calls | ||
# iterator is not directly consumed. | ||
# Issue is tracked on Pydantic side, with resolution planned for | ||
# v2.11 release. In the meantime, the official workaround is to | ||
# consume the iterator so the tool_calls are correctly deserialized | ||
# in the OpenAI ChatCompletionAssistantMessageParam object | ||
# https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501 | ||
# Official Pydantic Issues: | ||
# - https://github.com/pydantic/pydantic/issues/9541 | ||
# TODO: remove when pydantic v2.11 is released | ||
# because of issues with pydantic we need to potentially | ||
# re-serialize the tool_calls field of the request | ||
# for more info: see comment in `maybe_serialize_tool_calls` | ||
if isinstance(tokenizer, MistralTokenizer): | ||
for i, message in enumerate(request.messages): | ||
if message.get("role") == 'assistant': | ||
tool_calls_validator = message.get( | ||
"tool_calls", ().__iter__()) | ||
validated_tool_calls = [] | ||
while True: | ||
try: | ||
tool_call = next( | ||
tool_calls_validator) # type: ignore | ||
validated_tool_calls.append(tool_call) | ||
except StopIteration: | ||
break | ||
request.messages[i][ | ||
"tool_calls"] = validated_tool_calls | ||
maybe_serialize_tool_calls(request) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. moving this out of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point! I had originally thought about putting it directly in the Mistral Tokenizer but did not in the end because the same problem would occur for any other futur models having a tokenizer not relying on jinja chat templates (none right now, so this was highly hypothetical). |
||
|
||
if (request.tool_choice == "auto" and | ||
not (self.enable_auto_tools and tool_parser is not None) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .mistral import MistralTokenizer | ||
from .mistral import MistralTokenizer, maybe_serialize_tool_calls | ||
|
||
__all__ = ["MistralTokenizer"] | ||
__all__ = ["MistralTokenizer", "maybe_serialize_tool_calls"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
import huggingface_hub | ||
from huggingface_hub import HfApi, hf_hub_download | ||
from mistral_common.protocol.instruct.request import ChatCompletionRequest | ||
from mistral_common.tokens.tokenizers.base import SpecialTokens | ||
# yapf: disable | ||
from mistral_common.tokens.tokenizers.mistral import ( | ||
MistralTokenizer as PublicMistralTokenizer) | ||
|
@@ -29,6 +30,43 @@ class Encoding: | |
input_ids: List[int] | ||
|
||
|
||
def maybe_serialize_tool_calls(request: ChatCompletionRequest): | ||
# SEE: https://github.com/vllm-project/vllm/pull/9951 | ||
# Credits go to: @gcalmettes | ||
# NOTE: There is currently a bug in pydantic where attributes | ||
# declared as iterables are replaced in in the instances by | ||
# pydantic-core ValidatorIterator instance. In particular, this | ||
# affects tool_calls defined in ChatCompletionAssistantMessageParam | ||
# model: | ||
# see: | ||
# - https://github.com/pydantic/pydantic/issues/9467 | ||
# As a result, tool_calls from assistant messages are never | ||
# deserialized in the request object if the tool_calls iterator is | ||
# not consumed. This affect messages passed to the MistralTokenizer | ||
# since no chat template is applied and therefore the tools_calls | ||
# iterator is not directly consumed. | ||
# Issue is tracked on Pydantic side, with resolution planned for | ||
# v2.11 release. In the meantime, the official workaround is to | ||
# consume the iterator so the tool_calls are correctly deserialized | ||
# in the OpenAI ChatCompletionAssistantMessageParam object | ||
# https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501 | ||
# Official Pydantic Issues: | ||
# - https://github.com/pydantic/pydantic/issues/9541 | ||
# TODO: remove when pydantic v2.11 is released | ||
for i, message in enumerate(request.messages): | ||
if message.get("role") == 'assistant': | ||
tool_calls_validator = message.get("tool_calls", ().__iter__()) | ||
validated_tool_calls = [] | ||
while True: | ||
try: | ||
tool_call = next(tool_calls_validator) # type: ignore | ||
validated_tool_calls.append(tool_call) | ||
except StopIteration: | ||
break | ||
|
||
request.messages[i]["tool_calls"] = validated_tool_calls | ||
|
||
|
||
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As proposed by @gcalmettes here: #9059 (comment) We don't parse away the [TOOL_CALLS] token for neither tekken nor spm so that function calls can be correctly parsed. |
||
repo_cache = os.path.join( | ||
huggingface_hub.constants.HF_HUB_CACHE, | ||
|
@@ -222,7 +260,8 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str: | |
if self.is_tekken: | ||
tokens = [ | ||
t for t in tokens | ||
if t not in self.tokenizer._all_special_tokens | ||
if (t is SpecialTokens.tool_calls | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that after further testing on my end, I found a edge case where not skipping the
If we can find a way to not filter out the I have an easy reproducible example of this problem that I can share to you. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the note! Would be great if you could share an easy repro There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @patrickvonplaten please find below a scenario were it will break (and further below the small change in prompt that would make the code work, because of added guidance to the model). Note that the code requires However, after further investigation, I know now how to fix it (I'm preparing a PR, I'll tag you for your review) ! In fact the problem was present before but "masked" by the fact that the
Guiding the model to output JSON by changing the system prompt as below is enough so that the model actually does not produce a tool_call token :
|
||
or t not in self.tokenizer._all_special_tokens) | ||
] | ||
|
||
if any(isinstance(t, bytes) for t in tokens): | ||
|
@@ -246,7 +285,27 @@ def _token_to_id(t: str): | |
else: | ||
decoded = "".join(tokens) | ||
else: | ||
decoded = self.tokenizer.decode(tokens) # type: ignore[arg-type] | ||
# make sure certain special tokens like Tool calls are | ||
# not decoded | ||
special_tokens = {SpecialTokens.tool_calls} | ||
regular_tokens: List[str] = [] | ||
decoded_list = [] | ||
|
||
for token in tokens: | ||
if token in special_tokens: | ||
if regular_tokens: | ||
decoded_list.append( | ||
self.tokenizer.decode(regular_tokens)) | ||
regular_tokens = [] | ||
decoded_list.append(token) | ||
else: | ||
regular_tokens.append(token) | ||
|
||
if regular_tokens: | ||
decoded_list.append( | ||
self.decode(regular_tokens)) # type: ignore | ||
|
||
decoded = ''.join(decoded_list) | ||
|
||
return decoded | ||
|
||
|
@@ -274,8 +333,11 @@ def convert_ids_to_tokens( | |
assert self.is_tekken or self.is_spm, type(self.tokenizer) | ||
|
||
if self.is_tekken: | ||
# skip special tokens | ||
ids = [i for i in ids if i > self.tokenizer.num_special_tokens] | ||
# skip special tokens except tool call | ||
ids = [ | ||
i for i in ids if i > self.tokenizer.num_special_tokens or i == | ||
self.tokenizer.get_control_token(SpecialTokens.tool_calls) | ||
] | ||
|
||
tokens = [self.tokenizer.id_to_piece(id) for id in ids] | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make test much more difficult, complex to show the community to what extent function calling can be used with Mistral models