Skip to content

Commit

Permalink
Merge pull request #13 from frreiss/issue-savedmodel
Browse files Browse the repository at this point in the history
Initial implementation of SavedModel I/O
  • Loading branch information
frreiss authored Jan 21, 2019
2 parents aa61f9a + 440f918 commit e0c47b0
Show file tree
Hide file tree
Showing 12 changed files with 866 additions and 122 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
env
*.swp
*/__pycache__
graph_def_editor.iml
*.iml
test.out
example.out

38 changes: 19 additions & 19 deletions examples/batch_size_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _indent(s):
"/savedmodels/resnet_v2_fp16_savedmodel_NHWC.tar.gz"
_MODEL_TARBALL = _TMP_DIR + "/resnet_v2_fp16_savedmodel_NHWC.tar.gz"
_SAVED_MODEL_DIR = _TMP_DIR + "/resnet_v2_fp16_savedmodel_NHWC/1538686978"
_AFTER_MODEL_DIR = _TMP_DIR + "/rewritten_model"


def main(_):
Expand All @@ -67,7 +68,7 @@ def main(_):
tf_g = tf.Graph()
with tf.Session(graph=tf_g) as sess:
tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING],
_SAVED_MODEL_DIR)
_SAVED_MODEL_DIR)

# print("Graph is:\n{}".format(tf_g.as_graph_def()))

Expand All @@ -78,27 +79,26 @@ def main(_):
print(" Softmax tensor is {}".format(tf_g.get_tensor_by_name(
"softmax_tensor:0")))

# Convert the graph to a gde.Graph and rewrite the batch size to None
# TODO(frreiss): Perform this step over SavedModel files
g = gde.Graph(tf_g)
# Convert the SavedModel to a gde.Graph and rewrite the batch size to None
g = gde.saved_model_to_graph(_SAVED_MODEL_DIR)
gde.rewrite.change_batch_size(g, new_size=None, inputs=[g["input_tensor"]])
if os.path.exists(_AFTER_MODEL_DIR):
shutil.rmtree(_AFTER_MODEL_DIR)
g.to_saved_model(_AFTER_MODEL_DIR)

# Convert back to a TensorFlow graph
after_tf_g = g.to_tf_graph()
print("AFTER:")
print(" Input tensor is {}".format(after_tf_g.get_tensor_by_name(
"input_tensor:0")))
print(" Softmax tensor is {}".format(after_tf_g.get_tensor_by_name(
"softmax_tensor:0")))

# Feed a single array of zeros through the graph
print("Restoring variables and running inference on dummy data")
# Load the rewritten SavedModel into a TensorFlow graph
after_tf_g = tf.Graph()
with tf.Session(graph=after_tf_g) as sess:
# Load the variables checkpoint from the SavedModel file
saver = tf.train.Saver()
saver.restore(sess, _SAVED_MODEL_DIR + "/variables/variables")
# TODO(frreiss): Load variables with tf.saved_model.load() once the
# rewrite reads and writes SavedModel files
tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING],
_AFTER_MODEL_DIR)
print("AFTER:")
print(" Input tensor is {}".format(after_tf_g.get_tensor_by_name(
"input_tensor:0")))
print(" Softmax tensor is {}".format(after_tf_g.get_tensor_by_name(
"softmax_tensor:0")))

# Feed a single array of zeros through the graph
print("Running inference on dummy data")
result = sess.run("softmax_tensor:0",
{"input_tensor:0": np.zeros([1, 224, 224, 3])})
print("Result is {}".format(result))
Expand Down
Loading

0 comments on commit e0c47b0

Please sign in to comment.