Skip to content

Commit

Permalink
[Model][VLM] Add Qwen2-VL model support (vllm-project#7905)
Browse files Browse the repository at this point in the history
Co-authored-by: Roger Wang <[email protected]>
Co-authored-by: DarkLight1337 <[email protected]>
Signed-off-by: Alvant <[email protected]>
  • Loading branch information
3 people authored and Alvant committed Oct 26, 2024
1 parent a384bd9 commit c4425ab
Show file tree
Hide file tree
Showing 14 changed files with 1,531 additions and 31 deletions.
10 changes: 7 additions & 3 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@ Multimodal Language Models
- Image\ :sup:`E`
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
-
* - :code:`Qwen2VLForConditionalGeneration`
- Qwen2-VL (see note)
- Image\ :sup:`+` / Video\ :sup:`+`
- :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc.
-
* - :code:`UltravoxModel`
- Ultravox
- Audio\ :sup:`E+`
Expand All @@ -265,15 +270,14 @@ Multimodal Language Models
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630

For :code:`LLaVA-NeXT-Video`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now.
.. note::
For :code:`LLaVA-NeXT-Video` and :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now.
This can be installed by running the following command:


.. code-block:: bash
pip install git+https://github.com/huggingface/transformers.git@21fac7abba2a37fae86106f87fcf9974fd1e3830
----

If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
Expand Down
18 changes: 18 additions & 0 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,23 @@ def run_qwen_vl(question):
return llm, prompt, stop_token_ids


# Qwen2-VL
def run_qwen2_vl(question):
model_name = "Qwen/Qwen2-VL-7B-Instruct"

llm = LLM(
model=model_name,
max_num_seqs=5,
)

prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n")
stop_token_ids = None
return llm, prompt, stop_token_ids


model_example_map = {
"llava": run_llava,
"llava-next": run_llava_next,
Expand All @@ -191,6 +208,7 @@ def run_qwen_vl(question):
"blip-2": run_blip2,
"internvl_chat": run_internvl,
"qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
}


Expand Down
68 changes: 61 additions & 7 deletions examples/offline_inference_vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from argparse import Namespace
from typing import List

from transformers import AutoTokenizer
from transformers import AutoProcessor, AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.multimodal.utils import fetch_image
Expand All @@ -30,7 +30,7 @@ def load_phi3v(question, image_urls: List[str]):
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
stop_token_ids = None
return llm, prompt, stop_token_ids
return llm, prompt, stop_token_ids, None


def load_internvl(question, image_urls: List[str]):
Expand Down Expand Up @@ -60,18 +60,72 @@ def load_internvl(question, image_urls: List[str]):
# https://huggingface.co/OpenGVLab/InternVL2-2B#service
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
return llm, prompt, stop_token_ids

return llm, prompt, stop_token_ids, None


def load_qwen2_vl(question, image_urls: List[str]):
try:
from qwen_vl_utils import process_vision_info
except ModuleNotFoundError:
print('WARNING: `qwen-vl-utils` not installed, input images will not '
'be automatically resized. You can enable this functionality by '
'`pip install qwen-vl-utils`.')
process_vision_info = None

model_name = "Qwen/Qwen2-VL-7B-Instruct"

llm = LLM(
model=model_name,
max_num_seqs=5,
max_model_len=32768 if process_vision_info is None else 4096,
limit_mm_per_prompt={"image": len(image_urls)},
)

placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role":
"user",
"content": [
*placeholders,
{
"type": "text",
"text": question
},
],
}]

processor = AutoProcessor.from_pretrained(model_name)

prompt = processor.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)

stop_token_ids = None

if process_vision_info is None:
image_data = [fetch_image(url) for url in image_urls]
else:
image_data, _ = process_vision_info(messages)

return llm, prompt, stop_token_ids, image_data


model_example_map = {
"phi3_v": load_phi3v,
"internvl_chat": load_internvl,
"qwen2_vl": load_qwen2_vl,
}


def run_generate(model, question: str, image_urls: List[str]):
llm, prompt, stop_token_ids = model_example_map[model](question,
image_urls)
llm, prompt, stop_token_ids, image_data = model_example_map[model](
question, image_urls)
if image_data is None:
image_data = [fetch_image(url) for url in image_urls]

sampling_params = SamplingParams(temperature=0.0,
max_tokens=128,
Expand All @@ -81,7 +135,7 @@ def run_generate(model, question: str, image_urls: List[str]):
{
"prompt": prompt,
"multi_modal_data": {
"image": [fetch_image(url) for url in image_urls]
"image": image_data
},
},
sampling_params=sampling_params)
Expand All @@ -92,7 +146,7 @@ def run_generate(model, question: str, image_urls: List[str]):


def run_chat(model: str, question: str, image_urls: List[str]):
llm, _, stop_token_ids = model_example_map[model](question, image_urls)
llm, _, stop_token_ids, _ = model_example_map[model](question, image_urls)

sampling_params = SamplingParams(temperature=0.0,
max_tokens=128,
Expand Down
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ importlib_metadata
mistral_common >= 1.3.4
pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
einops # Required for Qwen2-VL.
5 changes: 5 additions & 0 deletions tests/models/test_registry.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import pytest
import transformers

from vllm.model_executor.models import _MODELS, ModelRegistry


@pytest.mark.parametrize("model_cls", _MODELS)
def test_registry_imports(model_cls):
if (model_cls == "Qwen2VLForConditionalGeneration"
and transformers.__version__ < "4.45"):
pytest.skip("Waiting for next transformers release")

# Ensure all model classes can be imported successfully
ModelRegistry.resolve_model_cls([model_cls])
9 changes: 6 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ class LoadConfig:
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
"""

load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
Expand Down Expand Up @@ -1733,8 +1733,11 @@ def _get_and_verify_max_len(
"with rope_scaling. Please raise an issue so we can "
"investigate.")

assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_type == "mrope":
scaling_factor = 1
else:
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_type == "yarn":
derived_max_model_len = rope_scaling[
"original_max_position_embeddings"]
Expand Down
8 changes: 7 additions & 1 deletion vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class ConversationMessage(TypedDict, total=False):
"""The tool calls generated by the model, such as function calls."""


ModalityStr = Literal["image", "audio"]
ModalityStr = Literal["image", "audio", "video"]
_T = TypeVar("_T")


Expand Down Expand Up @@ -158,12 +158,18 @@ def _placeholder_str(self, modality: ModalityStr,
hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"):
return "<image>"
if model_type == "qwen2_vl":
return "<|vision_start|><|image_pad|><|vision_end|>"

raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio":
if model_type == "ultravox":
return "<|reserved_special_token_0|>"
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "video":
if model_type == "qwen2_vl":
return "<|vision_start|><|video_pad|><|vision_end|>"
raise TypeError(f"Unknown model type: {model_type}")
else:
raise TypeError(f"Unknown modality: {modality}")

Expand Down
Loading

0 comments on commit c4425ab

Please sign in to comment.