diff --git a/serket/image/filter.py b/serket/image/filter.py index cd725c3..a336652 100644 --- a/serket/image/filter.py +++ b/serket/image/filter.py @@ -32,68 +32,82 @@ def filter_2d( - array: Annotated[jax.Array, "CHW"], - weight: Annotated[jax.Array, "OIHW"], -) -> jax.Array: + array: Annotated[jax.Array, "HW"], + weight: Annotated[jax.Array, "HW"], +) -> Annotated[jax.Array, "HW"]: """Filtering wrapping ``jax.lax.conv_general_dilated``. Args: - array: input array. shape is (in_features, *spatial). - weight: convolutional kernel. shape is (out_features, in_features, *kernel). + array: 2D input array. shape is (row, col). + weight: convolutional kernel. shape is (row, col). """ - assert array.ndim == 3 + assert array.ndim == 2 + assert weight.ndim == 2 + + array = jnp.expand_dims(array, 0) + weight = jnp.expand_dims(weight, (0, 1)) - ones = (1,) * (array.ndim - 1) x = jax.lax.conv_general_dilated( lhs=jnp.expand_dims(array, 0), rhs=weight, - window_strides=ones, + window_strides=(1, 1), padding="SAME", - rhs_dilation=ones, - dimension_numbers=generate_conv_dim_numbers(array.ndim - 1), + rhs_dilation=(1, 1), + dimension_numbers=generate_conv_dim_numbers(2), feature_group_count=array.shape[0], # in_features ) - return jax.lax.stop_gradient_p.bind(jnp.squeeze(x, 0)) + return jax.lax.stop_gradient_p.bind(jnp.squeeze(x, (0, 1))) def fft_filter_2d( - array: Annotated[jax.Array, "CHW"], - weight: Annotated[jax.Array, "OIHW"], -) -> jax.Array: + array: Annotated[jax.Array, "HW"], + weight: Annotated[jax.Array, "HW"], +) -> Annotated[jax.Array, "HW"]: """Filtering wrapping ``serket`` ``fft_conv_general_dilated`` Args: - array: input array. shape is (in_features, *spatial). - weight: convolutional kernel. shape is (out_features, in_features, *kernel). + array: 2D input array. shape is (row, col). + weight: convolutional kernel. shape is (row, col). """ - assert array.ndim == 3 + assert array.ndim == 2 + assert weight.ndim == 2 - ones = (1,) * (array.ndim - 1) + array = jnp.expand_dims(array, 0) + weight = jnp.expand_dims(weight, (0, 1)) padding = resolve_string_padding( in_dim=array.shape[1:], padding="SAME", kernel_size=weight.shape[2:], - strides=ones, + strides=(1, 1), ) x = fft_conv_general_dilated( lhs=jnp.expand_dims(array, 0), rhs=weight, - strides=ones, + strides=(1, 1), padding=padding, - dilation=ones, + dilation=(1, 1), groups=array.shape[0], # in_features ) - return jax.lax.stop_gradient_p.bind(jnp.squeeze(x, 0)) + return jax.lax.stop_gradient_p.bind(jnp.squeeze(x, (0, 1))) + + +def calculate_average_kernel( + kernel_size: int, + dtype: DType, +) -> Annotated[jax.Array, "HW"]: + kernel = jnp.ones((kernel_size)) + kernel = kernel / jnp.sum(kernel) + kernel = kernel.astype(dtype) + kernel = jnp.expand_dims(kernel, 0) + return kernel class AvgBlur2DBase(sk.TreeClass): def __init__(self, kernel_size: int, *, dtype: DType = jnp.float32): kernel_size = positive_int_cb(kernel_size) - kernel = jnp.ones(kernel_size) - kernel = kernel / jnp.sum(kernel) - self.kernel = kernel[None, None, None].astype(dtype) + self.kernel = calculate_average_kernel(kernel_size, dtype) @property def spatial_ndim(self) -> int: @@ -123,11 +137,9 @@ class AvgBlur2D(AvgBlur2DBase): @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") def __call__(self, x: jax.Array) -> jax.Array: - x = jnp.expand_dims(x, 1) - x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel) - k = jnp.moveaxis(self.kernel, 2, 3) - x = jax.vmap(filter_2d, in_axes=(0, None))(x, k) - x = jnp.squeeze(x, 1) + kernel = jax.lax.stop_gradient_p.bind(self.kernel) + x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel) + x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel.T) return x @@ -154,14 +166,26 @@ class FFTAvgBlur2D(AvgBlur2DBase): @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") def __call__(self, x: jax.Array) -> jax.Array: - x = jnp.expand_dims(x, 1) - x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, self.kernel) - k = jnp.moveaxis(self.kernel, 2, 3) - x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, k) - x = jnp.squeeze(x, 1) + kernel = jax.lax.stop_gradient_p.bind(self.kernel) + x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel) + x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel.T) return x +def calculate_gaussian_kernel( + kernel_size: int, + sigma: float, + dtype: DType, +) -> Annotated[jax.Array, "HW"]: + x = jnp.arange(kernel_size) - kernel_size // 2 + x = x + 0.5 if kernel_size % 2 == 0 else x + kernel = jnp.exp(-(x**2) / (2 * sigma**2)) + kernel = kernel / jnp.sum(kernel) + kernel = kernel.astype(dtype) + kernel = jnp.expand_dims(kernel, (0)) + return kernel + + class GaussianBlur2DBase(sk.TreeClass): def __init__( self, @@ -172,11 +196,7 @@ def __init__( ): self.kernel_size = positive_int_cb(kernel_size) self.sigma = sigma - x = jnp.arange(kernel_size) - kernel_size // 2 - x = x + 0.5 if kernel_size % 2 == 0 else x - kernel = jnp.exp(-(x**2) / (2 * sigma**2)) - kernel = kernel / jnp.sum(kernel) - self.kernel = kernel[None, None, None].astype(dtype) + self.kernel = calculate_gaussian_kernel(self.kernel_size, sigma, dtype) @property def spatial_ndim(self) -> int: @@ -207,11 +227,9 @@ class GaussianBlur2D(GaussianBlur2DBase): @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") def __call__(self, x: jax.Array) -> jax.Array: - x = jnp.expand_dims(x, 1) - x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel) - k = jnp.moveaxis(self.kernel, 2, 3) - x = jax.vmap(filter_2d, in_axes=(0, None))(x, k) - x = jnp.squeeze(x, 1) + kernel = jax.lax.stop_gradient_p.bind(self.kernel) + x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel) + x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel.T) return x @@ -239,11 +257,9 @@ class FFTGaussianBlur2D(GaussianBlur2DBase): @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") def __call__(self, x: jax.Array) -> jax.Array: - x = jnp.expand_dims(x, 1) - x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, self.kernel) - k = jnp.moveaxis(self.kernel, 2, 3) - x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, k) - x = jnp.squeeze(x, 1) + kernel = jax.lax.stop_gradient_p.bind(self.kernel) + x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel) + x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel.T) return x @@ -276,13 +292,12 @@ def __init__( ): if not isinstance(kernel, jax.Array) or kernel.ndim != 2: raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)") - self.kernel = kernel[None, None].astype(dtype) + self.kernel = kernel.astype(dtype) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") def __call__(self, x: jax.Array) -> jax.Array: - x = jnp.expand_dims(x, 1) - x = jax.vmap(filter_2d, in_axes=(0, None))(x, self.kernel) - x = jnp.squeeze(x, 1) + kernel = jax.lax.stop_gradient_p.bind(self.kernel) + x = jax.vmap(filter_2d, in_axes=(0, None))(x, kernel) return x @property @@ -315,13 +330,12 @@ def __init__(self, kernel: jax.Array, *, dtype: DType = jnp.float32): if kernel.ndim != 2: raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)") - self.kernel = kernel[None, None].astype(dtype) + self.kernel = kernel.astype(dtype) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") def __call__(self, x: jax.Array) -> jax.Array: - x = jnp.expand_dims(x, 1) - x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, self.kernel) - x = jnp.squeeze(x, 1) + kernel = jax.lax.stop_gradient_p.bind(self.kernel) + x = jax.vmap(fft_filter_2d, in_axes=(0, None))(x, kernel) return x @property