From ef91ba810e4fdf900fe4e7b13bcf92985cbcd15f Mon Sep 17 00:00:00 2001 From: hezhihui Date: Tue, 30 Jul 2024 17:56:55 +0800 Subject: [PATCH 01/25] fix bug for image pos embedding. --- docs/source/models/supported_models.rst | 2 +- vllm/model_executor/models/minicpmv.py | 984 +++++++++++++++++++++++- vllm/model_executor/models/qwen2.py | 12 +- 3 files changed, 969 insertions(+), 29 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 4fe33e5ab5d80..7821214bf70be 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -222,7 +222,7 @@ Vision Language Models - * - :code:`MiniCPM-V` - MiniCPM-V - - :code:`openbmb/MiniCPM-V-2`, :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc. + - :code:`HwwwH/MiniCPM-V-2(Temporary)`, :code:`openbmb/MiniCPM-V-2(incoming...)`, :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc. - ---- diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 8563216d9c392..7b15ee4006728 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -23,18 +23,30 @@ """Inference-only MiniCPM-V-2 model compatible with HuggingFace weights.""" import math import re +import os +import logging +import warnings from functools import partial -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from PIL import Image from torch import nn -from torch.nn.init import trunc_normal_ +from torch.nn.init import trunc_normal_, _calculate_fan_in_and_fan_out +from transformers.utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + replace_return_docstrings, +) +from transformers.activations import ACT2FN +from transformers.modeling_utils import PreTrainedModel from transformers.configuration_utils import PretrainedConfig -from transformers.models.idefics2.modeling_idefics2 import ( - Idefics2VisionTransformer) +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig @@ -45,6 +57,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsVision from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM from vllm.model_executor.models.minicpm import MiniCPMForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -57,6 +70,899 @@ "language_model.model": "language_model", } +logger = logging.getLogger("vllm") + + +# For Siglip: copied from HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes +class SiglipVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a + Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + Example: + ```python + >>> from transformers import SiglipVisionConfig, SiglipVisionModel + >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipVisionConfig() + >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipVisionModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "siglip_vision_model" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from SiglipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" + +SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "google/siglip-base-patch16-224", + # See all SigLIP models at https://huggingface.co/models?filter=siglip +] + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + if tensor.dtype in [torch.float16, torch.bfloat16]: + # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu + og_dtype = tensor.dtype + tensor = tensor.to(torch.float32) + tensor.erfinv_() + tensor = tensor.to(og_dtype) + else: + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + if tensor.dtype == torch.float16: + # The `clamp_` op is not (yet?) defined in float16+cpu + tensor = tensor.to(torch.float32) + tensor.clamp_(min=a, max=b) + tensor = tensor.to(torch.float16) + else: + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsquently scaled and shifted by the mean and std args. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip +class SiglipVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class SiglipVisionEmbeddings(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor, tgt_sizes: Optional[torch.IntTensor]=None) -> torch.Tensor: + batch_size = pixel_values.size(0) + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) + max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) + position_ids = torch.full( + size=( + batch_size, + max_nb_patches_h * max_nb_patches_w, + ), + fill_value=0, + ) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + if tgt_sizes is not None: + nb_patches_h = tgt_sizes[batch_idx][0] + nb_patches_w = tgt_sizes[batch_idx][1] + else: + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class SiglipFlashAttention2(SiglipAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False # Hack to make sure we don't use a causal mask + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + # if past_key_value is not None: + # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + "The input hidden states seems to be silently casted in float32, this might be related to the fact" + " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip +class SiglipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip +class SiglipEncoderLayer(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.self_attn = ( + SiglipAttention(config) + if not self._use_flash_attention_2 + else SiglipFlashAttention2(config) + ) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SiglipPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SiglipVisionConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + + if isinstance(module, SiglipVisionEmbeddings): + width = self.config.hidden_size + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SiglipAttention): + nn.init.normal_(module.q_proj.weight) + nn.init.normal_(module.k_proj.weight) + nn.init.normal_(module.v_proj.weight) + nn.init.normal_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, SiglipMLP): + nn.init.normal_(module.fc1.weight) + nn.init.normal_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +SIGLIP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`SiglipVisionConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +SIGLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip +class SiglipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`SiglipEncoderLayer`]. + Args: + config: SiglipConfig + """ + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + +@add_start_docstrings( + """The vision model from SigLIP without any head or projection on top.""", + SIGLIP_START_DOCSTRING +) +class SiglipVisionTransformer(SiglipPreTrainedModel): + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + _supports_flash_attn_2 = True + + def __init__(self, config: SiglipVisionConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + tgt_sizes: Optional[torch.IntTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_attention_mask = torch.ones( + size=( + batch_size, + pixel_values.size(2) // self.config.patch_size, + pixel_values.size(3) // self.config.patch_size, + ), + dtype=torch.bool, + device=pixel_values.device, + ) + + hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + attention_mask=None + else: + attention_mask = ( + _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) + if not self._use_flash_attention_2 + else patch_attention_mask + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + if not return_dict: + return (last_hidden_state, None) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=None, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + def get_abs_pos(abs_pos, tgt_size): # abs_pos: L, C @@ -407,10 +1313,14 @@ def init_llm(self, config, cache_config, quant_config): return MiniCPMForCausalLM(config, cache_config=cache_config, quant_config=quant_config) - else: + elif self.version == 2.5: return LlamaForCausalLM(config, cache_config=cache_config, quant_config=quant_config) + else: + return Qwen2ForCausalLM(config, + cache_config=cache_config, + quant_config=quant_config) def init_vision_module(self): if self.version == 2.0: @@ -433,10 +1343,21 @@ def init_vision_module(self): if self.config.drop_vision_last_layer: model.blocks = model.blocks[:-1] - else: + elif self.version == 2.5: + from transformers.models.idefics2.modeling_idefics2 import ( + Idefics2VisionTransformer) model = Idefics2VisionTransformer(self.config.vision_config) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] + else: + if self.config._attn_implementation == 'flash_attention_2': + self.config.vision_config._attn_implementation = 'flash_attention_2' + else: + # not suport sdpa + self.config.vision_config._attn_implementation = 'eager' + model = SiglipVisionTransformer(self.config.vision_config) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] return model def init_resampler(self, embed_dim, vision_dim): @@ -484,21 +1405,31 @@ def get_vision_embedding(self, num_prefix_tokens:] res.append(self.resampler(vision_embedding, tgt_size)) return torch.vstack(res) - else: + elif version == 2.5: vision_embedding = self.vpm( pixel_values.type(dtype), patch_attention_mask=patch_attn_mask).last_hidden_state vision_embedding = self.resampler(vision_embedding, tgt_sizes) + else: + vision_embedding = self.vpm( + pixel_values.type(dtype), + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes).last_hidden_state def get_image_bounds(self, input_ids): tokenizer = cached_get_tokenizer(self.config._name_or_path, trust_remote_code=True) - im_start_token_id = tokenizer.im_start_id - im_end_token_id = tokenizer.im_end_id - image_start_tokens = torch.where(input_ids == im_start_token_id)[0] + if not hasattr(tokenizer, "slice_start_id"): + start_cond = input_ids == tokenizer.im_start_id + end_cond = input_ids == tokenizer.im_end_id + else: + start_cond = (input_ids == tokenizer.im_start_id) | (input_ids == tokenizer.slice_start_id) + end_cond = (input_ids == tokenizer.im_end_id) | (input_ids == tokenizer.slice_end_id) + + image_start_tokens = torch.where(start_cond)[0] image_start_tokens += 1 - image_end_tokens = torch.where(input_ids == im_end_token_id)[0] - valid_image_nums = min(len(image_start_tokens), len(image_end_tokens)) + image_end_tokens = torch.where(end_cond)[0] + valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) if valid_image_nums == 0: return [] image_bound = torch.hstack([ @@ -534,17 +1465,25 @@ def get_vision_hidden_states(self, data): B, L, _ = all_pixel_values.shape all_pixel_values = all_pixel_values.permute( 0, 2, 1).reshape(B, 3, -1, L) - patch_attn_mask = torch.zeros((B, 1, max_patches), - dtype=torch.bool, - device=device) - for i in range(B): - patch_attn_mask[i, :tgt_sizes[i][0] * - tgt_sizes[i][1]] = True - - vision_embedding = self.vpm( - all_pixel_values.type(dtype), - patch_attention_mask=patch_attn_mask).last_hidden_state + dtype=torch.bool, + device=device) + if self.version == 2.5: + for i in range(B): + patch_attn_mask[i, :tgt_sizes[i][0] * + tgt_sizes[i][1]] = True + vision_embedding = self.vpm( + all_pixel_values.type(dtype), + patch_attention_mask=patch_attn_mask).last_hidden_state + else: + for i in range(B): + patch_attn_mask[i, 0, :tgt_sizes[i][0] * + tgt_sizes[i][1]] = True + vision_embedding = self.vpm( + all_pixel_values.type(dtype), + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes).last_hidden_state + vision_hidden_states = self.resampler( vision_embedding, tgt_sizes) @@ -613,7 +1552,6 @@ def forward( "input_ids": input_ids, "tgt_sizes": kwargs.pop("tgt_sizes", None), } - inputs = self.process_multimodal_inputs(inputs) vlm_embeddings, vision_hidden_states = self.get_embedding(inputs) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 3deb3d8840cc4..33e32bd11710c 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -30,7 +30,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -253,10 +253,10 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + input_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds + if input_embeds is not None: + hidden_states = input_embeds else: hidden_states = self.embed_tokens(input_ids) residual = None @@ -340,9 +340,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + input_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors, + input_embeds) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, From 24543d1024c9a8b9937a00a14d273cd29fa3e033 Mon Sep 17 00:00:00 2001 From: hezhihui Date: Tue, 30 Jul 2024 18:14:22 +0800 Subject: [PATCH 02/25] format --- vllm/model_executor/models/minicpmv.py | 272 +++++++++++++++---------- vllm/model_executor/models/qwen2.py | 2 +- 2 files changed, 167 insertions(+), 107 deletions(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 7b15ee4006728..0424b8cb61c2b 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -145,23 +145,28 @@ def __init__( self.hidden_act = hidden_act @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, + os.PathLike], + **kwargs) -> "PretrainedConfig": cls._set_token_in_kwargs(kwargs) - config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs) # get the vision config dict if we are loading from SiglipConfig if config_dict.get("model_type") == "siglip": config_dict = config_dict["vision_config"] - if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + if "model_type" in config_dict and hasattr( + cls, + "model_type") and config_dict["model_type"] != cls.model_type: logger.warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." ) return cls.from_dict(config_dict, **kwargs) - + _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" @@ -180,7 +185,8 @@ def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) return ( indices, cu_seqlens, @@ -237,9 +243,11 @@ def norm_cdf(x): tensor.clamp_(min=a, max=b) -def trunc_normal_tf_( - tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 -) -> torch.Tensor: +def trunc_normal_tf_(tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0) -> torch.Tensor: """Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` @@ -321,6 +329,7 @@ class SiglipVisionModelOutput(ModelOutput): class SiglipVisionEmbeddings(nn.Module): + def __init__(self, config: SiglipVisionConfig): super().__init__() self.config = config @@ -339,9 +348,13 @@ def __init__(self, config: SiglipVisionConfig): self.num_patches_per_side = self.image_size // self.patch_size self.num_patches = self.num_patches_per_side**2 self.num_positions = self.num_patches - self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.position_embedding = nn.Embedding(self.num_positions, + self.embed_dim) - def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor, tgt_sizes: Optional[torch.IntTensor]=None) -> torch.Tensor: + def forward(self, + pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor, + tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor: batch_size = pixel_values.size(0) patch_embeds = self.patch_embedding(pixel_values) @@ -349,7 +362,8 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size - boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, + 1 / self.num_patches_per_side) position_ids = torch.full( size=( batch_size, @@ -369,10 +383,15 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) - bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) - bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + bucket_coords_h = torch.bucketize(fractional_coords_h, + boundaries, + right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, + boundaries, + right=True) - pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + + bucket_coords_w).flatten() position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids position_ids = position_ids.to(self.position_embedding.weight.device) @@ -394,8 +413,7 @@ def __init__(self, config): if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." - ) + f" {self.num_heads}).") self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout @@ -409,7 +427,8 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" batch_size, q_len, _ = hidden_states.size() @@ -418,18 +437,22 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) k_v_seq_len = key_states.shape[-2] - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + attn_weights = torch.matmul(query_states, key_states.transpose( + 2, 3)) * self.scale - if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + if attn_weights.size() != (batch_size, self.num_heads, q_len, + k_v_seq_len): raise ValueError( f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" - f" {attn_weights.size()}" - ) + f" {attn_weights.size()}") if attention_mask is not None: if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): @@ -439,15 +462,20 @@ def forward( attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_weights = nn.functional.softmax(attn_weights, + dim=-1, + dtype=torch.float32).to( + query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, + p=self.dropout, + training=self.training) attn_output = torch.matmul(attn_weights, value_states) - if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + if attn_output.size() != (batch_size, self.num_heads, q_len, + self.head_dim): raise ValueError( f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + f" {attn_output.size()}") attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) @@ -477,7 +505,8 @@ def forward( output_attentions: bool = False, use_cache: bool = False, **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -489,13 +518,17 @@ def forward( # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len += past_key_value.get_usable_length( + kv_seq_len, self.layer_idx) # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -530,18 +563,21 @@ def forward( logger.warning_once( "The input hidden states seems to be silently casted in float32, this might be related to the fact" " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) + f" {target_dtype}.") query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate - ) + attn_output = self._flash_attention_forward(query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate) - attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous() + attn_output = attn_output.reshape(bsz, q_len, + self.embed_dim).contiguous() attn_output = self.out_proj(attn_output) if not output_attentions: @@ -549,9 +585,14 @@ def forward( return attn_output, attn_weights - def _flash_attention_forward( - self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None - ): + def _flash_attention_forward(self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. @@ -578,8 +619,8 @@ def _flash_attention_forward( if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) + query_states, key_states, value_states, attention_mask, + query_length) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens @@ -597,28 +638,34 @@ def _flash_attention_forward( causal=causal, ) - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, + query_length) else: - attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal - ) + attn_output = flash_attn_func(query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal) return attn_output - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, + query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) if query_length == kv_seq_len: query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k - ) + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, + head_dim), indices_k) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k @@ -632,7 +679,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask) return ( query_layer, @@ -646,6 +694,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip class SiglipMLP(nn.Module): + def __init__(self, config): super().__init__() self.config = config @@ -662,18 +711,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip class SiglipEncoderLayer(nn.Module): + def __init__(self, config: SiglipVisionConfig): super().__init__() self.embed_dim = config.hidden_size self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self.self_attn = ( - SiglipAttention(config) - if not self._use_flash_attention_2 - else SiglipFlashAttention2(config) - ) - self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = (SiglipAttention(config) + if not self._use_flash_attention_2 else + SiglipFlashAttention2(config)) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) self.mlp = SiglipMLP(config) - self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) def forward( self, @@ -706,10 +756,10 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) + outputs = (hidden_states, ) if output_attentions: - outputs += (attn_weights,) + outputs += (attn_weights, ) return outputs @@ -729,7 +779,8 @@ def _init_weights(self, module): if isinstance(module, SiglipVisionEmbeddings): width = self.config.hidden_size - nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + nn.init.normal_(module.position_embedding.weight, + std=1 / np.sqrt(width)) elif isinstance(module, nn.Embedding): default_flax_embed_init(module.weight) elif isinstance(module, SiglipAttention): @@ -768,7 +819,6 @@ def _init_weights(self, module): configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ - SIGLIP_VISION_INPUTS_DOCSTRING = r""" Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): @@ -797,7 +847,9 @@ class SiglipEncoder(nn.Module): def __init__(self, config: SiglipVisionConfig): super().__init__() self.config = config - self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList([ + SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers) + ]) self.gradient_checkpointing = False # Ignore copy @@ -830,9 +882,9 @@ def forward( Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_states = () if output_hidden_states else None @@ -841,7 +893,7 @@ def forward( hidden_states = inputs_embeds for encoder_layer in self.layers: if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) + encoder_states = encoder_states + (hidden_states, ) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, @@ -859,21 +911,23 @@ def forward( hidden_states = layer_outputs[0] if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) + all_attentions = all_attentions + (layer_outputs[1], ) if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) + encoder_states = encoder_states + (hidden_states, ) if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) + return tuple( + v for v in [hidden_states, encoder_states, all_attentions] + if v is not None) + return BaseModelOutput(last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions) + @add_start_docstrings( """The vision model from SigLIP without any head or projection on top.""", - SIGLIP_START_DOCSTRING -) + SIGLIP_START_DOCSTRING) class SiglipVisionTransformer(SiglipPreTrainedModel): config_class = SiglipVisionConfig main_input_name = "pixel_values" @@ -886,7 +940,8 @@ def __init__(self, config: SiglipVisionConfig): self.embeddings = SiglipVisionEmbeddings(config) self.encoder = SiglipEncoder(config) - self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" # Initialize weights and apply final processing @@ -896,7 +951,8 @@ def get_input_embeddings(self) -> nn.Module: return self.embeddings.patch_embedding @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, + config_class=SiglipVisionConfig) def forward( self, pixel_values, @@ -910,9 +966,9 @@ def forward( Returns: """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size = pixel_values.size(0) @@ -927,20 +983,22 @@ def forward( device=pixel_values.device, ) - hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes) + hidden_states = self.embeddings( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + tgt_sizes=tgt_sizes) patch_attention_mask = patch_attention_mask.view(batch_size, -1) # The call to `_upad_input` in `_flash_attention_forward` is expensive # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), # avoiding passing the attention_mask, which is equivalent to attending to the full sequence if not torch.any(~patch_attention_mask): - attention_mask=None + attention_mask = None else: - attention_mask = ( - _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) - if not self._use_flash_attention_2 - else patch_attention_mask - ) + attention_mask = (_prepare_4d_attention_mask( + patch_attention_mask, hidden_states.dtype) + if not self._use_flash_attention_2 else + patch_attention_mask) encoder_outputs = self.encoder( inputs_embeds=hidden_states, @@ -1411,20 +1469,21 @@ def get_vision_embedding(self, patch_attention_mask=patch_attn_mask).last_hidden_state vision_embedding = self.resampler(vision_embedding, tgt_sizes) else: - vision_embedding = self.vpm( - pixel_values.type(dtype), - patch_attention_mask=patch_attn_mask, - tgt_sizes=tgt_sizes).last_hidden_state + vision_embedding = self.vpm(pixel_values.type(dtype), + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes).last_hidden_state def get_image_bounds(self, input_ids): tokenizer = cached_get_tokenizer(self.config._name_or_path, trust_remote_code=True) if not hasattr(tokenizer, "slice_start_id"): start_cond = input_ids == tokenizer.im_start_id - end_cond = input_ids == tokenizer.im_end_id + end_cond = input_ids == tokenizer.im_end_id else: - start_cond = (input_ids == tokenizer.im_start_id) | (input_ids == tokenizer.slice_start_id) - end_cond = (input_ids == tokenizer.im_end_id) | (input_ids == tokenizer.slice_end_id) + start_cond = (input_ids == tokenizer.im_start_id) | ( + input_ids == tokenizer.slice_start_id) + end_cond = (input_ids == tokenizer.im_end_id) | ( + input_ids == tokenizer.slice_end_id) image_start_tokens = torch.where(start_cond)[0] image_start_tokens += 1 @@ -1466,23 +1525,24 @@ def get_vision_hidden_states(self, data): all_pixel_values = all_pixel_values.permute( 0, 2, 1).reshape(B, 3, -1, L) patch_attn_mask = torch.zeros((B, 1, max_patches), - dtype=torch.bool, - device=device) + dtype=torch.bool, + device=device) if self.version == 2.5: for i in range(B): patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True vision_embedding = self.vpm( all_pixel_values.type(dtype), - patch_attention_mask=patch_attn_mask).last_hidden_state + patch_attention_mask=patch_attn_mask + ).last_hidden_state else: for i in range(B): patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True vision_embedding = self.vpm( all_pixel_values.type(dtype), - patch_attention_mask=patch_attn_mask, - tgt_sizes=tgt_sizes).last_hidden_state + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes).last_hidden_state vision_hidden_states = self.resampler( vision_embedding, tgt_sizes) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 33e32bd11710c..cea170a3d4f3a 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -343,7 +343,7 @@ def forward( input_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + attn_metadata, intermediate_tensors, input_embeds) return hidden_states From 326c0e9e1f0d83d1832792dedfa8806cea85f23e Mon Sep 17 00:00:00 2001 From: hezhihui Date: Tue, 30 Jul 2024 18:36:08 +0800 Subject: [PATCH 03/25] format --- vllm/model_executor/models/minicpmv.py | 309 ++++++------------------- vllm/model_executor/models/qwen2.py | 2 +- 2 files changed, 71 insertions(+), 240 deletions(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 0424b8cb61c2b..dceee5174f147 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -21,10 +21,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiniCPM-V-2 model compatible with HuggingFace weights.""" +import logging import math -import re import os -import logging +import re import warnings from functools import partial from typing import Iterable, List, Optional, Tuple, Union @@ -34,19 +34,15 @@ import torch.nn.functional as F from PIL import Image from torch import nn -from torch.nn.init import trunc_normal_, _calculate_fan_in_and_fan_out -from transformers.utils import ( - ModelOutput, - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - replace_return_docstrings, -) +from torch.nn.init import _calculate_fan_in_and_fan_out, trunc_normal_ from transformers.activations import ACT2FN -from transformers.modeling_utils import PreTrainedModel from transformers.configuration_utils import PretrainedConfig from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask -from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.modeling_outputs import (BaseModelOutput, + BaseModelOutputWithPooling) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import (ModelOutput, is_flash_attn_2_available, + replace_return_docstrings) from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig @@ -57,8 +53,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsVision from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM from vllm.model_executor.models.minicpm import MiniCPMForCausalLM +from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import (cached_get_image_processor, @@ -73,47 +69,10 @@ logger = logging.getLogger("vllm") -# For Siglip: copied from HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes +# For Siglip: copied from +# HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes +# Remove hints as there's little possibility to change these code. class SiglipVisionConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a - Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a - configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip - [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - Args: - hidden_size (`int`, *optional*, defaults to 768): - Dimensionality of the encoder layers and the pooler layer. - intermediate_size (`int`, *optional*, defaults to 3072): - Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. - num_hidden_layers (`int`, *optional*, defaults to 12): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 12): - Number of attention heads for each attention layer in the Transformer encoder. - num_channels (`int`, *optional*, defaults to 3): - Number of channels in the input images. - image_size (`int`, *optional*, defaults to 224): - The size (resolution) of each image. - patch_size (`int`, *optional*, defaults to 16): - The size (resolution) of each patch. - hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): - The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, - `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. - layer_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the layer normalization layers. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - Example: - ```python - >>> from transformers import SiglipVisionConfig, SiglipVisionModel - >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration - >>> configuration = SiglipVisionConfig() - >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration - >>> model = SiglipVisionModel(configuration) - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" model_type = "siglip_vision_model" @@ -161,9 +120,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, cls, "model_type") and config_dict["model_type"] != cls.model_type: logger.warning( - f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " - f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." - ) + "You are using a model of type %s to " + "instantiate a model of type %s. " + "This is not supported for all configurations" + "of models and can yield errors.", config_dict['model_type'], + cls.model_type) return cls.from_dict(config_dict, **kwargs) @@ -177,7 +138,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + from flash_attn.bert_padding import (index_first_axis, pad_input, # noqa + unpad_input) # Copied from transformers.models.llama.modeling_llama._get_unpad_data @@ -195,8 +157,7 @@ def _get_unpad_data(attention_mask): def _trunc_normal_(tensor, mean, std, a, b): - # Cut & paste from PyTorch official master until it's in a few official releases - RW - # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 @@ -211,12 +172,12 @@ def norm_cdf(x): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values - l = norm_cdf((a - mean) / std) + l_ = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) + tensor.uniform_(2 * l_ - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal @@ -248,22 +209,6 @@ def trunc_normal_tf_(tensor: torch.Tensor, std: float = 1.0, a: float = -2.0, b: float = 2.0) -> torch.Tensor: - """Fills the input Tensor with values drawn from a truncated - normal distribution. The values are effectively drawn from the - normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within - the bounds. The method used for generating the random values works - best when :math:`a \\leq \text{mean} \\leq b`. - NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the - bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 - and the result is subsquently scaled and shifted by the mean and std args. - Args: - tensor: an n-dimensional `torch.Tensor` - mean: the mean of the normal distribution - std: the standard deviation of the normal distribution - a: the minimum cutoff value - b: the maximum cutoff value - """ with torch.no_grad(): _trunc_normal_(tensor, 0, 1.0, a, b) tensor.mul_(std).add_(mean) @@ -302,26 +247,7 @@ def default_flax_embed_init(tensor): variance_scaling_(tensor, mode="fan_in", distribution="normal") -# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip class SiglipVisionModelOutput(ModelOutput): - """ - Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. - Args: - image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): - The image embeddings obtained by applying the projection layer to the pooler_output. - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - image_embeds: Optional[torch.FloatTensor] = None last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None @@ -361,7 +287,8 @@ def forward(self, embeddings = patch_embeds.flatten(2).transpose(1, 2) max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) - max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size + max_nb_patches_h, max_nb_patches_w = (max_im_h // self.patch_size, + max_im_w // self.patch_size) boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) position_ids = torch.full( @@ -412,7 +339,8 @@ def __init__(self, config): self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + "embed_dim must be divisible by num_heads (got `embed_dim`: " + f"{self.embed_dim} and `num_heads`:" f" {self.num_heads}).") self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout @@ -451,14 +379,16 @@ def forward( if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): raise ValueError( - f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + "Attention weights should be of size " + f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" f" {attn_weights.size()}") if attention_mask is not None: if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): raise ValueError( - f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" - ) + "Attention mask should be of size " + f"{(batch_size, 1, q_len, k_v_seq_len)}", + f"but is {attention_mask.size()}") attn_weights = attn_weights + attention_mask # upcast attention to fp32 @@ -474,7 +404,9 @@ def forward( if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): raise ValueError( - f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + "`attn_output` should be of size " + f"{(batch_size, self.num_heads, q_len, self.head_dim)}, " + "but is" f" {attn_output.size()}") attn_output = attn_output.transpose(1, 2).contiguous() @@ -486,11 +418,6 @@ def forward( class SiglipFlashAttention2(SiglipAttention): - """ - Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -515,9 +442,6 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_heads, @@ -529,27 +453,13 @@ def forward( if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length( kv_seq_len, self.layer_idx) - # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - # if past_key_value is not None: - # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) dropout_rate = self.dropout if self.training else 0.0 - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): @@ -560,10 +470,13 @@ def forward( else: target_dtype = self.q_proj.weight.dtype - logger.warning_once( - "The input hidden states seems to be silently casted in float32, this might be related to the fact" - " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}.") + logger.warning( + "The input hidden states seems to be " + "silently casted in float32, " + "this might be related to the fact " + "you have upcasted embedding or layer norm layers in float32. " + "We will cast back the input in" + " %s.", target_dtype) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) @@ -593,34 +506,15 @@ def _flash_attention_forward(self, query_length, dropout=0.0, softmax_scale=None): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`int`, *optional*): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - """ - - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. causal = self.is_causal and query_length != 1 # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, - query_length) + (query_states, key_states, value_states, indices_q, cu_seq_lens, + max_seq_lens) = self._upad_input(query_states, key_states, + value_states, attention_mask, + query_length) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens @@ -679,8 +573,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( - query_layer, attention_mask) + (query_layer, indices_q, cu_seqlens_q, + max_seqlen_in_batch_q) = unpad_input(query_layer, attention_mask) return ( query_layer, @@ -709,13 +603,15 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer +# with CLIP->Siglip class SiglipEncoderLayer(nn.Module): def __init__(self, config: SiglipVisionConfig): super().__init__() self.embed_dim = config.hidden_size - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_2 = ( + config._attn_implementation == "flash_attention_2") self.self_attn = (SiglipAttention(config) if not self._use_flash_attention_2 else SiglipFlashAttention2(config)) @@ -731,16 +627,6 @@ def forward( attention_mask: torch.Tensor, output_attentions: Optional[bool] = False, ) -> Tuple[torch.FloatTensor]: - """ - Args: - hidden_states (`torch.FloatTensor`): - Input to the layer of shape `(batch, seq_len, embed_dim)`. - attention_mask (`torch.FloatTensor`): - Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) @@ -765,11 +651,6 @@ def forward( class SiglipPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - config_class = SiglipVisionConfig base_model_prefix = "siglip" supports_gradient_checkpointing = True @@ -806,43 +687,9 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) -SIGLIP_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - Parameters: - config ([`SiglipVisionConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -SIGLIP_VISION_INPUTS_DOCSTRING = r""" - Args: - pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder +# with CLIP->Siglip class SiglipEncoder(nn.Module): - """ - Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a - [`SiglipEncoderLayer`]. - Args: - config: SiglipConfig - """ def __init__(self, config: SiglipVisionConfig): super().__init__() @@ -861,31 +708,13 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: - r""" - Args: - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = output_attentions if output_attentions is not None \ + else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = return_dict if return_dict is not None \ + else self.config.use_return_dict encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -925,9 +754,6 @@ def forward( attentions=all_attentions) -@add_start_docstrings( - """The vision model from SigLIP without any head or projection on top.""", - SIGLIP_START_DOCSTRING) class SiglipVisionTransformer(SiglipPreTrainedModel): config_class = SiglipVisionConfig main_input_name = "pixel_values" @@ -942,7 +768,8 @@ def __init__(self, config: SiglipVisionConfig): self.encoder = SiglipEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_2 = ( + config._attn_implementation == "flash_attention_2") # Initialize weights and apply final processing self.post_init() @@ -950,7 +777,6 @@ def __init__(self, config: SiglipVisionConfig): def get_input_embeddings(self) -> nn.Module: return self.embeddings.patch_embedding - @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) def forward( @@ -965,11 +791,13 @@ def forward( r""" Returns: """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = output_attentions if output_attentions is not None \ + else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = return_dict if return_dict is not None \ + else self.config.use_return_dict batch_size = pixel_values.size(0) if patch_attention_mask is None: @@ -990,8 +818,10 @@ def forward( patch_attention_mask = patch_attention_mask.view(batch_size, -1) # The call to `_upad_input` in `_flash_attention_forward` is expensive - # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), - # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + # So when the `patch_attention_mask` is full of 1s + # (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, + # which is equivalent to attending to the full sequence if not torch.any(~patch_attention_mask): attention_mask = None else: @@ -1409,9 +1239,10 @@ def init_vision_module(self): model.encoder.layers = model.encoder.layers[:-1] else: if self.config._attn_implementation == 'flash_attention_2': - self.config.vision_config._attn_implementation = 'flash_attention_2' + self.config.vision_config._attn_implementation \ + = 'flash_attention_2' else: - # not suport sdpa + # not support sdpa self.config.vision_config._attn_implementation = 'eager' model = SiglipVisionTransformer(self.config.vision_config) if self.config.drop_vision_last_layer: diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index cea170a3d4f3a..dc5787d1ff118 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -30,7 +30,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, From df838fff68c6dbad2fdfd2b35f78e580290a18f5 Mon Sep 17 00:00:00 2001 From: hezhihui Date: Tue, 30 Jul 2024 19:27:36 +0800 Subject: [PATCH 04/25] split Siglip && Adjust doc --- docs/source/models/supported_models.rst | 6 +- vllm/model_executor/models/minicpmv.py | 801 +---------------------- vllm/model_executor/models/na_vit.py | 805 ++++++++++++++++++++++++ 3 files changed, 813 insertions(+), 799 deletions(-) create mode 100644 vllm/model_executor/models/na_vit.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 7821214bf70be..a918d90dbf943 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -222,9 +222,13 @@ Vision Language Models - * - :code:`MiniCPM-V` - MiniCPM-V - - :code:`HwwwH/MiniCPM-V-2(Temporary)`, :code:`openbmb/MiniCPM-V-2(incoming...)`, :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc. + - :code:`openbmb/MiniCPM-V-2(Incoming...)`, :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc. - +.. note:: + For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork(:code:`HwwwH/MiniCPM-V-2`) for now. + For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 + ---- If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index dceee5174f147..108213f44a60e 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -20,29 +20,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only MiniCPM-V-2 model compatible with HuggingFace weights.""" -import logging import math -import os import re -import warnings from functools import partial -from typing import Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Tuple import numpy as np import torch import torch.nn.functional as F from PIL import Image from torch import nn -from torch.nn.init import _calculate_fan_in_and_fan_out, trunc_normal_ -from transformers.activations import ACT2FN +from torch.nn.init import trunc_normal_ from transformers.configuration_utils import PretrainedConfig -from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask -from transformers.modeling_outputs import (BaseModelOutput, - BaseModelOutputWithPooling) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import (ModelOutput, is_flash_attn_2_available, - replace_return_docstrings) from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig @@ -66,791 +55,6 @@ "language_model.model": "language_model", } -logger = logging.getLogger("vllm") - - -# For Siglip: copied from -# HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes -# Remove hints as there's little possibility to change these code. -class SiglipVisionConfig(PretrainedConfig): - - model_type = "siglip_vision_model" - - def __init__( - self, - hidden_size=768, - intermediate_size=3072, - num_hidden_layers=12, - num_attention_heads=12, - num_channels=3, - image_size=224, - patch_size=16, - hidden_act="gelu_pytorch_tanh", - layer_norm_eps=1e-6, - attention_dropout=0.0, - **kwargs, - ): - super().__init__(**kwargs) - - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_channels = num_channels - self.patch_size = patch_size - self.image_size = image_size - self.attention_dropout = attention_dropout - self.layer_norm_eps = layer_norm_eps - self.hidden_act = hidden_act - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, - os.PathLike], - **kwargs) -> "PretrainedConfig": - cls._set_token_in_kwargs(kwargs) - - config_dict, kwargs = cls.get_config_dict( - pretrained_model_name_or_path, **kwargs) - - # get the vision config dict if we are loading from SiglipConfig - if config_dict.get("model_type") == "siglip": - config_dict = config_dict["vision_config"] - - if "model_type" in config_dict and hasattr( - cls, - "model_type") and config_dict["model_type"] != cls.model_type: - logger.warning( - "You are using a model of type %s to " - "instantiate a model of type %s. " - "This is not supported for all configurations" - "of models and can yield errors.", config_dict['model_type'], - cls.model_type) - - return cls.from_dict(config_dict, **kwargs) - - -_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" - -SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "google/siglip-base-patch16-224", - # See all SigLIP models at https://huggingface.co/models?filter=siglip -] - -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import (index_first_axis, pad_input, # noqa - unpad_input) - - -# Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad( - torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -def _trunc_normal_(tensor, mean, std, a, b): - - def norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - if (mean < a - 2 * std) or (mean > b + 2 * std): - warnings.warn( - "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " - "The distribution of values may be incorrect.", - stacklevel=2, - ) - - # Values are generated by using a truncated uniform distribution and - # then using the inverse CDF for the normal distribution. - # Get upper and lower cdf values - l_ = norm_cdf((a - mean) / std) - u = norm_cdf((b - mean) / std) - - # Uniformly fill tensor with values from [l, u], then translate to - # [2l-1, 2u-1]. - tensor.uniform_(2 * l_ - 1, 2 * u - 1) - - # Use inverse cdf transform for normal distribution to get truncated - # standard normal - if tensor.dtype in [torch.float16, torch.bfloat16]: - # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu - og_dtype = tensor.dtype - tensor = tensor.to(torch.float32) - tensor.erfinv_() - tensor = tensor.to(og_dtype) - else: - tensor.erfinv_() - - # Transform to proper mean, std - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - # Clamp to ensure it's in the proper range - if tensor.dtype == torch.float16: - # The `clamp_` op is not (yet?) defined in float16+cpu - tensor = tensor.to(torch.float32) - tensor.clamp_(min=a, max=b) - tensor = tensor.to(torch.float16) - else: - tensor.clamp_(min=a, max=b) - - -def trunc_normal_tf_(tensor: torch.Tensor, - mean: float = 0.0, - std: float = 1.0, - a: float = -2.0, - b: float = 2.0) -> torch.Tensor: - with torch.no_grad(): - _trunc_normal_(tensor, 0, 1.0, a, b) - tensor.mul_(std).add_(mean) - - -def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): - fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) - if mode == "fan_in": - denom = fan_in - elif mode == "fan_out": - denom = fan_out - elif mode == "fan_avg": - denom = (fan_in + fan_out) / 2 - - variance = scale / denom - - if distribution == "truncated_normal": - # constant is stddev of standard normal truncated to (-2, 2) - trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) - elif distribution == "normal": - with torch.no_grad(): - tensor.normal_(std=math.sqrt(variance)) - elif distribution == "uniform": - bound = math.sqrt(3 * variance) - with torch.no_grad(): - tensor.uniform_(-bound, bound) - else: - raise ValueError(f"invalid distribution {distribution}") - - -def lecun_normal_(tensor): - variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") - - -def default_flax_embed_init(tensor): - variance_scaling_(tensor, mode="fan_in", distribution="normal") - - -class SiglipVisionModelOutput(ModelOutput): - image_embeds: Optional[torch.FloatTensor] = None - last_hidden_state: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - - -class SiglipVisionEmbeddings(nn.Module): - - def __init__(self, config: SiglipVisionConfig): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.image_size = config.image_size - self.patch_size = config.patch_size - - self.patch_embedding = nn.Conv2d( - in_channels=config.num_channels, - out_channels=self.embed_dim, - kernel_size=self.patch_size, - stride=self.patch_size, - padding="valid", - ) - - self.num_patches_per_side = self.image_size // self.patch_size - self.num_patches = self.num_patches_per_side**2 - self.num_positions = self.num_patches - self.position_embedding = nn.Embedding(self.num_positions, - self.embed_dim) - - def forward(self, - pixel_values: torch.FloatTensor, - patch_attention_mask: torch.BoolTensor, - tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor: - batch_size = pixel_values.size(0) - - patch_embeds = self.patch_embedding(pixel_values) - embeddings = patch_embeds.flatten(2).transpose(1, 2) - - max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) - max_nb_patches_h, max_nb_patches_w = (max_im_h // self.patch_size, - max_im_w // self.patch_size) - boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, - 1 / self.num_patches_per_side) - position_ids = torch.full( - size=( - batch_size, - max_nb_patches_h * max_nb_patches_w, - ), - fill_value=0, - ) - - for batch_idx, p_attn_mask in enumerate(patch_attention_mask): - if tgt_sizes is not None: - nb_patches_h = tgt_sizes[batch_idx][0] - nb_patches_w = tgt_sizes[batch_idx][1] - else: - nb_patches_h = p_attn_mask[:, 0].sum() - nb_patches_w = p_attn_mask[0].sum() - - fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) - fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) - - bucket_coords_h = torch.bucketize(fractional_coords_h, - boundaries, - right=True) - bucket_coords_w = torch.bucketize(fractional_coords_w, - boundaries, - right=True) - - pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + - bucket_coords_w).flatten() - position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids - - position_ids = position_ids.to(self.position_embedding.weight.device) - - embeddings = embeddings + self.position_embedding(position_ids) - return embeddings - - -class SiglipAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ - def __init__(self, config): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - "embed_dim must be divisible by num_heads (got `embed_dim`: " - f"{self.embed_dim} and `num_heads`:" - f" {self.num_heads}).") - self.scale = self.head_dim**-0.5 - self.dropout = config.attention_dropout - - self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - batch_size, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - key_states = key_states.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - - k_v_seq_len = key_states.shape[-2] - attn_weights = torch.matmul(query_states, key_states.transpose( - 2, 3)) * self.scale - - if attn_weights.size() != (batch_size, self.num_heads, q_len, - k_v_seq_len): - raise ValueError( - "Attention weights should be of size " - f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" - f" {attn_weights.size()}") - - if attention_mask is not None: - if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): - raise ValueError( - "Attention mask should be of size " - f"{(batch_size, 1, q_len, k_v_seq_len)}", - f"but is {attention_mask.size()}") - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, - dim=-1, - dtype=torch.float32).to( - query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, - p=self.dropout, - training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (batch_size, self.num_heads, q_len, - self.head_dim): - raise ValueError( - "`attn_output` should be of size " - f"{(batch_size, self.num_heads, q_len, self.head_dim)}, " - "but is" - f" {attn_output.size()}") - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights - - -class SiglipFlashAttention2(SiglipAttention): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.is_causal = False # Hack to make sure we don't use a causal mask - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length( - kv_seq_len, self.layer_idx) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.dropout if self.training else 0.0 - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning( - "The input hidden states seems to be " - "silently casted in float32, " - "this might be related to the fact " - "you have upcasted embedding or layer norm layers in float32. " - "We will cast back the input in" - " %s.", target_dtype) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = self._flash_attention_forward(query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate) - - attn_output = attn_output.reshape(bsz, q_len, - self.embed_dim).contiguous() - attn_output = self.out_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights - - def _flash_attention_forward(self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None): - causal = self.is_causal and query_length != 1 - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - (query_states, key_states, value_states, indices_q, cu_seq_lens, - max_seq_lens) = self._upad_input(query_states, key_states, - value_states, attention_mask, - query_length) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, - query_length) - else: - attn_output = flash_attn_func(query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal) - - return attn_output - - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, - query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( - attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, - head_dim), indices_k) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, - head_dim), indices_k) - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, - head_dim), indices_k) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - (query_layer, indices_q, cu_seqlens_q, - max_seqlen_in_batch_q) = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip -class SiglipMLP(nn.Module): - - def __init__(self, config): - super().__init__() - self.config = config - self.activation_fn = ACT2FN[config.hidden_act] - self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) - self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - -# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer -# with CLIP->Siglip -class SiglipEncoderLayer(nn.Module): - - def __init__(self, config: SiglipVisionConfig): - super().__init__() - self.embed_dim = config.hidden_size - self._use_flash_attention_2 = ( - config._attn_implementation == "flash_attention_2") - self.self_attn = (SiglipAttention(config) - if not self._use_flash_attention_2 else - SiglipFlashAttention2(config)) - self.layer_norm1 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) - self.mlp = SiglipMLP(config) - self.layer_norm2 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor]: - residual = hidden_states - - hidden_states = self.layer_norm1(hidden_states) - hidden_states, attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.layer_norm2(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states, ) - - if output_attentions: - outputs += (attn_weights, ) - - return outputs - - -class SiglipPreTrainedModel(PreTrainedModel): - config_class = SiglipVisionConfig - base_model_prefix = "siglip" - supports_gradient_checkpointing = True - - def _init_weights(self, module): - """Initialize the weights""" - - if isinstance(module, SiglipVisionEmbeddings): - width = self.config.hidden_size - nn.init.normal_(module.position_embedding.weight, - std=1 / np.sqrt(width)) - elif isinstance(module, nn.Embedding): - default_flax_embed_init(module.weight) - elif isinstance(module, SiglipAttention): - nn.init.normal_(module.q_proj.weight) - nn.init.normal_(module.k_proj.weight) - nn.init.normal_(module.v_proj.weight) - nn.init.normal_(module.out_proj.weight) - nn.init.zeros_(module.q_proj.bias) - nn.init.zeros_(module.k_proj.bias) - nn.init.zeros_(module.v_proj.bias) - nn.init.zeros_(module.out_proj.bias) - elif isinstance(module, SiglipMLP): - nn.init.normal_(module.fc1.weight) - nn.init.normal_(module.fc2.weight) - nn.init.normal_(module.fc1.bias, std=1e-6) - nn.init.normal_(module.fc2.bias, std=1e-6) - elif isinstance(module, (nn.Linear, nn.Conv2d)): - lecun_normal_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - -# Copied from transformers.models.clip.modeling_clip.CLIPEncoder -# with CLIP->Siglip -class SiglipEncoder(nn.Module): - - def __init__(self, config: SiglipVisionConfig): - super().__init__() - self.config = config - self.layers = nn.ModuleList([ - SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers) - ]) - self.gradient_checkpointing = False - - # Ignore copy - def forward( - self, - inputs_embeds, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutput]: - output_attentions = output_attentions if output_attentions is not None \ - else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else - self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None \ - else self.config.use_return_dict - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - - hidden_states = inputs_embeds - for encoder_layer in self.layers: - if output_hidden_states: - encoder_states = encoder_states + (hidden_states, ) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1], ) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states, ) - - if not return_dict: - return tuple( - v for v in [hidden_states, encoder_states, all_attentions] - if v is not None) - return BaseModelOutput(last_hidden_state=hidden_states, - hidden_states=encoder_states, - attentions=all_attentions) - - -class SiglipVisionTransformer(SiglipPreTrainedModel): - config_class = SiglipVisionConfig - main_input_name = "pixel_values" - _supports_flash_attn_2 = True - - def __init__(self, config: SiglipVisionConfig): - super().__init__(config) - self.config = config - embed_dim = config.hidden_size - - self.embeddings = SiglipVisionEmbeddings(config) - self.encoder = SiglipEncoder(config) - self.post_layernorm = nn.LayerNorm(embed_dim, - eps=config.layer_norm_eps) - self._use_flash_attention_2 = ( - config._attn_implementation == "flash_attention_2") - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> nn.Module: - return self.embeddings.patch_embedding - - @replace_return_docstrings(output_type=BaseModelOutputWithPooling, - config_class=SiglipVisionConfig) - def forward( - self, - pixel_values, - patch_attention_mask: Optional[torch.BoolTensor] = None, - tgt_sizes: Optional[torch.IntTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: - r""" - Returns: - """ - output_attentions = output_attentions if output_attentions is not None \ - else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else - self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None \ - else self.config.use_return_dict - - batch_size = pixel_values.size(0) - if patch_attention_mask is None: - patch_attention_mask = torch.ones( - size=( - batch_size, - pixel_values.size(2) // self.config.patch_size, - pixel_values.size(3) // self.config.patch_size, - ), - dtype=torch.bool, - device=pixel_values.device, - ) - - hidden_states = self.embeddings( - pixel_values=pixel_values, - patch_attention_mask=patch_attention_mask, - tgt_sizes=tgt_sizes) - - patch_attention_mask = patch_attention_mask.view(batch_size, -1) - # The call to `_upad_input` in `_flash_attention_forward` is expensive - # So when the `patch_attention_mask` is full of 1s - # (i.e. attending to the whole sequence), - # avoiding passing the attention_mask, - # which is equivalent to attending to the full sequence - if not torch.any(~patch_attention_mask): - attention_mask = None - else: - attention_mask = (_prepare_4d_attention_mask( - patch_attention_mask, hidden_states.dtype) - if not self._use_flash_attention_2 else - patch_attention_mask) - - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - last_hidden_state = encoder_outputs[0] - last_hidden_state = self.post_layernorm(last_hidden_state) - - if not return_dict: - return (last_hidden_state, None) + encoder_outputs[1:] - - return BaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=None, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - def get_abs_pos(abs_pos, tgt_size): # abs_pos: L, C @@ -1238,6 +442,7 @@ def init_vision_module(self): if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] else: + from vllm.model_executor.models.na_vit import SiglipVisionTransformer if self.config._attn_implementation == 'flash_attention_2': self.config.vision_config._attn_implementation \ = 'flash_attention_2' diff --git a/vllm/model_executor/models/na_vit.py b/vllm/model_executor/models/na_vit.py new file mode 100644 index 0000000000000..87d45d26f27f8 --- /dev/null +++ b/vllm/model_executor/models/na_vit.py @@ -0,0 +1,805 @@ +import os +import math +import logging +import warnings +from typing import Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.init import _calculate_fan_in_and_fan_out +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask +from transformers.modeling_outputs import (BaseModelOutput, + BaseModelOutputWithPooling) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import (ModelOutput, is_flash_attn_2_available, + replace_return_docstrings) +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig + + +logger = logging.getLogger("vllm") + + +# For Siglip: copied from +# HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes +# Remove hints as there's little possibility to change these code. +class SiglipVisionConfig(PretrainedConfig): + + model_type = "siglip_vision_model" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, + os.PathLike], + **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from SiglipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr( + cls, + "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + "You are using a model of type %s to " + "instantiate a model of type %s. " + "This is not supported for all configurations" + "of models and can yield errors.", config_dict['model_type'], + cls.model_type) + + return cls.from_dict(config_dict, **kwargs) + + +_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" + +SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "google/siglip-base-patch16-224", + # See all SigLIP models at https://huggingface.co/models?filter=siglip +] + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import (index_first_axis, pad_input, # noqa + unpad_input) + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _trunc_normal_(tensor, mean, std, a, b): + + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l_ = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l_ - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + if tensor.dtype in [torch.float16, torch.bfloat16]: + # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu + og_dtype = tensor.dtype + tensor = tensor.to(torch.float32) + tensor.erfinv_() + tensor = tensor.to(og_dtype) + else: + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + if tensor.dtype == torch.float16: + # The `clamp_` op is not (yet?) defined in float16+cpu + tensor = tensor.to(torch.float32) + tensor.clamp_(min=a, max=b) + tensor = tensor.to(torch.float16) + else: + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_(tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0) -> torch.Tensor: + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +class SiglipVisionModelOutput(ModelOutput): + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class SiglipVisionEmbeddings(nn.Module): + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, + self.embed_dim) + + def forward(self, + pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor, + tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor: + batch_size = pixel_values.size(0) + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) + max_nb_patches_h, max_nb_patches_w = (max_im_h // self.patch_size, + max_im_w // self.patch_size) + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, + 1 / self.num_patches_per_side) + position_ids = torch.full( + size=( + batch_size, + max_nb_patches_h * max_nb_patches_w, + ), + fill_value=0, + ) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + if tgt_sizes is not None: + nb_patches_h = tgt_sizes[batch_idx][0] + nb_patches_w = tgt_sizes[batch_idx][1] + else: + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize(fractional_coords_h, + boundaries, + right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, + boundaries, + right=True) + + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads (got `embed_dim`: " + f"{self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = torch.matmul(query_states, key_states.transpose( + 2, 3)) * self.scale + + if attn_weights.size() != (batch_size, self.num_heads, q_len, + k_v_seq_len): + raise ValueError( + "Attention weights should be of size " + f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}") + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + "Attention mask should be of size " + f"{(batch_size, 1, q_len, k_v_seq_len)}", + f"but is {attention_mask.size()}") + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, + dim=-1, + dtype=torch.float32).to( + query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, + p=self.dropout, + training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, + self.head_dim): + raise ValueError( + "`attn_output` should be of size " + f"{(batch_size, self.num_heads, q_len, self.head_dim)}, " + "but is" + f" {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class SiglipFlashAttention2(SiglipAttention): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False # Hack to make sure we don't use a causal mask + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length( + kv_seq_len, self.layer_idx) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.dropout if self.training else 0.0 + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning( + "The input hidden states seems to be " + "silently casted in float32, " + "this might be related to the fact " + "you have upcasted embedding or layer norm layers in float32. " + "We will cast back the input in" + " %s.", target_dtype) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward(query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate) + + attn_output = attn_output.reshape(bsz, q_len, + self.embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + def _flash_attention_forward(self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None): + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + (query_states, key_states, value_states, indices_q, cu_seq_lens, + max_seq_lens) = self._upad_input(query_states, key_states, + value_states, attention_mask, + query_length) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, + query_length) + else: + attn_output = flash_attn_func(query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, + query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, + head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + (query_layer, indices_q, cu_seqlens_q, + max_seqlen_in_batch_q) = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip +class SiglipMLP(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer +# with CLIP->Siglip +class SiglipEncoderLayer(nn.Module): + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self._use_flash_attention_2 = ( + config._attn_implementation == "flash_attention_2") + self.self_attn = (SiglipAttention(config) + if not self._use_flash_attention_2 else + SiglipFlashAttention2(config)) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, ) + + if output_attentions: + outputs += (attn_weights, ) + + return outputs + + +class SiglipPreTrainedModel(PreTrainedModel): + config_class = SiglipVisionConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + + if isinstance(module, SiglipVisionEmbeddings): + width = self.config.hidden_size + nn.init.normal_(module.position_embedding.weight, + std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SiglipAttention): + nn.init.normal_(module.q_proj.weight) + nn.init.normal_(module.k_proj.weight) + nn.init.normal_(module.v_proj.weight) + nn.init.normal_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, SiglipMLP): + nn.init.normal_(module.fc1.weight) + nn.init.normal_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder +# with CLIP->Siglip +class SiglipEncoder(nn.Module): + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers) + ]) + self.gradient_checkpointing = False + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None \ + else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None \ + else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1], ) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + + if not return_dict: + return tuple( + v for v in [hidden_states, encoder_states, all_attentions] + if v is not None) + return BaseModelOutput(last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions) + + +class SiglipVisionTransformer(SiglipPreTrainedModel): + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + _supports_flash_attn_2 = True + + def __init__(self, config: SiglipVisionConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + self._use_flash_attention_2 = ( + config._attn_implementation == "flash_attention_2") + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.embeddings.patch_embedding + + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, + config_class=SiglipVisionConfig) + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + tgt_sizes: Optional[torch.IntTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + """ + output_attentions = output_attentions if output_attentions is not None \ + else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None \ + else self.config.use_return_dict + + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_attention_mask = torch.ones( + size=( + batch_size, + pixel_values.size(2) // self.config.patch_size, + pixel_values.size(3) // self.config.patch_size, + ), + dtype=torch.bool, + device=pixel_values.device, + ) + + hidden_states = self.embeddings( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + tgt_sizes=tgt_sizes) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s + # (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, + # which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + attention_mask = None + else: + attention_mask = (_prepare_4d_attention_mask( + patch_attention_mask, hidden_states.dtype) + if not self._use_flash_attention_2 else + patch_attention_mask) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + if not return_dict: + return (last_hidden_state, None) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=None, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) From 84ac3e45837b1976c5a9ba8a68665972d42e2ca4 Mon Sep 17 00:00:00 2001 From: hezhihui Date: Wed, 31 Jul 2024 15:20:39 +0800 Subject: [PATCH 05/25] add version inference --- vllm/model_executor/models/minicpmv.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 108213f44a60e..e25699d769366 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -20,6 +20,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Inference-only MiniCPM-V model compatible with HuggingFace weights.""" import math import re from functools import partial @@ -388,7 +389,13 @@ def __init__( self.config = config self.multimodal_config = multimodal_config - self.version = float(self.config.version) + if not hasattr(self.config, "version"): + if self.config.hidden_size == 2304: + self.version = 2.0 + else: + self.version = 2.5 + else: + self.version = float(self.config.version) self.llm = self.init_llm(config, cache_config, quant_config) self.vpm = self.init_vision_module() param_dtype = torch.get_default_dtype() From 08d609b43c125dcc592dd5056161bd27082a33c9 Mon Sep 17 00:00:00 2001 From: hezhihui Date: Wed, 31 Jul 2024 15:22:32 +0800 Subject: [PATCH 06/25] format --- vllm/model_executor/models/minicpmv.py | 3 ++- vllm/model_executor/models/na_vit.py | 15 +++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index e25699d769366..4536d6bf22bbf 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -449,7 +449,8 @@ def init_vision_module(self): if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] else: - from vllm.model_executor.models.na_vit import SiglipVisionTransformer + from vllm.model_executor.models.na_vit import ( + SiglipVisionTransformer) if self.config._attn_implementation == 'flash_attention_2': self.config.vision_config._attn_implementation \ = 'flash_attention_2' diff --git a/vllm/model_executor/models/na_vit.py b/vllm/model_executor/models/na_vit.py index 87d45d26f27f8..871e4128b66e1 100644 --- a/vllm/model_executor/models/na_vit.py +++ b/vllm/model_executor/models/na_vit.py @@ -1,23 +1,22 @@ -import os -import math import logging +import math +import os import warnings from typing import Optional, Tuple, Union import numpy as np import torch -from torch import nn import torch.nn.functional as F +from torch import nn from torch.nn.init import _calculate_fan_in_and_fan_out +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.modeling_outputs import (BaseModelOutput, BaseModelOutputWithPooling) from transformers.modeling_utils import PreTrainedModel from transformers.utils import (ModelOutput, is_flash_attn_2_available, replace_return_docstrings) -from transformers.activations import ACT2FN -from transformers.configuration_utils import PretrainedConfig - logger = logging.getLogger("vllm") @@ -91,8 +90,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import (index_first_axis, pad_input, # noqa - unpad_input) + from flash_attn.bert_padding import pad_input # noqa + from flash_attn.bert_padding import index_first_axis, unpad_input # Copied from transformers.models.llama.modeling_llama._get_unpad_data From a1b85c044c5652c1084a1b971ef2d256d629103b Mon Sep 17 00:00:00 2001 From: hezhihui Date: Wed, 31 Jul 2024 15:54:05 +0800 Subject: [PATCH 07/25] add type annotations --- vllm/model_executor/models/minicpmv.py | 88 ++++++++++++++++---------- 1 file changed, 53 insertions(+), 35 deletions(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 4536d6bf22bbf..39067ec74470b 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -24,7 +24,7 @@ import math import re from functools import partial -from typing import Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch @@ -57,7 +57,7 @@ } -def get_abs_pos(abs_pos, tgt_size): +def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor): # abs_pos: L, C # tgt_size: (H, W) # return: M, C @@ -74,10 +74,10 @@ def get_abs_pos(abs_pos, tgt_size): # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 -def get_2d_sincos_pos_embed(embed_dim, - grid_size, - cls_token=False, - version=2.0): +def get_2d_sincos_pos_embed(embed_dim: int, + grid_size: Union[int, Tuple[int, int]], + cls_token: bool = False, + version: float = 2.0): """ grid_size: int of the grid height and width return: @@ -105,7 +105,9 @@ def get_2d_sincos_pos_embed(embed_dim, return pos_embed -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version=2.0): +def get_2d_sincos_pos_embed_from_grid(embed_dim: int, + grid: Union[int, Tuple[int, int]], + version: float = 2.0): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h @@ -121,7 +123,9 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version=2.0): return emb -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, version=2.0): +def get_1d_sincos_pos_embed_from_grid(embed_dim: int, + pos: int, + version: float = 2.0): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) / (H, W) @@ -157,15 +161,15 @@ class Resampler(nn.Module): default_norm_layer = partial(nn.LayerNorm, eps=1e-6) def __init__(self, - num_queries, - grid_size, - embed_dim, - num_heads, - kv_dim=None, - norm_layer=default_norm_layer, - adaptive=False, - max_size=(70, 70), - version=2.0): + num_queries: int, + grid_size: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: nn.Module = default_norm_layer, + adaptive: bool = False, + max_size: Tuple[int, int] = (70, 70), + version: float = 2.0): super().__init__() self.version = version @@ -205,14 +209,16 @@ def __init__(self, self.apply(self._init_weights) - def _set_2d_pos_cache(self, max_size, device='cpu'): + def _set_2d_pos_cache(self, + max_size: Tuple[int, int], + device: torch.device = 'cpu'): pos_embed = torch.from_numpy( get_2d_sincos_pos_embed(self.embed_dim, max_size, version=self.version)).float().to(device) self.register_buffer("pos_embed", pos_embed, persistent=False) - def _adjust_pos_cache(self, tgt_sizes, device): + def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, device: torch.device): max_h = torch.max(tgt_sizes[:, 0]) max_w = torch.max(tgt_sizes[:, 1]) if max_h > self.max_size[0] or max_w > self.max_size[1]: @@ -222,7 +228,7 @@ def _adjust_pos_cache(self, tgt_sizes, device): ] self._set_2d_pos_cache(self.max_size, device) - def _init_weights(self, m): + def _init_weights(self, m: nn.Module): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: @@ -231,7 +237,7 @@ def _init_weights(self, m): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) - def forward_2_5(self, x, tgt_sizes=None): + def forward_2_5(self, x: torch.Tensor, tgt_sizes: torch.Tensor = None): assert x.shape[0] == tgt_sizes.shape[0] bs = x.shape[0] @@ -277,7 +283,10 @@ def forward_2_5(self, x, tgt_sizes=None): x = x @ self.proj return x - def forward_2(self, x, tgt_sizes=None, attn_mask=None): + def forward_2(self, + x: torch.Tensor, + tgt_sizes: torch.Tensor = None, + attn_mask: torch.Tensor = None): if self.adaptive: pos_embed = torch.Tensor( get_2d_sincos_pos_embed(self.embed_dim, @@ -301,7 +310,10 @@ def forward_2(self, x, tgt_sizes=None, attn_mask=None): x = x @ self.proj return x - def forward(self, x, tgt_sizes=None, attn_mask=None): + def forward(self, + x: torch.Tensor, + tgt_sizes: torch.Tensor = None, + attn_mask: torch.Tensor = None): if self.version == 2.0: return self.forward_2(x, tgt_sizes=tgt_sizes, attn_mask=attn_mask) else: @@ -321,7 +333,7 @@ def dummy_seq_data_for_minicpmv(seq_len: int): return SequenceData(token_ids) -def dummy_image_for_minicpmv(hf_config): +def dummy_image_for_minicpmv(hf_config: PretrainedConfig): width = height = hf_config.image_size image = Image.new("RGB", (width, height), color=0) return {"image": image} @@ -380,7 +392,7 @@ class MiniCPMV(nn.Module, SupportsVision): def __init__( self, - config, + config: PretrainedConfig, multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, @@ -407,7 +419,10 @@ def __init__( self.resampler.to(device="cuda", dtype=param_dtype) self.sampler = Sampler() - def init_llm(self, config, cache_config, quant_config): + def init_llm(self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): if self.version == 2.0: return MiniCPMForCausalLM(config, cache_config=cache_config, @@ -462,7 +477,7 @@ def init_vision_module(self): model.encoder.layers = model.encoder.layers[:-1] return model - def init_resampler(self, embed_dim, vision_dim): + def init_resampler(self, embed_dim: int, vision_dim: int): default_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.float16) if self.version == 2.0: @@ -486,10 +501,10 @@ def init_resampler(self, embed_dim, vision_dim): return resampler def get_vision_embedding(self, - pixel_values, - patch_attn_mask=None, - tgt_sizes=None, - version=2.0): + pixel_values: List[List[torch.Tensor]], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + version: float = 2.0): if version == 2.0: res = [] dtype = self.vpm.pos_embed.data.dtype @@ -517,7 +532,7 @@ def get_vision_embedding(self, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes).last_hidden_state - def get_image_bounds(self, input_ids): + def get_image_bounds(self, input_ids: torch.Tensor): tokenizer = cached_get_tokenizer(self.config._name_or_path, trust_remote_code=True) if not hasattr(tokenizer, "slice_start_id"): @@ -542,7 +557,8 @@ def get_image_bounds(self, input_ids): return image_bound - def get_vision_hidden_states(self, data): + def get_vision_hidden_states(self, data: Dict[str, Union[List, + torch.Tensor]]): if "vision_hidden_states" not in data: pixel_values = data["pixel_values"] tgt_sizes = data["tgt_sizes"] @@ -599,7 +615,7 @@ def get_vision_hidden_states(self, data): return vision_hidden_states - def get_embedding(self, data): + def get_embedding(self, data: Dict[str, Union[List, torch.Tensor]]): input_ids = data["input_ids"] vision_hidden_states = self.get_vision_hidden_states(data) @@ -630,7 +646,9 @@ def get_embedding(self, data): vision_hidden_states.view(-1, vision_hidden_states.shape[-1])) return vlm_embedding, vision_hidden_states - def process_multimodal_inputs(self, inputs): + def process_multimodal_inputs(self, inputs: Dict[str, + Union[List, + torch.Tensor]]): pixel_values = [] tgt_sizes = [] for b in range(len(inputs["pixel_values"])): From 928a6fa83eae856f835c45d2d5717507d0d375eb Mon Sep 17 00:00:00 2001 From: hezhihui Date: Wed, 31 Jul 2024 16:07:02 +0800 Subject: [PATCH 08/25] add minicpmv version inference --- vllm/model_executor/models/minicpmv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 39067ec74470b..5566fbb2cb922 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -402,7 +402,7 @@ def __init__( self.multimodal_config = multimodal_config if not hasattr(self.config, "version"): - if self.config.hidden_size == 2304: + if self.config.hidden_size == 2304 and self.config.query_num == 64: self.version = 2.0 else: self.version = 2.5 From b7dc51ef90a9322da9d0cbab462bdd9942d2b868 Mon Sep 17 00:00:00 2001 From: hezhihui Date: Wed, 31 Jul 2024 16:40:03 +0800 Subject: [PATCH 09/25] change input_embeds to inputs_embeds --- vllm/model_executor/models/llama.py | 4 ++-- vllm/model_executor/models/minicpm.py | 4 ++-- vllm/model_executor/models/minicpmv.py | 2 +- vllm/model_executor/models/qwen2.py | 10 +++++----- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 306d22e42ed1d..f567b018ea30d 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -418,11 +418,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - input_embeds: Optional[torch.Tensor] = None + inputs_embeds: Optional[torch.Tensor] = None ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, - input_embeds) + inputs_embeds) return model_output def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 7a8ac0bb1f949..695bfc071d11e 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -463,11 +463,11 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - input_embeds: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, input_embeds) + attn_metadata, inputs_embeds) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 5566fbb2cb922..5407ac990be0b 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -683,7 +683,7 @@ def forward( kv_caches=kv_caches, attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, - input_embeds=vlm_embeddings) + inputs_embeds=vlm_embeddings) return output def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index dc5787d1ff118..1e53b5b0a1a4a 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -253,10 +253,10 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - input_embeds: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if input_embeds is not None: - hidden_states = input_embeds + if inputs_embeds is not None: + hidden_states = inputs_embeds else: hidden_states = self.embed_tokens(input_ids) residual = None @@ -340,11 +340,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - input_embeds: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, - input_embeds) + inputs_embeds) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, From f63c1e7d4534246f2f281cc58f0023ca91cc94c2 Mon Sep 17 00:00:00 2001 From: hezhihui Date: Wed, 31 Jul 2024 19:40:59 +0800 Subject: [PATCH 10/25] fix: use *Model instead of *CausalModel for llm --- vllm/model_executor/models/llama.py | 4 +- vllm/model_executor/models/minicpm.py | 4 +- vllm/model_executor/models/minicpmv.py | 57 +++++++++++++++----------- vllm/model_executor/models/qwen2.py | 4 +- 4 files changed, 37 insertions(+), 32 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f567b018ea30d..2052c443a8885 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -418,11 +418,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + attn_metadata, intermediate_tensors) return model_output def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 695bfc071d11e..b46e88f5fc584 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -370,6 +370,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: if inputs_embeds is not None: @@ -463,11 +464,10 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - inputs_embeds: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, inputs_embeds) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 5407ac990be0b..284087cc60e98 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -40,11 +40,13 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsVision -from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.model_executor.models.minicpm import MiniCPMForCausalLM -from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM +from vllm.model_executor.models.llama import LlamaModel +from vllm.model_executor.models.minicpm import MiniCPMModel +from vllm.model_executor.models.qwen2 import Qwen2Model +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import (cached_get_image_processor, @@ -52,8 +54,8 @@ from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData _KEYS_TO_MODIFY_MAPPING = { - "language_model.lm_head": "lm_head", - "language_model.model": "language_model", + "llm.lm_head": "lm_head", + "llm.model": "llm", } @@ -414,9 +416,13 @@ def __init__( self.vpm.to(dtype=param_dtype) self.vision_dim = self.vpm.embed_dim if self.version == 2.0 \ else self.vpm.embeddings.embed_dim - self.embed_dim = self.llm.config.hidden_size + self.embed_dim = self.config.hidden_size self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) self.resampler.to(device="cuda", dtype=param_dtype) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() def init_llm(self, @@ -424,17 +430,17 @@ def init_llm(self, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): if self.version == 2.0: - return MiniCPMForCausalLM(config, - cache_config=cache_config, - quant_config=quant_config) + return MiniCPMModel(config, + cache_config=cache_config, + quant_config=quant_config) elif self.version == 2.5: - return LlamaForCausalLM(config, - cache_config=cache_config, - quant_config=quant_config) + return LlamaModel(config, + cache_config=cache_config, + quant_config=quant_config) else: - return Qwen2ForCausalLM(config, - cache_config=cache_config, - quant_config=quant_config) + return Qwen2Model(config, + cache_config=cache_config, + quant_config=quant_config) def init_vision_module(self): if self.version == 2.0: @@ -624,11 +630,11 @@ def get_embedding(self, data: Dict[str, Union[List, torch.Tensor]]): else: image_bounds = [] - if hasattr(self.llm.config, 'scale_emb'): - vlm_embedding = self.llm.model.embed_tokens( - input_ids) * self.llm.config.scale_emb + if hasattr(self.config, 'scale_emb'): + vlm_embedding = self.llm.embed_tokens( + input_ids) * self.config.scale_emb else: - vlm_embedding = self.llm.model.embed_tokens(input_ids) + vlm_embedding = self.llm.embed_tokens(input_ids) vision_hidden_states = [ i.type(vlm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states @@ -688,14 +694,17 @@ def forward( def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - return self.llm.compute_logits(hidden_states, sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.llm.sample(logits, sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): @@ -709,9 +718,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - # for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - # if key_to_modify in name: - # name = name.replace(key_to_modify, new_key) + for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in name: + name = name.replace(key_to_modify, new_key) if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 1e53b5b0a1a4a..35fd6f37589a0 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -340,11 +340,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, From 4e5872c6119bf35c2211f2916ff7c910c03ee517 Mon Sep 17 00:00:00 2001 From: hezhihui Date: Wed, 31 Jul 2024 19:44:20 +0800 Subject: [PATCH 11/25] format --- vllm/model_executor/models/minicpmv.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 284087cc60e98..8fd0a8a21c62f 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -37,16 +37,16 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsVision from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.minicpm import MiniCPMModel from vllm.model_executor.models.qwen2 import Qwen2Model -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import (cached_get_image_processor, @@ -698,7 +698,6 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits - def sample( self, logits: torch.Tensor, From 7153277fce32389016bcea7a3fcab9e59087bc50 Mon Sep 17 00:00:00 2001 From: Alphi <52458637+HwwwwwwwH@users.noreply.github.com> Date: Wed, 31 Jul 2024 20:37:28 +0800 Subject: [PATCH 12/25] Update docs/source/models/supported_models.rst Co-authored-by: Cyrus Leung --- docs/source/models/supported_models.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index a918d90dbf943..2a86bfbf7ff66 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -222,7 +222,7 @@ Vision Language Models - * - :code:`MiniCPM-V` - MiniCPM-V - - :code:`openbmb/MiniCPM-V-2(Incoming...)`, :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc. + - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc. - .. note:: From 78fe2787817697d26e13f376aa6ea39ce86abaf7 Mon Sep 17 00:00:00 2001 From: Alphi <52458637+HwwwwwwwH@users.noreply.github.com> Date: Wed, 31 Jul 2024 20:37:58 +0800 Subject: [PATCH 13/25] Update docs/source/models/supported_models.rst Co-authored-by: Cyrus Leung --- docs/source/models/supported_models.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 2a86bfbf7ff66..d21be65bcb95c 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -226,7 +226,7 @@ Vision Language Models - .. note:: - For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork(:code:`HwwwH/MiniCPM-V-2`) for now. + For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 ---- From 50d78fd9a084fd7abcf67fd376bcb3dbead9f8ce Mon Sep 17 00:00:00 2001 From: Alphi <52458637+HwwwwwwwH@users.noreply.github.com> Date: Wed, 31 Jul 2024 20:38:09 +0800 Subject: [PATCH 14/25] Update docs/source/models/supported_models.rst Co-authored-by: Cyrus Leung --- docs/source/models/supported_models.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index d21be65bcb95c..a1ea366b82b04 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -227,7 +227,7 @@ Vision Language Models .. note:: For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. - For more details, please see: See: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 + For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 ---- From ff8db35d05932ad3b426bae9307fbfa9836897d2 Mon Sep 17 00:00:00 2001 From: Alphi <52458637+HwwwwwwwH@users.noreply.github.com> Date: Wed, 31 Jul 2024 20:38:36 +0800 Subject: [PATCH 15/25] Update vllm/model_executor/models/minicpmv.py Co-authored-by: Cyrus Leung --- vllm/model_executor/models/minicpmv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 8fd0a8a21c62f..35f18077a5292 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -213,7 +213,7 @@ def __init__(self, def _set_2d_pos_cache(self, max_size: Tuple[int, int], - device: torch.device = 'cpu'): + device: torch.types.Device = 'cpu'): pos_embed = torch.from_numpy( get_2d_sincos_pos_embed(self.embed_dim, max_size, From 794c9b1efcef07b95a59cc9d0836a94996a1faf0 Mon Sep 17 00:00:00 2001 From: Alphi <52458637+HwwwwwwwH@users.noreply.github.com> Date: Wed, 31 Jul 2024 20:40:09 +0800 Subject: [PATCH 16/25] Update vllm/model_executor/models/minicpmv.py Co-authored-by: Cyrus Leung --- vllm/model_executor/models/minicpmv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 35f18077a5292..ae6c314abece6 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -314,8 +314,8 @@ def forward_2(self, def forward(self, x: torch.Tensor, - tgt_sizes: torch.Tensor = None, - attn_mask: torch.Tensor = None): + tgt_sizes: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None): if self.version == 2.0: return self.forward_2(x, tgt_sizes=tgt_sizes, attn_mask=attn_mask) else: From 113238f440da0854100f9de164cd9eeb96483d31 Mon Sep 17 00:00:00 2001 From: Alphi <52458637+HwwwwwwwH@users.noreply.github.com> Date: Wed, 31 Jul 2024 20:40:25 +0800 Subject: [PATCH 17/25] Update vllm/model_executor/models/minicpmv.py Co-authored-by: Cyrus Leung --- vllm/model_executor/models/minicpmv.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index ae6c314abece6..8031f9529c12d 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -28,6 +28,7 @@ import numpy as np import torch +import torch.types import torch.nn.functional as F from PIL import Image from torch import nn From 92bc3167d3cb1199496d32fee2ece1b8da894482 Mon Sep 17 00:00:00 2001 From: Alphi <52458637+HwwwwwwwH@users.noreply.github.com> Date: Wed, 31 Jul 2024 20:40:48 +0800 Subject: [PATCH 18/25] Update vllm/model_executor/models/minicpmv.py Co-authored-by: Cyrus Leung --- vllm/model_executor/models/minicpmv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 8031f9529c12d..cf370ad3e32b7 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -221,7 +221,7 @@ def _set_2d_pos_cache(self, version=self.version)).float().to(device) self.register_buffer("pos_embed", pos_embed, persistent=False) - def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, device: torch.device): + def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, device: torch.types.Device): max_h = torch.max(tgt_sizes[:, 0]) max_w = torch.max(tgt_sizes[:, 1]) if max_h > self.max_size[0] or max_w > self.max_size[1]: From 0d15a8c2b33edcdd0f6be386c5159869522d0041 Mon Sep 17 00:00:00 2001 From: Alphi <52458637+HwwwwwwwH@users.noreply.github.com> Date: Wed, 31 Jul 2024 20:41:31 +0800 Subject: [PATCH 19/25] Update vllm/model_executor/models/minicpmv.py Co-authored-by: Cyrus Leung --- vllm/model_executor/models/minicpmv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index cf370ad3e32b7..8adb97a0b7c4b 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -240,7 +240,7 @@ def _init_weights(self, m: nn.Module): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) - def forward_2_5(self, x: torch.Tensor, tgt_sizes: torch.Tensor = None): + def forward_2_5(self, x: torch.Tensor, tgt_sizes: Optional[torch.Tensor] = None): assert x.shape[0] == tgt_sizes.shape[0] bs = x.shape[0] From aa837b7f73906e60c2d5bb094b9d0447748e57db Mon Sep 17 00:00:00 2001 From: Alphi <52458637+HwwwwwwwH@users.noreply.github.com> Date: Wed, 31 Jul 2024 20:41:40 +0800 Subject: [PATCH 20/25] Update vllm/model_executor/models/minicpmv.py Co-authored-by: Cyrus Leung --- vllm/model_executor/models/minicpmv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 8adb97a0b7c4b..167b6e510d5d0 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -288,8 +288,8 @@ def forward_2_5(self, x: torch.Tensor, tgt_sizes: Optional[torch.Tensor] = None) def forward_2(self, x: torch.Tensor, - tgt_sizes: torch.Tensor = None, - attn_mask: torch.Tensor = None): + tgt_sizes: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None): if self.adaptive: pos_embed = torch.Tensor( get_2d_sincos_pos_embed(self.embed_dim, From 65bef49f48914bd6c533af335cf7f30858906a2a Mon Sep 17 00:00:00 2001 From: Alphi <52458637+HwwwwwwwH@users.noreply.github.com> Date: Wed, 31 Jul 2024 20:41:58 +0800 Subject: [PATCH 21/25] Update vllm/model_executor/models/minicpmv.py Co-authored-by: Cyrus Leung --- vllm/model_executor/models/minicpmv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 167b6e510d5d0..0786d51fc83d5 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -622,7 +622,7 @@ def get_vision_hidden_states(self, data: Dict[str, Union[List, return vision_hidden_states - def get_embedding(self, data: Dict[str, Union[List, torch.Tensor]]): + def get_embedding(self, data: Dict[str, Union[List[torch.Tensor], torch.Tensor]]): input_ids = data["input_ids"] vision_hidden_states = self.get_vision_hidden_states(data) From f71fb296a03f0e991be9a80ec3dedb840bc54f7f Mon Sep 17 00:00:00 2001 From: Alphi <52458637+HwwwwwwwH@users.noreply.github.com> Date: Wed, 31 Jul 2024 20:42:06 +0800 Subject: [PATCH 22/25] Update vllm/model_executor/models/minicpmv.py Co-authored-by: Cyrus Leung --- vllm/model_executor/models/minicpmv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 0786d51fc83d5..55ef511eaec1f 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -564,7 +564,7 @@ def get_image_bounds(self, input_ids: torch.Tensor): return image_bound - def get_vision_hidden_states(self, data: Dict[str, Union[List, + def get_vision_hidden_states(self, data: Dict[str, Union[List[torch.Tensor], torch.Tensor]]): if "vision_hidden_states" not in data: pixel_values = data["pixel_values"] From bbc49e4c248ffcf594d0ff215094f35d172e2288 Mon Sep 17 00:00:00 2001 From: Alphi <52458637+HwwwwwwwH@users.noreply.github.com> Date: Wed, 31 Jul 2024 20:43:20 +0800 Subject: [PATCH 23/25] Update vllm/model_executor/models/minicpmv.py Co-authored-by: Cyrus Leung --- vllm/model_executor/models/minicpmv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 55ef511eaec1f..e7e1f4758ed21 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -654,7 +654,7 @@ def get_embedding(self, data: Dict[str, Union[List[torch.Tensor], torch.Tensor]] return vlm_embedding, vision_hidden_states def process_multimodal_inputs(self, inputs: Dict[str, - Union[List, + Union[List[torch.Tensor], torch.Tensor]]): pixel_values = [] tgt_sizes = [] From 5115b5a7de07d93c3f26a20ad74d221ee35a9ce1 Mon Sep 17 00:00:00 2001 From: hezhihui Date: Wed, 31 Jul 2024 20:48:16 +0800 Subject: [PATCH 24/25] format --- vllm/model_executor/models/minicpmv.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index e7e1f4758ed21..ae64c17b5bcac 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -28,8 +28,8 @@ import numpy as np import torch -import torch.types import torch.nn.functional as F +import torch.types from PIL import Image from torch import nn from torch.nn.init import trunc_normal_ @@ -221,7 +221,8 @@ def _set_2d_pos_cache(self, version=self.version)).float().to(device) self.register_buffer("pos_embed", pos_embed, persistent=False) - def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, device: torch.types.Device): + def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, + device: torch.types.Device): max_h = torch.max(tgt_sizes[:, 0]) max_w = torch.max(tgt_sizes[:, 1]) if max_h > self.max_size[0] or max_w > self.max_size[1]: @@ -240,7 +241,9 @@ def _init_weights(self, m: nn.Module): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) - def forward_2_5(self, x: torch.Tensor, tgt_sizes: Optional[torch.Tensor] = None): + def forward_2_5(self, + x: torch.Tensor, + tgt_sizes: Optional[torch.Tensor] = None): assert x.shape[0] == tgt_sizes.shape[0] bs = x.shape[0] @@ -564,8 +567,9 @@ def get_image_bounds(self, input_ids: torch.Tensor): return image_bound - def get_vision_hidden_states(self, data: Dict[str, Union[List[torch.Tensor], - torch.Tensor]]): + def get_vision_hidden_states(self, data: Dict[str, + Union[List[torch.Tensor], + torch.Tensor]]): if "vision_hidden_states" not in data: pixel_values = data["pixel_values"] tgt_sizes = data["tgt_sizes"] @@ -622,7 +626,8 @@ def get_vision_hidden_states(self, data: Dict[str, Union[List[torch.Tensor], return vision_hidden_states - def get_embedding(self, data: Dict[str, Union[List[torch.Tensor], torch.Tensor]]): + def get_embedding(self, data: Dict[str, Union[List[torch.Tensor], + torch.Tensor]]): input_ids = data["input_ids"] vision_hidden_states = self.get_vision_hidden_states(data) From 50148ad26924e91298e70abfda5649a232df0de7 Mon Sep 17 00:00:00 2001 From: hezhihui Date: Wed, 31 Jul 2024 21:08:59 +0800 Subject: [PATCH 25/25] change to (int, int) --- vllm/model_executor/models/minicpmv.py | 49 +++++++++++++------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index ae64c17b5bcac..2a7fe7ba0ebac 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -80,7 +80,7 @@ def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor): def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], cls_token: bool = False, - version: float = 2.0): + version: Tuple[int, int] = (2, 0)): """ grid_size: int of the grid height and width return: @@ -97,7 +97,7 @@ def get_2d_sincos_pos_embed(embed_dim: int, grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) - if version == 2.0: + if version == (2, 0): grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) if cls_token: @@ -110,7 +110,7 @@ def get_2d_sincos_pos_embed(embed_dim: int, def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: Union[int, Tuple[int, int]], - version: float = 2.0): + version: Tuple[int, int] = (2, 0)): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h @@ -119,7 +119,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, emb_w = get_1d_sincos_pos_embed_from_grid( embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2) - if version == 2.0: + if version == (2, 0): emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) else: emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D) @@ -128,7 +128,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: int, - version: float = 2.0): + version: Tuple[int, int] = (2, 0)): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) / (H, W) @@ -139,7 +139,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, omega /= embed_dim / 2. omega = 1. / 10000**omega # (D/2,) - if version == 2.0: + if version == (2, 0): pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) @@ -172,11 +172,11 @@ def __init__(self, norm_layer: nn.Module = default_norm_layer, adaptive: bool = False, max_size: Tuple[int, int] = (70, 70), - version: float = 2.0): + version: Tuple[int, int] = (2, 0)): super().__init__() self.version = version - if self.version == 2.0: + if self.version == (2, 0): self.num_queries = grid_size**2 else: self.num_queries = num_queries @@ -201,7 +201,7 @@ def __init__(self, self.proj = nn.Parameter( (embed_dim**-0.5) * torch.randn(embed_dim, embed_dim)) - if self.version == 2.0: + if self.version == (2, 0): self.pos_embed = nn.Parameter( torch.from_numpy( get_2d_sincos_pos_embed( @@ -320,7 +320,7 @@ def forward(self, x: torch.Tensor, tgt_sizes: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None): - if self.version == 2.0: + if self.version == (2, 0): return self.forward_2(x, tgt_sizes=tgt_sizes, attn_mask=attn_mask) else: return self.forward_2_5(x, tgt_sizes=tgt_sizes) @@ -409,16 +409,17 @@ def __init__( if not hasattr(self.config, "version"): if self.config.hidden_size == 2304 and self.config.query_num == 64: - self.version = 2.0 + self.version = (2, 0) else: - self.version = 2.5 + self.version = (2, 5) else: - self.version = float(self.config.version) + self.version = str(self.config.version).split(".") + self.version = tuple([int(x) for x in self.version]) self.llm = self.init_llm(config, cache_config, quant_config) self.vpm = self.init_vision_module() param_dtype = torch.get_default_dtype() self.vpm.to(dtype=param_dtype) - self.vision_dim = self.vpm.embed_dim if self.version == 2.0 \ + self.vision_dim = self.vpm.embed_dim if self.version == (2, 0) \ else self.vpm.embeddings.embed_dim self.embed_dim = self.config.hidden_size self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) @@ -433,11 +434,11 @@ def init_llm(self, config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): - if self.version == 2.0: + if self.version == (2, 0): return MiniCPMModel(config, cache_config=cache_config, quant_config=quant_config) - elif self.version == 2.5: + elif self.version == (2, 5): return LlamaModel(config, cache_config=cache_config, quant_config=quant_config) @@ -447,7 +448,7 @@ def init_llm(self, quant_config=quant_config) def init_vision_module(self): - if self.version == 2.0: + if self.version == (2, 0): try: import timm except ImportError: @@ -467,7 +468,7 @@ def init_vision_module(self): if self.config.drop_vision_last_layer: model.blocks = model.blocks[:-1] - elif self.version == 2.5: + elif self.version == (2, 5): from transformers.models.idefics2.modeling_idefics2 import ( Idefics2VisionTransformer) model = Idefics2VisionTransformer(self.config.vision_config) @@ -490,7 +491,7 @@ def init_vision_module(self): def init_resampler(self, embed_dim: int, vision_dim: int): default_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.float16) - if self.version == 2.0: + if self.version == (2, 0): resampler = Resampler(grid_size=int( math.sqrt(self.config.query_num)), num_queries=None, @@ -514,8 +515,8 @@ def get_vision_embedding(self, pixel_values: List[List[torch.Tensor]], patch_attn_mask: Optional[torch.Tensor] = None, tgt_sizes: Optional[torch.Tensor] = None, - version: float = 2.0): - if version == 2.0: + version: Tuple[int, int] = (2, 0)): + if version == (2, 0): res = [] dtype = self.vpm.pos_embed.data.dtype for pixel_value in pixel_values: @@ -532,7 +533,7 @@ def get_vision_embedding(self, num_prefix_tokens:] res.append(self.resampler(vision_embedding, tgt_size)) return torch.vstack(res) - elif version == 2.5: + elif version == (2, 5): vision_embedding = self.vpm( pixel_values.type(dtype), patch_attention_mask=patch_attn_mask).last_hidden_state @@ -574,7 +575,7 @@ def get_vision_hidden_states(self, data: Dict[str, pixel_values = data["pixel_values"] tgt_sizes = data["tgt_sizes"] vision_hidden_states = [] - if self.version == 2.0: + if self.version == (2, 0): if pixel_values is not None and len(pixel_values) > 0: vision_hidden_states = self.get_vision_embedding( pixel_values) @@ -598,7 +599,7 @@ def get_vision_hidden_states(self, data: Dict[str, patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device) - if self.version == 2.5: + if self.version == (2, 5): for i in range(B): patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True