From d4844011cb5abe2c13a7fd68eeec7b6e77feab6b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 1 Nov 2024 15:09:55 +0000 Subject: [PATCH 01/18] Write out skeleton code --- docs/source/models/vlm.rst | 3 +- .../serving/openai_compatible_server.md | 2 +- tests/entrypoints/openai/test_serving_chat.py | 2 +- tests/entrypoints/test_chat_utils.py | 4 +-- vllm/config.py | 4 +-- vllm/engine/arg_utils.py | 17 ++++++---- vllm/engine/llm_engine.py | 4 +-- vllm/entrypoints/chat_utils.py | 33 ++++++++++++------- 8 files changed, 41 insertions(+), 28 deletions(-) diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index 3377502a6db28..8ea5f776d6cd6 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -268,7 +268,8 @@ In this example, we will serve the ``TIGER-Lab/VLM2Vec-Full`` model. .. code-block:: bash vllm serve TIGER-Lab/VLM2Vec-Full --task embedding \ - --trust-remote-code --max-model-len 4096 --chat-template examples/template_vlm2vec.jinja + --trust-remote-code --max-model-len 4096 \ + --chat-template examples/template_vlm2vec.jinja --chat-template-content-format openai .. important:: diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 0b5f75caf2475..d6fc3405f517c 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -148,7 +148,7 @@ completion = client.chat.completions.create( ``` Most chat templates for LLMs expect the `content` to be a `string` but there are some newer models like `meta-llama/Llama-Guard-3-1B` that expect the content to be parsed with the new OpenAI spec. In order to choose which -format the content needs to be parsed in by vLLM, please use the `--chat-template-text-format` argument to specify +format the content needs to be parsed in by vLLM, please use the `--chat-template-content-format` argument to specify between `string` or `openai`. The default value is `string` and vLLM internally converts both spec formats to match this, unless explicitly specified. diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index e969d33775d86..2dd925fbc38aa 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -26,7 +26,7 @@ class MockModelConfig: tokenizer = MODEL_NAME trust_remote_code = False tokenizer_mode = "auto" - chat_template_text_format = "string" + chat_template_content_format = "string" max_model_len = 100 tokenizer_revision = None multimodal_config = MultiModalConfig() diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 5fa466f8f041f..a3688d64e3aae 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -26,7 +26,7 @@ def phi3v_model_config(): trust_remote_code=True, dtype="bfloat16", seed=0, - chat_template_text_format="string", + chat_template_content_format="string", limit_mm_per_prompt={ "image": 2, }) @@ -335,7 +335,7 @@ def test_parse_chat_messages_context_text_format( phi3v_model_config, phi3v_tokenizer, ): - phi3v_model_config.chat_template_text_format = "openai" + phi3v_model_config.chat_template_content_format = "openai" conversation, mm_data = parse_chat_messages( [{ "role": "user", diff --git a/vllm/config.py b/vllm/config.py index c2a8c956b374a..480791ce5d464 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -153,7 +153,7 @@ def __init__( use_async_output_proc: bool = True, override_neuron_config: Optional[Dict[str, Any]] = None, config_format: ConfigFormat = ConfigFormat.AUTO, - chat_template_text_format: str = "string", + chat_template_content_format: str = "string", mm_processor_kwargs: Optional[Dict[str, Any]] = None, pooling_type: Optional[str] = None, pooling_norm: Optional[bool] = None, @@ -190,7 +190,7 @@ def __init__( self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.use_async_output_proc = use_async_output_proc - self.chat_template_text_format = chat_template_text_format + self.chat_template_content_format = chat_template_content_format self.mm_processor_kwargs = mm_processor_kwargs # Set enforce_eager to False if the value is unset. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b1f0f8b9df925..33cebe5c98dc2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -89,7 +89,7 @@ class EngineArgs: task: TaskOption = "auto" skip_tokenizer_init: bool = False tokenizer_mode: str = 'auto' - chat_template_text_format: str = 'string' + chat_template_content_format: str = 'string' trust_remote_code: bool = False download_dir: Optional[str] = None load_format: str = 'auto' @@ -258,13 +258,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'always use the slow tokenizer. \n* ' '"mistral" will always use the `mistral_common` tokenizer.') parser.add_argument( - '--chat-template-text-format', + '--chat-template-content-format', type=str, - default=EngineArgs.chat_template_text_format, + default=EngineArgs.chat_template_content_format, choices=['string', 'openai'], - help='The format to render text content within a chat template. ' - '"string" will keep the content field as a string whereas ' - '"openai" will parse content in the current OpenAI format.') + help='The format to render message content within a chat template.' + '\n\n' + '* "string" will render `message["content"]` as a string. ' + 'Example: "Hello World"\n' + '* "openai" will render `message["content"]` like OpenAI schema. ' + 'Example: {"type": "text", "text": "Hello world!"}') parser.add_argument('--trust-remote-code', action='store_true', help='Trust remote code from huggingface.') @@ -917,7 +920,7 @@ def create_model_config(self) -> ModelConfig: # We know this is not None because we set it in __post_init__ tokenizer=cast(str, self.tokenizer), tokenizer_mode=self.tokenizer_mode, - chat_template_text_format=self.chat_template_text_format, + chat_template_content_format=self.chat_template_content_format, trust_remote_code=self.trust_remote_code, dtype=self.dtype, seed=self.seed, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index edef1f30a9e91..cba5038295aa5 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -257,7 +257,7 @@ def __init__( "num_scheduler_steps=%d, chunked_prefill_enabled=%s " "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " "use_async_output_proc=%s, use_cached_outputs=%s, " - "chat_template_text_format=%s, mm_processor_kwargs=%s, " + "chat_template_content_format=%s, mm_processor_kwargs=%s, " "pooler_config=%r)", VLLM_VERSION, model_config.model, @@ -293,7 +293,7 @@ def __init__( cache_config.enable_prefix_caching, model_config.use_async_output_proc, use_cached_outputs, - model_config.chat_template_text_format, + model_config.chat_template_content_format, model_config.mm_processor_kwargs, model_config.pooler_config, ) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index bc2de2d162473..649f331996f88 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -8,6 +8,7 @@ from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List, Literal, Mapping, Optional, Tuple, TypeVar, Union, cast) +import jinja2 # yapf conflicts with isort for this block # yapf: disable from openai.types.chat import (ChatCompletionAssistantMessageParam, @@ -24,6 +25,7 @@ # pydantic needs the TypedDict from typing_extensions from pydantic import ConfigDict from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers.utils.chat_template_utils import _compile_jinja_template from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig @@ -417,7 +419,6 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], _ImageParser = partial(cast, ChatCompletionContentPartImageParam) _AudioParser = partial(cast, ChatCompletionContentPartAudioParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) -MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'} # Define a mapping from part types to their corresponding parsing functions. MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = { @@ -490,18 +491,13 @@ def _parse_chat_message_content_parts( role: str, parts: Iterable[ChatCompletionContentPartParam], mm_tracker: BaseMultiModalItemTracker, - chat_template_text_format: str, ) -> List[ConversationMessage]: content: List[Union[str, Dict[str, str]]] = [] mm_parser = mm_tracker.create_parser() - model_config = mm_tracker.model_config + model_config = mm_tracer.model_config - wrap_dicts = (chat_template_text_format == "openai" - or (model_config.task == "embedding" - and model_config.is_multimodal_model) - or (model_config.hf_config.model_type - in MODEL_KEEP_MULTI_MODAL_CONTENT)) + wrap_dicts = model_config.chat_template_content_format == "openai" for part in parts: parse_res = _parse_chat_message_content_part( @@ -573,7 +569,6 @@ def _parse_chat_message_content_part( def _parse_chat_message_content( message: ChatCompletionMessageParam, mm_tracker: BaseMultiModalItemTracker, - chat_template_text_format: str, ) -> List[ConversationMessage]: role = message["role"] content = message.get("content") @@ -589,7 +584,6 @@ def _parse_chat_message_content( role, content, # type: ignore mm_tracker, - chat_template_text_format, ) for result_msg in result: @@ -636,7 +630,6 @@ def parse_chat_messages( sub_messages = _parse_chat_message_content( msg, mm_tracker, - model_config.chat_template_text_format, ) conversation.extend(sub_messages) @@ -658,7 +651,6 @@ def parse_chat_messages_futures( sub_messages = _parse_chat_message_content( msg, mm_tracker, - model_config.chat_template_text_format, ) conversation.extend(sub_messages) @@ -668,6 +660,23 @@ def parse_chat_messages_futures( return conversation, mm_tracker.all_mm_data() +def validate_hf_chat_template_content_format( + model_config: ModelConfig, + chat_template: str, +) -> None: + content_format = model_config.chat_template_content_format + + compiled = _compile_jinja_template(chat_template) + # Somehow parse out the AST and find how messages[int]['content'] is used? + + if content_format == "string": + pass + elif content_format == "openai": + pass + else: + raise ValueError(f"Invalid format: {content_format}") + + def apply_hf_chat_template( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], conversation: List[ConversationMessage], From b91de40798aff3c0c193eac2cccb3d46a36459a2 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 Nov 2024 10:16:11 +0000 Subject: [PATCH 02/18] Iterate --- tests/entrypoints/openai/test_serving_chat.py | 2 +- tests/entrypoints/test_chat_utils.py | 520 ++++++++++-------- vllm/config.py | 2 - vllm/engine/arg_utils.py | 13 - vllm/engine/llm_engine.py | 4 +- vllm/entrypoints/chat_utils.py | 148 ++++- vllm/entrypoints/llm.py | 28 +- vllm/entrypoints/openai/api_server.py | 13 +- vllm/entrypoints/openai/cli_args.py | 17 +- vllm/entrypoints/openai/protocol.py | 48 +- vllm/entrypoints/openai/run_batch.py | 2 + vllm/entrypoints/openai/serving_chat.py | 37 +- vllm/entrypoints/openai/serving_embedding.py | 12 +- vllm/entrypoints/openai/serving_engine.py | 13 +- .../openai/serving_tokenization.py | 20 +- 15 files changed, 566 insertions(+), 313 deletions(-) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 2dd925fbc38aa..535e910e66707 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -26,7 +26,6 @@ class MockModelConfig: tokenizer = MODEL_NAME trust_remote_code = False tokenizer_mode = "auto" - chat_template_content_format = "string" max_model_len = 100 tokenizer_revision = None multimodal_config = MultiModalConfig() @@ -49,6 +48,7 @@ async def _async_serving_chat_init(): BASE_MODEL_PATHS, response_role="assistant", chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", lora_modules=None, prompt_adapters=None, request_logger=None) diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index a3688d64e3aae..5adb7760337fe 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -26,7 +26,6 @@ def phi3v_model_config(): trust_remote_code=True, dtype="bfloat16", seed=0, - chat_template_content_format="string", limit_mm_per_prompt={ "image": 2, }) @@ -94,19 +93,24 @@ def test_parse_chat_messages_single_image( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in the image?" - }] - }], phi3v_model_config, phi3v_tokenizer) + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in the image?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [{ "role": "user", @@ -121,19 +125,24 @@ async def test_parse_chat_messages_single_image_async( phi3v_tokenizer, image_url, ): - conversation, mm_future = parse_chat_messages_futures([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in the image?" - }] - }], phi3v_model_config, phi3v_tokenizer) + conversation, mm_future = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in the image?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [{ "role": "user", @@ -147,24 +156,29 @@ def test_parse_chat_messages_multiple_images( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in these images?" - }] - }], phi3v_model_config, phi3v_tokenizer) + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in these images?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [{ "role": @@ -181,24 +195,29 @@ async def test_parse_chat_messages_multiple_images_async( phi3v_tokenizer, image_url, ): - conversation, mm_future = parse_chat_messages_futures([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in these images?" - }] - }], phi3v_model_config, phi3v_tokenizer) + conversation, mm_future = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in these images?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [{ "role": @@ -214,27 +233,31 @@ def test_parse_chat_messages_placeholder_already_in_prompt( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": - "text", - "text": - "What's in <|image_1|> and how does it compare to <|image_2|>?" - }] - }], phi3v_model_config, phi3v_tokenizer) - + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": + "text", + "text": + "What's in <|image_1|> and how does it compare to <|image_2|>?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [{ "role": "user", @@ -249,26 +272,35 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": - "text", - "text": - "What's in <|image_1|> and how does it compare to the other one?" - }] - }], phi3v_model_config, phi3v_tokenizer) + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": + "text", + "text": + "What's in <|image_1|> and how does it compare to the other one?" # noqa: E501 + } + ] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [{ "role": @@ -285,34 +317,39 @@ def test_parse_chat_messages_multiple_images_across_messages( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in this image?" + }] }, { - "type": "text", - "text": "What's in this image?" - }] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } + "role": "assistant", + "content": "Some stuff." }, { - "type": "text", - "text": "What about this one?" - }] - }], phi3v_model_config, phi3v_tokenizer) + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What about this one?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [ { @@ -335,7 +372,6 @@ def test_parse_chat_messages_context_text_format( phi3v_model_config, phi3v_tokenizer, ): - phi3v_model_config.chat_template_content_format = "openai" conversation, mm_data = parse_chat_messages( [{ "role": "user", @@ -349,7 +385,11 @@ def test_parse_chat_messages_context_text_format( }, { "role": "user", "content": "What about this one?" - }], phi3v_model_config, phi3v_tokenizer) + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="openai", + ) assert conversation == [ { @@ -389,29 +429,34 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message( ValueError, match="At most 2 image\\(s\\) may be provided in one request\\." ): - parse_chat_messages([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in these images?" - }] - }], phi3v_model_config, phi3v_tokenizer) + parse_chat_messages( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in these images?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) def test_parse_chat_messages_rejects_too_many_images_across_messages( @@ -427,39 +472,44 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages( ValueError, match="At most 2 image\\(s\\) may be provided in one request\\." ): - parse_chat_messages([{ - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } - }, { - "type": "text", - "text": "What's in this image?" - }] - }, { - "role": "assistant", - "content": "Some stuff." - }, { - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": image_url - } + parse_chat_messages( + [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in this image?" + }] }, { - "type": "image_url", - "image_url": { - "url": image_url - } + "role": "assistant", + "content": "Some stuff." }, { - "type": "text", - "text": "What about these two?" - }] - }], phi3v_model_config, phi3v_tokenizer) + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What about these two?" + }] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) def test_parse_chat_messages_multiple_images_uncommon_input( @@ -467,17 +517,22 @@ def test_parse_chat_messages_multiple_images_uncommon_input( phi3v_tokenizer, image_url, ): - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [ - "What's in these images?", { - "image_url": image_url - }, { - "image_url": image_url - } - ] - }], phi3v_model_config, phi3v_tokenizer) + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [ + "What's in these images?", { + "image_url": image_url + }, { + "image_url": image_url + } + ] + }], + phi3v_model_config, + phi3v_tokenizer, + content_format="string", + ) assert conversation == [{ "role": @@ -495,16 +550,21 @@ def test_mllama_single_image( image_url, ): """Ensures that a single image is parsed correctly mllama.""" - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [{ - 'type': 'text', - 'text': 'The content of this image is:' - }, { - "image_url": image_url - }] - }], mllama_model_config, mllama_tokenizer) + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [{ + 'type': 'text', + 'text': 'The content of this image is:' + }, { + "image_url": image_url + }] + }], + mllama_model_config, + mllama_tokenizer, + content_format="openai", + ) _assert_mm_data_is_image_input(mm_data, 1) assert conversation == [{ 'role': @@ -524,26 +584,31 @@ def test_mllama_interleaved_images( image_url, ): """Ensures that multiple image are parsed as interleaved dicts.""" - conversation, mm_data = parse_chat_messages([{ - "role": - "user", - "content": [ - { - 'type': 'text', - 'text': 'The content of the first image is:' - }, - { - "image_url": image_url - }, - { - 'type': 'text', - 'text': 'The content of the second image is:' - }, - { - "image_url": image_url - }, - ] - }], mllama_model_config, mllama_tokenizer) + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + 'type': 'text', + 'text': 'The content of the first image is:' + }, + { + "image_url": image_url + }, + { + 'type': 'text', + 'text': 'The content of the second image is:' + }, + { + "image_url": image_url + }, + ] + }], + mllama_model_config, + mllama_tokenizer, + content_format="openai", + ) _assert_mm_data_is_image_input(mm_data, 2) assert conversation == [{ 'role': @@ -626,6 +691,7 @@ def get_conversation(is_hf: bool): vllm_conversation, model_config, tokenizer_group, + content_format="openai", ) vllm_result = apply_hf_chat_template( diff --git a/vllm/config.py b/vllm/config.py index 480791ce5d464..223981d630552 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -153,7 +153,6 @@ def __init__( use_async_output_proc: bool = True, override_neuron_config: Optional[Dict[str, Any]] = None, config_format: ConfigFormat = ConfigFormat.AUTO, - chat_template_content_format: str = "string", mm_processor_kwargs: Optional[Dict[str, Any]] = None, pooling_type: Optional[str] = None, pooling_norm: Optional[bool] = None, @@ -190,7 +189,6 @@ def __init__( self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.use_async_output_proc = use_async_output_proc - self.chat_template_content_format = chat_template_content_format self.mm_processor_kwargs = mm_processor_kwargs # Set enforce_eager to False if the value is unset. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 33cebe5c98dc2..16bd5a9056a67 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -89,7 +89,6 @@ class EngineArgs: task: TaskOption = "auto" skip_tokenizer_init: bool = False tokenizer_mode: str = 'auto' - chat_template_content_format: str = 'string' trust_remote_code: bool = False download_dir: Optional[str] = None load_format: str = 'auto' @@ -257,17 +256,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'fast tokenizer if available.\n* "slow" will ' 'always use the slow tokenizer. \n* ' '"mistral" will always use the `mistral_common` tokenizer.') - parser.add_argument( - '--chat-template-content-format', - type=str, - default=EngineArgs.chat_template_content_format, - choices=['string', 'openai'], - help='The format to render message content within a chat template.' - '\n\n' - '* "string" will render `message["content"]` as a string. ' - 'Example: "Hello World"\n' - '* "openai" will render `message["content"]` like OpenAI schema. ' - 'Example: {"type": "text", "text": "Hello world!"}') parser.add_argument('--trust-remote-code', action='store_true', help='Trust remote code from huggingface.') @@ -920,7 +908,6 @@ def create_model_config(self) -> ModelConfig: # We know this is not None because we set it in __post_init__ tokenizer=cast(str, self.tokenizer), tokenizer_mode=self.tokenizer_mode, - chat_template_content_format=self.chat_template_content_format, trust_remote_code=self.trust_remote_code, dtype=self.dtype, seed=self.seed, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index cba5038295aa5..51aaa77a75d72 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -257,8 +257,7 @@ def __init__( "num_scheduler_steps=%d, chunked_prefill_enabled=%s " "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " "use_async_output_proc=%s, use_cached_outputs=%s, " - "chat_template_content_format=%s, mm_processor_kwargs=%s, " - "pooler_config=%r)", + "mm_processor_kwargs=%s, pooler_config=%r)", VLLM_VERSION, model_config.model, speculative_config, @@ -293,7 +292,6 @@ def __init__( cache_config.enable_prefix_caching, model_config.use_async_output_proc, use_cached_outputs, - model_config.chat_template_content_format, model_config.mm_processor_kwargs, model_config.pooler_config, ) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 649f331996f88..c24370d5f8ec6 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -9,6 +9,8 @@ Literal, Mapping, Optional, Tuple, TypeVar, Union, cast) import jinja2 +import jinja2.nodes +import transformers.utils.chat_template_utils as hf_chat_utils # yapf conflicts with isort for this block # yapf: disable from openai.types.chat import (ChatCompletionAssistantMessageParam, @@ -25,7 +27,6 @@ # pydantic needs the TypedDict from typing_extensions from pydantic import ConfigDict from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast -from transformers.utils.chat_template_utils import _compile_jinja_template from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig @@ -136,6 +137,58 @@ class ConversationMessage(TypedDict, total=False): """The tool calls generated by the model, such as function calls.""" +# Passed in by user +ChatTemplateContentFormatOption = Literal["auto", "string", "openai"] + +# Used internally +_ChatTemplateContentFormat = Literal["string", "openai"] + + +@lru_cache +def _resolve_chat_template_content_format( + chat_template: Optional[str], + given_format: ChatTemplateContentFormatOption, + tokenizer: AnyTokenizer, +) -> _ChatTemplateContentFormat: + if chat_template is None: + return "string" if given_format == "auto" else given_format + + if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + tokenizer_chat_template = tokenizer.chat_template + else: + tokenizer_chat_template = None + + if (isinstance(tokenizer_chat_template, dict) + and chat_template in tokenizer_chat_template): + jinja_text = tokenizer_chat_template[chat_template] + else: + jinja_text = load_chat_template(chat_template, is_literal=True) + + if jinja_text is None: + return "string" if given_format == "auto" else given_format + + detected_format = _detect_chat_template_content_format(jinja_text) + + if given_format != "auto" and given_format != detected_format: + logger.warning( + "You specified `--chat-template-content-format %s`, " + "but we detected that it should be '%s'.", + given_format, + detected_format, + ) + + return detected_format if given_format == "auto" else given_format + + +def resolve_chat_template_content_format( + chat_template: Optional[str], + given_format: ChatTemplateContentFormatOption, + tokenizer: AnyTokenizer, +) -> _ChatTemplateContentFormat: + return _resolve_chat_template_content_format(chat_template, given_format, + tokenizer) + + ModalityStr = Literal["image", "audio", "video"] _T = TypeVar("_T") @@ -363,12 +416,23 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]): def load_chat_template( - chat_template: Optional[Union[Path, str]]) -> Optional[str]: + chat_template: Optional[Union[Path, str]], + *, + is_literal: bool = False, +) -> Optional[str]: if chat_template is None: return None + + if is_literal: + if isinstance(chat_template, Path): + raise TypeError("chat_template is expected to be read directly " + "from its value") + + return codecs.decode(chat_template, "unicode_escape") + try: with open(chat_template, "r") as f: - resolved_chat_template = f.read() + return f.read() except OSError as e: if isinstance(chat_template, Path): raise @@ -382,10 +446,7 @@ def load_chat_template( # If opening a file fails, set chat template to be args to # ensure we decode so our escape are interpreted correctly - resolved_chat_template = codecs.decode(chat_template, "unicode_escape") - - logger.info("Using supplied chat template:\n%s", resolved_chat_template) - return resolved_chat_template + return load_chat_template(chat_template, is_literal=True) # TODO: Let user specify how to insert multimodal tokens into prompt @@ -491,13 +552,12 @@ def _parse_chat_message_content_parts( role: str, parts: Iterable[ChatCompletionContentPartParam], mm_tracker: BaseMultiModalItemTracker, + content_format: _ChatTemplateContentFormat, ) -> List[ConversationMessage]: content: List[Union[str, Dict[str, str]]] = [] mm_parser = mm_tracker.create_parser() - model_config = mm_tracer.model_config - - wrap_dicts = model_config.chat_template_content_format == "openai" + wrap_dicts = content_format == "openai" for part in parts: parse_res = _parse_chat_message_content_part( @@ -569,6 +629,7 @@ def _parse_chat_message_content_part( def _parse_chat_message_content( message: ChatCompletionMessageParam, mm_tracker: BaseMultiModalItemTracker, + content_format: _ChatTemplateContentFormat, ) -> List[ConversationMessage]: role = message["role"] content = message.get("content") @@ -584,6 +645,7 @@ def _parse_chat_message_content( role, content, # type: ignore mm_tracker, + content_format, ) for result_msg in result: @@ -622,6 +684,7 @@ def parse_chat_messages( messages: List[ChatCompletionMessageParam], model_config: ModelConfig, tokenizer: AnyTokenizer, + content_format: _ChatTemplateContentFormat, ) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]: conversation: List[ConversationMessage] = [] mm_tracker = MultiModalItemTracker(model_config, tokenizer) @@ -630,6 +693,7 @@ def parse_chat_messages( sub_messages = _parse_chat_message_content( msg, mm_tracker, + content_format, ) conversation.extend(sub_messages) @@ -643,6 +707,7 @@ def parse_chat_messages_futures( messages: List[ChatCompletionMessageParam], model_config: ModelConfig, tokenizer: AnyTokenizer, + content_format: _ChatTemplateContentFormat, ) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]: conversation: List[ConversationMessage] = [] mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) @@ -651,6 +716,7 @@ def parse_chat_messages_futures( sub_messages = _parse_chat_message_content( msg, mm_tracker, + content_format, ) conversation.extend(sub_messages) @@ -660,21 +726,59 @@ def parse_chat_messages_futures( return conversation, mm_tracker.all_mm_data() -def validate_hf_chat_template_content_format( - model_config: ModelConfig, - chat_template: str, -) -> None: - content_format = model_config.chat_template_content_format +def _iter_nodes_define_message(chat_template_ast: jinja2.nodes.Template): + # Search for {%- for message in messages -%} loops + for loop_ast in chat_template_ast.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + if not (isinstance(loop_iter, jinja2.nodes.Name) + and loop_iter.ctx == "load" and loop_iter.name == "messages"): + continue - compiled = _compile_jinja_template(chat_template) - # Somehow parse out the AST and find how messages[int]['content'] is used? + loop_target = loop_ast.target + if not isinstance(loop_target, jinja2.nodes.Name): + continue - if content_format == "string": - pass - elif content_format == "openai": - pass + yield loop_ast, loop_target.name + + +def _iter_nodes_define_content_item(chat_template_ast: jinja2.nodes.Template): + for node, message_varname in _iter_nodes_define_message(chat_template_ast): + # Search for {%- for content in message['content'] -%} loops + for loop_ast in node.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + if not (isinstance(loop_iter, jinja2.nodes.Getitem) + and loop_iter.ctx == "load"): + continue + + getitem_src = loop_iter.node + if not (isinstance(getitem_src, jinja2.nodes.Name) + and getitem_src.ctx == "load" + and getitem_src.name == message_varname): + continue + + getitem_idx = loop_iter.arg + if not (isinstance(getitem_idx, jinja2.nodes.Const) + and getitem_idx.value == "content"): + continue + + loop_target = loop_ast.target + if not isinstance(loop_target, jinja2.nodes.Name): + continue + + yield loop_iter, loop_target.name + + +def _detect_chat_template_content_format( + chat_template: str) -> _ChatTemplateContentFormat: + jinjacompiled = hf_chat_utils._compile_jinja_template(chat_template) + jinja_ast = jinjacompiled.environment.parse(chat_template) + + try: + next(_iter_nodes_define_content_item(jinja_ast)) + except StopIteration: + return "string" else: - raise ValueError(f"Invalid format: {content_format}") + return "openai" def apply_hf_chat_template( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 3d62cb3598477..b5f0a549c1d5b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -11,9 +11,11 @@ BeamSearchSequence, get_beam_search_score) from vllm.engine.arg_utils import EngineArgs, TaskOption from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + ChatTemplateContentFormatOption, apply_hf_chat_template, apply_mistral_chat_template, - parse_chat_messages) + parse_chat_messages, + resolve_chat_template_content_format) from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger @@ -507,6 +509,7 @@ def chat( use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, chat_template: Optional[str] = None, + chat_template_content_format: ChatTemplateContentFormatOption = "auto", add_generation_prompt: bool = True, continue_final_message: bool = False, tools: Optional[List[Dict[str, Any]]] = None, @@ -535,6 +538,12 @@ def chat( lora_request: LoRA request to use for generation, if any. chat_template: The template to use for structuring the chat. If not provided, the model's default chat template will be used. + chat_template_content_format: The format to render message content. + - "string" will render the content as a string. + Example: "Hello World" + - "openai" will render the content as a list of dictionaries, + similar to OpenAI schema. + Example: [{"type": "text", "text": "Hello world!"}] add_generation_prompt: If True, adds a generation template to each message. continue_final_message: If True, continues the final message in @@ -560,17 +569,26 @@ def chat( cast(List[ChatCompletionMessageParam], messages) ] + tokenizer = self.get_tokenizer() + model_config = self.llm_engine.get_model_config() + resolved_content_format = resolve_chat_template_content_format( + chat_template, + chat_template_content_format, + tokenizer, + ) + prompts: List[Union[TokensPrompt, TextPrompt]] = [] for msgs in list_of_messages: - tokenizer = self.get_tokenizer() - model_config = self.llm_engine.get_model_config() - # NOTE: _parse_chat_message_content_parts() currently doesn't # handle mm_processor_kwargs, since there is no implementation in # the chat message parsing for it. conversation, mm_data = parse_chat_messages( - msgs, model_config, tokenizer) + msgs, + model_config, + tokenizer, + content_format=resolved_content_format, + ) prompt_data: Union[str, List[int]] if isinstance(tokenizer, MistralTokenizer): diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 95fd56d916050..9a5e44468e335 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -29,6 +29,7 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.multiprocessing.engine import run_mp_engine from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import (make_arg_parser, @@ -508,6 +509,9 @@ def init_app_state( state.engine_client = engine_client state.log_stats = not args.disable_log_stats + resolved_chat_template = load_chat_template(args.chat_template) + logger.info("Using supplied chat template:\n%s", resolved_chat_template) + state.openai_serving_chat = OpenAIServingChat( engine_client, model_config, @@ -516,7 +520,8 @@ def init_app_state( lora_modules=args.lora_modules, prompt_adapters=args.prompt_adapters, request_logger=request_logger, - chat_template=args.chat_template, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser, @@ -535,7 +540,8 @@ def init_app_state( model_config, base_model_paths, request_logger=request_logger, - chat_template=args.chat_template, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, ) if model_config.task == "embedding" else None state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, @@ -543,7 +549,8 @@ def init_app_state( base_model_paths, lora_modules=args.lora_modules, request_logger=request_logger, - chat_template=args.chat_template, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, ) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index a089985ac9758..d775ef7da1b04 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -7,10 +7,11 @@ import argparse import json import ssl -from typing import List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Union, get_args from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str -from vllm.entrypoints.chat_utils import validate_chat_template +from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, + validate_chat_template) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, PromptAdapterPath) from vllm.entrypoints.openai.tool_parsers import ToolParserManager @@ -132,6 +133,18 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="The file path to the chat template, " "or the template in single-line form " "for the specified model") + parser.add_argument( + '--chat-template-content-format', + type=str, + default="auto", + choices=get_args(ChatTemplateContentFormatOption), + help='The format to render message content within a chat template.' + '\n\n' + '* "string" will render the content as a string. ' + 'Example: "Hello World"\n' + '* "openai" will render the content as a list of dictionaries, ' + 'similar to OpenAI schema. ' + 'Example: [{"type": "text", "text": "Hello world!"}]') parser.add_argument("--response-role", type=nullable_str, default="assistant", diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 1335e51bd152c..af8190a82ba1d 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1043,16 +1043,56 @@ class TokenizeCompletionRequest(OpenAIBaseModel): model: str prompt: str - add_special_tokens: bool = Field(default=True) + add_special_tokens: bool = Field( + default=True, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " + "the prompt."), + ) class TokenizeChatRequest(OpenAIBaseModel): model: str messages: List[ChatCompletionMessageParam] - add_generation_prompt: bool = Field(default=True) - continue_final_message: bool = Field(default=False) - add_special_tokens: bool = Field(default=False) + add_generation_prompt: bool = Field( + default=True, + description= + ("If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model."), + ) + continue_final_message: bool = Field( + default=False, + description= + ("If this is set, the chat will be formatted so that the final " + "message in the chat is open-ended, without any EOS tokens. The " + "model will continue this message rather than starting a new one. " + "This allows you to \"prefill\" part of the model's response for it. " + "Cannot be used at the same time as `add_generation_prompt`."), + ) + add_special_tokens: bool = Field( + default=False, + description=( + "If true, special tokens (e.g. BOS) will be added to the prompt " + "on top of what is added by the chat template. " + "For most models, the chat template takes care of adding the " + "special tokens so this should be set to false (as is the " + "default)."), + ) + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one."), + ) + chat_template_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the template renderer. " + "Will be accessible by the chat template."), + ) @model_validator(mode="before") @classmethod diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index a64467a311523..e2245e22c8c5e 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -217,6 +217,7 @@ async def main(args): prompt_adapters=None, request_logger=request_logger, chat_template=None, + chat_template_content_format="auto", ) if model_config.task == "generate" else None openai_serving_embedding = OpenAIServingEmbedding( engine, @@ -224,6 +225,7 @@ async def main(args): base_model_paths, request_logger=request_logger, chat_template=None, + chat_template_content_format="auto", ) if model_config.task == "embedding" else None tracker = BatchProgressTracker() diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 9551b4f2091dd..1cef011c0b10a 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -10,7 +10,8 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import ConversationMessage, load_chat_template +from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, + ConversationMessage) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, ChatCompletionLogProbs, @@ -37,19 +38,22 @@ class OpenAIServingChat(OpenAIServing): - def __init__(self, - engine_client: EngineClient, - model_config: ModelConfig, - base_model_paths: List[BaseModelPath], - response_role: str, - *, - lora_modules: Optional[List[LoRAModulePath]], - prompt_adapters: Optional[List[PromptAdapterPath]], - request_logger: Optional[RequestLogger], - chat_template: Optional[str], - return_tokens_as_token_ids: bool = False, - enable_auto_tools: bool = False, - tool_parser: Optional[str] = None): + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + base_model_paths: List[BaseModelPath], + response_role: str, + *, + lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]], + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + return_tokens_as_token_ids: bool = False, + enable_auto_tools: bool = False, + tool_parser: Optional[str] = None, + ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, base_model_paths=base_model_paths, @@ -60,7 +64,8 @@ def __init__(self, self.response_role = response_role self.use_tool_use_model_template = False - self.chat_template = load_chat_template(chat_template) + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format # set up tool use self.enable_auto_tools: bool = enable_auto_tools @@ -111,6 +116,7 @@ async def create_chat_completion( ) = self._maybe_get_adapters(request) tokenizer = await self.engine_client.get_tokenizer(lora_request) + tool_parser = self.tool_parser # validation for OpenAI tools @@ -142,6 +148,7 @@ async def create_chat_completion( tokenizer, request.messages, chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self.chat_template_content_format, add_generation_prompt=request.add_generation_prompt, continue_final_message=request.continue_final_message, tool_dicts=tool_dicts, diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 917856cd2b2dd..dd84e96e0754e 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -1,7 +1,7 @@ import asyncio import base64 import time -from typing import AsyncGenerator, List, Literal, Optional, Union, cast +from typing import AsyncGenerator, Final, List, Literal, Optional, Union, cast import numpy as np from fastapi import Request @@ -9,7 +9,7 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import load_chat_template +from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, EmbeddingRequest, @@ -77,7 +77,8 @@ def __init__( *, request_logger: Optional[RequestLogger], chat_template: Optional[str], - ): + chat_template_content_format: ChatTemplateContentFormatOption, + ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, base_model_paths=base_model_paths, @@ -85,7 +86,8 @@ def __init__( prompt_adapters=None, request_logger=request_logger) - self.chat_template = load_chat_template(chat_template) + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format async def create_embedding( self, @@ -144,6 +146,8 @@ async def create_embedding( tokenizer, request.messages, chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self. + chat_template_content_format, add_generation_prompt=request.add_generation_prompt, continue_final_message=request.continue_final_message, truncate_prompt_tokens=truncate_prompt_tokens, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index e7aeac8f8c018..12afb7c313423 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -12,10 +12,12 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + ChatTemplateContentFormatOption, ConversationMessage, apply_hf_chat_template, apply_mistral_chat_template, - parse_chat_messages_futures) + parse_chat_messages_futures, + resolve_chat_template_content_format) from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -426,7 +428,8 @@ async def _preprocess_chat( request: ChatLikeRequest, tokenizer: AnyTokenizer, messages: List[ChatCompletionMessageParam], - chat_template: Optional[str] = None, + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, add_generation_prompt: bool = True, continue_final_message: bool = False, tool_dicts: Optional[List[Dict[str, Any]]] = None, @@ -437,10 +440,16 @@ async def _preprocess_chat( add_special_tokens: bool = False, ) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt], List[TokensPrompt]]: + resolved_content_format = resolve_chat_template_content_format( + chat_template, + chat_template_content_format, + tokenizer, + ) conversation, mm_data_future = parse_chat_messages_futures( messages, self.model_config, tokenizer, + content_format=resolved_content_format, ) request_prompt: Union[str, List[int]] diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 1fd82304f7a4d..59b3b1311f881 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,8 +1,8 @@ -from typing import List, Optional, Union +from typing import Final, List, Optional, Union from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import load_chat_template +from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -33,7 +33,8 @@ def __init__( lora_modules: Optional[List[LoRAModulePath]], request_logger: Optional[RequestLogger], chat_template: Optional[str], - ): + chat_template_content_format: ChatTemplateContentFormatOption, + ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, base_model_paths=base_model_paths, @@ -41,12 +42,8 @@ def __init__( prompt_adapters=None, request_logger=request_logger) - # If this is None we use the tokenizer's default chat template - # the list of commonly-used chat template names for HF named templates - hf_chat_templates: List[str] = ['default', 'tool_use'] - self.chat_template = chat_template \ - if chat_template in hf_chat_templates \ - else load_chat_template(chat_template) + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format async def create_tokenize( self, @@ -75,9 +72,12 @@ async def create_tokenize( request, tokenizer, request.messages, - chat_template=self.chat_template, + chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self. + chat_template_content_format, add_generation_prompt=request.add_generation_prompt, continue_final_message=request.continue_final_message, + chat_template_kwargs=request.chat_template_kwargs, add_special_tokens=request.add_special_tokens, ) else: From 569c076bd0bdeb5ac52dd28451070ff6164c2510 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 Nov 2024 10:29:42 +0000 Subject: [PATCH 03/18] Update docs and logs --- docs/source/serving/openai_compatible_server.md | 11 ++++++----- vllm/entrypoints/chat_utils.py | 12 ++++++++++-- vllm/entrypoints/openai/serving_engine.py | 4 ++-- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index d6fc3405f517c..3e52170817b87 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -146,11 +146,12 @@ completion = client.chat.completions.create( ] ) ``` -Most chat templates for LLMs expect the `content` to be a `string` but there are some newer models like -`meta-llama/Llama-Guard-3-1B` that expect the content to be parsed with the new OpenAI spec. In order to choose which -format the content needs to be parsed in by vLLM, please use the `--chat-template-content-format` argument to specify -between `string` or `openai`. The default value is `string` and vLLM internally converts both spec formats to match -this, unless explicitly specified. +Most chat templates for LLMs expect the `content` field to be a string but there are some newer models like +`meta-llama/Llama-Guard-3-1B` that expect the content to be according to the OpenAI schema in the request. +vLLM provides best-effort support to detect this automatically, which is logged as a string like +*"Detected the chat template content format to be..."*, and internally converts incoming requests to match +the detected format. If the result is not what you expect, you can use the `--chat-template-content-format` +CLI argument to explicitly specify which format to use (`"string"` or `"openai"`). ## Command line arguments for the server diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index c24370d5f8ec6..ee862a2cb3997 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -168,11 +168,19 @@ def _resolve_chat_template_content_format( return "string" if given_format == "auto" else given_format detected_format = _detect_chat_template_content_format(jinja_text) + logger.info( + "Detected the chat template content format to be '%s'. " + "Set `--chat-template-content-format` to explicitly specifiy this.", + detected_format, + ) if given_format != "auto" and given_format != detected_format: logger.warning( - "You specified `--chat-template-content-format %s`, " - "but we detected that it should be '%s'.", + "You specified `--chat-template-content-format %s` " + "which is different from the detected format '%s'. " + "To help us improve automatic detection, please consider " + "opening a GitHub issue at: " + "https://github.com/vllm-project/vllm/issues/new/choose", given_format, detected_format, ) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 12afb7c313423..2fd24640bf7f6 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -11,6 +11,8 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient +# yapf conflicts with isort for this block +# yapf: disable from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, ChatTemplateContentFormatOption, ConversationMessage, @@ -19,8 +21,6 @@ parse_chat_messages_futures, resolve_chat_template_content_format) from vllm.entrypoints.logger import RequestLogger -# yapf conflicts with isort for this block -# yapf: disable from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest, DetokenizeRequest, From 4a1b1e0ecc2d34299eb242665dbf669d1272d88e Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 Nov 2024 11:10:38 +0000 Subject: [PATCH 04/18] Improve detection --- vllm/entrypoints/chat_utils.py | 44 ++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index ee862a2cb3997..debc0e3dbef94 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -144,33 +144,46 @@ class ConversationMessage(TypedDict, total=False): _ChatTemplateContentFormat = Literal["string", "openai"] -@lru_cache def _resolve_chat_template_content_format( chat_template: Optional[str], given_format: ChatTemplateContentFormatOption, tokenizer: AnyTokenizer, ) -> _ChatTemplateContentFormat: - if chat_template is None: - return "string" if given_format == "auto" else given_format - if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): tokenizer_chat_template = tokenizer.chat_template else: tokenizer_chat_template = None + jinja_text: Optional[str] + if isinstance(tokenizer_chat_template, str) and chat_template is None: + jinja_text = tokenizer_chat_template if (isinstance(tokenizer_chat_template, dict) and chat_template in tokenizer_chat_template): jinja_text = tokenizer_chat_template[chat_template] else: jinja_text = load_chat_template(chat_template, is_literal=True) - if jinja_text is None: - return "string" if given_format == "auto" else given_format + detected_format = ("string" if jinja_text is None else + _detect_chat_template_content_format(jinja_text)) + + return detected_format if given_format == "auto" else given_format + + +@lru_cache +def resolve_chat_template_content_format( + chat_template: Optional[str], + given_format: ChatTemplateContentFormatOption, + tokenizer: AnyTokenizer, +) -> _ChatTemplateContentFormat: + detected_format = _resolve_chat_template_content_format( + chat_template, + given_format, + tokenizer, + ) - detected_format = _detect_chat_template_content_format(jinja_text) logger.info( "Detected the chat template content format to be '%s'. " - "Set `--chat-template-content-format` to explicitly specifiy this.", + "Set `--chat-template-content-format` to explicitly specify this.", detected_format, ) @@ -178,23 +191,14 @@ def _resolve_chat_template_content_format( logger.warning( "You specified `--chat-template-content-format %s` " "which is different from the detected format '%s'. " - "To help us improve automatic detection, please consider " - "opening a GitHub issue at: " + "If our automatic detection is incorrect, please consider " + "opening a GitHub issue so that we can improve it: " "https://github.com/vllm-project/vllm/issues/new/choose", given_format, detected_format, ) - return detected_format if given_format == "auto" else given_format - - -def resolve_chat_template_content_format( - chat_template: Optional[str], - given_format: ChatTemplateContentFormatOption, - tokenizer: AnyTokenizer, -) -> _ChatTemplateContentFormat: - return _resolve_chat_template_content_format(chat_template, given_format, - tokenizer) + return detected_format ModalityStr = Literal["image", "audio", "video"] From 0410d9fe956d34b2d68f55063da8b9fd4235aca1 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 2 Nov 2024 12:15:21 +0000 Subject: [PATCH 05/18] Add and fix tests --- docs/source/models/vlm.rst | 3 +- .../serving/openai_compatible_server.md | 6 +- tests/entrypoints/openai/test_serving_chat.py | 1 + tests/entrypoints/test_chat_utils.py | 83 ++++++++++- vllm/entrypoints/chat_utils.py | 134 ++++++++++-------- vllm/entrypoints/llm.py | 26 ++-- 6 files changed, 177 insertions(+), 76 deletions(-) diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index 8ea5f776d6cd6..3377502a6db28 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -268,8 +268,7 @@ In this example, we will serve the ``TIGER-Lab/VLM2Vec-Full`` model. .. code-block:: bash vllm serve TIGER-Lab/VLM2Vec-Full --task embedding \ - --trust-remote-code --max-model-len 4096 \ - --chat-template examples/template_vlm2vec.jinja --chat-template-content-format openai + --trust-remote-code --max-model-len 4096 --chat-template examples/template_vlm2vec.jinja .. important:: diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 3e52170817b87..e6cb41b245cf3 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -147,11 +147,11 @@ completion = client.chat.completions.create( ) ``` Most chat templates for LLMs expect the `content` field to be a string but there are some newer models like -`meta-llama/Llama-Guard-3-1B` that expect the content to be according to the OpenAI schema in the request. -vLLM provides best-effort support to detect this automatically, which is logged as a string like +`meta-llama/Llama-Guard-3-1B` that expect the content to be formatted according to the OpenAI schema in the +request. vLLM provides best-effort support to detect this automatically, which is logged as a string like *"Detected the chat template content format to be..."*, and internally converts incoming requests to match the detected format. If the result is not what you expect, you can use the `--chat-template-content-format` -CLI argument to explicitly specify which format to use (`"string"` or `"openai"`). +CLI argument to override which format to use (`"string"` or `"openai"`). ## Command line arguments for the server diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 535e910e66707..93660e6118ca8 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -70,6 +70,7 @@ def test_serving_chat_should_set_correct_max_tokens(): BASE_MODEL_PATHS, response_role="assistant", chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", lora_modules=None, prompt_adapters=None, request_logger=None) diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 5adb7760337fe..f4df82021cfeb 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -6,15 +6,24 @@ from vllm.assets.image import ImageAsset from vllm.config import ModelConfig -from vllm.entrypoints.chat_utils import (parse_chat_messages, - parse_chat_messages_futures) +from vllm.entrypoints.chat_utils import (load_chat_template, + parse_chat_messages, + parse_chat_messages_futures, + resolve_chat_template_content_format) from vllm.entrypoints.llm import apply_hf_chat_template from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import encode_image_base64 from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from ..utils import VLLM_PATH + +EXAMPLES_DIR = VLLM_PATH / "examples" + PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" +ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_3" +QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" +LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B" @pytest.fixture(scope="function") @@ -702,3 +711,73 @@ def get_conversation(is_hf: bool): ) assert hf_result == vllm_result + + +# yapf: disable +@pytest.mark.parametrize( + ("model", "expected_format"), + [(PHI3V_MODEL_ID, "string"), + (QWEN2VL_MODEL_ID, "openai"), + (ULTRAVOX_MODEL_ID, "string"), + (MLLAMA_MODEL_ID, "openai"), + (LLAMA_GUARD_MODEL_ID, "openai")], +) +# yapf: enable +def test_resolve_content_format_hf_defined(model, expected_format): + tokenizer_group = TokenizerGroup( + model, + enable_lora=False, + max_num_seqs=5, + max_input_length=None, + ) + tokenizer = tokenizer_group.tokenizer + + resolved_format = resolve_chat_template_content_format( + tokenizer.chat_template, + "auto", + tokenizer, + ) + + assert resolved_format == expected_format + + +# yapf: disable +@pytest.mark.parametrize( + ("template_path", "expected_format"), + [("template_alpaca.jinja", "string"), + ("template_baichuan.jinja", "string"), + ("template_blip2.jinja", "string"), + ("template_chatglm.jinja", "string"), + ("template_chatglm2.jinja", "string"), + ("template_chatml.jinja", "string"), + ("template_falcon_180b.jinja", "string"), + ("template_falcon.jinja", "string"), + ("template_inkbot.jinja", "string"), + ("template_llava.jinja", "string"), + ("template_vlm2vec.jinja", "openai"), + ("tool_chat_template_granite_20b_fc.jinja", "string"), + ("tool_chat_template_hermes.jinja", "string"), + ("tool_chat_template_internlm2_tool.jinja", "string"), + ("tool_chat_template_llama3.1_json.jinja", "string"), + ("tool_chat_template_llama3.2_json.jinja", "string"), + ("tool_chat_template_mistral_parallel.jinja", "string"), + ("tool_chat_template_mistral.jinja", "string")], +) +# yapf: enable +def test_resolve_content_format_examples(template_path, expected_format): + tokenizer_group = TokenizerGroup( + PHI3V_MODEL_ID, + enable_lora=False, + max_num_seqs=5, + max_input_length=None, + ) + dummy_tokenizer = tokenizer_group.tokenizer + dummy_tokenizer.chat_template = None + + resolved_format = resolve_chat_template_content_format( + load_chat_template(EXAMPLES_DIR / template_path), + "auto", + dummy_tokenizer, + ) + + assert resolved_format == expected_format diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index debc0e3dbef94..8a5f18c860cb9 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -8,7 +8,6 @@ from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List, Literal, Mapping, Optional, Tuple, TypeVar, Union, cast) -import jinja2 import jinja2.nodes import transformers.utils.chat_template_utils as hf_chat_utils # yapf conflicts with isort for this block @@ -144,6 +143,80 @@ class ConversationMessage(TypedDict, total=False): _ChatTemplateContentFormat = Literal["string", "openai"] +def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: + if isinstance(node, jinja2.nodes.Name): + return node.ctx == "load" and node.name == varname + + return False + + +def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: + if isinstance(node, jinja2.nodes.Getitem): + return (node.ctx == "load" and _is_var_access(node.node, varname) + and isinstance(node.arg, jinja2.nodes.Const) + and node.arg.value == key) + + if isinstance(node, jinja2.nodes.Getattr): + return (node.ctx == "load" and _is_var_access(node.node, varname) + and node.attr == key) + + return False + + +def _iter_self_and_descendants(node: jinja2.nodes.Node): + yield node + yield from node.find_all(jinja2.nodes.Node) + + +def _iter_nodes_define_message(chat_template_ast: jinja2.nodes.Template): + # Search for {%- for message in messages -%} loops + for loop_ast in chat_template_ast.find_all(jinja2.nodes.For): + loop_target = loop_ast.target + + # yapf: disable + if any( + _is_var_access(loop_iter_desc, "messages") for loop_iter_desc + in _iter_self_and_descendants(loop_ast.iter) + ): # yapf: enable + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name + + +def _iter_nodes_define_content_item(chat_template_ast: jinja2.nodes.Template): + for node, message_varname in _iter_nodes_define_message(chat_template_ast): + # Search for {%- for content in message['content'] -%} loops + for loop_ast in node.find_all(jinja2.nodes.For): + loop_target = loop_ast.target + + # yapf: disable + if any( + _is_attr_access(loop_iter_desc, message_varname, "content") + for loop_iter_desc in _iter_self_and_descendants(loop_ast.iter) + ): # yapf: enable + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name + + +def _detect_content_format( + chat_template: str, + *, + default: _ChatTemplateContentFormat, +) -> _ChatTemplateContentFormat: + try: + jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) + jinja_ast = jinja_compiled.environment.parse(chat_template) + except Exception: + logger.exception("Error when compiling Jinja template") + return default + + try: + next(_iter_nodes_define_content_item(jinja_ast)) + except StopIteration: + return "string" + else: + return "openai" + + def _resolve_chat_template_content_format( chat_template: Optional[str], given_format: ChatTemplateContentFormatOption, @@ -164,7 +237,7 @@ def _resolve_chat_template_content_format( jinja_text = load_chat_template(chat_template, is_literal=True) detected_format = ("string" if jinja_text is None else - _detect_chat_template_content_format(jinja_text)) + _detect_content_format(jinja_text, default="string")) return detected_format if given_format == "auto" else given_format @@ -183,7 +256,7 @@ def resolve_chat_template_content_format( logger.info( "Detected the chat template content format to be '%s'. " - "Set `--chat-template-content-format` to explicitly specify this.", + "You can set `--chat-template-content-format` to override this.", detected_format, ) @@ -738,61 +811,6 @@ def parse_chat_messages_futures( return conversation, mm_tracker.all_mm_data() -def _iter_nodes_define_message(chat_template_ast: jinja2.nodes.Template): - # Search for {%- for message in messages -%} loops - for loop_ast in chat_template_ast.find_all(jinja2.nodes.For): - loop_iter = loop_ast.iter - if not (isinstance(loop_iter, jinja2.nodes.Name) - and loop_iter.ctx == "load" and loop_iter.name == "messages"): - continue - - loop_target = loop_ast.target - if not isinstance(loop_target, jinja2.nodes.Name): - continue - - yield loop_ast, loop_target.name - - -def _iter_nodes_define_content_item(chat_template_ast: jinja2.nodes.Template): - for node, message_varname in _iter_nodes_define_message(chat_template_ast): - # Search for {%- for content in message['content'] -%} loops - for loop_ast in node.find_all(jinja2.nodes.For): - loop_iter = loop_ast.iter - if not (isinstance(loop_iter, jinja2.nodes.Getitem) - and loop_iter.ctx == "load"): - continue - - getitem_src = loop_iter.node - if not (isinstance(getitem_src, jinja2.nodes.Name) - and getitem_src.ctx == "load" - and getitem_src.name == message_varname): - continue - - getitem_idx = loop_iter.arg - if not (isinstance(getitem_idx, jinja2.nodes.Const) - and getitem_idx.value == "content"): - continue - - loop_target = loop_ast.target - if not isinstance(loop_target, jinja2.nodes.Name): - continue - - yield loop_iter, loop_target.name - - -def _detect_chat_template_content_format( - chat_template: str) -> _ChatTemplateContentFormat: - jinjacompiled = hf_chat_utils._compile_jinja_template(chat_template) - jinja_ast = jinjacompiled.environment.parse(chat_template) - - try: - next(_iter_nodes_define_content_item(jinja_ast)) - except StopIteration: - return "string" - else: - return "openai" - - def apply_hf_chat_template( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], conversation: List[ConversationMessage], diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b5f0a549c1d5b..0cb56261ebd77 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -526,9 +526,11 @@ def chat( to the OpenAI API. Args: - messages: A list of conversations or a single conversation. - - Each conversation is represented as a list of messages. - - Each message is a dictionary with 'role' and 'content' keys. + messages: A list of conversations or a single conversation. + + - Each conversation is represented as a list of messages. + - Each message is a dictionary with 'role' and 'content' keys. + sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. When it @@ -539,16 +541,18 @@ def chat( chat_template: The template to use for structuring the chat. If not provided, the model's default chat template will be used. chat_template_content_format: The format to render message content. - - "string" will render the content as a string. - Example: "Hello World" - - "openai" will render the content as a list of dictionaries, - similar to OpenAI schema. - Example: [{"type": "text", "text": "Hello world!"}] + + - "string" will render the content as a string. + Example: ``"Who are you?"`` + - "openai" will render the content as a list of dictionaries, + similar to OpenAI schema. + Example: ``[{"type": "text", "text": "Who are you?"}]`` + add_generation_prompt: If True, adds a generation template to each message. continue_final_message: If True, continues the final message in - the conversation instead of starting a new one. Cannot be `True` - if `add_generation_prompt` is also `True`. + the conversation instead of starting a new one. Cannot be + ``True`` if ``add_generation_prompt`` is also ``True``. mm_processor_kwargs: Multimodal processor kwarg overrides for this chat request. Only used for offline requests. @@ -739,7 +743,7 @@ def encode( generation, if any. Returns: - A list of `EmbeddingRequestOutput` objects containing the + A list of ``EmbeddingRequestOutput`` objects containing the generated embeddings in the same order as the input prompts. Note: From 54c25c37793e9fb86991e59414e763ebc14764d2 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 5 Nov 2024 11:42:12 +0000 Subject: [PATCH 06/18] Improve error handling Signed-off-by: DarkLight1337 --- vllm/entrypoints/chat_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 53ee7ac5a02f7..a82840b67710a 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -213,6 +213,9 @@ def _detect_content_format( next(_iter_nodes_define_content_item(jinja_ast)) except StopIteration: return "string" + except Exception: + logger.exception("Error when parsing AST of Jinja template") + return default else: return "openai" From fea748128e3b8cc180040c76f8bd5687f4111089 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 5 Nov 2024 11:46:12 +0000 Subject: [PATCH 07/18] Add example Signed-off-by: DarkLight1337 --- docs/source/serving/openai_compatible_server.md | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index e6cb41b245cf3..1043bdea2af4e 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -146,13 +146,20 @@ completion = client.chat.completions.create( ] ) ``` -Most chat templates for LLMs expect the `content` field to be a string but there are some newer models like + +Most chat templates for LLMs expect the `content` field to be a string, but there are some newer models like `meta-llama/Llama-Guard-3-1B` that expect the content to be formatted according to the OpenAI schema in the request. vLLM provides best-effort support to detect this automatically, which is logged as a string like *"Detected the chat template content format to be..."*, and internally converts incoming requests to match -the detected format. If the result is not what you expect, you can use the `--chat-template-content-format` -CLI argument to override which format to use (`"string"` or `"openai"`). +the detected format, which can be one of: + +- `"string"`: A string. + - Example: `"Hello world"` +- `"openai"`: A list of dictionaries, similar to OpenAI schema. + - Example: `[{"type": "text", "text": "Hello world!"}]` +If the result is not what you expect, you can set the `--chat-template-content-format` CLI argument +to override which format to use. ## Command line arguments for the server From 4afe2547718b344c787dd5620c7f5edd04094928 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 8 Nov 2024 03:56:34 +0000 Subject: [PATCH 08/18] Remove repeated definition Signed-off-by: DarkLight1337 --- vllm/entrypoints/openai/protocol.py | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index af8190a82ba1d..06b666ea30945 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -5,9 +5,8 @@ from typing import Any, Dict, List, Literal, Optional, Union import torch -from openai.types.chat import ChatCompletionContentPartParam from pydantic import BaseModel, ConfigDict, Field, model_validator -from typing_extensions import Annotated, Required, TypedDict +from typing_extensions import Annotated from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.pooling_params import PoolingParams @@ -35,26 +34,6 @@ assert _LONG_INFO.max == _MOCK_LONG_INFO.max -class CustomChatCompletionMessageParam(TypedDict, total=False): - """Enables custom roles in the Chat Completion API.""" - role: Required[str] - """The role of the message's author.""" - - content: Union[str, List[ChatCompletionContentPartParam]] - """The contents of the message.""" - - name: str - """An optional name for the participant. - - Provides the model information to differentiate between participants of the - same role. - """ - - tool_call_id: Optional[str] - - tool_calls: Optional[List[dict]] - - class OpenAIBaseModel(BaseModel): # OpenAI API does not allow extra fields model_config = ConfigDict(extra="forbid") From 231b4d967735a2f30dadcbce00b86782539d27aa Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 13 Nov 2024 02:34:08 +0000 Subject: [PATCH 09/18] Remove unused attribute Signed-off-by: DarkLight1337 --- vllm/entrypoints/openai/serving_chat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 04182375d54ef..a58b40b7118c7 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -64,7 +64,6 @@ def __init__( return_tokens_as_token_ids=return_tokens_as_token_ids) self.response_role = response_role - self.use_tool_use_model_template = False self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format From 656b334ca233d7b3d41de4031674c8f7475739a0 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 13 Nov 2024 04:06:51 +0000 Subject: [PATCH 10/18] Consider variable reassignment Signed-off-by: DarkLight1337 --- tests/entrypoints/test_chat_utils.py | 22 +++++- vllm/entrypoints/chat_utils.py | 110 +++++++++++++++++++-------- 2 files changed, 97 insertions(+), 35 deletions(-) diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index f4df82021cfeb..6465e27a807d7 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -6,7 +6,7 @@ from vllm.assets.image import ImageAsset from vllm.config import ModelConfig -from vllm.entrypoints.chat_utils import (load_chat_template, +from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template, parse_chat_messages, parse_chat_messages_futures, resolve_chat_template_content_format) @@ -732,8 +732,16 @@ def test_resolve_content_format_hf_defined(model, expected_format): ) tokenizer = tokenizer_group.tokenizer + chat_template = tokenizer.chat_template + assert isinstance(chat_template, str) + + print("[ORIGINAL]") + print(chat_template) + print("[AST]") + print(_try_extract_ast(chat_template)) + resolved_format = resolve_chat_template_content_format( - tokenizer.chat_template, + chat_template, "auto", tokenizer, ) @@ -774,8 +782,16 @@ def test_resolve_content_format_examples(template_path, expected_format): dummy_tokenizer = tokenizer_group.tokenizer dummy_tokenizer.chat_template = None + chat_template = load_chat_template(EXAMPLES_DIR / template_path) + assert isinstance(chat_template, str) + + print("[ORIGINAL]") + print(chat_template) + print("[AST]") + print(_try_extract_ast(chat_template)) + resolved_format = resolve_chat_template_content_format( - load_chat_template(EXAMPLES_DIR / template_path), + chat_template, "auto", dummy_tokenizer, ) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index d592dd128555c..270d89e7ba004 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -182,38 +182,88 @@ def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: return False -def _iter_self_and_descendants(node: jinja2.nodes.Node): - yield node - yield from node.find_all(jinja2.nodes.Node) +def _is_var_or_elems_access( + node: jinja2.nodes.Node, + varname: str, + key: Optional[str] = None, +) -> bool: + if key: + if _is_attr_access(node, varname, key): + return True + else: + if _is_var_access(node, varname): + return True + + if isinstance(node, jinja2.nodes.Getitem): + return (node.ctx == "load" + and _is_var_or_elems_access(node.node, varname, key) + and isinstance(node.arg, jinja2.nodes.Slice)) + if isinstance(node, jinja2.nodes.Filter): + return (node.node is not None + and _is_var_or_elems_access(node.node, varname, key)) + if isinstance(node, jinja2.nodes.Test): + return _is_var_or_elems_access(node.node, varname, key) + + return False + + +def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): + yield root, varname + + for assign_ast in root.find_all(jinja2.nodes.Assign): + lhs = assign_ast.target + rhs = assign_ast.node + if _is_var_or_elems_access(rhs, varname): + assert isinstance(lhs, jinja2.nodes.Name) + yield assign_ast, lhs.name + + +# NOTE: The proper way to handle this is to build a CFG so that we can handle +# the scope in which each variable is defined, but that is too complicated +def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node): + messages_varnames = [ + varname + for _, varname in _iter_nodes_assign_var_or_elems(root, "messages") + ] -def _iter_nodes_define_message(chat_template_ast: jinja2.nodes.Template): # Search for {%- for message in messages -%} loops - for loop_ast in chat_template_ast.find_all(jinja2.nodes.For): + for loop_ast in root.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter loop_target = loop_ast.target - # yapf: disable - if any( - _is_var_access(loop_iter_desc, "messages") for loop_iter_desc - in _iter_self_and_descendants(loop_ast.iter) - ): # yapf: enable - assert isinstance(loop_target, jinja2.nodes.Name) - yield loop_ast, loop_target.name - - -def _iter_nodes_define_content_item(chat_template_ast: jinja2.nodes.Template): - for node, message_varname in _iter_nodes_define_message(chat_template_ast): - # Search for {%- for content in message['content'] -%} loops - for loop_ast in node.find_all(jinja2.nodes.For): - loop_target = loop_ast.target - - # yapf: disable - if any( - _is_attr_access(loop_iter_desc, message_varname, "content") - for loop_iter_desc in _iter_self_and_descendants(loop_ast.iter) - ): # yapf: enable + for varname in messages_varnames: + if _is_var_or_elems_access(loop_iter, varname): assert isinstance(loop_target, jinja2.nodes.Name) yield loop_ast, loop_target.name + break + + +def _iter_nodes_assign_content_item(root: jinja2.nodes.Node): + message_varnames = [ + varname + for _, varname in _iter_nodes_assign_messages_item(root) + ] + + # Search for {%- for content in message['content'] -%} loops + for loop_ast in root.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + loop_target = loop_ast.target + + for varname in message_varnames: + if _is_var_or_elems_access(loop_iter, varname, "content"): + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name + break + + +def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]: + try: + jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) + return jinja_compiled.environment.parse(chat_template) + except Exception: + logger.exception("Error when compiling Jinja template") + return None def _detect_content_format( @@ -221,15 +271,12 @@ def _detect_content_format( *, default: _ChatTemplateContentFormat, ) -> _ChatTemplateContentFormat: - try: - jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) - jinja_ast = jinja_compiled.environment.parse(chat_template) - except Exception: - logger.exception("Error when compiling Jinja template") + jinja_ast = _try_extract_ast(chat_template) + if jinja_ast is None: return default try: - next(_iter_nodes_define_content_item(jinja_ast)) + next(_iter_nodes_assign_content_item(jinja_ast)) except StopIteration: return "string" except Exception: @@ -616,7 +663,6 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _VideoParser = partial(cast, ChatCompletionContentPartVideoParam) - # Define a mapping from part types to their corresponding parsing functions. MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = { "text": From 3cec3915c6d897c5cc74e79f9402e272cda30697 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 13 Nov 2024 04:11:34 +0000 Subject: [PATCH 11/18] Cleanup Signed-off-by: DarkLight1337 --- vllm/entrypoints/chat_utils.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 270d89e7ba004..6d1ac7adee426 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -187,13 +187,6 @@ def _is_var_or_elems_access( varname: str, key: Optional[str] = None, ) -> bool: - if key: - if _is_attr_access(node, varname, key): - return True - else: - if _is_var_access(node, varname): - return True - if isinstance(node, jinja2.nodes.Getitem): return (node.ctx == "load" and _is_var_or_elems_access(node.node, varname, key) @@ -204,7 +197,11 @@ def _is_var_or_elems_access( if isinstance(node, jinja2.nodes.Test): return _is_var_or_elems_access(node.node, varname, key) - return False + # yapf: disable + return ( + _is_attr_access(node, varname, key) if key + else _is_var_access( node, varname) + ) # yapf: enable def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): @@ -241,8 +238,7 @@ def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node): def _iter_nodes_assign_content_item(root: jinja2.nodes.Node): message_varnames = [ - varname - for _, varname in _iter_nodes_assign_messages_item(root) + varname for _, varname in _iter_nodes_assign_messages_item(root) ] # Search for {%- for content in message['content'] -%} loops From f00419da0e9871b4afa57834ff878a39bc43291e Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 13 Nov 2024 06:28:58 +0000 Subject: [PATCH 12/18] Fix Signed-off-by: DarkLight1337 --- tests/entrypoints/test_chat_utils.py | 4 ++-- vllm/entrypoints/chat_utils.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 6465e27a807d7..e20d75f2e3acc 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -735,7 +735,7 @@ def test_resolve_content_format_hf_defined(model, expected_format): chat_template = tokenizer.chat_template assert isinstance(chat_template, str) - print("[ORIGINAL]") + print("[TEXT]") print(chat_template) print("[AST]") print(_try_extract_ast(chat_template)) @@ -785,7 +785,7 @@ def test_resolve_content_format_examples(template_path, expected_format): chat_template = load_chat_template(EXAMPLES_DIR / template_path) assert isinstance(chat_template, str) - print("[ORIGINAL]") + print("[TEXT]") print(chat_template) print("[AST]") print(_try_extract_ast(chat_template)) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 6d1ac7adee426..62964d989c6ec 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -187,20 +187,20 @@ def _is_var_or_elems_access( varname: str, key: Optional[str] = None, ) -> bool: - if isinstance(node, jinja2.nodes.Getitem): - return (node.ctx == "load" - and _is_var_or_elems_access(node.node, varname, key) - and isinstance(node.arg, jinja2.nodes.Slice)) if isinstance(node, jinja2.nodes.Filter): return (node.node is not None and _is_var_or_elems_access(node.node, varname, key)) if isinstance(node, jinja2.nodes.Test): return _is_var_or_elems_access(node.node, varname, key) + if (isinstance(node, jinja2.nodes.Getitem) + and isinstance(node.arg, jinja2.nodes.Slice)): + return _is_var_or_elems_access(node.node, varname, key) + # yapf: disable return ( _is_attr_access(node, varname, key) if key - else _is_var_access( node, varname) + else _is_var_access(node, varname) ) # yapf: enable From 7c594e1a115d323678842cc68bf8468de07f0a13 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 13 Nov 2024 06:35:46 +0000 Subject: [PATCH 13/18] format Signed-off-by: DarkLight1337 --- vllm/entrypoints/chat_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 62964d989c6ec..59c0341085eae 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -194,7 +194,7 @@ def _is_var_or_elems_access( return _is_var_or_elems_access(node.node, varname, key) if (isinstance(node, jinja2.nodes.Getitem) - and isinstance(node.arg, jinja2.nodes.Slice)): + and isinstance(node.arg, jinja2.nodes.Slice)): return _is_var_or_elems_access(node.node, varname, key) # yapf: disable From 5b87baf692d2e2f28fcfbf7fa74b4df8201b9027 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 13 Nov 2024 10:11:13 +0000 Subject: [PATCH 14/18] Simplify the code Signed-off-by: DarkLight1337 --- vllm/entrypoints/chat_utils.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 59c0341085eae..4782c233edcf8 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -735,12 +735,12 @@ def _parse_chat_message_content_parts( role: str, parts: Iterable[ChatCompletionContentPartParam], mm_tracker: BaseMultiModalItemTracker, - content_format: _ChatTemplateContentFormat, + *, + wrap_dicts: bool, ) -> List[ConversationMessage]: content: List[Union[str, Dict[str, str]]] = [] mm_parser = mm_tracker.create_parser() - wrap_dicts = content_format == "openai" for part in parts: parse_res = _parse_chat_message_content_part( @@ -765,9 +765,11 @@ def _parse_chat_message_content_parts( def _parse_chat_message_content_part( - part: ChatCompletionContentPartParam, - mm_parser: BaseMultiModalContentParser, - wrap_dicts: bool) -> Optional[Union[str, Dict[str, str]]]: + part: ChatCompletionContentPartParam, + mm_parser: BaseMultiModalContentParser, + *, + wrap_dicts: bool, +) -> Optional[Union[str, Dict[str, str]]]: """Parses a single part of a conversation. If wrap_dicts is True, structured dictionary pieces for texts and images will be wrapped in dictionaries, i.e., {"type": "text", "text", ...} and @@ -832,7 +834,7 @@ def _parse_chat_message_content( role, content, # type: ignore mm_tracker, - content_format, + wrap_dicts=(content_format == "openai"), ) for result_msg in result: From 4b3dd759b736a8efb99d569596c27fb9663b7ae1 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 13 Nov 2024 10:34:26 +0000 Subject: [PATCH 15/18] Fix bug when chat_template is None Signed-off-by: DarkLight1337 --- tests/entrypoints/test_chat_utils.py | 2 +- vllm/entrypoints/chat_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index e20d75f2e3acc..72477e048eafa 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -741,7 +741,7 @@ def test_resolve_content_format_hf_defined(model, expected_format): print(_try_extract_ast(chat_template)) resolved_format = resolve_chat_template_content_format( - chat_template, + None, # Test detecting the tokenizer's chat_template "auto", tokenizer, ) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 4782c233edcf8..4f4cb5a02abfe 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -295,7 +295,7 @@ def _resolve_chat_template_content_format( jinja_text: Optional[str] if isinstance(tokenizer_chat_template, str) and chat_template is None: jinja_text = tokenizer_chat_template - if (isinstance(tokenizer_chat_template, dict) + elif (isinstance(tokenizer_chat_template, dict) and chat_template in tokenizer_chat_template): jinja_text = tokenizer_chat_template[chat_template] else: From 03f6e98095dfd0f1d4787183a2c221d8240fd313 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 13 Nov 2024 13:42:59 +0000 Subject: [PATCH 16/18] Recurse into var assignment Signed-off-by: DarkLight1337 --- vllm/entrypoints/chat_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 4f4cb5a02abfe..d6ab3c04e7e2c 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -205,15 +205,18 @@ def _is_var_or_elems_access( def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): + # Global variable that is implicitly defined at the root yield root, varname + related_varnames: List[str] = [varname] for assign_ast in root.find_all(jinja2.nodes.Assign): lhs = assign_ast.target rhs = assign_ast.node - if _is_var_or_elems_access(rhs, varname): + if any(_is_var_or_elems_access(rhs, name) for name in related_varnames): assert isinstance(lhs, jinja2.nodes.Name) yield assign_ast, lhs.name + related_varnames.append(lhs.name) # NOTE: The proper way to handle this is to build a CFG so that we can handle From c8a6a75a80f1cf8aae8fdd8a1144534bec838b2e Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 15 Nov 2024 00:06:51 +0000 Subject: [PATCH 17/18] Fix redundant check Signed-off-by: DarkLight1337 --- vllm/entrypoints/chat_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index d6ab3c04e7e2c..b3e3e5e8b0784 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -171,13 +171,12 @@ def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: if isinstance(node, jinja2.nodes.Getitem): - return (node.ctx == "load" and _is_var_access(node.node, varname) + return (_is_var_access(node.node, varname) and isinstance(node.arg, jinja2.nodes.Const) and node.arg.value == key) if isinstance(node, jinja2.nodes.Getattr): - return (node.ctx == "load" and _is_var_access(node.node, varname) - and node.attr == key) + return _is_var_access(node.node, varname) and node.attr == key return False From 1ea0b371f7187387e348660f86edffe12c82f508 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 15 Nov 2024 00:13:42 +0000 Subject: [PATCH 18/18] Use iterative BFS Signed-off-by: DarkLight1337 --- vllm/entrypoints/chat_utils.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index b3e3e5e8b0784..abee5ac46391c 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -2,7 +2,7 @@ import codecs import json from abc import ABC, abstractmethod -from collections import defaultdict +from collections import defaultdict, deque from functools import lru_cache, partial from pathlib import Path from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List, @@ -207,15 +207,22 @@ def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): # Global variable that is implicitly defined at the root yield root, varname - related_varnames: List[str] = [varname] - for assign_ast in root.find_all(jinja2.nodes.Assign): - lhs = assign_ast.target - rhs = assign_ast.node + # Iterative BFS + related_varnames = deque([varname]) + while related_varnames: + related_varname = related_varnames.popleft() - if any(_is_var_or_elems_access(rhs, name) for name in related_varnames): - assert isinstance(lhs, jinja2.nodes.Name) - yield assign_ast, lhs.name - related_varnames.append(lhs.name) + for assign_ast in root.find_all(jinja2.nodes.Assign): + lhs = assign_ast.target + rhs = assign_ast.node + + if _is_var_or_elems_access(rhs, related_varname): + assert isinstance(lhs, jinja2.nodes.Name) + yield assign_ast, lhs.name + + # Avoid infinite looping for self-assignment + if lhs.name != related_varname: + related_varnames.append(lhs.name) # NOTE: The proper way to handle this is to build a CFG so that we can handle