diff --git a/serket/nn/convolution.py b/serket/nn/convolution.py index 47769dc..468352b 100644 --- a/serket/nn/convolution.py +++ b/serket/nn/convolution.py @@ -37,7 +37,7 @@ calculate_transpose_padding, canonicalize, delayed_canonicalize_padding, - lazy_call, + maybe_lazy_call, positive_int_cb, positive_int_or_none_cb, validate_axis_shape, @@ -155,7 +155,7 @@ def fft_conv_general_dilated( def is_lazy(instance, *_, **__) -> bool: - return instance.in_features is None + return getattr(instance, "in_features", False) is None def infer_in_features(instance, x, *_, **__) -> int: @@ -205,15 +205,9 @@ def __init__( self.groups = positive_int_cb(groups) if in_features is None: - # store the key for lazy initialization only self.key = key - # going to be initialized lazily return - if "key" in vars(self): - # coming from lazy initialization - del self.key - if self.out_features % self.groups != 0: raise ValueError(f"{(out_features % groups == 0)=}") @@ -231,7 +225,7 @@ def spatial_ndim(self) -> int: class ConvND(BaseConvND): - @ft.partial(lazy_call, is_lazy=is_lazy, updates=conv_updates) + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array, **k) -> jax.Array: @@ -456,7 +450,7 @@ def spatial_ndim(self) -> int: class FFTConvND(BaseConvND): - @ft.partial(lazy_call, is_lazy=is_lazy, updates=conv_updates) + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array, **k) -> jax.Array: @@ -713,14 +707,9 @@ def __init__( self.groups = positive_int_cb(groups) if in_features is None: - # store the key for lazy initialization only self.key = key return - if "key" in vars(self): - # coming from lazy initialization - del self.key - if self.out_features % self.groups != 0: raise ValueError(f"{(self.out_features % self.groups ==0)=}") @@ -738,7 +727,7 @@ def spatial_ndim(self) -> int: class ConvNDTranspose(BaseConvNDTranspose): - @ft.partial(lazy_call, is_lazy=is_lazy, updates=conv_updates) + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array, **k) -> jax.Array: @@ -981,7 +970,7 @@ def spatial_ndim(self) -> int: class FFTConvNDTranspose(BaseConvNDTranspose): - @ft.partial(lazy_call, is_lazy=is_lazy, updates=conv_updates) + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array, **k) -> jax.Array: @@ -1249,12 +1238,8 @@ def __init__( self.bias_init = bias_init if in_features is None: - # store the key for lazy initialization only self.key = key return - if "key" in vars(self): - # coming from lazy initialization - del self.key weight_shape = (depth_multiplier * in_features, 1, *self.kernel_size) # OIHW self.weight = resolve_init_func(self.weight_init)(key, weight_shape) @@ -1270,7 +1255,7 @@ def spatial_ndim(self) -> int: class DepthwiseConvND(BaseDepthwiseConvND): - @ft.partial(lazy_call, is_lazy=is_lazy, updates=conv_updates) + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array, **k) -> jax.Array: @@ -1460,7 +1445,7 @@ def spatial_ndim(self) -> int: class DepthwiseFFTConvND(BaseDepthwiseConvND): - @ft.partial(lazy_call, is_lazy=is_lazy, updates=conv_updates) + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array, **k) -> jax.Array: @@ -1675,18 +1660,6 @@ def __init__( # going to lazy init return - if hasattr(self, "key"): - # coming from lazy init - del self.in_features - del self.out_features - del self.kernel_size - del self.depth_multiplier - del self.strides - del self.padding - del self.depthwise_weight_init - del self.pointwise_weight_init - del self.key - self.depthwise_conv = self._depthwise_convolution_layer( in_features=in_features, depth_multiplier=depth_multiplier, @@ -1709,7 +1682,7 @@ def __init__( key=key, ) - @ft.partial(lazy_call, is_lazy=is_lazy, updates=conv_updates) + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates) def __call__(self, x: jax.Array, **k) -> jax.Array: x = self.depthwise_conv(x) x = self.pointwise_conv(x) @@ -2191,10 +2164,11 @@ def __init__( name="kernel_size", ) - if in_size is None: - self.in_size = None - else: - self.in_size = canonicalize(in_size, self.spatial_ndim, name="in_size") + self.in_size = ( + canonicalize(in_size, self.spatial_ndim, name="in_size") + if in_size is not None + else None + ) self.strides = canonicalize(strides, self.spatial_ndim, name="strides") @@ -2214,14 +2188,9 @@ def __init__( self.bias_init = bias_init if self.in_features is None or self.in_size is None: - # going to lazy initialization self.key = key return - if "key" in vars(self): - # coming from lazy initialization - del self.key - out_size = calculate_convolution_output_shape( shape=self.in_size, kernel_size=self.kernel_size, @@ -2241,7 +2210,7 @@ def __init__( bias_shape = (self.out_features, *out_size) self.bias = resolve_init_func(self.bias_init)(key, bias_shape) - @ft.partial(lazy_call, is_lazy=is_lazy, updates=convlocal_updates) + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=convlocal_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array, **k) -> jax.Array: diff --git a/serket/nn/utils.py b/serket/nn/utils.py index 5bdc271..3eb71dc 100644 --- a/serket/nn/utils.py +++ b/serket/nn/utils.py @@ -319,7 +319,7 @@ def wrapper(self, array, *a, **k): return wrapper -def lazy_call( +def maybe_lazy_call( func: Callable[P, T], is_lazy: Callable[..., bool], updates: dict[str, Callable[..., Any]], @@ -332,7 +332,12 @@ def inner(self, *a, **k): kwargs = dict(vars(self)) for key, update in updates.items(): kwargs[key] = update(self, *a, **k) + + # clear the instance information + vars(self).clear() + # re-initialize the instance getattr(type(self), "__init__")(self, **kwargs) + # call the decorated function return func(self, *a, **k) return inner