From 62e0b380365224cfc23b34f27438207e35336c0a Mon Sep 17 00:00:00 2001 From: Austin Veselka <50646302+FurtherAI@users.noreply.github.com> Date: Wed, 13 Nov 2024 02:28:13 -0600 Subject: [PATCH] [Model] Adding Support for Qwen2VL as an Embedding Model. Using MrLight/dse-qwen2-2b-mrl-v1 (#9944) Signed-off-by: FurtherAI Co-authored-by: FurtherAI Signed-off-by: Tyler Michael Smith --- docs/source/models/supported_models.rst | 6 + docs/source/models/vlm.rst | 17 ++ ...ai_chat_embedding_client_for_multimodal.py | 123 +++++++++-- examples/template_dse_qwen2_vl.jinja | 7 + tests/conftest.py | 3 + .../vision_language/test_dse_qwen2_vl.py | 209 ++++++++++++++++++ vllm/model_executor/models/qwen2_vl.py | 17 +- vllm/model_executor/models/registry.py | 1 + 8 files changed, 364 insertions(+), 19 deletions(-) create mode 100644 examples/template_dse_qwen2_vl.jinja create mode 100644 tests/models/embedding/vision_language/test_dse_qwen2_vl.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index ca894819f2c26..58ec3acc6aea5 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -584,6 +584,12 @@ Multimodal Embedding - :code:`TIGER-Lab/VLM2Vec-Full` - 🚧 - ✅︎ + * - :code:`Qwen2VLForConditionalGeneration` + - Qwen2-VL-based + - T + I + - :code:`MrLight/dse-qwen2-2b-mrl-v1` + - + - ✅︎ .. important:: Some model architectures support both generation and embedding tasks. diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index 112e9db6a41de..bcbe50a25fa09 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -310,4 +310,21 @@ Since the request schema is not defined by OpenAI client, we post a request to t response_json = response.json() print("Embedding output:", response_json["data"][0]["embedding"]) +Here is an example for serving the ``MrLight/dse-qwen2-2b-mrl-v1`` model. + +.. code-block:: bash + + vllm serve MrLight/dse-qwen2-2b-mrl-v1 --task embedding \ + --trust-remote-code --max-model-len 8192 --chat-template examples/template_dse_qwen2_vl.jinja + +.. important:: + + Like with VLM2Vec, we have to explicitly pass ``--task embedding``. Additionally, ``MrLight/dse-qwen2-2b-mrl-v1`` requires an EOS token for embeddings, + which is handled by the jinja template. + +.. important:: + + Also important, ``MrLight/dse-qwen2-2b-mrl-v1`` requires a placeholder image of the minimum image size for text query embeddings. See the full code + example below for details. + A full code example can be found in `examples/openai_chat_embedding_client_for_multimodal.py `_. diff --git a/examples/openai_chat_embedding_client_for_multimodal.py b/examples/openai_chat_embedding_client_for_multimodal.py index effb588e1387f..fff82020d9a30 100644 --- a/examples/openai_chat_embedding_client_for_multimodal.py +++ b/examples/openai_chat_embedding_client_for_multimodal.py @@ -1,33 +1,120 @@ +import argparse +import base64 +import io + import requests +from PIL import Image image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" -response = requests.post( - "http://localhost:8000/v1/embeddings", - json={ - "model": - "TIGER-Lab/VLM2Vec-Full", - "messages": [{ + +def vlm2vec(): + response = requests.post( + "http://localhost:8000/v1/embeddings", + json={ + "model": + "TIGER-Lab/VLM2Vec-Full", + "messages": [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "Represent the given image." + }, + ], + }], + "encoding_format": + "float", + }, + ) + response.raise_for_status() + response_json = response.json() + + print("Embedding output:", response_json["data"][0]["embedding"]) + + +def dse_qwen2_vl(inp: dict): + # Embedding an Image + if inp["dtype"] == "image": + messages = [{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": inp["image_url"], + } + }, { + "type": "text", + "text": "What is shown in this image?" + }] + }] + # Embedding a Text Query + else: + # MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image + # of the minimum input size + buffer = io.BytesIO() + image_placeholder = Image.new("RGB", (56, 56)) + image_placeholder.save(buffer, "png") + buffer.seek(0) + image_placeholder = base64.b64encode(buffer.read()).decode('utf-8') + messages = [{ "role": "user", "content": [ { "type": "image_url", "image_url": { - "url": image_url + "url": f"data:image/jpeg;base64,{image_placeholder}", } }, { "type": "text", - "text": "Represent the given image." + "text": f"Query: {inp['content']}" }, - ], - }], - "encoding_format": - "float", - }, -) -response.raise_for_status() -response_json = response.json() - -print("Embedding output:", response_json["data"][0]["embedding"]) + ] + }] + + response = requests.post( + "http://localhost:8000/v1/embeddings", + json={ + "model": "MrLight/dse-qwen2-2b-mrl-v1", + "messages": messages, + "encoding_format": "float", + }, + ) + response.raise_for_status() + response_json = response.json() + + print("Embedding output:", response_json["data"][0]["embedding"]) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + "Script to call a specified VLM through the API. Make sure to serve " + "the model with --task embedding before running this.") + parser.add_argument("model", + type=str, + choices=["vlm2vec", "dse_qwen2_vl"], + required=True, + help="Which model to call.") + args = parser.parse_args() + + if args.model == "vlm2vec": + vlm2vec() + elif args.model == "dse_qwen2_vl": + dse_qwen2_vl({ + "dtye": "image", + "image_url": image_url, + }) + dse_qwen2_vl({ + "dtype": "text", + "content": "What is the weather like today?", + }) diff --git a/examples/template_dse_qwen2_vl.jinja b/examples/template_dse_qwen2_vl.jinja new file mode 100644 index 0000000000000..e7b93fae31770 --- /dev/null +++ b/examples/template_dse_qwen2_vl.jinja @@ -0,0 +1,7 @@ +{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{% raw %}<|im_start|>system +You are a helpful assistant.<|im_end|> +{% endraw %}{% endif %}<|im_start|>{{ message['role'] }}{% raw %} +{% endraw %}{% if message['content'] is string %}{{ message['content'] }}<|im_end|>{% raw %} +{% endraw %}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>{% raw %} +{% endraw %}{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant{% raw %} +{% endraw %}{% endif %}<|endoftext|> \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 6cf791dc62ce5..0dc1cc6e83c18 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -243,6 +243,9 @@ def video_assets() -> _VideoAssets: class HfRunner: def wrap_device(self, x: _T, device: Optional[str] = None) -> _T: + if x is None or isinstance(x, (bool, )): + return x + if device is None: device = "cpu" if current_platform.is_cpu() else "cuda" diff --git a/tests/models/embedding/vision_language/test_dse_qwen2_vl.py b/tests/models/embedding/vision_language/test_dse_qwen2_vl.py new file mode 100644 index 0000000000000..3dd8cb729f8a6 --- /dev/null +++ b/tests/models/embedding/vision_language/test_dse_qwen2_vl.py @@ -0,0 +1,209 @@ +from functools import partial +from typing import Callable, Dict, List, Type + +import pytest +import torch +from PIL import Image +from transformers import BatchEncoding, Qwen2VLForConditionalGeneration + +from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner +from ....utils import large_gpu_test +from ..utils import check_embeddings_close + +HF_TEXT_PROMPTS = [ + # T -> X + ( + "Query: Find me an everyday image that matches the given caption: The label of the object is stop sign", # noqa: E501, + Image.new("RGB", (56, 56))), + # T -> X + ("Query: Retrieve an image of this caption: cherry blossom", + Image.new("RGB", (56, 56))), +] + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": + "What is shown in this image?", + "cherry_blossom": + "What is shown in this image?" +}) + +MODELS = ["MrLight/dse-qwen2-2b-mrl-v1"] + + +def get_messages(image: Image.Image, text: str, embed_text: bool): + # assert False, 'remember to use outer [] as required' + if embed_text: + messages = [{ + "role": + "user", + "content": [ + { + "type": "image", + "image": Image.new("RGB", (56, 56)), + "resized_height": 1, + "resized_width": 1 + }, # need a dummy image here for an easier process. + { + "type": "text", + "text": text + }, + ] + }] + else: + messages = [{ + "role": + "user", + "content": [{ + "type": "image", + "image": image + }, { + "type": "text", + "text": text + }] + }] + return messages + + +def apply_chat_template_and_add_eos( + messages: List[Dict], + apply_chat_template_fn: Callable, +): + prompt = apply_chat_template_fn( + messages, tokenize=False, add_generation_prompt=True) + "<|endoftext|>" + return prompt + + +def postprocess_inputs(hf_model: HfRunner, inputs: BatchEncoding, **kwargs): + return hf_model.model.prepare_inputs_for_generation(**inputs, **kwargs) + + +def _run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + input_texts: List[str], + input_images: PromptImageInput, + embed_texts: List[bool], + model: str, + *, + dtype: str, +) -> None: + '''SET PYTHONPATH''' + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + with vllm_runner(model, + task="embedding", + dtype=dtype, + enforce_eager=True, + max_model_len=8192) as vllm_model: + tokenizer = vllm_model.model.get_tokenizer() + texts = [ + # this is necessary because vllm_model.encode will not apply any + # templating to the prompt, and therefore lacks an image_pad + # token unless one is inserted beforehand (the (28,28) image + # above is converted to an image pad token by the chat template). + apply_chat_template_and_add_eos( + get_messages(image, text, False), + apply_chat_template_fn=tokenizer.apply_chat_template, + ) for text, image in zip(input_texts, input_images) + # vllm will replace the pad token with the actual image, + # which may be a placeholder image, later. + ] + vllm_outputs = vllm_model.encode(texts, images=input_images) + + hf_outputs = [] + with hf_runner(model, + dtype=dtype, + auto_cls=Qwen2VLForConditionalGeneration) as hf_model: + hf_model.postprocess_inputs = partial( + postprocess_inputs, + hf_model, + cache_position=torch.arange( + 0, + 1, # 1 for batch size + requires_grad=False), + use_cache=False) + for text, image, embed_text in zip(input_texts, input_images, + embed_texts): + # dse requires non-standard input processing + # because it needs an image_pad token + messages = get_messages(image, text, embed_text) + prompt = apply_chat_template_and_add_eos( + messages, hf_model.processor.apply_chat_template) + inputs = hf_model.get_inputs( + prompts=[[prompt]], + images=[[image]], + ) + with torch.no_grad(): + outputs = hf_model.model( + **hf_model.wrap_device(inputs[0], + device=hf_model.model.device.type), + return_dict=True, + output_hidden_states=True, + ) + pooled_output = torch.nn.functional.normalize( + outputs.hidden_states[-1][0, -1], p=2, dim=-1) + hf_outputs.append(pooled_output.tolist()) + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +def test_models_text( + hf_runner, + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + input_texts_images = [(text, image_placeholder) + for text, image_placeholder in HF_TEXT_PROMPTS] + input_texts = [text for text, _ in input_texts_images] + input_images = [image for _, image in input_texts_images] + embed_texts = [True] * len(input_texts) + + _run_test( + hf_runner, + vllm_runner, + input_texts, + input_images, # type: ignore + embed_texts, + model, + dtype=dtype, + ) + + +@large_gpu_test(min_gb=48) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +def test_models_image( + hf_runner, + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + input_texts_images = [ + (text, asset.pil_image) + for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) + ] + input_texts = [text for text, _ in input_texts_images] + input_images = [image for _, image in input_texts_images] + embed_texts = [False] * len(input_texts) + + _run_test( + hf_runner, + vllm_runner, + input_texts, + input_images, + embed_texts, + model, + dtype=dtype, + ) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 1b162e7df8578..9a19ccbca3f1e 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -51,6 +51,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import (GPTQConfig, GPTQMarlinConfig, QuantizationConfig) @@ -58,12 +59,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2 import Qwen2Model +from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalKwargs) from vllm.multimodal.base import MultiModalData from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import IntermediateTensors, SequenceData +from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.processor import cached_get_processor @@ -1067,6 +1069,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + pooler_config = vllm_config.model_config.pooler_config multimodal_config = vllm_config.model_config.multimodal_config assert not cache_config.enable_prefix_caching, \ "Qwen2-VL currently does not support prefix caching" @@ -1098,6 +1101,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = get_sampler() + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.LAST, + normalize=True, + softmax=False) self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) @@ -1318,6 +1326,13 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 32750602b988c..f172c06c4a26a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -109,6 +109,7 @@ # [Multimodal] "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), + "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501, } _MULTIMODAL_MODELS = {