-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
846 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,304 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# ✂️ Surgery\n", | ||
"\n", | ||
"This notebook provides tree editing (surgery) recipes using `at`. `at` encapsules a pytree and provides a way to edit it in out-of-place manner. Under the hood, `at` uses `jax.tree_util` or `optree` to traverse the tree and apply the provided function to the selected nodes.\n", | ||
"\n", | ||
"```python\n", | ||
"import sepes as sp\n", | ||
"import re\n", | ||
"tree = dict(key_1=[1, -2, 3], key_2=[4, 5, 6], other=[7, 8, 9])\n", | ||
"where = re.compile(\"key_.*\") # select all keys starting with \"key_\"\n", | ||
"func = lambda node: sum(map(abs, node)) # sum of absolute values\n", | ||
"sp.at(tree)[where].apply(func)\n", | ||
"# {'key_1': 6, 'key_2': 15, 'other': [7, 8, 9]}\n", | ||
"```\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install sepes" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Out-of-place editing\n", | ||
"\n", | ||
"Out-of-place means that the original tree is not modified. Instead, a new tree is created with the changes. This is the default behavior of `at`. The following example demonstrates this behavior:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"pytree1=[1, [2, 3], 4], pytree2=[1, [2, 3], 4]\n", | ||
"pytree1 is pytree2 = False\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import sepes as sp\n", | ||
"pytree1 = [1, [2, 3], 4]\n", | ||
"pytree2 = sp.at(pytree1)[...].get() # get the whole pytree using ...\n", | ||
"print(f\"{pytree1=}, {pytree2=}\")\n", | ||
"# even though pytree1 and pytree2 are the same, they are not the same object\n", | ||
"# because pytree2 is a copy of pytree1\n", | ||
"print(f\"pytree1 is pytree2 = {pytree1 is pytree2}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Matching keys" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Match all\n", | ||
"\n", | ||
"Use `...` to match all keys. \n", | ||
"\n", | ||
"The following example applies `plus_one` function to all tree nodes. This is equivalent to `tree = tree_map(plus_one, tree)`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"[2, [3, 4], 5]" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"import sepes as sp\n", | ||
"pytree1 = [1, [2, 3], 4]\n", | ||
"plus_one = lambda x: x + 1\n", | ||
"pytree2 = sp.at(pytree1)[...].apply(plus_one)\n", | ||
"pytree2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Integer indexing\n", | ||
"\n", | ||
"`at` can edit pytrees by integer paths." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"[1, [100, 3], 4]" | ||
] | ||
}, | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"import sepes as sp\n", | ||
"pytree1 = [1, [2, 3], 4]\n", | ||
"pytree2 = sp.at(pytree1)[1][0].set(100) # equivalent to pytree1[1][0] = 100\n", | ||
"pytree2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Named path indexing\n", | ||
"`at` can edit pytrees by named paths." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'a': -1, 'b': 100, 'e': -4, 'f': {'g': 7, 'h': 8}}" | ||
] | ||
}, | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"import sepes as sp\n", | ||
"pytree1 = {\"a\": -1, \"b\": {\"c\": 2, \"d\": 3}, \"e\": -4, \"f\": {\"g\": 7, \"h\": 8}}\n", | ||
"pytree2 = sp.at(pytree1)[\"b\"].set(100) # equivalent to pytree1[\"b\"] = 100\n", | ||
"pytree2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Regular expressions indexing\n", | ||
"`at` can edit pytrees by regular expressions applied to keys." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'key_1': 100, 'key_2': 100, 'key_5': 100, 'key_6': {'key_7': 7, 'key_8': 8}}" | ||
] | ||
}, | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"import sepes as sp\n", | ||
"import re\n", | ||
"pytree1 = {\"key_1\": 1, \"key_2\": {\"key_3\": 3, \"key_4\": 4}, \"key_5\": 5, \"key_6\": {\"key_7\": 7, \"key_8\": 8}}\n", | ||
"# from 1 - 5, set the value to 100\n", | ||
"pattern = re.compile(r\"key_[1-5]\")\n", | ||
"pytree2 = sp.at(pytree1)[pattern].set(100)\n", | ||
"pytree2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Mask indexing\n", | ||
"\n", | ||
"`at` can edit pytree entries by a boolean mask, meaning that given a mask of the same structure of the pytree, then nodes marked `True` will be edited, otherwise will not be touched. \n", | ||
"\n", | ||
"\n", | ||
"The following example set all negative tree nodes to zero." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'a': 0, 'b': {'c': 2, 'd': 3}, 'e': 0}" | ||
] | ||
}, | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"import sepes as sp\n", | ||
"import jax\n", | ||
"pytree1 = {\"a\": -1, \"b\": {\"c\": 2, \"d\": 3}, \"e\": -4}\n", | ||
"# mask defines all desired entries to apply the function\n", | ||
"mask = jax.tree_util.tree_map(lambda x: x < 0, pytree1)\n", | ||
"pytree2 = sp.at(pytree1)[mask].set(0)\n", | ||
"pytree2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Composition\n", | ||
"\n", | ||
"`at` can compose multiple keys, integer paths, named paths, regular expressions, and masks to edit the tree.\n", | ||
"\n", | ||
"The following example demonstrates how to apply a function to all nodes with:\n", | ||
"- Name starting with \"key_\"\n", | ||
"- Positive values" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'key_1': 100, 'key_2': -2, 'value_1': 1, 'value_2': 2}" | ||
] | ||
}, | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"import sepes as sp\n", | ||
"import re\n", | ||
"import jax\n", | ||
"\n", | ||
"pytree1 = {\"key_1\": 1, \"key_2\": -2, \"value_1\": 1, \"value_2\": 2}\n", | ||
"pattern = re.compile(r\"key_.*\")\n", | ||
"mask = jax.tree_util.tree_map(lambda x: x > 0, pytree1)\n", | ||
"pytree2 = sp.at(pytree1)[pattern][mask].set(100)\n", | ||
"pytree2" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "dev-jax", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.2" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.