From 72ce35c0bae6382bc945f3a14d89ead621d09125 Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Fri, 8 Sep 2023 02:36:23 +0900 Subject: [PATCH] simplify filter funcs --- serket/image/filter.py | 212 ++++++------------------------------- tests/test_image_filter.py | 25 ++--- 2 files changed, 43 insertions(+), 194 deletions(-) diff --git a/serket/image/filter.py b/serket/image/filter.py index ceab1d9..d6b1579 100644 --- a/serket/image/filter.py +++ b/serket/image/filter.py @@ -25,30 +25,12 @@ from serket.nn.initialization import DType from serket.utils import ( generate_conv_dim_numbers, - maybe_lazy_call, - maybe_lazy_init, positive_int_cb, resolve_string_padding, - validate_axis_shape, validate_spatial_ndim, ) -def is_lazy_call(instance, *_, **__) -> bool: - return getattr(instance, "in_features", False) is None - - -def is_lazy_init(_, in_features, *__, **___) -> bool: - return in_features is None - - -def infer_in_features(_, x, *__, **___) -> int: - return x.shape[0] - - -image_updates = dict(in_features=infer_in_features) - - def filter_2d( array: Annotated[jax.Array, "CHW"], weight: Annotated[jax.Array, "OIHW"], @@ -107,21 +89,11 @@ def fft_filter_2d( class AvgBlur2DBase(sk.TreeClass): - @ft.partial(maybe_lazy_init, is_lazy=is_lazy_init) - def __init__( - self, - in_features: int | None, - kernel_size: int, - *, - dtype: DType = jnp.float32, - ): - self.in_features = positive_int_cb(in_features) + def __init__(self, kernel_size: int, *, dtype: DType = jnp.float32): kernel_size = positive_int_cb(kernel_size) kernel = jnp.ones(kernel_size) kernel = kernel / jnp.sum(kernel) - kernel = kernel[:, None] - kernel = jnp.repeat(kernel[None, None], in_features, axis=0).astype(dtype) - self.kernel = kernel + self.kernel = kernel[None, None, None].astype(dtype) @property def spatial_ndim(self) -> int: @@ -134,52 +106,27 @@ class AvgBlur2D(AvgBlur2DBase): .. image:: ../_static/avgblur2d.png Args: - in_features: number of input channels. kernel_size: size of the convolving kernel. dtype: data type of the layer. Defaults to ``jnp.float32``. Example: >>> import serket as sk >>> import jax.numpy as jnp - >>> layer = sk.image.AvgBlur2D(in_features=1, kernel_size=3) + >>> layer = sk.image.AvgBlur2D(kernel_size=3) >>> print(layer(jnp.ones((1,5,5)))) # doctest: +SKIP [[[0.44444448 0.6666667 0.6666667 0.6666667 0.44444448] [0.6666667 1. 1. 1. 0.6666667 ] [0.6666667 1. 1. 1. 0.6666667 ] [0.6666667 1. 1. 1. 0.6666667 ] [0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]] - - Note: - :class:`.AvgBlur2D` supports lazy initialization, meaning that the weights and - biases are not initialized until the first call to the layer. This is - useful when the input shape is not known at initialization time. - - To use lazy initialization, pass ``None`` as the ``in_features`` argument - and use the ``.at["calling_method_name"]`` attribute to call the layer - with an input of known shape. - - >>> import serket as sk - >>> import jax.numpy as jnp - >>> import jax.random as jr - >>> import jax - >>> @sk.autoinit - ... class Blur(sk.TreeClass): - ... l1: sk.image.AvgBlur2D = sk.image.AvgBlur2D(None, 3) - ... l2: sk.image.AvgBlur2D = sk.image.AvgBlur2D(None, 3) - ... def __call__(self, x: jax.Array) -> jax.Array: - ... return self.l2(jax.nn.relu(self.l1(x))) - >>> # lazy initialization - >>> lazy_blur = Blur() - >>> # materialize the layer - >>> _, materialized_blur = lazy_blur.at["__call__"](jnp.ones((5, 2, 2))) """ - @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array) -> jax.Array: - x = filter_2d(x, self.kernel) - x = filter_2d(x, jnp.moveaxis(self.kernel, 2, 3)) + x = jnp.expand_dims(x, 1) + x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel) + x = jax.vmap(filter_2d, in_axes=(0, None))(x, jnp.moveaxis(self.kernel, 2, 3)) + x = x[:, 0] return x @@ -189,74 +136,44 @@ class FFTAvgBlur2D(AvgBlur2DBase): .. image:: ../_static/avgblur2d.png Args: - in_features: number of input channels. kernel_size: size of the convolving kernel. dtype: data type of the layer. Defaults to ``jnp.float32``. Example: >>> import serket as sk >>> import jax.numpy as jnp - >>> layer = sk.image.FFTAvgBlur2D(in_features=1, kernel_size=3) + >>> layer = sk.image.FFTAvgBlur2D(kernel_size=3) >>> print(layer(jnp.ones((1,5,5)))) # doctest: +SKIP [[[0.44444448 0.6666667 0.6666667 0.6666667 0.44444448] [0.6666667 1. 1. 1. 0.6666667 ] [0.6666667 1. 1. 1. 0.6666667 ] [0.6666667 1. 1. 1. 0.6666667 ] [0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]] - - Note: - :class:`.FFTAvgBlur2D` supports lazy initialization, meaning that the weights and - biases are not initialized until the first call to the layer. This is - useful when the input shape is not known at initialization time. - - To use lazy initialization, pass ``None`` as the ``in_features`` argument - and use the ``.at["calling_method_name"]`` attribute to call the layer - with an input of known shape. - - >>> import serket as sk - >>> import jax.numpy as jnp - >>> import jax.random as jr - >>> import jax - >>> @sk.autoinit - ... class Blur(sk.TreeClass): - ... l1: sk.image.AvgBlur2D = sk.image.FFTAvgBlur2D(None, 3) - ... l2: sk.image.AvgBlur2D = sk.image.FFTAvgBlur2D(None, 3) - ... def __call__(self, x: jax.Array) -> jax.Array: - ... return self.l2(jax.nn.relu(self.l1(x))) - >>> # lazy initialization - >>> lazy_blur = Blur() - >>> # materialize the layer - >>> _, materialized_blur = lazy_blur.at["__call__"](jnp.ones((5, 2, 2))) """ - @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array) -> jax.Array: - x = fft_filter_2d(x, self.kernel) - x = fft_filter_2d(x, jnp.moveaxis(self.kernel, 2, 3)) + x = jnp.expand_dims(x, 1) + x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel) + x = jax.vmap(filter_2d, in_axes=(0, None))(x, jnp.moveaxis(self.kernel, 2, 3)) + x = jnp.squeeze(x, 1) return x class GaussianBlur2DBase(sk.TreeClass): - @ft.partial(maybe_lazy_init, is_lazy=is_lazy_init) def __init__( self, - in_features: int, kernel_size: int, *, sigma: float = 1.0, dtype: DType = jnp.float32, ): - self.in_features = positive_int_cb(in_features) self.kernel_size = positive_int_cb(kernel_size) self.sigma = sigma - x = jnp.linspace(-(kernel_size - 1) / 2.0, (kernel_size - 1) / 2.0, kernel_size) kernel = jnp.exp(-0.5 * jnp.square(x) * jax.lax.rsqrt(self.sigma)) kernel = kernel / jnp.sum(kernel) - kernel = kernel[:, None].astype(dtype) - self.kernel = jnp.repeat(kernel[None, None], in_features, axis=0) + self.kernel = kernel[None, None, None].astype(dtype) @property def spatial_ndim(self) -> int: @@ -269,7 +186,6 @@ class GaussianBlur2D(GaussianBlur2DBase): .. image:: ../_static/gaussianblur2d.png Args: - in_features: number of input features kernel_size: kernel size sigma: sigma. Defaults to 1. dtype: data type of the layer. Defaults to ``jnp.float32``. @@ -277,45 +193,21 @@ class GaussianBlur2D(GaussianBlur2DBase): Example: >>> import serket as sk >>> import jax.numpy as jnp - >>> layer = sk.image.GaussianBlur2D(in_features=1, kernel_size=3) + >>> layer = sk.image.GaussianBlur2D(kernel_size=3) >>> print(layer(jnp.ones((1,5,5)))) # doctest: +SKIP [[[0.5269764 0.7259314 0.7259314 0.7259314 0.5269764] [0.7259314 1. 1. 1. 0.7259314] [0.7259314 1. 1. 1. 0.7259314] [0.7259314 1. 1. 1. 0.7259314] [0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]] - - Note: - :class:`.GaussianBlur2D` supports lazy initialization, meaning that the weights and - biases are not initialized until the first call to the layer. This is - useful when the input shape is not known at initialization time. - - To use lazy initialization, pass ``None`` as the ``in_features`` argument - and use the ``.at["calling_method_name"]`` attribute to call the layer - with an input of known shape. - - >>> import serket as sk - >>> import jax.numpy as jnp - >>> import jax.random as jr - >>> import jax - >>> @sk.autoinit - ... class Blur(sk.TreeClass): - ... l1: sk.image.GaussianBlur2D = sk.image.GaussianBlur2D(None, 3) - ... l2: sk.image.GaussianBlur2D = sk.image.GaussianBlur2D(None, 3) - ... def __call__(self, x: jax.Array) -> jax.Array: - ... return self.l2(jax.nn.relu(self.l1(x))) - >>> # lazy initialization - >>> lazy_blur = Blur() - >>> # materialize the layer - >>> _, materialized_blur = lazy_blur.at["__call__"](jnp.ones((5, 2, 2))) """ - @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array) -> jax.Array: - x = filter_2d(x, self.kernel) - x = filter_2d(x, jnp.moveaxis(self.kernel, 2, 3)) + x = jnp.expand_dims(x, 1) + x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel) + x = jax.vmap(filter_2d, in_axes=(0, None))(x, jnp.moveaxis(self.kernel, 2, 3)) + x = jnp.squeeze(x, 1) return x @@ -325,7 +217,6 @@ class FFTGaussianBlur2D(GaussianBlur2DBase): .. image:: ../_static/gaussianblur2d.png Args: - in_features: number of input features kernel_size: kernel size sigma: sigma. Defaults to 1. dtype: data type of the layer. Defaults to ``jnp.float32``. @@ -333,45 +224,21 @@ class FFTGaussianBlur2D(GaussianBlur2DBase): Example: >>> import serket as sk >>> import jax.numpy as jnp - >>> layer = sk.image.FFTGaussianBlur2D(in_features=1, kernel_size=3) + >>> layer = sk.image.FFTGaussianBlur2D(kernel_size=3) >>> print(layer(jnp.ones((1,5,5)))) # doctest: +SKIP [[[0.5269764 0.7259314 0.7259314 0.7259314 0.5269764] [0.7259314 1. 1. 1. 0.7259314] [0.7259314 1. 1. 1. 0.7259314] [0.7259314 1. 1. 1. 0.7259314] [0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]] - - Note: - :class:`.FFTGaussianBlur2D` supports lazy initialization, meaning that the weights and - biases are not initialized until the first call to the layer. This is - useful when the input shape is not known at initialization time. - - To use lazy initialization, pass ``None`` as the ``in_features`` argument - and use the ``.at["calling_method_name"]`` attribute to call the layer - with an input of known shape. - - >>> import serket as sk - >>> import jax.numpy as jnp - >>> import jax.random as jr - >>> import jax - >>> @sk.autoinit - ... class Blur(sk.TreeClass): - ... l1: sk.image.FFTGaussianBlur2D = sk.image.FFTGaussianBlur2D(None, 3) - ... l2: sk.image.FFTGaussianBlur2D = sk.image.FFTGaussianBlur2D(None, 3) - ... def __call__(self, x: jax.Array) -> jax.Array: - ... return self.l2(jax.nn.relu(self.l1(x))) - >>> # lazy initialization - >>> lazy_blur = Blur() - >>> # materialize the layer - >>> _, materialized_blur = lazy_blur.at["__call__"](jnp.ones((5, 2, 2))) """ - @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array) -> jax.Array: - x = fft_filter_2d(x, self.kernel) - x = fft_filter_2d(x, jnp.moveaxis(self.kernel, 2, 3)) + x = jnp.expand_dims(x, 1) + x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel) + x = jax.vmap(filter_2d, in_axes=(0, None))(x, jnp.moveaxis(self.kernel, 2, 3)) + x = jnp.squeeze(x, 1) return x @@ -381,14 +248,13 @@ class Filter2D(sk.TreeClass): .. image:: ../_static/filter2d.png Args: - in_features: number of input channels. kernel: kernel array with shape (H, W). dtype: data type of the layer. Defaults to ``jnp.float32``. Example: >>> import serket as sk >>> import jax.numpy as jnp - >>> layer = sk.image.Filter2D(in_features=1, kernel=jnp.ones((3,3))) + >>> layer = sk.image.Filter2D(kernel=jnp.ones((3,3))) >>> print(layer(jnp.ones((1,5,5)))) [[[4. 6. 6. 6. 4.] [6. 9. 9. 9. 6.] @@ -397,26 +263,22 @@ class Filter2D(sk.TreeClass): [4. 6. 6. 6. 4.]]] """ - @ft.partial(maybe_lazy_init, is_lazy=is_lazy_init) def __init__( self, - in_features: int, kernel: jax.Array, *, dtype: DType = jnp.float32, ): if not isinstance(kernel, jax.Array) or kernel.ndim != 2: raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)") + self.kernel = kernel[None, None].astype(dtype) - self.in_features = positive_int_cb(in_features) - kernel = jnp.stack([kernel] * in_features, axis=0) - self.kernel = kernel[:, None].astype(dtype) - - @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array) -> jax.Array: - return filter_2d(x, self.kernel) + x = jnp.expand_dims(x, 1) + x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel) + x = jnp.squeeze(x, 1) + return x @property def spatial_ndim(self) -> int: @@ -429,14 +291,13 @@ class FFTFilter2D(sk.TreeClass): .. image:: ../_static/filter2d.png Args: - in_features: number of input channels kernel: kernel array dtype: data type of the layer. Defaults to ``jnp.float32``. Example: >>> import serket as sk >>> import jax.numpy as jnp - >>> layer = sk.image.FFTFilter2D(in_features=1, kernel=jnp.ones((3,3))) + >>> layer = sk.image.FFTFilter2D(kernel=jnp.ones((3,3))) >>> print(layer(jnp.ones((1,5,5)))) # doctest: +SKIP [[[4.0000005 6.0000005 6.000001 6.0000005 4.0000005] [6.0000005 9. 9. 9. 6.0000005] @@ -445,10 +306,8 @@ class FFTFilter2D(sk.TreeClass): [4. 6.0000005 6.0000005 6.0000005 4. ]]] """ - @ft.partial(maybe_lazy_init, is_lazy=is_lazy_init) def __init__( self, - in_features: int, kernel: jax.Array, *, dtype: DType = jnp.float32, @@ -456,15 +315,14 @@ def __init__( if kernel.ndim != 2: raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)") - self.in_features = positive_int_cb(in_features) - kernel = jnp.stack([kernel] * in_features, axis=0) - self.kernel = kernel[:, None].astype(dtype) + self.kernel = kernel[None, None].astype(dtype) - @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array) -> jax.Array: - return fft_filter_2d(x, self.kernel) + x = jnp.expand_dims(x, 1) + x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel) + x = jnp.squeeze(x, 1) + return x @property def spatial_ndim(self) -> int: diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index e89293f..ac806cb 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -51,7 +51,7 @@ def test_AvgBlur2D(): - x = AvgBlur2D(1, 3)(jnp.arange(1, 26).reshape([1, 5, 5]).astype(jnp.float32)) + x = AvgBlur2D(3)(jnp.arange(1, 26).reshape([1, 5, 5]).astype(jnp.float32)) y = [ [ @@ -65,19 +65,13 @@ def test_AvgBlur2D(): npt.assert_allclose(x, y, atol=1e-5) - # with pytest.raises(ValueError): - # AvgBlur2D(1, 0) - - with pytest.raises(ValueError): - AvgBlur2D(0, 1) - # test with - z = FFTAvgBlur2D(1, 3)(jnp.arange(1, 26).reshape([1, 5, 5]).astype(jnp.float32)) + z = FFTAvgBlur2D(3)(jnp.arange(1, 26).reshape([1, 5, 5]).astype(jnp.float32)) npt.assert_allclose(y, z, atol=1e-5) def test_GaussBlur2D(): - layer = GaussianBlur2D(in_features=1, kernel_size=3, sigma=1.0) + layer = GaussianBlur2D(kernel_size=3, sigma=1.0) x = jnp.ones([1, 5, 5]) npt.assert_allclose( @@ -97,12 +91,9 @@ def test_GaussBlur2D(): ) with pytest.raises(ValueError): - GaussianBlur2D(1, 0, sigma=1.0) - - with pytest.raises(ValueError): - GaussianBlur2D(0, 1, sigma=1.0) + GaussianBlur2D(0, sigma=1.0) - z = FFTGaussianBlur2D(1, 3, sigma=1.0)(jnp.ones([1, 5, 5])).astype(jnp.float32) + z = FFTGaussianBlur2D(3, sigma=1.0)(jnp.ones([1, 5, 5])).astype(jnp.float32) npt.assert_allclose(layer(x), z, atol=1e-5) @@ -128,12 +119,12 @@ def test_GaussBlur2D(): def test_filter2d(): - layer = Filter2D(in_features=1, kernel=jnp.ones([3, 3]) / 9.0) + layer = Filter2D(kernel=jnp.ones([3, 3]) / 9.0) x = jnp.ones([1, 5, 5]) - npt.assert_allclose(AvgBlur2D(1, 3)(x), layer(x), atol=1e-4) + npt.assert_allclose(AvgBlur2D(3)(x), layer(x), atol=1e-4) - layer2 = FFTFilter2D(in_features=1, kernel=jnp.ones([3, 3]) / 9.0) + layer2 = FFTFilter2D(kernel=jnp.ones([3, 3]) / 9.0) npt.assert_allclose(layer(x), layer2(x), atol=1e-4)