Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Qwen2-VL #1542

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions optimum/habana/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import time
from typing import Union

import habana_frameworks.torch.core as htcore
import torch

from optimum.utils import logging
Expand Down Expand Up @@ -67,6 +68,7 @@ def gaudi_MaxTimeCriteria_call(
def gaudi_EosTokenCriteria_call(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> Union[torch.BoolTensor, bool]:
htcore.mark_step()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a general change, and may affect performance in other models.

Can you add some numbers for the compile times u have with and without this change?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to add teh markstep at teh end of the qwen model or something like that so the effect of this is isolated or qwen only?

otherwise we'd have to do a larger study to make sure it doesnt have any bad effect on any other model/usecase

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mark_step at the end of the model doesn't have the same effect.
There is 2x improvement in warmup time.
I agree this should be independent PR with more tests.
I will remove this particular change.

self.eos_token_id = self.eos_token_id.to(input_ids.device)
token_idx = kwargs.get("token_idx", None)
if token_idx is not None:
Expand Down
20 changes: 20 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@
GaudiQwen2MoeForCausalLM,
GaudiQwen2MoeMLP,
GaudiQwen2MoeModel,
GaudiQwen2VisionSdpaAttention,
GaudiQwen2VisionTransformerPretrainedModel,
GaudiQwen2VLDecoderLayer,
GaudiQwen2VLForConditionalGeneration,
GaudiQwen2VLModel,
GaudiQwen2VLSdpaAttention,
GaudiQwen2VLVisionBlock,
GaudiStableLmAttention,
GaudiStableLmDecoderLayer,
GaudiStableLmForCausalLM,
Expand Down Expand Up @@ -630,6 +637,19 @@ def adapt_transformers_to_gaudi():
gaudi_qwen2moe_block_sparse_moe_forward
)

# Optimization for qwen2-vl Gaudi
transformers.models.qwen2_vl.modeling_qwen2_vl.VisionSdpaAttention = GaudiQwen2VisionSdpaAttention
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLVisionBlock = GaudiQwen2VLVisionBlock
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VisionTransformerPretrainedModel = (
GaudiQwen2VisionTransformerPretrainedModel
)
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLSdpaAttention = GaudiQwen2VLSdpaAttention
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLDecoderLayer = GaudiQwen2VLDecoderLayer
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLModel = GaudiQwen2VLModel
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration = (
GaudiQwen2VLForConditionalGeneration
)

# Optimization for stablelm on Gaudi
transformers.models.stablelm.modeling_stablelm.StableLmAttention = GaudiStableLmAttention
transformers.models.stablelm.modeling_stablelm.StableLmDecoderLayer = GaudiStableLmDecoderLayer
Expand Down
9 changes: 9 additions & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,15 @@
gaudi_qwen2moe_block_sparse_moe_forward,
gaudi_qwen2moe_rmsnorm_forward,
)
from .qwen2_vl import (
GaudiQwen2VisionSdpaAttention,
GaudiQwen2VisionTransformerPretrainedModel,
GaudiQwen2VLDecoderLayer,
GaudiQwen2VLForConditionalGeneration,
GaudiQwen2VLModel,
GaudiQwen2VLSdpaAttention,
GaudiQwen2VLVisionBlock,
)
from .seamless_m4t import (
gaudi_SeamlessM4TAttention_forward,
gaudi_SeamlessM4TCodeHifiGan_get_output_hifigan_lengths,
Expand Down
9 changes: 9 additions & 0 deletions optimum/habana/transformers/models/qwen2_vl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .modeling_qwen2_vl import (
GaudiQwen2VisionSdpaAttention,
GaudiQwen2VisionTransformerPretrainedModel,
GaudiQwen2VLDecoderLayer,
GaudiQwen2VLForConditionalGeneration,
GaudiQwen2VLModel,
GaudiQwen2VLSdpaAttention,
GaudiQwen2VLVisionBlock,
)
Loading