From 6770dbd8fd48bfcad541b401be8b3192fbcc692a Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Mon, 14 Aug 2023 22:31:24 +0300 Subject: [PATCH] `act_func` -> `act` --- CHANEGLOG.md | 1 + docs/notebooks/layers_overview.ipynb | 8 +- serket/nn/activation.py | 30 +++--- serket/nn/linear.py | 52 +++++----- serket/nn/recurrent.py | 136 +++++++++++++-------------- tests/test_fully_connected.py | 12 +-- tests/test_linear.py | 2 +- tests/test_rnn.py | 6 +- 8 files changed, 124 insertions(+), 123 deletions(-) diff --git a/CHANEGLOG.md b/CHANEGLOG.md index e1e8034..f808983 100644 --- a/CHANEGLOG.md +++ b/CHANEGLOG.md @@ -12,6 +12,7 @@ - `***_init_func` -> `***_init` shorter and more concise - `gamma_init_func` -> `weight_init` - `beta_init_func` -> `bias_init` + - `act_func` -> `act` - `MLP` produces smaller `jaxprs` and are faster to compile. for my use case -higher order differentiation through `PINN`- the new `MLP` is faster to compile. - `kernel_dilation` -> `dilation` - `input_dilation` -> Removed. diff --git a/docs/notebooks/layers_overview.ipynb b/docs/notebooks/layers_overview.ipynb index b028e67..61947cd 100644 --- a/docs/notebooks/layers_overview.ipynb +++ b/docs/notebooks/layers_overview.ipynb @@ -134,10 +134,10 @@ "import jax\n", "\n", "# 1) activation function with a string\n", - "linear = sk.nn.FNN([1, 1], act_func=\"relu\")\n", + "linear = sk.nn.FNN([1, 1], act=\"relu\")\n", "\n", "# 2) activation function with a function\n", - "linear = sk.nn.FNN([1, 1], act_func=jax.nn.relu)\n", + "linear = sk.nn.FNN([1, 1], act=jax.nn.relu)\n", "\n", "\n", "@sk.autoinit\n", @@ -149,11 +149,11 @@ "\n", "\n", "# 3) activation function with a class\n", - "linear = sk.nn.FNN([1, 1], act_func=MyTrainableActivation())\n", + "linear = sk.nn.FNN([1, 1], act=MyTrainableActivation())\n", "\n", "# 4) activation function with a registered class\n", "sk.def_act_entry(\"my_act\", MyTrainableActivation)\n", - "linear = sk.nn.FNN([1, 1], act_func=\"my_act\")" + "linear = sk.nn.FNN([1, 1], act=\"my_act\")" ] }, { diff --git a/serket/nn/activation.py b/serket/nn/activation.py index f28928b..bc5d6a0 100644 --- a/serket/nn/activation.py +++ b/serket/nn/activation.py @@ -388,22 +388,22 @@ def __call__(self, x: jax.typing.ArrayLike) -> jax.Array: ... -def resolve_activation(act_func: ActivationType) -> ActivationFunctionType: +def resolve_activation(act: ActivationType) -> ActivationFunctionType: # in case the user passes a trainable activation function # we need to make a copy of it to avoid unpredictable side effects - if isinstance(act_func, str): - if act_func in act_map: - return act_map[act_func]() - raise ValueError(f"Unknown {act_func=}, available activations: {list(act_map)}") - return act_func + if isinstance(act, str): + if act in act_map: + return act_map[act]() + raise ValueError(f"Unknown {act=}, available activations: {list(act_map)}") + return act -def def_act_entry(key: str, act_func: ActivationClassType) -> None: +def def_act_entry(key: str, act: ActivationClassType) -> None: """Register a custom activation function key for use in ``serket`` layers. Args: key: The key to register the function under. - act_func: a class with a ``__call__`` method that takes a single argument + act: a class with a ``__call__`` method that takes a single argument and returns a ``jax`` array. Note: @@ -412,7 +412,7 @@ def def_act_entry(key: str, act_func: ActivationClassType) -> None: Note: By design, activation functions can be passed directly to ``serket`` layers - with the ``act_func`` argument. This function is useful if you want to + with the ``act`` argument. This function is useful if you want to represent activation functions as a string in a configuration file. Example: @@ -425,15 +425,15 @@ def def_act_entry(key: str, act_func: ActivationClassType) -> None: ... return x * self.my_param >>> sk.def_act_entry("my_act", MyTrainableActivation) >>> x = jnp.ones((1, 1)) - >>> sk.nn.FNN([1, 1, 1], act_func="my_act", weight_init="ones", bias_init=None)(x) + >>> sk.nn.FNN([1, 1, 1], act="my_act", weight_init="ones", bias_init=None)(x) Array([[10.]], dtype=float32) """ if key in act_map: raise ValueError(f"`init_key` {key=} already registered") - if not isinstance(act_func, type): - raise ValueError(f"Expected a class, got {act_func=}") - if not callable(act_func): - raise ValueError(f"Expected a class with a `__call__` method, got {act_func=}") + if not isinstance(act, type): + raise ValueError(f"Expected a class, got {act=}") + if not callable(act): + raise ValueError(f"Expected a class with a `__call__` method, got {act=}") - act_map[key] = act_func + act_map[key] = act diff --git a/serket/nn/linear.py b/serket/nn/linear.py index bdc48e1..7080139 100644 --- a/serket/nn/linear.py +++ b/serket/nn/linear.py @@ -400,7 +400,7 @@ class FNN(sk.TreeClass): Args: layers: Sequence of layer sizes - act_func: a single Activation function to be applied between layers or + act: a single Activation function to be applied between layers or ``len(layers)-2`` Sequence of activation functions applied between layers. weight_init: Weight initializer function. @@ -445,7 +445,7 @@ def __init__( self, layers: Sequence[int], *, - act_func: ActivationType | tuple[ActivationType, ...] = "tanh", + act: ActivationType | tuple[ActivationType, ...] = "tanh", weight_init: InitType = "glorot_uniform", bias_init: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), @@ -454,13 +454,13 @@ def __init__( keys = jr.split(key, len(layers) - 1) num_hidden_layers = len(layers) - 2 - if isinstance(act_func, tuple): - if len(act_func) != (num_hidden_layers): - raise ValueError(f"{len(act_func)=} != {(num_hidden_layers)=}") + if isinstance(act, tuple): + if len(act) != (num_hidden_layers): + raise ValueError(f"{len(act)=} != {(num_hidden_layers)=}") - self.act_func = tuple(resolve_activation(act) for act in act_func) + self.act = tuple(resolve_activation(act) for act in act) else: - self.act_func = resolve_activation(act_func) + self.act = resolve_activation(act) self.layers = tuple( Linear( @@ -477,12 +477,12 @@ def __init__( def __call__(self, x: jax.Array, **k) -> jax.Array: *layers, last = self.layers - if isinstance(self.act_func, tuple): - for ai, li in zip(self.act_func, layers): + if isinstance(self.act, tuple): + for ai, li in zip(self.act, layers): x = ai(li(x)) else: for li in layers: - x = self.act_func(li(x)) + x = self.act(li(x)) return last(x) @@ -490,19 +490,19 @@ def __call__(self, x: jax.Array, **k) -> jax.Array: def _scan_batched_layer_with_single_activation( x: Batched[jax.Array], layer: Batched[Linear], - act_func: ActivationFunctionType, + act: ActivationFunctionType, ) -> jax.Array: if layer.bias is None: def scan_func(x: jax.Array, bias: Batched[jax.Array]): - return act_func(x + bias), None + return act(x + bias), None x, _ = jax.lax.scan(scan_func, x, layer.weight) return x def scan_func(x: jax.Array, weight_bias: Batched[jax.Array]): weight, bias = weight_bias[..., :-1], weight_bias[..., -1] - return act_func(x @ weight + bias), None + return act(x @ weight + bias), None weight_bias = jnp.concatenate([layer.weight, layer.bias[:, :, None]], axis=-1) x, _ = jax.lax.scan(scan_func, x, weight_bias) @@ -512,13 +512,13 @@ def scan_func(x: jax.Array, weight_bias: Batched[jax.Array]): def _scan_batched_layer_with_multiple_activations( x: Batched[jax.Array], layer: Batched[Linear], - act_func: Sequence[ActivationFunctionType], + act: Sequence[ActivationFunctionType], ) -> jax.Array: if layer.bias is None: def scan_func(x_index: tuple[jax.Array, int], weight: Batched[jax.Array]): x, index = x_index - x = jax.lax.switch(index, act_func, x @ weight) + x = jax.lax.switch(index, act, x @ weight) return (x, index + 1), None (x, _), _ = jax.lax.scan(scan_func, (x, 0), layer.weight) @@ -527,7 +527,7 @@ def scan_func(x_index: tuple[jax.Array, int], weight: Batched[jax.Array]): def scan_func(x_index: jax.Array, weight_bias: Batched[jax.Array]): x, index = x_index weight, bias = weight_bias[..., :-1], weight_bias[..., -1] - x = jax.lax.switch(index, act_func, x @ weight + bias) + x = jax.lax.switch(index, act, x @ weight + bias) return [x, index + 1], None weight_bias = jnp.concatenate([layer.weight, layer.bias[:, :, None]], axis=-1) @@ -543,7 +543,7 @@ class MLP(sk.TreeClass): out_features: Number of output features. hidden_size: Number of hidden units in each hidden layer. num_hidden_layers: Number of hidden layers including the output layer. - act_func: Activation function. + act: Activation function. weight_init: Weight initialization function. bias_init: Bias initialization function. key: Random number generator key. @@ -610,7 +610,7 @@ def __init__( *, hidden_size: int, num_hidden_layers: int, - act_func: ActivationType | tuple[ActivationType, ...] = "tanh", + act: ActivationType | tuple[ActivationType, ...] = "tanh", weight_init: InitType = "glorot_uniform", bias_init: InitType = "zeros", key: jr.KeyArray = jr.PRNGKey(0), @@ -621,12 +621,12 @@ def __init__( keys = jr.split(key, num_hidden_layers + 1) - if isinstance(act_func, tuple): - if len(act_func) != (num_hidden_layers): - raise ValueError(f"{len(act_func)=} != {(num_hidden_layers)=}") - self.act_func = tuple(resolve_activation(act) for act in act_func) + if isinstance(act, tuple): + if len(act) != (num_hidden_layers): + raise ValueError(f"{len(act)=} != {(num_hidden_layers)=}") + self.act = tuple(resolve_activation(act) for act in act) else: - self.act_func = resolve_activation(act_func) + self.act = resolve_activation(act) kwargs = dict(weight_init=weight_init, bias_init=bias_init, dtype=dtype) @@ -643,13 +643,13 @@ def batched_linear(key: jr.KeyArray) -> Batched[Linear]: def __call__(self, x: jax.Array, **k) -> jax.Array: l0, lm, lh = self.layers - if isinstance(self.act_func, tuple): - a0, *ah = self.act_func + if isinstance(self.act, tuple): + a0, *ah = self.act x = a0(l0(x)) x = _scan_batched_layer_with_multiple_activations(x, lm, ah) return lh(x) - a0 = self.act_func + a0 = self.act x = a0(l0(x)) x = _scan_batched_layer_with_single_activation(x, lm, a0) return lh(x) diff --git a/serket/nn/recurrent.py b/serket/nn/recurrent.py index 27cc42b..31d8fae 100644 --- a/serket/nn/recurrent.py +++ b/serket/nn/recurrent.py @@ -100,7 +100,7 @@ def __init__( weight_init: InitType = "glorot_uniform", bias_init: InitType = "zeros", recurrent_weight_init: InitType = "orthogonal", - act_func: ActivationType = jax.nn.tanh, + act: ActivationType = jax.nn.tanh, key: jr.KeyArray = jr.PRNGKey(0), dtype: DType = jnp.float32, ): @@ -112,7 +112,7 @@ def __init__( weight_init: the function to use to initialize the weights bias_init: the function to use to initialize the bias recurrent_weight_init: the function to use to initialize the recurrent weights - act_func: the activation function to use for the hidden state update + act: the activation function to use for the hidden state update key: the key to use to initialize the weights dtype: dtype of the weights and biases. defaults to ``jnp.float32``. @@ -150,7 +150,7 @@ def __init__( self.in_features = positive_int_cb(in_features) self.hidden_features = positive_int_cb(hidden_features) - self.act_func = resolve_activation(act_func) + self.act = resolve_activation(act) i2h = sk.nn.Linear( in_features, @@ -182,7 +182,7 @@ def __call__(self, x: jax.Array, state: SimpleRNNState, **k) -> SimpleRNNState: ih = jnp.concatenate([x, state.hidden_state], axis=-1) h = ih @ self.in_hidden_to_hidden_weight + self.in_hidden_to_hidden_bias - return SimpleRNNState(self.act_func(h)) + return SimpleRNNState(self.act(h)) @property def spatial_ndim(self) -> int: @@ -201,7 +201,7 @@ class DenseCell(RNNCell): hidden_features: the number of hidden features weight_init: the function to use to initialize the weights bias_init: the function to use to initialize the bias - act_func: the activation function to use for the hidden state update, + act: the activation function to use for the hidden state update, use `None` for no activation key: the key to use to initialize the weights dtype: dtype of the weights and biases. defaults to ``jnp.float32``. @@ -244,13 +244,13 @@ def __init__( *, weight_init: InitType = "glorot_uniform", bias_init: InitType = "zeros", - act_func: ActivationType = jax.nn.tanh, + act: ActivationType = jax.nn.tanh, key: jr.KeyArray = jr.PRNGKey(0), dtype: DType = jnp.float32, ): self.in_features = positive_int_cb(in_features) self.hidden_features = positive_int_cb(hidden_features) - self.act_func = resolve_activation(act_func) + self.act = resolve_activation(act) self.in_to_hidden = sk.nn.Linear( in_features, @@ -268,7 +268,7 @@ def __call__(self, x: jax.Array, state: DenseState, **k) -> DenseState: if not isinstance(state, DenseState): raise TypeError(f"Expected {state=} to be an instance of `DenseState`") - h = self.act_func(self.in_to_hidden(x)) + h = self.act(self.in_to_hidden(x)) return DenseState(h) @property @@ -290,8 +290,8 @@ class LSTMCell(RNNCell): weight_init: the function to use to initialize the weights bias_init: the function to use to initialize the bias recurrent_weight_init: the function to use to initialize the recurrent weights - act_func: the activation function to use for the hidden state update - recurrent_act_func: the activation function to use for the cell state update + act: the activation function to use for the hidden state update + recurrent_act: the activation function to use for the cell state update key: the key to use to initialize the weights dtype: dtype of the weights and biases. defaults to ``jnp.float32``. @@ -338,8 +338,8 @@ def __init__( weight_init: str | Callable = "glorot_uniform", bias_init: str | Callable | None = "zeros", recurrent_weight_init: str | Callable = "orthogonal", - act_func: str | Callable[[Any], Any] | None = "tanh", - recurrent_act_func: ActivationType | None = "sigmoid", + act: str | Callable[[Any], Any] | None = "tanh", + recurrent_act: ActivationType | None = "sigmoid", key: jr.KeyArray = jr.PRNGKey(0), dtype: DType = jnp.float32, ): @@ -347,8 +347,8 @@ def __init__( self.in_features = positive_int_cb(in_features) self.hidden_features = positive_int_cb(hidden_features) - self.act_func = resolve_activation(act_func) - self.recurrent_act_func = resolve_activation(recurrent_act_func) + self.act = resolve_activation(act) + self.recurrent_act = resolve_activation(recurrent_act) i2h = sk.nn.Linear( in_features, @@ -382,12 +382,12 @@ def __call__(self, x: jax.Array, state: LSTMState, **k) -> LSTMState: ih = jnp.concatenate([x, h], axis=-1) h = ih @ self.in_hidden_to_hidden_weight + self.in_hidden_to_hidden_bias i, f, g, o = jnp.split(h, 4, axis=-1) - i = self.recurrent_act_func(i) - f = self.recurrent_act_func(f) - g = self.act_func(g) - o = self.recurrent_act_func(o) + i = self.recurrent_act(i) + f = self.recurrent_act(f) + g = self.act(g) + o = self.recurrent_act(o) c = f * c + i * g - h = o * self.act_func(c) + h = o * self.act(c) return LSTMState(h, c) @property @@ -408,8 +408,8 @@ class GRUCell(RNNCell): weight_init: the function to use to initialize the weights bias_init: the function to use to initialize the bias recurrent_weight_init: the function to use to initialize the recurrent weights - act_func: the activation function to use for the hidden state update - recurrent_act_func: the activation function to use for the cell state update + act: the activation function to use for the hidden state update + recurrent_act: the activation function to use for the cell state update key: the key to use to initialize the weights dtype: dtype of the weights and biases. defaults to ``jnp.float32``. @@ -455,8 +455,8 @@ def __init__( weight_init: InitType = "glorot_uniform", bias_init: InitType = "zeros", recurrent_weight_init: InitType = "orthogonal", - act_func: ActivationType | None = "tanh", - recurrent_act_func: ActivationType | None = "sigmoid", + act: ActivationType | None = "tanh", + recurrent_act: ActivationType | None = "sigmoid", key: jr.KeyArray = jr.PRNGKey(0), dtype: DType = jnp.float32, ): @@ -464,8 +464,8 @@ def __init__( self.in_features = positive_int_cb(in_features) self.hidden_features = positive_int_cb(hidden_features) - self.act_func = resolve_activation(act_func) - self.recurrent_act_func = resolve_activation(recurrent_act_func) + self.act = resolve_activation(act) + self.recurrent_act = resolve_activation(recurrent_act) self.in_to_hidden = sk.nn.Linear( in_features, @@ -495,9 +495,9 @@ def __call__(self, x: jax.Array, state: GRUState, **k) -> GRUState: h = state.hidden_state xe, xu, xo = jnp.split(self.in_to_hidden(x), 3, axis=-1) he, hu, ho = jnp.split(self.hidden_to_hidden(h), 3, axis=-1) - e = self.recurrent_act_func(xe + he) - u = self.recurrent_act_func(xu + hu) - o = self.act_func(xo + (e * ho)) + e = self.recurrent_act(xe + he) + u = self.recurrent_act(xu + hu) + o = self.act(xo + (e * ho)) h = (1 - u) * o + u * h return GRUState(hidden_state=h) @@ -525,8 +525,8 @@ def __init__( weight_init: InitType = "glorot_uniform", bias_init: InitType = "zeros", recurrent_weight_init: InitType = "orthogonal", - act_func: ActivationType | None = "tanh", - recurrent_act_func: ActivationType | None = "hard_sigmoid", + act: ActivationType | None = "tanh", + recurrent_act: ActivationType | None = "hard_sigmoid", key: jr.KeyArray = jr.PRNGKey(0), dtype: DType = jnp.float32, ): @@ -534,8 +534,8 @@ def __init__( self.in_features = positive_int_cb(in_features) self.hidden_features = positive_int_cb(hidden_features) - self.act_func = resolve_activation(act_func) - self.recurrent_act_func = resolve_activation(recurrent_act_func) + self.act = resolve_activation(act) + self.recurrent_act = resolve_activation(recurrent_act) self.in_to_hidden = self.convolution_layer( in_features, @@ -573,12 +573,12 @@ def __call__(self, x: jax.Array, state: ConvLSTMNDState, **k) -> ConvLSTMNDState h, c = state.hidden_state, state.cell_state h = self.in_to_hidden(x) + self.hidden_to_hidden(h) i, f, g, o = jnp.split(h, 4, axis=0) - i = self.recurrent_act_func(i) - f = self.recurrent_act_func(f) - g = self.act_func(g) - o = self.recurrent_act_func(o) + i = self.recurrent_act(i) + f = self.recurrent_act(f) + g = self.act(g) + o = self.recurrent_act(o) c = f * c + i * g - h = o * self.act_func(c) + h = o * self.act(c) return ConvLSTMNDState(h, c) @property @@ -600,8 +600,8 @@ class ConvLSTM1DCell(ConvLSTMNDCell): weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function - act_func: Activation function - recurrent_act_func: Recurrent activation function + act: Activation function + recurrent_act: Recurrent activation function key: PRNG key dtype: dtype of the weights and biases. defaults to ``jnp.float32``. @@ -658,8 +658,8 @@ class FFTConvLSTM1DCell(ConvLSTMNDCell): weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function - act_func: Activation function - recurrent_act_func: Recurrent activation function + act: Activation function + recurrent_act: Recurrent activation function key: PRNG key dtype: dtype of the weights and biases. defaults to ``jnp.float32``. @@ -716,8 +716,8 @@ class ConvLSTM2DCell(ConvLSTMNDCell): weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function - act_func: Activation function - recurrent_act_func: Recurrent activation function + act: Activation function + recurrent_act: Recurrent activation function key: PRNG key dtype: dtype of the weights and biases. defaults to ``jnp.float32``. @@ -774,8 +774,8 @@ class FFTConvLSTM2DCell(ConvLSTMNDCell): weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function - act_func: Activation function - recurrent_act_func: Recurrent activation function + act: Activation function + recurrent_act: Recurrent activation function key: PRNG key dtype: dtype of the weights and biases. defaults to ``jnp.float32``. @@ -832,8 +832,8 @@ class ConvLSTM3DCell(ConvLSTMNDCell): weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function - act_func: Activation function - recurrent_act_func: Recurrent activation function + act: Activation function + recurrent_act: Recurrent activation function key: PRNG key dtype: dtype of the weights and biases. defaults to ``jnp.float32``. @@ -890,8 +890,8 @@ class FFTConvLSTM3DCell(ConvLSTMNDCell): weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function - act_func: Activation function - recurrent_act_func: Recurrent activation function + act: Activation function + recurrent_act: Recurrent activation function key: PRNG key dtype: dtype of the weights and biases. defaults to ``jnp.float32``. @@ -953,8 +953,8 @@ def __init__( weight_init: InitType = "glorot_uniform", bias_init: InitType = "zeros", recurrent_weight_init: InitType = "orthogonal", - act_func: ActivationType | None = "tanh", - recurrent_act_func: ActivationType | None = "sigmoid", + act: ActivationType | None = "tanh", + recurrent_act: ActivationType | None = "sigmoid", key: jr.KeyArray = jr.PRNGKey(0), dtype: DType = jnp.float32, ): @@ -962,8 +962,8 @@ def __init__( self.in_features = positive_int_cb(in_features) self.hidden_features = positive_int_cb(hidden_features) - self.act_func = resolve_activation(act_func) - self.recurrent_act_func = resolve_activation(recurrent_act_func) + self.act = resolve_activation(act) + self.recurrent_act = resolve_activation(recurrent_act) self.in_to_hidden = self.convolution_layer( in_features, @@ -1001,9 +1001,9 @@ def __call__(self, x: jax.Array, state: ConvGRUNDState, **k) -> ConvGRUNDState: h = state.hidden_state xe, xu, xo = jnp.split(self.in_to_hidden(x), 3, axis=0) he, hu, ho = jnp.split(self.hidden_to_hidden(h), 3, axis=0) - e = self.recurrent_act_func(xe + he) - u = self.recurrent_act_func(xu + hu) - o = self.act_func(xo + (e * ho)) + e = self.recurrent_act(xe + he) + u = self.recurrent_act(xu + hu) + o = self.act(xo + (e * ho)) h = (1 - u) * o + u * h return ConvGRUNDState(hidden_state=h) @@ -1026,8 +1026,8 @@ class ConvGRU1DCell(ConvGRUNDCell): weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function - act_func: Activation function - recurrent_act_func: Recurrent activation function + act: Activation function + recurrent_act: Recurrent activation function key: PRNG key dtype: dtype of the weights and biases. defaults to ``jnp.float32``. @@ -1081,8 +1081,8 @@ class FFTConvGRU1DCell(ConvGRUNDCell): weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function - act_func: Activation function - recurrent_act_func: Recurrent activation function + act: Activation function + recurrent_act: Recurrent activation function key: PRNG key dtype: dtype of the weights and biases. defaults to ``jnp.float32``. @@ -1136,8 +1136,8 @@ class ConvGRU2DCell(ConvGRUNDCell): weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function - act_func: Activation function - recurrent_act_func: Recurrent activation function + act: Activation function + recurrent_act: Recurrent activation function key: PRNG key dtype: dtype of the weights and biases. defaults to ``jnp.float32``. @@ -1191,8 +1191,8 @@ class FFTConvGRU2DCell(ConvGRUNDCell): weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function - act_func: Activation function - recurrent_act_func: Recurrent activation function + act: Activation function + recurrent_act: Recurrent activation function key: PRNG key dtype: dtype of the weights and biases. defaults to ``jnp.float32``. @@ -1246,8 +1246,8 @@ class ConvGRU3DCell(ConvGRUNDCell): weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function - act_func: Activation function - recurrent_act_func: Recurrent activation function + act: Activation function + recurrent_act: Recurrent activation function key: PRNG key dtype: dtype of the weights and biases. defaults to ``jnp.float32``. @@ -1301,8 +1301,8 @@ class FFTConvGRU3DCell(ConvGRUNDCell): weight_init: Weight initialization function bias_init: Bias initialization function recurrent_weight_init: Recurrent weight initialization function - act_func: Activation function - recurrent_act_func: Recurrent activation function + act: Activation function + recurrent_act: Recurrent activation function key: PRNG key dtype: dtype of the weights and biases. defaults to ``jnp.float32``. diff --git a/tests/test_fully_connected.py b/tests/test_fully_connected.py index a6d422c..c5b65ee 100644 --- a/tests/test_fully_connected.py +++ b/tests/test_fully_connected.py @@ -20,8 +20,8 @@ def test_fnn(): - layer = FNN([1, 2, 3, 4], act_func=("relu", "tanh")) - assert layer.act_func[0] is not layer.act_func[1] + layer = FNN([1, 2, 3, 4], act=("relu", "tanh")) + assert layer.act[0] is not layer.act[1] assert layer.layers[0] is not layer.layers[1] x = jax.random.normal(jax.random.PRNGKey(0), (10, 1)) @@ -35,7 +35,7 @@ def test_fnn(): y = jax.nn.relu(y) y = y @ w3 - l1 = FNN([1, 5, 3, 4], act_func=("tanh", "relu"), bias_init=None) + l1 = FNN([1, 5, 3, 4], act=("tanh", "relu"), bias_init=None) l1 = l1.at["layers"].at[0].at["weight"].set(w1) l1 = l1.at["layers"].at[1].at["weight"].set(w2) l1 = l1.at["layers"].at[2].at["weight"].set(w3) @@ -49,7 +49,7 @@ def test_mlp(): 4, hidden_size=10, num_hidden_layers=2, - act_func=("relu", "tanh"), + act=("relu", "tanh"), bias_init=None, ) @@ -75,7 +75,7 @@ def test_mlp(): def test_fnn_mlp(): - fnn = FNN(layers=[2, 4, 4, 2], act_func="relu") - mlp = MLP(2, 2, hidden_size=4, num_hidden_layers=2, act_func="relu") + fnn = FNN(layers=[2, 4, 4, 2], act="relu") + mlp = MLP(2, 2, hidden_size=4, num_hidden_layers=2, act="relu") x = jax.random.normal(jax.random.PRNGKey(0), (10, 2)) npt.assert_allclose(fnn(x), mlp(x)) diff --git a/tests/test_linear.py b/tests/test_linear.py index f08d749..c36894a 100644 --- a/tests/test_linear.py +++ b/tests/test_linear.py @@ -47,7 +47,7 @@ def update(NN, x, y): nn = FNN( [1, 128, 128, 1], - act_func="relu", + act="relu", weight_init="he_normal", bias_init="ones", ) diff --git a/tests/test_rnn.py b/tests/test_rnn.py index 313341d..d9b8d9b 100644 --- a/tests/test_rnn.py +++ b/tests/test_rnn.py @@ -575,7 +575,7 @@ def test_conv_lstm1d(layer): cell = layer( in_features=in_features, hidden_features=hidden_features, - recurrent_act_func="sigmoid", + recurrent_act="sigmoid", kernel_size=3, padding="same", weight_init="glorot_uniform", @@ -604,7 +604,7 @@ def test_conv_lstm1d(layer): cell = layer( in_features=in_features, hidden_features=hidden_features, - recurrent_act_func="sigmoid", + recurrent_act="sigmoid", kernel_size=3, padding="same", weight_init="glorot_uniform", @@ -788,7 +788,7 @@ def test_dense_cell(): cell = DenseCell( in_features=10, hidden_features=10, - act_func=lambda x: x, + act=lambda x: x, weight_init="ones", bias_init=None, )