From a611890e05a15396be735783cda4d0179a9572a7 Mon Sep 17 00:00:00 2001 From: Joe Date: Mon, 15 Jul 2024 18:54:15 -0700 Subject: [PATCH] [Bugfix][CI/Build] Test prompt adapters in openai entrypoint tests (#6419) --- tests/entrypoints/openai/test_completion.py | 80 +++++++++++++-------- vllm/entrypoints/openai/serving_engine.py | 5 +- 2 files changed, 54 insertions(+), 31 deletions(-) diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 6e5fdebe786e1..f9dbf69c2eaab 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -17,9 +17,13 @@ # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" -# technically this needs Mistral-7B-v0.1 as base, but we're not testing -# generation quality here +# technically these adapters use a different base model, +# but we're not testing generation quality here LORA_NAME = "typeof/zephyr-7b-beta-lora" +PA_NAME = "swapnilbp/llama_tweet_ptune" +# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also +# need to change to match the prompt adapter +PA_NUM_VIRTUAL_TOKENS = 8 @pytest.fixture(scope="module") @@ -28,7 +32,12 @@ def zephyr_lora_files(): @pytest.fixture(scope="module") -def server(zephyr_lora_files): +def zephyr_pa_files(): + return snapshot_download(repo_id=PA_NAME) + + +@pytest.fixture(scope="module") +def server(zephyr_lora_files, zephyr_pa_files): with RemoteOpenAIServer([ "--model", MODEL_NAME, @@ -37,8 +46,10 @@ def server(zephyr_lora_files): "bfloat16", "--max-model-len", "8192", + "--max-num-seqs", + "128", "--enforce-eager", - # lora config below + # lora config "--enable-lora", "--lora-modules", f"zephyr-lora={zephyr_lora_files}", @@ -47,7 +58,14 @@ def server(zephyr_lora_files): "64", "--max-cpu-loras", "2", - "--max-num-seqs", + # pa config + "--enable-prompt-adapter", + "--prompt-adapters", + f"zephyr-pa={zephyr_pa_files}", + f"zephyr-pa2={zephyr_pa_files}", + "--max-prompt-adapters", + "2", + "--max-prompt-adapter-token", "128", ]) as remote_server: yield remote_server @@ -60,11 +78,14 @@ def client(server): @pytest.mark.asyncio @pytest.mark.parametrize( - # first test base model, then test loras - "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], + # first test base model, then test loras, then test prompt adapters + "model_name,num_virtual_tokens", + [(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0), + ("zephyr-pa", PA_NUM_VIRTUAL_TOKENS), + ("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)], ) -async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): +async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, + num_virtual_tokens: int): completion = await client.completions.create(model=model_name, prompt="Hello, my name is", max_tokens=5, @@ -77,28 +98,30 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): assert len(choice.text) >= 5 assert choice.finish_reason == "length" assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, prompt_tokens=6, total_tokens=11) + completion_tokens=5, + prompt_tokens=6 + num_virtual_tokens, + total_tokens=11 + num_virtual_tokens) # test using token IDs completion = await client.completions.create( - model=MODEL_NAME, + model=model_name, prompt=[0, 0, 0, 0, 0], max_tokens=5, temperature=0.0, ) - assert len(completion.choices[0].text) >= 5 + assert len(completion.choices[0].text) >= 1 @pytest.mark.asyncio @pytest.mark.parametrize( - # first test base model, then test loras + # first test base model, then test loras, then test prompt adapters "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], + [MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"], ) async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): # test using token IDs completion = await client.completions.create( - model=MODEL_NAME, + model=model_name, prompt=[0, 0, 0, 0, 0], max_tokens=5, temperature=0.0, @@ -110,14 +133,14 @@ async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize( - # just test 1 lora hereafter + # just test 1 lora and 1 pa hereafter "model_name", - [MODEL_NAME, "zephyr-lora"], + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], ) async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): # test using token IDs completion = await client.completions.create( - model=MODEL_NAME, + model=model_name, prompt=[0, 0, 0, 0, 0], max_tokens=5, temperature=0.0, @@ -133,12 +156,12 @@ async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", - [MODEL_NAME, "zephyr-lora"], + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], ) async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): # test using token IDs completion = await client.completions.create( - model=MODEL_NAME, + model=model_name, prompt=[0, 0, 0, 0, 0], max_tokens=5, temperature=0.0, @@ -154,7 +177,7 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", - [MODEL_NAME, "zephyr-lora"], + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], ) async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, model_name: str): @@ -162,7 +185,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, with pytest.raises( (openai.BadRequestError, openai.APIError)): # test using token IDs await client.completions.create( - model=MODEL_NAME, + model=model_name, prompt=[0, 0, 0, 0, 0], max_tokens=5, temperature=0.0, @@ -174,7 +197,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, with pytest.raises( (openai.BadRequestError, openai.APIError)): # test using token IDs stream = await client.completions.create( - model=MODEL_NAME, + model=model_name, prompt=[0, 0, 0, 0, 0], max_tokens=5, temperature=0.0, @@ -199,7 +222,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", - [MODEL_NAME, "zephyr-lora"], + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], ) async def test_completion_streaming(client: openai.AsyncOpenAI, model_name: str): @@ -233,7 +256,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", - ["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"], + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], ) async def test_completion_stream_options(client: openai.AsyncOpenAI, model_name: str): @@ -369,9 +392,8 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize( - # just test 1 lora hereafter "model_name", - [MODEL_NAME, "zephyr-lora"], + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], ) async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): # test both text and token IDs @@ -623,7 +645,7 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, ) async def test_tokenize(client: openai.AsyncOpenAI, model_name: str): base_url = str(client.base_url)[:-3].strip("/") - tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast") + tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") for add_special in [False, True]: prompt = "This is a test prompt." @@ -650,7 +672,7 @@ async def test_tokenize(client: openai.AsyncOpenAI, model_name: str): ) async def test_detokenize(client: openai.AsyncOpenAI, model_name: str): base_url = str(client.base_url)[:-3] - tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast") + tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") prompt = "This is a test prompt." tokens = tokenizer.encode(prompt, add_special_tokens=False) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 58e6571d310e6..14c1df89e064f 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,4 +1,5 @@ import json +import pathlib from dataclasses import dataclass from http import HTTPStatus from typing import Any, Dict, List, Optional, Tuple, Union @@ -74,8 +75,8 @@ def __init__( self.prompt_adapter_requests = [] if prompt_adapters is not None: for i, prompt_adapter in enumerate(prompt_adapters, start=1): - with open(f"./{prompt_adapter.local_path}" - f"/adapter_config.json") as f: + with pathlib.Path(prompt_adapter.local_path, + "adapter_config.json").open() as f: adapter_config = json.load(f) num_virtual_tokens = adapter_config["num_virtual_tokens"] self.prompt_adapter_requests.append(