Skip to content

Commit

Permalink
update freezing docs
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Dec 21, 2023
1 parent 4131904 commit 23f84f1
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 88 deletions.
171 changes: 111 additions & 60 deletions docs/notebooks/subset_training.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion serket/_src/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


@ft.singledispatch
def sequential(key: jax.Array, _, __):
def sequential(key: jax.Array, _1, _2):
raise TypeError(f"Invalid {type(key)=}")


Expand Down Expand Up @@ -73,6 +73,7 @@ class Sequential(sk.TreeClass):
"""

def __init__(self, *layers):
# use var args to enforce tuple type to maintain immutability
self.layers = layers

def __call__(self, input: jax.Array, *, key: jax.Array | None = None) -> jax.Array:
Expand Down
25 changes: 13 additions & 12 deletions serket/_src/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,18 +147,19 @@ class MultiHeadAttention(sk.TreeClass):
>>> kv_length = 2
>>> mask = jr.uniform(jr.PRNGKey(0), (batch, num_heads, q_length, kv_length))
>>> mask = (mask > 0.5).astype(jnp.float32)
>>> q = jr.uniform(jr.PRNGKey(1), (batch, q_length, q_features))
>>> k = jr.uniform(jr.PRNGKey(2), (batch, kv_length, k_features))
>>> v = jr.uniform(jr.PRNGKey(3), (batch, kv_length, v_features))
>>> k1, k2, k3, k4 = jr.split(jr.PRNGKey(0), 4)
>>> q = jr.uniform(k1, (batch, q_length, q_features))
>>> k = jr.uniform(k2, (batch, kv_length, k_features))
>>> v = jr.uniform(k3, (batch, kv_length, v_features))
>>> layer = sk.nn.MultiHeadAttention(
... num_heads,
... q_features,
... k_features,
... v_features,
... drop_rate=0.0,
... key=jr.PRNGKey(4),
... key=k4,
... )
>>> print(layer(q, k, v, mask=mask, key=jr.PRNGKey(0)).shape)
>>> print(layer(q, k, v, mask=mask, key=jr.PRNGKey(1)).shape)
(3, 4, 4)
Note:
Expand All @@ -184,13 +185,13 @@ class MultiHeadAttention(sk.TreeClass):
>>> import jax.random as jr
>>> import serket as sk
>>> q = jr.uniform(jr.PRNGKey(0), (3, 2, 6))
>>> k = jr.uniform(jr.PRNGKey(1), (3, 2, 6))
>>> v = jr.uniform(jr.PRNGKey(2), (3, 2, 6))
>>> key = jr.PRNGKey(0)
>>> lazy = sk.nn.MultiHeadAttention(2, None, key=key)
>>> _, material = sk.value_and_tree(lambda lazy: lazy(q, k, v, key=key))(lazy)
>>> material(q, k, v, key=key).shape
>>> k1, k2, k3, k4, k5 = jr.split(jr.PRNGKey(0), 5)
>>> q = jr.uniform(k1, (3, 2, 6))
>>> k = jr.uniform(k2, (3, 2, 6))
>>> v = jr.uniform(k3, (3, 2, 6))
>>> lazy = sk.nn.MultiHeadAttention(2, None, key=k4)
>>> _, material = sk.value_and_tree(lambda lazy: lazy(q, k, v, key=k4))(lazy)
>>> material(q, k, v, key=k5).shape
(3, 2, 6)
Reference:
Expand Down
1 change: 0 additions & 1 deletion serket/_src/nn/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
Weight = Annotated[jax.Array, "OI..."]


@ft.partial(jax.jit, static_argnums=(2, 3, 4, 5), inline=True)
def fft_conv_general_dilated(
lhs: jax.Array,
rhs: jax.Array,
Expand Down
62 changes: 48 additions & 14 deletions serket/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def get_params(func: MethodType) -> tuple[inspect.Parameter, ...]:
return tuple(inspect.signature(func).parameters.values())


# TODO: maybe expose this as a public API
# Maybe expose this as a public API
# Handling lazy layers
"""
Creating a _lazy_ ``Linear`` layer example:
Expand All @@ -329,32 +329,67 @@ def get_params(func: MethodType) -> tuple[inspect.Parameter, ...]:
translate code from both explicit and implicit shaped layer found in
libraries like ``pytorch`` and ``tensorflow``.
As quick sketch how this work is in the following example:
>>> import jax
>>> class Lazy:
... def __init__(self, dim_size: int | None):
... # let dim size be the array size
... # and if we dont have the array size
... # we can set it to None to be inferred later
... self.dim_size = dim_size
... def __call__(self, x):
... return x * self.dim_size
>>> def maybe_lazy_init(func):
... def wrapper(self, dim_size):
... if input is not None:
... return func(self, dim_size)
... # we do not execute the init function
... # because its lazy
... return None
... return wrapper
>>> def maybe_lazy_call(func):
... def wrapper(self, x):
... if self.dim_size is not None:
... return func(self, x)
... # the input is lazy , so we do infer the dim size
... # here. because `TreeClass` is immutable we need to
... # return a new instance of the class with the updated
... # dim size, but here we are just updating the dim size
... # of the current instance that is not immutable
... self.dim_size = x.size
... return func(self, x)
... return wrapper
>>> # now lets decorate our lazy class
>>> Lazy.__init__ = maybe_lazy_init(Lazy.__init__)
>>> Lazy.__call__ = maybe_lazy_call(Lazy.__call__)
>>> print(Lazy(2)(jax.numpy.ones([2])))
>>> print(Lazy(None)(jax.numpy.ones([2])))
Now lets create a lazy ``Linear`` layer using ``serket``:
>>> import functools as ft
>>> import serket as sk
>>> import jax.numpy as jnp
>>> from serket._src.utils import maybe_lazy_call, maybe_lazy_init
<BLANKLINE>
>>> def is_lazy_init(self, in_features, out_features):
... # we need to define how to tell if the layer is lazy
... # based on the inputs
... return in_features is None # or anything else really
<BLANKLINE>
>>> def is_lazy_call(self, x):
... # we need to define how to tell if the layer is lazy
... # at the call time
... # replicating the lazy init condition
... return getattr(self, "in_features", False) is None
<BLANKLINE>
>>> def infer_in_features(self, x):
... # we need to define how to infer the in_features
... # based on the inputs at call time
... # for linear layers, we can infer the in_features as the last dimension
... return x.shape[-1]
<BLANKLINE>
>>> # lastly we need to assign this function to a dictionary that has the name
>>> # of the feature we want to infer
>>> updates = dict(in_features=infer_in_features)
<BLANKLINE>
>>> class SimpleLinear(sk.TreeClass):
... @ft.partial(maybe_lazy_init, is_lazy=is_lazy_init)
... def __init__(self, in_features, out_features):
Expand All @@ -365,20 +400,17 @@ def get_params(func: MethodType) -> tuple[inspect.Parameter, ...]:
>>> @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates)
... def __call__(self, x):
... return x
<BLANKLINE>
>>> simple_lazy = SimpleLinear(None, 1)
>>> x = jnp.ones([10, 2]) # last dimension is the in_features of the layer
>>> print(repr(simple_lazy))
SimpleLinear(in_features=None, out_features=1)
<BLANKLINE>
>>> _, material = simple_lazy.at["__call__"](x)
<BLANKLINE>
>>> _, material = sk.value_and_tree(lambda layer: layer(x))(simple_lazy)
>>> print(repr(material))
SimpleLinear(
in_features=2,
out_features=1,
weight=f32[2,1](μ=1.00, σ=0.00, ∈[1.00,1.00]),
bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
in_features=2,
out_features=1,
weight=f32[2,1](μ=1.00, σ=0.00, ∈[1.00,1.00]),
bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
)
"""

Expand Down Expand Up @@ -527,6 +559,8 @@ def inner(instance, *a, **k):
raise RuntimeError(LAZY_CALL_ERROR.format(**kwargs))

# re-initialize the instance with the resolved arguments
# this will only works under `value_and_tree` that allows
# the instance to be mutable with it's context after being copied first
getattr(type(instance), "__init__")(instance, **kwargs)
# call the decorated function
return func(instance, *a, **k)
Expand Down

0 comments on commit 23f84f1

Please sign in to comment.