diff --git a/docs/API/linear.rst b/docs/API/linear.rst index 755a929..933a652 100644 --- a/docs/API/linear.rst +++ b/docs/API/linear.rst @@ -4,8 +4,5 @@ Linear .. autoclass:: Linear .. autoclass:: Identity -.. autoclass:: GeneralLinear .. autoclass:: Embedding - -.. autoclass:: FNN .. autoclass:: MLP \ No newline at end of file diff --git a/serket/_src/nn/activation.py b/serket/_src/nn/activation.py index cf8e3a1..1fdab3b 100644 --- a/serket/_src/nn/activation.py +++ b/serket/_src/nn/activation.py @@ -561,9 +561,6 @@ def def_act_entry(key: str, act: ActivationFunctionType) -> None: ... def __call__(self, x): ... return x * self.my_param >>> sk.def_act_entry("my_act", MyTrainableActivation()) - >>> x = jnp.ones((1, 1)) - >>> sk.nn.FNN([1, 1, 1], act="my_act", weight_init="ones", bias_init=None, key=jr.PRNGKey(0))(x) - Array([[10.]], dtype=float32) """ if key in act_map: raise ValueError(f"`init_key` {key=} already registered") diff --git a/serket/_src/nn/linear.py b/serket/_src/nn/linear.py index cc4616b..a829ce4 100644 --- a/serket/_src/nn/linear.py +++ b/serket/_src/nn/linear.py @@ -15,7 +15,7 @@ from __future__ import annotations import functools as ft -from typing import Any, Generic, Sequence, TypeVar +from typing import Any, Generic, TypeVar import jax import jax.numpy as jnp @@ -28,7 +28,7 @@ resolve_activation, ) from serket._src.nn.initialization import DType, InitType, resolve_init -from serket._src.utils import maybe_lazy_call, maybe_lazy_init, positive_int_cb +from serket._src.utils import maybe_lazy_call, maybe_lazy_init, positive_int_cb, tuplify T = TypeVar("T") @@ -48,32 +48,6 @@ def is_lazy_init(_, in_features, *__, **___) -> bool: return in_features is None -def infer_multilinear_in_features(_, *x, **__) -> int | tuple[int, ...]: - return x[0].shape[-1] if len(x) == 1 else tuple(xi.shape[-1] for xi in x) - - -updates = dict(in_features=infer_multilinear_in_features) - - -def multilinear( - arrays: tuple[jax.Array, ...], - weight: jax.Array, - bias: jax.Array | None, -) -> jax.Array: - """Apply a linear layer to multiple inputs""" - - def generate_einsum_string(degree: int) -> str: - alpha = "".join(map(str, range(degree + 1))) - xs_string = [f"...{i}" for i in alpha[:degree]] - output_string = ",".join(xs_string) - output_string += f",{alpha[:degree+1]}->...{alpha[degree]}" - return output_string - - einsum_string = generate_einsum_string(len(arrays)) - out = jnp.einsum(einsum_string, *arrays, weight) - return out if bias is None else (out + bias) - - def general_linear( input: jax.Array, weight: jax.Array, @@ -99,119 +73,16 @@ def generate_einsum_string(*axes: tuple[int, ...]) -> str: return out if bias is None else (out + bias) -class Linear(sk.TreeClass): - """Linear layer with arbitrary number of inputs applied to last axis of each input - - Args: - in_features: number of input features for each input. accepts a tuple of ints - or a single int for single input. If ``None``, the layer is lazily - initialized. - out_features: number of output features. - key: key for the random number generator. - weight_init: function to initialize the weights. Defaults to ``glorot_uniform``. - bias_init: function to initialize the bias. Defaults to ``zeros``. - dtype: dtype of the weights and bias. defaults to ``jnp.float32``. - - Example: - Linear layer example - - >>> import jax.numpy as jnp - >>> import serket as sk - >>> import jax.random as jr - >>> input = jnp.ones((1, 5)) - >>> layer = sk.nn.Linear(5, 6, key=jr.PRNGKey(0)) - >>> layer(input).shape - (1, 6) - - Example: - Bilinear layer example - - >>> import jax.numpy as jnp - >>> import jax.random as jr - >>> import serket as sk - >>> input_1 = jnp.ones((1, 5)) # 5 features - >>> input_2 = jnp.ones((1, 6)) # 6 features - >>> key = jr.PRNGKey(0) - >>> layer = sk.nn.Linear((5, 6), 7, key=key) - >>> layer(input_1, input_2).shape - (1, 7) - - Note: - :class:`.Linear` supports lazy initialization, meaning that the weights and - biases are not initialized until the first call to the layer. This is - useful when the input shape is not known at initialization time. - - To use lazy initialization, pass ``None`` as the ``in_features`` argument - and use the ``.at["__call__"]`` attribute to call the layer - with an input of known shape. - - >>> import serket as sk - >>> import jax.numpy as jnp - >>> import jax.random as jr - >>> import jax - >>> class Linears(sk.TreeClass): - ... def __init__(self, *, key: jax.Array): - ... k1, k2 = jr.split(key) - ... self.l1 = sk.nn.Linear(None, 32, key=k1) - ... self.l2 = sk.nn.Linear(32, 10, key=k2) - ... def __call__(self, x: jax.Array, y: jax.Array) -> jax.Array: - ... return self.l2(jax.nn.relu(self.l1(x, y))) - >>> key = jr.PRNGKey(0) - >>> lazy_layer = Linears(key=key) - >>> input_1 = jnp.ones([100, 28]) - >>> input_2 = jnp.ones([100, 56]) - >>> _, material_layer = lazy_layer.at["__call__"](input_1, input_2) - >>> material_layer.l1.in_features - (28, 56) - - Note: - The difference between :class:`.Linear` and :class:`.GeneralLinear` is that - :class:`.Linear` applies the linear layer to the last axis of each input - for possibly multiple inputs, while :class:`.GeneralLinear` applies the - linear layer to the axes specified by ``in_axes`` of a single input. - """ - - @ft.partial(maybe_lazy_init, is_lazy=is_lazy_init) - def __init__( - self, - in_features: int | tuple[int, ...] | None, - out_features: int, - *, - key: jax.Array, - weight_init: InitType = "glorot_uniform", - bias_init: InitType = "zeros", - dtype: DType = jnp.float32, - ): - if not isinstance(in_features, (int, tuple, type(None))): - raise TypeError(f"{in_features=} must be `None`, `tuple` or `int`") - if isinstance(in_features, int): - in_features = (in_features,) - self.in_features = in_features - self.out_features = out_features - self.weight_init = weight_init - self.bias_init = bias_init - k1, k2 = jr.split(key) - weight_shape = (*self.in_features, self.out_features) - self.weight = resolve_init(weight_init)(k1, weight_shape, dtype) - self.bias = resolve_init(bias_init)(k2, (out_features,), dtype) - - @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) - def __call__(self, *arrays) -> jax.Array: - if len(arrays) != len(self.in_features): - raise ValueError(f"{len(arrays)=} != {len(self.in_features)=}") - return multilinear(arrays, self.weight, self.bias) - - def infer_in_features(instance, x, **__) -> tuple[int, ...]: in_axes = getattr(instance, "in_axes", ()) - return tuple(x.shape[i] for i in in_axes) + return tuple(x.shape[i] for i in tuplify(in_axes)) updates = dict(in_features=infer_in_features) -class GeneralLinear(sk.TreeClass): - """Apply a Linear Layer to input at in_axes +class Linear(sk.TreeClass): + """Apply a Linear Layer to input at ``in_axes`` Args: in_features: number of input features corresponding to in_axes @@ -223,22 +94,22 @@ class GeneralLinear(sk.TreeClass): dtype: dtype of the weights and biases. defaults to ``jnp.float32``. Example: - Apply linear layer to first and second axes of input + Apply :class:`.Linear` layer to first and second axes of input >>> import jax.numpy as jnp >>> import serket as sk >>> import jax.random as jr - >>> x = jnp.ones([1, 2, 3, 4]) - >>> in_features = (1, 2) + >>> input = jnp.ones([1, 2, 3, 4]) + >>> in_features = (1, 2) # number of input features corresponding to ``in_axes`` >>> out_features = 5 - >>> in_axes = (0, 1) + >>> in_axes = (0, 1) # which axes to apply the linear layer to >>> key = jr.PRNGKey(0) - >>> layer = sk.nn.GeneralLinear(in_features, out_features, in_axes=in_axes, key=key) - >>> layer(x).shape + >>> layer = sk.nn.Linear(in_features, out_features, in_axes=in_axes, key=key) + >>> layer(input).shape (3, 4, 5) Note: - :class:`.GeneralLinear` supports lazy initialization, meaning that the weights and + :class:`.Linear` supports lazy initialization, meaning that the weights and biases are not initialized until the first call to the layer. This is useful when the input shape is not known at initialization time. @@ -249,49 +120,43 @@ class GeneralLinear(sk.TreeClass): >>> import jax.numpy as jnp >>> import serket as sk >>> import jax.random as jr - >>> lazy_linear = sk.nn.GeneralLinear(None, 12, in_axes=(0, 2), key=jr.PRNGKey(0)) - >>> _, material_linear = lazy_linear.at['__call__'](jnp.ones((10, 5, 4))) + >>> key = jr.PRNGKey(0) + >>> input = jnp.ones((10, 5, 4)) + >>> lazy_linear = sk.nn.Linear(None, 12, in_axes=(0, 2), key=jr.PRNGKey(0)) + >>> _, material_linear = lazy_linear.at['__call__'](input) >>> material_linear.in_features (10, 4) - >>> material_linear(jnp.ones((10, 5, 4))).shape - (5, 12) - - Note: - The difference between :class:`.Linear` and :class:`.GeneralLinear` is that - :class:`.Linear` applies the linear layer to the last axis of each input - for possibly multiple inputs, while :class:`.GeneralLinear` applies the - linear layer to the axes specified by ``in_axes`` of a single input. """ @ft.partial(maybe_lazy_init, is_lazy=is_lazy_init) def __init__( self, - in_features: tuple[int, ...] | None, + in_features: int | tuple[int, ...] | None, out_features: int, *, key: jax.Array, - in_axes: tuple[int, ...], + in_axes: int | tuple[int, ...] = -1, weight_init: InitType = "glorot_uniform", bias_init: InitType = "zeros", dtype: DType = jnp.float32, ): - self.in_features = in_features + self.in_features = tuplify(in_features) self.out_features = out_features - self.in_axes = in_axes + self.in_axes = tuplify(in_axes) self.weight_init = weight_init self.bias_init = bias_init - if not (all(isinstance(i, int) for i in in_features)): - raise TypeError(f"Expected tuple of ints for {in_features=}") + if not (all(isinstance(i, int) for i in self.in_features)): + raise TypeError(f"Expected tuple of ints for {self.in_features=}") - if not (all(isinstance(i, int) for i in in_axes)): - raise TypeError(f"Expected tuple of ints for {in_axes=}") + if not (all(isinstance(i, int) for i in self.in_axes)): + raise TypeError(f"Expected tuple of ints for {self.in_axes=}") - if len(in_axes) != len(in_features): - raise ValueError(f"{len(in_axes)=} != {len(in_features)=}") + if len(self.in_axes) != len(self.in_features): + raise ValueError(f"{len(self.in_axes)=} != {len(self.in_features)=}") k1, k2 = jr.split(key) - weight_shape = (*in_features, out_features) + weight_shape = (*self.in_features, self.out_features) self.weight = resolve_init(weight_init)(k1, weight_shape, dtype) self.bias = resolve_init(bias_init)(k2, (self.out_features,), dtype) @@ -347,94 +212,16 @@ def __call__(self, input: jax.Array) -> jax.Array: if not jnp.issubdtype(input.dtype, jnp.integer): raise TypeError(f"{input.dtype=} is not a subdtype of integer") - return jnp.take(self.weight, input, axis=0) - - -class FNN(sk.TreeClass): - """Fully connected neural network - - Args: - layers: Sequence of layer sizes - key: Random number generator key. - act: Activation function. Defaults to ``tanh``. - weight_init: Weight initializer function. Defaults to ``glorot_uniform``. - bias_init: Bias initializer function. Defaults to ``zeros``. - dtype: dtype of the weights and biases. defaults to ``jnp.float32``. - - Example: - >>> import jax.numpy as jnp - >>> import serket as sk - >>> import jax.random as jr - >>> key = jr.PRNGKey(0) - >>> layer = sk.nn.FNN([10, 5, 2], key=key) - >>> input = jnp.ones((3, 10)) - >>> layer(input).shape - (3, 2) - - Note: - - layers argument yields ``len(layers) - 1`` linear layers with required - ``len(layers)-2`` activation functions, for example, ``layers=[10, 5, 2]`` - yields 2 linear layers with weight shapes (10, 5) and (5, 2) - and single activation function is applied between them. - - :class:`.FNN` uses python ``for`` loop to apply layers and activation functions. - - Note: - :class:`.FNN` supports lazy initialization, meaning that the weights and - biases are not initialized until the first call to the layer. This is - useful when the input shape is not known at initialization time. - - To use lazy initialization, add ``None`` as the the first element of the - ``layers`` argument and use the ``.at["__call__"]`` attribute - to call the layer with an input of known shape. - - >>> import serket as sk - >>> import jax.numpy as jnp - >>> import jax.random as jr - >>> key = jr.PRNGKey(0) - >>> lazy_layer = sk.nn.FNN([None, 10, 2, 1], key=key) - >>> _, material_layer = lazy_layer.at['__call__'](jnp.ones([1, 10])) - >>> material_layer.linear_0.in_features - (10,) - """ - - def __init__( - self, - layers: Sequence[int], - *, - key: jax.Array, - act: ActivationType = "tanh", - weight_init: InitType = "glorot_uniform", - bias_init: InitType = "zeros", - dtype: DType = jnp.float32, - ): - keys = jr.split(key, len(layers) - 1) - self.act = resolve_activation(act) - - for i, (di, do, ki) in enumerate(zip(layers[:-1], layers[1:], keys)): - layer = Linear( - in_features=di, - out_features=do, - key=ki, - weight_init=weight_init, - bias_init=bias_init, - dtype=dtype, - ) - setattr(self, f"linear_{i}", layer) - - def __call__(self, input: jax.Array) -> jax.Array: - vs = vars(self) - *layers, last = [vs[k] for k in vs if k.startswith("linear_")] - for li in layers: - input = self.act(li(input)) - return last(input) + return self.weight[input] -def _scan_linear( +def scan_linear( input: jax.Array, weight: Batched[jax.Array], bias: Batched[jax.Array] | None, act: ActivationFunctionType, ) -> jax.Array: + # reduce the ``jaxpr`` size by using ``scan`` if bias is None: def scan_func(x: jax.Array, weight: Batched[jax.Array]): @@ -448,8 +235,8 @@ def scan_func(x: jax.Array, weight_bias: Batched[jax.Array]): return act(x @ weight + bias), None weight_bias = jnp.concatenate([weight, bias[:, :, None]], axis=-1) - input, _ = jax.lax.scan(scan_func, input, weight_bias) - return input + output, _ = jax.lax.scan(scan_func, input, weight_bias) + return output class MLP(sk.TreeClass): @@ -482,29 +269,6 @@ class MLP(sk.TreeClass): input layer (1, 4), one intermediate layer (4, 4), and one output layer (4, 2) = ``num_hidden_layers + 1`` - Note: - - :class:`.MLP` exploits same input/out size for intermediate layers to use - ``jax.lax.scan``, which offers better compilation speed for large - number of layers and producing a smaller ``jaxpr`` but could be - slower than equivalent :class:`.FNN` for small number of layers. - - The following compares the size of ``jaxpr`` for :class:`.MLP` and :class:`.FNN` - of equivalent layers. - - >>> import jax - >>> import jax.numpy as jnp - >>> import jax.random as jr - >>> import serket as sk - >>> import numpy.testing as npt - >>> key = jr.PRNGKey(0) - >>> fnn = sk.nn.FNN([1] + [4] * 100 + [2], key=key) - >>> mlp = sk.nn.MLP(1, 2, hidden_features=4, num_hidden_layers=100, key=key) - >>> input = jnp.ones((100, 1)) - >>> fnn_jaxpr = jax.make_jaxpr(fnn)(input) - >>> mlp_jaxpr = jax.make_jaxpr(mlp)(input) - >>> npt.assert_allclose(fnn(input), mlp(input), atol=1e-6) - >>> assert len(fnn_jaxpr.jaxpr.eqns) > len(mlp_jaxpr.jaxpr.eqns) - Note: :class:`.MLP` supports lazy initialization, meaning that the weights and biases are not initialized until the first call to the layer. This is @@ -521,7 +285,7 @@ class MLP(sk.TreeClass): >>> lazy_layer = sk.nn.MLP(None, 1, num_hidden_layers=2, hidden_features=10, key=key) >>> input = jnp.ones([1, 10]) >>> _, material_layer = lazy_layer.at['__call__'](input) - >>> material_layer.linear_i.in_features + >>> material_layer.in_linear.in_features (10,) """ @@ -545,17 +309,17 @@ def __init__( self.act = resolve_activation(act) kwargs = dict(weight_init=weight_init, bias_init=bias_init, dtype=dtype) + @jax.vmap def batched_linear(key: jax.Array) -> Batched[Linear]: layer = Linear(hidden_features, hidden_features, key=key, **kwargs) - # mask non-jaxtype on return return sk.tree_mask(layer) - self.linear_i = Linear(in_features, hidden_features, key=keys[0], **kwargs) - self.linear_h = sk.tree_unmask(jax.vmap(batched_linear)(keys[1:-1])) - self.linear_o = Linear(hidden_features, out_features, key=keys[-1], **kwargs) + self.in_linear = Linear(in_features, hidden_features, key=keys[0], **kwargs) + self.hidden_linear = sk.tree_unmask(batched_linear(keys[1:-1])) + self.out_linear = Linear(hidden_features, out_features, key=keys[-1], **kwargs) def __call__(self, input: jax.Array) -> jax.Array: - input = self.act(self.linear_i(input)) - weight_h, bias_h = self.linear_h.weight, self.linear_h.bias - input = _scan_linear(input, weight_h, bias_h, self.act) - return self.linear_o(input) + input = self.act(self.in_linear(input)) + weight_h, bias_h = self.hidden_linear.weight, self.hidden_linear.bias + input = scan_linear(input, weight_h, bias_h, self.act) + return self.out_linear(input) diff --git a/serket/_src/utils.py b/serket/_src/utils.py index be41ab7..e0ec6fc 100644 --- a/serket/_src/utils.py +++ b/serket/_src/utils.py @@ -398,6 +398,10 @@ def get_params(func: MethodType) -> tuple[inspect.Parameter, ...]: """ +def tuplify(value: T) -> T | tuple[T]: + return value if isinstance(value, tuple) else (value,) + + def maybe_lazy_init( func: Callable[P, T], is_lazy: Callable[..., bool], diff --git a/serket/nn/__init__.py b/serket/nn/__init__.py index ee17e39..d4b9456 100644 --- a/serket/nn/__init__.py +++ b/serket/nn/__init__.py @@ -98,7 +98,7 @@ dropout_nd, random_cutout_nd, ) -from serket._src.nn.linear import FNN, MLP, Embedding, GeneralLinear, Identity, Linear +from serket._src.nn.linear import MLP, Embedding, Identity, Linear from serket._src.nn.normalization import ( BatchNorm, GroupNorm, diff --git a/tests/test_linear.py b/tests/test_linear.py index 4b6bcde..fe6df8d 100644 --- a/tests/test_linear.py +++ b/tests/test_linear.py @@ -14,7 +14,6 @@ import jax import jax.numpy as jnp -import jax.tree_util as jtu import numpy.testing as npt import pytest @@ -30,132 +29,54 @@ def test_embed(): table(jnp.array([9.0])) -def test_linear(): - x = jnp.linspace(0, 1, 100)[:, None] - y = x**3 + jax.random.uniform(jax.random.PRNGKey(0), (100, 1)) * 0.01 - - @jax.value_and_grad - def loss_func(NN, x, y): - NN = sk.tree_unmask(NN) - return jnp.mean((NN(x) - y) ** 2) - - @jax.jit - def update(NN, x, y): - value, grad = loss_func(NN, x, y) - return value, jtu.tree_map(lambda x, g: x - 1e-3 * g, NN, grad) - - nn = sk.nn.FNN( - [1, 128, 128, 1], - act="relu", - weight_init="he_normal", - bias_init="ones", - key=jax.random.PRNGKey(0), - ) - - nn = sk.tree_mask(nn) - - for _ in range(20_000): - value, nn = update(nn, x, y) - - npt.assert_allclose(jnp.array(4.933563e-05), value, atol=1e-3) - - layer = sk.nn.Linear(1, 1, bias_init=None, key=jax.random.PRNGKey(0)) - w = jnp.array([[-0.31568417]]) - layer = layer.at["weight"].set(w) - y = jnp.array([[-0.31568417]]) - npt.assert_allclose(layer(jnp.array([[1.0]])), y) - - layer = sk.nn.Linear(None, 1, bias_init="zeros", key=jax.random.PRNGKey(0)) - _, layer = layer.at["__call__"](jnp.ones([100, 2])) - assert layer.in_features == (2,) - - -def test_bilinear(): - W = jnp.array( - [ - [[-0.246, -0.3016], [-0.5532, 0.4251], [0.0983, 0.4425], [-0.1003, 0.1923]], - [[0.4584, -0.5352], [-0.449, 0.1154], [-0.3347, 0.3776], [0.2751, -0.0284]], - [ - [-0.4469, 0.3681], - [-0.2142, -0.0545], - [-0.5095, -0.2242], - [-0.4428, 0.2033], - ], - ] - ) - - x1 = jnp.array([[-0.7676, -0.7205, -0.0586]]) - x2 = jnp.array([[0.4600, -0.2508, 0.0115, 0.6155]]) - y = jnp.array([[-0.3001916, 0.28336674]]) - layer = sk.nn.Linear((3, 4), 2, bias_init=None, key=jax.random.PRNGKey(0)) - layer = layer.at["weight"].set(W) - - npt.assert_allclose(y, layer(x1, x2), atol=1e-4) - - layer = sk.nn.Linear((3, 4), 2, bias_init="zeros", key=jax.random.PRNGKey(0)) - layer = layer.at["weight"].set(W) - - npt.assert_allclose(y, layer(x1, x2), atol=1e-4) - - def test_identity(): x = jnp.array([[1, 2, 3], [4, 5, 6]]) layer = sk.nn.Identity() npt.assert_allclose(x, layer(x)) -def test_multi_linear(): - x = jnp.linspace(0, 1, 100)[:, None] - lhs = sk.nn.Linear(1, 10, key=jax.random.PRNGKey(0)) - rhs = sk.nn.Linear((1,), 10, key=jax.random.PRNGKey(0)) - npt.assert_allclose(lhs(x), rhs(x), atol=1e-4) - - with pytest.raises(TypeError): - sk.nn.Linear([1, 2], 10, key=jax.random.PRNGKey(0)) - - def test_general_linear(): x = jnp.ones([1, 2, 3, 4]) - layer = sk.nn.GeneralLinear( + layer = sk.nn.Linear( in_features=(1, 2), in_axes=(0, 1), out_features=5, key=jax.random.PRNGKey(0) ) assert layer(x).shape == (3, 4, 5) x = jnp.ones([1, 2, 3, 4]) - layer = sk.nn.GeneralLinear( + layer = sk.nn.Linear( in_features=(1, 2), in_axes=(0, 1), out_features=5, key=jax.random.PRNGKey(0) ) assert layer(x).shape == (3, 4, 5) x = jnp.ones([1, 2, 3, 4]) - layer = sk.nn.GeneralLinear( + layer = sk.nn.Linear( in_features=(1, 2), in_axes=(0, -3), out_features=5, key=jax.random.PRNGKey(0) ) assert layer(x).shape == (3, 4, 5) x = jnp.ones([1, 2, 3, 4]) - layer = sk.nn.GeneralLinear( + layer = sk.nn.Linear( in_features=(2, 3), in_axes=(1, -2), out_features=5, key=jax.random.PRNGKey(0) ) assert layer(x).shape == (1, 4, 5) - with pytest.raises(TypeError): - sk.nn.GeneralLinear( + with pytest.raises(ValueError): + sk.nn.Linear( in_features=2, in_axes=(1, -2), out_features=5, key=jax.random.PRNGKey(0) ) - with pytest.raises(TypeError): - sk.nn.GeneralLinear( + with pytest.raises(ValueError): + sk.nn.Linear( in_features=(2, 3), in_axes=2, out_features=5, key=jax.random.PRNGKey(0) ) with pytest.raises(ValueError): - sk.nn.GeneralLinear( + sk.nn.Linear( in_features=(1,), in_axes=(0, -3), out_features=5, key=jax.random.PRNGKey(0) ) with pytest.raises(TypeError): - sk.nn.GeneralLinear( + sk.nn.Linear( in_features=(1, "s"), in_axes=(0, -3), out_features=5, @@ -163,7 +84,7 @@ def test_general_linear(): ) with pytest.raises(TypeError): - sk.nn.GeneralLinear( + sk.nn.Linear( in_features=(1, 2), in_axes=(0, "s"), out_features=3, @@ -174,29 +95,6 @@ def test_general_linear(): def test_mlp(): x = jnp.linspace(0, 1, 100)[:, None] - fnn = sk.nn.FNN([1, 2, 1], key=jax.random.PRNGKey(0)) - mlp = sk.nn.MLP( - in_features=1, - out_features=1, - hidden_features=2, - num_hidden_layers=1, - key=jax.random.PRNGKey(0), - ) - - npt.assert_allclose(fnn(x), mlp(x), atol=1e-4) - - fnn = sk.nn.FNN([1, 2, 2, 1], act="tanh", key=jax.random.PRNGKey(0)) - mlp = sk.nn.MLP( - in_features=1, - out_features=1, - hidden_features=2, - num_hidden_layers=2, - act="tanh", - key=jax.random.PRNGKey(0), - ) - - npt.assert_allclose(fnn(x), mlp(x), atol=1e-4) - x = jax.random.normal(jax.random.PRNGKey(0), (10, 1)) w1 = jax.random.normal(jax.random.PRNGKey(1), (1, 10)) w2 = jax.random.normal(jax.random.PRNGKey(2), (10, 10)) @@ -208,23 +106,6 @@ def test_mlp(): y = jax.nn.tanh(y) y = y @ w3 - layer = sk.nn.FNN( - [1, 10, 10, 4], - act="tanh", - bias_init=None, - key=jax.random.PRNGKey(0), - ) - layer = ( - layer.at["linear_0"]["weight"] - .set(w1) - .at["linear_1"]["weight"] - .set(w2) - .at["linear_2"]["weight"] - .set(w3) - ) - - npt.assert_allclose(layer(x), y) - layer = sk.nn.MLP( 1, 4, @@ -235,22 +116,8 @@ def test_mlp(): key=jax.random.PRNGKey(0), ) - layer = layer.at["linear_i"]["weight"].set(w1) - layer = layer.at["linear_h"]["weight"].set(w2[None]) - layer = layer.at["linear_o"]["weight"].set(w3) + layer = layer.at["in_linear"]["weight"].set(w1) + layer = layer.at["hidden_linear"]["weight"].set(w2[None]) + layer = layer.at["out_linear"]["weight"].set(w3) npt.assert_allclose(layer(x), y) - - -def test_fnn_mlp(): - fnn = sk.nn.FNN(layers=[2, 4, 4, 2], act="relu", key=jax.random.PRNGKey(0)) - mlp = sk.nn.MLP( - 2, - 2, - hidden_features=4, - num_hidden_layers=2, - act="relu", - key=jax.random.PRNGKey(0), - ) - x = jax.random.normal(jax.random.PRNGKey(0), (10, 2)) - npt.assert_allclose(fnn(x), mlp(x))