Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Apr 6, 2024
1 parent 6142118 commit b205108
Show file tree
Hide file tree
Showing 6 changed files with 846 additions and 68 deletions.
3 changes: 2 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ Install from pip::
:caption: 📖 Guides
:maxdepth: 1

recipes
notebooks/[guides]surgery
notebooks/[guides]optimlib

.. toctree::
:caption: API Documentation
Expand Down
527 changes: 527 additions & 0 deletions docs/notebooks/[guides]optimlib.ipynb

Large diffs are not rendered by default.

304 changes: 304 additions & 0 deletions docs/notebooks/[guides]surgery.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ✂️ Surgery\n",
"\n",
"This notebook provides tree editing (surgery) recipes using `at`. `at` encapsules a pytree and provides a way to edit it in out-of-place manner. Under the hood, `at` uses `jax.tree_util` or `optree` to traverse the tree and apply the provided function to the selected nodes.\n",
"\n",
"```python\n",
"import sepes as sp\n",
"import re\n",
"tree = dict(key_1=[1, -2, 3], key_2=[4, 5, 6], other=[7, 8, 9])\n",
"where = re.compile(\"key_.*\") # select all keys starting with \"key_\"\n",
"func = lambda node: sum(map(abs, node)) # sum of absolute values\n",
"sp.at(tree)[where].apply(func)\n",
"# {'key_1': 6, 'key_2': 15, 'other': [7, 8, 9]}\n",
"```\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"!pip install sepes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Out-of-place editing\n",
"\n",
"Out-of-place means that the original tree is not modified. Instead, a new tree is created with the changes. This is the default behavior of `at`. The following example demonstrates this behavior:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"pytree1=[1, [2, 3], 4], pytree2=[1, [2, 3], 4]\n",
"pytree1 is pytree2 = False\n"
]
}
],
"source": [
"import sepes as sp\n",
"pytree1 = [1, [2, 3], 4]\n",
"pytree2 = sp.at(pytree1)[...].get() # get the whole pytree using ...\n",
"print(f\"{pytree1=}, {pytree2=}\")\n",
"# even though pytree1 and pytree2 are the same, they are not the same object\n",
"# because pytree2 is a copy of pytree1\n",
"print(f\"pytree1 is pytree2 = {pytree1 is pytree2}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Matching keys"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Match all\n",
"\n",
"Use `...` to match all keys. \n",
"\n",
"The following example applies `plus_one` function to all tree nodes. This is equivalent to `tree = tree_map(plus_one, tree)`."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[2, [3, 4], 5]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sepes as sp\n",
"pytree1 = [1, [2, 3], 4]\n",
"plus_one = lambda x: x + 1\n",
"pytree2 = sp.at(pytree1)[...].apply(plus_one)\n",
"pytree2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Integer indexing\n",
"\n",
"`at` can edit pytrees by integer paths."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[1, [100, 3], 4]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sepes as sp\n",
"pytree1 = [1, [2, 3], 4]\n",
"pytree2 = sp.at(pytree1)[1][0].set(100) # equivalent to pytree1[1][0] = 100\n",
"pytree2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Named path indexing\n",
"`at` can edit pytrees by named paths."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'a': -1, 'b': 100, 'e': -4, 'f': {'g': 7, 'h': 8}}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sepes as sp\n",
"pytree1 = {\"a\": -1, \"b\": {\"c\": 2, \"d\": 3}, \"e\": -4, \"f\": {\"g\": 7, \"h\": 8}}\n",
"pytree2 = sp.at(pytree1)[\"b\"].set(100) # equivalent to pytree1[\"b\"] = 100\n",
"pytree2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Regular expressions indexing\n",
"`at` can edit pytrees by regular expressions applied to keys."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'key_1': 100, 'key_2': 100, 'key_5': 100, 'key_6': {'key_7': 7, 'key_8': 8}}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sepes as sp\n",
"import re\n",
"pytree1 = {\"key_1\": 1, \"key_2\": {\"key_3\": 3, \"key_4\": 4}, \"key_5\": 5, \"key_6\": {\"key_7\": 7, \"key_8\": 8}}\n",
"# from 1 - 5, set the value to 100\n",
"pattern = re.compile(r\"key_[1-5]\")\n",
"pytree2 = sp.at(pytree1)[pattern].set(100)\n",
"pytree2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Mask indexing\n",
"\n",
"`at` can edit pytree entries by a boolean mask, meaning that given a mask of the same structure of the pytree, then nodes marked `True` will be edited, otherwise will not be touched. \n",
"\n",
"\n",
"The following example set all negative tree nodes to zero."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'a': 0, 'b': {'c': 2, 'd': 3}, 'e': 0}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sepes as sp\n",
"import jax\n",
"pytree1 = {\"a\": -1, \"b\": {\"c\": 2, \"d\": 3}, \"e\": -4}\n",
"# mask defines all desired entries to apply the function\n",
"mask = jax.tree_util.tree_map(lambda x: x < 0, pytree1)\n",
"pytree2 = sp.at(pytree1)[mask].set(0)\n",
"pytree2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Composition\n",
"\n",
"`at` can compose multiple keys, integer paths, named paths, regular expressions, and masks to edit the tree.\n",
"\n",
"The following example demonstrates how to apply a function to all nodes with:\n",
"- Name starting with \"key_\"\n",
"- Positive values"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'key_1': 100, 'key_2': -2, 'value_1': 1, 'value_2': 2}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sepes as sp\n",
"import re\n",
"import jax\n",
"\n",
"pytree1 = {\"key_1\": 1, \"key_2\": -2, \"value_1\": 1, \"value_2\": 2}\n",
"pattern = re.compile(r\"key_.*\")\n",
"mask = jax.tree_util.tree_map(lambda x: x > 0, pytree1)\n",
"pytree2 = sp.at(pytree1)[pattern][mask].set(100)\n",
"pytree2"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "dev-jax",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
22 changes: 11 additions & 11 deletions docs/notebooks/[recipes]fields.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"\n",
"The standard way to mark an array as a buffer (e.g. non-trainable) is to write something like this:\n",
"```python\n",
"class Tree(sk.TreeClass):\n",
"class Tree(sp.TreeClass):\n",
" def __init__(self, buffer: jax.Array):\n",
" self.buffer = buffer\n",
"\n",
Expand All @@ -54,7 +54,7 @@
"However, if you access this buffer from other methods, then another `jax.lax.stop_gradient` should be used and written inside all the methods:\n",
"\n",
"```python\n",
"class Tree(sk.TreeClass):\n",
"class Tree(sp.TreeClass):\n",
" def method_1(self, x: jax.Array) -> jax.Array:\n",
" return x + jax.lax.stop_gradient(self.buffer)\n",
" .\n",
Expand Down Expand Up @@ -175,12 +175,12 @@
"metadata": {},
"outputs": [],
"source": [
"import serket as sk\n",
"import sepes as sp\n",
"import jax\n",
"\n",
"\n",
"def mask_field(**kwargs):\n",
" return sk.field(\n",
" return sp.field(\n",
" # un mask when the value is accessed\n",
" on_getattr=[lambda x: sp.tree_unmask(x, cond=lambda node: True)],\n",
" # mask when the value is set\n",
Expand Down Expand Up @@ -223,8 +223,8 @@
}
],
"source": [
"@sk.autoinit\n",
"class Tree(sk.TreeClass):\n",
"@sp.autoinit\n",
"class Tree(sp.TreeClass):\n",
" training_mode: str # <- will throw error with jax transformations.\n",
" alpha: float\n",
"\n",
Expand Down Expand Up @@ -256,8 +256,8 @@
"metadata": {},
"outputs": [],
"source": [
"@sk.autoinit\n",
"class Tree(sk.TreeClass):\n",
"@sp.autoinit\n",
"class Tree(sp.TreeClass):\n",
" training_mode: str = mask_field() # hide the field from jax transformations\n",
" alpha: float\n",
"\n",
Expand Down Expand Up @@ -303,7 +303,7 @@
"\n",
"\n",
"# you can use any function\n",
"@sk.autoinit\n",
"@sp.autoinit\n",
"class Range(sp.TreeClass):\n",
" min: int | float = -float(\"inf\")\n",
" max: int | float = float(\"inf\")\n",
Expand All @@ -314,7 +314,7 @@
" return x\n",
"\n",
"\n",
"@sk.autoinit\n",
"@sp.autoinit\n",
"class IsInstance(sp.TreeClass):\n",
" klass: type | tuple[type, ...]\n",
"\n",
Expand All @@ -324,7 +324,7 @@
" return x\n",
"\n",
"\n",
"@sk.autoinit\n",
"@sp.autoinit\n",
"class Foo(sp.TreeClass):\n",
" # allow in_dim to be an integer between [1,100]\n",
" in_dim: int = sp.field(on_setattr=[IsInstance(int), Range(1, 100)])\n",
Expand Down
Loading

0 comments on commit b205108

Please sign in to comment.