Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Gracefully handle missing chat template and fix CI failure #7238

Merged
21 changes: 8 additions & 13 deletions tests/async_engine/test_chat_template.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
import os
import pathlib

import pytest

from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.chat_utils import apply_chat_template, load_chat_template
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.transformers_utils.tokenizer import get_tokenizer

chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
__file__))).parent.parent / "examples/template_chatml.jinja"
from ..utils import VLLM_PATH

chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()

# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATON_OUTPUT = [
("facebook/opt-125m", None, True,
"Hello</s>Hi there!</s>What is the capital of</s>"),
("facebook/opt-125m", None, False,
"Hello</s>Hi there!</s>What is the capital of</s>"),
("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Expand Down Expand Up @@ -93,11 +87,12 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
add_generation_prompt=add_generation_prompt)

# Call the function and get the result
result = tokenizer.apply_chat_template(
result = apply_chat_template(
tokenizer,
conversation=mock_request.messages,
tokenize=False,
chat_template=mock_request.chat_template or template_content,
add_generation_prompt=mock_request.add_generation_prompt,
chat_template=mock_request.chat_template or template_content)
)

# Test assertion
assert result == expected_output, (
Expand Down
10 changes: 7 additions & 3 deletions tests/async_engine/test_openapi_server_ray.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import openai # use the official client for correctness check
import pytest

from ..utils import RemoteOpenAIServer
from ..utils import VLLM_PATH, RemoteOpenAIServer

# any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m"
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()


@pytest.fixture(scope="module")
Expand All @@ -16,7 +18,9 @@ def server():
"--max-model-len",
"2048",
"--enforce-eager",
"--engine-use-ray"
"--engine-use-ray",
"--chat-template",
str(chatml_jinja_path),
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
Expand Down Expand Up @@ -83,7 +87,7 @@ async def test_single_chat_session(client: openai.AsyncOpenAI):
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=13, total_tokens=23)
completion_tokens=10, prompt_tokens=55, total_tokens=65)

message = choice.message
assert message.content is not None and len(message.content) >= 10
Expand Down
30 changes: 28 additions & 2 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import codecs
from dataclasses import dataclass
from functools import lru_cache
from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast,
final)
from typing import (Any, Awaitable, Iterable, List, Optional, Tuple, Union,
cast, final)

# yapf conflicts with isort for this block
# yapf: disable
Expand All @@ -22,6 +22,7 @@
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import async_get_and_parse_image
from vllm.transformers_utils.tokenizer import AnyTokenizer

logger = init_logger(__name__)

Expand Down Expand Up @@ -208,3 +209,28 @@ def parse_chat_messages(
mm_futures.extend(parse_result.mm_futures)

return conversation, mm_futures


def apply_chat_template(
tokenizer: AnyTokenizer,
conversation: List[ConversationMessage],
chat_template: Optional[str],
*,
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> str:
if chat_template is None and tokenizer.chat_template is None:
raise ValueError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one.")

prompt = tokenizer.apply_chat_template(
conversation=conversation,
chat_template=chat_template,
tokenize=tokenize,
**kwargs,
)
assert isinstance(prompt, str)

return prompt
5 changes: 3 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
default=None,
description=(
"A Jinja template to use for this conversion. "
"If this is not passed, the model's default chat template will be "
"used instead."),
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."),
)
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
default=None,
Expand Down
8 changes: 4 additions & 4 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
apply_chat_template,
load_chat_template,
parse_chat_messages)
from vllm.entrypoints.logger import RequestLogger
Expand Down Expand Up @@ -98,16 +99,15 @@ async def create_chat_completion(
tool.model_dump() for tool in request.tools
]

prompt = tokenizer.apply_chat_template(
prompt = apply_chat_template(
tokenizer,
conversation=conversation,
tokenize=False,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template or self.chat_template,
**(request.chat_template_kwargs or {}),
)
assert isinstance(prompt, str)
except Exception as e:
logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e))
Expand Down
14 changes: 8 additions & 6 deletions vllm/entrypoints/openai/serving_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import load_chat_template, parse_chat_messages
from vllm.entrypoints.chat_utils import (apply_chat_template,
load_chat_template,
parse_chat_messages)
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
Expand Down Expand Up @@ -70,12 +72,12 @@ async def create_tokenize(
logger.warning(
"Multi-modal inputs are ignored during tokenization")

prompt = tokenizer.apply_chat_template(
add_generation_prompt=request.add_generation_prompt,
prompt = apply_chat_template(
tokenizer,
conversation=conversation,
tokenize=False,
chat_template=self.chat_template)
assert isinstance(prompt, str)
chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt,
)
else:
prompt = request.prompt

Expand Down
8 changes: 4 additions & 4 deletions vllm/transformers_utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
from vllm.utils import make_async

from .tokenizer_group import AnyTokenizer

logger = init_logger(__name__)


def get_cached_tokenizer(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
"""Get tokenizer with cached properties.

This will patch the tokenizer object in place.
Expand Down Expand Up @@ -63,7 +63,7 @@ def get_tokenizer(
revision: Optional[str] = None,
download_dir: Optional[str] = None,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
) -> AnyTokenizer:
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope.
"""
if VLLM_USE_MODELSCOPE:
Expand Down
Loading