-
-
Notifications
You must be signed in to change notification settings - Fork 5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Mike Depinet <[email protected]>
- Loading branch information
Showing
5 changed files
with
537 additions
and
1 deletion.
There are no files selected for viewing
Empty file.
160 changes: 160 additions & 0 deletions
160
tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
from typing import List | ||
from unittest.mock import MagicMock | ||
|
||
import pytest | ||
|
||
from tests.entrypoints.openai.tool_parsers.utils import ( | ||
run_tool_extraction, run_tool_extraction_streaming) | ||
from vllm.entrypoints.openai.protocol import FunctionCall | ||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager | ||
|
||
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1 | ||
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')" | ||
SIMPLE_FUNCTION_CALL = FunctionCall( | ||
name="get_weather", | ||
arguments='{"city": "San Francisco", "metric": "celsius"}', | ||
) | ||
MORE_TYPES_FUNCTION_OUTPUT = ( | ||
"register_user(name='John Doe', " | ||
"age=37, " | ||
"address={'city': 'San Francisco', 'state': 'CA'}, " | ||
"role=None, " | ||
"passed_test=True, " | ||
"aliases=['John', 'Johnny'])") | ||
MORE_TYPES_FUNCTION_CALL = FunctionCall( | ||
name="register_user", | ||
arguments='{"name": "John Doe", ' | ||
'"age": 37, ' | ||
'"address": {"city": "San Francisco", "state": "CA"}, ' | ||
'"role": null, ' | ||
'"passed_test": true, ' | ||
'"aliases": ["John", "Johnny"]}', | ||
) | ||
PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()" | ||
PARAMETERLESS_FUNCTION_CALL = FunctionCall( | ||
name="get_weather", | ||
arguments='{}', | ||
) | ||
EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})" | ||
EMPTY_DICT_FUNCTION_CALL = FunctionCall( | ||
name="do_something_cool", | ||
arguments='{"additional_data": {}}', | ||
) | ||
EMPTY_LIST_FUNCTION_OUTPUT = "do_something_cool(steps=[])" | ||
EMPTY_LIST_FUNCTION_CALL = FunctionCall( | ||
name="do_something_cool", | ||
arguments='{"steps": []}', | ||
) | ||
ESCAPED_STRING_FUNCTION_OUTPUT = ( | ||
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')") | ||
ESCAPED_STRING_FUNCTION_CALL = FunctionCall( | ||
name="get_weather", | ||
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}', | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("streaming", [True, False]) | ||
def test_no_tool_call(streaming: bool): | ||
mock_tokenizer = MagicMock() | ||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( | ||
mock_tokenizer) | ||
model_output = "How can I help you today?" | ||
|
||
content, tool_calls = run_tool_extraction(tool_parser, | ||
model_output, | ||
streaming=streaming) | ||
|
||
assert content == model_output | ||
assert len(tool_calls) == 0 | ||
|
||
|
||
TEST_CASES = [ | ||
pytest.param(True, | ||
f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL], | ||
id="simple_streaming"), | ||
pytest.param(False, | ||
f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL], | ||
id="simple_nonstreaming"), | ||
pytest.param(True, | ||
f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL], | ||
id="more_types_streaming"), | ||
pytest.param(False, | ||
f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL], | ||
id="more_types_nonstreaming"), | ||
pytest.param(True, | ||
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]", | ||
[PARAMETERLESS_FUNCTION_CALL], | ||
id="parameterless_streaming"), | ||
pytest.param(False, | ||
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]", | ||
[PARAMETERLESS_FUNCTION_CALL], | ||
id="parameterless_nonstreaming"), | ||
pytest.param(True, | ||
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL], | ||
id="empty_dict_streaming"), | ||
pytest.param(False, | ||
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL], | ||
id="empty_dict_nonstreaming"), | ||
pytest.param(True, | ||
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL], | ||
id="empty_list_streaming"), | ||
pytest.param(False, | ||
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL], | ||
id="empty_list_nonstreaming"), | ||
pytest.param(True, | ||
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]", | ||
[ESCAPED_STRING_FUNCTION_CALL], | ||
id="escaped_string_streaming"), | ||
pytest.param(False, | ||
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]", | ||
[ESCAPED_STRING_FUNCTION_CALL], | ||
id="escaped_string_nonstreaming"), | ||
pytest.param(True, | ||
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]", | ||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], | ||
id="parallel_calls_streaming"), | ||
pytest.param(False, | ||
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]", | ||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], | ||
id="parallel_calls_nonstreaming"), | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", | ||
TEST_CASES) | ||
def test_tool_call(streaming: bool, model_output: str, | ||
expected_tool_calls: List[FunctionCall]): | ||
mock_tokenizer = MagicMock() | ||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( | ||
mock_tokenizer) | ||
|
||
content, tool_calls = run_tool_extraction(tool_parser, | ||
model_output, | ||
streaming=streaming) | ||
|
||
assert content is None | ||
assert len(tool_calls) == len(expected_tool_calls) | ||
for actual, expected in zip(tool_calls, expected_tool_calls): | ||
assert actual.type == "function" | ||
assert actual.function == expected | ||
|
||
|
||
def test_streaming_tool_call_with_large_steps(): | ||
mock_tokenizer = MagicMock() | ||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( | ||
mock_tokenizer) | ||
model_output_deltas = [ | ||
"[get_weather(city='San", | ||
" Francisco', metric='celsius'), " | ||
f"{PARAMETERLESS_FUNCTION_OUTPUT}, " | ||
f"{EMPTY_LIST_FUNCTION_OUTPUT}]", | ||
] | ||
|
||
reconstructor = run_tool_extraction_streaming(tool_parser, | ||
model_output_deltas) | ||
|
||
assert reconstructor.other_content == "" | ||
assert len(reconstructor.tool_calls) == 3 | ||
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL | ||
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL | ||
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
from typing import Iterable, List, Tuple, Union | ||
|
||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, | ||
DeltaMessage, | ||
ExtractedToolCallInformation, | ||
FunctionCall, ToolCall) | ||
from vllm.entrypoints.openai.tool_parsers import ToolParser | ||
|
||
|
||
class StreamingToolReconstructor: | ||
|
||
def __init__(self): | ||
self.tool_calls: List[ToolCall] = [] | ||
self.other_content: str = "" | ||
|
||
def append_delta(self, delta: DeltaMessage): | ||
if delta.content is not None: | ||
self.other_content += delta.content | ||
else: | ||
assert delta.tool_calls, ( | ||
"Streaming results should have either content or tool calls " | ||
"(or both)") | ||
for call_delta in delta.tool_calls: | ||
assert call_delta.type == "function", ( | ||
"Streaming tool calls should only emit function calls. Got " | ||
f"{call_delta.type}") | ||
current_tool_call = self.tool_calls[ | ||
-1] if self.tool_calls and call_delta.id in { | ||
None, self.tool_calls[-1].id | ||
} else None | ||
if current_tool_call: | ||
assert ( | ||
current_tool_call.function.name == call_delta.function.name | ||
or not call_delta.function.name | ||
), ("Streaming tool calls should not emit partial function " | ||
f"names. Got {call_delta.function.name}") | ||
assert ( | ||
current_tool_call.id == call_delta.id or not call_delta.id | ||
), ("Streaming tool calls must not change function ids. Got " | ||
f"{call_delta.id}, expected {current_tool_call.id} or None" | ||
) | ||
assert (call_delta.index == len(self.tool_calls) - 1), ( | ||
f"Incorrect index for tool delta. Got {call_delta.index}, " | ||
f"expected {len(self.tool_calls) - 1}") | ||
current_tool_call.function.arguments += ( | ||
call_delta.function.arguments) | ||
else: | ||
assert call_delta.id is not None, ( | ||
"Streaming tool calls must have an id on first appearance") | ||
assert call_delta.function.name is not None, ( | ||
"Streaming tool calls must have a function name on first " | ||
"appearance") | ||
assert call_delta.index == len(self.tool_calls), ( | ||
f"Incorrect index for tool delta. Got {call_delta.index}, " | ||
f"expected {len(self.tool_calls)}") | ||
self.tool_calls.append( | ||
ToolCall(id=call_delta.id, | ||
function=FunctionCall( | ||
name=call_delta.function.name, | ||
arguments=call_delta.function.arguments | ||
or ""))) | ||
|
||
|
||
def run_tool_extraction( | ||
tool_parser: ToolParser, | ||
model_output: str, | ||
request: Union[ChatCompletionRequest, None] = None, | ||
streaming: bool = False) -> Tuple[Union[str, None], List[ToolCall]]: | ||
if streaming: | ||
reconstructor = run_tool_extraction_streaming(tool_parser, | ||
model_output, request) | ||
return reconstructor.other_content or None, reconstructor.tool_calls | ||
else: | ||
extracted = run_tool_extraction_nonstreaming(tool_parser, model_output, | ||
request) | ||
assert extracted.tools_called == bool(extracted.tool_calls) | ||
return extracted.content, extracted.tool_calls | ||
|
||
|
||
def run_tool_extraction_nonstreaming( | ||
tool_parser: ToolParser, | ||
model_output: str, | ||
request: Union[ChatCompletionRequest, None] = None | ||
) -> ExtractedToolCallInformation: | ||
request = request or ChatCompletionRequest(messages=[], model="test-model") | ||
return tool_parser.extract_tool_calls(model_output, request) | ||
|
||
|
||
def run_tool_extraction_streaming( | ||
tool_parser: ToolParser, | ||
model_deltas: Iterable[str], | ||
request: Union[ChatCompletionRequest, None] = None | ||
) -> StreamingToolReconstructor: | ||
request = request or ChatCompletionRequest(messages=[], model="test-model") | ||
reconstructor = StreamingToolReconstructor() | ||
previous_text = "" | ||
previous_tokens: List[int] = [] | ||
for delta in model_deltas: | ||
token_delta = [ | ||
tool_parser.vocab.get(token) | ||
for token in tool_parser.model_tokenizer.tokenize(delta) | ||
if token in tool_parser.vocab | ||
] | ||
current_text = previous_text + delta | ||
current_tokens = previous_tokens + token_delta | ||
delta_message = tool_parser.extract_tool_calls_streaming( | ||
previous_text, current_text, delta, previous_tokens, | ||
current_tokens, token_delta, request) | ||
if delta_message is not None: | ||
reconstructor.append_delta(delta_message) | ||
previous_text = current_text | ||
previous_tokens = current_tokens | ||
return reconstructor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.