Skip to content

Commit

Permalink
[Bugfix][CI/Build] Test prompt adapters in openai entrypoint tests (v…
Browse files Browse the repository at this point in the history
  • Loading branch information
g-eoj authored and dtrifiro committed Jul 17, 2024
1 parent 2efff72 commit e2f9a09
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 31 deletions.
80 changes: 51 additions & 29 deletions tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@

# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
# technically these adapters use a different base model,
# but we're not testing generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
PA_NAME = "swapnilbp/llama_tweet_ptune"
# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also
# need to change to match the prompt adapter
PA_NUM_VIRTUAL_TOKENS = 8


@pytest.fixture(scope="module")
Expand All @@ -28,7 +32,12 @@ def zephyr_lora_files():


@pytest.fixture(scope="module")
def server(zephyr_lora_files):
def zephyr_pa_files():
return snapshot_download(repo_id=PA_NAME)


@pytest.fixture(scope="module")
def server(zephyr_lora_files, zephyr_pa_files):
with RemoteOpenAIServer([
"--model",
MODEL_NAME,
Expand All @@ -37,8 +46,10 @@ def server(zephyr_lora_files):
"bfloat16",
"--max-model-len",
"8192",
"--max-num-seqs",
"128",
"--enforce-eager",
# lora config below
# lora config
"--enable-lora",
"--lora-modules",
f"zephyr-lora={zephyr_lora_files}",
Expand All @@ -47,7 +58,14 @@ def server(zephyr_lora_files):
"64",
"--max-cpu-loras",
"2",
"--max-num-seqs",
# pa config
"--enable-prompt-adapter",
"--prompt-adapters",
f"zephyr-pa={zephyr_pa_files}",
f"zephyr-pa2={zephyr_pa_files}",
"--max-prompt-adapters",
"2",
"--max-prompt-adapter-token",
"128",
]) as remote_server:
yield remote_server
Expand All @@ -60,11 +78,14 @@ def client(server):

@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
# first test base model, then test loras, then test prompt adapters
"model_name,num_virtual_tokens",
[(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0),
("zephyr-pa", PA_NUM_VIRTUAL_TOKENS),
("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)],
)
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
num_virtual_tokens: int):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
max_tokens=5,
Expand All @@ -77,28 +98,30 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
assert len(choice.text) >= 5
assert choice.finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)
completion_tokens=5,
prompt_tokens=6 + num_virtual_tokens,
total_tokens=11 + num_virtual_tokens)

# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert len(completion.choices[0].text) >= 5
assert len(completion.choices[0].text) >= 1


@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
# first test base model, then test loras, then test prompt adapters
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
[MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"],
)
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
Expand All @@ -110,14 +133,14 @@ async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):

@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora hereafter
# just test 1 lora and 1 pa hereafter
"model_name",
[MODEL_NAME, "zephyr-lora"],
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
Expand All @@ -133,12 +156,12 @@ async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
Expand All @@ -154,15 +177,15 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
model_name: str):

with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs
await client.completions.create(
model=MODEL_NAME,
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
Expand All @@ -174,7 +197,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs
stream = await client.completions.create(
model=MODEL_NAME,
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
Expand All @@ -199,7 +222,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_completion_streaming(client: openai.AsyncOpenAI,
model_name: str):
Expand Down Expand Up @@ -233,7 +256,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"],
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_completion_stream_options(client: openai.AsyncOpenAI,
model_name: str):
Expand Down Expand Up @@ -369,9 +392,8 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI,

@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
[MODEL_NAME, "zephyr-lora"],
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
# test both text and token IDs
Expand Down Expand Up @@ -623,7 +645,7 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
)
async def test_tokenize(client: openai.AsyncOpenAI, model_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")

for add_special in [False, True]:
prompt = "This is a test prompt."
Expand All @@ -650,7 +672,7 @@ async def test_tokenize(client: openai.AsyncOpenAI, model_name: str):
)
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str):
base_url = str(client.base_url)[:-3]
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")

prompt = "This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=False)
Expand Down
5 changes: 3 additions & 2 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import pathlib
from dataclasses import dataclass
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -74,8 +75,8 @@ def __init__(
self.prompt_adapter_requests = []
if prompt_adapters is not None:
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
with open(f"./{prompt_adapter.local_path}"
f"/adapter_config.json") as f:
with pathlib.Path(prompt_adapter.local_path,
"adapter_config.json").open() as f:
adapter_config = json.load(f)
num_virtual_tokens = adapter_config["num_virtual_tokens"]
self.prompt_adapter_requests.append(
Expand Down

0 comments on commit e2f9a09

Please sign in to comment.