diff --git a/serket/nn/convolution.py b/serket/nn/convolution.py index 2d09de9..9421534 100644 --- a/serket/nn/convolution.py +++ b/serket/nn/convolution.py @@ -162,15 +162,7 @@ def infer_in_features(instance, x, *_, **__) -> int: return x.shape[0] -def infer_in_size(instance, x, *_, **__) -> tuple[int, ...]: - return x.shape[1:] - - -def infer_key(instance, *_, **__) -> jr.KeyArray: - return instance.key - - -conv_updates = {"key": infer_key, "in_features": infer_in_features} +conv_updates = dict(in_features=infer_in_features) class BaseConvND(sk.TreeClass): @@ -2128,6 +2120,19 @@ def __init__( pointwise_bias_init: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), ): + if in_features is None: + self.in_features = in_features + self.out_features = out_features + self.kernel_size = kernel_size + self.depth_multiplier = depth_multiplier + self.strides = strides + self.padding = padding + self.depthwise_weight_init = depthwise_weight_init + self.pointwise_weight_init = pointwise_weight_init + self.pointwise_bias_init = pointwise_bias_init + self.key = key + return + self.depthwise_conv = self._depthwise_convolution_layer( in_features=in_features, depth_multiplier=depth_multiplier, @@ -2150,6 +2155,7 @@ def __init__( key=key, ) + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=conv_updates) def __call__(self, x: jax.Array, **k) -> jax.Array: x = self.depthwise_conv(x) x = self.pointwise_conv(x) @@ -2747,7 +2753,11 @@ def _depthwise_convolution_layer(self): return DepthwiseFFTConv3D -convlocal_updates = {**conv_updates, "in_size": infer_in_size} +def infer_in_size(_, x, *__, **___) -> tuple[int, ...]: + return x.shape[1:] + + +convlocal_updates = {**dict(in_size=infer_in_size), **conv_updates} class ConvNDLocal(sk.TreeClass): diff --git a/serket/nn/image.py b/serket/nn/image.py index 89ebdf3..076548c 100644 --- a/serket/nn/image.py +++ b/serket/nn/image.py @@ -26,7 +26,23 @@ from serket.nn.convolution import DepthwiseConv2D, DepthwiseFFTConv2D from serket.nn.custom_transform import tree_eval from serket.nn.linear import Identity -from serket.nn.utils import positive_int_cb, validate_axis_shape, validate_spatial_ndim +from serket.nn.utils import ( + maybe_lazy_call, + positive_int_cb, + validate_axis_shape, + validate_spatial_ndim, +) + + +def is_lazy(instance, *_, **__) -> bool: + return getattr(instance, "in_features", False) is None + + +def infer_in_features(instance, x, *_, **__) -> int: + return x.shape[0] + + +image_updates = dict(in_features=infer_in_features) class AvgBlur2D(sk.TreeClass): @@ -46,9 +62,38 @@ class AvgBlur2D(sk.TreeClass): [0.6666667 1. 1. 1. 0.6666667 ] [0.6666667 1. 1. 1. 0.6666667 ] [0.44444448 0.6666667 0.6666667 0.6666667 0.44444448]]] + + Note: + :class:`.AvgBlur2D` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class Blur(sk.TreeClass): + ... l1: sk.nn.AvgBlur2D = sk.nn.AvgBlur2D(None, 3) + ... l2: sk.nn.AvgBlur2D = sk.nn.AvgBlur2D(None, 3) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_blur = Blur() + >>> # materialize the layer + >>> _, materialized_blur = lazy_blur.at["__call__"](jnp.ones((5, 2, 2))) """ - def __init__(self, in_features: int, kernel_size: int | tuple[int, int]): + def __init__(self, in_features: int | None, kernel_size: int | tuple[int, int]): + if in_features is None: + self.in_features = None + self.kernel_size = kernel_size + return + weight = jnp.ones(kernel_size) weight = weight / jnp.sum(weight) weight = weight[:, None] @@ -70,6 +115,7 @@ def __init__(self, in_features: int, kernel_size: int | tuple[int, int]): bias_init=None, ) + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=image_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="conv1.in_features", axis=0) def __call__(self, x: jax.Array, **k) -> jax.Array: @@ -98,9 +144,39 @@ class GaussianBlur2D(sk.TreeClass): [0.7259314 1. 1. 1. 0.7259314] [0.7259314 1. 1. 1. 0.7259314] [0.5269764 0.7259314 0.7259314 0.7259314 0.5269764]]] + + Note: + :class:`.GaussianBlur2D` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> @sk.autoinit + ... class Blur(sk.TreeClass): + ... l1: sk.nn.GaussianBlur2D = sk.nn.GaussianBlur2D(None, 3) + ... l2: sk.nn.GaussianBlur2D = sk.nn.GaussianBlur2D(None, 3) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> # lazy initialization + >>> lazy_blur = Blur() + >>> # materialize the layer + >>> _, materialized_blur = lazy_blur.at["__call__"](jnp.ones((5, 2, 2))) """ def __init__(self, in_features: int, kernel_size: int, *, sigma: float = 1.0): + if in_features is None: + self.in_features = None + self.kernel_size = kernel_size + self.sigma = sigma + return + kernel_size = positive_int_cb(kernel_size) self.sigma = sigma @@ -127,6 +203,7 @@ def __init__(self, in_features: int, kernel_size: int, *, sigma: float = 1.0): bias_init=None, ) + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=image_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="conv1.in_features", axis=0) def __call__(self, x: jax.Array, **k) -> jax.Array: @@ -157,6 +234,11 @@ class Filter2D(sk.TreeClass): """ def __init__(self, in_features: int, kernel: jax.Array): + if in_features is None: + self.in_features = None + self.kernel = kernel + return + if not isinstance(kernel, jax.Array) or kernel.ndim != 2: raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)") @@ -172,6 +254,7 @@ def __init__(self, in_features: int, kernel: jax.Array): bias_init=None, ) + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=image_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="conv.in_features", axis=0) def __call__(self, x: jax.Array, **k) -> jax.Array: @@ -202,6 +285,11 @@ class FFTFilter2D(sk.TreeClass): """ def __init__(self, in_features: int, kernel: jax.Array): + if in_features is None: + self.in_features = None + self.kernel = kernel + return + if not isinstance(kernel, jax.Array) or kernel.ndim != 2: raise ValueError("Expected `kernel` to be a 2D `ndarray` with shape (H, W)") @@ -217,6 +305,7 @@ def __init__(self, in_features: int, kernel: jax.Array): bias_init=None, ) + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=image_updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="conv.in_features", axis=0) def __call__(self, x: jax.Array, **k) -> jax.Array: diff --git a/serket/nn/linear.py b/serket/nn/linear.py index 62d239c..c4f9cc5 100644 --- a/serket/nn/linear.py +++ b/serket/nn/linear.py @@ -28,7 +28,12 @@ resolve_activation, ) from serket.nn.initialization import InitType, resolve_init_func -from serket.nn.utils import IsInstance, positive_int_cb +from serket.nn.utils import ( + IsInstance, + maybe_lazy_call, + positive_int_cb, + positive_int_or_none_cb, +) T = TypeVar("T") @@ -40,6 +45,17 @@ class Batched(Generic[T]): PyTree = Any +def is_lazy(instance, *_, **__) -> bool: + return None in getattr(instance, "in_features", [False]) + + +def infer_in_features(_, *x, **__) -> int | tuple[int, ...]: + return x[0].shape[-1] if len(x) == 1 else tuple(xi.shape[-1] for xi in x) + + +linear_updates = dict(in_features=infer_in_features) + + @ft.lru_cache(maxsize=None) def _multilinear_einsum_string(degree: int) -> str: # Generate einsum string for a linear layer of degree n @@ -110,17 +126,52 @@ class Multilinear(sk.TreeClass): >>> layer = sk.nn.Multilinear((5,6,7), 8) >>> layer(jnp.ones((1,5)), jnp.ones((1,6)), jnp.ones((1,7))).shape (1, 8) + + Note: + :class:`.Multilinear` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``(None,)`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> k1, k2 = jr.split(jr.PRNGKey(0)) + >>> @sk.autoinit + ... class Linears(sk.TreeClass): + ... l1: sk.nn.Multilinear = sk.nn.Multilinear((None,), 32, key=k1) + ... l2: sk.nn.Multilinear = sk.nn.Multilinear((32,), 10, key=k2) + ... def __call__(self, x: jax.Array, y: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x, y))) + >>> lazy_linears = Linears() + >>> x = jnp.ones([100, 28]) + >>> y = jnp.ones([100, 56]) + >>> _, materialized_linears = lazy_linears.at["__call__"](x, y) + >>> materialized_linears.l1.in_features + (28, 56) """ def __init__( self, - in_features: int | tuple[int, ...] | None, + in_features: tuple[int | None, ...], out_features: int, *, weight_init: InitType = "he_normal", bias_init: InitType = "ones", key: jr.KeyArray = jr.PRNGKey(0), ): + if None in in_features: + self.in_features = in_features + self.out_features = out_features + self.weight_init = weight_init + self.bias_init = bias_init + self.key = key + return + if not isinstance(in_features, (tuple, int)): raise ValueError(f"Expected tuple or int for {in_features=}.") @@ -135,6 +186,7 @@ def __init__( self.weight = weight_init(k1, weight_shape) self.bias = bias_init(k2, (out_features,)) + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=linear_updates) def __call__(self, *x, **k) -> jax.Array: einsum_string = _multilinear_einsum_string(len(self.in_features)) x = jnp.einsum(einsum_string, *x, self.weight) @@ -157,11 +209,37 @@ class Linear(Multilinear): >>> layer = sk.nn.Linear(5, 6) >>> layer(jnp.ones((1,5))).shape (1, 6) + + Note: + :class:`.Linear` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> import jax + >>> k1, k2 = jr.split(jr.PRNGKey(0)) + >>> @sk.autoinit + ... class Linears(sk.TreeClass): + ... l1: sk.nn.Linear = sk.nn.Linear(None, 32,key=k1) + ... l2: sk.nn.Linear = sk.nn.Linear(32, 10,key=k2) + ... def __call__(self, x: jax.Array) -> jax.Array: + ... return self.l2(jax.nn.relu(self.l1(x))) + >>> lazy_linears = Linears() + >>> x = jnp.ones((100, 28, 28)).reshape(100, -1) + >>> _, materialized_linears = lazy_linears.at["__call__"](x) + >>> materialized_linears.l1.in_features + (784,) """ def __init__( self, - in_features: int, + in_features: int | None, out_features: int, *, weight_init: InitType = "he_normal", @@ -267,7 +345,7 @@ def __init__( out_features: int, key: jr.KeyArray = jr.PRNGKey(0), ): - self.in_features = positive_int_cb(in_features) + self.in_features = positive_int_or_none_cb(in_features) self.out_features = positive_int_cb(out_features) self.weight = jr.uniform(key, (self.in_features, self.out_features)) @@ -282,7 +360,7 @@ def __call__(self, x: jax.Array, **k) -> jax.Array: """ if not jnp.issubdtype(x.dtype, jnp.integer): - raise TypeError("Input must be an integer array.") + raise TypeError(f"{x.dtype=} is not a subdtype of integer") return jnp.take(self.weight, x, axis=0) @@ -314,6 +392,22 @@ class FNN(sk.TreeClass): and single activation function is applied between them. - :class:`.FNN` uses python ``for`` loop to apply layers and activation functions. + Note: + :class:`.FNN` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, add ``None`` as the the first element of the + ``layers`` argument and use the ``.at["calling_method_name"]`` attribute + to call the layer with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> import jax.random as jr + >>> lazy_fnn = sk.nn.FNN([None, 10, 2, 1], key=jr.PRNGKey(0)) + >>> _, materialized_fnn = lazy_fnn.at['__call__'](jnp.ones([1, 10])) + >>> materialized_fnn.layers[0].in_features + (10,) """ def __init__( @@ -457,6 +551,21 @@ class MLP(sk.TreeClass): 403 >>> len(mlp_jaxpr.jaxpr.eqns) 10 + + Note: + :class:`.MLP` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> lazy_mlp = sk.nn.MLP(None, 1, num_hidden_layers=2, hidden_size=10) + >>> _, materialized_mlp = lazy_mlp.at['__call__'](jnp.ones([1, 10])) + >>> materialized_mlp.layers[0].in_features """ def __init__( @@ -494,6 +603,7 @@ def batched_linear(key) -> Batched[Linear]: + [Linear(hidden_size, out_features, key=keys[-1], **kwargs)] ) + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=linear_updates) def __call__(self, x: jax.Array, **k) -> jax.Array: l0, lm, lh = self.layers