diff --git a/docs/API/activations.rst b/docs/API/activations.rst index b490bff..1821506 100644 --- a/docs/API/activations.rst +++ b/docs/API/activations.rst @@ -2,10 +2,6 @@ Activations --------------------------------- .. currentmodule:: serket.nn -.. autoclass:: AdaptiveLeakyReLU -.. autoclass:: AdaptiveReLU -.. autoclass:: AdaptiveSigmoid -.. autoclass:: AdaptiveTanh .. autoclass:: CeLU .. autoclass:: ELU .. autoclass:: GELU @@ -28,7 +24,6 @@ Activations .. autoclass:: SoftSign .. autoclass:: SquarePlus .. autoclass:: Swish -.. autoclass:: Snake .. autoclass:: Tanh .. autoclass:: TanhShrink .. autoclass:: ThresholdedReLU \ No newline at end of file diff --git a/docs/notebooks/layers_overview.ipynb b/docs/notebooks/layers_overview.ipynb index f897059..5ac9d19 100644 --- a/docs/notebooks/layers_overview.ipynb +++ b/docs/notebooks/layers_overview.ipynb @@ -23,206 +23,6 @@ "## `serket` general design features" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Handling weight initalization\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Layers that contain `weight_init` or `bias_init` can accept:\n", - "\n", - "- A string: \n", - " - `he_normal`\n", - " - `he_uniform`\n", - " - `glorot_normal`\n", - " - `glorot_uniform`\n", - " - `lecun_normal`\n", - " - `lecun_uniform`\n", - " - `normal`\n", - " - `uniform`\n", - " - `ones`\n", - " - `zeros`\n", - " - `xavier_normal`\n", - " - `xavier_uniform`\n", - " - `orthogonal`\n", - "- A function with the following signature `key:jax.Array, shape:tuple[int,...], dtype`.\n", - "- `None` to indicate no initialization (e.g no bias for layers that have `bias_init` argument).\n", - "- A registered string by `sk.def_init_entry(\"my_init\", ....)` to map to custom init function." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]]\n", - "[[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]]\n" - ] - } - ], - "source": [ - "import serket as sk\n", - "import jax\n", - "import math\n", - "import jax.random as jr\n", - "\n", - "# 1) linear layer with no bias\n", - "linear = sk.nn.Linear(1, 10, weight_init=\"he_normal\", bias_init=None, key=jr.PRNGKey(0))\n", - "\n", - "\n", - "# linear layer with custom initialization function\n", - "def init_func(key, shape, dtype=jax.numpy.float32):\n", - " return jax.numpy.arange(math.prod(shape), dtype=dtype).reshape(shape)\n", - "\n", - "\n", - "linear = sk.nn.Linear(1, 10, weight_init=init_func, bias_init=None, key=jr.PRNGKey(0))\n", - "print(linear.weight)\n", - "# [[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]]\n", - "\n", - "# linear layer with custom initialization function registered under a key\n", - "sk.def_init_entry(\"my_init\", init_func)\n", - "linear = sk.nn.Linear(1, 10, weight_init=\"my_init\", bias_init=None, key=jr.PRNGKey(0))\n", - "print(linear.weight)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Handling activation functions" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Layers that contain `act_func` accepts:\n", - "\n", - "- A string: \n", - " - `adaptive_leaky_relu`\n", - " - `adaptive_relu`\n", - " - `adaptive_sigmoid`\n", - " - `adaptive_tanh`\n", - " - `celu`\n", - " - `elu`\n", - " - `gelu`\n", - " - `glu`\n", - " - `hard_shrink`\n", - " - `hard_sigmoid`\n", - " - `hard_swish`\n", - " - `hard_tanh`\n", - " - `leaky_relu`\n", - " - `log_sigmoid`\n", - " - `log_softmax`\n", - " - `mish`\n", - " - `prelu`\n", - " - `relu`\n", - " - `relu6`\n", - " - `selu`\n", - " - `sigmoid`\n", - " - `snake`\n", - " - `softplus`\n", - " - `softshrink`\n", - " - `softsign`\n", - " - `squareplus`\n", - " - `swish`\n", - " - `tanh`\n", - " - `tanh_shrink`\n", - " - `thresholded_relu`\n", - "- A function of single input and output of `jax.Array`.\n", - "- A registered string by `sk.def_act_entry(\"my_act\", ....)` to map to custom activation class with a `__call__` method." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "import serket as sk\n", - "import jax\n", - "import jax.random as jr\n", - "\n", - "# 1) activation function with a string\n", - "linear = sk.nn.FNN([1, 1], act=\"relu\", key=jr.PRNGKey(0))\n", - "\n", - "# 2) activation function with a function\n", - "linear = sk.nn.FNN([1, 1], act=jax.nn.relu, key=jr.PRNGKey(0))\n", - "\n", - "\n", - "@sk.autoinit\n", - "class MyTrainableActivation(sk.TreeClass):\n", - " my_param: float = 10.0\n", - "\n", - " def __call__(self, x):\n", - " return x * self.my_param\n", - "\n", - "\n", - "# 3) activation function with a class\n", - "linear = sk.nn.FNN([1, 1], act=MyTrainableActivation(), key=jr.PRNGKey(0))\n", - "\n", - "# 4) activation function with a registered class\n", - "sk.def_act_entry(\"my_act\", MyTrainableActivation())\n", - "linear = sk.nn.FNN([1, 1], act=\"my_act\", key=jr.PRNGKey(0))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Handling dtype" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Layers that contain `dtype`, accept any valid `numpy.dtype` variant. this is useful if mixed precision policy is desired. For more, see the example on mixed precision training.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Linear(\n", - " in_features=(10), \n", - " out_features=5, \n", - " weight_init=glorot_uniform, \n", - " bias_init=zeros, \n", - " weight=f16[10,5](μ=0.07, σ=0.35, ∈[-0.63,0.60]), \n", - " bias=f16[5](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", - ")" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import serket as sk\n", - "import jax\n", - "import jax.random as jr\n", - "\n", - "linear = sk.nn.Linear(10, 5, dtype=jax.numpy.float16, key=jr.PRNGKey(0))\n", - "linear\n", - "# note the dtype is f16(float16) in the repr output" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -230,15 +30,6 @@ "### Lazy shape inference" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Lazy initialization is useful in scenarios where the dimensions of certain input features are not known in advance. For instance, when the number of neurons required for a flattened image input is uncertain, or the shape of the output from a flattened convolutional layer is not straightforward to calculate. In such cases, lazy initialization defers layers materialization until the first input.\n", - "\n", - "In `serket`, simply replace `in_features` with `None` to indicate that this layer is lazy. then materialzie the layer by functionally calling the layer. Recall that functional call - via `.at[method_name](*args, **kwargs)` _always_ returns a tuple of method output and a _new_ instance." - ] - }, { "cell_type": "markdown", "metadata": {}, diff --git a/serket/__init__.py b/serket/__init__.py index 88aa921..dbbf05a 100644 --- a/serket/__init__.py +++ b/serket/__init__.py @@ -33,7 +33,6 @@ from serket._src.containers import RandomChoice, Sequential from serket._src.custom_transform import tree_eval, tree_state -from serket._src.nn.activation import def_act_entry from . import cluster, image, nn @@ -65,7 +64,6 @@ "image", "tree_eval", "tree_state", - "def_act_entry", # containers "Sequential", "RandomChoice", diff --git a/serket/_src/nn/activation.py b/serket/_src/nn/activation.py index 5dbf4f7..2138c4c 100644 --- a/serket/_src/nn/activation.py +++ b/serket/_src/nn/activation.py @@ -14,6 +14,8 @@ from __future__ import annotations +import inspect +from collections.abc import Callable as ABCCallable from typing import Callable, Literal, TypeVar, Union, get_args import jax @@ -26,109 +28,6 @@ T = TypeVar("T") -def adaptive_leaky_relu( - input: jax.typing.ArrayLike, - a: float = 1.0, - v: float = 1.0, -) -> jax.Array: - """Adaptive Leaky ReLU activation function - - Reference: - https://arxiv.org/pdf/1906.01170.pdf. - """ - return jnp.maximum(0, a * input) - v * jnp.maximum(0, -a * input) - - -@sk.autoinit -class AdaptiveLeakyReLU(sk.TreeClass): - """Leaky ReLU activation function - - Reference: - https://arxiv.org/pdf/1906.01170.pdf. - """ - - a: float = sk.field(default=1.0, on_setattr=[Range(0), ScalarLike()]) - v: float = sk.field( - default=1.0, - on_setattr=[Range(0), ScalarLike()], - on_getattr=[lax.stop_gradient_p.bind], - ) - - def __call__(self, input: jax.Array) -> jax.Array: - return adaptive_leaky_relu(input, self.a, self.v) - - -def adaptive_relu(input: jax.typing.ArrayLike, a: float = 1.0) -> jax.Array: - """Adaptive ReLU activation function - - Reference: - https://arxiv.org/pdf/1906.01170.pdf. - """ - return jnp.maximum(0, a * input) - - -@sk.autoinit -class AdaptiveReLU(sk.TreeClass): - """ReLU activation function with learnable parameters - - Reference: - https://arxiv.org/pdf/1906.01170.pdf. - """ - - a: float = sk.field(default=1.0, on_setattr=[Range(0), ScalarLike()]) - - def __call__(self, input: jax.Array) -> jax.Array: - return adaptive_relu(input, self.a) - - -def adaptive_sigmoid(input: jax.typing.ArrayLike, a: float = 1.0) -> jax.Array: - """Adaptive sigmoid activation function - - Reference: - https://arxiv.org/pdf/1906.01170.pdf. - """ - return 1 / (1 + jnp.exp(-a * input)) - - -@sk.autoinit -class AdaptiveSigmoid(sk.TreeClass): - """Sigmoid activation function with learnable `a` parameter - - Reference: - https://arxiv.org/pdf/1906.01170.pdf. - """ - - a: float = sk.field(default=1.0, on_setattr=[Range(0), ScalarLike()]) - - def __call__(self, input: jax.Array) -> jax.Array: - return adaptive_sigmoid(input, self.a) - - -def adaptive_tanh(input: jax.typing.ArrayLike, a: float = 1.0) -> jax.Array: - """Adaptive tanh activation function - - Reference: - https://arxiv.org/pdf/1906.01170.pdf. - """ - return (jnp.exp(a * input) - jnp.exp(-a * input)) / ( - jnp.exp(a * input) + jnp.exp(-a * input) - ) - - -@sk.autoinit -class AdaptiveTanh(sk.TreeClass): - """Tanh activation function with learnable parameters - - Reference: - https://arxiv.org/pdf/1906.01170.pdf. - """ - - a: float = sk.field(default=1.0, on_setattr=[Range(0), ScalarLike()]) - - def __call__(self, input: jax.Array) -> jax.Array: - return adaptive_tanh(input, self.a) - - @sk.autoinit class CeLU(sk.TreeClass): """Celu activation function""" @@ -176,11 +75,7 @@ def __call__(self, input: jax.Array) -> jax.Array: def hard_shrink(input: jax.typing.ArrayLike, alpha: float = 0.5) -> jax.Array: - """Hard shrink activation function - - Reference: - https://arxiv.org/pdf/1702.00783.pdf. - """ + """Hard shrink activation function""" return jnp.where(input > alpha, input, jnp.where(input < -alpha, input, 0.0)) @@ -295,11 +190,7 @@ def __call__(self, input: jax.Array) -> jax.Array: def softshrink(input: jax.typing.ArrayLike, alpha: float = 0.5) -> jax.Array: - """Soft shrink activation function - - Reference: - https://arxiv.org/pdf/1702.00783.pdf. - """ + """Soft shrink activation function""" return jnp.where( input < -alpha, input + alpha, @@ -322,11 +213,7 @@ def __call__(self, input: jax.Array) -> jax.Array: def squareplus(input: jax.typing.ArrayLike) -> jax.Array: - """SquarePlus activation function - - Reference: - https://arxiv.org/pdf/1908.08681.pdf. - """ + """SquarePlus activation function""" return 0.5 * (input + jnp.sqrt(input * input + 4)) @@ -413,45 +300,8 @@ def __call__(self, input: jax.Array) -> jax.Array: return prelu(input, self.a) -def snake(input: jax.typing.ArrayLike, a: float = 1.0) -> jax.Array: - """Snake activation function - - Args: - a: scalar (frequency) parameter of the activation function, default is 1.0. - - Reference: - https://arxiv.org/pdf/2006.08195.pdf. - """ - return input + (1 - jnp.cos(2 * a * input)) / (2 * a) - - -@sk.autoinit -class Snake(sk.TreeClass): - """Snake activation function - - Args: - a: scalar (frequency) parameter of the activation function, default is 1.0. - - Reference: - https://arxiv.org/pdf/2006.08195.pdf. - """ - - a: float = sk.field( - default=1.0, - on_setattr=[Range(0), ScalarLike()], - on_getattr=[lax.stop_gradient_p.bind], - ) - - def __call__(self, input: jax.Array) -> jax.Array: - return snake(input, self.a) - - # useful for building layers from configuration text ActivationLiteral = Literal[ - "adaptive_leaky_relu", - "adaptive_relu", - "adaptive_sigmoid", - "adaptive_tanh", "celu", "elu", "gelu", @@ -469,7 +319,6 @@ def __call__(self, input: jax.Array) -> jax.Array: "relu6", "selu", "sigmoid", - "snake", "softplus", "softshrink", "softsign", @@ -482,10 +331,6 @@ def __call__(self, input: jax.Array) -> jax.Array: acts = [ - adaptive_leaky_relu, - adaptive_relu, - adaptive_sigmoid, - adaptive_tanh, jax.nn.celu, jax.nn.elu, jax.nn.gelu, @@ -503,7 +348,6 @@ def __call__(self, input: jax.Array) -> jax.Array: jax.nn.relu6, jax.nn.selu, jax.nn.sigmoid, - snake, jax.nn.softplus, softshrink, softsign, @@ -521,8 +365,14 @@ def __call__(self, input: jax.Array) -> jax.Array: @single_dispatch(argnum=0) -def resolve_activation(act: T) -> T: - return act +def resolve_activation(act): + raise TypeError(f"Unknown activation type {type(act)}.") + + +@resolve_activation.def_type(ABCCallable) +def _(func: T) -> T: + assert len(inspect.getfullargspec(func).args) == 1 + return func @resolve_activation.def_type(str) @@ -531,40 +381,3 @@ def _(act: str): return jax.tree_map(lambda x: x, act_map[act]) except KeyError: raise ValueError(f"Unknown {act=}, available activations: {list(act_map)}") - - -def def_act_entry(key: str, act: ActivationFunctionType) -> None: - """Register a custom activation function key for use in ``serket`` layers. - - Args: - key: The key to register the function under. - act: a callable object that takes a single argument and returns a ``jax`` - array. - - Note: - The registered key can be used in any of ``serket`` ``act_*`` arguments as - substitution for the function. - - Note: - By design, activation functions can be passed directly to ``serket`` layers - with the ``act`` argument. This function is useful if you want to - represent activation functions as a string in a configuration file. - - Example: - >>> import serket as sk - >>> import jax.numpy as jnp - >>> import jax.random as jr - >>> @sk.autoinit - ... class MyTrainableActivation(sk.TreeClass): - ... my_param: float = 10.0 - ... def __call__(self, x): - ... return x * self.my_param - >>> sk.def_act_entry("my_act", MyTrainableActivation()) - """ - if key in act_map: - raise ValueError(f"`init_key` {key=} already registered") - - if not callable(act): - raise TypeError(f"{act=} must be a callable object") - - act_map[key] = act diff --git a/serket/_src/nn/attention.py b/serket/_src/nn/attention.py index 791529d..e0b1361 100644 --- a/serket/_src/nn/attention.py +++ b/serket/_src/nn/attention.py @@ -314,7 +314,7 @@ def __call__( # [..., v_length, v_features] -> [..., v_length, head_features*num_heads] v_heads = self.v_projection(v_input) - attention = type(self).attention_op( + attention = self.attention_op( q_heads=q_heads, k_heads=k_heads, v_heads=v_heads, @@ -331,4 +331,4 @@ def __call__( return self.out_projection(attention) - attention_op = dot_product_attention + attention_op = staticmethod(dot_product_attention) diff --git a/serket/_src/nn/convolution.py b/serket/_src/nn/convolution.py index 1f99e86..061ed0d 100644 --- a/serket/_src/nn/convolution.py +++ b/serket/_src/nn/convolution.py @@ -100,9 +100,7 @@ def grouped_matmul(x, y, groups) -> jax.Array: if lhs.shape[-1] % 2 != 0: lhs = jnp.pad(lhs, tuple([(0, 0)] * (lhs.ndim - 1) + [(0, 1)])) - kernel_pad = tuple( - (0, lhs.shape[i] - rhs.shape[i]) for i in range(2, spatial_ndim + 2) - ) + kernel_pad = ((0, lhs.shape[i] - rhs.shape[i]) for i in range(2, spatial_ndim + 2)) rhs = pad(rhs, ((0, 0), (0, 0), *kernel_pad)) x_fft = jnp.fft.rfftn(lhs, axes=range(2, spatial_ndim + 2)) diff --git a/serket/_src/nn/initialization.py b/serket/_src/nn/initialization.py index ccf4d0c..8411dac 100644 --- a/serket/_src/nn/initialization.py +++ b/serket/_src/nn/initialization.py @@ -78,7 +78,7 @@ def _(init: str): @resolve_init.def_type(type(None)) -def _(init: None): +def _(init): return jtu.Partial(lambda key, shape, dtype=None: None) diff --git a/serket/_src/nn/linear.py b/serket/_src/nn/linear.py index 24fa043..5b4fe9b 100644 --- a/serket/_src/nn/linear.py +++ b/serket/_src/nn/linear.py @@ -179,7 +179,9 @@ def __init__( raise ValueError(f"{len(in_axis)=} != {len(in_features)=}") # arrange the in_features by the in_axis - compare = lambda ik: in_axis[ik[0]] + def compare(ik): + return in_axis[ik[0]] + _, in_features = zip(*sorted(enumerate(in_features), key=compare)) self.in_features = in_features @@ -204,7 +206,7 @@ def __init__( @ft.partial(maybe_lazy_call, is_lazy=is_lazy_call, updates=updates) def __call__(self, input: jax.Array) -> jax.Array: """Apply a linear transformation to the input.""" - return linear( + return self.linear_op( input=input, weight=self.weight, bias=self.bias, @@ -212,6 +214,8 @@ def __call__(self, input: jax.Array) -> jax.Array: out_axis=self.out_axis, ) + linear_op = staticmethod(linear) + class Identity(sk.TreeClass): """Identity layer. Returns the input.""" diff --git a/serket/nn/__init__.py b/serket/nn/__init__.py index 74b369f..294faee 100644 --- a/serket/nn/__init__.py +++ b/serket/nn/__init__.py @@ -16,10 +16,6 @@ ELU, GELU, GLU, - AdaptiveLeakyReLU, - AdaptiveReLU, - AdaptiveSigmoid, - AdaptiveTanh, CeLU, HardShrink, HardSigmoid, @@ -34,7 +30,6 @@ ReLU6, SeLU, Sigmoid, - Snake, SoftPlus, SoftShrink, SoftSign, @@ -180,10 +175,6 @@ "ELU", "GELU", "GLU", - "AdaptiveLeakyReLU", - "AdaptiveReLU", - "AdaptiveSigmoid", - "AdaptiveTanh", "CeLU", "HardShrink", "HardSigmoid", @@ -198,7 +189,6 @@ "ReLU6", "SeLU", "Sigmoid", - "Snake", "SoftPlus", "SoftShrink", "SoftSign", diff --git a/tests/test_activation.py b/tests/test_activation.py index fa49109..47ccd56 100644 --- a/tests/test_activation.py +++ b/tests/test_activation.py @@ -23,10 +23,6 @@ ELU, GELU, GLU, - AdaptiveLeakyReLU, - AdaptiveReLU, - AdaptiveSigmoid, - AdaptiveTanh, CeLU, HardShrink, HardSigmoid, @@ -41,7 +37,6 @@ ReLU6, SeLU, Sigmoid, - Snake, SoftPlus, SoftShrink, SoftSign, @@ -50,7 +45,6 @@ Tanh, TanhShrink, ThresholdedReLU, - def_act_entry, resolve_activation, ) @@ -63,60 +57,6 @@ def test_thresholded_relu(): npt.assert_allclose(actual, expected) -def test_AdaptiveReLU(): - npt.assert_allclose( - AdaptiveReLU(1.0)(jnp.array([1.0, 2.0, 3.0])), jnp.array([1.0, 2.0, 3.0]) - ) - npt.assert_allclose( - AdaptiveReLU(0.0)(jnp.array([1.0, 2.0, 3.0])), jnp.array([0.0, 0.0, 0.0]) - ) - npt.assert_allclose( - AdaptiveReLU(0.5)(jnp.array([1.0, 2.0, 3.0])), jnp.array([0.5, 1.0, 1.5]) - ) - - -def test_AdaptiveLeakyReLU(): - npt.assert_allclose( - AdaptiveLeakyReLU(0.0, 1.0)(jnp.array([1.0, 2.0, 3.0])), - jnp.array([0, 0, 0]), - ) - npt.assert_allclose( - AdaptiveLeakyReLU(0.0, 0.5)(jnp.array([1.0, 2.0, 3.0])), - jnp.array([0, 0, 0]), - ) - npt.assert_allclose( - AdaptiveLeakyReLU(1.0, 0.0)(jnp.array([1.0, 2.0, 3.0])), jnp.array([1, 2, 3]) - ) - npt.assert_allclose( - AdaptiveLeakyReLU(1.0, 0.5)(jnp.array([1.0, 2.0, 3.0])), - jnp.array([1, 2, 3]), - ) - - -def test_AdaptiveSigmoid(): - npt.assert_allclose( - AdaptiveSigmoid(1.0)(jnp.array([1.0, 2.0, 3.0])), - jnp.array([0.7310586, 0.880797, 0.95257413]), - ) - npt.assert_allclose( - AdaptiveSigmoid(0.0)(jnp.array([1.0, 2.0, 3.0])), jnp.array([0.5, 0.5, 0.5]) - ) - - -def test_AdaptiveTanh(): - npt.assert_allclose( - AdaptiveTanh(1.0)(jnp.array([1.0, 2.0, 3.0])), - jnp.array([0.7615942, 0.9640276, 0.9950547]), - ) - npt.assert_allclose( - AdaptiveTanh(0.0)(jnp.array([1.0, 2.0, 3.0])), jnp.array([0, 0, 0]) - ) - npt.assert_allclose( - AdaptiveTanh(0.5)(jnp.array([1.0, 2.0, 3.0])), - jnp.array([0.46211714, 0.7615942, 0.9051482]), - ) - - def test_prelu(): x = jnp.array([-1.0, 0, 1]) expected = x @@ -285,13 +225,6 @@ def test_mish(): npt.assert_allclose(actual, expected, atol=1e-4) -def test_snake(): - x = jnp.array([-1.0, 0, 1]) - expected = jnp.array([-0.29192656, 0.0, 1.7080734]) - actual = Snake()(x) - npt.assert_allclose(actual, expected, atol=1e-4) - - def test_square_plus(): x = jnp.array([-1.0, 0, 1]) expected = 0.5 * (x + jnp.sqrt(x**2 + 4)) @@ -304,13 +237,6 @@ def test_resolving(): resolve_activation("nonexistent") -def test_def_act_entry(): - def_act_entry("id", lambda x: x) - - with pytest.raises(ValueError): - # duplicate entry - def_act_entry("id", lambda x: x) - - with pytest.raises(TypeError): - # non-callable - def_act_entry("one", 1) +def test_invalid_act_sig(): + with pytest.raises(AssertionError): + resolve_activation(lambda x, y: x)