Skip to content

Commit

Permalink
Get input image shape statically in the model construction function
Browse files Browse the repository at this point in the history
Instead of using input.get_shape() fonction
We can pass input image shape to yolo_v3 and yolo_v3_tiny constructor
function.
Since the input placeholder shape is statically defined (to None, size, size, 3)
We can have access to this 'size' when constructing the yolo_v3 or
yolo_v3_tiny models.

This if more efficient for inference.
This commit should not break any anterior codes since it is only adding
1 optional argument to model constructor functions
  • Loading branch information
LucasMahieu committed Oct 23, 2020
1 parent 136fb66 commit d3daf23
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 13 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ To run demo type this in the command line:
1. Use yolov3-spp
6. `--ckpt_file`
1. Output checkpoint file
7. `--size`
1. Input image size
2. convert_weights_pb.py:
1. `--class_names`
1. Path to the class names file
Expand All @@ -51,6 +53,8 @@ To run demo type this in the command line:
1. Use yolov3-spp
6. `--output_graph`
1. Location to write the output .pb graph to
7. `--size`
1. Input image size
3. demo.py
1. `--class_names`
1. Path to the class names file
Expand All @@ -68,3 +72,5 @@ To run demo type this in the command line:
1. Desired iou threshold
8. `--gpu_memory_fraction`
1. Fraction of gpu memory to work with
9. `--size`
1. Input image size
5 changes: 4 additions & 1 deletion convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
'spp', False, 'Use SPP version of YOLOv3')
tf.app.flags.DEFINE_string(
'ckpt_file', './saved_model/model.ckpt', 'Chceckpoint file')
tf.app.flags.DEFINE_integer(
'size', 416, 'Input Image size')


def main(argv=None):
Expand All @@ -39,7 +41,8 @@ def main(argv=None):

with tf.variable_scope('detector'):
detections = model(inputs, len(classes),
data_format=FLAGS.data_format)
data_format=FLAGS.data_format,
img_size=[FLAGS.size, FLAGS.size])
load_ops = load_weights(tf.global_variables(
scope='detector'), FLAGS.weights_file)

Expand Down
5 changes: 3 additions & 2 deletions convert_weights_pb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
'tiny', False, 'Use tiny version of YOLOv3')
tf.app.flags.DEFINE_bool(
'spp', False, 'Use SPP version of YOLOv3')

tf.app.flags.DEFINE_integer(
'size', 416, 'Image size')
'size', 416, 'Input image size')



Expand All @@ -42,7 +43,7 @@ def main(argv=None):
inputs = tf.placeholder(tf.float32, [None, FLAGS.size, FLAGS.size, 3], "inputs")

with tf.variable_scope('detector'):
detections = model(inputs, len(classes), data_format=FLAGS.data_format)
detections = model(inputs, len(classes), data_format=FLAGS.data_format, img_size=[FLAGS.size, FLAGS.size])
load_ops = load_weights(tf.global_variables(scope='detector'), FLAGS.weights_file)

# Sets the output nodes in the current session
Expand Down
2 changes: 1 addition & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
'spp', False, 'Use SPP version of YOLOv3')

tf.app.flags.DEFINE_integer(
'size', 416, 'Image size')
'size', 416, 'Input image size')

tf.app.flags.DEFINE_float(
'conf_threshold', 0.5, 'Confidence threshold')
Expand Down
3 changes: 2 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def get_boxes_and_inputs(model, num_classes, size, data_format):

with tf.variable_scope('detector'):
detections = model(inputs, num_classes,
data_format=data_format)
data_format=data_format,
img_size=[size, size])

boxes = detections_boxes(detections)

Expand Down
8 changes: 3 additions & 5 deletions yolo_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _upsample(inputs, out_shape, data_format='NCHW'):
return inputs


def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False, with_spp=False):
def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False, with_spp=False, img_size=[416, 416]):
"""
Creates YOLO v3 model.
Expand All @@ -213,8 +213,6 @@ def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=Fa
:param with_spp: whether or not is using spp layer.
:return:
"""
# it will be needed later on
img_size = inputs.get_shape().as_list()[1:3]

# transpose the inputs to NCHW
if data_format == 'NCHW':
Expand Down Expand Up @@ -277,7 +275,7 @@ def yolo_v3(inputs, num_classes, is_training=False, data_format='NCHW', reuse=Fa
return detections


def yolo_v3_spp(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False):
def yolo_v3_spp(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False, img_size=[416, 416]):
"""
Creates YOLO v3 with SPP model.
Expand All @@ -289,4 +287,4 @@ def yolo_v3_spp(inputs, num_classes, is_training=False, data_format='NCHW', reus
:param reuse: whether or not the network and its variables should be reused.
:return:
"""
return yolo_v3(inputs, num_classes, is_training=is_training, data_format=data_format, reuse=reuse, with_spp=True)
return yolo_v3(inputs, num_classes, is_training=is_training, data_format=data_format, reuse=reuse, with_spp=True, img_size=img_size)
4 changes: 1 addition & 3 deletions yolo_v3_tiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
(81, 82), (135, 169), (344, 319)]


def yolo_v3_tiny(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False):
def yolo_v3_tiny(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False, img_size=[416, 416]):
"""
Creates YOLO v3 tiny model.
Expand All @@ -27,8 +27,6 @@ def yolo_v3_tiny(inputs, num_classes, is_training=False, data_format='NCHW', reu
:param reuse: whether or not the network and its variables should be reused.
:return:
"""
# it will be needed later on
img_size = inputs.get_shape().as_list()[1:3]

# transpose the inputs to NCHW
if data_format == 'NCHW':
Expand Down

0 comments on commit d3daf23

Please sign in to comment.