Skip to content

Commit

Permalink
addressed feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
vlasenkoalexey committed Aug 6, 2021
1 parent 90b77a7 commit a2993b5
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 13 deletions.
7 changes: 6 additions & 1 deletion graph_def_editor/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def __init__(
name = None, # type: str
):
"""
Wrap a tf.GraphDef protocol buffer in a Graph object.
Constructor to be called by subclasses only.
Initializes attributes of this base class.
Args:
name: Optional human-readable name for the graph. If not provided,
Expand Down Expand Up @@ -137,6 +139,8 @@ def add_node(self,
uniquify_name: Generate a unique name from this name if the graph
already has a node with the indicated name. If False, raise an
exception if the name is in use.
debug_info: Some internal TensorFlow debug information.
We just pass it through for safety.
Returns:
`MutableNode` wrapper for the new node.
Expand Down Expand Up @@ -265,6 +269,7 @@ def add_variable(self, name):
v = variable.Variable(self)
v.name = name
self._variable_name_to_variable[name] = v
self.increment_version_counter()
return v

def add_variable_from_variable_def(self, variable_def,
Expand Down
8 changes: 1 addition & 7 deletions graph_def_editor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,7 @@


class FunctionGraph(base_graph.BaseGraph):
"""Wrapper class for TensorFlow function graphs.
Summary of internal data structures:
* _node_name_to_node: Nodes in the graph, stored as a dictionary. Key is name.
* _version: Counter that increments every time the graph is modified
"""
"""Wrapper class for TensorFlow function graphs."""

def __init__(
self,
Expand All @@ -78,7 +73,6 @@ def __init__(
[(dtype, shape) for (dtype, shape, _) in tuples]

# Populate fields of object
self._node_name_to_node = {} # Dict[str, node.Node]; key is node name
self._node_to_frame_names = None
self._frame_name_to_nodes = None
self._head_name_to_coloc_group = None # Dict[str, FrozenList[str]]
Expand Down
13 changes: 10 additions & 3 deletions graph_def_editor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,11 @@ def to_graph_def(self, add_shapes=True):
for op in self.nodes:
op.to_node_def(ret.node.add(), add_shapes)

# Pass through library without modifications for now.
# Copy library as is.
if self._graph_def and self._graph_def.library:
ret.library.CopyFrom(self._graph_def.library)

# Update functions in library that were instantiated as function graphs.
for f_name, f_graph in self._function_graphs.items():
function_index_to_update = None
for index in range(0, len(ret.library.function)):
Expand Down Expand Up @@ -716,6 +717,9 @@ def breadth_first_visitor(
"""
Visit all nodes reachable from a starting set in the order of a
breadth-first traversal (going from node to output edges).
If visitor gets to a function call, and iterate_functions is True,
it will iterate all function nodes first and then continue with
remaining nodes in the graph.
Invokes a callback at each node visited.
Args:
Expand All @@ -731,7 +735,7 @@ def breadth_first_visitor(
that we should also iterate through all function callers up
the stack.
Returns:
True if iteration was iterruputed by visitor, otherwise False.
True if iteration was interrupted by visitor, otherwise False.
"""
if starting_nodes is None:
# Start with all of the nodes in the graph that have no inputs.
Expand Down Expand Up @@ -818,6 +822,9 @@ def backwards_breadth_first_visitor(
"""
Visit all nodes reachable from a starting set in the order of a
backwards breadth-first traversal (going from node to input edges).
If visitor gets to a function call, and iterate_functions is True,
it will iterate all function nodes first and then continue with
remaining nodes in the graph.
Invokes a callback at each node visited.
Args:
Expand All @@ -830,7 +837,7 @@ def backwards_breadth_first_visitor(
that we should also iterate through all function callers up
the stack.
Returns:
True if iteration was iterruputed by visitor, otherwise False.
True if iteration was interrupted by visitor, otherwise False.
"""
if not starting_nodes:
raise ValueError("starting_nodes is not provided")
Expand Down
4 changes: 2 additions & 2 deletions graph_def_editor/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
if sys.version >= '3':
from typing import Tuple, List, Iterable, Any, AbstractSet, Type

from graph_def_editor import tensor, util
from graph_def_editor import base_graph, tensor, util

# Magical attribute name that TensorFlow uses to store colocation groups.
# See colocation_groups property below for more information.
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(self,
device: TensorFlow device specification string indicating where this node
should be located. Default value of "" means "use the default device"
"""
_type_check(g, graph.Graph, "g")
_type_check(g, base_graph.BaseGraph, "g")
_type_check(node_id, int, "node_id")
_type_check(name, str, "name")
_type_check(op_name, str, "op_name")
Expand Down

0 comments on commit a2993b5

Please sign in to comment.