Skip to content

Commit

Permalink
add def_init_entry, def_act_entry
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 1, 2023
1 parent 20723a5 commit 4ab21a8
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/API/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

.. autofunction:: tree_state
.. autofunction:: tree_eval
.. autofunction:: def_init_entry
.. autofunction:: def_act_entry

.. toctree::
:maxdepth: 2
Expand Down
4 changes: 4 additions & 0 deletions serket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
)

from . import nn
from .nn.activation import def_act_entry
from .nn.custom_transform import tree_eval, tree_state
from .nn.initialization import def_init_entry

__all__ = (
# general utils
Expand Down Expand Up @@ -81,6 +83,8 @@
"nn",
"tree_eval",
"tree_state",
"def_init_entry",
"def_act_entry",
)


Expand Down
49 changes: 48 additions & 1 deletion serket/nn/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from typing import Callable, Literal, Union, get_args
from typing import Callable, Literal, Protocol, Union, get_args

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -383,6 +383,11 @@ def __call__(self, x: jax.Array, **k) -> jax.Array:
ActivationType = Union[ActivationLiteral, ActivationFunctionType]


class ActivationClassType(Protocol):
def __call__(self, x: jax.typing.ArrayLike) -> jax.Array:
...


def resolve_activation(act_func: ActivationType) -> ActivationFunctionType:
# in case the user passes a trainable activation function
# we need to make a copy of it to avoid unpredictable side effects
Expand All @@ -391,3 +396,45 @@ def resolve_activation(act_func: ActivationType) -> ActivationFunctionType:
return act_map[act_func]()
raise ValueError(f"Unknown {act_func=}, available activations: {list(act_map)}")
return act_func


def def_act_entry(key: str, act_func: ActivationClassType) -> None:
"""Register a custom activation function key for use in ``serket`` layers.
Args:
key: The key to register the function under.
act_func: a class with a ``__call__`` method that takes a single argument
and returns a ``jax`` array.
Note:
The registered key can be used in any of ``serket`` ``act_*`` arguments as
substitution for the function.
Note:
By design, activation functions can be passed directly to ``serket`` layers
with the ``act_func`` argument. This function is useful if you want to
represent activation functions as a string in a configuration file.
Example:
>>> import serket as sk
>>> import math
>>> import jax.numpy as jnp
>>> @sk.autoinit
... class MyTrainableActivation(sk.TreeClass):
... my_param: float = 10.0
... def __call__(self, x):
... return x * self.my_param
>>> sk.def_act_entry("my_act", MyTrainableActivation)
>>> x = jnp.ones((1, 1))
>>> sk.nn.FNN([1, 1, 1], act_func="my_act", weight_init="ones", bias_init=None)(x)
Array([[10.]], dtype=float32)
"""
if key in act_map:
raise ValueError(f"`init_key` {key=} already registered")

if not isinstance(act_func, type):
raise ValueError(f"Expected a class, got {act_func=}")
if not callable(act_func):
raise ValueError(f"Expected a class with a __call__ method, got {act_func=}")

act_map[key] = act_func
48 changes: 48 additions & 0 deletions serket/nn/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,51 @@ def resolve_init_func(init_func: str | InitFuncType) -> Callable:
return jtu.Partial(lambda key, shape, dtype=None: None)

raise ValueError("Value must be a string or a function.")


def def_init_entry(key: str, init_func: InitFuncType) -> None:
"""Register a custom initialization function key for use in ``serket`` layers.
Args:
key: The key to register the function under.
init_func: The function to register. must take three arguments: a key,
a shape, and a dtype, and return an array of the given shape and dtype.
dtype must have a default value.
Note:
By design initialization function can be passed directly to ``serket`` layers
without registration. This function is useful if you want to
represent initialization functions as a string in a configuration file.
Example:
>>> import jax
>>> import jax.numpy as jnp
>>> import serket as sk
>>> import math
>>> def my_init_func(key, shape, dtype=jnp.float32):
... return jnp.arange(math.prod(shape), dtype=dtype).reshape(shape)
>>> sk.def_init_entry("my_init", my_init_func)
>>> sk.nn.Linear(1, 5, weight_init="my_init").weight
Array([[0., 1., 2., 3., 4.]], dtype=float32)
"""
import inspect

signature = inspect.signature(init_func)

if key in init_map:
raise ValueError(f"`init_key` {key=} already registered")

if len(signature.parameters) != 3:
# verify its a three argument function
raise ValueError(f"`init_func` {len(signature.parameters)=} != 3")

argnames = list(dict(signature.parameters))

if argnames != ["key", "shape", "dtype"]:
# verify the names of the parameters
raise ValueError(f"`init_func` {argnames=} != ['key', 'shape', 'dtype']")

if signature.parameters["dtype"].default is inspect._empty:
raise ValueError("`init_func` `dtype` must have a default value")

init_map[key] = init_func

0 comments on commit 4ab21a8

Please sign in to comment.