Skip to content

Commit

Permalink
[Frontend] Gracefully handle missing chat template and fix CI failure (
Browse files Browse the repository at this point in the history
…vllm-project#7238)

Co-authored-by: Roger Wang <[email protected]>
  • Loading branch information
2 people authored and kylesayrs committed Aug 17, 2024
1 parent ab9e7a5 commit c646702
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 69 deletions.
21 changes: 8 additions & 13 deletions tests/async_engine/test_chat_template.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
import os
import pathlib

import pytest

from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.chat_utils import apply_chat_template, load_chat_template
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.transformers_utils.tokenizer import get_tokenizer

chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
__file__))).parent.parent / "examples/template_chatml.jinja"
from ..utils import VLLM_PATH

chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()

# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATON_OUTPUT = [
("facebook/opt-125m", None, True,
"Hello</s>Hi there!</s>What is the capital of</s>"),
("facebook/opt-125m", None, False,
"Hello</s>Hi there!</s>What is the capital of</s>"),
("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Expand Down Expand Up @@ -93,11 +87,12 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
add_generation_prompt=add_generation_prompt)

# Call the function and get the result
result = tokenizer.apply_chat_template(
result = apply_chat_template(
tokenizer,
conversation=mock_request.messages,
tokenize=False,
chat_template=mock_request.chat_template or template_content,
add_generation_prompt=mock_request.add_generation_prompt,
chat_template=mock_request.chat_template or template_content)
)

# Test assertion
assert result == expected_output, (
Expand Down
10 changes: 7 additions & 3 deletions tests/async_engine/test_openapi_server_ray.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import openai # use the official client for correctness check
import pytest

from ..utils import RemoteOpenAIServer
from ..utils import VLLM_PATH, RemoteOpenAIServer

# any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m"
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()


@pytest.fixture(scope="module")
Expand All @@ -16,7 +18,9 @@ def server():
"--max-model-len",
"2048",
"--enforce-eager",
"--engine-use-ray"
"--engine-use-ray",
"--chat-template",
str(chatml_jinja_path),
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
Expand Down Expand Up @@ -83,7 +87,7 @@ async def test_single_chat_session(client: openai.AsyncOpenAI):
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=13, total_tokens=23)
completion_tokens=10, prompt_tokens=55, total_tokens=65)

message = choice.message
assert message.content is not None and len(message.content) >= 10
Expand Down
87 changes: 55 additions & 32 deletions tests/entrypoints/openai/test_oot_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.utils import get_open_port

from ...utils import VLLM_PATH, RemoteOpenAIServer

chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()


class MyOPTForCausalLM(OPTForCausalLM):

Expand All @@ -21,12 +26,25 @@ def compute_logits(self, hidden_states: torch.Tensor,
return logits


def server_function(port):
def server_function(port: int):
# register our dummy model
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
sys.argv = ["placeholder.py"] + \
("--model facebook/opt-125m --gpu-memory-utilization 0.10 "
f"--dtype float32 --api-key token-abc123 --port {port}").split()

sys.argv = ["placeholder.py"] + [
"--model",
"facebook/opt-125m",
"--gpu-memory-utilization",
"0.10",
"--dtype",
"float32",
"--api-key",
"token-abc123",
"--port",
str(port),
"--chat-template",
str(chatml_jinja_path),
]

import runpy
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')

Expand All @@ -36,35 +54,40 @@ def test_oot_registration_for_api_server():
ctx = torch.multiprocessing.get_context()
server = ctx.Process(target=server_function, args=(port, ))
server.start()
MAX_SERVER_START_WAIT_S = 60
client = OpenAI(
base_url=f"http://localhost:{port}/v1",
api_key="token-abc123",
)
now = time.time()
while True:
try:
completion = client.chat.completions.create(
model="facebook/opt-125m",
messages=[{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Hello!"
}],
temperature=0,
)
break
except OpenAIError as e:
if "Connection error" in str(e):
time.sleep(3)
if time.time() - now > MAX_SERVER_START_WAIT_S:
raise RuntimeError("Server did not start in time") from e
else:
raise e
server.kill()

try:
client = OpenAI(
base_url=f"http://localhost:{port}/v1",
api_key="token-abc123",
)
now = time.time()
while True:
try:
completion = client.chat.completions.create(
model="facebook/opt-125m",
messages=[{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Hello!"
}],
temperature=0,
)
break
except OpenAIError as e:
if "Connection error" in str(e):
time.sleep(3)
if time.time() - now > RemoteOpenAIServer.MAX_START_WAIT_S:
msg = "Server did not start in time"
raise RuntimeError(msg) from e
else:
raise e
finally:
server.terminate()

generated_text = completion.choices[0].message.content
assert generated_text is not None
# make sure only the first token is generated
rest = generated_text.replace("<s>", "")
assert rest == ""
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _nvml():

class RemoteOpenAIServer:
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
MAX_SERVER_START_WAIT_S = 120 # wait for server to start for 120 seconds
MAX_START_WAIT_S = 120 # wait for server to start for 120 seconds

def __init__(
self,
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(
stdout=sys.stdout,
stderr=sys.stderr)
self._wait_for_server(url=self.url_for("health"),
timeout=self.MAX_SERVER_START_WAIT_S)
timeout=self.MAX_START_WAIT_S)

def __enter__(self):
return self
Expand Down
37 changes: 34 additions & 3 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import codecs
from dataclasses import dataclass
from functools import lru_cache
from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast,
final)
from pathlib import Path
from typing import (Any, Awaitable, Iterable, List, Optional, Tuple, Union,
cast, final)

# yapf conflicts with isort for this block
# yapf: disable
Expand All @@ -22,6 +23,7 @@
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import async_get_and_parse_image
from vllm.transformers_utils.tokenizer import AnyTokenizer

logger = init_logger(__name__)

Expand Down Expand Up @@ -69,13 +71,17 @@ class ChatMessageParseResult:
mm_futures: List[Awaitable[MultiModalDataDict]]


def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
def load_chat_template(
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
if chat_template is None:
return None
try:
with open(chat_template, "r") as f:
resolved_chat_template = f.read()
except OSError as e:
if isinstance(chat_template, Path):
raise

JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
Expand Down Expand Up @@ -208,3 +214,28 @@ def parse_chat_messages(
mm_futures.extend(parse_result.mm_futures)

return conversation, mm_futures


def apply_chat_template(
tokenizer: AnyTokenizer,
conversation: List[ConversationMessage],
chat_template: Optional[str],
*,
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> str:
if chat_template is None and tokenizer.chat_template is None:
raise ValueError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one.")

prompt = tokenizer.apply_chat_template(
conversation=conversation,
chat_template=chat_template,
tokenize=tokenize,
**kwargs,
)
assert isinstance(prompt, str)

return prompt
5 changes: 3 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
default=None,
description=(
"A Jinja template to use for this conversion. "
"If this is not passed, the model's default chat template will be "
"used instead."),
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."),
)
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
default=None,
Expand Down
8 changes: 4 additions & 4 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
apply_chat_template,
load_chat_template,
parse_chat_messages)
from vllm.entrypoints.logger import RequestLogger
Expand Down Expand Up @@ -99,16 +100,15 @@ async def create_chat_completion(
tool.model_dump() for tool in request.tools
]

prompt = tokenizer.apply_chat_template(
prompt = apply_chat_template(
tokenizer,
conversation=conversation,
tokenize=False,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template or self.chat_template,
**(request.chat_template_kwargs or {}),
)
assert isinstance(prompt, str)
except Exception as e:
logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e))
Expand Down
14 changes: 8 additions & 6 deletions vllm/entrypoints/openai/serving_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import load_chat_template, parse_chat_messages
from vllm.entrypoints.chat_utils import (apply_chat_template,
load_chat_template,
parse_chat_messages)
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
Expand Down Expand Up @@ -70,12 +72,12 @@ async def create_tokenize(
logger.warning(
"Multi-modal inputs are ignored during tokenization")

prompt = tokenizer.apply_chat_template(
add_generation_prompt=request.add_generation_prompt,
prompt = apply_chat_template(
tokenizer,
conversation=conversation,
tokenize=False,
chat_template=self.chat_template)
assert isinstance(prompt, str)
chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt,
)
else:
prompt = request.prompt

Expand Down
8 changes: 4 additions & 4 deletions vllm/transformers_utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
from vllm.utils import make_async

from .tokenizer_group import AnyTokenizer

logger = init_logger(__name__)


def get_cached_tokenizer(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
"""Get tokenizer with cached properties.
This will patch the tokenizer object in place.
Expand Down Expand Up @@ -63,7 +63,7 @@ def get_tokenizer(
revision: Optional[str] = None,
download_dir: Optional[str] = None,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
) -> AnyTokenizer:
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope.
"""
if VLLM_USE_MODELSCOPE:
Expand Down

0 comments on commit c646702

Please sign in to comment.