Skip to content

Commit

Permalink
tree_evaluation -> tree_eval
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Jul 28, 2023
1 parent 8cef188 commit 4055028
Show file tree
Hide file tree
Showing 11 changed files with 64 additions and 46 deletions.
19 changes: 19 additions & 0 deletions CHANEGLOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,25 @@
- `***_init_func` -> `***_init` shorter and more concise
- `gamma_init_func` -> `weight_init`
- `beta_init_func` -> `bias_init`
- `MLP` produces smaller `jaxprs` and are faster to compile. for my use case -higher order differentiation through `PINN`- the new `MLP` is faster to compile.

### Additions

- `tree_eval`: a dispatcher to define layers evaluation rule. for example `Dropout` is changed to `Identity` when `tree_state` is applied.

```python
@sk.tree_eval.def_eval(sk.nn.Dropout)
def dropout_evaluation(_) -> sk.nn.Identity:
return sk.nn.Identity()
```

- `tree_state`: a dispatcher to define state intialization for `BatchNorm`, `RNN` cells.

```python
@sk.tree_state.def_state(sk.nn.SimpleRNNCell)
def simple_rnn_init_state(cell: SimpleRNNCell) -> SimpleRNNState:
return SimpleRNNState(jnp.zeros([cell.hidden_features]))
```

### Deprecations

Expand Down
2 changes: 1 addition & 1 deletion docs/API/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
.. currentmodule:: serket

.. autofunction:: tree_state
.. autofunction:: tree_evaluation
.. autofunction:: tree_eval

.. toctree::
:maxdepth: 2
Expand Down
4 changes: 2 additions & 2 deletions serket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)

from . import nn
from .nn.custom_transform import tree_evaluation, tree_state
from .nn.custom_transform import tree_eval, tree_state

__all__ = (
# general utils
Expand Down Expand Up @@ -79,7 +79,7 @@
"leafwise",
# serket
"nn",
"tree_evaluation",
"tree_eval",
"tree_state",
)

Expand Down
6 changes: 3 additions & 3 deletions serket/nn/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import jax.random as jr

import serket as sk
from serket.nn.custom_transform import tree_evaluation
from serket.nn.custom_transform import tree_eval
from serket.nn.utils import Range


Expand Down Expand Up @@ -111,6 +111,6 @@ def __call__(self, x: jax.Array, key: jr.KeyArray = jr.PRNGKey(0)):
return self.layer(x) if jr.bernoulli(key, p) else x


@tree_evaluation.def_evaluation(RandomApply)
def tree_evaluation_random_apply(layer: RandomApply):
@tree_eval.def_eval(RandomApply)
def tree_eval_random_apply(layer: RandomApply):
return layer.layer
14 changes: 7 additions & 7 deletions serket/nn/custom_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def dispatch_func(leaf):
tree_state.def_state = tree_state.state_dispatcher.register


def tree_evaluation(tree):
def tree_eval(tree):
"""Modify tree layers to disable any trainning related behavior.
For example, :class:`nn.Dropout` layer is replaced by an :class:`nn.Identity` layer
Expand All @@ -110,21 +110,21 @@ def tree_evaluation(tree):
Example:
>>> # dropout is replaced by an identity layer in evaluation mode
>>> # by registering `tree_evaluation.def_evaluation(sk.nn.Dropout, sk.nn.Identity)`
>>> # by registering `tree_eval.def_eval(sk.nn.Dropout, sk.nn.Identity)`
>>> import jax.numpy as jnp
>>> import serket as sk
>>> layer = sk.nn.Dropout(0.5)
>>> sk.tree_evaluation(layer)
>>> sk.tree_eval(layer)
Identity()
"""

types = tuple(set(tree_evaluation.evaluation_dispatcher.registry) - {object})
types = tuple(set(tree_eval.eval_dispatcher.registry) - {object})

def is_leaf(x: Callable[[Any], bool]) -> bool:
return isinstance(x, types)

return jax.tree_map(tree_evaluation.evaluation_dispatcher, tree, is_leaf=is_leaf)
return jax.tree_map(tree_eval.eval_dispatcher, tree, is_leaf=is_leaf)


tree_evaluation.evaluation_dispatcher = ft.singledispatch(lambda x: x)
tree_evaluation.def_evaluation = tree_evaluation.evaluation_dispatcher.register
tree_eval.eval_dispatcher = ft.singledispatch(lambda x: x)
tree_eval.def_eval = tree_eval.eval_dispatcher.register
31 changes: 15 additions & 16 deletions serket/nn/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import jax.random as jr

import serket as sk
from serket.nn.custom_transform import tree_evaluation
from serket.nn.custom_transform import tree_eval
from serket.nn.linear import Identity
from serket.nn.utils import Range, canonicalize, positive_int_cb, validate_spatial_ndim

Expand All @@ -45,11 +45,11 @@ class Dropout(sk.TreeClass):
>>> layer = layer.at["p"].set(0.0, is_leaf=sk.is_frozen)
Note:
Use :func:`.tree_evaluation` to turn off dropout during evaluation.
Use :func:`.tree_eval` to turn off dropout during evaluation.
>>> import serket as sk
>>> layers = sk.nn.Sequential(sk.nn.Dropout(0.5), sk.nn.Linear(10, 10))
>>> sk.tree_evaluation(layers)
>>> sk.tree_eval(layers)
Sequential(
layers=(
Identity(),
Expand Down Expand Up @@ -108,11 +108,11 @@ class Dropout1D(DropoutND):
[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]
Note:
Use :func:`.tree_evaluation` to turn off dropout during evaluation.
Use :func:`.tree_eval` to turn off dropout during evaluation.
>>> import serket as sk
>>> layers = sk.nn.Sequential(sk.nn.Dropout1D(0.5), sk.nn.Linear(10, 10))
>>> sk.tree_evaluation(layers)
>>> sk.tree_eval(layers)
Sequential(
layers=(
Identity(),
Expand Down Expand Up @@ -153,11 +153,11 @@ class Dropout2D(DropoutND):
[2. 2. 2. 2. 2.]]]
Note:
Use :func:`.tree_evaluation` to turn off dropout during evaluation.
Use :func:`.tree_eval` to turn off dropout during evaluation.
>>> import serket as sk
>>> layers = sk.nn.Sequential(sk.nn.Dropout2D(0.5), sk.nn.Linear(10, 10))
>>> sk.tree_evaluation(layers)
>>> sk.tree_eval(layers)
Sequential(
layers=(
Identity(),
Expand Down Expand Up @@ -198,11 +198,11 @@ class Dropout3D(DropoutND):
[2. 2.]]]]
Note:
Use :func:`.tree_evaluation` to turn off dropout during evaluation.
Use :func:`.tree_eval` to turn off dropout during evaluation.
>>> import serket as sk
>>> layers = sk.nn.Sequential(sk.nn.Dropout2D(0.5), sk.nn.Linear(10, 10))
>>> sk.tree_evaluation(layers)
>>> sk.tree_eval(layers)
Sequential(
layers=(
Identity(),
Expand Down Expand Up @@ -311,7 +311,7 @@ class RandomCutout1D(sk.TreeClass):
fill_value: ``fill_value`` to fill the cutout region. Defaults to 0.
Note:
Use :func:`.tree_evaluation` to turn off the cutout during evaluation.
Use :func:`.tree_eval` to turn off the cutout during evaluation.
Examples:
>>> import jax.numpy as jnp
Expand Down Expand Up @@ -353,7 +353,7 @@ class RandomCutout2D(sk.TreeClass):
fill_value: ``fill_value`` to fill the cutout region. Defaults to 0.
Note:
Use :func:`.tree_evaluation` to turn off the cutout during evaluation.
Use :func:`.tree_eval` to turn off the cutout during evaluation.
Reference:
- https://arxiv.org/abs/1708.04552
Expand All @@ -380,10 +380,9 @@ def spatial_ndim(self) -> int:
return 2


@tree_evaluation.def_evaluation(RandomCutout1D)
@tree_evaluation.def_evaluation(RandomCutout2D)
@tree_evaluation.def_evaluation(Dropout)
@tree_evaluation.def_evaluation(DropoutND)
@tree_eval.def_eval(RandomCutout1D)
@tree_eval.def_eval(RandomCutout2D)
@tree_eval.def_eval(Dropout)
@tree_eval.def_eval(DropoutND)
def dropout_evaluation(_) -> Identity:
# dropout is a no-op during evaluation
return Identity()
6 changes: 3 additions & 3 deletions serket/nn/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import serket as sk
from serket.nn.convolution import DepthwiseConv2D, DepthwiseFFTConv2D
from serket.nn.custom_transform import tree_evaluation
from serket.nn.custom_transform import tree_eval
from serket.nn.linear import Identity
from serket.nn.utils import positive_int_cb, validate_axis_shape, validate_spatial_ndim

Expand Down Expand Up @@ -376,6 +376,6 @@ def spatial_ndim(self) -> int:
return 2


@tree_evaluation.def_evaluation(RandomContrast2D)
def tree_evaluation_random_contrast2d(_: RandomContrast2D):
@tree_eval.def_eval(RandomContrast2D)
def tree_eval_random_contrast2d(_: RandomContrast2D):
return Identity()
8 changes: 4 additions & 4 deletions serket/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jax.custom_batching import custom_vmap

import serket as sk
from serket.nn.custom_transform import tree_evaluation, tree_state
from serket.nn.custom_transform import tree_eval, tree_state
from serket.nn.initialization import InitType, resolve_init_func
from serket.nn.utils import Range, ScalarLike, positive_int_cb

Expand Down Expand Up @@ -323,7 +323,7 @@ class BatchNorm(sk.TreeClass):
- ``running_mean = momentum * running_mean + (1 - momentum) * batch_mean``
- ``running_var = momentum * running_var + (1 - momentum) * batch_var``
For evaluation, use :func:`.tree_evaluation` to convert the layer to
For evaluation, use :func:`.tree_eval` to convert the layer to
:class:`nn.EvalNorm`.
Expand Down Expand Up @@ -452,7 +452,7 @@ class EvalNorm(sk.TreeClass):
>>> x = jax.random.uniform(jax.random.PRNGKey(0), shape=(5, 10))
>>> x, state = jax.vmap(bn, in_axes=(0, None))(x, state)
>>> # convert to evaluation mode
>>> bn = sk.tree_evaluation(bn)
>>> bn = sk.tree_eval(bn)
>>> x, state = jax.vmap(bn, in_axes=(0, None))(x, state)
Note:
Expand Down Expand Up @@ -498,7 +498,7 @@ def __call__(
return x, state


@tree_evaluation.def_evaluation(BatchNorm)
@tree_eval.def_eval(BatchNorm)
def _(batchnorm: BatchNorm) -> EvalNorm:
return EvalNorm(
in_features=batchnorm.in_features,
Expand Down
8 changes: 4 additions & 4 deletions serket/nn/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,23 +1087,23 @@ def scan_func(carry, x):


@tree_state.def_state(SimpleRNNCell)
def simple_rnn_init_state(cell: SimpleRNNCell, _) -> SimpleRNNState:
def simple_rnn_init_state(cell: SimpleRNNCell) -> SimpleRNNState:
return SimpleRNNState(jnp.zeros([cell.hidden_features]))


@tree_state.def_state(DenseCell)
def dense_init_state(cell: DenseCell, _) -> DenseState:
def dense_init_state(cell: DenseCell) -> DenseState:
return DenseState(jnp.empty([cell.hidden_features]))


@tree_state.def_state(LSTMCell)
def lstm_init_state(cell: LSTMCell, _) -> LSTMState:
def lstm_init_state(cell: LSTMCell) -> LSTMState:
shape = [cell.hidden_features]
return LSTMState(jnp.zeros(shape), jnp.zeros(shape))


@tree_state.def_state(GRUCell)
def gru_init_state(cell: GRUCell, _) -> GRUState:
def gru_init_state(cell: GRUCell) -> GRUState:
return GRUState(jnp.zeros([cell.hidden_features]))


Expand Down
10 changes: 5 additions & 5 deletions serket/nn/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import jax.random as jr

import serket as sk
from serket.nn.custom_transform import tree_evaluation
from serket.nn.custom_transform import tree_eval
from serket.nn.linear import Identity
from serket.nn.utils import (
IsInstance,
Expand Down Expand Up @@ -655,9 +655,9 @@ def spatial_ndim(self) -> int:
return 2


@tree_evaluation.def_evaluation(RandomCrop1D)
@tree_evaluation.def_evaluation(RandomCrop2D)
@tree_evaluation.def_evaluation(RandomCrop3D)
@tree_evaluation.def_evaluation(RandomZoom2D)
@tree_eval.def_eval(RandomCrop1D)
@tree_eval.def_eval(RandomCrop2D)
@tree_eval.def_eval(RandomCrop3D)
@tree_eval.def_eval(RandomZoom2D)
def random_transform_eval(_) -> Identity:
return Identity()
2 changes: 1 addition & 1 deletion tests/test_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,6 @@ def test_batchnorm(axis):
npt.assert_allclose(bn_keras.moving_variance, state.running_var, rtol=1e-5)

x_keras = bn_keras(x_keras, training=False)
x_sk, _ = jax.vmap(sk.tree_evaluation(bn_sk), in_axes=(0, None))(x_sk, state)
x_sk, _ = jax.vmap(sk.tree_eval(bn_sk), in_axes=(0, None))(x_sk, state)

npt.assert_allclose(x_keras, x_sk, rtol=1e-5)

0 comments on commit 4055028

Please sign in to comment.