Skip to content

Commit

Permalink
feat: added non_max_supression to ivy (ivy-llc#23668)
Browse files Browse the repository at this point in the history
  • Loading branch information
ShreyanshBardia authored and iababio committed Sep 27, 2023
1 parent d5d9a61 commit a034740
Show file tree
Hide file tree
Showing 12 changed files with 212 additions and 5 deletions.
5 changes: 3 additions & 2 deletions docker/requirement_mappings.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
"jax": ["jax[cpu]","dm-haiku", "flax", "jaxlib"],
"numpy": ["numpy"],
"paddle": ["paddlepaddle"],
"mxnet": ["mxnet"]
}
"mxnet": ["mxnet"],
"torch": ["torchvision"]
}
4 changes: 2 additions & 2 deletions docker/requirement_mappings_gpu.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
"jax": ["dm-haiku", "flax"],
"numpy": ["numpy"],
"mxnet": ["mxnet"],
"torch": ["torch-scatter"]
}
"torch": ["torch-scatter", "torchvision"]
}
2 changes: 1 addition & 1 deletion docker/requirement_mappings_multiversion.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@
"mxnet"
],
"torch": [
"torch-scatter"
"torch-scatter", "torchvision"
]
}
18 changes: 18 additions & 0 deletions ivy/functional/backends/tensorflow/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,3 +491,21 @@ def conv_general_transpose(
if data_format == "channel_first":
res = tf.transpose(res, (0, dims + 1, *range(1, dims + 1)))
return res


def nms(
boxes,
scores=None,
iou_threshold=0.5,
max_output_size=None,
score_threshold=float("-inf"),
):
if scores is None:
scores = tf.ones(boxes.shape[0])

boxes = tf.gather(boxes, [1, 0, 3, 2], axis=1)
ret = tf.image.non_max_suppression(
boxes, scores, max_output_size or len(boxes), iou_threshold, score_threshold
)

return tf.cast(ret, dtype=tf.int64)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# import torchvision

from . import layers
from .layers import *


name = "torchvision"

incompatible_sub_backends = ()
37 changes: 37 additions & 0 deletions ivy/functional/backends/torch/sub_backends/torchvision/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from torchvision.ops import nms as torch_nms
import torch


def nms(
boxes,
scores=None,
iou_threshold=0.5,
max_output_size=None,
score_threshold=float("-inf"),
):
# boxes (Tensor[N, 4])) – boxes to perform NMS on.
# They are expected to be in (x1, y1, x2, y2) format
# with 0 <= x1 < x2 and 0 <= y1 < y2.
change_id = False
if score_threshold is not float("-inf") and scores is not None:
keep_idx = scores > score_threshold
boxes = boxes[keep_idx]
scores = scores[keep_idx]
change_id = True
nonzero = torch.nonzero(keep_idx).flatten()

if scores is None:
scores = torch.ones((boxes.shape[0],), dtype=boxes.dtype)

if len(boxes) < 2:
if len(boxes) == 1:
ret = torch.tensor([0], dtype=torch.int64)
else:
ret = torch.tensor([], dtype=torch.int64)
else:
ret = torch_nms(boxes, scores, iou_threshold)

if change_id and len(ret) > 0:
ret = torch.tensor(nonzero[ret], dtype=torch.int64).flatten()

return ret.flatten()[:max_output_size]
83 changes: 83 additions & 0 deletions ivy/functional/ivy/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2340,3 +2340,86 @@ def _get_num_padded_values(i, p, n, k, s):
return max(0, left_padding - current_index) + max(
0, current_index + k - n - left_padding
)


# TODO add paddle backend implementation back,
# once paddle.argsort uses a stable algorithm
# https://github.com/PaddlePaddle/Paddle/issues/57508
@handle_exceptions
@handle_nestable
@handle_array_like_without_promotion
@inputs_to_ivy_arrays
@handle_array_function
def nms(
boxes,
scores=None,
iou_threshold=0.5,
max_output_size=None,
score_threshold=float("-inf"),
):
change_id = False
if score_threshold is not float("-inf") and scores is not None:
keep_idx = scores > score_threshold
boxes = boxes[keep_idx]
scores = scores[keep_idx]
change_id = True
nonzero = ivy.nonzero(keep_idx)[0].flatten()
if scores is None:
scores = ivy.ones((boxes.shape[0],), dtype=boxes.dtype)

if len(boxes) < 2:
if len(boxes) == 1:
ret = ivy.array([0], dtype=ivy.int64)
else:
ret = ivy.array([], dtype=ivy.int64)
else:
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]

areas = (x2 - x1) * (y2 - y1)
order = ivy.argsort(
(-1 * scores), stable=True
) # get boxes with more ious first
keep = []

while order.size > 0:
i = order[0] # pick maxmum iou box
keep.append(i)
xx1 = ivy.maximum(x1[i], x1[order[1:]])
yy1 = ivy.maximum(y1[i], y1[order[1:]])
xx2 = ivy.minimum(x2[i], x2[order[1:]])
yy2 = ivy.minimum(y2[i], y2[order[1:]])

w = ivy.maximum(0.0, xx2 - xx1) # maximum width
h = ivy.maximum(0.0, yy2 - yy1) # maxiumum height
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = ivy.nonzero(ovr <= iou_threshold)[0]

order = order[inds + 1]

ret = ivy.array(keep)

if len(ret) > 1 and scores is not None:
ret = sorted(
ret.flatten().tolist(), reverse=True, key=lambda x: (scores[x], -x)
)
ret = ivy.array(ret, dtype=ivy.int64).flatten()

if change_id and len(ret) > 0:
ret = ivy.array(nonzero[ret], dtype=ivy.int64).flatten()

return ret.flatten()[:max_output_size]


nms.mixed_backend_wrappers = {
"to_add": (
"handle_backend_invalid",
"inputs_to_native_arrays",
"outputs_to_ivy_arrays",
"handle_device_shifting",
),
"to_skip": ("inputs_to_ivy_arrays",),
}
55 changes: 55 additions & 0 deletions ivy_tests/test_ivy/test_functional/test_nn/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,32 @@ def _mha_helper(draw):
)


@st.composite
def _nms_helper(draw):
img_width = draw(st.integers(250, 1250))
img_height = draw(st.integers(250, 1250))
num_boxes = draw(st.integers(5, 50))
bbox = {}
for _ in range(num_boxes):
x1 = draw(st.integers(0, img_width - 20))
w = draw(st.integers(5, img_width - x1))
y1 = draw(st.integers(0, img_height - 20))
h = draw(st.integers(5, img_height - y1))

bbox[(x1, y1, x1 + w, y1 + h)] = draw(st.floats(0.2, 1))

iou_threshold = draw(st.floats(0.2, 1))
max_output_size = draw(st.integers(1, num_boxes))
score_threshold = draw(st.floats(0, 1))
return (
np.array(list(bbox.keys()), dtype=np.float32),
np.array(list(bbox.values()), dtype=np.float32),
iou_threshold,
max_output_size,
score_threshold,
)


# Convolutions #
# -------------#

Expand Down Expand Up @@ -1269,6 +1295,35 @@ def test_multi_head_attention(
)


@handle_test(
fn_tree="functional.ivy.nms",
inputs=_nms_helper(),
test_instance_method=st.just(False),
test_with_out=st.just(False),
)
def test_nms(
*,
inputs,
test_flags,
backend_fw,
fn_name,
on_device,
):
boxes, scores, iou_threshold, max_output_size, score_threshold = inputs
helpers.test_function(
input_dtypes=[ivy.float32, ivy.float32],
test_flags=test_flags,
backend_to_test=backend_fw,
fn_name=fn_name,
on_device=on_device,
boxes=boxes,
scores=scores,
iou_threshold=iou_threshold,
max_output_size=max_output_size,
score_threshold=score_threshold,
)


# scaled_dot_product_attention
@handle_test(
fn_tree="functional.ivy.scaled_dot_product_attention",
Expand Down
1 change: 1 addition & 0 deletions requirements/optional.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ autoflake # for backend generation
snakeviz # for profiling
cryptography
xgboost
torchvision
1 change: 1 addition & 0 deletions requirements/optional_apple_silicon_1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ tensorflow-macos # mod_name=tensorflow_macos
tensorflow-probability # mod_name=tensorflow_probability
torch
paddlepaddle # unpinned , mod_name=paddle
torchvision
1 change: 1 addition & 0 deletions requirements/optional_apple_silicon_gpu_1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ tensorflow-metal
tensorflow-probability # mod_name=tensorflow_probability
torch
paddlepaddle # unpinned , mod_name=paddle
torchvision
1 change: 1 addition & 0 deletions requirements/optional_gpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ pandas
pyspark
autoflake # for backend generation
snakeviz # for profiling
torchvision

0 comments on commit a034740

Please sign in to comment.