Skip to content

Commit

Permalink
Add tools unpacking for vLLM (kserve#4035)
Browse files Browse the repository at this point in the history
* Add tools to chat template

Signed-off-by: Arjun Bhalla <[email protected]>

Linting

Signed-off-by: Arjun Bhalla <[email protected]>

add test

Signed-off-by: Arjun Bhalla <[email protected]>

Fix linting manually

Signed-off-by: Arjun Bhalla <[email protected]>

* Fix linting

Signed-off-by: Arjun Bhalla <[email protected]>

* Add tools unpacking for vllm

Signed-off-by: Arjun Bhalla <[email protected]>

* Add sanity check test

Signed-off-by: Arjun Bhalla <[email protected]>

---------

Signed-off-by: Arjun Bhalla <[email protected]>
Signed-off-by: Arjun Bhalla <[email protected]>
Co-authored-by: Arjun Bhalla <[email protected]>
  • Loading branch information
2 people authored and javierdlrm committed Dec 2, 2024
1 parent 7dadb51 commit 3b8c621
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def apply_chat_template(
chat_template=chat_template,
tokenize=False,
add_generation_prompt=True,
tools=tools,
tools=[tool.model_dump() for tool in tools] if tools else None,
)

async def _post_init(self):
Expand Down
54 changes: 51 additions & 3 deletions python/huggingfaceserver/tests/test_vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
TopLogprob,
CreateCompletionResponse,
Choice,
ChatCompletionTool,
)
from vllm_mock_outputs import (
opt_chat_cmpl_chunks,
Expand Down Expand Up @@ -149,6 +150,56 @@ async def test_vllm_chat_completion_tokenization_facebook_opt_model(
)
assert compare_chatprompt_to_expected(response, expected) is True

async def test_vllm_chat_completion_template_tools(self, vllm_opt_model):
opt_model, _ = vllm_opt_model

messages = [
{
"role": "system",
"content": "You are a friendly chatbot who will help users with weather queries.",
},
{
"role": "user",
"content": "What is the weather in Ithaca, NY?",
},
]

tool = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "dict",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the users location.",
},
},
"required": ["location", "format"],
},
},
}

tools = [ChatCompletionTool(**tool)]

chat_template = (
"{% for message in messages %}"
"{{ message.content }}{{ eos_token }}"
"{% for tool in tools %}"
"{% endfor %}{% endfor %}"
)
response = opt_model.apply_chat_template(messages, chat_template, tools)

# Sanity check / usage example to ensure that no error is thrown
assert response.prompt is not None


def compare_response_to_expected(actual, expected, fields_to_compare=None) -> bool:
if fields_to_compare is None:
Expand Down Expand Up @@ -180,7 +231,6 @@ def compare_response_to_expected(actual, expected, fields_to_compare=None) -> bo

@pytest.mark.asyncio()
class TestChatCompletions:

async def test_vllm_chat_completion_facebook_opt_model_without_request_id(
self, vllm_opt_model
):
Expand Down Expand Up @@ -1032,7 +1082,6 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:

@pytest.mark.asyncio()
class TestCompletions:

async def test_vllm_completion_facebook_opt_model_without_request_id(
self, vllm_opt_model
):
Expand Down Expand Up @@ -2872,7 +2921,6 @@ async def mock_generate(*args, **kwargs) -> AsyncIterator[RequestOutput]:


class TestOpenAIServingCompletion:

def test_validate_input_with_max_tokens_exceeding_model_limit(self, vllm_opt_model):
opt_model, mock_vllm_engine = vllm_opt_model
prompt = "Hi, I love my cat"
Expand Down

0 comments on commit 3b8c621

Please sign in to comment.