Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][CI/Build] Test prompt adapters in openai entrypoint tests #6419

Merged
merged 8 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 39 additions & 25 deletions tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
# technically these adapters use a different base model,
# but we're not testing generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
PA_NAME = "swapnilbp/llama_tweet_ptune"


@pytest.fixture(scope="module")
Expand All @@ -28,7 +29,12 @@ def zephyr_lora_files():


@pytest.fixture(scope="module")
def server(zephyr_lora_files):
def zephyr_pa_files():
return snapshot_download(repo_id=PA_NAME)


@pytest.fixture(scope="module")
def server(zephyr_lora_files, zephyr_pa_files):
with RemoteOpenAIServer([
"--model",
MODEL_NAME,
Expand All @@ -37,8 +43,10 @@ def server(zephyr_lora_files):
"bfloat16",
"--max-model-len",
"8192",
"--max-num-seqs",
"128",
"--enforce-eager",
# lora config below
# lora config
"--enable-lora",
"--lora-modules",
f"zephyr-lora={zephyr_lora_files}",
Expand All @@ -47,7 +55,14 @@ def server(zephyr_lora_files):
"64",
"--max-cpu-loras",
"2",
"--max-num-seqs",
# pa config
"--enable-prompt-adapter",
"--prompt-adapters",
f"zephyr-pa={zephyr_pa_files}",
f"zephyr-pa2={zephyr_pa_files}",
"--max-prompt-adapters",
"2",
"--max-prompt-adapter-token",
"128",
]) as remote_server:
yield remote_server
Expand All @@ -60,9 +75,9 @@ def client(server):

@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
# first test base model, then test loras, then test prompt adapters
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
[MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"],
)
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
completion = await client.completions.create(model=model_name,
Expand All @@ -81,7 +96,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):

# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
Expand All @@ -91,14 +106,14 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):

@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
# first test base model, then test loras, then test prompt adapters
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
[MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"],
)
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
Expand All @@ -110,14 +125,14 @@ async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):

@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora hereafter
# just test 1 lora and 1 pa hereafter
"model_name",
[MODEL_NAME, "zephyr-lora"],
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
Expand All @@ -133,12 +148,12 @@ async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
Expand All @@ -154,15 +169,15 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
model_name: str):

with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs
await client.completions.create(
model=MODEL_NAME,
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
Expand All @@ -174,7 +189,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs
stream = await client.completions.create(
model=MODEL_NAME,
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
Expand All @@ -199,7 +214,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_completion_streaming(client: openai.AsyncOpenAI,
model_name: str):
Expand Down Expand Up @@ -233,7 +248,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"],
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_completion_stream_options(client: openai.AsyncOpenAI,
model_name: str):
Expand Down Expand Up @@ -369,9 +384,8 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI,

@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
[MODEL_NAME, "zephyr-lora"],
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
# test both text and token IDs
Expand Down Expand Up @@ -623,7 +637,7 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
)
async def test_tokenize(client: openai.AsyncOpenAI, model_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")

for add_special in [False, True]:
prompt = "This is a test prompt."
Expand All @@ -650,7 +664,7 @@ async def test_tokenize(client: openai.AsyncOpenAI, model_name: str):
)
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str):
base_url = str(client.base_url)[:-3]
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")

prompt = "This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=False)
Expand Down
5 changes: 3 additions & 2 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import pathlib
from dataclasses import dataclass
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -74,8 +75,8 @@ def __init__(
self.prompt_adapter_requests = []
if prompt_adapters is not None:
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
with open(f"./{prompt_adapter.local_path}"
f"/adapter_config.json") as f:
with pathlib.Path(prompt_adapter.local_path,
"adapter_config.json").open() as f:
adapter_config = json.load(f)
num_virtual_tokens = adapter_config["num_virtual_tokens"]
self.prompt_adapter_requests.append(
Expand Down
Loading