Skip to content

Commit

Permalink
Lazy image and linear (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 1, 2023
1 parent 43ad7fa commit 6b70844
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 22 deletions.
30 changes: 20 additions & 10 deletions serket/nn/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,7 @@ def infer_in_features(instance, x, *_, **__) -> int:
return x.shape[0]


def infer_in_size(instance, x, *_, **__) -> tuple[int, ...]:
return x.shape[1:]


def infer_key(instance, *_, **__) -> jr.KeyArray:
return instance.key


conv_updates = {"key": infer_key, "in_features": infer_in_features}
conv_updates = dict(in_features=infer_in_features)


class BaseConvND(sk.TreeClass):
Expand Down Expand Up @@ -2128,6 +2120,19 @@ def __init__(
pointwise_bias_init: InitType = "zeros",
key: jr.KeyArray = jr.PRNGKey(0),
):
if in_features is None:
self.in_features = in_features
self.out_features = out_features
self.kernel_size = kernel_size
self.depth_multiplier = depth_multiplier
self.strides = strides
self.padding = padding
self.depthwise_weight_init = depthwise_weight_init
self.pointwise_weight_init = pointwise_weight_init
self.pointwise_bias_init = pointwise_bias_init
self.key = key
return

self.depthwise_conv = self._depthwise_convolution_layer(
in_features=in_features,
depth_multiplier=depth_multiplier,
Expand All @@ -2150,6 +2155,7 @@ def __init__(
key=key,
)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates)
def __call__(self, x: jax.Array, **k) -> jax.Array:
x = self.depthwise_conv(x)
x = self.pointwise_conv(x)
Expand Down Expand Up @@ -2747,7 +2753,11 @@ def _depthwise_convolution_layer(self):
return DepthwiseFFTConv3D


convlocal_updates = {**conv_updates, "in_size": infer_in_size}
def infer_in_size(_, x, *__, **___) -> tuple[int, ...]:
return x.shape[1:]


convlocal_updates = {**dict(in_size=infer_in_size), **conv_updates}


class ConvNDLocal(sk.TreeClass):
Expand Down
93 changes: 91 additions & 2 deletions serket/nn/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,23 @@
from serket.nn.convolution import DepthwiseConv2D, DepthwiseFFTConv2D
from serket.nn.custom_transform import tree_eval
from serket.nn.linear import Identity
from serket.nn.utils import positive_int_cb, validate_axis_shape, validate_spatial_ndim
from serket.nn.utils import (
maybe_lazy_call,
positive_int_cb,
validate_axis_shape,
validate_spatial_ndim,
)


def is_lazy(instance, *_, **__) -> bool:
return getattr(instance, "in_features", False) is None


def infer_in_features(instance, x, *_, **__) -> int:
return x.shape[0]


image_updates = dict(in_features=infer_in_features)


class AvgBlur2D(sk.TreeClass):
Expand All @@ -46,9 +62,38 @@ class AvgBlur2D(sk.TreeClass):
[0.6666667 1. 1. 1. 0.6666667 ]
[0.6666667 1. 1. 1. 0.6666667 ]
[0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]]
Note:
:class:`.AvgBlur2D` supports lazy initialization, meaning that the weights and
biases are not initialized until the first call to the layer. This is
useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use the ``.at["calling_method_name"]`` attribute to call the layer
with an input of known shape.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> import jax
>>> @sk.autoinit
... class Blur(sk.TreeClass):
... l1: sk.nn.AvgBlur2D = sk.nn.AvgBlur2D(None, 3)
... l2: sk.nn.AvgBlur2D = sk.nn.AvgBlur2D(None, 3)
... def __call__(self, x: jax.Array) -> jax.Array:
... return self.l2(jax.nn.relu(self.l1(x)))
>>> # lazy initialization
>>> lazy_blur = Blur()
>>> # materialize the layer
>>> _, materialized_blur = lazy_blur.at["__call__"](jnp.ones((5, 2, 2)))
"""

def __init__(self, in_features: int, kernel_size: int | tuple[int, int]):
def __init__(self, in_features: int | None, kernel_size: int | tuple[int, int]):
if in_features is None:
self.in_features = None
self.kernel_size = kernel_size
return

weight = jnp.ones(kernel_size)
weight = weight / jnp.sum(weight)
weight = weight[:, None]
Expand All @@ -70,6 +115,7 @@ def __init__(self, in_features: int, kernel_size: int | tuple[int, int]):
bias_init=None,
)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=image_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="conv1.in_features", axis=0)
def __call__(self, x: jax.Array, **k) -> jax.Array:
Expand Down Expand Up @@ -98,9 +144,39 @@ class GaussianBlur2D(sk.TreeClass):
[0.7259314 1. 1. 1. 0.7259314]
[0.7259314 1. 1. 1. 0.7259314]
[0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]]
Note:
:class:`.GaussianBlur2D` supports lazy initialization, meaning that the weights and
biases are not initialized until the first call to the layer. This is
useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use the ``.at["calling_method_name"]`` attribute to call the layer
with an input of known shape.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> import jax
>>> @sk.autoinit
... class Blur(sk.TreeClass):
... l1: sk.nn.GaussianBlur2D = sk.nn.GaussianBlur2D(None, 3)
... l2: sk.nn.GaussianBlur2D = sk.nn.GaussianBlur2D(None, 3)
... def __call__(self, x: jax.Array) -> jax.Array:
... return self.l2(jax.nn.relu(self.l1(x)))
>>> # lazy initialization
>>> lazy_blur = Blur()
>>> # materialize the layer
>>> _, materialized_blur = lazy_blur.at["__call__"](jnp.ones((5, 2, 2)))
"""

def __init__(self, in_features: int, kernel_size: int, *, sigma: float = 1.0):
if in_features is None:
self.in_features = None
self.kernel_size = kernel_size
self.sigma = sigma
return

kernel_size = positive_int_cb(kernel_size)
self.sigma = sigma

Expand All @@ -127,6 +203,7 @@ def __init__(self, in_features: int, kernel_size: int, *, sigma: float = 1.0):
bias_init=None,
)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=image_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="conv1.in_features", axis=0)
def __call__(self, x: jax.Array, **k) -> jax.Array:
Expand Down Expand Up @@ -157,6 +234,11 @@ class Filter2D(sk.TreeClass):
"""

def __init__(self, in_features: int, kernel: jax.Array):
if in_features is None:
self.in_features = None
self.kernel = kernel
return

if not isinstance(kernel, jax.Array) or kernel.ndim != 2:
raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)")

Expand All @@ -172,6 +254,7 @@ def __init__(self, in_features: int, kernel: jax.Array):
bias_init=None,
)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=image_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="conv.in_features", axis=0)
def __call__(self, x: jax.Array, **k) -> jax.Array:
Expand Down Expand Up @@ -202,6 +285,11 @@ class FFTFilter2D(sk.TreeClass):
"""

def __init__(self, in_features: int, kernel: jax.Array):
if in_features is None:
self.in_features = None
self.kernel = kernel
return

if not isinstance(kernel, jax.Array) or kernel.ndim != 2:
raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)")

Expand All @@ -217,6 +305,7 @@ def __init__(self, in_features: int, kernel: jax.Array):
bias_init=None,
)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=image_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="conv.in_features", axis=0)
def __call__(self, x: jax.Array, **k) -> jax.Array:
Expand Down
Loading

0 comments on commit 6b70844

Please sign in to comment.