Skip to content

Commit

Permalink
add Laplacian2D
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Sep 11, 2023
1 parent 9a6b4f2 commit 8704067
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/API/filter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ Filter API
.. autoclass:: FFTBoxBlur2D
.. autoclass:: GaussianBlur2D
.. autoclass:: FFTGaussianBlur2D
.. autoclass:: Laplacian2D
.. autoclass:: FFTLaplacian2D
.. autoclass:: UnsharpMask2D
.. autoclass:: FFTUnsharpMask2D

Expand Down
83 changes: 83 additions & 0 deletions serket/_src/image/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,89 @@ def __call__(self, x: jax.Array) -> jax.Array:
return x


def calculate_laplacian_kernel(
kernel_size: tuple[int, int],
dtype: DType,
) -> Annotated[jax.Array, "HW"]:
ky, kx = kernel_size
kernel = jnp.ones((ky, kx))
kernel = kernel.at[ky // 2, kx // 2].set(1 - jnp.sum(kernel)).astype(dtype)
return kernel


class Laplacian2DBase(sk.TreeClass):
def __init__(
self,
kernel_size: int | tuple[int, int],
*,
dtype: DType = jnp.float32,
):
self.kernel_size = canonicalize(kernel_size, ndim=2, name="kernel_size")
self.kernel = calculate_laplacian_kernel(self.kernel_size, dtype)

@property
def spatial_ndim(self) -> int:
return 2


class Laplacian2D(Laplacian2DBase):
"""Apply Laplacian filter to a channel-first image.
Args:
kernel_size: size of the convolving kernel. Accepts int or tuple of two ints.
dtype: data type of the layer. Defaults to ``jnp.float32``.
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> layer = sk.image.Laplacian2D(kernel_size=(3, 5))
>>> print(layer(jnp.ones((1, 5, 5))))
[[[-9. -7. -5. -7. -9.]
[-6. -3. 0. -3. -6.]
[-6. -3. 0. -3. -6.]
[-6. -3. 0. -3. -6.]
[-9. -7. -5. -7. -9.]]]
Note:
The laplacian considers all the neighbors of a pixel.
"""

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
kernel = jax.lax.stop_gradient_p.bind(self.kernel)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel)
return x


class FFTLaplacian2D(Laplacian2DBase):
"""Apply Laplacian filter to a channel-first image using FFT.
Args:
kernel_size: size of the convolving kernel. Accepts int or tuple of two ints.
dtype: data type of the layer. Defaults to ``jnp.float32``.
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> layer = sk.image.FFTLaplacian2D(kernel_size=(3, 5))
>>> print(layer(jnp.ones((1, 5, 5))))
[[[-9. -7. -5. -7. -9.]
[-6. -3. 0. -3. -6.]
[-6. -3. 0. -3. -6.]
[-6. -3. 0. -3. -6.]
[-9. -7. -5. -7. -9.]]]
Note:
The laplacian considers all the neighbors of a pixel.
"""

@ft.partial(validate_spatial_nd, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
kernel = jax.lax.stop_gradient_p.bind(self.kernel)
x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel)
return x


class Filter2D(sk.TreeClass):
"""Apply 2D filter for each channel
Expand Down
4 changes: 4 additions & 0 deletions serket/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
FFTBoxBlur2D,
FFTFilter2D,
FFTGaussianBlur2D,
FFTLaplacian2D,
FFTUnsharpMask2D,
Filter2D,
GaussianBlur2D,
Laplacian2D,
UnsharpMask2D,
)
from serket._src.image.geometric import (
Expand Down Expand Up @@ -63,9 +65,11 @@
"FFTBoxBlur2D",
"FFTFilter2D",
"FFTGaussianBlur2D",
"FFTLaplacian2D",
"FFTUnsharpMask2D",
"Filter2D",
"GaussianBlur2D",
"Laplacian2D",
"UnsharpMask2D",
# geometric
"HorizontalFlip2D",
Expand Down
18 changes: 18 additions & 0 deletions tests/test_image_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,21 @@ def test_unsharp_mask():
x + (x - guassian_x),
atol=1e-5,
)


def test_laplacian():
x = jax.random.uniform(jax.random.PRNGKey(0), (2, 10, 10))

kernel = jnp.array(([[1.0, 1.0, 1.0], [1.0, -8.0, 1.0], [1.0, 1.0, 1.0]]))

npt.assert_allclose(
sk.image.Laplacian2D(3)(x),
sk.image.Filter2D(kernel)(x),
atol=1e-5,
)

npt.assert_allclose(
sk.image.FFTLaplacian2D(3)(x),
sk.image.Filter2D(kernel)(x),
atol=1e-5,
)

0 comments on commit 8704067

Please sign in to comment.