From 6d1e428d8aefefe7d04ffad14f7a3583bcc03963 Mon Sep 17 00:00:00 2001 From: Joe G Date: Mon, 15 Jul 2024 10:54:43 -0700 Subject: [PATCH] Hardcode virtual token count for tests --- tests/entrypoints/openai/test_completion.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 5e215b25bcee6..955a7a7a98dcd 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -1,6 +1,5 @@ # imports for guided decoding tests import json -import pathlib import re from typing import List @@ -22,6 +21,9 @@ # but we're not testing generation quality here LORA_NAME = "typeof/zephyr-7b-beta-lora" PA_NAME = "swapnilbp/llama_tweet_ptune" +# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also +# need to change to match the prompt adapter +PA_NUM_VIRTUAL_TOKENS = 8 @pytest.fixture(scope="module") @@ -34,14 +36,6 @@ def zephyr_pa_files(): return snapshot_download(repo_id=PA_NAME) -@pytest.fixture(scope="module") -def zephyr_pa_num_virtual_tokens(zephyr_pa_files): - with pathlib.Path(zephyr_pa_files, "adapter_config.json").open() as f: - adapter_config = json.load(f) - num_virtual_tokens = adapter_config["num_virtual_tokens"] - return num_virtual_tokens - - @pytest.fixture(scope="module") def server(zephyr_lora_files, zephyr_pa_files): with RemoteOpenAIServer([ @@ -87,8 +81,8 @@ def client(server): # first test base model, then test loras, then test prompt adapters "model_name,num_virtual_tokens", [(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0), - ("zephyr-pa", zephyr_pa_num_virtual_tokens), - ("zephyr-pa2", zephyr_pa_num_virtual_tokens)], + ("zephyr-pa", PA_NUM_VIRTUAL_TOKENS), + ("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)], ) async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, num_virtual_tokens: int):