diff --git a/docs/API/constructor.rst b/docs/API/constructor.rst index ab39ea5..a519737 100644 --- a/docs/API/constructor.rst +++ b/docs/API/constructor.rst @@ -1,4 +1,4 @@ -🏗️ Constructor utils API +🏗️ Constructor API ============================= diff --git a/docs/API/sepes.rst b/docs/API/sepes.rst index 6e705be..bbdab9b 100644 --- a/docs/API/sepes.rst +++ b/docs/API/sepes.rst @@ -2,7 +2,7 @@ ============================= .. note:: - `sepes `_ API is fully re-exported under the ``serket`` namespace. + `sepes `_ for tree API is fully re-exported under the ``serket`` namespace. `Check the docs `_ for full details. .. toctree:: diff --git a/docs/index.rst b/docs/index.rst index 6ac364d..f876c82 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -64,6 +64,7 @@ Install from github:: core_guides other_guides interoperability + tree_recipes .. currentmodule:: serket diff --git a/docs/notebooks/[recipes]fields.ipynb b/docs/notebooks/[recipes]fields.ipynb new file mode 100644 index 0000000..479aa00 --- /dev/null +++ b/docs/notebooks/[recipes]fields.ipynb @@ -0,0 +1,525 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 🏟️ Fields" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install sepes" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This section introduces common recipes for fields. A `sepes.field` is class variable that adds certain functionality to the class with `jax` and `numpy`, but this can work with any other framework.\n", + "\n", + "Add field is written like this:\n", + "\n", + "```python\n", + "class MyClass:\n", + " my_field: Any = sepes.field()\n", + "```\n", + "For example, a `field` can be used to validate the input data, or to provide a default value. The notebook provides examples for common use cases.\n", + "\n", + "`sepes.field` is implemented as a [python descriptor](https://docs.python.org/3/howto/descriptor.html), which means that it can be used in any class not necessarily a `sepes` class. Refer to the [python documentation](https://docs.python.org/3/howto/descriptor.html) for more information on descriptors and how they work." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## [1] Buffers\n", + "In this example, certain array will be marked as non-trainable using `jax.lax.stop_gradient` and `field`.\n", + "\n", + "The standard way to mark an array as a buffer (e.g. non-trainable) is to write something like this:\n", + "```python\n", + "class Tree(sp.TreeClass):\n", + " def __init__(self, buffer: jax.Array):\n", + " self.buffer = buffer\n", + "\n", + " def __call__(self, x: jax.Array) -> jax.Array:\n", + " return x + jax.lax.stop_gradient(self.buffer)\n", + "```\n", + "However, if you access this buffer from other methods, then another `jax.lax.stop_gradient` should be used and written inside all the methods:\n", + "\n", + "```python\n", + "class Tree(sp.TreeClass):\n", + " def method_1(self, x: jax.Array) -> jax.Array:\n", + " return x + jax.lax.stop_gradient(self.buffer)\n", + " .\n", + " .\n", + " .\n", + " def method_n(self, x: jax.Array) -> jax.Array:\n", + " return x + jax.lax.stop_gradient(self.buffer)\n", + "```\n", + "\n", + "Similarly, if you access `buffer` defined for `Tree` instances, from another context, you need to use `jax.lax.stop_gradient` again:\n", + "\n", + "```python\n", + "tree = Tree(buffer=...)\n", + "def func(tree: Tree):\n", + " buffer = jax.lax.stop_gradient(tree.buffer)\n", + " ... \n", + "```\n", + "\n", + "This becomes **cumbersome** if this process is repeated multiple times.Alternatively, `jax.lax.stop_gradient` can be applied to the `buffer` using `sepes.field` whenever the buffer is accessed. The next example demonstrates this." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6.0\n", + "Tree(buffer=[0. 0. 0.])\n" + ] + } + ], + "source": [ + "import sepes as sp\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "\n", + "def buffer_field(**kwargs):\n", + " return sp.field(on_getattr=[jax.lax.stop_gradient], **kwargs)\n", + "\n", + "\n", + "@sp.autoinit # autoinit construct `__init__` from fields\n", + "class Tree(sp.TreeClass):\n", + " buffer: jax.Array = buffer_field()\n", + "\n", + " def __call__(self, x):\n", + " return self.buffer**x\n", + "\n", + "\n", + "tree = Tree(buffer=jnp.array([1.0, 2.0, 3.0]))\n", + "tree(2.0) # Array([1., 4., 9.], dtype=float32)\n", + "\n", + "\n", + "@jax.jit\n", + "def f(tree: Tree, x: jax.Array):\n", + " return jnp.sum(tree(x))\n", + "\n", + "\n", + "print(f(tree, 1.0))\n", + "print(jax.grad(f)(tree, 1.0))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## [2] Masked field\n", + "\n", + "`sepes` provide a simple wrapper to *mask* data. Masking here means that the data yields no leaves when flattened. This is useful in some frameworks like `jax` to hide a certain values from being seen by the transformation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Flattening a masked value**" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1, #2]\n", + "[1]\n" + ] + } + ], + "source": [ + "import sepes as sp\n", + "import jax\n", + "\n", + "tree = [1, sp.tree_mask(2, cond=lambda _: True)]\n", + "print(tree)\n", + "print(jax.tree_util.tree_leaves(tree)) # note that 2 is removed from the leaves" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Using masking with `jax` transformations**\n", + "\n", + "The next example demonstrates how to use masking to work with data types that are not supported by `jax`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import sepes as sp\n", + "import jax\n", + "\n", + "\n", + "def mask_field(**kwargs):\n", + " return sp.field(\n", + " # un mask when the value is accessed\n", + " on_getattr=[lambda x: sp.tree_unmask(x, cond=lambda node: True)],\n", + " # mask when the value is set\n", + " on_setattr=[lambda x: sp.tree_mask(x, cond=lambda node: True)],\n", + " **kwargs,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can use this custom `field` to mark some class attributes as masked. Masking a value will effectively hide it from `jax` transformations." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Without masking the `str` type**" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "Argument 'training' of type is not a valid JAX type.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[5], line 18\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m tree(\u001b[38;5;28minput\u001b[39m)\n\u001b[1;32m 17\u001b[0m tree \u001b[38;5;241m=\u001b[39m Tree(training_mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtraining\u001b[39m\u001b[38;5;124m\"\u001b[39m, alpha\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2.0\u001b[39m)\n\u001b[0;32m---> 18\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43mloss_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtree\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2.0\u001b[39;49m\u001b[43m)\u001b[49m) \u001b[38;5;66;03m# <- will throw error with jax transformations.\u001b[39;00m\n", + " \u001b[0;31m[... skipping hidden 5 frame]\u001b[0m\n", + "File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/envs/dev-jax/lib/python3.12/site-packages/jax/_src/dispatch.py:281\u001b[0m, in \u001b[0;36mcheck_arg\u001b[0;34m(arg)\u001b[0m\n\u001b[1;32m 279\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcheck_arg\u001b[39m(arg: Any):\n\u001b[1;32m 280\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28misinstance\u001b[39m(arg, core\u001b[38;5;241m.\u001b[39mTracer) \u001b[38;5;129;01mor\u001b[39;00m core\u001b[38;5;241m.\u001b[39mvalid_jaxtype(arg)):\n\u001b[0;32m--> 281\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mArgument \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00marg\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(arg)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is not a valid \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 282\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mJAX type.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[0;31mTypeError\u001b[0m: Argument 'training' of type is not a valid JAX type." + ] + } + ], + "source": [ + "@sp.autoinit\n", + "class Tree(sp.TreeClass):\n", + " training_mode: str # <- will throw error with jax transformations.\n", + " alpha: float\n", + "\n", + " def __call__(self, x):\n", + " if self.training_mode == \"training\":\n", + " return x**self.alpha\n", + " return x\n", + "\n", + "\n", + "@jax.grad\n", + "def loss_func(tree, input):\n", + " return tree(input)\n", + "\n", + "\n", + "tree = Tree(training_mode=\"training\", alpha=2.0)\n", + "print(loss_func(tree, 2.0)) # <- will throw error with jax transformations." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The error resulted because `jax` recognize numerical values only. The next example demonstrates how to modify the class to mask the `str` type." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@sp.autoinit\n", + "class Tree(sp.TreeClass):\n", + " training_mode: str = mask_field() # hide the field from jax transformations\n", + " alpha: float\n", + "\n", + " def __call__(self, x):\n", + " if self.training_mode == \"training\":\n", + " return x**self.alpha\n", + " return x\n", + "\n", + "\n", + "tree = Tree(training_mode=\"training\", alpha=2.0)\n", + "print(loss_func(tree, 2.0))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## [3] Validator fields" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following provides an example of how to use `sepes.field` to validate the input data. The `validator` function is used to check if the input data is valid. If the data is invalid, an exception is raised. This example is inspired by the [python offical docs example](https://docs.python.org/3/howto/descriptor.html#validator-class)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Range+Type validator" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import sepes as sp\n", + "\n", + "\n", + "# you can use any function\n", + "@sp.autoinit\n", + "class Range(sp.TreeClass):\n", + " min: int | float = -float(\"inf\")\n", + " max: int | float = float(\"inf\")\n", + "\n", + " def __call__(self, x):\n", + " if not (self.min <= x <= self.max):\n", + " raise ValueError(f\"{x} not in range [{self.min}, {self.max}]\")\n", + " return x\n", + "\n", + "\n", + "@sp.autoinit\n", + "class IsInstance(sp.TreeClass):\n", + " klass: type | tuple[type, ...]\n", + "\n", + " def __call__(self, x):\n", + " if not isinstance(x, self.klass):\n", + " raise TypeError(f\"{x} not an instance of {self.klass}\")\n", + " return x\n", + "\n", + "\n", + "@sp.autoinit\n", + "class Foo(sp.TreeClass):\n", + " # allow in_dim to be an integer between [1,100]\n", + " in_dim: int = sp.field(on_setattr=[IsInstance(int), Range(1, 100)])\n", + "\n", + "\n", + "tree = Foo(1)\n", + "# no error\n", + "\n", + "try:\n", + " tree = Foo(0)\n", + "except ValueError as e:\n", + " print(e)\n", + "\n", + "try:\n", + " tree = Foo(1.0)\n", + "except TypeError as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Array validator" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sepes as sp\n", + "from typing import Any\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "\n", + "class ArrayValidator(sp.TreeClass):\n", + " \"\"\"Validate shape and dtype of input array.\n", + "\n", + " Args:\n", + " shape: Expected shape of array. available values are int, None, ...\n", + " use int for fixed size, None for any size, and ... for any number\n", + " of dimensions. for example (..., 1) allows any number of dimensions\n", + " with the last dimension being 1. (1, ..., 1) allows any number of\n", + " dimensions with the first and last dimensions being 1.\n", + " dtype: Expected dtype of array.\n", + "\n", + " Example:\n", + " >>> x = jnp.ones((5, 5))\n", + " >>> # any number of dimensions with last dim=5\n", + " >>> shape = (..., 5)\n", + " >>> dtype = jnp.float32\n", + " >>> validator = ArrayValidator(shape, dtype)\n", + " >>> validator(x) # no error\n", + "\n", + " >>> # must be 2 dimensions with first dim unconstrained and last dim=5\n", + " >>> shape = (None, 5)\n", + " >>> validator = ArrayValidator(shape, dtype)\n", + " >>> validator(x) # no error\n", + " \"\"\"\n", + "\n", + " def __init__(self, shape, dtype):\n", + " if shape.count(...) > 1:\n", + " raise ValueError(\"Only one ellipsis allowed\")\n", + "\n", + " for si in shape:\n", + " if not isinstance(si, (int, type(...), type(None))):\n", + " raise TypeError(f\"Expected int or ..., got {si}\")\n", + "\n", + " self.shape = shape\n", + " self.dtype = dtype\n", + "\n", + " def __call__(self, x):\n", + " if not (hasattr(x, \"shape\") and hasattr(x, \"dtype\")):\n", + " raise TypeError(f\"Expected array with shape {self.shape}, got {x}\")\n", + "\n", + " shape = list(self.shape)\n", + " array_shape = list(x.shape)\n", + " array_dtype = x.dtype\n", + "\n", + " if self.shape and array_dtype != self.dtype:\n", + " raise TypeError(f\"Dtype mismatch, {array_dtype=} != {self.dtype=}\")\n", + "\n", + " if ... in shape:\n", + " index = shape.index(...)\n", + " shape = (\n", + " shape[:index]\n", + " + [None] * (len(array_shape) - len(shape) + 1)\n", + " + shape[index + 1 :]\n", + " )\n", + "\n", + " if len(shape) != len(array_shape):\n", + " raise ValueError(f\"{len(shape)=} != {len(array_shape)=}\")\n", + "\n", + " for i, (li, ri) in enumerate(zip(shape, array_shape)):\n", + " if li is None:\n", + " continue\n", + " if li != ri:\n", + " raise ValueError(f\"Size mismatch, {li} != {ri} at dimension {i}\")\n", + " return x\n", + "\n", + "\n", + "# any number of dimensions with firt dim=3 and last dim=6\n", + "shape = (3, ..., 6)\n", + "# dtype must be float32\n", + "dtype = jnp.float32\n", + "\n", + "validator = ArrayValidator(shape=shape, dtype=dtype)\n", + "\n", + "# convert to half precision from float32\n", + "converter = lambda x: x.astype(jnp.float16)\n", + "\n", + "\n", + "@sp.autoinit\n", + "class Tree(sp.TreeClass):\n", + " array: jax.Array = sp.field(on_setattr=[validator, converter])\n", + "\n", + "\n", + "x = jnp.ones([3, 1, 2, 6])\n", + "tree = Tree(array=x)\n", + "\n", + "\n", + "try:\n", + " y = jnp.ones([1, 1, 2, 3])\n", + " tree = Tree(array=y)\n", + "except ValueError as e:\n", + " print(e, \"\\n\")\n", + " # On applying ArrayValidator(shape=(3, Ellipsis, 6), dtype=) for field=`array`:\n", + " # Dtype mismatch, array_dtype=dtype('float16') != self.dtype=\n", + "\n", + "try:\n", + " z = x.astype(jnp.float16)\n", + " tree = Tree(array=z)\n", + "except TypeError as e:\n", + " print(e)\n", + " # On applying ArrayValidator(shape=(3, Ellipsis, 6), dtype=) for field=`array`:\n", + " # Size mismatch, 3 != 1 at dimension 0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## [4] Parameterization field\n", + "\n", + "In this example, field value is [parameterized](https://pytorch.org/tutorials/intermediate/parametrizations.html) using `on_getattr`,\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sepes as sp\n", + "import jax.numpy as jnp\n", + "\n", + "\n", + "def symmetric(array: jax.Array) -> jax.Array:\n", + " triangle = jnp.triu(array) # upper triangle\n", + " return triangle + triangle.transpose(-1, -2)\n", + "\n", + "\n", + "@sp.autoinit\n", + "class Tree(sp.TreeClass):\n", + " symmetric_matrix: jax.Array = sp.field(on_getattr=[symmetric])\n", + "\n", + "\n", + "tree = Tree(symmetric_matrix=jnp.arange(9).reshape(3, 3))\n", + "print(tree.symmetric_matrix)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dev-jax", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/notebooks/[recipes]intermediates.ipynb b/docs/notebooks/[recipes]intermediates.ipynb new file mode 100644 index 0000000..251d530 --- /dev/null +++ b/docs/notebooks/[recipes]intermediates.ipynb @@ -0,0 +1,130 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 🔧 Intermediates handling\n", + "\n", + "This notebook demonstrates how to capture the intermediate outputs of a model during inference. This is useful for debugging, understanding the model, and visualizing the model's internal representations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install sepes" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Capture intermediate values. \n", + "\n", + "In this example, we will capture the intermediate values in a method by simply returning them as part of the output." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'b': 2.0, 'c': 4.0}" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import sepes as sp\n", + "\n", + "\n", + "class Foo(sp.TreeClass):\n", + " def __init__(self):\n", + " self.a = 1.0\n", + "\n", + " def __call__(self, x):\n", + " capture = {}\n", + " b = self.a + x\n", + " capture[\"b\"] = b\n", + " c = 2 * b\n", + " capture[\"c\"] = c\n", + " e = 4 * c\n", + " return e, capture\n", + "\n", + "\n", + "foo = Foo()\n", + "\n", + "_, inter_values = foo(1.0)\n", + "inter_values" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Capture intermediate gradients\n", + "\n", + "In this example, we will capture the intermediate gradients in a method by 1) perturbing the desired value and 2) using `argnum` in `jax.grad` to compute the intermediate gradients." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'b': Array(8., dtype=float32, weak_type=True),\n", + " 'c': Array(4., dtype=float32, weak_type=True)}" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import sepes as sp\n", + "import jax\n", + "\n", + "\n", + "class Foo(sp.TreeClass):\n", + " def __init__(self):\n", + " self.a = 1.0\n", + "\n", + " def __call__(self, x, perturb):\n", + " # pass in the perturbations as a pytree\n", + " b = self.a + x + perturb[\"b\"]\n", + " c = 2 * b + perturb[\"c\"]\n", + " e = 4 * c\n", + " return e\n", + "\n", + "\n", + "foo = Foo()\n", + "\n", + "# de/dc = 4\n", + "# de/db = de/dc * dc/db = 4 * 2 = 8\n", + "\n", + "# take gradient with respect to the perturbations pytree\n", + "# by setting `argnums=1` in `jax.grad`\n", + "inter_grads = jax.grad(foo, argnums=1)(1.0, dict(b=0.0, c=0.0))\n", + "inter_grads" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/notebooks/[recipes]misc.ipynb b/docs/notebooks/[recipes]misc.ipynb new file mode 100644 index 0000000..607da61 --- /dev/null +++ b/docs/notebooks/[recipes]misc.ipynb @@ -0,0 +1,266 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 🗂️ Misc recipes" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install sepes" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This section introduces some miscellaneous recipes that are not covered in the previous sections." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## [1] Lazy layers.\n", + "In this example, a `Linear` layer with a weight parameter based on the shape of the input will be created. Since this requires parameter creation (i.e., `weight`) after instance initialization, we will use `value_and_tree` to create a new instance with the added parameter." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Layer before param is set:\tLazyLinear(out_features=1)\n", + "Layer after param is set:\tLazyLinear(out_features=1, weight=[[1.]], bias=[0.])\n" + ] + }, + { + "data": { + "text/plain": [ + "Array([[1.],\n", + " [1.],\n", + " [1.],\n", + " [1.],\n", + " [1.],\n", + " [1.],\n", + " [1.],\n", + " [1.],\n", + " [1.],\n", + " [1.]], dtype=float32)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import sepes as sp\n", + "from typing import Any\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "\n", + "class LazyLinear(sp.TreeClass):\n", + " def __init__(self, out_features: int):\n", + " self.out_features = out_features\n", + "\n", + " def param(self, name: str, value: Any):\n", + " # return the value if it exists, otherwise set it and return it\n", + " if name not in vars(self):\n", + " setattr(self, name, value)\n", + " return vars(self)[name]\n", + "\n", + " def __call__(self, input: jax.Array) -> jax.Array:\n", + " weight = self.param(\"weight\", jnp.ones((self.out_features, input.shape[-1])))\n", + " bias = self.param(\"bias\", jnp.zeros((self.out_features,)))\n", + " return input @ weight.T + bias\n", + "\n", + "\n", + "input = jnp.ones([10, 1])\n", + "\n", + "lazy = LazyLinear(out_features=1)\n", + "\n", + "print(f\"Layer before param is set:\\t{lazy}\")\n", + "\n", + "# `value_and_tree` executes any mutating method in a functional way\n", + "_, material = sp.value_and_tree(lambda layer: layer(input))(lazy)\n", + "\n", + "print(f\"Layer after param is set:\\t{material}\")\n", + "# subsequent calls will not set the parameters again\n", + "material(input)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## [3] Regularization\n", + "\n", + "The following code showcase how to use `at` functionality to select some leaves of a model based on boolean mask or/and name condition to apply some weight regualrization on them. For example using `.at[...]` functionality the following can be achieved concisely:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Boolean-based mask\n", + "\n", + "The entries of the arrays or leaves are selected based on a tree of the same structure but with boolean (`True`/`False`) leave. The `True` leaf points to place where the operation can be done, while `False` leaf is indicating that this leaf should not be touched." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.8333335\n" + ] + } + ], + "source": [ + "import sepes as sp\n", + "import jax.numpy as jnp\n", + "import jax\n", + "\n", + "\n", + "class Net(sp.TreeClass):\n", + " def __init__(self):\n", + " self.weight = jnp.array([-1, -2, -3, 1, 2, 3])\n", + " self.bias = jnp.array([-1, 1])\n", + "\n", + "\n", + "def negative_entries_l2_loss(net: Net):\n", + " return (\n", + " # select all positive array entries\n", + " net.at[jax.tree_map(lambda x: x > 0, net)]\n", + " # set them to zero to exclude their loss\n", + " .set(0)\n", + " # select all leaves\n", + " .at[...]\n", + " # finally reduce with l2 loss\n", + " .reduce(lambda x, y: x + jnp.mean(y**2), initializer=0)\n", + " )\n", + "\n", + "\n", + "net = Net()\n", + "print(negative_entries_l2_loss(net))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Name-based mask\n", + "\n", + "In this step, the mask is based on the path of the leaf." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "331.84155\n" + ] + } + ], + "source": [ + "import sepes as sp\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jax.random as jr\n", + "\n", + "\n", + "class Linear(sp.TreeClass):\n", + " def __init__(self, in_features: int, out_features: int, key: jax.Array):\n", + " self.weight = jr.normal(key=key, shape=(out_features, in_features))\n", + " self.bias = jnp.zeros((out_features,))\n", + "\n", + "\n", + "class Net(sp.TreeClass):\n", + " def __init__(self, key: jax.Array) -> None:\n", + " k1, k2, k3, k4 = jax.random.split(key, 4)\n", + " self.linear1 = Linear(1, 20, key=k1)\n", + " self.linear2 = Linear(20, 20, key=k2)\n", + " self.linear3 = Linear(20, 20, key=k3)\n", + " self.linear4 = Linear(20, 1, key=k4)\n", + "\n", + "\n", + "def linear_12_weight_l1_loss(net: Net):\n", + " return (\n", + " # select desired branches (linear1, linear2 in this example)\n", + " # and the desired leaves (weight)\n", + " net.at[\"linear1\", \"linear2\"][\"weight\"]\n", + " # alternatively, regex can be used to do the same functiontality\n", + " # >>> import re\n", + " # >>> net.at[re.compile(\"linear[12]\")][\"weight\"]\n", + " # finally apply l1 loss\n", + " .reduce(lambda x, y: x + jnp.sum(jnp.abs(y)), initializer=0)\n", + " )\n", + "\n", + "\n", + "net = Net(key=jr.PRNGKey(0))\n", + "print(linear_12_weight_l1_loss(net))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This recipe can then be included inside the loss function, for example\n", + "\n", + "``` python\n", + "\n", + "def loss_fnc(net, x, y):\n", + " l1_loss = linear_12_weight_l1_loss(net)\n", + " loss += l1_loss\n", + " ...\n", + "```" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dev-jax", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/notebooks/[recipes]sharing.ipynb b/docs/notebooks/[recipes]sharing.ipynb new file mode 100644 index 0000000..2649773 --- /dev/null +++ b/docs/notebooks/[recipes]sharing.ipynb @@ -0,0 +1,212 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 🪡 Sharing/Tie Weights\n", + "\n", + "Because sharing weights convert a pytree to a graph by pointing one leaf to another, a careful handling is needed to avoid breaking the _tree_ assumptions.\n", + "\n", + "In `sepes`, sharing/tie weights is done inside methods, this means, instead sharing the reference within `__init__` method, the reference is shared within the method of which the call is made.\n", + "\n", + "**From**\n", + "\n", + "```python\n", + "class TiedAutoEncoder:\n", + " def __init__(self, input_dim, hidden_dim):\n", + " self.encoder = Linear(input_dim, hidden_dim)\n", + " self.decoder = Linear(hidden_dim, input_dim)\n", + " self.decoder.weight = self.encoder.weight\n", + "\n", + " def __call__(self, x):\n", + " return self.decoder(self.encoder(x))\n", + "```\n", + "\n", + "**To** \n", + "\n", + "```python\n", + "class TiedAutoEncoder:\n", + " def __init__(self, input_dim, hidden_dim):\n", + " self.encoder = Linear(input_dim, hidden_dim)\n", + " self.decoder = Linear(hidden_dim, input_dim)\n", + " \n", + " def __call__(self, x):\n", + " self.decoder.weight = self.encoder.weight.T\n", + " return self.decoder(self.encoder(x))\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install sepes" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this example a simple `AutoEncoder` with shared `weight` between the encode/decoder is demonstrated." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import sepes as sp\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jax.random as jr\n", + "import functools as ft\n", + "\n", + "\n", + "def sharing(method):\n", + " # sharing simply copies the instance, executes the method, and returns the output\n", + " # **without modifying the original instance.**\n", + " @ft.wraps(method)\n", + " def wrapper(self, *args, **kwargs):\n", + " # `value_and_tree` executes any mutating method in a functional way\n", + " # by copying `self`, executing the method, and returning the new state\n", + " # along with the output.\n", + " output, _ = sp.value_and_tree(method)(self, *args, **kwargs)\n", + " return output\n", + "\n", + " return wrapper\n", + "\n", + "\n", + "class Linear(sp.TreeClass):\n", + " def __init__(self, in_features: int, out_features: int, key: jax.Array):\n", + " self.weight = jr.normal(key=key, shape=(out_features, in_features))\n", + " self.bias = jnp.zeros((out_features,))\n", + "\n", + " def __call__(self, input):\n", + " return input @ self.weight.T + self.bias\n", + "\n", + "\n", + "class AutoEncoder(sp.TreeClass):\n", + " def __init__(self, *, key: jax.Array):\n", + " k1, k2, k3, k4 = jr.split(key, 4)\n", + " self.enc1 = Linear(1, 10, key=k1)\n", + " self.enc2 = Linear(10, 20, key=k2)\n", + " self.dec2 = Linear(20, 10, key=k3)\n", + " self.dec1 = Linear(10, 1, key=k4)\n", + "\n", + " @sharing\n", + " def tied_call(self, input: jax.Array) -> jax.Array:\n", + " self.dec1.weight = self.enc1.weight.T\n", + " self.dec2.weight = self.enc2.weight.T\n", + " output = self.enc1(input)\n", + " output = self.enc2(output)\n", + " output = self.dec2(output)\n", + " output = self.dec1(output)\n", + " return output\n", + "\n", + " def non_tied_call(self, x):\n", + " output = self.enc1(x)\n", + " output = self.enc2(output)\n", + " output = self.dec2(output)\n", + " output = self.dec1(output)\n", + " return output" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Linear(\n", + " weight=f32[1,10](μ=0.00, σ=0.00, ∈[0.00,0.00]), \n", + " bias=f32[1](μ=622.29, σ=0.00, ∈[622.29,622.29])\n", + ") Linear(\n", + " weight=f32[10,20](μ=0.00, σ=0.00, ∈[0.00,0.00]), \n", + " bias=f32[10](μ=-107.37, σ=706.50, ∈[-1561.75,949.41])\n", + ")\n" + ] + } + ], + "source": [ + "@jax.jit\n", + "@jax.grad\n", + "def tied_loss_func(net, x, y):\n", + " net = sp.tree_unmask(net)\n", + " return jnp.mean((jax.vmap(net.tied_call)(x) - y) ** 2)\n", + "\n", + "\n", + "tree = sp.tree_mask(AutoEncoder(key=jr.PRNGKey(0)))\n", + "x = jnp.ones([10, 1]) + 0.0\n", + "y = jnp.ones([10, 1]) * 2.0\n", + "grads: AutoEncoder = tied_loss_func(tree, x, y)\n", + "# note that the shared weights have 0 gradient\n", + "print(repr(grads.dec1), repr(grads.dec2))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Linear(\n", + " weight=f32[1,10](μ=0.00, σ=0.00, ∈[0.00,0.00]), \n", + " bias=f32[1](μ=622.29, σ=0.00, ∈[622.29,622.29])\n", + ") Linear(\n", + " weight=f32[10,20](μ=0.00, σ=0.00, ∈[0.00,0.00]), \n", + " bias=f32[10](μ=-107.37, σ=706.50, ∈[-1561.75,949.41])\n", + ")\n" + ] + } + ], + "source": [ + "# check for non-tied call\n", + "@jax.jit\n", + "@jax.grad\n", + "def non_tied_loss_func(net, x, y):\n", + " net = sp.tree_unmask(net)\n", + " return jnp.mean((jax.vmap(net.non_tied_call)(x) - y) ** 2)\n", + "\n", + "\n", + "tree = sp.tree_mask(tree)\n", + "x = jnp.ones([10, 1]) + 0.0\n", + "y = jnp.ones([10, 1]) * 2.0\n", + "grads: AutoEncoder = tied_loss_func(tree, x, y)\n", + "\n", + "# note that the shared weights have non-zero gradients\n", + "print(repr(grads.dec1), repr(grads.dec2))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dev-jax", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/notebooks/[recipes]surgery.ipynb b/docs/notebooks/[recipes]surgery.ipynb new file mode 100644 index 0000000..7f63fed --- /dev/null +++ b/docs/notebooks/[recipes]surgery.ipynb @@ -0,0 +1,315 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ✂️ Surgery\n", + "\n", + "This notebook provides tree editing (surgery) recipes using `at`. `at` encapsules a pytree and provides a way to edit it in out-of-place manner. Under the hood, `at` uses `jax.tree_util` or `optree` to traverse the tree and apply the provided function to the selected nodes.\n", + "\n", + "```python\n", + "import sepes as sp\n", + "import re\n", + "tree = dict(key_1=[1, -2, 3], key_2=[4, 5, 6], other=[7, 8, 9])\n", + "where = re.compile(\"key_.*\") # select all keys starting with \"key_\"\n", + "func = lambda node: sum(map(abs, node)) # sum of absolute values\n", + "sp.at(tree)[where].apply(func)\n", + "# {'key_1': 6, 'key_2': 15, 'other': [7, 8, 9]}\n", + "```\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install sepes" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Out-of-place editing\n", + "\n", + "Out-of-place means that the original tree is not modified. Instead, a new tree is created with the changes. This is the default behavior of `at`. The following example demonstrates this behavior:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "pytree1=[1, [2, 3], 4], pytree2=[1, [2, 3], 4]\n", + "pytree1 is pytree2 = False\n" + ] + } + ], + "source": [ + "import sepes as sp\n", + "\n", + "pytree1 = [1, [2, 3], 4]\n", + "pytree2 = sp.at(pytree1)[...].get() # get the whole pytree using ...\n", + "print(f\"{pytree1=}, {pytree2=}\")\n", + "# even though pytree1 and pytree2 are the same, they are not the same object\n", + "# because pytree2 is a copy of pytree1\n", + "print(f\"pytree1 is pytree2 = {pytree1 is pytree2}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Matching keys" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Match all\n", + "\n", + "Use `...` to match all keys. \n", + "\n", + "The following example applies `plus_one` function to all tree nodes. This is equivalent to `tree = tree_map(plus_one, tree)`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[2, [3, 4], 5]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import sepes as sp\n", + "\n", + "pytree1 = [1, [2, 3], 4]\n", + "plus_one = lambda x: x + 1\n", + "pytree2 = sp.at(pytree1)[...].apply(plus_one)\n", + "pytree2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Integer indexing\n", + "\n", + "`at` can edit pytrees by integer paths." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[1, [100, 3], 4]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import sepes as sp\n", + "\n", + "pytree1 = [1, [2, 3], 4]\n", + "pytree2 = sp.at(pytree1)[1][0].set(100) # equivalent to pytree1[1][0] = 100\n", + "pytree2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Named path indexing\n", + "`at` can edit pytrees by named paths." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'a': -1, 'b': 100, 'e': -4, 'f': {'g': 7, 'h': 8}}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import sepes as sp\n", + "\n", + "pytree1 = {\"a\": -1, \"b\": {\"c\": 2, \"d\": 3}, \"e\": -4, \"f\": {\"g\": 7, \"h\": 8}}\n", + "pytree2 = sp.at(pytree1)[\"b\"].set(100) # equivalent to pytree1[\"b\"] = 100\n", + "pytree2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Regular expressions indexing\n", + "`at` can edit pytrees by regular expressions applied to keys." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'key_1': 100, 'key_2': 100, 'key_5': 100, 'key_6': {'key_7': 7, 'key_8': 8}}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import sepes as sp\n", + "import re\n", + "\n", + "pytree1 = {\n", + " \"key_1\": 1,\n", + " \"key_2\": {\"key_3\": 3, \"key_4\": 4},\n", + " \"key_5\": 5,\n", + " \"key_6\": {\"key_7\": 7, \"key_8\": 8},\n", + "}\n", + "# from 1 - 5, set the value to 100\n", + "pattern = re.compile(r\"key_[1-5]\")\n", + "pytree2 = sp.at(pytree1)[pattern].set(100)\n", + "pytree2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Mask indexing\n", + "\n", + "`at` can edit pytree entries by a boolean mask, meaning that given a mask of the same structure of the pytree, then nodes marked `True` will be edited, otherwise will not be touched. \n", + "\n", + "\n", + "The following example set all negative tree nodes to zero." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'a': 0, 'b': {'c': 2, 'd': 3}, 'e': 0}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import sepes as sp\n", + "import jax\n", + "\n", + "pytree1 = {\"a\": -1, \"b\": {\"c\": 2, \"d\": 3}, \"e\": -4}\n", + "# mask defines all desired entries to apply the function\n", + "mask = jax.tree_util.tree_map(lambda x: x < 0, pytree1)\n", + "pytree2 = sp.at(pytree1)[mask].set(0)\n", + "pytree2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Composition\n", + "\n", + "`at` can compose multiple keys, integer paths, named paths, regular expressions, and masks to edit the tree.\n", + "\n", + "The following example demonstrates how to apply a function to all nodes with:\n", + "- Name starting with \"key_\"\n", + "- Positive values" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'key_1': 100, 'key_2': -2, 'value_1': 1, 'value_2': 2}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import sepes as sp\n", + "import re\n", + "import jax\n", + "\n", + "pytree1 = {\"key_1\": 1, \"key_2\": -2, \"value_1\": 1, \"value_2\": 2}\n", + "pattern = re.compile(r\"key_.*\")\n", + "mask = jax.tree_util.tree_map(lambda x: x > 0, pytree1)\n", + "pytree2 = sp.at(pytree1)[pattern][mask].set(100)\n", + "pytree2" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dev-jax", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/notebooks/[recipes]transformations.ipynb b/docs/notebooks/[recipes]transformations.ipynb new file mode 100644 index 0000000..df053ec --- /dev/null +++ b/docs/notebooks/[recipes]transformations.ipynb @@ -0,0 +1,594 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 💫 Transformations" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This section introduces common function transformations that are used in conjunction with pytrees. Examples includes function transformation that wraps `jax` transforms or a function transformation that wraps `numpy`." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install sepes" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## [1] Broadcasting transformations\n", + "\n", + "Using `bcmap` to apply a function over pytree leaves with automatic broadcasting for scalar arguments. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `bcmap` + `numpy`\n", + "\n", + "In this recipe, `numpy` functions will operate directly on `TreeClass` instances." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tree(a=0, b=(0.0, 103.0), c=[104. 105. 106.])\n" + ] + } + ], + "source": [ + "import sepes as sp\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "\n", + "@sp.leafwise # enable math operations on leaves\n", + "@sp.autoinit # generate __init__ from type annotations\n", + "class Tree(sp.TreeClass):\n", + " a: int = 1\n", + " b: tuple[float] = (2.0, 3.0)\n", + " c: jax.Array = jnp.array([4.0, 5.0, 6.0])\n", + "\n", + "\n", + "tree = Tree()\n", + "\n", + "# make where work with arbitrary pytrees\n", + "tree_where = sp.bcmap(jnp.where)\n", + "# for values > 2, add 100, else set to 0\n", + "print(tree_where(tree > 2, tree + 100, 0))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`bcmap` on pytrees with non-jaxtype\n", + "\n", + "In case the tree has some non-jaxtype leaves, The above will fail, but we can use `tree_mask`/`tree_unmask` to fix it" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "bcmap fail '>' not supported between instances of 'str' and 'int'\n", + "Tree(a=0.0, b=(0.0, 103.0), c=[104. 105. 106.], name=tree, func=(x))\n" + ] + } + ], + "source": [ + "# in case the tree has some non-jaxtype leaves\n", + "# the above will fail, but we can use `tree_mask`/`tree_unmask` to fix it\n", + "import sepes as sp\n", + "import jax.numpy as jnp\n", + "from typing import Callable\n", + "\n", + "\n", + "@sp.leafwise # enable math operations on leaves\n", + "@sp.autoinit # generate __init__ from type annotations\n", + "class Tree(sp.TreeClass):\n", + " a: float = 1.0\n", + " b: tuple[float] = (2.0, 3.0)\n", + " c: jax.Array = jnp.array([4.0, 5.0, 6.0])\n", + " name: str = \"tree\" # non-jaxtype\n", + " func: Callable = lambda x: x # non-jaxtype\n", + "\n", + "\n", + "tree = Tree()\n", + "\n", + "try:\n", + " # make where work with arbitrary pytrees with non-jaxtype leaves\n", + " tree_where = sp.bcmap(jnp.where)\n", + " # for values > 2, add 100, else set to 0\n", + " print(tree_where(tree > 2, tree + 100, 0))\n", + "except TypeError as e:\n", + " print(\"bcmap fail\", e)\n", + " # now we can use `tree_mask`/`tree_unmask` to fix it\n", + " masked_tree = sp.tree_mask(tree) # mask non-jaxtype leaves\n", + " masked_tree = tree_where(masked_tree > 2, masked_tree + 100, 0)\n", + " unmasked_tree = sp.tree_unmask(masked_tree)\n", + " print(unmasked_tree)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `bcmap` + configs\n", + "\n", + "The next example shows how to use `serket.bcmap` to loop over a configuration dictionary that defines creation of simple linear layers." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Linear(\n", + " weight=f32[1,1](μ=0.31, σ=0.00, ∈[0.31,0.31]), \n", + " bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", + " ),\n", + " Linear(\n", + " weight=f32[2,1](μ=-1.27, σ=0.33, ∈[-1.59,-0.94]), \n", + " bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", + " ),\n", + " Linear(\n", + " weight=f32[3,1](μ=0.24, σ=0.53, ∈[-0.48,0.77]), \n", + " bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", + " ),\n", + " Linear(\n", + " weight=f32[4,1](μ=-0.28, σ=0.21, ∈[-0.64,-0.08]), \n", + " bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", + " )]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import sepes as sp\n", + "import jax\n", + "\n", + "\n", + "class Linear(sp.TreeClass):\n", + " def __init__(self, in_dim: int, out_dim: int, *, key: jax.Array):\n", + " self.weight = jax.random.normal(key, (in_dim, out_dim))\n", + " self.bias = jnp.zeros((out_dim,))\n", + "\n", + " def __call__(self, x: jax.Array) -> jax.Array:\n", + " return x @ self.weight + self.bias\n", + "\n", + "\n", + "config = {\n", + " # each layer gets a different input dimension\n", + " \"in_dim\": [1, 2, 3, 4],\n", + " # out_dim is broadcasted to all layers\n", + " \"out_dim\": 1,\n", + " # each layer gets a different key\n", + " \"key\": list(jax.random.split(jax.random.PRNGKey(0), 4)),\n", + "}\n", + "\n", + "\n", + "# `bcmap` transforms a function that takes a single input into a function that\n", + "# arbitrary pytree inputs. in case of a single input, the input is broadcasted\n", + "# to match the tree structure of the first argument\n", + "# (in our example is a list of 4 inputs)\n", + "\n", + "\n", + "@sp.bcmap\n", + "def build_layer(in_dim, out_dim, *, key: jax.Array):\n", + " return Linear(in_dim, out_dim, key=key)\n", + "\n", + "\n", + "build_layer(config[\"in_dim\"], config[\"out_dim\"], key=config[\"key\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## [2] Masked transformations\n", + "\n", + "As an alternative to using `sp.tree_unmask` on pytrees before calling the function -as seen throughout training examples and recipes- , another approach is to wrap a certain transformation - not pytrees - (e.g. `jit`) to be make the masking/unmasking automatic; however this apporach will incur more overhead than applying `sp.tree_unmask` before the function call.\n", + "\n", + "The following example demonstrate how to wrap `jit`, and `vmap`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import sepes as sp\n", + "import functools as ft\n", + "import jax\n", + "import jax.random as jr\n", + "import jax.numpy as jnp\n", + "from typing import Any\n", + "from typing import Any, Callable, TypeVar\n", + "from typing_extensions import ParamSpec\n", + "\n", + "T = TypeVar(\"T\")\n", + "P = ParamSpec(\"P\")\n", + "\n", + "\n", + "def automask(jax_transform: Callable[P, T]) -> Callable[P, T]:\n", + " \"\"\"Enable jax transformations to accept non-jax types. e.g. ``jax.grad``.\"\"\"\n", + " # works with functions that takes a function as input\n", + " # and returns a function as output e.g. `jax.grad`\n", + "\n", + " def out_transform(func, **transformation_kwargs):\n", + " @ft.partial(jax_transform, **transformation_kwargs)\n", + " def jax_boundary(*args, **kwargs):\n", + " args, kwargs = sp.tree_unmask((args, kwargs))\n", + " return sp.tree_mask(func(*args, **kwargs))\n", + "\n", + " @ft.wraps(func)\n", + " def outer_wrapper(*args, **kwargs):\n", + " args, kwargs = sp.tree_mask((args, kwargs))\n", + " output = jax_boundary(*args, **kwargs)\n", + " return sp.tree_unmask(output)\n", + "\n", + " return outer_wrapper\n", + "\n", + " return out_transform\n", + "\n", + "\n", + "def inline_automask(jax_transform: Callable[P, T]) -> Callable[P, T]:\n", + " \"\"\"Enable jax transformations to accept non-jax types e.g. ``jax.lax.scan``.\"\"\"\n", + " # works with functions that takes a function and arguments as input\n", + " # and returns jax types as output e.g. `jax.lax.scan`\n", + "\n", + " def outer_wrapper(func, *args, **kwargs):\n", + " args, kwargs = sp.tree_mask((args, kwargs))\n", + "\n", + " def func_masked(*args, **kwargs):\n", + " args, kwargs = sp.tree_unmask((args, kwargs))\n", + " return sp.tree_mask(func(*args, **kwargs))\n", + "\n", + " output = jax_transform(func_masked, *args, **kwargs)\n", + " return sp.tree_unmask(output)\n", + "\n", + " return outer_wrapper" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `automask`(`jit`)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "`jit error`: Argument 'layer' of type is not a valid JAX type\n", + "\n", + "Using automask:\n", + "forward_jit(params, x)=Array([[4.999546, 4.999546, 4.999546, 4.999546, 4.999546],\n", + " [4.999546, 4.999546, 4.999546, 4.999546, 4.999546],\n", + " [4.999546, 4.999546, 4.999546, 4.999546, 4.999546],\n", + " [4.999546, 4.999546, 4.999546, 4.999546, 4.999546],\n", + " [4.999546, 4.999546, 4.999546, 4.999546, 4.999546]], dtype=float32)\n" + ] + } + ], + "source": [ + "x, y = jnp.ones([5, 5]), jnp.ones([5, 5])\n", + "\n", + "params = dict(w1=jnp.ones([5, 5]), w2=jnp.ones([5, 5]), name=\"layer\")\n", + "\n", + "\n", + "def forward(params: dict[str, Any], x: jax.Array) -> jax.Array:\n", + " return jnp.tanh(x @ params[\"w1\"]) @ params[\"w2\"]\n", + "\n", + "\n", + "try:\n", + " forward_jit = jax.jit(forward)\n", + " print(forward_jit(params, x))\n", + "except TypeError as e:\n", + " print(\"`jit error`:\", e)\n", + " # now with `automask` the function can accept non-jax types (e.g. string)\n", + " forward_jit = automask(jax.jit)(forward)\n", + " print(\"\\nUsing automask:\")\n", + " print(f\"{forward_jit(params, x)=}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `automask`(`vmap`)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "`vmap error`: Output from batched function 'layer' with type is not a valid JAX type\n", + "\n", + "Using automask:\n", + "dict(\n", + " name=layer, \n", + " w1=f32[4,5,5](μ=0.50, σ=0.28, ∈[0.02,1.00]), \n", + " w2=f32[4,5,5](μ=0.46, σ=0.27, ∈[0.01,0.99])\n", + ")\n" + ] + } + ], + "source": [ + "def make_params(key: jax.Array):\n", + " k1, k2 = jax.random.split(key.astype(jnp.uint32))\n", + " return dict(w1=jr.uniform(k1, (5, 5)), w2=jr.uniform(k2, (5, 5)), name=\"layer\")\n", + "\n", + "\n", + "keys = jr.split(jr.PRNGKey(0), 4).astype(jnp.float32)\n", + "\n", + "try:\n", + " params = jax.vmap(make_params)(keys)\n", + " print(params)\n", + "except TypeError as e:\n", + " print(\"`vmap error`:\", e)\n", + " # now with `automask` the function can accept non-jax types (e.g. string)\n", + " params = automask(jax.vmap)(make_params)(keys)\n", + " print(\"\\nUsing automask:\")\n", + " print(sp.tree_repr(params))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `automask`(`make_jaxpr`)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "`jax.make_jaxpr` failed: error=TypeError(\"Argument 'layer' of type is not a valid JAX type\")\n", + "\n", + "Using `automask:\n", + "{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[4,5,5]\u001b[39m b\u001b[35m:f32[4,5,5]\u001b[39m c\u001b[35m:f32[10,5]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", + " \u001b[39m\u001b[22m\u001b[22md\u001b[35m:f32[10,4,5]\u001b[39m = dot_general[\n", + " dimension_numbers=(([1], [1]), ([], []))\n", + " preferred_element_type=float32\n", + " ] c a\n", + " e\u001b[35m:f32[4,10,5]\u001b[39m = transpose[permutation=(1, 0, 2)] d\n", + " f\u001b[35m:f32[4,10,5]\u001b[39m = tanh e\n", + " g\u001b[35m:f32[4,10,5]\u001b[39m = dot_general[\n", + " dimension_numbers=(([2], [1]), ([0], [0]))\n", + " preferred_element_type=float32\n", + " ] f b\n", + " \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(g,) }\n" + ] + } + ], + "source": [ + "def make_params(key: jax.Array):\n", + " k1, k2 = jax.random.split(key.astype(jnp.uint32))\n", + " return dict(w1=jr.uniform(k1, (5, 5)), w2=jr.uniform(k2, (5, 5)), name=\"layer\")\n", + "\n", + "\n", + "keys = jr.split(jr.PRNGKey(0), 4).astype(jnp.float32)\n", + "params = automask(jax.vmap)(make_params)(keys)\n", + "\n", + "\n", + "def forward(params: dict[str, Any], x: jax.Array) -> jax.Array:\n", + " return jnp.tanh(x @ params[\"w1\"]) @ params[\"w2\"]\n", + "\n", + "\n", + "try:\n", + " jax.make_jaxpr(forward)(params, jnp.ones((10, 5)))\n", + "except TypeError as error:\n", + " print(f\"`jax.make_jaxpr` failed: {error=}\")\n", + " print(\"\\nUsing `automask:\")\n", + " print(automask(jax.make_jaxpr)(forward)(params, jnp.ones((10, 5))))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `inline_automask`(`scan`)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "`jax.lax.scan` Failed: error=TypeError(\"Value 'layer' with type is not a valid JAX type\")\n", + "\n", + "Using `inline_automask`:\n", + "({'name': 'layer', 'w1': Array([[[0.6022109 , 0.06545091, 0.7613505 ]],\n", + "\n", + " [[0.33657324, 0.3744743 , 0.12130237]],\n", + "\n", + " [[0.51550114, 0.17686307, 0.6407058 ]],\n", + "\n", + " [[0.9101157 , 0.9690273 , 0.36771262]]], dtype=float32), 'w2': Array([[[0.2678218 ],\n", + " [0.3963921 ],\n", + " [0.7078583 ]],\n", + "\n", + " [[0.18808937],\n", + " [0.8475715 ],\n", + " [0.04241407]],\n", + "\n", + " [[0.74411213],\n", + " [0.6318574 ],\n", + " [0.58551705]],\n", + "\n", + " [[0.34456158],\n", + " [0.5347049 ],\n", + " [0.3992592 ]]], dtype=float32)}, Array([[[[0.62451595],\n", + " [0.3141999 ],\n", + " [0.59660065],\n", + " [0.7389193 ]],\n", + "\n", + " [[0.1839285 ],\n", + " [0.36948383],\n", + " [0.26153624],\n", + " [0.7847949 ]],\n", + "\n", + " [[0.81791794],\n", + " [0.53822035],\n", + " [0.7945141 ],\n", + " [1.2155443 ]],\n", + "\n", + " [[0.4768083 ],\n", + " [0.35134616],\n", + " [0.48272693],\n", + " [0.78913575]]]], dtype=float32))\n" + ] + } + ], + "source": [ + "def make_params(key: jax.Array):\n", + " k1, k2 = jax.random.split(key.astype(jnp.uint32))\n", + " return dict(w1=jr.uniform(k1, (1, 3)), w2=jr.uniform(k2, (3, 1)), name=\"layer\")\n", + "\n", + "\n", + "keys = jr.split(jr.PRNGKey(0), 4).astype(jnp.float32)\n", + "params = automask(jax.vmap)(make_params)(keys)\n", + "\n", + "\n", + "def forward(params: dict[str, Any], x: jax.Array) -> jax.Array:\n", + " return jnp.tanh(x @ params[\"w1\"]) @ params[\"w2\"]\n", + "\n", + "\n", + "def scan_func(params, input):\n", + " # layer contains non-jax types\n", + " output = forward(params, input)\n", + " return params, output\n", + "\n", + "\n", + "try:\n", + " jax.lax.scan(scan_func, params, jnp.ones((1, 1)))\n", + "except TypeError as error:\n", + " print(f\"`jax.lax.scan` Failed: {error=}\")\n", + " print(\"\\nUsing `inline_automask`:\")\n", + " print(inline_automask(jax.lax.scan)(scan_func, params, jnp.ones((1, 1))))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `inline_automask`(`eval_shape`)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "`jax.eval_shape` Failed: error=TypeError(\"Argument 'layer' of type is not a valid JAX type\")\n", + "\n", + "Using `inline_automask`:\n", + "ShapeDtypeStruct(shape=(10, 1), dtype=float32)\n" + ] + } + ], + "source": [ + "def make_params(key: jax.Array):\n", + " k1, k2 = jax.random.split(key.astype(jnp.uint32))\n", + " return dict(w1=jr.uniform(k1, (1, 3)), w2=jr.uniform(k2, (3, 1)), name=\"layer\")\n", + "\n", + "\n", + "params = make_params(jr.PRNGKey(0))\n", + "\n", + "\n", + "def forward(params: dict[str, Any], x: jax.Array) -> jax.Array:\n", + " return jnp.tanh(x @ params[\"w1\"]) @ params[\"w2\"]\n", + "\n", + "\n", + "try:\n", + " jax.eval_shape(forward, params, jnp.ones((10, 1)))\n", + "except TypeError as error:\n", + " print(f\"`jax.eval_shape` Failed: {error=}\")\n", + " print(\"\\nUsing `inline_automask`:\")\n", + " print(inline_automask(jax.eval_shape)(forward, params, jnp.ones((10, 1))))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dev-jax", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/tree_recipes.rst b/docs/tree_recipes.rst new file mode 100644 index 0000000..18da62e --- /dev/null +++ b/docs/tree_recipes.rst @@ -0,0 +1,17 @@ +🍳 Tree recipes +---------------- + +.. note:: + `sepes `_ for tree API is fully re-exported under the ``serket`` namespace. + `Check the docs `_ for full details. + +.. toctree:: + :caption: recipes + :maxdepth: 1 + + notebooks/[recipes]surgery + notebooks/[recipes]fields + notebooks/[recipes]intermediates + notebooks/[recipes]misc + notebooks/[recipes]sharing + notebooks/[recipes]transformations \ No newline at end of file