From 2f0f594aebae39d8a3a93d523217b9f725a1e811 Mon Sep 17 00:00:00 2001 From: Adheesh Juvekar Date: Thu, 30 May 2024 13:09:14 -0400 Subject: [PATCH] remove unnecessary args --- softnms_pytorch.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/softnms_pytorch.py b/softnms_pytorch.py index 26b0bf0..adbcdfb 100644 --- a/softnms_pytorch.py +++ b/softnms_pytorch.py @@ -6,7 +6,7 @@ import torch -def soft_nms_pytorch(dets, box_scores, sigma=0.5, thresh=0.001, cuda=0): +def soft_nms_pytorch(dets, box_scores, sigma=0.5, thresh=0.001): """ Build a pytorch implement of Soft NMS algorithm. # Augments @@ -14,17 +14,13 @@ def soft_nms_pytorch(dets, box_scores, sigma=0.5, thresh=0.001, cuda=0): box_scores: box score tensors sigma: variance of Gaussian function thresh: score thresh - cuda: CUDA flag # Return the index of the selected boxes """ # Indexes concatenate boxes with the last column N = dets.shape[0] - if cuda: - indexes = torch.arange(0, N, dtype=torch.float).cuda().view(N, 1) - else: - indexes = torch.arange(0, N, dtype=torch.float).view(N, 1) + indexes = torch.arange(0, N, dtype=torch.float).view(N, 1).to(dets.device) dets = torch.cat((dets, indexes), dim=1) # The order of boxes coordinate is [y1,x1,y2,x2] @@ -55,7 +51,7 @@ def soft_nms_pytorch(dets, box_scores, sigma=0.5, thresh=0.001, cuda=0): w = np.maximum(0.0, xx2 - xx1 + 1) h = np.maximum(0.0, yy2 - yy1 + 1) - inter = torch.tensor(w * h).cuda() if cuda else torch.tensor(w * h) + inter = torch.tensor(w * h).to(dets.device) ovr = torch.div(inter, (areas[i] + areas[pos:] - inter)) # Gaussian decay