From 099ddb3482ac0f3e4f06c635a5abb4ac97adcb67 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Thu, 23 May 2024 15:22:32 +0200 Subject: [PATCH 01/25] WIP added tool classes --- vllm/entrypoints/openai/protocol.py | 40 ++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 41e2f77fe56f1..ff70a4eb580ff 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -100,6 +100,22 @@ class ResponseFormat(OpenAIBaseModel): # type must be "json_object" or "text" type: Literal["text", "json_object"] +class FunctionDefinition(OpenAIBaseModel): + name: str + description: Optional[str] = None + parameters: Optional[Dict[str, Any]] = None + +class ChatCompletionToolsParam(OpenAIBaseModel): + type: Literal["function"] = "function" + function: FunctionDefinition + +class ChatCompletionNamedFunction(OpenAIBaseModel): + name: str + +class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): + function: ChatCompletionNamedFunction + type: Literal["function"] = "function" + class ChatCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation @@ -121,6 +137,9 @@ class ChatCompletionRequest(OpenAIBaseModel): stream: Optional[bool] = False temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 + tools: Optional[List[ChatCompletionToolsParam]] = None + tool_choice: Optional[Union[Literal["none", "required"], + ChatCompletionNamedToolChoiceParam]] = "none" user: Optional[str] = None # doc: begin-chat-completion-sampling-params @@ -484,22 +503,31 @@ class EmbeddingResponse(BaseModel): usage: UsageInfo +class FunctionCall(OpenAIBaseModel): + name: str + arguments: str + +class ToolCall(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") + type: Literal["function"] = "function" + function: FunctionCall + class ChatMessage(OpenAIBaseModel): role: str content: str + tool_calls: List[ToolCall] = Field(default_factory=list) class ChatCompletionResponseChoice(OpenAIBaseModel): index: int message: ChatMessage logprobs: Optional[LogProbs] = None - finish_reason: Optional[str] = None - stop_reason: Optional[Union[int, str]] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None class ChatCompletionResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") - object: str = "chat.completion" + object: Literal["chat.completion"] = "chat.completion" created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[ChatCompletionResponseChoice] @@ -509,19 +537,19 @@ class ChatCompletionResponse(OpenAIBaseModel): class DeltaMessage(OpenAIBaseModel): role: Optional[str] = None content: Optional[str] = None + tool_calls: List[ToolCall] = Field(default_factory=list) class ChatCompletionResponseStreamChoice(OpenAIBaseModel): index: int delta: DeltaMessage logprobs: Optional[LogProbs] = None - finish_reason: Optional[str] = None - stop_reason: Optional[Union[int, str]] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None class ChatCompletionStreamResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") - object: str = "chat.completion.chunk" + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[ChatCompletionResponseStreamChoice] From 5934e221bfdfc1560d93574008aab7f19b85e75c Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Thu, 23 May 2024 23:30:30 +0200 Subject: [PATCH 02/25] added correct models. Tests still missing --- vllm/entrypoints/openai/protocol.py | 27 ++++++++--- vllm/entrypoints/openai/serving_chat.py | 63 +++++++++++++++++++++---- 2 files changed, 75 insertions(+), 15 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 41e2f77fe56f1..aa42eb4c50cd6 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -489,12 +489,28 @@ class ChatMessage(OpenAIBaseModel): content: str +class ChatCompletionLogProb(OpenAIBaseModel): + token: str + logprob: float + bytes: Optional[List[int]] = None + + +class ChatCompletionLogProbsContent(OpenAIBaseModel): + token: str + logprob: float = -9999.0 + bytes: Optional[List[int]] + top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list) + + +class ChatCompletionLogProbs(OpenAIBaseModel): + content: List[ChatCompletionLogProbsContent] = Field(default_factory=list) + + class ChatCompletionResponseChoice(OpenAIBaseModel): index: int message: ChatMessage - logprobs: Optional[LogProbs] = None - finish_reason: Optional[str] = None - stop_reason: Optional[Union[int, str]] = None + logprobs: Optional[ChatCompletionLogProbs] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None class ChatCompletionResponse(OpenAIBaseModel): @@ -514,9 +530,8 @@ class DeltaMessage(OpenAIBaseModel): class ChatCompletionResponseStreamChoice(OpenAIBaseModel): index: int delta: DeltaMessage - logprobs: Optional[LogProbs] = None - finish_reason: Optional[str] = None - stop_reason: Optional[Union[int, str]] = None + logprobs: Optional[ChatCompletionLogProbs] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None class ChatCompletionStreamResponse(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 7e179362eef8a..cd15e574263b3 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,8 +1,8 @@ import codecs import time from dataclasses import dataclass -from typing import (AsyncGenerator, AsyncIterator, Iterable, List, Optional, - TypedDict, Union, cast, final) +from typing import (AsyncGenerator, AsyncIterator, Dict, Iterable, List, + Optional, TypedDict, Union, cast, final) from fastapi import Request from openai.types.chat import ChatCompletionContentPartTextParam @@ -10,8 +10,9 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ( - ChatCompletionContentPartParam, ChatCompletionMessageParam, - ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionContentPartParam, ChatCompletionLogProb, + ChatCompletionLogProbs, ChatCompletionLogProbsContent, + ChatCompletionMessageParam, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, UsageInfo) @@ -21,6 +22,7 @@ from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.outputs import RequestOutput +from vllm.sequence import Logprob from vllm.utils import random_uuid logger = init_logger(__name__) @@ -277,11 +279,10 @@ async def chat_completion_stream_generator( previous_num_tokens[i]:] if output.logprobs else None if request.logprobs: - logprobs = self._create_logprobs( + logprobs = self._create_chat_logprobs( token_ids=delta_token_ids, top_logprobs=top_logprobs, - num_output_top_logprobs=request.logprobs, - initial_text_offset=len(previous_texts[i]), + num_output_top_logprobs=request.top_logprobs, ) else: logprobs = None @@ -364,10 +365,10 @@ async def chat_completion_full_generator( top_logprobs = output.logprobs if request.logprobs: - logprobs = self._create_logprobs( + logprobs = self._create_chat_logprobs( token_ids=token_ids, top_logprobs=top_logprobs, - num_output_top_logprobs=request.logprobs, + num_output_top_logprobs=request.top_logprobs, ) else: logprobs = None @@ -408,3 +409,47 @@ async def chat_completion_full_generator( ) return response + + def _get_top_logprobs( + self, logprobs: Dict[int, Logprob], + top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]: + return [ + ChatCompletionLogProb( + token=p.decoded_token, + logprob=p.logprob, + bytes=list(p.decoded_token.encode("utf-8", errors="replace"))) + for i, p in enumerate(logprobs.values()) + if top_logprobs and i < top_logprobs + ] + + def _create_chat_logprobs( + self, + token_ids: List[int], + top_logprobs: List[Optional[Dict[int, Logprob]]], + num_output_top_logprobs: Optional[int] = None, + ) -> ChatCompletionLogProbs: + """Create OpenAI-style logprobs.""" + logprobs = ChatCompletionLogProbs() + + for i, token_id in enumerate(token_ids): + logprob_response_object = ChatCompletionLogProbsContent() + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is None: + token = self.tokenizer.decode(token_id) + logprob_response_object.token = token + logprob_response_object.bytes = list( + token.encode("utf-8", errors="replace")) + assert logprob_response_object.top_logprobs is not None and len( + logprob_response_object.top_logprobs) == 0 + else: + token = step_top_logprobs[token_id].decoded_token + logprob_response_object.token = token + logprob_response_object.logprob = step_top_logprobs[ + token_id].logprob + logprob_response_object.bytes = list( + token.encode("utf-8", errors="replace")) + logprob_response_object.top_logprobs = self._get_top_logprobs( + step_top_logprobs, num_output_top_logprobs) + logprobs.content.append(logprob_response_object) + + return logprobs From 1248bc193e363c003a41ca353dfebaa10f6c9d68 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Fri, 24 May 2024 12:28:34 +0200 Subject: [PATCH 03/25] fix implementation and tests --- tests/entrypoints/test_openai_server.py | 13 ++++---- vllm/entrypoints/openai/protocol.py | 2 +- vllm/entrypoints/openai/serving_chat.py | 41 ++++++++++++------------- 3 files changed, 27 insertions(+), 29 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 1b04e3205c4b8..fa43b064c4ab6 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -225,8 +225,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI, chat_completion.choices) == 1 assert chat_completion.choices[0].message is not None assert chat_completion.choices[0].logprobs is not None - assert chat_completion.choices[0].logprobs.top_logprobs is not None - assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5 + assert chat_completion.choices[0].logprobs.content[ + 0].top_logprobs is not None + assert len( + chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5 message = chat_completion.choices[0].message assert message.content is not None and len(message.content) >= 10 assert message.role == "assistant" @@ -723,13 +725,12 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, top_logprobs=5, extra_body=dict(guided_choice=TEST_CHOICE, guided_decoding_backend=guided_decoding_backend)) - top_logprobs = chat_completion.choices[0].logprobs.top_logprobs + top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs # -9999.0 is the minimum logprob returned by OpenAI assert all( - isinstance(logprob, float) and logprob >= -9999.0 - for token_dict in top_logprobs - for token, logprob in token_dict.items()) + isinstance(token.logprob, float) and token.logprob >= -9999.0 + for token in top_logprobs) async def test_response_format_json_object(server, client: openai.AsyncOpenAI): diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index aa42eb4c50cd6..1b5461d9935f6 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -498,7 +498,7 @@ class ChatCompletionLogProb(OpenAIBaseModel): class ChatCompletionLogProbsContent(OpenAIBaseModel): token: str logprob: float = -9999.0 - bytes: Optional[List[int]] + bytes: Optional[List[int]] = None top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index cd15e574263b3..43577af072d0e 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -318,8 +318,7 @@ async def chat_completion_stream_generator( index=i, delta=DeltaMessage(content=delta_text), logprobs=logprobs, - finish_reason=output.finish_reason, - stop_reason=output.stop_reason) + finish_reason=output.finish_reason) chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, @@ -377,9 +376,7 @@ async def chat_completion_full_generator( index=output.index, message=ChatMessage(role=role, content=output.text), logprobs=logprobs, - finish_reason=output.finish_reason, - stop_reason=output.stop_reason, - ) + finish_reason=output.finish_reason) choices.append(choice_data) if request.echo: @@ -416,7 +413,7 @@ def _get_top_logprobs( return [ ChatCompletionLogProb( token=p.decoded_token, - logprob=p.logprob, + logprob=max(p.logprob, -9999.0), bytes=list(p.decoded_token.encode("utf-8", errors="replace"))) for i, p in enumerate(logprobs.values()) if top_logprobs and i < top_logprobs @@ -429,27 +426,27 @@ def _create_chat_logprobs( num_output_top_logprobs: Optional[int] = None, ) -> ChatCompletionLogProbs: """Create OpenAI-style logprobs.""" + logprobs = ChatCompletionLogProbs() for i, token_id in enumerate(token_ids): - logprob_response_object = ChatCompletionLogProbsContent() step_top_logprobs = top_logprobs[i] if step_top_logprobs is None: - token = self.tokenizer.decode(token_id) - logprob_response_object.token = token - logprob_response_object.bytes = list( - token.encode("utf-8", errors="replace")) - assert logprob_response_object.top_logprobs is not None and len( - logprob_response_object.top_logprobs) == 0 + logprobs.content.append( + ChatCompletionLogProbsContent( + token=self.tokenizer.decode(token_id), + bytes=list( + self.tokenizer.decode(token_id).encode( + "utf-8", errors="replace")))) else: - token = step_top_logprobs[token_id].decoded_token - logprob_response_object.token = token - logprob_response_object.logprob = step_top_logprobs[ - token_id].logprob - logprob_response_object.bytes = list( - token.encode("utf-8", errors="replace")) - logprob_response_object.top_logprobs = self._get_top_logprobs( - step_top_logprobs, num_output_top_logprobs) - logprobs.content.append(logprob_response_object) + logprobs.content.append( + ChatCompletionLogProbsContent( + token=step_top_logprobs[token_id].decoded_token, + logprob=max(step_top_logprobs[token_id].logprob, -9999.0), + bytes=list( + step_top_logprobs[token_id].decoded_token.encode( + "utf-8", errors="replace")), + top_logprobs=self._get_top_logprobs( + step_top_logprobs, num_output_top_logprobs))) return logprobs From 1b2b453d5a934ccaa94b5b309c8a79f104ea0372 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Fri, 24 May 2024 12:35:42 +0200 Subject: [PATCH 04/25] fix formatting --- vllm/entrypoints/openai/serving_chat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 43577af072d0e..15ce7049e58e3 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -442,7 +442,8 @@ def _create_chat_logprobs( logprobs.content.append( ChatCompletionLogProbsContent( token=step_top_logprobs[token_id].decoded_token, - logprob=max(step_top_logprobs[token_id].logprob, -9999.0), + logprob=max(step_top_logprobs[token_id].logprob, + -9999.0), bytes=list( step_top_logprobs[token_id].decoded_token.encode( "utf-8", errors="replace")), From 07af0cce2ba417bd5d4047abf4b16c497f00ae0d Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Fri, 24 May 2024 13:20:06 +0200 Subject: [PATCH 05/25] fix test --- tests/async_engine/test_openapi_server_ray.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index ace4c53916c71..1d5d03a9a77a4 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -94,8 +94,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI): chat_completion.choices) == 1 assert chat_completion.choices[0].message is not None assert chat_completion.choices[0].logprobs is not None - assert chat_completion.choices[0].logprobs.top_logprobs is not None - assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5 + assert chat_completion.choices[0].logprobs.content[ + 0].top_logprobs is not None + assert len( + chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5 message = chat_completion.choices[0].message assert message.content is not None and len(message.content) >= 10 assert message.role == "assistant" From 755625f69b275dd11c3170e7285329a83096e0db Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Fri, 24 May 2024 14:42:52 +0200 Subject: [PATCH 06/25] named tool working --- tests/entrypoints/test_openai_server.py | 70 +++++++++++++++++++ tests/utils.py | 3 +- vllm/entrypoints/openai/protocol.py | 10 +++ .../guided_decoding/__init__.py | 37 +++++++++- 4 files changed, 117 insertions(+), 3 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index fa43b064c4ab6..8d261c1daa7fd 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -733,6 +733,76 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, for token in top_logprobs) +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_named_tool_use(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": + "user", + "content": + f"Give an example JSON for an employee profile that " + f"fits this schema: {TEST_SCHEMA}" + }] + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=1000, + tools=[{ + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": TEST_SCHEMA + } + }], + tool_choice={ + "type": "function", + "function": { + "name": "dummy_function_name" + } + }) + message = chat_completion.choices[0].message + assert message.content is not None + json1 = json.loads(message.content) + jsonschema.validate(instance=json1, schema=TEST_SCHEMA) + + messages.append({"role": "assistant", "content": message.content}) + messages.append({ + "role": + "user", + "content": + "Give me another one with a different name and age" + }) + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=1000, + tools=[{ + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": TEST_SCHEMA + } + }], + tool_choice={ + "type": "function", + "function": { + "name": "dummy_function_name" + } + }) + message = chat_completion.choices[0].message + assert message.content is not None + json2 = json.loads(message.content) + jsonschema.validate(instance=json2, schema=TEST_SCHEMA) + assert json1["name"] != json2["name"] + assert json1["age"] != json2["age"] + + async def test_response_format_json_object(server, client: openai.AsyncOpenAI): for _ in range(2): resp = await client.chat.completions.create( diff --git a/tests/utils.py b/tests/utils.py index 689d8c8c5ba8a..eaa9720628fe3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -22,7 +22,8 @@ def __init__(self, args): env = os.environ.copy() env["PYTHONUNBUFFERED"] = "1" self.proc = subprocess.Popen( - ["python3", "-m", "vllm.entrypoints.openai.api_server"] + args, + [sys.executable, "-m", "vllm.entrypoints.openai.api_server"] + + args, env=env, stdout=sys.stdout, stderr=sys.stderr, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 5f72e26550f45..4675c1d642387 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -100,18 +100,22 @@ class ResponseFormat(OpenAIBaseModel): # type must be "json_object" or "text" type: Literal["text", "json_object"] + class FunctionDefinition(OpenAIBaseModel): name: str description: Optional[str] = None parameters: Optional[Dict[str, Any]] = None + class ChatCompletionToolsParam(OpenAIBaseModel): type: Literal["function"] = "function" function: FunctionDefinition + class ChatCompletionNamedFunction(OpenAIBaseModel): name: str + class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): function: ChatCompletionNamedFunction type: Literal["function"] = "function" @@ -264,6 +268,10 @@ def check_guided_decoding_count(cls, data): "guided_regex" in data and data["guided_regex"] is not None, "guided_choice" in data and data["guided_choice"] is not None ]) + if guide_count > 1 and "tool_choice" in data and data[ + "tool_choice"] != "none": + raise ValueError( + "You can only either use guided decoding or tools, not both.") if guide_count > 1: raise ValueError( "You can only use one kind of guided decoding " @@ -507,11 +515,13 @@ class FunctionCall(OpenAIBaseModel): name: str arguments: str + class ToolCall(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") type: Literal["function"] = "function" function: FunctionCall + class ChatMessage(OpenAIBaseModel): role: str content: str diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 0558d6c95d97b..35f3806c16bf0 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -1,7 +1,8 @@ from typing import Optional, Union -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - CompletionRequest) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, + CompletionRequest) from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( get_lm_format_enforcer_guided_decoding_logits_processor) from vllm.model_executor.guided_decoding.outlines_decoding import ( @@ -13,6 +14,8 @@ async def get_guided_decoding_logits_processor( guided_decoding_backend: str, request: Union[CompletionRequest, ChatCompletionRequest], tokenizer) -> Optional[LogitsProcessor]: + request = _adapt_request_for_tool_use(request) + if guided_decoding_backend == 'outlines': return await get_outlines_guided_decoding_logits_processor( request, tokenizer) @@ -23,3 +26,33 @@ async def get_guided_decoding_logits_processor( raise ValueError( f"Unknown guided decoding backend '{guided_decoding_backend}'. " "Must be one of 'outlines, 'lm-format-enforcer'") + + +def _adapt_request_for_tool_use(request: Union[CompletionRequest, + ChatCompletionRequest]): + # the legacy completion API does not support tool use + if type(request) == CompletionRequest: + return request + + # user has chosen to not use any tool + if request.tool_choice == "none": + return request + + if request.tool_choice == "auto": + raise ValueError("Tool choice 'auto' is not yet supported by vLLM.") + + if request.tool_choice == "required": + raise ValueError( + "Tool choice 'required' is not yet supported by vLLM.") + + if type(request.tool_choice) == ChatCompletionNamedToolChoiceParam: + tool_name = request.tool_choice.function.name + tools = {tool.function.name: tool.function for tool in request.tools} + if not tool_name in tools: + raise ValueError( + f"Tool '{tool_name}' has not been passed in the `tools` parameter." + ) + tool = tools[tool_name] + request.guided_json = tool.parameters + + return request From 193e6ec4247ea614f31ebfb576635a10f95d4801 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Fri, 24 May 2024 14:59:01 +0200 Subject: [PATCH 07/25] fix formatting complaint --- vllm/model_executor/guided_decoding/__init__.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 35f3806c16bf0..8c1ef59434f48 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -48,10 +48,9 @@ def _adapt_request_for_tool_use(request: Union[CompletionRequest, if type(request.tool_choice) == ChatCompletionNamedToolChoiceParam: tool_name = request.tool_choice.function.name tools = {tool.function.name: tool.function for tool in request.tools} - if not tool_name in tools: + if tool_name not in tools: raise ValueError( - f"Tool '{tool_name}' has not been passed in the `tools` parameter." - ) + f"Tool '{tool_name}' has not been passed in `tools`.") tool = tools[tool_name] request.guided_json = tool.parameters From 46d5f2753d3a61c1eaf08ba0c7c5ed686bd116d9 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Fri, 24 May 2024 15:58:34 +0200 Subject: [PATCH 08/25] correct output format and support streaming --- tests/entrypoints/test_openai_server.py | 33 ++++++++++++++++++----- vllm/entrypoints/openai/serving_chat.py | 35 +++++++++++++++++++++---- 2 files changed, 56 insertions(+), 12 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 8d261c1daa7fd..ec0b68c6b1694 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -747,6 +747,9 @@ async def test_named_tool_use(server, client: openai.AsyncOpenAI, f"Give an example JSON for an employee profile that " f"fits this schema: {TEST_SCHEMA}" }] + + # non-streaming + chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, @@ -766,8 +769,8 @@ async def test_named_tool_use(server, client: openai.AsyncOpenAI, } }) message = chat_completion.choices[0].message - assert message.content is not None - json1 = json.loads(message.content) + assert len(message.content) == 0 + json1 = json.loads(message.tool_calls[0].function.arguments) jsonschema.validate(instance=json1, schema=TEST_SCHEMA) messages.append({"role": "assistant", "content": message.content}) @@ -777,7 +780,10 @@ async def test_named_tool_use(server, client: openai.AsyncOpenAI, "content": "Give me another one with a different name and age" }) - chat_completion = await client.chat.completions.create( + + # streaming + + stream = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_tokens=1000, @@ -794,10 +800,23 @@ async def test_named_tool_use(server, client: openai.AsyncOpenAI, "function": { "name": "dummy_function_name" } - }) - message = chat_completion.choices[0].message - assert message.content is not None - json2 = json.loads(message.content) + }, + stream=True) + + output = [] + finish_reason_count = 0 + async for chunk in stream: + delta = chunk.choices[0].delta + if delta.role: + assert delta.role == "assistant" + assert delta.content == None or len(delta.content) == 0 + if delta.tool_calls: + output.append(delta.tool_calls[0].function.arguments) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + # finish reason should only return in last block + assert finish_reason_count == 1 + json2 = json.loads("".join(output)) jsonschema.validate(instance=json2, schema=TEST_SCHEMA) assert json1["name"] != json2["name"] assert json1["age"] != json2["age"] diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 15ce7049e58e3..b58fde3f0a29a 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -12,10 +12,11 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionContentPartParam, ChatCompletionLogProb, ChatCompletionLogProbs, ChatCompletionLogProbsContent, - ChatCompletionMessageParam, ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam, + ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, - UsageInfo) + FunctionCall, ToolCall, UsageInfo) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) from vllm.logger import init_logger @@ -290,11 +291,22 @@ async def chat_completion_stream_generator( delta_text = output.text[len(previous_texts[i]):] previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) + is_function_call = request.tool_choice and type( + request.tool_choice + ) is ChatCompletionNamedToolChoiceParam + delta_message = DeltaMessage( + content=delta_text + ) if not is_function_call else DeltaMessage(tool_calls=[ + ToolCall(function=FunctionCall( + name=request.tool_choice.function.name, + arguments=delta_text)) + ]) if output.finish_reason is None: # Send token-by-token response for each request.n + choice_data = ChatCompletionResponseStreamChoice( index=i, - delta=DeltaMessage(content=delta_text), + delta=delta_message, logprobs=logprobs, finish_reason=None) chunk = ChatCompletionStreamResponse( @@ -316,7 +328,7 @@ async def chat_completion_stream_generator( ) choice_data = ChatCompletionResponseStreamChoice( index=i, - delta=DeltaMessage(content=delta_text), + delta=delta_message, logprobs=logprobs, finish_reason=output.finish_reason) chunk = ChatCompletionStreamResponse( @@ -372,9 +384,22 @@ async def chat_completion_full_generator( else: logprobs = None + if request.tool_choice and type( + request.tool_choice) is ChatCompletionNamedToolChoiceParam: + message = ChatMessage( + role=role, + content="", + tool_calls=[ + ToolCall(function=FunctionCall( + name=request.tool_choice.function.name, + arguments=output.text)) + ]) + elif not request.tool_choice or request.tool_choice == "none": + message = ChatMessage(role=role, content=output.text) + choice_data = ChatCompletionResponseChoice( index=output.index, - message=ChatMessage(role=role, content=output.text), + message=message, logprobs=logprobs, finish_reason=output.finish_reason) choices.append(choice_data) From b59e1b3fe6dc93bdf2bdf67efc435031fe39fdf8 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Fri, 24 May 2024 15:59:54 +0200 Subject: [PATCH 09/25] fix ruff complaint --- tests/entrypoints/test_openai_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index ec0b68c6b1694..129942d76cec5 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -809,7 +809,7 @@ async def test_named_tool_use(server, client: openai.AsyncOpenAI, delta = chunk.choices[0].delta if delta.role: assert delta.role == "assistant" - assert delta.content == None or len(delta.content) == 0 + assert delta.content is None or len(delta.content) == 0 if delta.tool_calls: output.append(delta.tool_calls[0].function.arguments) if chunk.choices[0].finish_reason is not None: From f0dc5b8fbfcb5707b75bafdfa7181ca6de752691 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Fri, 24 May 2024 16:03:02 +0200 Subject: [PATCH 10/25] fix mypy complaint --- vllm/entrypoints/openai/serving_chat.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index b58fde3f0a29a..1e419bb9b8c42 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -291,16 +291,18 @@ async def chat_completion_stream_generator( delta_text = output.text[len(previous_texts[i]):] previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) - is_function_call = request.tool_choice and type( - request.tool_choice - ) is ChatCompletionNamedToolChoiceParam - delta_message = DeltaMessage( - content=delta_text - ) if not is_function_call else DeltaMessage(tool_calls=[ - ToolCall(function=FunctionCall( - name=request.tool_choice.function.name, - arguments=delta_text)) - ]) + + if request.tool_choice and type( + request.tool_choice + ) is ChatCompletionNamedToolChoiceParam: + delta_message = DeltaMessage(tool_calls=[ + ToolCall(function=FunctionCall( + name=request.tool_choice.function.name, + arguments=delta_text)) + ]) + else: + delta_message = DeltaMessage(content=delta_text) + if output.finish_reason is None: # Send token-by-token response for each request.n From 80e66cf1f82d63e48f21f503afe0cbd6d5b92994 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Tue, 28 May 2024 08:58:55 +0200 Subject: [PATCH 11/25] reverting removal of --- vllm/entrypoints/openai/protocol.py | 2 ++ vllm/entrypoints/openai/serving_chat.py | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 1b5461d9935f6..8e66ce1646354 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -511,6 +511,7 @@ class ChatCompletionResponseChoice(OpenAIBaseModel): message: ChatMessage logprobs: Optional[ChatCompletionLogProbs] = None finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None + stop_reason: Optional[Union[int, str]] = None class ChatCompletionResponse(OpenAIBaseModel): @@ -532,6 +533,7 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel): delta: DeltaMessage logprobs: Optional[ChatCompletionLogProbs] = None finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None + stop_reason: Optional[Union[int, str]] = None class ChatCompletionStreamResponse(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 15ce7049e58e3..6cec55aecd2f3 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -318,7 +318,8 @@ async def chat_completion_stream_generator( index=i, delta=DeltaMessage(content=delta_text), logprobs=logprobs, - finish_reason=output.finish_reason) + finish_reason=output.finish_reason, + stop_reason=output.stop_reason) chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, @@ -376,7 +377,8 @@ async def chat_completion_full_generator( index=output.index, message=ChatMessage(role=role, content=output.text), logprobs=logprobs, - finish_reason=output.finish_reason) + finish_reason=output.finish_reason, + stop_reason=output.stop_reason) choices.append(choice_data) if request.echo: From 3ca5fce5bf1ea6898060aef2d4e99077c8dda3b5 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Tue, 28 May 2024 08:59:56 +0200 Subject: [PATCH 12/25] =?UTF-8?q?refactoring=20=E2=80=93=20move=20'create?= =?UTF-8?q?=5Flogprobs'=20for=20completion=20out=20of=20serving=5Fengine.p?= =?UTF-8?q?y?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vllm/entrypoints/openai/serving_completion.py | 45 ++++++++++++++++++ vllm/entrypoints/openai/serving_engine.py | 47 +------------------ 2 files changed, 46 insertions(+), 46 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 158d8ed7fbbf5..b3a21d3c32211 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -18,6 +18,7 @@ from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.outputs import RequestOutput +from vllm.sequence import Logprob from vllm.utils import merge_async_iterators, random_uuid logger = init_logger(__name__) @@ -346,3 +347,47 @@ def request_output_to_completion_response( choices=choices, usage=usage, ) + + def _create_logprobs( + self, + token_ids: List[int], + top_logprobs: List[Optional[Dict[int, Logprob]]], + num_output_top_logprobs: Optional[int] = None, + initial_text_offset: int = 0, + ) -> LogProbs: + """Create OpenAI-style logprobs.""" + logprobs = LogProbs() + last_token_len = 0 + if num_output_top_logprobs: + logprobs.top_logprobs = [] + + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is None: + token = self.tokenizer.decode(token_id) + logprobs.tokens.append(token) + logprobs.token_logprobs.append(None) + assert logprobs.top_logprobs is not None + logprobs.top_logprobs.append(None) + else: + token_logprob = step_top_logprobs[token_id].logprob + token = step_top_logprobs[token_id].decoded_token + logprobs.tokens.append(token) + logprobs.token_logprobs.append(token_logprob) + + if num_output_top_logprobs: + assert logprobs.top_logprobs is not None + logprobs.top_logprobs.append({ + # Convert float("-inf") to the + # JSON-serializable float that OpenAI uses + p.decoded_token: max(p.logprob, -9999.0) + for i, p in step_top_logprobs.items() + } if step_top_logprobs else None) + + if len(logprobs.text_offset) == 0: + logprobs.text_offset.append(initial_text_offset) + else: + logprobs.text_offset.append(logprobs.text_offset[-1] + + last_token_len) + last_token_len = len(token) + return logprobs diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index db3fc85decd70..b161b1d85db2e 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -11,11 +11,10 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest, EmbeddingRequest, ErrorResponse, - LogProbs, ModelCard, ModelList, + ModelCard, ModelList, ModelPermission) from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import get_tokenizer logger = init_logger(__name__) @@ -75,50 +74,6 @@ async def show_available_models(self) -> ModelList: model_cards.extend(lora_cards) return ModelList(data=model_cards) - def _create_logprobs( - self, - token_ids: List[int], - top_logprobs: List[Optional[Dict[int, Logprob]]], - num_output_top_logprobs: Optional[int] = None, - initial_text_offset: int = 0, - ) -> LogProbs: - """Create OpenAI-style logprobs.""" - logprobs = LogProbs() - last_token_len = 0 - if num_output_top_logprobs: - logprobs.top_logprobs = [] - - for i, token_id in enumerate(token_ids): - step_top_logprobs = top_logprobs[i] - if step_top_logprobs is None: - token = self.tokenizer.decode(token_id) - logprobs.tokens.append(token) - logprobs.token_logprobs.append(None) - assert logprobs.top_logprobs is not None - logprobs.top_logprobs.append(None) - else: - token_logprob = step_top_logprobs[token_id].logprob - token = step_top_logprobs[token_id].decoded_token - logprobs.tokens.append(token) - logprobs.token_logprobs.append(token_logprob) - - if num_output_top_logprobs: - assert logprobs.top_logprobs is not None - logprobs.top_logprobs.append({ - # Convert float("-inf") to the - # JSON-serializable float that OpenAI uses - p.decoded_token: max(p.logprob, -9999.0) - for i, p in step_top_logprobs.items() - } if step_top_logprobs else None) - - if len(logprobs.text_offset) == 0: - logprobs.text_offset.append(initial_text_offset) - else: - logprobs.text_offset.append(logprobs.text_offset[-1] + - last_token_len) - last_token_len = len(token) - return logprobs - def create_error_response( self, message: str, From 06519c7dd233c95b9ecc0911386d55844e5e04c3 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Tue, 28 May 2024 09:00:25 +0200 Subject: [PATCH 13/25] fix formatting --- vllm/entrypoints/openai/serving_completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index b3a21d3c32211..5da2c41d8e198 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -347,7 +347,7 @@ def request_output_to_completion_response( choices=choices, usage=usage, ) - + def _create_logprobs( self, token_ids: List[int], From 3c5457ac8873b673d28a143154be79f51ca85313 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Tue, 28 May 2024 14:30:53 +0200 Subject: [PATCH 14/25] adding changes after review from @DarkLight1337 --- tests/entrypoints/test_openai_server.py | 190 ++++++++++++++++-- vllm/entrypoints/openai/protocol.py | 32 ++- vllm/entrypoints/openai/serving_chat.py | 14 +- vllm/entrypoints/openai/serving_completion.py | 82 ++++---- vllm/entrypoints/openai/serving_engine.py | 6 + 5 files changed, 255 insertions(+), 69 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index fa43b064c4ab6..2c026d3406948 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -178,6 +178,25 @@ async def test_single_completion(server, client: openai.AsyncOpenAI, completion.choices[0].text) >= 5 +@pytest.mark.parametrize( + # first test base model, then test loras + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], +) +async def test_no_logprobs(server, client: openai.AsyncOpenAI, + model_name: str): + # test using token IDs + completion = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=None, + ) + choice = completion.choices[0] + assert choice.logprobs is None + + @pytest.mark.parametrize( # first test base model, then test loras "model_name", @@ -196,7 +215,70 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI, choice = completion.choices[0] assert choice.logprobs is not None assert choice.logprobs.token_logprobs is not None - assert choice.logprobs.top_logprobs is None + assert choice.logprobs.top_logprobs is not None + assert len(choice.logprobs.top_logprobs[0]) <= 1 + + +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_some_logprobs(server, client: openai.AsyncOpenAI, + model_name: str): + # test using token IDs + completion = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=5, + ) + choice = completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.token_logprobs is not None + assert choice.logprobs.top_logprobs is not None + assert len(choice.logprobs.top_logprobs[0]) <= 6 + + +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI, + model_name: str): + + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=6, + ) + ... + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + stream = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=6, + stream=True, + ) + async for chunk in stream: + ... + + # the server should still work afterwards + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + completion = completion.choices[0].text + assert completion is not None and len(completion) >= 0 @pytest.mark.parametrize( @@ -245,9 +327,89 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.parametrize( + # first test base model, then test loras + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], +) +async def test_no_logprobs_chat(server, client: openai.AsyncOpenAI, + model_name: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=5, + temperature=0.0, + logprobs=False) + + choice = chat_completion.choices[0] + assert choice.logprobs is None + + +@pytest.mark.parametrize( + # just test 1 lora hereafter + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_zero_logprobs_chat(server, client: openai.AsyncOpenAI, + model_name: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=5, + temperature=0.0, + logprobs=True, + top_logprobs=0) + + choice = chat_completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.content is not None + assert len(choice.logprobs.content[0].top_logprobs) <= 1 + + +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_some_logprobs_chat(server, client: openai.AsyncOpenAI, + model_name: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=5, + temperature=0.0, + logprobs=True, + top_logprobs=5) + + choice = chat_completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.content is not None + assert len(choice.logprobs.content[0].top_logprobs) <= 6 + + @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, - model_name: str): +async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI, + model_name: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -256,13 +418,13 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, "content": "what is 1+1?" }] - # Default max_logprobs is 5, so this should raise an error + # Default max_logprobs is 20, so this should raise an error with pytest.raises((openai.BadRequestError, openai.APIError)): stream = await client.chat.completions.create(model=model_name, messages=messages, max_tokens=10, logprobs=True, - top_logprobs=10, + top_logprobs=21, stream=True) async for chunk in stream: ... @@ -272,25 +434,9 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, messages=messages, max_tokens=10, logprobs=True, - top_logprobs=10, + top_logprobs=30, stream=False) - with pytest.raises((openai.BadRequestError, openai.APIError)): - stream = await client.completions.create(model=model_name, - prompt="Test", - max_tokens=10, - logprobs=10, - stream=True) - async for chunk in stream: - ... - - with pytest.raises(openai.BadRequestError): - await client.completions.create(model=model_name, - prompt="Test", - max_tokens=10, - logprobs=10, - stream=False) - # the server should still work afterwards chat_completion = await client.chat.completions.create(model=model_name, messages=messages, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 8e66ce1646354..d59ce1252d610 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -192,8 +192,6 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: end-chat-completion-extra-params def to_sampling_params(self) -> SamplingParams: - if self.logprobs and not self.top_logprobs: - raise ValueError("Top logprobs must be set when logprobs is.") logits_processors = None if self.logit_bias: @@ -251,6 +249,19 @@ def check_guided_decoding_count(cls, data): "('guided_json', 'guided_regex' or 'guided_choice').") return data + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if "top_logprobs" in data and data["top_logprobs"] is not None: + if "logprobs" not in data or data["logprobs"] is False: + raise ValueError( + "when using `top_logprobs`, `logprobs` must be set to true." + ) + elif not 0 <= data["top_logprobs"] <= 20: + raise ValueError( + "`top_logprobs` must be a value in the interval [0, 20].") + return data + class CompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation @@ -397,6 +408,15 @@ def check_guided_decoding_count(cls, data): "('guided_json', 'guided_regex' or 'guided_choice').") return data + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if "logprobs" in data and data[ + "logprobs"] is not None and not 0 <= data["logprobs"] <= 5: + raise ValueError(("if passed, `logprobs` must be a value", + " in the interval [0, 5].")) + return data + class EmbeddingRequest(BaseModel): # Ordered by official OpenAI API documentation @@ -416,7 +436,7 @@ def to_pooling_params(self): return PoolingParams(additional_data=self.additional_data) -class LogProbs(OpenAIBaseModel): +class CompletionLogProbs(OpenAIBaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list) @@ -426,7 +446,7 @@ class LogProbs(OpenAIBaseModel): class CompletionResponseChoice(OpenAIBaseModel): index: int text: str - logprobs: Optional[LogProbs] = None + logprobs: Optional[CompletionLogProbs] = None finish_reason: Optional[str] = None stop_reason: Optional[Union[int, str]] = Field( default=None, @@ -449,7 +469,7 @@ class CompletionResponse(OpenAIBaseModel): class CompletionResponseStreamChoice(OpenAIBaseModel): index: int text: str - logprobs: Optional[LogProbs] = None + logprobs: Optional[CompletionLogProbs] = None finish_reason: Optional[str] = None stop_reason: Optional[Union[int, str]] = Field( default=None, @@ -503,7 +523,7 @@ class ChatCompletionLogProbsContent(OpenAIBaseModel): class ChatCompletionLogProbs(OpenAIBaseModel): - content: List[ChatCompletionLogProbsContent] = Field(default_factory=list) + content: Optional[List[ChatCompletionLogProbsContent]] = None class ChatCompletionResponseChoice(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 6cec55aecd2f3..4f7ea13eaac62 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -414,9 +414,11 @@ def _get_top_logprobs( top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]: return [ ChatCompletionLogProb( - token=p.decoded_token, + token=self._get_decoded_token_from_logprob(p), logprob=max(p.logprob, -9999.0), - bytes=list(p.decoded_token.encode("utf-8", errors="replace"))) + bytes=list( + self._get_decoded_token_from_logprob(p).encode( + "utf-8", errors="replace"))) for i, p in enumerate(logprobs.values()) if top_logprobs and i < top_logprobs ] @@ -429,19 +431,19 @@ def _create_chat_logprobs( ) -> ChatCompletionLogProbs: """Create OpenAI-style logprobs.""" - logprobs = ChatCompletionLogProbs() + logprobs_content = [] for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] if step_top_logprobs is None: - logprobs.content.append( + logprobs_content.append( ChatCompletionLogProbsContent( token=self.tokenizer.decode(token_id), bytes=list( self.tokenizer.decode(token_id).encode( "utf-8", errors="replace")))) else: - logprobs.content.append( + logprobs_content.append( ChatCompletionLogProbsContent( token=step_top_logprobs[token_id].decoded_token, logprob=max(step_top_logprobs[token_id].logprob, @@ -452,4 +454,4 @@ def _create_chat_logprobs( top_logprobs=self._get_top_logprobs( step_top_logprobs, num_output_top_logprobs))) - return logprobs + return ChatCompletionLogProbs(content=logprobs_content) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 5da2c41d8e198..76a1fcc443189 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -6,12 +6,13 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.entrypoints.openai.protocol import (CompletionRequest, +from vllm.entrypoints.openai.protocol import (CompletionLogProbs, + CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, - LogProbs, UsageInfo) + UsageInfo) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) from vllm.logger import init_logger @@ -26,7 +27,7 @@ TypeTokenIDs = List[int] TypeTopLogProbs = List[Optional[Dict[int, float]]] TypeCreateLogProbsFn = Callable[ - [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs] + [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs] def parse_prompt_format(prompt) -> Tuple[bool, list]: @@ -231,7 +232,7 @@ async def completion_stream_generator( i]:] if output.logprobs else None if request.logprobs is not None: - logprobs = self._create_logprobs( + logprobs = self._create_completion_logprobs( token_ids=delta_token_ids, top_logprobs=top_logprobs, num_output_top_logprobs=request.logprobs, @@ -313,7 +314,7 @@ def request_output_to_completion_response( assert top_logprobs is not None, ( "top_logprobs must be provided when logprobs " "is requested") - logprobs = self._create_logprobs( + logprobs = self._create_completion_logprobs( token_ids=token_ids, top_logprobs=top_logprobs, num_output_top_logprobs=request.logprobs, @@ -348,46 +349,57 @@ def request_output_to_completion_response( usage=usage, ) - def _create_logprobs( + def _create_completion_logprobs( self, token_ids: List[int], top_logprobs: List[Optional[Dict[int, Logprob]]], - num_output_top_logprobs: Optional[int] = None, + num_output_top_logprobs: int, initial_text_offset: int = 0, - ) -> LogProbs: - """Create OpenAI-style logprobs.""" - logprobs = LogProbs() + ) -> CompletionLogProbs: + """Create logprobs for OpenAI Completion API.""" + out_text_offset: List[int] = [] + out_token_logprobs: List[Optional[float]] = [] + out_tokens: List[str] = [] + out_top_logprobs: List[Optional[Dict[str, float]]] = [] + last_token_len = 0 - if num_output_top_logprobs: - logprobs.top_logprobs = [] for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] if step_top_logprobs is None: token = self.tokenizer.decode(token_id) - logprobs.tokens.append(token) - logprobs.token_logprobs.append(None) - assert logprobs.top_logprobs is not None - logprobs.top_logprobs.append(None) + out_tokens.append(token) + out_token_logprobs.append(None) + out_top_logprobs.append(None) else: - token_logprob = step_top_logprobs[token_id].logprob - token = step_top_logprobs[token_id].decoded_token - logprobs.tokens.append(token) - logprobs.token_logprobs.append(token_logprob) - - if num_output_top_logprobs: - assert logprobs.top_logprobs is not None - logprobs.top_logprobs.append({ - # Convert float("-inf") to the - # JSON-serializable float that OpenAI uses - p.decoded_token: max(p.logprob, -9999.0) - for i, p in step_top_logprobs.items() - } if step_top_logprobs else None) - - if len(logprobs.text_offset) == 0: - logprobs.text_offset.append(initial_text_offset) + token = self._get_decoded_token_from_logprob( + step_top_logprobs[token_id]) + token_logprob = max(step_top_logprobs[token_id].logprob, + -9999.0) + out_tokens.append(token) + out_token_logprobs.append(token_logprob) + + # makes sure to add the top num_output_top_logprobs + 1 + # logprobs, as defined in the openai API + # (cf. https://github.com/openai/openai-openapi/blob/ + # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153) + out_top_logprobs.append({ + # Convert float("-inf") to the + # JSON-serializable float that OpenAI uses + self._get_decoded_token_from_logprob(top_lp[1]): + max(top_lp[1].logprob, -9999.0) + for i, top_lp in enumerate(step_top_logprobs.items()) + if num_output_top_logprobs >= i + }) + + if len(out_text_offset) == 0: + out_text_offset.append(initial_text_offset) else: - logprobs.text_offset.append(logprobs.text_offset[-1] + - last_token_len) + out_text_offset.append(out_text_offset[-1] + last_token_len) last_token_len = len(token) - return logprobs + return CompletionLogProbs( + text_offset=out_text_offset, + token_logprobs=out_token_logprobs, + tokens=out_tokens, + top_logprobs=out_top_logprobs, + ) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index b161b1d85db2e..b13eb008b91db 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -15,6 +15,7 @@ ModelPermission) from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import get_tokenizer logger = init_logger(__name__) @@ -187,3 +188,8 @@ def _validate_prompt_and_tokenize( f"Please reduce the length of the messages or completion.", ) else: return input_ids, input_text + + def _get_decoded_token_from_logprob(self, logprob: Logprob) -> str: + if logprob.decoded_token is not None: + return logprob.decoded_token + return self.tokenizer.decode(logprob.token_id) From c37d5a970b5ac13f8274fea38c90acdd86b743bd Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Tue, 28 May 2024 21:37:09 +0200 Subject: [PATCH 15/25] review iteration 2 --- vllm/entrypoints/openai/protocol.py | 7 ++----- vllm/entrypoints/openai/serving_chat.py | 19 +++++++++++-------- vllm/entrypoints/openai/serving_completion.py | 15 +++++++++------ vllm/entrypoints/openai/serving_engine.py | 4 ++-- 4 files changed, 24 insertions(+), 21 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index d59ce1252d610..611b8f17ae4cf 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -511,14 +511,11 @@ class ChatMessage(OpenAIBaseModel): class ChatCompletionLogProb(OpenAIBaseModel): token: str - logprob: float + logprob: float = -9999.0 bytes: Optional[List[int]] = None -class ChatCompletionLogProbsContent(OpenAIBaseModel): - token: str - logprob: float = -9999.0 - bytes: Optional[List[int]] = None +class ChatCompletionLogProbsContent(ChatCompletionLogProb): top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 4f7ea13eaac62..d60371e932733 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -2,7 +2,9 @@ import time from dataclasses import dataclass from typing import (AsyncGenerator, AsyncIterator, Dict, Iterable, List, - Optional, TypedDict, Union, cast, final) + Optional) +from typing import Sequence as GenericSequence +from typing import TypedDict, Union, cast, final from fastapi import Request from openai.types.chat import ChatCompletionContentPartTextParam @@ -414,19 +416,20 @@ def _get_top_logprobs( top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]: return [ ChatCompletionLogProb( - token=self._get_decoded_token_from_logprob(p), - logprob=max(p.logprob, -9999.0), + token=self._get_decoded_token(p[1], p[0]), + logprob=max(p[1].logprob, -9999.0), bytes=list( - self._get_decoded_token_from_logprob(p).encode( - "utf-8", errors="replace"))) - for i, p in enumerate(logprobs.values()) + self._get_decoded_token(p[1], + p[0]).encode("utf-8", + errors="replace"))) + for i, p in enumerate(logprobs.items()) if top_logprobs and i < top_logprobs ] def _create_chat_logprobs( self, - token_ids: List[int], - top_logprobs: List[Optional[Dict[int, Logprob]]], + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], num_output_top_logprobs: Optional[int] = None, ) -> ChatCompletionLogProbs: """Create OpenAI-style logprobs.""" diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 76a1fcc443189..7bc5a127fa16a 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -1,6 +1,8 @@ import time from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, - Optional, Tuple) + Optional) +from typing import Sequence as GenericSequence +from typing import Tuple from fastapi import Request @@ -351,8 +353,8 @@ def request_output_to_completion_response( def _create_completion_logprobs( self, - token_ids: List[int], - top_logprobs: List[Optional[Dict[int, Logprob]]], + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], num_output_top_logprobs: int, initial_text_offset: int = 0, ) -> CompletionLogProbs: @@ -372,8 +374,8 @@ def _create_completion_logprobs( out_token_logprobs.append(None) out_top_logprobs.append(None) else: - token = self._get_decoded_token_from_logprob( - step_top_logprobs[token_id]) + token = self._get_decoded_token(step_top_logprobs[token_id], + token_id) token_logprob = max(step_top_logprobs[token_id].logprob, -9999.0) out_tokens.append(token) @@ -386,7 +388,7 @@ def _create_completion_logprobs( out_top_logprobs.append({ # Convert float("-inf") to the # JSON-serializable float that OpenAI uses - self._get_decoded_token_from_logprob(top_lp[1]): + self._get_decoded_token(top_lp[1], top_lp[0]): max(top_lp[1].logprob, -9999.0) for i, top_lp in enumerate(step_top_logprobs.items()) if num_output_top_logprobs >= i @@ -397,6 +399,7 @@ def _create_completion_logprobs( else: out_text_offset.append(out_text_offset[-1] + last_token_len) last_token_len = len(token) + return CompletionLogProbs( text_offset=out_text_offset, token_logprobs=out_token_logprobs, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index b13eb008b91db..f0fa612d31b0e 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -189,7 +189,7 @@ def _validate_prompt_and_tokenize( else: return input_ids, input_text - def _get_decoded_token_from_logprob(self, logprob: Logprob) -> str: + def _get_decoded_token(self, logprob: Logprob, token_id: int) -> str: if logprob.decoded_token is not None: return logprob.decoded_token - return self.tokenizer.decode(logprob.token_id) + return self.tokenizer.decode(token_id) From adcdc3118c3d3078135623fd201fe7fb67aeb736 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Tue, 28 May 2024 21:49:44 +0200 Subject: [PATCH 16/25] =?UTF-8?q?formatting=20=E2=80=93=20isort=20breaks?= =?UTF-8?q?=20it=20again..=3F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vllm/entrypoints/openai/serving_completion.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 7bc5a127fa16a..cfa2e8ae9f177 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,13 +8,10 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.entrypoints.openai.protocol import (CompletionLogProbs, - CompletionRequest, - CompletionResponse, - CompletionResponseChoice, - CompletionResponseStreamChoice, - CompletionStreamResponse, - UsageInfo) +from vllm.entrypoints.openai.protocol import ( + CompletionLogProbs, CompletionRequest, CompletionResponse, + CompletionResponseChoice, CompletionResponseStreamChoice, + CompletionStreamResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) from vllm.logger import init_logger From 91b4cfa4d049fae7a8732c7dc6853fe97b833704 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Wed, 29 May 2024 09:27:10 +0200 Subject: [PATCH 17/25] disable yapf in import to avoid conflict with isort --- vllm/entrypoints/openai/serving_completion.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index cfa2e8ae9f177..e52442b62c0e9 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -8,10 +8,15 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.entrypoints.openai.protocol import ( - CompletionLogProbs, CompletionRequest, CompletionResponse, - CompletionResponseChoice, CompletionResponseStreamChoice, - CompletionStreamResponse, UsageInfo) +# yapf: disable +from vllm.entrypoints.openai.protocol import (CompletionLogProbs, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionResponseStreamChoice, + CompletionStreamResponse, + UsageInfo) +# yapf: enable from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) from vllm.logger import init_logger From 496fb25e1bb04aaf35adfbe1432e8aa9a233b0d9 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Wed, 29 May 2024 09:47:24 +0200 Subject: [PATCH 18/25] fix formatting --- tests/entrypoints/test_openai_server.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 6604ea13ffd91..972137030f46f 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -183,6 +183,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI, assert completion.choices[0].text is not None and len( completion.choices[0].text) >= 5 + @pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras @@ -202,6 +203,7 @@ async def test_no_logprobs(server, client: openai.AsyncOpenAI, choice = completion.choices[0] assert choice.logprobs is None + @pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras @@ -224,6 +226,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI, assert choice.logprobs.top_logprobs is not None assert len(choice.logprobs.top_logprobs[0]) <= 1 + @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", From e7c74507d12646f1b1ddf9cd1eda3731f836024e Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Thu, 30 May 2024 21:40:03 +0200 Subject: [PATCH 19/25] remove tool_choice 'required' --- vllm/entrypoints/openai/protocol.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 5369c25093cd5..f1408592a8e2e 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -142,7 +142,7 @@ class ChatCompletionRequest(OpenAIBaseModel): temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 tools: Optional[List[ChatCompletionToolsParam]] = None - tool_choice: Optional[Union[Literal["none", "required"], + tool_choice: Optional[Union[Literal["none"], ChatCompletionNamedToolChoiceParam]] = "none" user: Optional[str] = None @@ -267,14 +267,28 @@ def check_guided_decoding_count(cls, data): "guided_regex" in data and data["guided_regex"] is not None, "guided_choice" in data and data["guided_choice"] is not None ]) - if guide_count > 1 and "tool_choice" in data and data[ - "tool_choice"] != "none": - raise ValueError( - "You can only either use guided decoding or tools, not both.") + # you can only use one kind of guided decoding if guide_count > 1: raise ValueError( "You can only use one kind of guided decoding " "('guided_json', 'guided_regex' or 'guided_choice').") + # you can only either use guided decoding or tools, not both + if guide_count > 1 and "tool_choice" in data and data[ + "tool_choice"] != "none": + raise ValueError( + "You can only either use guided decoding or tools, not both.") + return data + + @model_validator(mode="before") + @classmethod + def check_tool_choice(cls, data): + if "tool_choice" in data and data["tool_choice"] != "none": + if type(data["tool_choice"] + ) is not ChatCompletionNamedToolChoiceParam: + raise ValueError("Currently only named tools are supported.") + if "tools" not in data or data["tools"] is None: + raise ValueError( + "When using `tool_choice`, `tools` must be set.") return data @model_validator(mode="before") From b77e60a046eeda8442e1d91cafb1ae951d02bca7 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Thu, 30 May 2024 21:54:33 +0200 Subject: [PATCH 20/25] add sad path test --- tests/entrypoints/test_openai_server.py | 47 +++++++++++++++++++++++++ vllm/entrypoints/openai/protocol.py | 3 +- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 1af9317554d51..6892b2d343ec7 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -996,6 +996,53 @@ async def test_named_tool_use(server, client: openai.AsyncOpenAI, assert json1["age"] != json2["age"] +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", ["outlines"]) +async def test_required_tool_use_not_yet_supported( + server, client: openai.AsyncOpenAI, guided_decoding_backend: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": + "user", + "content": + f"Give an example JSON for an employee profile that " + f"fits this schema: {TEST_SCHEMA}" + }] + + with pytest.raises(openai.BadRequestError): + await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=1000, + tools=[{ + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": TEST_SCHEMA + } + }], + tool_choice="required") + ... + + with pytest.raises(openai.BadRequestError): + await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=1000, + tools=[{ + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": TEST_SCHEMA + } + }], + tool_choice="auto") + + @pytest.mark.asyncio async def test_response_format_json_object(server, client: openai.AsyncOpenAI): for _ in range(2): diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index f1408592a8e2e..f62eead82ac0f 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -283,8 +283,7 @@ def check_guided_decoding_count(cls, data): @classmethod def check_tool_choice(cls, data): if "tool_choice" in data and data["tool_choice"] != "none": - if type(data["tool_choice"] - ) is not ChatCompletionNamedToolChoiceParam: + if not isinstance(data["tool_choice"], dict): raise ValueError("Currently only named tools are supported.") if "tools" not in data or data["tools"] is None: raise ValueError( From 9f33687c1a200cebb516e64cea83d3373bc1ef66 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Thu, 30 May 2024 22:04:44 +0200 Subject: [PATCH 21/25] add more sad path tests --- tests/entrypoints/test_openai_server.py | 50 +++++++++++++++++++ .../guided_decoding/__init__.py | 8 +-- 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 6892b2d343ec7..0187856c0ebb4 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -1043,6 +1043,56 @@ async def test_required_tool_use_not_yet_supported( tool_choice="auto") +@pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", ["outlines"]) +async def test_inconsistent_tool_choice_and_tools( + server, client: openai.AsyncOpenAI, guided_decoding_backend: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": + "user", + "content": + f"Give an example JSON for an employee profile that " + f"fits this schema: {TEST_SCHEMA}" + }] + + with pytest.raises(openai.BadRequestError): + await client.chat.completions.create(model=MODEL_NAME, + messages=messages, + max_tokens=1000, + tool_choice={ + "type": "function", + "function": { + "name": + "dummy_function_name" + } + }) + ... + + with pytest.raises(openai.BadRequestError): + await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=1000, + tools=[{ + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": TEST_SCHEMA + } + }], + tool_choice={ + "type": "function", + "function": { + "name": "nondefined_function_name" + } + }) + ... + + @pytest.mark.asyncio async def test_response_format_json_object(server, client: openai.AsyncOpenAI): for _ in range(2): diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 8c1ef59434f48..1f94d30e34052 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -38,13 +38,7 @@ def _adapt_request_for_tool_use(request: Union[CompletionRequest, if request.tool_choice == "none": return request - if request.tool_choice == "auto": - raise ValueError("Tool choice 'auto' is not yet supported by vLLM.") - - if request.tool_choice == "required": - raise ValueError( - "Tool choice 'required' is not yet supported by vLLM.") - + # user has chosen to use a named tool if type(request.tool_choice) == ChatCompletionNamedToolChoiceParam: tool_name = request.tool_choice.function.name tools = {tool.function.name: tool.function for tool in request.tools} From 5f0c3ae8c12e925eb019524c56a3f8ba22383bc6 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Fri, 31 May 2024 11:20:02 +0200 Subject: [PATCH 22/25] fix test --- tests/entrypoints/test_openai_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 0187856c0ebb4..0143e0531f86e 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -947,7 +947,7 @@ async def test_named_tool_use(server, client: openai.AsyncOpenAI, json1 = json.loads(message.tool_calls[0].function.arguments) jsonschema.validate(instance=json1, schema=TEST_SCHEMA) - messages.append({"role": "assistant", "content": message.content}) + messages.append({"role": "assistant", "content": json1}) messages.append({ "role": "user", From 15da872549830a4424f221a8aaadf2d2d2e23071 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Fri, 31 May 2024 17:10:31 +0200 Subject: [PATCH 23/25] fix test --- tests/entrypoints/test_openai_server.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 0143e0531f86e..c9c4f49abb9c2 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -944,10 +944,11 @@ async def test_named_tool_use(server, client: openai.AsyncOpenAI, }) message = chat_completion.choices[0].message assert len(message.content) == 0 - json1 = json.loads(message.tool_calls[0].function.arguments) + json_string = message.tool_calls[0].function.arguments + json1 = json.loads(json_string) jsonschema.validate(instance=json1, schema=TEST_SCHEMA) - messages.append({"role": "assistant", "content": json1}) + messages.append({"role": "assistant", "content": json_string}) messages.append({ "role": "user", From bdf0dcf120e8bc70c840af5fedff4b829a9cee10 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Mon, 3 Jun 2024 21:32:34 +0200 Subject: [PATCH 24/25] after review --- tests/entrypoints/test_openai_server.py | 3 --- vllm/model_executor/guided_decoding/__init__.py | 4 ++-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index c9c4f49abb9c2..bff2487117837 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -1026,7 +1026,6 @@ async def test_required_tool_use_not_yet_supported( } }], tool_choice="required") - ... with pytest.raises(openai.BadRequestError): await client.chat.completions.create( @@ -1070,7 +1069,6 @@ async def test_inconsistent_tool_choice_and_tools( "dummy_function_name" } }) - ... with pytest.raises(openai.BadRequestError): await client.chat.completions.create( @@ -1091,7 +1089,6 @@ async def test_inconsistent_tool_choice_and_tools( "name": "nondefined_function_name" } }) - ... @pytest.mark.asyncio diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 1f94d30e34052..50aa3ec379f4a 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -31,7 +31,7 @@ async def get_guided_decoding_logits_processor( def _adapt_request_for_tool_use(request: Union[CompletionRequest, ChatCompletionRequest]): # the legacy completion API does not support tool use - if type(request) == CompletionRequest: + if type(request) is CompletionRequest: return request # user has chosen to not use any tool @@ -39,7 +39,7 @@ def _adapt_request_for_tool_use(request: Union[CompletionRequest, return request # user has chosen to use a named tool - if type(request.tool_choice) == ChatCompletionNamedToolChoiceParam: + if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam: tool_name = request.tool_choice.function.name tools = {tool.function.name: tool.function for tool in request.tools} if tool_name not in tools: From 37130f7af400fd66f4a201fbe9f51cd2db0e1962 Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Mon, 3 Jun 2024 22:07:08 +0200 Subject: [PATCH 25/25] adding docs for named function calling in tool use --- docs/source/serving/openai_compatible_server.md | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 15a8761eb5738..a912949352b86 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -109,4 +109,15 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/) :module: vllm.entrypoints.openai.cli_args :func: make_arg_parser :prog: -m vllm.entrypoints.openai.api_server -``` \ No newline at end of file +``` + +## Tool calling in the chat completion API +vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap. + +To use a named function you need to define the function in the `tools` parameter and call it in the `tool_choice` parameter. + +It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. **This may change in the future.** + +vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter. + +Please refer to the OpenAI API reference documentation for more information. \ No newline at end of file