Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND] OpenAI tools support named functions #5032

Merged
merged 30 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
099ddb3
WIP added tool classes
br3no May 23, 2024
5934e22
added correct models. Tests still missing
br3no May 23, 2024
1248bc1
fix implementation and tests
br3no May 24, 2024
1b2b453
fix formatting
br3no May 24, 2024
07af0cc
fix test
br3no May 24, 2024
49b560c
Merge branch '5008-chat-logprobs' into 1869-tools-support-step-1
br3no May 24, 2024
755625f
named tool working
br3no May 24, 2024
193e6ec
fix formatting complaint
br3no May 24, 2024
46d5f27
correct output format and support streaming
br3no May 24, 2024
b59e1b3
fix ruff complaint
br3no May 24, 2024
f0dc5b8
fix mypy complaint
br3no May 24, 2024
80e66cf
reverting removal of
br3no May 28, 2024
3ca5fce
refactoring – move 'create_logprobs' for completion out of serving_en…
br3no May 28, 2024
06519c7
fix formatting
br3no May 28, 2024
3c5457a
adding changes after review from @DarkLight1337
br3no May 28, 2024
e388194
Merge branch 'main' into 5008-chat-logprobs
br3no May 28, 2024
c37d5a9
review iteration 2
br3no May 28, 2024
adcdc31
formatting – isort breaks it again..?
br3no May 28, 2024
91b4cfa
disable yapf in import to avoid conflict with isort
br3no May 29, 2024
825e0ad
Merge branch 'main' into 5008-chat-logprobs
br3no May 29, 2024
496fb25
fix formatting
br3no May 29, 2024
111548a
formatting
br3no May 29, 2024
2d59282
Merge branch 'main' into 1869-tools-support-step-1
br3no May 30, 2024
e7c7450
remove tool_choice 'required'
br3no May 30, 2024
b77e60a
add sad path test
br3no May 30, 2024
9f33687
add more sad path tests
br3no May 30, 2024
5f0c3ae
fix test
br3no May 31, 2024
15da872
fix test
br3no May 31, 2024
bdf0dcf
after review
br3no Jun 3, 2024
37130f7
adding docs for named function calling in tool use
br3no Jun 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 188 additions & 0 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,194 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
for token in top_logprobs)


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_named_tool_use(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
}]

# non-streaming

chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
}
}],
tool_choice={
"type": "function",
"function": {
"name": "dummy_function_name"
}
})
message = chat_completion.choices[0].message
assert len(message.content) == 0
json_string = message.tool_calls[0].function.arguments
json1 = json.loads(json_string)
jsonschema.validate(instance=json1, schema=TEST_SCHEMA)

messages.append({"role": "assistant", "content": json_string})
messages.append({
"role":
"user",
"content":
"Give me another one with a different name and age"
})

# streaming

stream = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
}
}],
tool_choice={
"type": "function",
"function": {
"name": "dummy_function_name"
}
},
stream=True)

output = []
finish_reason_count = 0
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.role:
assert delta.role == "assistant"
assert delta.content is None or len(delta.content) == 0
if delta.tool_calls:
output.append(delta.tool_calls[0].function.arguments)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
# finish reason should only return in last block
assert finish_reason_count == 1
json2 = json.loads("".join(output))
jsonschema.validate(instance=json2, schema=TEST_SCHEMA)
assert json1["name"] != json2["name"]
assert json1["age"] != json2["age"]


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
async def test_required_tool_use_not_yet_supported(
server, client: openai.AsyncOpenAI, guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
}]

with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
}
}],
tool_choice="required")
...
br3no marked this conversation as resolved.
Show resolved Hide resolved

with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
}
}],
tool_choice="auto")


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
async def test_inconsistent_tool_choice_and_tools(
server, client: openai.AsyncOpenAI, guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
}]

with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tool_choice={
"type": "function",
"function": {
"name":
"dummy_function_name"
}
})
...
br3no marked this conversation as resolved.
Show resolved Hide resolved

with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
}
}],
tool_choice={
"type": "function",
"function": {
"name": "nondefined_function_name"
}
})
...
br3no marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.asyncio
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
for _ in range(2):
Expand Down
3 changes: 2 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def __init__(self, args):
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
self.proc = subprocess.Popen(
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
[sys.executable, "-m", "vllm.entrypoints.openai.api_server"] +
args,
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
Expand Down
57 changes: 55 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,26 @@ class ResponseFormat(OpenAIBaseModel):
type: Literal["text", "json_object"]


class FunctionDefinition(OpenAIBaseModel):
name: str
description: Optional[str] = None
parameters: Optional[Dict[str, Any]] = None


class ChatCompletionToolsParam(OpenAIBaseModel):
type: Literal["function"] = "function"
function: FunctionDefinition


class ChatCompletionNamedFunction(OpenAIBaseModel):
name: str


class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
function: ChatCompletionNamedFunction
type: Literal["function"] = "function"


class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
Expand All @@ -121,6 +141,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
stream: Optional[bool] = False
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
tools: Optional[List[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"],
ChatCompletionNamedToolChoiceParam]] = "none"
user: Optional[str] = None

# doc: begin-chat-completion-sampling-params
Expand Down Expand Up @@ -244,10 +267,27 @@ def check_guided_decoding_count(cls, data):
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
# you can only use one kind of guided decoding
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
# you can only either use guided decoding or tools, not both
if guide_count > 1 and "tool_choice" in data and data[
"tool_choice"] != "none":
raise ValueError(
"You can only either use guided decoding or tools, not both.")
return data

@model_validator(mode="before")
@classmethod
def check_tool_choice(cls, data):
if "tool_choice" in data and data["tool_choice"] != "none":
if not isinstance(data["tool_choice"], dict):
raise ValueError("Currently only named tools are supported.")
if "tools" not in data or data["tools"] is None:
raise ValueError(
"When using `tool_choice`, `tools` must be set.")
return data

@model_validator(mode="before")
Expand Down Expand Up @@ -505,9 +545,21 @@ class EmbeddingResponse(BaseModel):
usage: UsageInfo


class FunctionCall(OpenAIBaseModel):
name: str
arguments: str


class ToolCall(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
type: Literal["function"] = "function"
function: FunctionCall


class ChatMessage(OpenAIBaseModel):
role: str
content: str
tool_calls: List[ToolCall] = Field(default_factory=list)


class ChatCompletionLogProb(OpenAIBaseModel):
Expand All @@ -534,7 +586,7 @@ class ChatCompletionResponseChoice(OpenAIBaseModel):

class ChatCompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion"
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
Expand All @@ -544,6 +596,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
class DeltaMessage(OpenAIBaseModel):
role: Optional[str] = None
content: Optional[str] = None
tool_calls: List[ToolCall] = Field(default_factory=list)


class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
Expand All @@ -556,7 +609,7 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel):

class ChatCompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk"
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
Expand Down
37 changes: 32 additions & 5 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
from vllm.entrypoints.openai.protocol import (
ChatCompletionContentPartParam, ChatCompletionLogProb,
ChatCompletionLogProbs, ChatCompletionLogProbsContent,
ChatCompletionMessageParam, ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo)
FunctionCall, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
Expand Down Expand Up @@ -298,11 +299,24 @@ async def chat_completion_stream_generator(
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)

if request.tool_choice and type(
request.tool_choice
) is ChatCompletionNamedToolChoiceParam:
delta_message = DeltaMessage(tool_calls=[
ToolCall(function=FunctionCall(
name=request.tool_choice.function.name,
arguments=delta_text))
])
else:
delta_message = DeltaMessage(content=delta_text)

if output.finish_reason is None:
# Send token-by-token response for each request.n

choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
delta=delta_message,
logprobs=logprobs,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
Expand All @@ -324,7 +338,7 @@ async def chat_completion_stream_generator(
)
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
delta=delta_message,
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
Expand Down Expand Up @@ -381,9 +395,22 @@ async def chat_completion_full_generator(
else:
logprobs = None

if request.tool_choice and type(
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
message = ChatMessage(
role=role,
content="",
tool_calls=[
ToolCall(function=FunctionCall(
name=request.tool_choice.function.name,
arguments=output.text))

This comment was marked as resolved.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outlines or lm-format-enforcer guarantee that for us.

])
elif not request.tool_choice or request.tool_choice == "none":
message = ChatMessage(role=role, content=output.text)

choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role=role, content=output.text),
message=message,
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
Expand Down
Loading
Loading