Skip to content

Commit

Permalink
[CI] Expand OpenAI test_chat.py guided decoding tests (vllm-project#1…
Browse files Browse the repository at this point in the history
…1048)

Signed-off-by: mgoin <[email protected]>
  • Loading branch information
mgoin authored Dec 23, 2024
1 parent 8cef6e0 commit 63afbe9
Showing 1 changed file with 12 additions and 17 deletions.
29 changes: 12 additions & 17 deletions tests/entrypoints/openai/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"

GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]


@pytest.fixture(scope="module")
def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811
Expand Down Expand Up @@ -464,8 +466,7 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
# will fail on the second `guided_decoding_backend` even when I swap their order
# (ref: https://github.com/vllm-project/vllm/pull/5526#issuecomment-2173772256)
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_choice_chat(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_guided_choice):
Expand Down Expand Up @@ -506,8 +507,7 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_json_chat(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_json_schema):
Expand Down Expand Up @@ -554,8 +554,7 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI,


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_regex_chat(client: openai.AsyncOpenAI,
guided_decoding_backend: str, sample_regex):
messages = [{
Expand Down Expand Up @@ -613,8 +612,7 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI):


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_guided_choice):
Expand Down Expand Up @@ -646,8 +644,7 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
async def test_named_tool_use(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_json_schema):
Expand Down Expand Up @@ -681,7 +678,8 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
"function": {
"name": "dummy_function_name"
}
})
},
extra_body=dict(guided_decoding_backend=guided_decoding_backend))
message = chat_completion.choices[0].message
assert len(message.content) == 0
json_string = message.tool_calls[0].function.arguments
Expand Down Expand Up @@ -716,6 +714,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
"name": "dummy_function_name"
}
},
extra_body=dict(guided_decoding_backend=guided_decoding_backend),
stream=True)

output = []
Expand All @@ -738,10 +737,8 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
async def test_required_tool_use_not_yet_supported(
client: openai.AsyncOpenAI, guided_decoding_backend: str,
sample_json_schema):
async def test_required_tool_use_not_yet_supported(client: openai.AsyncOpenAI,
sample_json_schema):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
Expand Down Expand Up @@ -785,9 +782,7 @@ async def test_required_tool_use_not_yet_supported(


@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_json_schema):
messages = [{
"role": "system",
Expand Down

0 comments on commit 63afbe9

Please sign in to comment.