From d88b49fe8a92e1a00138063f164608704e778762 Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 3 Oct 2024 07:06:16 +0000 Subject: [PATCH 01/12] Support Pixtral models in the HF Transformers format --- examples/offline_inference_vision_language.py | 19 ++ vllm/model_executor/models/llava.py | 43 ++- vllm/model_executor/models/pixtral.py | 322 +++++++++++++++++- 3 files changed, 380 insertions(+), 4 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index b94ef537d783f..878a7cb9451c0 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -252,6 +252,24 @@ def run_qwen2_vl(question, modality): return llm, prompt, stop_token_ids +# Pixtral +def run_pixtral(question, modality): + assert modality == "image" + + model_name = "mistral-community/pixtral-12b" + + llm = LLM( + model=model_name, + max_model_len=10000, + max_num_seqs=16, + enforce_eager=True, + ) + + prompt = f"[INST]{question}\n[IMG][/INST]" + stop_token_ids = None + return llm, prompt, stop_token_ids + + # LLama def run_mllama(question, modality): assert modality == "image" @@ -289,6 +307,7 @@ def run_mllama(question, modality): "internvl_chat": run_internvl, "qwen_vl": run_qwen_vl, "qwen2_vl": run_qwen2_vl, + "pixtral": run_pixtral, "mllama": run_mllama, } diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 69eb177a7dea8..9214b9d2430c1 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -4,7 +4,8 @@ import torch import torch.nn as nn from PIL import Image -from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig +from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig, + SiglipVisionConfig) from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig @@ -22,6 +23,10 @@ dummy_seq_data_for_clip, get_max_clip_image_tokens, input_processor_for_clip) from .interfaces import SupportsMultiModal +from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf, + dummy_seq_data_for_pixtral_hf, + get_max_pixtral_hf_image_tokens, + input_processor_for_pixtral_hf) from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens, input_processor_for_siglip) @@ -77,6 +82,8 @@ def get_max_llava_image_tokens(ctx: InputContext): num_image_tokens = get_max_clip_image_tokens(vision_config) elif isinstance(vision_config, SiglipVisionConfig): num_image_tokens = get_max_siglip_image_tokens(vision_config) + elif isinstance(vision_config, PixtralVisionConfig): + num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config) else: msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @@ -120,6 +127,17 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int, mm_data = dummy_image_for_siglip(vision_config, num_images) return seq_data, mm_data + elif isinstance(vision_config, PixtralVisionConfig): + seq_data = dummy_seq_data_for_pixtral_hf( + vision_config, + seq_len, + num_images, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + mm_data = dummy_image_for_pixtral_hf(vision_config, num_images) + return seq_data, mm_data msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @@ -163,6 +181,14 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): image_token_id=hf_config.image_token_index, image_feature_size_override=image_feature_size, ) + elif isinstance(vision_config, PixtralVisionConfig): + return input_processor_for_pixtral_hf( + model_config, + vision_config, + llm_inputs, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @@ -189,6 +215,9 @@ def _init_vision_tower(hf_config: LlavaConfig): vision_config, num_hidden_layers_override=num_hidden_layers, ) + elif isinstance(vision_config, PixtralVisionConfig): + # TODO: allow layer override? + return PixtralHFVisionModel(vision_config) msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) @@ -210,6 +239,14 @@ def __init__(self, self.config = config self.multimodal_config = multimodal_config + # HACK: Special cases for pixtral + if (config.text_config.architectures is None + and config.text_config.model_type == "mistral"): + config.text_config.architectures = ["MistralForCausalLM"] + if (config.projector_hidden_act is None + and config.vision_config.hidden_act == "gelu"): + config.projector_hidden_act = "gelu" + # TODO: Optionally initializes this for supporting embeddings. self.vision_tower = _init_vision_tower(config) self.multi_modal_projector = LlavaMultiModalProjector( @@ -223,6 +260,10 @@ def __init__(self, def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) + # HACK due to: + # expected shape of pixel values is ('batch_size', '3', '1024', '1024') + # You supplied (2, 1, 3, 1024, 1024). + # data = data.reshape(-1, *data.shape[-3:]) actual_dims = tuple(data.shape[1:]) if actual_dims != expected_dims: diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index aa92e62a30d3f..eae7da8d4fc78 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -7,12 +7,12 @@ import torch.nn.functional as F from mistral_common.protocol.instruct.messages import ImageChunk from PIL import Image -from transformers import PretrainedConfig +from transformers import PixtralVisionConfig, PretrainedConfig from xformers.ops.fmha import memory_efficient_attention from xformers.ops.fmha.attn_bias import BlockDiagonalMask from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import CacheConfig, ModelConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig @@ -22,7 +22,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs -from vllm.multimodal.utils import cached_get_tokenizer +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) from vllm.sequence import IntermediateTensors, SequenceData from .interfaces import SupportsMultiModal @@ -560,3 +561,318 @@ def __init__(self, args: VisionEncoderArgs, dim: int): def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w_out(self.gelu(self.w_in(x))) + + +#### HF Transformers version of Pixtral #### +# Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py +# This model follows the Llava family, meaning image embeddings are placed +# instead of the `[IMG]` token placeholders. +# The model uses [`PixtralVisionModel`] for its vision encoder, +# and [`MistralForCausalLM`] for its language decoder. + + +def get_pixtral_hf_patch_grid_length(*, image_size: int, + patch_size: int) -> int: + # Since interpolation is applied, the image size need not be divisible + # assert image_size % patch_size == 0 + return image_size // patch_size + + +def get_pixtral_hf_num_patches(*, image_size: int, patch_size: int) -> int: + grid_length = get_pixtral_hf_patch_grid_length(image_size=image_size, + patch_size=patch_size) + return grid_length * grid_length + + +def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig) -> int: + return get_pixtral_hf_num_patches(image_size=hf_config.image_size, + patch_size=hf_config.patch_size) + + +def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int: + return get_pixtral_hf_image_feature_size(hf_config) + + +def dummy_seq_data_for_pixtral_hf( + hf_config: PixtralVisionConfig, + seq_len: int, + num_images: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, +): + if image_feature_size_override is None: + image_feature_size = get_pixtral_hf_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + return SequenceData.from_token_counts( + (image_token_id, image_feature_size * num_images), + (0, seq_len - image_feature_size * num_images), + ) + + +def dummy_image_for_pixtral_hf( + hf_config: PixtralVisionConfig, + num_images: int, + *, + image_width_override: Optional[int] = None, + image_height_override: Optional[int] = None, +): + width = height = hf_config.image_size + if image_width_override is not None: + width = image_width_override + if image_height_override is not None: + height = image_height_override + + image = Image.new("RGB", (width, height), color=0) + return {"image": image if num_images == 1 else [image] * num_images} + + +def input_processor_for_pixtral_hf( + model_config: ModelConfig, + hf_config: PixtralVisionConfig, + llm_inputs: LLMInputs, + *, + image_token_id: int, + image_feature_size_override: Optional[Union[int, List[int]]] = None, +): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + tokenizer = cached_get_tokenizer(model_config.tokenizer) + + if image_feature_size_override is None: + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + image_feature_size = get_pixtral_hf_image_feature_size(hf_config) + elif isinstance(image_data, torch.Tensor): + num_images, image_feature_size, hidden_size = image_data.shape + else: + raise TypeError(f"Invalid image type: {type(image_data)}") + else: + image_feature_size = image_feature_size_override + + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + tokenizer, + llm_inputs.get("prompt"), + llm_inputs["prompt_token_ids"], + placeholder_token_id=image_token_id, + repeat_count=image_feature_size, + ) + + # NOTE: Create a defensive copy of the original inputs + return LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + + +class PixtralHFFeedForward(nn.Module): + + def __init__(self, config: PixtralVisionConfig): + super().__init__() + assert config.intermediate_size is not None + self.gate_proj = nn.Linear(config.hidden_size, + config.intermediate_size, + bias=False) + self.down_proj = nn.Linear(config.intermediate_size, + config.hidden_size, + bias=False) + self.up_proj = nn.Linear(config.hidden_size, + config.intermediate_size, + bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class PixtralHFAttention(nn.Module): + + def __init__(self, config: PixtralVisionConfig): + super().__init__() + self.config = config + assert not config.hidden_size % config.num_attention_heads + self.n_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + + self.q_proj = nn.Linear(config.hidden_size, + config.hidden_size, + bias=False) + self.k_proj = nn.Linear(config.hidden_size, + config.hidden_size, + bias=False) + self.v_proj = nn.Linear(config.hidden_size, + config.hidden_size, + bias=False) + self.o_proj = nn.Linear(config.hidden_size, + config.hidden_size, + bias=False) + + def forward( + self, + x: torch.Tensor, + mask: BlockDiagonalMask, + freqs_cis: torch.Tensor, + ) -> torch.Tensor: + batch, patches, _ = x.shape + + q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x) + q = q.reshape(batch, patches, self.n_heads, self.head_dim) + k = k.reshape(batch, patches, self.n_heads, self.head_dim) + v = v.reshape(batch, patches, self.n_heads, self.head_dim) + + q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis) + out = memory_efficient_attention(q, k, v, attn_bias=mask) + out = out.reshape(batch, patches, self.n_heads * self.head_dim) + return self.o_proj(out) + + +class PixtralHFTransformerBlock(nn.Module): + + def __init__(self, config: PixtralVisionConfig): + super().__init__() + self.attention = PixtralHFAttention(config) + self.feed_forward = PixtralHFFeedForward(config) + self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5) + self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5) + + def forward( + self, + x: torch.Tensor, + mask: BlockDiagonalMask, + freqs_cis: torch.Tensor, + ) -> torch.Tensor: + r = self.attention.forward(self.attention_norm(x), + mask=mask, + freqs_cis=freqs_cis) + h = x + r + r = self.feed_forward.forward(self.ffn_norm(h)) + out = h + r + return out + + +class PixtralHFTransformer(nn.Module): + + def __init__(self, config: PixtralVisionConfig): + super().__init__() + self.layers = torch.nn.ModuleList() + for _ in range(config.num_hidden_layers): + self.layers.append(PixtralHFTransformerBlock(config)) + + def forward( + self, + x: torch.Tensor, + mask: BlockDiagonalMask, + freqs_cis: Optional[torch.Tensor], + ) -> torch.Tensor: + for layer in self.layers: + x = layer(x, mask=mask, freqs_cis=freqs_cis) + return x + + +class PixtralHFVisionModel(nn.Module): + + config_class = PixtralVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: PixtralVisionConfig): + super().__init__() + + self.config = config + self.patch_conv = nn.Conv2d( + in_channels=config.num_channels, + out_channels=config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, + bias=False, + ) + self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5) + self.transformer = PixtralHFTransformer(config) + + head_dim = self.config.hidden_size // self.config.num_attention_heads + assert head_dim % 2 == 0, "ROPE requires even head_dim" + self._freqs_cis: Optional[torch.Tensor] = None + + @property + def max_patches_per_side(self) -> int: + return self.config.image_size // self.config.patch_size + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.device: + return next(self.parameters()).dtype + + @property + def freqs_cis(self) -> torch.Tensor: + if self._freqs_cis is None: + self._freqs_cis = precompute_freqs_cis_2d( + dim=self.config.hidden_size // self.config.num_attention_heads, + height=self.max_patches_per_side, + width=self.max_patches_per_side, + theta=self.config.rope_theta, + ) + + if self._freqs_cis.device != self.device: + self._freqs_cis = self._freqs_cis.to(device=self.device) + + return self._freqs_cis + + def forward( + self, + pixel_values: List[torch.Tensor], + ) -> torch.Tensor: + """ + Args: + pixel_values: tensor of token features for + all tokens of all images of shape (N_toks, D) + Returns: + image_features: tensor of token features for + all tokens of all images of shape (N_toks, D) + """ + # pass images through initial convolution independently + patch_embeds_list = [ + self.patch_conv(img.unsqueeze(0).to(self.dtype)) + for img in pixel_values + ] + + # flatten to a single sequence + patch_embeds = torch.cat( + [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1) + patch_embeds = self.ln_pre(patch_embeds) + + # positional embeddings + positions = position_meshgrid(patch_embeds_list).to(self.device) + freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] + + # pass through Transformer with a block diagonal mask delimiting images + mask = BlockDiagonalMask.from_seqlens( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) + out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) + + # remove batch dimension of the single sequence + return out.squeeze(0) + + # (TODO) Add prefix argument for filtering out weights to be loaded + # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [] + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From 6df111e3cbb6101dfec97c69aba8e4a9e4278c71 Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 10 Oct 2024 19:35:56 +0000 Subject: [PATCH 02/12] Working for cherry blossom? --- vllm/model_executor/models/llava.py | 37 ++++++++++++++-- vllm/model_executor/models/pixtral.py | 62 +++++++++++++++++++++------ vllm/model_executor/models/utils.py | 8 +++- 3 files changed, 91 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 1dcc8f6be10b4..b3575301e5e1b 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -37,8 +37,13 @@ class LlavaImagePixelInputs(TypedDict): type: Literal["pixel_values"] - data: torch.Tensor - """Shape: `(batch_size * num_images, num_channels, height, width)`""" + data: Union[torch.Tensor, List[torch.Tensor]] + """ + Shape: `(batch_size * num_images, num_channels, height, width)` + + Note that `height` or `width` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. + """ class LlavaImageEmbeddingInputs(TypedDict): @@ -183,12 +188,13 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): image_feature_size_override=image_feature_size, ) elif isinstance(vision_config, PixtralVisionConfig): + # We ignore image_feature_size_override since we have non-uniform + # image sizes for Pixtral return input_processor_for_pixtral_hf( model_config, vision_config, llm_inputs, image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, ) msg = f"Unsupported vision config: {type(vision_config)}" @@ -288,6 +294,7 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) + image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: @@ -298,6 +305,30 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + # Case for models like PixtralHF that have dynamic image sizes + if image_sizes is not None: + images = pixel_values + if isinstance(images, torch.Tensor): + # if passed as batch take all images + NN, N, B, C, W, H = images.shape + images = images.reshape(NN * N * B, C, W, H) + images = [images[i] for i in range(images.size(0))] + elif isinstance(images, list): + # if passed as list flatten lists of tensors + def flatten(lst): + while isinstance(lst, list) and len(lst) == 1: + lst = lst[0] + return lst + + images = flatten(images) + # print("flattened", [img.shape for img in images]) + + # TODO: Add validation based on image_sizes + return LlavaImagePixelInputs( + type="pixel_values", + data=images, + ) + return LlavaImagePixelInputs( type="pixel_values", data=self._validate_pixel_values( diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index d5723085bc967..e8eb543929b90 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -26,6 +26,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) from vllm.sequence import IntermediateTensors, SequenceData +from vllm.utils import is_list_of from .interfaces import SupportsMultiModal, SupportsPP from .utils import init_vllm_registered_model @@ -539,6 +540,7 @@ def forward( image_features: tensor of token features for all tokens of all images of shape (N_toks, D) """ + print("VisionTransformer.forward", [img.shape for img in images]) # pass images through initial convolution independently patch_embeds_list = [ self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images @@ -600,13 +602,14 @@ def get_pixtral_hf_num_patches(*, image_size: int, patch_size: int) -> int: return grid_length * grid_length -def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig) -> int: +def get_max_pixtral_hf_image_feature_size( + hf_config: PixtralVisionConfig) -> int: return get_pixtral_hf_num_patches(image_size=hf_config.image_size, patch_size=hf_config.patch_size) def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int: - return get_pixtral_hf_image_feature_size(hf_config) + return get_max_pixtral_hf_image_feature_size(hf_config) def dummy_seq_data_for_pixtral_hf( @@ -618,7 +621,7 @@ def dummy_seq_data_for_pixtral_hf( image_feature_size_override: Optional[int] = None, ): if image_feature_size_override is None: - image_feature_size = get_pixtral_hf_image_feature_size(hf_config) + image_feature_size = get_max_pixtral_hf_image_feature_size(hf_config) else: image_feature_size = image_feature_size_override @@ -645,6 +648,28 @@ def dummy_image_for_pixtral_hf( return {"image": image if num_images == 1 else [image] * num_images} +def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig, + image_width: int, image_height: int): + # Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501 + # https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180 # noqa: E501 + max_height, max_width = hf_config.image_size, hf_config.image_size + patch_height, patch_width = hf_config.patch_size, hf_config.patch_size + + ratio = max(image_height / max_height, image_width / max_width) + + if ratio > 1: + import numpy + image_height = int(numpy.ceil(image_height / ratio)) + image_width = int(numpy.ceil(image_width / ratio)) + + from transformers.models.pixtral.image_processing_pixtral import ( # noqa: E501 + _num_image_tokens) + num_height_tokens, num_width_tokens = _num_image_tokens( + (image_height, image_width), (patch_height, patch_width)) + + return num_height_tokens * num_width_tokens + + def input_processor_for_pixtral_hf( model_config: ModelConfig, hf_config: PixtralVisionConfig, @@ -653,22 +678,33 @@ def input_processor_for_pixtral_hf( image_token_id: int, image_feature_size_override: Optional[Union[int, List[int]]] = None, ): + assert image_feature_size_override is None, ( + "image_feature_size_override is not supported for Pixtral") + multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs tokenizer = cached_get_tokenizer(model_config.tokenizer) - if image_feature_size_override is None: - image_data = multi_modal_data["image"] - if isinstance(image_data, Image.Image): - image_feature_size = get_pixtral_hf_image_feature_size(hf_config) - elif isinstance(image_data, torch.Tensor): - num_images, image_feature_size, hidden_size = image_data.shape - else: - raise TypeError(f"Invalid image type: {type(image_data)}") + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + w, h = image_data.size + image_feature_size = get_pixtral_hf_image_feature_size(hf_config, + image_width=w, + image_height=h) + elif is_list_of(image_data, Image.Image): + image_feature_size = [] + for image in image_data: + w, h = image.size + image_feature_size.append( + get_pixtral_hf_image_feature_size(hf_config, + image_width=w, + image_height=h)) + elif isinstance(image_data, torch.Tensor): + num_images, image_feature_size, hidden_size = image_data.shape else: - image_feature_size = image_feature_size_override + raise TypeError(f"Invalid image type: {type(image_data)}") new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, @@ -849,6 +885,8 @@ def forward( image_features: tensor of token features for all tokens of all images of shape (N_toks, D) """ + print("PixtralHFVisionModel.forward", + [img.shape for img in pixel_values]) # pass images through initial convolution independently patch_embeds_list = [ self.patch_conv(img.unsqueeze(0).to(self.dtype)) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 761f0406b1333..f95d81674d7cd 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -119,10 +119,16 @@ def flatten_bn( The input tensor should have shape ``(B, N, ...)```. """ if isinstance(x, torch.Tensor): + print(f"flatten_bn isinstance(x, torch.Tensor), before:{x.shape}, " + f"after:{x.flatten(0, 1).shape}") return x.flatten(0, 1) if concat: - return torch.cat(x) + print("flatten_bn concat") + for xi in x[0][0]: + print(f"before:{xi.shape}") + print(f"after:{torch.cat(x).squeeze().shape}") + return torch.cat(x).squeeze() return [x_n for x_b in x for x_n in x_b] From 69f47fa522e2991bd62cd0dcefd71d781147d8b6 Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 10 Oct 2024 19:36:10 +0000 Subject: [PATCH 03/12] Test script --- pixtral_hf.py | 69 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 pixtral_hf.py diff --git a/pixtral_hf.py b/pixtral_hf.py new file mode 100644 index 0000000000000..8530827445e8a --- /dev/null +++ b/pixtral_hf.py @@ -0,0 +1,69 @@ +from vllm import LLM, SamplingParams +from vllm.assets.image import ImageAsset + +def reference_one_image(): + model_name = "mistral-community/pixtral-12b" + llm = LLM( + model=model_name, + max_num_seqs=1, + enforce_eager=True, + max_model_len=10000, + ) + + image1 = ImageAsset("stop_sign").pil_image.convert("RGB") + inputs = { + "prompt": f"[INST]Describe the image.\n[IMG][/INST]", + "multi_modal_data": {"image": image1}, + } + outputs = llm.generate(inputs, sampling_params=SamplingParams(temperature=0.0, max_tokens=100)) + + print(outputs[0].outputs[0].text) + """ + This image appears to be a close-up view of a large number of pink flowers, possibly cherry blossoms, against a blue sky background. The flowers are densely packed and fill the entire frame of the image, creating a vibrant and colorful display. The blue sky provides a striking contrast to the pink flowers, enhancing their visual appeal. The image does not contain any discernible text or other objects. It focuses solely on the flowers and the sky, capturing a moment of natural beauty. + """ + +def reference_two_image(): + model_name = "mistral-community/pixtral-12b" + llm = LLM( + model=model_name, + max_num_seqs=1, + enforce_eager=True, + max_model_len=10000, + limit_mm_per_prompt={"image": 2} + ) + + image1 = ImageAsset("cherry_blossom").pil_image.convert("RGB") + image2 = ImageAsset("stop_sign").pil_image.convert("RGB") + inputs = { + "prompt": f"[INST]Describe the images.\n[IMG][IMG][/INST]", + "multi_modal_data": { + "image": [image1, image2] + }, + } + outputs = llm.generate(inputs, sampling_params=SamplingParams(temperature=0.0, max_tokens=100)) + + print(outputs[0].outputs[0].text) + +def fp8_one_image(): + model_name = "nm-testing/pixtral-12b-FP8-dynamic" + llm = LLM( + model=model_name, + max_num_seqs=1, + enforce_eager=True, + max_model_len=10000, + ) + + image1 = ImageAsset("cherry_blossom").pil_image.convert("RGB") + inputs = { + "prompt": f"[INST]Describe the image.\n[IMG][/INST]", + "multi_modal_data": {"image": image1}, + } + outputs = llm.generate(inputs, sampling_params=SamplingParams(temperature=0.0, max_tokens=100)) + + print(outputs[0].outputs[0].text) + """ + This image appears to be a close-up view of a large number of pink flowers, possibly cherry blossoms, against a blue sky background. The flowers are densely packed and fill the entire frame of the image. The vibrant pink color of the flowers contrasts beautifully with the clear blue sky, creating a visually striking scene. The image likely captures a moment of natural beauty, possibly during the spring season when cherry blossoms are in full bloom. + """ + +reference_one_image() +# reference_two_image() From c0f815d9cdfabb2688bba4b4c8a0d1a04ce4d036 Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 16 Oct 2024 01:06:01 +0000 Subject: [PATCH 04/12] Pixtral is basically working! --- pixtral_hf.py | 119 ++++++++-- vllm/model_executor/models/llava.py | 2 +- vllm/model_executor/models/pixtral.py | 290 +++++++++++++++++-------- vllm/model_executor/models/qwen2_vl.py | 4 +- vllm/transformers_utils/processor.py | 3 + 5 files changed, 304 insertions(+), 114 deletions(-) diff --git a/pixtral_hf.py b/pixtral_hf.py index 8530827445e8a..c67f10d8f1f91 100644 --- a/pixtral_hf.py +++ b/pixtral_hf.py @@ -1,7 +1,40 @@ from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset +from vllm.multimodal.utils import encode_image_base64 +import torch +torch.set_printoptions(sci_mode=False) + +sampling_params = SamplingParams(temperature=0.0, max_tokens=100) + +def image_url(asset: str): + image = ImageAsset(asset) + base64 = encode_image_base64(image.pil_image) + return f"data:image/jpeg;base64,{base64}" + def reference_one_image(): + from transformers import AutoProcessor, LlavaForConditionalGeneration + model_id = "mistral-community/pixtral-12b" + model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto").to("cuda") + processor = AutoProcessor.from_pretrained(model_id) + + IMGS = [ + ImageAsset("stop_sign").pil_image.convert("RGB"), + ] + PROMPT = f"[INST][IMG]Describe the image.[/INST]" + + inputs = processor(text=PROMPT, images=IMGS, return_tensors="pt").to(model.device) + print("inputs['input_ids']", inputs["input_ids"].shape) + print("detok(inputs['input_ids'])", processor.batch_decode(inputs["input_ids"])[0]) + print("inputs['pixel_values']", [i.shape for i in inputs["pixel_values"]]) + generate_ids = model.generate(**inputs, do_sample=False, max_new_tokens=100) + output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + print(output) + """ + The image features a beautiful cherry blossoms in foreground, creating a picturesque frame around the Washington Monument. The monument is illuminated with a warm, golden light, standing tall against a clear, blue sky. The blossoms are in full bloom, with delicate, pink petals, adding a soft, romantic touch to the scene. The perspective is from a low angle, looking up at the monument, emphasizing its height and grandeur. The overall mood of the image is serene and peaceful, celebrating the + """ + +def pixtralhf_one_image(): model_name = "mistral-community/pixtral-12b" llm = LLM( model=model_name, @@ -10,19 +43,74 @@ def reference_one_image(): max_model_len=10000, ) - image1 = ImageAsset("stop_sign").pil_image.convert("RGB") - inputs = { - "prompt": f"[INST]Describe the image.\n[IMG][/INST]", - "multi_modal_data": {"image": image1}, - } - outputs = llm.generate(inputs, sampling_params=SamplingParams(temperature=0.0, max_tokens=100)) + chat_template = "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\n\n\" }}\n {%- else %}\n {{- \"[INST]\" }}\n {%- endif %}\n {%- if message[\"content\"] is not string %}\n {%- for chunk in message[\"content\"] %}\n {%- if chunk[\"type\"] == \"text\" %}\n {{- chunk[\"content\"] }}\n {%- elif chunk[\"type\"] == \"image\" %}\n {{- \"[IMG]\" }}\n {%- else %}\n {{- raise_exception(\"Unrecognized content type!\") }}\n {%- endif %}\n {%- endfor %}\n {%- else %}\n {{- message[\"content\"] }}\n {%- endif %}\n {{- \"[/INST]\" }}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}" + + # image1 = ImageAsset("stop_sign").pil_image.convert("RGB") + # inputs = { + # "prompt": "[INST][IMG]Describe the image.[/INST]", + # "multi_modal_data": {"image": image1}, + # } + # outputs = llm.generate(inputs, sampling_params=sampling_params) + + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Describe the image." + }, + { + "type": "image_url", + "image_url": { + "url": image_url("stop_sign") + } + }, + ], + }, + ] + outputs = llm.chat(messages, sampling_params=sampling_params, chat_template=chat_template) print(outputs[0].outputs[0].text) """ This image appears to be a close-up view of a large number of pink flowers, possibly cherry blossoms, against a blue sky background. The flowers are densely packed and fill the entire frame of the image, creating a vibrant and colorful display. The blue sky provides a striking contrast to the pink flowers, enhancing their visual appeal. The image does not contain any discernible text or other objects. It focuses solely on the flowers and the sky, capturing a moment of natural beauty. """ -def reference_two_image(): +def pixtral_one_image(): + model_name = "mistralai/Pixtral-12B-2409" + llm = LLM( + model=model_name, + max_num_seqs=1, + enforce_eager=True, + max_model_len=10000, + tokenizer_mode="mistral", + ) + + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Describe the image." + }, + { + "type": "image_url", + "image_url": { + "url": image_url("stop_sign") + } + }, + ], + }, + ] + outputs = llm.chat(messages, sampling_params=sampling_params) + + print(outputs[0].outputs[0].text) + """ + This image appears to be a close-up view of a large number of pink flowers, possibly cherry blossoms, against a blue sky background. The flowers are densely packed and fill the entire frame of the image, creating a vibrant and colorful display. The blue sky provides a striking contrast to the pink flowers, enhancing their visual appeal. The image does not contain any discernible text or other objects. It focuses solely on the flowers and the sky, capturing a moment of natural beauty. + """ + +def pixtralhf_two_image(): model_name = "mistral-community/pixtral-12b" llm = LLM( model=model_name, @@ -35,16 +123,16 @@ def reference_two_image(): image1 = ImageAsset("cherry_blossom").pil_image.convert("RGB") image2 = ImageAsset("stop_sign").pil_image.convert("RGB") inputs = { - "prompt": f"[INST]Describe the images.\n[IMG][IMG][/INST]", + "prompt": "[INST][IMG][IMG]Describe the images.[/INST]", "multi_modal_data": { "image": [image1, image2] }, } - outputs = llm.generate(inputs, sampling_params=SamplingParams(temperature=0.0, max_tokens=100)) + outputs = llm.generate(inputs, sampling_params=sampling_params) print(outputs[0].outputs[0].text) -def fp8_one_image(): +def pixtralhf_fp8_one_image(): model_name = "nm-testing/pixtral-12b-FP8-dynamic" llm = LLM( model=model_name, @@ -55,15 +143,18 @@ def fp8_one_image(): image1 = ImageAsset("cherry_blossom").pil_image.convert("RGB") inputs = { - "prompt": f"[INST]Describe the image.\n[IMG][/INST]", + "prompt": "[INST][IMG]Describe the image.[/INST]", "multi_modal_data": {"image": image1}, } - outputs = llm.generate(inputs, sampling_params=SamplingParams(temperature=0.0, max_tokens=100)) + outputs = llm.generate(inputs, sampling_params=sampling_params) print(outputs[0].outputs[0].text) """ This image appears to be a close-up view of a large number of pink flowers, possibly cherry blossoms, against a blue sky background. The flowers are densely packed and fill the entire frame of the image. The vibrant pink color of the flowers contrasts beautifully with the clear blue sky, creating a visually striking scene. The image likely captures a moment of natural beauty, possibly during the spring season when cherry blossoms are in full bloom. """ -reference_one_image() -# reference_two_image() +# reference_one_image() +pixtralhf_one_image() +# pixtral_one_image() + +# pixtralhf_two_image() diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index b3575301e5e1b..dc40c335d96f2 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -359,7 +359,7 @@ def _select_image_features(self, image_features: torch.Tensor, *, def _image_pixels_to_features( self, - vision_tower: Union[CLIPVisionModel, SiglipVisionModel], + vision_tower: Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel], pixel_values: torch.Tensor, ) -> torch.Tensor: diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index e8eb543929b90..c3149eb43606f 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -9,8 +9,11 @@ from mistral_common.protocol.instruct.messages import ImageChunk from PIL import Image from transformers import PixtralVisionConfig, PretrainedConfig +from transformers.models.pixtral.image_processing_pixtral import ( + _num_image_tokens) from xformers.ops.fmha import memory_efficient_attention from xformers.ops.fmha.attn_bias import BlockDiagonalMask +import numpy from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, MultiModalConfig @@ -23,8 +26,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs -from vllm.multimodal.utils import (cached_get_tokenizer, - repeat_and_pad_placeholder_tokens) +from vllm.model_executor.layers.activation import get_act_fn +from vllm.multimodal.utils import cached_get_tokenizer +from vllm.transformers_utils.processor import cached_get_processor from vllm.sequence import IntermediateTensors, SequenceData from vllm.utils import is_list_of @@ -540,7 +544,6 @@ def forward( image_features: tensor of token features for all tokens of all images of shape (N_toks, D) """ - print("VisionTransformer.forward", [img.shape for img in images]) # pass images through initial convolution independently patch_embeds_list = [ self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images @@ -649,25 +652,22 @@ def dummy_image_for_pixtral_hf( def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig, - image_width: int, image_height: int): + image_width: int, image_height: int) -> Tuple[int, int]: # Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501 # https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180 # noqa: E501 - max_height, max_width = hf_config.image_size, hf_config.image_size - patch_height, patch_width = hf_config.patch_size, hf_config.patch_size + max_width, max_height = hf_config.image_size, hf_config.image_size + patch_width, patch_height = hf_config.patch_size, hf_config.patch_size - ratio = max(image_height / max_height, image_width / max_width) + ratio = max(image_width / max_width, image_height / max_height) if ratio > 1: - import numpy - image_height = int(numpy.ceil(image_height / ratio)) image_width = int(numpy.ceil(image_width / ratio)) + image_height = int(numpy.ceil(image_height / ratio)) - from transformers.models.pixtral.image_processing_pixtral import ( # noqa: E501 - _num_image_tokens) num_height_tokens, num_width_tokens = _num_image_tokens( (image_height, image_width), (patch_height, patch_width)) - return num_height_tokens * num_width_tokens + return num_width_tokens, num_height_tokens def input_processor_for_pixtral_hf( @@ -686,41 +686,117 @@ def input_processor_for_pixtral_hf( return llm_inputs tokenizer = cached_get_tokenizer(model_config.tokenizer) + processor = cached_get_processor(model_config.model) image_data = multi_modal_data["image"] if isinstance(image_data, Image.Image): - w, h = image_data.size - image_feature_size = get_pixtral_hf_image_feature_size(hf_config, - image_width=w, - image_height=h) - elif is_list_of(image_data, Image.Image): - image_feature_size = [] - for image in image_data: - w, h = image.size - image_feature_size.append( - get_pixtral_hf_image_feature_size(hf_config, - image_width=w, - image_height=h)) - elif isinstance(image_data, torch.Tensor): - num_images, image_feature_size, hidden_size = image_data.shape - else: + image_data = [image_data] + elif not is_list_of(image_data, Image.Image): raise TypeError(f"Invalid image type: {type(image_data)}") - new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( - tokenizer, - llm_inputs.get("prompt"), - llm_inputs["prompt_token_ids"], - placeholder_token_id=image_token_id, - repeat_count=image_feature_size, - ) + replace_strings = [] + new_prompt = llm_inputs.get("prompt") + for image in image_data: + w, h = image.size + + num_width_tokens, num_height_tokens = get_pixtral_hf_image_feature_size(hf_config, + image_width=w, + image_height=h) + + replace_tokens = [ + [processor.image_token] * num_width_tokens + [processor.image_break_token] + ] * num_height_tokens + # Flatten list + replace_tokens = [item for sublist in replace_tokens for item in sublist] + replace_tokens[-1] = processor.image_end_token + replace_str = "".join(replace_tokens) + replace_strings.append(replace_str) + new_prompt = new_prompt.replace(processor.image_token, "", 1) + + while "" in new_prompt: + replace_str = replace_strings.pop(0) + new_prompt = new_prompt.replace("", replace_str, 1) + + new_token_ids = tokenizer(new_prompt)["input_ids"] + + # print("new_token_ids", new_token_ids) + # print("new_token_ids", len(new_token_ids)) + # print("new_prompt", new_prompt) # NOTE: Create a defensive copy of the original inputs return LLMInputs(prompt_token_ids=new_token_ids, prompt=new_prompt, multi_modal_data=multi_modal_data) +class PixtralHFRotaryEmbedding(nn.Module): + """ + The key with pixtral embedding is just that you have a frequency for each pixel positions. + If you have height x width pixels (or embedding pixels), then the frequency used for ROPE + is given by indexing the pre_computed frequency on the width and height. + + What you output is of dimension (batch, height * width, dim) with dim the embed dim. + + This simply means that for each image hidden state, you are going to add + a corresponding positional embedding, based on its index in the grid. + """ -class PixtralHFFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.rope_type = "default" + self.dim = config.head_dim + self.base = config.rope_theta + max_patches_per_side = config.image_size // config.patch_size + freqs = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) + + h = torch.arange(max_patches_per_side, device=freqs.device) + w = torch.arange(max_patches_per_side, device=freqs.device) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + inv_freq = torch.cat( + [ + freqs_h[:, None, :].repeat(1, max_patches_per_side, 1), + freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), + ], + dim=-1, + ).reshape(-1, self.dim // 2) # we reshape to only index on the position indexes, not tuple of indexes + # Different from paper, but it uses a different permutation in order to obtain the same calculation + + # TODO maybe make it torch compatible later on. We can also just slice + self.register_buffer("inv_freq", torch.cat((inv_freq, inv_freq), dim=-1), persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids): + # Core RoPE block + freqs = self.inv_freq[position_ids] + # position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + emb = freqs + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class PixtralHFMLP(nn.Module): def __init__(self, config: PixtralVisionConfig): super().__init__() @@ -728,12 +804,13 @@ def __init__(self, config: PixtralVisionConfig): self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) - self.down_proj = nn.Linear(config.intermediate_size, - config.hidden_size, - bias=False) self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, + config.hidden_size, + bias=False) + self.act_fn = get_act_fn(config.hidden_act) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) @@ -748,6 +825,8 @@ def __init__(self, config: PixtralVisionConfig): self.n_heads = config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads + self.scale = self.head_dim**-0.5 + self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) @@ -763,42 +842,59 @@ def __init__(self, config: PixtralVisionConfig): def forward( self, - x: torch.Tensor, - mask: BlockDiagonalMask, - freqs_cis: torch.Tensor, - ) -> torch.Tensor: - batch, patches, _ = x.shape + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" - q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x) - q = q.reshape(batch, patches, self.n_heads, self.head_dim) - k = k.reshape(batch, patches, self.n_heads, self.head_dim) - v = v.reshape(batch, patches, self.n_heads, self.head_dim) + batch_size, patches, _ = hidden_states.size() - q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis) - out = memory_efficient_attention(q, k, v, attn_bias=mask) - out = out.reshape(batch, patches, self.n_heads * self.head_dim) - return self.o_proj(out) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, patches, self.n_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, patches, self.n_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, patches, self.n_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=0) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, patches, -1) + + return self.o_proj(attn_output) class PixtralHFTransformerBlock(nn.Module): def __init__(self, config: PixtralVisionConfig): super().__init__() - self.attention = PixtralHFAttention(config) - self.feed_forward = PixtralHFFeedForward(config) self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5) + self.attention = PixtralHFAttention(config) + self.feed_forward = PixtralHFMLP(config) self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5) def forward( self, - x: torch.Tensor, - mask: BlockDiagonalMask, - freqs_cis: torch.Tensor, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, ) -> torch.Tensor: - r = self.attention.forward(self.attention_norm(x), - mask=mask, - freqs_cis=freqs_cis) - h = x + r + r = self.attention.forward(self.attention_norm(hidden_states), + attention_mask=attention_mask, + position_embeddings=position_embeddings) + h = hidden_states + r r = self.feed_forward.forward(self.ffn_norm(h)) out = h + r return out @@ -815,12 +911,37 @@ def __init__(self, config: PixtralVisionConfig): def forward( self, x: torch.Tensor, - mask: BlockDiagonalMask, - freqs_cis: Optional[torch.Tensor], + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, ) -> torch.Tensor: for layer in self.layers: - x = layer(x, mask=mask, freqs_cis=freqs_cis) + x = layer(x, attention_mask, position_embeddings) return x + +def position_ids_in_meshgrid(patch_embeds_list, max_width): + positions = [] + for patch in patch_embeds_list: + height, width = patch.shape[-2:] + mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") + h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) + ids = h_grid * max_width + v_grid + positions.append(ids[:, 0]) + return torch.cat(positions) + +def generate_block_attention_mask(patch_embeds_list, tensor): + dtype = tensor.dtype + device = tensor.device + seq_len = tensor.shape[1] + d_min = torch.finfo(dtype).min + causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device) + + block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1) + block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1) + for start, end in zip(block_start_idx, block_end_idx): + causal_mask[start:end, start:end] = 0 + + causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1) + return causal_mask class PixtralHFVisionModel(nn.Module): @@ -841,14 +962,7 @@ def __init__(self, config: PixtralVisionConfig): ) self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5) self.transformer = PixtralHFTransformer(config) - - head_dim = self.config.hidden_size // self.config.num_attention_heads - assert head_dim % 2 == 0, "ROPE requires even head_dim" - self._freqs_cis: Optional[torch.Tensor] = None - - @property - def max_patches_per_side(self) -> int: - return self.config.image_size // self.config.patch_size + self.patch_positional_embedding = PixtralHFRotaryEmbedding(config) @property def device(self) -> torch.device: @@ -858,21 +972,6 @@ def device(self) -> torch.device: def dtype(self) -> torch.device: return next(self.parameters()).dtype - @property - def freqs_cis(self) -> torch.Tensor: - if self._freqs_cis is None: - self._freqs_cis = precompute_freqs_cis_2d( - dim=self.config.hidden_size // self.config.num_attention_heads, - height=self.max_patches_per_side, - width=self.max_patches_per_side, - theta=self.config.rope_theta, - ) - - if self._freqs_cis.device != self.device: - self._freqs_cis = self._freqs_cis.to(device=self.device) - - return self._freqs_cis - def forward( self, pixel_values: List[torch.Tensor], @@ -885,8 +984,6 @@ def forward( image_features: tensor of token features for all tokens of all images of shape (N_toks, D) """ - print("PixtralHFVisionModel.forward", - [img.shape for img in pixel_values]) # pass images through initial convolution independently patch_embeds_list = [ self.patch_conv(img.unsqueeze(0).to(self.dtype)) @@ -899,16 +996,17 @@ def forward( patch_embeds = self.ln_pre(patch_embeds) # positional embeddings - positions = position_meshgrid(patch_embeds_list).to(self.device) - freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] + position_ids = position_ids_in_meshgrid( + patch_embeds_list, max_width=self.config.image_size // self.config.patch_size + ).to(self.device) - # pass through Transformer with a block diagonal mask delimiting images - mask = BlockDiagonalMask.from_seqlens( - [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) - out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) + position_embedding = self.patch_positional_embedding(patch_embeds, position_ids) + attention_mask = generate_block_attention_mask( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds + ) + out = self.transformer(patch_embeds, attention_mask, position_embedding) - # remove batch dimension of the single sequence - return out.squeeze(0) + return out # (TODO) Add prefix argument for filtering out weights to be loaded # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 24fd5152ecd09..6029fb089beeb 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -64,7 +64,7 @@ from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig, Qwen2VLVisionConfig) -from vllm.transformers_utils.processor import get_processor +from vllm.transformers_utils.processor import cached_get_processor from vllm.utils import is_cpu from .interfaces import SupportsMultiModal, SupportsPP @@ -570,8 +570,6 @@ def forward( # === Vision input helpers === # -cached_get_processor = lru_cache(get_processor) - def mm_input_mapper_for_qwen2_vl( ctx: InputContext, diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index 98663f7f0bd07..ada51ffd60f06 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -1,4 +1,5 @@ from typing import Any, cast +from functools import lru_cache def get_processor( @@ -37,6 +38,8 @@ def get_processor( return cast(ProcessorMixin, processor) +cached_get_processor = lru_cache(get_processor) + def get_image_processor( processor_name: str, *args: Any, From cdc075d78becc047f216114923ae080e3f94121e Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 16 Oct 2024 01:06:15 +0000 Subject: [PATCH 05/12] Format --- pixtral_hf.py | 91 ++++++++++--------- vllm/model_executor/models/llava.py | 3 +- vllm/model_executor/models/pixtral.py | 117 ++++++++++++++++--------- vllm/model_executor/models/qwen2_vl.py | 2 +- vllm/transformers_utils/processor.py | 3 +- 5 files changed, 130 insertions(+), 86 deletions(-) diff --git a/pixtral_hf.py b/pixtral_hf.py index c67f10d8f1f91..d60cccb464842 100644 --- a/pixtral_hf.py +++ b/pixtral_hf.py @@ -1,11 +1,14 @@ +import torch + from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from vllm.multimodal.utils import encode_image_base64 -import torch + torch.set_printoptions(sci_mode=False) sampling_params = SamplingParams(temperature=0.0, max_tokens=100) + def image_url(asset: str): image = ImageAsset(asset) base64 = encode_image_base64(image.pil_image) @@ -15,35 +18,40 @@ def image_url(asset: str): def reference_one_image(): from transformers import AutoProcessor, LlavaForConditionalGeneration model_id = "mistral-community/pixtral-12b" - model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto").to("cuda") + model = LlavaForConditionalGeneration.from_pretrained( + model_id, torch_dtype="auto").to("cuda") processor = AutoProcessor.from_pretrained(model_id) IMGS = [ ImageAsset("stop_sign").pil_image.convert("RGB"), ] - PROMPT = f"[INST][IMG]Describe the image.[/INST]" + PROMPT = "[INST][IMG]Describe the image.[/INST]" - inputs = processor(text=PROMPT, images=IMGS, return_tensors="pt").to(model.device) + inputs = processor(text=PROMPT, images=IMGS, + return_tensors="pt").to(model.device) print("inputs['input_ids']", inputs["input_ids"].shape) - print("detok(inputs['input_ids'])", processor.batch_decode(inputs["input_ids"])[0]) + print("detok(inputs['input_ids'])", + processor.batch_decode(inputs["input_ids"])[0]) print("inputs['pixel_values']", [i.shape for i in inputs["pixel_values"]]) - generate_ids = model.generate(**inputs, do_sample=False, max_new_tokens=100) - output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + generate_ids = model.generate(**inputs, + do_sample=False, + max_new_tokens=100) + output = processor.batch_decode(generate_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False)[0] print(output) - """ - The image features a beautiful cherry blossoms in foreground, creating a picturesque frame around the Washington Monument. The monument is illuminated with a warm, golden light, standing tall against a clear, blue sky. The blossoms are in full bloom, with delicate, pink petals, adding a soft, romantic touch to the scene. The perspective is from a low angle, looking up at the monument, emphasizing its height and grandeur. The overall mood of the image is serene and peaceful, celebrating the - """ + def pixtralhf_one_image(): model_name = "mistral-community/pixtral-12b" llm = LLM( - model=model_name, - max_num_seqs=1, - enforce_eager=True, - max_model_len=10000, + model=model_name, + max_num_seqs=1, + enforce_eager=True, + max_model_len=10000, ) - chat_template = "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\n\n\" }}\n {%- else %}\n {{- \"[INST]\" }}\n {%- endif %}\n {%- if message[\"content\"] is not string %}\n {%- for chunk in message[\"content\"] %}\n {%- if chunk[\"type\"] == \"text\" %}\n {{- chunk[\"content\"] }}\n {%- elif chunk[\"type\"] == \"image\" %}\n {{- \"[IMG]\" }}\n {%- else %}\n {{- raise_exception(\"Unrecognized content type!\") }}\n {%- endif %}\n {%- endfor %}\n {%- else %}\n {{- message[\"content\"] }}\n {%- endif %}\n {{- \"[/INST]\" }}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}" + chat_template = "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\n\n\" }}\n {%- else %}\n {{- \"[INST]\" }}\n {%- endif %}\n {%- if message[\"content\"] is not string %}\n {%- for chunk in message[\"content\"] %}\n {%- if chunk[\"type\"] == \"text\" %}\n {{- chunk[\"content\"] }}\n {%- elif chunk[\"type\"] == \"image\" %}\n {{- \"[IMG]\" }}\n {%- else %}\n {{- raise_exception(\"Unrecognized content type!\") }}\n {%- endif %}\n {%- endfor %}\n {%- else %}\n {{- message[\"content\"] }}\n {%- endif %}\n {{- \"[/INST]\" }}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}" # noqa # image1 = ImageAsset("stop_sign").pil_image.convert("RGB") # inputs = { @@ -54,7 +62,8 @@ def pixtralhf_one_image(): messages = [ { - "role": "user", + "role": + "user", "content": [ { "type": "text", @@ -69,26 +78,27 @@ def pixtralhf_one_image(): ], }, ] - outputs = llm.chat(messages, sampling_params=sampling_params, chat_template=chat_template) + outputs = llm.chat(messages, + sampling_params=sampling_params, + chat_template=chat_template) print(outputs[0].outputs[0].text) - """ - This image appears to be a close-up view of a large number of pink flowers, possibly cherry blossoms, against a blue sky background. The flowers are densely packed and fill the entire frame of the image, creating a vibrant and colorful display. The blue sky provides a striking contrast to the pink flowers, enhancing their visual appeal. The image does not contain any discernible text or other objects. It focuses solely on the flowers and the sky, capturing a moment of natural beauty. - """ + def pixtral_one_image(): model_name = "mistralai/Pixtral-12B-2409" llm = LLM( - model=model_name, - max_num_seqs=1, - enforce_eager=True, + model=model_name, + max_num_seqs=1, + enforce_eager=True, max_model_len=10000, tokenizer_mode="mistral", ) messages = [ { - "role": "user", + "role": + "user", "content": [ { "type": "text", @@ -106,19 +116,15 @@ def pixtral_one_image(): outputs = llm.chat(messages, sampling_params=sampling_params) print(outputs[0].outputs[0].text) - """ - This image appears to be a close-up view of a large number of pink flowers, possibly cherry blossoms, against a blue sky background. The flowers are densely packed and fill the entire frame of the image, creating a vibrant and colorful display. The blue sky provides a striking contrast to the pink flowers, enhancing their visual appeal. The image does not contain any discernible text or other objects. It focuses solely on the flowers and the sky, capturing a moment of natural beauty. - """ + def pixtralhf_two_image(): model_name = "mistral-community/pixtral-12b" - llm = LLM( - model=model_name, - max_num_seqs=1, - enforce_eager=True, - max_model_len=10000, - limit_mm_per_prompt={"image": 2} - ) + llm = LLM(model=model_name, + max_num_seqs=1, + enforce_eager=True, + max_model_len=10000, + limit_mm_per_prompt={"image": 2}) image1 = ImageAsset("cherry_blossom").pil_image.convert("RGB") image2 = ImageAsset("stop_sign").pil_image.convert("RGB") @@ -132,26 +138,27 @@ def pixtralhf_two_image(): print(outputs[0].outputs[0].text) + def pixtralhf_fp8_one_image(): model_name = "nm-testing/pixtral-12b-FP8-dynamic" llm = LLM( - model=model_name, - max_num_seqs=1, - enforce_eager=True, - max_model_len=10000, + model=model_name, + max_num_seqs=1, + enforce_eager=True, + max_model_len=10000, ) image1 = ImageAsset("cherry_blossom").pil_image.convert("RGB") inputs = { "prompt": "[INST][IMG]Describe the image.[/INST]", - "multi_modal_data": {"image": image1}, + "multi_modal_data": { + "image": image1 + }, } outputs = llm.generate(inputs, sampling_params=sampling_params) print(outputs[0].outputs[0].text) - """ - This image appears to be a close-up view of a large number of pink flowers, possibly cherry blossoms, against a blue sky background. The flowers are densely packed and fill the entire frame of the image. The vibrant pink color of the flowers contrasts beautifully with the clear blue sky, creating a visually striking scene. The image likely captures a moment of natural beauty, possibly during the spring season when cherry blossoms are in full bloom. - """ + # reference_one_image() pixtralhf_one_image() diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index dc40c335d96f2..5fd219f0bf336 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -359,7 +359,8 @@ def _select_image_features(self, image_features: torch.Tensor, *, def _image_pixels_to_features( self, - vision_tower: Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel], + vision_tower: Union[CLIPVisionModel, SiglipVisionModel, + PixtralHFVisionModel], pixel_values: torch.Tensor, ) -> torch.Tensor: diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index c3149eb43606f..e8c113c5afe25 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -3,6 +3,7 @@ from itertools import tee from typing import Iterable, List, Mapping, Optional, Tuple, Union +import numpy import torch import torch.nn as nn import torch.nn.functional as F @@ -10,14 +11,14 @@ from PIL import Image from transformers import PixtralVisionConfig, PretrainedConfig from transformers.models.pixtral.image_processing_pixtral import ( - _num_image_tokens) + _num_image_tokens) from xformers.ops.fmha import memory_efficient_attention from xformers.ops.fmha.attn_bias import BlockDiagonalMask -import numpy from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -26,10 +27,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs -from vllm.model_executor.layers.activation import get_act_fn from vllm.multimodal.utils import cached_get_tokenizer -from vllm.transformers_utils.processor import cached_get_processor from vllm.sequence import IntermediateTensors, SequenceData +from vllm.transformers_utils.processor import cached_get_processor from vllm.utils import is_list_of from .interfaces import SupportsMultiModal, SupportsPP @@ -652,7 +652,8 @@ def dummy_image_for_pixtral_hf( def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig, - image_width: int, image_height: int) -> Tuple[int, int]: + image_width: int, + image_height: int) -> Tuple[int, int]: # Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501 # https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180 # noqa: E501 max_width, max_height = hf_config.image_size, hf_config.image_size @@ -698,20 +699,21 @@ def input_processor_for_pixtral_hf( new_prompt = llm_inputs.get("prompt") for image in image_data: w, h = image.size - - num_width_tokens, num_height_tokens = get_pixtral_hf_image_feature_size(hf_config, - image_width=w, - image_height=h) - replace_tokens = [ - [processor.image_token] * num_width_tokens + [processor.image_break_token] - ] * num_height_tokens + num_width_tokens, num_height_tokens = get_pixtral_hf_image_feature_size( + hf_config, image_width=w, image_height=h) + + replace_tokens = [[processor.image_token] * num_width_tokens + + [processor.image_break_token]] * num_height_tokens # Flatten list - replace_tokens = [item for sublist in replace_tokens for item in sublist] + replace_tokens = [ + item for sublist in replace_tokens for item in sublist + ] replace_tokens[-1] = processor.image_end_token replace_str = "".join(replace_tokens) replace_strings.append(replace_str) - new_prompt = new_prompt.replace(processor.image_token, "", 1) + new_prompt = new_prompt.replace(processor.image_token, "", + 1) while "" in new_prompt: replace_str = replace_strings.pop(0) @@ -728,13 +730,16 @@ def input_processor_for_pixtral_hf( prompt=new_prompt, multi_modal_data=multi_modal_data) + class PixtralHFRotaryEmbedding(nn.Module): """ - The key with pixtral embedding is just that you have a frequency for each pixel positions. - If you have height x width pixels (or embedding pixels), then the frequency used for ROPE - is given by indexing the pre_computed frequency on the width and height. + The key with pixtral embedding is just that you have a frequency for each + pixel positions. If you have height x width pixels (or embedding pixels), + then the frequency used for ROPE is given by indexing the pre_computed + frequency on the width and height. - What you output is of dimension (batch, height * width, dim) with dim the embed dim. + What you output is of dimension (batch, height * width, dim) with dim the + embed dim. This simply means that for each image hidden state, you are going to add a corresponding positional embedding, based on its index in the grid. @@ -746,7 +751,8 @@ def __init__(self, config): self.dim = config.head_dim self.base = config.rope_theta max_patches_per_side = config.image_size // config.patch_size - freqs = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) + freqs = 1.0 / (self.base + **(torch.arange(0, self.dim, 2).float() / self.dim)) h = torch.arange(max_patches_per_side, device=freqs.device) w = torch.arange(max_patches_per_side, device=freqs.device) @@ -759,11 +765,16 @@ def __init__(self, config): freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), ], dim=-1, - ).reshape(-1, self.dim // 2) # we reshape to only index on the position indexes, not tuple of indexes - # Different from paper, but it uses a different permutation in order to obtain the same calculation + ).reshape( + -1, self.dim // 2 + ) # we reshape to only index on the position indexes, not tuple of + # indexes. Different from paper, but it uses a different permutation + # in order to obtain the same calculation # TODO maybe make it torch compatible later on. We can also just slice - self.register_buffer("inv_freq", torch.cat((inv_freq, inv_freq), dim=-1), persistent=False) + self.register_buffer("inv_freq", + torch.cat((inv_freq, inv_freq), dim=-1), + persistent=False) @torch.no_grad() def forward(self, x, position_ids): @@ -772,7 +783,8 @@ def forward(self, x, position_ids): # position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + device_type = device_type if isinstance( + device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): emb = freqs cos = emb.cos() @@ -783,8 +795,8 @@ def forward(self, x, position_ids): # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) @@ -796,6 +808,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed + class PixtralHFMLP(nn.Module): def __init__(self, config: PixtralVisionConfig): @@ -854,20 +867,31 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(batch_size, patches, self.n_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, patches, self.n_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, patches, self.n_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(batch_size, patches, self.n_heads, + self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, patches, self.n_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, patches, self.n_heads, + self.head_dim).transpose(1, 2) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=0) + query_states, key_states = apply_rotary_pos_emb(query_states, + key_states, + cos, + sin, + unsqueeze_dim=0) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + attn_weights = torch.matmul(query_states, key_states.transpose( + 2, 3)) * self.scale if attention_mask is not None: attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.softmax(attn_weights, + dim=-1, + dtype=torch.float32).to( + query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() @@ -917,30 +941,38 @@ def forward( for layer in self.layers: x = layer(x, attention_mask, position_embeddings) return x - + + def position_ids_in_meshgrid(patch_embeds_list, max_width): positions = [] for patch in patch_embeds_list: height, width = patch.shape[-2:] - mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") + mesh = torch.meshgrid(torch.arange(height), + torch.arange(width), + indexing="ij") h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) ids = h_grid * max_width + v_grid positions.append(ids[:, 0]) return torch.cat(positions) + def generate_block_attention_mask(patch_embeds_list, tensor): dtype = tensor.dtype device = tensor.device seq_len = tensor.shape[1] d_min = torch.finfo(dtype).min - causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device) + causal_mask = torch.full((seq_len, seq_len), + fill_value=d_min, + dtype=dtype, + device=device) block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1) block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1) for start, end in zip(block_start_idx, block_end_idx): causal_mask[start:end, start:end] = 0 - causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1) + causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, + -1) return causal_mask @@ -997,14 +1029,17 @@ def forward( # positional embeddings position_ids = position_ids_in_meshgrid( - patch_embeds_list, max_width=self.config.image_size // self.config.patch_size - ).to(self.device) + patch_embeds_list, + max_width=self.config.image_size // self.config.patch_size).to( + self.device) - position_embedding = self.patch_positional_embedding(patch_embeds, position_ids) + position_embedding = self.patch_positional_embedding( + patch_embeds, position_ids) attention_mask = generate_block_attention_mask( - [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds - ) - out = self.transformer(patch_embeds, attention_mask, position_embedding) + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], + patch_embeds) + out = self.transformer(patch_embeds, attention_mask, + position_embedding) return out diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 6029fb089beeb..bdb5493c85dbd 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -22,7 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" -from functools import lru_cache, partial +from functools import partial from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, Tuple, Type, TypedDict, Union) diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index ada51ffd60f06..f1523667b0466 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -1,5 +1,5 @@ -from typing import Any, cast from functools import lru_cache +from typing import Any, cast def get_processor( @@ -40,6 +40,7 @@ def get_processor( cached_get_processor = lru_cache(get_processor) + def get_image_processor( processor_name: str, *args: Any, From 592b69219e0c23f8392e0b048bb4ab8afa45ce1d Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 16 Oct 2024 14:02:49 +0000 Subject: [PATCH 06/12] Merge and format --- vllm/model_executor/models/llava.py | 2 +- vllm/model_executor/models/pixtral.py | 22 +++++++++------------- vllm/model_executor/models/qwen2_vl.py | 2 +- 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index afbb2728f2b63..5520bd5a8922e 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -192,7 +192,7 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs): return input_processor_for_pixtral_hf( model_config, vision_config, - llm_inputs, + inputs, image_token_id=hf_config.image_token_index, ) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 7b5479cef0218..09438813556e3 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -17,7 +17,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext, LLMInputs +from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -673,17 +673,17 @@ def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig, def input_processor_for_pixtral_hf( model_config: ModelConfig, hf_config: PixtralVisionConfig, - llm_inputs: LLMInputs, + inputs: DecoderOnlyInputs, *, image_token_id: int, image_feature_size_override: Optional[Union[int, List[int]]] = None, -): +) -> DecoderOnlyInputs: assert image_feature_size_override is None, ( "image_feature_size_override is not supported for Pixtral") - multi_modal_data = llm_inputs.get("multi_modal_data") + multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs + return inputs tokenizer = cached_get_tokenizer(model_config.tokenizer) processor = cached_get_processor(model_config.model) @@ -695,7 +695,7 @@ def input_processor_for_pixtral_hf( raise TypeError(f"Invalid image type: {type(image_data)}") replace_strings = [] - new_prompt = llm_inputs.get("prompt") + new_prompt = inputs.get("prompt") for image in image_data: w, h = image.size @@ -720,14 +720,10 @@ def input_processor_for_pixtral_hf( new_token_ids = tokenizer(new_prompt)["input_ids"] - # print("new_token_ids", new_token_ids) - # print("new_token_ids", len(new_token_ids)) - # print("new_prompt", new_prompt) - # NOTE: Create a defensive copy of the original inputs - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return DecoderOnlyInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) class PixtralHFRotaryEmbedding(nn.Module): diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 997dd33ea741f..a3540abdc23d3 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -62,8 +62,8 @@ from vllm.multimodal.base import MultiModalData from vllm.multimodal.image import cached_get_image_processor from vllm.sequence import IntermediateTensors, SequenceData -from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.config import uses_mrope +from vllm.transformers_utils.processor import cached_get_processor from .interfaces import SupportsMultiModal, SupportsPP from .utils import (PPMissingLayer, get_vit_attn_backend, From 7279180e1b0abcb9286966476606ccecfc6df61c Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 16 Oct 2024 15:16:44 +0000 Subject: [PATCH 07/12] Clean up --- examples/offline_inference_vision_language.py | 10 +- pixtral_hf.py | 167 ------------------ vllm/model_executor/models/llava.py | 14 +- vllm/model_executor/models/pixtral.py | 2 +- vllm/model_executor/models/utils.py | 6 - 5 files changed, 8 insertions(+), 191 deletions(-) delete mode 100644 pixtral_hf.py diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 4dcb3c9b47191..06b424abd50b5 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -277,17 +277,15 @@ def run_qwen2_vl(question: str, modality: str): return llm, prompt, stop_token_ids -# Pixtral -def run_pixtral(question: str, modality: str): +# Pixtral HF-format +def run_pixtral_hf(question: str, modality: str): assert modality == "image" model_name = "mistral-community/pixtral-12b" llm = LLM( model=model_name, - max_model_len=10000, - max_num_seqs=16, - enforce_eager=True, + max_model_len=8192, ) prompt = f"[INST]{question}\n[IMG][/INST]" @@ -365,7 +363,7 @@ def run_glm4v(question: str, modality: str): "NVLM_D": run_nvlm_d, "qwen_vl": run_qwen_vl, "qwen2_vl": run_qwen2_vl, - "pixtral": run_pixtral, + "pixtral_hf": run_pixtral_hf, "mllama": run_mllama, "molmo": run_molmo, "glm4v": run_glm4v, diff --git a/pixtral_hf.py b/pixtral_hf.py deleted file mode 100644 index d60cccb464842..0000000000000 --- a/pixtral_hf.py +++ /dev/null @@ -1,167 +0,0 @@ -import torch - -from vllm import LLM, SamplingParams -from vllm.assets.image import ImageAsset -from vllm.multimodal.utils import encode_image_base64 - -torch.set_printoptions(sci_mode=False) - -sampling_params = SamplingParams(temperature=0.0, max_tokens=100) - - -def image_url(asset: str): - image = ImageAsset(asset) - base64 = encode_image_base64(image.pil_image) - return f"data:image/jpeg;base64,{base64}" - - -def reference_one_image(): - from transformers import AutoProcessor, LlavaForConditionalGeneration - model_id = "mistral-community/pixtral-12b" - model = LlavaForConditionalGeneration.from_pretrained( - model_id, torch_dtype="auto").to("cuda") - processor = AutoProcessor.from_pretrained(model_id) - - IMGS = [ - ImageAsset("stop_sign").pil_image.convert("RGB"), - ] - PROMPT = "[INST][IMG]Describe the image.[/INST]" - - inputs = processor(text=PROMPT, images=IMGS, - return_tensors="pt").to(model.device) - print("inputs['input_ids']", inputs["input_ids"].shape) - print("detok(inputs['input_ids'])", - processor.batch_decode(inputs["input_ids"])[0]) - print("inputs['pixel_values']", [i.shape for i in inputs["pixel_values"]]) - generate_ids = model.generate(**inputs, - do_sample=False, - max_new_tokens=100) - output = processor.batch_decode(generate_ids, - skip_special_tokens=True, - clean_up_tokenization_spaces=False)[0] - print(output) - - -def pixtralhf_one_image(): - model_name = "mistral-community/pixtral-12b" - llm = LLM( - model=model_name, - max_num_seqs=1, - enforce_eager=True, - max_model_len=10000, - ) - - chat_template = "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\n\n\" }}\n {%- else %}\n {{- \"[INST]\" }}\n {%- endif %}\n {%- if message[\"content\"] is not string %}\n {%- for chunk in message[\"content\"] %}\n {%- if chunk[\"type\"] == \"text\" %}\n {{- chunk[\"content\"] }}\n {%- elif chunk[\"type\"] == \"image\" %}\n {{- \"[IMG]\" }}\n {%- else %}\n {{- raise_exception(\"Unrecognized content type!\") }}\n {%- endif %}\n {%- endfor %}\n {%- else %}\n {{- message[\"content\"] }}\n {%- endif %}\n {{- \"[/INST]\" }}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}" # noqa - - # image1 = ImageAsset("stop_sign").pil_image.convert("RGB") - # inputs = { - # "prompt": "[INST][IMG]Describe the image.[/INST]", - # "multi_modal_data": {"image": image1}, - # } - # outputs = llm.generate(inputs, sampling_params=sampling_params) - - messages = [ - { - "role": - "user", - "content": [ - { - "type": "text", - "text": "Describe the image." - }, - { - "type": "image_url", - "image_url": { - "url": image_url("stop_sign") - } - }, - ], - }, - ] - outputs = llm.chat(messages, - sampling_params=sampling_params, - chat_template=chat_template) - - print(outputs[0].outputs[0].text) - - -def pixtral_one_image(): - model_name = "mistralai/Pixtral-12B-2409" - llm = LLM( - model=model_name, - max_num_seqs=1, - enforce_eager=True, - max_model_len=10000, - tokenizer_mode="mistral", - ) - - messages = [ - { - "role": - "user", - "content": [ - { - "type": "text", - "text": "Describe the image." - }, - { - "type": "image_url", - "image_url": { - "url": image_url("stop_sign") - } - }, - ], - }, - ] - outputs = llm.chat(messages, sampling_params=sampling_params) - - print(outputs[0].outputs[0].text) - - -def pixtralhf_two_image(): - model_name = "mistral-community/pixtral-12b" - llm = LLM(model=model_name, - max_num_seqs=1, - enforce_eager=True, - max_model_len=10000, - limit_mm_per_prompt={"image": 2}) - - image1 = ImageAsset("cherry_blossom").pil_image.convert("RGB") - image2 = ImageAsset("stop_sign").pil_image.convert("RGB") - inputs = { - "prompt": "[INST][IMG][IMG]Describe the images.[/INST]", - "multi_modal_data": { - "image": [image1, image2] - }, - } - outputs = llm.generate(inputs, sampling_params=sampling_params) - - print(outputs[0].outputs[0].text) - - -def pixtralhf_fp8_one_image(): - model_name = "nm-testing/pixtral-12b-FP8-dynamic" - llm = LLM( - model=model_name, - max_num_seqs=1, - enforce_eager=True, - max_model_len=10000, - ) - - image1 = ImageAsset("cherry_blossom").pil_image.convert("RGB") - inputs = { - "prompt": "[INST][IMG]Describe the image.[/INST]", - "multi_modal_data": { - "image": image1 - }, - } - outputs = llm.generate(inputs, sampling_params=sampling_params) - - print(outputs[0].outputs[0].text) - - -# reference_one_image() -pixtralhf_one_image() -# pixtral_one_image() - -# pixtralhf_two_image() diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 5520bd5a8922e..77099838263fb 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -276,10 +276,6 @@ def sampler(self): def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) - # HACK due to: - # expected shape of pixel values is ('batch_size', '3', '1024', '1024') - # You supplied (2, 1, 3, 1024, 1024). - # data = data.reshape(-1, *data.shape[-3:]) actual_dims = tuple(data.shape[1:]) if actual_dims != expected_dims: @@ -305,6 +301,7 @@ def _parse_and_validate_image_input( f"Got type: {type(pixel_values)}") # Case for models like PixtralHF that have dynamic image sizes + # so we need to produce a list of tensors if image_sizes is not None: images = pixel_values if isinstance(images, torch.Tensor): @@ -314,13 +311,8 @@ def _parse_and_validate_image_input( images = [images[i] for i in range(images.size(0))] elif isinstance(images, list): # if passed as list flatten lists of tensors - def flatten(lst): - while isinstance(lst, list) and len(lst) == 1: - lst = lst[0] - return lst - - images = flatten(images) - # print("flattened", [img.shape for img in images]) + while isinstance(images, list) and len(images) == 1: + images = images[0] # TODO: Add validation based on image_sizes return LlavaImagePixelInputs( diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 09438813556e3..340bcc2f5d40d 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -627,7 +627,7 @@ def dummy_seq_data_for_pixtral_hf( else: image_feature_size = image_feature_size_override - return SequenceData.from_token_counts( + return SequenceData.from_prompt_token_counts( (image_token_id, image_feature_size * num_images), (0, seq_len - image_feature_size * num_images), ) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 37ad53b73cfc6..48c7509c97a81 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -260,15 +260,9 @@ def flatten_bn( The input tensor should have shape ``(B, N, ...)```. """ if isinstance(x, torch.Tensor): - print(f"flatten_bn isinstance(x, torch.Tensor), before:{x.shape}, " - f"after:{x.flatten(0, 1).shape}") return x.flatten(0, 1) if concat: - print("flatten_bn concat") - for xi in x[0][0]: - print(f"before:{xi.shape}") - print(f"after:{torch.cat(x).squeeze().shape}") return torch.cat(x).squeeze() return [x_n for x_b in x for x_n in x_b] From f1cc5691fa58b0382086f4856034f5af89f2ce56 Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 16 Oct 2024 15:57:54 +0000 Subject: [PATCH 08/12] Remove flatten_bn change --- vllm/model_executor/models/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 48c7509c97a81..9e2f5476f3aff 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -263,7 +263,7 @@ def flatten_bn( return x.flatten(0, 1) if concat: - return torch.cat(x).squeeze() + return torch.cat(x) return [x_n for x_b in x for x_n in x_b] From 466ea3e57435106d91d6c567fc092da76fb2a1a5 Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 16 Oct 2024 20:26:39 +0000 Subject: [PATCH 09/12] Better comments --- vllm/model_executor/models/llava.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 77099838263fb..a83b7d05df7aa 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -245,7 +245,8 @@ def __init__(self, self.config = config self.multimodal_config = multimodal_config - # HACK: Special cases for pixtral + # NOTE: These are special cases for Pixtral-12B in the HF-format + # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa if (config.text_config.architectures is None and config.text_config.model_type == "mistral"): config.text_config.architectures = ["MistralForCausalLM"] From 7f1eec39c912afef24bae5b48998fb5e12caf41d Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 17 Oct 2024 19:47:56 +0000 Subject: [PATCH 10/12] Review comments --- vllm/model_executor/models/pixtral.py | 83 +++++---------------------- 1 file changed, 14 insertions(+), 69 deletions(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 340bcc2f5d40d..93f7a3027be71 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -12,12 +12,17 @@ from transformers import PixtralVisionConfig, PretrainedConfig from transformers.models.pixtral.image_processing_pixtral import ( _num_image_tokens) +from transformers.models.pixtral.modeling_pixtral import ( + apply_rotary_pos_emb, generate_block_attention_mask, + position_ids_in_meshgrid) from xformers.ops.fmha import memory_efficient_attention from xformers.ops.fmha.attn_bias import BlockDiagonalMask from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) +from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -721,9 +726,9 @@ def input_processor_for_pixtral_hf( new_token_ids = tokenizer(new_prompt)["input_ids"] # NOTE: Create a defensive copy of the original inputs - return DecoderOnlyInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) + return token_inputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) class PixtralHFRotaryEmbedding(nn.Module): @@ -787,23 +792,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - class PixtralHFMLP(nn.Module): def __init__(self, config: PixtralVisionConfig): @@ -818,9 +806,10 @@ def __init__(self, config: PixtralVisionConfig): self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.act = get_act_fn(config.hidden_act) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) class PixtralHFAttention(nn.Module): @@ -937,44 +926,8 @@ def forward( return x -def position_ids_in_meshgrid(patch_embeds_list, max_width): - positions = [] - for patch in patch_embeds_list: - height, width = patch.shape[-2:] - mesh = torch.meshgrid(torch.arange(height), - torch.arange(width), - indexing="ij") - h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) - ids = h_grid * max_width + v_grid - positions.append(ids[:, 0]) - return torch.cat(positions) - - -def generate_block_attention_mask(patch_embeds_list, tensor): - dtype = tensor.dtype - device = tensor.device - seq_len = tensor.shape[1] - d_min = torch.finfo(dtype).min - causal_mask = torch.full((seq_len, seq_len), - fill_value=d_min, - dtype=dtype, - device=device) - - block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1) - block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1) - for start, end in zip(block_start_idx, block_end_idx): - causal_mask[start:end, start:end] = 0 - - causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, - -1) - return causal_mask - - class PixtralHFVisionModel(nn.Module): - config_class = PixtralVisionConfig - main_input_name = "pixel_values" - def __init__(self, config: PixtralVisionConfig): super().__init__() @@ -990,14 +943,6 @@ def __init__(self, config: PixtralVisionConfig): self.transformer = PixtralHFTransformer(config) self.patch_positional_embedding = PixtralHFRotaryEmbedding(config) - @property - def device(self) -> torch.device: - return next(self.parameters()).device - - @property - def dtype(self) -> torch.device: - return next(self.parameters()).dtype - def forward( self, pixel_values: List[torch.Tensor], @@ -1011,9 +956,9 @@ def forward( all tokens of all images of shape (N_toks, D) """ # pass images through initial convolution independently + dtype = next(self.parameters()).dtype patch_embeds_list = [ - self.patch_conv(img.unsqueeze(0).to(self.dtype)) - for img in pixel_values + self.patch_conv(img.unsqueeze(0).to(dtype)) for img in pixel_values ] # flatten to a single sequence @@ -1025,7 +970,7 @@ def forward( position_ids = position_ids_in_meshgrid( patch_embeds_list, max_width=self.config.image_size // self.config.patch_size).to( - self.device) + patch_embeds.device) position_embedding = self.patch_positional_embedding( patch_embeds, position_ids) From a8c0f3540d4f909b216f745470b7fa4926d4fff2 Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 18 Oct 2024 16:01:28 +0000 Subject: [PATCH 11/12] Fix new_token_ids --- vllm/model_executor/layers/activation.py | 2 + vllm/model_executor/models/pixtral.py | 68 ++++++++++++++++++------ 2 files changed, 54 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index cf99306c9caef..8de3385a257f8 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -264,6 +264,8 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): lambda: nn.ReLU(), "relu2": lambda: ReLUSquaredActivation(), + "silu": + lambda: nn.SiLU(), "quick_gelu": lambda: QuickGELU(), }) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 93f7a3027be71..0a033fa8eb497 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -690,7 +690,6 @@ def input_processor_for_pixtral_hf( if multi_modal_data is None or "image" not in multi_modal_data: return inputs - tokenizer = cached_get_tokenizer(model_config.tokenizer) processor = cached_get_processor(model_config.model) image_data = multi_modal_data["image"] @@ -699,31 +698,66 @@ def input_processor_for_pixtral_hf( elif not is_list_of(image_data, Image.Image): raise TypeError(f"Invalid image type: {type(image_data)}") - replace_strings = [] new_prompt = inputs.get("prompt") + new_token_ids = inputs["prompt_token_ids"] + + # Update new_prompt if present + if new_prompt: + replace_strings = [] + for image in image_data: + w, h = image.size + + (num_width_tokens, + num_height_tokens) = get_pixtral_hf_image_feature_size( + hf_config, image_width=w, image_height=h) + + replace_tokens = [[processor.image_token] * num_width_tokens + + [processor.image_break_token] + ] * num_height_tokens + # Flatten list + replace_tokens = [ + item for sublist in replace_tokens for item in sublist + ] + replace_tokens[-1] = processor.image_end_token + replace_str = "".join(replace_tokens) + replace_strings.append(replace_str) + new_prompt = new_prompt.replace(processor.image_token, + "", 1) + + while "" in new_prompt: + replace_str = replace_strings.pop(0) + new_prompt = new_prompt.replace("", replace_str, 1) + + # Update new_token_ids + image_token_id = 10 + image_break_id = 12 + image_end_id = 13 + placeholder_token_id = -999 + replace_tokens_list = [] for image in image_data: w, h = image.size num_width_tokens, num_height_tokens = get_pixtral_hf_image_feature_size( hf_config, image_width=w, image_height=h) - replace_tokens = [[processor.image_token] * num_width_tokens + - [processor.image_break_token]] * num_height_tokens + replace_tokens = [[image_token_id] * num_width_tokens + + [image_break_id]] * num_height_tokens # Flatten list replace_tokens = [ item for sublist in replace_tokens for item in sublist ] - replace_tokens[-1] = processor.image_end_token - replace_str = "".join(replace_tokens) - replace_strings.append(replace_str) - new_prompt = new_prompt.replace(processor.image_token, "", - 1) - - while "" in new_prompt: - replace_str = replace_strings.pop(0) - new_prompt = new_prompt.replace("", replace_str, 1) - - new_token_ids = tokenizer(new_prompt)["input_ids"] + replace_tokens[-1] = image_end_id + replace_tokens_list.append(replace_tokens) + # Replace image id with placeholder id + next_image_index = new_token_ids.index(image_token_id) + new_token_ids[next_image_index] = placeholder_token_id + + while placeholder_token_id in new_token_ids: + replace_tokens = replace_tokens_list.pop(0) + next_image_index = new_token_ids.index(placeholder_token_id) + prefix = new_token_ids[:next_image_index] + postfix = new_token_ids[next_image_index + 1:] + new_token_ids = prefix + replace_tokens + postfix # NOTE: Create a defensive copy of the original inputs return token_inputs(prompt_token_ids=new_token_ids, @@ -958,7 +992,9 @@ def forward( # pass images through initial convolution independently dtype = next(self.parameters()).dtype patch_embeds_list = [ - self.patch_conv(img.unsqueeze(0).to(dtype)) for img in pixel_values + self.patch_conv( + img.reshape(-1, img.shape[-3], img.shape[-2], + img.shape[-1]).to(dtype)) for img in pixel_values ] # flatten to a single sequence From 775ff5ed3163c1d31a41da787d00105780294647 Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 18 Oct 2024 17:12:15 +0000 Subject: [PATCH 12/12] Reuse HF PixtralRotaryEmbedding --- docs/source/models/supported_models.rst | 2 +- vllm/model_executor/models/pixtral.py | 76 +++---------------------- 2 files changed, 10 insertions(+), 68 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index b5fa83b437ac4..6cebd7f29fd85 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -433,7 +433,7 @@ Text Generation * - :code:`PixtralForConditionalGeneration` - Pixtral - T + I\ :sup:`+` - - :code:`mistralai/Pixtral-12B-2409` + - :code:`mistralai/Pixtral-12B-2409`, :code:`mistral-community/pixtral-12b` etc. - - ✅︎ * - :code:`QWenLMHeadModel` diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 0a033fa8eb497..d09cbe5ca02e9 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -13,8 +13,8 @@ from transformers.models.pixtral.image_processing_pixtral import ( _num_image_tokens) from transformers.models.pixtral.modeling_pixtral import ( - apply_rotary_pos_emb, generate_block_attention_mask, - position_ids_in_meshgrid) + PixtralRotaryEmbedding, apply_rotary_pos_emb, + generate_block_attention_mask, position_ids_in_meshgrid) from xformers.ops.fmha import memory_efficient_attention from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -765,67 +765,6 @@ def input_processor_for_pixtral_hf( multi_modal_data=multi_modal_data) -class PixtralHFRotaryEmbedding(nn.Module): - """ - The key with pixtral embedding is just that you have a frequency for each - pixel positions. If you have height x width pixels (or embedding pixels), - then the frequency used for ROPE is given by indexing the pre_computed - frequency on the width and height. - - What you output is of dimension (batch, height * width, dim) with dim the - embed dim. - - This simply means that for each image hidden state, you are going to add - a corresponding positional embedding, based on its index in the grid. - """ - - def __init__(self, config): - super().__init__() - self.rope_type = "default" - self.dim = config.head_dim - self.base = config.rope_theta - max_patches_per_side = config.image_size // config.patch_size - freqs = 1.0 / (self.base - **(torch.arange(0, self.dim, 2).float() / self.dim)) - - h = torch.arange(max_patches_per_side, device=freqs.device) - w = torch.arange(max_patches_per_side, device=freqs.device) - - freqs_h = torch.outer(h, freqs[::2]).float() - freqs_w = torch.outer(w, freqs[1::2]).float() - inv_freq = torch.cat( - [ - freqs_h[:, None, :].repeat(1, max_patches_per_side, 1), - freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1), - ], - dim=-1, - ).reshape( - -1, self.dim // 2 - ) # we reshape to only index on the position indexes, not tuple of - # indexes. Different from paper, but it uses a different permutation - # in order to obtain the same calculation - - # TODO maybe make it torch compatible later on. We can also just slice - self.register_buffer("inv_freq", - torch.cat((inv_freq, inv_freq), dim=-1), - persistent=False) - - @torch.no_grad() - def forward(self, x, position_ids): - # Core RoPE block - freqs = self.inv_freq[position_ids] - # position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance( - device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - emb = freqs - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - class PixtralHFMLP(nn.Module): def __init__(self, config: PixtralVisionConfig): @@ -975,7 +914,10 @@ def __init__(self, config: PixtralVisionConfig): ) self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5) self.transformer = PixtralHFTransformer(config) - self.patch_positional_embedding = PixtralHFRotaryEmbedding(config) + self.dtype = next(self.parameters()).dtype + self.device = next(self.parameters()).device + self.patch_positional_embedding = PixtralRotaryEmbedding( + config, self.device) def forward( self, @@ -990,11 +932,11 @@ def forward( all tokens of all images of shape (N_toks, D) """ # pass images through initial convolution independently - dtype = next(self.parameters()).dtype patch_embeds_list = [ self.patch_conv( img.reshape(-1, img.shape[-3], img.shape[-2], - img.shape[-1]).to(dtype)) for img in pixel_values + img.shape[-1]).to(self.dtype)) + for img in pixel_values ] # flatten to a single sequence @@ -1006,7 +948,7 @@ def forward( position_ids = position_ids_in_meshgrid( patch_embeds_list, max_width=self.config.image_size // self.config.patch_size).to( - patch_embeds.device) + self.device) position_embedding = self.patch_positional_embedding( patch_embeds, position_ids)