Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Use a proper chat template for VLM2Vec #9912

Merged
merged 5 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,7 @@ To consume the server, you can use the OpenAI client like in the example below:
)
print("Chat completion output:", chat_response.choices[0].message.content)


A full code example can be found in `examples/openai_api_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_api_client_for_multimodal.py>`_.
A full code example can be found in `examples/openai_chat_completion_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_chat_completion_client_for_multimodal.py>`_.

.. tip::
There is no need to place image placeholders in the text content of the API request - they are already represented by the image content.
Expand Down Expand Up @@ -269,14 +268,19 @@ In this example, we will serve the ``TIGER-Lab/VLM2Vec-Full`` model.
.. code-block:: bash

vllm serve TIGER-Lab/VLM2Vec-Full --task embedding \
--trust-remote-code --max-model-len 4096
--trust-remote-code --max-model-len 4096 --chat-template examples/template_vlm2vec.jinja

.. important::

Since VLM2Vec has the same model architecture as Phi-3.5-Vision, we have to explicitly pass ``--task embedding``
to run this model in embedding mode instead of text generation mode.

Since this schema is not defined by OpenAI client, we post a request to the server using the lower-level ``requests`` library:
.. important::

VLM2Vec does not expect chat-based input. We use a `custom chat template <https://github.com/vllm-project/vllm/blob/main/examples/template_vlm2vec.jinja>`_
to combine the text and images together.

Since the request schema is not defined by OpenAI client, we post a request to the server using the lower-level ``requests`` library:

.. code-block:: python

Expand All @@ -301,3 +305,5 @@ Since this schema is not defined by OpenAI client, we post a request to the serv
response.raise_for_status()
response_json = response.json()
print("Embedding output:", response_json["data"][0]["embedding"])

A full code example can be found in `examples/openai_chat_embedding_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_chat_embedding_client_for_multimodal.py>`_.
33 changes: 33 additions & 0 deletions examples/openai_chat_embedding_client_for_multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import requests

image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"

response = requests.post(
"http://localhost:8000/v1/embeddings",
json={
"model":
"TIGER-Lab/VLM2Vec-Full",
"messages": [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "Represent the given image."
},
],
}],
"encoding_format":
"float",
},
)
response.raise_for_status()
response_json = response.json()

print("Embedding output:", response_json["data"][0]["embedding"])
16 changes: 16 additions & 0 deletions examples/template_vlm2vec.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{%- if messages | length > 1 -%}
{{ raise_exception('Embedding models should only embed one message at a time') }}
{%- endif -%}

{% set vars = namespace(parts=[], next_image_id=1) %}
{%- for message in messages -%}
{%- for content in message['content'] -%}
{%- if content['type'] == 'text' -%}
{%- set vars.parts = vars.parts + [content['text']] %}
{%- elif content['type'] == 'image' -%}
{%- set vars.parts = vars.parts + ['<|image_{i:d}|>'.format(i=vars.next_image_id)] %}
{%- set vars.next_image_id = vars.next_image_id + 1 %}
{%- endif -%}
{%- endfor -%}
{%- endfor -%}
{{ vars.parts | join(' ') }}
11 changes: 8 additions & 3 deletions tests/entrypoints/openai/test_vision_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@

from vllm.multimodal.utils import encode_image_base64, fetch_image

from ...utils import RemoteOpenAIServer
from ...utils import VLLM_PATH, RemoteOpenAIServer

MODEL_NAME = "TIGER-Lab/VLM2Vec-Full"
MAXIMUM_IMAGES = 2

vlm2vec_jinja_path = VLLM_PATH / "examples/template_vlm2vec.jinja"
assert vlm2vec_jinja_path.exists()

# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
Expand All @@ -35,6 +38,8 @@ def server():
"--trust-remote-code",
"--limit-mm-per-prompt",
f"image={MAXIMUM_IMAGES}",
"--chat-template",
str(vlm2vec_jinja_path),
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
Expand Down Expand Up @@ -90,5 +95,5 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
assert len(embeddings["data"]) == 1
assert len(embeddings["data"][0]["embedding"]) == 3072
assert embeddings["usage"]["completion_tokens"] == 0
assert embeddings["usage"]["prompt_tokens"] == 771
assert embeddings["usage"]["total_tokens"] == 771
assert embeddings["usage"]["prompt_tokens"] == 762
assert embeddings["usage"]["total_tokens"] == 762
15 changes: 11 additions & 4 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):

self._items: List[_T] = []

@property
def model_config(self) -> ModelConfig:
return self._model_config

@staticmethod
@lru_cache(maxsize=None)
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
Expand Down Expand Up @@ -491,10 +495,13 @@ def _parse_chat_message_content_parts(
content: List[Union[str, Dict[str, str]]] = []

mm_parser = mm_tracker.create_parser()
wrap_dicts = \
mm_tracker._model_config.hf_config.model_type in \
MODEL_KEEP_MULTI_MODAL_CONTENT or \
(chat_template_text_format == "openai")
model_config = mm_tracker.model_config

wrap_dicts = (chat_template_text_format == "openai"
or (model_config.task == "embedding"
and model_config.is_multimodal_model)
or (model_config.hf_config.model_type
in MODEL_KEEP_MULTI_MODAL_CONTENT))
Comment on lines +500 to +504
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should auto-detect this in the future...


for part in parts:
parse_res = _parse_chat_message_content_part(
Expand Down