Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds ASTLayer support for BetterTransformer #548

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/bettertransformer/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ You can now use this feature in 🤗 Optimum together with Transformers and use
The list of supported model below:

- [AlBERT](https://arxiv.org/abs/1909.11942)
- [ASTLayer](https://arxiv.org/abs/2104.01778)
- [BART](https://arxiv.org/abs/1910.13461)
- [BERT](https://arxiv.org/abs/1810.04805)
- [BERT-generation](https://arxiv.org/abs/1907.12461)
Expand Down
2 changes: 2 additions & 0 deletions optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from .encoder_models import (
AlbertLayerBetterTransformer,
ASTLayerBetterTransformer,
BartEncoderLayerBetterTransformer,
BertLayerBetterTransformer,
CLIPLayerBetterTransformer,
Expand All @@ -30,6 +31,7 @@

class BetterTransformerManager:
MODEL_MAPPING = {
"audio-spectrogram-transformer": ("ASTLayer", ASTLayerBetterTransformer),
"albert": ("AlbertLayer", AlbertLayerBetterTransformer),
"bart": ("BartEncoderLayer", BartEncoderLayerBetterTransformer),
"bert": ("BertLayer", BertLayerBetterTransformer),
Expand Down
94 changes: 94 additions & 0 deletions optimum/bettertransformer/models/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,100 @@ def forward(self, hidden_states, attention_mask, *_):
return (hidden_states,)


class ASTLayerBetterTransformer(BetterTransformerBaseLayer):
def __init__(self, ast_layer, config):
r"""
A simple conversion of the `ASTLayer` to its `BetterTransformer` implementation.
Args:
ast_layer (`torch.nn.Module`):
The original `ASTLayer` where the weights needs to be retrieved.
"""
super().__init__(config)
# In_proj layer
self.in_proj_weight = nn.Parameter(
torch.cat(
[
ast_layer.attention.attention.query.weight,
ast_layer.attention.attention.key.weight,
ast_layer.attention.attention.value.weight,
]
)
)

self.in_proj_bias = nn.Parameter(
torch.cat(
[
ast_layer.attention.attention.query.bias,
ast_layer.attention.attention.key.bias,
ast_layer.attention.attention.value.bias,
]
)
)

# Out proj layer
self.out_proj_weight = ast_layer.attention.output.dense.weight
self.out_proj_bias = ast_layer.attention.output.dense.bias

# Linear layer 1
self.linear1_weight = ast_layer.intermediate.dense.weight
self.linear1_bias = ast_layer.intermediate.dense.bias

# Linear layer 2
self.linear2_weight = ast_layer.output.dense.weight
self.linear2_bias = ast_layer.output.dense.bias

# Layer norm 1
self.norm1_eps = ast_layer.layernorm_before.eps
self.norm1_weight = ast_layer.layernorm_before.weight
self.norm1_bias = ast_layer.layernorm_before.bias

# Layer norm 2
self.norm2_eps = ast_layer.layernorm_after.eps
self.norm2_weight = ast_layer.layernorm_after.weight
self.norm2_bias = ast_layer.layernorm_after.bias

# Model hyper parameters
self.num_heads = ast_layer.attention.attention.num_attention_heads
self.embed_dim = int(ast_layer.attention.attention.attention_head_size * self.num_heads)

# Last step: set the last layer to `False` -> this will be set to `True` when converting the model
self.is_last_layer = False
self.norm_first = True
self.validate_bettertransformer()

def forward(self, hidden_states, *_, **__):
r"""
This is just a wrapper around the forward function proposed in:
https://github.com/huggingface/transformers/pull/19553
"""
super().forward_checker()
attention_mask = None
hidden_states = torch._transformer_encoder_layer_fwd(
hidden_states,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.use_gelu,
self.norm_first,
self.norm1_eps,
self.norm1_weight,
self.norm1_bias,
self.norm2_weight,
self.norm2_bias,
self.linear1_weight,
self.linear1_bias,
self.linear2_weight,
self.linear2_bias,
attention_mask,
)
if hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0)
return (hidden_states,)


class BertLayerBetterTransformer(BetterTransformerBaseLayer):
def __init__(self, bert_layer, config):
r"""
Expand Down
22 changes: 22 additions & 0 deletions tests/bettertransformer/test_bettertransformer_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
"ybelkada/hubert-tiny-random",
]

AST_TO_TEST = [
"Ericwang/tiny-random-ast",
]


class BetterTransformersWhisperTest(BetterTransformersTestMixin, unittest.TestCase):
r"""
Expand Down Expand Up @@ -56,6 +60,24 @@ def prepare_inputs_for_class(self, model_id):
return input_dict


class BetterTransformersASTTest(BetterTransformersTestMixin, unittest.TestCase):
r"""
Testing suite for AST - tests all the tests defined in `BetterTransformersTestMixin`
Since `AST` uses slightly different preprocessor than other audio models, it is preferrable
to define its own testing class.
"""
all_models_to_test = AST_TO_TEST

def prepare_inputs_for_class(self, model_id):
batch_duration_in_seconds = [1, 3, 2, 6]
input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]

feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)

input_dict = feature_extractor(input_features, return_tensors="pt", padding=True)
return input_dict


class BetterTransformersAudioTest(BetterTransformersTestMixin, unittest.TestCase):
r"""
Testing suite for Audio models - tests all the tests defined in `BetterTransformersTestMixin`
Expand Down