From f74ccda8ec639521e6f164cfb7b7189734f7badc Mon Sep 17 00:00:00 2001 From: Aleksey Vlasenko Date: Thu, 1 Jul 2021 17:13:11 -0700 Subject: [PATCH 1/2] added support for rewriting function graphs in TF2.x models --- graph_def_editor/__init__.py | 1 + graph_def_editor/base_graph.py | 499 +++++++++++++++++ graph_def_editor/function_graph.py | 351 ++++++++++++ graph_def_editor/graph.py | 842 ++++++++++++----------------- graph_def_editor/node.py | 67 ++- graph_def_editor/rewrite.py | 29 +- graph_def_editor/util.py | 76 ++- graph_def_editor/variable.py | 10 +- tests/function_graph_test.py | 151 ++++++ tests/graph_test.py | 349 ++++++++++++ tests/match_test.py | 4 +- tests/select_test.py | 3 +- tests/transform_test.py | 3 + 13 files changed, 1830 insertions(+), 555 deletions(-) create mode 100644 graph_def_editor/base_graph.py create mode 100644 graph_def_editor/function_graph.py create mode 100644 tests/function_graph_test.py diff --git a/graph_def_editor/__init__.py b/graph_def_editor/__init__.py index 3e27f4e..2532725 100644 --- a/graph_def_editor/__init__.py +++ b/graph_def_editor/__init__.py @@ -23,6 +23,7 @@ from __future__ import print_function # pylint: disable=wildcard-import +from graph_def_editor.base_graph import * from graph_def_editor.edit import * from graph_def_editor.graph import * from graph_def_editor.match import * diff --git a/graph_def_editor/base_graph.py b/graph_def_editor/base_graph.py new file mode 100644 index 0000000..4f295e5 --- /dev/null +++ b/graph_def_editor/base_graph.py @@ -0,0 +1,499 @@ +# Copyright 2021 Google. All Rights Reserved. +# Copyright 2019 IBM. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for Graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import datetime +from distutils import dir_util +import os +from six import string_types +import tensorflow.compat.v1 as tf +import sys +if sys.version >= '3': + from typing import Tuple, Dict, FrozenSet, Iterable, Union, Set, Any + +from graph_def_editor import node, util, tensor, variable + + +__all__ = [ + "BaseGraph", +] + +class BaseGraph(object): + """ + Base class for Graph and FunctionGraph classes. + + Mutable surrogate for a `tf.GraphDef` protocol buffer message. + + Summary of internal data structures: + * _node_name_to_node: Nodes in the graph, stored as a dictionary. Key is name. + * _version: Counter that increments every time the graph is modified + * _collections: Map from collection name to collection contents for all + collections + """ + + def __init__( + self, + name = None, # type: str + ): + """ + Wrap a tf.GraphDef protocol buffer in a Graph object. + + Args: + name: Optional human-readable name for the graph. If not provided, + the constructor will generate a name. + """ + # Populate fields of object + self._name = name # str + self._version = 0 # Must happen first; other init code needs self._version + self._frozen = False # bool + self._next_id = 1 # int + self._node_name_to_node = {} # Dict[str, node.Node]; key is node name + self._variable_name_to_variable = {} # Dict[str, Variable] + + @property + def name(self): + """ + Returns human-readable name for this graph. This name may not be unique + across graphs. + """ + return self._name + + def __getitem__(self, name): + # type: (str) -> Union[tensor.Tensor, 'node.Node'] + """ + Convenience method to retrieve a node or tensor of the graph by name + + Args: + name: Name of the node or tensor to return. Case-sensitive. + + Returns the named item as a `gde.Node` or `gde.Tensor` object. If there + is a conflict between node and tensor names, node names win. + """ + if not isinstance(name, string_types): + raise TypeError("name must be a string; got type {}".format(type(name))) + + if self.contains_node(name): + return self._node_name_to_node[name] + elif self.contains_tensor(name): + return self.get_tensor_by_name(name) + else: + raise ValueError("No node or tensor '{}' found in graph".format(name)) + + def get_node_by_name(self, name): + # type: (str) -> node.Node + """ + Retrieve a node in the graph by name. + + Args: + name: Name of the node. Case-sensitive. + + Returns the indicated node as a `gde.Node` object. + """ + if self.contains_node(name): + return self._node_name_to_node[name] + else: + raise ValueError("No node '{}' found in graph".format(name)) + + def contains_node(self, name): + # type: (str) -> bool + """ + Returns true if the graph has a node by the indicated name. Exact string + match. + """ + if not isinstance(name, string_types): + raise ValueError("Node name argument is not a string, but is of type " + "{}".format(type(name))) + return name in self._node_name_to_node.keys() + + def add_node(self, + name, # type: str + op_name, # type: str + uniquify_name = False, # type: bool + debug_info = None # type: tf.compat.v1.NodeDef.ExperimentalDebugInfo + ): + # type: (...) -> node.Node + """ + Add a new, empty node to the graph. + Args: + name: Name for the new op + op_name: Name of the type of operation for the node + uniquify_name: Generate a unique name from this name if the graph + already has a node with the indicated name. If False, raise an + exception if the name is in use. + + Returns: + `MutableNode` wrapper for the new node. + + Raises: + ValueError if the name is already in use and `uniquify_name` is False + """ + if uniquify_name: + name = self.unique_name(name) + elif self._name_in_use(name): # and not uniquify_name + raise ValueError("Graph already contains a node with name '{}' " + "(Note that this check is case-insensitive)." + .format(name)) + ret = node.Node(self, + self._get_next_id(), + name=name, + op_name=op_name, + debug_info=debug_info) + self._node_name_to_node[name] = ret + self.increment_version_counter() + return ret + + def add_node_from_node_def(self, + node_def, # type: tf.NodeDef + set_inputs = False, # type: bool + set_control_inputs = False # type: bool + ): + # type: (...) -> node.Node + """ + Adds a new node to the graph, populating fields of the node from a + `tf.NodeDef` protocol buffer. + + Equivalent to calling `add_node()`, then populating the relevant fields + of the returned MutableNode object. + + Args: + node_def: Protocol buffer describing parameters of the new node. + set_inputs: If True, populate the node's inputs list from the list of + inputs in the `NodeDef` + set_control_inputs: Also set control inputs. Must be False if + `set_inputs` is False. + + Returns: + `MutableNode` wrapper for the new node + """ + if set_control_inputs and not set_inputs: + raise ValueError("set_inputs must be True if set_control_inputs is True") + ret = self.add_node(name=node_def.name, + op_name=node_def.op, + debug_info=node_def.experimental_debug_info) + if set_inputs: + ret.set_inputs_from_strings(node_def.input, + set_control_inputs=set_control_inputs) + ret.device = node_def.device + ret.clear_attrs() + for key in node_def.attr: + ret.add_attr(key, node_def.attr[key]) + + # Don't need to increment version counter; add_node() already did that. + return ret + + def remove_node_by_name(self, name, check_for_refs = True): + # type: (str, str) -> None + """ + Removes the indicated node from this graph and from any collections in + this graph. + + The caller is responsible for removing all links to the indicated node + prior to making this call. + + Args: + name: name of the node to remove + check_for_refs: Optional. If True, raise an exception if there are any + other nodes in the graph that reference this node. If False, allow + removal of nodes with outstanding references to them. In the latter + case, the caller is responsible for cleaning up the graph afterwards. + """ + n = self.get_node_by_name(name) + if check_for_refs: + for t in n.outputs: + if len(t.consumers()) > 0: + raise ValueError("Removing node '{}' would leave dangling " + "references from nodes {} to tensor '{}'" + "".format(name, [c.name for c in t.consumers()], + t.name)) + # noinspection PyProtectedMember + n._remove_from_graph() + del self._node_name_to_node[name] + self.increment_version_counter() + # Don't need to update collection info because collection membership is + # stored in the node. + # Don't need to update consumers of tensors because that information is + # calculated dynamically by iterating over nodes. + + def rename_node(self, old_name, new_name): + # type: (str, str) -> None + """ + Change the name of a node in the graph. + + Args: + old_name: Name of an existing node + new_name: New name for the node in question. Must not currently be in use. + """ + if self.contains_node(new_name): + raise ValueError("Graph already has a node under name '{}'".format( + new_name)) + n = self.get_node_by_name(old_name) + # noinspection PyProtectedMember + n._change_name(new_name) + del self._node_name_to_node[old_name] + self._node_name_to_node[new_name] = n + self.increment_version_counter() + + def add_variable(self, name): + # type: (str) -> variable.Variable + """ + Adds a new variable to the graph. + + Args: + name: Name of the variable. Must not already be in use. + + Returns the `gde.Variable` object corresponding to the added variable. + """ + if name in self._variable_name_to_variable: + raise ValueError("Variable name '{}' already in use".format(name)) + v = variable.Variable(self) + v.name = name + self._variable_name_to_variable[name] = v + return v + + def add_variable_from_variable_def(self, variable_def, + skip_if_present = False): + # type: (Any, bool) -> None + """ + Adds a new variable to the graph and populates the fields of the + corresponding Variable object according to a protocol buffer message. + + Args: + variable_def: `tensorflow.core.framework.variable_pb2.VariableDef` + protobuf object. May be serialized as a `bytes` object. + skip_if_present: If True, silently skips inserting duplicate variables, + as long as they don't conflict with existing variables. + + Returns the `gde.Variable` object corresponding to the added variable. + """ + v = variable.Variable(self) + v.from_proto(variable_def, allow_duplicates=skip_if_present) + if v.name not in self._variable_name_to_variable: + self._variable_name_to_variable[v.name] = v + return self._variable_name_to_variable[v.name] + + @property + def variable_names(self): + return self._variable_name_to_variable.keys() + + def get_variable_by_name(self, name): + # type: (str) -> variable.Variable + """ + Fetch a variable by its variable name. + + Args: + name: Name of a variable in this graph. + + Returns the variable associated with the name. Raises an exception if + there is no variable with the indicated name. + """ + return self._variable_name_to_variable[name] + + def _name_in_use(self, name): + # type: (str) -> bool + """Check whether a name is in use, using the same collision semantics as + TensorFlow: Exact lowercase string match. + + Args: + name: Name of a potential node in the graph. + + Returns True if the indicated name is currently in use, ignoring case. + """ + return name.lower() in [k.lower() for k in self._node_name_to_node.keys()] + + def unique_name(self, name): + # type: (str) -> str + """Emulate the behavior of the method by the same name in `tf.Graph`. + + Does *not* emulate the `name_stack` field of `tf.Graph`. + + Unlike the original method, this version does *not* keep a separate table + of names currently "in use for the purposes of `unique_name()`", but instead + refers directly to internal data structures to find names that are truly + in use. + + Args: + name: The name for an operation. + + Returns: + A variant of `name` that has been made unique by appending a key to it + in the same way that `tf.Graph.unique_name()` would. + """ + # For the sake of checking for names in use, we treat names as case + # insensitive (e.g. foo = Foo). + if not self._name_in_use(name): + return name + + # Generate a unique version by appending "_1", "_2", etc. until we find + # an unused name. Note that this approach will behave slightly + # differently from the original if nodes are deleted. + i = 1 + new_name = "{}_{}".format(name, i) + while self._name_in_use(new_name): + i = i + 1 + new_name = "{}_{}".format(name, i) + return new_name + + @property + def node_names(self): + # type: () -> Iterable[str] + return self._node_name_to_node.keys() + + @property + def nodes(self): + # type: () -> Tuple[node.Node] + """ + Returns: + A list of all nodes, both immutable and mutable, present in the graph + after the edits that this object is buffering. + """ + return tuple(self._node_name_to_node.values()) + + @property + def tensors(self): + # type: () -> List[tensor.Tensor] + """ + Return a list of all the tensors which are input or output of an op in + the graph. + """ + ts = [] + for op in self.nodes: + ts += op.outputs + return ts + + def contains_tensor(self, tensor_name): + # type: (str) -> bool + """ + Returns true if the graph has a tensor by the indicated name. Exact string + match. + + Args: + tensor_name: TensorFlow-format name ('node name:input num', or 'node + name' as shorthand for 'node name:0') + + Raises ValueError if the tensor name is not properly formatted. + """ + error_msg = "Invalid tensor name '{}': {}" + node_name, output_ix = self._decode_tensor_name(tensor_name, error_msg) + if node_name not in self._node_name_to_node: + return False + else: + n = self[node_name] + if output_ix >= len(n.outputs): + return False + else: + return True + + def get_tensor_by_name(self, tensor_name, error_msg = None): + # type: (str, str) -> tensor.Tensor + """ + Retrieve a tensor by human-readable name. + + Args: + tensor_name: TensorFlow-format name ('node name:input num', or 'node + name' as shorthand for 'node name:0') + error_msg: Optional format string for raising errors. Must be able to + serve as an input to `str.format()` with two arguments: tensor name + string and reason for failure. + + Returns: gde.Tensor object corresponding to the indicated tensor. + + Raises ValueError if the name is invalid or references a tensor that does + not exist. + """ + if error_msg is None: + error_msg = "Invalid tensor name '{}': {}" + node_name, output_ix = self._decode_tensor_name(tensor_name, error_msg) + if node_name not in self._node_name_to_node: + raise ValueError(error_msg.format( + tensor_name, "Node name '{}' not found in graph.".format(node_name) + )) + n = self[node_name] + if output_ix >= len(n.outputs): + raise ValueError(error_msg.format( + tensor_name, "Requested output {}, but node '{}' has {} " + "outputs.".format(output_ix, node_name, len(n.outputs)) + )) + return n.output(output_ix) + + @property + def version(self): + # type: () -> int + """ + Returns a counter that goes up every time this graph is changed. + """ + return self._version + + @property + def frozen(self): + # type: () -> bool + """ + True if the graph is configured to raise an exception on any structural + modification. + """ + return self._frozen + + @frozen.setter + def frozen(self, value): + # type: (bool) -> None + self._frozen = value + + def increment_version_counter(self): + """ + Mark the structure of this graph as "changed" and invalidate any cached + information about the edges of the graph. + """ + if self.frozen: + raise RuntimeError("Detected a change to a frozen graph") + self._version += 1 + + def _get_next_id(self): + # type: () -> int + """Generates and returns a unique integer ID *within this graph*.""" + ret = self._next_id + self._next_id = ret + 1 + return ret + + def _decode_tensor_name(self, tensor_name, error_msg): + # type: (str, str) -> Tuple[str, int] + """ + Args: + tensor_name: TensorFlow-format name ('node name:input num', or 'node + name' as shorthand for 'node name:0') + error_msg: Format string for raising errors. Must be able to + serve as an input to `str.format()` with two arguments: tensor name + string and reason for failure. + + Returns: (node name, output index) tuple identifying the tensor + + Raises ValueError if the name is not properly formatted + """ + if ":" in tensor_name: + node_name, output_ix_str = tensor_name.split(":") + if not output_ix_str.isdigit(): + raise ValueError(error_msg.format( + tensor_name, "Invalid output index string '{}'.".format(output_ix_str) + )) + output_ix = int(output_ix_str) + else: + node_name = tensor_name + output_ix = 0 + + return node_name, output_ix + diff --git a/graph_def_editor/function_graph.py b/graph_def_editor/function_graph.py new file mode 100644 index 0000000..afc0736 --- /dev/null +++ b/graph_def_editor/function_graph.py @@ -0,0 +1,351 @@ +# Copyright 2021 Google. All Rights Reserved. +# Copyright 2019 IBM. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Objects for representing function graphs undergoing rewrite operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import Counter +import datetime +from distutils import dir_util +import os +from six import string_types +import tensorflow.compat.v1 as tf +import sys +if sys.version >= "3": + from typing import Tuple, Dict, List, FrozenSet, Iterable, Union, Set, Any + +from graph_def_editor import base_graph, node, util, tensor, variable + +# TODO: Move this protobuf into this project so we don't depend on +# tf.core.framework +from tensorflow.core.framework import function_pb2, op_def_pb2 +from tensorflow.python.framework import function_def_to_graph + + +__all__ = [ + "FunctionGraph", +] + +# Special attribute in which TensorFlow stores frame names for while loops ( +# see node_to_frame_name() for more information +_INPUT_DUMMY_OP_NAME = "__input__" + + +class FunctionGraph(base_graph.BaseGraph): + """Wrapper class for TensorFlow function graphs. + + Summary of internal data structures: + * _node_name_to_node: Nodes in the graph, stored as a dictionary. Key is name. + * _version: Counter that increments every time the graph is modified + """ + + def __init__( + self, + name=None, # type: str + parent_graph=None # type: tf.Graph + ): + """Wrap a tf.GraphDef protocol buffer in a FunctionGraph object. + + Args: + g: a tf.Graph or tf.GraphDef protobuf that represents a + TensorFlow graph. If set to None, generate an empty + tf.GraphDef + name: Optional human-readable name for the graph. If not provided, + the constructor will generate a name. + """ + super(FunctionGraph, self).__init__(name) + (self._func_graph, self._func_graph_def) = \ + _get_func_graph_for_name(parent_graph, name) + output_map = _decode_graph(name, self._func_graph) + output_map_pairs = {} + for op_name, tuples in output_map.items(): + output_map_pairs[op_name] = \ + [(dtype, shape) for (dtype, shape, _) in tuples] + + # Populate fields of object + self._node_name_to_node = {} # Dict[str, node.Node]; key is node name + self._node_to_frame_names = None + self._frame_name_to_nodes = None + self._head_name_to_coloc_group = None # Dict[str, FrozenList[str]] + self._variable_name_to_variable = {} # Dict[str, Variable] + self._collection_name_to_type = None # Dict[str, str], generated on demand + self._input_nodes = [] + self._output_nodes = [] + + for input_arg in self._func_graph_def.signature.input_arg: + self._input_nodes.append( + self.add_node(input_arg.name, _INPUT_DUMMY_OP_NAME)) + self[input_arg.name].set_outputs_from_pairs( + output_map_pairs[input_arg.name]) + + # Load nodes in three passes because the g may contain cycles. + for node_def in self._func_graph_def.node_def: + self.add_node_from_node_def(node_def, set_inputs=False) + for node_def in self._func_graph_def.node_def: + self[node_def.name].set_outputs_from_pairs( + output_map_pairs[node_def.name]) + for node_def in self._func_graph_def.node_def: + try: + self[node_def.name].set_inputs_from_strings( + node_def.input, + set_control_inputs=True, + output_map=output_map) + except Exception as ex: + print("can't set inputs for node: {}; reason: {}".format( + node_def.name, ex)) + + for output_tensor in self._func_graph.outputs: + self._output_nodes.append(self.get_node_by_name(output_tensor.op.name)) + + @property + def input_nodes(self): + return self._input_nodes + + @property + def output_nodes(self): + return self._output_nodes + + def get_func_graph_for_name(self, graph, func_name): + """Returns the FuncGraph associated to the given func_name if possible.""" + outer_graph = graph + while graph is not None: + # pylint: disable=protected-access + func = graph._get_function(str(func_name)) + if func is not None: + if hasattr(func, "graph"): + return func.graph + # `outer_graph` may not be the same as `ops.get_default_graph()` e.g. + # in the case of nested if ops or when the gradient is being computed + # from inside a Defun. We build the `func_graph` with `outer_graph` + # as its outer graph. + with outer_graph.as_default(): + # This is a _DefinedFunction. + func_graph = ( + function_def_to_graph.function_def_to_graph(func.definition)) + if func_graph is not None: + return func_graph + if hasattr(graph, "outer_graph"): + graph = graph.outer_graph + else: + raise ValueError( + "Function {} does not exist in the graph.".format(func_name)) + + def to_function_graph_def(self, add_shapes=True): + # type: (bool) -> function_pb2.FunctionDef + """ + Args: + add_shapes: If True, add the special "_output_shapes" attribute with + output shape information from this Node's output metadata. + + Returns the `function_pb2.FunctionDef` serialization of this function's + graph in its current form. + """ + ret = function_pb2.FunctionDef() + ret.CopyFrom(self._func_graph_def) + # Leave signature as is, but replace all node_defs + del ret.node_def[:] + ret.signature.CopyFrom(self._func_graph_def.signature) + + input_args = [input_arg.name for input_arg in ret.signature.input_arg] + + for op in self.nodes: + if op.op_type == _INPUT_DUMMY_OP_NAME: + continue + + node_def = ret.node_def.add() + op.to_node_def(node_def, add_shapes) + unique_input_counter = Counter() + + for i in range(len(op.inputs)): + (input_tensor_name, global_input_index_str) = ( + op.inputs[i].name.split(":")) + + global_input_index = int(global_input_index_str) + if input_tensor_name in input_args: + # don't add index for function args + node_def.input[i] = input_tensor_name + else: + input_op_output_args, input_op_output_has_number_attr = ( + self._get_op_def_denormalized_outputs(op.inputs[i].op)) + if (len(input_op_output_args) == 1 and + input_op_output_args[0].type_list_attr): + node_def.input[i] = ( + input_tensor_name + ":" + input_op_output_args[0].name + ":" + + str(global_input_index)) + else: + input_name = ( + input_tensor_name + ":" + + input_op_output_args[global_input_index].name) + node_def.input[i] = ( + input_name + ":" + str(unique_input_counter[input_name])) + if input_op_output_has_number_attr: + # only uniquify input args with var length, + # otherwise it should be 0 + unique_input_counter[input_name] += 1 + return ret + + def to_tf_function_graph(self): + # type: () -> tf.Graph + """ + Converts this graph into a new TensorFlow `Graph`. Also takes care of + variables. + Note that function_def_to_graph.function_def_to_graph won't work if + function calls into other functions. + + Returns a fresh `tf.Graph` containing all the nodes and variables that + this object represents. + """ + return function_def_to_graph.function_def_to_graph( + self.to_function_graph_def()) + + def increment_version_counter(self): + """ + Mark the structure of this graph as "changed" and invalidate any cached + information about the edges of the graph. + """ + super(FunctionGraph, self).increment_version_counter() + self._node_to_frame_names = None + self._frame_name_to_nodes = None + self._head_name_to_coloc_group = None + self._collection_name_to_type = None + + def frame_name_to_nodes(self, frame_name): + # type: (str) -> Tuple[node.Node] + """ + Performs the inverse mapping of node_to_frame_name(). + + Args: + frame_name: Name of a control flow frame in the graph + + Returns: + All nodes that are tagged with the indicated frame, either as an + innermost frame or as a containing frame. + """ + if self._node_to_frame_names is None: + self._generate_node_to_frame_name() + return self._frame_name_to_nodes[frame_name] + + def get_frame_names(self): + # type: () -> Tuple[str] + """ + Returns: + Tuple of all the unique names of frames that occur in this graph. + """ + if self._node_to_frame_names is None: + self._generate_node_to_frame_name() + return self._frame_name_to_nodes.keys() + + def _get_op_def_denormalized_outputs(self, op): + # type: (Node) -> (List[op_def_pb2.OpDef.ArgDef], bool) + # pylint: disable=protected-access + op_def = self._func_graph._get_op_def(op.op_type) + output_args = [] + + input_op_output_has_number_attr = False + for output_arg in op_def.output_arg: + if output_arg.number_attr: + l = op.get_attr(output_arg.number_attr) + input_op_output_has_number_attr = True + for _ in range(l): + output_args.append(op_def_pb2.OpDef.ArgDef(name=output_arg.name, + type=output_arg.type)) + else: + output_args.append(output_arg) + + return (output_args, input_op_output_has_number_attr) + +################################################################################ +# Stuff below this line is private to this file. + + +def _get_func_graph_for_name(graph, func_name): + """Returns the FuncGraph and FuncDef associated to the given func_name.""" + outer_graph = graph + while graph is not None: + # pylint: disable=protected-access + func = graph._get_function(str(func_name)) + if func is not None: + if hasattr(func, "graph"): + return (func.graph, func.definition) + # `outer_graph` may not be the same as `ops.get_default_graph()` e.g. + # in the case of nested if ops or when the gradient is being computed + # from inside a Defun. We build the `func_graph` with `outer_graph` as its + # outer graph. + with outer_graph.as_default(): + # This is a _DefinedFunction. + func_graph = ( + function_def_to_graph.function_def_to_graph(func.definition)) + if func_graph is not None: + return (func_graph, func.definition) + if hasattr(graph, "outer_graph"): + graph = graph.outer_graph + else: + raise ValueError( + "Function {} does not exist in the graph.".format(func_name)) + + +def _decode_graph(name, func_graph): + # type: (str, tf.Graph) -> Dict[str, List[Tuple[tf.DType, tf.TensorShape, str]]] + """ + Use public TensorFlow APIs to decode the important information that is not + explicitly stored in the GraphDef proto, but which must be inferred from the + GraphDef in conjunction with additional data structures that TensorFlow + generally keeps to itself. + + Args: + name: function name. + func_graph: tf.GraphDef protobuf that represents a function graph. + + Returns: + A map from node name to a list of (type, shape, output_arg_name) tuples + that describes in turn each of the outputs of said node. + """ + # The information in a NodeDef is not sufficient to determine output type + # information. For that kind of type inference, you need access to the + # corresponding OpDef protos. Unfortunately there is not a public API that + # allows for OpDef lookup. So instead we instantiate the graph that + # graph_def describes. This approach makes things easier, but there will be + # a reduction in forwards compatibility, because import_graph_def() does a + # lot of sanity checks that aren't necessary when rewriting a graph_def. + output_map = {} + for op in func_graph.get_operations(): + # pylint: disable=protected-access + op_def = func_graph._get_op_def(op.type) + output_idx = 0 + output_map[op.name] = [] + for output_arg_idx in range(len(op_def.output_arg)): + output_arg = op_def.output_arg[output_arg_idx] + output = op.outputs[output_idx] + if output_arg.type_list_attr: + output_map[op.name] = [( + output.dtype, output.shape, op_def.output_arg[0].name) + for output in op.outputs] + break + elif output_arg.number_attr: + output_len = op.node_def.attr[output_arg.number_attr].i + for _ in range(output_len): + output = op.outputs[output_idx] + output_map[op.name].append( + (output.dtype, output.shape, output_arg.name)) + output_idx += 1 + else: + output_map[op.name].append( + (output.dtype, output.shape, output_arg.name)) + output_idx += 1 + return output_map + diff --git a/graph_def_editor/graph.py b/graph_def_editor/graph.py index 4ea0d76..54d580d 100644 --- a/graph_def_editor/graph.py +++ b/graph_def_editor/graph.py @@ -1,3 +1,4 @@ +# Copyright 2021 Google. All Rights Reserved. # Copyright 2019 IBM. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,16 +19,17 @@ from __future__ import division from __future__ import print_function +from collections import defaultdict import datetime -from distutils import dir_util import os from six import string_types import tensorflow.compat.v1 as tf import sys if sys.version >= '3': from typing import Tuple, Dict, FrozenSet, Iterable, Union, Set, Any +import queue -from graph_def_editor import node, util, tensor, variable +from graph_def_editor import base_graph, function_graph, node, util, tensor, variable # TODO: Move this protobuf into this project so we don't depend on # tf.core.framework @@ -45,6 +47,7 @@ # Special attribute in which TensorFlow stores frame names for while loops ( # see node_to_frame_name() for more information _FRAME_NAME_ATTR = "frame_name" +_INPUT_DUMMY_OP_NAME = "__input__" class GraphVisitor(object): @@ -55,6 +58,9 @@ def visit_node(self, n): # type: (node.Node) -> None raise NotImplementedError() + def __call__(self, n): + return self.visit_node(n) + class SaverInfo(object): """ @@ -99,8 +105,7 @@ def signature_defs(self): # type: () -> Dict[str, Any] return self._signature_defs - -class Graph(object): +class Graph(base_graph.BaseGraph): """ Mutable surrogate for a `tf.GraphDef` protocol buffer message @@ -114,15 +119,17 @@ class Graph(object): """ def __init__( - self, - g = None, # type: Union[tf.Graph, tf.GraphDef] - name = None, # type: str - collections = None, # type: Dict[str, meta_graph_pb2.CollectionDef] - saver_info = None, # type: SaverInfo - signature_info = None # type: SignatureInfo - ): - """ - Wrap a tf.GraphDef protocol buffer in a Graph object. + self, + g=None, # type: Union[tf.Graph, tf.GraphDef] + name=None, # type: str + collections=None, # type: Dict[str, meta_graph_pb2.CollectionDef] + saver_info=None, # type: SaverInfo + signature_info=None, # type: SignatureInfo, + object_graph_def=None, # type: saved_object_graph_pb2.SavedObjectGraph + stripped_op_list=None, # type: op_def_pb2.OpList + asset_file_def=None, # type: meta_graph_pb2.AssetFileDef + ): + """Wrap a tf.GraphDef protocol buffer in a Graph object. Args: g: a tf.Graph or tf.GraphDef protobuf that represents a @@ -130,7 +137,7 @@ def __init__( tf.GraphDef name: Optional human-readable name for the graph. If not provided, the constructor will generate a name. - collections: Optional iterable of tf.MetaGraphDef.CollectionDefEntry + collections: Optional iterable of tf.MetaGraphDef.CollectionDefEntry objects containing information about collections in the graph. Note that this constructor will pull collection info out of `g` if it is a `tf.Graph` and `collections` is `None`. @@ -138,36 +145,46 @@ def __init__( `tf.train.Saver` object that can save and restore variables in this graph. signature_info: Optional semi-serialized information about entry points - to the graph, AKA signatures + to the graph, AKA signatures. + object_graph_def: Optional SavedObjectGraph for TF2.x. + stripped_op_list: Optional stripped op list. + asset_file_def: Optional saved model assets. """ + if name is None: + time_str = datetime.datetime.now().isoformat() + name = "GraphDef Editor Graph created {}".format(time_str) + super(Graph, self).__init__(name) + self._graph = None if g is None: - graph_def = tf.GraphDef() + self._graph_def = tf.GraphDef() elif isinstance(g, tf.GraphDef): - graph_def = g + self._graph_def = g elif isinstance(g, tf.Graph): - graph_def = g.as_graph_def() + self._graph_def = g.as_graph_def() + self._graph = g if collections is None: meta_gd = tf.train.export_meta_graph(graph=g) collections = _extract_collection_defs(meta_gd) else: raise TypeError("Graph is of type {}. Expected a tf.Graph or GraphDef " "proto".format(type(g))) - if name is None: - time_str = datetime.datetime.now().isoformat() - name = "GraphDef Editor Graph created {}".format(time_str) if signature_info is None: signature_info = SignatureInfo() elif not isinstance(signature_info, SignatureInfo): raise ValueError("signature_info argument must be a SignatureInfo object") + # Caching tf.Graph object, so we won't have to load it again. + if self._graph is None: + self._graph = tf.Graph() + with self._graph.as_default(): + tf.import_graph_def(self._graph_def, name="") + # Populate fields of object - self._name = name # str self._version = 0 # Must happen first; other init code needs self._version self._frozen = False # bool - self._graph_def = graph_def # tf.GraphDef self._next_id = 1 # int - output_map = _decode_graph(graph_def) self._node_name_to_node = {} # Dict[str, node.Node]; key is node name + output_map = _decode_graph(self._graph) self._node_to_frame_names = None self._frame_name_to_nodes = None self._head_name_to_coloc_group = None # Dict[str, FrozenList[str]] @@ -175,14 +192,15 @@ def __init__( self._collection_name_to_type = None # Dict[str, str], generated on demand self._passthrough_collections = {} # Dict[str, List[CollectionDef]] self._passthrough_saver = None - self._passthrough_versions = graph_def.versions # tf.VersionDef + self._passthrough_versions = self._graph_def.versions # tf.VERSIONDef + self._function_graphs = dict() # Dict[str, gde.FuncGraph], on demand # Load nodes in three passes because the g may contain cycles. - for node_def in graph_def.node: + for node_def in self._graph_def.node: self.add_node_from_node_def(node_def, set_inputs=False) - for node_def in graph_def.node: - self[node_def.name].set_outputs_from_pairs(output_map[node_def.name]) - for node_def in graph_def.node: + for node_def in self._graph_def.node: + self[node_def.name].set_outputs_from_pairs(output_map[node_def.name]) + for node_def in self._graph_def.node: self[node_def.name].set_inputs_from_strings(node_def.input, set_control_inputs=True) # Collections reference nodes and variables @@ -194,42 +212,14 @@ def __init__( # so load after variables are constituted (i.e. from collections) self._passthrough_saver = saver_info self._signatures = signature_info - - @property - def name(self): - """ - Returns human-readable name for this graph. This name may not be unique - across graphs. - """ - return self._name + self._object_graph_def = object_graph_def + self._stripped_op_list = stripped_op_list + self._asset_file_def = asset_file_def @property def has_passthrough_saver(self): return self._passthrough_saver is not None - def add_node_from_node_def(self, node_def, set_inputs = False): - # type: (tf.NodeDef, bool) -> node.Node - """ - Unpack a `tf.NodeDef` protobuf into a mutable `Node` object.' - - Does NOT set the outputs of the node. - - Args: - g: Graph in which the node will be created - node_def: Fully-populated NodeDef proto; all fields, including inputs, - will be used. - set_inputs: Optional. If True, also populate the data and control inputs - of the returned Node. This operation will only work if the targets of - those inputs are already present in the graph. - """ - ret = self.add_node(name=node_def.name, op_name=node_def.op) - ret.device = node_def.device - for key in node_def.attr: - ret.add_attr(key, util.attr_value_to_python_type(node_def.attr[key])) - if set_inputs: - ret.set_inputs_from_strings(node_def.input, set_control_inputs=True) - return ret - def add_collection_from_collection_def( self, collection_name, @@ -237,14 +227,14 @@ def add_collection_from_collection_def( validate_name = True): # type: (str, meta_graph_pb2.CollectionDef, bool) -> None """ - Unpack a `tf.MetaGraphDef.CollectionDefEntry` of serialized variables - into a collection of variables in this graph. The collection must not exist. + Unpack a `tf.MetaGraphDef.CollectionDefEntry` of serialized variables + into a collection of variables in this graph. The collection must not exist. Variables that do not already exist will be created. Note that this method is intended to be used to bulk-load a collection. To add individual items to a collection one-by-one, call the `add_to_collection` methods of `Node`, etc., objects. - + Args: collection_name: Name of collection collection_def: Serialized information about the collection @@ -258,7 +248,7 @@ def add_collection_from_collection_def( if collection_def.HasField("node_list"): for node_name in collection_def.node_list.value: # Check if node name is a Tensor type - if node_name.rfind(':') > -1: + if node_name.rfind(":") > -1: n = self.get_tensor_by_name(node_name) else: n = self.get_node_by_name(node_name) @@ -268,368 +258,46 @@ def add_collection_from_collection_def( var = self.add_variable_from_variable_def(serialized_var, skip_if_present=True) var.add_to_collection(collection_name) - elif (collection_def.HasField("int64_list") - or collection_def.HasField("float_list") + elif (collection_def.HasField("int64_list") \ + or collection_def.HasField("float_list") \ or collection_def.HasField("any_list")): self._passthrough_collections[collection_name] = collection_def if self._collection_name_to_type is not None: self._collection_name_to_type[collection_name] = "passthrough" else: raise ValueError("Unknown collection with name: {}".format( - collection_name)) - - def __getitem__(self, name): - # type: (str) -> Union[tensor.Tensor, 'node.Node'] - """ - Convenience method to retrieve a node or tensor of the graph by name - - Args: - name: Name of the node or tensor to return. Case-sensitive. - - Returns the named item as a `gde.Node` or `gde.Tensor` object. If there - is a conflict between node and tensor names, node names win. - """ - if not isinstance(name, string_types): - raise TypeError("name must be a string; got type {}".format(type(name))) - - if self.contains_node(name): - return self._node_name_to_node[name] - elif self.contains_tensor(name): - return self.get_tensor_by_name(name) - else: - raise ValueError("No node or tensor '{}' found in graph".format(name)) - - def get_node_by_name(self, name): - # type: (str) -> node.Node - """ - Retrieve a node in the graph by name. - - Args: - name: Name of the node. Case-sensitive. - - Returns the indicated node as a `gde.Node` object. - """ - if self.contains_node(name): - return self._node_name_to_node[name] - else: - raise ValueError("No node '{}' found in graph".format(name)) - - def contains_node(self, name): - # type: (str) -> bool - """ - Returns true if the graph has a node by the indicated name. Exact string - match. - """ - if not isinstance(name, string_types): - raise ValueError("Node name argument is not a string, but is of type " - "{}".format(type(name))) - return name in self._node_name_to_node.keys() - - def add_node(self, - name, # type: str - op_name, # type: str - uniquify_name = False # type: bool - ): - # type: (...) -> node.Node - """ - Add a new, empty node to the graph. - Args: - name: Name for the new op - op_name: Name of the type of operation for the node - uniquify_name: Generate a unique name from this name if the graph - already has a node with the indicated name. If False, raise an - exception if the name is in use. - - Returns: - `MutableNode` wrapper for the new node. - - Raises: - ValueError if the name is already in use and `uniquify_name` is False - """ - if uniquify_name: - name = self.unique_name(name) - elif self._name_in_use(name): # and not uniquify_name - raise ValueError("Graph already contains a node with name '{}' " - "(Note that this check is case-insensitive)." - .format(name)) - ret = node.Node(self, self._get_next_id(), name=name, op_name=op_name) - self._node_name_to_node[name] = ret - self.increment_version_counter() - return ret - - def add_node_from_node_def(self, - node_def, # type: tf.NodeDef - set_inputs = False, # type: bool - set_control_inputs = False # type: bool - ): - # type: (...) -> node.Node - """ - Adds a new node to the graph, populating fields of the node from a - `tf.NodeDef` protocol buffer. - - Equivalent to calling `add_node()`, then populating the relevant fields - of the returned MutableNode object. - - Args: - node_def: Protocol buffer describing parameters of the new node. - set_inputs: If True, populate the node's inputs list from the list of - inputs in the `NodeDef` - set_control_inputs: Also set control inputs. Must be False if - `set_inputs` is False. - - Returns: - `MutableNode` wrapper for the new node - """ - if set_control_inputs and not set_inputs: - raise ValueError("set_inputs must be True if set_control_inputs is True") - ret = self.add_node(node_def.name, node_def.op) - if set_inputs: - ret.set_inputs_from_strings(node_def.input, - set_control_inputs=set_control_inputs) - ret.device = node_def.device - ret.clear_attrs() - for key in node_def.attr: - ret.add_attr(key, node_def.attr[key]) - - # Don't need to increment version counter; add_node() already did that. - return ret - - def remove_node_by_name(self, name, check_for_refs = True): - # type: (str, str) -> None - """ - Removes the indicated node from this graph and from any collections in - this graph. - - The caller is responsible for removing all links to the indicated node - prior to making this call. - - Args: - name: name of the node to remove - check_for_refs: Optional. If True, raise an exception if there are any - other nodes in the graph that reference this node. If False, allow - removal of nodes with outstanding references to them. In the latter - case, the caller is responsible for cleaning up the graph afterwards. - """ - n = self.get_node_by_name(name) - if check_for_refs: - for t in n.outputs: - if len(t.consumers()) > 0: - raise ValueError("Removing node '{}' would leave dangling " - "references from nodes {} to tensor '{}'" - "".format(name, [c.name for c in t.consumers()], - t.name)) - # noinspection PyProtectedMember - n._remove_from_graph() - del self._node_name_to_node[name] - self.increment_version_counter() - # Don't need to update collection info because collection membership is - # stored in the node. - # Don't need to update consumers of tensors because that information is - # calculated dynamically by iterating over nodes. - - def rename_node(self, old_name, new_name): - # type: (str, str) -> None - """ - Change the name of a node in the graph. - - Args: - old_name: Name of an existing node - new_name: New name for the node in question. Must not currently be in use. - """ - if self.contains_node(new_name): - raise ValueError("Graph already has a node under name '{}'".format( - new_name)) - n = self.get_node_by_name(old_name) - # noinspection PyProtectedMember - n._change_name(new_name) - del self._node_name_to_node[old_name] - self._node_name_to_node[new_name] = n - self.increment_version_counter() - - def add_variable(self, name): - # type: (str) -> variable.Variable - """ - Adds a new variable to the graph. - - Args: - name: Name of the variable. Must not already be in use. - - Returns the `gde.Variable` object corresponding to the added variable. - """ - if name in self._variable_name_to_variable: - raise ValueError("Variable name '{}' already in use".format(name)) - v = variable.Variable(self) - v.name = name - self._variable_name_to_variable[name] = v - return v - - def add_variable_from_variable_def(self, variable_def, - skip_if_present = False): - # type: (Any, bool) -> None - """ - Adds a new variable to the graph and populates the fields of the - corresponding Variable object according to a protocol buffer message. - - Args: - variable_def: `tensorflow.core.framework.variable_pb2.VariableDef` - protobuf object. May be serialized as a `bytes` object. - skip_if_present: If True, silently skips inserting duplicate variables, - as long as they don't conflict with existing variables. - - Returns the `gde.Variable` object corresponding to the added variable. - """ - v = variable.Variable(self) - v.from_proto(variable_def, allow_duplicates=skip_if_present) - if v.name not in self._variable_name_to_variable: - self._variable_name_to_variable[v.name] = v - return self._variable_name_to_variable[v.name] + collection_name)) @property - def variable_names(self): - return self._variable_name_to_variable.keys() - - def get_variable_by_name(self, name): - # type: (str) -> variable.Variable - """ - Fetch a variable by its variable name. - - Args: - name: Name of a variable in this graph. - - Returns the variable associated with the name. Raises an exception if - there is no variable with the indicated name. - """ - return self._variable_name_to_variable[name] - - def _name_in_use(self, name): - # type: (str) -> bool - """Check whether a name is in use, using the same collision semantics as - TensorFlow: Exact lowercase string match. - - Args: - name: Name of a potential node in the graph. - - Returns True if the indicated name is currently in use, ignoring case. - """ - return name.lower() in [k.lower() for k in self._node_name_to_node.keys()] - - def unique_name(self, name): - # type: (str) -> str - """Emulate the behavior of the method by the same name in `tf.Graph`. - - Does *not* emulate the `name_stack` field of `tf.Graph`. - - Unlike the original method, this version does *not* keep a separate table - of names currently "in use for the purposes of `unique_name()`", but instead - refers directly to internal data structures to find names that are truly - in use. - - Args: - name: The name for an operation. - - Returns: - A variant of `name` that has been made unique by appending a key to it - in the same way that `tf.Graph.unique_name()` would. - """ - # For the sake of checking for names in use, we treat names as case - # insensitive (e.g. foo = Foo). - if not self._name_in_use(name): - return name - - # Generate a unique version by appending "_1", "_2", etc. until we find - # an unused name. Note that this approach will behave slightly - # differently from the original if nodes are deleted. - i = 1 - new_name = "{}_{}".format(name, i) - while self._name_in_use(new_name): - i = i + 1 - new_name = "{}_{}".format(name, i) - return new_name - - @property - def node_names(self): - # type: () -> Iterable[node.Node] - return self._node_name_to_node.keys() - - @property - def nodes(self): - # type: () -> Tuple[node.Node] - """ - Returns: - A list of all nodes, both immutable and mutable, present in the graph - after the edits that this object is buffering. - """ - return tuple(self._node_name_to_node.values()) - - @property - def tensors(self): - # type: () -> List[tensor.Tensor] - """ - Return a list of all the tensors which are input or output of an op in - the graph. - """ - ts = [] - for op in self.nodes: - ts += op.outputs - return ts + def function_names(self): + # type: () -> Iterable[str] + return [f.signature.name for f in self._graph_def.library.function] - def contains_tensor(self, tensor_name): - # type: (str) -> bool + def get_function_graph_by_name(self, function_name): """ - Returns true if the graph has a tensor by the indicated name. Exact string - match. + Retrieve a function by name and wrap it into a function_graph.FunctionGraph. Args: - tensor_name: TensorFlow-format name ('node name:input num', or 'node - name' as shorthand for 'node name:0') + function_name: Function name. - Raises ValueError if the tensor name is not properly formatted. - """ - error_msg = "Invalid tensor name '{}': {}" - node_name, output_ix = _decode_tensor_name(tensor_name, error_msg) - if node_name not in self._node_name_to_node: - return False - else: - n = self[node_name] - if output_ix >= len(n.outputs): - return False - else: - return True + Returns: function_graph.FunctionGraph object. - def get_tensor_by_name(self, tensor_name, error_msg = None): - # type: (str, str) -> tensor.Tensor + Raises: ValueError if the function with specified name is not found + in the graph. """ - Retrieve a tensor by human-readable name. + if function_name not in self.function_names: + raise ValueError("Function '{}' is not found in graph".format( + function_name)) - Args: - tensor_name: TensorFlow-format name ('node name:input num', or 'node - name' as shorthand for 'node name:0') - error_msg: Optional format string for raising errors. Must be able to - serve as an input to `str.format()` with two arguments: tensor name - string and reason for failure. + if function_name not in self._function_graphs: + self._function_graphs[function_name] = function_graph.FunctionGraph( + name=function_name, + parent_graph=self._graph) - Returns: gde.Tensor object corresponding to the indicated tensor. + return self._function_graphs[function_name] - Raises ValueError if the name is invalid or references a tensor that does - not exist. - """ - if error_msg is None: - error_msg = "Invalid tensor name '{}': {}" - node_name, output_ix = _decode_tensor_name(tensor_name, error_msg) - if node_name not in self._node_name_to_node: - raise ValueError(error_msg.format( - tensor_name, "Node name '{}' not found in graph.".format(node_name) - )) - n = self[node_name] - if output_ix >= len(n.outputs): - raise ValueError(error_msg.format( - tensor_name, "Requested output {}, but node '{}' has {} " - "outputs.".format(output_ix, node_name, len(n.outputs)) - )) - return n.output(output_ix) - - def to_graph_def(self, add_shapes = True): - # type: (bool) -> tf.GraphDef + def to_graph_def(self, add_shapes=True): + # type: (bool) -> tf.compat.v1.GraphDef """ Args: add_shapes: If True, add the special "_output_shapes" attribute with @@ -642,6 +310,22 @@ def to_graph_def(self, add_shapes = True): ret.versions.CopyFrom(self._passthrough_versions) for op in self.nodes: op.to_node_def(ret.node.add(), add_shapes) + + # Pass through library without modifications for now. + if self._graph_def and self._graph_def.library: + ret.library.CopyFrom(self._graph_def.library) + + for f_name, f_graph in self._function_graphs.items(): + function_index_to_update = None + for index in range(0, len(ret.library.function)): + if ret.library.function[index].signature.name == f_name: + function_index_to_update = index + break + if function_index_to_update is None: + ValueError("Function '{}' is not found in graph".format(f_name)) + ret.library.function[function_index_to_update].Clear() + ret.library.function[function_index_to_update].MergeFrom( + f_graph.to_function_graph_def()) return ret def to_tf_graph(self): @@ -676,14 +360,14 @@ def to_saved_model(self, saved_model_path, tags = None): """ if tags is None: tags = [tf.saved_model.tag_constants.SERVING] - if os.path.exists(saved_model_path): + if tf.gfile.Exists(saved_model_path): raise ValueError("Output path '{}' already exists".format( saved_model_path)) - if not os.path.exists(os.path.dirname(saved_model_path)): + if not tf.gfile.Exists(os.path.dirname(saved_model_path)): raise ValueError("Parent directory '{}' of output dir '{}' does not " "exist".format(os.path.dirname(saved_model_path), saved_model_path)) - os.mkdir(saved_model_path) + tf.gfile.MkDir(saved_model_path) # Core part of the SavedModel is a protocol buffers file containing a # SavedModel protocol buffer message. @@ -732,6 +416,10 @@ def to_saved_model(self, saved_model_path, tags = None): # set it to False. meta_info_def.stripped_default_attrs = False + # Passing through stripped_op_list. + if self._stripped_op_list: + meta_info_def.stripped_op_list.CopyFrom(self._stripped_op_list) + meta_graph.meta_info_def.CopyFrom(meta_info_def) # After the meta_info_def comes a GraphDef proto holding all the graph @@ -745,17 +433,20 @@ def to_saved_model(self, saved_model_path, tags = None): # instance that will be used to reconstitute any variables in the graph if self.has_passthrough_saver: meta_graph.saver_def.CopyFrom(self._passthrough_saver.saver_def) + if not tf.gfile.Exists(_vars_dir_for_saved_model(saved_model_path)): + tf.gfile.MkDir(_vars_dir_for_saved_model(saved_model_path)) # Copy serialized variables checkpoint wholesale, because the checkpoint # format is a black box to us. - dir_util.copy_tree(self._passthrough_saver.path, - _vars_dir_for_saved_model(saved_model_path)) + util.copy_directory(self._passthrough_saver.path, + _vars_dir_for_saved_model(saved_model_path), + overwrite=True) elif len(self.variable_names) > 0: raise NotImplementedError("Can't generate a SaverDef.") else: # Zero variables, no passthrough SaverDef. # For this case, TensorFlow creates an empty variables directory and # doesn't set the "saver_def" field. We emulate this behavior. - os.mkdir(_vars_dir_for_saved_model(saved_model_path)) + tf.gfile.MkDir(_vars_dir_for_saved_model(saved_model_path)) # The next field, "collection_def", holds serialized information about all # collections in the MetaGraph. @@ -784,47 +475,39 @@ def to_saved_model(self, saved_model_path, tags = None): # The final field, asset_file_def, stores information about additional # assets that are packaged along with the graph in the SavedModel's - # "assets" directory. Fow now we leave this field empty. - # TODO(frreiss): Represent assets as a field in the Graph class and - # serialize them here. + # "assets" directory. + if self._asset_file_def: + meta_graph.asset_file_def.extend(self._asset_file_def) + from_assets_path = os.path.join(self._passthrough_saver.path, "..", + "assets") + if tf.gfile.Exists(from_assets_path): + to_assets_path = os.path.join(saved_model_path, "assets") + + if not tf.gfile.Exists(to_assets_path): + tf.gfile.MkDir(to_assets_path) + util.copy_directory(from_assets_path, + to_assets_path, + overwrite=True) + + # It should be fine copying object_graph_def, as function signature + # changes are not supported. + if self._object_graph_def: + meta_graph.object_graph_def.CopyFrom(self._object_graph_def) # At this point, we have created the root directory for the SavedModel, # as well as the checkpoints directory. The only thing left to write is # the SavedModel protobuf itself. - with open(saved_model_path + "/saved_model.pb", "wb") as f: + with tf.gfile.Open(saved_model_path + "/saved_model.pb", "wb") as f: f.write(saved_model.SerializeToString()) return saved_model - @property - def version(self): - # type: () -> int - """ - Returns a counter that goes up every time this graph is changed. - """ - return self._version - - @property - def frozen(self): - # type: () -> bool - """ - True if the graph is configured to raise an exception on any structural - modification. - """ - return self._frozen - - @frozen.setter - def frozen(self, value): - # type: (bool) -> None - self._frozen = value def increment_version_counter(self): """ Mark the structure of this graph as "changed" and invalidate any cached information about the edges of the graph. """ - if self.frozen: - raise RuntimeError("Detected a change to a frozen graph") - self._version += 1 + super(Graph, self).increment_version_counter() self._node_to_frame_names = None self._frame_name_to_nodes = None self._head_name_to_coloc_group = None @@ -1001,42 +684,218 @@ def get_frame_names(self): self._generate_node_to_frame_name() return self._frame_name_to_nodes.keys() - def breadth_first_visitor(self, - visitor, # type: GraphVisitor - starting_nodes = None # type: Iterable[node.Node] - ): + def nodes_iterator( + self, + predicate=lambda _: True, # type: (node.Node) -> bool + iterate_functions=False # type: bool + ): + # type: (...) -> Iterable[node.Node] + """ + Returns: + An iterator over nodes matching predicate in current graph and + from function graphs if iterate_functions=True. + """ + for op in self.nodes: + if predicate(op): + yield op + if iterate_functions: + for function_name in self.function_names: + for op in self.get_function_graph_by_name(function_name).nodes: + if predicate(op): + yield op + + def breadth_first_visitor( + self, + visitor, # type: callable + starting_nodes=None, # type: Iterable[node.Node] + visited_nodes=None, # type: set + iterate_functions=False, # type: bool + escape_functions=False # type: bool + ): # type: (...) -> None """ Visit all nodes reachable from a starting set in the order of a - breadth-first traversal. Invokes a callback at each node visited. + breadth-first traversal (going from node to output edges). + Invokes a callback at each node visited. Args: visitor: Possibly-stateful callback to be invoked on each node reached starting_nodes: Optional list of starting nodes. If this set is not - provided, this method will use all nodes with zero inputs as the - starting set. Search will visit these nodes first, then visit their - children in order by parent node. + provided, this method will use all nodes with zero inputs as the + starting set. Search will visit these nodes first, then visit their + children in order by parent node. + visited_nodes: Optional set of nodes to skip iterating over. + iterate_functions: Indicates if we should also go inside functions if one + is found in the graph. + escape_functions: If iteration started in a function graph, indicates + that we should also iterate through all function callers up + the stack. + Returns: + True if iteration was iterruputed by visitor, otherwise False. """ if starting_nodes is None: # Start with all of the nodes in the graph that have no inputs. # The maintainers of the TensorFlow scheduler like to call these nodes # "root nodes". - starting_nodes = [n for n in self.nodes if 0 == len(n.inputs)] + starting_nodes = [n for n in self.nodes if not n.inputs] + + if visited_nodes is None: + visited_nodes = set() + + nodes_queue = queue.Queue() + function_graph_names_set = set() + + starting_nodes_set = set() + for n in starting_nodes: + nodes_queue.put((n, None)) + starting_nodes_set.add(n) + + while not nodes_queue.empty(): + (n, input_tensor) = nodes_queue.get() + if n in visited_nodes: + continue + if n.op_type != _INPUT_DUMMY_OP_NAME: + if visitor(n): + return True + + if escape_functions and isinstance(n.graph, function_graph.FunctionGraph): + function_graph_names_set.add(n.graph.name) + + visited_nodes.add(n) + if iterate_functions and n.op_type in node.PARTITIONED_CALL_OP_TYPES: + function_name = n.get_attr("f").name + f_graph = self.get_function_graph_by_name(function_name) + + function_inputs = [] + if input_tensor is not None: + for input_node in f_graph.nodes: + if input_node.op_type == _INPUT_DUMMY_OP_NAME and input_node.name in input_tensor.name: + function_inputs.append(input_node) + + if len(function_inputs) == 0: + function_inputs = [input for input in f_graph.nodes if input.op_type == _INPUT_DUMMY_OP_NAME] + + if self.breadth_first_visitor( + visitor, + starting_nodes=function_inputs, + visited_nodes=visited_nodes, + iterate_functions=iterate_functions, + escape_functions=False): + return True + + for output_tensor in n.outputs: + for consumer in output_tensor.consumers(): + if consumer not in visited_nodes and \ + consumer not in starting_nodes_set: + nodes_queue.put((consumer, output_tensor)) + + if escape_functions and function_graph_names_set: + function_invocation_ops = self.nodes_iterator( + predicate=lambda f: (f.op_type in node.PARTITIONED_CALL_OP_TYPES and + f.get_attr("f").name in function_graph_names_set), + iterate_functions=True) + for function_invocation_op in function_invocation_ops: + function_output_consumers = set() + for output_tensor in function_invocation_op.outputs: + function_output_consumers.update(output_tensor.consumers()) + if self.breadth_first_visitor( + visitor, + starting_nodes=function_output_consumers, + visited_nodes=visited_nodes, + iterate_functions=iterate_functions, + escape_functions=True): + return True + return False + + def backwards_breadth_first_visitor( + self, + visitor, # type: callable + starting_nodes=None, # type: Iterable[node.Node] + visited_nodes=None, # type: set + iterate_functions=False, # type: bool + escape_functions=False # type: bool + ): # type: (...) -> None + """ + Visit all nodes reachable from a starting set in the order of a + backwards breadth-first traversal (going from node to input edges). + Invokes a callback at each node visited. + + Args: + visitor: Possibly-stateful callback to be invoked on each node reached + starting_nodes: List of starting nodes. + visited_nodes: Optional set of nodes to skip iterating over. + iterate_functions: Indicates if we should also go inside functions if one + is found in the graph. + escape_functions: If iteration started in a function graph, indicates + that we should also iterate through all function callers up + the stack. + Returns: + True if iteration was iterruputed by visitor, otherwise False. + """ + if not starting_nodes: + raise ValueError("starting_nodes is not provided") + + nodes_queue = queue.Queue() + function_graph_names_set = set() + + if visited_nodes is None: + visited_nodes = set() + + starting_nodes_set = set() + for n in starting_nodes: + starting_nodes_set.add(n) + nodes_queue.put(n) - # Use a Python list as a node queue for the breadth-first search. - queue = list(starting_nodes) - enqueued_nodes = set(queue) + while not nodes_queue.empty(): + n = nodes_queue.get() - while len(queue) > 0: - cur_node = queue.pop(0) - visitor.visit_node(cur_node) + if n in visited_nodes: + continue - # Prepare for next stage of search - for out_tensor in cur_node.outputs: - for out_node in out_tensor.consumers(): - if out_node not in enqueued_nodes: - queue.append(out_node) - enqueued_nodes.add(out_node) + if n.op_type != _INPUT_DUMMY_OP_NAME: + if visitor(n): + return True + + if escape_functions and isinstance(n.graph, function_graph.FunctionGraph): + function_graph_names_set.add(n.graph.name) + + if (iterate_functions and n.op_type in node.PARTITIONED_CALL_OP_TYPES and + n not in starting_nodes_set): + function_name = n.get_attr("f").name + f_graph = self.get_function_graph_by_name(function_name) + + if self.backwards_breadth_first_visitor( + visitor, + starting_nodes=f_graph.output_nodes, + visited_nodes=visited_nodes, + iterate_functions=iterate_functions, + escape_functions=False): + return True + + visited_nodes.add(n) + for input_tensor in n.inputs: + if (input_tensor.op not in visited_nodes and + input_tensor.op not in starting_nodes_set): + nodes_queue.put(input_tensor.op) + + if escape_functions and function_graph_names_set: + function_invocation_ops = self.nodes_iterator( + predicate=lambda f: (f.op_type in node.PARTITIONED_CALL_OP_TYPES and + f.get_attr("f").name in function_graph_names_set), + iterate_functions=True) + + caller_ops = list(function_invocation_ops) + + if len(caller_ops) > 0: + if self.backwards_breadth_first_visitor( + visitor, + starting_nodes=caller_ops, + visited_nodes=visited_nodes, + iterate_functions=iterate_functions, + escape_functions=True): + return True + + return False def infer_shapes_and_dtypes(self, starting_nodes = None # type: Iterable[node.Node] @@ -1166,10 +1025,10 @@ def saved_model_to_graph(saved_model_path, # type: str Returns: In-memory representation of the contents of the SavedModel as a Graph object. """ - if not os.path.exists(saved_model_path): + if not tf.gfile.Exists(saved_model_path): raise ValueError("SavedModel root directory {} not found".format( saved_model_path)) - if not os.path.isdir(saved_model_path): + if not tf.gfile.IsDirectory(saved_model_path): raise ValueError("SavedModel root path {} is not a directory".format( saved_model_path)) @@ -1177,7 +1036,7 @@ def saved_model_to_graph(saved_model_path, # type: str # "saved_model.pb" protobuf_file = saved_model_path + "/saved_model.pb" saved_model = saved_model_pb2.SavedModel() - with open(protobuf_file, "rb") as f: + with tf.gfile.Open(protobuf_file, "rb") as f: saved_model.ParseFromString(f.read()) # Drill down to pull out the appropriate MetaGraphDef proto @@ -1215,18 +1074,29 @@ def saved_model_to_graph(saved_model_path, # type: str for key in meta_graph.signature_def: signature_info.add_signature_def(key, meta_graph.signature_def[key]) + object_graph_def = None + if meta_graph.object_graph_def: + object_graph_def=meta_graph.object_graph_def + + stripped_op_list = None + if meta_graph.meta_info_def.stripped_op_list: + stripped_op_list = meta_graph.meta_info_def.stripped_op_list + return Graph(graph_def, name=meta_graph.meta_info_def.meta_graph_version, collections=collections, saver_info=saver_info, - signature_info=signature_info) + signature_info=signature_info, + object_graph_def=object_graph_def, + stripped_op_list=stripped_op_list, + asset_file_def=meta_graph.asset_file_def) ################################################################################ # Stuff below this line is private to this file. -def _decode_graph(graph_def): - # type: (tf.GraphDef) -> Dict[str, List[Tuple[tf.DType, tf.TensorShape]]] +def _decode_graph(graph): + # type: (tf.Graph) -> Dict[str, List[Tuple[tf.DType, tf.TensorShape]]] """ Use public TensorFlow APIs to decode the important information that is not explicitly stored in the GraphDef proto, but which must be inferred from the @@ -1250,11 +1120,8 @@ def _decode_graph(graph_def): # graph_def describes. This approach makes things easier, but there will be # a reduction in forwards compatibility, because import_graph_def() does a # lot of sanity checks that aren't necessary when rewriting a graph_def. - temp_graph = tf.Graph() - with temp_graph.as_default(): - tf.import_graph_def(graph_def, name="") output_map = {op.name: [(t.dtype, t.shape) for t in op.outputs] - for op in temp_graph.get_operations()} + for op in graph.get_operations()} return output_map @@ -1274,35 +1141,6 @@ def _extract_collection_defs(meta_graph): return collections - -def _decode_tensor_name(tensor_name, error_msg): - # type: (str, str) -> Tuple[str, int] - """ - Args: - tensor_name: TensorFlow-format name ('node name:input num', or 'node - name' as shorthand for 'node name:0') - error_msg: Format string for raising errors. Must be able to - serve as an input to `str.format()` with two arguments: tensor name - string and reason for failure. - - Returns: (node name, output index) tuple identifying the tensor - - Raises ValueError if the name is not properly formatted - """ - if ":" in tensor_name: - node_name, output_ix_str = tensor_name.split(":") - if not output_ix_str.isdigit(): - raise ValueError(error_msg.format( - tensor_name, "Invalid output index string '{}'.".format(output_ix_str) - )) - output_ix = int(output_ix_str) - else: - node_name = tensor_name - output_ix = 0 - - return node_name, output_ix - - def _duplicate_collection_error_str( name, # type: str passthrough_collection_names, # type: Set[str] diff --git a/graph_def_editor/node.py b/graph_def_editor/node.py index 81ed00a..913caf7 100644 --- a/graph_def_editor/node.py +++ b/graph_def_editor/node.py @@ -23,7 +23,7 @@ if sys.version >= '3': from typing import Tuple, List, Iterable, Any, AbstractSet -from graph_def_editor import graph, tensor, util +from graph_def_editor import tensor, util # Magical attribute name that TensorFlow uses to store colocation groups. # See colocation_groups property below for more information. @@ -42,6 +42,9 @@ "Node", ] +PARTITIONED_CALL_OP_TYPES = frozenset([ + "PartitionedCall", "StatefulPartitionedCall", "TPUPartitionedCall"]) + class Node(object): """ @@ -54,7 +57,8 @@ def __init__(self, node_id, # type: int name, # type: int op_name, # type: str - device = "" # type: str + device = "", # type: str + debug_info = None # type: tf.compat.v1.NodeDef.ExperimentalDebugInfo ): """ This constructor should only be called from methods of the Graph @@ -80,10 +84,14 @@ def __init__(self, self._control_inputs = [] self._colocation_groups = [] # List[str] self._collection_names = set() # Set[str] + self._debug_info = debug_info def __repr__(self): # type: () -> str - return "Node[{}|{}]".format(self.name, self.op_type) + if self.op_type in PARTITIONED_CALL_OP_TYPES and self.has_attr("f"): + return "Node[{}({})|{}]".format(self.name, self.get_attr("f").name, self.op_type) + else: + return "Node[{}|{}]".format(self.name, self.op_type) @property def name(self): @@ -416,20 +424,22 @@ def to_node_def(self, target = None, add_shapes = True): for (attr_name, attr_value) in self._attributes: # Funky syntax for setting a field of a union in a protobuf target.attr[attr_name].CopyFrom( - util.python_type_to_attr_value(attr_value)) + util.python_type_to_attr_value(attr_value, attr_name)) if len(self._colocation_groups) > 0: # Serialize colocation groups. See docstring in getter for # colocation_groups property for more information. transformed_names = [_COLOCATION_PREFIX + name for name in self._colocation_groups] target.attr[_COLOCATION_ATTR_NAME].CopyFrom( - util.python_type_to_attr_value(transformed_names) + util.python_type_to_attr_value(transformed_names, _COLOCATION_ATTR_NAME) ) if add_shapes and self._outputs is not None and len(self._outputs) > 0: shapes_list = [t.shape for t in self._outputs] target.attr[_OUTPUT_SHAPES_ATTR_NAME].CopyFrom( - util.python_type_to_attr_value(shapes_list) + util.python_type_to_attr_value(shapes_list, _OUTPUT_SHAPES_ATTR_NAME) ) + if self._debug_info is not None: + target.experimental_debug_info.CopyFrom(self._debug_info) return target def get_attr(self, key): @@ -458,7 +468,7 @@ def get_attr(self, key): "under key '{}'".format(self, key)) ret = matches[0] if isinstance(ret, tf.AttrValue): - return util.attr_value_to_python_type(ret) + return util.attr_value_to_python_type(ret, key) else: return ret @@ -470,7 +480,7 @@ def has_attr(self, key): Returns True if the node has an attribute under the indicated key """ - return key in self._attributes + return key in self.get_attr_keys() def get_attr_keys(self): # type: () -> Tuple[str] @@ -689,7 +699,7 @@ def infer_outputs(self): # TODO(frreiss): If this op has a "T" attribute, set that too. - def set_inputs_from_strings(self, new_inputs, set_control_inputs = True): + def set_inputs_from_strings(self, new_inputs, set_control_inputs = True, output_map = None): # type: (Iterable[str], bool) -> None """ Set all input at once, converting TensorFlow string-format inputs into @@ -704,7 +714,7 @@ def set_inputs_from_strings(self, new_inputs, set_control_inputs = True): Otherwise , this method will ignore any strings that describe control inputs. """ - self._inputs = _decode_inputs(new_inputs, self._graph) + self._inputs = _decode_inputs(new_inputs, self._graph, output_map) if set_control_inputs: self._control_inputs = _decode_control_inputs(new_inputs, self._graph) self._graph.increment_version_counter() # New edges added to graph @@ -757,7 +767,8 @@ def _canonicalize_output_name(name): def _decode_inputs(inputs, # type: Iterable[str] - g # type: graph.Graph + g, # type: graph.Graph + outputs_map # type: Mapping[str:(...)] ): # type: (...) -> List[tensor.Tensor] """ @@ -779,6 +790,7 @@ def _decode_inputs(inputs, # type: Iterable[str] # "^node_name" --> Control input from indicated node # "node_name" --> Input from output number 0 of indicated node # "node_name:ix" --> Input from output number of indicated node + # "node_name:arg_def_name:ix" --> Input from arg_def's output number of indicated node # Start by filtering out the control inputs and turning "node_name" into # "node_name:0". input_names = [_canonicalize_output_name(n) for n in inputs @@ -786,9 +798,36 @@ def _decode_inputs(inputs, # type: Iterable[str] input_tensors = [] for name in input_names: # Name is in form "node:output number" - node_name, output_ix_name = name.split(":") - output_ix = int(output_ix_name) - input_tensors.append(g[node_name].output(output_ix)) + parts = name.split(":") + if len(parts) == 2: + node_name, output_ix_name = name.split(":") + output_ix = int(output_ix_name) + input_tensors.append(g[node_name].output(output_ix)) + elif len(parts) == 3: + # FuncGraph is using different format for input tensor definitions. + # See function_def_to_graph_def function in + # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/function_def_to_graph.py + node_name, arg_def_name, output_ix_name = name.split(":") + if outputs_map is None: + raise ValueError("output_map is not specified, can't decode op input") + if node_name not in outputs_map: + raise ValueError("op {} is not found in output_map".format(node_name)) + arg_def_names = [name for (_, _, name) in outputs_map[node_name]] + + local_index = int(output_ix_name) + global_index = None + for i in range(0, len(arg_def_names)): + if arg_def_names[i] == arg_def_name: + if local_index == 0: + global_index = i + break + else: + local_index = local_index-1 + if global_index is None: + raise ValueError("can't find output op corresponding to: {}".format(name)) + input_tensors.append(g[node_name].output(global_index)) + else: + raise ValueError("invalid input name format: {}".format(name)) return input_tensors diff --git a/graph_def_editor/rewrite.py b/graph_def_editor/rewrite.py index aff4b34..383b023 100644 --- a/graph_def_editor/rewrite.py +++ b/graph_def_editor/rewrite.py @@ -582,18 +582,23 @@ def compute_input_dim(n #type: node.Node raise ValueError("Unexpected op type {}".format(n.op_type)) pattern_1 = ( - TreeExpr(op="Conv2D|MatMul|DepthwiseConv2dNative", alias="conv", inputs=( - TreeExpr(op="Relu|Relu6", alias="relu", optional=True, inputs=( - TreeExpr(op="Add", alias="add", inputs=( - TreeExpr(op="Mul", alias="mul", inputs=( - TreeExpr(), - TreeExpr(op="Const", alias="mul_values") - )), - TreeExpr(op="Const", alias="add_values") - )) - )), - TreeExpr(op="Const", alias="weights"))) - ) + TreeExpr( + op="Conv2D|MatMul|DepthwiseConv2dNative", + alias="conv", + inputs=(TreeExpr( + op="Relu|Relu6", + alias="relu", + optional=True, + inputs=(TreeExpr( + op="Add|AddV2", + alias="add", + inputs=(TreeExpr( + op="Mul", + alias="mul", + inputs=(TreeExpr(), + TreeExpr(op="Const", alias="mul_values"))), + TreeExpr(op="Const", alias="add_values"))))), + TreeExpr(op="Const", alias="weights")))) def handle_relu6(relu6_op, scale): # type: (node.Node, np.ndarray) -> None diff --git a/graph_def_editor/util.py b/graph_def_editor/util.py index ec8bbc9..617d8b2 100644 --- a/graph_def_editor/util.py +++ b/graph_def_editor/util.py @@ -21,6 +21,7 @@ from __future__ import print_function import collections +import os import re import sys if sys.version >= '3': @@ -30,7 +31,7 @@ from six import iteritems, string_types import tensorflow.compat.v1 as tf -from graph_def_editor import graph, node, tensor +from graph_def_editor import base_graph, node, tensor __all__ = [ @@ -211,7 +212,7 @@ def get_unique_graph(tops, check_types=None, none_if_empty=False): TypeError: if tops is not a iterable of `gde.Node`. ValueError: if the graph is not unique. """ - if isinstance(tops, graph.Graph): + if isinstance(tops, base_graph.BaseGraph): return tops if not is_iterable(tops): raise TypeError("{} is not iterable".format(type(tops))) @@ -249,7 +250,7 @@ def make_list_of_op(ops, check_graph=True, allow_graph=True, ignore_ts=False): if `check_graph` is `True`, if all the ops do not belong to the same graph. """ - if isinstance(ops, graph.Graph): + if isinstance(ops, base_graph.BaseGraph): if allow_graph: return ops.nodes else: @@ -280,7 +281,7 @@ def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False): TypeError: if `ts` cannot be converted to a list of `gde.Tensor` or, if `check_graph` is `True`, if all the ops do not belong to the same graph. """ - if isinstance(ts, graph.Graph): + if isinstance(ts, base_graph.BaseGraph): if allow_graph: return ts.tensors else: @@ -335,7 +336,7 @@ class ControlOutputs(object): """The control outputs topology.""" def __init__(self, - g # type: graph.Graph + g # type: base_graph.BaseGraph ): """Create a dictionary of control-output dependencies. @@ -348,7 +349,7 @@ def __init__(self, Raises: TypeError: graph is not a `gde.Graph`. """ - if not isinstance(g, graph.Graph): + if not isinstance(g, base_graph.BaseGraph): raise TypeError("Expected a gde.Graph, got: {}".format(type(g))) self._control_outputs = {} self._graph = g @@ -446,7 +447,7 @@ def placeholder_name(t=None, scope=None, prefix=_DEFAULT_PLACEHOLDER_PREFIX): def make_placeholder_from_tensor( - g, # type: graph.Graph + g, # type: base_graph.BaseGraph t, # type: tensor.Tensor scope=None, prefix=_DEFAULT_PLACEHOLDER_PREFIX @@ -568,7 +569,8 @@ def func(top): def _python_type_to_attr_list_elem( list_value, # type: tf.AttrValue.ListValue - elem # type: Any + elem, # type: Any + attr_name # type: String ): """ Subroutine of python_type_to_attr_value(). Converts one element of a Python @@ -592,15 +594,16 @@ def _python_type_to_attr_list_elem( list_value.type.append(elem.as_datatype_enum) elif isinstance(elem, tf.TensorShape): list_value.shape.add().CopyFrom(elem.as_proto()) - elif isinstance(elem, np.ndarray): + elif isinstance(elem, np.ndarray) or isinstance(elem, list): list_value.tensor.add().CopyFrom(tf.make_tensor_proto(values=elem)) # TODO(frreiss): Populate the "func" field of the union here else: raise ValueError("Don't know how to convert a {} to " - "tf.AttrValue.ListValue".format(type(elem))) + "tf.AttrValue.ListValue for attribute {}".format(type(elem), attr_name)) -def python_type_to_attr_value(value #type: Any +def python_type_to_attr_value(value, #type: Any + attr_name #type: String ): # type (...) -> tf.AttrValue """ @@ -622,7 +625,7 @@ def python_type_to_attr_value(value #type: Any list_value = tf.AttrValue.ListValue() for elem in value: # TODO(frreiss): Should we disallow heterogeneous types in lists? - _python_type_to_attr_list_elem(list_value, elem) + _python_type_to_attr_list_elem(list_value, elem, attr_name) return tf.AttrValue(list=list_value) elif isinstance(value, tf.AttrValue): # TODO(frreiss): Should this case result in an error? @@ -647,10 +650,11 @@ def python_type_to_attr_value(value #type: Any # here else: raise ValueError("Don't know how to convert a {} to " - "tf.AttrValue".format(type(value))) + "tf.AttrValue for attribute {}".format(type(value)), attr_name) -def attr_value_to_python_type(attr_value # type: tf.AttrValue +def attr_value_to_python_type(attr_value, # type: tf.AttrValue + attr_name # type: String ): # type (...) -> Any """ @@ -681,14 +685,17 @@ def attr_value_to_python_type(attr_value # type: tf.AttrValue return tf.TensorShape(attr_value.shape) elif attr_value.HasField("tensor"): # TensorProto return tf.make_ndarray(attr_value.tensor) - # TODO(frreiss): Convert the "func" and "placeholder" fields of the union - # here + elif attr_value.HasField("list"): # list + return attr_value.list + elif attr_value.HasField("func"): # func + return attr_value.func + # TODO(frreiss): Convert the "placeholder" fields of the union here else: raise ValueError("Don't know how to convert AttrValue {} to " - "a Python object".format(attr_value)) + "a Python object for attribute {}".format(attr_value, attr_name)) -def load_variables_to_tf_graph(g # type: graph.Graph +def load_variables_to_tf_graph(g # type: base_graph.BaseGraph ): """ Convenience function to load all variables present in a `gde.Graph` into @@ -705,7 +712,7 @@ def load_variables_to_tf_graph(g # type: graph.Graph tf.add_to_collections(var.collection_names, tf_var) -def make_const(g, # type: graph.Graph +def make_const(g, # type: base_graph.BaseGraph name, # type: str value, # type: np.ndarray uniquify_name=False # type: bool @@ -731,7 +738,7 @@ def make_const(g, # type: graph.Graph return ret -def make_placeholder(g, # type: graph.Graph +def make_placeholder(g, # type: base_graph.BaseGraph name, # type: str dtype, # type: tf.DType shape, #type: tf.TensorShape @@ -757,7 +764,7 @@ def make_placeholder(g, # type: graph.Graph return ret -def make_identity(g, # type: graph.Graph +def make_identity(g, # type: base_graph.BaseGraph name, # type: str input, # type: tensor.Tensor uniquify_name=False # type: bool @@ -782,7 +789,7 @@ def make_identity(g, # type: graph.Graph return ret -def make_simple_binary_op(g, # type: graph.Graph +def make_simple_binary_op(g, # type: base_graph.BaseGraph name, # type: str op_name, # type: str input_1, # type: tensor.Tensor @@ -817,3 +824,28 @@ def make_simple_binary_op(g, # type: graph.Graph ret.set_inputs([input_1, input_2]) ret.infer_outputs() return ret + + +def copy_directory(oldpath, newpath, overwrite=False): + """Recursively copy a directory of files to GCS. + + Args: + oldpath: string, bytes, or os.PathLike; a pathname of a directory. + newpath: string, bytes, or os.PathLike; a pathname to which the directory + will be copied. + overwrite: boolean; if false, it is an error for newpath to be occupied by + an existing file. + """ + assert tf.gfile.IsDirectory(oldpath) + items = tf.gfile.Walk(oldpath) + for dirname, subdirs, filenames in items: + for subdir in subdirs: + tf.gfile.MakeDirs(os.path.join(dirname, subdir)) + full_subdir = os.path.join(dirname, subdir) + remote_dir_path = os.path.join(newpath, full_subdir[1 + len(oldpath) :]) + tf.gfile.MakeDirs(remote_dir_path) + for filename in filenames: + full_filename = os.path.join(dirname, filename) + remote_file_path = os.path.join(newpath, + full_filename[1 + len(oldpath) :]) + tf.gfile.Copy(full_filename, remote_file_path, overwrite=overwrite) diff --git a/graph_def_editor/variable.py b/graph_def_editor/variable.py index f585b74..5bb6226 100644 --- a/graph_def_editor/variable.py +++ b/graph_def_editor/variable.py @@ -19,7 +19,7 @@ import sys if sys.version >= '3': - from graph_def_editor import graph + from graph_def_editor import base_graph from typing import AbstractSet, Union @@ -43,7 +43,7 @@ class Variable(object): objects. This class tracks a similar set of pointers in protobuf land. """ def __init__(self, - g # type: graph.Graph + g # type: base_graph.BaseGraph ): """ Do not call this constructor directly. @@ -173,7 +173,11 @@ def validate(self, self._variable_name)) # self._initializer_name should reference a node. Other names should # reference tensors. - if not self.graph.contains_node(self._initializer_name): + _initializer_name = self._initializer_name + if _initializer_name and _initializer_name.rfind(":") > 0: + # Adding extra check in case _initializer_name refers to a tensor. + _initializer_name = _initializer_name[:_initializer_name.rfind(":")] + if not self.graph.contains_node(_initializer_name): raise ValueError("Initializer name '{}' does not correspond to any " "node in graph".format(self._initializer_name)) _ = self.graph.get_tensor_by_name(self._initial_value_name, diff --git a/tests/function_graph_test.py b/tests/function_graph_test.py new file mode 100644 index 0000000..4a1710c --- /dev/null +++ b/tests/function_graph_test.py @@ -0,0 +1,151 @@ +# Copyright 2021 Google. All Rights Reserved. +# Copyright 2019 IBM. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for function_graph.py in the GraphDef Editor.""" + +import unittest +import tensorflow.compat.v1 as tf +tf.disable_eager_execution() +import shutil +import tempfile +import numpy as np + +import graph_def_editor as gde + + +class FunctionGraphTest(unittest.TestCase): + + def setUp(self): + # Create a temporary directory for SavedModel files. + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + # Remove the directory after the test. + # Comment out this line to prevent deleting temps. + shutil.rmtree(self.temp_dir) + pass # In case previous line gets commented out + + def build_tf_graph(self): + """Builds a tf graph for function (x + y) * 10.0 .""" + @tf.function + def multiplier_function(x): + return tf.constant(10.0, name="function_multiplier") * x + + tf_g = tf.Graph() + with tf_g.as_default(): + x = tf.placeholder(name="x", dtype=tf.float32, shape=[]) + y = tf.placeholder(name="y", dtype=tf.float32, shape=[]) + result_op = tf.add(x, y, name="add") + _ = multiplier_function(result_op) + return tf_g + + def run_tf_graph(self, tf_g, x, y): + with tf.Session(graph=tf_g) as sess: + x_tensor = tf_g.get_tensor_by_name("x:0") + y_tensor = tf_g.get_tensor_by_name("y:0") + output_tensor = tf_g.get_tensor_by_name("PartitionedCall:0") + return sess.run(output_tensor, {x_tensor: x, y_tensor: y}) + + def save_tf_graph(self, tf_g, model_dir): + x_tensor = tf_g.get_tensor_by_name("x:0") + y_tensor = tf_g.get_tensor_by_name("y:0") + output_tensor = tf_g.get_tensor_by_name("PartitionedCall:0") + with tf.Session(graph=tf_g) as sess: + tf.saved_model.simple_save(sess, model_dir, + inputs={"x": x_tensor, "y": y_tensor}, + outputs={"out": output_tensor}) + + def test_function_rewrite(self): + tf_g = self.build_tf_graph() + self.assertEqual(30.0, self.run_tf_graph(tf_g, 1.0, 2.0)) + graph = gde.Graph(tf_g) + add_op = graph.get_node_by_name("add") + function_name = add_op.outputs[0].consumers()[0].get_attr("f").name + self.assertIn(function_name, graph.function_names) + + function_graph = graph.get_function_graph_by_name(function_name) + function_multiplier_op = \ + function_graph.get_node_by_name("function_multiplier") + self.assertEqual(10.0, function_multiplier_op.get_attr("value")) + function_multiplier_op.replace_attr("value", + np.array(1000.0, dtype=np.float32)) + + self.assertEqual(3000.0, self.run_tf_graph(graph.to_tf_graph(), 1.0, 2.0)) + return graph + + def test_export_saved_model(self): + g = self.test_function_rewrite() + model_dir = self.temp_dir + "/saved_model" + g.to_saved_model(model_dir) + tf_g = tf.Graph() + with tf.Session(graph=tf_g) as sess: + _ = tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING], + model_dir) + self.assertEqual(3000.0, self.run_tf_graph(tf_g, 1.0, 2.0)) + + def test_import_saved_model(self): + g = self.test_function_rewrite() + model_dir = self.temp_dir + "/saved_model" + self.save_tf_graph(g.to_tf_graph(), model_dir) + + g = gde.saved_model_to_graph(model_dir) + self.assertEqual(3000.0, self.run_tf_graph(g.to_tf_graph(), 1.0, 2.0)) + + def test_number_attr_support(self): + model_dir = self.temp_dir + "/saved_model" + + @tf.function + def test_function(c): + cdim = tf.constant(1, tf.int32) + c1 = tf.constant([2, 1, 5], tf.int32, name="FuncConst") + c2 = tf.constant([2, 1, 5], tf.int32) + # ConcatOffset has variable number of intputs and outputs + # that is using number_attr in functions + concat_offset = tf.raw_ops.ConcatOffset( + concat_dim=cdim, shape=[c, c1, c2]) + out = tf.math.reduce_sum(concat_offset) + return out + + tf_g = tf.Graph() + with tf_g.as_default(): + with tf.Session() as sess: + c = tf.placeholder(name="c", dtype=tf.int32) + out_func = test_function(c) + c = tf_g.get_tensor_by_name("c:0") + self.assertEqual(3, sess.run(out_func, {c: [2, 1, 5]})) + + tf.saved_model.simple_save( + sess, model_dir, inputs={"c": c}, outputs={"out_func": out_func}) + + g = gde.saved_model_to_graph(model_dir) + + tf_g = g.to_tf_graph() + with tf.Session(graph=tf_g) as sess: + output_tensor = tf_g.get_tensor_by_name("PartitionedCall:0") + c = tf_g.get_tensor_by_name("c:0") + self.assertEqual(3, sess.run(output_tensor, {c: [2, 1, 5]})) + + f = g.get_function_graph_by_name(g.function_names[0]) + func_const_op = f.get_node_by_name("FuncConst") + func_const_op.replace_attr("value", np.array([2, 2, 5], dtype=np.int32)) + + tf_g = g.to_tf_graph() + with tf.Session(graph=tf_g) as sess: + output_tensor = tf_g.get_tensor_by_name("PartitionedCall:0") + c = tf_g.get_tensor_by_name("c:0") + self.assertEqual(4, sess.run(output_tensor, {c: [2, 1, 5]})) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/graph_test.py b/tests/graph_test.py index dbd2439..913e58e 100644 --- a/tests/graph_test.py +++ b/tests/graph_test.py @@ -39,6 +39,51 @@ def tearDown(self): shutil.rmtree(self.temp_dir) pass # In case previous line gets commented out + def build_graph(self): + tf_g = tf.Graph() + with tf_g.as_default(): + a = tf.constant(1, name="a") + b = tf.constant(2, name="b") + c = tf.constant(10, name="c") + add_res = tf.add(a, b, name="add") + res = tf.multiply(add_res, c, name="mult") + g = gde.Graph(g=tf_g) + return g + + def build_graph_with_function(self): + """Builds a tf graph for function (x + y) * 10.0 .""" + @tf.function + def multiplier_function(v): + return tf.constant(10.0, name="function_multiplier") * v + + tf_g = tf.Graph() + with tf_g.as_default(): + x = tf.placeholder(name="x", dtype=tf.float32, shape=[]) + y = tf.placeholder(name="y", dtype=tf.float32, shape=[]) + result_op = tf.add(x, y, name="add") + func_call_op = multiplier_function(result_op) + _ = tf.identity(func_call_op, name="output") + return gde.Graph(g=tf_g) + + def build_graph_with_nested_function_call(self): + """Builds a tf graph for function (x + y) * 10.0 .""" + @tf.function + def adder_function(a, b): + return a + b + + @tf.function + def multiplier_function(a, b): + v = adder_function(a, b) + return tf.constant(10.0, name="function_multiplier") * v + + tf_g = tf.Graph() + with tf_g.as_default(): + x = tf.placeholder(name="x", dtype=tf.float32, shape=[]) + y = tf.placeholder(name="y", dtype=tf.float32, shape=[]) + func_call_op = multiplier_function(x, y) + _ = tf.identity(func_call_op, name="output") + return gde.Graph(g=tf_g) + def test_import_saved_model(self): tf_g = tf.Graph() with tf_g.as_default(): @@ -233,6 +278,310 @@ def test_node_collection_type_unique(self): with self.assertRaisesRegex(TypeError, "Node collections cannot be Nodes and Tensors.*"): g.get_collection_by_name("mixed_collection") + def test_nodes_iterator(self): + g = self.build_graph_with_function() + self.assertEqual( + {g.get_node_by_name("x"), + g.get_node_by_name("y"), + g.get_node_by_name("add"), + g.get_node_by_name("PartitionedCall"), + g.get_node_by_name("output")}, + set(g.nodes_iterator())) + + def test_nodes_iterator_predicate(self): + g = self.build_graph_with_function() + self.assertEqual( + {g.get_node_by_name("x"), + g.get_node_by_name("y")}, + set(g.nodes_iterator(predicate=lambda n: n.op_type == "Placeholder"))) + + def test_nodes_iterator_iterate_functions(self): + g = self.build_graph_with_function() + f = g.get_function_graph_by_name(g.function_names[0]) + self.assertEqual( + {g.get_node_by_name("x"), + g.get_node_by_name("y"), + g.get_node_by_name("add"), + g.get_node_by_name("PartitionedCall"), + g.get_node_by_name("output"), + f.get_node_by_name("function_multiplier"), + f.get_node_by_name("mul"), + f.get_node_by_name("Identity"), + f.get_node_by_name("v")}, + set(g.nodes_iterator(iterate_functions=True))) + + def test_breadth_first_visitor(self): + g = self.build_graph() + nodes_in_bfs = [] + def visit(node): + nodes_in_bfs.append(node) + def visit_with_break(node): + nodes_in_bfs.append(node) + return True + g.breadth_first_visitor(visit) + self.assertEqual( + [g.get_node_by_name("a"), + g.get_node_by_name("b"), + g.get_node_by_name("c"), + g.get_node_by_name("add"), + g.get_node_by_name("mult")], + nodes_in_bfs) + + nodes_in_bfs = [] + g.breadth_first_visitor(visit, + starting_nodes=[g.get_node_by_name("a")]) + self.assertEqual( + [g.get_node_by_name("a"), + g.get_node_by_name("add"), + g.get_node_by_name("mult")], + nodes_in_bfs) + + nodes_in_bfs = [] + g.breadth_first_visitor(visit, + starting_nodes=[g.get_node_by_name("c")]) + self.assertEqual( + [g.get_node_by_name("c"), + g.get_node_by_name("mult")], + nodes_in_bfs) + + nodes_in_bfs = [] + g.breadth_first_visitor(visit_with_break, + starting_nodes=[g.get_node_by_name("c")]) + self.assertEqual( + [g.get_node_by_name("c")], + nodes_in_bfs) + + def test_breadth_first_visitor_iterate_functions(self): + g = self.build_graph_with_function() + nodes_in_bfs = [] + def visit(node): + nodes_in_bfs.append(node) + g.breadth_first_visitor( + visit, + starting_nodes=[g.get_node_by_name("x"), g.get_node_by_name("y")]) + self.assertEqual( + [g.get_node_by_name("x"), + g.get_node_by_name("y"), + g.get_node_by_name("add"), + g.get_node_by_name("PartitionedCall"), + g.get_node_by_name("output")], + nodes_in_bfs) + + nodes_in_bfs = [] + f = g.get_function_graph_by_name(g.function_names[0]) + g.breadth_first_visitor( + visit, + starting_nodes=[g.get_node_by_name("x"), g.get_node_by_name("y")], + iterate_functions=True) + self.assertEqual( + [g.get_node_by_name("x"), + g.get_node_by_name("y"), + g.get_node_by_name("add"), + g.get_node_by_name("PartitionedCall"), + f.get_node_by_name("mul"), + f.get_node_by_name("Identity"), + g.get_node_by_name("output")], + nodes_in_bfs) + + def test_breadth_first_visitor_escape_functions(self): + g = self.build_graph_with_function() + nodes_in_bfs = [] + def visit(node): + nodes_in_bfs.append(node) + f = g.get_function_graph_by_name(g.function_names[0]) + g.breadth_first_visitor( + visit, + starting_nodes=[f.get_node_by_name("function_multiplier")]) + self.assertEqual( + [f.get_node_by_name("function_multiplier"), + f.get_node_by_name("mul"), + f.get_node_by_name("Identity")], + nodes_in_bfs) + + nodes_in_bfs = [] + f = g.get_function_graph_by_name(g.function_names[0]) + g.breadth_first_visitor( + visit, + starting_nodes=[f.get_node_by_name("function_multiplier")], + escape_functions=True) + self.assertEqual( + [f.get_node_by_name("function_multiplier"), + f.get_node_by_name("mul"), + f.get_node_by_name("Identity"), + g.get_node_by_name("output")], + nodes_in_bfs) + + def test_breadth_first_visitor_escape_nested_functions(self): + g = self.build_graph_with_nested_function_call() + nodes_in_bfs = [] + def visit(node): + nodes_in_bfs.append(node) + + nodes_in_bfs = [] + f = g.get_function_graph_by_name(g.function_names[0]) + g.breadth_first_visitor( + visit, + starting_nodes=[f.get_node_by_name("function_multiplier")], + iterate_functions=True, + escape_functions=True) + self.assertEqual( + [f.get_node_by_name("function_multiplier"), + f.get_node_by_name("mul"), + f.get_node_by_name("Identity"), + g.get_node_by_name("output")], + nodes_in_bfs) + + def test_breadth_first_visitor_escape_nested_functions(self): + g = self.build_graph_with_nested_function_call() + nodes_in_bfs = [] + def visit(node): + nodes_in_bfs.append(node) + + add_node = list(g.nodes_iterator(lambda n:n.name=='add', iterate_functions=True))[0] + multiplier_function_name = g.get_node_by_name("x").outputs[0].consumers()[0].get_attr('f').name + multiplier_function_graph = g.get_function_graph_by_name(multiplier_function_name) + adder_function_graph = add_node.graph + nodes_in_bfs = [] + g.breadth_first_visitor( + visit, + starting_nodes=[add_node], + iterate_functions=True, + escape_functions=True) + self.assertEqual( + [add_node, + add_node.graph.get_node_by_name("Identity"), + multiplier_function_graph.get_node_by_name("mul"), + multiplier_function_graph.get_node_by_name("Identity"), + g.get_node_by_name("output")], + nodes_in_bfs) + + def test_backwards_breadth_first_visitor(self): + g = self.build_graph() + nodes_in_backwards_bfs = [] + def visit(node): + nodes_in_backwards_bfs.append(node) + def visit_with_break(node): + nodes_in_backwards_bfs.append(node) + return True + g.backwards_breadth_first_visitor( + visit, + starting_nodes=[g.get_node_by_name("mult")]) + self.assertEqual( + [g.get_node_by_name("mult"), + g.get_node_by_name("add"), + g.get_node_by_name("c"), + g.get_node_by_name("a"), + g.get_node_by_name("b")], + nodes_in_backwards_bfs) + + nodes_in_backwards_bfs = [] + g.backwards_breadth_first_visitor( + visit, + starting_nodes=[g.get_node_by_name("add")]) + self.assertEqual( + [g.get_node_by_name("add"), + g.get_node_by_name("a"), + g.get_node_by_name("b")], + nodes_in_backwards_bfs) + + nodes_in_backwards_bfs = [] + g.backwards_breadth_first_visitor( + visit_with_break, + starting_nodes=[g.get_node_by_name("add")]) + self.assertEqual( + [g.get_node_by_name("add")], + nodes_in_backwards_bfs) + + def test_backwards_breadth_first_visitor_iterate_functions(self): + g = self.build_graph_with_function() + nodes_in_backwards_bfs = [] + def visit(node): + nodes_in_backwards_bfs.append(node) + g.backwards_breadth_first_visitor( + visit, + starting_nodes=[g.get_node_by_name("output")]) + self.assertEqual( + [g.get_node_by_name("output"), + g.get_node_by_name("PartitionedCall"), + g.get_node_by_name("add"), + g.get_node_by_name("x"), + g.get_node_by_name("y")], + nodes_in_backwards_bfs) + + nodes_in_backwards_bfs = [] + f = g.get_function_graph_by_name(g.function_names[0]) + g.backwards_breadth_first_visitor( + visit, + starting_nodes=[g.get_node_by_name("output")], + iterate_functions=True) + self.assertEqual( + [g.get_node_by_name("output"), + g.get_node_by_name("PartitionedCall"), + f.get_node_by_name("Identity"), + f.get_node_by_name("mul"), + f.get_node_by_name("function_multiplier"), + g.get_node_by_name("add"), + g.get_node_by_name("x"), + g.get_node_by_name("y")], + nodes_in_backwards_bfs) + + def test_backwards_breadth_first_visitor_escape_functions(self): + g = self.build_graph_with_function() + nodes_in_backwards_bfs = [] + def visit(node): + nodes_in_backwards_bfs.append(node) + f = g.get_function_graph_by_name(g.function_names[0]) + g.backwards_breadth_first_visitor( + visit, + starting_nodes=[f.get_node_by_name("Identity")]) + self.assertEqual( + [f.get_node_by_name("Identity"), + f.get_node_by_name("mul"), + f.get_node_by_name("function_multiplier")], + nodes_in_backwards_bfs) + + nodes_in_backwards_bfs = [] + f = g.get_function_graph_by_name(g.function_names[0]) + g.backwards_breadth_first_visitor( + visit, + starting_nodes=[f.get_node_by_name("Identity")], + escape_functions=True) + self.assertEqual( + [f.get_node_by_name("Identity"), + f.get_node_by_name("mul"), + f.get_node_by_name("function_multiplier"), + g.get_node_by_name("PartitionedCall"), + g.get_node_by_name("add"), + g.get_node_by_name("x"), + g.get_node_by_name("y")], + nodes_in_backwards_bfs) + + def test_backwards_breadth_first_visitor_escape_nested_functions(self): + g = self.build_graph_with_nested_function_call() + nodes_in_backwards_bfs = [] + def visit(node): + nodes_in_backwards_bfs.append(node) + + add_node = list(g.nodes_iterator(lambda n:n.name=='add', iterate_functions=True))[0] + multiplier_function_name = g.get_node_by_name("x").outputs[0].consumers()[0].get_attr('f').name + multiplier_function_graph = g.get_function_graph_by_name(multiplier_function_name) + + adder_call_op = list(g.nodes_iterator(lambda n:n.op_type=='PartitionedCall' and n.get_attr('f').name == add_node.graph.name, iterate_functions=True))[0] + multiplier_call_op = list(g.nodes_iterator(lambda n:n.op_type=='PartitionedCall' and n.get_attr('f').name != add_node.graph.name, iterate_functions=True))[0] + + g.backwards_breadth_first_visitor( + visit, + starting_nodes=[add_node], + iterate_functions=True, + escape_functions=True) + self.assertEqual( + [add_node, + adder_call_op, + multiplier_call_op, + g.get_node_by_name("x"), + g.get_node_by_name("y")], + nodes_in_backwards_bfs) + if __name__ == "__main__": unittest.main() diff --git a/tests/match_test.py b/tests/match_test.py index 6f6303d..ab4107b 100644 --- a/tests/match_test.py +++ b/tests/match_test.py @@ -54,7 +54,9 @@ def test_simple_match(self): gde.OpMatcher("^.*/f$").input_ops(True, "^.*/d$")(self.f_op)) self.assertTrue( gde.OpMatcher("^.*/f$").input_ops( - gde.op_type("Add"), gde.op_type("Const"))(self.f_op)) + gde.op_type("Add"), gde.op_type("Const"))(self.f_op) or + gde.OpMatcher("^.*/f$").input_ops( + gde.op_type("AddV2"), gde.op_type("Const"))(self.f_op)) self.assertTrue( gde.OpMatcher("^.*/f$").input_ops("^.*/c$", "^.*/d$") .output_ops(gde.OpMatcher("^.*/h$") diff --git a/tests/select_test.py b/tests/select_test.py index 49e9673..b014128 100644 --- a/tests/select_test.py +++ b/tests/select_test.py @@ -75,7 +75,8 @@ def test_get_filter(self): len(gde.filter_ops(self.graph, lambda op: op.op_type == "Const")), 3) self.assertEqual( - len(gde.filter_ops(self.graph, lambda op: op.op_type == "Add")), 5) + len(gde.filter_ops(self.graph, + lambda op: op.op_type in ["Add", "AddV2"])), 5) self.assertEqual( len(gde.filter_ops_from_regex(self.graph, r"^.*\b[abc]$")), 3) diff --git a/tests/transform_test.py b/tests/transform_test.py index 2f7ed3d..6a08ddd 100644 --- a/tests/transform_test.py +++ b/tests/transform_test.py @@ -281,6 +281,7 @@ def test_graph_replace_missing(self): self.assertEqual(res[0].name, "b:0") self.assertEqual(res[1].name, "c_1:0") + @unittest.skipIf(tf.version.VERSION[0] == "2", "not supported in TF2.x") def test_graph_replace_gradients(self): tmp_graph = tf.Graph() with tmp_graph.as_default(): @@ -310,6 +311,7 @@ def test_graph_replace_gradients(self): self.assertNear(g_val, 0.0, ERROR_TOLERANCE) self.assertNear(res_val, 0.0, ERROR_TOLERANCE) + @unittest.skipIf(tf.version.VERSION[0] == "2", "not supported in TF2.x") def test_graph_while_loop(self): tf_graph = tf.Graph() with tf_graph.as_default(): @@ -340,6 +342,7 @@ def test_graph_while_loop(self): feed_dict={copied_max_index_tensor.name: n}) self.assertEqual(sum_val, 55) + @unittest.skipIf(tf.version.VERSION[0] == "2", "not supported in TF2.x") def test_graph_cond(self): tf_g = tf.Graph() with tf_g.as_default(): From a2993b5a5e7f71da9847444ba8041f82ac47c211 Mon Sep 17 00:00:00 2001 From: Aleksey Vlasenko Date: Fri, 6 Aug 2021 11:17:14 -0700 Subject: [PATCH 2/2] addressed feedback --- graph_def_editor/base_graph.py | 7 ++++++- graph_def_editor/function_graph.py | 8 +------- graph_def_editor/graph.py | 13 ++++++++++--- graph_def_editor/node.py | 4 ++-- 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/graph_def_editor/base_graph.py b/graph_def_editor/base_graph.py index 4f295e5..f06b134 100644 --- a/graph_def_editor/base_graph.py +++ b/graph_def_editor/base_graph.py @@ -53,7 +53,9 @@ def __init__( name = None, # type: str ): """ - Wrap a tf.GraphDef protocol buffer in a Graph object. + Constructor to be called by subclasses only. + + Initializes attributes of this base class. Args: name: Optional human-readable name for the graph. If not provided, @@ -137,6 +139,8 @@ def add_node(self, uniquify_name: Generate a unique name from this name if the graph already has a node with the indicated name. If False, raise an exception if the name is in use. + debug_info: Some internal TensorFlow debug information. + We just pass it through for safety. Returns: `MutableNode` wrapper for the new node. @@ -265,6 +269,7 @@ def add_variable(self, name): v = variable.Variable(self) v.name = name self._variable_name_to_variable[name] = v + self.increment_version_counter() return v def add_variable_from_variable_def(self, variable_def, diff --git a/graph_def_editor/function_graph.py b/graph_def_editor/function_graph.py index afc0736..dc412d0 100644 --- a/graph_def_editor/function_graph.py +++ b/graph_def_editor/function_graph.py @@ -47,12 +47,7 @@ class FunctionGraph(base_graph.BaseGraph): - """Wrapper class for TensorFlow function graphs. - - Summary of internal data structures: - * _node_name_to_node: Nodes in the graph, stored as a dictionary. Key is name. - * _version: Counter that increments every time the graph is modified - """ + """Wrapper class for TensorFlow function graphs.""" def __init__( self, @@ -78,7 +73,6 @@ def __init__( [(dtype, shape) for (dtype, shape, _) in tuples] # Populate fields of object - self._node_name_to_node = {} # Dict[str, node.Node]; key is node name self._node_to_frame_names = None self._frame_name_to_nodes = None self._head_name_to_coloc_group = None # Dict[str, FrozenList[str]] diff --git a/graph_def_editor/graph.py b/graph_def_editor/graph.py index 54d580d..bf2d004 100644 --- a/graph_def_editor/graph.py +++ b/graph_def_editor/graph.py @@ -311,10 +311,11 @@ def to_graph_def(self, add_shapes=True): for op in self.nodes: op.to_node_def(ret.node.add(), add_shapes) - # Pass through library without modifications for now. + # Copy library as is. if self._graph_def and self._graph_def.library: ret.library.CopyFrom(self._graph_def.library) + # Update functions in library that were instantiated as function graphs. for f_name, f_graph in self._function_graphs.items(): function_index_to_update = None for index in range(0, len(ret.library.function)): @@ -716,6 +717,9 @@ def breadth_first_visitor( """ Visit all nodes reachable from a starting set in the order of a breadth-first traversal (going from node to output edges). + If visitor gets to a function call, and iterate_functions is True, + it will iterate all function nodes first and then continue with + remaining nodes in the graph. Invokes a callback at each node visited. Args: @@ -731,7 +735,7 @@ def breadth_first_visitor( that we should also iterate through all function callers up the stack. Returns: - True if iteration was iterruputed by visitor, otherwise False. + True if iteration was interrupted by visitor, otherwise False. """ if starting_nodes is None: # Start with all of the nodes in the graph that have no inputs. @@ -818,6 +822,9 @@ def backwards_breadth_first_visitor( """ Visit all nodes reachable from a starting set in the order of a backwards breadth-first traversal (going from node to input edges). + If visitor gets to a function call, and iterate_functions is True, + it will iterate all function nodes first and then continue with + remaining nodes in the graph. Invokes a callback at each node visited. Args: @@ -830,7 +837,7 @@ def backwards_breadth_first_visitor( that we should also iterate through all function callers up the stack. Returns: - True if iteration was iterruputed by visitor, otherwise False. + True if iteration was interrupted by visitor, otherwise False. """ if not starting_nodes: raise ValueError("starting_nodes is not provided") diff --git a/graph_def_editor/node.py b/graph_def_editor/node.py index bb312df..943e893 100644 --- a/graph_def_editor/node.py +++ b/graph_def_editor/node.py @@ -23,7 +23,7 @@ if sys.version >= '3': from typing import Tuple, List, Iterable, Any, AbstractSet, Type -from graph_def_editor import tensor, util +from graph_def_editor import base_graph, tensor, util # Magical attribute name that TensorFlow uses to store colocation groups. # See colocation_groups property below for more information. @@ -87,7 +87,7 @@ def __init__(self, device: TensorFlow device specification string indicating where this node should be located. Default value of "" means "use the default device" """ - _type_check(g, graph.Graph, "g") + _type_check(g, base_graph.BaseGraph, "g") _type_check(node_id, int, "node_id") _type_check(name, str, "name") _type_check(op_name, str, "op_name")