Skip to content

Commit

Permalink
Add test for DETR BetterTransformer
Browse files Browse the repository at this point in the history
Signed-off-by: Issam Arabi <[email protected]>
  • Loading branch information
issamarabi committed Oct 1, 2023
1 parent 937bd99 commit c9711b8
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion tests/bettertransformer/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class BetterTransformersVisionTest(BetterTransformersTestMixin, unittest.TestCas
r"""
Testing suite for Vision Models - tests all the tests defined in `BetterTransformersTestMixin`
"""
SUPPORTED_ARCH = ["blip-2", "clip", "clip_text_model", "deit", "vilt", "vit", "vit_mae", "vit_msn", "yolos"]
SUPPORTED_ARCH = ["blip-2", "clip", "clip_text_model", "deit", "detr", "vilt", "vit", "vit_mae", "vit_msn", "yolos"]

def prepare_inputs_for_class(self, model_id, model_type, batch_size=3, **preprocessor_kwargs):
if model_type == "vilt":
Expand Down Expand Up @@ -56,6 +56,14 @@ def prepare_inputs_for_class(self, model_id, model_type, batch_size=3, **preproc

if model_type == "blip-2":
inputs["decoder_input_ids"] = inputs["input_ids"]

elif model_type == "detr":
# Assuming detr just needs an image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

feature_extractor = AutoFeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-DetrModel")
inputs = feature_extractor(images=image, return_tensors="pt")

else:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
Expand Down

0 comments on commit c9711b8

Please sign in to comment.