From a2ab86d17dad8d46fb0235ae29bbab725328552b Mon Sep 17 00:00:00 2001 From: Issam Arabi Date: Sun, 1 Oct 2023 17:28:01 -0400 Subject: [PATCH] make style Signed-off-by: Issam Arabi --- .../models/encoder_models.py | 20 ++++++------------- tests/bettertransformer/test_vision.py | 15 ++++++++++++-- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/optimum/bettertransformer/models/encoder_models.py b/optimum/bettertransformer/models/encoder_models.py index b8839c6a916..0151f67d55e 100644 --- a/optimum/bettertransformer/models/encoder_models.py +++ b/optimum/bettertransformer/models/encoder_models.py @@ -1251,16 +1251,8 @@ def __init__(self, detr_layer, config): self.norm_first = True self.original_layers_mapping = { - "in_proj_weight": [ - "self_attn.q_proj.weight", - "self_attn.k_proj.weight", - "self_attn.v_proj.weight" - ], - "in_proj_bias": [ - "self_attn.q_proj.bias", - "self_attn.k_proj.bias", - "self_attn.v_proj.bias" - ], + "in_proj_weight": ["self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"], + "in_proj_bias": ["self_attn.q_proj.bias", "self_attn.k_proj.bias", "self_attn.v_proj.bias"], "out_proj_weight": "self_attn.out_proj.weight", "out_proj_bias": "self_attn.out_proj.bias", "linear1_weight": "fc1.weight", @@ -1272,7 +1264,7 @@ def __init__(self, detr_layer, config): "norm2_weight": "final_layer_norm.weight", "norm2_bias": "final_layer_norm.bias", } - + self.validate_bettertransformer() def forward(self, hidden_states, attention_mask, output_attentions: bool, *_, **__): @@ -1303,15 +1295,15 @@ def forward(self, hidden_states, attention_mask, output_attentions: bool, *_, ** self.linear2_bias, attention_mask, ) - + if hidden_states.is_nested and self.is_last_layer: hidden_states = hidden_states.to_padded_tensor(0.0) - + else: raise NotImplementedError( "Training and Autocast are not implemented for BetterTransformer + Detr. Please open an issue." ) - + return (hidden_states,) diff --git a/tests/bettertransformer/test_vision.py b/tests/bettertransformer/test_vision.py index 176ff13329b..8c61ea09ba2 100644 --- a/tests/bettertransformer/test_vision.py +++ b/tests/bettertransformer/test_vision.py @@ -27,7 +27,18 @@ class BetterTransformersVisionTest(BetterTransformersTestMixin, unittest.TestCas r""" Testing suite for Vision Models - tests all the tests defined in `BetterTransformersTestMixin` """ - SUPPORTED_ARCH = ["blip-2", "clip", "clip_text_model", "deit", "detr", "vilt", "vit", "vit_mae", "vit_msn", "yolos"] + SUPPORTED_ARCH = [ + "blip-2", + "clip", + "clip_text_model", + "deit", + "detr", + "vilt", + "vit", + "vit_mae", + "vit_msn", + "yolos", + ] def prepare_inputs_for_class(self, model_id, model_type, batch_size=3, **preprocessor_kwargs): if model_type == "vilt": @@ -56,7 +67,7 @@ def prepare_inputs_for_class(self, model_id, model_type, batch_size=3, **preproc if model_type == "blip-2": inputs["decoder_input_ids"] = inputs["input_ids"] - + elif model_type == "detr": # Assuming detr just needs an image url = "http://images.cocodataset.org/val2017/000000039769.jpg"