Skip to content

Commit

Permalink
fix filter
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Sep 7, 2023
1 parent 72ce35c commit 8422fa1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
4 changes: 2 additions & 2 deletions serket/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@
"RandomContrast2D",
# filter
"AvgBlur2D",
"FFTAvgBlur2D",
"FFTFilter2D",
"FFTGaussianBlur2D",
"Filter2D",
"GaussianBlur2D",
"FFTAvgBlur2D",
"FFTGaussianBlur2D",
# geometric
"HorizontalFlip2D",
"HorizontalShear2D",
Expand Down
32 changes: 16 additions & 16 deletions serket/image/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,9 @@ class AvgBlur2D(AvgBlur2DBase):
def __call__(self, x: jax.Array) -> jax.Array:
x = jnp.expand_dims(x, 1)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, jnp.moveaxis(self.kernel, 2, 3))
x = x[:, 0]
k = jnp.moveaxis(self.kernel, 2, 3)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, k)
x = jnp.squeeze(x, 1)
return x


Expand Down Expand Up @@ -154,8 +155,9 @@ class FFTAvgBlur2D(AvgBlur2DBase):
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
x = jnp.expand_dims(x, 1)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, jnp.moveaxis(self.kernel, 2, 3))
x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, self.kernel)
k = jnp.moveaxis(self.kernel, 2, 3)
x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, k)
x = jnp.squeeze(x, 1)
return x

Expand All @@ -170,8 +172,9 @@ def __init__(
):
self.kernel_size = positive_int_cb(kernel_size)
self.sigma = sigma
x = jnp.linspace(-(kernel_size - 1) / 2.0, (kernel_size - 1) / 2.0, kernel_size)
kernel = jnp.exp(-0.5 * jnp.square(x) * jax.lax.rsqrt(self.sigma))
x = jnp.arange(kernel_size) - kernel_size // 2
x = x + 0.5 if kernel_size % 2 == 0 else x
kernel = jnp.exp(-(x**2) / (2 * sigma**2))
kernel = kernel / jnp.sum(kernel)
self.kernel = kernel[None, None, None].astype(dtype)

Expand Down Expand Up @@ -206,7 +209,8 @@ class GaussianBlur2D(GaussianBlur2DBase):
def __call__(self, x: jax.Array) -> jax.Array:
x = jnp.expand_dims(x, 1)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, jnp.moveaxis(self.kernel, 2, 3))
k = jnp.moveaxis(self.kernel, 2, 3)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, k)
x = jnp.squeeze(x, 1)
return x

Expand Down Expand Up @@ -236,8 +240,9 @@ class FFTGaussianBlur2D(GaussianBlur2DBase):
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
x = jnp.expand_dims(x, 1)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, jnp.moveaxis(self.kernel, 2, 3))
x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, self.kernel)
k = jnp.moveaxis(self.kernel, 2, 3)
x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, k)
x = jnp.squeeze(x, 1)
return x

Expand Down Expand Up @@ -306,12 +311,7 @@ class FFTFilter2D(sk.TreeClass):
[4. 6.0000005 6.0000005 6.0000005 4. ]]]
"""

def __init__(
self,
kernel: jax.Array,
*,
dtype: DType = jnp.float32,
):
def __init__(self, kernel: jax.Array, *, dtype: DType = jnp.float32):
if kernel.ndim != 2:
raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)")

Expand All @@ -320,7 +320,7 @@ def __init__(
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
x = jnp.expand_dims(x, 1)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel)
x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, self.kernel)
x = jnp.squeeze(x, 1)
return x

Expand Down

0 comments on commit 8422fa1

Please sign in to comment.