From 89d0aae1823a1b76db7f16316de738ef6fdb6e78 Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Thu, 27 Jul 2023 06:12:38 +0900 Subject: [PATCH] fnn edits reduce jaxpr length of mlp --- serket/nn/fully_connected.py | 213 ++++++++++++++++++++++++---------- tests/test_fully_connected.py | 59 +++++++++- 2 files changed, 203 insertions(+), 69 deletions(-) diff --git a/serket/nn/fully_connected.py b/serket/nn/fully_connected.py index 010f521..5fdf6a7 100644 --- a/serket/nn/fully_connected.py +++ b/serket/nn/fully_connected.py @@ -20,7 +20,11 @@ import jax.random as jr import serket as sk -from serket.nn.activation import ActivationType, resolve_activation +from serket.nn.activation import ( + ActivationFunctionType, + ActivationType, + resolve_activation, +) from serket.nn.initialization import InitType from serket.nn.linear import Linear @@ -52,7 +56,7 @@ class FNN(sk.TreeClass): ``len(layers)-2`` activation functions, for example, ``layers=[10, 5, 2]`` yields 2 linear layers with weight shapes (10, 5) and (5, 2) and single activation function is applied between them. - - ``FNN`` uses python ``for`` loop to apply layers and activation functions. + - :class:`.FNN` uses python ``for`` loop to apply layers and activation functions. """ @@ -90,22 +94,65 @@ def __init__( for (ki, di, do) in (zip(keys, layers[:-1], layers[1:])) ) - def _multi_call(self, x: jax.Array, **k) -> jax.Array: + def __call__(self, x: jax.Array, **k) -> jax.Array: *layers, last = self.layers - for ai, li in zip(self.act_func, layers): - x = ai(li(x)) - return last(x) - def _single_call(self, x: jax.Array, **k) -> jax.Array: - *layers, last = self.layers - for li in layers: - x = self.act_func(li(x)) + if isinstance(self.act_func, tuple): + for ai, li in zip(self.act_func, layers): + x = ai(li(x)) + else: + for li in layers: + x = self.act_func(li(x)) + return last(x) - def __call__(self, x: jax.Array, **k) -> jax.Array: - if isinstance(self.act_func, tuple): - return self._multi_call(x, **k) - return self._single_call(x, **k) + +def _scan_batched_layer_with_single_activation( + x: Batched[jax.Array], + layer: Batched[Linear], + act_func: ActivationFunctionType, +) -> jax.Array: + if layer.bias is None: + + def scan_func(x: jax.Array, bias: Batched[jax.Array]): + return act_func(x + bias), None + + x, _ = jax.lax.scan(scan_func, x, layer.weight) + return x + + def scan_func(x: jax.Array, weight_bias: Batched[jax.Array]): + weight, bias = weight_bias[..., :-1], weight_bias[..., -1] + return act_func(x @ weight + bias), None + + weight_bias = jnp.concatenate([layer.weight, layer.bias[:, :, None]], axis=-1) + x, _ = jax.lax.scan(scan_func, x, weight_bias) + return x + + +def _scan_batched_layer_with_multiple_activations( + x: Batched[jax.Array], + layer: Batched[Linear], + act_func: Sequence[ActivationFunctionType], +) -> jax.Array: + if layer.bias is None: + + def scan_func(x_index: tuple[jax.Array, int], weight: Batched[jax.Array]): + x, index = x_index + x = jax.lax.switch(index, act_func, x @ weight) + return (x, index + 1), None + + (x, _), _ = jax.lax.scan(scan_func, (x, 0), layer.weight) + return x + + def scan_func(x_index: jax.Array, weight_bias: Batched[jax.Array]): + x, index = x_index + weight, bias = weight_bias[..., :-1], weight_bias[..., -1] + x = jax.lax.switch(index, act_func, x @ weight + bias) + return [x, index + 1], None + + weight_bias = jnp.concatenate([layer.weight, layer.bias[:, :, None]], axis=-1) + (x, _), _ = jax.lax.scan(scan_func, [x, 0], weight_bias) + return x class MLP(sk.TreeClass): @@ -129,15 +176,34 @@ class MLP(sk.TreeClass): (3, 2) Note: - - ``MLP`` with ``in_features=1``, ``out_features=2``, ``hidden_size=4``, + - :class:`.MLP` with ``in_features=1``, ``out_features=2``, ``hidden_size=4``, ``num_hidden_layers=2`` is equivalent to ``[1, 4, 4, 2]`` which has one input layer (1, 4), one intermediate layer (4, 4), and one output layer (4, 2) = ``num_hidden_layers + 1`` - - ``MLP`` exploits same input/out size for intermediate layers to use + + Note: + - :class:`.MLP` exploits same input/out size for intermediate layers to use ``jax.lax.scan``, which offers better compilation speed for large number of layers and producing a smaller ``jaxpr`` but could be - slower than equivalent `FNN` for small number of layers. + slower than equivalent :class:`.FNN` for small number of layers. + + The following compares the size of ``jaxpr`` for :class:`.MLP` and :class:`.FNN` + of equivalent layers. + >>> import jax + >>> import jax.numpy as jnp + >>> import serket as sk + >>> import numpy.testing as npt + >>> fnn = sk.nn.FNN([1] + [4] * 100 + [2]) + >>> mlp = sk.nn.MLP(1, 2, hidden_size=4, num_hidden_layers=100) + >>> x = jnp.ones((100, 1)) + >>> fnn_jaxpr = jax.make_jaxpr(fnn)(x) + >>> mlp_jaxpr = jax.make_jaxpr(mlp)(x) + >>> npt.assert_allclose(fnn(x), mlp(x), atol=1e-6) + >>> len(fnn_jaxpr.jaxpr.eqns) + 403 + >>> len(mlp_jaxpr.jaxpr.eqns) + 10 """ def __init__( @@ -153,7 +219,7 @@ def __init__( key: jr.KeyArray = jr.PRNGKey(0), ): if hidden_size < 1: - raise ValueError(f"hidden_size must be positive, got {hidden_size}") + raise ValueError(f"`{hidden_size=}` must be positive.") keys = jr.split(key, num_hidden_layers + 1) @@ -167,60 +233,79 @@ def __init__( else: self.act_func = resolve_activation(act_func) + kwargs = dict(weight_init_func=weight_init_func, bias_init_func=bias_init_func) + + def batched_linear(key) -> Batched[Linear]: + return sk.tree_mask(Linear(hidden_size, hidden_size, key=key, **kwargs)) + self.layers = tuple( - [ - Linear( - in_features=in_features, - out_features=hidden_size, - weight_init_func=weight_init_func, - bias_init_func=bias_init_func, - key=keys[0], - ) - ] - + [ - Linear( - in_features=hidden_size, - out_features=hidden_size, - weight_init_func=weight_init_func, - bias_init_func=bias_init_func, - key=key, - ) - for key in keys[1:-1] - ] - + [ - Linear( - in_features=hidden_size, - out_features=out_features, - weight_init_func=weight_init_func, - bias_init_func=bias_init_func, - key=keys[-1], - ) - ] + [Linear(in_features, hidden_size, key=keys[0], **kwargs)] + + [sk.tree_unmask(jax.vmap(batched_linear)(keys[1:-1]))] + + [Linear(hidden_size, out_features, key=keys[-1], **kwargs)] ) def _single_call(self, x: jax.Array, **k) -> jax.Array: - def scan_func(carry, _): - x, (l0, *lh) = carry - return [self.act_func(l0(x)), [*lh, l0]], None + x = self.act_func(self.in_layer(x)) + + if self.mid_layer.bias is None: + + def scan_mid_layer(x: jax.Array, weight_bias: jax.Array): + weight, bias = weight_bias[..., :-1], weight_bias[..., -1] + return self.act_func(x @ weight + bias), None + + x, _ = jax.lax.scan(scan_mid_layer, x, self.mid_layer.bias) + + else: + bias = self.mid_layer.bias[:, :, None] + weight_bias = jnp.concatenate([self.mid_layer.weight, bias], axis=-1) + + def scan_mid_layer(x: jax.Array, weight_bias: jax.Array): + weight, bias = weight_bias[..., :-1], weight_bias[..., -1] + return self.act_func(x @ weight + bias), None - (l0, *lh, lf) = self.layers - x = self.act_func(l0(x)) - if length := len(lh): - (x, _), _ = jax.lax.scan(scan_func, [x, lh], None, length=length) - return lf(x) + x, _ = jax.lax.scan(scan_mid_layer, x, weight_bias) + + x = self.out_layer(x) + return x def _multi_call(self, x: jax.Array, **k) -> jax.Array: - def scan_func(carry, _): - x, (l0, *lh), (a0, *ah) = carry - return [a0(l0(x)), [*lh, l0], [*ah, a0]], None + a0, *ah = self.act_func + x = a0(self.in_layer(x)) - (l0, *lh, lf), (a0, *ah) = self.layers, self.act_func - x = a0(l0(x)) - if length := len(lh): - (x, _, _), _ = jax.lax.scan(scan_func, [x, lh, ah], None, length=length) - return lf(x) + if self.mid_layer.bias is None: + + def scan_mid_layer(x_index: jax.Array, weight: jax.Array): + x, index = x_index + x = jax.lax.switch(index, ah, x @ weight) + return [x, index + 1], None + + (x, _), _ = jax.lax.scan(scan_mid_layer, [x, 0], self.mid_layer.weight) + + else: + bias = self.mid_layer.bias[:, :, None] + weight_bias = jnp.concatenate([self.mid_layer.weight, bias], axis=-1) + + def scan_mid_layer(x_index: jax.Array, weight_bias: jax.Array): + x, index = x_index + weight, bias = weight_bias[..., :-1], weight_bias[..., -1] + x = jax.lax.switch(index, ah, x @ weight + bias) + return [x, index + 1], None + + (x, _), _ = jax.lax.scan(scan_mid_layer, [x, 0], weight_bias) + + x = self.out_layer(x) + return x def __call__(self, x: jax.Array, **k) -> jax.Array: + l0, lm, lh = self.layers + if isinstance(self.act_func, tuple): - return self._multi_call(x, **k) - return self._single_call(x, **k) + a0, *ah = self.act_func + x = a0(l0(x)) + x = _scan_batched_layer_with_multiple_activations(x, lm, ah) + return lh(x) + + a0 = self.act_func + x = a0(l0(x)) + x = _scan_batched_layer_with_single_activation(x, lm, a0) + return lh(x) diff --git a/tests/test_fully_connected.py b/tests/test_fully_connected.py index 9afc86c..6225453 100644 --- a/tests/test_fully_connected.py +++ b/tests/test_fully_connected.py @@ -14,17 +14,66 @@ import jax import numpy.testing as npt - +import jax.numpy as jnp from serket.nn import FNN, MLP -def test_FNN(): - layer = FNN([1, 2, 3, 4], act_func=("relu", "relu", "relu")) - assert not (layer.act_func[0] is layer.act_func[1]) - assert not (layer.layers[0] is layer.layers[1]) +def test_fnn(): + layer = FNN([1, 2, 3, 4], act_func=("relu", "tanh")) + assert layer.act_func[0] is not layer.act_func[1] + assert layer.layers[0] is not layer.layers[1] + + x = jax.random.normal(jax.random.PRNGKey(0), (10, 1)) + w1 = jax.random.normal(jax.random.PRNGKey(1), (1, 5)) + w2 = jax.random.normal(jax.random.PRNGKey(2), (5, 3)) + w3 = jax.random.normal(jax.random.PRNGKey(3), (3, 4)) + + y = x @ w1 + y = jnp.tanh(y) + y = y @ w2 + y = jax.nn.relu(y) + y = y @ w3 + + l1 = FNN([1, 5, 3, 4], act_func=("tanh", "relu"), bias_init_func=None) + l1 = l1.at["layers"].at[0].at["weight"].set(w1) + l1 = l1.at["layers"].at[1].at["weight"].set(w2) + l1 = l1.at["layers"].at[2].at["weight"].set(w3) + + npt.assert_allclose(l1(x), y) def test_mlp(): + layer = MLP( + 1, + 4, + hidden_size=10, + num_hidden_layers=2, + act_func=("relu", "tanh"), + bias_init_func=None, + ) + + x = jax.random.normal(jax.random.PRNGKey(0), (10, 1)) + w1 = jax.random.normal(jax.random.PRNGKey(1), (1, 10)) + w2 = jax.random.normal(jax.random.PRNGKey(2), (10, 10)) + w3 = jax.random.normal(jax.random.PRNGKey(3), (10, 4)) + + y = x @ w1 + y = jax.nn.relu(y) + y = y @ w2 + y = jnp.tanh(y) + y = y @ w3 + + layer = layer.at["layers"].at[0].at["weight"].set(w1) + layer = layer.at["layers"].at[1].at["weight"].set(w2[None]) + layer = layer.at["layers"].at[2].at["weight"].set(w3) + + # breakpoint() + print(layer(x).shape) + + npt.assert_allclose(layer(x), y) + + +def test_fnn_mlp(): fnn = FNN(layers=[2, 4, 4, 2], act_func="relu") mlp = MLP(2, 2, hidden_size=4, num_hidden_layers=2, act_func="relu") x = jax.random.normal(jax.random.PRNGKey(0), (10, 2))