Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BetterTransformer support for FlavaModel #538

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e2fa9b4
add BetterTransformer support for FlavaModel and test
Dec 2, 2022
c76b1c4
Update optimum/bettertransformer/models/encoder_models.py
katiele47 Dec 3, 2022
fadd2ed
Merge branch 'main' of https://github.com/huggingface/optimum into ad…
Dec 3, 2022
f91a928
Merge branch 'add-better-transformers-support-for-flava' of https://g…
Dec 3, 2022
072420c
Update tests/bettertransformer/test_bettertransformer_vision.py
katiele47 Dec 5, 2022
4966b96
Merge branch 'main' of https://github.com/huggingface/optimum into ad…
Dec 5, 2022
6c80658
Merge branch 'add-better-transformers-support-for-flava' of https://g…
Dec 5, 2022
2a4b790
Optimum ONNX Runtime API improvement (#515)
michaelbenayoun Dec 6, 2022
8b559db
Add IO binding support for custom ORTModel (#447)
JingyaHuang Dec 6, 2022
d026fdc
fix import (#553)
fxmarty Dec 7, 2022
037467d
Update readme (#550)
echarlaix Dec 7, 2022
f6eb417
Refactor of 2 functions used in ORTModel (#551)
michaelbenayoun Dec 7, 2022
382077d
Update tests/bettertransformer/test_bettertransformer_vision.py
katiele47 Dec 7, 2022
422f3d7
Update tests/bettertransformer/test_bettertransformer_vision.py
katiele47 Dec 7, 2022
fac2694
Update readme (#556)
echarlaix Dec 7, 2022
08d7917
applied make style
Dec 7, 2022
6063fc4
Fix ORTTrainer wrapper duplication / PyTorch evaluate / update with t…
JingyaHuang Dec 7, 2022
d169fc3
fix test
younesbelkada Dec 8, 2022
1588a2e
Add CLIP BetterTransformer (#534)
fxmarty Dec 8, 2022
86375c4
Fix flaky BetterTransformer test (#564)
fxmarty Dec 8, 2022
0970ec4
Support decoder generated with `--for-ort` from `optimum.exporters.on…
fxmarty Dec 8, 2022
5c90cf1
enable FP16Optimizer for fp16 deepspeed training. (#547)
AdamLouly Dec 8, 2022
1521f1b
fixed merge conflict due to rebase with upstream main
Dec 9, 2022
f48da36
merge conflict clip and flava
Dec 9, 2022
0b8ed50
installed missing dependencies
Dec 9, 2022
a012ec7
applied make style
Dec 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 122 additions & 109 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
- local: bettertransformer/tutorials/contribute
title: How to add support for new architectures?
title: Tutorials
title: BetterTransformer integration
title: BetterTransformer
isExpanded: false
- sections:
- local: utils/dummy_input_generators
Expand Down
4 changes: 4 additions & 0 deletions docs/source/bettertransformer/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ The list of supported model below:
- [BERT](https://arxiv.org/abs/1810.04805)
- [BERT-generation](https://arxiv.org/abs/1907.12461)
- [CamemBERT](https://arxiv.org/abs/1911.03894)
- [CLIP](https://arxiv.org/abs/2103.00020)
- [Data2VecText](https://arxiv.org/abs/2202.03555)
- [DistilBert](https://arxiv.org/abs/1910.01108)
- [DeiT](https://arxiv.org/abs/2012.12877)
- [Electra](https://arxiv.org/abs/2003.10555)
- [Ernie](https://arxiv.org/abs/1904.09223)
- [Flava](https://arxiv.org/abs/2112.04482)
- [FSMT](https://arxiv.org/abs/1907.06616)
- [HuBERT](https://arxiv.org/pdf/2106.07447.pdf)
- [LayoutLM](https://arxiv.org/abs/1912.13318)
Expand All @@ -52,6 +54,8 @@ The list of supported model below:
- [XLMRoberta](https://arxiv.org/abs/1911.02116)
- [YOLOS](https://arxiv.org/abs/2106.00666)

Note that for encoder-decoder models, only the encoder part is supported by PyTorch's BetterTransformer for now.

Let us know by opening an issue in 🤗 Optimum if you want more models to be supported, or check out the contribution guideline if you want to add it by yourself!

### Quick usage
Expand Down
4 changes: 4 additions & 0 deletions docs/source/onnxruntime/package_reference/modeling_ort.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ specific language governing permissions and limitations under the License.

[[autodoc]] onnxruntime.ORTModelForCausalLM

## ORTModelForCustomTasks

[[autodoc]] onnxruntime.ORTModelForCustomTasks

## ORTModelForFeatureExtraction

[[autodoc]] onnxruntime.ORTModelForFeatureExtraction
Expand Down
10 changes: 10 additions & 0 deletions optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
AlbertLayerBetterTransformer,
BartEncoderLayerBetterTransformer,
BertLayerBetterTransformer,
CLIPLayerBetterTransformer,
DistilBertLayerBetterTransformer,
FlavaLayerBetterTransformer,
FSMTEncoderLayerBetterTransformer,
MBartEncoderLayerBetterTransformer,
ViltLayerBetterTransformer,
Expand Down Expand Up @@ -75,6 +77,14 @@
# FSMTModel:
"EncoderLayer": FSMTEncoderLayerBetterTransformer,
"ViltLayer": ViltLayerBetterTransformer,
# Flava:
"FlavaLayer": FlavaLayerBetterTransformer,
# CLIP
"CLIPEncoderLayer": CLIPLayerBetterTransformer,
}

EXCLUDE_FROM_TRANSFORM = {
"clip": ["text_model"], # text model uses causal attention, that is most likely not supported in BetterTransformer
}


Expand Down
17 changes: 16 additions & 1 deletion optimum/bettertransformer/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@
import torch
import torch.nn as nn

from ...utils import logging


KNOWN_ACTIVATION_ATTRIBUTES = ["hidden_act", "activation", "act_fn", "activation_function"]
KNOWN_POS_EMB_ATTRIBUTES = ["position_embedding_type"]
KNOWN_NUM_LAYERS = ["num_hidden_layers", "num_layers", "encoder_layers", "n_layers"]

SUPPORTED_ACTIVATION_FUNCTIONS = ["gelu", "relu", "gelu_new"]
USE_AT_OWN_RISK_ACTIVATION_FUNCTIONS = ["quick_gelu"]


logger = logging.get_logger(__name__)


class BetterTransformerBaseLayer(nn.Module):
Expand All @@ -39,6 +45,10 @@ def __init__(self, config):
self.act_fn = getattr(config, attr)
break

# if act_fn not found in the config, fall back to the private `_get_activation_function` if available
if self.act_fn is None and hasattr(self, "_get_activation_function"):
self.act_fn = self._get_activation_function(config)

# Get pos emb type
for attr in KNOWN_POS_EMB_ATTRIBUTES:
if hasattr(config, attr):
Expand Down Expand Up @@ -77,7 +87,12 @@ def validate_bettertransformer(self):
raise ValueError("norm1_eps and norm2_eps must be equal for `BetterTransformer` integration.")

# Check activation function
if self.act_fn not in SUPPORTED_ACTIVATION_FUNCTIONS:
if self.act_fn in USE_AT_OWN_RISK_ACTIVATION_FUNCTIONS:
logger.warning(
f"Overridding {self.act_fn} activation with gelu. Use the transformed model at your own risk, the output logits could be significantly different."
)
self.act_fn = "gelu"
elif self.act_fn not in SUPPORTED_ACTIVATION_FUNCTIONS:
raise ValueError(
f"Activation function {self.act_fn} not supported" " for `BetterTransformer` integration."
)
Expand Down
209 changes: 209 additions & 0 deletions optimum/bettertransformer/models/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

import torch
import torch.nn as nn

from .base import BetterTransformerBaseLayer


if TYPE_CHECKING:
from transformers import PretrainedConfig


class AlbertLayerBetterTransformer(BetterTransformerBaseLayer):
def __init__(self, albert_layer, config):
r"""
Expand Down Expand Up @@ -1095,3 +1101,206 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
elif hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0, original_shape)
return (hidden_states, attention_mask)


class CLIPLayerBetterTransformer(BetterTransformerBaseLayer):
def __init__(self, layer, config):
r"""
A simple conversion of the CLIPEncoderLayer to its `BetterTransformer` implementation.
**The implementation is valid only for the vision model, that does not use `causal_attention_mask`.**
Args:
layer (`torch.nn.Module`):
The original `CLIPEncoderLayer` where the weights needs to be retrieved.
"""
super().__init__(config)
# In_proj layer
self.in_proj_weight = nn.Parameter(
torch.cat(
[
layer.self_attn.q_proj.weight,
layer.self_attn.k_proj.weight,
layer.self_attn.v_proj.weight,
]
)
)
self.in_proj_bias = nn.Parameter(
torch.cat(
[
layer.self_attn.q_proj.bias,
layer.self_attn.k_proj.bias,
layer.self_attn.v_proj.bias,
]
)
)

# Out proj layer
self.out_proj_weight = layer.self_attn.out_proj.weight
self.out_proj_bias = layer.self_attn.out_proj.bias

# Linear layer 1
self.linear1_weight = layer.mlp.fc1.weight
self.linear1_bias = layer.mlp.fc1.bias

# Linear layer 2
self.linear2_weight = layer.mlp.fc2.weight
self.linear2_bias = layer.mlp.fc2.bias

# Layer norm 1
self.norm1_eps = layer.layer_norm1.eps
self.norm1_weight = layer.layer_norm1.weight
self.norm1_bias = layer.layer_norm1.bias

# Layer norm 2
self.norm2_eps = layer.layer_norm2.eps
self.norm2_weight = layer.layer_norm2.weight
self.norm2_bias = layer.layer_norm2.bias

# Model hyper parameters
self.num_heads = layer.self_attn.num_heads
self.embed_dim = layer.self_attn.embed_dim

# Last step: set the last layer to `False` -> this will be set to `True` when converting the model
self.is_last_layer = False
self.norm_first = True

self.validate_bettertransformer()

def forward(self, hidden_states, attention_mask, *_, **__):
r"""
This is just a wrapper around the forward function proposed in:
https://github.com/huggingface/transformers/pull/19553
"""
super().forward_checker()

# we expect attention_mask to be None in the vision model
if attention_mask is not None:
raise ValueError(
"Please do not use attention masks when using `BetterTransformer` converted vision models"
)

hidden_states = torch._transformer_encoder_layer_fwd(
hidden_states,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.use_gelu,
self.norm_first,
self.norm1_eps,
self.norm1_weight,
self.norm1_bias,
self.norm2_weight,
self.norm2_bias,
self.linear1_weight,
self.linear1_bias,
self.linear2_weight,
self.linear2_bias,
attention_mask,
)

return (hidden_states,)

def _get_activation_function(self, config: "PretrainedConfig"):
if hasattr(config, "vision_config") and hasattr(config, "text_config"):
assert config.vision_config.hidden_act == config.text_config.hidden_act
return config.vision_config.hidden_act
else:
return config.hidden_act


class FlavaLayerBetterTransformer(BetterTransformerBaseLayer):
def __init__(self, flava_layer, config):
r"""
A simple conversion of the FlavaLayer to its `BetterTransformer` implementation.

Args:
flava_layer (`torch.nn.Module`):
The original `FlavaLayer` where the weights needs to be retrieved.
"""
super().__init__(config.image_config)
# In_proj layer
self.in_proj_weight = nn.Parameter(
torch.cat(
[
flava_layer.attention.attention.query.weight,
flava_layer.attention.attention.key.weight,
flava_layer.attention.attention.value.weight,
]
)
)
self.in_proj_bias = nn.Parameter(
torch.cat(
[
flava_layer.attention.attention.query.bias,
flava_layer.attention.attention.key.bias,
flava_layer.attention.attention.value.bias,
]
)
)

# Out proj layer
self.out_proj_weight = flava_layer.attention.output.dense.weight
self.out_proj_bias = flava_layer.attention.output.dense.bias

# Linear layer 1
self.linear1_weight = flava_layer.intermediate.dense.weight
self.linear1_bias = flava_layer.intermediate.dense.bias

# Linear layer 2
self.linear2_weight = flava_layer.output.dense.weight
self.linear2_bias = flava_layer.output.dense.bias

# Layer norm 1
self.norm1_eps = flava_layer.layernorm_before.eps
self.norm1_weight = flava_layer.layernorm_before.weight
self.norm1_bias = flava_layer.layernorm_before.bias

# Layer norm 2
self.norm2_eps = flava_layer.layernorm_after.eps
self.norm2_weight = flava_layer.layernorm_after.weight
self.norm2_bias = flava_layer.layernorm_after.bias

# Model hyper parameters
self.num_heads = flava_layer.attention.attention.num_attention_heads
self.embed_dim = int(flava_layer.attention.attention.attention_head_size * self.num_heads)

# Last step: set the last layer to `False` -> this will be set to `True` when converting the model
self.is_last_layer = False
self.norm_first = True

self.validate_bettertransformer()

def forward(self, hidden_states, *_, **__):
r"""
This is just a wrapper around the forward function proposed in:
https://github.com/huggingface/transformers/pull/19553
"""
super().forward_checker()
attention_mask = None

hidden_states = torch._transformer_encoder_layer_fwd(
hidden_states,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.use_gelu,
self.norm_first,
self.norm1_eps,
self.norm1_weight,
self.norm1_bias,
self.norm2_weight,
self.norm2_bias,
self.linear1_weight,
self.linear1_bias,
self.linear2_weight,
self.linear2_bias,
attention_mask,
)
if hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0)
return (hidden_states,)
Loading