diff --git a/serket/nn/convolution.py b/serket/nn/convolution.py index 52681e6..1ed800e 100644 --- a/serket/nn/convolution.py +++ b/serket/nn/convolution.py @@ -851,7 +851,7 @@ def __init__( groups: int = 1, key: jr.KeyArray = jr.PRNGKey(0), ): - self.in_features = in_features + self.in_features = positive_int_or_none_cb(in_features) self.out_features = positive_int_cb(out_features) self.kernel_size = canonicalize(kernel_size, self.spatial_ndim, "kernel_size") self.strides = canonicalize(strides, self.spatial_ndim, "strides") diff --git a/serket/nn/utils.py b/serket/nn/utils.py index e493155..a305145 100644 --- a/serket/nn/utils.py +++ b/serket/nn/utils.py @@ -324,21 +324,23 @@ def maybe_lazy_call( is_lazy: Callable[..., bool], updates: dict[str, Callable[..., Any]], ) -> Callable[P, T]: + """Reinitialize the instance if it is lazy.""" + @ft.wraps(func) - def inner(self, *a, **k): - if not is_lazy(self): - return func(self, *a, **k) + def inner(instance, *a, **k): + if not is_lazy(instance, *a, **k): + return func(instance, *a, **k) - kwargs = dict(vars(self)) + kwargs = dict(vars(instance)) for key, update in updates.items(): - kwargs[key] = update(self, *a, **k) + kwargs[key] = update(instance, *a, **k) # clear the instance information for key in kwargs: - delattr(self, key) + delattr(instance, key) # re-initialize the instance - getattr(type(self), "__init__")(self, **kwargs) + getattr(type(instance), "__init__")(instance, **kwargs) # call the decorated function - return func(self, *a, **k) + return func(instance, *a, **k) return inner