diff --git a/.gitignore b/.gitignore index ddfd6ad..c6fc533 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,7 @@ env *.swp */__pycache__ -graph_def_editor.iml +*.iml test.out example.out diff --git a/examples/batch_size_example.py b/examples/batch_size_example.py index dbc03d3..59418c0 100644 --- a/examples/batch_size_example.py +++ b/examples/batch_size_example.py @@ -47,6 +47,7 @@ def _indent(s): "/savedmodels/resnet_v2_fp16_savedmodel_NHWC.tar.gz" _MODEL_TARBALL = _TMP_DIR + "/resnet_v2_fp16_savedmodel_NHWC.tar.gz" _SAVED_MODEL_DIR = _TMP_DIR + "/resnet_v2_fp16_savedmodel_NHWC/1538686978" +_AFTER_MODEL_DIR = _TMP_DIR + "/rewritten_model" def main(_): @@ -67,7 +68,7 @@ def main(_): tf_g = tf.Graph() with tf.Session(graph=tf_g) as sess: tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING], - _SAVED_MODEL_DIR) + _SAVED_MODEL_DIR) # print("Graph is:\n{}".format(tf_g.as_graph_def())) @@ -78,27 +79,26 @@ def main(_): print(" Softmax tensor is {}".format(tf_g.get_tensor_by_name( "softmax_tensor:0"))) - # Convert the graph to a gde.Graph and rewrite the batch size to None - # TODO(frreiss): Perform this step over SavedModel files - g = gde.Graph(tf_g) + # Convert the SavedModel to a gde.Graph and rewrite the batch size to None + g = gde.saved_model_to_graph(_SAVED_MODEL_DIR) gde.rewrite.change_batch_size(g, new_size=None, inputs=[g["input_tensor"]]) + if os.path.exists(_AFTER_MODEL_DIR): + shutil.rmtree(_AFTER_MODEL_DIR) + g.to_saved_model(_AFTER_MODEL_DIR) - # Convert back to a TensorFlow graph - after_tf_g = g.to_tf_graph() - print("AFTER:") - print(" Input tensor is {}".format(after_tf_g.get_tensor_by_name( - "input_tensor:0"))) - print(" Softmax tensor is {}".format(after_tf_g.get_tensor_by_name( - "softmax_tensor:0"))) - - # Feed a single array of zeros through the graph - print("Restoring variables and running inference on dummy data") + # Load the rewritten SavedModel into a TensorFlow graph + after_tf_g = tf.Graph() with tf.Session(graph=after_tf_g) as sess: - # Load the variables checkpoint from the SavedModel file - saver = tf.train.Saver() - saver.restore(sess, _SAVED_MODEL_DIR + "/variables/variables") - # TODO(frreiss): Load variables with tf.saved_model.load() once the - # rewrite reads and writes SavedModel files + tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING], + _AFTER_MODEL_DIR) + print("AFTER:") + print(" Input tensor is {}".format(after_tf_g.get_tensor_by_name( + "input_tensor:0"))) + print(" Softmax tensor is {}".format(after_tf_g.get_tensor_by_name( + "softmax_tensor:0"))) + + # Feed a single array of zeros through the graph + print("Running inference on dummy data") result = sess.run("softmax_tensor:0", {"input_tensor:0": np.zeros([1, 224, 224, 3])}) print("Result is {}".format(result)) diff --git a/graph_def_editor/graph.py b/graph_def_editor/graph.py index 8f02b30..ac31f23 100644 --- a/graph_def_editor/graph.py +++ b/graph_def_editor/graph.py @@ -18,13 +18,25 @@ from __future__ import division from __future__ import print_function +import datetime +from distutils import dir_util +import os import tensorflow as tf -from typing import Tuple, Dict, FrozenSet, Iterable, Union +from typing import Tuple, Dict, FrozenSet, Iterable, Union, Set, Any from graph_def_editor import node, util, tensor, variable +# TODO: Move this protobuf into this project so we don't depend on +# tf.core.framework +from tensorflow.core.protobuf import saved_model_pb2, meta_graph_pb2 + + __all__ = [ "Graph", + "SaverInfo", + "SignatureInfo", + "GraphVisitor", + "saved_model_to_graph", ] # Special attribute in which TensorFlow stores frame names for while loops ( @@ -40,6 +52,49 @@ def visit_node(self, n: 'node.Node'): raise NotImplementedError() +class SaverInfo(object): + """ + Object to encapsulate information about a `tf.train.Saver` object that can + reconstitute the variable values for this graph. + """ + def __init__(self, path: str, saver_def: tf.train.SaverDef): + """ + Args: + path: Path to the location of serialized variable information on disk + saver_def: Serialized version of `tf.train.Saver` object + """ + self.path = path + self.saver_def = saver_def + + +class SignatureInfo(object): + """ + Object that encapsulates information about entry points to the graph, + AKA signatures. + """ + def __init__(self): + self._signature_defs = {} # Dict[str, meta_graph_pb2.SignatureDef] + + def add_signature_def(self, + name: str, + signature_def: meta_graph_pb2.SignatureDef): + """ + Add a signature to the set of entry points. + + Args: + name: Name for the entry point + signature_def: Definition of the entry point; specifies input and + output nodes and maps them to input and output names + """ + if name in self._signature_defs: + raise ValueError("Already have a signature with name '{}'".format(name)) + self._signature_defs[name] = signature_def + + @property + def signature_defs(self): + return self._signature_defs + + class Graph(object): """ Mutable surrogate for a `tf.GraphDef` protocol buffer message @@ -53,8 +108,11 @@ class Graph(object): collections """ - def __init__(self, g: tf.GraphDef = None, collections: - Iterable[tf.MetaGraphDef.CollectionDefEntry] = None): + def __init__(self, g: Union[tf.Graph, tf.GraphDef] = None, + name: str = None, + collections: Iterable[tf.MetaGraphDef.CollectionDefEntry] = None, + saver_info: SaverInfo = None, + signature_info: SignatureInfo = None): """ Wrap a tf.GraphDef protocol buffer in a Graph object. @@ -62,10 +120,17 @@ def __init__(self, g: tf.GraphDef = None, collections: g: a tf.Graph or tf.GraphDef protobuf that represents a TensorFlow graph. If set to None, generate an empty tf.GraphDef + name: Optional human-readable name for the graph. If not provided, + the constructor will generate a name. collections: Optional iterable of tf.MetaGraphDef.CollectionDefEntry objects containing information about collections in the graph. Note that this constructor will pull collection info out of `g` if it is a `tf.Graph` and `collections` is `None`. + saver_info: Optional serialiazed information about the + `tf.train.Saver` object that can save and restore variables in this + graph. + signature_info: Optional semi-serialized information about entry points + to the graph, AKA signatures """ if g is None: graph_def = tf.GraphDef() @@ -78,16 +143,29 @@ def __init__(self, g: tf.GraphDef = None, collections: else: raise TypeError("Graph is of type {}. Expected a tf.Graph or GraphDef " "proto".format(type(g))) + if name is None: + time_str = datetime.datetime.now().isoformat() + name = "GraphDef Editor Graph created {}".format(time_str) + if signature_info is None: + signature_info = SignatureInfo() + elif not isinstance(signature_info, SignatureInfo): + raise ValueError("signature_info argument must be a SignatureInfo object") + + # Populate fields of object + self._name = name # str self._version = 0 # Must happen first; other init code needs self._version - self._frozen = False - self._graph_def = graph_def - self._next_id = 1 + self._frozen = False # bool + self._graph_def = graph_def # tf.GraphDef + self._next_id = 1 # int output_map = _decode_graph(graph_def) self._node_name_to_node = {} # Dict[str, node.Node]; key is node name self._node_to_frame_names = None self._frame_name_to_nodes = None self._head_name_to_coloc_group = None # Dict[str, FrozenList[str]] self._variable_name_to_variable = {} # Dict[str, Variable] + self._collection_name_to_type = None # Dict[str, str], generated on demand + self._passthrough_collections = {} # Dict[str, List[CollectionDef]] + self._passthrough_saver = None # Load nodes in three passes because the g may contain cycles. for node_def in graph_def.node: @@ -97,12 +175,28 @@ def __init__(self, g: tf.GraphDef = None, collections: for node_def in graph_def.node: self[node_def.name].set_inputs_from_strings(node_def.input, set_control_inputs=True) - - self._collections = {} + # Collections reference nodes and variables if collections is not None: for c in collections: self.add_collection_from_collection_def(c) + # Presence of a passthrough saver prevents adding additional variables, + # so load after variables are constituted (i.e. from collections) + self._passthrough_saver = saver_info + self._signatures = signature_info + + @property + def name(self): + """ + Returns human-readable name for this graph. This name may not be unique + across graphs. + """ + return self._name + + @property + def has_passthrough_saver(self): + return self._passthrough_saver is not None + def add_node_from_node_def(self, node_def: tf.NodeDef, set_inputs: bool = False) -> 'node.Node': """ @@ -114,9 +208,9 @@ def add_node_from_node_def(self, node_def: tf.NodeDef, g: Graph in which the node will be created node_def: Fully-populated NodeDef proto; all fields, including inputs, will be used. - set_inputs: Optional. If True, also populate the data and control inputs of - the returned Node. This operation will only work if the targets of those - inputs are already present in the graph. + set_inputs: Optional. If True, also populate the data and control inputs + of the returned Node. This operation will only work if the targets of + those inputs are already present in the graph. """ ret = self.add_node(name=node_def.name, op_name=node_def.op) ret.device = node_def.device @@ -126,18 +220,43 @@ def add_node_from_node_def(self, node_def: tf.NodeDef, ret.set_inputs_from_strings(node_def.input, set_control_inputs=True) return ret - def add_collection_from_collection_def(self, collection_def: - tf.MetaGraphDef.CollectionDefEntry): + def add_collection_from_collection_def( + self, + collection_def: tf.MetaGraphDef.CollectionDefEntry, + validate_name: bool = True): """ Unpack a `tf.MetaGraphDef.CollectionDefEntry` of serialized variables into a collection of variables in this graph. The collection must not exist. Variables that do not already exist will be created. + + Args: + collection_def: Serialized information about the collection + validate_name: Verify that a collection by this name doesn't already + exist. Set this argument to False to avoid O(n^2) behavior when + bulk-loading known-good collection metadata. """ collection_name = collection_def.key - for serialized_var in collection_def.value.bytes_list.value: - var = self.add_variable_from_variable_def(serialized_var, - skip_if_present=True) - var.add_to_collection(collection_name) + if validate_name and collection_name in self.get_all_collection_keys(): + raise ValueError("Collection '{}' already exists".format(collection_name)) + collection = collection_def.value + # The collection is stored in exactly one of five different formats. + if collection.HasField("node_list"): + for node_name in collection_def.value.node_list.value: + n = self.get_node_by_name(node_name) + n.add_to_collection(collection_name) + elif collection.HasField("bytes_list"): + for serialized_var in collection_def.value.bytes_list.value: + var = self.add_variable_from_variable_def(serialized_var, + skip_if_present=True) + var.add_to_collection(collection_name) + elif (collection.HasField("int64_list") + or collection.HasField("float_list") + or collection.HasField("any_list")): + self._passthrough_collections[collection_name] = collection + if self._collection_name_to_type is not None: + self._collection_name_to_type[collection_name] = "passthrough" + else: + raise ValueError("Unknown collection type: {}".format(collection)) def __getitem__(self, name: str) -> Union[tensor.Tensor, 'node.Node']: """ @@ -282,7 +401,7 @@ def add_variable_from_variable_def(self, variable_def, def variable_names(self): return self._variable_name_to_variable.keys() - def name_to_variable(self, name: str) -> variable.Variable: + def get_variable_by_name(self, name: str) -> variable.Variable: """ Fetch a variable by its variable name. @@ -414,14 +533,18 @@ def get_tensor_by_name(self, tensor_name: str, error_msg: str = None): )) return n.output(output_ix) - def to_graph_def(self): + def to_graph_def(self, add_shapes: bool = True): """ + Args: + add_shapes: If True, add the special "_output_shapes" attribute with + output shape information from this Node's output metadata. + Returns the `tf.GraphDef` serialization of this graph in its current form. """ ret = tf.GraphDef() for op in self.nodes: - op.to_node_def(ret.node.add()) + op.to_node_def(ret.node.add(), add_shapes) return ret def to_tf_graph(self): @@ -438,6 +561,143 @@ def to_tf_graph(self): util.load_variables_to_tf_graph(self) return ret + def to_saved_model(self, saved_model_path: str, + tags: Iterable[str] = None) -> \ + saved_model_pb2.SavedModel: + """ + Writes this graph out as a TensorFlow SavedModel on disk. + + Args: + saved_model_path: Location where the root directory of the SavedModel + should reside. + tags: What tag strings should be associated with the MetaGraph that this + method puts inside the SavedModel. If None, use the + tag `tf.saved_model.tag_constants.SERVING` + + Returns the SavedModel protocol buffer message that it wrote to the + specified location. + """ + if tags is None: + tags = [tf.saved_model.tag_constants.SERVING] + if os.path.exists(saved_model_path): + raise ValueError("Output path '{}' already exists".format( + saved_model_path)) + if not os.path.exists(os.path.dirname(saved_model_path)): + raise ValueError("Parent directory '{}' of output dir '{}' does not " + "exist".format(os.path.dirname(saved_model_path), + saved_model_path)) + os.mkdir(saved_model_path) + + # Core part of the SavedModel is a protocol buffers file containing a + # SavedModel protocol buffer message. + # See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/ + # core/protobuf/saved_model.proto + saved_model = saved_model_pb2.SavedModel() + saved_model.saved_model_schema_version = 1 + + # Inside the SavedModel protobuf is a list of MetaGraphDef protobufs. In + # this case there is only one. + # See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/ + # core/protobuf/meta_graph.proto + meta_graph = saved_model.meta_graphs.add() + + # The MetaGraphDef message contains a nested header called a MetaInfoDef. + # The first field of the MetaInfoDef is called "meta_graph_version". + # This field does not actually hold the version of the MetaGraph. Instead + # it holds an arbitrary string that can be whatever you want. + meta_info_def = tf.MetaGraphDef.MetaInfoDef() + meta_info_def.meta_graph_version = self.name + + # The second field, "stripped_op_list" holds "A copy fo the OpDefs used by + # the producer of this graph_def". According to the docs for + # tf.import_graph_def, this field is deprecated. This field does not + # appear to have ever accomplished anything useful. + # TensorFlow fills this field with a deluge of OpDef information. We leave + # this field out. + + # The third field, "any_info", provides a place for holding additional + # arbitrary information. We also leave this field out. + + # The fourth field holds the string tags for this MetaGraph + meta_info_def.tags.extend(tags) + + # The fifth and sixth fields hold TensorFlow version information. + # We punt here and populate these fields with the version info from + # the current Python session's copy of TensorFlow. + meta_info_def.tensorflow_version = tf.VERSION + meta_info_def.tensorflow_git_version = tf.GIT_VERSION + + # The final field, "stripped_default_attrs", is "A flag to denote whether + # default-valued attrs have been stripped from the nodes in this graph_def" + # The TensorFlow authors appear to have added this field in the hopes + # that future versions of the system might be able to use it for forwards + # compatibility. No code in TensorFlow currently reads this attribute. We + # set it to False. + meta_info_def.stripped_default_attrs = False + + meta_graph.meta_info_def.CopyFrom(meta_info_def) + + # After the meta_info_def comes a GraphDef proto holding all the graph + # nodes that this MetaGraph uses. If an op in the original TensorFlow + # graph is in multiple MetaGraphs, that op will be stored ONCE PER + # METAGRAPH under this field. In our case there is exactly one + # MetaGraph in the SavedModel. + meta_graph.graph_def.CopyFrom(self.to_graph_def()) + + # The next field, "saver_def", holds information about the tf.Saver + # instance that will be used to reconstitute any variables in the graph + if self.has_passthrough_saver: + meta_graph.saver_def.CopyFrom(self._passthrough_saver.saver_def) + # Copy serialized variables checkpoint wholesale, because the checkpoint + # format is a black box to us. + dir_util.copy_tree(self._passthrough_saver.path, + _vars_dir_for_saved_model(saved_model_path)) + elif len(self.variable_names) > 0: + raise NotImplementedError("Can't generate a SaverDef.") + else: + # Zero variables, no passthrough SaverDef. + # For this case, TensorFlow creates an empty variables directory and + # doesn't set the "saver_def" field. We emulate this behavior. + os.mkdir(_vars_dir_for_saved_model(saved_model_path)) + + # The next field, "collection_def", holds serialized information about all + # collections in the MetaGraph. + if self._collection_name_to_type is None: + self._build_collection_name_to_type() + for coll_name, coll_type in self._collection_name_to_type.items(): + if coll_type == "passthrough": + meta_graph.collection_def[coll_name] = self._passthrough_collections[ + coll_name] + elif coll_type == "variable": + vars_list = self.get_collection_by_name(coll_name) + serialized_vars = [v.to_proto().SerializeToString() for v in vars_list] + meta_graph.collection_def[coll_name].bytes_list.value.extend( + serialized_vars) + elif coll_type == "node": + nodes_list = self.get_collection_by_name(coll_name) + meta_graph.collection_def[coll_name].node_list.value.extend( + [n.name for n in nodes_list]) + else: + raise ValueError("Unknown collection type '{}'".format(coll_type)) + + # The next field, "signature_def", contains information about + # input/output signatures that this MetaGraph supports. + for sig_name, sig_def in self.signatures.items(): + meta_graph.signature_def[sig_name].CopyFrom(sig_def) + + # The final field, asset_file_def, stores information about additional + # assets that are packaged along with the graph in the SavedModel's + # "assets" directory. Fow now we leave this field empty. + # TODO(frreiss): Represent assets as a field in the Graph class and + # serialize them here. + + # At this point, we have created the root directory for the SavedModel, + # as well as the checkpoints directory. The only thing left to write is + # the SavedModel protobuf itself. + with open(saved_model_path + "/saved_model.pb", "wb") as f: + f.write(saved_model.SerializeToString()) + return saved_model + @property def version(self): """ @@ -468,8 +728,9 @@ def increment_version_counter(self): self._node_to_frame_names = None self._frame_name_to_nodes = None self._head_name_to_coloc_group = None + self._collection_name_to_type = None - def get_collection(self, name: str): + def get_collection_by_name(self, name: str) -> Iterable[Any]: """Fetch the contents of a collection, similarly to the method in `tf.Graph` by the same name. @@ -480,11 +741,60 @@ def get_collection(self, name: str): The values in the collection. Currently any type is allowed in these values, following the conventions of the TensorFlow APIs. """ - return self._collections[name] + if self._collection_name_to_type is None: + self._build_collection_name_to_type() + if name not in self._collection_name_to_type: + raise ValueError("No collection with name '{}'".format(name)) + coll_type = self._collection_name_to_type[name] + if coll_type == "passthrough": + return self._passthrough_collections[name] + elif coll_type == "variable": + ret = [] + for v_name in self.variable_names: + v = self.get_variable_by_name(v_name) + if name in v.collection_names: + ret.append(v) + return ret + elif coll_type == "node": + ret = [] + for n in self.nodes: + if name in n.collection_names: + ret.append(n) + return ret + else: + raise ValueError("Unknown collection type '{}'".format(coll_type)) + + def _build_collection_name_to_type(self): + self._collection_name_to_type = {} + passthrough_collection_names = set(self._passthrough_collections.keys()) + variable_collection_names = set() + node_collection_names = set() + for var_name in self.variable_names: + v = self.get_variable_by_name(var_name) + for name in v.collection_names: + variable_collection_names.add(name) + for n in self.nodes: + for name in n.collection_names: + node_collection_names.add(name) + + def _add(names, type_name): + for coll_name in names: + if coll_name in self._collection_name_to_type: + raise ValueError(( + _duplicate_collection_error_str(coll_name, + passthrough_collection_names, + variable_collection_names, + node_collection_names))) + self._collection_name_to_type[coll_name] = type_name + _add(passthrough_collection_names, "passthrough") + _add(variable_collection_names, "variable") + _add(node_collection_names, "node") def get_all_collection_keys(self): """Returns the keys associated with all collections stored in this object""" - return self._collections.keys() + if self._collection_name_to_type is None: + self._build_collection_name_to_type() + return self._collection_name_to_type.keys() def _get_next_id(self) -> int: """Generates and returns a unique integer ID *within this graph*.""" @@ -690,7 +1000,96 @@ def colocation_groups(self) -> Dict[str, FrozenSet['node.Node']]: k: frozenset(v) for k, v in head_name_to_coloc_group.items()} return self._head_name_to_coloc_group + @property + def signatures(self) -> Dict[str, meta_graph_pb2.SignatureDef]: + """ + Returns a map from signature name to signature definition. Changes to + this map will be reflected in this object. + """ + return self._signatures.signature_defs + +def saved_model_to_graph(saved_model_path: str, tag: str = None, + include_saver: bool = True, + include_signatures: bool = True) -> 'Graph': + """ + Load the contents of a TensorFlow SavedModel into a Graph object. + + Args: + saved_model_path: Path to the SavedModel's directory on disk + tag: User-specified tag attached to the MetaGraphDef that should be + loaded from the SavedModel. If None, verify that there is only one + MetaGraphDef in the model and load that one. + include_saver: If True, attach black-box information about the SavedModel's + serialized `tf.Saver` object in the returned Graph object. Otherwise the + returned Graph will not contain any serialized variable values, though + it will contain variable initializers. + include_signatures: If True, attach signature information from the + SavedModel to the returned Graph object. Otherwise the returned graph + will have no signatures. + + Returns: In-memory representation of the contents of the SavedModel as a + Graph object. + """ + if not os.path.exists(saved_model_path): + raise ValueError("SavedModel root directory {} not found".format( + saved_model_path)) + if not os.path.isdir(saved_model_path): + raise ValueError("SavedModel root path {} is not a directory".format( + saved_model_path)) + + # By convention, the main protobuf for the SavedModel is in a file called + # "saved_model.pb" + protobuf_file = saved_model_path + "/saved_model.pb" + saved_model = saved_model_pb2.SavedModel() + with open(protobuf_file, "rb") as f: + saved_model.ParseFromString(f.read()) + + # Drill down to pull out the appropriate MetaGraphDef proto + if tag is None: + if len(saved_model.meta_graphs) != 1: + raise ValueError("No tags specified and there are multiple " + "MetaGraphDefs in the SavedModel. Please specify a " + "tag to select a specific MetaGraphDef") + meta_graph = saved_model.meta_graphs[0] + else: + matching_ixs = [ + i for i in range(len(saved_model.meta_graphs)) + if tag in saved_model.meta_graphs[i].meta_info_def.tags + ] + if len(matching_ixs) == 0: + raise ValueError("No MetaGraphDef in SavedModel at {} contains tag " + "'{}'".format(saved_model_path, tag)) + if len(matching_ixs) > 1: + raise ValueError("{} different MetaGraphDef in SavedModel at {} " + "contain tag '{}'. Please specify a tag that " + "uniquely identifies a MetaGraphDef" + "".format(len(matching_ixs), saved_model_path, tag)) + meta_graph = saved_model.meta_graphs[matching_ixs[0]] + + # Decompose the MetaGraphDef into the serialized components of the graph + graph_def = meta_graph.graph_def + collections = [] + for collection_name in meta_graph.collection_def: + collection_proto = tf.MetaGraphDef.CollectionDefEntry() + collection_proto.key = collection_name + collection_proto.value.CopyFrom( + meta_graph.collection_def[collection_name]) + if include_saver and meta_graph.HasField("saver_def"): + saver_info = SaverInfo(_vars_dir_for_saved_model(saved_model_path), + meta_graph.saver_def) + else: + saver_info = None + signature_info = SignatureInfo() + if include_signatures: + for key in meta_graph.signature_def: + signature_info.add_signature_def(key, meta_graph.signature_def[key]) + + return Graph(graph_def, + name=meta_graph.meta_info_def.meta_graph_version, + collections=collections, + saver_info=saver_info, + signature_info=signature_info) ################################################################################ # Stuff below this line is private to this file. @@ -733,14 +1132,11 @@ def _make_collection_defs(tf_g: tf.Graph) -> Iterable[ """ Convenience function to serialize all the collections in a TensorFlow graph. - **NOTE:** Currently this function only captures collections of variables. - Args: tf_g: TensorFlow graph from which to harvest collections Returns a list of `tf.MetaGraphDef.CollectionDefEntry` protobuf containing - the serialized - contents of the collections. + the serialized contents of the collections. """ ret = [] for collection_name in tf_g.collections: @@ -798,7 +1194,6 @@ def _make_collection_defs(tf_g: tf.Graph) -> Iterable[ raise NotImplementedError("Unexpected collection type {}".format( collection_type)) - ret.append(collection_proto) return ret @@ -827,4 +1222,35 @@ def _decode_tensor_name(tensor_name: str, error_msg: str): node_name = tensor_name output_ix = 0 - return node_name, output_ix \ No newline at end of file + return node_name, output_ix + + +def _duplicate_collection_error_str(name: str, + passthrough_collection_names: Set[str], + variable_collection_names: Set[str], + node_collection_names: Set[str]): + """ + Generate an error string for the case where a collection ends up being of + multiple types simultaneously. + """ + types = [] + if name in passthrough_collection_names: + types.append("passthrough") + if name in variable_collection_names: + types.append("variable") + if name in node_collection_names: + types.append("node") + return ( + "Collection name '{}' maps to multiple collection types: " + "{}".format(name, types)) + + +def _vars_dir_for_saved_model(saved_model_path: str) -> str: + """ + Args: + saved_model_path: Root directory of a SavedModel on disk + + Returns the location of the directory where the indicated SavedModel will + store its variables checkpoint. + """ + return saved_model_path + "/variables" diff --git a/graph_def_editor/node.py b/graph_def_editor/node.py index 3b2f8e3..398958d 100644 --- a/graph_def_editor/node.py +++ b/graph_def_editor/node.py @@ -19,7 +19,7 @@ from __future__ import print_function import tensorflow as tf -from typing import Tuple, List, Iterable, Any +from typing import Tuple, List, Iterable, Any, AbstractSet, Sized from graph_def_editor import graph, tensor, util @@ -31,6 +31,11 @@ # group names. _COLOCATION_PREFIX = "loc:@" +# Magical attribute name that TensorFLow uses to store shape information. +# Note that TensorFlow will treat the value of this field as truth and will +# skip shape inference if it is present. +_OUTPUT_SHAPES_ATTR_NAME = "_output_shapes" + __all__ = [ "Node", ] @@ -62,11 +67,12 @@ def __init__(self, g: 'graph.Graph', node_id: int, name: str, op_name: str, self._name = name self._op_name = op_name self._device = device - self._attributes = [] # List[Tuple[str,Any]] + self._attributes = [] # List[Tuple[str,Any]] self._inputs = [] - self._outputs = [] + self._outputs = None # List[Tensor] self._control_inputs = [] - self._colocation_groups = [] + self._colocation_groups = [] # List[str] + self._collection_names = set() # Set[str] def __repr__(self): return "Node[{}]".format(self.name) @@ -87,6 +93,17 @@ def op_type(self) -> str: """ return self._op_name + def change_op_type(self, new_op_type: str): + """ + Change the op type of this node. Does NOT rerun shape or type inference. + + Args: + new_op_type: New string value for the operator type. Should correspond + to the name of a TensorFlow op, although this method does not validate + the string. + """ + self._op_name = new_op_type + @property def graph(self) -> 'graph.Graph': """ @@ -111,6 +128,8 @@ def outputs(self): current outputs of this node. Note that this tuple does not change if the underlying node is mutable and gets edited. """ + if self._outputs is None: + raise ValueError("Outputs have not been set") return tuple(self._outputs) def output(self, index: int): @@ -120,6 +139,8 @@ def output(self, index: int): Returns: The Tensor corresponding to the indicated output of the node """ + if self._outputs is None: + raise ValueError("Outputs have not been set") return self._outputs[index] @property @@ -322,11 +343,13 @@ def add_colocation_group(self, head_node_name: str, validate: bool = True): head_node_name)) self._colocation_groups.append(head_node_name) - def to_node_def(self, target: tf.NodeDef = None): + def to_node_def(self, target: tf.NodeDef = None, add_shapes: bool = True): """ Args: target: optional preallocated, empty NodeDef object to fill in. If not provided, this method will allocate a new `tf.NodeDef` object. + add_shapes: If True, add the special "_output_shapes" attribute with + output shape information from this Node's output metadata. Returns: A copy of the contents of this node as a NodeDef proto. The returned proto will *not* change if this node is changed after the call, and @@ -350,9 +373,14 @@ def to_node_def(self, target: tf.NodeDef = None): # colocation_groups property for more information. transformed_names = [_COLOCATION_PREFIX + name for name in self._colocation_groups] - target.attr["_class"].CopyFrom( + target.attr[_COLOCATION_ATTR_NAME].CopyFrom( util.python_type_to_attr_value(transformed_names) ) + if add_shapes and self._outputs is not None and len(self._outputs) > 0: + shapes_list = [t.shape for t in self._outputs] + target.attr[_OUTPUT_SHAPES_ATTR_NAME].CopyFrom( + util.python_type_to_attr_value(shapes_list) + ) return target def get_attr(self, key: str) -> Any: @@ -400,6 +428,10 @@ def add_attr(self, key: str, value: Any, will redirect to a call to the setter for the `colocation_groups` property. + If you use this method to set the special "_output_shapes" attribute, + the method will redirect to a call to the set_outputs_from_pairs() + method. + Args: key: Name of the attribute. Must be unique. value: Value to put in place for the attribute. Can be a Python type or a @@ -415,43 +447,37 @@ def add_attr(self, key: str, value: Any, raise ValueError("Tried to set special '{}' attribute when the " "Node already has colocation " "groups".format(_COLOCATION_ATTR_NAME)) - for group_name in self._validate_colocation_group_attr(value): + for group_name in _validate_colocation_group_attr(value): self.add_colocation_group(group_name, validate=validate_colocation_groups) + elif key == _OUTPUT_SHAPES_ATTR_NAME: + # Special magic key name for output shapes. + new_shapes = _validate_output_shapes_attr(value) + self._update_shapes(new_shapes) elif key in self._attr_names(): raise ValueError("Already have an attribute called '{}'".format(key)) else: self._attributes.append((key, value)) - @staticmethod - def _validate_colocation_group_attr(value: str) -> List[str]: - """Validate a potential value for the special "_class" attribute that - holds collocation groups. - - Returns a list of node names that comprise the group.""" - if isinstance(value, tf.AttrValue): - # Internal TF type; convert to iterable of Python strings - if value.list.s is None: - raise ValueError("Tried to set special '{}' attribute using " - "tf.AttrValue object, and the object's 'list.s' " - "attribute was not populated. Value: " - "'{}'".format(_COLOCATION_ATTR_NAME, str(value))) - value = [tf.compat.as_str(s_i) for s_i in value.list.s] - elif not isinstance(value, list) and not isinstance(value, tuple): - raise ValueError("Tried to set special '{}' attribute with a type " - "other than list or tuple. Type is '{}' and value " - "is '{}'".format(_COLOCATION_ATTR_NAME, type(value), - str(value))) - ret = [] - for elem in value: - if not elem.startswith(_COLOCATION_PREFIX): - raise ValueError("Tried to set special '{}' attribute with " - "something other than a string starting with " - "'{}' (value used: " - "'{}')".format(_COLOCATION_ATTR_NAME, - _COLOCATION_PREFIX, elem)) - ret.append(elem[len(_COLOCATION_PREFIX):]) - return ret + def _update_shapes(self, new_shapes: List[tf.TensorShape]): + """ + Put a set of output shapes in place without changing dtypes. Raises an + error if doing so would change the number of outputs. Sets dtypes to None + if no output information is present. + """ + if self._outputs is None: + pairs = [(None, s) for s in new_shapes] + self.set_outputs_from_pairs(pairs) + else: + if len(new_shapes) != len(self._outputs): + raise ValueError("Attempted to put in place {} output shapes, " + "but node has {} outputs.".format(len(new_shapes), + len(self._outputs))) + # Update shapes in place to avoid creating new Tensor objects. + # If we created new Tensor objects, we would need to update all the + # downstream ops that used those Tensors as inputs. + for i in range(len(new_shapes)): + self._outputs[i].shape = new_shapes[i] def replace_attr(self, key: str, value: Any, validate_colocation_groups: bool = False): @@ -481,6 +507,10 @@ def replace_attr(self, key: str, value: Any, for group_name in self._validate_colocation_group_attr(value): self.add_colocation_group(group_name, validate=validate_colocation_groups) + elif key == _OUTPUT_SHAPES_ATTR_NAME: + # Special magic key name for output shapes. + new_shapes = _validate_output_shapes_attr(value) + self._update_shapes(new_shapes) elif key not in self._attr_names(): raise ValueError("No attribute called '{}'".format(key)) else: @@ -509,8 +539,8 @@ def set_control_inputs(self, new_control_inputs: Iterable['Node']): self._control_inputs = list(new_control_inputs) def set_outputs_from_pairs(self, - new_outputs: Iterable[Tuple[tf.DType, - tf.TensorShape]]): + new_outputs: List[Tuple[tf.DType, + tf.TensorShape]]): """ Set all outputs at once, removing anything that was there previously. @@ -519,13 +549,26 @@ def set_outputs_from_pairs(self, inference to infer the number, type, and shape of the node's outputs. Args: - new_outputs: Iterable of (dtype, shape) pairs that describe the outputs - """ - self._outputs = [] - i = 0 - for (dtype, shape) in new_outputs: - self._outputs.append(tensor.Tensor(self, i, dtype, shape)) - i += 1 + new_outputs: List of (dtype, shape) pairs that describe the outputs + """ + if self._outputs is not None and len(new_outputs) != len(self._outputs): + # TODO(frreiss): Implement changing the number of outputs. This + # implementation will require walking the graph and dealing with pointers + # to Tensors that don't exist. + raise NotImplementedError("Attempted to change number of output tensors " + "on node {} from {} to {}. Changing the " + "number of output tensors is not currently " + "supported.".format(self.name, + len(self._outputs), + len(new_outputs))) + elif self._outputs is None: + self._outputs = [tensor.Tensor(self, i, None, None) + for i in range(len(new_outputs))] + + # At this point, self._outputs is initialized. Update dtypes and shapes + # in place. + for i in range(len(new_outputs)): + self._outputs[i].dtype, self._outputs[i].shape = new_outputs[i] self._graph.increment_version_counter() # Just in case def infer_outputs(self): @@ -544,23 +587,33 @@ def infer_outputs(self): Raises: TBD """ - # TF lack a supported API for invoking shape inference directly, - # so we instantiate a dummy graph and create a dummy Operation object - temp_graph = tf.Graph() - with temp_graph.as_default(): - input_placeholders = [tf.placeholder(shape=t.shape, dtype=t.dtype) for - t in self._inputs] - # See the docs for tf.Operation for important notes about the semantics - # of each arg to the following constructor. - dummy_op = tf.Operation(self.to_node_def(), temp_graph, - inputs=input_placeholders) - self.set_outputs_from_pairs([(o.dtype, o.shape) - for o in dummy_op.outputs]) - # set_outputs_from_pairs() increments the version counter, so we don't - # need to. Also, we haven't added edges to the graph until these - # outputs are connected to another node's inputs. - - # TODO(frreiss): If this op has a "T" attribute, set that too. + if self.op_type == "Assign": + # SPECIAL CASE: Assign op takes a reference as input. Don't build up a + # graph and invoke shape inference, because the APIs for references are + # in flux. Instead, just trust the attributes. + # First input is the reference, second is the value to put in place. + # Assign op returns the reference that it just assigned to. + input_ref = self._inputs[0] + self.set_outputs_from_pairs([(input_ref.dtype, input_ref.shape)]) + else: + # Common case: Use shape inference. + # TF lacks a supported API for invoking shape inference directly, + # so we instantiate a dummy graph and create a dummy Operation object. + temp_graph = tf.Graph() + with temp_graph.as_default(): + input_placeholders = [tf.placeholder(shape=t.shape, dtype=t.dtype) for + t in self._inputs] + # See the docs for tf.Operation for important notes about the semantics + # of each arg to the following constructor. + dummy_op = tf.Operation(self.to_node_def(), temp_graph, + inputs=input_placeholders) + self.set_outputs_from_pairs([(o.dtype, o.shape) + for o in dummy_op.outputs]) + # set_outputs_from_pairs() increments the version counter, so we don't + # need to. Also, we haven't added edges to the graph until these + # outputs are connected to another node's inputs. + + # TODO(frreiss): If this op has a "T" attribute, set that too. def set_inputs_from_strings(self, new_inputs: Iterable[str], set_control_inputs: bool = True): @@ -582,7 +635,25 @@ def set_inputs_from_strings(self, new_inputs: Iterable[str], self._control_inputs = _decode_control_inputs(new_inputs, self._graph) self._graph.increment_version_counter() # New edges added to graph + @property + def collection_names(self) -> AbstractSet[str]: + """ + Returns the names of all collections this node is a member of in the + parent graph. + """ + return frozenset(self._collection_names) + def add_to_collection(self, collection_name: str): + """ + Add the node to the indicated collection. + """ + if collection_name in self._collection_names: + raise ValueError("Node '{}' already in collection '{}'".format( + self.name, collection_name)) + self._collection_names.add(collection_name) + # Invalidate any information the parent graph may have cached about + # collections. + self._graph.increment_version_counter() ################################################################################ @@ -660,3 +731,51 @@ def _decode_control_inputs(inputs: Iterable[str], g: 'graph.Graph') -> List[ return [g[name] for name in control_input_names] +def _validate_colocation_group_attr(value: Any) -> List[str]: + """Validate a potential value for the special "_class" attribute that + holds collocation groups. + + Returns a list of node names that comprise the group.""" + if isinstance(value, tf.AttrValue): + # Internal TF type; convert to iterable of Python strings + if value.list.s is None: + raise ValueError("Tried to set special '{}' attribute using " + "tf.AttrValue object, and the object's 'list.s' " + "attribute was not populated. Value: " + "'{}'".format(_COLOCATION_ATTR_NAME, str(value))) + value = [tf.compat.as_str(s_i) for s_i in value.list.s] + elif not isinstance(value, list) and not isinstance(value, tuple): + raise ValueError("Tried to set special '{}' attribute with a type " + "other than list or tuple. Type is '{}' and value " + "is '{}'".format(_COLOCATION_ATTR_NAME, type(value), + str(value))) + ret = [] + for elem in value: + if not elem.startswith(_COLOCATION_PREFIX): + raise ValueError("Tried to set special '{}' attribute with " + "something other than a string starting with " + "'{}' (value used: " + "'{}')".format(_COLOCATION_ATTR_NAME, + _COLOCATION_PREFIX, elem)) + ret.append(elem[len(_COLOCATION_PREFIX):]) + return ret + + +def _validate_output_shapes_attr(value: Any) -> List[tf.TensorShape]: + """ + Validate a potential value for the special "_output_shapes" attribute. + + Returns a list of output shapes extracted from the attribute value. + """ + if isinstance(value, tf.AttrValue): + if value.list.shape is None: + raise ValueError("Tried to set special '{}' attribute using " + "tf.AttrValue object, and the object's 'list.shape' " + "attribute was not populated. Value: " + "'{}'".format(_OUTPUT_SHAPES_ATTR_NAME, str(value))) + return [tf.TensorShape(shape_i) for shape_i in value.list.shape] + else: + raise ValueError("Tried to set special '{}' attribute with a type " + "other than a tf.AttrValue. Type is '{}' and value " + "is '{}'".format(_OUTPUT_SHAPES_ATTR_NAME, type(value), + str(value))) diff --git a/graph_def_editor/rewrite.py b/graph_def_editor/rewrite.py index 039994a..dac4ebc 100644 --- a/graph_def_editor/rewrite.py +++ b/graph_def_editor/rewrite.py @@ -77,3 +77,4 @@ def change_batch_size(g: graph.Graph, + diff --git a/graph_def_editor/tensor.py b/graph_def_editor/tensor.py index 749a2df..3679ad6 100644 --- a/graph_def_editor/tensor.py +++ b/graph_def_editor/tensor.py @@ -70,10 +70,18 @@ def value_index(self): def dtype(self) -> tf.DType: return self._dtype + @dtype.setter + def dtype(self, value: tf.DType): + self._dtype = value + @property - def shape(self): + def shape(self) -> tf.TensorShape: return self._shape + @shape.setter + def shape(self, value: tf.TensorShape): + self._shape = value + @property def graph(self): """Returns the `gde.Graph` object representing the graph in which the diff --git a/graph_def_editor/transform.py b/graph_def_editor/transform.py index 5ad65e2..15a2488 100644 --- a/graph_def_editor/transform.py +++ b/graph_def_editor/transform.py @@ -338,7 +338,7 @@ def __init__(self, sgv, dst_graph: Graph, dst_scope, src_scope): self.scope_ = dst_scope self.transformed_ops = {} self.transformed_ts = {} - self.collections = dict((key, self.graph.get_collection(key)) + self.collections = dict((key, self.graph.get_collection_by_name(key)) for key in self.graph.get_all_collection_keys()) self.cyclic_ops = [] self.transform_original_op_handler = transform_op_if_inside_handler diff --git a/graph_def_editor/util.py b/graph_def_editor/util.py index 4b4d25f..c5931b9 100644 --- a/graph_def_editor/util.py +++ b/graph_def_editor/util.py @@ -578,9 +578,9 @@ def _python_type_to_attr_list_elem(list_value: tf.AttrValue.ListValue, elif isinstance(elem, tf.DType): list_value.type.append(elem.as_datatype_enum) elif isinstance(elem, tf.TensorShape): - list_value.shape.append(elem.as_proto()) + list_value.shape.add().CopyFrom(elem.as_proto()) elif isinstance(elem, np.ndarray): - list_value.tensor.append(tf.make_tensor_proto(values=elem)) + list_value.tensor.add().CopyFrom(tf.make_tensor_proto(values=elem)) # TODO(frreiss): Populate the "func" field of the union here else: raise ValueError("Don't know how to convert a {} to " @@ -681,7 +681,7 @@ def load_variables_to_tf_graph(g: 'graph.Graph'): should be loaded """ for var_name in g.variable_names: - var = g.name_to_variable(var_name) + var = g.get_variable_by_name(var_name) tf_var = tf.Variable.from_proto(var.to_proto()) tf.add_to_collections(var.collection_names, tf_var) diff --git a/graph_def_editor/variable.py b/graph_def_editor/variable.py index ef060a2..7f1b746 100644 --- a/graph_def_editor/variable.py +++ b/graph_def_editor/variable.py @@ -50,8 +50,13 @@ def __init__(self, g: 'graph.Graph'): Args: g: gde.Graph object representing the containing graph """ + if g.has_passthrough_saver: + # The internals of a tf.Saver are opaque to us. + raise ValueError("Attempted to add a variable to Graph '{}', which has " + "an immutable serialized tf.Saver " + "object.".format(g.name)) self._graph = g - self._collection_names = set() + self._collection_names = set() # Set[str] # Core fields are modeled after those of VariableDef. self._variable_name = None # str @@ -122,6 +127,21 @@ def from_proto(self, variable_def: Union[variable_pb2.VariableDef, bytes], if validate: self.validate(allow_duplicates) + def to_proto(self): + """ + Inverse of `from_proto()` method. + + Returns a `VariableDef` protocol buffer message that represents this + variable. + """ + ret = variable_pb2.VariableDef() + ret.variable_name = self._variable_name + ret.initial_value_name = self._initial_value_name + ret.initializer_name = self._initializer_name + ret.snapshot_name = self._snapshot_name + ret.trainable = self._trainable + return ret + def validate(self, allow_duplicate: bool = False): """ Verify that all the names this variable references are valid in the @@ -132,7 +152,7 @@ def validate(self, allow_duplicate: bool = False): same name, provided that the two variables are equal. """ if self._variable_name in self.graph.variable_names: - other_var = self.graph.name_to_variable(self._variable_name) + other_var = self.graph.get_variable_by_name(self._variable_name) if other_var is not self: if not self.is_same_variable(other_var): raise ValueError("Existing '{}' in graph conflicts with this one " diff --git a/setup.py b/setup.py index d7742dd..7bf9392 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ long_description=long_description, long_description_content_type='text/markdown', packages=find_packages(), - install_requires=['numpy<=1.14.5', 'tensorflow'], + install_requires=['numpy<=1.14.5', 'tensorflow', 'six'], include_package_data=True, zip_safe=False, classifiers=[ diff --git a/tests/graph_test.py b/tests/graph_test.py new file mode 100644 index 0000000..4e3b690 --- /dev/null +++ b/tests/graph_test.py @@ -0,0 +1,125 @@ +# Copyright 2019 IBM. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +Tests for graph.py in the GraphDef Editor +""" + +import unittest +import tensorflow as tf +import numpy as np +import shutil +import tempfile + + +import graph_def_editor as gde + + +class GraphTest(unittest.TestCase): + + def setUp(self): + # Create a temporary directory for SavedModel files. + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + # Remove the directory after the test. + # Comment out this line to prevent deleting temps. + shutil.rmtree(self.temp_dir) + pass # In case previous line gets commented out + + def test_import_saved_model(self): + tf_g = tf.Graph() + with tf_g.as_default(): + input_tensor = tf.placeholder(dtype=tf.int32, shape=[], + name="Input") + result_tensor = input_tensor + 42 + + model_dir = self.temp_dir + "/saved_model" + with tf.Session() as sess: + tf.saved_model.simple_save(sess, model_dir, + inputs={"in": input_tensor}, + outputs={"out": result_tensor}) + + g = gde.saved_model_to_graph(model_dir) + with g.to_tf_graph().as_default(): + with tf.Session() as sess: + result = sess.run(result_tensor.name, {input_tensor.name: 1}) + self.assertEqual(result, 43) + + def test_export_saved_model_no_vars(self): + """Generate a graph in memory with no variables and export as SavedModel + (with empty checkpoint)""" + tf_g = tf.Graph() + with tf_g.as_default(): + input_tensor = tf.placeholder(dtype=tf.int32, shape=[], + name="Input") + result_tensor = input_tensor + 42 + g = gde.Graph(tf_g) + model_dir = self.temp_dir + "/saved_model" + g.to_saved_model(model_dir) + + # Load the model we just saved and do a test run + after_tf_g = tf.Graph() + with after_tf_g.as_default(): + with tf.Session() as sess: + tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING], + model_dir) + result = sess.run(result_tensor.name, {input_tensor.name: 1}) + self.assertEqual(result, 43) + + def test_export_saved_model_with_var(self): + """Import a SavedModel with a variable, modify the resulting graph, + and write it out as a second SavedModel""" + tf_g = tf.Graph() + with tf_g.as_default(): + input_tensor = tf.placeholder(dtype=tf.int32, shape=[], + name="Input") + var_tensor = tf.Variable(initial_value=42, name="FortyTwo") + result_tensor = input_tensor + var_tensor + + with tf.Session() as sess: + sess.run(var_tensor.initializer) + model_dir = self.temp_dir + "/saved_model" + tf.saved_model.simple_save(sess, model_dir, + inputs={"in": input_tensor}, + outputs={"out": result_tensor}) + + g = gde.saved_model_to_graph(model_dir) + + # Verify that the import went ok + with g.to_tf_graph().as_default(): + with tf.Session() as sess: + sess.run(var_tensor.initializer.name) + result = sess.run(result_tensor.name, {input_tensor.name: 1}) + self.assertEqual(result, 43) + + # Now rewrite plus to minus. + result_op = g.get_node_by_name(result_tensor.op.name) + result_op.change_op_type("Sub") + + second_model_dir = self.temp_dir + "/saved_model_after" + g.to_saved_model(second_model_dir) + + after_tf_g = tf.Graph() + with after_tf_g.as_default(): + with tf.Session() as sess: + tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING], + second_model_dir) + result = sess.run(result_tensor.name, {input_tensor.name: 1}) + self.assertEqual(result, -41) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/rewrite_test.py b/tests/rewrite_test.py index 4e2b86e..e44dc2e 100644 --- a/tests/rewrite_test.py +++ b/tests/rewrite_test.py @@ -17,6 +17,8 @@ Tests for rewrite.py in the GraphDef Editor """ +import shutil +import tempfile import unittest import tensorflow as tf import numpy as np @@ -30,7 +32,7 @@ def test_change_batch_size(self): """Basic test for gde.rewrite.change_batch_size.""" tf_g = tf.Graph() with tf_g.as_default(): - input_tensor = tf.placeholder(dtype=tf.int32, shape=[32,1], + input_tensor = tf.placeholder(dtype=tf.int32, shape=[32, 1], name="Input") result_tensor = input_tensor + 42 g = gde.Graph(tf_g) @@ -51,7 +53,7 @@ def test_change_batch_size_variable_size(self): """ tf_g = tf.Graph() with tf_g.as_default(): - input_tensor = tf.placeholder(dtype=tf.float32, shape=[32,1], + input_tensor = tf.placeholder(dtype=tf.float32, shape=[32, 1], name="Input") result_tensor = input_tensor + 42.0 g = gde.Graph(tf_g) @@ -67,7 +69,50 @@ def test_change_batch_size_variable_size(self): np.array([42.]).reshape([1, 1]))) result = sess.run(result_tensor.name, {input_tensor.name: - np.array([0, 1]).reshape([2, 1])}) + np.array([0, 1]).reshape([2, 1])}) self.assertTrue(np.array_equal(result, np.array([42., 43.]).reshape([2, 1]))) + def test_change_batch_size_saved_model(self): + """ + Verifies that changes of batch size survive serializing the graph as a + SavedModel + """ + temp_dir = tempfile.mkdtemp() + try: + tf_g = tf.Graph() + with tf_g.as_default(): + input_tensor = tf.placeholder(dtype=tf.float32, shape=[32, 1], + name="Input") + result_tensor = input_tensor + 42.0 + with tf.Session() as sess: + tf.saved_model.simple_save(sess, temp_dir + "/model_before", + inputs={"in": input_tensor}, + outputs={"out": result_tensor}) + + # Make sure the original SavedModel loads properly + with tf.Session(graph=tf.Graph()) as sess: + tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING], + temp_dir + "/model_before") + + g = gde.saved_model_to_graph(temp_dir + "/model_before") + gde.rewrite.change_batch_size(g, None, [g[input_tensor.name]]) + g.to_saved_model(temp_dir + "/model_after") + + with tf.Session(graph=tf.Graph()) as sess: + tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING], + temp_dir + "/model_after") + result = sess.run(result_tensor.name, + {input_tensor.name: + np.array([0]).reshape([1, 1])}) + self.assertTrue(np.array_equal(result, + np.array([42.]).reshape([1, 1]))) + result = sess.run(result_tensor.name, + {input_tensor.name: + np.array([0, 1]).reshape([2, 1])}) + self.assertTrue(np.array_equal(result, + np.array([42., 43.]).reshape([2, 1]))) + finally: + # Remove temp dir unconditionally. Comment out try and finally if you + # want the directory to stick around after a test failure. + shutil.rmtree(temp_dir)