From 4ab21a8dac91106240d3e7c76c33092928b280f9 Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Tue, 1 Aug 2023 15:38:56 +0900 Subject: [PATCH] add `def_init_entry`, `def_act_entry` --- docs/API/api.rst | 2 ++ serket/__init__.py | 4 +++ serket/nn/activation.py | 49 ++++++++++++++++++++++++++++++++++++- serket/nn/initialization.py | 48 ++++++++++++++++++++++++++++++++++++ 4 files changed, 102 insertions(+), 1 deletion(-) diff --git a/docs/API/api.rst b/docs/API/api.rst index 4021425..b86b273 100644 --- a/docs/API/api.rst +++ b/docs/API/api.rst @@ -5,6 +5,8 @@ .. autofunction:: tree_state .. autofunction:: tree_eval +.. autofunction:: def_init_entry +.. autofunction:: def_act_entry .. toctree:: :maxdepth: 2 diff --git a/serket/__init__.py b/serket/__init__.py index 8c97542..06a382e 100644 --- a/serket/__init__.py +++ b/serket/__init__.py @@ -43,7 +43,9 @@ ) from . import nn +from .nn.activation import def_act_entry from .nn.custom_transform import tree_eval, tree_state +from .nn.initialization import def_init_entry __all__ = ( # general utils @@ -81,6 +83,8 @@ "nn", "tree_eval", "tree_state", + "def_init_entry", + "def_act_entry", ) diff --git a/serket/nn/activation.py b/serket/nn/activation.py index c75e7bd..624af35 100644 --- a/serket/nn/activation.py +++ b/serket/nn/activation.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Callable, Literal, Union, get_args +from typing import Callable, Literal, Protocol, Union, get_args import jax import jax.numpy as jnp @@ -383,6 +383,11 @@ def __call__(self, x: jax.Array, **k) -> jax.Array: ActivationType = Union[ActivationLiteral, ActivationFunctionType] +class ActivationClassType(Protocol): + def __call__(self, x: jax.typing.ArrayLike) -> jax.Array: + ... + + def resolve_activation(act_func: ActivationType) -> ActivationFunctionType: # in case the user passes a trainable activation function # we need to make a copy of it to avoid unpredictable side effects @@ -391,3 +396,45 @@ def resolve_activation(act_func: ActivationType) -> ActivationFunctionType: return act_map[act_func]() raise ValueError(f"Unknown {act_func=}, available activations: {list(act_map)}") return act_func + + +def def_act_entry(key: str, act_func: ActivationClassType) -> None: + """Register a custom activation function key for use in ``serket`` layers. + + Args: + key: The key to register the function under. + act_func: a class with a ``__call__`` method that takes a single argument + and returns a ``jax`` array. + + Note: + The registered key can be used in any of ``serket`` ``act_*`` arguments as + substitution for the function. + + Note: + By design, activation functions can be passed directly to ``serket`` layers + with the ``act_func`` argument. This function is useful if you want to + represent activation functions as a string in a configuration file. + + Example: + >>> import serket as sk + >>> import math + >>> import jax.numpy as jnp + >>> @sk.autoinit + ... class MyTrainableActivation(sk.TreeClass): + ... my_param: float = 10.0 + ... def __call__(self, x): + ... return x * self.my_param + >>> sk.def_act_entry("my_act", MyTrainableActivation) + >>> x = jnp.ones((1, 1)) + >>> sk.nn.FNN([1, 1, 1], act_func="my_act", weight_init="ones", bias_init=None)(x) + Array([[10.]], dtype=float32) + """ + if key in act_map: + raise ValueError(f"`init_key` {key=} already registered") + + if not isinstance(act_func, type): + raise ValueError(f"Expected a class, got {act_func=}") + if not callable(act_func): + raise ValueError(f"Expected a class with a __call__ method, got {act_func=}") + + act_map[key] = act_func diff --git a/serket/nn/initialization.py b/serket/nn/initialization.py index 91b9541..51da41f 100644 --- a/serket/nn/initialization.py +++ b/serket/nn/initialization.py @@ -75,3 +75,51 @@ def resolve_init_func(init_func: str | InitFuncType) -> Callable: return jtu.Partial(lambda key, shape, dtype=None: None) raise ValueError("Value must be a string or a function.") + + +def def_init_entry(key: str, init_func: InitFuncType) -> None: + """Register a custom initialization function key for use in ``serket`` layers. + + Args: + key: The key to register the function under. + init_func: The function to register. must take three arguments: a key, + a shape, and a dtype, and return an array of the given shape and dtype. + dtype must have a default value. + + Note: + By design initialization function can be passed directly to ``serket`` layers + without registration. This function is useful if you want to + represent initialization functions as a string in a configuration file. + + Example: + >>> import jax + >>> import jax.numpy as jnp + >>> import serket as sk + >>> import math + >>> def my_init_func(key, shape, dtype=jnp.float32): + ... return jnp.arange(math.prod(shape), dtype=dtype).reshape(shape) + >>> sk.def_init_entry("my_init", my_init_func) + >>> sk.nn.Linear(1, 5, weight_init="my_init").weight + Array([[0., 1., 2., 3., 4.]], dtype=float32) + """ + import inspect + + signature = inspect.signature(init_func) + + if key in init_map: + raise ValueError(f"`init_key` {key=} already registered") + + if len(signature.parameters) != 3: + # verify its a three argument function + raise ValueError(f"`init_func` {len(signature.parameters)=} != 3") + + argnames = list(dict(signature.parameters)) + + if argnames != ["key", "shape", "dtype"]: + # verify the names of the parameters + raise ValueError(f"`init_func` {argnames=} != ['key', 'shape', 'dtype']") + + if signature.parameters["dtype"].default is inspect._empty: + raise ValueError("`init_func` `dtype` must have a default value") + + init_map[key] = init_func