diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 18dbee94e10b0..a9dbb3823743a 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -14,8 +14,6 @@ _num_image_tokens) from transformers.models.pixtral.modeling_pixtral import ( PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid) -from xformers.ops.fmha import memory_efficient_attention -from xformers.ops.fmha.attn_bias import BlockDiagonalMask from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, MultiModalConfig @@ -38,6 +36,12 @@ from .interfaces import SupportsMultiModal, SupportsPP from .utils import init_vllm_registered_model +try: + from xformers import ops as xops + USE_XFORMERS_OPS = True +except ImportError: + USE_XFORMERS_OPS = False + def get_max_pixtral_image_tokens(ctx: InputContext): tokenizer = cached_get_tokenizer( @@ -416,7 +420,7 @@ def __init__(self, args: VisionEncoderArgs): def forward( self, x: torch.Tensor, - mask: BlockDiagonalMask, + mask: torch.Tensor, freqs_cis: torch.Tensor, ) -> torch.Tensor: batch, patches, _ = x.shape @@ -427,7 +431,7 @@ def forward( v = v.reshape(batch, patches, self.n_heads, self.head_dim) q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis) - out = memory_efficient_attention(q, k, v, attn_bias=mask) + out = xops.memory_efficient_attention(q, k, v, attn_bias=mask) out = out.reshape(batch, patches, self.n_heads * self.head_dim) return self.wo(out) @@ -444,7 +448,7 @@ def __init__(self, args: VisionEncoderArgs): def forward( self, x: torch.Tensor, - mask: BlockDiagonalMask, + mask: torch.Tensor, freqs_cis: torch.Tensor, ) -> torch.Tensor: r = self.attention.forward(self.attention_norm(x), @@ -467,7 +471,7 @@ def __init__(self, args: VisionEncoderArgs): def forward( self, x: torch.Tensor, - mask: BlockDiagonalMask, + mask: torch.Tensor, freqs_cis: Optional[torch.Tensor], ) -> torch.Tensor: for layer in self.layers: @@ -562,8 +566,12 @@ def forward( freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] # pass through Transformer with a block diagonal mask delimiting images - mask = BlockDiagonalMask.from_seqlens( - [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) + if USE_XFORMERS_OPS: + mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) + else: + raise ImportError("Xformers is required for Pixtral inference " + "with the Mistral format") out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) # remove batch dimension of the single sequence @@ -828,7 +836,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - attention_mask: BlockDiagonalMask, + attention_mask: torch.Tensor, position_embeddings: torch.Tensor, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: batch, patches, _ = hidden_states.size() @@ -843,12 +851,23 @@ def forward( cos, sin = position_embeddings q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0) - # Transpose q and k back for attention - q = q.transpose(1, 2).contiguous() - k = k.transpose(1, 2).contiguous() - v = v.reshape(batch, patches, self.n_heads, self.head_dim) + if USE_XFORMERS_OPS: + # Transpose q and k back for attention + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + v = v.reshape(batch, patches, self.n_heads, self.head_dim) + + out = xops.memory_efficient_attention(q, + k, + v, + attn_bias=attention_mask) + else: + v = v.reshape(batch, patches, self.n_heads, + self.head_dim).transpose(1, 2) + out = nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attention_mask) + out = out.transpose(1, 2) - out = memory_efficient_attention(q, k, v, attn_bias=attention_mask) out = out.reshape(batch, patches, self.n_heads * self.head_dim) return self.o_proj(out) @@ -877,7 +896,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - attention_mask: BlockDiagonalMask, + attention_mask: torch.Tensor, position_embeddings: torch.Tensor, ) -> torch.Tensor: r = self.attention.forward(self.attention_norm(hidden_states), @@ -916,7 +935,7 @@ def __init__( def forward( self, x: torch.Tensor, - attention_mask: BlockDiagonalMask, + attention_mask: torch.Tensor, position_embeddings: torch.Tensor, ) -> torch.Tensor: for layer in self.layers: @@ -1000,11 +1019,19 @@ def forward( patch_embeds_list, max_width=self.config.image_size // self.config.patch_size).to( self.device) - position_embedding = self.patch_positional_embedding( patch_embeds, position_ids) - attention_mask = BlockDiagonalMask.from_seqlens( - [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) + + if USE_XFORMERS_OPS: + attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) + else: + from transformers.models.pixtral.modeling_pixtral import ( + generate_block_attention_mask) + attention_mask = generate_block_attention_mask( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], + patch_embeds) + out = self.transformer(patch_embeds, attention_mask, position_embedding)