Skip to content

Commit

Permalink
resolve_init_func -> resolve_init
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 24, 2023
1 parent 6a62c0b commit 9bf0548
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 54 deletions.
18 changes: 9 additions & 9 deletions serket/nn/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from jax.lax import ConvDimensionNumbers

import serket as sk
from serket.nn.initialization import DType, InitType, resolve_init_func
from serket.nn.initialization import DType, InitType, resolve_init
from serket.nn.utils import (
DilationType,
KernelSizeType,
Expand Down Expand Up @@ -200,10 +200,10 @@ def __init__(
raise ValueError(f"{(out_features % groups == 0)=}")

weight_shape = (out_features, in_features // groups, *self.kernel_size)
self.weight = resolve_init_func(self.weight_init)(key, weight_shape, dtype)
self.weight = resolve_init(self.weight_init)(key, weight_shape, dtype)

bias_shape = (out_features, *(1,) * self.spatial_ndim)
self.bias = resolve_init_func(self.bias_init)(key, bias_shape, dtype)
self.bias = resolve_init(self.bias_init)(key, bias_shape, dtype)

@property
@abc.abstractmethod
Expand Down Expand Up @@ -869,10 +869,10 @@ def __init__(

in_features = positive_int_cb(self.in_features)
weight_shape = (out_features, in_features // groups, *self.kernel_size) # OIHW
self.weight = resolve_init_func(self.weight_init)(key, weight_shape, dtype)
self.weight = resolve_init(self.weight_init)(key, weight_shape, dtype)

bias_shape = (out_features, *(1,) * self.spatial_ndim)
self.bias = resolve_init_func(self.bias_init)(key, bias_shape, dtype)
self.bias = resolve_init(self.bias_init)(key, bias_shape, dtype)

@property
@abc.abstractmethod
Expand Down Expand Up @@ -1564,10 +1564,10 @@ def __init__(
self.bias_init = bias_init

weight_shape = (depth_multiplier * in_features, 1, *self.kernel_size) # OIHW
self.weight = resolve_init_func(self.weight_init)(key, weight_shape, dtype)
self.weight = resolve_init(self.weight_init)(key, weight_shape, dtype)

bias_shape = (depth_multiplier * in_features, *(1,) * self.spatial_ndim)
self.bias = resolve_init_func(self.bias_init)(key, bias_shape, dtype)
self.bias = resolve_init(self.bias_init)(key, bias_shape, dtype)

@property
@abc.abstractmethod
Expand Down Expand Up @@ -2818,10 +2818,10 @@ def __init__(
*out_size,
)

self.weight = resolve_init_func(self.weight_init)(key, weight_shape, dtype)
self.weight = resolve_init(self.weight_init)(key, weight_shape, dtype)

bias_shape = (self.out_features, *out_size)
self.bias = resolve_init_func(self.bias_init)(key, bias_shape, dtype)
self.bias = resolve_init(self.bias_init)(key, bias_shape, dtype)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates)
@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
Expand Down
33 changes: 21 additions & 12 deletions serket/nn/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.
from __future__ import annotations

from types import FunctionType
from typing import Any, Callable, Literal, Tuple, Union, get_args
import functools as ft
from collections.abc import Callable
from typing import Any, Literal, Tuple, Union, get_args

import jax
import jax.nn.initializers as ji
Expand Down Expand Up @@ -63,19 +64,27 @@
init_map: dict[str, InitType] = dict(zip(get_args(InitLiteral), inits))


def resolve_init_func(init_func: str | InitFuncType) -> Callable:
if isinstance(init_func, FunctionType):
return jtu.Partial(init_func)
@ft.singledispatch
def resolve_init(init):
raise TypeError(f"Unknown type {type(init)}")

if isinstance(init_func, str):
if init_func in init_map:
return jtu.Partial(init_map[init_func])
raise ValueError(f"value must be one of ({', '.join(init_map.keys())})")

if init_func is None:
return jtu.Partial(lambda key, shape, dtype=None: None)
@resolve_init.register(str)
def _(init: str):
try:
return jtu.Partial(jax.tree_map(lambda x: x, init_map[init]))
except KeyError:
raise ValueError(f"Unknown {init=}, available init: {list(init_map)}")

raise ValueError("Value must be a string or a function.")

@resolve_init.register(type(None))
def _(init: None):
return jtu.Partial(lambda key, shape, dtype=None: None)


@resolve_init.register(Callable)
def _(init: Callable):
return jtu.Partial(init)


def def_init_entry(key: str, init_func: InitFuncType) -> None:
Expand Down
10 changes: 5 additions & 5 deletions serket/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
ActivationType,
resolve_activation,
)
from serket.nn.initialization import DType, InitType, resolve_init_func
from serket.nn.initialization import DType, InitType, resolve_init
from serket.nn.utils import maybe_lazy_call, maybe_lazy_init, positive_int_cb

T = TypeVar("T")
Expand Down Expand Up @@ -176,8 +176,8 @@ def __init__(

k1, k2 = jr.split(key)
weight_shape = (*in_features, out_features)
self.weight = resolve_init_func(weight_init)(k1, weight_shape, dtype)
self.bias = resolve_init_func(bias_init)(k2, (out_features,), dtype)
self.weight = resolve_init(weight_init)(k1, weight_shape, dtype)
self.bias = resolve_init(bias_init)(k2, (out_features,), dtype)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates)
def __call__(self, *x) -> jax.Array:
Expand Down Expand Up @@ -332,8 +332,8 @@ def __init__(

k1, k2 = jr.split(key)
weight_shape = (*in_features, out_features)
self.weight = resolve_init_func(weight_init)(k1, weight_shape, dtype)
self.bias = resolve_init_func(bias_init)(k2, (self.out_features,), dtype)
self.weight = resolve_init(weight_init)(k1, weight_shape, dtype)
self.bias = resolve_init(bias_init)(k2, (self.out_features,), dtype)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates)
def __call__(self, x: jax.Array) -> jax.Array:
Expand Down
18 changes: 9 additions & 9 deletions serket/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import serket as sk
from serket.nn.custom_transform import tree_eval, tree_state
from serket.nn.initialization import DType, InitType, resolve_init_func
from serket.nn.initialization import DType, InitType, resolve_init
from serket.nn.utils import (
Range,
ScalarLike,
Expand Down Expand Up @@ -153,8 +153,8 @@ def __init__(
self.weight_init = weight_init
self.bias_init = bias_init

self.gamma = resolve_init_func(weight_init)(key, self.normalized_shape, dtype)
self.beta = resolve_init_func(bias_init)(key, self.normalized_shape, dtype)
self.gamma = resolve_init(weight_init)(key, self.normalized_shape, dtype)
self.beta = resolve_init(bias_init)(key, self.normalized_shape, dtype)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates)
def __call__(self, x: jax.Array, **k) -> jax.Array:
Expand Down Expand Up @@ -246,8 +246,8 @@ def __init__(
if in_features % groups != 0:
raise ValueError(f"{in_features} must be divisible by {groups=}.")

self.weight = resolve_init_func(weight_init)(key, (in_features,), dtype)
self.bias = resolve_init_func(bias_init)(key, (in_features,), dtype)
self.weight = resolve_init(weight_init)(key, (in_features,), dtype)
self.bias = resolve_init(bias_init)(key, (in_features,), dtype)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates)
def __call__(self, x: jax.Array, **k) -> jax.Array:
Expand Down Expand Up @@ -531,8 +531,8 @@ def __init__(
self.bias_init = bias_init
self.axis = axis

self.weight = resolve_init_func(weight_init)(key, (in_features,), dtype=dtype)
self.bias = resolve_init_func(bias_init)(key, (in_features,), dtype=dtype)
self.weight = resolve_init(weight_init)(key, (in_features,), dtype=dtype)
self.bias = resolve_init(bias_init)(key, (in_features,), dtype=dtype)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates)
def __call__(
Expand Down Expand Up @@ -620,8 +620,8 @@ def __init__(
self.weight_init = weight_init
self.bias_init = bias_init
self.axis = axis
self.weight = resolve_init_func(weight_init)(key, (in_features,), dtype)
self.bias = resolve_init_func(bias_init)(key, (in_features,), dtype)
self.weight = resolve_init(weight_init)(key, (in_features,), dtype)
self.bias = resolve_init(bias_init)(key, (in_features,), dtype)

@ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates)
def __call__(
Expand Down
38 changes: 19 additions & 19 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,34 @@
import jax.tree_util as jtu
import pytest

from serket.nn.initialization import resolve_init_func
from serket.nn.initialization import resolve_init
from serket.nn.utils import canonicalize


def test_canonicalize_init_func():
k = jr.PRNGKey(0)

assert resolve_init_func("he_normal")(k, (2, 2)).shape == (2, 2)
assert resolve_init_func("he_uniform")(k, (2, 2)).shape == (2, 2)
assert resolve_init_func("glorot_normal")(k, (2, 2)).shape == (2, 2)
assert resolve_init_func("glorot_uniform")(k, (2, 2)).shape == (2, 2)
assert resolve_init_func("lecun_normal")(k, (2, 2)).shape == (2, 2)
assert resolve_init_func("lecun_uniform")(k, (2, 2)).shape == (2, 2)
assert resolve_init_func("normal")(k, (2, 2)).shape == (2, 2)
assert resolve_init_func("uniform")(k, (2, 2)).shape == (2, 2)
assert resolve_init_func("ones")(k, (2, 2)).shape == (2, 2)
assert resolve_init_func("zeros")(k, (2, 2)).shape == (2, 2)
assert resolve_init_func("xavier_normal")(k, (2, 2)).shape == (2, 2)
assert resolve_init_func("xavier_uniform")(k, (2, 2)).shape == (2, 2)

assert isinstance(resolve_init_func(jax.nn.initializers.he_normal()), jtu.Partial)
assert isinstance(resolve_init_func(None), jtu.Partial)
assert resolve_init("he_normal")(k, (2, 2)).shape == (2, 2)
assert resolve_init("he_uniform")(k, (2, 2)).shape == (2, 2)
assert resolve_init("glorot_normal")(k, (2, 2)).shape == (2, 2)
assert resolve_init("glorot_uniform")(k, (2, 2)).shape == (2, 2)
assert resolve_init("lecun_normal")(k, (2, 2)).shape == (2, 2)
assert resolve_init("lecun_uniform")(k, (2, 2)).shape == (2, 2)
assert resolve_init("normal")(k, (2, 2)).shape == (2, 2)
assert resolve_init("uniform")(k, (2, 2)).shape == (2, 2)
assert resolve_init("ones")(k, (2, 2)).shape == (2, 2)
assert resolve_init("zeros")(k, (2, 2)).shape == (2, 2)
assert resolve_init("xavier_normal")(k, (2, 2)).shape == (2, 2)
assert resolve_init("xavier_uniform")(k, (2, 2)).shape == (2, 2)

assert isinstance(resolve_init(jax.nn.initializers.he_normal()), jtu.Partial)
assert isinstance(resolve_init(None), jtu.Partial)

with pytest.raises(ValueError):
resolve_init_func("invalid")
resolve_init("invalid")

with pytest.raises(ValueError):
resolve_init_func(1)
with pytest.raises(TypeError):
resolve_init(1)


def test_canonicalize():
Expand Down

0 comments on commit 9bf0548

Please sign in to comment.