Skip to content

Commit

Permalink
[Frontend] Use a proper chat template for VLM2Vec (vllm-project#9912)
Browse files Browse the repository at this point in the history
Signed-off-by: Loc Huynh <[email protected]>
  • Loading branch information
DarkLight1337 authored and JC1DA committed Nov 11, 2024
1 parent 58e5ad3 commit 60464e8
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 11 deletions.
14 changes: 10 additions & 4 deletions docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,7 @@ To consume the server, you can use the OpenAI client like in the example below:
)
print("Chat completion output:", chat_response.choices[0].message.content)
A full code example can be found in `examples/openai_api_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_api_client_for_multimodal.py>`_.
A full code example can be found in `examples/openai_chat_completion_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_chat_completion_client_for_multimodal.py>`_.

.. tip::
There is no need to place image placeholders in the text content of the API request - they are already represented by the image content.
Expand Down Expand Up @@ -269,14 +268,19 @@ In this example, we will serve the ``TIGER-Lab/VLM2Vec-Full`` model.
.. code-block:: bash
vllm serve TIGER-Lab/VLM2Vec-Full --task embedding \
--trust-remote-code --max-model-len 4096
--trust-remote-code --max-model-len 4096 --chat-template examples/template_vlm2vec.jinja
.. important::

Since VLM2Vec has the same model architecture as Phi-3.5-Vision, we have to explicitly pass ``--task embedding``
to run this model in embedding mode instead of text generation mode.

Since this schema is not defined by OpenAI client, we post a request to the server using the lower-level ``requests`` library:
.. important::

VLM2Vec does not expect chat-based input. We use a `custom chat template <https://github.com/vllm-project/vllm/blob/main/examples/template_vlm2vec.jinja>`_
to combine the text and images together.

Since the request schema is not defined by OpenAI client, we post a request to the server using the lower-level ``requests`` library:

.. code-block:: python
Expand All @@ -301,3 +305,5 @@ Since this schema is not defined by OpenAI client, we post a request to the serv
response.raise_for_status()
response_json = response.json()
print("Embedding output:", response_json["data"][0]["embedding"])
A full code example can be found in `examples/openai_chat_embedding_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_chat_embedding_client_for_multimodal.py>`_.
33 changes: 33 additions & 0 deletions examples/openai_chat_embedding_client_for_multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import requests

image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"

response = requests.post(
"http://localhost:8000/v1/embeddings",
json={
"model":
"TIGER-Lab/VLM2Vec-Full",
"messages": [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "Represent the given image."
},
],
}],
"encoding_format":
"float",
},
)
response.raise_for_status()
response_json = response.json()

print("Embedding output:", response_json["data"][0]["embedding"])
16 changes: 16 additions & 0 deletions examples/template_vlm2vec.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{%- if messages | length > 1 -%}
{{ raise_exception('Embedding models should only embed one message at a time') }}
{%- endif -%}

{% set vars = namespace(parts=[], next_image_id=1) %}
{%- for message in messages -%}
{%- for content in message['content'] -%}
{%- if content['type'] == 'text' -%}
{%- set vars.parts = vars.parts + [content['text']] %}
{%- elif content['type'] == 'image' -%}
{%- set vars.parts = vars.parts + ['<|image_{i:d}|>'.format(i=vars.next_image_id)] %}
{%- set vars.next_image_id = vars.next_image_id + 1 %}
{%- endif -%}
{%- endfor -%}
{%- endfor -%}
{{ vars.parts | join(' ') }}
11 changes: 8 additions & 3 deletions tests/entrypoints/openai/test_vision_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@

from vllm.multimodal.utils import encode_image_base64, fetch_image

from ...utils import RemoteOpenAIServer
from ...utils import VLLM_PATH, RemoteOpenAIServer

MODEL_NAME = "TIGER-Lab/VLM2Vec-Full"
MAXIMUM_IMAGES = 2

vlm2vec_jinja_path = VLLM_PATH / "examples/template_vlm2vec.jinja"
assert vlm2vec_jinja_path.exists()

# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
Expand All @@ -35,6 +38,8 @@ def server():
"--trust-remote-code",
"--limit-mm-per-prompt",
f"image={MAXIMUM_IMAGES}",
"--chat-template",
str(vlm2vec_jinja_path),
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
Expand Down Expand Up @@ -90,5 +95,5 @@ async def test_image_embedding(server: RemoteOpenAIServer, model_name: str,
assert len(embeddings["data"]) == 1
assert len(embeddings["data"][0]["embedding"]) == 3072
assert embeddings["usage"]["completion_tokens"] == 0
assert embeddings["usage"]["prompt_tokens"] == 771
assert embeddings["usage"]["total_tokens"] == 771
assert embeddings["usage"]["prompt_tokens"] == 762
assert embeddings["usage"]["total_tokens"] == 762
15 changes: 11 additions & 4 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):

self._items: List[_T] = []

@property
def model_config(self) -> ModelConfig:
return self._model_config

@staticmethod
@lru_cache(maxsize=None)
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
Expand Down Expand Up @@ -491,10 +495,13 @@ def _parse_chat_message_content_parts(
content: List[Union[str, Dict[str, str]]] = []

mm_parser = mm_tracker.create_parser()
wrap_dicts = \
mm_tracker._model_config.hf_config.model_type in \
MODEL_KEEP_MULTI_MODAL_CONTENT or \
(chat_template_text_format == "openai")
model_config = mm_tracker.model_config

wrap_dicts = (chat_template_text_format == "openai"
or (model_config.task == "embedding"
and model_config.is_multimodal_model)
or (model_config.hf_config.model_type
in MODEL_KEEP_MULTI_MODAL_CONTENT))

for part in parts:
parse_res = _parse_chat_message_content_part(
Expand Down

0 comments on commit 60464e8

Please sign in to comment.