diff --git a/serket/nn/convolution.py b/serket/nn/convolution.py index b425244..69e7e18 100644 --- a/serket/nn/convolution.py +++ b/serket/nn/convolution.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Convolutional layers.""" + from __future__ import annotations import abc @@ -42,11 +44,115 @@ @ft.lru_cache(maxsize=None) -def generate_conv_dim_numbers(spatial_ndim): +def generate_conv_dim_numbers(spatial_ndim) -> ConvDimensionNumbers: return ConvDimensionNumbers(*((tuple(range(spatial_ndim + 2)),) * 3)) -class ConvND(sk.TreeClass): +@ft.partial(jax.jit, inline=True) +def _ungrouped_matmul(x, y) -> jax.Array: + alpha = "".join(map(str, range(max(x.ndim, y.ndim)))) + lhs = "a" + alpha[: x.ndim - 1] + rhs = "b" + alpha[: y.ndim - 1] + out = "ab" + lhs[2:] + return jnp.einsum(f"{lhs},{rhs}->{out}", x, y) + + +@ft.partial(jax.jit, static_argnums=(2,), inline=True) +def _grouped_matmul(x, y, groups) -> jax.Array: + b, c, *s = x.shape # batch, channels, spatial + o, i, *k = y.shape # out_channels, in_channels, kernel + x = x.reshape(groups, b, c // groups, *s) # groups, batch, channels, spatial + y = y.reshape(groups, o // groups, *(i, *k)) + z = jax.vmap(_ungrouped_matmul, in_axes=(0, 0), out_axes=1)(x, y) + return z.reshape(z.shape[0], z.shape[1] * z.shape[2], *z.shape[3:]) + + +def grouped_matmul(x, y, groups: int = 1): + return _ungrouped_matmul(x, y) if groups == 1 else _grouped_matmul(x, y, groups) + + +@ft.partial(jax.jit, static_argnums=(1, 2), inline=True) +def _intersperse_along_axis(x: jax.Array, dilation: int, axis: int) -> jax.Array: + shape = list(x.shape) + shape[axis] = (dilation) * shape[axis] - (dilation - 1) + z = jnp.zeros(shape) + z = z.at[(slice(None),) * axis + (slice(None, None, (dilation)),)].set(x) + return z + + +@ft.partial(jax.jit, static_argnums=(1, 2), inline=True) +def _general_intersperse( + x: jax.Array, + dilation: tuple[int, ...], + axis: tuple[int, ...], +) -> jax.Array: + for di, ai in zip(dilation, axis): + x = _intersperse_along_axis(x, di, ai) if di > 1 else x + return x + + +@ft.partial(jax.jit, static_argnums=(1,), inline=True) +def _general_pad(x: jax.Array, pad_width: tuple[tuple[int, int], ...]) -> jax.Array: + """Pad the input with `pad_width` on each side. Negative value will lead to cropping. + Example: + >>> print(_general_pad(jnp.ones([3,3]),((0,0),(-1,1)))) # DOCTEST: +NORMALIZE_WHITESPACE + [[1. 1. 0.] + [1. 1. 0.] + [1. 1. 0.]] + """ + + for axis, (lhs, rhs) in enumerate(pad_width := list(pad_width)): + if lhs < 0 and rhs < 0: + x = jax.lax.dynamic_slice_in_dim(x, -lhs, x.shape[axis] + lhs + rhs, axis) + elif lhs < 0: + x = jax.lax.dynamic_slice_in_dim(x, -lhs, x.shape[axis] + lhs, axis) + elif rhs < 0: + x = jax.lax.dynamic_slice_in_dim(x, 0, x.shape[axis] + rhs, axis) + + return jnp.pad(x, [(max(lhs, 0), max(rhs, 0)) for (lhs, rhs) in (pad_width)]) + + +@ft.partial(jax.jit, static_argnums=(2, 3, 4, 5), inline=True) +def fft_conv_general_dilated( + lhs: jax.Array, + rhs: jax.Array, + strides: tuple[int, ...], + padding: tuple[tuple[int, int], ...], + groups: int, + dilation: tuple[int, ...], +) -> jax.Array: + spatial_ndim = lhs.ndim - 2 # spatial dimensions + rhs = _general_intersperse(rhs, dilation=dilation, axis=range(2, 2 + spatial_ndim)) + lhs = _general_pad(lhs, ((0, 0), (0, 0), *padding)) + + x_shape, w_shape = lhs.shape, rhs.shape + + if lhs.shape[-1] % 2 != 0: + lhs = jnp.pad(lhs, tuple([(0, 0)] * (lhs.ndim - 1) + [(0, 1)])) + + kernel_padding = ( + (0, lhs.shape[i] - rhs.shape[i]) for i in range(2, spatial_ndim + 2) + ) + rhs = _general_pad(rhs, ((0, 0), (0, 0), *kernel_padding)) + + # for real-valued input + x_fft = jnp.fft.rfftn(lhs, axes=range(2, spatial_ndim + 2)) + w_fft = jnp.conjugate(jnp.fft.rfftn(rhs, axes=range(2, spatial_ndim + 2))) + z_fft = grouped_matmul(x_fft, w_fft, groups) + + z = jnp.fft.irfftn(z_fft, axes=range(2, spatial_ndim + 2)) + + start = (0,) * (spatial_ndim + 2) + end = [z.shape[0], z.shape[1]] + end += [max((x_shape[i] - w_shape[i] + 1), 0) for i in range(2, spatial_ndim + 2)] + + if all(s == 1 for s in strides): + return jax.lax.dynamic_slice(z, start, end) + + return jax.lax.slice(z, start, end, (1, 1, *strides)) + + +class BaseConvND(sk.TreeClass): def __init__( self, in_features: int, @@ -55,8 +161,7 @@ def __init__( *, strides: StridesType = 1, padding: PaddingType = "same", - input_dilation: DilationType = 1, - kernel_dilation: DilationType = 1, + dilation: DilationType = 1, weight_init: InitType = "glorot_uniform", bias_init: InitType = "zeros", groups: int = 1, @@ -65,22 +170,11 @@ def __init__( self.in_features = positive_int_cb(in_features) self.out_features = positive_int_cb(out_features) self.kernel_size = canonicalize( - kernel_size, - self.spatial_ndim, - name="kernel_size", + kernel_size, self.spatial_ndim, name="kernel_size" ) self.strides = canonicalize(strides, self.spatial_ndim, name="strides") - self.padding = padding # delayed canonicalization - self.input_dilation = canonicalize( - input_dilation, - self.spatial_ndim, - name="input_dilation", - ) - self.kernel_dilation = canonicalize( - kernel_dilation, - self.spatial_ndim, - name="kernel_dilation", - ) + self.padding = padding + self.dilation = canonicalize(dilation, self.spatial_ndim, name="dilation") weight_init = resolve_init_func(weight_init) bias_init = resolve_init_func(bias_init) @@ -99,34 +193,42 @@ def __init__( @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array, **k) -> jax.Array: + x = self.convolution_operation(jnp.expand_dims(x, 0)) + if self.bias is None: + return jnp.squeeze(x, 0) + return jnp.squeeze((x + self.bias), 0) + + @property + @abc.abstractmethod + def spatial_ndim(self) -> int: + """Number of spatial dimensions of the convolutional layer.""" + ... + + @abc.abstractmethod + def convolution_operation(self, x: jax.Array) -> jax.Array: + """Convolution operation.""" + ... + + +class ConvND(BaseConvND): + def convolution_operation(self, x: jax.Array) -> jax.Array: padding = delayed_canonicalize_padding( - in_dim=x.shape[1:], + in_dim=x.shape[2:], padding=self.padding, kernel_size=self.kernel_size, strides=self.strides, ) - x = jax.lax.conv_general_dilated( - lhs=jnp.expand_dims(x, 0), + return jax.lax.conv_general_dilated( + lhs=x, rhs=self.weight, window_strides=self.strides, padding=padding, - lhs_dilation=self.input_dilation, - rhs_dilation=self.kernel_dilation, + rhs_dilation=self.dilation, dimension_numbers=generate_conv_dim_numbers(self.spatial_ndim), feature_group_count=self.groups, ) - if self.bias is None: - return jnp.squeeze(x, 0) - return jnp.squeeze((x + self.bias), 0) - - @property - @abc.abstractmethod - def spatial_ndim(self) -> int: - """Number of spatial dimensions of the convolutional layer.""" - ... - class Conv1D(ConvND): """1D Convolutional layer. @@ -159,12 +261,7 @@ class Conv1D(ConvND): as the input. - ``valid``/``VALID`` for no padding. - input_dilation: dilation of the input. accepts: - - - single integer for same dilation in all dimensions. - - sequence of integers for different dilation in each dimension. - - kernel_dilation: Dilation of the convolutional kernel accepts: + dilation: Dilation of the convolutional kernel accepts: - single integer for same dilation in all dimensions. - sequence of integers for different dilation in each dimension. @@ -230,12 +327,7 @@ class Conv2D(ConvND): as the input. - ``valid``/``VALID`` for no padding. - input_dilation: Dilation of the input. accepts: - - - single integer for same dilation in all dimensions. - - sequence of integers for different dilation in each dimension. - - kernel_dilation: dilation of the convolutional kernel accepts: + dilation: dilation of the convolutional kernel accepts: - single integer for same dilation in all dimensions. - sequence of integers for different dilation in each dimension. @@ -301,12 +393,7 @@ class Conv3D(ConvND): as the input. - ``valid``/``VALID`` for no padding. - input_dilation: Dilation of the input. accepts: - - - single integer for same dilation in all dimensions. - - sequence of integers for different dilation in each dimension. - - kernel_dilation: dilation of the convolutional kernel accepts: + dilation: dilation of the convolutional kernel accepts: - single integer for same dilation in all dimensions. - sequence of integers for different dilation in each dimension. @@ -341,91 +428,27 @@ def spatial_ndim(self) -> int: return 3 -class ConvNDTranspose(sk.TreeClass): - def __init__( - self, - in_features: int, - out_features: int, - kernel_size: KernelSizeType, - *, - strides: StridesType = 1, - padding: PaddingType = "same", - output_padding: int = 0, - kernel_dilation: DilationType = 1, - weight_init: InitType = "glorot_uniform", - bias_init: InitType = "zeros", - groups: int = 1, - key: jr.KeyArray = jr.PRNGKey(0), - ): - self.in_features = positive_int_cb(in_features) - self.out_features = positive_int_cb(out_features) - self.kernel_size = canonicalize( - kernel_size, self.spatial_ndim, name="kernel_size" - ) - self.strides = canonicalize(strides, self.spatial_ndim, name="strides") - self.padding = padding # delayed canonicalization - self.output_padding = canonicalize( - output_padding, - self.spatial_ndim, - name="output_padding", - ) - self.kernel_dilation = canonicalize( - kernel_dilation, - self.spatial_ndim, - name="kernel_dilation", - ) - weight_init = resolve_init_func(weight_init) - bias_init = resolve_init_func(bias_init) - self.groups = positive_int_cb(groups) - - if self.out_features % self.groups != 0: - raise ValueError(f"{(self.out_features % self.groups ==0)=}") - - weight_shape = (out_features, in_features // groups, *self.kernel_size) # OIHW - self.weight = weight_init(key, weight_shape) - - bias_shape = (out_features, *(1,) * self.spatial_ndim) - self.bias = bias_init(key, bias_shape) - - @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) - def __call__(self, x: jax.Array, **k) -> jax.Array: +class FFTConvND(BaseConvND): + def convolution_operation(self, x: jax.Array) -> jax.Array: padding = delayed_canonicalize_padding( - in_dim=x.shape[1:], + in_dim=x.shape[2:], padding=self.padding, kernel_size=self.kernel_size, strides=self.strides, ) - transposed_padding = calculate_transpose_padding( - padding=padding, - extra_padding=self.output_padding, - kernel_size=self.kernel_size, - input_dilation=self.kernel_dilation, - ) - - y = jax.lax.conv_transpose( - lhs=jnp.expand_dims(x, 0), + return fft_conv_general_dilated( + lhs=x, rhs=self.weight, strides=self.strides, - padding=transposed_padding, - rhs_dilation=self.kernel_dilation, - dimension_numbers=generate_conv_dim_numbers(self.spatial_ndim), + padding=padding, + groups=self.groups, + dilation=self.dilation, ) - if self.bias is None: - return jnp.squeeze(y, 0) - return jnp.squeeze(y + self.bias, 0) - - @property - @abc.abstractmethod - def spatial_ndim(self) -> int: - """Number of spatial dimensions of the convolutional layer.""" - ... - -class Conv1DTranspose(ConvNDTranspose): - """1D Convolution transpose layer. +class FFTConv1D(FFTConvND): + """1D Convolutional layer. Args: in_features: Number of input feature maps, for 1D convolution this is the @@ -455,18 +478,14 @@ class Conv1DTranspose(ConvNDTranspose): as the input. - ``valid``/``VALID`` for no padding. - output_padding: padding of the output after convolution. accepts: - - - single integer for same padding in all dimensions. - - kernel_dilation: dilation of the convolutional kernel accepts: + dilation: Dilation of the convolutional kernel accepts: - single integer for same dilation in all dimensions. - sequence of integers for different dilation in each dimension. - weight_init: Function to use for initializing the weights. defaults + weight_init: function to use for initializing the weights. defaults to ``glorot uniform``. - bias_init: Function to use for initializing the bias. defaults to + bias_init: function to use for initializing the bias. defaults to ``zeros``. set to ``None`` to not use a bias. groups: number of groups to use for grouped convolution. key: key to use for initializing the weights. defaults to ``0``. @@ -475,7 +494,7 @@ class Conv1DTranspose(ConvNDTranspose): >>> import jax.numpy as jnp >>> import serket as sk >>> import jax - >>> layer = sk.nn.Conv1DTranspose(1, 2, 3) + >>> layer = sk.nn.FFTConv1D(in_features=1, out_features=2, kernel_size=3) >>> # single sample >>> x = jnp.ones((1, 5)) >>> print(layer(x).shape) @@ -485,8 +504,8 @@ class Conv1DTranspose(ConvNDTranspose): >>> print(jax.vmap(layer)(x).shape) (2, 2, 5) - Reference: - - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + References: + https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html """ @property @@ -494,8 +513,8 @@ def spatial_ndim(self) -> int: return 1 -class Conv2DTranspose(ConvNDTranspose): - """2D Convolution transpose layer. +class FFTConv2D(FFTConvND): + """2D FFT Convolutional layer. Args: in_features: Number of input feature maps, for 1D convolution this is the @@ -525,18 +544,14 @@ class Conv2DTranspose(ConvNDTranspose): as the input. - ``valid``/``VALID`` for no padding. - output_padding: padding of the output after convolution. accepts: - - - single integer for same padding in all dimensions. - - kernel_dilation: dilation of the convolutional kernel accepts: + dilation: Dilation of the convolutional kernel accepts: - single integer for same dilation in all dimensions. - sequence of integers for different dilation in each dimension. - weight_init: Function to use for initializing the weights. defaults + weight_init: function to use for initializing the weights. defaults to ``glorot uniform``. - bias_init: Function to use for initializing the bias. defaults to + bias_init: function to use for initializing the bias. defaults to ``zeros``. set to ``None`` to not use a bias. groups: number of groups to use for grouped convolution. key: key to use for initializing the weights. defaults to ``0``. @@ -545,7 +560,7 @@ class Conv2DTranspose(ConvNDTranspose): >>> import jax.numpy as jnp >>> import serket as sk >>> import jax - >>> layer = sk.nn.Conv2DTranspose(1, 2, 3) + >>> layer = sk.nn.FFTConv2D(in_features=1, out_features=2, kernel_size=3) >>> # single sample >>> x = jnp.ones((1, 5, 5)) >>> print(layer(x).shape) @@ -555,7 +570,7 @@ class Conv2DTranspose(ConvNDTranspose): >>> print(jax.vmap(layer)(x).shape) (2, 2, 5, 5) - Reference: + References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html """ @@ -564,8 +579,8 @@ def spatial_ndim(self) -> int: return 2 -class Conv3DTranspose(ConvNDTranspose): - """3D Convolution transpose layer. +class FFTConv3D(FFTConvND): + """3D FFT Convolutional layer. Args: in_features: Number of input feature maps, for 1D convolution this is the @@ -595,18 +610,14 @@ class Conv3DTranspose(ConvNDTranspose): as the input. - ``valid``/``VALID`` for no padding. - output_padding: padding of the output after convolution. accepts: - - - single integer for same padding in all dimensions. - - kernel_dilation: dilation of the convolutional kernel accepts: + dilation: Dilation of the convolutional kernel accepts: - single integer for same dilation in all dimensions. - sequence of integers for different dilation in each dimension. - weight_init: Function to use for initializing the weights. defaults + weight_init: function to use for initializing the weights. defaults to ``glorot uniform``. - bias_init: Function to use for initializing the bias. defaults to + bias_init: function to use for initializing the bias. defaults to ``zeros``. set to ``None`` to not use a bias. groups: number of groups to use for grouped convolution. key: key to use for initializing the weights. defaults to ``0``. @@ -615,7 +626,7 @@ class Conv3DTranspose(ConvNDTranspose): >>> import jax.numpy as jnp >>> import serket as sk >>> import jax - >>> layer = sk.nn.Conv3DTranspose(1, 2, 3) + >>> layer = sk.nn.FFTConv3D(in_features=1, out_features=2, kernel_size=3) >>> # single sample >>> x = jnp.ones((1, 5, 5, 5)) >>> print(layer(x).shape) @@ -625,7 +636,7 @@ class Conv3DTranspose(ConvNDTranspose): >>> print(jax.vmap(layer)(x).shape) (2, 2, 5, 5, 5) - Reference: + References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html """ @@ -634,88 +645,107 @@ def spatial_ndim(self) -> int: return 3 -class DepthwiseConvND(sk.TreeClass): +class BaseConvNDTranspose(sk.TreeClass): def __init__( self, in_features: int, + out_features: int, kernel_size: KernelSizeType, *, - depth_multiplier: int = 1, - strides: int = 1, + strides: StridesType = 1, padding: PaddingType = "same", + output_padding: int = 0, + dilation: DilationType = 1, weight_init: InitType = "glorot_uniform", bias_init: InitType = "zeros", + groups: int = 1, key: jr.KeyArray = jr.PRNGKey(0), ): self.in_features = positive_int_cb(in_features) + self.out_features = positive_int_cb(out_features) self.kernel_size = canonicalize( kernel_size, self.spatial_ndim, name="kernel_size" ) - self.depth_multiplier = positive_int_cb(depth_multiplier) self.strides = canonicalize(strides, self.spatial_ndim, name="strides") self.padding = padding # delayed canonicalization - self.input_dilation = canonicalize(1, self.spatial_ndim, name="input_dilation") - self.kernel_dilation = canonicalize( - 1, + self.output_padding = canonicalize( + output_padding, self.spatial_ndim, - name="kernel_dilation", + name="output_padding", ) + self.dilation = canonicalize(dilation, self.spatial_ndim, name="dilation") weight_init = resolve_init_func(weight_init) bias_init = resolve_init_func(bias_init) + self.groups = positive_int_cb(groups) - weight_shape = (depth_multiplier * in_features, 1, *self.kernel_size) # OIHW + if self.out_features % self.groups != 0: + raise ValueError(f"{(self.out_features % self.groups ==0)=}") + + weight_shape = (out_features, in_features // groups, *self.kernel_size) # OIHW self.weight = weight_init(key, weight_shape) - bias_shape = (depth_multiplier * in_features, *(1,) * self.spatial_ndim) + bias_shape = (out_features, *(1,) * self.spatial_ndim) self.bias = bias_init(key, bias_shape) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array, **k) -> jax.Array: + y = self.convolution_operation(jnp.expand_dims(x, 0)) + if self.bias is None: + return jnp.squeeze(y, 0) + return jnp.squeeze(y + self.bias, 0) + + @property + @abc.abstractmethod + def spatial_ndim(self) -> int: + """Number of spatial dimensions of the convolutional layer.""" + ... + + +class ConvNDTranspose(BaseConvNDTranspose): + def convolution_operation(self, x: jax.Array) -> jax.Array: padding = delayed_canonicalize_padding( - in_dim=x.shape[1:], + in_dim=x.shape[2:], padding=self.padding, kernel_size=self.kernel_size, strides=self.strides, ) - y = jax.lax.conv_general_dilated( - lhs=jnp.expand_dims(x, axis=0), - rhs=self.weight, - window_strides=self.strides, + transposed_padding = calculate_transpose_padding( padding=padding, - lhs_dilation=self.input_dilation, - rhs_dilation=self.kernel_dilation, - dimension_numbers=generate_conv_dim_numbers(self.spatial_ndim), - feature_group_count=self.in_features, + extra_padding=self.output_padding, + kernel_size=self.kernel_size, + input_dilation=self.dilation, ) - if self.bias is None: - return jnp.squeeze(y, 0) - return jnp.squeeze((y + self.bias), 0) + # breakpoint() - @property - @abc.abstractmethod - def spatial_ndim(self) -> int: - """Number of spatial dimensions of the convolutional layer.""" - ... + return jax.lax.conv_transpose( + lhs=x, + rhs=self.weight, + strides=self.strides, + padding=transposed_padding, + rhs_dilation=self.dilation, + dimension_numbers=generate_conv_dim_numbers(self.spatial_ndim), + ) -class DepthwiseConv1D(DepthwiseConvND): - """1D Depthwise convolution layer. +class Conv1DTranspose(ConvNDTranspose): + """1D Convolution transpose layer. Args: in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input channels, for 3D convolution this is the number of input channels. + out_features: Number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. kernel_size: Size of the convolutional kernel. accepts: - single integer for same kernel size in all dimensions. - sequence of integers for different kernel sizes in each dimension. - depth_multiplier: multiplier for the number of output channels. for example - if the input has 32 channels and the depth multiplier is 2 then the - output will have 64 channels. strides: Stride of the convolution. accepts: - single integer for same stride in all dimensions. @@ -731,23 +761,38 @@ class DepthwiseConv1D(DepthwiseConvND): as the input. - ``valid``/``VALID`` for no padding. + output_padding: padding of the output after convolution. accepts: + + - single integer for same padding in all dimensions. + + dilation: dilation of the convolutional kernel accepts: + + - single integer for same dilation in all dimensions. + - sequence of integers for different dilation in each dimension. + weight_init: Function to use for initializing the weights. defaults to ``glorot uniform``. bias_init: Function to use for initializing the bias. defaults to ``zeros``. set to ``None`` to not use a bias. + groups: number of groups to use for grouped convolution. key: key to use for initializing the weights. defaults to ``0``. - Example: >>> import jax.numpy as jnp >>> import serket as sk - >>> l1 = sk.nn.DepthwiseConv1D(3, 3, depth_multiplier=2, strides=2) - >>> l1(jnp.ones((3, 32))).shape - (6, 16) + >>> import jax + >>> layer = sk.nn.Conv1DTranspose(1, 2, 3) + >>> # single sample + >>> x = jnp.ones((1, 5)) + >>> print(layer(x).shape) + (2, 5) + >>> # batch of samples + >>> x = jnp.ones((2, 1, 5)) + >>> print(jax.vmap(layer)(x).shape) + (2, 2, 5) Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - - https://github.com/google/flax/blob/main/flax/linen/linear.py """ @property @@ -755,21 +800,22 @@ def spatial_ndim(self) -> int: return 1 -class DepthwiseConv2D(DepthwiseConvND): - """2D Depthwise convolution layer. +class Conv2DTranspose(ConvNDTranspose): + """2D Convolution transpose layer. Args: in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input channels, for 3D convolution this is the number of input channels. + out_features: Number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. kernel_size: Size of the convolutional kernel. accepts: - single integer for same kernel size in all dimensions. - sequence of integers for different kernel sizes in each dimension. - depth_multiplier: multiplier for the number of output channels. for example - if the input has 32 channels and the depth multiplier is 2 then the - output will have 64 channels. strides: Stride of the convolution. accepts: - single integer for same stride in all dimensions. @@ -785,23 +831,38 @@ class DepthwiseConv2D(DepthwiseConvND): as the input. - ``valid``/``VALID`` for no padding. + output_padding: padding of the output after convolution. accepts: + + - single integer for same padding in all dimensions. + + dilation: dilation of the convolutional kernel accepts: + + - single integer for same dilation in all dimensions. + - sequence of integers for different dilation in each dimension. + weight_init: Function to use for initializing the weights. defaults to ``glorot uniform``. bias_init: Function to use for initializing the bias. defaults to ``zeros``. set to ``None`` to not use a bias. + groups: number of groups to use for grouped convolution. key: key to use for initializing the weights. defaults to ``0``. - Example: >>> import jax.numpy as jnp >>> import serket as sk - >>> l1 = sk.nn.DepthwiseConv2D(3, 3, depth_multiplier=2, strides=2) - >>> l1(jnp.ones((3, 32, 32))).shape - (6, 16, 16) + >>> import jax + >>> layer = sk.nn.Conv2DTranspose(1, 2, 3) + >>> # single sample + >>> x = jnp.ones((1, 5, 5)) + >>> print(layer(x).shape) + (2, 5, 5) + >>> # batch of samples + >>> x = jnp.ones((2, 1, 5, 5)) + >>> print(jax.vmap(layer)(x).shape) + (2, 2, 5, 5) Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - - https://github.com/google/flax/blob/main/flax/linen/linear.py """ @property @@ -809,28 +870,28 @@ def spatial_ndim(self) -> int: return 2 -class DepthwiseConv3D(DepthwiseConvND): - """3D Depthwise convolution layer. +class Conv3DTranspose(ConvNDTranspose): + """3D Convolution transpose layer. Args: in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input channels, for 3D convolution this is the number of input channels. - + out_features: Number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. kernel_size: Size of the convolutional kernel. accepts: - single integer for same kernel size in all dimensions. - sequence of integers for different kernel sizes in each dimension. - depth_multiplier: multiplier for the number of output channels. for example - if the input has 32 channels and the depth multiplier is 2 then the - output will have 64 channels. - strides: stride of the convolution. accepts: + strides: Stride of the convolution. accepts: - single integer for same stride in all dimensions. - sequence of integers for different strides in each dimension. - padding: adding of the input before convolution. accepts: + padding: Padding of the input before convolution. accepts: - single integer for same padding in all dimensions. - tuple of integers for different padding in each dimension. @@ -840,23 +901,38 @@ class DepthwiseConv3D(DepthwiseConvND): as the input. - ``valid``/``VALID`` for no padding. + output_padding: padding of the output after convolution. accepts: + + - single integer for same padding in all dimensions. + + dilation: dilation of the convolutional kernel accepts: + + - single integer for same dilation in all dimensions. + - sequence of integers for different dilation in each dimension. + weight_init: Function to use for initializing the weights. defaults to ``glorot uniform``. bias_init: Function to use for initializing the bias. defaults to ``zeros``. set to ``None`` to not use a bias. + groups: number of groups to use for grouped convolution. key: key to use for initializing the weights. defaults to ``0``. - Example: >>> import jax.numpy as jnp >>> import serket as sk - >>> l1 = sk.nn.DepthwiseConv3D(3, 3, depth_multiplier=2, strides=2) - >>> l1(jnp.ones((3, 32, 32, 32))).shape - (6, 16, 16, 16) + >>> import jax + >>> layer = sk.nn.Conv3DTranspose(1, 2, 3) + >>> # single sample + >>> x = jnp.ones((1, 5, 5, 5)) + >>> print(layer(x).shape) + (2, 5, 5, 5) + >>> # batch of samples + >>> x = jnp.ones((2, 1, 5, 5, 5)) + >>> print(jax.vmap(layer)(x).shape) + (2, 2, 5, 5, 5) Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - - https://github.com/google/flax/blob/main/flax/linen/linear.py """ @property @@ -864,73 +940,34 @@ def spatial_ndim(self) -> int: return 3 -class SeparableConvND(sk.TreeClass): - def __init__( - self, - in_features: int, - out_features: int, - kernel_size: KernelSizeType, - *, - depth_multiplier: int = 1, - strides: StridesType = 1, - padding: PaddingType = "same", - depthwise_weight_init: InitType = "glorot_uniform", - pointwise_weight_init: InitType = "glorot_uniform", - pointwise_bias_init: InitType = "zeros", - key: jr.KeyArray = jr.PRNGKey(0), - ): - self.depthwise_conv = self.depthwise_convolution_layer( - in_features=in_features, - depth_multiplier=depth_multiplier, - kernel_size=kernel_size, - strides=strides, - padding=padding, - weight_init=depthwise_weight_init, - bias_init=None, # no bias for lhs - key=key, +class FFTConvNDTranspose(BaseConvNDTranspose): + def convolution_operation(self, x: jax.Array) -> jax.Array: + padding = delayed_canonicalize_padding( + in_dim=x.shape[2:], + padding=self.padding, + kernel_size=self.kernel_size, + strides=self.strides, ) - self.pointwise_conv = self.pointwise_convolution_layer( - in_features=in_features * depth_multiplier, - out_features=out_features, - kernel_size=1, - strides=strides, + transposed_padding = calculate_transpose_padding( padding=padding, - weight_init=pointwise_weight_init, - bias_init=pointwise_bias_init, - key=key, + extra_padding=self.output_padding, + kernel_size=self.kernel_size, + input_dilation=self.dilation, ) - def __call__(self, x: jax.Array, **k) -> jax.Array: - x = self.depthwise_conv(x) - x = self.pointwise_conv(x) - return x - - @property - @abc.abstractmethod - def spatial_ndim(self) -> int: - ... - - @property - @abc.abstractmethod - def pointwise_convolution_layer(self): - ... - - @property - @abc.abstractmethod - def depthwise_convolution_layer(self): - ... - + return fft_conv_general_dilated( + lhs=x, + rhs=self.weight, + strides=self.strides, + padding=transposed_padding, + dilation=self.dilation, + groups=1, + ) -class SeparableConv1D(SeparableConvND): - """1D Separable convolution layer. - Separable convolution is a depthwise convolution followed by a pointwise - convolution. The objective is to reduce the number of parameters in the - convolutional layer. For example, for I input features and O output features, - and a kernel size = Ki, then standard convolution has I * O * K0 ... * Kn + O - parameters, whereas separable convolution has I * K0 ... * Kn + I * O + O - parameters. +class FFTConv1DTranspose(FFTConvNDTranspose): + """1D FFT Convolution transpose layer. Args: in_features: Number of input feature maps, for 1D convolution this is the @@ -945,9 +982,6 @@ class SeparableConv1D(SeparableConvND): - single integer for same kernel size in all dimensions. - sequence of integers for different kernel sizes in each dimension. - depth_multiplier: multiplier for the number of output channels. for example - if the input has 32 channels and the depth multiplier is 2 then the - output will have 64 channels. strides: Stride of the convolution. accepts: - single integer for same stride in all dimensions. @@ -963,47 +997,47 @@ class SeparableConv1D(SeparableConvND): as the input. - ``valid``/``VALID`` for no padding. - weight_init: Function to use for initializing the weights. defaults + output_padding: Padding of the output after convolution. accepts: + + - single integer for same padding in all dimensions. + + dilation: Dilation of the convolutional kernel accepts: + + - single integer for same dilation in all dimensions. + - sequence of integers for different dilation in each dimension. + + weight_init: function to use for initializing the weights. defaults to ``glorot uniform``. - bias_init: Function to use for initializing the bias. defaults to + bias_init: function to use for initializing the bias. defaults to ``zeros``. set to ``None`` to not use a bias. + groups: number of groups to use for grouped convolution. key: key to use for initializing the weights. defaults to ``0``. - Example: >>> import jax.numpy as jnp >>> import serket as sk - >>> l1 = sk.nn.SeparableConv1D(3, 3, 3, depth_multiplier=2) - >>> l1(jnp.ones((3, 32))).shape - (3, 32) + >>> import jax + >>> layer = sk.nn.FFTConv1DTranspose(1, 2, 3) + >>> # single sample + >>> x = jnp.ones((1, 5)) + >>> print(layer(x).shape) + (2, 5) + >>> # batch of samples + >>> x = jnp.ones((2, 1, 5)) + >>> print(jax.vmap(layer)(x).shape) + (2, 2, 5) - Reference: + References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - - https://github.com/google/flax/blob/main/flax/linen/linear.py """ @property def spatial_ndim(self) -> int: return 1 - @property - def pointwise_convolution_layer(self): - return Conv1D - - @property - def depthwise_convolution_layer(self): - return DepthwiseConv1D - - -class SeparableConv2D(SeparableConvND): - """2D Separable convolution layer. - Separable convolution is a depthwise convolution followed by a pointwise - convolution. The objective is to reduce the number of parameters in the - convolutional layer. For example, for I input features and O output features, - and a kernel size = Ki, then standard convolution has I * O * K0 ... * Kn + O - parameters, whereas separable convolution has I * K0 ... * Kn + I * O + O - parameters. +class FFTConv2DTranspose(FFTConvNDTranspose): + """2D FFT Convolution transpose layer. Args: in_features: Number of input feature maps, for 1D convolution this is the @@ -1018,9 +1052,6 @@ class SeparableConv2D(SeparableConvND): - single integer for same kernel size in all dimensions. - sequence of integers for different kernel sizes in each dimension. - depth_multiplier: multiplier for the number of output channels. for example - if the input has 32 channels and the depth multiplier is 2 then the - output will have 64 channels. strides: Stride of the convolution. accepts: - single integer for same stride in all dimensions. @@ -1036,46 +1067,47 @@ class SeparableConv2D(SeparableConvND): as the input. - ``valid``/``VALID`` for no padding. - weight_init: Function to use for initializing the weights. defaults + output_padding: Padding of the output after convolution. accepts: + + - single integer for same padding in all dimensions. + + dilation: Dilation of the convolutional kernel accepts: + + - single integer for same dilation in all dimensions. + - sequence of integers for different dilation in each dimension. + + weight_init: function to use for initializing the weights. defaults to ``glorot uniform``. - bias_init: Function to use for initializing the bias. defaults to + bias_init: function to use for initializing the bias. defaults to ``zeros``. set to ``None`` to not use a bias. + groups: number of groups to use for grouped convolution. key: key to use for initializing the weights. defaults to ``0``. Example: >>> import jax.numpy as jnp >>> import serket as sk - >>> l1 = sk.nn.SeparableConv2D(3, 3, 3, depth_multiplier=2) - >>> l1(jnp.ones((3, 32, 32))).shape - (3, 32, 32) + >>> import jax + >>> layer = sk.nn.FFTConv2DTranspose(1, 2, 3) + >>> # single sample + >>> x = jnp.ones((1, 5, 5)) + >>> print(layer(x).shape) + (2, 5, 5) + >>> # batch of samples + >>> x = jnp.ones((2, 1, 5, 5)) + >>> print(jax.vmap(layer)(x).shape) + (2, 2, 5, 5) - Reference: + References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - - https://github.com/google/flax/blob/main/flax/linen/linear.py """ @property def spatial_ndim(self) -> int: return 2 - @property - def pointwise_convolution_layer(self): - return Conv2D - - @property - def depthwise_convolution_layer(self): - return DepthwiseConv2D - - -class SeparableConv3D(SeparableConvND): - """3D Separable convolution layer. - Separable convolution is a depthwise convolution followed by a pointwise - convolution. The objective is to reduce the number of parameters in the - convolutional layer. For example, for I input features and O output features, - and a kernel size = Ki, then standard convolution has I * O * K0 ... * Kn + O - parameters, whereas separable convolution has I * K0 ... * Kn + I * O + O - parameters. +class FFTConv3DTranspose(FFTConvNDTranspose): + """3D FFT Convolution transpose layer. Args: in_features: Number of input feature maps, for 1D convolution this is the @@ -1090,9 +1122,6 @@ class SeparableConv3D(SeparableConvND): - single integer for same kernel size in all dimensions. - sequence of integers for different kernel sizes in each dimension. - depth_multiplier: multiplier for the number of output channels. for example - if the input has 32 channels and the depth multiplier is 2 then the - output will have 64 channels. strides: Stride of the convolution. accepts: - single integer for same stride in all dimensions. @@ -1108,113 +1137,79 @@ class SeparableConv3D(SeparableConvND): as the input. - ``valid``/``VALID`` for no padding. - weight_init: Function to use for initializing the weights. defaults + output_padding: Padding of the output after convolution. accepts: + + - single integer for same padding in all dimensions. + + dilation: Dilation of the convolutional kernel accepts: + + - single integer for same dilation in all dimensions. + - sequence of integers for different dilation in each dimension. + + weight_init: function to use for initializing the weights. defaults to ``glorot uniform``. - bias_init: Function to use for initializing the bias. defaults to + bias_init: function to use for initializing the bias. defaults to ``zeros``. set to ``None`` to not use a bias. + groups: number of groups to use for grouped convolution. key: key to use for initializing the weights. defaults to ``0``. Example: >>> import jax.numpy as jnp >>> import serket as sk - >>> l1 = sk.nn.SeparableConv3D(3, 3, 3, depth_multiplier=2) - >>> l1(jnp.ones((3, 32, 32, 32))).shape - (3, 32, 32, 32) + >>> import jax + >>> layer = sk.nn.FFTConv3DTranspose(1, 2, 3) + >>> # single sample + >>> x = jnp.ones((1, 5, 5, 5)) + >>> print(layer(x).shape) + (2, 5, 5, 5) + >>> # batch of samples + >>> x = jnp.ones((2, 1, 5, 5, 5)) + >>> print(jax.vmap(layer)(x).shape) + (2, 2, 5, 5, 5) - Reference: + References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - - https://github.com/google/flax/blob/main/flax/linen/linear.py """ @property def spatial_ndim(self) -> int: return 3 - @property - def pointwise_convolution_layer(self): - return Conv3D - - @property - def depthwise_convolution_layer(self): - return DepthwiseConv3D - -class ConvNDLocal(sk.TreeClass): +class BaseDepthwiseConvND(sk.TreeClass): def __init__( self, in_features: int, - out_features: int, kernel_size: KernelSizeType, *, - in_size: Sequence[int], - strides: StridesType = 1, + depth_multiplier: int = 1, + strides: int = 1, padding: PaddingType = "same", - input_dilation: DilationType = 1, - kernel_dilation: DilationType = 1, weight_init: InitType = "glorot_uniform", bias_init: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - # checked by callbacks self.in_features = positive_int_cb(in_features) - self.out_features = positive_int_cb(out_features) self.kernel_size = canonicalize( kernel_size, self.spatial_ndim, name="kernel_size" ) - self.in_size = canonicalize(in_size, self.spatial_ndim, name="in_size") + self.depth_multiplier = positive_int_cb(depth_multiplier) self.strides = canonicalize(strides, self.spatial_ndim, name="strides") - self.padding = delayed_canonicalize_padding( - self.in_size, - padding, - self.kernel_size, - self.strides, - ) - self.input_dilation = canonicalize( - input_dilation, - self.spatial_ndim, - name="input_dilation", - ) - self.kernel_dilation = canonicalize( - kernel_dilation, - self.spatial_ndim, - name="kernel_dilation", - ) + self.padding = padding # delayed canonicalization + self.dilation = canonicalize(1, self.spatial_ndim, name="dilation") weight_init = resolve_init_func(weight_init) bias_init = resolve_init_func(bias_init) - out_size = calculate_convolution_output_shape( - shape=self.in_size, - kernel_size=self.kernel_size, - padding=self.padding, - strides=self.strides, - ) - - # OIHW - weight_shape = ( - self.out_features, - self.in_features * ft.reduce(op.mul, self.kernel_size), - *out_size, - ) - + weight_shape = (depth_multiplier * in_features, 1, *self.kernel_size) # OIHW self.weight = weight_init(key, weight_shape) - bias_shape = (self.out_features, *out_size) + bias_shape = (depth_multiplier * in_features, *(1,) * self.spatial_ndim) self.bias = bias_init(key, bias_shape) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array, **k) -> jax.Array: - y = jax.lax.conv_general_dilated_local( - lhs=jnp.expand_dims(x, 0), - rhs=self.weight, - window_strides=self.strides, - padding=self.padding, - filter_shape=self.kernel_size, - lhs_dilation=self.kernel_dilation, - rhs_dilation=self.input_dilation, # atrous dilation - dimension_numbers=generate_conv_dim_numbers(self.spatial_ndim), - ) - + y = self.convolution_operation(jnp.expand_dims(x, 0)) if self.bias is None: return jnp.squeeze(y, 0) return jnp.squeeze((y + self.bias), 0) @@ -1225,30 +1220,46 @@ def spatial_ndim(self) -> int: """Number of spatial dimensions of the convolutional layer.""" ... + @abc.abstractmethod + def convolution_operation(self, x: jax.Array) -> jax.Array: + ... -class Conv1DLocal(ConvNDLocal): - """1D Local convolutional layer. - Local convolutional layer is a convolutional layer where the convolution - kernel is applied to a local region of the input. The kernel weights are - *not* shared across the spatial dimensions of the input. +class DepthwiseConvND(BaseDepthwiseConvND): + def convolution_operation(self, x: jax.Array) -> jax.Array: + padding = delayed_canonicalize_padding( + in_dim=x.shape[2:], + padding=self.padding, + kernel_size=self.kernel_size, + strides=self.strides, + ) + + return jax.lax.conv_general_dilated( + lhs=x, + rhs=self.weight, + window_strides=self.strides, + padding=padding, + rhs_dilation=self.dilation, + dimension_numbers=generate_conv_dim_numbers(self.spatial_ndim), + feature_group_count=self.in_features, + ) +class DepthwiseConv1D(DepthwiseConvND): + """1D Depthwise convolution layer. + Args: in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input channels, for 3D convolution this is the number of input channels. - out_features: Number of output features maps, for 1D convolution this is - the length of the output, for 2D convolution this is the number of - output channels, for 3D convolution this is the number of output - channels. kernel_size: Size of the convolutional kernel. accepts: - single integer for same kernel size in all dimensions. - sequence of integers for different kernel sizes in each dimension. - in_size: the size of the spatial dimensions of the input. e.g excluding - the first dimension. accepts a sequence of integer(s). + depth_multiplier: multiplier for the number of output channels. for example + if the input has 32 channels and the depth multiplier is 2 then the + output will have 64 channels. strides: Stride of the convolution. accepts: - single integer for same stride in all dimensions. @@ -1270,12 +1281,13 @@ class Conv1DLocal(ConvNDLocal): ``zeros``. set to ``None`` to not use a bias. key: key to use for initializing the weights. defaults to ``0``. + Example: >>> import jax.numpy as jnp >>> import serket as sk - >>> l1 = sk.nn.Conv1DLocal(3, 3, 3, in_size=(32,)) + >>> l1 = sk.nn.DepthwiseConv1D(3, 3, depth_multiplier=2, strides=2) >>> l1(jnp.ones((3, 32))).shape - (3, 32) + (6, 16) Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html @@ -1287,29 +1299,21 @@ def spatial_ndim(self) -> int: return 1 -class Conv2DLocal(ConvNDLocal): - """2D Local convolutional layer. - - Local convolutional layer is a convolutional layer where the convolution - kernel is applied to a local region of the input. This means that the kernel - weights are *not* shared across the spatial dimensions of the input. - +class DepthwiseConv2D(DepthwiseConvND): + """2D Depthwise convolution layer. Args: in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input channels, for 3D convolution this is the number of input channels. - out_features: Number of output features maps, for 1D convolution this is - the length of the output, for 2D convolution this is the number of - output channels, for 3D convolution this is the number of output - channels. kernel_size: Size of the convolutional kernel. accepts: - single integer for same kernel size in all dimensions. - sequence of integers for different kernel sizes in each dimension. - in_size: the size of the spatial dimensions of the input. e.g excluding - the first dimension. accepts a sequence of integer(s). + depth_multiplier: multiplier for the number of output channels. for example + if the input has 32 channels and the depth multiplier is 2 then the + output will have 64 channels. strides: Stride of the convolution. accepts: - single integer for same stride in all dimensions. @@ -1331,12 +1335,13 @@ class Conv2DLocal(ConvNDLocal): ``zeros``. set to ``None`` to not use a bias. key: key to use for initializing the weights. defaults to ``0``. + Example: >>> import jax.numpy as jnp >>> import serket as sk - >>> l1 = sk.nn.Conv2DLocal(3, 3, 3, in_size=(32, 32)) + >>> l1 = sk.nn.DepthwiseConv2D(3, 3, depth_multiplier=2, strides=2) >>> l1(jnp.ones((3, 32, 32))).shape - (3, 32, 32) + (6, 16, 16) Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html @@ -1348,35 +1353,28 @@ def spatial_ndim(self) -> int: return 2 -class Conv3DLocal(ConvNDLocal): - """3D Local convolutional layer. - - Local convolutional layer is a convolutional layer where the convolution - kernel is applied to a local region of the input. This means that the kernel - weights are *not* shared across the spatial dimensions of the input. - +class DepthwiseConv3D(DepthwiseConvND): + """3D Depthwise convolution layer. Args: in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input channels, for 3D convolution this is the number of input channels. - out_features: Number of output features maps, for 1D convolution this is - the length of the output, for 2D convolution this is the number of - output channels, for 3D convolution this is the number of output - channels. + kernel_size: Size of the convolutional kernel. accepts: - single integer for same kernel size in all dimensions. - sequence of integers for different kernel sizes in each dimension. - in_size: the size of the spatial dimensions of the input. e.g excluding - the first dimension. accepts a sequence of integer(s). - strides: Stride of the convolution. accepts: + depth_multiplier: multiplier for the number of output channels. for example + if the input has 32 channels and the depth multiplier is 2 then the + output will have 64 channels. + strides: stride of the convolution. accepts: - single integer for same stride in all dimensions. - sequence of integers for different strides in each dimension. - padding: Padding of the input before convolution. accepts: + padding: adding of the input before convolution. accepts: - single integer for same padding in all dimensions. - tuple of integers for different padding in each dimension. @@ -1392,12 +1390,13 @@ class Conv3DLocal(ConvNDLocal): ``zeros``. set to ``None`` to not use a bias. key: key to use for initializing the weights. defaults to ``0``. + Example: >>> import jax.numpy as jnp >>> import serket as sk - >>> l1 = sk.nn.Conv3DLocal(3, 3, 3, in_size=(32, 32, 32)) + >>> l1 = sk.nn.DepthwiseConv3D(3, 3, depth_multiplier=2, strides=2) >>> l1(jnp.ones((3, 32, 32, 32))).shape - (3, 32, 32, 32) + (6, 16, 16, 16) Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html @@ -1409,208 +1408,40 @@ def spatial_ndim(self) -> int: return 3 -@ft.partial(jax.jit, inline=True) -def _ungrouped_matmul(x, y) -> jax.Array: - alpha = "".join(map(str, range(max(x.ndim, y.ndim)))) - lhs = "a" + alpha[: x.ndim - 1] - rhs = "b" + alpha[: y.ndim - 1] - out = "ab" + lhs[2:] - return jnp.einsum(f"{lhs},{rhs}->{out}", x, y) +class DepthwiseFFTConvND(BaseDepthwiseConvND): + def convolution_operation(self, x: jax.Array) -> jax.Array: + padding = delayed_canonicalize_padding( + in_dim=x.shape[2:], + padding=self.padding, + kernel_size=self.kernel_size, + strides=self.strides, + ) + return fft_conv_general_dilated( + lhs=x, + rhs=self.weight, + strides=self.strides, + padding=padding, + dilation=self.dilation, + groups=self.in_features, + ) -@ft.partial(jax.jit, static_argnums=(2,), inline=True) -def _grouped_matmul(x, y, groups) -> jax.Array: - b, c, *s = x.shape # batch, channels, spatial - o, i, *k = y.shape # out_channels, in_channels, kernel - x = x.reshape(groups, b, c // groups, *s) # groups, batch, channels, spatial - y = y.reshape(groups, o // groups, *(i, *k)) - z = jax.vmap(_ungrouped_matmul, in_axes=(0, 0), out_axes=1)(x, y) - return z.reshape(z.shape[0], z.shape[1] * z.shape[2], *z.shape[3:]) - - -def grouped_matmul(x, y, groups: int = 1): - return _ungrouped_matmul(x, y) if groups == 1 else _grouped_matmul(x, y, groups) - - -@ft.partial(jax.jit, static_argnums=(1, 2), inline=True) -def _intersperse_along_axis(x: jax.Array, dilation: int, axis: int) -> jax.Array: - shape = list(x.shape) - shape[axis] = (dilation) * shape[axis] - (dilation - 1) - z = jnp.zeros(shape) - z = z.at[(slice(None),) * axis + (slice(None, None, (dilation)),)].set(x) - return z - - -@ft.partial(jax.jit, static_argnums=(1, 2), inline=True) -def _general_intersperse( - x: jax.Array, - dilation: tuple[int, ...], - axis: tuple[int, ...], -) -> jax.Array: - for di, ai in zip(dilation, axis): - x = _intersperse_along_axis(x, di, ai) if di > 1 else x - return x - -@ft.partial(jax.jit, static_argnums=(1,), inline=True) -def _general_pad(x: jax.Array, pad_width: tuple[tuple[int, int], ...]) -> jax.Array: - """Pad the input with `pad_width` on each side. Negative value will lead to cropping. - Example: - >>> print(_general_pad(jnp.ones([3,3]),((0,0),(-1,1)))) # DOCTEST: +NORMALIZE_WHITESPACE - [[1. 1. 0.] - [1. 1. 0.] - [1. 1. 0.]] - """ - - for axis, (lhs, rhs) in enumerate(pad_width := list(pad_width)): - if lhs < 0 and rhs < 0: - x = jax.lax.dynamic_slice_in_dim(x, -lhs, x.shape[axis] + lhs + rhs, axis) - elif lhs < 0: - x = jax.lax.dynamic_slice_in_dim(x, -lhs, x.shape[axis] + lhs, axis) - elif rhs < 0: - x = jax.lax.dynamic_slice_in_dim(x, 0, x.shape[axis] + rhs, axis) - - return jnp.pad(x, [(max(lhs, 0), max(rhs, 0)) for (lhs, rhs) in (pad_width)]) - - -@ft.partial(jax.jit, static_argnums=(2, 3, 4, 5), inline=True) -def fft_conv_general_dilated( - x: jax.Array, - w: jax.Array, - strides: tuple[int, ...], - padding: tuple[tuple[int, int], ...], - groups: int, - dilation: tuple[int, ...], -) -> jax.Array: - """General dilated convolution using FFT - Args: - x: input array in shape (batch, in_features, *spatial_in_shape) - w: kernel array in shape of (out_features, in_features // groups, *kernel_size) - strides: strides in form of tuple of ints for each spatial dimension - padding: padding in the form of ((before_1, after_1), ..., (before_N, after_N)) - for each spatial dimension - groups: number of groups - dilation: dilation in the form of tuple of ints for each spatial dimension - """ - - spatial_ndim = x.ndim - 2 # spatial dimensions - w = _general_intersperse(w, dilation=dilation, axis=range(2, 2 + spatial_ndim)) - x = _general_pad(x, ((0, 0), (0, 0), *padding)) - - x_shape, w_shape = x.shape, w.shape - - if x.shape[-1] % 2 != 0: - x = jnp.pad(x, tuple([(0, 0)] * (x.ndim - 1) + [(0, 1)])) - - kernel_padding = ((0, x.shape[i] - w.shape[i]) for i in range(2, spatial_ndim + 2)) - w = _general_pad(w, ((0, 0), (0, 0), *kernel_padding)) - - # for real-valued input - x_fft = jnp.fft.rfftn(x, axes=range(2, spatial_ndim + 2)) - w_fft = jnp.conjugate(jnp.fft.rfftn(w, axes=range(2, spatial_ndim + 2))) - z_fft = grouped_matmul(x_fft, w_fft, groups) - - z = jnp.fft.irfftn(z_fft, axes=range(2, spatial_ndim + 2)) - - start = (0,) * (spatial_ndim + 2) - end = [z.shape[0], z.shape[1]] - end += [max((x_shape[i] - w_shape[i] + 1), 0) for i in range(2, spatial_ndim + 2)] - - if all(s == 1 for s in strides): - return jax.lax.dynamic_slice(z, start, end) - - return jax.lax.slice(z, start, end, (1, 1, *strides)) - - -class FFTConvND(sk.TreeClass): - def __init__( - self, - in_features: int, - out_features: int, - kernel_size: KernelSizeType, - *, - strides: StridesType = 1, - padding: PaddingType = "same", - kernel_dilation: DilationType = 1, - weight_init: InitType = "glorot_uniform", - bias_init: InitType = "zeros", - groups: int = 1, - key: jr.KeyArray = jr.PRNGKey(0), - ): - self.in_features = positive_int_cb(in_features) - self.out_features = positive_int_cb(out_features) - self.kernel_size = canonicalize( - kernel_size, - ndim=self.spatial_ndim, - name="kernel_size", - ) - self.strides = canonicalize(strides, self.spatial_ndim, name="strides") - self.padding = padding - self.kernel_dilation = canonicalize( - kernel_dilation, - self.spatial_ndim, - name="kernel_dilation", - ) - weight_init = resolve_init_func(weight_init) - bias_init = resolve_init_func(bias_init) - self.groups = positive_int_cb(groups) - - if self.out_features % self.groups != 0: - msg = f"Expected out_features % groups == 0, got {self.out_features % self.groups}" - raise ValueError(msg) - - weight_shape = (out_features, in_features // groups, *self.kernel_size) - self.weight = weight_init(key, weight_shape) - - bias_shape = (out_features, *(1,) * self.spatial_ndim) - self.bias = bias_init(key, bias_shape) - - @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) - def __call__(self, x: jax.Array, **k) -> jax.Array: - padding = delayed_canonicalize_padding( - in_dim=x.shape[1:], - padding=self.padding, - kernel_size=self.kernel_size, - strides=self.strides, - ) - - y = fft_conv_general_dilated( - jnp.expand_dims(x, axis=0), - self.weight, - strides=self.strides, - padding=padding, - groups=self.groups, - dilation=self.kernel_dilation, - ) - y = jnp.squeeze(y, axis=0) - if self.bias is None: - return y - return y + self.bias - - @property - @abc.abstractmethod - def spatial_ndim(self) -> int: - """Number of spatial dimensions of the convolution.""" - ... - - -class FFTConv1D(FFTConvND): - """1D Convolutional layer. +class DepthwiseFFTConv1D(DepthwiseFFTConvND): + """1D Depthwise FFT convolution layer. Args: in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input channels, for 3D convolution this is the number of input channels. - out_features: Number of output features maps, for 1D convolution this is - the length of the output, for 2D convolution this is the number of - output channels, for 3D convolution this is the number of output - channels. kernel_size: Size of the convolutional kernel. accepts: - single integer for same kernel size in all dimensions. - sequence of integers for different kernel sizes in each dimension. + depth_multiplier: multiplier for the number of output channels. for example + if the input has 32 channels and the depth multiplier is 2 then the + output will have 64 channels. strides: Stride of the convolution. accepts: - single integer for same stride in all dimensions. @@ -1626,34 +1457,23 @@ class FFTConv1D(FFTConvND): as the input. - ``valid``/``VALID`` for no padding. - kernel_dilation: Dilation of the convolutional kernel accepts: - - - single integer for same dilation in all dimensions. - - sequence of integers for different dilation in each dimension. - weight_init: function to use for initializing the weights. defaults to ``glorot uniform``. bias_init: function to use for initializing the bias. defaults to ``zeros``. set to ``None`` to not use a bias. - groups: number of groups to use for grouped convolution. key: key to use for initializing the weights. defaults to ``0``. + Example: >>> import jax.numpy as jnp >>> import serket as sk - >>> import jax - >>> layer = sk.nn.FFTConv1D(in_features=1, out_features=2, kernel_size=3) - >>> # single sample - >>> x = jnp.ones((1, 5)) - >>> print(layer(x).shape) - (2, 5) - >>> # batch of samples - >>> x = jnp.ones((2, 1, 5)) - >>> print(jax.vmap(layer)(x).shape) - (2, 2, 5) + >>> l1 = sk.nn.DepthwiseFFTConv1D(3, 3, depth_multiplier=2, strides=2) + >>> l1(jnp.ones((3, 32))).shape + (6, 16) References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://github.com/google/flax/blob/main/flax/linen/linear.py """ @property @@ -1661,22 +1481,21 @@ def spatial_ndim(self) -> int: return 1 -class FFTConv2D(FFTConvND): - """2D FFT Convolutional layer. +class DepthwiseFFTConv2D(DepthwiseFFTConvND): + """2D Depthwise convolution layer. Args: in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input channels, for 3D convolution this is the number of input channels. - out_features: Number of output features maps, for 1D convolution this is - the length of the output, for 2D convolution this is the number of - output channels, for 3D convolution this is the number of output - channels. kernel_size: Size of the convolutional kernel. accepts: - - single integer for same kernel size in all dimensions. - - sequence of integers for different kernel sizes in each dimension. + - single integer for same kernel size in all dimnsions. + - sequence of integers for different kernel sizes in each dimension. + depth_multiplier: multiplier for the number of output channels. for example + if the input has 32 channels and the depth multiplier is 2 then the + output will have 64 channels. strides: Stride of the convolution. accepts: - single integer for same stride in all dimensions. @@ -1692,34 +1511,23 @@ class FFTConv2D(FFTConvND): as the input. - ``valid``/``VALID`` for no padding. - kernel_dilation: Dilation of the convolutional kernel accepts: - - - single integer for same dilation in all dimensions. - - sequence of integers for different dilation in each dimension. - weight_init: function to use for initializing the weights. defaults to ``glorot uniform``. bias_init: function to use for initializing the bias. defaults to ``zeros``. set to ``None`` to not use a bias. - groups: number of groups to use for grouped convolution. key: key to use for initializing the weights. defaults to ``0``. + Example: >>> import jax.numpy as jnp >>> import serket as sk - >>> import jax - >>> layer = sk.nn.FFTConv2D(in_features=1, out_features=2, kernel_size=3) - >>> # single sample - >>> x = jnp.ones((1, 5, 5)) - >>> print(layer(x).shape) - (2, 5, 5) - >>> # batch of samples - >>> x = jnp.ones((2, 1, 5, 5)) - >>> print(jax.vmap(layer)(x).shape) - (2, 2, 5, 5) + >>> l1 = sk.nn.DepthwiseFFTConv2D(3, 3, depth_multiplier=2, strides=2) + >>> l1(jnp.ones((3, 32, 32))).shape + (6, 16, 16) References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://github.com/google/flax/blob/main/flax/linen/linear.py """ @property @@ -1727,22 +1535,21 @@ def spatial_ndim(self) -> int: return 2 -class FFTConv3D(FFTConvND): - """3D FFT Convolutional layer. +class DepthwiseFFTConv3D(DepthwiseFFTConvND): + """3D Depthwise FFT convolution layer. Args: in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input channels, for 3D convolution this is the number of input channels. - out_features: Number of output features maps, for 1D convolution this is - the length of the output, for 2D convolution this is the number of - output channels, for 3D convolution this is the number of output - channels. kernel_size: Size of the convolutional kernel. accepts: - - single integer for same kernel size in all dimensions. - - sequence of integers for different kernel sizes in each dimension. + - single integer for same kernel size in all dimnsions. + - sequence of integers for different kernel sizes in each dimension. + depth_multiplier: multiplier for the number of output channels. for example + if the input has 32 channels and the depth multiplier is 2 then the + output will have 64 channels. strides: Stride of the convolution. accepts: - single integer for same stride in all dimensions. @@ -1758,34 +1565,23 @@ class FFTConv3D(FFTConvND): as the input. - ``valid``/``VALID`` for no padding. - kernel_dilation: Dilation of the convolutional kernel accepts: - - - single integer for same dilation in all dimensions. - - sequence of integers for different dilation in each dimension. - weight_init: function to use for initializing the weights. defaults to ``glorot uniform``. bias_init: function to use for initializing the bias. defaults to ``zeros``. set to ``None`` to not use a bias. - groups: number of groups to use for grouped convolution. key: key to use for initializing the weights. defaults to ``0``. + Example: >>> import jax.numpy as jnp >>> import serket as sk - >>> import jax - >>> layer = sk.nn.FFTConv3D(in_features=1, out_features=2, kernel_size=3) - >>> # single sample - >>> x = jnp.ones((1, 5, 5, 5)) - >>> print(layer(x).shape) - (2, 5, 5, 5) - >>> # batch of samples - >>> x = jnp.ones((2, 1, 5, 5, 5)) - >>> print(jax.vmap(layer)(x).shape) - (2, 2, 5, 5, 5) + >>> l1 = sk.nn.DepthwiseFFTConv3D(3, 3, depth_multiplier=2, strides=2) + >>> l1(jnp.ones((3, 32, 32, 32))).shape + (6, 16, 16, 16) References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://github.com/google/flax/blob/main/flax/linen/linear.py """ @property @@ -1793,101 +1589,73 @@ def spatial_ndim(self) -> int: return 3 -class FFTConvNDTranspose(sk.TreeClass): +class SeparableConvND(sk.TreeClass): def __init__( self, in_features: int, out_features: int, kernel_size: KernelSizeType, - *, - strides: StridesType = 1, - padding: PaddingType = "same", - output_padding: int = 0, - kernel_dilation: DilationType = 1, - weight_init: InitType = "glorot_uniform", - bias_init: InitType = "zeros", - groups: int = 1, - key: jr.KeyArray = jr.PRNGKey(0), - ): - self.in_features = positive_int_cb(in_features) - self.out_features = positive_int_cb(out_features) - self.kernel_size = canonicalize( - kernel_size, - self.spatial_ndim, - name="kernel_size", - ) - self.strides = canonicalize(strides, self.spatial_ndim, name="strides") - self.padding = padding - self.output_padding = canonicalize( - output_padding, - self.spatial_ndim, - name="output_padding", - ) - self.kernel_dilation = canonicalize( - kernel_dilation, - self.spatial_ndim, - name="kernel_dilation", - ) - weight_init = resolve_init_func(weight_init) - bias_init = resolve_init_func(bias_init) - self.groups = positive_int_cb(groups) - - if self.in_features % self.groups != 0: - raise ValueError( - f"Expected in_features % groups == 0, " - f"got {self.in_features % self.groups}" - ) - - weight_shape = (out_features, in_features // groups, *self.kernel_size) # OIHW - self.weight = weight_init(key, weight_shape) - - if bias_init is None: - self.bias = None - else: - bias_shape = (out_features, *(1,) * self.spatial_ndim) - self.bias = bias_init(key, bias_shape) - - @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) - def __call__(self, x: jax.Array, **k) -> jax.Array: - padding = delayed_canonicalize_padding( - in_dim=x.shape[1:], - padding=self.padding, - kernel_size=self.kernel_size, - strides=self.strides, + *, + depth_multiplier: int = 1, + strides: StridesType = 1, + padding: PaddingType = "same", + depthwise_weight_init: InitType = "glorot_uniform", + pointwise_weight_init: InitType = "glorot_uniform", + pointwise_bias_init: InitType = "zeros", + key: jr.KeyArray = jr.PRNGKey(0), + ): + self.depthwise_conv = self.depthwise_convolution_layer( + in_features=in_features, + depth_multiplier=depth_multiplier, + kernel_size=kernel_size, + strides=strides, + padding=padding, + weight_init=depthwise_weight_init, + bias_init=None, # no bias for lhs + key=key, ) - transposed_padding = calculate_transpose_padding( + self.pointwise_conv = self.pointwise_convolution_layer( + in_features=in_features * depth_multiplier, + out_features=out_features, + kernel_size=1, + strides=strides, padding=padding, - extra_padding=self.output_padding, - kernel_size=self.kernel_size, - input_dilation=self.kernel_dilation, + weight_init=pointwise_weight_init, + bias_init=pointwise_bias_init, + key=key, ) - y = fft_conv_general_dilated( - jnp.expand_dims(x, axis=0), - self.weight, - strides=self.strides, - padding=transposed_padding, - groups=self.groups, - dilation=self.kernel_dilation, - ) + def __call__(self, x: jax.Array, **k) -> jax.Array: + x = self.depthwise_conv(x) + x = self.pointwise_conv(x) + return x - y = jnp.squeeze(y, axis=0) + @property + @abc.abstractmethod + def spatial_ndim(self) -> int: + ... - if self.bias is None: - return y - return y + self.bias + @property + @abc.abstractmethod + def pointwise_convolution_layer(self): + ... @property @abc.abstractmethod - def spatial_ndim(self) -> int: - """Number of spatial dimensions of the convolution.""" + def depthwise_convolution_layer(self): ... -class FFTConv1DTranspose(FFTConvNDTranspose): - """1D FFT Convolution transpose layer. +class SeparableConv1D(SeparableConvND): + """1D Separable convolution layer. + + Separable convolution is a depthwise convolution followed by a pointwise + convolution. The objective is to reduce the number of parameters in the + convolutional layer. For example, for I input features and O output features, + and a kernel size = Ki, then standard convolution has I * O * K0 ... * Kn + O + parameters, whereas separable convolution has I * K0 ... * Kn + I * O + O + parameters. Args: in_features: Number of input feature maps, for 1D convolution this is the @@ -1902,6 +1670,9 @@ class FFTConv1DTranspose(FFTConvNDTranspose): - single integer for same kernel size in all dimensions. - sequence of integers for different kernel sizes in each dimension. + depth_multiplier: multiplier for the number of output channels. for example + if the input has 32 channels and the depth multiplier is 2 then the + output will have 64 channels. strides: Stride of the convolution. accepts: - single integer for same stride in all dimensions. @@ -1917,47 +1688,47 @@ class FFTConv1DTranspose(FFTConvNDTranspose): as the input. - ``valid``/``VALID`` for no padding. - output_padding: Padding of the output after convolution. accepts: - - - single integer for same padding in all dimensions. - - kernel_dilation: Dilation of the convolutional kernel accepts: - - - single integer for same dilation in all dimensions. - - sequence of integers for different dilation in each dimension. - - weight_init: function to use for initializing the weights. defaults + weight_init: Function to use for initializing the weights. defaults to ``glorot uniform``. - bias_init: function to use for initializing the bias. defaults to + bias_init: Function to use for initializing the bias. defaults to ``zeros``. set to ``None`` to not use a bias. - groups: number of groups to use for grouped convolution. key: key to use for initializing the weights. defaults to ``0``. + Example: >>> import jax.numpy as jnp >>> import serket as sk - >>> import jax - >>> layer = sk.nn.FFTConv1DTranspose(1, 2, 3) - >>> # single sample - >>> x = jnp.ones((1, 5)) - >>> print(layer(x).shape) - (2, 5) - >>> # batch of samples - >>> x = jnp.ones((2, 1, 5)) - >>> print(jax.vmap(layer)(x).shape) - (2, 2, 5) + >>> l1 = sk.nn.SeparableConv1D(3, 3, 3, depth_multiplier=2) + >>> l1(jnp.ones((3, 32))).shape + (3, 32) - References: + Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://github.com/google/flax/blob/main/flax/linen/linear.py """ @property def spatial_ndim(self) -> int: return 1 + @property + def pointwise_convolution_layer(self): + return Conv1D -class FFTConv2DTranspose(FFTConvNDTranspose): - """2D FFT Convolution transpose layer. + @property + def depthwise_convolution_layer(self): + return DepthwiseConv1D + + +class SeparableConv2D(SeparableConvND): + """2D Separable convolution layer. + + Separable convolution is a depthwise convolution followed by a pointwise + convolution. The objective is to reduce the number of parameters in the + convolutional layer. For example, for I input features and O output features, + and a kernel size = Ki, then standard convolution has I * O * K0 ... * Kn + O + parameters, whereas separable convolution has I * K0 ... * Kn + I * O + O + parameters. Args: in_features: Number of input feature maps, for 1D convolution this is the @@ -1972,6 +1743,9 @@ class FFTConv2DTranspose(FFTConvNDTranspose): - single integer for same kernel size in all dimensions. - sequence of integers for different kernel sizes in each dimension. + depth_multiplier: multiplier for the number of output channels. for example + if the input has 32 channels and the depth multiplier is 2 then the + output will have 64 channels. strides: Stride of the convolution. accepts: - single integer for same stride in all dimensions. @@ -1987,47 +1761,46 @@ class FFTConv2DTranspose(FFTConvNDTranspose): as the input. - ``valid``/``VALID`` for no padding. - output_padding: Padding of the output after convolution. accepts: - - - single integer for same padding in all dimensions. - - kernel_dilation: Dilation of the convolutional kernel accepts: - - - single integer for same dilation in all dimensions. - - sequence of integers for different dilation in each dimension. - - weight_init: function to use for initializing the weights. defaults + weight_init: Function to use for initializing the weights. defaults to ``glorot uniform``. - bias_init: function to use for initializing the bias. defaults to + bias_init: Function to use for initializing the bias. defaults to ``zeros``. set to ``None`` to not use a bias. - groups: number of groups to use for grouped convolution. key: key to use for initializing the weights. defaults to ``0``. Example: >>> import jax.numpy as jnp >>> import serket as sk - >>> import jax - >>> layer = sk.nn.FFTConv2DTranspose(1, 2, 3) - >>> # single sample - >>> x = jnp.ones((1, 5, 5)) - >>> print(layer(x).shape) - (2, 5, 5) - >>> # batch of samples - >>> x = jnp.ones((2, 1, 5, 5)) - >>> print(jax.vmap(layer)(x).shape) - (2, 2, 5, 5) + >>> l1 = sk.nn.SeparableConv2D(3, 3, 3, depth_multiplier=2) + >>> l1(jnp.ones((3, 32, 32))).shape + (3, 32, 32) - References: + Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://github.com/google/flax/blob/main/flax/linen/linear.py """ @property def spatial_ndim(self) -> int: return 2 + @property + def pointwise_convolution_layer(self): + return Conv2D + + @property + def depthwise_convolution_layer(self): + return DepthwiseConv2D -class FFTConv3DTranspose(FFTConvNDTranspose): - """3D FFT Convolution transpose layer. + +class SeparableConv3D(SeparableConvND): + """3D Separable convolution layer. + + Separable convolution is a depthwise convolution followed by a pointwise + convolution. The objective is to reduce the number of parameters in the + convolutional layer. For example, for I input features and O output features, + and a kernel size = Ki, then standard convolution has I * O * K0 ... * Kn + O + parameters, whereas separable convolution has I * K0 ... * Kn + I * O + O + parameters. Args: in_features: Number of input feature maps, for 1D convolution this is the @@ -2042,6 +1815,9 @@ class FFTConv3DTranspose(FFTConvNDTranspose): - single integer for same kernel size in all dimensions. - sequence of integers for different kernel sizes in each dimension. + depth_multiplier: multiplier for the number of output channels. for example + if the input has 32 channels and the depth multiplier is 2 then the + output will have 64 channels. strides: Stride of the convolution. accepts: - single integer for same stride in all dimensions. @@ -2057,121 +1833,59 @@ class FFTConv3DTranspose(FFTConvNDTranspose): as the input. - ``valid``/``VALID`` for no padding. - output_padding: Padding of the output after convolution. accepts: - - - single integer for same padding in all dimensions. - - kernel_dilation: Dilation of the convolutional kernel accepts: - - - single integer for same dilation in all dimensions. - - sequence of integers for different dilation in each dimension. - - weight_init: function to use for initializing the weights. defaults + weight_init: Function to use for initializing the weights. defaults to ``glorot uniform``. - bias_init: function to use for initializing the bias. defaults to + bias_init: Function to use for initializing the bias. defaults to ``zeros``. set to ``None`` to not use a bias. - groups: number of groups to use for grouped convolution. key: key to use for initializing the weights. defaults to ``0``. Example: >>> import jax.numpy as jnp >>> import serket as sk - >>> import jax - >>> layer = sk.nn.FFTConv3DTranspose(1, 2, 3) - >>> # single sample - >>> x = jnp.ones((1, 5, 5, 5)) - >>> print(layer(x).shape) - (2, 5, 5, 5) - >>> # batch of samples - >>> x = jnp.ones((2, 1, 5, 5, 5)) - >>> print(jax.vmap(layer)(x).shape) - (2, 2, 5, 5, 5) + >>> l1 = sk.nn.SeparableConv3D(3, 3, 3, depth_multiplier=2) + >>> l1(jnp.ones((3, 32, 32, 32))).shape + (3, 32, 32, 32) - References: + Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html + - https://github.com/google/flax/blob/main/flax/linen/linear.py """ @property def spatial_ndim(self) -> int: return 3 - -class DepthwiseFFTConvND(sk.TreeClass): - def __init__( - self, - in_features: int, - kernel_size: int | tuple[int, ...], - *, - depth_multiplier: int = 1, - strides: StridesType = 1, - padding: PaddingType = "same", - weight_init: InitType = "glorot_uniform", - bias_init: InitType = "zeros", - key: jr.KeyArray = jr.PRNGKey(0), - ): - self.in_features = positive_int_cb(in_features) - self.kernel_size = canonicalize( - kernel_size, self.spatial_ndim, name="kernel_size" - ) - self.depth_multiplier = positive_int_cb(depth_multiplier) - self.strides = canonicalize(strides, self.spatial_ndim, name="strides") - self.padding = padding - self.input_dilation = canonicalize(1, self.spatial_ndim, name="input_dilation") - self.kernel_dilation = canonicalize( - 1, - self.spatial_ndim, - name="kernel_dilation", - ) - weight_init = resolve_init_func(weight_init) - bias_init = resolve_init_func(bias_init) - - weight_shape = (depth_multiplier * in_features, 1, *self.kernel_size) # OIHW - self.weight = weight_init(key, weight_shape) - - bias_shape = (depth_multiplier * in_features, *(1,) * self.spatial_ndim) - self.bias = bias_init(key, bias_shape) - - @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) - def __call__(self, x: jax.Array, **k) -> jax.Array: - padding = delayed_canonicalize_padding( - in_dim=x.shape[1:], - padding=self.padding, - kernel_size=self.kernel_size, - strides=self.strides, - ) - - y = fft_conv_general_dilated( - jnp.expand_dims(x, axis=0), - self.weight, - strides=self.strides, - padding=padding, - groups=x.shape[0], - dilation=self.kernel_dilation, - ) - y = jnp.squeeze(y, axis=0) - if self.bias is None: - return y - return y + self.bias + @property + def pointwise_convolution_layer(self): + return Conv3D @property - @abc.abstractmethod - def spatial_ndim(self) -> int: - """Number of spatial dimensions of the convolution.""" - ... + def depthwise_convolution_layer(self): + return DepthwiseConv3D -class DepthwiseFFTConv1D(DepthwiseFFTConvND): - """1D Depthwise FFT convolution layer. +class SeparableFFTConv1D(SeparableConvND): + """1D Separable FFT convolution layer. + + Separable convolution is a depthwise convolution followed by a pointwise + convolution. The objective is to reduce the number of parameters in the + convolutional layer. For example, for I input features and O output features, + and a kernel size = Ki, then standard convolution has I * O * K0 ... * Kn + O + parameters, whereas separable convolution has I * K0 ... * Kn + I * O + O + parameters. Args: in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input channels, for 3D convolution this is the number of input channels. + out_features: Number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. kernel_size: Size of the convolutional kernel. accepts: - - single integer for same kernel size in all dimensions. - - sequence of integers for different kernel sizes in each dimension. + - single integer for same kernel size in all dimnsions. + - sequence of integers for different kernel sizes in each dimension. depth_multiplier: multiplier for the number of output channels. for example if the input has 32 channels and the depth multiplier is 2 then the @@ -2201,9 +1915,9 @@ class DepthwiseFFTConv1D(DepthwiseFFTConvND): Example: >>> import jax.numpy as jnp >>> import serket as sk - >>> l1 = sk.nn.DepthwiseFFTConv1D(3, 3, depth_multiplier=2, strides=2) + >>> l1 = sk.nn.SeparableFFTConv1D(3, 3, 3, depth_multiplier=2) >>> l1(jnp.ones((3, 32))).shape - (6, 16) + (3, 32) References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html @@ -2214,14 +1928,33 @@ class DepthwiseFFTConv1D(DepthwiseFFTConvND): def spatial_ndim(self) -> int: return 1 + @property + def pointwise_convolution_layer(self): + return FFTConv1D -class DepthwiseFFTConv2D(DepthwiseFFTConvND): - """2D Depthwise convolution layer. + @property + def depthwise_convolution_layer(self): + return DepthwiseFFTConv1D + + +class SeparableFFTConv2D(SeparableConvND): + """2D Separable FFT convolution layer. + + Separable convolution is a depthwise convolution followed by a pointwise + convolution. The objective is to reduce the number of parameters in the + convolutional layer. For example, for I input features and O output features, + and a kernel size = Ki, then standard convolution has I * O * K0 ... * Kn + O + parameters, whereas separable convolution has I * K0 ... * Kn + I * O + O + parameters. Args: in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input channels, for 3D convolution this is the number of input channels. + out_features: Number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. kernel_size: Size of the convolutional kernel. accepts: - single integer for same kernel size in all dimnsions. @@ -2251,13 +1984,12 @@ class DepthwiseFFTConv2D(DepthwiseFFTConvND): ``zeros``. set to ``None`` to not use a bias. key: key to use for initializing the weights. defaults to ``0``. - Example: >>> import jax.numpy as jnp >>> import serket as sk - >>> l1 = sk.nn.DepthwiseFFTConv2D(3, 3, depth_multiplier=2, strides=2) + >>> l1 = sk.nn.SeparableFFTConv2D(3, 3, 3, depth_multiplier=2) >>> l1(jnp.ones((3, 32, 32))).shape - (6, 16, 16) + (3, 32, 32) References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html @@ -2268,14 +2000,33 @@ class DepthwiseFFTConv2D(DepthwiseFFTConvND): def spatial_ndim(self) -> int: return 2 + @property + def pointwise_convolution_layer(self): + return FFTConv2D + + @property + def depthwise_convolution_layer(self): + return DepthwiseFFTConv2D + -class DepthwiseFFTConv3D(DepthwiseFFTConvND): - """3D Depthwise FFT convolution layer. +class SeparableFFTConv3D(SeparableConvND): + """3D Separable FFT convolution layer. + + Separable convolution is a depthwise convolution followed by a pointwise + convolution. The objective is to reduce the number of parameters in the + convolutional layer. For example, for I input features and O output features, + and a kernel size = Ki, then standard convolution has I * O * K0 ... * Kn + O + parameters, whereas separable convolution has I * K0 ... * Kn + I * O + O + parameters. Args: in_features: Number of input feature maps, for 1D convolution this is the length of the input, for 2D convolution this is the number of input channels, for 3D convolution this is the number of input channels. + out_features: Number of output features maps, for 1D convolution this is + the length of the output, for 2D convolution this is the number of + output channels, for 3D convolution this is the number of output + channels. kernel_size: Size of the convolutional kernel. accepts: - single integer for same kernel size in all dimnsions. @@ -2305,13 +2056,12 @@ class DepthwiseFFTConv3D(DepthwiseFFTConvND): ``zeros``. set to ``None`` to not use a bias. key: key to use for initializing the weights. defaults to ``0``. - Example: >>> import jax.numpy as jnp >>> import serket as sk - >>> l1 = sk.nn.DepthwiseFFTConv3D(3, 3, depth_multiplier=2, strides=2) + >>> l1 = sk.nn.SeparableFFTConv3D(3, 3, 3, depth_multiplier=2) >>> l1(jnp.ones((3, 32, 32, 32))).shape - (6, 16, 16, 16) + (3, 32, 32, 32) References: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html @@ -2322,83 +2072,103 @@ class DepthwiseFFTConv3D(DepthwiseFFTConvND): def spatial_ndim(self) -> int: return 3 + @property + def pointwise_convolution_layer(self): + return FFTConv3D + + @property + def depthwise_convolution_layer(self): + return DepthwiseFFTConv3D + -class SeparableFFTConvND(sk.TreeClass): +class BaseConvNDLocal(sk.TreeClass): def __init__( self, in_features: int, out_features: int, kernel_size: KernelSizeType, *, - depth_multiplier: int = 1, + in_size: Sequence[int], strides: StridesType = 1, padding: PaddingType = "same", - depthwise_weight_init: InitType = "glorot_uniform", - pointwise_weight_init: InitType = "glorot_uniform", - pointwise_bias_init: InitType = "zeros", + dilation: DilationType = 1, + weight_init: InitType = "glorot_uniform", + bias_init: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): - self.in_features = in_features - self.depth_multiplier = canonicalize( - depth_multiplier, - self.in_features, - name="depth_multiplier", + # checked by callbacks + self.in_features = positive_int_cb(in_features) + self.out_features = positive_int_cb(out_features) + self.kernel_size = canonicalize( + kernel_size, self.spatial_ndim, name="kernel_size" + ) + self.in_size = canonicalize(in_size, self.spatial_ndim, name="in_size") + self.strides = canonicalize(strides, self.spatial_ndim, name="strides") + self.padding = delayed_canonicalize_padding( + self.in_size, padding, self.kernel_size, self.strides ) + self.dilation = canonicalize(dilation, self.spatial_ndim, name="dilation") + weight_init = resolve_init_func(weight_init) + bias_init = resolve_init_func(bias_init) - self.depthwise_conv = self.depthwise_convolution_layer( - in_features=in_features, - depth_multiplier=depth_multiplier, - kernel_size=kernel_size, - strides=strides, - padding=padding, - weight_init=depthwise_weight_init, - bias_init=None, # no bias for lhs - key=key, + out_size = calculate_convolution_output_shape( + shape=self.in_size, + kernel_size=self.kernel_size, + padding=self.padding, + strides=self.strides, ) - self.pointwise_conv = self.pointwise_convolution_layer( - in_features=in_features * depth_multiplier, - out_features=out_features, - kernel_size=1, - strides=strides, - padding=padding, - weight_init=pointwise_weight_init, - bias_init=pointwise_bias_init, - key=key, + # OIHW + weight_shape = ( + self.out_features, + self.in_features * ft.reduce(op.mul, self.kernel_size), + *out_size, ) + self.weight = weight_init(key, weight_shape) + + bias_shape = (self.out_features, *out_size) + self.bias = bias_init(key, bias_shape) + @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array, **k) -> jax.Array: - x = self.depthwise_conv(x) - x = self.pointwise_conv(x) - return x + y = self.convolution_operation(jnp.expand_dims(x, 0)) + if self.bias is None: + return jnp.squeeze(y, 0) + return jnp.squeeze((y + self.bias), 0) @property - @abc.abstractclassmethod + @abc.abstractmethod def spatial_ndim(self) -> int: + """Number of spatial dimensions of the convolutional layer.""" ... - @property - @abc.abstractclassmethod - def pointwise_convolution_layer(self): + @abc.abstractmethod + def convolution_operation(self, x: jax.Array) -> jax.Array: ... - @property - @abc.abstractclassmethod - def depthwise_convolution_layer(self): - ... + +class ConvNDLocal(BaseConvNDLocal): + def convolution_operation(self, x: jax.Array) -> jax.Array: + return jax.lax.conv_general_dilated_local( + lhs=x, + rhs=self.weight, + window_strides=self.strides, + padding=self.padding, + filter_shape=self.kernel_size, + rhs_dilation=self.dilation, + dimension_numbers=generate_conv_dim_numbers(self.spatial_ndim), + ) -class SeparableFFTConv1D(SeparableFFTConvND): - """1D Separable FFT convolution layer. +class Conv1DLocal(ConvNDLocal): + """1D Local convolutional layer. + + Local convolutional layer is a convolutional layer where the convolution + kernel is applied to a local region of the input. The kernel weights are + *not* shared across the spatial dimensions of the input. - Separable convolution is a depthwise convolution followed by a pointwise - convolution. The objective is to reduce the number of parameters in the - convolutional layer. For example, for I input features and O output features, - and a kernel size = Ki, then standard convolution has I * O * K0 ... * Kn + O - parameters, whereas separable convolution has I * K0 ... * Kn + I * O + O - parameters. Args: in_features: Number of input feature maps, for 1D convolution this is the @@ -2410,12 +2180,11 @@ class SeparableFFTConv1D(SeparableFFTConvND): channels. kernel_size: Size of the convolutional kernel. accepts: - - single integer for same kernel size in all dimnsions. - - sequence of integers for different kernel sizes in each dimension. + - single integer for same kernel size in all dimensions. + - sequence of integers for different kernel sizes in each dimension. - depth_multiplier: multiplier for the number of output channels. for example - if the input has 32 channels and the depth multiplier is 2 then the - output will have 64 channels. + in_size: the size of the spatial dimensions of the input. e.g excluding + the first dimension. accepts a sequence of integer(s). strides: Stride of the convolution. accepts: - single integer for same stride in all dimensions. @@ -2431,21 +2200,20 @@ class SeparableFFTConv1D(SeparableFFTConvND): as the input. - ``valid``/``VALID`` for no padding. - weight_init: function to use for initializing the weights. defaults + weight_init: Function to use for initializing the weights. defaults to ``glorot uniform``. - bias_init: function to use for initializing the bias. defaults to + bias_init: Function to use for initializing the bias. defaults to ``zeros``. set to ``None`` to not use a bias. key: key to use for initializing the weights. defaults to ``0``. - Example: >>> import jax.numpy as jnp >>> import serket as sk - >>> l1 = sk.nn.SeparableFFTConv1D(3, 3, 3, depth_multiplier=2) + >>> l1 = sk.nn.Conv1DLocal(3, 3, 3, in_size=(32,)) >>> l1(jnp.ones((3, 32))).shape (3, 32) - References: + Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - https://github.com/google/flax/blob/main/flax/linen/linear.py """ @@ -2454,24 +2222,14 @@ class SeparableFFTConv1D(SeparableFFTConvND): def spatial_ndim(self) -> int: return 1 - @property - def pointwise_convolution_layer(self): - return FFTConv1D - - @property - def depthwise_convolution_layer(self): - return DepthwiseFFTConv1D +class Conv2DLocal(ConvNDLocal): + """2D Local convolutional layer. -class SeparableFFTConv2D(SeparableFFTConvND): - """2D Separable FFT convolution layer. + Local convolutional layer is a convolutional layer where the convolution + kernel is applied to a local region of the input. This means that the kernel + weights are *not* shared across the spatial dimensions of the input. - Separable convolution is a depthwise convolution followed by a pointwise - convolution. The objective is to reduce the number of parameters in the - convolutional layer. For example, for I input features and O output features, - and a kernel size = Ki, then standard convolution has I * O * K0 ... * Kn + O - parameters, whereas separable convolution has I * K0 ... * Kn + I * O + O - parameters. Args: in_features: Number of input feature maps, for 1D convolution this is the @@ -2483,12 +2241,11 @@ class SeparableFFTConv2D(SeparableFFTConvND): channels. kernel_size: Size of the convolutional kernel. accepts: - - single integer for same kernel size in all dimnsions. - - sequence of integers for different kernel sizes in each dimension. + - single integer for same kernel size in all dimensions. + - sequence of integers for different kernel sizes in each dimension. - depth_multiplier: multiplier for the number of output channels. for example - if the input has 32 channels and the depth multiplier is 2 then the - output will have 64 channels. + in_size: the size of the spatial dimensions of the input. e.g excluding + the first dimension. accepts a sequence of integer(s). strides: Stride of the convolution. accepts: - single integer for same stride in all dimensions. @@ -2504,20 +2261,20 @@ class SeparableFFTConv2D(SeparableFFTConvND): as the input. - ``valid``/``VALID`` for no padding. - weight_init: function to use for initializing the weights. defaults + weight_init: Function to use for initializing the weights. defaults to ``glorot uniform``. - bias_init: function to use for initializing the bias. defaults to + bias_init: Function to use for initializing the bias. defaults to ``zeros``. set to ``None`` to not use a bias. key: key to use for initializing the weights. defaults to ``0``. Example: >>> import jax.numpy as jnp >>> import serket as sk - >>> l1 = sk.nn.SeparableFFTConv2D(3, 3, 3, depth_multiplier=2) + >>> l1 = sk.nn.Conv2DLocal(3, 3, 3, in_size=(32, 32)) >>> l1(jnp.ones((3, 32, 32))).shape (3, 32, 32) - References: + Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - https://github.com/google/flax/blob/main/flax/linen/linear.py """ @@ -2526,24 +2283,14 @@ class SeparableFFTConv2D(SeparableFFTConvND): def spatial_ndim(self) -> int: return 2 - @property - def pointwise_convolution_layer(self): - return FFTConv2D - - @property - def depthwise_convolution_layer(self): - return DepthwiseFFTConv2D +class Conv3DLocal(ConvNDLocal): + """3D Local convolutional layer. -class SeparableFFTConv3D(SeparableFFTConvND): - """3D Separable FFT convolution layer. + Local convolutional layer is a convolutional layer where the convolution + kernel is applied to a local region of the input. This means that the kernel + weights are *not* shared across the spatial dimensions of the input. - Separable convolution is a depthwise convolution followed by a pointwise - convolution. The objective is to reduce the number of parameters in the - convolutional layer. For example, for I input features and O output features, - and a kernel size = Ki, then standard convolution has I * O * K0 ... * Kn + O - parameters, whereas separable convolution has I * K0 ... * Kn + I * O + O - parameters. Args: in_features: Number of input feature maps, for 1D convolution this is the @@ -2555,12 +2302,11 @@ class SeparableFFTConv3D(SeparableFFTConvND): channels. kernel_size: Size of the convolutional kernel. accepts: - - single integer for same kernel size in all dimnsions. - - sequence of integers for different kernel sizes in each dimension. + - single integer for same kernel size in all dimensions. + - sequence of integers for different kernel sizes in each dimension. - depth_multiplier: multiplier for the number of output channels. for example - if the input has 32 channels and the depth multiplier is 2 then the - output will have 64 channels. + in_size: the size of the spatial dimensions of the input. e.g excluding + the first dimension. accepts a sequence of integer(s). strides: Stride of the convolution. accepts: - single integer for same stride in all dimensions. @@ -2576,20 +2322,20 @@ class SeparableFFTConv3D(SeparableFFTConvND): as the input. - ``valid``/``VALID`` for no padding. - weight_init: function to use for initializing the weights. defaults + weight_init: Function to use for initializing the weights. defaults to ``glorot uniform``. - bias_init: function to use for initializing the bias. defaults to + bias_init: Function to use for initializing the bias. defaults to ``zeros``. set to ``None`` to not use a bias. key: key to use for initializing the weights. defaults to ``0``. Example: >>> import jax.numpy as jnp >>> import serket as sk - >>> l1 = sk.nn.SeparableFFTConv3D(3, 3, 3, depth_multiplier=2) + >>> l1 = sk.nn.Conv3DLocal(3, 3, 3, in_size=(32, 32, 32)) >>> l1(jnp.ones((3, 32, 32, 32))).shape (3, 32, 32, 32) - References: + Reference: - https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv.html - https://github.com/google/flax/blob/main/flax/linen/linear.py """ @@ -2597,11 +2343,3 @@ class SeparableFFTConv3D(SeparableFFTConvND): @property def spatial_ndim(self) -> int: return 3 - - @property - def pointwise_convolution_layer(self): - return FFTConv3D - - @property - def depthwise_convolution_layer(self): - return DepthwiseFFTConv3D diff --git a/serket/nn/dropout.py b/serket/nn/dropout.py index f5028d0..30ace14 100644 --- a/serket/nn/dropout.py +++ b/serket/nn/dropout.py @@ -51,15 +51,15 @@ class Dropout(sk.TreeClass): >>> layers = sk.nn.Sequential(sk.nn.Dropout(0.5), sk.nn.Linear(10, 10)) >>> sk.tree_eval(layers) Sequential( - layers=( - Identity(), - Linear( - in_features=(10), - out_features=10, - weight=f32[10,10](μ=0.04, σ=0.43, ∈[-0.86,0.95]), - bias=f32[10](μ=1.00, σ=0.00, ∈[1.00,1.00]) - ) + layers=( + Identity(), + Linear( + in_features=(10), + out_features=10, + weight=f32[10,10](μ=0.01, σ=0.45, ∈[-0.96,0.95]), + bias=f32[10](μ=1.00, σ=0.00, ∈[1.00,1.00]) ) + ) ) """ @@ -114,15 +114,15 @@ class Dropout1D(DropoutND): >>> layers = sk.nn.Sequential(sk.nn.Dropout1D(0.5), sk.nn.Linear(10, 10)) >>> sk.tree_eval(layers) Sequential( - layers=( - Identity(), - Linear( - in_features=(10), - out_features=10, - weight=f32[10,10](μ=0.04, σ=0.43, ∈[-0.86,0.95]), - bias=f32[10](μ=1.00, σ=0.00, ∈[1.00,1.00]) - ) + layers=( + Identity(), + Linear( + in_features=(10), + out_features=10, + weight=f32[10,10](μ=0.01, σ=0.45, ∈[-0.96,0.95]), + bias=f32[10](μ=1.00, σ=0.00, ∈[1.00,1.00]) ) + ) ) Reference: @@ -159,15 +159,15 @@ class Dropout2D(DropoutND): >>> layers = sk.nn.Sequential(sk.nn.Dropout2D(0.5), sk.nn.Linear(10, 10)) >>> sk.tree_eval(layers) Sequential( - layers=( - Identity(), - Linear( - in_features=(10), - out_features=10, - weight=f32[10,10](μ=0.04, σ=0.43, ∈[-0.86,0.95]), - bias=f32[10](μ=1.00, σ=0.00, ∈[1.00,1.00]) - ) + layers=( + Identity(), + Linear( + in_features=(10), + out_features=10, + weight=f32[10,10](μ=0.01, σ=0.45, ∈[-0.96,0.95]), + bias=f32[10](μ=1.00, σ=0.00, ∈[1.00,1.00]) ) + ) ) Reference: @@ -204,15 +204,15 @@ class Dropout3D(DropoutND): >>> layers = sk.nn.Sequential(sk.nn.Dropout2D(0.5), sk.nn.Linear(10, 10)) >>> sk.tree_eval(layers) Sequential( - layers=( - Identity(), - Linear( - in_features=(10), - out_features=10, - weight=f32[10,10](μ=0.04, σ=0.43, ∈[-0.86,0.95]), - bias=f32[10](μ=1.00, σ=0.00, ∈[1.00,1.00]) - ) + layers=( + Identity(), + Linear( + in_features=(10), + out_features=10, + weight=f32[10,10](μ=0.01, σ=0.45, ∈[-0.96,0.95]), + bias=f32[10](μ=1.00, σ=0.00, ∈[1.00,1.00]) ) + ) ) Reference: diff --git a/serket/nn/recurrent.py b/serket/nn/recurrent.py index 243701c..1b8006b 100644 --- a/serket/nn/recurrent.py +++ b/serket/nn/recurrent.py @@ -389,8 +389,7 @@ class ConvLSTMNDCell(RNNCell): kernel_size: Size of the convolutional kernel strides: Stride of the convolution padding: Padding of the convolution - input_dilation: Dilation of the input - kernel_dilation: Dilation of the convolutional kernel + dilation: Dilation of the convolutional kernel weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function @@ -410,7 +409,7 @@ def __init__( *, strides: StridesType = 1, padding: PaddingType = "same", - kernel_dilation: DilationType = 1, + dilation: DilationType = 1, weight_init: InitType = "glorot_uniform", bias_init: InitType = "zeros", recurrent_weight_init: InitType = "orthogonal", @@ -431,7 +430,7 @@ def __init__( kernel_size, strides=strides, padding=padding, - kernel_dilation=kernel_dilation, + dilation=dilation, weight_init=weight_init, bias_init=bias_init, key=k1, @@ -443,7 +442,7 @@ def __init__( kernel_size, strides=strides, padding=padding, - kernel_dilation=kernel_dilation, + dilation=dilation, weight_init=recurrent_weight_init, bias_init=None, key=k2, @@ -481,7 +480,7 @@ class ConvLSTM1DCell(ConvLSTMNDCell): kernel_size: Size of the convolutional kernel strides: Stride of the convolution padding: Padding of the convolution - kernel_dilation: Dilation of the convolutional kernel + dilation: Dilation of the convolutional kernel weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function @@ -511,7 +510,7 @@ class FFTConvLSTM1DCell(ConvLSTMNDCell): kernel_size: Size of the convolutional kernel strides: Stride of the convolution padding: Padding of the convolution - kernel_dilation: Dilation of the convolutional kernel + dilation: Dilation of the convolutional kernel weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function @@ -541,7 +540,7 @@ class ConvLSTM2DCell(ConvLSTMNDCell): kernel_size: Size of the convolutional kernel strides: Stride of the convolution padding: Padding of the convolution - kernel_dilation: Dilation of the convolutional kernel + dilation: Dilation of the convolutional kernel weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function @@ -571,7 +570,7 @@ class FFTConvLSTM2DCell(ConvLSTMNDCell): kernel_size: Size of the convolutional kernel strides: Stride of the convolution padding: Padding of the convolution - kernel_dilation: Dilation of the convolutional kernel + dilation: Dilation of the convolutional kernel weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function @@ -601,7 +600,7 @@ class ConvLSTM3DCell(ConvLSTMNDCell): kernel_size: Size of the convolutional kernel strides: Stride of the convolution padding: Padding of the convolution - kernel_dilation: Dilation of the convolutional kernel + dilation: Dilation of the convolutional kernel weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function @@ -631,7 +630,7 @@ class FFTConvLSTM3DCell(ConvLSTMNDCell): kernel_size: Size of the convolutional kernel strides: Stride of the convolution padding: Padding of the convolution - kernel_dilation: Dilation of the convolutional kernel + dilation: Dilation of the convolutional kernel weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function @@ -665,7 +664,7 @@ class ConvGRUNDCell(RNNCell): kernel_size: Size of the convolutional kernel strides: Stride of the convolution padding: Padding of the convolution - kernel_dilation: Dilation of the convolutional kernel + dilation: Dilation of the convolutional kernel weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function @@ -683,7 +682,7 @@ def __init__( *, strides: StridesType = 1, padding: PaddingType = "same", - kernel_dilation: DilationType = 1, + dilation: DilationType = 1, weight_init: InitType = "glorot_uniform", bias_init: InitType = "zeros", recurrent_weight_init: InitType = "orthogonal", @@ -704,7 +703,7 @@ def __init__( kernel_size, strides=strides, padding=padding, - kernel_dilation=kernel_dilation, + dilation=dilation, weight_init=weight_init, bias_init=bias_init, key=k1, @@ -716,7 +715,7 @@ def __init__( kernel_size, strides=strides, padding=padding, - kernel_dilation=kernel_dilation, + dilation=dilation, weight_init=recurrent_weight_init, bias_init=None, key=k2, @@ -751,7 +750,7 @@ class ConvGRU1DCell(ConvGRUNDCell): kernel_size: Size of the convolutional kernel strides: Stride of the convolution padding: Padding of the convolution - kernel_dilation: Dilation of the convolutional kernel + dilation: Dilation of the convolutional kernel weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function @@ -779,7 +778,7 @@ class FFTConvGRU1DCell(ConvGRUNDCell): kernel_size: Size of the convolutional kernel strides: Stride of the convolution padding: Padding of the convolution - kernel_dilation: Dilation of the convolutional kernel + dilation: Dilation of the convolutional kernel weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function @@ -807,7 +806,7 @@ class ConvGRU2DCell(ConvGRUNDCell): kernel_size: Size of the convolutional kernel strides: Stride of the convolution padding: Padding of the convolution - kernel_dilation: Dilation of the convolutional kernel + dilation: Dilation of the convolutional kernel weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function @@ -835,7 +834,7 @@ class FFTConvGRU2DCell(ConvGRUNDCell): kernel_size: Size of the convolutional kernel strides: Stride of the convolution padding: Padding of the convolution - kernel_dilation: Dilation of the convolutional kernel + dilation: Dilation of the convolutional kernel weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function @@ -863,7 +862,7 @@ class ConvGRU3DCell(ConvGRUNDCell): kernel_size: Size of the convolutional kernel strides: Stride of the convolution padding: Padding of the convolution - kernel_dilation: Dilation of the convolutional kernel + dilation: Dilation of the convolutional kernel weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function @@ -890,7 +889,7 @@ class FFTConvGRU3DCell(ConvGRUNDCell): kernel_size: Size of the convolutional kernel strides: Stride of the convolution padding: Padding of the convolution - kernel_dilation: Dilation of the convolutional kernel + dilation: Dilation of the convolutional kernel weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function diff --git a/serket/nn/utils.py b/serket/nn/utils.py index bfe4711..59e2980 100644 --- a/serket/nn/utils.py +++ b/serket/nn/utils.py @@ -70,7 +70,11 @@ def calculate_convolution_output_shape( ) -def same_padding_along_dim(in_dim: int, kernel_size: int, stride: int): +def same_padding_along_dim( + in_dim: int, + kernel_size: int, + stride: int, +) -> tuple[int, int]: # https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2 # di: input dimension # ki: kernel size @@ -88,7 +92,8 @@ def resolve_tuple_padding( padding: PaddingType, kernel_size: KernelSizeType, strides: StridesType, -): +) -> tuple[tuple[int, int], ...]: + del in_dim, strides if len(padding) != len(kernel_size): raise ValueError(f"Length mismatch {len(kernel_size)=}!={len(padding)=}.") @@ -100,21 +105,9 @@ def resolve_tuple_padding( elif isinstance(item, tuple): if len(item) != 2: - # ex: padding = ((1, 2), (3, 4), (5, 6)) raise ValueError(f"Expected tuple of length 2, got {len(item)=}") resolved_padding[i] = item - elif isinstance(item, str): - if item.lower() == "same": - di, ki, si = in_dim[i], kernel_size[i], strides[i] - resolved_padding[i] = same_padding_along_dim(di, ki, si) - - elif item.lower() == "valid": - resolved_padding[i] = (0, 0) - - else: - raise ValueError("Invalid padding, must be in [`same`, `valid`].") - return tuple(resolved_padding) diff --git a/tests/test_conv.py b/tests/test_conv.py index 6744624..c969181 100644 --- a/tests/test_conv.py +++ b/tests/test_conv.py @@ -67,7 +67,7 @@ # kernel_size=ki, # strides=si, # padding=pi, -# kernel_dilation=ddi, +# dilation=ddi, # groups=gi, # ) @@ -143,7 +143,7 @@ # kernel_size=ki, # strides=si, # padding=pi, -# kernel_dilation=ddi, +# dilation=ddi, # groups=gi, # bias_init=None, # ) diff --git a/tests/test_convolution.py b/tests/test_convolution.py index 206f404..45ec023 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -115,13 +115,13 @@ def test_conv1D(): ] ) - layer = Conv1D(1, 2, 3, padding=2, strides=1, kernel_dilation=2) + layer = Conv1D(1, 2, 3, padding=2, strides=1, dilation=2) layer = layer.at["weight"].set(w) layer = layer.at["bias"].set(b) npt.assert_allclose(layer(x), y) - layer = Conv1D(1, 2, 3, padding=2, strides=1, kernel_dilation=2, bias_init=None) + layer = Conv1D(1, 2, 3, padding=2, strides=1, dilation=2, bias_init=None) layer = layer.at["weight"].set(w) npt.assert_allclose(layer(x), y) @@ -237,15 +237,13 @@ def test_conv1dtranspose(): b = jnp.array([[[0.0]]]) - layer = Conv1DTranspose(4, 1, 3, padding=2, strides=1, kernel_dilation=2) + layer = Conv1DTranspose(4, 1, 3, padding=2, strides=1, dilation=2) layer = layer.at["weight"].set(w) layer = layer.at["bias"].set(b) y = jnp.array([[0.27022034, 0.24495776, -0.00368674]]) npt.assert_allclose(layer(x), y, atol=1e-5) - layer = Conv1DTranspose( - 4, 1, 3, padding=2, strides=1, kernel_dilation=2, bias_init=None - ) + layer = Conv1DTranspose(4, 1, 3, padding=2, strides=1, dilation=2, bias_init=None) layer = layer.at["weight"].set(w) y = jnp.array([[0.27022034, 0.24495776, -0.00368674]]) npt.assert_allclose(layer(x), y, atol=1e-5) @@ -299,7 +297,7 @@ def test_conv2dtranspose(): b = jnp.array([[[0.0]]]) - layer = Conv2DTranspose(3, 1, 3, padding=2, strides=1, kernel_dilation=2) + layer = Conv2DTranspose(3, 1, 3, padding=2, strides=1, dilation=2) layer = layer.at["weight"].set(w) layer = layer.at["bias"].set(b) @@ -317,9 +315,7 @@ def test_conv2dtranspose(): npt.assert_allclose(layer(x), y, atol=1e-5) - layer = Conv2DTranspose( - 3, 1, 3, padding=2, strides=1, kernel_dilation=2, bias_init=None - ) + layer = Conv2DTranspose(3, 1, 3, padding=2, strides=1, dilation=2, bias_init=None) layer = layer.at["weight"].set(w) @@ -488,7 +484,7 @@ def test_conv3dtranspose(): b = jnp.array([[[[0.0]]]]) - layer = Conv3DTranspose(4, 1, 3, padding=2, strides=1, kernel_dilation=2) + layer = Conv3DTranspose(4, 1, 3, padding=2, strides=1, dilation=2) layer = layer.at["weight"].set(w) layer = layer.at["bias"].set(b) @@ -516,9 +512,7 @@ def test_conv3dtranspose(): npt.assert_allclose(y, layer(x), atol=1e-5) - layer = Conv3DTranspose( - 4, 1, 3, padding=2, strides=1, kernel_dilation=2, bias_init=None - ) + layer = Conv3DTranspose(4, 1, 3, padding=2, strides=1, dilation=2, bias_init=None) layer = layer.at["weight"].set(w) y = jnp.array( diff --git a/tests/test_fft_convolution.py b/tests/test_fft_convolution.py index 5f60634..db0dfd4 100644 --- a/tests/test_fft_convolution.py +++ b/tests/test_fft_convolution.py @@ -535,8 +535,8 @@ def test_fft_conv(): x = jnp.ones([10, 1]) npt.assert_allclose( - FFTConv1D(10, 1, 3, kernel_dilation=2)(x), - Conv1D(10, 1, 3, kernel_dilation=2)(x), + FFTConv1D(10, 1, 3, dilation=2)(x), + Conv1D(10, 1, 3, dilation=2)(x), atol=1e-5, ) @@ -547,14 +547,14 @@ def test_fft_conv(): x = jnp.ones([10, 10, 10]) npt.assert_allclose( - FFTConv2D(10, 1, 3, kernel_dilation=3)(x), - Conv2D(10, 1, 3, kernel_dilation=3)(x), + FFTConv2D(10, 1, 3, dilation=3)(x), + Conv2D(10, 1, 3, dilation=3)(x), atol=1e-5, ) x = jnp.ones([7, 8, 9]) npt.assert_allclose( - FFTConv2D(7, 1, 3, kernel_dilation=2)(x), - Conv2D(7, 1, 3, kernel_dilation=2)(x), + FFTConv2D(7, 1, 3, dilation=2)(x), + Conv2D(7, 1, 3, dilation=2)(x), atol=1e-5, ) @@ -565,8 +565,8 @@ def test_fft_conv(): x = jnp.ones([7, 8, 9, 10]) npt.assert_allclose( - FFTConv3D(7, 1, 3, kernel_dilation=(1, 2, 3))(x), - Conv3D(7, 1, 3, kernel_dilation=(1, 2, 3))(x), + FFTConv3D(7, 1, 3, dilation=(1, 2, 3))(x), + Conv3D(7, 1, 3, dilation=(1, 2, 3))(x), atol=1e-5, ) @@ -596,8 +596,8 @@ def test_conv_transpose(): x = jnp.ones([10, 4]) npt.assert_allclose( - Conv1DTranspose(10, 4, 3, kernel_dilation=2)(x), - FFTConv1DTranspose(10, 4, 3, kernel_dilation=2)(x), + Conv1DTranspose(10, 4, 3, dilation=2)(x), + FFTConv1DTranspose(10, 4, 3, dilation=2)(x), atol=1e-5, ) @@ -608,8 +608,8 @@ def test_conv_transpose(): x = jnp.ones([10, 4, 4, 4]) npt.assert_allclose( - Conv3DTranspose(10, 4, 3, kernel_dilation=2)(x), - FFTConv3DTranspose(10, 4, 3, kernel_dilation=2)(x), + Conv3DTranspose(10, 4, 3, dilation=2)(x), + FFTConv3DTranspose(10, 4, 3, dilation=2)(x), atol=1e-5, )