Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 1, 2023
1 parent ed369a5 commit 70c4f2d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
2 changes: 1 addition & 1 deletion serket/nn/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ def __init__(
groups: int = 1,
key: jr.KeyArray = jr.PRNGKey(0),
):
self.in_features = in_features
self.in_features = positive_int_or_none_cb(in_features)
self.out_features = positive_int_cb(out_features)
self.kernel_size = canonicalize(kernel_size, self.spatial_ndim, "kernel_size")
self.strides = canonicalize(strides, self.spatial_ndim, "strides")
Expand Down
18 changes: 10 additions & 8 deletions serket/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,21 +324,23 @@ def maybe_lazy_call(
is_lazy: Callable[..., bool],
updates: dict[str, Callable[..., Any]],
) -> Callable[P, T]:
"""Reinitialize the instance if it is lazy."""

@ft.wraps(func)
def inner(self, *a, **k):
if not is_lazy(self):
return func(self, *a, **k)
def inner(instance, *a, **k):
if not is_lazy(instance, *a, **k):
return func(instance, *a, **k)

kwargs = dict(vars(self))
kwargs = dict(vars(instance))
for key, update in updates.items():
kwargs[key] = update(self, *a, **k)
kwargs[key] = update(instance, *a, **k)

# clear the instance information
for key in kwargs:
delattr(self, key)
delattr(instance, key)
# re-initialize the instance
getattr(type(self), "__init__")(self, **kwargs)
getattr(type(instance), "__init__")(instance, **kwargs)
# call the decorated function
return func(self, *a, **k)
return func(instance, *a, **k)

return inner

0 comments on commit 70c4f2d

Please sign in to comment.