Skip to content

Commit

Permalink
Merge pull request arabian9ts#18 from arabian9ts/nms
Browse files Browse the repository at this point in the history
Non Maximum Suppression
  • Loading branch information
arabian9ts authored Feb 28, 2018
2 parents 52e3d1f + 0572464 commit 7e06587
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 15 deletions.
2 changes: 1 addition & 1 deletion matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class label loss is evaled by loss_conf
matches[near_index] = Box(gt_box, gt_label)
pos += 1

indicies = self.extract_highest_indicies(pred_confs, pos*3)
indicies = self.extract_highest_indicies(pred_confs, pos*5)

for i in indicies:
if matches[i] is None and classes-1 != np.argmax(pred_confs[i]):
Expand Down
9 changes: 5 additions & 4 deletions model/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
author: arabian9ts
"""

import numpy as np

def intersection(rect1, rect2):
"""
Expand Down Expand Up @@ -49,7 +50,7 @@ def corner2center(rect):
center_x = (2 * rect[0] + rect[2]) * 0.5
center_y = (2 * rect[1] + rect[3]) * 0.5

return [center_x, center_y, abs(rect[2]), abs(rect[3])]
return np.array([center_x, center_y, abs(rect[2]), abs(rect[3])])


def center2corner(rect):
Expand All @@ -59,7 +60,7 @@ def center2corner(rect):
corner_x = rect[0] - rect[2] * 0.5
corner_y = rect[1] - rect[3] * 0.5

return [corner_x, corner_y, abs(rect[2]), abs(rect[3])]
return np.array([corner_x, corner_y, abs(rect[2]), abs(rect[3])])


def convert2diagonal_points(rect):
Expand All @@ -73,7 +74,7 @@ def convert2diagonal_points(rect):
output format is...
[ top_left_x, top_left_y, bottom_right_x, bottom_right_y ]
"""
return [rect[0], rect[1], rect[0]+rect[2], rect[1]+rect[3]]
return np.array([rect[0], rect[1], rect[0]+rect[2], rect[1]+rect[3]])


def convert2wh(rect):
Expand All @@ -87,4 +88,4 @@ def convert2wh(rect):
output format is...
[ top_left_x, top_left_y, width, height ]
"""
return [rect[0], rect[1], rect[2]-rect[0], rect[3]-rect[1]]
return np.array([rect[0], rect[1], rect[2]-rect[0], rect[3]-rect[1]])
2 changes: 1 addition & 1 deletion model/default_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def generate_boxes(fmap_shapes):
for i, ratio in enumerate(ratios):
s = s_k

if 0 == i:
if 1.0 == s:
s = np.sqrt(s_k*s_k1)

box_width = s * np.sqrt(ratio)
Expand Down
8 changes: 4 additions & 4 deletions model/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
classes = 21

# the number of boxes per feature map
boxes = [6, 3, 6, 3, 6, 3,]
boxes = [4, 6, 6, 6, 6, 6,]

# default box ratios
# each length should be matches boxes[index]
box_ratios = [
[1.0, 1.0, 2.0, 1.0/2.0],
[1.0, 1.0, 2.0, 1.0/2.0, 3.0, 1.0/3.0],
[1.0, 1.0, 2.0, 1.0/2.0, 3.0, 1.0/3.0],
[1.0, 1.0, 2.0, 1.0/2.0, 3.0, 1.0/3.0],
[1.0, 2.0, 1.0/2.0],
[1.0, 1.0, 2.0, 1.0/2.0, 3.0, 1.0/3.0],
[1.0, 2.0, 1.0/2.0],
[1.0, 1.0, 2.0, 1.0/2.0, 3.0, 1.0/3.0],
[1.0, 2.0, 1.0/2.0],
]
74 changes: 69 additions & 5 deletions model/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,63 @@ def jacc_filter(plabel, ploc):
det_labels.append(plabel)

return det_locs, det_labels


def non_maximum_suppression(self, candidates, overlap_threshold):
"""
this is nms(non maximum_suppression) which filters predicted objects.
Args:
predicted bounding boxes
Returns:
detected bounding boxes and its label
"""

label = candidates[:,4]
boxes = candidates[label<classes-1]

if len(boxes) == 0:
return []

picked = []

x1 = boxes[:,0]
y1 = boxes[:,1]
x2 = boxes[:,2] + x1
y2 = boxes[:,3] + y1

area = (x2 - x1 + 1) * (y2 - y1 + 1)
idxs = np.argsort(y2)

while len(idxs) > 0:
last = len(idxs) - 1
i = idxs[last]
picked.append(i)
suppress = [last]

for pos in range(0, last):
j = idxs[pos]

# extract smallest and largest bouding boxes
xx1 = max(x1[i], x1[j])
yy1 = max(y1[i], y1[j])
xx2 = min(x2[i], x2[j])
yy2 = min(y2[i], y2[j])

w = max(0, xx2 - xx1 + 1)
h = max(0, yy2 - yy1 + 1)

# overlap of current box and those in area list
overlap = float(w * h) / area[j]

# suppress current box
if overlap > overlap_threshold:
suppress.append(pos)

# delete suppressed indexes
idxs = np.delete(idxs, suppress)

return boxes[picked]


def detect_objects(self, pred_confs, pred_locs):
Expand All @@ -219,9 +276,16 @@ def detect_objects(self, pred_confs, pred_locs):
# extract top 200 by confidence
possibilities = [np.amax(np.exp(conf)) / (np.sum(np.exp(conf)) + 1e-3) for conf in pred_confs[0]]
indicies = np.argpartition(possibilities, -200)[-200:]
top200 = np.asarray(possibilities)[indicies]
slicer = indicies[0.7 < top200]

locations, labels = self._filter(pred_confs[0][slicer], pred_locs[0][slicer])
# top200 = np.asarray(possibilities)[indicies]
# slicer = indicies[0.7 < top200]
# locations, labels = self._filter(pred_confs[0][slicer], pred_locs[0][slicer])

locations, labels = pred_locs[0][indicies], np.argmax(pred_confs[0][indicies], axis=1)
labels = np.asarray(labels).reshape(len(labels), 1)
with_labels = np.concatenate((locations, labels), axis=1)

return locations, labels
# labels, locations = image.non_max_suppression(boxes, possibilities, 10)
filtered = self.non_maximum_suppression(with_labels, 0.5)
# locations, labels = pred_confs[0][indices], pred_locs[0][indices]

return filtered[:,:4], filtered[:,4]

0 comments on commit 7e06587

Please sign in to comment.