Skip to content

Commit

Permalink
add lazy error
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 1, 2023
1 parent f941176 commit edc4a04
Showing 1 changed file with 27 additions and 3 deletions.
30 changes: 27 additions & 3 deletions serket/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,23 @@ def inner(instance, *a, **k):
return inner


LAZY_CALL_ERROR = """
Cannot call ``{func_name}`` directly on a lazy layer.
use ``layer.at['{func_name}'](...)`` instead to return a tuple of:
- The layer output.
- Materialized layer.
Example:
>>> layer = {class_name}(...)
>>> layer(x) # this will raise an error
Traceback (most recent call last):
...
>>> _, materialized_layer = layer.at['{func_name}'](x)
>>> materialized_layer(x)
...
""".lstrip()


def maybe_lazy_call(
func: Callable[P, T],
is_lazy: Callable[..., bool],
Expand Down Expand Up @@ -395,9 +412,16 @@ def inner(instance, *a, **k):
kwargs[key] = update(instance, *a, **k)

for key in kwargs:
# clear the instance information (i.e. the initial input arguments)
# use ``delattr`` to raise an error if the instance is immutable
delattr(instance, key)
try:
# clear the instance information (i.e. the initial input arguments)
# use ``delattr`` to raise an error if the instance is immutable
delattr(instance, key)
except AttributeError:
# the instance is lazy and immutable
func_name = func.__name__
class_name = type(instance).__name__
kwargs = dict(func_name=func_name, class_name=class_name)
raise RuntimeError(LAZY_CALL_ERROR.format(**kwargs))

# re-initialize the instance with the resolved arguments
getattr(type(instance), "__init__")(instance, **kwargs)
Expand Down

0 comments on commit edc4a04

Please sign in to comment.