diff --git a/graph_def_editor/base_graph.py b/graph_def_editor/base_graph.py index 4f295e5..f06b134 100644 --- a/graph_def_editor/base_graph.py +++ b/graph_def_editor/base_graph.py @@ -53,7 +53,9 @@ def __init__( name = None, # type: str ): """ - Wrap a tf.GraphDef protocol buffer in a Graph object. + Constructor to be called by subclasses only. + + Initializes attributes of this base class. Args: name: Optional human-readable name for the graph. If not provided, @@ -137,6 +139,8 @@ def add_node(self, uniquify_name: Generate a unique name from this name if the graph already has a node with the indicated name. If False, raise an exception if the name is in use. + debug_info: Some internal TensorFlow debug information. + We just pass it through for safety. Returns: `MutableNode` wrapper for the new node. @@ -265,6 +269,7 @@ def add_variable(self, name): v = variable.Variable(self) v.name = name self._variable_name_to_variable[name] = v + self.increment_version_counter() return v def add_variable_from_variable_def(self, variable_def, diff --git a/graph_def_editor/function_graph.py b/graph_def_editor/function_graph.py index afc0736..dc412d0 100644 --- a/graph_def_editor/function_graph.py +++ b/graph_def_editor/function_graph.py @@ -47,12 +47,7 @@ class FunctionGraph(base_graph.BaseGraph): - """Wrapper class for TensorFlow function graphs. - - Summary of internal data structures: - * _node_name_to_node: Nodes in the graph, stored as a dictionary. Key is name. - * _version: Counter that increments every time the graph is modified - """ + """Wrapper class for TensorFlow function graphs.""" def __init__( self, @@ -78,7 +73,6 @@ def __init__( [(dtype, shape) for (dtype, shape, _) in tuples] # Populate fields of object - self._node_name_to_node = {} # Dict[str, node.Node]; key is node name self._node_to_frame_names = None self._frame_name_to_nodes = None self._head_name_to_coloc_group = None # Dict[str, FrozenList[str]] diff --git a/graph_def_editor/graph.py b/graph_def_editor/graph.py index 54d580d..bf2d004 100644 --- a/graph_def_editor/graph.py +++ b/graph_def_editor/graph.py @@ -311,10 +311,11 @@ def to_graph_def(self, add_shapes=True): for op in self.nodes: op.to_node_def(ret.node.add(), add_shapes) - # Pass through library without modifications for now. + # Copy library as is. if self._graph_def and self._graph_def.library: ret.library.CopyFrom(self._graph_def.library) + # Update functions in library that were instantiated as function graphs. for f_name, f_graph in self._function_graphs.items(): function_index_to_update = None for index in range(0, len(ret.library.function)): @@ -716,6 +717,9 @@ def breadth_first_visitor( """ Visit all nodes reachable from a starting set in the order of a breadth-first traversal (going from node to output edges). + If visitor gets to a function call, and iterate_functions is True, + it will iterate all function nodes first and then continue with + remaining nodes in the graph. Invokes a callback at each node visited. Args: @@ -731,7 +735,7 @@ def breadth_first_visitor( that we should also iterate through all function callers up the stack. Returns: - True if iteration was iterruputed by visitor, otherwise False. + True if iteration was interrupted by visitor, otherwise False. """ if starting_nodes is None: # Start with all of the nodes in the graph that have no inputs. @@ -818,6 +822,9 @@ def backwards_breadth_first_visitor( """ Visit all nodes reachable from a starting set in the order of a backwards breadth-first traversal (going from node to input edges). + If visitor gets to a function call, and iterate_functions is True, + it will iterate all function nodes first and then continue with + remaining nodes in the graph. Invokes a callback at each node visited. Args: @@ -830,7 +837,7 @@ def backwards_breadth_first_visitor( that we should also iterate through all function callers up the stack. Returns: - True if iteration was iterruputed by visitor, otherwise False. + True if iteration was interrupted by visitor, otherwise False. """ if not starting_nodes: raise ValueError("starting_nodes is not provided") diff --git a/graph_def_editor/node.py b/graph_def_editor/node.py index bb312df..943e893 100644 --- a/graph_def_editor/node.py +++ b/graph_def_editor/node.py @@ -23,7 +23,7 @@ if sys.version >= '3': from typing import Tuple, List, Iterable, Any, AbstractSet, Type -from graph_def_editor import tensor, util +from graph_def_editor import base_graph, tensor, util # Magical attribute name that TensorFlow uses to store colocation groups. # See colocation_groups property below for more information. @@ -87,7 +87,7 @@ def __init__(self, device: TensorFlow device specification string indicating where this node should be located. Default value of "" means "use the default device" """ - _type_check(g, graph.Graph, "g") + _type_check(g, base_graph.BaseGraph, "g") _type_check(node_id, int, "node_id") _type_check(name, str, "name") _type_check(op_name, str, "op_name")