Skip to content

Commit

Permalink
Merge pull request arabian9ts#17 from arabian9ts/fix-model
Browse files Browse the repository at this point in the history
Fix-model
  • Loading branch information
arabian9ts authored Feb 28, 2018
2 parents 7e06587 + 42b6158 commit 8065904
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 30 deletions.
20 changes: 6 additions & 14 deletions matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,25 +105,17 @@ class label loss is evaled by loss_conf
near_jacc = 0.
near_index = None
for i in range(len(matches)):
jacc = jaccard(center2corner(gt_box), center2corner(self.default_boxes[i]))
jacc = jaccard(gt_box, self.default_boxes[i])
if 0.5 <= jacc:
matches[i] = Box(gt_box, gt_label)
pos += 1
matched.append(gt_label)
else:
if near_jacc < jacc:
near_miss = jacc
near_index = i

# prevent pos from becoming 0 <=> loss_loc is 0
# force to match most near box to ground truth box
if 0 == len(matched) and near_index is not None and matches[near_index] is None:
matches[near_index] = Box(gt_box, gt_label)
pos += 1

indicies = self.extract_highest_indicies(pred_confs, pos*5)

neg_pos = 5
indicies = self.extract_highest_indicies(pred_confs, pos*neg_pos)
for i in indicies:
if neg > pos*neg_pos:
break
if matches[i] is None and classes-1 != np.argmax(pred_confs[i]):
matches[i] = Box([], classes-1)
neg += 1
Expand Down Expand Up @@ -152,4 +144,4 @@ class label loss is evaled by loss_conf
expanded_gt_locs.append(box.loc)


return pos_list, neg_list, expanded_gt_labels, expanded_gt_locs
return pos_list, neg_list, expanded_gt_labels, expanded_gt_locs
10 changes: 5 additions & 5 deletions model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,18 @@ def build(self, input, is_training=True):

self.conv4_1 = convolution(self.pool3, 'conv4_1')
self.conv4_2 = convolution(self.conv4_1, 'conv4_2')
self.conv4_3 = convolution(self.conv4_2, 'conv4_3')
self.pool4 = pooling(self.conv4_3, 'pool4')
# self.conv4_3 = convolution(self.conv4_2, 'conv4_3')
# self.pool4 = pooling(self.conv4_3, 'pool4')

self.conv5_1 = convolution(self.pool4, 'conv5_1')
self.conv5_2 = convolution(self.conv5_1, 'conv5_2')
# self.conv5_1 = convolution(self.pool4, 'conv5_1')
# self.conv5_2 = convolution(self.conv5_1, 'conv5_2')
# self.conv5_3 = convolution(self.conv5_2, 'conv5_3')
# self.pool5 = self.pooling(self.conv5_3, 'pool5')

# self.fc6 = self.fully_connection(self.pool5, Activation.relu, 'fc6')
# self.fc7 = self.fully_connection(self.fc6, Activation.relu, 'fc7')
# self.fc8 = self.fully_connection(self.fc7, Activation.softmax, 'fc8')

self.prob = self.conv5_2
self.prob = self.conv4_2

return self.prob
26 changes: 15 additions & 11 deletions model/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def build(self, input, is_training=True):
self.base = super().build(input, is_training)

self.conv6 = convolution(self.base, 'conv6')
self.conv7 = convolution(self.conv6, 'conv7')
self.pool6 = pooling(self.conv6, 'pool6')
self.conv7 = convolution(self.pool6, 'conv7')

self.conv8_1 = convolution(self.conv7, 'conv8_1')
self.conv8_2 = convolution(self.conv8_1, 'conv8_2', ksize=3, stride=2)
Expand Down Expand Up @@ -220,8 +221,8 @@ def non_maximum_suppression(self, candidates, overlap_threshold):
x2 = boxes[:,2] + x1
y2 = boxes[:,3] + y1

area = (x2 - x1 + 1) * (y2 - y1 + 1)
idxs = np.argsort(y2)
area = (boxes[:,2]) * (boxes[:,3])
idxs = np.argsort(x1)

while len(idxs) > 0:
last = len(idxs) - 1
Expand All @@ -238,8 +239,8 @@ def non_maximum_suppression(self, candidates, overlap_threshold):
xx2 = min(x2[i], x2[j])
yy2 = min(y2[i], y2[j])

w = max(0, xx2 - xx1 + 1)
h = max(0, yy2 - yy1 + 1)
w = max(0, xx2 - xx1)
h = max(0, yy2 - yy1)

# overlap of current box and those in area list
overlap = float(w * h) / area[j]
Expand Down Expand Up @@ -271,21 +272,24 @@ def detect_objects(self, pred_confs, pred_locs):
hist = [0 for _ in range(classes)]
for conf, loc in zip(pred_confs[0], pred_locs[0]):
hist[np.argmax(conf)] += 1
print(hist)
# print(hist)

# extract top 200 by confidence
possibilities = [np.amax(np.exp(conf)) / (np.sum(np.exp(conf)) + 1e-3) for conf in pred_confs[0]]
indicies = np.argpartition(possibilities, -200)[-200:]
# top200 = np.asarray(possibilities)[indicies]
# slicer = indicies[0.7 < top200]
# locations, labels = self._filter(pred_confs[0][slicer], pred_locs[0][slicer])
top200 = np.asarray(possibilities)[indicies]
slicer = indicies[0.9 < top200]
locations, labels = self._filter(pred_confs[0][slicer], pred_locs[0][slicer])

locations, labels = pred_locs[0][indicies], np.argmax(pred_confs[0][indicies], axis=1)
locations, labels = pred_locs[0][slicer], np.argmax(pred_confs[0][slicer], axis=1)
labels = np.asarray(labels).reshape(len(labels), 1)
with_labels = np.concatenate((locations, labels), axis=1)

# labels, locations = image.non_max_suppression(boxes, possibilities, 10)
filtered = self.non_maximum_suppression(with_labels, 0.5)
filtered = self.non_maximum_suppression(with_labels, 0.1)
# locations, labels = pred_confs[0][indices], pred_locs[0][indices]
if len(filtered) == 0:
filtered = np.zeros((4, 5))

return filtered[:,:4], filtered[:,4]
#return locations, labels
2 changes: 2 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def draw_marker(image_name, save):
img = cv2.imread('./voc2007/'+image_name, 1)
h = img.shape[0]
w = img.shape[1]
fontType = cv2.FONT_HERSHEY_SIMPLEX
reshaped = cv2.resize(img, (300, 300))
reshaped = reshaped / 255
pred_confs, pred_locs = ssd.eval(images=[reshaped], actual_data=None, is_training=False)
Expand All @@ -81,6 +82,7 @@ def draw_marker(image_name, save):
loc = center2corner(loc)
loc = convert2diagonal_points(loc)
cv2.rectangle(img, (int(loc[0]*w), int(loc[1]*h)), (int(loc[2]*w), int(loc[3]*h)), (0, 0, 255), 1)
cv2.putText(img, str(int(label)), (int(loc[0]*w), int(loc[1]*h)), fontType, 0.7, (0, 0, 255), 1)

if save:
if not os.path.exists('./evaluated'):
Expand Down

0 comments on commit 8065904

Please sign in to comment.