diff --git a/tests/async_engine/test_chat_template.py b/tests/async_engine/test_chat_template.py index aea8a7fed6e33..4df6c02973284 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/async_engine/test_chat_template.py @@ -1,22 +1,16 @@ -import os -import pathlib - import pytest -from vllm.entrypoints.chat_utils import load_chat_template +from vllm.entrypoints.chat_utils import apply_chat_template, load_chat_template from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.transformers_utils.tokenizer import get_tokenizer -chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath( - __file__))).parent.parent / "examples/template_chatml.jinja" +from ..utils import VLLM_PATH + +chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" assert chatml_jinja_path.exists() # Define models, templates, and their corresponding expected outputs MODEL_TEMPLATE_GENERATON_OUTPUT = [ - ("facebook/opt-125m", None, True, - "HelloHi there!What is the capital of"), - ("facebook/opt-125m", None, False, - "HelloHi there!What is the capital of"), ("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user Hello<|im_end|> <|im_start|>assistant @@ -93,11 +87,12 @@ def test_get_gen_prompt(model, template, add_generation_prompt, add_generation_prompt=add_generation_prompt) # Call the function and get the result - result = tokenizer.apply_chat_template( + result = apply_chat_template( + tokenizer, conversation=mock_request.messages, - tokenize=False, + chat_template=mock_request.chat_template or template_content, add_generation_prompt=mock_request.add_generation_prompt, - chat_template=mock_request.chat_template or template_content) + ) # Test assertion assert result == expected_output, ( diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index 5ecd770ede836..0d53b39e7ce1c 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -1,10 +1,12 @@ import openai # use the official client for correctness check import pytest -from ..utils import RemoteOpenAIServer +from ..utils import VLLM_PATH, RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "facebook/opt-125m" +chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" +assert chatml_jinja_path.exists() @pytest.fixture(scope="module") @@ -16,7 +18,9 @@ def server(): "--max-model-len", "2048", "--enforce-eager", - "--engine-use-ray" + "--engine-use-ray", + "--chat-template", + str(chatml_jinja_path), ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -83,7 +87,7 @@ async def test_single_chat_session(client: openai.AsyncOpenAI): choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=13, total_tokens=23) + completion_tokens=10, prompt_tokens=55, total_tokens=65) message = choice.message assert message.content is not None and len(message.content) >= 10 diff --git a/tests/entrypoints/openai/test_oot_registration.py b/tests/entrypoints/openai/test_oot_registration.py index 5272ac4065f1d..9f9a4cd972c51 100644 --- a/tests/entrypoints/openai/test_oot_registration.py +++ b/tests/entrypoints/openai/test_oot_registration.py @@ -9,6 +9,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.utils import get_open_port +from ...utils import VLLM_PATH, RemoteOpenAIServer + +chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" +assert chatml_jinja_path.exists() + class MyOPTForCausalLM(OPTForCausalLM): @@ -21,12 +26,25 @@ def compute_logits(self, hidden_states: torch.Tensor, return logits -def server_function(port): +def server_function(port: int): # register our dummy model ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) - sys.argv = ["placeholder.py"] + \ - ("--model facebook/opt-125m --gpu-memory-utilization 0.10 " - f"--dtype float32 --api-key token-abc123 --port {port}").split() + + sys.argv = ["placeholder.py"] + [ + "--model", + "facebook/opt-125m", + "--gpu-memory-utilization", + "0.10", + "--dtype", + "float32", + "--api-key", + "token-abc123", + "--port", + str(port), + "--chat-template", + str(chatml_jinja_path), + ] + import runpy runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') @@ -36,35 +54,40 @@ def test_oot_registration_for_api_server(): ctx = torch.multiprocessing.get_context() server = ctx.Process(target=server_function, args=(port, )) server.start() - MAX_SERVER_START_WAIT_S = 60 - client = OpenAI( - base_url=f"http://localhost:{port}/v1", - api_key="token-abc123", - ) - now = time.time() - while True: - try: - completion = client.chat.completions.create( - model="facebook/opt-125m", - messages=[{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Hello!" - }], - temperature=0, - ) - break - except OpenAIError as e: - if "Connection error" in str(e): - time.sleep(3) - if time.time() - now > MAX_SERVER_START_WAIT_S: - raise RuntimeError("Server did not start in time") from e - else: - raise e - server.kill() + + try: + client = OpenAI( + base_url=f"http://localhost:{port}/v1", + api_key="token-abc123", + ) + now = time.time() + while True: + try: + completion = client.chat.completions.create( + model="facebook/opt-125m", + messages=[{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Hello!" + }], + temperature=0, + ) + break + except OpenAIError as e: + if "Connection error" in str(e): + time.sleep(3) + if time.time() - now > RemoteOpenAIServer.MAX_START_WAIT_S: + msg = "Server did not start in time" + raise RuntimeError(msg) from e + else: + raise e + finally: + server.terminate() + generated_text = completion.choices[0].message.content + assert generated_text is not None # make sure only the first token is generated rest = generated_text.replace("", "") assert rest == "" diff --git a/tests/utils.py b/tests/utils.py index bd431b85d2663..e3d04cc638a95 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -50,7 +50,7 @@ def _nvml(): class RemoteOpenAIServer: DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key - MAX_SERVER_START_WAIT_S = 120 # wait for server to start for 120 seconds + MAX_START_WAIT_S = 120 # wait for server to start for 120 seconds def __init__( self, @@ -85,7 +85,7 @@ def __init__( stdout=sys.stdout, stderr=sys.stderr) self._wait_for_server(url=self.url_for("health"), - timeout=self.MAX_SERVER_START_WAIT_S) + timeout=self.MAX_START_WAIT_S) def __enter__(self): return self diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 072450a6146ee..12634c3261856 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1,8 +1,9 @@ import codecs from dataclasses import dataclass from functools import lru_cache -from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast, - final) +from pathlib import Path +from typing import (Any, Awaitable, Iterable, List, Optional, Tuple, Union, + cast, final) # yapf conflicts with isort for this block # yapf: disable @@ -22,6 +23,7 @@ from vllm.logger import init_logger from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import async_get_and_parse_image +from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) @@ -69,13 +71,17 @@ class ChatMessageParseResult: mm_futures: List[Awaitable[MultiModalDataDict]] -def load_chat_template(chat_template: Optional[str]) -> Optional[str]: +def load_chat_template( + chat_template: Optional[Union[Path, str]]) -> Optional[str]: if chat_template is None: return None try: with open(chat_template, "r") as f: resolved_chat_template = f.read() except OSError as e: + if isinstance(chat_template, Path): + raise + JINJA_CHARS = "{}\n" if not any(c in chat_template for c in JINJA_CHARS): msg = (f"The supplied chat template ({chat_template}) " @@ -208,3 +214,28 @@ def parse_chat_messages( mm_futures.extend(parse_result.mm_futures) return conversation, mm_futures + + +def apply_chat_template( + tokenizer: AnyTokenizer, + conversation: List[ConversationMessage], + chat_template: Optional[str], + *, + tokenize: bool = False, # Different from HF's default + **kwargs: Any, +) -> str: + if chat_template is None and tokenizer.chat_template is None: + raise ValueError( + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one.") + + prompt = tokenizer.apply_chat_template( + conversation=conversation, + chat_template=chat_template, + tokenize=tokenize, + **kwargs, + ) + assert isinstance(prompt, str) + + return prompt diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 76318a1271229..70467bd879690 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -190,8 +190,9 @@ class ChatCompletionRequest(OpenAIBaseModel): default=None, description=( "A Jinja template to use for this conversion. " - "If this is not passed, the model's default chat template will be " - "used instead."), + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one."), ) chat_template_kwargs: Optional[Dict[str, Any]] = Field( default=None, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index add1ce8acc95e..2167b967b14b5 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -10,6 +10,7 @@ from vllm.config import ModelConfig from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, + apply_chat_template, load_chat_template, parse_chat_messages) from vllm.entrypoints.logger import RequestLogger @@ -99,16 +100,15 @@ async def create_chat_completion( tool.model_dump() for tool in request.tools ] - prompt = tokenizer.apply_chat_template( + prompt = apply_chat_template( + tokenizer, conversation=conversation, - tokenize=False, + chat_template=request.chat_template or self.chat_template, add_generation_prompt=request.add_generation_prompt, tools=tool_dicts, documents=request.documents, - chat_template=request.chat_template or self.chat_template, **(request.chat_template_kwargs or {}), ) - assert isinstance(prompt, str) except Exception as e: logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 5b6b979b9b9e7..1aeabb7a7d729 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -2,7 +2,9 @@ from vllm.config import ModelConfig from vllm.engine.protocol import AsyncEngineClient -from vllm.entrypoints.chat_utils import load_chat_template, parse_chat_messages +from vllm.entrypoints.chat_utils import (apply_chat_template, + load_chat_template, + parse_chat_messages) from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -70,12 +72,12 @@ async def create_tokenize( logger.warning( "Multi-modal inputs are ignored during tokenization") - prompt = tokenizer.apply_chat_template( - add_generation_prompt=request.add_generation_prompt, + prompt = apply_chat_template( + tokenizer, conversation=conversation, - tokenize=False, - chat_template=self.chat_template) - assert isinstance(prompt, str) + chat_template=self.chat_template, + add_generation_prompt=request.add_generation_prompt, + ) else: prompt = request.prompt diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index bf26d889d1388..25e4c41592c68 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -12,12 +12,12 @@ from vllm.transformers_utils.tokenizers import BaichuanTokenizer from vllm.utils import make_async +from .tokenizer_group import AnyTokenizer + logger = init_logger(__name__) -def get_cached_tokenizer( - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] -) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: +def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: """Get tokenizer with cached properties. This will patch the tokenizer object in place. @@ -63,7 +63,7 @@ def get_tokenizer( revision: Optional[str] = None, download_dir: Optional[str] = None, **kwargs, -) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: +) -> AnyTokenizer: """Gets a tokenizer for the given model name via HuggingFace or ModelScope. """ if VLLM_USE_MODELSCOPE: