Skip to content

Commit

Permalink
add FFTAvgBlur2D and FFTAvgBlur2D
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Sep 5, 2023
1 parent e879883 commit cb8dfeb
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 49 deletions.
4 changes: 4 additions & 0 deletions CHANEGLOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### Changes

- Moved image related layers to `serket.image`

- `ScanRNN` changes:

- `cell.init_state` is deprecated use `sk.tree_state(cell, ...)` instead.
Expand Down Expand Up @@ -51,6 +53,8 @@
- `Solarize2D`
- `Posterize2D`
- `JigSaw2D`
- `FFTAvgBlur2D`
- `FFTGaussianBlur2D`

### Deprecations

Expand Down
6 changes: 4 additions & 2 deletions docs/API/filter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ Filter API
.. currentmodule:: serket.image

.. autoclass:: AvgBlur2D
.. autoclass:: FFTAvgBlur2D
.. autoclass:: GaussianBlur2D
.. autoclass:: FFTGaussianBlur2D
.. autoclass:: FFTFilter2D
.. autoclass:: Filter2D
.. autoclass:: GaussianBlur2D
.. autoclass:: Filter2D
2 changes: 1 addition & 1 deletion docs/API/geometric.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Affine API
Geometric API
---------------------------------
.. currentmodule:: serket.image

Expand Down
23 changes: 17 additions & 6 deletions serket/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@
Posterize2D,
RandomContrast2D,
)
from .filter import AvgBlur2D, FFTFilter2D, Filter2D, GaussianBlur2D
from .filter import (
AvgBlur2D,
FFTAvgBlur2D,
FFTFilter2D,
FFTGaussianBlur2D,
Filter2D,
GaussianBlur2D,
)
from .geometric import (
HorizontalFlip2D,
HorizontalShear2D,
Expand All @@ -39,20 +46,24 @@
)

__all__ = (
# image
# augment
"AdjustContrast2D",
"JigSaw2D",
"PixelShuffle2D",
"Posterize2D",
"RandomContrast2D",
# filter
"AvgBlur2D",
"FFTFilter2D",
"Filter2D",
"GaussianBlur2D",
"FFTAvgBlur2D",
"FFTGaussianBlur2D",
# geometric
"HorizontalFlip2D",
"HorizontalShear2D",
"HorizontalTranslate2D",
"JigSaw2D",
"Pixelate2D",
"PixelShuffle2D",
"Posterize2D",
"RandomContrast2D",
"RandomHorizontalShear2D",
"RandomHorizontalTranslate2D",
"RandomPerspective2D",
Expand Down
201 changes: 162 additions & 39 deletions serket/image/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,38 @@ def fft_filter_2d(
return jax.lax.stop_gradient_p.bind(jnp.squeeze(x, 0))


class AvgBlur2D(sk.TreeClass):
"""Average blur 2D layer
class AvgBlur2DBase(sk.TreeClass):
@ft.partial(maybe_lazy_init, is_lazy=is_lazy_init)
def __init__(
self,
in_features: int | None,
kernel_size: int,
*,
dtype: DType = jnp.float32,
):
self.in_features = positive_int_cb(in_features)
kernel_size = positive_int_cb(kernel_size)
kernel = jnp.ones(kernel_size)
kernel = kernel / jnp.sum(kernel)
kernel = kernel[:, None]
kernel = jnp.repeat(kernel[None, None], in_features, axis=0).astype(dtype)
self.kernel = kernel

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
def __call__(self, x: jax.Array) -> jax.Array:
x = filter_2d(x, self.kernel)
x = filter_2d(x, jnp.moveaxis(self.kernel, 2, 3))
return x

@property
def spatial_ndim(self) -> int:
return 2


class AvgBlur2D(AvgBlur2DBase):
"""Average blur 2D layer.
.. image:: ../_static/avgblur2d.png
Expand Down Expand Up @@ -152,36 +182,96 @@ class AvgBlur2D(sk.TreeClass):
>>> _, materialized_blur = lazy_blur.at["__call__"](jnp.ones((5, 2, 2)))
"""

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
def __call__(self, x: jax.Array) -> jax.Array:
x = filter_2d(x, self.kernel)
x = filter_2d(x, jnp.moveaxis(self.kernel, 2, 3))
return x


class FFTAvgBlur2D(AvgBlur2DBase):
"""Average blur 2D layer using FFT.
.. image:: ../_static/avgblur2d.png
Args:
in_features: number of input channels.
kernel_size: size of the convolving kernel.
dtype: data type of the layer. Defaults to ``jnp.float32``.
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> layer = sk.image.FFTAvgBlur2D(in_features=1, kernel_size=3)
>>> print(layer(jnp.ones((1,5,5)))) # doctest: +SKIP
[[[0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]
[0.6666667 1. 1. 1. 0.6666667 ]
[0.6666667 1. 1. 1. 0.6666667 ]
[0.6666667 1. 1. 1. 0.6666667 ]
[0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]]
Note:
:class:`.FFTAvgBlur2D` supports lazy initialization, meaning that the weights and
biases are not initialized until the first call to the layer. This is
useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use the ``.at["calling_method_name"]`` attribute to call the layer
with an input of known shape.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> import jax
>>> @sk.autoinit
... class Blur(sk.TreeClass):
... l1: sk.image.AvgBlur2D = sk.image.FFTAvgBlur2D(None, 3)
... l2: sk.image.AvgBlur2D = sk.image.FFTAvgBlur2D(None, 3)
... def __call__(self, x: jax.Array) -> jax.Array:
... return self.l2(jax.nn.relu(self.l1(x)))
>>> # lazy initialization
>>> lazy_blur = Blur()
>>> # materialize the layer
>>> _, materialized_blur = lazy_blur.at["__call__"](jnp.ones((5, 2, 2)))
"""

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
def __call__(self, x: jax.Array) -> jax.Array:
x = fft_filter_2d(x, self.kernel)
x = fft_filter_2d(x, jnp.moveaxis(self.kernel, 2, 3))
return x


class GaussianBlur2DBase(sk.TreeClass):
@ft.partial(maybe_lazy_init, is_lazy=is_lazy_init)
def __init__(
self,
in_features: int | None,
in_features: int,
kernel_size: int,
*,
sigma: float = 1.0,
dtype: DType = jnp.float32,
):
self.in_features = positive_int_cb(in_features)
kernel_size = positive_int_cb(kernel_size)
kernel = jnp.ones(kernel_size)
kernel = kernel / jnp.sum(kernel)
kernel = kernel[:, None]
kernel = jnp.repeat(kernel[None, None], in_features, axis=0).astype(dtype)
self.kernel = kernel
self.kernel_size = positive_int_cb(kernel_size)
self.sigma = sigma

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
def __call__(self, x: jax.Array) -> jax.Array:
x = filter_2d(x, self.kernel)
x = filter_2d(x, jnp.moveaxis(self.kernel, 2, 3))
return x
x = jnp.linspace(-(kernel_size - 1) / 2.0, (kernel_size - 1) / 2.0, kernel_size)
kernel = jnp.exp(-0.5 * jnp.square(x) * jax.lax.rsqrt(self.sigma))
kernel = kernel / jnp.sum(kernel)
kernel = kernel[:, None].astype(dtype)
self.kernel = jnp.repeat(kernel[None, None], in_features, axis=0)

@property
def spatial_ndim(self) -> int:
return 2


class GaussianBlur2D(sk.TreeClass):
class GaussianBlur2D(GaussianBlur2DBase):
"""Apply Gaussian blur to a channel-first image.
.. image:: ../_static/gaussianblur2d.png
Expand Down Expand Up @@ -228,25 +318,6 @@ class GaussianBlur2D(sk.TreeClass):
>>> _, materialized_blur = lazy_blur.at["__call__"](jnp.ones((5, 2, 2)))
"""

@ft.partial(maybe_lazy_init, is_lazy=is_lazy_init)
def __init__(
self,
in_features: int,
kernel_size: int,
*,
sigma: float = 1.0,
dtype: DType = jnp.float32,
):
self.in_features = positive_int_cb(in_features)
self.kernel_size = positive_int_cb(kernel_size)
self.sigma = sigma

x = jnp.linspace(-(kernel_size - 1) / 2.0, (kernel_size - 1) / 2.0, kernel_size)
kernel = jnp.exp(-0.5 * jnp.square(x) * jax.lax.rsqrt(self.sigma))
kernel = kernel / jnp.sum(kernel)
kernel = kernel[:, None].astype(dtype)
self.kernel = jnp.repeat(kernel[None, None], in_features, axis=0)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
Expand All @@ -255,9 +326,61 @@ def __call__(self, x: jax.Array) -> jax.Array:
x = filter_2d(x, jnp.moveaxis(self.kernel, 2, 3))
return x

@property
def spatial_ndim(self) -> int:
return 2

class FFTGaussianBlur2D(GaussianBlur2DBase):
"""Apply Gaussian blur to a channel-first image using FFT.
.. image:: ../_static/gaussianblur2d.png
Args:
in_features: number of input features
kernel_size: kernel size
sigma: sigma. Defaults to 1.
dtype: data type of the layer. Defaults to ``jnp.float32``.
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> layer = sk.image.FFTGaussianBlur2D(in_features=1, kernel_size=3)
>>> print(layer(jnp.ones((1,5,5)))) # doctest: +SKIP
[[[0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]
[0.7259314 1. 1. 1. 0.7259314]
[0.7259314 1. 1. 1. 0.7259314]
[0.7259314 1. 1. 1. 0.7259314]
[0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]]
Note:
:class:`.FFTGaussianBlur2D` supports lazy initialization, meaning that the weights and
biases are not initialized until the first call to the layer. This is
useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use the ``.at["calling_method_name"]`` attribute to call the layer
with an input of known shape.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> import jax
>>> @sk.autoinit
... class Blur(sk.TreeClass):
... l1: sk.image.FFTGaussianBlur2D = sk.image.FFTGaussianBlur2D(None, 3)
... l2: sk.image.FFTGaussianBlur2D = sk.image.FFTGaussianBlur2D(None, 3)
... def __call__(self, x: jax.Array) -> jax.Array:
... return self.l2(jax.nn.relu(self.l1(x)))
>>> # lazy initialization
>>> lazy_blur = Blur()
>>> # materialize the layer
>>> _, materialized_blur = lazy_blur.at["__call__"](jnp.ones((5, 2, 2)))
"""

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
def __call__(self, x: jax.Array) -> jax.Array:
x = fft_filter_2d(x, self.kernel)
x = fft_filter_2d(x, jnp.moveaxis(self.kernel, 2, 3))
return x


class Filter2D(sk.TreeClass):
Expand Down
17 changes: 16 additions & 1 deletion tests/test_image_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,14 @@
PixelShuffle2D,
RandomContrast2D,
)
from serket.image.filter import AvgBlur2D, FFTFilter2D, Filter2D, GaussianBlur2D
from serket.image.filter import (
AvgBlur2D,
FFTAvgBlur2D,
FFTFilter2D,
FFTGaussianBlur2D,
Filter2D,
GaussianBlur2D,
)
from serket.image.geometric import (
HorizontalFlip2D,
HorizontalShear2D,
Expand Down Expand Up @@ -64,6 +71,10 @@ def test_AvgBlur2D():
with pytest.raises(ValueError):
AvgBlur2D(0, 1)

# test with
z = FFTAvgBlur2D(1, 3)(jnp.arange(1, 26).reshape([1, 5, 5]).astype(jnp.float32))
npt.assert_allclose(y, z, atol=1e-5)


def test_GaussBlur2D():
layer = GaussianBlur2D(in_features=1, kernel_size=3, sigma=1.0)
Expand Down Expand Up @@ -91,6 +102,10 @@ def test_GaussBlur2D():
with pytest.raises(ValueError):
GaussianBlur2D(0, 1, sigma=1.0)

z = FFTGaussianBlur2D(1, 3, sigma=1.0)(jnp.ones([1, 5, 5])).astype(jnp.float32)

npt.assert_allclose(layer(x), z, atol=1e-5)


# def test_lazy_blur():
# layer = GaussianBlur2D(in_features=None, kernel_size=3, sigma=1.0)
Expand Down

0 comments on commit cb8dfeb

Please sign in to comment.