Skip to content

Commit

Permalink
fnn edits
Browse files Browse the repository at this point in the history
reduce jaxpr length of mlp
  • Loading branch information
ASEM000 committed Jul 26, 2023
1 parent 1c47f5a commit 89d0aae
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 69 deletions.
213 changes: 149 additions & 64 deletions serket/nn/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
import jax.random as jr

import serket as sk
from serket.nn.activation import ActivationType, resolve_activation
from serket.nn.activation import (
ActivationFunctionType,
ActivationType,
resolve_activation,
)
from serket.nn.initialization import InitType
from serket.nn.linear import Linear

Expand Down Expand Up @@ -52,7 +56,7 @@ class FNN(sk.TreeClass):
``len(layers)-2`` activation functions, for example, ``layers=[10, 5, 2]``
yields 2 linear layers with weight shapes (10, 5) and (5, 2)
and single activation function is applied between them.
- ``FNN`` uses python ``for`` loop to apply layers and activation functions.
- :class:`.FNN` uses python ``for`` loop to apply layers and activation functions.
"""

Expand Down Expand Up @@ -90,22 +94,65 @@ def __init__(
for (ki, di, do) in (zip(keys, layers[:-1], layers[1:]))
)

def _multi_call(self, x: jax.Array, **k) -> jax.Array:
def __call__(self, x: jax.Array, **k) -> jax.Array:
*layers, last = self.layers
for ai, li in zip(self.act_func, layers):
x = ai(li(x))
return last(x)

def _single_call(self, x: jax.Array, **k) -> jax.Array:
*layers, last = self.layers
for li in layers:
x = self.act_func(li(x))
if isinstance(self.act_func, tuple):
for ai, li in zip(self.act_func, layers):
x = ai(li(x))
else:
for li in layers:
x = self.act_func(li(x))

return last(x)

def __call__(self, x: jax.Array, **k) -> jax.Array:
if isinstance(self.act_func, tuple):
return self._multi_call(x, **k)
return self._single_call(x, **k)

def _scan_batched_layer_with_single_activation(
x: Batched[jax.Array],
layer: Batched[Linear],
act_func: ActivationFunctionType,
) -> jax.Array:
if layer.bias is None:

def scan_func(x: jax.Array, bias: Batched[jax.Array]):
return act_func(x + bias), None

x, _ = jax.lax.scan(scan_func, x, layer.weight)
return x

def scan_func(x: jax.Array, weight_bias: Batched[jax.Array]):
weight, bias = weight_bias[..., :-1], weight_bias[..., -1]
return act_func(x @ weight + bias), None

weight_bias = jnp.concatenate([layer.weight, layer.bias[:, :, None]], axis=-1)
x, _ = jax.lax.scan(scan_func, x, weight_bias)
return x


def _scan_batched_layer_with_multiple_activations(
x: Batched[jax.Array],
layer: Batched[Linear],
act_func: Sequence[ActivationFunctionType],
) -> jax.Array:
if layer.bias is None:

def scan_func(x_index: tuple[jax.Array, int], weight: Batched[jax.Array]):
x, index = x_index
x = jax.lax.switch(index, act_func, x @ weight)
return (x, index + 1), None

(x, _), _ = jax.lax.scan(scan_func, (x, 0), layer.weight)
return x

def scan_func(x_index: jax.Array, weight_bias: Batched[jax.Array]):
x, index = x_index
weight, bias = weight_bias[..., :-1], weight_bias[..., -1]
x = jax.lax.switch(index, act_func, x @ weight + bias)
return [x, index + 1], None

weight_bias = jnp.concatenate([layer.weight, layer.bias[:, :, None]], axis=-1)
(x, _), _ = jax.lax.scan(scan_func, [x, 0], weight_bias)
return x


class MLP(sk.TreeClass):
Expand All @@ -129,15 +176,34 @@ class MLP(sk.TreeClass):
(3, 2)
Note:
- ``MLP`` with ``in_features=1``, ``out_features=2``, ``hidden_size=4``,
- :class:`.MLP` with ``in_features=1``, ``out_features=2``, ``hidden_size=4``,
``num_hidden_layers=2`` is equivalent to ``[1, 4, 4, 2]`` which has one
input layer (1, 4), one intermediate layer (4, 4), and one output
layer (4, 2) = ``num_hidden_layers + 1``
- ``MLP`` exploits same input/out size for intermediate layers to use
Note:
- :class:`.MLP` exploits same input/out size for intermediate layers to use
``jax.lax.scan``, which offers better compilation speed for large
number of layers and producing a smaller ``jaxpr`` but could be
slower than equivalent `FNN` for small number of layers.
slower than equivalent :class:`.FNN` for small number of layers.
The following compares the size of ``jaxpr`` for :class:`.MLP` and :class:`.FNN`
of equivalent layers.
>>> import jax
>>> import jax.numpy as jnp
>>> import serket as sk
>>> import numpy.testing as npt
>>> fnn = sk.nn.FNN([1] + [4] * 100 + [2])
>>> mlp = sk.nn.MLP(1, 2, hidden_size=4, num_hidden_layers=100)
>>> x = jnp.ones((100, 1))
>>> fnn_jaxpr = jax.make_jaxpr(fnn)(x)
>>> mlp_jaxpr = jax.make_jaxpr(mlp)(x)
>>> npt.assert_allclose(fnn(x), mlp(x), atol=1e-6)
>>> len(fnn_jaxpr.jaxpr.eqns)
403
>>> len(mlp_jaxpr.jaxpr.eqns)
10
"""

def __init__(
Expand All @@ -153,7 +219,7 @@ def __init__(
key: jr.KeyArray = jr.PRNGKey(0),
):
if hidden_size < 1:
raise ValueError(f"hidden_size must be positive, got {hidden_size}")
raise ValueError(f"`{hidden_size=}` must be positive.")

keys = jr.split(key, num_hidden_layers + 1)

Expand All @@ -167,60 +233,79 @@ def __init__(
else:
self.act_func = resolve_activation(act_func)

kwargs = dict(weight_init_func=weight_init_func, bias_init_func=bias_init_func)

def batched_linear(key) -> Batched[Linear]:
return sk.tree_mask(Linear(hidden_size, hidden_size, key=key, **kwargs))

self.layers = tuple(
[
Linear(
in_features=in_features,
out_features=hidden_size,
weight_init_func=weight_init_func,
bias_init_func=bias_init_func,
key=keys[0],
)
]
+ [
Linear(
in_features=hidden_size,
out_features=hidden_size,
weight_init_func=weight_init_func,
bias_init_func=bias_init_func,
key=key,
)
for key in keys[1:-1]
]
+ [
Linear(
in_features=hidden_size,
out_features=out_features,
weight_init_func=weight_init_func,
bias_init_func=bias_init_func,
key=keys[-1],
)
]
[Linear(in_features, hidden_size, key=keys[0], **kwargs)]
+ [sk.tree_unmask(jax.vmap(batched_linear)(keys[1:-1]))]
+ [Linear(hidden_size, out_features, key=keys[-1], **kwargs)]
)

def _single_call(self, x: jax.Array, **k) -> jax.Array:
def scan_func(carry, _):
x, (l0, *lh) = carry
return [self.act_func(l0(x)), [*lh, l0]], None
x = self.act_func(self.in_layer(x))

if self.mid_layer.bias is None:

def scan_mid_layer(x: jax.Array, weight_bias: jax.Array):
weight, bias = weight_bias[..., :-1], weight_bias[..., -1]
return self.act_func(x @ weight + bias), None

x, _ = jax.lax.scan(scan_mid_layer, x, self.mid_layer.bias)

else:
bias = self.mid_layer.bias[:, :, None]
weight_bias = jnp.concatenate([self.mid_layer.weight, bias], axis=-1)

def scan_mid_layer(x: jax.Array, weight_bias: jax.Array):
weight, bias = weight_bias[..., :-1], weight_bias[..., -1]
return self.act_func(x @ weight + bias), None

(l0, *lh, lf) = self.layers
x = self.act_func(l0(x))
if length := len(lh):
(x, _), _ = jax.lax.scan(scan_func, [x, lh], None, length=length)
return lf(x)
x, _ = jax.lax.scan(scan_mid_layer, x, weight_bias)

x = self.out_layer(x)
return x

def _multi_call(self, x: jax.Array, **k) -> jax.Array:
def scan_func(carry, _):
x, (l0, *lh), (a0, *ah) = carry
return [a0(l0(x)), [*lh, l0], [*ah, a0]], None
a0, *ah = self.act_func
x = a0(self.in_layer(x))

(l0, *lh, lf), (a0, *ah) = self.layers, self.act_func
x = a0(l0(x))
if length := len(lh):
(x, _, _), _ = jax.lax.scan(scan_func, [x, lh, ah], None, length=length)
return lf(x)
if self.mid_layer.bias is None:

def scan_mid_layer(x_index: jax.Array, weight: jax.Array):
x, index = x_index
x = jax.lax.switch(index, ah, x @ weight)
return [x, index + 1], None

(x, _), _ = jax.lax.scan(scan_mid_layer, [x, 0], self.mid_layer.weight)

else:
bias = self.mid_layer.bias[:, :, None]
weight_bias = jnp.concatenate([self.mid_layer.weight, bias], axis=-1)

def scan_mid_layer(x_index: jax.Array, weight_bias: jax.Array):
x, index = x_index
weight, bias = weight_bias[..., :-1], weight_bias[..., -1]
x = jax.lax.switch(index, ah, x @ weight + bias)
return [x, index + 1], None

(x, _), _ = jax.lax.scan(scan_mid_layer, [x, 0], weight_bias)

x = self.out_layer(x)
return x

def __call__(self, x: jax.Array, **k) -> jax.Array:
l0, lm, lh = self.layers

if isinstance(self.act_func, tuple):
return self._multi_call(x, **k)
return self._single_call(x, **k)
a0, *ah = self.act_func
x = a0(l0(x))
x = _scan_batched_layer_with_multiple_activations(x, lm, ah)
return lh(x)

a0 = self.act_func
x = a0(l0(x))
x = _scan_batched_layer_with_single_activation(x, lm, a0)
return lh(x)
59 changes: 54 additions & 5 deletions tests/test_fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,66 @@

import jax
import numpy.testing as npt

import jax.numpy as jnp
from serket.nn import FNN, MLP


def test_FNN():
layer = FNN([1, 2, 3, 4], act_func=("relu", "relu", "relu"))
assert not (layer.act_func[0] is layer.act_func[1])
assert not (layer.layers[0] is layer.layers[1])
def test_fnn():
layer = FNN([1, 2, 3, 4], act_func=("relu", "tanh"))
assert layer.act_func[0] is not layer.act_func[1]
assert layer.layers[0] is not layer.layers[1]

x = jax.random.normal(jax.random.PRNGKey(0), (10, 1))
w1 = jax.random.normal(jax.random.PRNGKey(1), (1, 5))
w2 = jax.random.normal(jax.random.PRNGKey(2), (5, 3))
w3 = jax.random.normal(jax.random.PRNGKey(3), (3, 4))

y = x @ w1
y = jnp.tanh(y)
y = y @ w2
y = jax.nn.relu(y)
y = y @ w3

l1 = FNN([1, 5, 3, 4], act_func=("tanh", "relu"), bias_init_func=None)
l1 = l1.at["layers"].at[0].at["weight"].set(w1)
l1 = l1.at["layers"].at[1].at["weight"].set(w2)
l1 = l1.at["layers"].at[2].at["weight"].set(w3)

npt.assert_allclose(l1(x), y)


def test_mlp():
layer = MLP(
1,
4,
hidden_size=10,
num_hidden_layers=2,
act_func=("relu", "tanh"),
bias_init_func=None,
)

x = jax.random.normal(jax.random.PRNGKey(0), (10, 1))
w1 = jax.random.normal(jax.random.PRNGKey(1), (1, 10))
w2 = jax.random.normal(jax.random.PRNGKey(2), (10, 10))
w3 = jax.random.normal(jax.random.PRNGKey(3), (10, 4))

y = x @ w1
y = jax.nn.relu(y)
y = y @ w2
y = jnp.tanh(y)
y = y @ w3

layer = layer.at["layers"].at[0].at["weight"].set(w1)
layer = layer.at["layers"].at[1].at["weight"].set(w2[None])
layer = layer.at["layers"].at[2].at["weight"].set(w3)

# breakpoint()
print(layer(x).shape)

npt.assert_allclose(layer(x), y)


def test_fnn_mlp():
fnn = FNN(layers=[2, 4, 4, 2], act_func="relu")
mlp = MLP(2, 2, hidden_size=4, num_hidden_layers=2, act_func="relu")
x = jax.random.normal(jax.random.PRNGKey(0), (10, 2))
Expand Down

0 comments on commit 89d0aae

Please sign in to comment.