Skip to content

Commit

Permalink
Hardcode virtual token count for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
g-eoj committed Jul 15, 2024
1 parent 6b03c49 commit 6d1e428
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# imports for guided decoding tests
import json
import pathlib
import re
from typing import List

Expand All @@ -22,6 +21,9 @@
# but we're not testing generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
PA_NAME = "swapnilbp/llama_tweet_ptune"
# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also
# need to change to match the prompt adapter
PA_NUM_VIRTUAL_TOKENS = 8


@pytest.fixture(scope="module")
Expand All @@ -34,14 +36,6 @@ def zephyr_pa_files():
return snapshot_download(repo_id=PA_NAME)


@pytest.fixture(scope="module")
def zephyr_pa_num_virtual_tokens(zephyr_pa_files):
with pathlib.Path(zephyr_pa_files, "adapter_config.json").open() as f:
adapter_config = json.load(f)
num_virtual_tokens = adapter_config["num_virtual_tokens"]
return num_virtual_tokens


@pytest.fixture(scope="module")
def server(zephyr_lora_files, zephyr_pa_files):
with RemoteOpenAIServer([
Expand Down Expand Up @@ -87,8 +81,8 @@ def client(server):
# first test base model, then test loras, then test prompt adapters
"model_name,num_virtual_tokens",
[(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0),
("zephyr-pa", zephyr_pa_num_virtual_tokens),
("zephyr-pa2", zephyr_pa_num_virtual_tokens)],
("zephyr-pa", PA_NUM_VIRTUAL_TOKENS),
("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)],
)
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
num_virtual_tokens: int):
Expand Down

0 comments on commit 6d1e428

Please sign in to comment.