-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial implementation of SavedModel I/O #13
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM from my limited knowledge so far :)
I just had a few minor comments
tests/graph_test.py
Outdated
# Remove the directory after the test. | ||
# Comment out this line to prevent deleting temps. | ||
shutil.rmtree(self.temp_dir) | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't need the pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You do if you comment out the previous line ;-) I'll add a comment to that effect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see. I totally missed the comment :)
tests/rewrite_test.py
Outdated
name="Input") | ||
result_tensor = input_tensor + 42 | ||
g = gde.Graph(tf_g) | ||
gde.rewrite.change_batch_size(g, 3, [g[input_tensor.op.name]]) | ||
print("Graph def is:\n{}".format(g.to_graph_def())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you mean to keep this print statement?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ugh, I thought I got all those. Thanks for catching that one!
tests/rewrite_test.py
Outdated
np.array([42., 43.]).reshape([2, 1]))) | ||
|
||
# Remove temp dir if the test is successful | ||
shutil.rmtree(temp_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you want to keep the director if the test was not successful? If not then it might be good to put this in a finally
block.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I had been using the directory's contents for debugging, but it's not a good idea to check in code that pollutes /tmp
tests/rewrite_test.py
Outdated
SavedModel | ||
""" | ||
temp_dir = tempfile.mkdtemp() | ||
print("Temp dir is {}".format(temp_dir)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also this print
statement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, fixed.
graph_def_editor/graph.py
Outdated
return self._signature_defs | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: i think extra newline here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, I neglected to go through the PEP8 linter warnings before checking in. Fixed that newline and a bunch of other, similar warnings.
This PR contains my implementation of issue #5 -- reading and writing SavedModel files.
There are limitations on the use of variables when writing to SavedModel files as described in #5.
I added tests of the new functionality.
I modified the example
batch_size_example.py
to read and write SavedModel files. After the modification, I noticed that, although the example script correctly returns a single row in its result, the script infers a size of (64, 1001) for thesoftmax_tensor
output, i.e.:This result appears to be technically correct --- the graph contains a
batch_normalization
meta-operator that hard-codes a batch size of 64. Some follow-on work will be needed to make this example set the batch size properly all the way to the end.