Skip to content

Commit

Permalink
Write out skeleton code
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 committed Nov 1, 2024
1 parent ba0d892 commit d484401
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 28 deletions.
3 changes: 2 additions & 1 deletion docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ In this example, we will serve the ``TIGER-Lab/VLM2Vec-Full`` model.
.. code-block:: bash
vllm serve TIGER-Lab/VLM2Vec-Full --task embedding \
--trust-remote-code --max-model-len 4096 --chat-template examples/template_vlm2vec.jinja
--trust-remote-code --max-model-len 4096 \
--chat-template examples/template_vlm2vec.jinja --chat-template-content-format openai
.. important::

Expand Down
2 changes: 1 addition & 1 deletion docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ completion = client.chat.completions.create(
```
Most chat templates for LLMs expect the `content` to be a `string` but there are some newer models like
`meta-llama/Llama-Guard-3-1B` that expect the content to be parsed with the new OpenAI spec. In order to choose which
format the content needs to be parsed in by vLLM, please use the `--chat-template-text-format` argument to specify
format the content needs to be parsed in by vLLM, please use the `--chat-template-content-format` argument to specify
between `string` or `openai`. The default value is `string` and vLLM internally converts both spec formats to match
this, unless explicitly specified.

Expand Down
2 changes: 1 addition & 1 deletion tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class MockModelConfig:
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
chat_template_text_format = "string"
chat_template_content_format = "string"
max_model_len = 100
tokenizer_revision = None
multimodal_config = MultiModalConfig()
Expand Down
4 changes: 2 additions & 2 deletions tests/entrypoints/test_chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def phi3v_model_config():
trust_remote_code=True,
dtype="bfloat16",
seed=0,
chat_template_text_format="string",
chat_template_content_format="string",
limit_mm_per_prompt={
"image": 2,
})
Expand Down Expand Up @@ -335,7 +335,7 @@ def test_parse_chat_messages_context_text_format(
phi3v_model_config,
phi3v_tokenizer,
):
phi3v_model_config.chat_template_text_format = "openai"
phi3v_model_config.chat_template_content_format = "openai"
conversation, mm_data = parse_chat_messages(
[{
"role": "user",
Expand Down
4 changes: 2 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(
use_async_output_proc: bool = True,
override_neuron_config: Optional[Dict[str, Any]] = None,
config_format: ConfigFormat = ConfigFormat.AUTO,
chat_template_text_format: str = "string",
chat_template_content_format: str = "string",
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
pooling_type: Optional[str] = None,
pooling_norm: Optional[bool] = None,
Expand Down Expand Up @@ -190,7 +190,7 @@ def __init__(
self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.use_async_output_proc = use_async_output_proc
self.chat_template_text_format = chat_template_text_format
self.chat_template_content_format = chat_template_content_format
self.mm_processor_kwargs = mm_processor_kwargs

# Set enforce_eager to False if the value is unset.
Expand Down
17 changes: 10 additions & 7 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class EngineArgs:
task: TaskOption = "auto"
skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto'
chat_template_text_format: str = 'string'
chat_template_content_format: str = 'string'
trust_remote_code: bool = False
download_dir: Optional[str] = None
load_format: str = 'auto'
Expand Down Expand Up @@ -258,13 +258,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.')
parser.add_argument(
'--chat-template-text-format',
'--chat-template-content-format',
type=str,
default=EngineArgs.chat_template_text_format,
default=EngineArgs.chat_template_content_format,
choices=['string', 'openai'],
help='The format to render text content within a chat template. '
'"string" will keep the content field as a string whereas '
'"openai" will parse content in the current OpenAI format.')
help='The format to render message content within a chat template.'
'\n\n'
'* "string" will render `message["content"]` as a string. '
'Example: "Hello World"\n'
'* "openai" will render `message["content"]` like OpenAI schema. '
'Example: {"type": "text", "text": "Hello world!"}')
parser.add_argument('--trust-remote-code',
action='store_true',
help='Trust remote code from huggingface.')
Expand Down Expand Up @@ -917,7 +920,7 @@ def create_model_config(self) -> ModelConfig:
# We know this is not None because we set it in __post_init__
tokenizer=cast(str, self.tokenizer),
tokenizer_mode=self.tokenizer_mode,
chat_template_text_format=self.chat_template_text_format,
chat_template_content_format=self.chat_template_content_format,
trust_remote_code=self.trust_remote_code,
dtype=self.dtype,
seed=self.seed,
Expand Down
4 changes: 2 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def __init__(
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, "
"chat_template_text_format=%s, mm_processor_kwargs=%s, "
"chat_template_content_format=%s, mm_processor_kwargs=%s, "
"pooler_config=%r)",
VLLM_VERSION,
model_config.model,
Expand Down Expand Up @@ -293,7 +293,7 @@ def __init__(
cache_config.enable_prefix_caching,
model_config.use_async_output_proc,
use_cached_outputs,
model_config.chat_template_text_format,
model_config.chat_template_content_format,
model_config.mm_processor_kwargs,
model_config.pooler_config,
)
Expand Down
33 changes: 21 additions & 12 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
Literal, Mapping, Optional, Tuple, TypeVar, Union, cast)

import jinja2

Check failure on line 11 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff (F401)

vllm/entrypoints/chat_utils.py:11:8: F401 `jinja2` imported but unused

Check failure on line 11 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.9)

Ruff (F401)

vllm/entrypoints/chat_utils.py:11:8: F401 `jinja2` imported but unused

Check failure on line 11 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (F401)

vllm/entrypoints/chat_utils.py:11:8: F401 `jinja2` imported but unused

Check failure on line 11 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (F401)

vllm/entrypoints/chat_utils.py:11:8: F401 `jinja2` imported but unused
# yapf conflicts with isort for this block
# yapf: disable
from openai.types.chat import (ChatCompletionAssistantMessageParam,
Expand All @@ -24,6 +25,7 @@
# pydantic needs the TypedDict from typing_extensions
from pydantic import ConfigDict
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.utils.chat_template_utils import _compile_jinja_template
from typing_extensions import Required, TypeAlias, TypedDict

from vllm.config import ModelConfig
Expand Down Expand Up @@ -417,7 +419,6 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}

# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = {
Expand Down Expand Up @@ -490,18 +491,13 @@ def _parse_chat_message_content_parts(
role: str,
parts: Iterable[ChatCompletionContentPartParam],
mm_tracker: BaseMultiModalItemTracker,
chat_template_text_format: str,
) -> List[ConversationMessage]:
content: List[Union[str, Dict[str, str]]] = []

mm_parser = mm_tracker.create_parser()
model_config = mm_tracker.model_config
model_config = mm_tracer.model_config

Check failure on line 498 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff (F821)

vllm/entrypoints/chat_utils.py:498:20: F821 Undefined name `mm_tracer`

Check failure on line 498 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.8)

Name "mm_tracer" is not defined [name-defined]

Check failure on line 498 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Name "mm_tracer" is not defined [name-defined]

Check failure on line 498 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.9)

Ruff (F821)

vllm/entrypoints/chat_utils.py:498:20: F821 Undefined name `mm_tracer`

Check failure on line 498 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Name "mm_tracer" is not defined [name-defined]

Check failure on line 498 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (F821)

vllm/entrypoints/chat_utils.py:498:20: F821 Undefined name `mm_tracer`

Check failure on line 498 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Name "mm_tracer" is not defined [name-defined]

Check failure on line 498 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (F821)

vllm/entrypoints/chat_utils.py:498:20: F821 Undefined name `mm_tracer`

Check failure on line 498 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Name "mm_tracer" is not defined [name-defined]

wrap_dicts = (chat_template_text_format == "openai"
or (model_config.task == "embedding"
and model_config.is_multimodal_model)
or (model_config.hf_config.model_type
in MODEL_KEEP_MULTI_MODAL_CONTENT))
wrap_dicts = model_config.chat_template_content_format == "openai"

for part in parts:
parse_res = _parse_chat_message_content_part(
Expand Down Expand Up @@ -573,7 +569,6 @@ def _parse_chat_message_content_part(
def _parse_chat_message_content(
message: ChatCompletionMessageParam,
mm_tracker: BaseMultiModalItemTracker,
chat_template_text_format: str,
) -> List[ConversationMessage]:
role = message["role"]
content = message.get("content")
Expand All @@ -589,7 +584,6 @@ def _parse_chat_message_content(
role,
content, # type: ignore
mm_tracker,
chat_template_text_format,
)

for result_msg in result:
Expand Down Expand Up @@ -636,7 +630,6 @@ def parse_chat_messages(
sub_messages = _parse_chat_message_content(
msg,
mm_tracker,
model_config.chat_template_text_format,
)

conversation.extend(sub_messages)
Expand All @@ -658,7 +651,6 @@ def parse_chat_messages_futures(
sub_messages = _parse_chat_message_content(
msg,
mm_tracker,
model_config.chat_template_text_format,
)

conversation.extend(sub_messages)
Expand All @@ -668,6 +660,23 @@ def parse_chat_messages_futures(
return conversation, mm_tracker.all_mm_data()


def validate_hf_chat_template_content_format(
model_config: ModelConfig,
chat_template: str,
) -> None:
content_format = model_config.chat_template_content_format

compiled = _compile_jinja_template(chat_template)

Check failure on line 669 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff (F841)

vllm/entrypoints/chat_utils.py:669:5: F841 Local variable `compiled` is assigned to but never used

Check failure on line 669 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.9)

Ruff (F841)

vllm/entrypoints/chat_utils.py:669:5: F841 Local variable `compiled` is assigned to but never used

Check failure on line 669 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (F841)

vllm/entrypoints/chat_utils.py:669:5: F841 Local variable `compiled` is assigned to but never used

Check failure on line 669 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (F841)

vllm/entrypoints/chat_utils.py:669:5: F841 Local variable `compiled` is assigned to but never used
# Somehow parse out the AST and find how messages[int]['content'] is used?

if content_format == "string":
pass
elif content_format == "openai":
pass

Check failure on line 675 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.8)

Ruff (SIM114)

vllm/entrypoints/chat_utils.py:672:5: SIM114 Combine `if` branches using logical `or` operator

Check failure on line 675 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.9)

Ruff (SIM114)

vllm/entrypoints/chat_utils.py:672:5: SIM114 Combine `if` branches using logical `or` operator

Check failure on line 675 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (SIM114)

vllm/entrypoints/chat_utils.py:672:5: SIM114 Combine `if` branches using logical `or` operator

Check failure on line 675 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.11)

Ruff (SIM114)

vllm/entrypoints/chat_utils.py:672:5: SIM114 Combine `if` branches using logical `or` operator
else:
raise ValueError(f"Invalid format: {content_format}")


def apply_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: List[ConversationMessage],
Expand Down

0 comments on commit d484401

Please sign in to comment.