Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

value and tree #96

Merged
merged 3 commits into from
Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 81 additions & 75 deletions docs/notebooks/distributed_training.ipynb

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions docs/notebooks/fields.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -126,7 +126,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -156,7 +156,7 @@
" return self.frozen_a + x\n",
"\n",
"\n",
"tree = Tree(frozen_a=1) # 1 is non-jaxtype\n",
"tree = Tree(frozen_a=1)\n",
"\n",
"\n",
"@jax.jit\n",
Expand All @@ -166,7 +166,7 @@
"\n",
"print(f(tree, 1.0))\n",
"\n",
"print(jax.grad(f)(tree, 1.0))\n",
"print(jax.grad(f)(tree, 1.0)) # 1 is not differentiable\n",
"\n",
"# not visible to `jax.tree_util...`\n",
"print(jax.tree_util.tree_leaves(tree))"
Expand All @@ -181,7 +181,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -212,7 +212,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -282,7 +282,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand Down
113 changes: 2 additions & 111 deletions docs/notebooks/function_transformations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@
"def build_layer(in_dim, out_dim, *, key: jax.Array):\n",
" return Linear(in_dim, out_dim, key=key)\n",
"\n",
"\n",
"build_layer(config[\"in_dim\"], config[\"out_dim\"], key=config[\"key\"])"
]
},
Expand Down Expand Up @@ -480,6 +481,7 @@
" output = layer(input)\n",
" return layer, output\n",
"\n",
"\n",
"layer = sk.nn.Linear(1, 1, key=jr.PRNGKey(0))\n",
"\n",
"try:\n",
Expand Down Expand Up @@ -530,117 +532,6 @@
" print(\"\\nUsing `inline_automask`:\")\n",
" print(inline_automask(jax.eval_shape)(func, linear, jnp.ones((10, 10))))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## [3] Functional transformations\n",
"\n",
"In this example, a function transformation variant of `AtIndexer[method_name](*args, **kwargs)` that enables writing methods in stateful manner but executed in a functional manner is demonstrated.\n",
"\n",
"Inplace mutation is not desired in case of working within `jax` transformations. see [jax docs](https://jax.readthedocs.io/en/latest/jax-101/07-state.html), as a solution `value_and_tree` is introduced to enable writing methods in stateful manner but executed in a functional manner, by first copying the tree, then applying the method, then returning the new tree and the output of the method."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Before increment:\tcounter=StatefulCounter(counter=0)\n",
"After increment:\tcounter=StatefulCounter(counter=1)\n"
]
}
],
"source": [
"import serket as sk\n",
"import jax\n",
"from typing import Any\n",
"from typing_extensions import TypeVar\n",
"import functools as ft\n",
"\n",
"T = TypeVar(\"T\")\n",
"\n",
"\n",
"@jax.tree_util.register_pytree_node_class\n",
"class StatefulCounter:\n",
" def __init__(self):\n",
" self.counter = 0\n",
"\n",
" def tree_flatten(self):\n",
" keys, values = zip(*vars(self).items())\n",
" return tuple(values), keys\n",
"\n",
" @classmethod\n",
" def tree_unflatten(cls, keys, values):\n",
" self = object.__new__(cls)\n",
" vars(self).update(zip(keys, values))\n",
" return self\n",
"\n",
" def __repr__(self):\n",
" params = \", \".join(f\"{k}={v}\" for k, v in vars(self).items())\n",
" return f\"{self.__class__.__name__}({params})\"\n",
"\n",
" def increment(self) -> None:\n",
" # stateful function\n",
" self.counter += 1\n",
"\n",
" def increment_with_no_side_effect(self) -> tuple[None, \"StatefulCounter\"]:\n",
" # apply increment on a copy of self\n",
" # and return the result and the new instance\n",
" return sk.AtIndexer(self)[\"increment\"]()\n",
"\n",
"\n",
"# stateful counter\n",
"counter = StatefulCounter()\n",
"print(f\"Before increment:\\t{counter=}\")\n",
"counter_value = counter.increment()\n",
"print(f\"After increment:\\t{counter=}\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Before increment:\tcounter=StatefulCounter(counter=0)\n",
"After increment:\tcounter=StatefulCounter(counter=0)\n",
"new_counter=StatefulCounter(counter=1)\n",
"counter.increment_with_no_side_effect()=(None, StatefulCounter(counter=1))\n"
]
}
],
"source": [
"def value_and_tree(method):\n",
" \"\"\"Transforms bound method to return both the value of the method and a copy of the tree.\"\"\"\n",
"\n",
" def wrapper(*args, **kwargs):\n",
" indexer = sk.AtIndexer(method.__self__)\n",
" return indexer[method.__name__](*args, **kwargs)\n",
"\n",
" return wrapper\n",
"\n",
"\n",
"# using the `value_and_tree`, the method will return both the value and a copy of the tree\n",
"# that contains the mutated state without mutating the original tree\n",
"\n",
"counter = StatefulCounter()\n",
"print(f\"Before increment:\\t{counter=}\")\n",
"method_output, new_counter = value_and_tree(counter.increment)()\n",
"print(f\"After increment:\\t{counter=}\")\n",
"print(f\"{new_counter=}\")\n",
"\n",
"# this equivalent to calling `increment_with_no_side_effect`\n",
"print(f\"{counter.increment_with_no_side_effect()=}\")"
]
}
],
"metadata": {
Expand Down
8 changes: 4 additions & 4 deletions docs/notebooks/layers_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -278,7 +278,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand All @@ -287,14 +287,14 @@
"(5, 10)"
]
},
"execution_count": 6,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# materialize the layer with single image\n",
"_, layer = layer.at[\"__call__\"](x[0])\n",
"_, layer = sk.value_and_tree(lambda layer: layer(x[0]))(layer)\n",
"# apply on batch\n",
"y = jax.vmap(layer)(x)\n",
"y.shape"
Expand Down
12 changes: 6 additions & 6 deletions docs/notebooks/misc_recipes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,12 @@
"source": [
"## [3] Data pipelines\n",
"\n",
"In this example, `AtIndexer` is used in similar fashion to [PyFunctional](https://github.com/EntilZha/PyFunctional) to work on general data pipelines."
"In this example, `at` is used in similar fashion to [PyFunctional](https://github.com/EntilZha/PyFunctional) to work on general data pipelines."
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 1,
"metadata": {},
"outputs": [
{
Expand All @@ -214,13 +214,13 @@
"25"
]
},
"execution_count": 4,
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from serket import AtIndexer\n",
"import serket as sk\n",
"import jax\n",
"\n",
"\n",
Expand All @@ -243,9 +243,8 @@
" Transaction(\"paycheck\", -1000),\n",
"]\n",
"\n",
"indexer = AtIndexer(transactions)\n",
"where = jax.tree_map(lambda x: x.reason == \"food\", transactions)\n",
"food_cost = indexer[where].reduce(lambda x, y: x + y.amount, initializer=0)\n",
"food_cost = sk.at(transactions)[where].reduce(lambda x, y: x + y.amount, initializer=0)\n",
"food_cost"
]
},
Expand Down Expand Up @@ -443,6 +442,7 @@
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"\n",
"\n",
"class TiedAutoEncoder(sk.TreeClass):\n",
" def __init__(self, *, key: jax.Array):\n",
" k1, k2, k3, k4 = jr.split(key, 4)\n",
Expand Down
Loading
Loading