Skip to content

Commit

Permalink
Scan cell (#87)
Browse files Browse the repository at this point in the history
* unify output,state = ...(input,state)

* linear edit

* Update recurrent.py

* Update train_bilstm.ipynb

* Update train_convlstm.ipynb

* Update recurrent.py

* Update train_bilstm.ipynb
  • Loading branch information
ASEM000 authored Dec 2, 2023
1 parent 412e41d commit 7473148
Show file tree
Hide file tree
Showing 9 changed files with 373 additions and 1,339 deletions.
7 changes: 1 addition & 6 deletions docs/API/recurrent.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ Recurrent

.. currentmodule:: serket.nn

.. autoclass:: RNNCell

.. autoclass:: LSTMCell
.. autoclass:: GRUCell
.. autoclass:: SimpleRNNCell
Expand All @@ -24,7 +22,4 @@ Recurrent
.. autoclass:: FFTConvGRU2DCell
.. autoclass:: FFTConvGRU3DCell

.. autoclass:: ScanRNN


.. autofunction:: scan_rnn
.. autofunction:: scan_cell
78 changes: 43 additions & 35 deletions docs/notebooks/train_bilstm.ipynb

Large diffs are not rendered by default.

138 changes: 45 additions & 93 deletions docs/notebooks/train_convlstm.ipynb

Large diffs are not rendered by default.

17 changes: 14 additions & 3 deletions serket/_src/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,19 @@ def general_linear(
out = "".join(str(axis) for axis in range(input.ndim) if axis not in in_axis)
out_axis = out_axis if out_axis >= 0 else out_axis + len(out) + 1
out = out[:out_axis] + "F" + out[out_axis:]
result = jnp.einsum(f"{lhs},{rhs}->{out}", input, weight)

try:
einsum = f"{lhs},{rhs}->{out}"
result = jnp.einsum(einsum, input, weight)
except ValueError as error:
raise ValueError(f"{einsum=}\n{input.shape=}\n{weight.shape=}\n{error=}")

if bias is None:
return result
broadcast_shape = list(range(result.ndim))
del broadcast_shape[out_axis]

with jax.ensure_compile_time_eval():
broadcast_shape = list(range(result.ndim))
del broadcast_shape[out_axis]
bias = jnp.expand_dims(bias, axis=broadcast_shape)
return result + bias

Expand Down Expand Up @@ -306,6 +313,10 @@ class MLP(sk.TreeClass):
>>> _, material_layer = lazy_layer.at['__call__'](input)
>>> material_layer.in_linear.in_features
(10,)
Note:
:class:`.MLP` uses ``jax.lax.scan`` to reduce the ``jaxpr`` size.
Leading to faster compilation times and smaller ``jaxpr`` size.
"""

def __init__(
Expand Down
74 changes: 36 additions & 38 deletions serket/_src/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ class GroupNorm(sk.TreeClass):
>>> import serket as sk
>>> import jax.numpy as jnp
>>> import jax.random as jr
>>> key = jr.PRNGKey(0)
>>> layer = sk.nn.GroupNorm(5, groups=1, key=key)
>>> input = jnp.ones((5,10))
>>> layer(input).shape
Expand Down Expand Up @@ -426,13 +427,12 @@ class BatchNorm(sk.TreeClass):
.. warning::
Works under
- ``jax.vmap(BatchNorm(...), in_axes=(0, None), out_axes=(0, None))(x, state)``
- ``jax.vmap(BatchNorm(...), out_axes=(0, None))(x)``
- ``jax.vmap(BatchNorm(...), in_axes=(0, None), out_axes=(0, None))(input, state)``
otherwise will be a no-op.
Training behavior:
- ``output = (x - batch_mean) / sqrt(batch_var + eps)``
- ``output = (input - batch_mean) / sqrt(batch_var + eps)``
- ``running_mean = momentum * running_mean + (1 - momentum) * batch_mean``
- ``running_var = momentum * running_var + (1 - momentum) * batch_var``
Expand Down Expand Up @@ -461,8 +461,9 @@ class BatchNorm(sk.TreeClass):
>>> import jax.random as jr
>>> bn = sk.nn.BatchNorm(10, key=jr.PRNGKey(0))
>>> state = sk.tree_state(bn)
>>> x = jax.random.uniform(jax.random.PRNGKey(0), shape=(5, 10))
>>> x, state = jax.vmap(bn, in_axes=(0, None), out_axes=(0, None))(x, state)
>>> key = jr.PRNGKey(0)
>>> input = jr.uniform(key, shape=(5, 10))
>>> output, state = jax.vmap(bn, in_axes=(0, None), out_axes=(0, None))(input, state)
Example:
Working with :class:`.BatchNorm` with threading the state.
Expand All @@ -476,19 +477,19 @@ class BatchNorm(sk.TreeClass):
... k1, k2 = jax.random.split(key)
... self.bn1 = sk.nn.BatchNorm(5, axis=-1, key=k1)
... self.bn2 = sk.nn.BatchNorm(5, axis=-1, key=k2)
... def __call__(self, x, state):
... x, bn1 = self.bn1(x, state.bn1)
... x = x + 1.0
... x, bn2 = self.bn2(x, state.bn2)
... def __call__(self, input, state):
... input, bn1 = self.bn1(input, state.bn1)
... input = input + 1.0
... input, bn2 = self.bn2(input, state.bn2)
... # update the output state
... state = state.at["bn1"].set(bn1).at["bn2"].set(bn2)
... return x, state
... return input, state
>>> net: ThreadedBatchNorm = ThreadedBatchNorm(key=jr.PRNGKey(0))
>>> # initialize state as the same structure as tree
>>> state: ThreadedBatchNorm = sk.tree_state(net)
>>> x = jnp.linspace(-jnp.pi, jnp.pi, 50 * 20).reshape(20, 10, 5)
>>> for xi in x:
... out, state = jax.vmap(net, in_axes=(0, None), out_axes=(0, None))(xi, state)
>>> inputs = jnp.linspace(-jnp.pi, jnp.pi, 50 * 20).reshape(20, 10, 5)
>>> for input in inputs:
... output, state = jax.vmap(net, in_axes=(0, None), out_axes=(0, None))(input, state)
Example:
Working with :class:`.BatchNorm` without threading the state.
Expand All @@ -511,28 +512,28 @@ class BatchNorm(sk.TreeClass):
... self.bn1_state = sk.tree_state(self.bn1)
... self.bn2 = sk.nn.BatchNorm(5, axis=-1, key=k2)
... self.bn2_state = sk.tree_state(self.bn2)
... def _call(self, x):
... def _call(self, input):
... # this method will raise `AttributeError` if used directly
... # because this method mutates the state
... # instead, use `at["_call"]` to call this method to
... # return the output and updated state in a functional manner
... x, self.bn1_state = self.bn1(x, self.bn1_state)
... x = x + 1.0
... x, self.bn2_state = self.bn2(x, self.bn2_state)
... return x
... def __call__(self, x):
... return self.at["_call"](x)
... input, self.bn1_state = self.bn1(input, self.bn1_state)
... input = input + 1.0
... input, self.bn2_state = self.bn2(input, self.bn2_state)
... return input
... def __call__(self, input):
... return self.at["_call"](input)
>>> # define a function to mask and unmask the net across `vmap`
>>> # this is necessary because `vmap` needs the output to be of inexact
>>> def mask_vmap(net, x):
>>> def mask_vmap(net, input):
... @ft.partial(jax.vmap, out_axes=(0, None))
... def forward(x):
... return sk.tree_mask(net(x))
... return sk.tree_unmask(forward(x))
... def forward(input):
... return sk.tree_mask(net(input))
... return sk.tree_unmask(forward(input))
>>> net: UnthreadedBatchNorm = UnthreadedBatchNorm(key=jr.PRNGKey(0))
>>> input = jnp.linspace(-jnp.pi, jnp.pi, 50 * 20).reshape(20, 10, 5)
>>> for xi in input:
... out, net = mask_vmap(net, xi)
>>> inputs = jnp.linspace(-jnp.pi, jnp.pi, 50 * 20).reshape(20, 10, 5)
>>> for input in inputs:
... output, net = mask_vmap(net, input)
Note:
:class:`.BatchNorm` supports lazy initialization, meaning that the
Expand All @@ -549,8 +550,9 @@ class BatchNorm(sk.TreeClass):
>>> key = jr.PRNGKey(0)
>>> lazy_layer = sk.nn.BatchNorm(None, key=key)
>>> input = jnp.ones((5,10))
>>> _ , material_layer = lazy_layer.at['__call__'](input)
>>> output, state = jax.vmap(material_layer, out_axes=(0, None))(input)
>>> _ , material_layer = lazy_layer.at["__call__"](input, None)
>>> state = sk.tree_state(material_layer)
>>> output, state = jax.vmap(material_layer, in_axes=(0, None), out_axes=(0, None))(input, state)
>>> output.shape
(5, 10)
Expand Down Expand Up @@ -589,11 +591,8 @@ def __init__(

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates)
def __call__(
self,
input: jax.Array,
state: BatchNormState | None = None,
self, input: jax.Array, state: BatchNormState
) -> tuple[jax.Array, BatchNormState]:
state = sk.tree_state(self) if state is None else state
batchnorm_impl = custom_vmap(lambda x, state: (x, state))
momentum, eps = jax.lax.stop_gradient((self.momentum, self.eps))

Expand Down Expand Up @@ -622,13 +621,12 @@ class EvalNorm(sk.TreeClass):
.. warning::
Works under
- ``jax.vmap(BatchNorm(...), in_axes=(0, None), out_axes=(0, None))(x, state)``
- ``jax.vmap(BatchNorm(...), out_axes=(0, None))(x)``
- ``jax.vmap(BatchNorm(...), in_axes=(0, None), out_axes=(0, None))(input, state)``
otherwise will be a no-op.
Evaluation behavior:
- ``output = (x - running_mean) / sqrt(running_var + eps)``
- ``output = (input - running_mean) / sqrt(running_var + eps)``
Args:
in_features: the shape of the input to be normalized.
Expand All @@ -653,10 +651,10 @@ class EvalNorm(sk.TreeClass):
>>> bn = sk.nn.BatchNorm(10, key=jr.PRNGKey(0))
>>> state = sk.tree_state(bn)
>>> input = jax.random.uniform(jr.PRNGKey(0), shape=(5, 10))
>>> output, state = jax.vmap(bn, in_axes=(0, None), out_axes=(0, None))(x, state)
>>> output, state = jax.vmap(bn, in_axes=(0, None), out_axes=(0, None))(input, state)
>>> # convert to evaluation mode
>>> bn = sk.tree_eval(bn)
>>> output, state = jax.vmap(bn, in_axes=(0, None))(input, state)
>>> output, state = jax.vmap(bn, in_axes=(0, None), out_axes=(0,None))(input, state)
Note:
If ``axis_name`` is specified, then ``axis_name`` argument must be passed
Expand Down
Loading

0 comments on commit 7473148

Please sign in to comment.