diff --git a/docs/API/constructor.rst b/docs/API/constructor.rst
index ab39ea5..a519737 100644
--- a/docs/API/constructor.rst
+++ b/docs/API/constructor.rst
@@ -1,4 +1,4 @@
-🏗️ Constructor utils API
+🏗️ Constructor API
=============================
diff --git a/docs/API/sepes.rst b/docs/API/sepes.rst
index 6e705be..bbdab9b 100644
--- a/docs/API/sepes.rst
+++ b/docs/API/sepes.rst
@@ -2,7 +2,7 @@
=============================
.. note::
- `sepes `_ API is fully re-exported under the ``serket`` namespace.
+ `sepes `_ for tree API is fully re-exported under the ``serket`` namespace.
`Check the docs `_ for full details.
.. toctree::
diff --git a/docs/index.rst b/docs/index.rst
index 6ac364d..f876c82 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -64,6 +64,7 @@ Install from github::
core_guides
other_guides
interoperability
+ tree_recipes
.. currentmodule:: serket
diff --git a/docs/notebooks/[recipes]fields.ipynb b/docs/notebooks/[recipes]fields.ipynb
new file mode 100644
index 0000000..479aa00
--- /dev/null
+++ b/docs/notebooks/[recipes]fields.ipynb
@@ -0,0 +1,525 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# 🏟️ Fields"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install sepes"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This section introduces common recipes for fields. A `sepes.field` is class variable that adds certain functionality to the class with `jax` and `numpy`, but this can work with any other framework.\n",
+ "\n",
+ "Add field is written like this:\n",
+ "\n",
+ "```python\n",
+ "class MyClass:\n",
+ " my_field: Any = sepes.field()\n",
+ "```\n",
+ "For example, a `field` can be used to validate the input data, or to provide a default value. The notebook provides examples for common use cases.\n",
+ "\n",
+ "`sepes.field` is implemented as a [python descriptor](https://docs.python.org/3/howto/descriptor.html), which means that it can be used in any class not necessarily a `sepes` class. Refer to the [python documentation](https://docs.python.org/3/howto/descriptor.html) for more information on descriptors and how they work."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## [1] Buffers\n",
+ "In this example, certain array will be marked as non-trainable using `jax.lax.stop_gradient` and `field`.\n",
+ "\n",
+ "The standard way to mark an array as a buffer (e.g. non-trainable) is to write something like this:\n",
+ "```python\n",
+ "class Tree(sp.TreeClass):\n",
+ " def __init__(self, buffer: jax.Array):\n",
+ " self.buffer = buffer\n",
+ "\n",
+ " def __call__(self, x: jax.Array) -> jax.Array:\n",
+ " return x + jax.lax.stop_gradient(self.buffer)\n",
+ "```\n",
+ "However, if you access this buffer from other methods, then another `jax.lax.stop_gradient` should be used and written inside all the methods:\n",
+ "\n",
+ "```python\n",
+ "class Tree(sp.TreeClass):\n",
+ " def method_1(self, x: jax.Array) -> jax.Array:\n",
+ " return x + jax.lax.stop_gradient(self.buffer)\n",
+ " .\n",
+ " .\n",
+ " .\n",
+ " def method_n(self, x: jax.Array) -> jax.Array:\n",
+ " return x + jax.lax.stop_gradient(self.buffer)\n",
+ "```\n",
+ "\n",
+ "Similarly, if you access `buffer` defined for `Tree` instances, from another context, you need to use `jax.lax.stop_gradient` again:\n",
+ "\n",
+ "```python\n",
+ "tree = Tree(buffer=...)\n",
+ "def func(tree: Tree):\n",
+ " buffer = jax.lax.stop_gradient(tree.buffer)\n",
+ " ... \n",
+ "```\n",
+ "\n",
+ "This becomes **cumbersome** if this process is repeated multiple times.Alternatively, `jax.lax.stop_gradient` can be applied to the `buffer` using `sepes.field` whenever the buffer is accessed. The next example demonstrates this."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "6.0\n",
+ "Tree(buffer=[0. 0. 0.])\n"
+ ]
+ }
+ ],
+ "source": [
+ "import sepes as sp\n",
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "\n",
+ "\n",
+ "def buffer_field(**kwargs):\n",
+ " return sp.field(on_getattr=[jax.lax.stop_gradient], **kwargs)\n",
+ "\n",
+ "\n",
+ "@sp.autoinit # autoinit construct `__init__` from fields\n",
+ "class Tree(sp.TreeClass):\n",
+ " buffer: jax.Array = buffer_field()\n",
+ "\n",
+ " def __call__(self, x):\n",
+ " return self.buffer**x\n",
+ "\n",
+ "\n",
+ "tree = Tree(buffer=jnp.array([1.0, 2.0, 3.0]))\n",
+ "tree(2.0) # Array([1., 4., 9.], dtype=float32)\n",
+ "\n",
+ "\n",
+ "@jax.jit\n",
+ "def f(tree: Tree, x: jax.Array):\n",
+ " return jnp.sum(tree(x))\n",
+ "\n",
+ "\n",
+ "print(f(tree, 1.0))\n",
+ "print(jax.grad(f)(tree, 1.0))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## [2] Masked field\n",
+ "\n",
+ "`sepes` provide a simple wrapper to *mask* data. Masking here means that the data yields no leaves when flattened. This is useful in some frameworks like `jax` to hide a certain values from being seen by the transformation."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Flattening a masked value**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[1, #2]\n",
+ "[1]\n"
+ ]
+ }
+ ],
+ "source": [
+ "import sepes as sp\n",
+ "import jax\n",
+ "\n",
+ "tree = [1, sp.tree_mask(2, cond=lambda _: True)]\n",
+ "print(tree)\n",
+ "print(jax.tree_util.tree_leaves(tree)) # note that 2 is removed from the leaves"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Using masking with `jax` transformations**\n",
+ "\n",
+ "The next example demonstrates how to use masking to work with data types that are not supported by `jax`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import sepes as sp\n",
+ "import jax\n",
+ "\n",
+ "\n",
+ "def mask_field(**kwargs):\n",
+ " return sp.field(\n",
+ " # un mask when the value is accessed\n",
+ " on_getattr=[lambda x: sp.tree_unmask(x, cond=lambda node: True)],\n",
+ " # mask when the value is set\n",
+ " on_setattr=[lambda x: sp.tree_mask(x, cond=lambda node: True)],\n",
+ " **kwargs,\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now we can use this custom `field` to mark some class attributes as masked. Masking a value will effectively hide it from `jax` transformations."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Without masking the `str` type**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "TypeError",
+ "evalue": "Argument 'training' of type is not a valid JAX type.",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[5], line 18\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m tree(\u001b[38;5;28minput\u001b[39m)\n\u001b[1;32m 17\u001b[0m tree \u001b[38;5;241m=\u001b[39m Tree(training_mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtraining\u001b[39m\u001b[38;5;124m\"\u001b[39m, alpha\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2.0\u001b[39m)\n\u001b[0;32m---> 18\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43mloss_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtree\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2.0\u001b[39;49m\u001b[43m)\u001b[49m) \u001b[38;5;66;03m# <- will throw error with jax transformations.\u001b[39;00m\n",
+ " \u001b[0;31m[... skipping hidden 5 frame]\u001b[0m\n",
+ "File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/envs/dev-jax/lib/python3.12/site-packages/jax/_src/dispatch.py:281\u001b[0m, in \u001b[0;36mcheck_arg\u001b[0;34m(arg)\u001b[0m\n\u001b[1;32m 279\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcheck_arg\u001b[39m(arg: Any):\n\u001b[1;32m 280\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28misinstance\u001b[39m(arg, core\u001b[38;5;241m.\u001b[39mTracer) \u001b[38;5;129;01mor\u001b[39;00m core\u001b[38;5;241m.\u001b[39mvalid_jaxtype(arg)):\n\u001b[0;32m--> 281\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mArgument \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00marg\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(arg)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is not a valid \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 282\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mJAX type.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
+ "\u001b[0;31mTypeError\u001b[0m: Argument 'training' of type is not a valid JAX type."
+ ]
+ }
+ ],
+ "source": [
+ "@sp.autoinit\n",
+ "class Tree(sp.TreeClass):\n",
+ " training_mode: str # <- will throw error with jax transformations.\n",
+ " alpha: float\n",
+ "\n",
+ " def __call__(self, x):\n",
+ " if self.training_mode == \"training\":\n",
+ " return x**self.alpha\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "@jax.grad\n",
+ "def loss_func(tree, input):\n",
+ " return tree(input)\n",
+ "\n",
+ "\n",
+ "tree = Tree(training_mode=\"training\", alpha=2.0)\n",
+ "print(loss_func(tree, 2.0)) # <- will throw error with jax transformations."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The error resulted because `jax` recognize numerical values only. The next example demonstrates how to modify the class to mask the `str` type."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@sp.autoinit\n",
+ "class Tree(sp.TreeClass):\n",
+ " training_mode: str = mask_field() # hide the field from jax transformations\n",
+ " alpha: float\n",
+ "\n",
+ " def __call__(self, x):\n",
+ " if self.training_mode == \"training\":\n",
+ " return x**self.alpha\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "tree = Tree(training_mode=\"training\", alpha=2.0)\n",
+ "print(loss_func(tree, 2.0))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## [3] Validator fields"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The following provides an example of how to use `sepes.field` to validate the input data. The `validator` function is used to check if the input data is valid. If the data is invalid, an exception is raised. This example is inspired by the [python offical docs example](https://docs.python.org/3/howto/descriptor.html#validator-class)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Range+Type validator"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import jax\n",
+ "import sepes as sp\n",
+ "\n",
+ "\n",
+ "# you can use any function\n",
+ "@sp.autoinit\n",
+ "class Range(sp.TreeClass):\n",
+ " min: int | float = -float(\"inf\")\n",
+ " max: int | float = float(\"inf\")\n",
+ "\n",
+ " def __call__(self, x):\n",
+ " if not (self.min <= x <= self.max):\n",
+ " raise ValueError(f\"{x} not in range [{self.min}, {self.max}]\")\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "@sp.autoinit\n",
+ "class IsInstance(sp.TreeClass):\n",
+ " klass: type | tuple[type, ...]\n",
+ "\n",
+ " def __call__(self, x):\n",
+ " if not isinstance(x, self.klass):\n",
+ " raise TypeError(f\"{x} not an instance of {self.klass}\")\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "@sp.autoinit\n",
+ "class Foo(sp.TreeClass):\n",
+ " # allow in_dim to be an integer between [1,100]\n",
+ " in_dim: int = sp.field(on_setattr=[IsInstance(int), Range(1, 100)])\n",
+ "\n",
+ "\n",
+ "tree = Foo(1)\n",
+ "# no error\n",
+ "\n",
+ "try:\n",
+ " tree = Foo(0)\n",
+ "except ValueError as e:\n",
+ " print(e)\n",
+ "\n",
+ "try:\n",
+ " tree = Foo(1.0)\n",
+ "except TypeError as e:\n",
+ " print(e)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Array validator"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import sepes as sp\n",
+ "from typing import Any\n",
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "\n",
+ "\n",
+ "class ArrayValidator(sp.TreeClass):\n",
+ " \"\"\"Validate shape and dtype of input array.\n",
+ "\n",
+ " Args:\n",
+ " shape: Expected shape of array. available values are int, None, ...\n",
+ " use int for fixed size, None for any size, and ... for any number\n",
+ " of dimensions. for example (..., 1) allows any number of dimensions\n",
+ " with the last dimension being 1. (1, ..., 1) allows any number of\n",
+ " dimensions with the first and last dimensions being 1.\n",
+ " dtype: Expected dtype of array.\n",
+ "\n",
+ " Example:\n",
+ " >>> x = jnp.ones((5, 5))\n",
+ " >>> # any number of dimensions with last dim=5\n",
+ " >>> shape = (..., 5)\n",
+ " >>> dtype = jnp.float32\n",
+ " >>> validator = ArrayValidator(shape, dtype)\n",
+ " >>> validator(x) # no error\n",
+ "\n",
+ " >>> # must be 2 dimensions with first dim unconstrained and last dim=5\n",
+ " >>> shape = (None, 5)\n",
+ " >>> validator = ArrayValidator(shape, dtype)\n",
+ " >>> validator(x) # no error\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, shape, dtype):\n",
+ " if shape.count(...) > 1:\n",
+ " raise ValueError(\"Only one ellipsis allowed\")\n",
+ "\n",
+ " for si in shape:\n",
+ " if not isinstance(si, (int, type(...), type(None))):\n",
+ " raise TypeError(f\"Expected int or ..., got {si}\")\n",
+ "\n",
+ " self.shape = shape\n",
+ " self.dtype = dtype\n",
+ "\n",
+ " def __call__(self, x):\n",
+ " if not (hasattr(x, \"shape\") and hasattr(x, \"dtype\")):\n",
+ " raise TypeError(f\"Expected array with shape {self.shape}, got {x}\")\n",
+ "\n",
+ " shape = list(self.shape)\n",
+ " array_shape = list(x.shape)\n",
+ " array_dtype = x.dtype\n",
+ "\n",
+ " if self.shape and array_dtype != self.dtype:\n",
+ " raise TypeError(f\"Dtype mismatch, {array_dtype=} != {self.dtype=}\")\n",
+ "\n",
+ " if ... in shape:\n",
+ " index = shape.index(...)\n",
+ " shape = (\n",
+ " shape[:index]\n",
+ " + [None] * (len(array_shape) - len(shape) + 1)\n",
+ " + shape[index + 1 :]\n",
+ " )\n",
+ "\n",
+ " if len(shape) != len(array_shape):\n",
+ " raise ValueError(f\"{len(shape)=} != {len(array_shape)=}\")\n",
+ "\n",
+ " for i, (li, ri) in enumerate(zip(shape, array_shape)):\n",
+ " if li is None:\n",
+ " continue\n",
+ " if li != ri:\n",
+ " raise ValueError(f\"Size mismatch, {li} != {ri} at dimension {i}\")\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "# any number of dimensions with firt dim=3 and last dim=6\n",
+ "shape = (3, ..., 6)\n",
+ "# dtype must be float32\n",
+ "dtype = jnp.float32\n",
+ "\n",
+ "validator = ArrayValidator(shape=shape, dtype=dtype)\n",
+ "\n",
+ "# convert to half precision from float32\n",
+ "converter = lambda x: x.astype(jnp.float16)\n",
+ "\n",
+ "\n",
+ "@sp.autoinit\n",
+ "class Tree(sp.TreeClass):\n",
+ " array: jax.Array = sp.field(on_setattr=[validator, converter])\n",
+ "\n",
+ "\n",
+ "x = jnp.ones([3, 1, 2, 6])\n",
+ "tree = Tree(array=x)\n",
+ "\n",
+ "\n",
+ "try:\n",
+ " y = jnp.ones([1, 1, 2, 3])\n",
+ " tree = Tree(array=y)\n",
+ "except ValueError as e:\n",
+ " print(e, \"\\n\")\n",
+ " # On applying ArrayValidator(shape=(3, Ellipsis, 6), dtype=) for field=`array`:\n",
+ " # Dtype mismatch, array_dtype=dtype('float16') != self.dtype=\n",
+ "\n",
+ "try:\n",
+ " z = x.astype(jnp.float16)\n",
+ " tree = Tree(array=z)\n",
+ "except TypeError as e:\n",
+ " print(e)\n",
+ " # On applying ArrayValidator(shape=(3, Ellipsis, 6), dtype=) for field=`array`:\n",
+ " # Size mismatch, 3 != 1 at dimension 0"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## [4] Parameterization field\n",
+ "\n",
+ "In this example, field value is [parameterized](https://pytorch.org/tutorials/intermediate/parametrizations.html) using `on_getattr`,\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import sepes as sp\n",
+ "import jax.numpy as jnp\n",
+ "\n",
+ "\n",
+ "def symmetric(array: jax.Array) -> jax.Array:\n",
+ " triangle = jnp.triu(array) # upper triangle\n",
+ " return triangle + triangle.transpose(-1, -2)\n",
+ "\n",
+ "\n",
+ "@sp.autoinit\n",
+ "class Tree(sp.TreeClass):\n",
+ " symmetric_matrix: jax.Array = sp.field(on_getattr=[symmetric])\n",
+ "\n",
+ "\n",
+ "tree = Tree(symmetric_matrix=jnp.arange(9).reshape(3, 3))\n",
+ "print(tree.symmetric_matrix)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "dev-jax",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.2"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs/notebooks/[recipes]intermediates.ipynb b/docs/notebooks/[recipes]intermediates.ipynb
new file mode 100644
index 0000000..251d530
--- /dev/null
+++ b/docs/notebooks/[recipes]intermediates.ipynb
@@ -0,0 +1,130 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# 🔧 Intermediates handling\n",
+ "\n",
+ "This notebook demonstrates how to capture the intermediate outputs of a model during inference. This is useful for debugging, understanding the model, and visualizing the model's internal representations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install sepes"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Capture intermediate values. \n",
+ "\n",
+ "In this example, we will capture the intermediate values in a method by simply returning them as part of the output."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'b': 2.0, 'c': 4.0}"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import sepes as sp\n",
+ "\n",
+ "\n",
+ "class Foo(sp.TreeClass):\n",
+ " def __init__(self):\n",
+ " self.a = 1.0\n",
+ "\n",
+ " def __call__(self, x):\n",
+ " capture = {}\n",
+ " b = self.a + x\n",
+ " capture[\"b\"] = b\n",
+ " c = 2 * b\n",
+ " capture[\"c\"] = c\n",
+ " e = 4 * c\n",
+ " return e, capture\n",
+ "\n",
+ "\n",
+ "foo = Foo()\n",
+ "\n",
+ "_, inter_values = foo(1.0)\n",
+ "inter_values"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Capture intermediate gradients\n",
+ "\n",
+ "In this example, we will capture the intermediate gradients in a method by 1) perturbing the desired value and 2) using `argnum` in `jax.grad` to compute the intermediate gradients."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'b': Array(8., dtype=float32, weak_type=True),\n",
+ " 'c': Array(4., dtype=float32, weak_type=True)}"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import sepes as sp\n",
+ "import jax\n",
+ "\n",
+ "\n",
+ "class Foo(sp.TreeClass):\n",
+ " def __init__(self):\n",
+ " self.a = 1.0\n",
+ "\n",
+ " def __call__(self, x, perturb):\n",
+ " # pass in the perturbations as a pytree\n",
+ " b = self.a + x + perturb[\"b\"]\n",
+ " c = 2 * b + perturb[\"c\"]\n",
+ " e = 4 * c\n",
+ " return e\n",
+ "\n",
+ "\n",
+ "foo = Foo()\n",
+ "\n",
+ "# de/dc = 4\n",
+ "# de/db = de/dc * dc/db = 4 * 2 = 8\n",
+ "\n",
+ "# take gradient with respect to the perturbations pytree\n",
+ "# by setting `argnums=1` in `jax.grad`\n",
+ "inter_grads = jax.grad(foo, argnums=1)(1.0, dict(b=0.0, c=0.0))\n",
+ "inter_grads"
+ ]
+ }
+ ],
+ "metadata": {
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs/notebooks/[recipes]misc.ipynb b/docs/notebooks/[recipes]misc.ipynb
new file mode 100644
index 0000000..607da61
--- /dev/null
+++ b/docs/notebooks/[recipes]misc.ipynb
@@ -0,0 +1,266 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# 🗂️ Misc recipes"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install sepes"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This section introduces some miscellaneous recipes that are not covered in the previous sections."
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## [1] Lazy layers.\n",
+ "In this example, a `Linear` layer with a weight parameter based on the shape of the input will be created. Since this requires parameter creation (i.e., `weight`) after instance initialization, we will use `value_and_tree` to create a new instance with the added parameter."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Layer before param is set:\tLazyLinear(out_features=1)\n",
+ "Layer after param is set:\tLazyLinear(out_features=1, weight=[[1.]], bias=[0.])\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "Array([[1.],\n",
+ " [1.],\n",
+ " [1.],\n",
+ " [1.],\n",
+ " [1.],\n",
+ " [1.],\n",
+ " [1.],\n",
+ " [1.],\n",
+ " [1.],\n",
+ " [1.]], dtype=float32)"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import sepes as sp\n",
+ "from typing import Any\n",
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "\n",
+ "\n",
+ "class LazyLinear(sp.TreeClass):\n",
+ " def __init__(self, out_features: int):\n",
+ " self.out_features = out_features\n",
+ "\n",
+ " def param(self, name: str, value: Any):\n",
+ " # return the value if it exists, otherwise set it and return it\n",
+ " if name not in vars(self):\n",
+ " setattr(self, name, value)\n",
+ " return vars(self)[name]\n",
+ "\n",
+ " def __call__(self, input: jax.Array) -> jax.Array:\n",
+ " weight = self.param(\"weight\", jnp.ones((self.out_features, input.shape[-1])))\n",
+ " bias = self.param(\"bias\", jnp.zeros((self.out_features,)))\n",
+ " return input @ weight.T + bias\n",
+ "\n",
+ "\n",
+ "input = jnp.ones([10, 1])\n",
+ "\n",
+ "lazy = LazyLinear(out_features=1)\n",
+ "\n",
+ "print(f\"Layer before param is set:\\t{lazy}\")\n",
+ "\n",
+ "# `value_and_tree` executes any mutating method in a functional way\n",
+ "_, material = sp.value_and_tree(lambda layer: layer(input))(lazy)\n",
+ "\n",
+ "print(f\"Layer after param is set:\\t{material}\")\n",
+ "# subsequent calls will not set the parameters again\n",
+ "material(input)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## [3] Regularization\n",
+ "\n",
+ "The following code showcase how to use `at` functionality to select some leaves of a model based on boolean mask or/and name condition to apply some weight regualrization on them. For example using `.at[...]` functionality the following can be achieved concisely:"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Boolean-based mask\n",
+ "\n",
+ "The entries of the arrays or leaves are selected based on a tree of the same structure but with boolean (`True`/`False`) leave. The `True` leaf points to place where the operation can be done, while `False` leaf is indicating that this leaf should not be touched."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2.8333335\n"
+ ]
+ }
+ ],
+ "source": [
+ "import sepes as sp\n",
+ "import jax.numpy as jnp\n",
+ "import jax\n",
+ "\n",
+ "\n",
+ "class Net(sp.TreeClass):\n",
+ " def __init__(self):\n",
+ " self.weight = jnp.array([-1, -2, -3, 1, 2, 3])\n",
+ " self.bias = jnp.array([-1, 1])\n",
+ "\n",
+ "\n",
+ "def negative_entries_l2_loss(net: Net):\n",
+ " return (\n",
+ " # select all positive array entries\n",
+ " net.at[jax.tree_map(lambda x: x > 0, net)]\n",
+ " # set them to zero to exclude their loss\n",
+ " .set(0)\n",
+ " # select all leaves\n",
+ " .at[...]\n",
+ " # finally reduce with l2 loss\n",
+ " .reduce(lambda x, y: x + jnp.mean(y**2), initializer=0)\n",
+ " )\n",
+ "\n",
+ "\n",
+ "net = Net()\n",
+ "print(negative_entries_l2_loss(net))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Name-based mask\n",
+ "\n",
+ "In this step, the mask is based on the path of the leaf."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "331.84155\n"
+ ]
+ }
+ ],
+ "source": [
+ "import sepes as sp\n",
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "import jax.random as jr\n",
+ "\n",
+ "\n",
+ "class Linear(sp.TreeClass):\n",
+ " def __init__(self, in_features: int, out_features: int, key: jax.Array):\n",
+ " self.weight = jr.normal(key=key, shape=(out_features, in_features))\n",
+ " self.bias = jnp.zeros((out_features,))\n",
+ "\n",
+ "\n",
+ "class Net(sp.TreeClass):\n",
+ " def __init__(self, key: jax.Array) -> None:\n",
+ " k1, k2, k3, k4 = jax.random.split(key, 4)\n",
+ " self.linear1 = Linear(1, 20, key=k1)\n",
+ " self.linear2 = Linear(20, 20, key=k2)\n",
+ " self.linear3 = Linear(20, 20, key=k3)\n",
+ " self.linear4 = Linear(20, 1, key=k4)\n",
+ "\n",
+ "\n",
+ "def linear_12_weight_l1_loss(net: Net):\n",
+ " return (\n",
+ " # select desired branches (linear1, linear2 in this example)\n",
+ " # and the desired leaves (weight)\n",
+ " net.at[\"linear1\", \"linear2\"][\"weight\"]\n",
+ " # alternatively, regex can be used to do the same functiontality\n",
+ " # >>> import re\n",
+ " # >>> net.at[re.compile(\"linear[12]\")][\"weight\"]\n",
+ " # finally apply l1 loss\n",
+ " .reduce(lambda x, y: x + jnp.sum(jnp.abs(y)), initializer=0)\n",
+ " )\n",
+ "\n",
+ "\n",
+ "net = Net(key=jr.PRNGKey(0))\n",
+ "print(linear_12_weight_l1_loss(net))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This recipe can then be included inside the loss function, for example\n",
+ "\n",
+ "``` python\n",
+ "\n",
+ "def loss_fnc(net, x, y):\n",
+ " l1_loss = linear_12_weight_l1_loss(net)\n",
+ " loss += l1_loss\n",
+ " ...\n",
+ "```"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "dev-jax",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.2"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs/notebooks/[recipes]sharing.ipynb b/docs/notebooks/[recipes]sharing.ipynb
new file mode 100644
index 0000000..2649773
--- /dev/null
+++ b/docs/notebooks/[recipes]sharing.ipynb
@@ -0,0 +1,212 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# 🪡 Sharing/Tie Weights\n",
+ "\n",
+ "Because sharing weights convert a pytree to a graph by pointing one leaf to another, a careful handling is needed to avoid breaking the _tree_ assumptions.\n",
+ "\n",
+ "In `sepes`, sharing/tie weights is done inside methods, this means, instead sharing the reference within `__init__` method, the reference is shared within the method of which the call is made.\n",
+ "\n",
+ "**From**\n",
+ "\n",
+ "```python\n",
+ "class TiedAutoEncoder:\n",
+ " def __init__(self, input_dim, hidden_dim):\n",
+ " self.encoder = Linear(input_dim, hidden_dim)\n",
+ " self.decoder = Linear(hidden_dim, input_dim)\n",
+ " self.decoder.weight = self.encoder.weight\n",
+ "\n",
+ " def __call__(self, x):\n",
+ " return self.decoder(self.encoder(x))\n",
+ "```\n",
+ "\n",
+ "**To** \n",
+ "\n",
+ "```python\n",
+ "class TiedAutoEncoder:\n",
+ " def __init__(self, input_dim, hidden_dim):\n",
+ " self.encoder = Linear(input_dim, hidden_dim)\n",
+ " self.decoder = Linear(hidden_dim, input_dim)\n",
+ " \n",
+ " def __call__(self, x):\n",
+ " self.decoder.weight = self.encoder.weight.T\n",
+ " return self.decoder(self.encoder(x))\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install sepes"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this example a simple `AutoEncoder` with shared `weight` between the encode/decoder is demonstrated."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import sepes as sp\n",
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "import jax.random as jr\n",
+ "import functools as ft\n",
+ "\n",
+ "\n",
+ "def sharing(method):\n",
+ " # sharing simply copies the instance, executes the method, and returns the output\n",
+ " # **without modifying the original instance.**\n",
+ " @ft.wraps(method)\n",
+ " def wrapper(self, *args, **kwargs):\n",
+ " # `value_and_tree` executes any mutating method in a functional way\n",
+ " # by copying `self`, executing the method, and returning the new state\n",
+ " # along with the output.\n",
+ " output, _ = sp.value_and_tree(method)(self, *args, **kwargs)\n",
+ " return output\n",
+ "\n",
+ " return wrapper\n",
+ "\n",
+ "\n",
+ "class Linear(sp.TreeClass):\n",
+ " def __init__(self, in_features: int, out_features: int, key: jax.Array):\n",
+ " self.weight = jr.normal(key=key, shape=(out_features, in_features))\n",
+ " self.bias = jnp.zeros((out_features,))\n",
+ "\n",
+ " def __call__(self, input):\n",
+ " return input @ self.weight.T + self.bias\n",
+ "\n",
+ "\n",
+ "class AutoEncoder(sp.TreeClass):\n",
+ " def __init__(self, *, key: jax.Array):\n",
+ " k1, k2, k3, k4 = jr.split(key, 4)\n",
+ " self.enc1 = Linear(1, 10, key=k1)\n",
+ " self.enc2 = Linear(10, 20, key=k2)\n",
+ " self.dec2 = Linear(20, 10, key=k3)\n",
+ " self.dec1 = Linear(10, 1, key=k4)\n",
+ "\n",
+ " @sharing\n",
+ " def tied_call(self, input: jax.Array) -> jax.Array:\n",
+ " self.dec1.weight = self.enc1.weight.T\n",
+ " self.dec2.weight = self.enc2.weight.T\n",
+ " output = self.enc1(input)\n",
+ " output = self.enc2(output)\n",
+ " output = self.dec2(output)\n",
+ " output = self.dec1(output)\n",
+ " return output\n",
+ "\n",
+ " def non_tied_call(self, x):\n",
+ " output = self.enc1(x)\n",
+ " output = self.enc2(output)\n",
+ " output = self.dec2(output)\n",
+ " output = self.dec1(output)\n",
+ " return output"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Linear(\n",
+ " weight=f32[1,10](μ=0.00, σ=0.00, ∈[0.00,0.00]), \n",
+ " bias=f32[1](μ=622.29, σ=0.00, ∈[622.29,622.29])\n",
+ ") Linear(\n",
+ " weight=f32[10,20](μ=0.00, σ=0.00, ∈[0.00,0.00]), \n",
+ " bias=f32[10](μ=-107.37, σ=706.50, ∈[-1561.75,949.41])\n",
+ ")\n"
+ ]
+ }
+ ],
+ "source": [
+ "@jax.jit\n",
+ "@jax.grad\n",
+ "def tied_loss_func(net, x, y):\n",
+ " net = sp.tree_unmask(net)\n",
+ " return jnp.mean((jax.vmap(net.tied_call)(x) - y) ** 2)\n",
+ "\n",
+ "\n",
+ "tree = sp.tree_mask(AutoEncoder(key=jr.PRNGKey(0)))\n",
+ "x = jnp.ones([10, 1]) + 0.0\n",
+ "y = jnp.ones([10, 1]) * 2.0\n",
+ "grads: AutoEncoder = tied_loss_func(tree, x, y)\n",
+ "# note that the shared weights have 0 gradient\n",
+ "print(repr(grads.dec1), repr(grads.dec2))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Linear(\n",
+ " weight=f32[1,10](μ=0.00, σ=0.00, ∈[0.00,0.00]), \n",
+ " bias=f32[1](μ=622.29, σ=0.00, ∈[622.29,622.29])\n",
+ ") Linear(\n",
+ " weight=f32[10,20](μ=0.00, σ=0.00, ∈[0.00,0.00]), \n",
+ " bias=f32[10](μ=-107.37, σ=706.50, ∈[-1561.75,949.41])\n",
+ ")\n"
+ ]
+ }
+ ],
+ "source": [
+ "# check for non-tied call\n",
+ "@jax.jit\n",
+ "@jax.grad\n",
+ "def non_tied_loss_func(net, x, y):\n",
+ " net = sp.tree_unmask(net)\n",
+ " return jnp.mean((jax.vmap(net.non_tied_call)(x) - y) ** 2)\n",
+ "\n",
+ "\n",
+ "tree = sp.tree_mask(tree)\n",
+ "x = jnp.ones([10, 1]) + 0.0\n",
+ "y = jnp.ones([10, 1]) * 2.0\n",
+ "grads: AutoEncoder = tied_loss_func(tree, x, y)\n",
+ "\n",
+ "# note that the shared weights have non-zero gradients\n",
+ "print(repr(grads.dec1), repr(grads.dec2))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "dev-jax",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs/notebooks/[recipes]surgery.ipynb b/docs/notebooks/[recipes]surgery.ipynb
new file mode 100644
index 0000000..7f63fed
--- /dev/null
+++ b/docs/notebooks/[recipes]surgery.ipynb
@@ -0,0 +1,315 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# ✂️ Surgery\n",
+ "\n",
+ "This notebook provides tree editing (surgery) recipes using `at`. `at` encapsules a pytree and provides a way to edit it in out-of-place manner. Under the hood, `at` uses `jax.tree_util` or `optree` to traverse the tree and apply the provided function to the selected nodes.\n",
+ "\n",
+ "```python\n",
+ "import sepes as sp\n",
+ "import re\n",
+ "tree = dict(key_1=[1, -2, 3], key_2=[4, 5, 6], other=[7, 8, 9])\n",
+ "where = re.compile(\"key_.*\") # select all keys starting with \"key_\"\n",
+ "func = lambda node: sum(map(abs, node)) # sum of absolute values\n",
+ "sp.at(tree)[where].apply(func)\n",
+ "# {'key_1': 6, 'key_2': 15, 'other': [7, 8, 9]}\n",
+ "```\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install sepes"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Out-of-place editing\n",
+ "\n",
+ "Out-of-place means that the original tree is not modified. Instead, a new tree is created with the changes. This is the default behavior of `at`. The following example demonstrates this behavior:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "pytree1=[1, [2, 3], 4], pytree2=[1, [2, 3], 4]\n",
+ "pytree1 is pytree2 = False\n"
+ ]
+ }
+ ],
+ "source": [
+ "import sepes as sp\n",
+ "\n",
+ "pytree1 = [1, [2, 3], 4]\n",
+ "pytree2 = sp.at(pytree1)[...].get() # get the whole pytree using ...\n",
+ "print(f\"{pytree1=}, {pytree2=}\")\n",
+ "# even though pytree1 and pytree2 are the same, they are not the same object\n",
+ "# because pytree2 is a copy of pytree1\n",
+ "print(f\"pytree1 is pytree2 = {pytree1 is pytree2}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Matching keys"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Match all\n",
+ "\n",
+ "Use `...` to match all keys. \n",
+ "\n",
+ "The following example applies `plus_one` function to all tree nodes. This is equivalent to `tree = tree_map(plus_one, tree)`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[2, [3, 4], 5]"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import sepes as sp\n",
+ "\n",
+ "pytree1 = [1, [2, 3], 4]\n",
+ "plus_one = lambda x: x + 1\n",
+ "pytree2 = sp.at(pytree1)[...].apply(plus_one)\n",
+ "pytree2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Integer indexing\n",
+ "\n",
+ "`at` can edit pytrees by integer paths."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[1, [100, 3], 4]"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import sepes as sp\n",
+ "\n",
+ "pytree1 = [1, [2, 3], 4]\n",
+ "pytree2 = sp.at(pytree1)[1][0].set(100) # equivalent to pytree1[1][0] = 100\n",
+ "pytree2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Named path indexing\n",
+ "`at` can edit pytrees by named paths."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'a': -1, 'b': 100, 'e': -4, 'f': {'g': 7, 'h': 8}}"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import sepes as sp\n",
+ "\n",
+ "pytree1 = {\"a\": -1, \"b\": {\"c\": 2, \"d\": 3}, \"e\": -4, \"f\": {\"g\": 7, \"h\": 8}}\n",
+ "pytree2 = sp.at(pytree1)[\"b\"].set(100) # equivalent to pytree1[\"b\"] = 100\n",
+ "pytree2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Regular expressions indexing\n",
+ "`at` can edit pytrees by regular expressions applied to keys."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'key_1': 100, 'key_2': 100, 'key_5': 100, 'key_6': {'key_7': 7, 'key_8': 8}}"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import sepes as sp\n",
+ "import re\n",
+ "\n",
+ "pytree1 = {\n",
+ " \"key_1\": 1,\n",
+ " \"key_2\": {\"key_3\": 3, \"key_4\": 4},\n",
+ " \"key_5\": 5,\n",
+ " \"key_6\": {\"key_7\": 7, \"key_8\": 8},\n",
+ "}\n",
+ "# from 1 - 5, set the value to 100\n",
+ "pattern = re.compile(r\"key_[1-5]\")\n",
+ "pytree2 = sp.at(pytree1)[pattern].set(100)\n",
+ "pytree2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Mask indexing\n",
+ "\n",
+ "`at` can edit pytree entries by a boolean mask, meaning that given a mask of the same structure of the pytree, then nodes marked `True` will be edited, otherwise will not be touched. \n",
+ "\n",
+ "\n",
+ "The following example set all negative tree nodes to zero."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'a': 0, 'b': {'c': 2, 'd': 3}, 'e': 0}"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import sepes as sp\n",
+ "import jax\n",
+ "\n",
+ "pytree1 = {\"a\": -1, \"b\": {\"c\": 2, \"d\": 3}, \"e\": -4}\n",
+ "# mask defines all desired entries to apply the function\n",
+ "mask = jax.tree_util.tree_map(lambda x: x < 0, pytree1)\n",
+ "pytree2 = sp.at(pytree1)[mask].set(0)\n",
+ "pytree2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Composition\n",
+ "\n",
+ "`at` can compose multiple keys, integer paths, named paths, regular expressions, and masks to edit the tree.\n",
+ "\n",
+ "The following example demonstrates how to apply a function to all nodes with:\n",
+ "- Name starting with \"key_\"\n",
+ "- Positive values"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'key_1': 100, 'key_2': -2, 'value_1': 1, 'value_2': 2}"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import sepes as sp\n",
+ "import re\n",
+ "import jax\n",
+ "\n",
+ "pytree1 = {\"key_1\": 1, \"key_2\": -2, \"value_1\": 1, \"value_2\": 2}\n",
+ "pattern = re.compile(r\"key_.*\")\n",
+ "mask = jax.tree_util.tree_map(lambda x: x > 0, pytree1)\n",
+ "pytree2 = sp.at(pytree1)[pattern][mask].set(100)\n",
+ "pytree2"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "dev-jax",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs/notebooks/[recipes]transformations.ipynb b/docs/notebooks/[recipes]transformations.ipynb
new file mode 100644
index 0000000..df053ec
--- /dev/null
+++ b/docs/notebooks/[recipes]transformations.ipynb
@@ -0,0 +1,594 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# 💫 Transformations"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This section introduces common function transformations that are used in conjunction with pytrees. Examples includes function transformation that wraps `jax` transforms or a function transformation that wraps `numpy`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install sepes"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## [1] Broadcasting transformations\n",
+ "\n",
+ "Using `bcmap` to apply a function over pytree leaves with automatic broadcasting for scalar arguments. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### `bcmap` + `numpy`\n",
+ "\n",
+ "In this recipe, `numpy` functions will operate directly on `TreeClass` instances."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Tree(a=0, b=(0.0, 103.0), c=[104. 105. 106.])\n"
+ ]
+ }
+ ],
+ "source": [
+ "import sepes as sp\n",
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "\n",
+ "\n",
+ "@sp.leafwise # enable math operations on leaves\n",
+ "@sp.autoinit # generate __init__ from type annotations\n",
+ "class Tree(sp.TreeClass):\n",
+ " a: int = 1\n",
+ " b: tuple[float] = (2.0, 3.0)\n",
+ " c: jax.Array = jnp.array([4.0, 5.0, 6.0])\n",
+ "\n",
+ "\n",
+ "tree = Tree()\n",
+ "\n",
+ "# make where work with arbitrary pytrees\n",
+ "tree_where = sp.bcmap(jnp.where)\n",
+ "# for values > 2, add 100, else set to 0\n",
+ "print(tree_where(tree > 2, tree + 100, 0))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "`bcmap` on pytrees with non-jaxtype\n",
+ "\n",
+ "In case the tree has some non-jaxtype leaves, The above will fail, but we can use `tree_mask`/`tree_unmask` to fix it"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "bcmap fail '>' not supported between instances of 'str' and 'int'\n",
+ "Tree(a=0.0, b=(0.0, 103.0), c=[104. 105. 106.], name=tree, func=(x))\n"
+ ]
+ }
+ ],
+ "source": [
+ "# in case the tree has some non-jaxtype leaves\n",
+ "# the above will fail, but we can use `tree_mask`/`tree_unmask` to fix it\n",
+ "import sepes as sp\n",
+ "import jax.numpy as jnp\n",
+ "from typing import Callable\n",
+ "\n",
+ "\n",
+ "@sp.leafwise # enable math operations on leaves\n",
+ "@sp.autoinit # generate __init__ from type annotations\n",
+ "class Tree(sp.TreeClass):\n",
+ " a: float = 1.0\n",
+ " b: tuple[float] = (2.0, 3.0)\n",
+ " c: jax.Array = jnp.array([4.0, 5.0, 6.0])\n",
+ " name: str = \"tree\" # non-jaxtype\n",
+ " func: Callable = lambda x: x # non-jaxtype\n",
+ "\n",
+ "\n",
+ "tree = Tree()\n",
+ "\n",
+ "try:\n",
+ " # make where work with arbitrary pytrees with non-jaxtype leaves\n",
+ " tree_where = sp.bcmap(jnp.where)\n",
+ " # for values > 2, add 100, else set to 0\n",
+ " print(tree_where(tree > 2, tree + 100, 0))\n",
+ "except TypeError as e:\n",
+ " print(\"bcmap fail\", e)\n",
+ " # now we can use `tree_mask`/`tree_unmask` to fix it\n",
+ " masked_tree = sp.tree_mask(tree) # mask non-jaxtype leaves\n",
+ " masked_tree = tree_where(masked_tree > 2, masked_tree + 100, 0)\n",
+ " unmasked_tree = sp.tree_unmask(masked_tree)\n",
+ " print(unmasked_tree)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### `bcmap` + configs\n",
+ "\n",
+ "The next example shows how to use `serket.bcmap` to loop over a configuration dictionary that defines creation of simple linear layers."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[Linear(\n",
+ " weight=f32[1,1](μ=0.31, σ=0.00, ∈[0.31,0.31]), \n",
+ " bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])\n",
+ " ),\n",
+ " Linear(\n",
+ " weight=f32[2,1](μ=-1.27, σ=0.33, ∈[-1.59,-0.94]), \n",
+ " bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])\n",
+ " ),\n",
+ " Linear(\n",
+ " weight=f32[3,1](μ=0.24, σ=0.53, ∈[-0.48,0.77]), \n",
+ " bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])\n",
+ " ),\n",
+ " Linear(\n",
+ " weight=f32[4,1](μ=-0.28, σ=0.21, ∈[-0.64,-0.08]), \n",
+ " bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])\n",
+ " )]"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import sepes as sp\n",
+ "import jax\n",
+ "\n",
+ "\n",
+ "class Linear(sp.TreeClass):\n",
+ " def __init__(self, in_dim: int, out_dim: int, *, key: jax.Array):\n",
+ " self.weight = jax.random.normal(key, (in_dim, out_dim))\n",
+ " self.bias = jnp.zeros((out_dim,))\n",
+ "\n",
+ " def __call__(self, x: jax.Array) -> jax.Array:\n",
+ " return x @ self.weight + self.bias\n",
+ "\n",
+ "\n",
+ "config = {\n",
+ " # each layer gets a different input dimension\n",
+ " \"in_dim\": [1, 2, 3, 4],\n",
+ " # out_dim is broadcasted to all layers\n",
+ " \"out_dim\": 1,\n",
+ " # each layer gets a different key\n",
+ " \"key\": list(jax.random.split(jax.random.PRNGKey(0), 4)),\n",
+ "}\n",
+ "\n",
+ "\n",
+ "# `bcmap` transforms a function that takes a single input into a function that\n",
+ "# arbitrary pytree inputs. in case of a single input, the input is broadcasted\n",
+ "# to match the tree structure of the first argument\n",
+ "# (in our example is a list of 4 inputs)\n",
+ "\n",
+ "\n",
+ "@sp.bcmap\n",
+ "def build_layer(in_dim, out_dim, *, key: jax.Array):\n",
+ " return Linear(in_dim, out_dim, key=key)\n",
+ "\n",
+ "\n",
+ "build_layer(config[\"in_dim\"], config[\"out_dim\"], key=config[\"key\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## [2] Masked transformations\n",
+ "\n",
+ "As an alternative to using `sp.tree_unmask` on pytrees before calling the function -as seen throughout training examples and recipes- , another approach is to wrap a certain transformation - not pytrees - (e.g. `jit`) to be make the masking/unmasking automatic; however this apporach will incur more overhead than applying `sp.tree_unmask` before the function call.\n",
+ "\n",
+ "The following example demonstrate how to wrap `jit`, and `vmap`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import sepes as sp\n",
+ "import functools as ft\n",
+ "import jax\n",
+ "import jax.random as jr\n",
+ "import jax.numpy as jnp\n",
+ "from typing import Any\n",
+ "from typing import Any, Callable, TypeVar\n",
+ "from typing_extensions import ParamSpec\n",
+ "\n",
+ "T = TypeVar(\"T\")\n",
+ "P = ParamSpec(\"P\")\n",
+ "\n",
+ "\n",
+ "def automask(jax_transform: Callable[P, T]) -> Callable[P, T]:\n",
+ " \"\"\"Enable jax transformations to accept non-jax types. e.g. ``jax.grad``.\"\"\"\n",
+ " # works with functions that takes a function as input\n",
+ " # and returns a function as output e.g. `jax.grad`\n",
+ "\n",
+ " def out_transform(func, **transformation_kwargs):\n",
+ " @ft.partial(jax_transform, **transformation_kwargs)\n",
+ " def jax_boundary(*args, **kwargs):\n",
+ " args, kwargs = sp.tree_unmask((args, kwargs))\n",
+ " return sp.tree_mask(func(*args, **kwargs))\n",
+ "\n",
+ " @ft.wraps(func)\n",
+ " def outer_wrapper(*args, **kwargs):\n",
+ " args, kwargs = sp.tree_mask((args, kwargs))\n",
+ " output = jax_boundary(*args, **kwargs)\n",
+ " return sp.tree_unmask(output)\n",
+ "\n",
+ " return outer_wrapper\n",
+ "\n",
+ " return out_transform\n",
+ "\n",
+ "\n",
+ "def inline_automask(jax_transform: Callable[P, T]) -> Callable[P, T]:\n",
+ " \"\"\"Enable jax transformations to accept non-jax types e.g. ``jax.lax.scan``.\"\"\"\n",
+ " # works with functions that takes a function and arguments as input\n",
+ " # and returns jax types as output e.g. `jax.lax.scan`\n",
+ "\n",
+ " def outer_wrapper(func, *args, **kwargs):\n",
+ " args, kwargs = sp.tree_mask((args, kwargs))\n",
+ "\n",
+ " def func_masked(*args, **kwargs):\n",
+ " args, kwargs = sp.tree_unmask((args, kwargs))\n",
+ " return sp.tree_mask(func(*args, **kwargs))\n",
+ "\n",
+ " output = jax_transform(func_masked, *args, **kwargs)\n",
+ " return sp.tree_unmask(output)\n",
+ "\n",
+ " return outer_wrapper"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### `automask`(`jit`)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "`jit error`: Argument 'layer' of type is not a valid JAX type\n",
+ "\n",
+ "Using automask:\n",
+ "forward_jit(params, x)=Array([[4.999546, 4.999546, 4.999546, 4.999546, 4.999546],\n",
+ " [4.999546, 4.999546, 4.999546, 4.999546, 4.999546],\n",
+ " [4.999546, 4.999546, 4.999546, 4.999546, 4.999546],\n",
+ " [4.999546, 4.999546, 4.999546, 4.999546, 4.999546],\n",
+ " [4.999546, 4.999546, 4.999546, 4.999546, 4.999546]], dtype=float32)\n"
+ ]
+ }
+ ],
+ "source": [
+ "x, y = jnp.ones([5, 5]), jnp.ones([5, 5])\n",
+ "\n",
+ "params = dict(w1=jnp.ones([5, 5]), w2=jnp.ones([5, 5]), name=\"layer\")\n",
+ "\n",
+ "\n",
+ "def forward(params: dict[str, Any], x: jax.Array) -> jax.Array:\n",
+ " return jnp.tanh(x @ params[\"w1\"]) @ params[\"w2\"]\n",
+ "\n",
+ "\n",
+ "try:\n",
+ " forward_jit = jax.jit(forward)\n",
+ " print(forward_jit(params, x))\n",
+ "except TypeError as e:\n",
+ " print(\"`jit error`:\", e)\n",
+ " # now with `automask` the function can accept non-jax types (e.g. string)\n",
+ " forward_jit = automask(jax.jit)(forward)\n",
+ " print(\"\\nUsing automask:\")\n",
+ " print(f\"{forward_jit(params, x)=}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### `automask`(`vmap`)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "`vmap error`: Output from batched function 'layer' with type is not a valid JAX type\n",
+ "\n",
+ "Using automask:\n",
+ "dict(\n",
+ " name=layer, \n",
+ " w1=f32[4,5,5](μ=0.50, σ=0.28, ∈[0.02,1.00]), \n",
+ " w2=f32[4,5,5](μ=0.46, σ=0.27, ∈[0.01,0.99])\n",
+ ")\n"
+ ]
+ }
+ ],
+ "source": [
+ "def make_params(key: jax.Array):\n",
+ " k1, k2 = jax.random.split(key.astype(jnp.uint32))\n",
+ " return dict(w1=jr.uniform(k1, (5, 5)), w2=jr.uniform(k2, (5, 5)), name=\"layer\")\n",
+ "\n",
+ "\n",
+ "keys = jr.split(jr.PRNGKey(0), 4).astype(jnp.float32)\n",
+ "\n",
+ "try:\n",
+ " params = jax.vmap(make_params)(keys)\n",
+ " print(params)\n",
+ "except TypeError as e:\n",
+ " print(\"`vmap error`:\", e)\n",
+ " # now with `automask` the function can accept non-jax types (e.g. string)\n",
+ " params = automask(jax.vmap)(make_params)(keys)\n",
+ " print(\"\\nUsing automask:\")\n",
+ " print(sp.tree_repr(params))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### `automask`(`make_jaxpr`)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "`jax.make_jaxpr` failed: error=TypeError(\"Argument 'layer' of type is not a valid JAX type\")\n",
+ "\n",
+ "Using `automask:\n",
+ "{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[4,5,5]\u001b[39m b\u001b[35m:f32[4,5,5]\u001b[39m c\u001b[35m:f32[10,5]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n",
+ " \u001b[39m\u001b[22m\u001b[22md\u001b[35m:f32[10,4,5]\u001b[39m = dot_general[\n",
+ " dimension_numbers=(([1], [1]), ([], []))\n",
+ " preferred_element_type=float32\n",
+ " ] c a\n",
+ " e\u001b[35m:f32[4,10,5]\u001b[39m = transpose[permutation=(1, 0, 2)] d\n",
+ " f\u001b[35m:f32[4,10,5]\u001b[39m = tanh e\n",
+ " g\u001b[35m:f32[4,10,5]\u001b[39m = dot_general[\n",
+ " dimension_numbers=(([2], [1]), ([0], [0]))\n",
+ " preferred_element_type=float32\n",
+ " ] f b\n",
+ " \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(g,) }\n"
+ ]
+ }
+ ],
+ "source": [
+ "def make_params(key: jax.Array):\n",
+ " k1, k2 = jax.random.split(key.astype(jnp.uint32))\n",
+ " return dict(w1=jr.uniform(k1, (5, 5)), w2=jr.uniform(k2, (5, 5)), name=\"layer\")\n",
+ "\n",
+ "\n",
+ "keys = jr.split(jr.PRNGKey(0), 4).astype(jnp.float32)\n",
+ "params = automask(jax.vmap)(make_params)(keys)\n",
+ "\n",
+ "\n",
+ "def forward(params: dict[str, Any], x: jax.Array) -> jax.Array:\n",
+ " return jnp.tanh(x @ params[\"w1\"]) @ params[\"w2\"]\n",
+ "\n",
+ "\n",
+ "try:\n",
+ " jax.make_jaxpr(forward)(params, jnp.ones((10, 5)))\n",
+ "except TypeError as error:\n",
+ " print(f\"`jax.make_jaxpr` failed: {error=}\")\n",
+ " print(\"\\nUsing `automask:\")\n",
+ " print(automask(jax.make_jaxpr)(forward)(params, jnp.ones((10, 5))))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### `inline_automask`(`scan`)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "`jax.lax.scan` Failed: error=TypeError(\"Value 'layer' with type is not a valid JAX type\")\n",
+ "\n",
+ "Using `inline_automask`:\n",
+ "({'name': 'layer', 'w1': Array([[[0.6022109 , 0.06545091, 0.7613505 ]],\n",
+ "\n",
+ " [[0.33657324, 0.3744743 , 0.12130237]],\n",
+ "\n",
+ " [[0.51550114, 0.17686307, 0.6407058 ]],\n",
+ "\n",
+ " [[0.9101157 , 0.9690273 , 0.36771262]]], dtype=float32), 'w2': Array([[[0.2678218 ],\n",
+ " [0.3963921 ],\n",
+ " [0.7078583 ]],\n",
+ "\n",
+ " [[0.18808937],\n",
+ " [0.8475715 ],\n",
+ " [0.04241407]],\n",
+ "\n",
+ " [[0.74411213],\n",
+ " [0.6318574 ],\n",
+ " [0.58551705]],\n",
+ "\n",
+ " [[0.34456158],\n",
+ " [0.5347049 ],\n",
+ " [0.3992592 ]]], dtype=float32)}, Array([[[[0.62451595],\n",
+ " [0.3141999 ],\n",
+ " [0.59660065],\n",
+ " [0.7389193 ]],\n",
+ "\n",
+ " [[0.1839285 ],\n",
+ " [0.36948383],\n",
+ " [0.26153624],\n",
+ " [0.7847949 ]],\n",
+ "\n",
+ " [[0.81791794],\n",
+ " [0.53822035],\n",
+ " [0.7945141 ],\n",
+ " [1.2155443 ]],\n",
+ "\n",
+ " [[0.4768083 ],\n",
+ " [0.35134616],\n",
+ " [0.48272693],\n",
+ " [0.78913575]]]], dtype=float32))\n"
+ ]
+ }
+ ],
+ "source": [
+ "def make_params(key: jax.Array):\n",
+ " k1, k2 = jax.random.split(key.astype(jnp.uint32))\n",
+ " return dict(w1=jr.uniform(k1, (1, 3)), w2=jr.uniform(k2, (3, 1)), name=\"layer\")\n",
+ "\n",
+ "\n",
+ "keys = jr.split(jr.PRNGKey(0), 4).astype(jnp.float32)\n",
+ "params = automask(jax.vmap)(make_params)(keys)\n",
+ "\n",
+ "\n",
+ "def forward(params: dict[str, Any], x: jax.Array) -> jax.Array:\n",
+ " return jnp.tanh(x @ params[\"w1\"]) @ params[\"w2\"]\n",
+ "\n",
+ "\n",
+ "def scan_func(params, input):\n",
+ " # layer contains non-jax types\n",
+ " output = forward(params, input)\n",
+ " return params, output\n",
+ "\n",
+ "\n",
+ "try:\n",
+ " jax.lax.scan(scan_func, params, jnp.ones((1, 1)))\n",
+ "except TypeError as error:\n",
+ " print(f\"`jax.lax.scan` Failed: {error=}\")\n",
+ " print(\"\\nUsing `inline_automask`:\")\n",
+ " print(inline_automask(jax.lax.scan)(scan_func, params, jnp.ones((1, 1))))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### `inline_automask`(`eval_shape`)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "`jax.eval_shape` Failed: error=TypeError(\"Argument 'layer' of type is not a valid JAX type\")\n",
+ "\n",
+ "Using `inline_automask`:\n",
+ "ShapeDtypeStruct(shape=(10, 1), dtype=float32)\n"
+ ]
+ }
+ ],
+ "source": [
+ "def make_params(key: jax.Array):\n",
+ " k1, k2 = jax.random.split(key.astype(jnp.uint32))\n",
+ " return dict(w1=jr.uniform(k1, (1, 3)), w2=jr.uniform(k2, (3, 1)), name=\"layer\")\n",
+ "\n",
+ "\n",
+ "params = make_params(jr.PRNGKey(0))\n",
+ "\n",
+ "\n",
+ "def forward(params: dict[str, Any], x: jax.Array) -> jax.Array:\n",
+ " return jnp.tanh(x @ params[\"w1\"]) @ params[\"w2\"]\n",
+ "\n",
+ "\n",
+ "try:\n",
+ " jax.eval_shape(forward, params, jnp.ones((10, 1)))\n",
+ "except TypeError as error:\n",
+ " print(f\"`jax.eval_shape` Failed: {error=}\")\n",
+ " print(\"\\nUsing `inline_automask`:\")\n",
+ " print(inline_automask(jax.eval_shape)(forward, params, jnp.ones((10, 1))))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "dev-jax",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.2"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs/tree_recipes.rst b/docs/tree_recipes.rst
new file mode 100644
index 0000000..18da62e
--- /dev/null
+++ b/docs/tree_recipes.rst
@@ -0,0 +1,17 @@
+🍳 Tree recipes
+----------------
+
+.. note::
+ `sepes `_ for tree API is fully re-exported under the ``serket`` namespace.
+ `Check the docs `_ for full details.
+
+.. toctree::
+ :caption: recipes
+ :maxdepth: 1
+
+ notebooks/[recipes]surgery
+ notebooks/[recipes]fields
+ notebooks/[recipes]intermediates
+ notebooks/[recipes]misc
+ notebooks/[recipes]sharing
+ notebooks/[recipes]transformations
\ No newline at end of file