Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model][VLM] Add Qwen2-VL model support #7905

Merged
merged 44 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
0a648b2
Add support to Qwen2-VL.
fyabc Aug 23, 2024
320df57
Merge branch 'refs/heads/main' into add_qwen2_vl_new
fyabc Aug 26, 2024
7f96df8
Reformat
fyabc Aug 27, 2024
fbf2b8b
Merge branch 'refs/heads/main' into add_qwen2_vl_new
fyabc Aug 27, 2024
bcaff4f
Update transformers link.
fyabc Aug 27, 2024
f2185bf
Bugfix of mrope_input_positions in model_runner.py.
fyabc Aug 27, 2024
60448cb
Rename pixel_values_video to pixel_values_videos in qwen2_vl.py.
fyabc Aug 27, 2024
71a77b1
Fix the bug of MultiModalInputs.batch() when passing different modali…
fyabc Aug 27, 2024
60c4cbd
Fix the bug when running OpenAI-compatible API server.
fyabc Aug 27, 2024
e29ff54
Merge branch 'refs/heads/main' into add_qwen2_vl_new
fyabc Aug 29, 2024
ddb7138
Refactor qwen2_vl.py based on review comments.
fyabc Aug 29, 2024
14fe12a
reformat
fyabc Aug 29, 2024
89def23
reformat
fyabc Aug 29, 2024
e721e60
Fix the bug of model_is_mrope in model_runner.py.
fyabc Aug 29, 2024
d66d167
fix type hints in qwen2_vl.py
fyabc Aug 29, 2024
acd85ed
Update mm input processors according to new MultiModalInput.batch() i…
fyabc Aug 29, 2024
8d762c6
Merge branch 'refs/heads/main' into add_qwen2_vl_new
fyabc Aug 30, 2024
87ba5ed
Fix SamplerOutput.
fyabc Aug 30, 2024
cda300a
Fix bug of quantization.
fyabc Aug 30, 2024
da03a3f
Bugfix of type hints in qwen2_vl.py.
fyabc Aug 31, 2024
25fb189
reformat.
fyabc Aug 31, 2024
d01530d
Merge branch 'main' into add_qwen2_vl_new
ywang96 Sep 1, 2024
faebfe4
fix typo from resolving conflict
ywang96 Sep 1, 2024
e492e53
Merge branch 'refs/heads/main' into add_qwen2_vl_new
fyabc Sep 2, 2024
2e87db7
Bugfix in qwen2_vl.py.
fyabc Sep 2, 2024
39a1069
Adding xformers implementation
fyabc Sep 5, 2024
855c78b
Fix bug of attn_bias in xformers implementation
fyabc Sep 5, 2024
091983f
Fix bug in xformers implementation, and add backend check in vision a…
fyabc Sep 6, 2024
b406571
Merge branch 'refs/heads/main' into add_qwen2_vl_new
fyabc Sep 6, 2024
7739588
Bugfix in qwen2_vl.py.
fyabc Sep 6, 2024
5bab9ba
Bugfix in qwen2_vl.py.
fyabc Sep 6, 2024
4587346
reformat.
fyabc Sep 6, 2024
ffad79f
Refactor MRotaryEmbedding.
fyabc Sep 6, 2024
9e7a946
Merge branch 'refs/heads/main' into add_qwen2_vl_new
fyabc Sep 9, 2024
d527417
Add "video" into ModalityStr.
fyabc Sep 9, 2024
6f3116c
Add Qwen2-VL examples.
fyabc Sep 9, 2024
386f302
Optimizer Qwen2-VL input processor. Update document.
fyabc Sep 10, 2024
c64c217
Update model notes and requirements-common.txt.
fyabc Sep 10, 2024
6bdefd6
Update model notes.
fyabc Sep 10, 2024
33dd048
Skip loading model
DarkLight1337 Sep 11, 2024
369ce7d
Merge branch 'main' into add_qwen2_vl_new
DarkLight1337 Sep 11, 2024
282c66a
format
DarkLight1337 Sep 11, 2024
14ef94d
Increase `max_model_len` to fit the original image
DarkLight1337 Sep 11, 2024
09b7a4f
Merge branch 'main' into add_qwen2_vl_new
DarkLight1337 Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,11 @@ Multimodal Language Models
- Image
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
-
* - :code:`Qwen2VLForConditionalGeneration`
- Qwen2-VL
- Image / Video
- :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc.
-
fyabc marked this conversation as resolved.
Show resolved Hide resolved
* - :code:`PaliGemmaForConditionalGeneration`
- PaliGemma
- Image
Expand Down
9 changes: 6 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ class LoadConfig:
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.

"""

load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
Expand Down Expand Up @@ -1697,8 +1697,11 @@ def _get_and_verify_max_len(
"with rope_scaling. Please raise an issue so we can "
"investigate.")

assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_type == "mrope":
fyabc marked this conversation as resolved.
Show resolved Hide resolved
scaling_factor = 1
else:
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_type == "yarn":
derived_max_model_len = rope_scaling[
"original_max_position_embeddings"]
Expand Down
6 changes: 6 additions & 0 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,18 @@ def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
return tokenizer.decode(model_config.hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"):
return "<image>"
if model_type == "qwen2_vl":
return "<|vision_start|><|image_pad|><|vision_end|>"

raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio":
if model_type == "ultravox":
return "<|reserved_special_token_0|>"
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "video":
if model_type == "qwen2_vl":
return "<|vision_start|><|video_pad|><|vision_end|>"
raise TypeError(f"Unknown model type: {model_type}")
else:
raise TypeError(f"Unknown modality: {modality}")

Expand Down
196 changes: 195 additions & 1 deletion vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,190 @@ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
return new_freqs


class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
mrope_section: Optional[List[int]] = None,
) -> None:
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)

self.mrope_section = mrope_section
if self.mrope_section:
assert sum(self.mrope_section) == rotary_dim // 2

def forward(
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""

qk_ndim_in = query.ndim

query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved

query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]

cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.mrope_section and positions.ndim == query.ndim - 1:
cos = torch.cat([
m[i]
for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
],
dim=-1)
sin = torch.cat([
m[i]
for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
],
dim=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found this part a bit difficult to understand. Could you please write a comment explaining it (or provide a pointer to the relevant paper if any)? Especially, I found m[i] particularly confusing.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We design a Multimodal Rotary Position Embedding (M-ROPE). By deconstructing the original rotary embedding into three parts representing temporal and spatial (height and width) information,M-ROPE enables LLM to concurrently capture and integrate 1D textual, 2D visual, and 3D video positional information.
mrope_section represents the number of dimensions occupied by each modality (temporal and spatial: height and width) in the embedding (emb). i indicates which modality (dimension) it refers to. We will extract the corresponding dimensions of the embedding based on the 3D rope_index and then concatenate them.

 Examples:
                Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
                input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
                vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
                vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
                vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
                text temporal position_ids: [3, 4, 5, 6, 7]
                text height position_ids: [3, 4, 5, 6, 7]
                text width position_ids: [3, 4, 5, 6, 7]
                Here we calculate the text start position_ids as the max vision position_ids plus 1.

For a 64-channel rope_emb where mrope_section is defined as (time 16, height 24, width 24), the rope_index (1, 2, 3) corresponds to the concatenation of rope_emb[1][:16], rope_emb[2][16:16+24], and rope_emb[3][16+24:16+24+24].


if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)

rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin

if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot

query = query.flatten(-2)
key = key.flatten(-2)
if query.ndim > qk_ndim_in:
query = query.squeeze(0)
key = key.squeeze(1)

return query, key

@staticmethod
def get_input_positions(
input_tokens: List[int],
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
vision_end_token_id: int,
spatial_merge_size: int,
context_len: int = 0,
) -> Tuple[List[List[int]], int]:
"""Get mrope input positions and delta value."""

if torch.is_tensor(image_grid_thw):
image_grid_thw = image_grid_thw.tolist()
if torch.is_tensor(video_grid_thw):
video_grid_thw = video_grid_thw.tolist()

input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []

st = 0
remain_images, remain_videos = image_nums, video_nums

image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st

st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w

if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:]
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()

return llm_positions.tolist(), mrope_position_delta

@staticmethod
def get_next_input_positions(
mrope_position_delta: int,
context_len: int,
seq_len: int,
) -> List[List[int]]:
return [
list(
range(context_len + mrope_position_delta,
seq_len + mrope_position_delta)) for _ in range(3)
]


_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}


Expand Down Expand Up @@ -805,7 +989,7 @@ def get_rope(
# The correct one should be "longrope" but keep "su" here
# for backward compatible
if scaling_type not in {"su", "longrope"}:
scaling_factor = rope_scaling["factor"]
scaling_factor = rope_scaling.get("factor")
if scaling_type == "llama3":
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
Expand Down Expand Up @@ -869,6 +1053,16 @@ def get_rope(
head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, dtype, short_factor, long_factor,
**extra_kwargs)
elif scaling_type == "mrope":
return MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
"Qwen2VLForConditionalGeneration":
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
Expand Down Expand Up @@ -85,6 +87,8 @@
"PaliGemmaForConditionalGeneration"),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"UltravoxModel": ("ultravox", "UltravoxModel"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
"Qwen2VLForConditionalGeneration"),
}
_CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"),
Expand Down
Loading
Loading