Skip to content

Commit

Permalink
refactor: divide into train and inference scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
arabian9ts committed Mar 17, 2018
1 parent c5e2438 commit 9089b69
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 61 deletions.
48 changes: 48 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
inference script
date: 3/17
author: arabian9ts
"""

import cv2
import sys
from util.util import *
from model.SSD300 import *

def inference(image_name):
if image_name is None:
return Exception('not specified image name to be drawed')

fontType = cv2.FONT_HERSHEY_SIMPLEX
img, w, h, _, = preprocess('./voc2007/'+image_name)
pred_confs, pred_locs = ssd.infer(images=[img])
locs, labels = ssd.ssd.detect_objects(pred_confs, pred_locs)
img = deprocess(img)
if len(labels) and len(locs):
for label, loc in zip(labels, locs):
loc = center2corner(loc)
loc = convert2diagonal_points(loc)
cv2.rectangle(img, (int(loc[0]*w), int(loc[1]*h)), (int(loc[2]*w), int(loc[3]*h)), (0, 0, 255), 1)
cv2.putText(img, str(int(label)), (int(loc[0]*w), int(loc[1]*h)), fontType, 0.7, (0, 0, 255), 1)

return img


# detect objects on a specified image.
if 2 == len(sys.argv):
sess = tf.Session()
# tensorflow session
ssd = SSD300(sess)
sess.run(tf.global_variables_initializer())

# parameter saver
saver = tf.train.Saver()
saver.restore(sess, './checkpoints/params.ckpt')
img = inference(sys.argv[1])
cv2.imwrite('./evaluated/'+sys.argv[1], img)
cv2.namedWindow("img", cv2.WINDOW_NORMAL)
cv2.imshow("img", img)
cv2.waitKey(0)
cv2.destroyAllWindows()
sys.exit()
13 changes: 7 additions & 6 deletions model/SSD300.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ def __init__(self, sess):
# provides matching method
self.matcher = Matcher(fmap_shapes, self.dboxes)

# evaluate loss
def eval(self, images, actual_data, is_training):
if not is_training:
feature_maps, pred_confs, pred_locs = self.sess.run(self.pred_set, feed_dict={self.input: images})
return pred_confs, pred_locs
# inference process
def infer(self, images):
feature_maps, pred_confs, pred_locs = self.sess.run(self.pred_set, feed_dict={self.input: images})
return pred_confs, pred_locs

# training process
def train(self, images, actual_data):
# ================ RESET / EVAL ================ #
positives = []
negatives = []
Expand Down Expand Up @@ -96,4 +97,4 @@ def prepare_loss(pred_confs, pred_locs, actual_labels, actual_locs):
self.sess.run(self.train_step, \
feed_dict={self.input: images, self.pos: positives, self.neg: negatives, self.gt_labels: ex_gt_labels, self.gt_boxes: ex_gt_boxes})

return pred_confs, pred_locs, batch_loc, batch_conf, batch_loss
return pred_confs, pred_locs, batch_loc, batch_conf, batch_loss
58 changes: 3 additions & 55 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'

import cv2
import sys
import datetime
import tensorflow as tf
import numpy as np
Expand All @@ -28,7 +26,7 @@

# ====================== Training Parameters ====================== #
BATCH_SIZE = 10
EPOCH = 100
EPOCH = 200
EPOCH_LOSSES = []
SHUFFLED_INDECES = []
# ============================== END ============================== #
Expand Down Expand Up @@ -57,58 +55,21 @@ def next_batch():
for idx in indices:
# make images mini batch

img = load_image('voc2007/'+keys[idx])
img, _, _, _, = preprocess('voc2007/'+keys[idx])

actual_data.append(data[keys[idx]])
mini_batch.append(img)

buff.append((mini_batch, actual_data))


def draw_marker(image_name, save):
if image_name is None:
return Exception('not specified image name to be drawed')

img = cv2.imread('./voc2007/'+image_name, 1)
h = img.shape[0]
w = img.shape[1]
fontType = cv2.FONT_HERSHEY_SIMPLEX
reshaped = cv2.resize(img, (300, 300))
reshaped = reshaped / 255
pred_confs, pred_locs = ssd.eval(images=[reshaped], actual_data=None, is_training=False)
locs, labels = ssd.ssd.detect_objects(pred_confs, pred_locs)
if len(labels) and len(locs):
for label, loc in zip(labels, locs):
loc = center2corner(loc)
loc = convert2diagonal_points(loc)
cv2.rectangle(img, (int(loc[0]*w), int(loc[1]*h)), (int(loc[2]*w), int(loc[3]*h)), (0, 0, 255), 1)
cv2.putText(img, str(int(label)), (int(loc[0]*w), int(loc[1]*h)), fontType, 0.7, (0, 0, 255), 1)

if save:
if not os.path.exists('./evaluated'):
os.mkdir('./evaluated')
cv2.imwrite('./evaluated/'+image_name, img)

return img


# tensorflow session
ssd = SSD300(sess)
sess.run(tf.global_variables_initializer())

# parameter saver
saver = tf.train.Saver()

# eval and predict object on a specified image.
if 2 == len(sys.argv):
saver.restore(sess, './checkpoints/params.ckpt')
img = draw_marker(sys.argv[1], save=False)
cv2.namedWindow("img", cv2.WINDOW_NORMAL)
cv2.imshow("img", img)
cv2.waitKey(0)
cv2.destroyAllWindows()
sys.exit()

# saver.restore(sess, './checkpoints/params.ckpt')

SHUFFLED_INDECES = list(np.random.permutation(len(keys)))
Expand All @@ -124,30 +85,17 @@ def draw_marker(image_name, save):
for ba in trange(BATCH):
batch, actual = buff.pop(0)
threading.Thread(name='load', target=next_batch).start()
_, _, batch_loc, batch_conf, batch_loss = ssd.eval(batch, actual, True)
_, _, batch_loc, batch_conf, batch_loss = ssd.train(batch, actual)
BATCH_LOSSES.append(batch_loss)

# print('BATCH: {0} / EPOCH: {1}, LOSS: {2}'.format(ba+1, ep+1, batch_loss))
EPOCH_LOSSES.append(np.mean(BATCH_LOSSES))
print('\n*** AVERAGE: '+str(EPOCH_LOSSES[-1])+' ***')

saver.save(sess, './checkpoints/params.ckpt')


print('\n*** TEST ***')
id = np.random.choice(len(keys))
name = keys[id]
draw_marker(image_name=name, save=True)
print('\nSaved Evaled Image')


print('\n========== EPOCH: '+str(ep+1)+' END ==========')

print('\nEND LEARNING')


saver.save(sess, './params_final.ckpt')

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.plot(np.array(range(EPOCH)), EPOCH_LOSSES)
Expand Down

0 comments on commit 9089b69

Please sign in to comment.