Skip to content

Commit

Permalink
[FRONTEND] OpenAI tools support named functions (vllm-project#5032)
Browse files Browse the repository at this point in the history
  • Loading branch information
br3no authored and joerunde committed Jun 4, 2024
1 parent b7de754 commit aa19635
Show file tree
Hide file tree
Showing 6 changed files with 314 additions and 11 deletions.
13 changes: 12 additions & 1 deletion docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,15 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/)
:module: vllm.entrypoints.openai.cli_args
:func: make_arg_parser
:prog: -m vllm.entrypoints.openai.api_server
```
```

## Tool calling in the chat completion API
vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap.

To use a named function you need to define the function in the `tools` parameter and call it in the `tool_choice` parameter.

It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. **This may change in the future.**

vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter.

Please refer to the OpenAI API reference documentation for more information.
185 changes: 185 additions & 0 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,191 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
for token in top_logprobs)


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_named_tool_use(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
}]

# non-streaming

chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
}
}],
tool_choice={
"type": "function",
"function": {
"name": "dummy_function_name"
}
})
message = chat_completion.choices[0].message
assert len(message.content) == 0
json_string = message.tool_calls[0].function.arguments
json1 = json.loads(json_string)
jsonschema.validate(instance=json1, schema=TEST_SCHEMA)

messages.append({"role": "assistant", "content": json_string})
messages.append({
"role":
"user",
"content":
"Give me another one with a different name and age"
})

# streaming

stream = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
}
}],
tool_choice={
"type": "function",
"function": {
"name": "dummy_function_name"
}
},
stream=True)

output = []
finish_reason_count = 0
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.role:
assert delta.role == "assistant"
assert delta.content is None or len(delta.content) == 0
if delta.tool_calls:
output.append(delta.tool_calls[0].function.arguments)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
# finish reason should only return in last block
assert finish_reason_count == 1
json2 = json.loads("".join(output))
jsonschema.validate(instance=json2, schema=TEST_SCHEMA)
assert json1["name"] != json2["name"]
assert json1["age"] != json2["age"]


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
async def test_required_tool_use_not_yet_supported(
server, client: openai.AsyncOpenAI, guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
}]

with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
}
}],
tool_choice="required")

with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
}
}],
tool_choice="auto")


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
async def test_inconsistent_tool_choice_and_tools(
server, client: openai.AsyncOpenAI, guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {TEST_SCHEMA}"
}]

with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tool_choice={
"type": "function",
"function": {
"name":
"dummy_function_name"
}
})

with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1000,
tools=[{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": TEST_SCHEMA
}
}],
tool_choice={
"type": "function",
"function": {
"name": "nondefined_function_name"
}
})


@pytest.mark.asyncio
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
for _ in range(2):
Expand Down
3 changes: 2 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def __init__(self, args):
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
self.proc = subprocess.Popen(
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
[sys.executable, "-m", "vllm.entrypoints.openai.api_server"] +
args,
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
Expand Down
57 changes: 55 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,26 @@ class ResponseFormat(OpenAIBaseModel):
type: Literal["text", "json_object"]


class FunctionDefinition(OpenAIBaseModel):
name: str
description: Optional[str] = None
parameters: Optional[Dict[str, Any]] = None


class ChatCompletionToolsParam(OpenAIBaseModel):
type: Literal["function"] = "function"
function: FunctionDefinition


class ChatCompletionNamedFunction(OpenAIBaseModel):
name: str


class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
function: ChatCompletionNamedFunction
type: Literal["function"] = "function"


class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
Expand All @@ -122,6 +142,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
stream: Optional[bool] = False
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
tools: Optional[List[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"],
ChatCompletionNamedToolChoiceParam]] = "none"
user: Optional[str] = None

# doc: begin-chat-completion-sampling-params
Expand Down Expand Up @@ -245,10 +268,27 @@ def check_guided_decoding_count(cls, data):
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
# you can only use one kind of guided decoding
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
# you can only either use guided decoding or tools, not both
if guide_count > 1 and "tool_choice" in data and data[
"tool_choice"] != "none":
raise ValueError(
"You can only either use guided decoding or tools, not both.")
return data

@model_validator(mode="before")
@classmethod
def check_tool_choice(cls, data):
if "tool_choice" in data and data["tool_choice"] != "none":
if not isinstance(data["tool_choice"], dict):
raise ValueError("Currently only named tools are supported.")
if "tools" not in data or data["tools"] is None:
raise ValueError(
"When using `tool_choice`, `tools` must be set.")
return data

@model_validator(mode="before")
Expand Down Expand Up @@ -506,9 +546,21 @@ class EmbeddingResponse(BaseModel):
usage: UsageInfo


class FunctionCall(OpenAIBaseModel):
name: str
arguments: str


class ToolCall(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
type: Literal["function"] = "function"
function: FunctionCall


class ChatMessage(OpenAIBaseModel):
role: str
content: str
tool_calls: List[ToolCall] = Field(default_factory=list)


class ChatCompletionLogProb(OpenAIBaseModel):
Expand All @@ -535,7 +587,7 @@ class ChatCompletionResponseChoice(OpenAIBaseModel):

class ChatCompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion"
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
Expand All @@ -545,6 +597,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
class DeltaMessage(OpenAIBaseModel):
role: Optional[str] = None
content: Optional[str] = None
tool_calls: List[ToolCall] = Field(default_factory=list)


class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
Expand All @@ -557,7 +610,7 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel):

class ChatCompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk"
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
Expand Down
Loading

0 comments on commit aa19635

Please sign in to comment.