diff --git a/docs/source/bettertransformer/overview.mdx b/docs/source/bettertransformer/overview.mdx index 465ed014934..19b56275b17 100644 --- a/docs/source/bettertransformer/overview.mdx +++ b/docs/source/bettertransformer/overview.mdx @@ -24,6 +24,7 @@ You can now use this feature in 🤗 Optimum together with Transformers and use The list of supported model below: - [AlBERT](https://arxiv.org/abs/1909.11942) +- [ASTLayer](https://arxiv.org/abs/2104.01778) - [BART](https://arxiv.org/abs/1910.13461) - [BERT](https://arxiv.org/abs/1810.04805) - [BERT-generation](https://arxiv.org/abs/1907.12461) diff --git a/optimum/bettertransformer/models/__init__.py b/optimum/bettertransformer/models/__init__.py index f34766a2e43..5df1dffdfe7 100644 --- a/optimum/bettertransformer/models/__init__.py +++ b/optimum/bettertransformer/models/__init__.py @@ -15,6 +15,7 @@ from .encoder_models import ( AlbertLayerBetterTransformer, + ASTLayerBetterTransformer, BartEncoderLayerBetterTransformer, BertLayerBetterTransformer, CLIPLayerBetterTransformer, @@ -30,6 +31,7 @@ class BetterTransformerManager: MODEL_MAPPING = { + "audio-spectrogram-transformer": ("ASTLayer", ASTLayerBetterTransformer), "albert": ("AlbertLayer", AlbertLayerBetterTransformer), "bart": ("BartEncoderLayer", BartEncoderLayerBetterTransformer), "bert": ("BertLayer", BertLayerBetterTransformer), diff --git a/optimum/bettertransformer/models/encoder_models.py b/optimum/bettertransformer/models/encoder_models.py index df3b81f4374..6c5cd15e100 100644 --- a/optimum/bettertransformer/models/encoder_models.py +++ b/optimum/bettertransformer/models/encoder_models.py @@ -128,6 +128,100 @@ def forward(self, hidden_states, attention_mask, *_): return (hidden_states,) +class ASTLayerBetterTransformer(BetterTransformerBaseLayer): + def __init__(self, ast_layer, config): + r""" + A simple conversion of the `ASTLayer` to its `BetterTransformer` implementation. + Args: + ast_layer (`torch.nn.Module`): + The original `ASTLayer` where the weights needs to be retrieved. + """ + super().__init__(config) + # In_proj layer + self.in_proj_weight = nn.Parameter( + torch.cat( + [ + ast_layer.attention.attention.query.weight, + ast_layer.attention.attention.key.weight, + ast_layer.attention.attention.value.weight, + ] + ) + ) + + self.in_proj_bias = nn.Parameter( + torch.cat( + [ + ast_layer.attention.attention.query.bias, + ast_layer.attention.attention.key.bias, + ast_layer.attention.attention.value.bias, + ] + ) + ) + + # Out proj layer + self.out_proj_weight = ast_layer.attention.output.dense.weight + self.out_proj_bias = ast_layer.attention.output.dense.bias + + # Linear layer 1 + self.linear1_weight = ast_layer.intermediate.dense.weight + self.linear1_bias = ast_layer.intermediate.dense.bias + + # Linear layer 2 + self.linear2_weight = ast_layer.output.dense.weight + self.linear2_bias = ast_layer.output.dense.bias + + # Layer norm 1 + self.norm1_eps = ast_layer.layernorm_before.eps + self.norm1_weight = ast_layer.layernorm_before.weight + self.norm1_bias = ast_layer.layernorm_before.bias + + # Layer norm 2 + self.norm2_eps = ast_layer.layernorm_after.eps + self.norm2_weight = ast_layer.layernorm_after.weight + self.norm2_bias = ast_layer.layernorm_after.bias + + # Model hyper parameters + self.num_heads = ast_layer.attention.attention.num_attention_heads + self.embed_dim = int(ast_layer.attention.attention.attention_head_size * self.num_heads) + + # Last step: set the last layer to `False` -> this will be set to `True` when converting the model + self.is_last_layer = False + self.norm_first = True + self.validate_bettertransformer() + + def forward(self, hidden_states, *_, **__): + r""" + This is just a wrapper around the forward function proposed in: + https://github.com/huggingface/transformers/pull/19553 + """ + super().forward_checker() + attention_mask = None + hidden_states = torch._transformer_encoder_layer_fwd( + hidden_states, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.use_gelu, + self.norm_first, + self.norm1_eps, + self.norm1_weight, + self.norm1_bias, + self.norm2_weight, + self.norm2_bias, + self.linear1_weight, + self.linear1_bias, + self.linear2_weight, + self.linear2_bias, + attention_mask, + ) + if hidden_states.is_nested and self.is_last_layer: + hidden_states = hidden_states.to_padded_tensor(0.0) + return (hidden_states,) + + class BertLayerBetterTransformer(BetterTransformerBaseLayer): def __init__(self, bert_layer, config): r""" diff --git a/tests/bettertransformer/test_bettertransformer_audio.py b/tests/bettertransformer/test_bettertransformer_audio.py index 563bd4d8202..143677ce728 100644 --- a/tests/bettertransformer/test_bettertransformer_audio.py +++ b/tests/bettertransformer/test_bettertransformer_audio.py @@ -28,6 +28,10 @@ "ybelkada/hubert-tiny-random", ] +AST_TO_TEST = [ + "Ericwang/tiny-random-ast", +] + class BetterTransformersWhisperTest(BetterTransformersTestMixin, unittest.TestCase): r""" @@ -56,6 +60,24 @@ def prepare_inputs_for_class(self, model_id): return input_dict +class BetterTransformersASTTest(BetterTransformersTestMixin, unittest.TestCase): + r""" + Testing suite for AST - tests all the tests defined in `BetterTransformersTestMixin` + Since `AST` uses slightly different preprocessor than other audio models, it is preferrable + to define its own testing class. + """ + all_models_to_test = AST_TO_TEST + + def prepare_inputs_for_class(self, model_id): + batch_duration_in_seconds = [1, 3, 2, 6] + input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds] + + feature_extractor = AutoFeatureExtractor.from_pretrained(model_id) + + input_dict = feature_extractor(input_features, return_tensors="pt", padding=True) + return input_dict + + class BetterTransformersAudioTest(BetterTransformersTestMixin, unittest.TestCase): r""" Testing suite for Audio models - tests all the tests defined in `BetterTransformersTestMixin`