From 2510d9301fa855e00d53b950ef6f45da660aaea4 Mon Sep 17 00:00:00 2001 From: Neelesh Gokhale Date: Mon, 2 Dec 2024 10:28:05 +0000 Subject: [PATCH 1/7] Add Qwen2-VL --- .../generation/stopping_criteria.py | 2 + optimum/habana/transformers/modeling_utils.py | 20 + .../habana/transformers/models/__init__.py | 9 + .../transformers/models/qwen2_vl/__init__.py | 9 + .../models/qwen2_vl/modeling_qwen2_vl.py | 774 ++++++++++++++++++ 5 files changed, 814 insertions(+) create mode 100644 optimum/habana/transformers/models/qwen2_vl/__init__.py create mode 100644 optimum/habana/transformers/models/qwen2_vl/modeling_qwen2_vl.py diff --git a/optimum/habana/transformers/generation/stopping_criteria.py b/optimum/habana/transformers/generation/stopping_criteria.py index 844ffa50f2..e442f58d72 100644 --- a/optimum/habana/transformers/generation/stopping_criteria.py +++ b/optimum/habana/transformers/generation/stopping_criteria.py @@ -17,6 +17,7 @@ import time from typing import Union +import habana_frameworks.torch.core as htcore import torch from optimum.utils import logging @@ -67,6 +68,7 @@ def gaudi_MaxTimeCriteria_call( def gaudi_EosTokenCriteria_call( self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs ) -> Union[torch.BoolTensor, bool]: + htcore.mark_step() self.eos_token_id = self.eos_token_id.to(input_ids.device) token_idx = kwargs.get("token_idx", None) if token_idx is not None: diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 27f4de8820..efc10ed1cb 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -139,6 +139,13 @@ GaudiQwen2MoeForCausalLM, GaudiQwen2MoeMLP, GaudiQwen2MoeModel, + GaudiQwen2VisionSdpaAttention, + GaudiQwen2VisionTransformerPretrainedModel, + GaudiQwen2VLDecoderLayer, + GaudiQwen2VLForConditionalGeneration, + GaudiQwen2VLModel, + GaudiQwen2VLSdpaAttention, + GaudiQwen2VLVisionBlock, GaudiStableLmAttention, GaudiStableLmDecoderLayer, GaudiStableLmForCausalLM, @@ -648,6 +655,19 @@ def adapt_transformers_to_gaudi(): gaudi_qwen2moe_block_sparse_moe_forward ) + # Optimization for qwen2-vl Gaudi + transformers.models.qwen2_vl.modeling_qwen2_vl.VisionSdpaAttention = GaudiQwen2VisionSdpaAttention + transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLVisionBlock = GaudiQwen2VLVisionBlock + transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VisionTransformerPretrainedModel = ( + GaudiQwen2VisionTransformerPretrainedModel + ) + transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLSdpaAttention = GaudiQwen2VLSdpaAttention + transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLDecoderLayer = GaudiQwen2VLDecoderLayer + transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLModel = GaudiQwen2VLModel + transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration = ( + GaudiQwen2VLForConditionalGeneration + ) + # Optimization for stablelm on Gaudi transformers.models.stablelm.modeling_stablelm.StableLmAttention = GaudiStableLmAttention transformers.models.stablelm.modeling_stablelm.StableLmDecoderLayer = GaudiStableLmDecoderLayer diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index ffcfa4ccbb..80fdbe9c02 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -256,6 +256,15 @@ gaudi_qwen2moe_block_sparse_moe_forward, gaudi_qwen2moe_rmsnorm_forward, ) +from .qwen2_vl import ( + GaudiQwen2VisionSdpaAttention, + GaudiQwen2VisionTransformerPretrainedModel, + GaudiQwen2VLDecoderLayer, + GaudiQwen2VLForConditionalGeneration, + GaudiQwen2VLModel, + GaudiQwen2VLSdpaAttention, + GaudiQwen2VLVisionBlock, +) from .seamless_m4t import ( gaudi_SeamlessM4TAttention_forward, gaudi_SeamlessM4TCodeHifiGan_get_output_hifigan_lengths, diff --git a/optimum/habana/transformers/models/qwen2_vl/__init__.py b/optimum/habana/transformers/models/qwen2_vl/__init__.py new file mode 100644 index 0000000000..72a587c799 --- /dev/null +++ b/optimum/habana/transformers/models/qwen2_vl/__init__.py @@ -0,0 +1,9 @@ +from .modeling_qwen2_vl import ( + GaudiQwen2VisionSdpaAttention, + GaudiQwen2VisionTransformerPretrainedModel, + GaudiQwen2VLDecoderLayer, + GaudiQwen2VLForConditionalGeneration, + GaudiQwen2VLModel, + GaudiQwen2VLSdpaAttention, + GaudiQwen2VLVisionBlock, +) diff --git a/optimum/habana/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/optimum/habana/transformers/models/qwen2_vl/modeling_qwen2_vl.py new file mode 100644 index 0000000000..ade546fa17 --- /dev/null +++ b/optimum/habana/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -0,0 +1,774 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Gaudi Qwen2-VL model.""" + +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss +from transformers.cache_utils import Cache, StaticCache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, +) +from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VisionTransformerPretrainedModel, + Qwen2VLAttention, + Qwen2VLCausalLMOutputWithPast, + Qwen2VLConfig, + Qwen2VLDecoderLayer, + Qwen2VLFlashAttention2, + Qwen2VLForConditionalGeneration, + Qwen2VLModel, + Qwen2VLSdpaAttention, + Qwen2VLVisionBlock, + VisionAttention, + VisionFlashAttention2, + VisionSdpaAttention, + _prepare_4d_causal_attention_mask_with_cache_position, + apply_multimodal_rotary_pos_emb, + apply_rotary_pos_emb_vision, + repeat_kv, +) +from transformers.utils import logging + + +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA +except ImportError: + print("Not using HPU fused scaled dot-product attention kernel.") + FusedSDPA = None + +logger = logging.get_logger(__name__) + + +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA + + def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode): + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) + + +# from: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +class GaudiQwen2VisionSdpaAttention(VisionSdpaAttention): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__(dim, num_heads) + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor = None, + use_flash_attention: Optional[bool] = False, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[:, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + + if FusedSDPA is not None and use_flash_attention: + attn_output = self.fused_scaled_dot_product_attention(q, k, v, attention_mask, 0.0, False, None, "None") + else: + attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) + + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + del attention_mask + return attn_output + + +GAUDI_QWEN2_VL_VISION_ATTENTION_CLASSES = { + "eager": VisionAttention, + "flash_attention_2": VisionFlashAttention2, + "sdpa": GaudiQwen2VisionSdpaAttention, +} + + +class GaudiQwen2VLVisionBlock(Qwen2VLVisionBlock): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__(config, attn_implementation) + + self.attn = GAUDI_QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation]( + config.embed_dim, num_heads=config.num_heads + ) + + def forward( + self, + hidden_states, + cu_seqlens, + rotary_pos_emb, + use_flash_attention: Optional[bool] = False, + ) -> torch.Tensor: + """ + Copied from https://github.com/huggingface/transformers/blob/53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L418 + The only differences are: + - add new args use_flash_attention + """ + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + use_flash_attention=use_flash_attention, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class GaudiQwen2VisionTransformerPretrainedModel(Qwen2VisionTransformerPretrainedModel): + def forward( + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + use_flash_attention: Optional[bool] = False, + ) -> torch.Tensor: + """ + Copied from https://github.com/huggingface/transformers/blob/53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1118 + The only differences are: + - add new args use_flash_attention + """ + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32 + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for blk in self.blocks: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + use_flash_attention=use_flash_attention, + ) + + return self.merger(hidden_states) + + +class GaudiQwen2VLSdpaAttention(Qwen2VLSdpaAttention): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None + + # Adapted from Qwen2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + use_flash_attention: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Copied from https://github.com/huggingface/transformers/blob/53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L821 + The only differences are: + - add new args use_flash_attention + - add work around for bfloat16 FusedSDPA accuracy issue by using float16 + """ + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2VLModel is using Qwen2VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + if FusedSDPA is not None and use_flash_attention: + input_dtype = query_states.dtype + # For accuracy + target_dtype = torch.float16 + if input_dtype != target_dtype: + warnings.warn("FusedSDPA Type conversion for Accuracy") + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + attn_output = self.fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + causal_mask, + self.attention_dropout if self.training else 0.0, + is_causal, + None, # scale + "None", #'fast' + ) + if input_dtype != target_dtype: + attn_output = attn_output.to(input_dtype) + else: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +GAUDI_QWEN2_VL_ATTENTION_CLASSES = { + "eager": Qwen2VLAttention, + "flash_attention_2": Qwen2VLFlashAttention2, + "sdpa": GaudiQwen2VLSdpaAttention, +} + + +class GaudiQwen2VLDecoderLayer(Qwen2VLDecoderLayer): + def __init__(self, config: Qwen2VLConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.self_attn = GAUDI_QWEN2_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Copied from https://github.com/huggingface/transformers/blob/53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L821 + The only differences are: + - add new args use_flash_attention + """ + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + use_flash_attention = kwargs.get("use_flash_attention", None) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + use_flash_attention=use_flash_attention, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class GaudiQwen2VLModel(Qwen2VLModel): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + use_flash_attention: Optional[bool] = False, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """ + Inherits from Qwen2VLModel https://github.com/huggingface/transformers/blob/53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1137 + The only differences are: + - add new arg use_flash_attention + - fixes graph recompilation due to torch.arange + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + # causes graph recompilations + # cache_position = torch.arange( + # past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + # ) + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + use_flash_attention=use_flash_attention, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class GaudiQwen2VLForConditionalGeneration(Qwen2VLForConditionalGeneration): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: + """ + Inherits from Qwen2VLForConditionalGeneration::https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py + The only differences are: + - add new arg token_idx + - add new arg use_flash_attention + - add Gaudi Example + """ + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor + >>> from optimum.habana.transformers.models import GaudiQwen2VLForConditionalGeneration + >>> from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi + >>> from habana_frameworks.torch.hpu import wrap_in_hpu_graph + >>> adapt_transformers_to_gaudi() + >>> model = GaudiQwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + >>> model = model.to("hpu") + >>> wrap_in_hpu_graph(model) + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], return_tensors="pt") + >>> inputs = inputs.to("hpu") + >>> generate_kwargs = { + "lazy_mode": True, + "hpu_graphs": True, + "static_shapes": True, + "use_cache": True, + "cache_implementation": "static", + "use_flash_attention": True + } + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=30, **generate_kwargs) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene in what appears to be a Chinatown area. The focal point is a red stop sign on the left side of the..." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.get_dtype()) + image_embeds = self.visual( + pixel_values, grid_thw=image_grid_thw, use_flash_attention=use_flash_attention + ) + image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_flash_attention=use_flash_attention, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, + ): + """ + Inherits from Qwen2VLForConditionalGeneration https://github.com/huggingface/transformers/blob/53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1748 + The only differences are: + - handle new args token_idx + - handle new args use_flash_attention + """ + token_idx = kwargs.get("token_idx", None) + use_flash_attention = kwargs.get("use_flash_attention", False) + if token_idx is not None: + if isinstance(past_key_values, StaticCache): + if cache_position.shape[0] > 1: + input_ids = input_ids[:, :token_idx] + attention_mask = attention_mask[:, :token_idx] + cache_position = cache_position[:token_idx] + else: + # over-write with token idx + cache_position[0] = token_idx - 1 + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + rope_deltas = kwargs.get("rope_deltas", None) + if attention_mask is not None and position_ids is None: + if cache_position is None or (cache_position is not None and cache_position[0] == 0): + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, attention_mask + ) + else: + batch_size, seq_length = input_ids.shape + delta = ( + cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0 + ) + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + if cache_position[0] != 0: + pixel_values = None + pixel_values_videos = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + dtype = self.lm_head.weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "image_grid_thw": image_grid_thw, + "video_grid_thw": video_grid_thw, + "rope_deltas": rope_deltas, + "token_idx": token_idx, + "use_flash_attention": use_flash_attention, + } + ) + + return model_inputs From 8693bae7bf7ca574d2d4cdc47ef988b8f80bb6b7 Mon Sep 17 00:00:00 2001 From: Neelesh Gokhale Date: Wed, 4 Dec 2024 05:55:47 +0000 Subject: [PATCH 2/7] Remove change in EosTokenCriteria --- optimum/habana/transformers/generation/stopping_criteria.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/optimum/habana/transformers/generation/stopping_criteria.py b/optimum/habana/transformers/generation/stopping_criteria.py index e442f58d72..844ffa50f2 100644 --- a/optimum/habana/transformers/generation/stopping_criteria.py +++ b/optimum/habana/transformers/generation/stopping_criteria.py @@ -17,7 +17,6 @@ import time from typing import Union -import habana_frameworks.torch.core as htcore import torch from optimum.utils import logging @@ -68,7 +67,6 @@ def gaudi_MaxTimeCriteria_call( def gaudi_EosTokenCriteria_call( self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs ) -> Union[torch.BoolTensor, bool]: - htcore.mark_step() self.eos_token_id = self.eos_token_id.to(input_ids.device) token_idx = kwargs.get("token_idx", None) if token_idx is not None: From ae0dd23caf72ba6271b60fd164cda5069a55154b Mon Sep 17 00:00:00 2001 From: Neelesh Gokhale Date: Thu, 12 Dec 2024 09:17:51 +0000 Subject: [PATCH 3/7] Remove accuracy work around --- .../transformers/models/qwen2_vl/modeling_qwen2_vl.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/optimum/habana/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/optimum/habana/transformers/models/qwen2_vl/modeling_qwen2_vl.py index ade546fa17..fe05ea4adf 100644 --- a/optimum/habana/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/optimum/habana/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch Gaudi Qwen2-VL model.""" -import warnings from typing import List, Optional, Tuple, Union import torch @@ -266,14 +265,6 @@ def forward( is_causal = True if causal_mask is None and q_len > 1 else False if FusedSDPA is not None and use_flash_attention: - input_dtype = query_states.dtype - # For accuracy - target_dtype = torch.float16 - if input_dtype != target_dtype: - warnings.warn("FusedSDPA Type conversion for Accuracy") - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, @@ -284,8 +275,6 @@ def forward( None, # scale "None", #'fast' ) - if input_dtype != target_dtype: - attn_output = attn_output.to(input_dtype) else: attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, From d9af41c9b1d5532d797afccd4260269dee4b5c64 Mon Sep 17 00:00:00 2001 From: Neelesh Gokhale Date: Thu, 12 Dec 2024 09:18:58 +0000 Subject: [PATCH 4/7] Add qwen2-vl to image-to-text pipeline example --- examples/image-to-text/README.md | 1 - examples/image-to-text/run_pipeline.py | 15 ++++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md index 7a8ad04664..9259e7b126 100644 --- a/examples/image-to-text/README.md +++ b/examples/image-to-text/README.md @@ -33,7 +33,6 @@ python3 run_pipeline.py \ ``` > SDPA may introduce [reduced precison](https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-reduction-for-fp16-and-bf16-in-scaled-dot-product-attention-sdpa) - ### Multi-cards inference with BF16 Use the following commands to run Llama-3.2-90B-Vision-Instruct BF16 inference with FusedSDPA on 8 HPUs: diff --git a/examples/image-to-text/run_pipeline.py b/examples/image-to-text/run_pipeline.py index cc19de3b83..236abdf7c4 100644 --- a/examples/image-to-text/run_pipeline.py +++ b/examples/image-to-text/run_pipeline.py @@ -222,7 +222,8 @@ def main(): config = AutoConfig.from_pretrained(args.model_name_or_path) model_type = config.model_type - if args.image_path is None and model_type in ["llava", "idefics2", "mllama"]: + + if args.image_path is None and model_type in ["llava", "idefics2", "mllama", "qwen2_vl"]: args.image_path = ["https://llava-vl.github.io/static/images/view.jpg"] elif args.image_path is None and model_type == "paligemma": args.image_path = [ @@ -233,7 +234,7 @@ def main(): "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" ] - if model_type in ["llava", "idefics2", "llava_next", "mllama", "paligemma"]: + if model_type in ["llava", "idefics2", "llava_next", "mllama", "paligemma", "qwen2_vl"]: processor = AutoProcessor.from_pretrained(args.model_name_or_path, padding_side="left") if args.prompt is None: if processor.chat_template is not None: @@ -312,6 +313,9 @@ def main(): generator = pipeline( "image-to-text", model=args.model_name_or_path, + config=args.model_name_or_path, + tokenizer=args.model_name_or_path, + image_processor=args.model_name_or_path, torch_dtype=model_dtype, device="hpu", ) @@ -340,13 +344,18 @@ def main(): if args.use_kv_cache: generate_kwargs["use_cache"] = args.use_kv_cache + if model_type == "qwen2_vl": + generate_kwargs["use_cache"] = True + generate_kwargs["cache_implementation"] = "static" + generate_kwargs["static_shapes"] = True + if args.quant_config: generator.model = setup_quantization(generator.model, args) htcore.hpu_initialize(generator.model) # delete once pipeline integrate AutoProcessor as preprocess engine # could use "image-text-to-text" pipeline in transformers 4.47 - if model_type in ["idefics2", "mllama", "paligemma"]: + if model_type in ["idefics2", "mllama", "paligemma", "qwen2_vl"]: from transformers.image_utils import load_image def preprocess(self, image, prompt=None, timeout=None): From b3337918c78513d4bddfeef830d5876d145e4b8a Mon Sep 17 00:00:00 2001 From: Neelesh Gokhale Date: Thu, 12 Dec 2024 11:03:12 +0000 Subject: [PATCH 5/7] Add Qwen2-VL to image-to-text example tests --- tests/test_image_to_text_example.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_image_to_text_example.py b/tests/test_image_to_text_example.py index 921f59ad68..538ca8c182 100644 --- a/tests/test_image_to_text_example.py +++ b/tests/test_image_to_text_example.py @@ -23,6 +23,8 @@ ("HuggingFaceM4/idefics2-8b", 1, 21.89944593215077), ("meta-llama/Llama-3.2-11B-Vision-Instruct", 1, 18.974541922240313), ("tiiuae/falcon-11B-vlm", 1, 23.69260849957278), + ("Qwen/Qwen2-VL-2B-Instruct", 1, 28.755882208438422), + ("Qwen/Qwen2-VL-7B-Instruct", 1, 19.32562189532818), ], "fp8": [ # ("llava-hf/llava-1.5-7b-hf", 1, 98.72578382705062), From 7a354b41ed9f470176d4e88e10cd0cc4f6323dfd Mon Sep 17 00:00:00 2001 From: Neelesh Gokhale Date: Mon, 13 Jan 2025 12:43:58 +0000 Subject: [PATCH 6/7] Add source and changes, Remove other attn classes --- .../models/qwen2_vl/modeling_qwen2_vl.py | 54 ++++++++----------- 1 file changed, 23 insertions(+), 31 deletions(-) diff --git a/optimum/habana/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/optimum/habana/transformers/models/qwen2_vl/modeling_qwen2_vl.py index fe05ea4adf..d2f0706dd6 100644 --- a/optimum/habana/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/optimum/habana/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -25,17 +25,13 @@ ) from transformers.models.qwen2_vl.modeling_qwen2_vl import ( Qwen2VisionTransformerPretrainedModel, - Qwen2VLAttention, Qwen2VLCausalLMOutputWithPast, Qwen2VLConfig, Qwen2VLDecoderLayer, - Qwen2VLFlashAttention2, Qwen2VLForConditionalGeneration, Qwen2VLModel, Qwen2VLSdpaAttention, Qwen2VLVisionBlock, - VisionAttention, - VisionFlashAttention2, VisionSdpaAttention, _prepare_4d_causal_attention_mask_with_cache_position, apply_multimodal_rotary_pos_emb, @@ -63,7 +59,7 @@ def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, sof return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) -# from: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +# from: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L383 class GaudiQwen2VisionSdpaAttention(VisionSdpaAttention): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__(dim, num_heads) @@ -76,6 +72,12 @@ def forward( rotary_pos_emb: torch.Tensor = None, use_flash_attention: Optional[bool] = False, ) -> torch.Tensor: + """ + Copied from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L390 + The only differences are: + - add new args use_flash_attention + - add FusedSDPA + """ seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) @@ -100,20 +102,12 @@ def forward( return attn_output -GAUDI_QWEN2_VL_VISION_ATTENTION_CLASSES = { - "eager": VisionAttention, - "flash_attention_2": VisionFlashAttention2, - "sdpa": GaudiQwen2VisionSdpaAttention, -} - - +# from: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L418 class GaudiQwen2VLVisionBlock(Qwen2VLVisionBlock): def __init__(self, config, attn_implementation: str = "sdpa") -> None: super().__init__(config, attn_implementation) - self.attn = GAUDI_QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation]( - config.embed_dim, num_heads=config.num_heads - ) + self.attn = GaudiQwen2VisionSdpaAttention(config.embed_dim, num_heads=config.num_heads) def forward( self, @@ -123,7 +117,7 @@ def forward( use_flash_attention: Optional[bool] = False, ) -> torch.Tensor: """ - Copied from https://github.com/huggingface/transformers/blob/53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L418 + Copied from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L430 The only differences are: - add new args use_flash_attention """ @@ -137,6 +131,7 @@ def forward( return hidden_states +# from: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1058 class GaudiQwen2VisionTransformerPretrainedModel(Qwen2VisionTransformerPretrainedModel): def forward( self, @@ -168,6 +163,7 @@ def forward( return self.merger(hidden_states) +# from: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L821 class GaudiQwen2VLSdpaAttention(Qwen2VLSdpaAttention): """ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -194,10 +190,10 @@ def forward( use_flash_attention: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ - Copied from https://github.com/huggingface/transformers/blob/53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L821 + Copied from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L829 The only differences are: - add new args use_flash_attention - - add work around for bfloat16 FusedSDPA accuracy issue by using float16 + - add FusedSDPA """ if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -293,17 +289,11 @@ def forward( return attn_output, None, past_key_value -GAUDI_QWEN2_VL_ATTENTION_CLASSES = { - "eager": Qwen2VLAttention, - "flash_attention_2": Qwen2VLFlashAttention2, - "sdpa": GaudiQwen2VLSdpaAttention, -} - - +# from: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L930 class GaudiQwen2VLDecoderLayer(Qwen2VLDecoderLayer): def __init__(self, config: Qwen2VLConfig, layer_idx: int): super().__init__(config, layer_idx) - self.self_attn = GAUDI_QWEN2_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.self_attn = GaudiQwen2VLSdpaAttention(config, layer_idx) def forward( self, @@ -318,9 +308,9 @@ def forward( **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ - Copied from https://github.com/huggingface/transformers/blob/53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L821 + Copied from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L946 The only differences are: - - add new args use_flash_attention + - add new kwargs use_flash_attention """ """ Args: @@ -380,6 +370,7 @@ def forward( return outputs +# from: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1137 class GaudiQwen2VLModel(Qwen2VLModel): def forward( self, @@ -396,7 +387,7 @@ def forward( use_flash_attention: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutputWithPast]: """ - Inherits from Qwen2VLModel https://github.com/huggingface/transformers/blob/53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1137 + Copied from Qwen2VLModel https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1161 The only differences are: - add new arg use_flash_attention - fixes graph recompilation due to torch.arange @@ -504,6 +495,7 @@ def forward( ) +# from: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1420 class GaudiQwen2VLForConditionalGeneration(Qwen2VLForConditionalGeneration): def forward( self, @@ -526,7 +518,7 @@ def forward( use_flash_attention: Optional[bool] = False, ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: """ - Inherits from Qwen2VLForConditionalGeneration::https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py + Copied from Qwen2VLForConditionalGeneration https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1623 The only differences are: - add new arg token_idx - add new arg use_flash_attention @@ -670,7 +662,7 @@ def prepare_inputs_for_generation( **kwargs, ): """ - Inherits from Qwen2VLForConditionalGeneration https://github.com/huggingface/transformers/blob/53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1748 + Copied from https://github.com/huggingface/transformers/blob/53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1748 The only differences are: - handle new args token_idx - handle new args use_flash_attention From 5a178089397008123f5a0ca012977b50e29a3b7f Mon Sep 17 00:00:00 2001 From: Neelesh Gokhale Date: Fri, 24 Jan 2025 15:35:55 +0000 Subject: [PATCH 7/7] Add Qwen2-VL to docs and MODELS_OPTIMIZED_WITH_STATIC_SHAPES --- README.md | 2 ++ docs/source/index.mdx | 1 + examples/image-to-text/README.md | 1 + examples/image-to-text/run_pipeline.py | 1 - optimum/habana/transformers/generation/utils.py | 1 + 5 files changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 211c0b3956..79c16a4212 100644 --- a/README.md +++ b/README.md @@ -281,6 +281,8 @@ The following model architectures, tasks and device distributions have been vali | Baichuan2 |
  • DeepSpeed
  • |
  • Single card
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | DeepSeek-V2 | | :heavy_check_mark: |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | ChatGLM |
  • DeepSpeed
  • |
  • Single card
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | +| Qwen2-VL | |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | + ### Diffusers: diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 4144f2e5f1..a3b811b850 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -109,6 +109,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be | Baichuan2 |
  • DeepSpeed
  • |
  • Single card
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | DeepSeek-V2 | | ✅ |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | ChatGLM |
  • DeepSpeed
  • |
  • Single card
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | +| Qwen2-VL | |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | - Diffusers diff --git a/examples/image-to-text/README.md b/examples/image-to-text/README.md index 9259e7b126..7a8ad04664 100644 --- a/examples/image-to-text/README.md +++ b/examples/image-to-text/README.md @@ -33,6 +33,7 @@ python3 run_pipeline.py \ ``` > SDPA may introduce [reduced precison](https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-reduction-for-fp16-and-bf16-in-scaled-dot-product-attention-sdpa) + ### Multi-cards inference with BF16 Use the following commands to run Llama-3.2-90B-Vision-Instruct BF16 inference with FusedSDPA on 8 HPUs: diff --git a/examples/image-to-text/run_pipeline.py b/examples/image-to-text/run_pipeline.py index 236abdf7c4..81fb910cdb 100644 --- a/examples/image-to-text/run_pipeline.py +++ b/examples/image-to-text/run_pipeline.py @@ -347,7 +347,6 @@ def main(): if model_type == "qwen2_vl": generate_kwargs["use_cache"] = True generate_kwargs["cache_implementation"] = "static" - generate_kwargs["static_shapes"] = True if args.quant_config: generator.model = setup_quantization(generator.model, args) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index ba9d17dc7c..dc1ba6f7db 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -116,6 +116,7 @@ "baichuan", "deepseek_v2", "chatglm", + "qwen2_vl", ] # Initial generated token index is set to 1 to accomodate SOS (start of string) token.