forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
export_tflite_ssd_graph.py
143 lines (126 loc) · 5.73 KB
/
export_tflite_ssd_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Exports an SSD detection model to use with tf-lite.
Outputs file:
* A tflite compatible frozen graph - $output_directory/tflite_graph.pb
The exported graph has the following input and output nodes.
Inputs:
'normalized_input_image_tensor': a float32 tensor of shape
[1, height, width, 3] containing the normalized input image. Note that the
height and width must be compatible with the height and width configured in
the fixed_shape_image resizer options in the pipeline config proto.
In floating point Mobilenet model, 'normalized_image_tensor' has values
between [-1,1). This typically means mapping each pixel (linearly)
to a value between [-1, 1]. Input image
values between 0 and 255 are scaled by (1/128.0) and then a value of
-1 is added to them to ensure the range is [-1,1).
In quantized Mobilenet model, 'normalized_image_tensor' has values between [0,
255].
In general, see the `preprocess` function defined in the feature extractor class
in the object_detection/models directory.
Outputs:
If add_postprocessing_op is true: frozen graph adds a
TFLite_Detection_PostProcess custom op node has four outputs:
detection_boxes: a float32 tensor of shape [1, num_boxes, 4] with box
locations
detection_classes: a float32 tensor of shape [1, num_boxes]
with class indices
detection_scores: a float32 tensor of shape [1, num_boxes]
with class scores
num_boxes: a float32 tensor of size 1 containing the number of detected boxes
else:
the graph has two outputs:
'raw_outputs/box_encodings': a float32 tensor of shape [1, num_anchors, 4]
containing the encoded box predictions.
'raw_outputs/class_predictions': a float32 tensor of shape
[1, num_anchors, num_classes] containing the class scores for each anchor
after applying score conversion.
Example Usage:
--------------
python object_detection/export_tflite_ssd_graph.py \
--pipeline_config_path path/to/ssd_mobilenet.config \
--trained_checkpoint_prefix path/to/model.ckpt \
--output_directory path/to/exported_model_directory
The expected output would be in the directory
path/to/exported_model_directory (which is created if it does not exist)
with contents:
- tflite_graph.pbtxt
- tflite_graph.pb
Config overrides (see the `config_override` flag) are text protobufs
(also of type pipeline_pb2.TrainEvalPipelineConfig) which are used to override
certain fields in the provided pipeline_config_path. These are useful for
making small changes to the inference graph that differ from the training or
eval config.
Example Usage (in which we change the NMS iou_threshold to be 0.5 and
NMS score_threshold to be 0.0):
python object_detection/export_tflite_ssd_graph.py \
--pipeline_config_path path/to/ssd_mobilenet.config \
--trained_checkpoint_prefix path/to/model.ckpt \
--output_directory path/to/exported_model_directory
--config_override " \
model{ \
ssd{ \
post_processing { \
batch_non_max_suppression { \
score_threshold: 0.0 \
iou_threshold: 0.5 \
} \
} \
} \
} \
"
"""
import tensorflow.compat.v1 as tf
from google.protobuf import text_format
from object_detection import export_tflite_ssd_graph_lib
from object_detection.protos import pipeline_pb2
flags = tf.app.flags
flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
flags.DEFINE_string(
'pipeline_config_path', None,
'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
'file.')
flags.DEFINE_string('trained_checkpoint_prefix', None, 'Checkpoint prefix.')
flags.DEFINE_integer('max_detections', 10,
'Maximum number of detections (boxes) to show.')
flags.DEFINE_integer('max_classes_per_detection', 1,
'Maximum number of classes to output per detection box.')
flags.DEFINE_integer(
'detections_per_class', 100,
'Number of anchors used per class in Regular Non-Max-Suppression.')
flags.DEFINE_bool('add_postprocessing_op', True,
'Add TFLite custom op for postprocessing to the graph.')
flags.DEFINE_bool(
'use_regular_nms', False,
'Flag to set postprocessing op to use Regular NMS instead of Fast NMS.')
flags.DEFINE_string(
'config_override', '', 'pipeline_pb2.TrainEvalPipelineConfig '
'text proto to override pipeline_config_path.')
FLAGS = flags.FLAGS
def main(argv):
del argv # Unused.
flags.mark_flag_as_required('output_directory')
flags.mark_flag_as_required('pipeline_config_path')
flags.mark_flag_as_required('trained_checkpoint_prefix')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
text_format.Merge(f.read(), pipeline_config)
text_format.Merge(FLAGS.config_override, pipeline_config)
export_tflite_ssd_graph_lib.export_tflite_graph(
pipeline_config, FLAGS.trained_checkpoint_prefix, FLAGS.output_directory,
FLAGS.add_postprocessing_op, FLAGS.max_detections,
FLAGS.max_classes_per_detection, use_regular_nms=FLAGS.use_regular_nms)
if __name__ == '__main__':
tf.app.run(main)