From 32e46e000f77499f4dd7c0bed194e33856f2df24 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 16 Nov 2024 13:35:40 +0800 Subject: [PATCH] [Frontend] Automatic detection of chat content format from AST (#9919) Signed-off-by: DarkLight1337 --- .../serving/openai_compatible_server.md | 18 +- tests/entrypoints/openai/test_serving_chat.py | 3 +- tests/entrypoints/test_chat_utils.py | 619 +++++++++++------- vllm/config.py | 2 - vllm/engine/arg_utils.py | 10 - vllm/engine/llm_engine.py | 4 +- vllm/entrypoints/chat_utils.py | 246 ++++++- vllm/entrypoints/llm.py | 44 +- vllm/entrypoints/openai/api_server.py | 13 +- vllm/entrypoints/openai/cli_args.py | 17 +- vllm/entrypoints/openai/protocol.py | 71 +- vllm/entrypoints/openai/run_batch.py | 2 + vllm/entrypoints/openai/serving_chat.py | 40 +- vllm/entrypoints/openai/serving_embedding.py | 12 +- vllm/entrypoints/openai/serving_engine.py | 17 +- .../openai/serving_tokenization.py | 20 +- 16 files changed, 788 insertions(+), 350 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 78965813b1213..79d032bf8b211 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -172,12 +172,20 @@ completion = client.chat.completions.create( ] ) ``` -Most chat templates for LLMs expect the `content` to be a `string` but there are some newer models like -`meta-llama/Llama-Guard-3-1B` that expect the content to be parsed with the new OpenAI spec. In order to choose which -format the content needs to be parsed in by vLLM, please use the `--chat-template-text-format` argument to specify -between `string` or `openai`. The default value is `string` and vLLM internally converts both spec formats to match -this, unless explicitly specified. +Most chat templates for LLMs expect the `content` field to be a string, but there are some newer models like +`meta-llama/Llama-Guard-3-1B` that expect the content to be formatted according to the OpenAI schema in the +request. vLLM provides best-effort support to detect this automatically, which is logged as a string like +*"Detected the chat template content format to be..."*, and internally converts incoming requests to match +the detected format, which can be one of: + +- `"string"`: A string. + - Example: `"Hello world"` +- `"openai"`: A list of dictionaries, similar to OpenAI schema. + - Example: `[{"type": "text", "text": "Hello world!"}]` + +If the result is not what you expect, you can set the `--chat-template-content-format` CLI argument +to override which format to use. ## Command line arguments for the server diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index e969d33775d86..93660e6118ca8 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -26,7 +26,6 @@ class MockModelConfig: tokenizer = MODEL_NAME trust_remote_code = False tokenizer_mode = "auto" - chat_template_text_format = "string" max_model_len = 100 tokenizer_revision = None multimodal_config = MultiModalConfig() @@ -49,6 +48,7 @@ async def _async_serving_chat_init(): BASE_MODEL_PATHS, response_role="assistant", chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", lora_modules=None, prompt_adapters=None, request_logger=None) @@ -70,6 +70,7 @@ def test_serving_chat_should_set_correct_max_tokens(): BASE_MODEL_PATHS, response_role="assistant", chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", lora_modules=None, prompt_adapters=None, request_logger=None) diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 5fa466f8f041f..72477e048eafa 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -6,15 +6,24 @@ from vllm.assets.image import ImageAsset from vllm.config import ModelConfig -from vllm.entrypoints.chat_utils import (parse_chat_messages, - parse_chat_messages_futures) +from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template, + parse_chat_messages, + parse_chat_messages_futures, + resolve_chat_template_content_format) from vllm.entrypoints.llm import apply_hf_chat_template from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import encode_image_base64 from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from ..utils import VLLM_PATH + +EXAMPLES_DIR = VLLM_PATH / "examples" + PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" +ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_3" +QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" +LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B" @pytest.fixture(scope="function") @@ -26,7 +35,6 @@ def phi3v_model_config(): trust_remote_code=True, dtype="bfloat16", seed=0, - chat_template_text_format="string", limit_mm_per_prompt={ "image": 2, }) @@ -94,19 +102,24 @@ def test_parse_chat_messages_single_image( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in the image?" - }] - }], phi3v_model_config, phi3v_tokenizer) + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in the image?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [{ "role": "user", @@ -121,19 +134,24 @@ async def test_parse_chat_messages_single_image_async( phi3v_tokenizer, image_url, ): - conversation, mm_future = parse_chat_messages_futures([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in the image?" - }] - }], phi3v_model_config, phi3v_tokenizer) + conversation, mm_future = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in the image?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [{ "role": "user", @@ -147,24 +165,29 @@ def test_parse_chat_messages_multiple_images( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in these images?" - }] - }], phi3v_model_config, phi3v_tokenizer) + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in these images?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [{ "role": @@ -181,24 +204,29 @@ async def test_parse_chat_messages_multiple_images_async( phi3v_tokenizer, image_url, ): - conversation, mm_future = parse_chat_messages_futures([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in these images?" - }] - }], phi3v_model_config, phi3v_tokenizer) + conversation, mm_future = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in these images?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [{ "role": @@ -214,27 +242,31 @@ def test_parse_chat_messages_placeholder_already_in_prompt( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": - "text", - "text": - "What's in <|image_1|> and how does it compare to <|image_2|>?" - }] - }], phi3v_model_config, phi3v_tokenizer) - + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": + "text", + "text": + "What's in <|image_1|> and how does it compare to <|image_2|>?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [{ "role": "user", @@ -249,26 +281,35 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": - "text", - "text": - "What's in <|image_1|> and how does it compare to the other one?" - }] - }], phi3v_model_config, phi3v_tokenizer) + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": + "text", + "text": + "What's in <|image_1|> and how does it compare to the other one?" # noqa: E501 + } + ] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [{ "role": @@ -285,34 +326,39 @@ def test_parse_chat_messages_multiple_images_across_messages( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in this image?" + }] }, { - "type": "text", - "text": "What's in this image?" - }] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } + "role": "assistant", + "content": "Some stuff." }, { - "type": "text", - "text": "What about this one?" - }] - }], phi3v_model_config, phi3v_tokenizer) + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What about this one?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [ { @@ -335,7 +381,6 @@ def test_parse_chat_messages_context_text_format( phi3v_model_config, phi3v_tokenizer, ): - phi3v_model_config.chat_template_text_format = "openai" conversation, mm_data = parse_chat_messages( [{ "role": "user", @@ -349,7 +394,11 @@ def test_parse_chat_messages_context_text_format( }, { "role": "user", "content": "What about this one?" - }], phi3v_model_config, phi3v_tokenizer) + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="openai", + ) assert conversation == [ { @@ -389,29 +438,34 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message( ValueError, match="At most 2 image\\(s\\) may be provided in one request\\." ): - parse_chat_messages([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in these images?" - }] - }], phi3v_model_config, phi3v_tokenizer) + parse_chat_messages( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in these images?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) def test_parse_chat_messages_rejects_too_many_images_across_messages( @@ -427,39 +481,44 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages( ValueError, match="At most 2 image\\(s\\) may be provided in one request\\." ): - parse_chat_messages([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } + parse_chat_messages( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in this image?" + }] }, { - "type": "text", - "text": "What's in this image?" - }] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } + "role": "assistant", + "content": "Some stuff." }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What about these two?" - }] - }], phi3v_model_config, phi3v_tokenizer) + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What about these two?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) def test_parse_chat_messages_multiple_images_uncommon_input( @@ -467,17 +526,22 @@ def test_parse_chat_messages_multiple_images_uncommon_input( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [ - "What's in these images?", { - "image_url": image_url - }, { - "image_url": image_url - } - ] - }], phi3v_model_config, phi3v_tokenizer) + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [ + "What's in these images?", { + "image_url": image_url + }, { + "image_url": image_url + } + ] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [{ "role": @@ -495,16 +559,21 @@ def test_mllama_single_image( image_url, ): """Ensures that a single image is parsed correctly mllama.""" - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [{ - 'type': 'text', - 'text': 'The content of this image is:' - }, { - "image_url": image_url - }] - }], mllama_model_config, mllama_tokenizer) + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [{ + 'type': 'text', + 'text': 'The content of this image is:' + }, { + "image_url": image_url + }] + }], + mllama_model_config, + mllama_tokenizer, + content_format="openai", + ) _assert_mm_data_is_image_input(mm_data, 1) assert conversation == [{ 'role': @@ -524,26 +593,31 @@ def test_mllama_interleaved_images( image_url, ): """Ensures that multiple image are parsed as interleaved dicts.""" - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [ - { - 'type': 'text', - 'text': 'The content of the first image is:' - }, - { - "image_url": image_url - }, - { - 'type': 'text', - 'text': 'The content of the second image is:' - }, - { - "image_url": image_url - }, - ] - }], mllama_model_config, mllama_tokenizer) + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + 'type': 'text', + 'text': 'The content of the first image is:' + }, + { + "image_url": image_url + }, + { + 'type': 'text', + 'text': 'The content of the second image is:' + }, + { + "image_url": image_url + }, + ] + }], + mllama_model_config, + mllama_tokenizer, + content_format="openai", + ) _assert_mm_data_is_image_input(mm_data, 2) assert conversation == [{ 'role': @@ -626,6 +700,7 @@ def get_conversation(is_hf: bool): vllm_conversation, model_config, tokenizer_group, + content_format="openai", ) vllm_result = apply_hf_chat_template( @@ -636,3 +711,89 @@ def get_conversation(is_hf: bool): ) assert hf_result == vllm_result + + +# yapf: disable +@pytest.mark.parametrize( + ("model", "expected_format"), + [(PHI3V_MODEL_ID, "string"), + (QWEN2VL_MODEL_ID, "openai"), + (ULTRAVOX_MODEL_ID, "string"), + (MLLAMA_MODEL_ID, "openai"), + (LLAMA_GUARD_MODEL_ID, "openai")], +) +# yapf: enable +def test_resolve_content_format_hf_defined(model, expected_format): + tokenizer_group = TokenizerGroup( + model, + enable_lora=False, + max_num_seqs=5, + max_input_length=None, + ) + tokenizer = tokenizer_group.tokenizer + + chat_template = tokenizer.chat_template + assert isinstance(chat_template, str) + + print("[TEXT]") + print(chat_template) + print("[AST]") + print(_try_extract_ast(chat_template)) + + resolved_format = resolve_chat_template_content_format( + None, # Test detecting the tokenizer's chat_template + "auto", + tokenizer, + ) + + assert resolved_format == expected_format + + +# yapf: disable +@pytest.mark.parametrize( + ("template_path", "expected_format"), + [("template_alpaca.jinja", "string"), + ("template_baichuan.jinja", "string"), + ("template_blip2.jinja", "string"), + ("template_chatglm.jinja", "string"), + ("template_chatglm2.jinja", "string"), + ("template_chatml.jinja", "string"), + ("template_falcon_180b.jinja", "string"), + ("template_falcon.jinja", "string"), + ("template_inkbot.jinja", "string"), + ("template_llava.jinja", "string"), + ("template_vlm2vec.jinja", "openai"), + ("tool_chat_template_granite_20b_fc.jinja", "string"), + ("tool_chat_template_hermes.jinja", "string"), + ("tool_chat_template_internlm2_tool.jinja", "string"), + ("tool_chat_template_llama3.1_json.jinja", "string"), + ("tool_chat_template_llama3.2_json.jinja", "string"), + ("tool_chat_template_mistral_parallel.jinja", "string"), + ("tool_chat_template_mistral.jinja", "string")], +) +# yapf: enable +def test_resolve_content_format_examples(template_path, expected_format): + tokenizer_group = TokenizerGroup( + PHI3V_MODEL_ID, + enable_lora=False, + max_num_seqs=5, + max_input_length=None, + ) + dummy_tokenizer = tokenizer_group.tokenizer + dummy_tokenizer.chat_template = None + + chat_template = load_chat_template(EXAMPLES_DIR / template_path) + assert isinstance(chat_template, str) + + print("[TEXT]") + print(chat_template) + print("[AST]") + print(_try_extract_ast(chat_template)) + + resolved_format = resolve_chat_template_content_format( + chat_template, + "auto", + dummy_tokenizer, + ) + + assert resolved_format == expected_format diff --git a/vllm/config.py b/vllm/config.py index 1c190da1d327e..64b2f75e092de 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -155,7 +155,6 @@ def __init__( limit_mm_per_prompt: Optional[Mapping[str, int]] = None, use_async_output_proc: bool = True, config_format: ConfigFormat = ConfigFormat.AUTO, - chat_template_text_format: str = "string", hf_overrides: Optional[HfOverrides] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, override_neuron_config: Optional[Dict[str, Any]] = None, @@ -216,7 +215,6 @@ def __init__( self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.use_async_output_proc = use_async_output_proc - self.chat_template_text_format = chat_template_text_format self.mm_processor_kwargs = mm_processor_kwargs # Set enforce_eager to False if the value is unset. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d73f95f59c71f..92fa87c7fa45b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -90,7 +90,6 @@ class EngineArgs: task: TaskOption = "auto" skip_tokenizer_init: bool = False tokenizer_mode: str = 'auto' - chat_template_text_format: str = 'string' trust_remote_code: bool = False allowed_local_media_path: str = "" download_dir: Optional[str] = None @@ -258,14 +257,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'fast tokenizer if available.\n* "slow" will ' 'always use the slow tokenizer. \n* ' '"mistral" will always use the `mistral_common` tokenizer.') - parser.add_argument( - '--chat-template-text-format', - type=str, - default=EngineArgs.chat_template_text_format, - choices=['string', 'openai'], - help='The format to render text content within a chat template. ' - '"string" will keep the content field as a string whereas ' - '"openai" will parse content in the current OpenAI format.') parser.add_argument('--trust-remote-code', action='store_true', help='Trust remote code from huggingface.') @@ -894,7 +885,6 @@ def create_model_config(self) -> ModelConfig: # We know this is not None because we set it in __post_init__ tokenizer=cast(str, self.tokenizer), tokenizer_mode=self.tokenizer_mode, - chat_template_text_format=self.chat_template_text_format, trust_remote_code=self.trust_remote_code, allowed_local_media_path=self.allowed_local_media_path, dtype=self.dtype, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index aa9c7893c4cfe..9a2d73a020c8f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -262,8 +262,7 @@ def __init__( "num_scheduler_steps=%d, chunked_prefill_enabled=%s " "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " "use_async_output_proc=%s, use_cached_outputs=%s, " - "chat_template_text_format=%s, mm_processor_kwargs=%s, " - "pooler_config=%r)", + "mm_processor_kwargs=%s, pooler_config=%r)", VLLM_VERSION, model_config.model, speculative_config, @@ -296,7 +295,6 @@ def __init__( cache_config.enable_prefix_caching, model_config.use_async_output_proc, use_cached_outputs, - model_config.chat_template_text_format, model_config.mm_processor_kwargs, model_config.pooler_config, ) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 3ca460c47c3bd..abee5ac46391c 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -2,12 +2,14 @@ import codecs import json from abc import ABC, abstractmethod -from collections import defaultdict +from collections import defaultdict, deque from functools import lru_cache, partial from pathlib import Path from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List, Literal, Mapping, Optional, Tuple, TypeVar, Union, cast) +import jinja2.nodes +import transformers.utils.chat_template_utils as hf_chat_utils # yapf conflicts with isort for this block # yapf: disable from openai.types.chat import (ChatCompletionAssistantMessageParam, @@ -153,6 +155,199 @@ class ConversationMessage(TypedDict, total=False): """The tool calls generated by the model, such as function calls.""" +# Passed in by user +ChatTemplateContentFormatOption = Literal["auto", "string", "openai"] + +# Used internally +_ChatTemplateContentFormat = Literal["string", "openai"] + + +def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: + if isinstance(node, jinja2.nodes.Name): + return node.ctx == "load" and node.name == varname + + return False + + +def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: + if isinstance(node, jinja2.nodes.Getitem): + return (_is_var_access(node.node, varname) + and isinstance(node.arg, jinja2.nodes.Const) + and node.arg.value == key) + + if isinstance(node, jinja2.nodes.Getattr): + return _is_var_access(node.node, varname) and node.attr == key + + return False + + +def _is_var_or_elems_access( + node: jinja2.nodes.Node, + varname: str, + key: Optional[str] = None, +) -> bool: + if isinstance(node, jinja2.nodes.Filter): + return (node.node is not None + and _is_var_or_elems_access(node.node, varname, key)) + if isinstance(node, jinja2.nodes.Test): + return _is_var_or_elems_access(node.node, varname, key) + + if (isinstance(node, jinja2.nodes.Getitem) + and isinstance(node.arg, jinja2.nodes.Slice)): + return _is_var_or_elems_access(node.node, varname, key) + + # yapf: disable + return ( + _is_attr_access(node, varname, key) if key + else _is_var_access(node, varname) + ) # yapf: enable + + +def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): + # Global variable that is implicitly defined at the root + yield root, varname + + # Iterative BFS + related_varnames = deque([varname]) + while related_varnames: + related_varname = related_varnames.popleft() + + for assign_ast in root.find_all(jinja2.nodes.Assign): + lhs = assign_ast.target + rhs = assign_ast.node + + if _is_var_or_elems_access(rhs, related_varname): + assert isinstance(lhs, jinja2.nodes.Name) + yield assign_ast, lhs.name + + # Avoid infinite looping for self-assignment + if lhs.name != related_varname: + related_varnames.append(lhs.name) + + +# NOTE: The proper way to handle this is to build a CFG so that we can handle +# the scope in which each variable is defined, but that is too complicated +def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node): + messages_varnames = [ + varname + for _, varname in _iter_nodes_assign_var_or_elems(root, "messages") + ] + + # Search for {%- for message in messages -%} loops + for loop_ast in root.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + loop_target = loop_ast.target + + for varname in messages_varnames: + if _is_var_or_elems_access(loop_iter, varname): + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name + break + + +def _iter_nodes_assign_content_item(root: jinja2.nodes.Node): + message_varnames = [ + varname for _, varname in _iter_nodes_assign_messages_item(root) + ] + + # Search for {%- for content in message['content'] -%} loops + for loop_ast in root.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + loop_target = loop_ast.target + + for varname in message_varnames: + if _is_var_or_elems_access(loop_iter, varname, "content"): + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name + break + + +def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]: + try: + jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) + return jinja_compiled.environment.parse(chat_template) + except Exception: + logger.exception("Error when compiling Jinja template") + return None + + +def _detect_content_format( + chat_template: str, + *, + default: _ChatTemplateContentFormat, +) -> _ChatTemplateContentFormat: + jinja_ast = _try_extract_ast(chat_template) + if jinja_ast is None: + return default + + try: + next(_iter_nodes_assign_content_item(jinja_ast)) + except StopIteration: + return "string" + except Exception: + logger.exception("Error when parsing AST of Jinja template") + return default + else: + return "openai" + + +def _resolve_chat_template_content_format( + chat_template: Optional[str], + given_format: ChatTemplateContentFormatOption, + tokenizer: AnyTokenizer, +) -> _ChatTemplateContentFormat: + if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + tokenizer_chat_template = tokenizer.chat_template + else: + tokenizer_chat_template = None + + jinja_text: Optional[str] + if isinstance(tokenizer_chat_template, str) and chat_template is None: + jinja_text = tokenizer_chat_template + elif (isinstance(tokenizer_chat_template, dict) + and chat_template in tokenizer_chat_template): + jinja_text = tokenizer_chat_template[chat_template] + else: + jinja_text = load_chat_template(chat_template, is_literal=True) + + detected_format = ("string" if jinja_text is None else + _detect_content_format(jinja_text, default="string")) + + return detected_format if given_format == "auto" else given_format + + +@lru_cache +def resolve_chat_template_content_format( + chat_template: Optional[str], + given_format: ChatTemplateContentFormatOption, + tokenizer: AnyTokenizer, +) -> _ChatTemplateContentFormat: + detected_format = _resolve_chat_template_content_format( + chat_template, + given_format, + tokenizer, + ) + + logger.info( + "Detected the chat template content format to be '%s'. " + "You can set `--chat-template-content-format` to override this.", + detected_format, + ) + + if given_format != "auto" and given_format != detected_format: + logger.warning( + "You specified `--chat-template-content-format %s` " + "which is different from the detected format '%s'. " + "If our automatic detection is incorrect, please consider " + "opening a GitHub issue so that we can improve it: " + "https://github.com/vllm-project/vllm/issues/new/choose", + given_format, + detected_format, + ) + + return detected_format + + ModalityStr = Literal["image", "audio", "video"] _T = TypeVar("_T") @@ -407,12 +602,23 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]): def load_chat_template( - chat_template: Optional[Union[Path, str]]) -> Optional[str]: + chat_template: Optional[Union[Path, str]], + *, + is_literal: bool = False, +) -> Optional[str]: if chat_template is None: return None + + if is_literal: + if isinstance(chat_template, Path): + raise TypeError("chat_template is expected to be read directly " + "from its value") + + return codecs.decode(chat_template, "unicode_escape") + try: with open(chat_template) as f: - resolved_chat_template = f.read() + return f.read() except OSError as e: if isinstance(chat_template, Path): raise @@ -426,10 +632,7 @@ def load_chat_template( # If opening a file fails, set chat template to be args to # ensure we decode so our escape are interpreted correctly - resolved_chat_template = codecs.decode(chat_template, "unicode_escape") - - logger.info("Using supplied chat template:\n%s", resolved_chat_template) - return resolved_chat_template + return load_chat_template(chat_template, is_literal=True) # TODO: Let user specify how to insert multimodal tokens into prompt @@ -464,7 +667,6 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], _AudioParser = partial(cast, ChatCompletionContentPartAudioParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _VideoParser = partial(cast, ChatCompletionContentPartVideoParam) -MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'} # Define a mapping from part types to their corresponding parsing functions. MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = { @@ -542,18 +744,12 @@ def _parse_chat_message_content_parts( role: str, parts: Iterable[ChatCompletionContentPartParam], mm_tracker: BaseMultiModalItemTracker, - chat_template_text_format: str, + *, + wrap_dicts: bool, ) -> List[ConversationMessage]: content: List[Union[str, Dict[str, str]]] = [] mm_parser = mm_tracker.create_parser() - model_config = mm_tracker.model_config - - wrap_dicts = (chat_template_text_format == "openai" - or (model_config.task == "embedding" - and model_config.is_multimodal_model) - or (model_config.hf_config.model_type - in MODEL_KEEP_MULTI_MODAL_CONTENT)) for part in parts: parse_res = _parse_chat_message_content_part( @@ -578,9 +774,11 @@ def _parse_chat_message_content_parts( def _parse_chat_message_content_part( - part: ChatCompletionContentPartParam, - mm_parser: BaseMultiModalContentParser, - wrap_dicts: bool) -> Optional[Union[str, Dict[str, str]]]: + part: ChatCompletionContentPartParam, + mm_parser: BaseMultiModalContentParser, + *, + wrap_dicts: bool, +) -> Optional[Union[str, Dict[str, str]]]: """Parses a single part of a conversation. If wrap_dicts is True, structured dictionary pieces for texts and images will be wrapped in dictionaries, i.e., {"type": "text", "text", ...} and @@ -629,7 +827,7 @@ def _parse_chat_message_content_part( def _parse_chat_message_content( message: ChatCompletionMessageParam, mm_tracker: BaseMultiModalItemTracker, - chat_template_text_format: str, + content_format: _ChatTemplateContentFormat, ) -> List[ConversationMessage]: role = message["role"] content = message.get("content") @@ -645,7 +843,7 @@ def _parse_chat_message_content( role, content, # type: ignore mm_tracker, - chat_template_text_format, + wrap_dicts=(content_format == "openai"), ) for result_msg in result: @@ -684,6 +882,7 @@ def parse_chat_messages( messages: List[ChatCompletionMessageParam], model_config: ModelConfig, tokenizer: AnyTokenizer, + content_format: _ChatTemplateContentFormat, ) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]: conversation: List[ConversationMessage] = [] mm_tracker = MultiModalItemTracker(model_config, tokenizer) @@ -692,7 +891,7 @@ def parse_chat_messages( sub_messages = _parse_chat_message_content( msg, mm_tracker, - model_config.chat_template_text_format, + content_format, ) conversation.extend(sub_messages) @@ -706,6 +905,7 @@ def parse_chat_messages_futures( messages: List[ChatCompletionMessageParam], model_config: ModelConfig, tokenizer: AnyTokenizer, + content_format: _ChatTemplateContentFormat, ) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]: conversation: List[ConversationMessage] = [] mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) @@ -714,7 +914,7 @@ def parse_chat_messages_futures( sub_messages = _parse_chat_message_content( msg, mm_tracker, - model_config.chat_template_text_format, + content_format, ) conversation.extend(sub_messages) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 4b33fc1458ee3..86b0b6893f1d9 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -13,9 +13,11 @@ TaskOption) from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + ChatTemplateContentFormatOption, apply_hf_chat_template, apply_mistral_chat_template, - parse_chat_messages) + parse_chat_messages, + resolve_chat_template_content_format) from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger @@ -523,6 +525,7 @@ def chat( use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, chat_template: Optional[str] = None, + chat_template_content_format: ChatTemplateContentFormatOption = "auto", add_generation_prompt: bool = True, continue_final_message: bool = False, tools: Optional[List[Dict[str, Any]]] = None, @@ -539,9 +542,11 @@ def chat( to the OpenAI API. Args: - messages: A list of conversations or a single conversation. - - Each conversation is represented as a list of messages. - - Each message is a dictionary with 'role' and 'content' keys. + messages: A list of conversations or a single conversation. + + - Each conversation is represented as a list of messages. + - Each message is a dictionary with 'role' and 'content' keys. + sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. When it @@ -551,11 +556,19 @@ def chat( lora_request: LoRA request to use for generation, if any. chat_template: The template to use for structuring the chat. If not provided, the model's default chat template will be used. + chat_template_content_format: The format to render message content. + + - "string" will render the content as a string. + Example: ``"Who are you?"`` + - "openai" will render the content as a list of dictionaries, + similar to OpenAI schema. + Example: ``[{"type": "text", "text": "Who are you?"}]`` + add_generation_prompt: If True, adds a generation template to each message. continue_final_message: If True, continues the final message in - the conversation instead of starting a new one. Cannot be `True` - if `add_generation_prompt` is also `True`. + the conversation instead of starting a new one. Cannot be + ``True`` if ``add_generation_prompt`` is also ``True``. mm_processor_kwargs: Multimodal processor kwarg overrides for this chat request. Only used for offline requests. @@ -576,17 +589,26 @@ def chat( cast(List[ChatCompletionMessageParam], messages) ] + tokenizer = self.get_tokenizer() + model_config = self.llm_engine.get_model_config() + resolved_content_format = resolve_chat_template_content_format( + chat_template, + chat_template_content_format, + tokenizer, + ) + prompts: List[Union[TokensPrompt, TextPrompt]] = [] for msgs in list_of_messages: - tokenizer = self.get_tokenizer() - model_config = self.llm_engine.get_model_config() - # NOTE: _parse_chat_message_content_parts() currently doesn't # handle mm_processor_kwargs, since there is no implementation in # the chat message parsing for it. conversation, mm_data = parse_chat_messages( - msgs, model_config, tokenizer) + msgs, + model_config, + tokenizer, + content_format=resolved_content_format, + ) prompt_data: Union[str, List[int]] if isinstance(tokenizer, MistralTokenizer): @@ -737,7 +759,7 @@ def encode( generation, if any. Returns: - A list of `EmbeddingRequestOutput` objects containing the + A list of ``EmbeddingRequestOutput`` objects containing the generated embeddings in the same order as the input prompts. Note: diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b13f6a228b4c6..b0fe061f5db4a 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -29,6 +29,7 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.multiprocessing.engine import run_mp_engine from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import (make_arg_parser, @@ -529,6 +530,9 @@ def init_app_state( state.engine_client = engine_client state.log_stats = not args.disable_log_stats + resolved_chat_template = load_chat_template(args.chat_template) + logger.info("Using supplied chat template:\n%s", resolved_chat_template) + state.openai_serving_chat = OpenAIServingChat( engine_client, model_config, @@ -537,7 +541,8 @@ def init_app_state( lora_modules=args.lora_modules, prompt_adapters=args.prompt_adapters, request_logger=request_logger, - chat_template=args.chat_template, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser, @@ -557,7 +562,8 @@ def init_app_state( model_config, base_model_paths, request_logger=request_logger, - chat_template=args.chat_template, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, ) if model_config.task == "embedding" else None state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, @@ -565,7 +571,8 @@ def init_app_state( base_model_paths, lora_modules=args.lora_modules, request_logger=request_logger, - chat_template=args.chat_template, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, ) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index eb08a89293370..24c206a1261f2 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -7,10 +7,11 @@ import argparse import json import ssl -from typing import List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Union, get_args from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str -from vllm.entrypoints.chat_utils import validate_chat_template +from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, + validate_chat_template) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, PromptAdapterPath) from vllm.entrypoints.openai.tool_parsers import ToolParserManager @@ -132,6 +133,18 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="The file path to the chat template, " "or the template in single-line form " "for the specified model") + parser.add_argument( + '--chat-template-content-format', + type=str, + default="auto", + choices=get_args(ChatTemplateContentFormatOption), + help='The format to render message content within a chat template.' + '\n\n' + '* "string" will render the content as a string. ' + 'Example: "Hello World"\n' + '* "openai" will render the content as a list of dictionaries, ' + 'similar to OpenAI schema. ' + 'Example: [{"type": "text", "text": "Hello world!"}]') parser.add_argument("--response-role", type=nullable_str, default="assistant", diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 820aefd8800d9..b7b064ae01f05 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -5,9 +5,8 @@ from typing import Any, Dict, List, Literal, Optional, Union import torch -from openai.types.chat import ChatCompletionContentPartParam from pydantic import BaseModel, ConfigDict, Field, model_validator -from typing_extensions import Annotated, Required, TypedDict +from typing_extensions import Annotated from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.pooling_params import PoolingParams @@ -35,26 +34,6 @@ assert _LONG_INFO.max == _MOCK_LONG_INFO.max -class CustomChatCompletionMessageParam(TypedDict, total=False): - """Enables custom roles in the Chat Completion API.""" - role: Required[str] - """The role of the message's author.""" - - content: Union[str, List[ChatCompletionContentPartParam]] - """The contents of the message.""" - - name: str - """An optional name for the participant. - - Provides the model information to differentiate between participants of the - same role. - """ - - tool_call_id: Optional[str] - - tool_calls: Optional[List[dict]] - - class OpenAIBaseModel(BaseModel): # OpenAI API does not allow extra fields model_config = ConfigDict(extra="forbid") @@ -1054,16 +1033,56 @@ class TokenizeCompletionRequest(OpenAIBaseModel): model: str prompt: str - add_special_tokens: bool = Field(default=True) + add_special_tokens: bool = Field( + default=True, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " + "the prompt."), + ) class TokenizeChatRequest(OpenAIBaseModel): model: str messages: List[ChatCompletionMessageParam] - add_generation_prompt: bool = Field(default=True) - continue_final_message: bool = Field(default=False) - add_special_tokens: bool = Field(default=False) + add_generation_prompt: bool = Field( + default=True, + description= + ("If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model."), + ) + continue_final_message: bool = Field( + default=False, + description= + ("If this is set, the chat will be formatted so that the final " + "message in the chat is open-ended, without any EOS tokens. The " + "model will continue this message rather than starting a new one. " + "This allows you to \"prefill\" part of the model's response for it. " + "Cannot be used at the same time as `add_generation_prompt`."), + ) + add_special_tokens: bool = Field( + default=False, + description=( + "If true, special tokens (e.g. BOS) will be added to the prompt " + "on top of what is added by the chat template. " + "For most models, the chat template takes care of adding the " + "special tokens so this should be set to false (as is the " + "default)."), + ) + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one."), + ) + chat_template_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the template renderer. " + "Will be accessible by the chat template."), + ) @model_validator(mode="before") @classmethod diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 1b422a93263b2..00cdb3b6839f5 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -222,6 +222,7 @@ async def main(args): prompt_adapters=None, request_logger=request_logger, chat_template=None, + chat_template_content_format="auto", enable_prompt_tokens_details=args.enable_prompt_tokens_details, ) if model_config.task == "generate" else None openai_serving_embedding = OpenAIServingEmbedding( @@ -230,6 +231,7 @@ async def main(args): base_model_paths, request_logger=request_logger, chat_template=None, + chat_template_content_format="auto", ) if model_config.task == "embedding" else None tracker = BatchProgressTracker() diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 77cae00ae827f..2eef909eb9319 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -10,7 +10,8 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import ConversationMessage, load_chat_template +from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, + ConversationMessage) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, ChatCompletionLogProbs, @@ -38,20 +39,23 @@ class OpenAIServingChat(OpenAIServing): - def __init__(self, - engine_client: EngineClient, - model_config: ModelConfig, - base_model_paths: List[BaseModelPath], - response_role: str, - *, - lora_modules: Optional[List[LoRAModulePath]], - prompt_adapters: Optional[List[PromptAdapterPath]], - request_logger: Optional[RequestLogger], - chat_template: Optional[str], - return_tokens_as_token_ids: bool = False, - enable_auto_tools: bool = False, - tool_parser: Optional[str] = None, - enable_prompt_tokens_details: bool = False): + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + base_model_paths: List[BaseModelPath], + response_role: str, + *, + lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]], + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + return_tokens_as_token_ids: bool = False, + enable_auto_tools: bool = False, + tool_parser: Optional[str] = None, + enable_prompt_tokens_details: bool = False, + ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, base_model_paths=base_model_paths, @@ -61,8 +65,8 @@ def __init__(self, return_tokens_as_token_ids=return_tokens_as_token_ids) self.response_role = response_role - self.use_tool_use_model_template = False - self.chat_template = load_chat_template(chat_template) + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format # set up tool use self.enable_auto_tools: bool = enable_auto_tools @@ -120,6 +124,7 @@ async def create_chat_completion( ) = self._maybe_get_adapters(request) tokenizer = await self.engine_client.get_tokenizer(lora_request) + tool_parser = self.tool_parser # validation for OpenAI tools @@ -157,6 +162,7 @@ async def create_chat_completion( tokenizer, request.messages, chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self.chat_template_content_format, add_generation_prompt=request.add_generation_prompt, continue_final_message=request.continue_final_message, tool_dicts=tool_dicts, diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index bbe7db8f13231..74ad7389784fc 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -1,7 +1,7 @@ import asyncio import base64 import time -from typing import AsyncGenerator, List, Literal, Optional, Union, cast +from typing import AsyncGenerator, Final, List, Literal, Optional, Union, cast import numpy as np from fastapi import Request @@ -9,7 +9,7 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import load_chat_template +from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, EmbeddingRequest, @@ -77,7 +77,8 @@ def __init__( *, request_logger: Optional[RequestLogger], chat_template: Optional[str], - ): + chat_template_content_format: ChatTemplateContentFormatOption, + ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, base_model_paths=base_model_paths, @@ -85,7 +86,8 @@ def __init__( prompt_adapters=None, request_logger=request_logger) - self.chat_template = load_chat_template(chat_template) + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format async def create_embedding( self, @@ -144,6 +146,8 @@ async def create_embedding( tokenizer, request.messages, chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self. + chat_template_content_format, add_generation_prompt=request.add_generation_prompt, continue_final_message=request.continue_final_message, truncate_prompt_tokens=truncate_prompt_tokens, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index fa315fa516632..cae2877ea7e99 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -11,14 +11,16 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient +# yapf conflicts with isort for this block +# yapf: disable from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + ChatTemplateContentFormatOption, ConversationMessage, apply_hf_chat_template, apply_mistral_chat_template, - parse_chat_messages_futures) + parse_chat_messages_futures, + resolve_chat_template_content_format) from vllm.entrypoints.logger import RequestLogger -# yapf conflicts with isort for this block -# yapf: disable from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest, DetokenizeRequest, @@ -426,7 +428,8 @@ async def _preprocess_chat( request: ChatLikeRequest, tokenizer: AnyTokenizer, messages: List[ChatCompletionMessageParam], - chat_template: Optional[str] = None, + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, add_generation_prompt: bool = True, continue_final_message: bool = False, tool_dicts: Optional[List[Dict[str, Any]]] = None, @@ -437,10 +440,16 @@ async def _preprocess_chat( add_special_tokens: bool = False, ) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt], List[TokensPrompt]]: + resolved_content_format = resolve_chat_template_content_format( + chat_template, + chat_template_content_format, + tokenizer, + ) conversation, mm_data_future = parse_chat_messages_futures( messages, self.model_config, tokenizer, + content_format=resolved_content_format, ) _chat_template_kwargs: Dict[str, Any] = dict( diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 1fd82304f7a4d..59b3b1311f881 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,8 +1,8 @@ -from typing import List, Optional, Union +from typing import Final, List, Optional, Union from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import load_chat_template +from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -33,7 +33,8 @@ def __init__( lora_modules: Optional[List[LoRAModulePath]], request_logger: Optional[RequestLogger], chat_template: Optional[str], - ): + chat_template_content_format: ChatTemplateContentFormatOption, + ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, base_model_paths=base_model_paths, @@ -41,12 +42,8 @@ def __init__( prompt_adapters=None, request_logger=request_logger) - # If this is None we use the tokenizer's default chat template - # the list of commonly-used chat template names for HF named templates - hf_chat_templates: List[str] = ['default', 'tool_use'] - self.chat_template = chat_template \ - if chat_template in hf_chat_templates \ - else load_chat_template(chat_template) + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format async def create_tokenize( self, @@ -75,9 +72,12 @@ async def create_tokenize( request, tokenizer, request.messages, - chat_template=self.chat_template, + chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self. + chat_template_content_format, add_generation_prompt=request.add_generation_prompt, continue_final_message=request.continue_final_message, + chat_template_kwargs=request.chat_template_kwargs, add_special_tokens=request.add_special_tokens, ) else: