From 4154713ef8c76898658fa46f1301bed8e885a99c Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Sun, 15 Oct 2023 16:56:32 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=A4=96=20Lint=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ivy/functional/frontends/torchvision/ops.py | 30 ++++---- .../test_torchvision/test_ops.py | 68 +++++++++---------- 2 files changed, 49 insertions(+), 49 deletions(-) diff --git a/ivy/functional/frontends/torchvision/ops.py b/ivy/functional/frontends/torchvision/ops.py index 74f89f5e3b062..051776624adbb 100644 --- a/ivy/functional/frontends/torchvision/ops.py +++ b/ivy/functional/frontends/torchvision/ops.py @@ -3,21 +3,6 @@ from ivy.func_wrapper import with_supported_dtypes, with_unsupported_device_and_dtypes -@to_ivy_arrays_and_back -def nms(boxes, scores, iou_threshold): - return ivy.nms(boxes, scores, iou_threshold) - - -@with_supported_dtypes({"2.1.0 and below": ("float32", "float64")}, "torch") -@to_ivy_arrays_and_back -def roi_align( - input, boxes, output_size, spatial_scale=1.0, sampling_ratio=1, aligned=False -): - return ivy.roi_align( - input, boxes, output_size, spatial_scale, sampling_ratio, aligned - ) - - @with_unsupported_device_and_dtypes( { "2.1.0 and below": { @@ -33,3 +18,18 @@ def clip_boxes_to_image(boxes, size): boxes_y = boxes[..., 1::2].clip(0, height) clipped_boxes = ivy.stack([boxes_x, boxes_y], axis=-1) return clipped_boxes.reshape(boxes.shape).astype(boxes.dtype) + + +@to_ivy_arrays_and_back +def nms(boxes, scores, iou_threshold): + return ivy.nms(boxes, scores, iou_threshold) + + +@with_supported_dtypes({"2.1.0 and below": ("float32", "float64")}, "torch") +@to_ivy_arrays_and_back +def roi_align( + input, boxes, output_size, spatial_scale=1.0, sampling_ratio=1, aligned=False +): + return ivy.roi_align( + input, boxes, output_size, spatial_scale, sampling_ratio, aligned + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_torchvision/test_ops.py b/ivy_tests/test_ivy/test_frontends/test_torchvision/test_ops.py index 752002e763073..ec646cc3084f0 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torchvision/test_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torchvision/test_ops.py @@ -82,6 +82,40 @@ def _roi_align_helper(draw): # ------------ # +@handle_frontend_test( + fn_tree="torchvision.ops.clip_boxes_to_image", + boxes=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=st.tuples(helpers.ints(min_value=1, max_value=5), st.just(4)), + ), + size=st.tuples( + helpers.ints(min_value=1, max_value=256), + helpers.ints(min_value=1, max_value=256), + ), +) +def test_torchvision_clip_boxes_to_image( + *, + boxes, + size, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtype, boxes = boxes + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + boxes=boxes[0], + size=size, + ) + + # nms @handle_frontend_test( fn_tree="torchvision.ops.nms", @@ -143,37 +177,3 @@ def test_torchvision_roi_align( rtol=1e-5, atol=1e-5, ) - - -@handle_frontend_test( - fn_tree="torchvision.ops.clip_boxes_to_image", - boxes=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), - shape=st.tuples(helpers.ints(min_value=1, max_value=5), - st.just(4)) - ), - size=st.tuples( - helpers.ints(min_value=1, max_value=256), - helpers.ints(min_value=1, max_value=256)), -) -def test_torchvision_clip_boxes_to_image( - *, - boxes, - size, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - dtype, boxes = boxes - helpers.test_frontend_function( - input_dtypes=dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - boxes=boxes[0], - size=size - )