diff --git a/serket/image/__init__.py b/serket/image/__init__.py index afb9f6f..664bd50 100644 --- a/serket/image/__init__.py +++ b/serket/image/__init__.py @@ -54,11 +54,11 @@ "RandomContrast2D", # filter "AvgBlur2D", + "FFTAvgBlur2D", "FFTFilter2D", + "FFTGaussianBlur2D", "Filter2D", "GaussianBlur2D", - "FFTAvgBlur2D", - "FFTGaussianBlur2D", # geometric "HorizontalFlip2D", "HorizontalShear2D", diff --git a/serket/image/filter.py b/serket/image/filter.py index d6b1579..cd725c3 100644 --- a/serket/image/filter.py +++ b/serket/image/filter.py @@ -125,8 +125,9 @@ class AvgBlur2D(AvgBlur2DBase): def __call__(self, x: jax.Array) -> jax.Array: x = jnp.expand_dims(x, 1) x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel) - x = jax.vmap(filter_2d, in_axes=(0, None))(x, jnp.moveaxis(self.kernel, 2, 3)) - x = x[:, 0] + k = jnp.moveaxis(self.kernel, 2, 3) + x = jax.vmap(filter_2d, in_axes=(0, None))(x, k) + x = jnp.squeeze(x, 1) return x @@ -154,8 +155,9 @@ class FFTAvgBlur2D(AvgBlur2DBase): @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") def __call__(self, x: jax.Array) -> jax.Array: x = jnp.expand_dims(x, 1) - x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel) - x = jax.vmap(filter_2d, in_axes=(0, None))(x, jnp.moveaxis(self.kernel, 2, 3)) + x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, self.kernel) + k = jnp.moveaxis(self.kernel, 2, 3) + x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, k) x = jnp.squeeze(x, 1) return x @@ -170,8 +172,9 @@ def __init__( ): self.kernel_size = positive_int_cb(kernel_size) self.sigma = sigma - x = jnp.linspace(-(kernel_size - 1) / 2.0, (kernel_size - 1) / 2.0, kernel_size) - kernel = jnp.exp(-0.5 * jnp.square(x) * jax.lax.rsqrt(self.sigma)) + x = jnp.arange(kernel_size) - kernel_size // 2 + x = x + 0.5 if kernel_size % 2 == 0 else x + kernel = jnp.exp(-(x**2) / (2 * sigma**2)) kernel = kernel / jnp.sum(kernel) self.kernel = kernel[None, None, None].astype(dtype) @@ -206,7 +209,8 @@ class GaussianBlur2D(GaussianBlur2DBase): def __call__(self, x: jax.Array) -> jax.Array: x = jnp.expand_dims(x, 1) x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel) - x = jax.vmap(filter_2d, in_axes=(0, None))(x, jnp.moveaxis(self.kernel, 2, 3)) + k = jnp.moveaxis(self.kernel, 2, 3) + x = jax.vmap(filter_2d, in_axes=(0, None))(x, k) x = jnp.squeeze(x, 1) return x @@ -236,8 +240,9 @@ class FFTGaussianBlur2D(GaussianBlur2DBase): @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") def __call__(self, x: jax.Array) -> jax.Array: x = jnp.expand_dims(x, 1) - x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel) - x = jax.vmap(filter_2d, in_axes=(0, None))(x, jnp.moveaxis(self.kernel, 2, 3)) + x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, self.kernel) + k = jnp.moveaxis(self.kernel, 2, 3) + x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, k) x = jnp.squeeze(x, 1) return x @@ -306,12 +311,7 @@ class FFTFilter2D(sk.TreeClass): [4. 6.0000005 6.0000005 6.0000005 4. ]]] """ - def __init__( - self, - kernel: jax.Array, - *, - dtype: DType = jnp.float32, - ): + def __init__(self, kernel: jax.Array, *, dtype: DType = jnp.float32): if kernel.ndim != 2: raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)") @@ -320,7 +320,7 @@ def __init__( @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") def __call__(self, x: jax.Array) -> jax.Array: x = jnp.expand_dims(x, 1) - x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel) + x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, self.kernel) x = jnp.squeeze(x, 1) return x