Skip to content

Commit

Permalink
Update filter.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Sep 8, 2023
1 parent f43afeb commit 7c7b4ec
Showing 1 changed file with 72 additions and 58 deletions.
130 changes: 72 additions & 58 deletions serket/image/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,68 +32,82 @@


def filter_2d(
array: Annotated[jax.Array, "CHW"],
weight: Annotated[jax.Array, "OIHW"],
) -> jax.Array:
array: Annotated[jax.Array, "HW"],
weight: Annotated[jax.Array, "HW"],
) -> Annotated[jax.Array, "HW"]:
"""Filtering wrapping ``jax.lax.conv_general_dilated``.
Args:
array: input array. shape is (in_features, *spatial).
weight: convolutional kernel. shape is (out_features, in_features, *kernel).
array: 2D input array. shape is (row, col).
weight: convolutional kernel. shape is (row, col).
"""
assert array.ndim == 3
assert array.ndim == 2
assert weight.ndim == 2

array = jnp.expand_dims(array, 0)
weight = jnp.expand_dims(weight, (0, 1))

ones = (1,) * (array.ndim - 1)
x = jax.lax.conv_general_dilated(
lhs=jnp.expand_dims(array, 0),
rhs=weight,
window_strides=ones,
window_strides=(1, 1),
padding="SAME",
rhs_dilation=ones,
dimension_numbers=generate_conv_dim_numbers(array.ndim - 1),
rhs_dilation=(1, 1),
dimension_numbers=generate_conv_dim_numbers(2),
feature_group_count=array.shape[0], # in_features
)
return jax.lax.stop_gradient_p.bind(jnp.squeeze(x, 0))
return jax.lax.stop_gradient_p.bind(jnp.squeeze(x, (0, 1)))


def fft_filter_2d(
array: Annotated[jax.Array, "CHW"],
weight: Annotated[jax.Array, "OIHW"],
) -> jax.Array:
array: Annotated[jax.Array, "HW"],
weight: Annotated[jax.Array, "HW"],
) -> Annotated[jax.Array, "HW"]:
"""Filtering wrapping ``serket`` ``fft_conv_general_dilated``
Args:
array: input array. shape is (in_features, *spatial).
weight: convolutional kernel. shape is (out_features, in_features, *kernel).
array: 2D input array. shape is (row, col).
weight: convolutional kernel. shape is (row, col).
"""
assert array.ndim == 3
assert array.ndim == 2
assert weight.ndim == 2

ones = (1,) * (array.ndim - 1)
array = jnp.expand_dims(array, 0)
weight = jnp.expand_dims(weight, (0, 1))

padding = resolve_string_padding(
in_dim=array.shape[1:],
padding="SAME",
kernel_size=weight.shape[2:],
strides=ones,
strides=(1, 1),
)

x = fft_conv_general_dilated(
lhs=jnp.expand_dims(array, 0),
rhs=weight,
strides=ones,
strides=(1, 1),
padding=padding,
dilation=ones,
dilation=(1, 1),
groups=array.shape[0], # in_features
)
return jax.lax.stop_gradient_p.bind(jnp.squeeze(x, 0))
return jax.lax.stop_gradient_p.bind(jnp.squeeze(x, (0, 1)))


def calculate_average_kernel(
kernel_size: int,
dtype: DType,
) -> Annotated[jax.Array, "HW"]:
kernel = jnp.ones((kernel_size))
kernel = kernel / jnp.sum(kernel)
kernel = kernel.astype(dtype)
kernel = jnp.expand_dims(kernel, 0)
return kernel


class AvgBlur2DBase(sk.TreeClass):
def __init__(self, kernel_size: int, *, dtype: DType = jnp.float32):
kernel_size = positive_int_cb(kernel_size)
kernel = jnp.ones(kernel_size)
kernel = kernel / jnp.sum(kernel)
self.kernel = kernel[None, None, None].astype(dtype)
self.kernel = calculate_average_kernel(kernel_size, dtype)

@property
def spatial_ndim(self) -> int:
Expand Down Expand Up @@ -123,11 +137,9 @@ class AvgBlur2D(AvgBlur2DBase):

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
x = jnp.expand_dims(x, 1)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel)
k = jnp.moveaxis(self.kernel, 2, 3)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, k)
x = jnp.squeeze(x, 1)
kernel = jax.lax.stop_gradient_p.bind(self.kernel)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel.T)
return x


Expand All @@ -154,14 +166,26 @@ class FFTAvgBlur2D(AvgBlur2DBase):

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
x = jnp.expand_dims(x, 1)
x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, self.kernel)
k = jnp.moveaxis(self.kernel, 2, 3)
x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, k)
x = jnp.squeeze(x, 1)
kernel = jax.lax.stop_gradient_p.bind(self.kernel)
x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel)
x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel.T)
return x


def calculate_gaussian_kernel(
kernel_size: int,
sigma: float,
dtype: DType,
) -> Annotated[jax.Array, "HW"]:
x = jnp.arange(kernel_size) - kernel_size // 2
x = x + 0.5 if kernel_size % 2 == 0 else x
kernel = jnp.exp(-(x**2) / (2 * sigma**2))
kernel = kernel / jnp.sum(kernel)
kernel = kernel.astype(dtype)
kernel = jnp.expand_dims(kernel, (0))
return kernel


class GaussianBlur2DBase(sk.TreeClass):
def __init__(
self,
Expand All @@ -172,11 +196,7 @@ def __init__(
):
self.kernel_size = positive_int_cb(kernel_size)
self.sigma = sigma
x = jnp.arange(kernel_size) - kernel_size // 2
x = x + 0.5 if kernel_size % 2 == 0 else x
kernel = jnp.exp(-(x**2) / (2 * sigma**2))
kernel = kernel / jnp.sum(kernel)
self.kernel = kernel[None, None, None].astype(dtype)
self.kernel = calculate_gaussian_kernel(self.kernel_size, sigma, dtype)

@property
def spatial_ndim(self) -> int:
Expand Down Expand Up @@ -207,11 +227,9 @@ class GaussianBlur2D(GaussianBlur2DBase):

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
x = jnp.expand_dims(x, 1)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel)
k = jnp.moveaxis(self.kernel, 2, 3)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, k)
x = jnp.squeeze(x, 1)
kernel = jax.lax.stop_gradient_p.bind(self.kernel)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel.T)
return x


Expand Down Expand Up @@ -239,11 +257,9 @@ class FFTGaussianBlur2D(GaussianBlur2DBase):

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
x = jnp.expand_dims(x, 1)
x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, self.kernel)
k = jnp.moveaxis(self.kernel, 2, 3)
x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, k)
x = jnp.squeeze(x, 1)
kernel = jax.lax.stop_gradient_p.bind(self.kernel)
x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel)
x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel.T)
return x


Expand Down Expand Up @@ -276,13 +292,12 @@ def __init__(
):
if not isinstance(kernel, jax.Array) or kernel.ndim != 2:
raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)")
self.kernel = kernel[None, None].astype(dtype)
self.kernel = kernel.astype(dtype)

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
x = jnp.expand_dims(x, 1)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel)
x = jnp.squeeze(x, 1)
kernel = jax.lax.stop_gradient_p.bind(self.kernel)
x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel)
return x

@property
Expand Down Expand Up @@ -315,13 +330,12 @@ def __init__(self, kernel: jax.Array, *, dtype: DType = jnp.float32):
if kernel.ndim != 2:
raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)")

self.kernel = kernel[None, None].astype(dtype)
self.kernel = kernel.astype(dtype)

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array) -> jax.Array:
x = jnp.expand_dims(x, 1)
x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, self.kernel)
x = jnp.squeeze(x, 1)
kernel = jax.lax.stop_gradient_p.bind(self.kernel)
x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel)
return x

@property
Expand Down

0 comments on commit 7c7b4ec

Please sign in to comment.