Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Jul 31, 2023
1 parent 5c7095b commit e438def
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 47 deletions.
61 changes: 15 additions & 46 deletions serket/nn/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
calculate_transpose_padding,
canonicalize,
delayed_canonicalize_padding,
lazy_call,
maybe_lazy_call,
positive_int_cb,
positive_int_or_none_cb,
validate_axis_shape,
Expand Down Expand Up @@ -155,7 +155,7 @@ def fft_conv_general_dilated(


def is_lazy(instance, *_, **__) -> bool:
return instance.in_features is None
return getattr(instance, "in_features", False) is None


def infer_in_features(instance, x, *_, **__) -> int:
Expand Down Expand Up @@ -205,15 +205,9 @@ def __init__(
self.groups = positive_int_cb(groups)

if in_features is None:
# store the key for lazy initialization only
self.key = key
# going to be initialized lazily
return

if "key" in vars(self):
# coming from lazy initialization
del self.key

if self.out_features % self.groups != 0:
raise ValueError(f"{(out_features % groups == 0)=}")

Expand All @@ -231,7 +225,7 @@ def spatial_ndim(self) -> int:


class ConvND(BaseConvND):
@ft.partial(lazy_call, is_lazy=is_lazy, updates=conv_updates)
@ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
def __call__(self, x: jax.Array, **k) -> jax.Array:
Expand Down Expand Up @@ -456,7 +450,7 @@ def spatial_ndim(self) -> int:


class FFTConvND(BaseConvND):
@ft.partial(lazy_call, is_lazy=is_lazy, updates=conv_updates)
@ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
def __call__(self, x: jax.Array, **k) -> jax.Array:
Expand Down Expand Up @@ -713,14 +707,9 @@ def __init__(
self.groups = positive_int_cb(groups)

if in_features is None:
# store the key for lazy initialization only
self.key = key
return

if "key" in vars(self):
# coming from lazy initialization
del self.key

if self.out_features % self.groups != 0:
raise ValueError(f"{(self.out_features % self.groups ==0)=}")

Expand All @@ -738,7 +727,7 @@ def spatial_ndim(self) -> int:


class ConvNDTranspose(BaseConvNDTranspose):
@ft.partial(lazy_call, is_lazy=is_lazy, updates=conv_updates)
@ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
def __call__(self, x: jax.Array, **k) -> jax.Array:
Expand Down Expand Up @@ -981,7 +970,7 @@ def spatial_ndim(self) -> int:


class FFTConvNDTranspose(BaseConvNDTranspose):
@ft.partial(lazy_call, is_lazy=is_lazy, updates=conv_updates)
@ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
def __call__(self, x: jax.Array, **k) -> jax.Array:
Expand Down Expand Up @@ -1249,12 +1238,8 @@ def __init__(
self.bias_init = bias_init

if in_features is None:
# store the key for lazy initialization only
self.key = key
return
if "key" in vars(self):
# coming from lazy initialization
del self.key

weight_shape = (depth_multiplier * in_features, 1, *self.kernel_size) # OIHW
self.weight = resolve_init_func(self.weight_init)(key, weight_shape)
Expand All @@ -1270,7 +1255,7 @@ def spatial_ndim(self) -> int:


class DepthwiseConvND(BaseDepthwiseConvND):
@ft.partial(lazy_call, is_lazy=is_lazy, updates=conv_updates)
@ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
def __call__(self, x: jax.Array, **k) -> jax.Array:
Expand Down Expand Up @@ -1460,7 +1445,7 @@ def spatial_ndim(self) -> int:


class DepthwiseFFTConvND(BaseDepthwiseConvND):
@ft.partial(lazy_call, is_lazy=is_lazy, updates=conv_updates)
@ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
def __call__(self, x: jax.Array, **k) -> jax.Array:
Expand Down Expand Up @@ -1675,18 +1660,6 @@ def __init__(
# going to lazy init
return

if hasattr(self, "key"):
# coming from lazy init
del self.in_features
del self.out_features
del self.kernel_size
del self.depth_multiplier
del self.strides
del self.padding
del self.depthwise_weight_init
del self.pointwise_weight_init
del self.key

self.depthwise_conv = self._depthwise_convolution_layer(
in_features=in_features,
depth_multiplier=depth_multiplier,
Expand All @@ -1709,7 +1682,7 @@ def __init__(
key=key,
)

@ft.partial(lazy_call, is_lazy=is_lazy, updates=conv_updates)
@ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates)
def __call__(self, x: jax.Array, **k) -> jax.Array:
x = self.depthwise_conv(x)
x = self.pointwise_conv(x)
Expand Down Expand Up @@ -2191,10 +2164,11 @@ def __init__(
name="kernel_size",
)

if in_size is None:
self.in_size = None
else:
self.in_size = canonicalize(in_size, self.spatial_ndim, name="in_size")
self.in_size = (
canonicalize(in_size, self.spatial_ndim, name="in_size")
if in_size is not None
else None
)

self.strides = canonicalize(strides, self.spatial_ndim, name="strides")

Expand All @@ -2214,14 +2188,9 @@ def __init__(
self.bias_init = bias_init

if self.in_features is None or self.in_size is None:
# going to lazy initialization
self.key = key
return

if "key" in vars(self):
# coming from lazy initialization
del self.key

out_size = calculate_convolution_output_shape(
shape=self.in_size,
kernel_size=self.kernel_size,
Expand All @@ -2241,7 +2210,7 @@ def __init__(
bias_shape = (self.out_features, *out_size)
self.bias = resolve_init_func(self.bias_init)(key, bias_shape)

@ft.partial(lazy_call, is_lazy=is_lazy, updates=convlocal_updates)
@ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=convlocal_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
def __call__(self, x: jax.Array, **k) -> jax.Array:
Expand Down
7 changes: 6 additions & 1 deletion serket/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def wrapper(self, array, *a, **k):
return wrapper


def lazy_call(
def maybe_lazy_call(
func: Callable[P, T],
is_lazy: Callable[..., bool],
updates: dict[str, Callable[..., Any]],
Expand All @@ -332,7 +332,12 @@ def inner(self, *a, **k):
kwargs = dict(vars(self))
for key, update in updates.items():
kwargs[key] = update(self, *a, **k)

# clear the instance information
vars(self).clear()
# re-initialize the instance
getattr(type(self), "__init__")(self, **kwargs)
# call the decorated function
return func(self, *a, **k)

return inner

0 comments on commit e438def

Please sign in to comment.