Skip to content

Commit

Permalink
remove def_act_entry, non-essential act entries
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Apr 1, 2024
1 parent 749a0ee commit eaec12b
Show file tree
Hide file tree
Showing 10 changed files with 26 additions and 511 deletions.
5 changes: 0 additions & 5 deletions docs/API/activations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@ Activations
---------------------------------
.. currentmodule:: serket.nn

.. autoclass:: AdaptiveLeakyReLU
.. autoclass:: AdaptiveReLU
.. autoclass:: AdaptiveSigmoid
.. autoclass:: AdaptiveTanh
.. autoclass:: CeLU
.. autoclass:: ELU
.. autoclass:: GELU
Expand All @@ -28,7 +24,6 @@ Activations
.. autoclass:: SoftSign
.. autoclass:: SquarePlus
.. autoclass:: Swish
.. autoclass:: Snake
.. autoclass:: Tanh
.. autoclass:: TanhShrink
.. autoclass:: ThresholdedReLU
209 changes: 0 additions & 209 deletions docs/notebooks/layers_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,222 +23,13 @@
"## `serket` general design features"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Handling weight initalization\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Layers that contain `weight_init` or `bias_init` can accept:\n",
"\n",
"- A string: \n",
" - `he_normal`\n",
" - `he_uniform`\n",
" - `glorot_normal`\n",
" - `glorot_uniform`\n",
" - `lecun_normal`\n",
" - `lecun_uniform`\n",
" - `normal`\n",
" - `uniform`\n",
" - `ones`\n",
" - `zeros`\n",
" - `xavier_normal`\n",
" - `xavier_uniform`\n",
" - `orthogonal`\n",
"- A function with the following signature `key:jax.Array, shape:tuple[int,...], dtype`.\n",
"- `None` to indicate no initialization (e.g no bias for layers that have `bias_init` argument).\n",
"- A registered string by `sk.def_init_entry(\"my_init\", ....)` to map to custom init function."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]]\n",
"[[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]]\n"
]
}
],
"source": [
"import serket as sk\n",
"import jax\n",
"import math\n",
"import jax.random as jr\n",
"\n",
"# 1) linear layer with no bias\n",
"linear = sk.nn.Linear(1, 10, weight_init=\"he_normal\", bias_init=None, key=jr.PRNGKey(0))\n",
"\n",
"\n",
"# linear layer with custom initialization function\n",
"def init_func(key, shape, dtype=jax.numpy.float32):\n",
" return jax.numpy.arange(math.prod(shape), dtype=dtype).reshape(shape)\n",
"\n",
"\n",
"linear = sk.nn.Linear(1, 10, weight_init=init_func, bias_init=None, key=jr.PRNGKey(0))\n",
"print(linear.weight)\n",
"# [[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]]\n",
"\n",
"# linear layer with custom initialization function registered under a key\n",
"sk.def_init_entry(\"my_init\", init_func)\n",
"linear = sk.nn.Linear(1, 10, weight_init=\"my_init\", bias_init=None, key=jr.PRNGKey(0))\n",
"print(linear.weight)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Handling activation functions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Layers that contain `act_func` accepts:\n",
"\n",
"- A string: \n",
" - `adaptive_leaky_relu`\n",
" - `adaptive_relu`\n",
" - `adaptive_sigmoid`\n",
" - `adaptive_tanh`\n",
" - `celu`\n",
" - `elu`\n",
" - `gelu`\n",
" - `glu`\n",
" - `hard_shrink`\n",
" - `hard_sigmoid`\n",
" - `hard_swish`\n",
" - `hard_tanh`\n",
" - `leaky_relu`\n",
" - `log_sigmoid`\n",
" - `log_softmax`\n",
" - `mish`\n",
" - `prelu`\n",
" - `relu`\n",
" - `relu6`\n",
" - `selu`\n",
" - `sigmoid`\n",
" - `snake`\n",
" - `softplus`\n",
" - `softshrink`\n",
" - `softsign`\n",
" - `squareplus`\n",
" - `swish`\n",
" - `tanh`\n",
" - `tanh_shrink`\n",
" - `thresholded_relu`\n",
"- A function of single input and output of `jax.Array`.\n",
"- A registered string by `sk.def_act_entry(\"my_act\", ....)` to map to custom activation class with a `__call__` method."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import serket as sk\n",
"import jax\n",
"import jax.random as jr\n",
"\n",
"# 1) activation function with a string\n",
"linear = sk.nn.FNN([1, 1], act=\"relu\", key=jr.PRNGKey(0))\n",
"\n",
"# 2) activation function with a function\n",
"linear = sk.nn.FNN([1, 1], act=jax.nn.relu, key=jr.PRNGKey(0))\n",
"\n",
"\n",
"@sk.autoinit\n",
"class MyTrainableActivation(sk.TreeClass):\n",
" my_param: float = 10.0\n",
"\n",
" def __call__(self, x):\n",
" return x * self.my_param\n",
"\n",
"\n",
"# 3) activation function with a class\n",
"linear = sk.nn.FNN([1, 1], act=MyTrainableActivation(), key=jr.PRNGKey(0))\n",
"\n",
"# 4) activation function with a registered class\n",
"sk.def_act_entry(\"my_act\", MyTrainableActivation())\n",
"linear = sk.nn.FNN([1, 1], act=\"my_act\", key=jr.PRNGKey(0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Handling dtype"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Layers that contain `dtype`, accept any valid `numpy.dtype` variant. this is useful if mixed precision policy is desired. For more, see the example on mixed precision training.\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Linear(\n",
" in_features=(10), \n",
" out_features=5, \n",
" weight_init=glorot_uniform, \n",
" bias_init=zeros, \n",
" weight=f16[10,5](μ=0.07, σ=0.35, ∈[-0.63,0.60]), \n",
" bias=f16[5](μ=0.00, σ=0.00, ∈[0.00,0.00])\n",
")"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import serket as sk\n",
"import jax\n",
"import jax.random as jr\n",
"\n",
"linear = sk.nn.Linear(10, 5, dtype=jax.numpy.float16, key=jr.PRNGKey(0))\n",
"linear\n",
"# note the dtype is f16(float16) in the repr output"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Lazy shape inference"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Lazy initialization is useful in scenarios where the dimensions of certain input features are not known in advance. For instance, when the number of neurons required for a flattened image input is uncertain, or the shape of the output from a flattened convolutional layer is not straightforward to calculate. In such cases, lazy initialization defers layers materialization until the first input.\n",
"\n",
"In `serket`, simply replace `in_features` with `None` to indicate that this layer is lazy. then materialzie the layer by functionally calling the layer. Recall that functional call - via `.at[method_name](*args, **kwargs)` _always_ returns a tuple of method output and a _new_ instance."
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
2 changes: 0 additions & 2 deletions serket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

from serket._src.containers import RandomChoice, Sequential
from serket._src.custom_transform import tree_eval, tree_state
from serket._src.nn.activation import def_act_entry

from . import cluster, image, nn

Expand Down Expand Up @@ -65,7 +64,6 @@
"image",
"tree_eval",
"tree_state",
"def_act_entry",
# containers
"Sequential",
"RandomChoice",
Expand Down
Loading

0 comments on commit eaec12b

Please sign in to comment.