Skip to content

Commit

Permalink
enhance: non maximum suppression
Browse files Browse the repository at this point in the history
  • Loading branch information
arabian9ts committed Feb 6, 2018
1 parent f89d088 commit 6f5ffad
Showing 1 changed file with 69 additions and 5 deletions.
74 changes: 69 additions & 5 deletions model/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,63 @@ def jacc_filter(plabel, ploc):
det_labels.append(plabel)

return det_locs, det_labels


def non_maximum_suppression(self, candidates, overlap_threshold):
"""
this is nms(non maximum_suppression) which filters predicted objects.
Args:
predicted bounding boxes
Returns:
detected bounding boxes and its label
"""

label = candidates[:,4]
boxes = candidates[label<classes]

if len(boxes) == 0:
return []

picked = []

x1 = boxes[:,0]
y1 = boxes[:,1]
x2 = boxes[:,2] + x1
y2 = boxes[:,3] + y1

area = (x2 - x1 + 1) * (y2 - y1 + 1)
idxs = np.argsort(y2)

while len(idxs) > 0:
last = len(idxs) - 1
i = idxs[last]
picked.append(i)
suppress = [last]

for pos in range(0, last):
j = idxs[pos]

# extract smallest and largest bouding boxes
xx1 = max(x1[i], x1[j])
yy1 = max(y1[i], y1[j])
xx2 = min(x2[i], x2[j])
yy2 = min(y2[i], y2[j])

w = max(0, xx2 - xx1 + 1)
h = max(0, yy2 - yy1 + 1)

# overlap of current box and those in area list
overlap = float(w * h) / area[j]

# suppress current box
if overlap > overlap_threshold:
suppress.append(pos)

# delete suppressed indexes
idxs = np.delete(idxs, suppress)

return boxes[picked]


def detect_objects(self, pred_confs, pred_locs):
Expand All @@ -219,9 +276,16 @@ def detect_objects(self, pred_confs, pred_locs):
# extract top 200 by confidence
possibilities = [np.amax(np.exp(conf)) / (np.sum(np.exp(conf)) + 1e-3) for conf in pred_confs[0]]
indicies = np.argpartition(possibilities, -200)[-200:]
top200 = np.asarray(possibilities)[indicies]
slicer = indicies[0.7 < top200]

locations, labels = self._filter(pred_confs[0][slicer], pred_locs[0][slicer])
# top200 = np.asarray(possibilities)[indicies]
# slicer = indicies[0.7 < top200]
# locations, labels = self._filter(pred_confs[0][slicer], pred_locs[0][slicer])

locations, labels = pred_locs[0][indicies], np.argmax(pred_confs[0][indicies], axis=1)
labels = np.asarray(labels).reshape(len(labels), 1)
with_labels = np.concatenate((locations, labels), axis=1)

return locations, labels
# labels, locations = image.non_max_suppression(boxes, possibilities, 10)
filtered = self.non_maximum_suppression(with_labels, 0.5)
# locations, labels = pred_confs[0][indices], pred_locs[0][indices]

return filtered[:,:4], filtered[:,4]

0 comments on commit 6f5ffad

Please sign in to comment.