Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
vlasenkoalexey authored Aug 6, 2021
2 parents f74ccda + c9e5a26 commit 90b77a7
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion graph_def_editor/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import tensorflow.compat.v1 as tf
import sys
if sys.version >= '3':
from typing import Tuple, List, Iterable, Any, AbstractSet
from typing import Tuple, List, Iterable, Any, AbstractSet, Type

from graph_def_editor import tensor, util

Expand All @@ -45,6 +45,20 @@
PARTITIONED_CALL_OP_TYPES = frozenset([
"PartitionedCall", "StatefulPartitionedCall", "TPUPartitionedCall"])

def _type_check(obj: Any, expected_type: Type, arg_name: str):
"""
Subroutine to validate the type of a function argument.
Args:
obj: Argument to validate
expected_type: Expected type of `obj`
arg_name: Name of the function argument where the caller passed `obj`
"""
if not isinstance(obj, expected_type):
raise TypeError("'{}' argument should be of type {}, but got object of "
"type {} instead. Value received was {}."
"".format(arg_name, expected_type, type(obj), obj))


class Node(object):
"""
Expand Down Expand Up @@ -73,6 +87,11 @@ def __init__(self,
device: TensorFlow device specification string indicating where this node
should be located. Default value of "" means "use the default device"
"""
_type_check(g, graph.Graph, "g")
_type_check(node_id, int, "node_id")
_type_check(name, str, "name")
_type_check(op_name, str, "op_name")
_type_check(device, str, "device")
self._graph = g
self._id = node_id
self._name = name
Expand Down

0 comments on commit 90b77a7

Please sign in to comment.