diff --git a/pyproject.toml b/pyproject.toml index d0233c7..2109e02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ keywords = [ "functional-programming", "machine-learning", ] -dependencies = ["pytreeclass>=0.5.0"] +dependencies = ["pytreeclass>=0.6.0"] classifiers = [ "Development Status :: 5 - Production/Stable", @@ -30,6 +30,8 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Python Modules", ] +readme = "README.md" + [tool.setuptools.dynamic] version = { attr = "serket.__version__" } diff --git a/serket/nn/containers.py b/serket/nn/containers.py index 501bcca..3109ebd 100644 --- a/serket/nn/containers.py +++ b/serket/nn/containers.py @@ -81,8 +81,7 @@ def __reversed__(self): @sk.autoinit class RandomApply(sk.TreeClass): - """ - Randomly applies a layer with probability p. + """Randomly applies a layer with probability p. Args: layer: layer to apply. diff --git a/serket/nn/convolution.py b/serket/nn/convolution.py index 705efa2..1ed800e 100644 --- a/serket/nn/convolution.py +++ b/serket/nn/convolution.py @@ -37,7 +37,9 @@ calculate_transpose_padding, canonicalize, delayed_canonicalize_padding, + maybe_lazy_call, positive_int_cb, + positive_int_or_none_cb, validate_axis_shape, validate_spatial_ndim, ) @@ -152,10 +154,29 @@ def fft_conv_general_dilated( return jax.lax.slice(z, start, end, (1, 1, *strides)) +def is_lazy(instance, *_, **__) -> bool: + return getattr(instance, "in_features", False) is None + + +def infer_in_features(instance, x, *_, **__) -> int: + return x.shape[0] + + +def infer_in_size(instance, x, *_, **__) -> tuple[int, ...]: + return x.shape[1:] + + +def infer_key(instance, *_, **__) -> jr.KeyArray: + return instance.key + + +conv_updates = {"key": infer_key, "in_features": infer_in_features} + + class BaseConvND(sk.TreeClass): def __init__( self, - in_features: int, + in_features: int | None, out_features: int, kernel_size: KernelSizeType, *, @@ -167,36 +188,28 @@ def __init__( groups: int = 1, key: jr.KeyArray = jr.PRNGKey(0), ): - self.in_features = positive_int_cb(in_features) + self.in_features = positive_int_or_none_cb(in_features) self.out_features = positive_int_cb(out_features) - self.kernel_size = canonicalize( - kernel_size, self.spatial_ndim, name="kernel_size" - ) - self.strides = canonicalize(strides, self.spatial_ndim, name="strides") + self.kernel_size = canonicalize(kernel_size, self.spatial_ndim, "kernel_size") + self.strides = canonicalize(strides, self.spatial_ndim, "strides") self.padding = padding - self.dilation = canonicalize(dilation, self.spatial_ndim, name="dilation") - - weight_init = resolve_init_func(weight_init) - bias_init = resolve_init_func(bias_init) - + self.dilation = canonicalize(dilation, self.spatial_ndim, "dilation") + self.weight_init = weight_init + self.bias_init = bias_init self.groups = positive_int_cb(groups) + if in_features is None: + self.key = key + return + if self.out_features % self.groups != 0: raise ValueError(f"{(out_features % groups == 0)=}") weight_shape = (out_features, in_features // groups, *self.kernel_size) - self.weight = weight_init(key, weight_shape) + self.weight = resolve_init_func(self.weight_init)(key, weight_shape) bias_shape = (out_features, *(1,) * self.spatial_ndim) - self.bias = bias_init(key, bias_shape) - - @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) - def __call__(self, x: jax.Array, **k) -> jax.Array: - x = self._convolution_operation(jnp.expand_dims(x, 0)) - if self.bias is None: - return jnp.squeeze(x, 0) - return jnp.squeeze((x + self.bias), 0) + self.bias = resolve_init_func(self.bias_init)(key, bias_shape) @property @abc.abstractmethod @@ -204,23 +217,21 @@ def spatial_ndim(self) -> int: """Number of spatial dimensions of the convolutional layer.""" ... - @abc.abstractmethod - def _convolution_operation(self, x: jax.Array) -> jax.Array: - """Convolution operation.""" - ... - class ConvND(BaseConvND): - def _convolution_operation(self, x: jax.Array) -> jax.Array: + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates) + @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") + @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) + def __call__(self, x: jax.Array, **k) -> jax.Array: padding = delayed_canonicalize_padding( - in_dim=x.shape[2:], + in_dim=x.shape[1:], padding=self.padding, kernel_size=self.kernel_size, strides=self.strides, ) - return jax.lax.conv_general_dilated( - lhs=x, + x = jax.lax.conv_general_dilated( + lhs=jnp.expand_dims(x, 0), rhs=self.weight, window_strides=self.strides, padding=padding, @@ -229,6 +240,10 @@ def _convolution_operation(self, x: jax.Array) -> jax.Array: feature_group_count=self.groups, ) + if self.bias is None: + return jnp.squeeze(x, 0) + return jnp.squeeze((x + self.bias), 0) + class Conv1D(ConvND): """1D Convolutional layer. @@ -287,6 +302,34 @@ class Conv1D(ConvND): >>> print(jax.vmap(layer)(x).shape) (2, 2, 5) + Note: + :class:`.Conv1D` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.Conv1D = sk.nn.Conv1D(None, 12, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.Conv1D = sk.nn.Conv1D(None, 1, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> print(lazy_cnn.l1.in_features, lazy_cnn.l2.in_features) + None None + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2))) + >>> print(materialized_cnn.l1.in_features, materialized_cnn.l2.in_features) + 5 12 + Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html """ @@ -353,6 +396,34 @@ class Conv2D(ConvND): >>> print(jax.vmap(layer)(x).shape) (2, 2, 5, 5) + Note: + :class:`.Conv2D` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.Conv2D = sk.nn.Conv2D(None, 12, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.Conv2D = sk.nn.Conv2D(None, 1, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> print(lazy_cnn.l1.in_features, lazy_cnn.l2.in_features) + None None + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2, 2))) + >>> print(materialized_cnn.l1.in_features, materialized_cnn.l2.in_features) + 5 12 + Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html """ @@ -419,6 +490,34 @@ class Conv3D(ConvND): >>> print(jax.vmap(layer)(x).shape) (2, 2, 5, 5, 5) + Note: + :class:`.Conv3D` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.Conv3D = sk.nn.Conv3D(None, 12, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.Conv3D = sk.nn.Conv3D(None, 1, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> print(lazy_cnn.l1.in_features, lazy_cnn.l2.in_features) + None None + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2, 2, 2))) + >>> print(materialized_cnn.l1.in_features, materialized_cnn.l2.in_features) + 5 12 + Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html """ @@ -429,16 +528,19 @@ def spatial_ndim(self) -> int: class FFTConvND(BaseConvND): - def _convolution_operation(self, x: jax.Array) -> jax.Array: + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates) + @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") + @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) + def __call__(self, x: jax.Array, **k) -> jax.Array: padding = delayed_canonicalize_padding( - in_dim=x.shape[2:], + in_dim=x.shape[1:], padding=self.padding, kernel_size=self.kernel_size, strides=self.strides, ) - return fft_conv_general_dilated( - lhs=x, + x = fft_conv_general_dilated( + lhs=jnp.expand_dims(x, 0), rhs=self.weight, strides=self.strides, padding=padding, @@ -446,6 +548,10 @@ def _convolution_operation(self, x: jax.Array) -> jax.Array: dilation=self.dilation, ) + if self.bias is None: + return jnp.squeeze(x, 0) + return jnp.squeeze((x + self.bias), 0) + class FFTConv1D(FFTConvND): """1D Convolutional layer. @@ -504,6 +610,34 @@ class FFTConv1D(FFTConvND): >>> print(jax.vmap(layer)(x).shape) (2, 2, 5) + Note: + :class:`.FFTConv1D` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.FFTConv1D = sk.nn.FFTConv1D(None, 12, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.FFTConv1D = sk.nn.FFTConv1D(None, 1, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> print(lazy_cnn.l1.in_features, lazy_cnn.l2.in_features) + None None + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2))) + >>> print(materialized_cnn.l1.in_features, materialized_cnn.l2.in_features) + 5 12 + References: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html """ @@ -570,6 +704,34 @@ class FFTConv2D(FFTConvND): >>> print(jax.vmap(layer)(x).shape) (2, 2, 5, 5) + Note: + :class:`.FFTConv2D` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.FFTConv2D = sk.nn.FFTConv2D(None, 12, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.FFTConv2D = sk.nn.FFTConv2D(None, 1, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> print(lazy_cnn.l1.in_features, lazy_cnn.l2.in_features) + None None + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2, 2))) + >>> print(materialized_cnn.l1.in_features, materialized_cnn.l2.in_features) + 5 12 + References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html """ @@ -636,6 +798,34 @@ class FFTConv3D(FFTConvND): >>> print(jax.vmap(layer)(x).shape) (2, 2, 5, 5, 5) + Note: + :class:`.FFTConv3D` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.FFTConv3D = sk.nn.FFTConv3D(None, 12, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.FFTConv3D = sk.nn.FFTConv3D(None, 1, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> print(lazy_cnn.l1.in_features, lazy_cnn.l2.in_features) + None None + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2, 2, 2))) + >>> print(materialized_cnn.l1.in_features, materialized_cnn.l2.in_features) + 5 12 + References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html """ @@ -648,52 +838,43 @@ def spatial_ndim(self) -> int: class BaseConvNDTranspose(sk.TreeClass): def __init__( self, - in_features: int, + in_features: int | None, out_features: int, kernel_size: KernelSizeType, *, strides: StridesType = 1, padding: PaddingType = "same", - output_padding: int = 0, + out_padding: int = 0, dilation: DilationType = 1, weight_init: InitType = "glorot_uniform", bias_init: InitType = "zeros", groups: int = 1, key: jr.KeyArray = jr.PRNGKey(0), ): - self.in_features = positive_int_cb(in_features) + self.in_features = positive_int_or_none_cb(in_features) self.out_features = positive_int_cb(out_features) - self.kernel_size = canonicalize( - kernel_size, self.spatial_ndim, name="kernel_size" - ) - self.strides = canonicalize(strides, self.spatial_ndim, name="strides") + self.kernel_size = canonicalize(kernel_size, self.spatial_ndim, "kernel_size") + self.strides = canonicalize(strides, self.spatial_ndim, "strides") self.padding = padding # delayed canonicalization - self.output_padding = canonicalize( - output_padding, - self.spatial_ndim, - name="output_padding", - ) - self.dilation = canonicalize(dilation, self.spatial_ndim, name="dilation") - weight_init = resolve_init_func(weight_init) - bias_init = resolve_init_func(bias_init) + self.out_padding = canonicalize(out_padding, self.spatial_ndim, "out_padding") + self.dilation = canonicalize(dilation, self.spatial_ndim, "dilation") + self.weight_init = weight_init + self.bias_init = bias_init self.groups = positive_int_cb(groups) + if in_features is None: + self.key = key + return + if self.out_features % self.groups != 0: raise ValueError(f"{(self.out_features % self.groups ==0)=}") + in_features = positive_int_cb(self.in_features) weight_shape = (out_features, in_features // groups, *self.kernel_size) # OIHW - self.weight = weight_init(key, weight_shape) + self.weight = resolve_init_func(self.weight_init)(key, weight_shape) bias_shape = (out_features, *(1,) * self.spatial_ndim) - self.bias = bias_init(key, bias_shape) - - @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) - def __call__(self, x: jax.Array, **k) -> jax.Array: - x = self._convolution_operation(jnp.expand_dims(x, 0)) - if self.bias is None: - return jnp.squeeze(x, 0) - return jnp.squeeze(x + self.bias, 0) + self.bias = resolve_init_func(self.bias_init)(key, bias_shape) @property @abc.abstractmethod @@ -703,9 +884,12 @@ def spatial_ndim(self) -> int: class ConvNDTranspose(BaseConvNDTranspose): - def _convolution_operation(self, x: jax.Array) -> jax.Array: + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates) + @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") + @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) + def __call__(self, x: jax.Array, **k) -> jax.Array: padding = delayed_canonicalize_padding( - in_dim=x.shape[2:], + in_dim=x.shape[1:], padding=self.padding, kernel_size=self.kernel_size, strides=self.strides, @@ -713,15 +897,13 @@ def _convolution_operation(self, x: jax.Array) -> jax.Array: transposed_padding = calculate_transpose_padding( padding=padding, - extra_padding=self.output_padding, + extra_padding=self.out_padding, kernel_size=self.kernel_size, input_dilation=self.dilation, ) - # breakpoint() - - return jax.lax.conv_transpose( - lhs=x, + x = jax.lax.conv_transpose( + lhs=jnp.expand_dims(x, 0), rhs=self.weight, strides=self.strides, padding=transposed_padding, @@ -729,6 +911,10 @@ def _convolution_operation(self, x: jax.Array) -> jax.Array: dimension_numbers=generate_conv_dim_numbers(self.spatial_ndim), ) + if self.bias is None: + return jnp.squeeze(x, 0) + return jnp.squeeze(x + self.bias, 0) + class Conv1DTranspose(ConvNDTranspose): """1D Convolution transpose layer. @@ -761,7 +947,7 @@ class Conv1DTranspose(ConvNDTranspose): as the input. - ``valid``/``VALID`` for no padding. - output_padding: padding of the output after convolution. accepts: + out_padding: padding of the output after convolution. accepts: - single integer for same padding in all dimensions. @@ -791,6 +977,34 @@ class Conv1DTranspose(ConvNDTranspose): >>> print(jax.vmap(layer)(x).shape) (2, 2, 5) + Note: + :class:`.Conv1DTranspose` supports lazy initialization, meaning that the + weights and biases are not initialized until the first call to the layer. + This is useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.Conv1DTranspose = sk.nn.Conv1DTranspose(None, 12, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.Conv1DTranspose = sk.nn.Conv1DTranspose(None, 1, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> print(lazy_cnn.l1.in_features, lazy_cnn.l2.in_features) + None None + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2))) + >>> print(materialized_cnn.l1.in_features, materialized_cnn.l2.in_features) + 5 12 + Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html """ @@ -831,7 +1045,7 @@ class Conv2DTranspose(ConvNDTranspose): as the input. - ``valid``/``VALID`` for no padding. - output_padding: padding of the output after convolution. accepts: + out_padding: padding of the output after convolution. accepts: - single integer for same padding in all dimensions. @@ -861,6 +1075,34 @@ class Conv2DTranspose(ConvNDTranspose): >>> print(jax.vmap(layer)(x).shape) (2, 2, 5, 5) + Note: + :class:`.Conv2DTranspose` supports lazy initialization, meaning that the + weights and biases are not initialized until the first call to the layer. + This is useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.Conv2DTranspose = sk.nn.Conv2DTranspose(None, 12, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.Conv2DTranspose = sk.nn.Conv2DTranspose(None, 1, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> print(lazy_cnn.l1.in_features, lazy_cnn.l2.in_features) + None None + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2, 2))) + >>> print(materialized_cnn.l1.in_features, materialized_cnn.l2.in_features) + 5 12 + Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html """ @@ -901,7 +1143,7 @@ class Conv3DTranspose(ConvNDTranspose): as the input. - ``valid``/``VALID`` for no padding. - output_padding: padding of the output after convolution. accepts: + out_padding: padding of the output after convolution. accepts: - single integer for same padding in all dimensions. @@ -931,6 +1173,34 @@ class Conv3DTranspose(ConvNDTranspose): >>> print(jax.vmap(layer)(x).shape) (2, 2, 5, 5, 5) + Note: + :class:`.Conv3DTranspose` supports lazy initialization, meaning that the + weights and biases are not initialized until the first call to the layer. + This is useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.Conv3DTranspose = sk.nn.Conv3DTranspose(None, 12, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.Conv3DTranspose = sk.nn.Conv3DTranspose(None, 1, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> print(lazy_cnn.l1.in_features, lazy_cnn.l2.in_features) + None None + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2, 2, 2))) + >>> print(materialized_cnn.l1.in_features, materialized_cnn.l2.in_features) + 5 12 + Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html """ @@ -941,9 +1211,12 @@ def spatial_ndim(self) -> int: class FFTConvNDTranspose(BaseConvNDTranspose): - def _convolution_operation(self, x: jax.Array) -> jax.Array: + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates) + @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") + @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) + def __call__(self, x: jax.Array, **k) -> jax.Array: padding = delayed_canonicalize_padding( - in_dim=x.shape[2:], + in_dim=x.shape[1:], padding=self.padding, kernel_size=self.kernel_size, strides=self.strides, @@ -951,19 +1224,22 @@ def _convolution_operation(self, x: jax.Array) -> jax.Array: transposed_padding = calculate_transpose_padding( padding=padding, - extra_padding=self.output_padding, + extra_padding=self.out_padding, kernel_size=self.kernel_size, input_dilation=self.dilation, ) - return fft_conv_general_dilated( - lhs=x, + x = fft_conv_general_dilated( + lhs=jnp.expand_dims(x, 0), rhs=self.weight, strides=self.strides, padding=transposed_padding, dilation=self.dilation, groups=1, ) + if self.bias is None: + return jnp.squeeze(x, 0) + return jnp.squeeze(x + self.bias, 0) class FFTConv1DTranspose(FFTConvNDTranspose): @@ -997,7 +1273,7 @@ class FFTConv1DTranspose(FFTConvNDTranspose): as the input. - ``valid``/``VALID`` for no padding. - output_padding: Padding of the output after convolution. accepts: + out_padding: Padding of the output after convolution. accepts: - single integer for same padding in all dimensions. @@ -1027,6 +1303,34 @@ class FFTConv1DTranspose(FFTConvNDTranspose): >>> print(jax.vmap(layer)(x).shape) (2, 2, 5) + Note: + :class:`.FFTConv1DTranspose` supports lazy initialization, meaning that the + weights and biases are not initialized until the first call to the layer. + This is useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.FFTConv1DTranspose = sk.nn.FFTConv1DTranspose(None, 12, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.FFTConv1DTranspose = sk.nn.FFTConv1DTranspose(None, 1, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> print(lazy_cnn.l1.in_features, lazy_cnn.l2.in_features) + None None + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2))) + >>> print(materialized_cnn.l1.in_features, materialized_cnn.l2.in_features) + 5 12 + References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html """ @@ -1067,7 +1371,7 @@ class FFTConv2DTranspose(FFTConvNDTranspose): as the input. - ``valid``/``VALID`` for no padding. - output_padding: Padding of the output after convolution. accepts: + out_padding: Padding of the output after convolution. accepts: - single integer for same padding in all dimensions. @@ -1097,6 +1401,34 @@ class FFTConv2DTranspose(FFTConvNDTranspose): >>> print(jax.vmap(layer)(x).shape) (2, 2, 5, 5) + Note: + :class:`.FFTConv2DTranspose` supports lazy initialization, meaning that the + weights and biases are not initialized until the first call to the layer. + This is useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.FFTConv2DTranspose = sk.nn.FFTConv2DTranspose(None, 12, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.FFTConv2DTranspose = sk.nn.FFTConv2DTranspose(None, 1, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> print(lazy_cnn.l1.in_features, lazy_cnn.l2.in_features) + None None + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2, 2))) + >>> print(materialized_cnn.l1.in_features, materialized_cnn.l2.in_features) + 5 12 + References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html """ @@ -1137,7 +1469,7 @@ class FFTConv3DTranspose(FFTConvNDTranspose): as the input. - ``valid``/``VALID`` for no padding. - output_padding: Padding of the output after convolution. accepts: + out_padding: Padding of the output after convolution. accepts: - single integer for same padding in all dimensions. @@ -1167,6 +1499,34 @@ class FFTConv3DTranspose(FFTConvNDTranspose): >>> print(jax.vmap(layer)(x).shape) (2, 2, 5, 5, 5) + Note: + :class:`.FFTConv3DTranspose` supports lazy initialization, meaning that the + weights and biases are not initialized until the first call to the layer. + This is useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.FFTConv3DTranspose = sk.nn.FFTConv3DTranspose(None, 12, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.FFTConv3DTranspose = sk.nn.FFTConv3DTranspose(None, 1, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> print(lazy_cnn.l1.in_features, lazy_cnn.l2.in_features) + None None + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2, 2, 2))) + >>> print(materialized_cnn.l1.in_features, materialized_cnn.l2.in_features) + 5 12 + References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html """ @@ -1179,7 +1539,7 @@ def spatial_ndim(self) -> int: class BaseDepthwiseConvND(sk.TreeClass): def __init__( self, - in_features: int, + in_features: int | None, kernel_size: KernelSizeType, *, depth_multiplier: int = 1, @@ -1189,30 +1549,23 @@ def __init__( bias_init: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - self.in_features = positive_int_cb(in_features) - self.kernel_size = canonicalize( - kernel_size, self.spatial_ndim, name="kernel_size" - ) + self.in_features = positive_int_or_none_cb(in_features) + self.kernel_size = canonicalize(kernel_size, self.spatial_ndim, "kernel_size") self.depth_multiplier = positive_int_cb(depth_multiplier) - self.strides = canonicalize(strides, self.spatial_ndim, name="strides") + self.strides = canonicalize(strides, self.spatial_ndim, "strides") self.padding = padding # delayed canonicalization - self.dilation = canonicalize(1, self.spatial_ndim, name="dilation") - weight_init = resolve_init_func(weight_init) - bias_init = resolve_init_func(bias_init) + self.weight_init = weight_init + self.bias_init = bias_init + + if in_features is None: + self.key = key + return weight_shape = (depth_multiplier * in_features, 1, *self.kernel_size) # OIHW - self.weight = weight_init(key, weight_shape) + self.weight = resolve_init_func(self.weight_init)(key, weight_shape) bias_shape = (depth_multiplier * in_features, *(1,) * self.spatial_ndim) - self.bias = bias_init(key, bias_shape) - - @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) - def __call__(self, x: jax.Array, **k) -> jax.Array: - x = self._convolution_operation(jnp.expand_dims(x, 0)) - if self.bias is None: - return jnp.squeeze(x, 0) - return jnp.squeeze((x + self.bias), 0) + self.bias = resolve_init_func(self.bias_init)(key, bias_shape) @property @abc.abstractmethod @@ -1220,30 +1573,33 @@ def spatial_ndim(self) -> int: """Number of spatial dimensions of the convolutional layer.""" ... - @abc.abstractmethod - def _convolution_operation(self, x: jax.Array) -> jax.Array: - ... - class DepthwiseConvND(BaseDepthwiseConvND): - def _convolution_operation(self, x: jax.Array) -> jax.Array: + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates) + @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") + @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) + def __call__(self, x: jax.Array, **k) -> jax.Array: padding = delayed_canonicalize_padding( - in_dim=x.shape[2:], + in_dim=x.shape[1:], padding=self.padding, kernel_size=self.kernel_size, strides=self.strides, ) - return jax.lax.conv_general_dilated( - lhs=x, + x = jax.lax.conv_general_dilated( + lhs=jnp.expand_dims(x, 0), rhs=self.weight, window_strides=self.strides, padding=padding, - rhs_dilation=self.dilation, + rhs_dilation=canonicalize(1, self.spatial_ndim, "dilation"), dimension_numbers=generate_conv_dim_numbers(self.spatial_ndim), feature_group_count=self.in_features, ) + if self.bias is None: + return jnp.squeeze(x, 0) + return jnp.squeeze((x + self.bias), 0) + class DepthwiseConv1D(DepthwiseConvND): """1D Depthwise convolution layer. @@ -1281,7 +1637,6 @@ class DepthwiseConv1D(DepthwiseConvND): ``zeros``. set to ``None`` to not use a bias. key: key to use for initializing the weights. defaults to ``0``. - Example: >>> import jax.numpy as jnp >>> import serket as sk @@ -1289,6 +1644,34 @@ class DepthwiseConv1D(DepthwiseConvND): >>> l1(jnp.ones((3, 32))).shape (6, 16) + Note: + :class:`.DepthwiseConv1D` supports lazy initialization, meaning that the + weights and biases are not initialized until the first call to the layer. + This is useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.DepthwiseConv1D = sk.nn.DepthwiseConv1D(None, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.DepthwiseConv1D = sk.nn.DepthwiseConv1D(None, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> print(lazy_cnn.l1.in_features, lazy_cnn.l2.in_features) + None None + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2))) + >>> print(materialized_cnn.l1.in_features, materialized_cnn.l2.in_features) + 5 5 + Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - https://github.com/google/flax/blob/main/flax/linen/linear.py @@ -1335,7 +1718,6 @@ class DepthwiseConv2D(DepthwiseConvND): ``zeros``. set to ``None`` to not use a bias. key: key to use for initializing the weights. defaults to ``0``. - Example: >>> import jax.numpy as jnp >>> import serket as sk @@ -1343,6 +1725,34 @@ class DepthwiseConv2D(DepthwiseConvND): >>> l1(jnp.ones((3, 32, 32))).shape (6, 16, 16) + Note: + :class:`.DepthwiseConv2D` supports lazy initialization, meaning that the + weights and biases are not initialized until the first call to the layer. + This is useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.DepthwiseConv2D = sk.nn.DepthwiseConv2D(None, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.DepthwiseConv2D = sk.nn.DepthwiseConv2D(None, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> print(lazy_cnn.l1.in_features, lazy_cnn.l2.in_features) + None None + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2, 2))) + >>> print(materialized_cnn.l1.in_features, materialized_cnn.l2.in_features) + 5 5 + Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - https://github.com/google/flax/blob/main/flax/linen/linear.py @@ -1390,7 +1800,6 @@ class DepthwiseConv3D(DepthwiseConvND): ``zeros``. set to ``None`` to not use a bias. key: key to use for initializing the weights. defaults to ``0``. - Example: >>> import jax.numpy as jnp >>> import serket as sk @@ -1398,6 +1807,34 @@ class DepthwiseConv3D(DepthwiseConvND): >>> l1(jnp.ones((3, 32, 32, 32))).shape (6, 16, 16, 16) + Note: + :class:`.DepthwiseConv3D` supports lazy initialization, meaning that the + weights and biases are not initialized until the first call to the layer. + This is useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.DepthwiseConv3D = sk.nn.DepthwiseConv3D(None, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.DepthwiseConv3D = sk.nn.DepthwiseConv3D(None, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> print(lazy_cnn.l1.in_features, lazy_cnn.l2.in_features) + None None + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2, 2, 2))) + >>> print(materialized_cnn.l1.in_features, materialized_cnn.l2.in_features) + 5 5 + Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - https://github.com/google/flax/blob/main/flax/linen/linear.py @@ -1409,23 +1846,30 @@ def spatial_ndim(self) -> int: class DepthwiseFFTConvND(BaseDepthwiseConvND): - def _convolution_operation(self, x: jax.Array) -> jax.Array: + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates) + @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") + @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) + def __call__(self, x: jax.Array, **k) -> jax.Array: padding = delayed_canonicalize_padding( - in_dim=x.shape[2:], + in_dim=x.shape[1:], padding=self.padding, kernel_size=self.kernel_size, strides=self.strides, ) - return fft_conv_general_dilated( - lhs=x, + x = fft_conv_general_dilated( + lhs=jnp.expand_dims(x, 0), rhs=self.weight, strides=self.strides, padding=padding, - dilation=self.dilation, + dilation=canonicalize(1, self.spatial_ndim, "dilation"), groups=self.in_features, ) + if self.bias is None: + return jnp.squeeze(x, 0) + return jnp.squeeze((x + self.bias), 0) + class DepthwiseFFTConv1D(DepthwiseFFTConvND): """1D Depthwise FFT convolution layer. @@ -1463,7 +1907,6 @@ class DepthwiseFFTConv1D(DepthwiseFFTConvND): ``zeros``. set to ``None`` to not use a bias. key: key to use for initializing the weights. defaults to ``0``. - Example: >>> import jax.numpy as jnp >>> import serket as sk @@ -1471,6 +1914,34 @@ class DepthwiseFFTConv1D(DepthwiseFFTConvND): >>> l1(jnp.ones((3, 32))).shape (6, 16) + Note: + :class:`.DepthwiseFFTConv1D` supports lazy initialization, meaning that the + weights and biases are not initialized until the first call to the layer. + This is useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.DepthwiseFFTConv1D = sk.nn.DepthwiseFFTConv1D(None, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.DepthwiseFFTConv1D = sk.nn.DepthwiseFFTConv1D(None, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> print(lazy_cnn.l1.in_features, lazy_cnn.l2.in_features) + None None + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2))) + >>> print(materialized_cnn.l1.in_features, materialized_cnn.l2.in_features) + 5 5 + References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - https://github.com/google/flax/blob/main/flax/linen/linear.py @@ -1517,7 +1988,6 @@ class DepthwiseFFTConv2D(DepthwiseFFTConvND): ``zeros``. set to ``None`` to not use a bias. key: key to use for initializing the weights. defaults to ``0``. - Example: >>> import jax.numpy as jnp >>> import serket as sk @@ -1525,6 +1995,34 @@ class DepthwiseFFTConv2D(DepthwiseFFTConvND): >>> l1(jnp.ones((3, 32, 32))).shape (6, 16, 16) + Note: + :class:`.DepthwiseFFTConv2D` supports lazy initialization, meaning that the + weights and biases are not initialized until the first call to the layer. + This is useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.DepthwiseFFTConv2D = sk.nn.DepthwiseFFTConv2D(None, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.DepthwiseFFTConv2D = sk.nn.DepthwiseFFTConv2D(None, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> print(lazy_cnn.l1.in_features, lazy_cnn.l2.in_features) + None None + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2, 2))) + >>> print(materialized_cnn.l1.in_features, materialized_cnn.l2.in_features) + 5 5 + References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - https://github.com/google/flax/blob/main/flax/linen/linear.py @@ -1571,7 +2069,6 @@ class DepthwiseFFTConv3D(DepthwiseFFTConvND): ``zeros``. set to ``None`` to not use a bias. key: key to use for initializing the weights. defaults to ``0``. - Example: >>> import jax.numpy as jnp >>> import serket as sk @@ -1579,6 +2076,33 @@ class DepthwiseFFTConv3D(DepthwiseFFTConvND): >>> l1(jnp.ones((3, 32, 32, 32))).shape (6, 16, 16, 16) + Note: + :class:`.DepthwiseFFTConv3D` supports lazy initialization, meaning that the + weights and biases are not initialized until the first call to the layer. + This is useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.DepthwiseFFTConv3D = sk.nn.DepthwiseFFTConv3D(None, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.DepthwiseFFTConv3D = sk.nn.DepthwiseFFTConv3D(None, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> print(lazy_cnn.l1.in_features, lazy_cnn.l2.in_features) + None None + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2, 2, 2))) + >>> print(materialized_cnn.l1.in_features, materialized_cnn.l2.in_features) + 5 5 References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - https://github.com/google/flax/blob/main/flax/linen/linear.py @@ -1592,7 +2116,7 @@ def spatial_ndim(self) -> int: class SeparableConvND(sk.TreeClass): def __init__( self, - in_features: int, + in_features: int | None, out_features: int, kernel_size: KernelSizeType, *, @@ -1604,6 +2128,19 @@ def __init__( pointwise_bias_init: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): + if in_features is None: + self.in_features = in_features + self.out_features = out_features + self.kernel_size = kernel_size + self.depth_multiplier = depth_multiplier + self.strides = strides + self.padding = padding + self.depthwise_weight_init = depthwise_weight_init + self.pointwise_weight_init = pointwise_weight_init + self.key = key + # going to lazy init + return + self.depthwise_conv = self._depthwise_convolution_layer( in_features=in_features, depth_multiplier=depth_multiplier, @@ -1626,6 +2163,7 @@ def __init__( key=key, ) + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates) def __call__(self, x: jax.Array, **k) -> jax.Array: x = self.depthwise_conv(x) x = self.pointwise_conv(x) @@ -1694,7 +2232,6 @@ class SeparableConv1D(SeparableConvND): ``zeros``. set to ``None`` to not use a bias. key: key to use for initializing the weights. defaults to ``0``. - Example: >>> import jax.numpy as jnp >>> import serket as sk @@ -1702,6 +2239,30 @@ class SeparableConv1D(SeparableConvND): >>> l1(jnp.ones((3, 32))).shape (3, 32) + Note: + :class:`.SeparableConv1D` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.SeparableConv1D = sk.nn.SeparableConv1D(None, 12, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.SeparableConv1D = sk.nn.SeparableConv1D(None, 1, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2))) + Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - https://github.com/google/flax/blob/main/flax/linen/linear.py @@ -1774,6 +2335,30 @@ class SeparableConv2D(SeparableConvND): >>> l1(jnp.ones((3, 32, 32))).shape (3, 32, 32) + Note: + :class:`.SeparableConv2D` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.SeparableConv2D = sk.nn.SeparableConv2D(None, 12, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.SeparableConv2D = sk.nn.SeparableConv2D(None, 1, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2, 2))) + Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - https://github.com/google/flax/blob/main/flax/linen/linear.py @@ -1846,6 +2431,30 @@ class SeparableConv3D(SeparableConvND): >>> l1(jnp.ones((3, 32, 32, 32))).shape (3, 32, 32, 32) + Note: + :class:`.SeparableConv3D` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.SeparableConv3D = sk.nn.SeparableConv3D(None, 12, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.SeparableConv3D = sk.nn.SeparableConv3D(None, 1, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2, 2, 2))) + Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - https://github.com/google/flax/blob/main/flax/linen/linear.py @@ -1911,7 +2520,6 @@ class SeparableFFTConv1D(SeparableConvND): ``zeros``. set to ``None`` to not use a bias. key: key to use for initializing the weights. defaults to ``0``. - Example: >>> import jax.numpy as jnp >>> import serket as sk @@ -1919,6 +2527,30 @@ class SeparableFFTConv1D(SeparableConvND): >>> l1(jnp.ones((3, 32))).shape (3, 32) + Note: + :class:`.SeparableFFTConv1D` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.SeparableFFTConv1D = sk.nn.SeparableFFTConv1D(None, 12, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.SeparableFFTConv1D = sk.nn.SeparableFFTConv1D(None, 1, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2))) + References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - https://github.com/google/flax/blob/main/flax/linen/linear.py @@ -1991,6 +2623,30 @@ class SeparableFFTConv2D(SeparableConvND): >>> l1(jnp.ones((3, 32, 32))).shape (3, 32, 32) + Note: + :class:`.SeparableFFTConv2D` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.SeparableFFTConv2D = sk.nn.SeparableFFTConv2D(None, 12, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.SeparableFFTConv2D = sk.nn.SeparableFFTConv2D(None, 1, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2, 2))) + References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - https://github.com/google/flax/blob/main/flax/linen/linear.py @@ -2063,6 +2719,30 @@ class SeparableFFTConv3D(SeparableConvND): >>> l1(jnp.ones((3, 32, 32, 32))).shape (3, 32, 32, 32) + Note: + :class:`.SeparableFFTConv3D` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.SeparableFFTConv3D = sk.nn.SeparableFFTConv3D(None, 12, 3, key=jr.PRNGKey(1)) + ... l2: sk.nn.SeparableFFTConv3D = sk.nn.SeparableFFTConv3D(None, 1, 3, key=jr.PRNGKey(2)) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2, 2, 2))) + References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - https://github.com/google/flax/blob/main/flax/linen/linear.py @@ -2081,10 +2761,13 @@ def _depthwise_convolution_layer(self): return DepthwiseFFTConv3D -class BaseConvNDLocal(sk.TreeClass): +convlocal_updates = {**conv_updates, "in_size": infer_in_size} + + +class ConvNDLocal(sk.TreeClass): def __init__( self, - in_features: int, + in_features: int | None, out_features: int, kernel_size: KernelSizeType, *, @@ -2096,20 +2779,34 @@ def __init__( bias_init: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - # checked by callbacks - self.in_features = positive_int_cb(in_features) + self.in_features = positive_int_or_none_cb(in_features) self.out_features = positive_int_cb(out_features) - self.kernel_size = canonicalize( - kernel_size, self.spatial_ndim, name="kernel_size" - ) - self.in_size = canonicalize(in_size, self.spatial_ndim, name="in_size") - self.strides = canonicalize(strides, self.spatial_ndim, name="strides") - self.padding = delayed_canonicalize_padding( - self.in_size, padding, self.kernel_size, self.strides + self.kernel_size = canonicalize(kernel_size, self.spatial_ndim, "kernel_size") + self.in_size = ( + canonicalize(in_size, self.spatial_ndim, name="in_size") + if in_size is not None + else None ) - self.dilation = canonicalize(dilation, self.spatial_ndim, name="dilation") - weight_init = resolve_init_func(weight_init) - bias_init = resolve_init_func(bias_init) + self.strides = canonicalize(strides, self.spatial_ndim, "strides") + + if in_size is None: + self.padding = padding + + else: + self.padding = delayed_canonicalize_padding( + self.in_size, + padding, + self.kernel_size, + self.strides, + ) + + self.dilation = canonicalize(dilation, self.spatial_ndim, "dilation") + self.weight_init = weight_init + self.bias_init = bias_init + + if self.in_features is None or self.in_size is None: + self.key = key + return out_size = calculate_convolution_output_shape( shape=self.in_size, @@ -2125,15 +2822,25 @@ def __init__( *out_size, ) - self.weight = weight_init(key, weight_shape) + self.weight = resolve_init_func(self.weight_init)(key, weight_shape) bias_shape = (self.out_features, *out_size) - self.bias = bias_init(key, bias_shape) + self.bias = resolve_init_func(self.bias_init)(key, bias_shape) + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=convlocal_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array, **k) -> jax.Array: - x = self._convolution_operation(jnp.expand_dims(x, 0)) + x = jax.lax.conv_general_dilated_local( + lhs=jnp.expand_dims(x, 0), + rhs=self.weight, + window_strides=self.strides, + padding=self.padding, + filter_shape=self.kernel_size, + rhs_dilation=self.dilation, + dimension_numbers=generate_conv_dim_numbers(self.spatial_ndim), + ) + if self.bias is None: return jnp.squeeze(x, 0) return jnp.squeeze((x + self.bias), 0) @@ -2144,23 +2851,6 @@ def spatial_ndim(self) -> int: """Number of spatial dimensions of the convolutional layer.""" ... - @abc.abstractmethod - def _convolution_operation(self, x: jax.Array) -> jax.Array: - ... - - -class ConvNDLocal(BaseConvNDLocal): - def _convolution_operation(self, x: jax.Array) -> jax.Array: - return jax.lax.conv_general_dilated_local( - lhs=x, - rhs=self.weight, - window_strides=self.strides, - padding=self.padding, - filter_shape=self.kernel_size, - rhs_dilation=self.dilation, - dimension_numbers=generate_conv_dim_numbers(self.spatial_ndim), - ) - class Conv1DLocal(ConvNDLocal): """1D Local convolutional layer. @@ -2169,7 +2859,6 @@ class Conv1DLocal(ConvNDLocal): kernel is applied to a local region of the input. The kernel weights are *not* shared across the spatial dimensions of the input. - Args: in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input @@ -2213,6 +2902,35 @@ class Conv1DLocal(ConvNDLocal): >>> l1(jnp.ones((3, 32))).shape (3, 32) + Note: + :class:`.Conv1DLocal` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> k1, k2 = jr.split(jr.PRNGKey(0)) + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.Conv1DLocal = sk.nn.Conv1DLocal(None, 12, 3, in_size=None, key=k1) + ... l2: sk.nn.Conv1DLocal = sk.nn.Conv1DLocal(None, 5, 3, in_size=None, key=k2) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> print(lazy_cnn.l1.in_features, lazy_cnn.l2.in_features) + None None + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2))) + >>> print(materialized_cnn.l1.in_features, materialized_cnn.l2.in_features) + 5 12 + Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - https://github.com/google/flax/blob/main/flax/linen/linear.py @@ -2230,7 +2948,6 @@ class Conv2DLocal(ConvNDLocal): kernel is applied to a local region of the input. This means that the kernel weights are *not* shared across the spatial dimensions of the input. - Args: in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input @@ -2274,6 +2991,35 @@ class Conv2DLocal(ConvNDLocal): >>> l1(jnp.ones((3, 32, 32))).shape (3, 32, 32) + Note: + :class:`.Conv2DLocal` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> k1, k2 = jr.split(jr.PRNGKey(0)) + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.Conv2DLocal = sk.nn.Conv2DLocal(None, 12, 3, in_size=None, key=k1) + ... l2: sk.nn.Conv2DLocal = sk.nn.Conv2DLocal(None, 5, 3, in_size=None, key=k2) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> print(lazy_cnn.l1.in_features, lazy_cnn.l2.in_features) + None None + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2, 2))) + >>> print(materialized_cnn.l1.in_features, materialized_cnn.l2.in_features) + 5 12 + Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - https://github.com/google/flax/blob/main/flax/linen/linear.py @@ -2291,7 +3037,6 @@ class Conv3DLocal(ConvNDLocal): kernel is applied to a local region of the input. This means that the kernel weights are *not* shared across the spatial dimensions of the input. - Args: in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input @@ -2335,6 +3080,35 @@ class Conv3DLocal(ConvNDLocal): >>> l1(jnp.ones((3, 32, 32, 32))).shape (3, 32, 32, 32) + Note: + :class:`.Conv3DLocal` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> k1, k2 = jr.split(jr.PRNGKey(0)) + >>> @sk.autoinit + ... class CNN(sk.TreeClass): + ... l1: sk.nn.Conv3DLocal = sk.nn.Conv3DLocal(None, 12, 3, in_size=None, key=k1) + ... l2: sk.nn.Conv3DLocal = sk.nn.Conv3DLocal(None, 5, 3, in_size=None, key=k2) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_cnn = CNN() + >>> print(lazy_cnn.l1.in_features, lazy_cnn.l2.in_features) + None None + >>> # materialize the layer + >>> _, materialized_cnn = lazy_cnn.at["__call__"](jnp.ones((5, 2, 2, 2))) + >>> print(materialized_cnn.l1.in_features, materialized_cnn.l2.in_features) + 5 12 + Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - https://github.com/google/flax/blob/main/flax/linen/linear.py diff --git a/serket/nn/utils.py b/serket/nn/utils.py index 59e2980..a305145 100644 --- a/serket/nn/utils.py +++ b/serket/nn/utils.py @@ -164,7 +164,7 @@ def delayed_canonicalize_padding( ) -def canonicalize(value, ndim, *, name: str | None = None): +def canonicalize(value, ndim, name: str | None = None): if isinstance(value, int): return (value,) * ndim if isinstance(value, jax.Array): @@ -247,6 +247,17 @@ def positive_int_cb(value): return value +def positive_int_or_none_cb(value): + """Return if value is a positive integer, otherwise raise an error.""" + if value is None: + return value + if not isinstance(value, int): + raise ValueError(f"value must be an integer, got {type(value).__name__}") + if value <= 0: + raise ValueError(f"{value=} must be positive.") + return value + + def recursive_getattr(obj, attr: Sequence[str]): return ( getattr(obj, attr[0]) @@ -264,7 +275,7 @@ def check_spatial_in_shape(x, spatial_ndim: int) -> None: spatial = {", ".join(("rows", "cols", "depths")[:spatial_ndim])} raise ValueError( f"Input should satisfy:\n" - f"- {spatial_ndim+1=} dimension, got {x.ndim=}.\n" + f"- {(spatial_ndim + 1)=} dimension, got {x.ndim=}.\n" f"- shape of (in_features, {spatial}), got {x.shape=}.\n" + ( # maybe the user apply the layer on a batched input @@ -293,6 +304,9 @@ def validate_axis_shape( attribute_list = attribute_name.split(".") def check_axis_shape(x, in_features: int, axis: int) -> None: + if in_features is None: + # lazy initialization + return x if x.shape[axis] != in_features: raise ValueError(f"Specified {in_features=}, got {x.shape[axis]=}.") return x @@ -303,3 +317,30 @@ def wrapper(self, array, *a, **k): return func(self, array, *a, **k) return wrapper + + +def maybe_lazy_call( + func: Callable[P, T], + is_lazy: Callable[..., bool], + updates: dict[str, Callable[..., Any]], +) -> Callable[P, T]: + """Reinitialize the instance if it is lazy.""" + + @ft.wraps(func) + def inner(instance, *a, **k): + if not is_lazy(instance, *a, **k): + return func(instance, *a, **k) + + kwargs = dict(vars(instance)) + for key, update in updates.items(): + kwargs[key] = update(instance, *a, **k) + + # clear the instance information + for key in kwargs: + delattr(instance, key) + # re-initialize the instance + getattr(type(instance), "__init__")(instance, **kwargs) + # call the decorated function + return func(instance, *a, **k) + + return inner diff --git a/tests/test_convolution.py b/tests/test_convolution.py index 45ec023..3ef3351 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -1001,75 +1001,20 @@ def test_groups_error(): Conv3DTranspose(1, 1, 3, groups=0) -# def test_lazy_conv(): -# layer = Conv1D(None, 1, 3) -# assert layer(jnp.ones([10, 3])).shape == (1, 3) - -# layer = Conv2D(None, 1, 3) -# assert layer(jnp.ones([10, 3, 3])).shape == (1, 3, 3) - -# layer = Conv3D(None, 1, 3) -# assert layer(jnp.ones([10, 3, 3, 3])).shape == (1, 3, 3, 3) - -# layer = Conv1DTranspose(None, 1, 3) -# assert layer(jnp.ones([10, 3])).shape == (1, 3) - -# layer = Conv2DTranspose(None, 1, 3) -# assert layer(jnp.ones([10, 3, 3])).shape == (1, 3, 3) - -# layer = Conv3DTranspose(None, 1, 3) -# assert layer(jnp.ones([10, 3, 3, 3])).shape == (1, 3, 3, 3) - -# layer = DepthwiseConv1D(None, 3) -# assert layer(jnp.ones([10, 3])).shape == (10, 3) - -# layer = DepthwiseConv2D(None, 3) -# assert layer(jnp.ones([10, 3, 3])).shape == (10, 3, 3) - -# layer = Conv1DLocal(None, 1, 3, in_size=(3,)) -# assert layer(jnp.ones([10, 3])).shape == (1, 3) - -# layer = Conv2DLocal(None, 1, 3, in_size=(3, 3)) -# assert layer(jnp.ones([10, 3, 3])).shape == (1, 3, 3) - -# layer = SeparableConv1D(None, 1, 3) -# assert layer(jnp.ones([10, 3])).shape == (1, 3) - -# layer = SeparableConv2D(None, 1, 3) -# assert layer(jnp.ones([10, 3, 3])).shape == (1, 3, 3) - -# with pytest.raises(ConcretizationTypeError): -# jax.jit(Conv1D(None, 1, 3))(jnp.ones([10, 3])) - -# with pytest.raises(ConcretizationTypeError): -# jax.jit(Conv2D(None, 1, 3))(jnp.ones([10, 3, 3])) - -# with pytest.raises(ConcretizationTypeError): -# jax.jit(Conv3D(None, 1, 3))(jnp.ones([10, 3, 3, 3])) - -# with pytest.raises(ConcretizationTypeError): -# jax.jit(Conv1DTranspose(None, 1, 3))(jnp.ones([10, 3])) - -# with pytest.raises(ConcretizationTypeError): -# jax.jit(Conv2DTranspose(None, 1, 3))(jnp.ones([10, 3, 3])) - -# with pytest.raises(ConcretizationTypeError): -# jax.jit(Conv3DTranspose(None, 1, 3))(jnp.ones([10, 3, 3, 3])) - -# with pytest.raises(ConcretizationTypeError): -# jax.jit(DepthwiseConv1D(None, 3))(jnp.ones([10, 3])) - -# with pytest.raises(ConcretizationTypeError): -# jax.jit(DepthwiseConv2D(None, 3))(jnp.ones([10, 3, 3])) - -# with pytest.raises(ConcretizationTypeError): -# jax.jit(Conv1DLocal(None, 1, 3, in_size=(3,)))(jnp.ones([10, 3])) - -# with pytest.raises(ConcretizationTypeError): -# jax.jit(Conv2DLocal(None, 1, 3, in_size=(3, 3)))(jnp.ones([10, 3, 3])) - -# with pytest.raises(ConcretizationTypeError): -# jax.jit(SeparableConv1D(None, 1, 3))(jnp.ones([10, 3])) +@pytest.mark.parametrize( + "layer,array,expected_shape", + [ + [Conv1D, jnp.ones([10, 3]), (1, 3)], + [Conv2D, jnp.ones([10, 3, 3]), (1, 3, 3)], + [Conv3D, jnp.ones([10, 3, 3, 3]), (1, 3, 3, 3)], + [Conv1DTranspose, jnp.ones([10, 3]), (1, 3)], + [Conv2DTranspose, jnp.ones([10, 3, 3]), (1, 3, 3)], + [Conv3DTranspose, jnp.ones([10, 3, 3, 3]), (1, 3, 3, 3)], + ], +) +def test_lazy_conv(layer, array, expected_shape): + lazy_layer = layer(None, 1, 3) + value, materialized_layer = lazy_layer.at["__call__"](array) -# with pytest.raises(ConcretizationTypeError): -# jax.jit(SeparableConv2D(None, 1, 3))(jnp.ones([10, 3, 3])) + assert value.shape == expected_shape + assert materialized_layer.in_features == 10