Skip to content

Commit

Permalink
edit def_act_entry
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 24, 2023
1 parent be1e174 commit a7fbd3e
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 48 deletions.
2 changes: 1 addition & 1 deletion docs/notebooks/layers_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@
"linear = sk.nn.FNN([1, 1], act=MyTrainableActivation())\n",
"\n",
"# 4) activation function with a registered class\n",
"sk.def_act_entry(\"my_act\", MyTrainableActivation)\n",
"sk.def_act_entry(\"my_act\", MyTrainableActivation())\n",
"linear = sk.nn.FNN([1, 1], act=\"my_act\")"
]
},
Expand Down
101 changes: 54 additions & 47 deletions serket/nn/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from __future__ import annotations

from typing import Callable, Literal, Protocol, Union, get_args
import functools as ft
from typing import Any, Callable, Literal, TypeVar, Union, get_args

import jax
import jax.numpy as jnp
Expand All @@ -23,6 +24,8 @@
import serket as sk
from serket.nn.utils import IsInstance, Range, ScalarLike

T = TypeVar("T")


@sk.autoinit
class AdaptiveLeakyReLU(sk.TreeClass):
Expand Down Expand Up @@ -107,7 +110,7 @@ def __call__(self, x: jax.Array) -> jax.Array:
class GELU(sk.TreeClass):
"""Gaussian error linear unit"""

approximate: bool = sk.field(default=1.0, callbacks=[IsInstance(bool)])
approximate: bool = sk.field(default=False, callbacks=[IsInstance(bool)])

def __call__(self, x: jax.Array) -> jax.Array:
return jax.nn.gelu(x, approximate=self.approximate)
Expand Down Expand Up @@ -344,36 +347,36 @@ def __call__(self, x: jax.Array) -> jax.Array:


acts = [
AdaptiveLeakyReLU,
AdaptiveReLU,
AdaptiveSigmoid,
AdaptiveTanh,
CeLU,
ELU,
GELU,
GLU,
HardShrink,
HardSigmoid,
HardSwish,
HardTanh,
LeakyReLU,
LogSigmoid,
LogSoftmax,
Mish,
PReLU,
ReLU,
ReLU6,
SeLU,
Sigmoid,
Snake,
SoftPlus,
SoftShrink,
SoftSign,
SquarePlus,
Swish,
Tanh,
TanhShrink,
ThresholdedReLU,
AdaptiveLeakyReLU(),
AdaptiveReLU(),
AdaptiveSigmoid(),
AdaptiveTanh(),
CeLU(),
ELU(),
GELU(),
GLU(),
HardShrink(),
HardSigmoid(),
HardSwish(),
HardTanh(),
LeakyReLU(),
LogSigmoid(),
LogSoftmax(),
Mish(),
PReLU(),
ReLU(),
ReLU6(),
SeLU(),
Sigmoid(),
Snake(),
SoftPlus(),
SoftShrink(),
SoftSign(),
SquarePlus(),
Swish(),
Tanh(),
TanhShrink(),
ThresholdedReLU(),
]


Expand All @@ -383,28 +386,34 @@ def __call__(self, x: jax.Array) -> jax.Array:
ActivationType = Union[ActivationLiteral, ActivationFunctionType]


class ActivationClassType(Protocol):
class ActivationClassType(sk.TreeClass):
def __call__(self, x: jax.typing.ArrayLike) -> jax.Array:
...


def resolve_activation(act: ActivationType) -> ActivationFunctionType:
# in case the user passes a trainable activation function
# we need to make a copy of it to avoid unpredictable side effects
if isinstance(act, str):
if act in act_map:
return act_map[act]()
raise ValueError(f"Unknown {act=}, available activations: {list(act_map)}")
@ft.singledispatch
def resolve_activation(act: T) -> T:
return act


def def_act_entry(key: str, act: ActivationClassType) -> None:
@resolve_activation.register(str)
def _(act: str) -> sk.TreeClass:
try:
return jax.tree_map(lambda x: x, act_map[act])
except KeyError:
raise ValueError(f"Unknown {act=}, available activations: {list(act_map)}")


def def_act_entry(
key: str,
act: Callable[[jax.typing.ArrayLike], jax.Array] | Any,
) -> None:
"""Register a custom activation function key for use in ``serket`` layers.
Args:
key: The key to register the function under.
act: a class with a ``__call__`` method that takes a single argument
and returns a ``jax`` array.
act: a callable object that takes a single argument and returns a ``jax``
array.
Note:
The registered key can be used in any of ``serket`` ``act_*`` arguments as
Expand All @@ -423,17 +432,15 @@ def def_act_entry(key: str, act: ActivationClassType) -> None:
... my_param: float = 10.0
... def __call__(self, x):
... return x * self.my_param
>>> sk.def_act_entry("my_act", MyTrainableActivation)
>>> sk.def_act_entry("my_act", MyTrainableActivation())
>>> x = jnp.ones((1, 1))
>>> sk.nn.FNN([1, 1, 1], act="my_act", weight_init="ones", bias_init=None)(x)
Array([[10.]], dtype=float32)
"""
if key in act_map:
raise ValueError(f"`init_key` {key=} already registered")

if not isinstance(act, type):
raise ValueError(f"Expected a class, got {act=}")
if not callable(act):
raise ValueError(f"Expected a class with a `__call__` method, got {act=}")
raise TypeError(f"{act=} must be a callable object")

act_map[key] = act

0 comments on commit a7fbd3e

Please sign in to comment.