Skip to content

Commit

Permalink
lazy general linear
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 1, 2023
1 parent 6b70844 commit 20723a5
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 40 deletions.
8 changes: 8 additions & 0 deletions serket/nn/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class Dropout(sk.TreeClass):
Linear(
in_features=(10),
out_features=10,
weight_init=he_normal,
bias_init=ones,
weight=f32[10,10](μ=0.01, σ=0.45, ∈[-0.96,0.95]),
bias=f32[10](μ=1.00, σ=0.00, ∈[1.00,1.00])
)
Expand Down Expand Up @@ -119,6 +121,8 @@ class Dropout1D(DropoutND):
Linear(
in_features=(10),
out_features=10,
weight_init=he_normal,
bias_init=ones,
weight=f32[10,10](μ=0.01, σ=0.45, ∈[-0.96,0.95]),
bias=f32[10](μ=1.00, σ=0.00, ∈[1.00,1.00])
)
Expand Down Expand Up @@ -164,6 +168,8 @@ class Dropout2D(DropoutND):
Linear(
in_features=(10),
out_features=10,
weight_init=he_normal,
bias_init=ones,
weight=f32[10,10](μ=0.01, σ=0.45, ∈[-0.96,0.95]),
bias=f32[10](μ=1.00, σ=0.00, ∈[1.00,1.00])
)
Expand Down Expand Up @@ -209,6 +215,8 @@ class Dropout3D(DropoutND):
Linear(
in_features=(10),
out_features=10,
weight_init=he_normal,
bias_init=ones,
weight=f32[10,10](μ=0.01, σ=0.45, ∈[-0.96,0.95]),
bias=f32[10](μ=1.00, σ=0.00, ∈[1.00,1.00])
)
Expand Down
109 changes: 69 additions & 40 deletions serket/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,7 @@
resolve_activation,
)
from serket.nn.initialization import InitType, resolve_init_func
from serket.nn.utils import (
IsInstance,
maybe_lazy_call,
positive_int_cb,
positive_int_or_none_cb,
)
from serket.nn.utils import maybe_lazy_call, positive_int_cb, positive_int_or_none_cb

T = TypeVar("T")

Expand All @@ -46,14 +41,14 @@ class Batched(Generic[T]):


def is_lazy(instance, *_, **__) -> bool:
return None in getattr(instance, "in_features", [False])
return getattr(instance, "in_features", False) is None


def infer_in_features(_, *x, **__) -> int | tuple[int, ...]:
def infer_multilinear_in_features(_, *x, **__) -> int | tuple[int, ...]:
return x[0].shape[-1] if len(x) == 1 else tuple(xi.shape[-1] for xi in x)


linear_updates = dict(in_features=infer_in_features)
linear_updates = dict(in_features=infer_multilinear_in_features)


@ft.lru_cache(maxsize=None)
Expand Down Expand Up @@ -132,7 +127,7 @@ class Multilinear(sk.TreeClass):
biases are not initialized until the first call to the layer. This is
useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``(None,)`` as the ``in_features`` argument
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use the ``.at["calling_method_name"]`` attribute to call the layer
with an input of known shape.
Expand All @@ -143,7 +138,7 @@ class Multilinear(sk.TreeClass):
>>> k1, k2 = jr.split(jr.PRNGKey(0))
>>> @sk.autoinit
... class Linears(sk.TreeClass):
... l1: sk.nn.Multilinear = sk.nn.Multilinear((None,), 32, key=k1)
... l1: sk.nn.Multilinear = sk.nn.Multilinear(None, 32, key=k1)
... l2: sk.nn.Multilinear = sk.nn.Multilinear((32,), 10, key=k2)
... def __call__(self, x: jax.Array, y: jax.Array) -> jax.Array:
... return self.l2(jax.nn.relu(self.l1(x, y)))
Expand All @@ -157,34 +152,29 @@ class Multilinear(sk.TreeClass):

def __init__(
self,
in_features: tuple[int | None, ...],
in_features: tuple[int, ...] | None,
out_features: int,
*,
weight_init: InitType = "he_normal",
bias_init: InitType = "ones",
key: jr.KeyArray = jr.PRNGKey(0),
):
if None in in_features:
self.in_features = in_features
self.out_features = out_features
self.weight_init = weight_init
self.bias_init = bias_init
self.in_features = in_features
self.out_features = out_features
self.weight_init = weight_init
self.bias_init = bias_init

if in_features is None:
self.key = key
return

if not isinstance(in_features, (tuple, int)):
raise ValueError(f"Expected tuple or int for {in_features=}.")

k1, k2 = jr.split(key)

self.in_features = in_features
self.out_features = out_features
weight_init = resolve_init_func(weight_init)
bias_init = resolve_init_func(bias_init)

weight_shape = (*in_features, out_features)
self.weight = weight_init(k1, weight_shape)
self.bias = bias_init(k2, (out_features,))
self.weight = resolve_init_func(weight_init)(k1, weight_shape)
self.bias = resolve_init_func(bias_init)(k2, (out_features,))

@ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=linear_updates)
def __call__(self, *x, **k) -> jax.Array:
Expand Down Expand Up @@ -247,14 +237,26 @@ def __init__(
key: jr.KeyArray = jr.PRNGKey(0),
):
super().__init__(
(in_features,),
in_features if in_features is None else (in_features,),
out_features,
weight_init=weight_init,
bias_init=bias_init,
key=key,
)


def is_lazy(instance, *_, **__) -> bool:
return getattr(instance, "in_features", False) is None


def infer_in_features(instance, x, **__) -> tuple[int, ...]:
in_axes = getattr(instance, "in_axes", ())
return tuple(x.shape[i] for i in in_axes)


general_linear_updates = dict(in_features=infer_in_features)


class GeneralLinear(sk.TreeClass):
"""Apply a Linear Layer to input at in_axes
Expand All @@ -275,37 +277,64 @@ class GeneralLinear(sk.TreeClass):
(3, 4, 5)
Note:
This layer is similar to to flax linen's DenseGeneral, the difference is that
this layer uses einsum to apply the linear layer to the specified axes.
This layer is similar to to ``flax`` linen's ``DenseGeneral``, the difference
is that this layer uses einsum to apply the linear layer to the specified axes.
Note:
:class:`.GeneralLinear` supports lazy initialization, meaning that the weights and
biases are not initialized until the first call to the layer. This is
useful when the input shape is not known at initialization time.
To use lazy initialization, pass ``None`` as the ``in_features`` argument
and use the ``.at["calling_method_name"]`` attribute to call the layer
with an input of known shape.
>>> import jax
>>> import jax.numpy as jnp
>>> import serket as sk
>>> lazy_linear = sk.nn.GeneralLinear(None, 12, in_axes=(0, 2))
>>> _, materialized_linear = lazy_linear.at['__call__'](jnp.ones((10, 5, 4)))
>>> materialized_linear.in_features
(10, 4)
>>> materialized_linear(jnp.ones((10, 5, 4))).shape
(5, 12)
"""

def __init__(
self,
in_features: tuple[int, ...],
in_features: tuple[int, ...] | None,
out_features: int,
*,
in_axes: tuple[int, ...],
weight_init: InitType = "he_normal",
bias_init: InitType = "ones",
key: jr.KeyArray = jr.PRNGKey(0),
):
self.in_features = IsInstance(tuple)(in_features)
self.in_features = in_features
self.out_features = out_features
self.in_axes = IsInstance(tuple)(in_axes)
self.in_axes = in_axes
self.weight_init = weight_init
self.bias_init = bias_init

if in_features is None:
self.key = key
return

if not (all(isinstance(i, int) for i in in_features)):
raise ValueError(f"Expected tuple of ints for {in_features=}")

if not (all(isinstance(i, int) for i in in_axes)):
raise ValueError(f"Expected tuple of ints for {in_axes=}")

if len(in_axes) != len(in_features):
raise ValueError(
"Expected in_axes and in_features to have the same length,"
f"got {len(in_axes)=} and {len(in_features)=}"
)
raise ValueError(f"{len(in_axes)=} != {len(in_features)=}")

k1, k2 = jr.split(key)
weight_shape = (*in_features, out_features)
self.weight = resolve_init_func(weight_init)(k1, weight_shape)
self.bias = resolve_init_func(bias_init)(k2, (self.out_features,))

weight_init = resolve_init_func(weight_init)
bias_init = resolve_init_func(bias_init)
self.weight = weight_init(k1, (*self.in_features, self.out_features))
self.bias = bias_init(k2, (self.out_features,))

@ft.partial(maybe_lazy_call, is_lazy=is_lazy, updates=general_linear_updates)
def __call__(self, x: jax.Array, **k) -> jax.Array:
# ensure negative axes
axes = map(lambda i: i if i < 0 else i - x.ndim, self.in_axes)
Expand Down

0 comments on commit 20723a5

Please sign in to comment.