Skip to content

Commit

Permalink
avoid replicating activation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Jun 8, 2023
1 parent e387c76 commit 1e41265
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 14 deletions.
3 changes: 1 addition & 2 deletions serket/nn/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from __future__ import annotations

import copy
from typing import Callable, Literal, Union, get_args

import jax
Expand Down Expand Up @@ -419,4 +418,4 @@ def resolve_activation(act_func: ActivationType) -> ActivationFunctionType:
f"Unknown activation function {act_func=}, "
f"available activations are {list(act_map.keys())}"
)
return copy.copy(act_func)
return act_func
65 changes: 56 additions & 9 deletions serket/nn/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
self,
layers: Sequence[int],
*,
act_func: ActivationType = "tanh",
act_func: ActivationType | tuple[ActivationType, ...] = "tanh",
weight_init_func: InitType = "glorot_uniform",
bias_init_func: InitType = "zeros",
key: jr.KeyArray = jr.PRNGKey(0),
Expand Down Expand Up @@ -62,7 +62,17 @@ def __init__(

keys = jr.split(key, len(layers) - 1)

self.act_funcs = tuple(resolve_activation(act_func) for _ in keys[1:])
num_hidden_layers = len(layers) - 2

if isinstance(act_func, tuple):
if len(act_func) != (num_hidden_layers + 1):
raise ValueError(
"tuple of activation functions must have "
f"length: {(num_hidden_layers+1)=}, "
)
self.act_func = tuple(resolve_activation(act) for act in act_func)
else:
self.act_func = resolve_activation(act_func)

self.layers = tuple(
Linear(
Expand All @@ -75,10 +85,22 @@ def __init__(
for (ki, di, do) in (zip(keys, layers[:-1], layers[1:]))
)

def _multi_call(self, x: jax.Array, **k) -> jax.Array:
*layers, last = self.layers
for ai, li in zip(self.act_func, layers):
x = ai(li(x))
return last(x)

def _single_call(self, x: jax.Array, **k) -> jax.Array:
*layers, last = self.layers
for li in layers:
x = self.act_func(li(x))
return last(x)

def __call__(self, x: jax.Array, **k) -> jax.Array:
for act, layer in zip(self.act_funcs, self.layers[:-1]):
x = act(layer(x))
return self.layers[-1](x)
if isinstance(self.act_func, tuple):
return self._multi_call(x, **k)
return self._single_call(x, **k)


class MLP(pytc.TreeClass):
Expand All @@ -89,7 +111,7 @@ def __init__(
*,
hidden_size: int,
num_hidden_layers: int,
act_func: ActivationType = "tanh",
act_func: ActivationType | tuple[ActivationType, ...] = "tanh",
weight_init_func: InitType = "glorot_uniform",
bias_init_func: InitType = "zeros",
key: jr.KeyArray = jr.PRNGKey(0),
Expand Down Expand Up @@ -118,7 +140,16 @@ def __init__(
raise ValueError(f"hidden_size must be positive, got {hidden_size}")

keys = jr.split(key, num_hidden_layers + 1)
self.act_funcs = tuple(resolve_activation(act_func) for _ in keys[1:])

if isinstance(act_func, tuple):
if len(act_func) != (num_hidden_layers + 1):
raise ValueError(
"tuple of activation functions must have "
f"length: {(num_hidden_layers+1)=}, "
)
self.act_func = tuple(resolve_activation(act) for act in act_func)
else:
self.act_func = resolve_activation(act_func)

self.layers = tuple(
[
Expand Down Expand Up @@ -151,13 +182,29 @@ def __init__(
]
)

def __call__(self, x: jax.Array, **k) -> jax.Array:
def _single_call(self, x: jax.Array, **k) -> jax.Array:
def scan_func(carry, _):
x, (l0, *lh) = carry
return [self.act_func(l0(x)), [*lh, l0]], None

(l0, *lh, lf) = self.layers
x = self.act_func(l0(x))
if length := len(lh):
(x, _), _ = jax.lax.scan(scan_func, [x, lh], None, length=length)
return lf(x)

def _multi_call(self, x: jax.Array, **k) -> jax.Array:
def scan_func(carry, _):
x, (l0, *lh), (a0, *ah) = carry
return [a0(l0(x)), [*lh, l0], [*ah, a0]], None

(l0, *lh, lf), (a0, *ah) = self.layers, self.act_funcs
(l0, *lh, lf), (a0, *ah) = self.layers, self.act_func
x = a0(l0(x))
if length := len(lh):
(x, _, _), _ = jax.lax.scan(scan_func, [x, lh, ah], None, length=length)
return lf(x)

def __call__(self, x: jax.Array, **k) -> jax.Array:
if isinstance(self.act_func, tuple):
return self._multi_call(x, **k)
return self._single_call(x, **k)
6 changes: 3 additions & 3 deletions tests/test_fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@


def test_FNN():
layer = FNN([1, 2, 3, 4], act_func="relu")
assert not layer.act_funcs[0] is layer.act_funcs[1]
assert not layer.layers[0] is layer.layers[1]
layer = FNN([1, 2, 3, 4], act_func=("relu", "relu", "relu"))
assert not (layer.act_func[0] is layer.act_func[1])
assert not (layer.layers[0] is layer.layers[1])


def test_mlp():
Expand Down

0 comments on commit 1e41265

Please sign in to comment.