diff --git a/serket/nn/convolution.py b/serket/nn/convolution.py index 4fdeabb..3fec1f3 100644 --- a/serket/nn/convolution.py +++ b/serket/nn/convolution.py @@ -27,7 +27,7 @@ from jax.lax import ConvDimensionNumbers import serket as sk -from serket.nn.initialization import DType, InitType, resolve_init_func +from serket.nn.initialization import DType, InitType, resolve_init from serket.nn.utils import ( DilationType, KernelSizeType, @@ -200,10 +200,10 @@ def __init__( raise ValueError(f"{(out_features % groups == 0)=}") weight_shape = (out_features, in_features // groups, *self.kernel_size) - self.weight = resolve_init_func(self.weight_init)(key, weight_shape, dtype) + self.weight = resolve_init(self.weight_init)(key, weight_shape, dtype) bias_shape = (out_features, *(1,) * self.spatial_ndim) - self.bias = resolve_init_func(self.bias_init)(key, bias_shape, dtype) + self.bias = resolve_init(self.bias_init)(key, bias_shape, dtype) @property @abc.abstractmethod @@ -869,10 +869,10 @@ def __init__( in_features = positive_int_cb(self.in_features) weight_shape = (out_features, in_features // groups, *self.kernel_size) # OIHW - self.weight = resolve_init_func(self.weight_init)(key, weight_shape, dtype) + self.weight = resolve_init(self.weight_init)(key, weight_shape, dtype) bias_shape = (out_features, *(1,) * self.spatial_ndim) - self.bias = resolve_init_func(self.bias_init)(key, bias_shape, dtype) + self.bias = resolve_init(self.bias_init)(key, bias_shape, dtype) @property @abc.abstractmethod @@ -1564,10 +1564,10 @@ def __init__( self.bias_init = bias_init weight_shape = (depth_multiplier * in_features, 1, *self.kernel_size) # OIHW - self.weight = resolve_init_func(self.weight_init)(key, weight_shape, dtype) + self.weight = resolve_init(self.weight_init)(key, weight_shape, dtype) bias_shape = (depth_multiplier * in_features, *(1,) * self.spatial_ndim) - self.bias = resolve_init_func(self.bias_init)(key, bias_shape, dtype) + self.bias = resolve_init(self.bias_init)(key, bias_shape, dtype) @property @abc.abstractmethod @@ -2818,10 +2818,10 @@ def __init__( *out_size, ) - self.weight = resolve_init_func(self.weight_init)(key, weight_shape, dtype) + self.weight = resolve_init(self.weight_init)(key, weight_shape, dtype) bias_shape = (self.out_features, *out_size) - self.bias = resolve_init_func(self.bias_init)(key, bias_shape, dtype) + self.bias = resolve_init(self.bias_init)(key, bias_shape, dtype) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") diff --git a/serket/nn/initialization.py b/serket/nn/initialization.py index fd7f03b..d86ed3a 100644 --- a/serket/nn/initialization.py +++ b/serket/nn/initialization.py @@ -13,8 +13,9 @@ # limitations under the License. from __future__ import annotations -from types import FunctionType -from typing import Any, Callable, Literal, Tuple, Union, get_args +import functools as ft +from collections.abc import Callable +from typing import Any, Literal, Tuple, Union, get_args import jax import jax.nn.initializers as ji @@ -63,19 +64,27 @@ init_map: dict[str, InitType] = dict(zip(get_args(InitLiteral), inits)) -def resolve_init_func(init_func: str | InitFuncType) -> Callable: - if isinstance(init_func, FunctionType): - return jtu.Partial(init_func) +@ft.singledispatch +def resolve_init(init): + raise TypeError(f"Unknown type {type(init)}") - if isinstance(init_func, str): - if init_func in init_map: - return jtu.Partial(init_map[init_func]) - raise ValueError(f"value must be one of ({', '.join(init_map.keys())})") - if init_func is None: - return jtu.Partial(lambda key, shape, dtype=None: None) +@resolve_init.register(str) +def _(init: str): + try: + return jtu.Partial(jax.tree_map(lambda x: x, init_map[init])) + except KeyError: + raise ValueError(f"Unknown {init=}, available init: {list(init_map)}") - raise ValueError("Value must be a string or a function.") + +@resolve_init.register(type(None)) +def _(init: None): + return jtu.Partial(lambda key, shape, dtype=None: None) + + +@resolve_init.register(Callable) +def _(init: Callable): + return jtu.Partial(init) def def_init_entry(key: str, init_func: InitFuncType) -> None: diff --git a/serket/nn/linear.py b/serket/nn/linear.py index 2e8bc1a..6654ec7 100644 --- a/serket/nn/linear.py +++ b/serket/nn/linear.py @@ -27,7 +27,7 @@ ActivationType, resolve_activation, ) -from serket.nn.initialization import DType, InitType, resolve_init_func +from serket.nn.initialization import DType, InitType, resolve_init from serket.nn.utils import maybe_lazy_call, maybe_lazy_init, positive_int_cb T = TypeVar("T") @@ -176,8 +176,8 @@ def __init__( k1, k2 = jr.split(key) weight_shape = (*in_features, out_features) - self.weight = resolve_init_func(weight_init)(k1, weight_shape, dtype) - self.bias = resolve_init_func(bias_init)(k2, (out_features,), dtype) + self.weight = resolve_init(weight_init)(k1, weight_shape, dtype) + self.bias = resolve_init(bias_init)(k2, (out_features,), dtype) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) def __call__(self, *x) -> jax.Array: @@ -332,8 +332,8 @@ def __init__( k1, k2 = jr.split(key) weight_shape = (*in_features, out_features) - self.weight = resolve_init_func(weight_init)(k1, weight_shape, dtype) - self.bias = resolve_init_func(bias_init)(k2, (self.out_features,), dtype) + self.weight = resolve_init(weight_init)(k1, weight_shape, dtype) + self.bias = resolve_init(bias_init)(k2, (self.out_features,), dtype) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) def __call__(self, x: jax.Array) -> jax.Array: diff --git a/serket/nn/normalization.py b/serket/nn/normalization.py index bad2a01..18be0a8 100644 --- a/serket/nn/normalization.py +++ b/serket/nn/normalization.py @@ -23,7 +23,7 @@ import serket as sk from serket.nn.custom_transform import tree_eval, tree_state -from serket.nn.initialization import DType, InitType, resolve_init_func +from serket.nn.initialization import DType, InitType, resolve_init from serket.nn.utils import ( Range, ScalarLike, @@ -153,8 +153,8 @@ def __init__( self.weight_init = weight_init self.bias_init = bias_init - self.gamma = resolve_init_func(weight_init)(key, self.normalized_shape, dtype) - self.beta = resolve_init_func(bias_init)(key, self.normalized_shape, dtype) + self.gamma = resolve_init(weight_init)(key, self.normalized_shape, dtype) + self.beta = resolve_init(bias_init)(key, self.normalized_shape, dtype) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) def __call__(self, x: jax.Array, **k) -> jax.Array: @@ -246,8 +246,8 @@ def __init__( if in_features % groups != 0: raise ValueError(f"{in_features} must be divisible by {groups=}.") - self.weight = resolve_init_func(weight_init)(key, (in_features,), dtype) - self.bias = resolve_init_func(bias_init)(key, (in_features,), dtype) + self.weight = resolve_init(weight_init)(key, (in_features,), dtype) + self.bias = resolve_init(bias_init)(key, (in_features,), dtype) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) def __call__(self, x: jax.Array, **k) -> jax.Array: @@ -531,8 +531,8 @@ def __init__( self.bias_init = bias_init self.axis = axis - self.weight = resolve_init_func(weight_init)(key, (in_features,), dtype=dtype) - self.bias = resolve_init_func(bias_init)(key, (in_features,), dtype=dtype) + self.weight = resolve_init(weight_init)(key, (in_features,), dtype=dtype) + self.bias = resolve_init(bias_init)(key, (in_features,), dtype=dtype) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) def __call__( @@ -620,8 +620,8 @@ def __init__( self.weight_init = weight_init self.bias_init = bias_init self.axis = axis - self.weight = resolve_init_func(weight_init)(key, (in_features,), dtype) - self.bias = resolve_init_func(bias_init)(key, (in_features,), dtype) + self.weight = resolve_init(weight_init)(key, (in_features,), dtype) + self.bias = resolve_init(bias_init)(key, (in_features,), dtype) @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) def __call__( diff --git a/tests/test_utils.py b/tests/test_utils.py index 93f6c5c..71d6481 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,34 +17,34 @@ import jax.tree_util as jtu import pytest -from serket.nn.initialization import resolve_init_func +from serket.nn.initialization import resolve_init from serket.nn.utils import canonicalize def test_canonicalize_init_func(): k = jr.PRNGKey(0) - assert resolve_init_func("he_normal")(k, (2, 2)).shape == (2, 2) - assert resolve_init_func("he_uniform")(k, (2, 2)).shape == (2, 2) - assert resolve_init_func("glorot_normal")(k, (2, 2)).shape == (2, 2) - assert resolve_init_func("glorot_uniform")(k, (2, 2)).shape == (2, 2) - assert resolve_init_func("lecun_normal")(k, (2, 2)).shape == (2, 2) - assert resolve_init_func("lecun_uniform")(k, (2, 2)).shape == (2, 2) - assert resolve_init_func("normal")(k, (2, 2)).shape == (2, 2) - assert resolve_init_func("uniform")(k, (2, 2)).shape == (2, 2) - assert resolve_init_func("ones")(k, (2, 2)).shape == (2, 2) - assert resolve_init_func("zeros")(k, (2, 2)).shape == (2, 2) - assert resolve_init_func("xavier_normal")(k, (2, 2)).shape == (2, 2) - assert resolve_init_func("xavier_uniform")(k, (2, 2)).shape == (2, 2) - - assert isinstance(resolve_init_func(jax.nn.initializers.he_normal()), jtu.Partial) - assert isinstance(resolve_init_func(None), jtu.Partial) + assert resolve_init("he_normal")(k, (2, 2)).shape == (2, 2) + assert resolve_init("he_uniform")(k, (2, 2)).shape == (2, 2) + assert resolve_init("glorot_normal")(k, (2, 2)).shape == (2, 2) + assert resolve_init("glorot_uniform")(k, (2, 2)).shape == (2, 2) + assert resolve_init("lecun_normal")(k, (2, 2)).shape == (2, 2) + assert resolve_init("lecun_uniform")(k, (2, 2)).shape == (2, 2) + assert resolve_init("normal")(k, (2, 2)).shape == (2, 2) + assert resolve_init("uniform")(k, (2, 2)).shape == (2, 2) + assert resolve_init("ones")(k, (2, 2)).shape == (2, 2) + assert resolve_init("zeros")(k, (2, 2)).shape == (2, 2) + assert resolve_init("xavier_normal")(k, (2, 2)).shape == (2, 2) + assert resolve_init("xavier_uniform")(k, (2, 2)).shape == (2, 2) + + assert isinstance(resolve_init(jax.nn.initializers.he_normal()), jtu.Partial) + assert isinstance(resolve_init(None), jtu.Partial) with pytest.raises(ValueError): - resolve_init_func("invalid") + resolve_init("invalid") - with pytest.raises(ValueError): - resolve_init_func(1) + with pytest.raises(TypeError): + resolve_init(1) def test_canonicalize():