diff --git a/serket/nn/dropout.py b/serket/nn/dropout.py index 30ace14..870d519 100644 --- a/serket/nn/dropout.py +++ b/serket/nn/dropout.py @@ -56,6 +56,8 @@ class Dropout(sk.TreeClass): Linear( in_features=(10), out_features=10, + weight_init=he_normal, + bias_init=ones, weight=f32[10,10](μ=0.01, σ=0.45, ∈[-0.96,0.95]), bias=f32[10](μ=1.00, σ=0.00, ∈[1.00,1.00]) ) @@ -119,6 +121,8 @@ class Dropout1D(DropoutND): Linear( in_features=(10), out_features=10, + weight_init=he_normal, + bias_init=ones, weight=f32[10,10](μ=0.01, σ=0.45, ∈[-0.96,0.95]), bias=f32[10](μ=1.00, σ=0.00, ∈[1.00,1.00]) ) @@ -164,6 +168,8 @@ class Dropout2D(DropoutND): Linear( in_features=(10), out_features=10, + weight_init=he_normal, + bias_init=ones, weight=f32[10,10](μ=0.01, σ=0.45, ∈[-0.96,0.95]), bias=f32[10](μ=1.00, σ=0.00, ∈[1.00,1.00]) ) @@ -209,6 +215,8 @@ class Dropout3D(DropoutND): Linear( in_features=(10), out_features=10, + weight_init=he_normal, + bias_init=ones, weight=f32[10,10](μ=0.01, σ=0.45, ∈[-0.96,0.95]), bias=f32[10](μ=1.00, σ=0.00, ∈[1.00,1.00]) ) diff --git a/serket/nn/linear.py b/serket/nn/linear.py index 9117c3e..a29bb2c 100644 --- a/serket/nn/linear.py +++ b/serket/nn/linear.py @@ -28,12 +28,7 @@ resolve_activation, ) from serket.nn.initialization import InitType, resolve_init_func -from serket.nn.utils import ( - IsInstance, - maybe_lazy_call, - positive_int_cb, - positive_int_or_none_cb, -) +from serket.nn.utils import maybe_lazy_call, positive_int_cb, positive_int_or_none_cb T = TypeVar("T") @@ -46,14 +41,14 @@ class Batched(Generic[T]): def is_lazy(instance, *_, **__) -> bool: - return None in getattr(instance, "in_features", [False]) + return getattr(instance, "in_features", False) is None -def infer_in_features(_, *x, **__) -> int | tuple[int, ...]: +def infer_multilinear_in_features(_, *x, **__) -> int | tuple[int, ...]: return x[0].shape[-1] if len(x) == 1 else tuple(xi.shape[-1] for xi in x) -linear_updates = dict(in_features=infer_in_features) +linear_updates = dict(in_features=infer_multilinear_in_features) @ft.lru_cache(maxsize=None) @@ -132,7 +127,7 @@ class Multilinear(sk.TreeClass): biases are not initialized until the first call to the layer. This is useful when the input shape is not known at initialization time. - To use lazy initialization, pass ``(None,)`` as the ``in_features`` argument + To use lazy initialization, pass ``None`` as the ``in_features`` argument and use the ``.at["calling_method_name"]`` attribute to call the layer with an input of known shape. @@ -143,7 +138,7 @@ class Multilinear(sk.TreeClass): >>> k1, k2 = jr.split(jr.PRNGKey(0)) >>> @sk.autoinit ... class Linears(sk.TreeClass): - ... l1: sk.nn.Multilinear = sk.nn.Multilinear((None,), 32, key=k1) + ... l1: sk.nn.Multilinear = sk.nn.Multilinear(None, 32, key=k1) ... l2: sk.nn.Multilinear = sk.nn.Multilinear((32,), 10, key=k2) ... def __call__(self, x: jax.Array, y: jax.Array) -> jax.Array: ... return self.l2(jax.nn.relu(self.l1(x, y))) @@ -157,18 +152,19 @@ class Multilinear(sk.TreeClass): def __init__( self, - in_features: tuple[int | None, ...], + in_features: tuple[int, ...] | None, out_features: int, *, weight_init: InitType = "he_normal", bias_init: InitType = "ones", key: jr.KeyArray = jr.PRNGKey(0), ): - if None in in_features: - self.in_features = in_features - self.out_features = out_features - self.weight_init = weight_init - self.bias_init = bias_init + self.in_features = in_features + self.out_features = out_features + self.weight_init = weight_init + self.bias_init = bias_init + + if in_features is None: self.key = key return @@ -176,15 +172,9 @@ def __init__( raise ValueError(f"Expected tuple or int for {in_features=}.") k1, k2 = jr.split(key) - - self.in_features = in_features - self.out_features = out_features - weight_init = resolve_init_func(weight_init) - bias_init = resolve_init_func(bias_init) - weight_shape = (*in_features, out_features) - self.weight = weight_init(k1, weight_shape) - self.bias = bias_init(k2, (out_features,)) + self.weight = resolve_init_func(weight_init)(k1, weight_shape) + self.bias = resolve_init_func(bias_init)(k2, (out_features,)) @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=linear_updates) def __call__(self, *x, **k) -> jax.Array: @@ -247,7 +237,7 @@ def __init__( key: jr.KeyArray = jr.PRNGKey(0), ): super().__init__( - (in_features,), + in_features if in_features is None else (in_features,), out_features, weight_init=weight_init, bias_init=bias_init, @@ -255,6 +245,18 @@ def __init__( ) +def is_lazy(instance, *_, **__) -> bool: + return getattr(instance, "in_features", False) is None + + +def infer_in_features(instance, x, **__) -> tuple[int, ...]: + in_axes = getattr(instance, "in_axes", ()) + return tuple(x.shape[i] for i in in_axes) + + +general_linear_updates = dict(in_features=infer_in_features) + + class GeneralLinear(sk.TreeClass): """Apply a Linear Layer to input at in_axes @@ -275,13 +277,32 @@ class GeneralLinear(sk.TreeClass): (3, 4, 5) Note: - This layer is similar to to flax linen's DenseGeneral, the difference is that - this layer uses einsum to apply the linear layer to the specified axes. + This layer is similar to to ``flax`` linen's ``DenseGeneral``, the difference + is that this layer uses einsum to apply the linear layer to the specified axes. + + Note: + :class:`.GeneralLinear` supports lazy initialization, meaning that the weights and + biases are not initialized until the first call to the layer. This is + useful when the input shape is not known at initialization time. + + To use lazy initialization, pass ``None`` as the ``in_features`` argument + and use the ``.at["calling_method_name"]`` attribute to call the layer + with an input of known shape. + + >>> import jax + >>> import jax.numpy as jnp + >>> import serket as sk + >>> lazy_linear = sk.nn.GeneralLinear(None, 12, in_axes=(0, 2)) + >>> _, materialized_linear = lazy_linear.at['__call__'](jnp.ones((10, 5, 4))) + >>> materialized_linear.in_features + (10, 4) + >>> materialized_linear(jnp.ones((10, 5, 4))).shape + (5, 12) """ def __init__( self, - in_features: tuple[int, ...], + in_features: tuple[int, ...] | None, out_features: int, *, in_axes: tuple[int, ...], @@ -289,23 +310,31 @@ def __init__( bias_init: InitType = "ones", key: jr.KeyArray = jr.PRNGKey(0), ): - self.in_features = IsInstance(tuple)(in_features) + self.in_features = in_features self.out_features = out_features - self.in_axes = IsInstance(tuple)(in_axes) + self.in_axes = in_axes + self.weight_init = weight_init + self.bias_init = bias_init + + if in_features is None: + self.key = key + return + + if not (all(isinstance(i, int) for i in in_features)): + raise ValueError(f"Expected tuple of ints for {in_features=}") + + if not (all(isinstance(i, int) for i in in_axes)): + raise ValueError(f"Expected tuple of ints for {in_axes=}") if len(in_axes) != len(in_features): - raise ValueError( - "Expected in_axes and in_features to have the same length," - f"got {len(in_axes)=} and {len(in_features)=}" - ) + raise ValueError(f"{len(in_axes)=} != {len(in_features)=}") k1, k2 = jr.split(key) + weight_shape = (*in_features, out_features) + self.weight = resolve_init_func(weight_init)(k1, weight_shape) + self.bias = resolve_init_func(bias_init)(k2, (self.out_features,)) - weight_init = resolve_init_func(weight_init) - bias_init = resolve_init_func(bias_init) - self.weight = weight_init(k1, (*self.in_features, self.out_features)) - self.bias = bias_init(k2, (self.out_features,)) - + @ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=general_linear_updates) def __call__(self, x: jax.Array, **k) -> jax.Array: # ensure negative axes axes = map(lambda i: i if i < 0 else i - x.ndim, self.in_axes)