diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 8e9e36f..7a52828 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -1,9 +1,5 @@ name: pypi - -on: - release: - types: [created] - +on: workflow_dispatch jobs: deploy: runs-on: ubuntu-latest @@ -16,11 +12,12 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install setuptools wheel twine + pip install build + pip install twine - name: Build and publish env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} run: | - python setup.py sdist bdist_wheel - twine upload dist/* + python -m build + python -m twine upload dist/* diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a37e82e..85f104b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -22,8 +22,8 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install git+https://github.com/ASEM000/PyTreeClass - python -m pip install tensorflow + python -m pip install pytreeclass>=0.4.0 + python -m pip install keras_core>=0.1.1 python -m pip install pytest wheel optax jaxlib coverage kernex - name: Pytest Check run: | diff --git a/serket/__init__.py b/serket/__init__.py index e3ca36b..8d7cffb 100644 --- a/serket/__init__.py +++ b/serket/__init__.py @@ -27,6 +27,7 @@ is_tree_equal, tree_diagram, tree_flatten_with_trace, + tree_graph, tree_leaves_with_trace, tree_map_with_trace, tree_mask, @@ -40,6 +41,8 @@ ) from . import nn +from .nn.evaluation import tree_evaluation +from .nn.state import tree_state __all__ = ( # general utils @@ -49,6 +52,7 @@ "fields", # pprint utils "tree_diagram", + "tree_graph", "tree_mermaid", "tree_repr", "tree_str", @@ -72,7 +76,9 @@ "Partial", # serket "nn", + "tree_evaluation", + "tree_state", ) -__version__ = "0.2.0b7" +__version__ = "0.4.0b1" diff --git a/serket/nn/__init__.py b/serket/nn/__init__.py index 6330869..d685f35 100644 --- a/serket/nn/__init__.py +++ b/serket/nn/__init__.py @@ -86,16 +86,8 @@ from .flatten import Flatten, Unflatten from .flip import FlipLeftRight2D, FlipUpDown2D from .fully_connected import FNN, MLP -from .linear import ( - Bilinear, - Embedding, - GeneralLinear, - Identity, - Linear, - MergeLinear, - Multilinear, -) -from .normalization import GroupNorm, InstanceNorm, LayerNorm +from .linear import Bilinear, Embedding, GeneralLinear, Identity, Linear, Multilinear +from .normalization import BatchNorm, GroupNorm, InstanceNorm, LayerNorm from .padding import Pad1D, Pad2D, Pad3D from .pooling import ( AdaptiveAvgPool1D, @@ -149,7 +141,6 @@ "Multilinear", "GeneralLinear", "Embedding", - "MergeLinear", # Dropout "Dropout", "Dropout1D", @@ -215,6 +206,7 @@ "LayerNorm", "InstanceNorm", "GroupNorm", + "BatchNorm", # Blur "AvgBlur2D", "GaussianBlur2D", diff --git a/serket/nn/contrast.py b/serket/nn/contrast.py index 0e9238c..d96af36 100644 --- a/serket/nn/contrast.py +++ b/serket/nn/contrast.py @@ -95,9 +95,10 @@ def __init__(self, contrast_range=(0.5, 1)): and len(contrast_range) == 2 and contrast_range[0] <= contrast_range[1] ): - msg = "contrast_range must be a tuple of two floats, " - msg += "with the first one smaller than the second one." - raise ValueError(msg) + raise ValueError( + "`contrast_range` must be a tuple of two floats, " + "with the first one smaller than the second one." + ) self.contrast_range = contrast_range diff --git a/serket/nn/convolution.py b/serket/nn/convolution.py index 9aba873..3c50438 100644 --- a/serket/nn/convolution.py +++ b/serket/nn/convolution.py @@ -81,24 +81,20 @@ def __init__( self.spatial_ndim, name="kernel_dilation", ) - self.weight_init_func = resolve_init_func(weight_init_func) - self.bias_init_func = resolve_init_func(bias_init_func) + + weight_init_func = resolve_init_func(weight_init_func) + bias_init_func = resolve_init_func(bias_init_func) + self.groups = positive_int_cb(groups) if self.out_features % self.groups != 0: - raise ValueError( - f"Expected out_features % groups == 0, \n" - f"got {self.out_features % self.groups}" - ) + raise ValueError(f"{(out_features % groups == 0)=}") weight_shape = (out_features, in_features // groups, *self.kernel_size) - self.weight = self.weight_init_func(key, weight_shape) + self.weight = weight_init_func(key, weight_shape) - if bias_init_func is None: - self.bias = None - else: - bias_shape = (out_features, *(1,) * self.spatial_ndim) - self.bias = self.bias_init_func(key, bias_shape) + bias_shape = (out_features, *(1,) * self.spatial_ndim) + self.bias = bias_init_func(key, bias_shape) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) @@ -432,24 +428,18 @@ def __init__( self.spatial_ndim, name="kernel_dilation", ) - self.weight_init_func = resolve_init_func(weight_init_func) - self.bias_init_func = resolve_init_func(bias_init_func) + weight_init_func = resolve_init_func(weight_init_func) + bias_init_func = resolve_init_func(bias_init_func) self.groups = positive_int_cb(groups) if self.out_features % self.groups != 0: - raise ValueError( - "Expected out_features % groups == 0," - f"got {self.out_features % self.groups}" - ) + raise ValueError(f"{(self.out_features % self.groups ==0)=}") weight_shape = (out_features, in_features // groups, *self.kernel_size) # OIHW - self.weight = self.weight_init_func(key, weight_shape) + self.weight = weight_init_func(key, weight_shape) - if bias_init_func is None: - self.bias = None - else: - bias_shape = (out_features, *(1,) * self.spatial_ndim) - self.bias = self.bias_init_func(key, bias_shape) + bias_shape = (out_features, *(1,) * self.spatial_ndim) + self.bias = bias_init_func(key, bias_shape) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) @@ -774,19 +764,18 @@ def __init__( self.padding = padding # delayed canonicalization self.input_dilation = canonicalize(1, self.spatial_ndim, name="input_dilation") self.kernel_dilation = canonicalize( - 1, self.spatial_ndim, name="kernel_dilation" + 1, + self.spatial_ndim, + name="kernel_dilation", ) - self.weight_init_func = resolve_init_func(weight_init_func) - self.bias_init_func = resolve_init_func(bias_init_func) + weight_init_func = resolve_init_func(weight_init_func) + bias_init_func = resolve_init_func(bias_init_func) weight_shape = (depth_multiplier * in_features, 1, *self.kernel_size) # OIHW - self.weight = self.weight_init_func(key, weight_shape) + self.weight = weight_init_func(key, weight_shape) - if bias_init_func is None: - self.bias = None - else: - bias_shape = (depth_multiplier * in_features, *(1,) * self.spatial_ndim) - self.bias = self.bias_init_func(key, bias_shape) + bias_shape = (depth_multiplier * in_features, *(1,) * self.spatial_ndim) + self.bias = bias_init_func(key, bias_shape) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) @@ -1359,8 +1348,8 @@ def __init__( self.spatial_ndim, name="kernel_dilation", ) - self.weight_init_func = resolve_init_func(weight_init_func) - self.bias_init_func = resolve_init_func(bias_init_func) + weight_init_func = resolve_init_func(weight_init_func) + bias_init_func = resolve_init_func(bias_init_func) out_size = calculate_convolution_output_shape( shape=self.in_size, @@ -1376,14 +1365,10 @@ def __init__( *out_size, ) - self.weight = self.weight_init_func(key, weight_shape) + self.weight = weight_init_func(key, weight_shape) bias_shape = (self.out_features, *out_size) - - if bias_init_func is None: - self.bias = None - else: - self.bias = self.bias_init_func(key, bias_shape) + self.bias = bias_init_func(key, bias_shape) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) diff --git a/serket/nn/dropout.py b/serket/nn/dropout.py index ff40bfd..721e2eb 100644 --- a/serket/nn/dropout.py +++ b/serket/nn/dropout.py @@ -22,6 +22,8 @@ from jax import lax import serket as sk +from serket.nn.evaluation import tree_evaluation +from serket.nn.linear import Identity from serket.nn.utils import Range, validate_spatial_ndim @@ -38,7 +40,7 @@ class Dropout(sk.TreeClass): >>> import jax.numpy as jnp >>> layer = sk.nn.Dropout(0.5) >>> # change `p` to 0.0 to turn off dropout - >>> layer = layer.at["p"].set(0.0, is_leaf=pytc.is_frozen) + >>> layer = layer.at["p"].set(0.0, is_leaf=sk.is_frozen) Note: Use `p`= 0.0 to turn off dropout. @@ -157,3 +159,10 @@ def __init__(self, p: float = 0.5): @property def spatial_ndim(self) -> int: return 3 + + +@tree_evaluation.def_evalutation(Dropout) +@tree_evaluation.def_evalutation(DropoutND) +def dropout_evaluation(_): + # dropout is a no-op during evaluation + return Identity() diff --git a/serket/nn/evaluation.py b/serket/nn/evaluation.py new file mode 100644 index 0000000..a3206c4 --- /dev/null +++ b/serket/nn/evaluation.py @@ -0,0 +1,58 @@ +# Copyright 2023 Serket authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Define dispatchers for custom tree evaluation.""" + +from __future__ import annotations + +import functools as ft +from typing import Any, Callable, TypeVar + +import jax + +T = TypeVar("T") + + +def tree_evaluation(tree: T) -> T: + """Modify tree layers to disable any trainning related behavior. + + For example, `Dropout` layers drop probability is set to 0.0. and `BatchNorm` + layer `track_running_stats` is set to False when evaluating the tree. + + Args: + tree: A tree of layers. + + Returns: + A tree of layers with evaluation behavior. + + Example: + >>> # dropout is replaced by an identity layer in evaluation mode + >>> # by registering `tree_evaluation.def_evaluation(sk.nn.Dropout, sk.nn.Identity)` + >>> import jax.numpy as jnp + >>> import serket as sk + >>> layer = sk.nn.Dropout(0.5) + >>> sk.tree_evaluation(layer) + Identity() + """ + + def is_leaf(x: Callable[[Any], bool]) -> bool: + types = set(tree_evaluation.evaluation_dispatcher.registry.keys()) + types.discard(object) + return isinstance(x, tuple(types)) + + return jax.tree_map(tree_evaluation.evaluation_dispatcher, tree, is_leaf=is_leaf) + + +tree_evaluation.evaluation_dispatcher = ft.singledispatch(lambda x: x) +tree_evaluation.def_evalutation = tree_evaluation.evaluation_dispatcher.register diff --git a/serket/nn/fft_convolution.py b/serket/nn/fft_convolution.py index c49b12f..6fc15c5 100644 --- a/serket/nn/fft_convolution.py +++ b/serket/nn/fft_convolution.py @@ -179,8 +179,8 @@ def __init__( self.spatial_ndim, name="kernel_dilation", ) - self.weight_init_func = resolve_init_func(weight_init_func) - self.bias_init_func = resolve_init_func(bias_init_func) + weight_init_func = resolve_init_func(weight_init_func) + bias_init_func = resolve_init_func(bias_init_func) self.groups = positive_int_cb(groups) if self.out_features % self.groups != 0: @@ -188,13 +188,10 @@ def __init__( raise ValueError(msg) weight_shape = (out_features, in_features // groups, *self.kernel_size) - self.weight = self.weight_init_func(key, weight_shape) + self.weight = weight_init_func(key, weight_shape) - if bias_init_func is None: - self.bias = None - else: - bias_shape = (out_features, *(1,) * self.spatial_ndim) - self.bias = self.bias_init_func(key, bias_shape) + bias_shape = (out_features, *(1,) * self.spatial_ndim) + self.bias = bias_init_func(key, bias_shape) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) @@ -513,8 +510,8 @@ def __init__( self.spatial_ndim, name="kernel_dilation", ) - self.weight_init_func = resolve_init_func(weight_init_func) - self.bias_init_func = resolve_init_func(bias_init_func) + weight_init_func = resolve_init_func(weight_init_func) + bias_init_func = resolve_init_func(bias_init_func) self.groups = positive_int_cb(groups) if self.in_features % self.groups != 0: @@ -524,13 +521,13 @@ def __init__( ) weight_shape = (out_features, in_features // groups, *self.kernel_size) # OIHW - self.weight = self.weight_init_func(key, weight_shape) + self.weight = weight_init_func(key, weight_shape) if bias_init_func is None: self.bias = None else: bias_shape = (out_features, *(1,) * self.spatial_ndim) - self.bias = self.bias_init_func(key, bias_shape) + self.bias = bias_init_func(key, bias_shape) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) @@ -857,19 +854,18 @@ def __init__( self.padding = padding self.input_dilation = canonicalize(1, self.spatial_ndim, name="input_dilation") self.kernel_dilation = canonicalize( - 1, self.spatial_ndim, name="kernel_dilation" + 1, + self.spatial_ndim, + name="kernel_dilation", ) - self.weight_init_func = resolve_init_func(weight_init_func) - self.bias_init_func = resolve_init_func(bias_init_func) + weight_init_func = resolve_init_func(weight_init_func) + bias_init_func = resolve_init_func(bias_init_func) weight_shape = (depth_multiplier * in_features, 1, *self.kernel_size) # OIHW - self.weight = self.weight_init_func(key, weight_shape) + self.weight = weight_init_func(key, weight_shape) - if bias_init_func is None: - self.bias = None - else: - bias_shape = (depth_multiplier * in_features, *(1,) * self.spatial_ndim) - self.bias = self.bias_init_func(key, bias_shape) + bias_shape = (depth_multiplier * in_features, *(1,) * self.spatial_ndim) + self.bias = bias_init_func(key, bias_shape) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) diff --git a/serket/nn/initialization.py b/serket/nn/initialization.py index ff838d7..2f91741 100644 --- a/serket/nn/initialization.py +++ b/serket/nn/initialization.py @@ -62,7 +62,7 @@ init_map: dict[str, InitType] = dict(zip(get_args(InitLiteral), inits)) -def resolve_init_func(init_func: str | Callable) -> Callable: +def resolve_init_func(init_func: str | InitFuncType) -> Callable: if isinstance(init_func, FunctionType): return jtu.Partial(init_func) @@ -74,6 +74,6 @@ def resolve_init_func(init_func: str | Callable) -> Callable: raise ValueError(f"value must be one of ({', '.join(init_map.keys())})") if init_func is None: - return None + return jtu.Partial(lambda key, shape, dtype=None: None) raise ValueError("Value must be a string or a function.") diff --git a/serket/nn/linear.py b/serket/nn/linear.py index 5aeea2f..8deb872 100644 --- a/serket/nn/linear.py +++ b/serket/nn/linear.py @@ -110,16 +110,14 @@ def __init__( self.in_features = in_features self.out_features = out_features - self.weight_init_func = resolve_init_func(weight_init_func) - self.bias_init_func = resolve_init_func(bias_init_func) + weight_init_func = resolve_init_func(weight_init_func) + bias_init_func = resolve_init_func(bias_init_func) weight_shape = (*self.in_features, out_features) - self.weight = self.weight_init_func(key, weight_shape) + self.weight = weight_init_func(key, weight_shape) self.bias = ( - None - if bias_init_func is None - else self.bias_init_func(key, (out_features,)) + None if bias_init_func is None else bias_init_func(key, (out_features,)) ) def __call__(self, *x, **k) -> jax.Array: @@ -239,15 +237,10 @@ def __init__( f"got {len(in_axes)=} and {len(in_features)=}" ) - self.weight_init_func = resolve_init_func(weight_init_func) - self.bias_init_func = resolve_init_func(bias_init_func) - self.weight = self.weight_init_func(key, (*self.in_features, self.out_features)) - - self.bias = ( - None - if self.bias_init_func is None - else self.bias_init_func(key, (self.out_features,)) - ) + weight_init_func = resolve_init_func(weight_init_func) + bias_init_func = resolve_init_func(bias_init_func) + self.weight = weight_init_func(key, (*self.in_features, self.out_features)) + self.bias = bias_init_func(key, (self.out_features,)) def __call__(self, x: jax.Array, **k) -> jax.Array: # ensure negative axes @@ -305,45 +298,3 @@ def __call__(self, x: jax.Array, **k) -> jax.Array: raise TypeError("Input must be an integer array.") return jnp.take(self.weight, x, axis=0) - - -class MergeLinear(sk.TreeClass): - """Merge multiple linear layers with the same `out_features`. - - Args: - layers: linear layers to merge - - Example: - >>> import serket as sk - >>> import numpy.testing as npt - >>> layer1 = sk.nn.Linear(5, 6) # 5 input features, 6 output features - >>> layer2 = sk.nn.Linear(7, 6) # 7 input features, 6 output features - >>> merged_layer = sk.nn.MergeLinear(layer1, layer2) # 12 input features, 6 output features - >>> x1 = jnp.ones([1, 5]) # 1 sample, 5 features - >>> x2 = jnp.ones([1, 7]) # 1 sample, 7 features - >>> y = merged_layer(x1, x2) # one matrix multiplication - >>> z = layer1(x1) + layer2(x2) # two matrix multiplications - >>> npt.assert_allclose(y, z, atol=1e-6) - - Note: - Use this layer to reduce the matrix multiplication operations in the forward pass. - """ - - def __init__(self, *layers: tuple[Linear, ...]): - out_dim0 = layers[0].out_features - if not all(isinstance(layer, Linear) for layer in layers): - raise TypeError("All layers must be instances of Linear.") - - for layer in layers[1:]: - if layer.out_features != out_dim0: - raise ValueError( - "All layers must have the same output dimension." - f" Got {out_dim0} and {layer.out_features}" - ) - - self.weight = jnp.concatenate([L.weight for L in layers], axis=0) - self.bias = sum([L.bias for L in layers if L.bias_init_func]) - - def __call__(self, *xs: tuple[jax.Array, ...], **k) -> jax.Array: - xs = jnp.concatenate(xs, axis=-1) - return xs @ self.weight + self.bias diff --git a/serket/nn/normalization.py b/serket/nn/normalization.py index 69cecd3..00f80ca 100644 --- a/serket/nn/normalization.py +++ b/serket/nn/normalization.py @@ -14,14 +14,16 @@ from __future__ import annotations -from typing import NamedTuple - import jax import jax.numpy as jnp +import jax.random as jr from jax.custom_batching import custom_vmap import serket as sk -from serket.nn.utils import IsInstance, Range, ScalarLike, positive_int_cb +from serket.nn.evaluation import tree_evaluation +from serket.nn.initialization import InitType, resolve_init_func +from serket.nn.state import tree_state +from serket.nn.utils import Range, ScalarLike, positive_int_cb def layer_norm( @@ -49,8 +51,12 @@ def layer_norm( σ_2 = jnp.var(x, axis=dims, keepdims=True) x̂ = (x - μ) * jax.lax.rsqrt((σ_2 + eps)) - if gamma is not None and beta is not None: - return x̂ * gamma + beta + if gamma is not None: + x̂ = x̂ * gamma + + if beta is not None: + x̂ = x̂ + beta + return x̂ @@ -80,10 +86,13 @@ def group_norm( x̂ = (xx - μ) * jax.lax.rsqrt((σ_2 + eps)) x̂ = x̂.reshape(*x.shape) - if gamma is not None and beta is not None: + if gamma is not None: gamma = jnp.expand_dims(gamma, axis=range(1, x.ndim)) + x̂ *= gamma + + if beta is not None: beta = jnp.expand_dims(beta, axis=range(1, x.ndim)) - x̂ = x̂ * gamma + beta + x̂ += beta return x̂ @@ -95,17 +104,23 @@ class LayerNorm(sk.TreeClass): Args: normalized_shape: the shape of the input to be normalized. eps: a value added to the denominator for numerical stability. - affine: a boolean value that when set to True, this module has learnable affine parameters. + gamma_init_func: a function to initialize the scale. Defaults to ones. + if None, the scale is not trainable. + beta_init_func: a function to initialize the shift. Defaults to zeros. + if None, the shift is not trainable. + key: a random key for initialization. Defaults to jax.random.PRNGKey(0). """ - eps: float = sk.field(callbacks=[Range(0), ScalarLike()]) + eps: float = sk.field(callbacks=[Range(0, min_inclusive=False), ScalarLike()]) def __init__( self, normalized_shape: int | tuple[int, ...], *, eps: float = 1e-5, - affine: bool = True, + gamma_init_func: InitType = "ones", + beta_init_func: InitType = "zeros", + key: jr.KeyArray = jr.PRNGKey(0), ): self.normalized_shape = ( normalized_shape @@ -113,11 +128,8 @@ def __init__( else (normalized_shape,) ) self.eps = eps - self.affine = affine - - # make gamma and beta trainable - self.gamma = jnp.ones(normalized_shape) if self.affine else None - self.beta = jnp.zeros(normalized_shape) if self.affine else None + self.gamma = resolve_init_func(gamma_init_func)(key, self.normalized_shape) + self.beta = resolve_init_func(beta_init_func)(key, self.normalized_shape) def __call__(self, x: jax.Array, **kwargs) -> jax.Array: return layer_norm( @@ -138,7 +150,11 @@ class GroupNorm(sk.TreeClass): in_features : the shape of the input to be normalized. groups : number of groups to separate the channels into. eps : a value added to the denominator for numerical stability. - affine : a boolean value that when set to True, this module has learnable affine parameters. + gamma_init_func: a function to initialize the scale. Defaults to ones. + if None, the scale is not trainable. + beta_init_func: a function to initialize the shift. Defaults to zeros. + if None, the shift is not trainable. + key: a random key for initialization. Defaults to jax.random.PRNGKey(0). """ eps: float = sk.field(callbacks=[Range(0), ScalarLike()]) @@ -149,21 +165,20 @@ def __init__( *, groups: int, eps: float = 1e-5, - affine: bool = True, + gamma_init_func: InitType = "ones", + beta_init_func: InitType = "zeros", + key: jr.KeyArray = jr.PRNGKey(0), ): self.in_features = positive_int_cb(in_features) self.groups = positive_int_cb(groups) - self.affine = affine self.eps = eps # needs more info for checking if in_features % groups != 0: - msg = f"in_features must be divisible by groups. Got {in_features} and {groups}" - raise ValueError(msg) + raise ValueError(f"{in_features} must be divisible by {groups=}.") - # make gamma and beta trainable - self.gamma = jnp.ones(self.in_features) if self.affine else None - self.beta = jnp.zeros(self.in_features) if self.affine else None + self.gamma = resolve_init_func(gamma_init_func)(key, (in_features,)) + self.beta = resolve_init_func(beta_init_func)(key, (in_features,)) def __call__(self, x: jax.Array, **k) -> jax.Array: return group_norm( @@ -183,7 +198,11 @@ class InstanceNorm(GroupNorm): Args: in_features : the shape of the input to be normalized. eps : a value added to the denominator for numerical stability. - affine : a boolean value that when set to True, this module has learnable affine parameters. + gamma_init_func: a function to initialize the scale. Defaults to ones. + if None, the scale is not trainable. + beta_init_func: a function to initialize the shift. Defaults to zeros. + if None, the shift is not trainable. + key: a random key for initialization. Defaults to jax.random.PRNGKey(0). """ def __init__( @@ -191,33 +210,86 @@ def __init__( in_features: int, *, eps: float = 1e-5, - affine: bool = True, + gamma_init_func: InitType = "ones", + beta_init_func: InitType = "zeros", + key: jr.KeyArray = jr.PRNGKey(0), ): super().__init__( in_features=in_features, groups=in_features, eps=eps, - affine=affine, + gamma_init_func=gamma_init_func, + beta_init_func=beta_init_func, + key=key, ) -class BatchNormState(NamedTuple): +class BatchNormState(sk.TreeClass): running_mean: jax.Array running_var: jax.Array +def _batchnorm_impl( + x: jax.Array, + state: BatchNormState, + momentum: float = 0.1, + eps: float = 1e-3, + gamma: jax.Array = None, + beta: jax.Array = None, + evalution: bool = False, + axis: int = 1, +): + # reduce over axis=1 + broadcast_shape = [1] * x.ndim + broadcast_shape[axis] = x.shape[axis] + + def bn_eval_step(x, state): + run_mean, run_var = state.running_mean, state.running_var + run_mean = jnp.reshape(run_mean, broadcast_shape) + run_var = jnp.reshape(run_var, broadcast_shape) + output = (x - run_mean) / jnp.sqrt(run_var + eps) + + return output, state + + def bn_train_step(x, state): + # maybe support axes option + run_mean, run_var = state.running_mean, state.running_var + axes = list(range(x.ndim)) + with jax.ensure_compile_time_eval(): + del axes[axis] + batch_mean = jnp.mean(x, axis=axes, keepdims=True) + batch_var = jnp.mean(jnp.square(x), axis=axes, keepdims=True) - batch_mean**2 + output = (x - batch_mean) / jnp.sqrt(batch_var + eps) + run_mean = momentum * run_mean + (1 - momentum) * jnp.squeeze(batch_mean) + run_var = momentum * run_var + (1 - momentum) * jnp.squeeze(batch_var) + return output, BatchNormState(run_mean, run_var) + + output, state = jax.lax.cond(evalution, bn_eval_step, bn_train_step, x, state) + + state = jax.lax.stop_gradient(state) + + if gamma is not None: + output *= jnp.reshape(gamma, broadcast_shape) + + if beta is not None: + output += jnp.reshape(beta, broadcast_shape) + + return output, state + + @custom_vmap def batchnorm( x: jax.Array, - state: tuple[jax.Array, jax.Array], - *, + state: BatchNormState, momentum: float = 0.1, eps: float = 1e-5, gamma: jax.Array | None = None, beta: jax.Array | None = None, - track_running_stats: bool = False, -): - del momentum, eps, gamma, beta, track_running_stats + evaluation: bool = False, + axis: int = 1, +) -> tuple[jax.Array, BatchNormState]: + del momentum, eps, gamma, beta, evaluation, axis + # no-op when unbatched return x, state @@ -226,40 +298,106 @@ def _( axis_size, in_batched, x: jax.Array, - state: tuple[jax.Array, jax.Array], - *, - momentum: float = 0.1, - eps: float = 1e-5, - track_running_stats: bool = True, -): - run_mean, run_var = state + state: BatchNormState, + momentum: float = 0.99, + eps: float = 1e-3, + gamma: jax.Array | None = None, + beta: jax.Array | None = None, + evaluation: bool = True, + axis: int = 1, +) -> tuple[jax.Array, BatchNormState]: + output = _batchnorm_impl( + x=x, + state=state, + momentum=momentum, + eps=eps, + gamma=gamma, + beta=beta, + evalution=evaluation, + axis=axis, + ) + return output, (True, BatchNormState(True, True)) - axes = [0] + list(range(2, x.ndim)) - batch_mean, batch_var = jnp.mean(x, axis=axes), jnp.var(x, axis=axes) +class BatchNorm(sk.TreeClass): + """Applies normalization over batched inputs` - run_mean = jnp.where( - track_running_stats, - (1 - momentum) * run_mean + momentum * batch_mean, - batch_mean, - ) + Works under ``jax.vmap(BatchNorm(...), in_axes=(0, None))``, otherwise will be a no-op. - run_var = jnp.where( - track_running_stats, - (1 - momentum) * run_var + momentum * batch_var * (axis_size / (axis_size - 1)), - batch_var, - ) - x_normalized = (x - batch_mean) * jax.lax.rsqrt(batch_var + eps) - return (x_normalized, (run_mean, run_var)), (True, (True, True)) + Evaluation behavior: + ``output = (x - running_mean) / sqrt(running_var + eps)`` + Training behavior: + ``output = (x - batch_mean) / sqrt(batch_var + eps)`` + ``running_mean = momentum * running_mean + (1 - momentum) * batch_mean`` + ``running_var = momentum * running_var + (1 - momentum) * batch_var`` -class BatchNorm(sk.TreeClass): - in_features: int = sk.field(callbacks=[IsInstance(int), Range(1)]) - momentum: float = sk.field(callbacks=[Range(0, 1), ScalarLike()]) - eps: float = sk.field(callbacks=[Range(0), ScalarLike()]) - track_running_stats: bool = sk.field(callbacks=[IsInstance(bool)]) + Args: + in_features : the shape of the input to be normalized. + momentum : the value used for the ``running_mean`` and ``running_var`` + computation. must be a number between ``0`` and ``1``. + eps : a value added to the denominator for numerical stability. + gamma_init_func: a function to initialize the scale. Defaults to ones. + if None, the scale is not trainable. + beta_init_func: a function to initialize the shift. Defaults to zeros. + if None, the shift is not trainable. + axis: the axis that should be normalized. Defaults to 1. + evaluation : a boolean value that when set to True, this module will run in + evaluation mode. In this case, this module will always use the running + estimates of the batch statistics during training. + + Note: + https://keras.io/api/layers/normalization_layers/batch_normalization/ + """ - def __post_init__(self): - self.state = BatchNormState( - jnp.zeros(self.in_features), jnp.ones(self.in_features) + def __init__( + self, + in_features: int, + *, + momentum: float = 0.99, + eps: float = 1e-3, + gamma_init_func: InitType = "ones", + beta_init_func: InitType = "zeros", + axis: int = 1, + evaluation: bool = False, + key: jr.KeyArray = jr.PRNGKey(0), + ) -> None: + self.in_features = in_features + self.momentum = momentum + self.eps = eps + self.gamma = resolve_init_func(gamma_init_func)(key, (in_features,)) + self.beta = resolve_init_func(beta_init_func)(key, (in_features,)) + self.axis = axis + self.evaluation = evaluation + + def __call__( + self, + x: jax.Array, + state: BatchNormState | None = None, + **k, + ) -> jax.Array: + state = sk.tree_state(self) if state is None else state + + x, state = batchnorm( + x, + state, + self.momentum, + self.eps, + self.gamma, + self.beta, + self.evaluation, + self.axis, ) + return x, state + + +@tree_evaluation.def_evalutation(BatchNorm) +def _(batchnorm: BatchNorm) -> BatchNorm: + return batchnorm.at["evaluation"].set(True) + + +@tree_state.def_state(BatchNorm) +def batchnorm_init_state(batchnorm: BatchNorm, _) -> BatchNormState: + running_mean = jnp.zeros([batchnorm.in_features]) + running_var = jnp.ones([batchnorm.in_features]) + return BatchNormState(running_mean, running_var) diff --git a/serket/nn/recurrent.py b/serket/nn/recurrent.py index 84832a9..6d33cb2 100644 --- a/serket/nn/recurrent.py +++ b/serket/nn/recurrent.py @@ -25,6 +25,7 @@ import serket as sk from serket.nn.activation import ActivationType, resolve_activation from serket.nn.initialization import InitType +from serket.nn.state import tree_state from serket.nn.utils import ( DilationType, KernelSizeType, @@ -62,13 +63,6 @@ class RNNCell(sk.TreeClass): def __call__(self, x: jax.Array, state: RNNState, **k) -> RNNState: ... - @abc.abstractclassmethod - def init_state(self, spatial_shape: tuple[int, ...]) -> RNNState: - # return the initial state of the RNN for a given input - # for non-spatial RNNs, output shape is (hidden_features,) - # for spatial RNNs, output shape is (hidden_features, *spatial_shape) - ... - @property @abc.abstractclassmethod def spatial_ndim(self) -> int: @@ -104,8 +98,10 @@ def __init__( key: the key to use to initialize the weights Example: + >>> import serket as sk + >>> import jax.numpy as jnp >>> cell = SimpleRNNCell(10, 20) # 10-dimensional input, 20-dimensional hidden state - >>> rnn_state = cell.init_state() # 20-dimensional hidden state + >>> rnn_state = sk.tree_state(cell) # 20-dimensional hidden state >>> x = jnp.ones((10,)) # 10 features >>> result = cell(x, rnn_state) >>> result.hidden_state.shape # 20 features @@ -120,7 +116,7 @@ def __init__( self.hidden_features = positive_int_cb(hidden_features) self.act_func = resolve_activation(act_func) - in_to_hidden = sk.nn.Linear( + i2h = sk.nn.Linear( in_features, hidden_features, weight_init_func=weight_init_func, @@ -128,7 +124,7 @@ def __init__( key=k1, ) - hidden_to_hidden = sk.nn.Linear( + h2h = sk.nn.Linear( hidden_features, hidden_features, weight_init_func=recurrent_weight_init_func, @@ -136,7 +132,8 @@ def __init__( key=k2, ) - self.in_and_hidden_to_hidden = sk.nn.MergeLinear(in_to_hidden, hidden_to_hidden) + self.ih2h_weight = jnp.concatenate([i2h.weight, h2h.weight], axis=0) + self.ih2h_bias = i2h.bias @property def spatial_ndim(self) -> int: @@ -148,13 +145,14 @@ def __call__(self, x: jax.Array, state: SimpleRNNState, **k) -> SimpleRNNState: if not isinstance(state, SimpleRNNState): raise TypeError(f"Expected {state=} to be an instance of `SimpleRNNState`") - h = self.act_func(self.in_and_hidden_to_hidden(x, state.hidden_state)) - return SimpleRNNState(h) + ih = jnp.concatenate([x, state.hidden_state], axis=-1) + h = ih @ self.ih2h_weight + self.ih2h_bias + return SimpleRNNState(self.act_func(h)) - def init_state(self, spatial_dim: tuple[int, ...] = ()) -> SimpleRNNState: - del spatial_dim - shape = (self.hidden_features,) - return SimpleRNNState(jnp.zeros(shape)) + +@tree_state.def_state(SimpleRNNCell) +def simple_rnn_init_state(cell: SimpleRNNCell, _) -> SimpleRNNState: + return SimpleRNNState(jnp.zeros([cell.hidden_features])) class DenseState(RNNState): @@ -174,8 +172,10 @@ class DenseCell(RNNCell): key: the key to use to initialize the weights Example: + >>> import serket as sk + >>> import jax.numpy as jnp >>> cell = DenseCell(10, 20) # 10-dimensional input, 20-dimensional hidden state - >>> dummy_state = cell.init_state() # 20-dimensional hidden state + >>> dummy_state = sk.tree_state(cell) # 20-dimensional hidden state >>> x = jnp.ones((10,)) # 10 features >>> result = cell(x, dummy_state) >>> result.hidden_state.shape # 20 features @@ -217,10 +217,10 @@ def __call__(self, x: jax.Array, state: DenseState, **k) -> DenseState: h = self.act_func(self.in_to_hidden(x)) return DenseState(h) - def init_state(self, spatial_dim: tuple[int, ...] = ()) -> DenseState: - del spatial_dim - shape = (self.hidden_features,) - return DenseState(jnp.empty(shape)) # dummy state + +@tree_state.def_state(DenseCell) +def dense_init_state(cell: DenseCell, _) -> DenseState: + return DenseState(jnp.empty([cell.hidden_features])) class LSTMState(RNNState): @@ -264,7 +264,7 @@ def __init__( self.act_func = resolve_activation(act_func) self.recurrent_act_func = resolve_activation(recurrent_act_func) - in_to_hidden = sk.nn.Linear( + i2h = sk.nn.Linear( in_features, hidden_features * 4, weight_init_func=weight_init_func, @@ -272,7 +272,7 @@ def __init__( key=k1, ) - hidden_to_hidden = sk.nn.Linear( + h2h = sk.nn.Linear( hidden_features, hidden_features * 4, weight_init_func=recurrent_weight_init_func, @@ -280,7 +280,8 @@ def __init__( key=k2, ) - self.in_and_hidden_to_hidden = sk.nn.MergeLinear(in_to_hidden, hidden_to_hidden) + self.ih2h_weight = jnp.concatenate([i2h.weight, h2h.weight], axis=0) + self.ih2h_bias = i2h.bias @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") @ft.partial(validate_axis_shape, attribute_name="in_features", axis=0) @@ -289,7 +290,8 @@ def __call__(self, x: jax.Array, state: LSTMState, **k) -> LSTMState: raise TypeError(f"Expected {state=} to be an instance of `LSTMState`") h, c = state.hidden_state, state.cell_state - h = self.in_and_hidden_to_hidden(x, h) + ih = jnp.concatenate([x, h], axis=-1) + h = ih @ self.ih2h_weight + self.ih2h_bias i, f, g, o = jnp.split(h, 4, axis=-1) i = self.recurrent_act_func(i) f = self.recurrent_act_func(f) @@ -309,6 +311,12 @@ def spatial_ndim(self) -> int: return 0 +@tree_state.def_state(LSTMCell) +def lstm_init_state(cell: LSTMCell, _) -> LSTMState: + shape = [cell.hidden_features] + return LSTMState(jnp.zeros(shape), jnp.zeros(shape)) + + class GRUState(RNNState): ... @@ -384,10 +392,10 @@ def __call__(self, x: jax.Array, state: GRUState, **k) -> GRUState: h = (1 - u) * o + u * h return GRUState(hidden_state=h) - def init_state(self, spatial_dim: tuple[int, ...]) -> GRUState: - del spatial_dim - shape = (self.hidden_features,) - return GRUState(jnp.zeros(shape, dtype=jnp.float32)) + +@tree_state.def_state(GRUCell) +def gru_init_state(cell: GRUCell, _) -> GRUState: + return GRUState(jnp.zeros([cell.hidden_features])) # Spatial RNN @@ -487,15 +495,33 @@ def __call__(self, x: jax.Array, state: ConvLSTMNDState, **k) -> ConvLSTMNDState h = o * self.act_func(c) return ConvLSTMNDState(h, c) - def init_state(self, spatial_dim: tuple[int, ...]) -> ConvLSTMNDState: - msg = f"Expected spatial_dim to be a tuple of length {self.spatial_ndim}, got {spatial_dim}" - assert len(spatial_dim) == self.spatial_ndim, msg - shape = (self.hidden_features, *spatial_dim) - return ConvLSTMNDState(jnp.zeros(shape), jnp.zeros(shape)) + +@tree_state.def_state(ConvLSTMNDCell) +def conv_lstm_init_state(cell: ConvLSTMNDCell, x: jax.Array | None) -> ConvLSTMNDState: + if not (hasattr(x, "ndim") and hasattr(x, "shape")): + raise TypeError( + f"Expected {x=} to have ndim and shape attributes.", + "To initialize the `ConvLSTMNDCell` state.\n" + "pass a single sample array to `tree_state` second argument.", + ) + + if x.ndim != cell.spatial_ndim + 1: + raise ValueError( + f"{x.ndim=} != {(cell.spatial_ndim + 1)=}.", + "Expected input to have shape (channel, *spatial_dim)." + "Pass a single sample array to `tree_state", + ) + + spatial_dim = x.shape[1:] + if len(spatial_dim) != cell.spatial_ndim: + raise ValueError(f"{len(spatial_dim)=} != {cell.spatial_ndim=}.") + shape = (cell.hidden_features, *spatial_dim) + return ConvLSTMNDState(jnp.zeros(shape), jnp.zeros(shape)) class ConvLSTM1DCell(ConvLSTMNDCell): """1D Convolution LSTM cell that defines the update rule for the hidden state and cell state + Args: in_features: Number of input features hidden_features: Number of output features @@ -510,7 +536,6 @@ class ConvLSTM1DCell(ConvLSTMNDCell): act_func: Activation function recurrent_act_func: Recurrent activation function key: PRNG key - spatial_ndim: Number of spatial dimensions. Note: https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM1D @@ -557,6 +582,7 @@ def spatial_ndim(self) -> int: class ConvLSTM2DCell(ConvLSTMNDCell): """2D Convolution LSTM cell that defines the update rule for the hidden state and cell state + Args: in_features: Number of input features hidden_features: Number of output features @@ -571,7 +597,6 @@ class ConvLSTM2DCell(ConvLSTMNDCell): act_func: Activation function recurrent_act_func: Recurrent activation function key: PRNG key - spatial_ndim: Number of spatial dimensions. Note: https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM1D @@ -618,6 +643,7 @@ def spatial_ndim(self) -> int: class ConvLSTM3DCell(ConvLSTMNDCell): """3D Convolution LSTM cell that defines the update rule for the hidden state and cell state + Args: in_features: Number of input features hidden_features: Number of output features @@ -632,7 +658,6 @@ class ConvLSTM3DCell(ConvLSTMNDCell): act_func: Activation function recurrent_act_func: Recurrent activation function key: PRNG key - spatial_ndim: Number of spatial dimensions. Note: https://www.tensorflow.org/api_docs/python/tf/keras/layers/ConvLSTM1D @@ -683,6 +708,7 @@ class ConvGRUNDState(RNNState): class ConvGRUNDCell(RNNCell): """Convolution GRU cell that defines the update rule for the hidden state and cell state + Args: in_features: Number of input features hidden_features: Number of output features @@ -698,7 +724,6 @@ class ConvGRUNDCell(RNNCell): recurrent_act_func: Recurrent activation function key: PRNG key spatial_ndim: Number of spatial dimensions. - """ def __init__( @@ -755,7 +780,7 @@ def __init__( @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") def __call__(self, x: jax.Array, state: ConvGRUNDState, **k) -> ConvGRUNDState: if not isinstance(state, ConvGRUNDState): - raise TypeError(f"Expected {state=} to be an instance of GRUState") + raise TypeError(f"Expected {state=} to be an instance of `GRUState`") h = state.hidden_state xe, xu, xo = jnp.split(self.in_to_hidden(x), 3, axis=0) @@ -766,15 +791,35 @@ def __call__(self, x: jax.Array, state: ConvGRUNDState, **k) -> ConvGRUNDState: h = (1 - u) * o + u * h return ConvGRUNDState(hidden_state=h) - def init_state(self, spatial_dim: tuple[int, ...]) -> ConvGRUNDState: - msg = f"Expected spatial_dim to be a tuple of length {self.spatial_ndim}, got {spatial_dim}" - assert len(spatial_dim) == self.spatial_ndim, msg - shape = (self.hidden_features, *spatial_dim) - return ConvGRUNDState(hidden_state=jnp.zeros(shape)) + +@tree_state.def_state(ConvGRUNDCell) +def conv_gru_init_state(cell: ConvGRUNDCell, x: jax.Array | None) -> ConvGRUNDState: + if not (hasattr(x, "ndim") and hasattr(x, "shape")): + # maybe the input is not an array + raise TypeError( + f"Expected {x=} to have ndim and shape attributes.", + "To initialize the `ConvGRUNDCell` state.\n" + "pass a single sample array to `tree_state` second argument.", + ) + + if x.ndim != cell.spatial_ndim + 1: + # channel, *spatial_dim + raise ValueError( + f"{x.ndim=} != {(cell.spatial_ndim + 1)=}.", + "Expected input to have shape (channel, *spatial_dim)." + "Pass a single sample array to `tree_state", + ) + + spatial_dim = x.shape[1:] + if len(spatial_dim) != cell.spatial_ndim: + raise ValueError(f"{len(spatial_dim)=} != {cell.spatial_ndim=}.") + shape = (cell.hidden_features, *spatial_dim) + return ConvGRUNDState(jnp.zeros(shape), jnp.zeros(shape)) class ConvGRU1DCell(ConvGRUNDCell): """1D Convolution GRU cell that defines the update rule for the hidden state and cell state + Args: in_features: Number of input features hidden_features: Number of output features @@ -790,7 +835,6 @@ class ConvGRU1DCell(ConvGRUNDCell): recurrent_act_func: Recurrent activation function key: PRNG key spatial_ndim: Number of spatial dimensions. - """ def __init__( @@ -893,6 +937,7 @@ def spatial_ndim(self) -> int: class ConvGRU3DCell(ConvGRUNDCell): """3D Convolution GRU cell that defines the update rule for the hidden state and cell state + Args: in_features: Number of input features hidden_features: Number of output features @@ -907,8 +952,6 @@ class ConvGRU3DCell(ConvGRUNDCell): act_func: Activation function recurrent_act_func: Recurrent activation function key: PRNG key - spatial_ndim: Number of spatial dimensions. - """ def __init__( @@ -966,7 +1009,15 @@ class ScanRNN(sk.TreeClass): >>> cell = SimpleRNNCell(10, 20) # 10-dimensional input, 20-dimensional hidden state >>> rnn = ScanRNN(cell) >>> x = jnp.ones((5, 10)) # 5 timesteps, 10 features - >>> result = rnn(x) # 20 features + >>> result, state = rnn(x) # 20 features + >>> print(result.shape) + (20,) + >>> cell = SimpleRNNCell(10, 20) + >>> rnn = ScanRNN(cell, return_sequences=True) + >>> x = jnp.ones((5, 10)) # 5 timesteps, 10 features + >>> result, state = rnn(x) # 5 timesteps, 20 features + >>> print(result.shape) + (5, 20) """ # cell: RNN @@ -994,7 +1045,20 @@ def __call__( state: RNNState | None = None, backward_state: RNNState | None = None, **k, - ) -> jax.Array: + ) -> tuple[jax.Array, tuple[RNNState, RNNState] | RNNState]: + """Scans the RNN cell over a sequence. + + Args: + x: the input sequence. + state: the initial state. if None, a zero state is used. + backward_state: the initial backward state. if None, a zero state is used. + + Returns: + the output sequence and the final two states tuple if backward_cell + is not ``None``, otherwise return the final state of the forward + cell. + """ + if not isinstance(state, (RNNState, type(None))): raise TypeError(f"Expected state to be an instance of RNNState, {state=}") @@ -1013,20 +1077,29 @@ def __call__( f"Expected x to have shape (timesteps, {self.cell.in_features}," f"{'*'*self.cell.spatial_ndim}), got {x.shape=}" ) - - state = state or self.cell.init_state(x.shape[2:]) + # pass a sample not the whole sequence + state = state or tree_state(self.cell, x[0]) if self.backward_cell is not None and backward_state is None: - backward_state = self.backward_cell.init_state(x.shape[2:]) + # pass a sample not the whole sequence + backward_state = tree_state(self.backward_cell, x[0]) scan_func = _accumulate_scan if self.return_sequences else _no_accumulate_scan - result = scan_func(x, self.cell, state) + result, state = scan_func(x, self.cell, state) + + states = state if self.backward_cell is not None: - back_result = scan_func(x, self.backward_cell, backward_state) + backward_result, backward_state = scan_func( + x, + self.backward_cell, + backward_state, + ) + states = (state, backward_state) concat_axis = int(self.return_sequences) - result = jnp.concatenate((result, back_result), axis=concat_axis) - return result + result = jnp.concatenate((result, backward_result), axis=concat_axis) + + return result, states def _accumulate_scan( @@ -1034,14 +1107,17 @@ def _accumulate_scan( cell: RNNCell, state: RNNState, reverse: bool = False, -) -> jax.Array: +) -> tuple[jax.Array, RNNState]: def scan_func(carry, x): state = cell(x, state=carry) return state, state x = jnp.flip(x, axis=0) if reverse else x # flip over time axis result = jax.lax.scan(scan_func, state, x)[1].hidden_state - return jnp.flip(result, axis=-1) if reverse else result + carry, result = jax.lax.scan(scan_func, state, x) + result = result.hidden_state + result = jnp.flip(result, axis=-1) if reverse else result + return result, carry def _no_accumulate_scan( @@ -1055,4 +1131,6 @@ def scan_func(carry, x): return state, None x = jnp.flip(x, axis=0) if reverse else x - return jax.lax.scan(scan_func, state, x)[0].hidden_state + carry, _ = jax.lax.scan(scan_func, state, x) + result = carry.hidden_state + return result, carry diff --git a/serket/nn/state.py b/serket/nn/state.py new file mode 100644 index 0000000..d157a68 --- /dev/null +++ b/serket/nn/state.py @@ -0,0 +1,91 @@ +# Copyright 2023 Serket authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Define dispatchers for custom tree state.""" + +from __future__ import annotations + +import functools as ft +from typing import Any, Callable, TypeVar + +import jax + +import serket as sk + +T = TypeVar("T") + + +class NoState(sk.TreeClass): + """No state placeholder.""" + + def __init__(self, _: Any, __: Any): + del _, __ + + +def tree_state(tree: T, array: jax.Array | None = None) -> T: + """Build state for a tree of layers. + + Some layers require state to be initialized before training. For example, + `BatchNorm` layers require `running_mean` and `running_var` to be initialized + before training. This function initializes the state for a tree of layers, + based on the layer defined ``state`` rule using ``tree_state.def_state``. + + Args: + tree: A tree of layers. + array: An array to use for initializing state required by some layers + (e.g. ConvGRUNDCell). default: ``None``. + + Returns: + A tree of state leaves if it has state, otherwise ``None``. + + Example: + >>> import jax.numpy as jnp + >>> import serket as sk + >>> tree = [1, 2, sk.nn.BatchNorm(5)] + >>> sk.tree_state(tree) + [NoState(), NoState(), BatchNormState( + running_mean=f32[5](μ=0.00, σ=0.00, ∈(0.00,0.00)), + running_var=f32[5](μ=1.00, σ=0.00, ∈(1.00,1.00)) + )] + + Example: + >>> # define state initialization rule for a custom layer + >>> import jax + >>> import serket as sk + >>> class LayerWithState(sk.TreeClass): + ... pass + >>> # state function accept the `layer` and optional input array as arguments + >>> @sk.tree_state.def_state(LayerWithState) + ... def _(leaf, _): + ... del _ # array is not used + ... return "some state" + >>> sk.tree_state(LayerWithState()) + 'some state' + >>> sk.tree_state(LayerWithState(), jax.numpy.ones((1, 1))) + 'some state' + """ + + def is_leaf(x: Callable[[Any], bool]) -> bool: + types = set(tree_state.state_dispatcher.registry.keys()) + types.discard(object) + return isinstance(x, tuple(types)) + + def dispatch_func(node): + return tree_state.state_dispatcher(node, array) + + return jax.tree_map(dispatch_func, tree, is_leaf=is_leaf) + + +tree_state.state_dispatcher = ft.singledispatch(NoState) +tree_state.def_state = tree_state.state_dispatcher.register diff --git a/serket/nn/utils.py b/serket/nn/utils.py index 0ba2a1f..4908498 100644 --- a/serket/nn/utils.py +++ b/serket/nn/utils.py @@ -15,6 +15,7 @@ from __future__ import annotations import functools as ft +import operator as op from typing import Any, Sequence, Tuple, Union import jax @@ -184,11 +185,17 @@ class Range(sk.TreeClass): min_val: float = -float("inf") max_val: float = float("inf") + min_inclusive: bool = True + max_inclusive: bool = True def __call__(self, value: Any): - if self.min_val <= value <= self.max_val: + lop, ls = (op.ge, "[") if self.min_inclusive else (op.gt, "(") + rop, rs = (op.le, "]") if self.max_inclusive else (op.lt, ")") + + if lop(value, self.min_val) and rop(value, self.max_val): return value - raise ValueError(f"Not in range[{self.min_val}, {self.max_val}] got {value=}.") + + raise ValueError(f"Not in {ls}{self.min_val}, {self.max_val}{rs} got {value=}.") class IsInstance(sk.TreeClass): diff --git a/tests/test_linear.py b/tests/test_linear.py index a01b137..6556611 100644 --- a/tests/test_linear.py +++ b/tests/test_linear.py @@ -26,7 +26,6 @@ GeneralLinear, Identity, Linear, - MergeLinear, Multilinear, ) @@ -144,20 +143,3 @@ def test_general_linear(): with pytest.raises(ValueError): GeneralLinear(in_features=(1,), in_axes=(0, -3), out_features=5) - - -def test_merge_linear(): - layer1 = Linear(5, 6) # 5 input features, 6 output features - layer2 = Linear(7, 6) # 7 input features, 6 output features - merged_layer = MergeLinear(layer1, layer2) # 12 input features, 6 output features - x1 = jnp.ones([1, 5]) # 1 sample, 5 features - x2 = jnp.ones([1, 7]) # 1 sample, 7 features - y = merged_layer(x1, x2) - z = layer1(x1) + layer2(x2) - npt.assert_allclose(y, z, atol=1e-6) - - with pytest.raises(ValueError): - # output features of layer1 and layer2 mismatch - l1 = Linear(5, 6) - l2 = Linear(7, 8) - MergeLinear(l1, l2) diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 8af1d6d..7492e39 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -12,15 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + +import jax import jax.numpy as jnp import numpy.testing as npt import pytest -from serket.nn import GroupNorm, InstanceNorm, LayerNorm +import serket as sk +from serket.nn import BatchNorm, GroupNorm, InstanceNorm, LayerNorm + +os.environ["KERAS_BACKEND"] = "jax" def test_LayerNorm(): - layer = LayerNorm((5, 2), affine=False) + layer = LayerNorm((5, 2), beta_init_func=None, gamma_init_func=None) x = jnp.array( [ @@ -96,7 +102,7 @@ def test_InstanceNorm(): npt.assert_allclose(layer(x), y, atol=1e-5) - layer = InstanceNorm(in_features=3, affine=False) + layer = InstanceNorm(in_features=3, gamma_init_func=None, beta_init_func=None) npt.assert_allclose(layer(x), y, atol=1e-5) @@ -210,10 +216,39 @@ def test_group_norm(): layer = GroupNorm(in_features=-1, groups=0) -# def test_lazy_normalization(): -# layer = GroupNorm(None, groups=1) -# assert layer(jnp.ones([1, 2, 3, 4])).shape == (1, 2, 3, 4) +@pytest.mark.parametrize("axis", [0, 1, 2, 3]) +def test_batchnorm(axis): + import math + + from keras_core.layers import BatchNormalization + + mat_jax = lambda n: jnp.arange(1, math.prod(n) + 1).reshape(*n).astype(jnp.float32) + + x_keras = mat_jax((5, 10, 7, 8)) + + bn_keras = BatchNormalization(axis=axis, momentum=0.5, center=False, scale=False) + + for i in range(5): + x_keras = bn_keras(x_keras, training=True) + + bn_sk = BatchNorm( + x_keras.shape[axis], + momentum=0.5, + axis=axis, + beta_init_func=None, + gamma_init_func=None, + ) + state = sk.tree_state(bn_sk) + x_sk = mat_jax((5, 10, 7, 8)) + + for _ in range(5): + x_sk, state = jax.vmap(bn_sk, in_axes=(0, None))(x_sk, state) + + npt.assert_allclose(x_keras, x_sk, atol=1e-5) + npt.assert_allclose(bn_keras.moving_mean, state.running_mean, atol=1e-5) + npt.assert_allclose(bn_keras.moving_variance, state.running_var, rtol=1e-5) + + x_keras = bn_keras(x_keras, training=False) + x_sk, _ = jax.vmap(bn_sk.at["evaluation"].set(True), in_axes=(0, None))(x_sk, state) -# with pytest.raises(ConcretizationTypeError): -# layer = jax.jit(GroupNorm(None, groups=1)) -# layer(jnp.ones([1, 2, 3, 4])) + npt.assert_allclose(x_keras, x_sk, rtol=1e-5) diff --git a/tests/test_rnn.py b/tests/test_rnn.py index 9221723..7f9a0c8 100644 --- a/tests/test_rnn.py +++ b/tests/test_rnn.py @@ -114,10 +114,10 @@ def test_vanilla_rnn(): ) w_combined = jnp.concatenate([w_in_to_hidden, w_hidden_to_hidden], axis=0) - cell = cell.at["in_and_hidden_to_hidden"].at["weight"].set(w_combined) + cell = cell.at["ih2h_weight"].set(w_combined) sk_layer = ScanRNN(cell) y = jnp.array([0.9637042, -0.8282256, 0.7314449]) - npt.assert_allclose(sk_layer(x), y) + npt.assert_allclose(sk_layer(x)[0], y) def test_lstm(): @@ -228,12 +228,13 @@ def test_lstm(): recurrent_weight_init_func="glorot_uniform", ) w_combined = jnp.concatenate([w_in_to_hidden, w_hidden_to_hidden], axis=0) - cell = cell.at["in_and_hidden_to_hidden"].at["weight"].set(w_combined) - cell = cell.at["in_and_hidden_to_hidden"].at["bias"].set(b_hidden_to_hidden) + cell = cell.at["ih2h_weight"].set(w_combined) + cell = cell.at["ih2h_bias"].set(b_hidden_to_hidden) sk_layer = ScanRNN(cell, return_sequences=False) + y = jnp.array([0.18658024, -0.6338659, 0.3445018]) - npt.assert_allclose(y, sk_layer(x), atol=1e-5) + npt.assert_allclose(y, sk_layer(x)[0], atol=1e-5) w_in_to_hidden = jnp.array( [ @@ -327,8 +328,8 @@ def test_lstm(): w_combined = jnp.concatenate([w_in_to_hidden, w_hidden_to_hidden], axis=0) - cell = cell.at["in_and_hidden_to_hidden"].at["weight"].set(w_combined) - cell = cell.at["in_and_hidden_to_hidden"].at["bias"].set(b_hidden_to_hidden) + cell = cell.at["ih2h_weight"].set(w_combined) + cell = cell.at["ih2h_bias"].set(b_hidden_to_hidden) sk_layer = ScanRNN(cell, return_sequences=True) @@ -347,7 +348,7 @@ def test_lstm(): ] ) - npt.assert_allclose(y, sk_layer(x), atol=1e-5) + npt.assert_allclose(y, sk_layer(x)[0], atol=1e-5) cell = LSTMCell( in_features=in_features, @@ -356,7 +357,7 @@ def test_lstm(): ) sk_layer = ScanRNN(cell, return_sequences=True) - assert sk_layer(x).shape == (10, 3) + assert sk_layer(x)[0].shape == (10, 3) def test_gru(): @@ -418,7 +419,7 @@ def test_gru(): cell = cell.at["in_to_hidden"].at["weight"].set(w1) cell = cell.at["hidden_to_hidden"].at["weight"].set(w2) y = jnp.array([[-0.00142191, 0.11011646, 0.1613554]]) - ypred = ScanRNN(cell, return_sequences=True)(jnp.ones([1, 1])) + ypred, _ = ScanRNN(cell, return_sequences=True)(jnp.ones([1, 1])) npt.assert_allclose(y, ypred, atol=1e-4) @@ -586,7 +587,7 @@ def test_conv_lstm1d(): x = jnp.ones([time_steps, in_features, *spatial_dim]) - res_sk = ScanRNN(cell, return_sequences=False)(x) + res_sk, _ = ScanRNN(cell, return_sequences=False)(x) y = jnp.array( [ @@ -609,7 +610,7 @@ def test_conv_lstm1d(): bias_init_func="zeros", ) - res_sk = ScanRNN(cell, return_sequences=False)(x) + res_sk, _ = ScanRNN(cell, return_sequences=False)(x) assert res_sk.shape == (3, 3) @@ -748,22 +749,16 @@ def test_bilstm(): b_hidden_to_hidden_reverse = jnp.array([0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0]) combined_w = jnp.concatenate([w_in_to_hidden, w_hidden_to_hidden], axis=0) - cell = cell.at["in_and_hidden_to_hidden"].at["weight"].set(combined_w) - cell = cell.at["in_and_hidden_to_hidden"].at["bias"].set(b_hidden_to_hidden) + cell = cell.at["ih2h_weight"].set(combined_w) + cell = cell.at["ih2h_bias"].set(b_hidden_to_hidden) combined_w_reverse = jnp.concatenate( [w_in_to_hidden_reverse, w_hidden_to_hidden_reverse], axis=0 ) - reverse_cell = ( - reverse_cell.at["in_and_hidden_to_hidden"].at["weight"].set(combined_w_reverse) - ) - reverse_cell = ( - reverse_cell.at["in_and_hidden_to_hidden"] - .at["bias"] - .set(b_hidden_to_hidden_reverse) - ) + reverse_cell = reverse_cell.at["ih2h_weight"].set(combined_w_reverse) + reverse_cell = reverse_cell.at["ih2h_bias"].set(b_hidden_to_hidden_reverse) - res = ScanRNN(cell, backward_cell=reverse_cell, return_sequences=False)(x) + res, _ = ScanRNN(cell, backward_cell=reverse_cell, return_sequences=False)(x) y = jnp.array([0.35901642, 0.00826644, -0.3015435, -0.13661332]) @@ -794,6 +789,6 @@ def test_dense_cell(): bias_init_func=None, ) x = jnp.ones([10, 10]) - res = ScanRNN(cell=cell)(x) + res, _ = ScanRNN(cell=cell)(x) # 1x10 @ 10x10 => 1x10 npt.assert_allclose(res, jnp.ones([10]) * 10.0) diff --git a/tests/test_utils.py b/tests/test_utils.py index 7b45a69..93f6c5c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -38,7 +38,7 @@ def test_canonicalize_init_func(): assert resolve_init_func("xavier_uniform")(k, (2, 2)).shape == (2, 2) assert isinstance(resolve_init_func(jax.nn.initializers.he_normal()), jtu.Partial) - assert isinstance(resolve_init_func(None), type(None)) + assert isinstance(resolve_init_func(None), jtu.Partial) with pytest.raises(ValueError): resolve_init_func("invalid")