diff --git a/graph_def_editor/node.py b/graph_def_editor/node.py index 81ed00a..0732555 100644 --- a/graph_def_editor/node.py +++ b/graph_def_editor/node.py @@ -21,7 +21,7 @@ import tensorflow.compat.v1 as tf import sys if sys.version >= '3': - from typing import Tuple, List, Iterable, Any, AbstractSet + from typing import Tuple, List, Iterable, Any, AbstractSet, Type from graph_def_editor import graph, tensor, util @@ -42,6 +42,20 @@ "Node", ] +def _type_check(obj: Any, expected_type: Type, arg_name: str): + """ + Subroutine to validate the type of a function argument. + + Args: + obj: Argument to validate + expected_type: Expected type of `obj` + arg_name: Name of the function argument where the caller passed `obj` + """ + if not isinstance(obj, expected_type): + raise TypeError("'{}' argument should be of type {}, but got object of " + "type {} instead. Value received was {}." + "".format(arg_name, expected_type, type(obj), obj)) + class Node(object): """ @@ -50,11 +64,11 @@ class Node(object): tf.NodeDef protobuf on demand. """ def __init__(self, - g, # type: graph.Graph - node_id, # type: int - name, # type: int - op_name, # type: str - device = "" # type: str + g, # type: graph.Graph + node_id, # type: int + name, # type: str + op_name, # type: str + device="" # type: str ): """ This constructor should only be called from methods of the Graph @@ -69,6 +83,11 @@ def __init__(self, device: TensorFlow device specification string indicating where this node should be located. Default value of "" means "use the default device" """ + _type_check(g, graph.Graph, "g") + _type_check(node_id, int, "node_id") + _type_check(name, str, "name") + _type_check(op_name, str, "op_name") + _type_check(device, str, "device") self._graph = g self._id = node_id self._name = name