From b0e0b147ee057b89c8c146e679e20c2bd7f8bfa8 Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Tue, 5 Sep 2023 02:07:12 +0900 Subject: [PATCH] functionalize filter --- serket/__init__.py | 3 +- serket/image/filter.py | 175 +++++++++++++++++++++------------------ serket/nn/convolution.py | 7 +- serket/utils.py | 5 ++ 4 files changed, 101 insertions(+), 89 deletions(-) diff --git a/serket/__init__.py b/serket/__init__.py index 409918e..a3af86d 100644 --- a/serket/__init__.py +++ b/serket/__init__.py @@ -42,7 +42,7 @@ unfreeze, ) -from . import nn +from . import image, nn from .custom_transform import tree_eval, tree_state from .nn.activation import def_act_entry from .nn.initialization import def_init_entry @@ -81,6 +81,7 @@ "leafwise", # serket "nn", + "image", "tree_eval", "tree_state", "def_init_entry", diff --git a/serket/image/filter.py b/serket/image/filter.py index bbefefc..a4b76e7 100644 --- a/serket/image/filter.py +++ b/serket/image/filter.py @@ -18,15 +18,17 @@ import jax import jax.numpy as jnp -from jax import lax +from typing_extensions import Annotated import serket as sk -from serket.nn.convolution import DepthwiseConv2D, DepthwiseFFTConv2D +from serket.nn.convolution import fft_conv_general_dilated from serket.nn.initialization import DType from serket.utils import ( + generate_conv_dim_numbers, maybe_lazy_call, maybe_lazy_init, positive_int_cb, + resolve_string_padding, validate_axis_shape, validate_spatial_ndim, ) @@ -47,6 +49,63 @@ def infer_in_features(_, x, *__, **___) -> int: image_updates = dict(in_features=infer_in_features) +def filter_2d( + array: Annotated[jax.Array, "CHW"], + weight: Annotated[jax.Array, "OIHW"], +) -> jax.Array: + """Filtering wrapping ``jax.lax.conv_general_dilated``. + + Args: + array: input array. shape is (in_features, *spatial). + weight: convolutional kernel. shape is (out_features, in_features, *kernel). + """ + assert array.ndim == 3 + + ones = (1,) * (array.ndim - 1) + x = jax.lax.conv_general_dilated( + lhs=jnp.expand_dims(array, 0), + rhs=weight, + window_strides=ones, + padding="SAME", + rhs_dilation=ones, + dimension_numbers=generate_conv_dim_numbers(array.ndim - 1), + feature_group_count=array.shape[0], # in_features + ) + return jax.lax.stop_gradient_p.bind(jnp.squeeze(x, 0)) + + +def fft_filter_2d( + array: Annotated[jax.Array, "CHW"], + weight: Annotated[jax.Array, "OIHW"], +) -> jax.Array: + """Filtering wrapping ``jax.lax.conv_general_dilated``. + + Args: + array: input array. shape is (in_features, *spatial). + weight: convolutional kernel. shape is (out_features, in_features, *kernel). + """ + assert array.ndim == 3 + + ones = (1,) * (array.ndim - 1) + + padding = resolve_string_padding( + in_dim=array.shape[1:], + padding="SAME", + kernel_size=weight.shape[2:], + strides=ones, + ) + + x = fft_conv_general_dilated( + lhs=jnp.expand_dims(array, 0), + rhs=weight, + strides=ones, + padding=padding, + dilation=ones, + groups=array.shape[0], # in_features + ) + return jax.lax.stop_gradient_p.bind(jnp.squeeze(x, 0)) + + class AvgBlur2D(sk.TreeClass): """Average blur 2D layer @@ -101,35 +160,21 @@ def __init__( *, dtype: DType = jnp.float32, ): + self.in_features = positive_int_cb(in_features) kernel_size = positive_int_cb(kernel_size) - weight = jnp.ones(kernel_size) - weight = weight / jnp.sum(weight) - weight = weight[:, None] - weight = jnp.repeat(weight[None, None], in_features, axis=0) - - self.conv1 = DepthwiseConv2D( - in_features=in_features, - kernel_size=(kernel_size, 1), - padding="same", - weight_init=lambda *_: weight, - bias_init=None, - dtype=dtype, - ) - - self.conv2 = DepthwiseConv2D( - in_features=in_features, - kernel_size=(1, kernel_size), - padding="same", - weight_init=lambda *_: jnp.moveaxis(weight, 2, 3), - bias_init=None, - dtype=dtype, - ) + kernel = jnp.ones(kernel_size) + kernel = kernel / jnp.sum(kernel) + kernel = kernel[:, None] + kernel = jnp.repeat(kernel[None, None], in_features, axis=0).astype(dtype) + self.kernel = kernel @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="conv1.in_features", axis=0) + @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array) -> jax.Array: - return lax.stop_gradient(self.conv2(self.conv1(x))) + x = filter_2d(x, self.kernel) + x = filter_2d(x, jnp.moveaxis(self.kernel, 2, 3)) + return x @property def spatial_ndim(self) -> int: @@ -192,39 +237,23 @@ def __init__( sigma: float = 1.0, dtype: DType = jnp.float32, ): - kernel_size = positive_int_cb(kernel_size) + self.in_features = positive_int_cb(in_features) + self.kernel_size = positive_int_cb(kernel_size) self.sigma = sigma x = jnp.linspace(-(kernel_size - 1) / 2.0, (kernel_size - 1) / 2.0, kernel_size) - weight = jnp.exp(-0.5 * jnp.square(x) * jax.lax.rsqrt(self.sigma)) - - weight = weight / jnp.sum(weight) - weight = weight[:, None] - - weight = jnp.repeat(weight[None, None], in_features, axis=0) - self.conv1 = DepthwiseFFTConv2D( - in_features=in_features, - kernel_size=(kernel_size, 1), - padding="same", - weight_init=lambda *_: weight, - bias_init=None, - dtype=dtype, - ) - - self.conv2 = DepthwiseFFTConv2D( - in_features=in_features, - kernel_size=(1, kernel_size), - padding="same", - weight_init=lambda *_: jnp.moveaxis(weight, 2, 3), - bias_init=None, - dtype=dtype, - ) + kernel = jnp.exp(-0.5 * jnp.square(x) * jax.lax.rsqrt(self.sigma)) + kernel = kernel / jnp.sum(kernel) + kernel = kernel[:, None].astype(dtype) + self.kernel = jnp.repeat(kernel[None, None], in_features, axis=0) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="conv1.in_features", axis=0) + @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array) -> jax.Array: - return lax.stop_gradient(self.conv1(self.conv2(x))) + x = filter_2d(x, self.kernel) + x = filter_2d(x, jnp.moveaxis(self.kernel, 2, 3)) + return x @property def spatial_ndim(self) -> int: @@ -238,7 +267,7 @@ class Filter2D(sk.TreeClass): Args: in_features: number of input channels. - kernel: kernel array. + kernel: kernel array with shape (H, W). dtype: data type of the layer. Defaults to ``jnp.float32``. Example: @@ -264,24 +293,15 @@ def __init__( if not isinstance(kernel, jax.Array) or kernel.ndim != 2: raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)") - in_features = positive_int_cb(in_features) - weight = jnp.stack([kernel] * in_features, axis=0) - weight = weight[:, None] - - self.conv = DepthwiseConv2D( - in_features=in_features, - kernel_size=kernel.shape, - padding="same", - weight_init=lambda *_: weight, - bias_init=None, - dtype=dtype, - ) + self.in_features = positive_int_cb(in_features) + kernel = jnp.stack([kernel] * in_features, axis=0) + self.kernel = kernel[:, None].astype(dtype) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="conv.in_features", axis=0) + @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array) -> jax.Array: - return lax.stop_gradient(self.conv(x)) + return filter_2d(x, self.kernel) @property def spatial_ndim(self) -> int: @@ -318,27 +338,18 @@ def __init__( *, dtype: DType = jnp.float32, ): - if not isinstance(kernel, jax.Array) or kernel.ndim != 2: + if kernel.ndim != 2: raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)") - in_features = positive_int_cb(in_features) - weight = jnp.stack([kernel] * in_features, axis=0) - weight = weight[:, None] - - self.conv = DepthwiseFFTConv2D( - in_features=in_features, - kernel_size=kernel.shape, - padding="same", - weight_init=lambda *_: weight, - bias_init=None, - dtype=dtype, - ) + self.in_features = positive_int_cb(in_features) + kernel = jnp.stack([kernel] * in_features, axis=0) + self.kernel = kernel[:, None].astype(dtype) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") - @ft.partial(validate_axis_shape, attribute_name="conv.in_features", axis=0) + @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) def __call__(self, x: jax.Array) -> jax.Array: - return lax.stop_gradient(self.conv(x)) + return fft_filter_2d(x, self.kernel) @property def spatial_ndim(self) -> int: diff --git a/serket/nn/convolution.py b/serket/nn/convolution.py index 0a27909..72e8283 100644 --- a/serket/nn/convolution.py +++ b/serket/nn/convolution.py @@ -24,7 +24,6 @@ import jax import jax.numpy as jnp import jax.random as jr -from jax.lax import ConvDimensionNumbers from typing_extensions import Annotated import serket as sk @@ -38,6 +37,7 @@ calculate_transpose_padding, canonicalize, delayed_canonicalize_padding, + generate_conv_dim_numbers, maybe_lazy_call, maybe_lazy_init, positive_int_cb, @@ -46,11 +46,6 @@ ) -@ft.lru_cache(maxsize=None) -def generate_conv_dim_numbers(spatial_ndim) -> ConvDimensionNumbers: - return ConvDimensionNumbers(*((tuple(range(spatial_ndim + 2)),) * 3)) - - @ft.partial(jax.jit, inline=True) def _ungrouped_matmul(x, y) -> jax.Array: alpha = "".join(map(str, range(max(x.ndim, y.ndim)))) diff --git a/serket/utils.py b/serket/utils.py index fb2b00d..cf001e8 100644 --- a/serket/utils.py +++ b/serket/utils.py @@ -35,6 +35,11 @@ T = TypeVar("T") +@ft.lru_cache(maxsize=None) +def generate_conv_dim_numbers(spatial_ndim) -> jax.lax.ConvDimensionNumbers: + return jax.lax.ConvDimensionNumbers(*((tuple(range(spatial_ndim + 2)),) * 3)) + + @ft.lru_cache(maxsize=128) def calculate_transpose_padding( padding,