Skip to content

Commit

Permalink
Compatibility with Transformers 4.36 (#1590)
Browse files Browse the repository at this point in the history
* fix sdpa

* fix bart and whisper ONNX export

* fix falcon

* use staticmethod in unpatching

* remove print

* fix generate of whisper that is not compatible with us in transformers anymore

* fix falcon

* fix tests

* update bt doc
  • Loading branch information
fxmarty authored Dec 13, 2023
1 parent e840d21 commit 0645d3b
Show file tree
Hide file tree
Showing 19 changed files with 606 additions and 962 deletions.
12 changes: 8 additions & 4 deletions docs/source/bettertransformer/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ In the 2.0 version, PyTorch includes a native scaled dot-product attention opera

We provide an integration with these optimizations out of the box in 🤗 Optimum, so that you can convert any supported 🤗 Transformers model so as to use the optimized paths & `scaled_dot_product_attention` function when relevant.

<Tip warning={true}>
PyTorch-native `scaled_dot_product_attention` is slowly being natively [made default and integrated in 🤗 Transformers](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention). For models that do support SDPA in Transformers, we deprecate BetterTransformer and recommend you to use directly Transformers and PyTorc latest version for the attention optimizations (Flash Attention, memory-efficient attention) through SDPA.
</Tip>

<Tip warning={true}>
The PyTorch-native `scaled_dot_product_attention` operator can only dispatch to Flash Attention if no `attention_mask` is provided.

Expand Down Expand Up @@ -50,16 +54,16 @@ The list of supported model below:
- [DeiT](https://arxiv.org/abs/2012.12877)
- [Electra](https://arxiv.org/abs/2003.10555)
- [Ernie](https://arxiv.org/abs/1904.09223)
- [Falcon](https://arxiv.org/abs/2306.01116)
- [Falcon](https://arxiv.org/abs/2306.01116) (No need to use BetterTransformer, it is [directy supported by Transformers](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention))
- [FSMT](https://arxiv.org/abs/1907.06616)
- [GPT2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)
- [GPT-j](https://huggingface.co/EleutherAI/gpt-j-6B)
- [GPT-neo](https://github.com/EleutherAI/gpt-neo)
- [GPT-neo-x](https://arxiv.org/abs/2204.06745)
- [GPT BigCode](https://arxiv.org/abs/2301.03988) (SantaCoder, StarCoder)
- [GPT BigCode](https://arxiv.org/abs/2301.03988) (SantaCoder, StarCoder - no need to use BetterTransformer, it is [directy supported by Transformers](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention))
- [HuBERT](https://arxiv.org/pdf/2106.07447.pdf)
- [LayoutLM](https://arxiv.org/abs/1912.13318)
- [Llama & Llama2](https://arxiv.org/abs/2302.13971)
- [Llama & Llama2](https://arxiv.org/abs/2302.13971) (No need to use BetterTransformer, it is [directy supported by Transformers](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention))
- [MarkupLM](https://arxiv.org/abs/2110.08518)
- [Marian](https://arxiv.org/abs/1804.00344)
- [MBart](https://arxiv.org/abs/2001.08210)
Expand All @@ -77,7 +81,7 @@ The list of supported model below:
- [ViT-MAE](https://arxiv.org/abs/2111.06377)
- [ViT-MSN](https://arxiv.org/abs/2204.07141)
- [Wav2Vec2](https://arxiv.org/abs/2006.11477)
- [Whisper](https://cdn.openai.com/papers/whisper.pdf)
- [Whisper](https://cdn.openai.com/papers/whisper.pdf) (No need to use BetterTransformer, it is [directy supported by Transformers](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention))
- [XLMRoberta](https://arxiv.org/abs/1911.02116)
- [YOLOS](https://arxiv.org/abs/2106.00666)

Expand Down
23 changes: 1 addition & 22 deletions optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,10 @@
BlenderbotAttentionLayerBetterTransformer,
BloomAttentionLayerBetterTransformer,
CodegenAttentionLayerBetterTransformer,
FalconAttentionLayerBetterTransformer,
GPT2AttentionLayerBetterTransformer,
GPTBigCodeAttentionLayerBetterTransformer,
GPTJAttentionLayerBetterTransformer,
GPTNeoAttentionLayerBetterTransformer,
GPTNeoXAttentionLayerBetterTransformer,
LlamaAttentionLayerBetterTransformer,
M2M100AttentionLayerBetterTransformer,
MarianAttentionLayerBetterTransformer,
OPTAttentionLayerBetterTransformer,
Expand All @@ -45,17 +42,9 @@
ViltLayerBetterTransformer,
ViTLayerBetterTransformer,
Wav2Vec2EncoderLayerBetterTransformer,
WhisperEncoderLayerBetterTransformer,
)


# TODO: remove once we are much higher than 4.31
if check_if_transformers_greater("4.31"):
from .attention import _llama_prepare_decoder_attention_mask
else:
from ...utils.dummy_bettertransformer_objects import _llama_prepare_decoder_attention_mask


class BetterTransformerManager:
MODEL_MAPPING = {
"albert": {"AlbertLayer": AlbertLayerBetterTransformer},
Expand All @@ -78,15 +67,12 @@ class BetterTransformerManager:
"electra": {"ElectraLayer": BertLayerBetterTransformer},
"ernie": {"ErnieLayer": BertLayerBetterTransformer},
"fsmt": {"EncoderLayer": FSMTEncoderLayerBetterTransformer},
"falcon": {"FalconAttention": FalconAttentionLayerBetterTransformer},
"gpt2": {"GPT2Attention": GPT2AttentionLayerBetterTransformer},
"gpt_bigcode": {"GPTBigCodeAttention": GPTBigCodeAttentionLayerBetterTransformer},
"gptj": {"GPTJAttention": GPTJAttentionLayerBetterTransformer},
"gpt_neo": {"GPTNeoSelfAttention": GPTNeoAttentionLayerBetterTransformer},
"gpt_neox": {"GPTNeoXAttention": GPTNeoXAttentionLayerBetterTransformer},
"hubert": {"HubertEncoderLayer": Wav2Vec2EncoderLayerBetterTransformer},
"layoutlm": {"LayoutLMLayer": BertLayerBetterTransformer},
"llama": {"LlamaAttention": LlamaAttentionLayerBetterTransformer},
"m2m_100": {
"M2M100EncoderLayer": MBartEncoderLayerBetterTransformer,
"M2M100Attention": M2M100AttentionLayerBetterTransformer,
Expand Down Expand Up @@ -115,13 +101,12 @@ class BetterTransformerManager:
"Wav2Vec2EncoderLayer": Wav2Vec2EncoderLayerBetterTransformer,
"Wav2Vec2EncoderLayerStableLayerNorm": Wav2Vec2EncoderLayerBetterTransformer,
},
"whisper": {"WhisperEncoderLayer": WhisperEncoderLayerBetterTransformer},
"xlm-roberta": {"XLMRobertaLayer": BertLayerBetterTransformer},
"yolos": {"YolosLayer": ViTLayerBetterTransformer},
}

OVERWRITE_METHODS = {
"llama": {"LlamaModel": ("_prepare_decoder_attention_mask", _llama_prepare_decoder_attention_mask)}
# "llama": {"LlamaModel": ("_prepare_decoder_attention_mask", _llama_prepare_decoder_attention_mask)}
}

EXCLUDE_FROM_TRANSFORM = {
Expand All @@ -144,15 +129,12 @@ class BetterTransformerManager:
"bloom",
"codegen",
"gpt2",
"gpt_bigcode",
"gptj",
"gpt_neo",
"gpt_neox",
"llama",
"opt",
"pegasus",
"t5",
"falcon",
}

NOT_REQUIRES_STRICT_VALIDATION = {
Expand All @@ -161,15 +143,12 @@ class BetterTransformerManager:
"bloom",
"codegen",
"gpt2",
"gpt_bigcode",
"gptj",
"gpt_neo",
"gpt_neox",
"llama",
"opt",
"pegasus",
"t5",
"falcon",
}

@staticmethod
Expand Down
Loading

0 comments on commit 0645d3b

Please sign in to comment.