Skip to content

Commit

Permalink
Account for virtual tokens in token counts
Browse files Browse the repository at this point in the history
  • Loading branch information
g-eoj committed Jul 15, 2024
1 parent 8503d2e commit 6b03c49
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# imports for guided decoding tests
import json
import pathlib
import re
from typing import List

Expand Down Expand Up @@ -33,6 +34,14 @@ def zephyr_pa_files():
return snapshot_download(repo_id=PA_NAME)


@pytest.fixture(scope="module")
def zephyr_pa_num_virtual_tokens(zephyr_pa_files):
with pathlib.Path(zephyr_pa_files, "adapter_config.json").open() as f:
adapter_config = json.load(f)
num_virtual_tokens = adapter_config["num_virtual_tokens"]
return num_virtual_tokens


@pytest.fixture(scope="module")
def server(zephyr_lora_files, zephyr_pa_files):
with RemoteOpenAIServer([
Expand Down Expand Up @@ -76,10 +85,13 @@ def client(server):
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras, then test prompt adapters
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"],
"model_name,num_virtual_tokens",
[(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0),
("zephyr-pa", zephyr_pa_num_virtual_tokens),
("zephyr-pa2", zephyr_pa_num_virtual_tokens)],
)
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
num_virtual_tokens: int):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
max_tokens=5,
Expand All @@ -92,7 +104,9 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
assert len(choice.text) >= 5
assert choice.finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)
completion_tokens=5,
prompt_tokens=6 + num_virtual_tokens,
total_tokens=11 + num_virtual_tokens)

# test using token IDs
completion = await client.completions.create(
Expand Down

0 comments on commit 6b03c49

Please sign in to comment.