From 344892caf1f45c7e07eeebf7d710d08dcfe05fc8 Mon Sep 17 00:00:00 2001 From: Alphi <52458637+HwwwwwwwH@users.noreply.github.com> Date: Wed, 31 Jul 2024 22:39:19 +0800 Subject: [PATCH] [Bugfix] Clean up MiniCPM-V (#6939) Co-authored-by: hezhihui Co-authored-by: Cyrus Leung --- docs/source/models/supported_models.rst | 6 +- vllm/model_executor/models/llama.py | 4 +- vllm/model_executor/models/minicpm.py | 4 +- vllm/model_executor/models/minicpmv.py | 249 +++++--- vllm/model_executor/models/na_vit.py | 804 ++++++++++++++++++++++++ vllm/model_executor/models/qwen2.py | 2 +- 6 files changed, 975 insertions(+), 94 deletions(-) create mode 100644 vllm/model_executor/models/na_vit.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 4fe33e5ab5d80..a1ea366b82b04 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -222,9 +222,13 @@ Vision Language Models - * - :code:`MiniCPM-V` - MiniCPM-V - - :code:`openbmb/MiniCPM-V-2`, :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc. + - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc. - +.. note:: + For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. + For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 + ---- If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 306d22e42ed1d..2052c443a8885 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -418,11 +418,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - input_embeds: Optional[torch.Tensor] = None ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - input_embeds) + attn_metadata, intermediate_tensors) return model_output def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 7a8ac0bb1f949..b46e88f5fc584 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -370,6 +370,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: if inputs_embeds is not None: @@ -463,11 +464,10 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - input_embeds: Optional[torch.Tensor] = None, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, input_embeds) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 8563216d9c392..2a7fe7ba0ebac 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -20,32 +20,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only MiniCPM-V-2 model compatible with HuggingFace weights.""" +"""Inference-only MiniCPM-V model compatible with HuggingFace weights.""" import math import re from functools import partial -from typing import Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F +import torch.types from PIL import Image from torch import nn from torch.nn.init import trunc_normal_ from transformers.configuration_utils import PretrainedConfig -from transformers.models.idefics2.modeling_idefics2 import ( - Idefics2VisionTransformer) from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsVision -from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.model_executor.models.minicpm import MiniCPMForCausalLM +from vllm.model_executor.models.llama import LlamaModel +from vllm.model_executor.models.minicpm import MiniCPMModel +from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import (cached_get_image_processor, @@ -53,12 +55,12 @@ from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData _KEYS_TO_MODIFY_MAPPING = { - "language_model.lm_head": "lm_head", - "language_model.model": "language_model", + "llm.lm_head": "lm_head", + "llm.model": "llm", } -def get_abs_pos(abs_pos, tgt_size): +def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor): # abs_pos: L, C # tgt_size: (H, W) # return: M, C @@ -75,10 +77,10 @@ def get_abs_pos(abs_pos, tgt_size): # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 -def get_2d_sincos_pos_embed(embed_dim, - grid_size, - cls_token=False, - version=2.0): +def get_2d_sincos_pos_embed(embed_dim: int, + grid_size: Union[int, Tuple[int, int]], + cls_token: bool = False, + version: Tuple[int, int] = (2, 0)): """ grid_size: int of the grid height and width return: @@ -95,7 +97,7 @@ def get_2d_sincos_pos_embed(embed_dim, grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) - if version == 2.0: + if version == (2, 0): grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) if cls_token: @@ -106,7 +108,9 @@ def get_2d_sincos_pos_embed(embed_dim, return pos_embed -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version=2.0): +def get_2d_sincos_pos_embed_from_grid(embed_dim: int, + grid: Union[int, Tuple[int, int]], + version: Tuple[int, int] = (2, 0)): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h @@ -115,14 +119,16 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version=2.0): emb_w = get_1d_sincos_pos_embed_from_grid( embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2) - if version == 2.0: + if version == (2, 0): emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) else: emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D) return emb -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, version=2.0): +def get_1d_sincos_pos_embed_from_grid(embed_dim: int, + pos: int, + version: Tuple[int, int] = (2, 0)): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) / (H, W) @@ -133,7 +139,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, version=2.0): omega /= embed_dim / 2. omega = 1. / 10000**omega # (D/2,) - if version == 2.0: + if version == (2, 0): pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) @@ -158,19 +164,19 @@ class Resampler(nn.Module): default_norm_layer = partial(nn.LayerNorm, eps=1e-6) def __init__(self, - num_queries, - grid_size, - embed_dim, - num_heads, - kv_dim=None, - norm_layer=default_norm_layer, - adaptive=False, - max_size=(70, 70), - version=2.0): + num_queries: int, + grid_size: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: nn.Module = default_norm_layer, + adaptive: bool = False, + max_size: Tuple[int, int] = (70, 70), + version: Tuple[int, int] = (2, 0)): super().__init__() self.version = version - if self.version == 2.0: + if self.version == (2, 0): self.num_queries = grid_size**2 else: self.num_queries = num_queries @@ -195,7 +201,7 @@ def __init__(self, self.proj = nn.Parameter( (embed_dim**-0.5) * torch.randn(embed_dim, embed_dim)) - if self.version == 2.0: + if self.version == (2, 0): self.pos_embed = nn.Parameter( torch.from_numpy( get_2d_sincos_pos_embed( @@ -206,14 +212,17 @@ def __init__(self, self.apply(self._init_weights) - def _set_2d_pos_cache(self, max_size, device='cpu'): + def _set_2d_pos_cache(self, + max_size: Tuple[int, int], + device: torch.types.Device = 'cpu'): pos_embed = torch.from_numpy( get_2d_sincos_pos_embed(self.embed_dim, max_size, version=self.version)).float().to(device) self.register_buffer("pos_embed", pos_embed, persistent=False) - def _adjust_pos_cache(self, tgt_sizes, device): + def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, + device: torch.types.Device): max_h = torch.max(tgt_sizes[:, 0]) max_w = torch.max(tgt_sizes[:, 1]) if max_h > self.max_size[0] or max_w > self.max_size[1]: @@ -223,7 +232,7 @@ def _adjust_pos_cache(self, tgt_sizes, device): ] self._set_2d_pos_cache(self.max_size, device) - def _init_weights(self, m): + def _init_weights(self, m: nn.Module): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: @@ -232,7 +241,9 @@ def _init_weights(self, m): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) - def forward_2_5(self, x, tgt_sizes=None): + def forward_2_5(self, + x: torch.Tensor, + tgt_sizes: Optional[torch.Tensor] = None): assert x.shape[0] == tgt_sizes.shape[0] bs = x.shape[0] @@ -278,7 +289,10 @@ def forward_2_5(self, x, tgt_sizes=None): x = x @ self.proj return x - def forward_2(self, x, tgt_sizes=None, attn_mask=None): + def forward_2(self, + x: torch.Tensor, + tgt_sizes: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None): if self.adaptive: pos_embed = torch.Tensor( get_2d_sincos_pos_embed(self.embed_dim, @@ -302,8 +316,11 @@ def forward_2(self, x, tgt_sizes=None, attn_mask=None): x = x @ self.proj return x - def forward(self, x, tgt_sizes=None, attn_mask=None): - if self.version == 2.0: + def forward(self, + x: torch.Tensor, + tgt_sizes: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None): + if self.version == (2, 0): return self.forward_2(x, tgt_sizes=tgt_sizes, attn_mask=attn_mask) else: return self.forward_2_5(x, tgt_sizes=tgt_sizes) @@ -322,7 +339,7 @@ def dummy_seq_data_for_minicpmv(seq_len: int): return SequenceData(token_ids) -def dummy_image_for_minicpmv(hf_config): +def dummy_image_for_minicpmv(hf_config: PretrainedConfig): width = height = hf_config.image_size image = Image.new("RGB", (width, height), color=0) return {"image": image} @@ -381,7 +398,7 @@ class MiniCPMV(nn.Module, SupportsVision): def __init__( self, - config, + config: PretrainedConfig, multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, @@ -390,30 +407,48 @@ def __init__( self.config = config self.multimodal_config = multimodal_config - self.version = float(self.config.version) + if not hasattr(self.config, "version"): + if self.config.hidden_size == 2304 and self.config.query_num == 64: + self.version = (2, 0) + else: + self.version = (2, 5) + else: + self.version = str(self.config.version).split(".") + self.version = tuple([int(x) for x in self.version]) self.llm = self.init_llm(config, cache_config, quant_config) self.vpm = self.init_vision_module() param_dtype = torch.get_default_dtype() self.vpm.to(dtype=param_dtype) - self.vision_dim = self.vpm.embed_dim if self.version == 2.0 \ + self.vision_dim = self.vpm.embed_dim if self.version == (2, 0) \ else self.vpm.embeddings.embed_dim - self.embed_dim = self.llm.config.hidden_size + self.embed_dim = self.config.hidden_size self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) self.resampler.to(device="cuda", dtype=param_dtype) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() - def init_llm(self, config, cache_config, quant_config): - if self.version == 2.0: - return MiniCPMForCausalLM(config, - cache_config=cache_config, - quant_config=quant_config) + def init_llm(self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + if self.version == (2, 0): + return MiniCPMModel(config, + cache_config=cache_config, + quant_config=quant_config) + elif self.version == (2, 5): + return LlamaModel(config, + cache_config=cache_config, + quant_config=quant_config) else: - return LlamaForCausalLM(config, - cache_config=cache_config, - quant_config=quant_config) + return Qwen2Model(config, + cache_config=cache_config, + quant_config=quant_config) def init_vision_module(self): - if self.version == 2.0: + if self.version == (2, 0): try: import timm except ImportError: @@ -433,16 +468,30 @@ def init_vision_module(self): if self.config.drop_vision_last_layer: model.blocks = model.blocks[:-1] - else: + elif self.version == (2, 5): + from transformers.models.idefics2.modeling_idefics2 import ( + Idefics2VisionTransformer) model = Idefics2VisionTransformer(self.config.vision_config) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] + else: + from vllm.model_executor.models.na_vit import ( + SiglipVisionTransformer) + if self.config._attn_implementation == 'flash_attention_2': + self.config.vision_config._attn_implementation \ + = 'flash_attention_2' + else: + # not support sdpa + self.config.vision_config._attn_implementation = 'eager' + model = SiglipVisionTransformer(self.config.vision_config) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] return model - def init_resampler(self, embed_dim, vision_dim): + def init_resampler(self, embed_dim: int, vision_dim: int): default_dtype = torch.get_default_dtype() torch.set_default_dtype(torch.float16) - if self.version == 2.0: + if self.version == (2, 0): resampler = Resampler(grid_size=int( math.sqrt(self.config.query_num)), num_queries=None, @@ -463,11 +512,11 @@ def init_resampler(self, embed_dim, vision_dim): return resampler def get_vision_embedding(self, - pixel_values, - patch_attn_mask=None, - tgt_sizes=None, - version=2.0): - if version == 2.0: + pixel_values: List[List[torch.Tensor]], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + version: Tuple[int, int] = (2, 0)): + if version == (2, 0): res = [] dtype = self.vpm.pos_embed.data.dtype for pixel_value in pixel_values: @@ -484,21 +533,32 @@ def get_vision_embedding(self, num_prefix_tokens:] res.append(self.resampler(vision_embedding, tgt_size)) return torch.vstack(res) - else: + elif version == (2, 5): vision_embedding = self.vpm( pixel_values.type(dtype), patch_attention_mask=patch_attn_mask).last_hidden_state vision_embedding = self.resampler(vision_embedding, tgt_sizes) + else: + vision_embedding = self.vpm(pixel_values.type(dtype), + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes).last_hidden_state - def get_image_bounds(self, input_ids): + def get_image_bounds(self, input_ids: torch.Tensor): tokenizer = cached_get_tokenizer(self.config._name_or_path, trust_remote_code=True) - im_start_token_id = tokenizer.im_start_id - im_end_token_id = tokenizer.im_end_id - image_start_tokens = torch.where(input_ids == im_start_token_id)[0] + if not hasattr(tokenizer, "slice_start_id"): + start_cond = input_ids == tokenizer.im_start_id + end_cond = input_ids == tokenizer.im_end_id + else: + start_cond = (input_ids == tokenizer.im_start_id) | ( + input_ids == tokenizer.slice_start_id) + end_cond = (input_ids == tokenizer.im_end_id) | ( + input_ids == tokenizer.slice_end_id) + + image_start_tokens = torch.where(start_cond)[0] image_start_tokens += 1 - image_end_tokens = torch.where(input_ids == im_end_token_id)[0] - valid_image_nums = min(len(image_start_tokens), len(image_end_tokens)) + image_end_tokens = torch.where(end_cond)[0] + valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) if valid_image_nums == 0: return [] image_bound = torch.hstack([ @@ -508,12 +568,14 @@ def get_image_bounds(self, input_ids): return image_bound - def get_vision_hidden_states(self, data): + def get_vision_hidden_states(self, data: Dict[str, + Union[List[torch.Tensor], + torch.Tensor]]): if "vision_hidden_states" not in data: pixel_values = data["pixel_values"] tgt_sizes = data["tgt_sizes"] vision_hidden_states = [] - if self.version == 2.0: + if self.version == (2, 0): if pixel_values is not None and len(pixel_values) > 0: vision_hidden_states = self.get_vision_embedding( pixel_values) @@ -534,17 +596,26 @@ def get_vision_hidden_states(self, data): B, L, _ = all_pixel_values.shape all_pixel_values = all_pixel_values.permute( 0, 2, 1).reshape(B, 3, -1, L) - patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device) - for i in range(B): - patch_attn_mask[i, :tgt_sizes[i][0] * - tgt_sizes[i][1]] = True + if self.version == (2, 5): + for i in range(B): + patch_attn_mask[i, :tgt_sizes[i][0] * + tgt_sizes[i][1]] = True + vision_embedding = self.vpm( + all_pixel_values.type(dtype), + patch_attention_mask=patch_attn_mask + ).last_hidden_state + else: + for i in range(B): + patch_attn_mask[i, 0, :tgt_sizes[i][0] * + tgt_sizes[i][1]] = True + vision_embedding = self.vpm( + all_pixel_values.type(dtype), + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes).last_hidden_state - vision_embedding = self.vpm( - all_pixel_values.type(dtype), - patch_attention_mask=patch_attn_mask).last_hidden_state vision_hidden_states = self.resampler( vision_embedding, tgt_sizes) @@ -556,7 +627,8 @@ def get_vision_hidden_states(self, data): return vision_hidden_states - def get_embedding(self, data): + def get_embedding(self, data: Dict[str, Union[List[torch.Tensor], + torch.Tensor]]): input_ids = data["input_ids"] vision_hidden_states = self.get_vision_hidden_states(data) @@ -565,11 +637,11 @@ def get_embedding(self, data): else: image_bounds = [] - if hasattr(self.llm.config, 'scale_emb'): - vlm_embedding = self.llm.model.embed_tokens( - input_ids) * self.llm.config.scale_emb + if hasattr(self.config, 'scale_emb'): + vlm_embedding = self.llm.embed_tokens( + input_ids) * self.config.scale_emb else: - vlm_embedding = self.llm.model.embed_tokens(input_ids) + vlm_embedding = self.llm.embed_tokens(input_ids) vision_hidden_states = [ i.type(vlm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states @@ -587,7 +659,9 @@ def get_embedding(self, data): vision_hidden_states.view(-1, vision_hidden_states.shape[-1])) return vlm_embedding, vision_hidden_states - def process_multimodal_inputs(self, inputs): + def process_multimodal_inputs(self, inputs: Dict[str, + Union[List[torch.Tensor], + torch.Tensor]]): pixel_values = [] tgt_sizes = [] for b in range(len(inputs["pixel_values"])): @@ -613,7 +687,6 @@ def forward( "input_ids": input_ids, "tgt_sizes": kwargs.pop("tgt_sizes", None), } - inputs = self.process_multimodal_inputs(inputs) vlm_embeddings, vision_hidden_states = self.get_embedding(inputs) @@ -623,19 +696,21 @@ def forward( kv_caches=kv_caches, attn_metadata=attn_metadata, intermediate_tensors=intermediate_tensors, - input_embeds=vlm_embeddings) + inputs_embeds=vlm_embeddings) return output def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - return self.llm.compute_logits(hidden_states, sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.llm.sample(logits, sampling_metadata) + next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): @@ -649,9 +724,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - # for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - # if key_to_modify in name: - # name = name.replace(key_to_modify, new_key) + for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in name: + name = name.replace(key_to_modify, new_key) if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name diff --git a/vllm/model_executor/models/na_vit.py b/vllm/model_executor/models/na_vit.py new file mode 100644 index 0000000000000..871e4128b66e1 --- /dev/null +++ b/vllm/model_executor/models/na_vit.py @@ -0,0 +1,804 @@ +import logging +import math +import os +import warnings +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.init import _calculate_fan_in_and_fan_out +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask +from transformers.modeling_outputs import (BaseModelOutput, + BaseModelOutputWithPooling) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import (ModelOutput, is_flash_attn_2_available, + replace_return_docstrings) + +logger = logging.getLogger("vllm") + + +# For Siglip: copied from +# HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes +# Remove hints as there's little possibility to change these code. +class SiglipVisionConfig(PretrainedConfig): + + model_type = "siglip_vision_model" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, + os.PathLike], + **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from SiglipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr( + cls, + "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + "You are using a model of type %s to " + "instantiate a model of type %s. " + "This is not supported for all configurations" + "of models and can yield errors.", config_dict['model_type'], + cls.model_type) + + return cls.from_dict(config_dict, **kwargs) + + +_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" + +SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "google/siglip-base-patch16-224", + # See all SigLIP models at https://huggingface.co/models?filter=siglip +] + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import pad_input # noqa + from flash_attn.bert_padding import index_first_axis, unpad_input + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _trunc_normal_(tensor, mean, std, a, b): + + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l_ = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l_ - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + if tensor.dtype in [torch.float16, torch.bfloat16]: + # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu + og_dtype = tensor.dtype + tensor = tensor.to(torch.float32) + tensor.erfinv_() + tensor = tensor.to(og_dtype) + else: + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + if tensor.dtype == torch.float16: + # The `clamp_` op is not (yet?) defined in float16+cpu + tensor = tensor.to(torch.float32) + tensor.clamp_(min=a, max=b) + tensor = tensor.to(torch.float16) + else: + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_(tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0) -> torch.Tensor: + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +class SiglipVisionModelOutput(ModelOutput): + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class SiglipVisionEmbeddings(nn.Module): + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, + self.embed_dim) + + def forward(self, + pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor, + tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor: + batch_size = pixel_values.size(0) + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) + max_nb_patches_h, max_nb_patches_w = (max_im_h // self.patch_size, + max_im_w // self.patch_size) + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, + 1 / self.num_patches_per_side) + position_ids = torch.full( + size=( + batch_size, + max_nb_patches_h * max_nb_patches_w, + ), + fill_value=0, + ) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + if tgt_sizes is not None: + nb_patches_h = tgt_sizes[batch_idx][0] + nb_patches_w = tgt_sizes[batch_idx][1] + else: + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize(fractional_coords_h, + boundaries, + right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, + boundaries, + right=True) + + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads (got `embed_dim`: " + f"{self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = torch.matmul(query_states, key_states.transpose( + 2, 3)) * self.scale + + if attn_weights.size() != (batch_size, self.num_heads, q_len, + k_v_seq_len): + raise ValueError( + "Attention weights should be of size " + f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}") + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + "Attention mask should be of size " + f"{(batch_size, 1, q_len, k_v_seq_len)}", + f"but is {attention_mask.size()}") + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, + dim=-1, + dtype=torch.float32).to( + query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, + p=self.dropout, + training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, + self.head_dim): + raise ValueError( + "`attn_output` should be of size " + f"{(batch_size, self.num_heads, q_len, self.head_dim)}, " + "but is" + f" {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class SiglipFlashAttention2(SiglipAttention): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False # Hack to make sure we don't use a causal mask + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length( + kv_seq_len, self.layer_idx) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.dropout if self.training else 0.0 + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning( + "The input hidden states seems to be " + "silently casted in float32, " + "this might be related to the fact " + "you have upcasted embedding or layer norm layers in float32. " + "We will cast back the input in" + " %s.", target_dtype) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward(query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate) + + attn_output = attn_output.reshape(bsz, q_len, + self.embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + def _flash_attention_forward(self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None): + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + (query_states, key_states, value_states, indices_q, cu_seq_lens, + max_seq_lens) = self._upad_input(query_states, key_states, + value_states, attention_mask, + query_length) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, + query_length) + else: + attn_output = flash_attn_func(query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, + query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, + head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + (query_layer, indices_q, cu_seqlens_q, + max_seqlen_in_batch_q) = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip +class SiglipMLP(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer +# with CLIP->Siglip +class SiglipEncoderLayer(nn.Module): + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self._use_flash_attention_2 = ( + config._attn_implementation == "flash_attention_2") + self.self_attn = (SiglipAttention(config) + if not self._use_flash_attention_2 else + SiglipFlashAttention2(config)) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, ) + + if output_attentions: + outputs += (attn_weights, ) + + return outputs + + +class SiglipPreTrainedModel(PreTrainedModel): + config_class = SiglipVisionConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + + if isinstance(module, SiglipVisionEmbeddings): + width = self.config.hidden_size + nn.init.normal_(module.position_embedding.weight, + std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SiglipAttention): + nn.init.normal_(module.q_proj.weight) + nn.init.normal_(module.k_proj.weight) + nn.init.normal_(module.v_proj.weight) + nn.init.normal_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, SiglipMLP): + nn.init.normal_(module.fc1.weight) + nn.init.normal_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder +# with CLIP->Siglip +class SiglipEncoder(nn.Module): + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers) + ]) + self.gradient_checkpointing = False + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None \ + else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None \ + else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1], ) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + + if not return_dict: + return tuple( + v for v in [hidden_states, encoder_states, all_attentions] + if v is not None) + return BaseModelOutput(last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions) + + +class SiglipVisionTransformer(SiglipPreTrainedModel): + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + _supports_flash_attn_2 = True + + def __init__(self, config: SiglipVisionConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + self._use_flash_attention_2 = ( + config._attn_implementation == "flash_attention_2") + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.embeddings.patch_embedding + + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, + config_class=SiglipVisionConfig) + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + tgt_sizes: Optional[torch.IntTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + """ + output_attentions = output_attentions if output_attentions is not None \ + else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None \ + else self.config.use_return_dict + + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_attention_mask = torch.ones( + size=( + batch_size, + pixel_values.size(2) // self.config.patch_size, + pixel_values.size(3) // self.config.patch_size, + ), + dtype=torch.bool, + device=pixel_values.device, + ) + + hidden_states = self.embeddings( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + tgt_sizes=tgt_sizes) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s + # (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, + # which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + attention_mask = None + else: + attention_mask = (_prepare_4d_attention_mask( + patch_attention_mask, hidden_states.dtype) + if not self._use_flash_attention_2 else + patch_attention_mask) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + if not return_dict: + return (last_hidden_state, None) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=None, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 3deb3d8840cc4..35fd6f37589a0 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -342,7 +342,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits(self, hidden_states: torch.Tensor,