Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy image and linear #26

Merged
merged 3 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions serket/nn/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,7 @@
return x.shape[0]


def infer_in_size(instance, x, *_, **__) -> tuple[int, ...]:
return x.shape[1:]


def infer_key(instance, *_, **__) -> jr.KeyArray:
return instance.key


conv_updates = {"key": infer_key, "in_features": infer_in_features}
conv_updates = dict(in_features=infer_in_features)


class BaseConvND(sk.TreeClass):
Expand Down Expand Up @@ -2128,6 +2120,19 @@
pointwise_bias_init: InitType = "zeros",
key: jr.KeyArray = jr.PRNGKey(0),
):
if in_features is None:
self.in_features = in_features
self.out_features = out_features
self.kernel_size = kernel_size
self.depth_multiplier = depth_multiplier
self.strides = strides
self.padding = padding
self.depthwise_weight_init = depthwise_weight_init
self.pointwise_weight_init = pointwise_weight_init
self.pointwise_bias_init = pointwise_bias_init
self.key = key
return

Check warning on line 2134 in serket/nn/convolution.py

View check run for this annotation

Codecov / codecov/patch

serket/nn/convolution.py#L2124-L2134

Added lines #L2124 - L2134 were not covered by tests

self.depthwise_conv = self._depthwise_convolution_layer(
in_features=in_features,
depth_multiplier=depth_multiplier,
Expand All @@ -2150,6 +2155,7 @@
key=key,
)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates)
def __call__(self, x: jax.Array, **k) -> jax.Array:
x = self.depthwise_conv(x)
x = self.pointwise_conv(x)
Expand Down Expand Up @@ -2747,7 +2753,11 @@
return DepthwiseFFTConv3D


convlocal_updates = {**conv_updates, "in_size": infer_in_size}
def infer_in_size(_, x, *__, **___) -> tuple[int, ...]:
return x.shape[1:]

Check warning on line 2757 in serket/nn/convolution.py

View check run for this annotation

Codecov / codecov/patch

serket/nn/convolution.py#L2757

Added line #L2757 was not covered by tests


convlocal_updates = {**dict(in_size=infer_in_size), **conv_updates}


class ConvNDLocal(sk.TreeClass):
Expand Down
93 changes: 91 additions & 2 deletions serket/nn/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,23 @@
from serket.nn.convolution import DepthwiseConv2D, DepthwiseFFTConv2D
from serket.nn.custom_transform import tree_eval
from serket.nn.linear import Identity
from serket.nn.utils import positive_int_cb, validate_axis_shape, validate_spatial_ndim
from serket.nn.utils import (
maybe_lazy_call,
positive_int_cb,
validate_axis_shape,
validate_spatial_ndim,
)


def is_lazy(instance, *_, **__) -> bool:
return getattr(instance, "in_features", False) is None


def infer_in_features(instance, x, *_, **__) -> int:
return x.shape[0]

Check warning on line 42 in serket/nn/image.py

View check run for this annotation

Codecov / codecov/patch

serket/nn/image.py#L42

Added line #L42 was not covered by tests


image_updates = dict(in_features=infer_in_features)


class AvgBlur2D(sk.TreeClass):
Expand All @@ -46,9 +62,38 @@
[0.6666667 1. 1. 1. 0.6666667 ]
[0.6666667 1. 1. 1. 0.6666667 ]
[0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]]

Note:
:class:`.AvgBlur2D` supports lazy initialization, meaning that the weights and
biases are not initialized until the first call to the layer. This is
useful when the input shape is not known at initialization time.

To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use the ``.at["calling_method_name"]`` attribute to call the layer
with an input of known shape.

>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> import jax
>>> @sk.autoinit
... class Blur(sk.TreeClass):
... l1: sk.nn.AvgBlur2D = sk.nn.AvgBlur2D(None, 3)
... l2: sk.nn.AvgBlur2D = sk.nn.AvgBlur2D(None, 3)
... def __call__(self, x: jax.Array) -> jax.Array:
... return self.l2(jax.nn.relu(self.l1(x)))
>>> # lazy initialization
>>> lazy_blur = Blur()
>>> # materialize the layer
>>> _, materialized_blur = lazy_blur.at["__call__"](jnp.ones((5, 2, 2)))
"""

def __init__(self, in_features: int, kernel_size: int | tuple[int, int]):
def __init__(self, in_features: int | None, kernel_size: int | tuple[int, int]):
if in_features is None:
self.in_features = None
self.kernel_size = kernel_size
return

Check warning on line 95 in serket/nn/image.py

View check run for this annotation

Codecov / codecov/patch

serket/nn/image.py#L93-L95

Added lines #L93 - L95 were not covered by tests

weight = jnp.ones(kernel_size)
weight = weight / jnp.sum(weight)
weight = weight[:, None]
Expand All @@ -70,6 +115,7 @@
bias_init=None,
)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=image_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="conv1.in_features", axis=0)
def __call__(self, x: jax.Array, **k) -> jax.Array:
Expand Down Expand Up @@ -98,9 +144,39 @@
[0.7259314 1. 1. 1. 0.7259314]
[0.7259314 1. 1. 1. 0.7259314]
[0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]]

Note:
:class:`.GaussianBlur2D` supports lazy initialization, meaning that the weights and
biases are not initialized until the first call to the layer. This is
useful when the input shape is not known at initialization time.

To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use the ``.at["calling_method_name"]`` attribute to call the layer
with an input of known shape.

>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> import jax
>>> @sk.autoinit
... class Blur(sk.TreeClass):
... l1: sk.nn.GaussianBlur2D = sk.nn.GaussianBlur2D(None, 3)
... l2: sk.nn.GaussianBlur2D = sk.nn.GaussianBlur2D(None, 3)
... def __call__(self, x: jax.Array) -> jax.Array:
... return self.l2(jax.nn.relu(self.l1(x)))
>>> # lazy initialization
>>> lazy_blur = Blur()
>>> # materialize the layer
>>> _, materialized_blur = lazy_blur.at["__call__"](jnp.ones((5, 2, 2)))
"""

def __init__(self, in_features: int, kernel_size: int, *, sigma: float = 1.0):
if in_features is None:
self.in_features = None
self.kernel_size = kernel_size
self.sigma = sigma
return

Check warning on line 178 in serket/nn/image.py

View check run for this annotation

Codecov / codecov/patch

serket/nn/image.py#L175-L178

Added lines #L175 - L178 were not covered by tests

kernel_size = positive_int_cb(kernel_size)
self.sigma = sigma

Expand All @@ -127,6 +203,7 @@
bias_init=None,
)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=image_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="conv1.in_features", axis=0)
def __call__(self, x: jax.Array, **k) -> jax.Array:
Expand Down Expand Up @@ -157,6 +234,11 @@
"""

def __init__(self, in_features: int, kernel: jax.Array):
if in_features is None:
self.in_features = None
self.kernel = kernel
return

Check warning on line 240 in serket/nn/image.py

View check run for this annotation

Codecov / codecov/patch

serket/nn/image.py#L238-L240

Added lines #L238 - L240 were not covered by tests

if not isinstance(kernel, jax.Array) or kernel.ndim != 2:
raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)")

Expand All @@ -172,6 +254,7 @@
bias_init=None,
)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=image_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="conv.in_features", axis=0)
def __call__(self, x: jax.Array, **k) -> jax.Array:
Expand Down Expand Up @@ -202,6 +285,11 @@
"""

def __init__(self, in_features: int, kernel: jax.Array):
if in_features is None:
self.in_features = None
self.kernel = kernel
return

Check warning on line 291 in serket/nn/image.py

View check run for this annotation

Codecov / codecov/patch

serket/nn/image.py#L289-L291

Added lines #L289 - L291 were not covered by tests

if not isinstance(kernel, jax.Array) or kernel.ndim != 2:
raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)")

Expand All @@ -217,6 +305,7 @@
bias_init=None,
)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=image_updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
@ft.partial(validate_axis_shape, attribute_name="conv.in_features", axis=0)
def __call__(self, x: jax.Array, **k) -> jax.Array:
Expand Down
Loading
Loading