Skip to content

Commit

Permalink
fix tflite inference (#1036)
Browse files Browse the repository at this point in the history
* fix tflite inference and add keras turtorial

* update requirements

* update ema support

* add continue training

* add box_net

* add init epoch

* add new features

* format code
  • Loading branch information
fsx950223 authored Dec 15, 2021
1 parent 39c39e5 commit d5afa2a
Show file tree
Hide file tree
Showing 18 changed files with 1,171 additions and 416 deletions.
2 changes: 1 addition & 1 deletion efficientdet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ To visualize training tfrecords with input dataloader use.
```
python dataset/inspect_tfrecords.py --file_pattern dataset/sample.record\
--model_name "efficientdet-d0" --samples 10\
--save_samples_dir train_samples/ -hparams="label_map={1:'label1'}, autoaugmentation_policy=v3"
--save_samples_dir train_samples/ --hparams="label_map={1:'label1'}, autoaugmentation_policy=v3"
```

Expand Down
1 change: 0 additions & 1 deletion efficientdet/hparams_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,6 @@ def default_detection_configs():
h.strategy = None # 'tpu', 'gpus', None
h.mixed_precision = False # If False, use float32.
h.loss_scale = None # set to 2**16 enables dynamic loss scale
h.model_optimizations = {} # 'prune':{}

# For detection.
h.box_class_repeats = 3
Expand Down
3 changes: 1 addition & 2 deletions efficientdet/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def image_preprocess(image, image_size, mean_rgb, stddev_rgb):
def batch_image_files_decode(image_files):
raw_images = tf.TensorArray(tf.uint8, size=0, dynamic_size=True)
for i in tf.range(tf.shape(image_files)[0]):
image = tf.io.decode_image(image_files[i])
image.set_shape([None, None, None])
image = tf.io.decode_image(image_files[i], expand_animations=False)
raw_images = raw_images.write(i, image)
return raw_images.stack()

Expand Down
34 changes: 16 additions & 18 deletions efficientdet/object_detection/tf_example_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
protos for object detection.
"""

import tensorflow.compat.v1 as tf
import tensorflow as tf


def _get_source_id_from_encoded_image(parsed_tensors):
Expand All @@ -34,29 +34,27 @@ def __init__(self, include_mask=False, regenerate_source_id=False):
self._include_mask = include_mask
self._regenerate_source_id = regenerate_source_id
self._keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string),
'image/source_id': tf.FixedLenFeature((), tf.string, ''),
'image/height': tf.FixedLenFeature((), tf.int64, -1),
'image/width': tf.FixedLenFeature((), tf.int64, -1),
'image/object/bbox/xmin': tf.VarLenFeature(tf.float32),
'image/object/bbox/xmax': tf.VarLenFeature(tf.float32),
'image/object/bbox/ymin': tf.VarLenFeature(tf.float32),
'image/object/bbox/ymax': tf.VarLenFeature(tf.float32),
'image/object/class/label': tf.VarLenFeature(tf.int64),
'image/object/area': tf.VarLenFeature(tf.float32),
'image/object/is_crowd': tf.VarLenFeature(tf.int64),
'image/encoded': tf.io.FixedLenFeature((), tf.string),
'image/source_id': tf.io.FixedLenFeature((), tf.string, ''),
'image/height': tf.io.FixedLenFeature((), tf.int64, -1),
'image/width': tf.io.FixedLenFeature((), tf.int64, -1),
'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
'image/object/class/label': tf.io.VarLenFeature(tf.int64),
'image/object/area': tf.io.VarLenFeature(tf.float32),
'image/object/is_crowd': tf.io.VarLenFeature(tf.int64),
}
if include_mask:
self._keys_to_features.update({
'image/object/mask':
tf.VarLenFeature(tf.string),
tf.io.VarLenFeature(tf.string),
})

def _decode_image(self, parsed_tensors):
"""Decodes the image and set its static shape."""
image = tf.io.decode_image(parsed_tensors['image/encoded'], channels=3)
image.set_shape([None, None, 3])
return image
return tf.io.decode_image(parsed_tensors['image/encoded'], channels=3, expand_animations=False)

def _decode_boxes(self, parsed_tensors):
"""Concat box coordinates in the format of [ymin, xmin, ymax, xmax]."""
Expand Down Expand Up @@ -118,10 +116,10 @@ def decode(self, serialized_example):
for k in parsed_tensors:
if isinstance(parsed_tensors[k], tf.SparseTensor):
if parsed_tensors[k].dtype == tf.string:
parsed_tensors[k] = tf.sparse_tensor_to_dense(
parsed_tensors[k] = tf.sparse.to_dense(
parsed_tensors[k], default_value='')
else:
parsed_tensors[k] = tf.sparse_tensor_to_dense(
parsed_tensors[k] = tf.sparse.to_dense(
parsed_tensors[k], default_value=0)

image = self._decode_image(parsed_tensors)
Expand Down
5 changes: 2 additions & 3 deletions efficientdet/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ numpy>=1.19.4
Pillow>=6.0.0
PyYAML>=5.1
six>=1.15.0
tensorflow>=2.4.0
tensorflow-addons>=0.12
tensorflow>=2.7.0
tensorflow-addons>=0.15
tensorflow-hub>=0.11
neural-structured-learning>=1.3.1
tensorflow-model-optimization>=0.5
Cython>=0.29.13
git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI
2 changes: 1 addition & 1 deletion efficientdet/tf2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[1] Mingxing Tan, Ruoming Pang, Quoc V. Le. EfficientDet: Scalable and Efficient Object Detection. CVPR 2020.
Arxiv link: https://arxiv.org/abs/1911.09070

**Quick start tutorial: [tutorial.ipynb](tutorial.ipynb)**
**Quick start tutorial: [tutorial.ipynb](./tutorial.ipynb)**

**Quick install dependencies: ```pip install -r requirements.txt```**

Expand Down
35 changes: 10 additions & 25 deletions efficientdet/tf2/efficientdet_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from backbone import efficientnet_builder
from tf2 import fpn_configs
from tf2 import postprocess
from tf2 import tfmot
from tf2 import util_keras


Expand Down Expand Up @@ -56,7 +55,6 @@ def __init__(self,
strategy,
weight_method,
data_format,
model_optimizations,
name='fnode'):
super().__init__(name=name)
self.feat_level = feat_level
Expand All @@ -73,7 +71,6 @@ def __init__(self,
self.conv_bn_act_pattern = conv_bn_act_pattern
self.resample_layers = []
self.vars = []
self.model_optimizations = model_optimizations

def fuse_features(self, nodes):
"""Fuse features from different resolutions and return a weighted sum.
Expand Down Expand Up @@ -141,7 +138,6 @@ def build(self, feats_shape):
self.conv_after_downsample,
strategy=self.strategy,
data_format=self.data_format,
model_optimizations=self.model_optimizations,
name=name))
if self.weight_method == 'attn':
self._add_wsm('ones')
Expand All @@ -161,7 +157,6 @@ def build(self, feats_shape):
self.act_type,
self.data_format,
self.strategy,
self.model_optimizations,
name='op_after_combine{}'.format(len(feats_shape)))
self.built = True
super().build(feats_shape)
Expand All @@ -188,7 +183,6 @@ def __init__(self,
act_type,
data_format,
strategy,
model_optimizations,
name='op_after_combine'):
super().__init__(name=name)
self.conv_bn_act_pattern = conv_bn_act_pattern
Expand All @@ -211,10 +205,6 @@ def __init__(self,
use_bias=not self.conv_bn_act_pattern,
data_format=self.data_format,
name='conv')
if model_optimizations:
for method in model_optimizations.keys():
self.conv_op = (
tfmot.get_method(method)(self.conv_op))
self.bn = util_keras.build_batch_norm(
is_training_bn=self.is_training_bn,
data_format=self.data_format,
Expand Down Expand Up @@ -244,7 +234,6 @@ def __init__(self,
data_format=None,
pooling_type=None,
upsampling_type=None,
model_optimizations=None,
name='resample_p0'):
super().__init__(name=name)
self.apply_bn = apply_bn
Expand All @@ -262,9 +251,6 @@ def __init__(self,
padding='same',
data_format=self.data_format,
name='conv2d')
if model_optimizations:
for method in model_optimizations.keys():
self.conv2d = tfmot.get_method(method)(self.conv2d)
self.bn = util_keras.build_batch_norm(
is_training_bn=self.is_training_bn,
data_format=self.data_format,
Expand All @@ -291,14 +277,14 @@ def _pool2d(self, inputs, height, width, target_height, target_width):

def _upsample2d(self, inputs, target_height, target_width):
if self.data_format == 'channels_first':
inputs = tf.compat.v1.transpose(inputs, perm=[0, 2, 3, 1])
outputs = tf.cast(
inputs = tf.transpose(inputs, [0, 2, 3, 1])
resized = tf.cast(
tf.compat.v1.image.resize_nearest_neighbor(
tf.cast(inputs, tf.float32), [target_height, target_width]),
inputs.dtype)
if self.data_format == 'channels_first':
outputs = tf.compat.v1.transpose(outputs, perm=[0, 3, 1, 2])
return outputs
resized = tf.transpose(resized, [0, 3, 1, 2])
return resized

def _maybe_apply_1x1(self, feat, training, num_channels):
"""Apply 1x1 conv to change layer width if necessary."""
Expand Down Expand Up @@ -428,15 +414,14 @@ def __init__(self,
def _conv_bn_act(self, image, i, level_id, training):
conv_op = self.conv_ops[i]
bn = self.bns[i][level_id]
act_type = self.act_type

@utils.recompute_grad(self.grad_checkpoint)
def _call(image):
original_image = image
image = conv_op(image)
image = bn(image, training=training)
if self.act_type:
image = utils.activation_fn(image, act_type)
image = utils.activation_fn(image, self.act_type)
if i > 0 and self.survival_prob:
image = utils.drop_connect(image, training, self.survival_prob)
image = image + original_image
Expand Down Expand Up @@ -590,15 +575,14 @@ def __init__(self,
def _conv_bn_act(self, image, i, level_id, training):
conv_op = self.conv_ops[i]
bn = self.bns[i][level_id]
act_type = self.act_type

@utils.recompute_grad(self.grad_checkpoint)
def _call(image):
original_image = image
image = conv_op(image)
image = bn(image, training=training)
if self.act_type:
image = utils.activation_fn(image, act_type)
image = utils.activation_fn(image, self.act_type)
if i > 0 and self.survival_prob:
image = utils.drop_connect(image, training, self.survival_prob)
image = image + original_image
Expand Down Expand Up @@ -754,6 +738,7 @@ class FPNCell(tf.keras.layers.Layer):

def __init__(self, config, name='fpn_cell'):
super().__init__(name=name)
logging.info('building FPNCell %s', name)
self.config = config
if config.fpn_config:
self.fpn_config = config.fpn_config
Expand All @@ -778,7 +763,6 @@ def __init__(self, config, name='fpn_cell'):
strategy=config.strategy,
weight_method=self.fpn_config.weight_method,
data_format=config.data_format,
model_optimizations=config.model_optimizations,
name='fnode%d' % i)
self.fnodes.append(fnode)

Expand Down Expand Up @@ -839,7 +823,6 @@ def __init__(self,
conv_after_downsample=config.conv_after_downsample,
strategy=config.strategy,
data_format=config.data_format,
model_optimizations=config.model_optimizations,
name='resample_p%d' % level,
))
self.fpn_cells = FPNCells(config)
Expand Down Expand Up @@ -953,7 +936,7 @@ def map_fn(image):
if raw_images.shape.as_list()[0]: # fixed batch size.
batch_size = raw_images.shape.as_list()[0]
outputs = [map_fn(raw_images[i]) for i in range(batch_size)]
return [tf.stack(y) for y in zip(*outputs)]
return [tf.stop_gradient(tf.stack(y)) for y in zip(*outputs)]

# otherwise treat it as dynamic batch size.
return tf.vectorized_map(map_fn, raw_images)
Expand Down Expand Up @@ -999,6 +982,8 @@ def call(self, inputs, training=False, pre_mode='infer', post_mode='global'):
config.mean_rgb, config.stddev_rgb,
pre_mode)
# network.
if config.data_format == 'channels_first':
inputs = tf.transpose(inputs, [0, 3, 1, 2])
outputs = super().call(inputs, training)

if 'object_detection' in config.heads and post_mode:
Expand Down
8 changes: 5 additions & 3 deletions efficientdet/tf2/eval_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from absl import app
from absl import flags
from absl import logging
import multiprocessing
import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -46,7 +47,7 @@ def define_flags():
'only_network', False,
'TFLite model only contains EfficientDetNet without post-processing NMS op.'
)
flags.DEFINE_bool('pre_class_nms', False, 'Use pre_class_nms for evaluation.')
flags.DEFINE_bool('per_class_nms', False, 'Use per_class_nms for evaluation.')
flags.DEFINE_string('hparams', '', 'Comma separated k=v pairs or a yaml file')
flags.mark_flag_as_required('val_file_pattern')
flags.mark_flag_as_required('tflite_path')
Expand All @@ -64,7 +65,8 @@ def __init__(self, tflite_model_path, only_network=False):
without post-processing NMS op. If False, TFLite model contains custom
NMS op.
"""
self.interpreter = tf.lite.Interpreter(tflite_model_path)
self.interpreter = tf.lite.Interpreter(tflite_model_path,
num_threads=multiprocessing.cpu_count())
self.interpreter.allocate_tensors()
# Get input and output tensors.
self.input_details = self.interpreter.get_input_details()
Expand Down Expand Up @@ -175,7 +177,7 @@ def main(_):
box_outputs,
labels['image_scales'],
labels['source_ids'],
pre_class_nms=FLAGS.pre_class_nms)
per_class_nms=FLAGS.per_class_nms)

detections = postprocess.transform_detections(detections)
evaluator.update_state(labels['groundtruth_data'].numpy(),
Expand Down
2 changes: 1 addition & 1 deletion efficientdet/tf2/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def main(_):
model = efficientdet_keras.EfficientDetModel(config=config)
model.build((None, None, None, 3))
model.load_weights(tf.train.latest_checkpoint(FLAGS.model_dir))
model.summary()
model.summary(expand_nested=True)

class ExportModel(tf.Module):

Expand Down
Loading

0 comments on commit d5afa2a

Please sign in to comment.