Skip to content

Commit

Permalink
🤖 Lint code
Browse files Browse the repository at this point in the history
  • Loading branch information
ivy-branch committed Oct 15, 2023
1 parent cf97926 commit 4154713
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 49 deletions.
30 changes: 15 additions & 15 deletions ivy/functional/frontends/torchvision/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,6 @@
from ivy.func_wrapper import with_supported_dtypes, with_unsupported_device_and_dtypes


@to_ivy_arrays_and_back
def nms(boxes, scores, iou_threshold):
return ivy.nms(boxes, scores, iou_threshold)


@with_supported_dtypes({"2.1.0 and below": ("float32", "float64")}, "torch")
@to_ivy_arrays_and_back
def roi_align(
input, boxes, output_size, spatial_scale=1.0, sampling_ratio=1, aligned=False
):
return ivy.roi_align(
input, boxes, output_size, spatial_scale, sampling_ratio, aligned
)


@with_unsupported_device_and_dtypes(
{
"2.1.0 and below": {
Expand All @@ -33,3 +18,18 @@ def clip_boxes_to_image(boxes, size):
boxes_y = boxes[..., 1::2].clip(0, height)
clipped_boxes = ivy.stack([boxes_x, boxes_y], axis=-1)
return clipped_boxes.reshape(boxes.shape).astype(boxes.dtype)


@to_ivy_arrays_and_back
def nms(boxes, scores, iou_threshold):
return ivy.nms(boxes, scores, iou_threshold)


@with_supported_dtypes({"2.1.0 and below": ("float32", "float64")}, "torch")
@to_ivy_arrays_and_back
def roi_align(
input, boxes, output_size, spatial_scale=1.0, sampling_ratio=1, aligned=False
):
return ivy.roi_align(
input, boxes, output_size, spatial_scale, sampling_ratio, aligned
)
68 changes: 34 additions & 34 deletions ivy_tests/test_ivy/test_frontends/test_torchvision/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,40 @@ def _roi_align_helper(draw):
# ------------ #


@handle_frontend_test(
fn_tree="torchvision.ops.clip_boxes_to_image",
boxes=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
shape=st.tuples(helpers.ints(min_value=1, max_value=5), st.just(4)),
),
size=st.tuples(
helpers.ints(min_value=1, max_value=256),
helpers.ints(min_value=1, max_value=256),
),
)
def test_torchvision_clip_boxes_to_image(
*,
boxes,
size,
on_device,
fn_tree,
frontend,
test_flags,
backend_fw,
):
dtype, boxes = boxes
helpers.test_frontend_function(
input_dtypes=dtype,
backend_to_test=backend_fw,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
boxes=boxes[0],
size=size,
)


# nms
@handle_frontend_test(
fn_tree="torchvision.ops.nms",
Expand Down Expand Up @@ -143,37 +177,3 @@ def test_torchvision_roi_align(
rtol=1e-5,
atol=1e-5,
)


@handle_frontend_test(
fn_tree="torchvision.ops.clip_boxes_to_image",
boxes=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
shape=st.tuples(helpers.ints(min_value=1, max_value=5),
st.just(4))
),
size=st.tuples(
helpers.ints(min_value=1, max_value=256),
helpers.ints(min_value=1, max_value=256)),
)
def test_torchvision_clip_boxes_to_image(
*,
boxes,
size,
on_device,
fn_tree,
frontend,
test_flags,
backend_fw,
):
dtype, boxes = boxes
helpers.test_frontend_function(
input_dtypes=dtype,
backend_to_test=backend_fw,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
boxes=boxes[0],
size=size
)

0 comments on commit 4154713

Please sign in to comment.