diff --git a/serket/nn/utils.py b/serket/nn/utils.py index 85ff45c..a43822c 100644 --- a/serket/nn/utils.py +++ b/serket/nn/utils.py @@ -364,6 +364,23 @@ def inner(instance, *a, **k): return inner +LAZY_CALL_ERROR = """ +Cannot call ``{func_name}`` directly on a lazy layer. +use ``layer.at['{func_name}'](...)`` instead to return a tuple of: + - The layer output. + - Materialized layer. + +Example: + >>> layer = {class_name}(...) + >>> layer(x) # this will raise an error + Traceback (most recent call last): + ... + >>> _, materialized_layer = layer.at['{func_name}'](x) + >>> materialized_layer(x) + ... +""".lstrip() + + def maybe_lazy_call( func: Callable[P, T], is_lazy: Callable[..., bool], @@ -395,9 +412,16 @@ def inner(instance, *a, **k): kwargs[key] = update(instance, *a, **k) for key in kwargs: - # clear the instance information (i.e. the initial input arguments) - # use ``delattr`` to raise an error if the instance is immutable - delattr(instance, key) + try: + # clear the instance information (i.e. the initial input arguments) + # use ``delattr`` to raise an error if the instance is immutable + delattr(instance, key) + except AttributeError: + # the instance is lazy and immutable + func_name = func.__name__ + class_name = type(instance).__name__ + kwargs = dict(func_name=func_name, class_name=class_name) + raise RuntimeError(LAZY_CALL_ERROR.format(**kwargs)) # re-initialize the instance with the resolved arguments getattr(type(instance), "__init__")(instance, **kwargs)