Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scan cell #87

Merged
merged 7 commits into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions docs/API/recurrent.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ Recurrent

.. currentmodule:: serket.nn

.. autoclass:: RNNCell

.. autoclass:: LSTMCell
.. autoclass:: GRUCell
.. autoclass:: SimpleRNNCell
Expand All @@ -24,7 +22,4 @@ Recurrent
.. autoclass:: FFTConvGRU2DCell
.. autoclass:: FFTConvGRU3DCell

.. autoclass:: ScanRNN


.. autofunction:: scan_rnn
.. autofunction:: scan_cell
78 changes: 43 additions & 35 deletions docs/notebooks/train_bilstm.ipynb

Large diffs are not rendered by default.

138 changes: 45 additions & 93 deletions docs/notebooks/train_convlstm.ipynb

Large diffs are not rendered by default.

17 changes: 14 additions & 3 deletions serket/_src/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,19 @@
out = "".join(str(axis) for axis in range(input.ndim) if axis not in in_axis)
out_axis = out_axis if out_axis >= 0 else out_axis + len(out) + 1
out = out[:out_axis] + "F" + out[out_axis:]
result = jnp.einsum(f"{lhs},{rhs}->{out}", input, weight)

try:
einsum = f"{lhs},{rhs}->{out}"
result = jnp.einsum(einsum, input, weight)
except ValueError as error:
raise ValueError(f"{einsum=}\n{input.shape=}\n{weight.shape=}\n{error=}")

Check warning on line 69 in serket/_src/nn/linear.py

View check run for this annotation

Codecov / codecov/patch

serket/_src/nn/linear.py#L68-L69

Added lines #L68 - L69 were not covered by tests

if bias is None:
return result
broadcast_shape = list(range(result.ndim))
del broadcast_shape[out_axis]

with jax.ensure_compile_time_eval():
broadcast_shape = list(range(result.ndim))
del broadcast_shape[out_axis]
bias = jnp.expand_dims(bias, axis=broadcast_shape)
return result + bias

Expand Down Expand Up @@ -306,6 +313,10 @@
>>> _, material_layer = lazy_layer.at['__call__'](input)
>>> material_layer.in_linear.in_features
(10,)

Note:
:class:`.MLP` uses ``jax.lax.scan`` to reduce the ``jaxpr`` size.
Leading to faster compilation times and smaller ``jaxpr`` size.
"""

def __init__(
Expand Down
74 changes: 36 additions & 38 deletions serket/_src/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ class GroupNorm(sk.TreeClass):
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> key = jr.PRNGKey(0)
>>> layer = sk.nn.GroupNorm(5, groups=1, key=key)
>>> input = jnp.ones((5,10))
>>> layer(input).shape
Expand Down Expand Up @@ -426,13 +427,12 @@ class BatchNorm(sk.TreeClass):

.. warning::
Works under
- ``jax.vmap(BatchNorm(...), in_axes=(0, None), out_axes=(0, None))(x, state)``
- ``jax.vmap(BatchNorm(...), out_axes=(0, None))(x)``
- ``jax.vmap(BatchNorm(...), in_axes=(0, None), out_axes=(0, None))(input, state)``

otherwise will be a no-op.

Training behavior:
- ``output = (x - batch_mean) / sqrt(batch_var + eps)``
- ``output = (input - batch_mean) / sqrt(batch_var + eps)``
- ``running_mean = momentum * running_mean + (1 - momentum) * batch_mean``
- ``running_var = momentum * running_var + (1 - momentum) * batch_var``

Expand Down Expand Up @@ -461,8 +461,9 @@ class BatchNorm(sk.TreeClass):
>>> import jax.random as jr
>>> bn = sk.nn.BatchNorm(10, key=jr.PRNGKey(0))
>>> state = sk.tree_state(bn)
>>> x = jax.random.uniform(jax.random.PRNGKey(0), shape=(5, 10))
>>> x, state = jax.vmap(bn, in_axes=(0, None), out_axes=(0, None))(x, state)
>>> key = jr.PRNGKey(0)
>>> input = jr.uniform(key, shape=(5, 10))
>>> output, state = jax.vmap(bn, in_axes=(0, None), out_axes=(0, None))(input, state)

Example:
Working with :class:`.BatchNorm` with threading the state.
Expand All @@ -476,19 +477,19 @@ class BatchNorm(sk.TreeClass):
... k1, k2 = jax.random.split(key)
... self.bn1 = sk.nn.BatchNorm(5, axis=-1, key=k1)
... self.bn2 = sk.nn.BatchNorm(5, axis=-1, key=k2)
... def __call__(self, x, state):
... x, bn1 = self.bn1(x, state.bn1)
... x = x + 1.0
... x, bn2 = self.bn2(x, state.bn2)
... def __call__(self, input, state):
... input, bn1 = self.bn1(input, state.bn1)
... input = input + 1.0
... input, bn2 = self.bn2(input, state.bn2)
... # update the output state
... state = state.at["bn1"].set(bn1).at["bn2"].set(bn2)
... return x, state
... return input, state
>>> net: ThreadedBatchNorm = ThreadedBatchNorm(key=jr.PRNGKey(0))
>>> # initialize state as the same structure as tree
>>> state: ThreadedBatchNorm = sk.tree_state(net)
>>> x = jnp.linspace(-jnp.pi, jnp.pi, 50 * 20).reshape(20, 10, 5)
>>> for xi in x:
... out, state = jax.vmap(net, in_axes=(0, None), out_axes=(0, None))(xi, state)
>>> inputs = jnp.linspace(-jnp.pi, jnp.pi, 50 * 20).reshape(20, 10, 5)
>>> for input in inputs:
... output, state = jax.vmap(net, in_axes=(0, None), out_axes=(0, None))(input, state)

Example:
Working with :class:`.BatchNorm` without threading the state.
Expand All @@ -511,28 +512,28 @@ class BatchNorm(sk.TreeClass):
... self.bn1_state = sk.tree_state(self.bn1)
... self.bn2 = sk.nn.BatchNorm(5, axis=-1, key=k2)
... self.bn2_state = sk.tree_state(self.bn2)
... def _call(self, x):
... def _call(self, input):
... # this method will raise `AttributeError` if used directly
... # because this method mutates the state
... # instead, use `at["_call"]` to call this method to
... # return the output and updated state in a functional manner
... x, self.bn1_state = self.bn1(x, self.bn1_state)
... x = x + 1.0
... x, self.bn2_state = self.bn2(x, self.bn2_state)
... return x
... def __call__(self, x):
... return self.at["_call"](x)
... input, self.bn1_state = self.bn1(input, self.bn1_state)
... input = input + 1.0
... input, self.bn2_state = self.bn2(input, self.bn2_state)
... return input
... def __call__(self, input):
... return self.at["_call"](input)
>>> # define a function to mask and unmask the net across `vmap`
>>> # this is necessary because `vmap` needs the output to be of inexact
>>> def mask_vmap(net, x):
>>> def mask_vmap(net, input):
... @ft.partial(jax.vmap, out_axes=(0, None))
... def forward(x):
... return sk.tree_mask(net(x))
... return sk.tree_unmask(forward(x))
... def forward(input):
... return sk.tree_mask(net(input))
... return sk.tree_unmask(forward(input))
>>> net: UnthreadedBatchNorm = UnthreadedBatchNorm(key=jr.PRNGKey(0))
>>> input = jnp.linspace(-jnp.pi, jnp.pi, 50 * 20).reshape(20, 10, 5)
>>> for xi in input:
... out, net = mask_vmap(net, xi)
>>> inputs = jnp.linspace(-jnp.pi, jnp.pi, 50 * 20).reshape(20, 10, 5)
>>> for input in inputs:
... output, net = mask_vmap(net, input)

Note:
:class:`.BatchNorm` supports lazy initialization, meaning that the
Expand All @@ -549,8 +550,9 @@ class BatchNorm(sk.TreeClass):
>>> key = jr.PRNGKey(0)
>>> lazy_layer = sk.nn.BatchNorm(None, key=key)
>>> input = jnp.ones((5,10))
>>> _ , material_layer = lazy_layer.at['__call__'](input)
>>> output, state = jax.vmap(material_layer, out_axes=(0, None))(input)
>>> _ , material_layer = lazy_layer.at["__call__"](input, None)
>>> state = sk.tree_state(material_layer)
>>> output, state = jax.vmap(material_layer, in_axes=(0, None), out_axes=(0, None))(input, state)
>>> output.shape
(5, 10)

Expand Down Expand Up @@ -589,11 +591,8 @@ def __init__(

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates)
def __call__(
self,
input: jax.Array,
state: BatchNormState | None = None,
self, input: jax.Array, state: BatchNormState
) -> tuple[jax.Array, BatchNormState]:
state = sk.tree_state(self) if state is None else state
batchnorm_impl = custom_vmap(lambda x, state: (x, state))
momentum, eps = jax.lax.stop_gradient((self.momentum, self.eps))

Expand Down Expand Up @@ -622,13 +621,12 @@ class EvalNorm(sk.TreeClass):

.. warning::
Works under
- ``jax.vmap(BatchNorm(...), in_axes=(0, None), out_axes=(0, None))(x, state)``
- ``jax.vmap(BatchNorm(...), out_axes=(0, None))(x)``
- ``jax.vmap(BatchNorm(...), in_axes=(0, None), out_axes=(0, None))(input, state)``

otherwise will be a no-op.

Evaluation behavior:
- ``output = (x - running_mean) / sqrt(running_var + eps)``
- ``output = (input - running_mean) / sqrt(running_var + eps)``

Args:
in_features: the shape of the input to be normalized.
Expand All @@ -653,10 +651,10 @@ class EvalNorm(sk.TreeClass):
>>> bn = sk.nn.BatchNorm(10, key=jr.PRNGKey(0))
>>> state = sk.tree_state(bn)
>>> input = jax.random.uniform(jr.PRNGKey(0), shape=(5, 10))
>>> output, state = jax.vmap(bn, in_axes=(0, None), out_axes=(0, None))(x, state)
>>> output, state = jax.vmap(bn, in_axes=(0, None), out_axes=(0, None))(input, state)
>>> # convert to evaluation mode
>>> bn = sk.tree_eval(bn)
>>> output, state = jax.vmap(bn, in_axes=(0, None))(input, state)
>>> output, state = jax.vmap(bn, in_axes=(0, None), out_axes=(0,None))(input, state)

Note:
If ``axis_name`` is specified, then ``axis_name`` argument must be passed
Expand Down
Loading
Loading