Skip to content

Commit

Permalink
[BugFix][Frontend] Use LoRA tokenizer in OpenAI APIs (vllm-project#6227)
Browse files Browse the repository at this point in the history
Co-authored-by: Cyrus Leung <[email protected]>
  • Loading branch information
2 people authored and gnpinkert committed Jul 26, 2024
1 parent 554db1c commit e7f7549
Show file tree
Hide file tree
Showing 16 changed files with 267 additions and 186 deletions.
33 changes: 6 additions & 27 deletions tests/async_engine/test_chat_template.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import pathlib
from dataclasses import dataclass

import pytest

Expand Down Expand Up @@ -50,23 +49,9 @@
]


@dataclass
class MockTokenizer:
chat_template = None


@dataclass
class MockServingChat:
tokenizer: MockTokenizer


def test_load_chat_template():
# Testing chatml template
tokenizer = MockTokenizer()
mock_serving_chat = MockServingChat(tokenizer)
load_chat_template(mock_serving_chat, chat_template=chatml_jinja_path)

template_content = tokenizer.chat_template
template_content = load_chat_template(chat_template=chatml_jinja_path)

# Test assertions
assert template_content is not None
Expand All @@ -78,22 +63,16 @@ def test_load_chat_template():
def test_no_load_chat_template_filelike():
# Testing chatml template
template = "../../examples/does_not_exist"
tokenizer = MockTokenizer()

mock_serving_chat = MockServingChat(tokenizer)

with pytest.raises(ValueError, match="looks like a file path"):
load_chat_template(mock_serving_chat, chat_template=template)
load_chat_template(chat_template=template)


def test_no_load_chat_template_literallike():
# Testing chatml template
template = "{{ messages }}"
tokenizer = MockTokenizer()

mock_serving_chat = MockServingChat(tokenizer)
load_chat_template(mock_serving_chat, chat_template=template)
template_content = tokenizer.chat_template
template_content = load_chat_template(chat_template=template)

assert template_content == template

Expand All @@ -105,8 +84,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
expected_output):
# Initialize the tokenizer
tokenizer = get_tokenizer(tokenizer_name=model)
mock_serving_chat = MockServingChat(tokenizer)
load_chat_template(mock_serving_chat, chat_template=template)
template_content = load_chat_template(chat_template=template)

# Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest(
Expand All @@ -118,7 +96,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
result = tokenizer.apply_chat_template(
conversation=mock_request.messages,
tokenize=False,
add_generation_prompt=mock_request.add_generation_prompt)
add_generation_prompt=mock_request.add_generation_prompt,
chat_template=mock_request.chat_template or template_content)

# Test assertion
assert result == expected_output, (
Expand Down
13 changes: 4 additions & 9 deletions tests/entrypoints/openai/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import openai # use the official client for correctness check
import pytest
import torch
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
from openai import BadRequestError

from ...utils import RemoteOpenAIServer
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
from .test_completion import zephyr_lora_files # noqa: F401

# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
Expand All @@ -21,12 +21,7 @@


@pytest.fixture(scope="module")
def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)


@pytest.fixture(scope="module")
def server(zephyr_lora_files):
def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
Expand All @@ -38,7 +33,7 @@ def server(zephyr_lora_files):
"--enable-lora",
"--lora-modules",
f"zephyr-lora={zephyr_lora_files}",
f"zephyr-lora2={zephyr_lora_files}",
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
"--max-lora-rank",
"64",
"--max-cpu-loras",
Expand Down
51 changes: 49 additions & 2 deletions tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# imports for guided decoding tests
import json
import re
import shutil
from tempfile import TemporaryDirectory
from typing import List

import jsonschema
Expand All @@ -9,6 +11,7 @@
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
from openai import BadRequestError
from transformers import AutoTokenizer

from vllm.transformers_utils.tokenizer import get_tokenizer

Expand All @@ -30,13 +33,29 @@ def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)


@pytest.fixture(scope="module")
def zephyr_lora_added_tokens_files(zephyr_lora_files):
tmp_dir = TemporaryDirectory()
tmp_model_dir = f"{tmp_dir.name}/zephyr"
shutil.copytree(zephyr_lora_files, tmp_model_dir)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Copy tokenizer to adapter and add some unique tokens
# 32000, 32001, 32002
added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"],
special_tokens=True)
assert added == 3
tokenizer.save_pretrained(tmp_model_dir)
yield tmp_model_dir
tmp_dir.cleanup()


@pytest.fixture(scope="module")
def zephyr_pa_files():
return snapshot_download(repo_id=PA_NAME)


@pytest.fixture(scope="module")
def server(zephyr_lora_files, zephyr_pa_files):
def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files):
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
Expand All @@ -50,7 +69,7 @@ def server(zephyr_lora_files, zephyr_pa_files):
"--enable-lora",
"--lora-modules",
f"zephyr-lora={zephyr_lora_files}",
f"zephyr-lora2={zephyr_lora_files}",
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
"--max-lora-rank",
"64",
"--max-cpu-loras",
Expand Down Expand Up @@ -111,6 +130,34 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
assert len(completion.choices[0].text) >= 1


@pytest.mark.asyncio
async def test_added_lora_tokens(client: openai.AsyncOpenAI):
# test using token IDs
completion = await client.completions.create(
model="zephyr-lora2",
prompt=[0, 0, 32000, 32001, 32002],
echo=True,
max_tokens=5,
temperature=0.0,
)
# Added tokens should appear in tokenized prompt
assert completion.choices[0].text.startswith("<unk><unk>vllm1vllm2vllm3")


@pytest.mark.asyncio
async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 32000, 32001, 32002],
echo=True,
max_tokens=5,
temperature=0.0,
)
# Added tokens should not appear in tokenized prompt
assert "vllm" not in completion.choices[0].text


@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras, then test prompt adapters
Expand Down
3 changes: 1 addition & 2 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,4 @@ async def _async_serving_chat_init():

def test_async_serving_chat_init():
serving_completion = asyncio.run(_async_serving_chat_init())
assert serving_completion.tokenizer is not None
assert serving_completion.tokenizer.chat_template == CHAT_TEMPLATE
assert serving_completion.chat_template == CHAT_TEMPLATE
56 changes: 40 additions & 16 deletions tests/entrypoints/openai/test_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from vllm.transformers_utils.tokenizer import get_tokenizer

from ...utils import RemoteOpenAIServer
from .test_completion import zephyr_lora_added_tokens_files # noqa: F401
from .test_completion import zephyr_lora_files # noqa: F401

# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"


@pytest.fixture(scope="module")
def server():
def server(zephyr_lora_added_tokens_files: str): # noqa: F811
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
Expand All @@ -21,29 +23,44 @@ def server():
"--enforce-eager",
"--max-num-seqs",
"128",
# lora config
"--enable-lora",
"--lora-modules",
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
"--max-lora-rank",
"64",
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server


@pytest.fixture(scope="module")
def tokenizer_name(model_name: str,
zephyr_lora_added_tokens_files: str): # noqa: F811
return zephyr_lora_added_tokens_files if (
model_name == "zephyr-lora2") else model_name


@pytest.fixture(scope="module")
def client(server):
return server.get_async_client()


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
"model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
)
async def test_tokenize_completions(client: openai.AsyncOpenAI,
model_name: str):
model_name: str, tokenizer_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast")

for add_special in [False, True]:
prompt = "This is a test prompt."
prompt = "vllm1 This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)

response = requests.post(base_url + "/tokenize",
Expand All @@ -63,12 +80,15 @@ async def test_tokenize_completions(client: openai.AsyncOpenAI,

@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
"model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
)
async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str):
async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
tokenizer_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast")

for add_generation in [False, True]:
for add_special in [False, True]:
Expand All @@ -80,7 +100,7 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str):
"content": "Nice to meet you!"
}, {
"role": "user",
"content": "Can I ask a question?"
"content": "Can I ask a question? vllm1"
}]

prompt = tokenizer.apply_chat_template(
Expand Down Expand Up @@ -108,16 +128,20 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str):

@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
"model_name,tokenizer_name",
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
indirect=["tokenizer_name"],
)
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str):
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str,
tokenizer_name: str):
base_url = str(client.base_url)[:-3].strip("/")
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode="fast")

prompt = "This is a test prompt."
prompt = "This is a test prompt. vllm1"
tokens = tokenizer.encode(prompt, add_special_tokens=False)

print(f"CALLING {base_url} FOR {model_name}")
response = requests.post(base_url + "/detokenize",
json={
"model": model_name,
Expand Down
13 changes: 9 additions & 4 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,11 +480,16 @@ def _error_callback(self, exc: Exception) -> None:
self.set_errored(exc)
self._request_tracker.propagate_exception(exc)

async def get_tokenizer(self) -> "PreTrainedTokenizer":
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> "PreTrainedTokenizer":
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote() # type: ignore
else:
return self.engine.get_tokenizer()
return await self.engine.get_tokenizer.remote( # type: ignore
lora_request)

return await (self.engine.get_tokenizer_group().
get_lora_tokenizer_async(lora_request))

def start_background_loop(self) -> None:
"""Start the background loop."""
Expand Down
7 changes: 5 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,11 @@ def get_tokenizer_group(

return self.tokenizer

def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.get_tokenizer_group().get_lora_tokenizer(None)
def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)

def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
Expand Down
3 changes: 2 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def run_server(args, llm_engine=None):
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
served_model_names)
openai_serving_tokenization = OpenAIServingTokenization(
engine, model_config, served_model_names, args.chat_template)
engine, model_config, served_model_names, args.lora_modules,
args.chat_template)
app.root_path = args.root_path

logger.info("Available routes are:")
Expand Down
Loading

0 comments on commit e7f7549

Please sign in to comment.