Skip to content

Commit

Permalink
act_func -> act
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 14, 2023
1 parent 1492ee5 commit 6770dbd
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 123 deletions.
1 change: 1 addition & 0 deletions CHANEGLOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- `***_init_func` -> `***_init` shorter and more concise
- `gamma_init_func` -> `weight_init`
- `beta_init_func` -> `bias_init`
- `act_func` -> `act`
- `MLP` produces smaller `jaxprs` and are faster to compile. for my use case -higher order differentiation through `PINN`- the new `MLP` is faster to compile.
- `kernel_dilation` -> `dilation`
- `input_dilation` -> Removed.
Expand Down
8 changes: 4 additions & 4 deletions docs/notebooks/layers_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@
"import jax\n",
"\n",
"# 1) activation function with a string\n",
"linear = sk.nn.FNN([1, 1], act_func=\"relu\")\n",
"linear = sk.nn.FNN([1, 1], act=\"relu\")\n",
"\n",
"# 2) activation function with a function\n",
"linear = sk.nn.FNN([1, 1], act_func=jax.nn.relu)\n",
"linear = sk.nn.FNN([1, 1], act=jax.nn.relu)\n",
"\n",
"\n",
"@sk.autoinit\n",
Expand All @@ -149,11 +149,11 @@
"\n",
"\n",
"# 3) activation function with a class\n",
"linear = sk.nn.FNN([1, 1], act_func=MyTrainableActivation())\n",
"linear = sk.nn.FNN([1, 1], act=MyTrainableActivation())\n",
"\n",
"# 4) activation function with a registered class\n",
"sk.def_act_entry(\"my_act\", MyTrainableActivation)\n",
"linear = sk.nn.FNN([1, 1], act_func=\"my_act\")"
"linear = sk.nn.FNN([1, 1], act=\"my_act\")"
]
},
{
Expand Down
30 changes: 15 additions & 15 deletions serket/nn/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,22 +388,22 @@ def __call__(self, x: jax.typing.ArrayLike) -> jax.Array:
...


def resolve_activation(act_func: ActivationType) -> ActivationFunctionType:
def resolve_activation(act: ActivationType) -> ActivationFunctionType:
# in case the user passes a trainable activation function
# we need to make a copy of it to avoid unpredictable side effects
if isinstance(act_func, str):
if act_func in act_map:
return act_map[act_func]()
raise ValueError(f"Unknown {act_func=}, available activations: {list(act_map)}")
return act_func
if isinstance(act, str):
if act in act_map:
return act_map[act]()
raise ValueError(f"Unknown {act=}, available activations: {list(act_map)}")
return act


def def_act_entry(key: str, act_func: ActivationClassType) -> None:
def def_act_entry(key: str, act: ActivationClassType) -> None:
"""Register a custom activation function key for use in ``serket`` layers.
Args:
key: The key to register the function under.
act_func: a class with a ``__call__`` method that takes a single argument
act: a class with a ``__call__`` method that takes a single argument
and returns a ``jax`` array.
Note:
Expand All @@ -412,7 +412,7 @@ def def_act_entry(key: str, act_func: ActivationClassType) -> None:
Note:
By design, activation functions can be passed directly to ``serket`` layers
with the ``act_func`` argument. This function is useful if you want to
with the ``act`` argument. This function is useful if you want to
represent activation functions as a string in a configuration file.
Example:
Expand All @@ -425,15 +425,15 @@ def def_act_entry(key: str, act_func: ActivationClassType) -> None:
... return x * self.my_param
>>> sk.def_act_entry("my_act", MyTrainableActivation)
>>> x = jnp.ones((1, 1))
>>> sk.nn.FNN([1, 1, 1], act_func="my_act", weight_init="ones", bias_init=None)(x)
>>> sk.nn.FNN([1, 1, 1], act="my_act", weight_init="ones", bias_init=None)(x)
Array([[10.]], dtype=float32)
"""
if key in act_map:
raise ValueError(f"`init_key` {key=} already registered")

if not isinstance(act_func, type):
raise ValueError(f"Expected a class, got {act_func=}")
if not callable(act_func):
raise ValueError(f"Expected a class with a `__call__` method, got {act_func=}")
if not isinstance(act, type):
raise ValueError(f"Expected a class, got {act=}")
if not callable(act):
raise ValueError(f"Expected a class with a `__call__` method, got {act=}")

act_map[key] = act_func
act_map[key] = act
52 changes: 26 additions & 26 deletions serket/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ class FNN(sk.TreeClass):
Args:
layers: Sequence of layer sizes
act_func: a single Activation function to be applied between layers or
act: a single Activation function to be applied between layers or
``len(layers)-2`` Sequence of activation functions applied between
layers.
weight_init: Weight initializer function.
Expand Down Expand Up @@ -445,7 +445,7 @@ def __init__(
self,
layers: Sequence[int],
*,
act_func: ActivationType | tuple[ActivationType, ...] = "tanh",
act: ActivationType | tuple[ActivationType, ...] = "tanh",
weight_init: InitType = "glorot_uniform",
bias_init: InitType = "zeros",
key: jr.KeyArray = jr.PRNGKey(0),
Expand All @@ -454,13 +454,13 @@ def __init__(
keys = jr.split(key, len(layers) - 1)
num_hidden_layers = len(layers) - 2

if isinstance(act_func, tuple):
if len(act_func) != (num_hidden_layers):
raise ValueError(f"{len(act_func)=} != {(num_hidden_layers)=}")
if isinstance(act, tuple):
if len(act) != (num_hidden_layers):
raise ValueError(f"{len(act)=} != {(num_hidden_layers)=}")

self.act_func = tuple(resolve_activation(act) for act in act_func)
self.act = tuple(resolve_activation(act) for act in act)
else:
self.act_func = resolve_activation(act_func)
self.act = resolve_activation(act)

self.layers = tuple(
Linear(
Expand All @@ -477,32 +477,32 @@ def __init__(
def __call__(self, x: jax.Array, **k) -> jax.Array:
*layers, last = self.layers

if isinstance(self.act_func, tuple):
for ai, li in zip(self.act_func, layers):
if isinstance(self.act, tuple):
for ai, li in zip(self.act, layers):
x = ai(li(x))
else:
for li in layers:
x = self.act_func(li(x))
x = self.act(li(x))

return last(x)


def _scan_batched_layer_with_single_activation(
x: Batched[jax.Array],
layer: Batched[Linear],
act_func: ActivationFunctionType,
act: ActivationFunctionType,
) -> jax.Array:
if layer.bias is None:

def scan_func(x: jax.Array, bias: Batched[jax.Array]):
return act_func(x + bias), None
return act(x + bias), None

x, _ = jax.lax.scan(scan_func, x, layer.weight)
return x

def scan_func(x: jax.Array, weight_bias: Batched[jax.Array]):
weight, bias = weight_bias[..., :-1], weight_bias[..., -1]
return act_func(x @ weight + bias), None
return act(x @ weight + bias), None

weight_bias = jnp.concatenate([layer.weight, layer.bias[:, :, None]], axis=-1)
x, _ = jax.lax.scan(scan_func, x, weight_bias)
Expand All @@ -512,13 +512,13 @@ def scan_func(x: jax.Array, weight_bias: Batched[jax.Array]):
def _scan_batched_layer_with_multiple_activations(
x: Batched[jax.Array],
layer: Batched[Linear],
act_func: Sequence[ActivationFunctionType],
act: Sequence[ActivationFunctionType],
) -> jax.Array:
if layer.bias is None:

def scan_func(x_index: tuple[jax.Array, int], weight: Batched[jax.Array]):
x, index = x_index
x = jax.lax.switch(index, act_func, x @ weight)
x = jax.lax.switch(index, act, x @ weight)
return (x, index + 1), None

(x, _), _ = jax.lax.scan(scan_func, (x, 0), layer.weight)
Expand All @@ -527,7 +527,7 @@ def scan_func(x_index: tuple[jax.Array, int], weight: Batched[jax.Array]):
def scan_func(x_index: jax.Array, weight_bias: Batched[jax.Array]):
x, index = x_index
weight, bias = weight_bias[..., :-1], weight_bias[..., -1]
x = jax.lax.switch(index, act_func, x @ weight + bias)
x = jax.lax.switch(index, act, x @ weight + bias)
return [x, index + 1], None

weight_bias = jnp.concatenate([layer.weight, layer.bias[:, :, None]], axis=-1)
Expand All @@ -543,7 +543,7 @@ class MLP(sk.TreeClass):
out_features: Number of output features.
hidden_size: Number of hidden units in each hidden layer.
num_hidden_layers: Number of hidden layers including the output layer.
act_func: Activation function.
act: Activation function.
weight_init: Weight initialization function.
bias_init: Bias initialization function.
key: Random number generator key.
Expand Down Expand Up @@ -610,7 +610,7 @@ def __init__(
*,
hidden_size: int,
num_hidden_layers: int,
act_func: ActivationType | tuple[ActivationType, ...] = "tanh",
act: ActivationType | tuple[ActivationType, ...] = "tanh",
weight_init: InitType = "glorot_uniform",
bias_init: InitType = "zeros",
key: jr.KeyArray = jr.PRNGKey(0),
Expand All @@ -621,12 +621,12 @@ def __init__(

keys = jr.split(key, num_hidden_layers + 1)

if isinstance(act_func, tuple):
if len(act_func) != (num_hidden_layers):
raise ValueError(f"{len(act_func)=} != {(num_hidden_layers)=}")
self.act_func = tuple(resolve_activation(act) for act in act_func)
if isinstance(act, tuple):
if len(act) != (num_hidden_layers):
raise ValueError(f"{len(act)=} != {(num_hidden_layers)=}")
self.act = tuple(resolve_activation(act) for act in act)
else:
self.act_func = resolve_activation(act_func)
self.act = resolve_activation(act)

kwargs = dict(weight_init=weight_init, bias_init=bias_init, dtype=dtype)

Expand All @@ -643,13 +643,13 @@ def batched_linear(key: jr.KeyArray) -> Batched[Linear]:
def __call__(self, x: jax.Array, **k) -> jax.Array:
l0, lm, lh = self.layers

if isinstance(self.act_func, tuple):
a0, *ah = self.act_func
if isinstance(self.act, tuple):
a0, *ah = self.act
x = a0(l0(x))
x = _scan_batched_layer_with_multiple_activations(x, lm, ah)
return lh(x)

a0 = self.act_func
a0 = self.act
x = a0(l0(x))
x = _scan_batched_layer_with_single_activation(x, lm, a0)
return lh(x)
Loading

0 comments on commit 6770dbd

Please sign in to comment.