Skip to content

Commit

Permalink
functionalize filter
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Sep 4, 2023
1 parent 9c15c6e commit b0e0b14
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 89 deletions.
3 changes: 2 additions & 1 deletion serket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
unfreeze,
)

from . import nn
from . import image, nn
from .custom_transform import tree_eval, tree_state
from .nn.activation import def_act_entry
from .nn.initialization import def_init_entry
Expand Down Expand Up @@ -81,6 +81,7 @@
"leafwise",
# serket
"nn",
"image",
"tree_eval",
"tree_state",
"def_init_entry",
Expand Down
175 changes: 93 additions & 82 deletions serket/image/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@

import jax
import jax.numpy as jnp
from jax import lax
from typing_extensions import Annotated

import serket as sk
from serket.nn.convolution import DepthwiseConv2D, DepthwiseFFTConv2D
from serket.nn.convolution import fft_conv_general_dilated
from serket.nn.initialization import DType
from serket.utils import (
generate_conv_dim_numbers,
maybe_lazy_call,
maybe_lazy_init,
positive_int_cb,
resolve_string_padding,
validate_axis_shape,
validate_spatial_ndim,
)
Expand All @@ -47,6 +49,63 @@ def infer_in_features(_, x, *__, **___) -> int:
image_updates = dict(in_features=infer_in_features)


def filter_2d(
array: Annotated[jax.Array, "CHW"],
weight: Annotated[jax.Array, "OIHW"],
) -> jax.Array:
"""Filtering wrapping ``jax.lax.conv_general_dilated``.
Args:
array: input array. shape is (in_features, *spatial).
weight: convolutional kernel. shape is (out_features, in_features, *kernel).
"""
assert array.ndim == 3

ones = (1,) * (array.ndim - 1)
x = jax.lax.conv_general_dilated(
lhs=jnp.expand_dims(array, 0),
rhs=weight,
window_strides=ones,
padding="SAME",
rhs_dilation=ones,
dimension_numbers=generate_conv_dim_numbers(array.ndim - 1),
feature_group_count=array.shape[0], # in_features
)
return jax.lax.stop_gradient_p.bind(jnp.squeeze(x, 0))


def fft_filter_2d(
array: Annotated[jax.Array, "CHW"],
weight: Annotated[jax.Array, "OIHW"],
) -> jax.Array:
"""Filtering wrapping ``jax.lax.conv_general_dilated``.
Args:
array: input array. shape is (in_features, *spatial).
weight: convolutional kernel. shape is (out_features, in_features, *kernel).
"""
assert array.ndim == 3

ones = (1,) * (array.ndim - 1)

padding = resolve_string_padding(
in_dim=array.shape[1:],
padding="SAME",
kernel_size=weight.shape[2:],
strides=ones,
)

x = fft_conv_general_dilated(
lhs=jnp.expand_dims(array, 0),
rhs=weight,
strides=ones,
padding=padding,
dilation=ones,
groups=array.shape[0], # in_features
)
return jax.lax.stop_gradient_p.bind(jnp.squeeze(x, 0))


class AvgBlur2D(sk.TreeClass):
"""Average blur 2D layer
Expand Down Expand Up @@ -101,35 +160,21 @@ def __init__(
*,
dtype: DType = jnp.float32,
):
self.in_features = positive_int_cb(in_features)
kernel_size = positive_int_cb(kernel_size)
weight = jnp.ones(kernel_size)
weight = weight / jnp.sum(weight)
weight = weight[:, None]
weight = jnp.repeat(weight[None, None], in_features, axis=0)

self.conv1 = DepthwiseConv2D(
in_features=in_features,
kernel_size=(kernel_size, 1),
padding="same",
weight_init=lambda *_: weight,
bias_init=None,
dtype=dtype,
)

self.conv2 = DepthwiseConv2D(
in_features=in_features,
kernel_size=(1, kernel_size),
padding="same",
weight_init=lambda *_: jnp.moveaxis(weight, 2, 3),
bias_init=None,
dtype=dtype,
)
kernel = jnp.ones(kernel_size)
kernel = kernel / jnp.sum(kernel)
kernel = kernel[:, None]
kernel = jnp.repeat(kernel[None, None], in_features, axis=0).astype(dtype)
self.kernel = kernel

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="conv1.in_features", axis=0)
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
def __call__(self, x: jax.Array) -> jax.Array:
return lax.stop_gradient(self.conv2(self.conv1(x)))
x = filter_2d(x, self.kernel)
x = filter_2d(x, jnp.moveaxis(self.kernel, 2, 3))
return x

@property
def spatial_ndim(self) -> int:
Expand Down Expand Up @@ -192,39 +237,23 @@ def __init__(
sigma: float = 1.0,
dtype: DType = jnp.float32,
):
kernel_size = positive_int_cb(kernel_size)
self.in_features = positive_int_cb(in_features)
self.kernel_size = positive_int_cb(kernel_size)
self.sigma = sigma

x = jnp.linspace(-(kernel_size - 1) / 2.0, (kernel_size - 1) / 2.0, kernel_size)
weight = jnp.exp(-0.5 * jnp.square(x) * jax.lax.rsqrt(self.sigma))

weight = weight / jnp.sum(weight)
weight = weight[:, None]

weight = jnp.repeat(weight[None, None], in_features, axis=0)
self.conv1 = DepthwiseFFTConv2D(
in_features=in_features,
kernel_size=(kernel_size, 1),
padding="same",
weight_init=lambda *_: weight,
bias_init=None,
dtype=dtype,
)

self.conv2 = DepthwiseFFTConv2D(
in_features=in_features,
kernel_size=(1, kernel_size),
padding="same",
weight_init=lambda *_: jnp.moveaxis(weight, 2, 3),
bias_init=None,
dtype=dtype,
)
kernel = jnp.exp(-0.5 * jnp.square(x) * jax.lax.rsqrt(self.sigma))
kernel = kernel / jnp.sum(kernel)
kernel = kernel[:, None].astype(dtype)
self.kernel = jnp.repeat(kernel[None, None], in_features, axis=0)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="conv1.in_features", axis=0)
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
def __call__(self, x: jax.Array) -> jax.Array:
return lax.stop_gradient(self.conv1(self.conv2(x)))
x = filter_2d(x, self.kernel)
x = filter_2d(x, jnp.moveaxis(self.kernel, 2, 3))
return x

@property
def spatial_ndim(self) -> int:
Expand All @@ -238,7 +267,7 @@ class Filter2D(sk.TreeClass):
Args:
in_features: number of input channels.
kernel: kernel array.
kernel: kernel array with shape (H, W).
dtype: data type of the layer. Defaults to ``jnp.float32``.
Example:
Expand All @@ -264,24 +293,15 @@ def __init__(
if not isinstance(kernel, jax.Array) or kernel.ndim != 2:
raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)")

in_features = positive_int_cb(in_features)
weight = jnp.stack([kernel] * in_features, axis=0)
weight = weight[:, None]

self.conv = DepthwiseConv2D(
in_features=in_features,
kernel_size=kernel.shape,
padding="same",
weight_init=lambda *_: weight,
bias_init=None,
dtype=dtype,
)
self.in_features = positive_int_cb(in_features)
kernel = jnp.stack([kernel] * in_features, axis=0)
self.kernel = kernel[:, None].astype(dtype)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="conv.in_features", axis=0)
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
def __call__(self, x: jax.Array) -> jax.Array:
return lax.stop_gradient(self.conv(x))
return filter_2d(x, self.kernel)

@property
def spatial_ndim(self) -> int:
Expand Down Expand Up @@ -318,27 +338,18 @@ def __init__(
*,
dtype: DType = jnp.float32,
):
if not isinstance(kernel, jax.Array) or kernel.ndim != 2:
if kernel.ndim != 2:
raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)")

in_features = positive_int_cb(in_features)
weight = jnp.stack([kernel] * in_features, axis=0)
weight = weight[:, None]

self.conv = DepthwiseFFTConv2D(
in_features=in_features,
kernel_size=kernel.shape,
padding="same",
weight_init=lambda *_: weight,
bias_init=None,
dtype=dtype,
)
self.in_features = positive_int_cb(in_features)
kernel = jnp.stack([kernel] * in_features, axis=0)
self.kernel = kernel[:, None].astype(dtype)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=image_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="conv.in_features", axis=0)
@ft.partial(validate_axis_shape, attribute_name="in_features", axis=0)
def __call__(self, x: jax.Array) -> jax.Array:
return lax.stop_gradient(self.conv(x))
return fft_filter_2d(x, self.kernel)

@property
def spatial_ndim(self) -> int:
Expand Down
7 changes: 1 addition & 6 deletions serket/nn/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import jax
import jax.numpy as jnp
import jax.random as jr
from jax.lax import ConvDimensionNumbers
from typing_extensions import Annotated

import serket as sk
Expand All @@ -38,6 +37,7 @@
calculate_transpose_padding,
canonicalize,
delayed_canonicalize_padding,
generate_conv_dim_numbers,
maybe_lazy_call,
maybe_lazy_init,
positive_int_cb,
Expand All @@ -46,11 +46,6 @@
)


@ft.lru_cache(maxsize=None)
def generate_conv_dim_numbers(spatial_ndim) -> ConvDimensionNumbers:
return ConvDimensionNumbers(*((tuple(range(spatial_ndim + 2)),) * 3))


@ft.partial(jax.jit, inline=True)
def _ungrouped_matmul(x, y) -> jax.Array:
alpha = "".join(map(str, range(max(x.ndim, y.ndim))))
Expand Down
5 changes: 5 additions & 0 deletions serket/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
T = TypeVar("T")


@ft.lru_cache(maxsize=None)
def generate_conv_dim_numbers(spatial_ndim) -> jax.lax.ConvDimensionNumbers:
return jax.lax.ConvDimensionNumbers(*((tuple(range(spatial_ndim + 2)),) * 3))


@ft.lru_cache(maxsize=128)
def calculate_transpose_padding(
padding,
Expand Down

0 comments on commit b0e0b14

Please sign in to comment.