Skip to content

Commit

Permalink
pythonic tool parser
Browse files Browse the repository at this point in the history
Signed-off-by: Mike Depinet <[email protected]>
  • Loading branch information
mdepinet committed Oct 30, 2024
1 parent 00d91c8 commit 1aac0b8
Show file tree
Hide file tree
Showing 5 changed files with 537 additions and 1 deletion.
Empty file.
160 changes: 160 additions & 0 deletions tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from typing import List
from unittest.mock import MagicMock

import pytest

from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction, run_tool_extraction_streaming)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager

# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
SIMPLE_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "San Francisco", "metric": "celsius"}',
)
MORE_TYPES_FUNCTION_OUTPUT = (
"register_user(name='John Doe', "
"age=37, "
"address={'city': 'San Francisco', 'state': 'CA'}, "
"role=None, "
"passed_test=True, "
"aliases=['John', 'Johnny'])")
MORE_TYPES_FUNCTION_CALL = FunctionCall(
name="register_user",
arguments='{"name": "John Doe", '
'"age": 37, '
'"address": {"city": "San Francisco", "state": "CA"}, '
'"role": null, '
'"passed_test": true, '
'"aliases": ["John", "Johnny"]}',
)
PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()"
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{}',
)
EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})"
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
name="do_something_cool",
arguments='{"additional_data": {}}',
)
EMPTY_LIST_FUNCTION_OUTPUT = "do_something_cool(steps=[])"
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
name="do_something_cool",
arguments='{"steps": []}',
)
ESCAPED_STRING_FUNCTION_OUTPUT = (
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')")
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
name="get_weather",
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
)


@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
mock_tokenizer)
model_output = "How can I help you today?"

content, tool_calls = run_tool_extraction(tool_parser,
model_output,
streaming=streaming)

assert content == model_output
assert len(tool_calls) == 0


TEST_CASES = [
pytest.param(True,
f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL],
id="simple_streaming"),
pytest.param(False,
f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL],
id="simple_nonstreaming"),
pytest.param(True,
f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL],
id="more_types_streaming"),
pytest.param(False,
f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL],
id="more_types_nonstreaming"),
pytest.param(True,
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_streaming"),
pytest.param(False,
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
[PARAMETERLESS_FUNCTION_CALL],
id="parameterless_nonstreaming"),
pytest.param(True,
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_streaming"),
pytest.param(False,
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL],
id="empty_dict_nonstreaming"),
pytest.param(True,
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL],
id="empty_list_streaming"),
pytest.param(False,
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL],
id="empty_list_nonstreaming"),
pytest.param(True,
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_streaming"),
pytest.param(False,
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
[ESCAPED_STRING_FUNCTION_CALL],
id="escaped_string_nonstreaming"),
pytest.param(True,
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
id="parallel_calls_streaming"),
pytest.param(False,
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
id="parallel_calls_nonstreaming"),
]


@pytest.mark.parametrize("streaming, model_output, expected_tool_calls",
TEST_CASES)
def test_tool_call(streaming: bool, model_output: str,
expected_tool_calls: List[FunctionCall]):
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
mock_tokenizer)

content, tool_calls = run_tool_extraction(tool_parser,
model_output,
streaming=streaming)

assert content is None
assert len(tool_calls) == len(expected_tool_calls)
for actual, expected in zip(tool_calls, expected_tool_calls):
assert actual.type == "function"
assert actual.function == expected


def test_streaming_tool_call_with_large_steps():
mock_tokenizer = MagicMock()
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
mock_tokenizer)
model_output_deltas = [
"[get_weather(city='San",
" Francisco', metric='celsius'), "
f"{PARAMETERLESS_FUNCTION_OUTPUT}, "
f"{EMPTY_LIST_FUNCTION_OUTPUT}]",
]

reconstructor = run_tool_extraction_streaming(tool_parser,
model_output_deltas)

assert reconstructor.other_content == ""
assert len(reconstructor.tool_calls) == 3
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
113 changes: 113 additions & 0 deletions tests/entrypoints/openai/tool_parsers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from typing import Iterable, List, Tuple, Union

from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers import ToolParser


class StreamingToolReconstructor:

def __init__(self):
self.tool_calls: List[ToolCall] = []
self.other_content: str = ""

def append_delta(self, delta: DeltaMessage):
if delta.content is not None:
self.other_content += delta.content
else:
assert delta.tool_calls, (
"Streaming results should have either content or tool calls "
"(or both)")
for call_delta in delta.tool_calls:
assert call_delta.type == "function", (
"Streaming tool calls should only emit function calls. Got "
f"{call_delta.type}")
current_tool_call = self.tool_calls[
-1] if self.tool_calls and call_delta.id in {
None, self.tool_calls[-1].id
} else None
if current_tool_call:
assert (
current_tool_call.function.name == call_delta.function.name
or not call_delta.function.name
), ("Streaming tool calls should not emit partial function "
f"names. Got {call_delta.function.name}")
assert (
current_tool_call.id == call_delta.id or not call_delta.id
), ("Streaming tool calls must not change function ids. Got "
f"{call_delta.id}, expected {current_tool_call.id} or None"
)
assert (call_delta.index == len(self.tool_calls) - 1), (
f"Incorrect index for tool delta. Got {call_delta.index}, "
f"expected {len(self.tool_calls) - 1}")
current_tool_call.function.arguments += (
call_delta.function.arguments)
else:
assert call_delta.id is not None, (
"Streaming tool calls must have an id on first appearance")
assert call_delta.function.name is not None, (
"Streaming tool calls must have a function name on first "
"appearance")
assert call_delta.index == len(self.tool_calls), (
f"Incorrect index for tool delta. Got {call_delta.index}, "
f"expected {len(self.tool_calls)}")
self.tool_calls.append(
ToolCall(id=call_delta.id,
function=FunctionCall(
name=call_delta.function.name,
arguments=call_delta.function.arguments
or "")))


def run_tool_extraction(
tool_parser: ToolParser,
model_output: str,
request: Union[ChatCompletionRequest, None] = None,
streaming: bool = False) -> Tuple[Union[str, None], List[ToolCall]]:
if streaming:
reconstructor = run_tool_extraction_streaming(tool_parser,
model_output, request)
return reconstructor.other_content or None, reconstructor.tool_calls
else:
extracted = run_tool_extraction_nonstreaming(tool_parser, model_output,
request)
assert extracted.tools_called == bool(extracted.tool_calls)
return extracted.content, extracted.tool_calls


def run_tool_extraction_nonstreaming(
tool_parser: ToolParser,
model_output: str,
request: Union[ChatCompletionRequest, None] = None
) -> ExtractedToolCallInformation:
request = request or ChatCompletionRequest(messages=[], model="test-model")
return tool_parser.extract_tool_calls(model_output, request)


def run_tool_extraction_streaming(
tool_parser: ToolParser,
model_deltas: Iterable[str],
request: Union[ChatCompletionRequest, None] = None
) -> StreamingToolReconstructor:
request = request or ChatCompletionRequest(messages=[], model="test-model")
reconstructor = StreamingToolReconstructor()
previous_text = ""
previous_tokens: List[int] = []
for delta in model_deltas:
token_delta = [
tool_parser.vocab.get(token)
for token in tool_parser.model_tokenizer.tokenize(delta)
if token in tool_parser.vocab
]
current_text = previous_text + delta
current_tokens = previous_tokens + token_delta
delta_message = tool_parser.extract_tool_calls_streaming(
previous_text, current_text, delta, previous_tokens,
current_tokens, token_delta, request)
if delta_message is not None:
reconstructor.append_delta(delta_message)
previous_text = current_text
previous_tokens = current_tokens
return reconstructor
3 changes: 2 additions & 1 deletion vllm/entrypoints/openai/tool_parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from .jamba_tool_parser import JambaToolParser
from .llama_tool_parser import Llama3JsonToolParser
from .mistral_tool_parser import MistralToolParser
from .pythonic_tool_parser import PythonicToolParser

__all__ = [
"ToolParser", "ToolParserManager", "Granite20bFCToolParser",
"Hermes2ProToolParser", "MistralToolParser", "Internlm2ToolParser",
"Llama3JsonToolParser", "JambaToolParser"
"Llama3JsonToolParser", "JambaToolParser", "PythonicToolParser"
]
Loading

0 comments on commit 1aac0b8

Please sign in to comment.