Skip to content

Commit

Permalink
value and tree (#96)
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 authored Dec 16, 2023
1 parent a17c19e commit 2d9d3ce
Show file tree
Hide file tree
Showing 16 changed files with 412 additions and 487 deletions.
156 changes: 81 additions & 75 deletions docs/notebooks/distributed_training.ipynb

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions docs/notebooks/fields.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -126,7 +126,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -156,7 +156,7 @@
" return self.frozen_a + x\n",
"\n",
"\n",
"tree = Tree(frozen_a=1) # 1 is non-jaxtype\n",
"tree = Tree(frozen_a=1)\n",
"\n",
"\n",
"@jax.jit\n",
Expand All @@ -166,7 +166,7 @@
"\n",
"print(f(tree, 1.0))\n",
"\n",
"print(jax.grad(f)(tree, 1.0))\n",
"print(jax.grad(f)(tree, 1.0)) # 1 is not differentiable\n",
"\n",
"# not visible to `jax.tree_util...`\n",
"print(jax.tree_util.tree_leaves(tree))"
Expand All @@ -181,7 +181,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -212,7 +212,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -282,7 +282,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand Down
113 changes: 2 additions & 111 deletions docs/notebooks/function_transformations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@
"def build_layer(in_dim, out_dim, *, key: jax.Array):\n",
" return Linear(in_dim, out_dim, key=key)\n",
"\n",
"\n",
"build_layer(config[\"in_dim\"], config[\"out_dim\"], key=config[\"key\"])"
]
},
Expand Down Expand Up @@ -480,6 +481,7 @@
" output = layer(input)\n",
" return layer, output\n",
"\n",
"\n",
"layer = sk.nn.Linear(1, 1, key=jr.PRNGKey(0))\n",
"\n",
"try:\n",
Expand Down Expand Up @@ -530,117 +532,6 @@
" print(\"\\nUsing `inline_automask`:\")\n",
" print(inline_automask(jax.eval_shape)(func, linear, jnp.ones((10, 10))))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## [3] Functional transformations\n",
"\n",
"In this example, a function transformation variant of `AtIndexer[method_name](*args, **kwargs)` that enables writing methods in stateful manner but executed in a functional manner is demonstrated.\n",
"\n",
"Inplace mutation is not desired in case of working within `jax` transformations. see [jax docs](https://jax.readthedocs.io/en/latest/jax-101/07-state.html), as a solution `value_and_tree` is introduced to enable writing methods in stateful manner but executed in a functional manner, by first copying the tree, then applying the method, then returning the new tree and the output of the method."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Before increment:\tcounter=StatefulCounter(counter=0)\n",
"After increment:\tcounter=StatefulCounter(counter=1)\n"
]
}
],
"source": [
"import serket as sk\n",
"import jax\n",
"from typing import Any\n",
"from typing_extensions import TypeVar\n",
"import functools as ft\n",
"\n",
"T = TypeVar(\"T\")\n",
"\n",
"\n",
"@jax.tree_util.register_pytree_node_class\n",
"class StatefulCounter:\n",
" def __init__(self):\n",
" self.counter = 0\n",
"\n",
" def tree_flatten(self):\n",
" keys, values = zip(*vars(self).items())\n",
" return tuple(values), keys\n",
"\n",
" @classmethod\n",
" def tree_unflatten(cls, keys, values):\n",
" self = object.__new__(cls)\n",
" vars(self).update(zip(keys, values))\n",
" return self\n",
"\n",
" def __repr__(self):\n",
" params = \", \".join(f\"{k}={v}\" for k, v in vars(self).items())\n",
" return f\"{self.__class__.__name__}({params})\"\n",
"\n",
" def increment(self) -> None:\n",
" # stateful function\n",
" self.counter += 1\n",
"\n",
" def increment_with_no_side_effect(self) -> tuple[None, \"StatefulCounter\"]:\n",
" # apply increment on a copy of self\n",
" # and return the result and the new instance\n",
" return sk.AtIndexer(self)[\"increment\"]()\n",
"\n",
"\n",
"# stateful counter\n",
"counter = StatefulCounter()\n",
"print(f\"Before increment:\\t{counter=}\")\n",
"counter_value = counter.increment()\n",
"print(f\"After increment:\\t{counter=}\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Before increment:\tcounter=StatefulCounter(counter=0)\n",
"After increment:\tcounter=StatefulCounter(counter=0)\n",
"new_counter=StatefulCounter(counter=1)\n",
"counter.increment_with_no_side_effect()=(None, StatefulCounter(counter=1))\n"
]
}
],
"source": [
"def value_and_tree(method):\n",
" \"\"\"Transforms bound method to return both the value of the method and a copy of the tree.\"\"\"\n",
"\n",
" def wrapper(*args, **kwargs):\n",
" indexer = sk.AtIndexer(method.__self__)\n",
" return indexer[method.__name__](*args, **kwargs)\n",
"\n",
" return wrapper\n",
"\n",
"\n",
"# using the `value_and_tree`, the method will return both the value and a copy of the tree\n",
"# that contains the mutated state without mutating the original tree\n",
"\n",
"counter = StatefulCounter()\n",
"print(f\"Before increment:\\t{counter=}\")\n",
"method_output, new_counter = value_and_tree(counter.increment)()\n",
"print(f\"After increment:\\t{counter=}\")\n",
"print(f\"{new_counter=}\")\n",
"\n",
"# this equivalent to calling `increment_with_no_side_effect`\n",
"print(f\"{counter.increment_with_no_side_effect()=}\")"
]
}
],
"metadata": {
Expand Down
8 changes: 4 additions & 4 deletions docs/notebooks/layers_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -278,7 +278,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand All @@ -287,14 +287,14 @@
"(5, 10)"
]
},
"execution_count": 6,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# materialize the layer with single image\n",
"_, layer = layer.at[\"__call__\"](x[0])\n",
"_, layer = sk.value_and_tree(lambda layer: layer(x[0]))(layer)\n",
"# apply on batch\n",
"y = jax.vmap(layer)(x)\n",
"y.shape"
Expand Down
12 changes: 6 additions & 6 deletions docs/notebooks/misc_recipes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,12 @@
"source": [
"## [3] Data pipelines\n",
"\n",
"In this example, `AtIndexer` is used in similar fashion to [PyFunctional](https://github.com/EntilZha/PyFunctional) to work on general data pipelines."
"In this example, `at` is used in similar fashion to [PyFunctional](https://github.com/EntilZha/PyFunctional) to work on general data pipelines."
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 1,
"metadata": {},
"outputs": [
{
Expand All @@ -214,13 +214,13 @@
"25"
]
},
"execution_count": 4,
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from serket import AtIndexer\n",
"import serket as sk\n",
"import jax\n",
"\n",
"\n",
Expand All @@ -243,9 +243,8 @@
" Transaction(\"paycheck\", -1000),\n",
"]\n",
"\n",
"indexer = AtIndexer(transactions)\n",
"where = jax.tree_map(lambda x: x.reason == \"food\", transactions)\n",
"food_cost = indexer[where].reduce(lambda x, y: x + y.amount, initializer=0)\n",
"food_cost = sk.at(transactions)[where].reduce(lambda x, y: x + y.amount, initializer=0)\n",
"food_cost"
]
},
Expand Down Expand Up @@ -443,6 +442,7 @@
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"\n",
"\n",
"class TiedAutoEncoder(sk.TreeClass):\n",
" def __init__(self, *, key: jax.Array):\n",
" k1, k2, k3, k4 = jr.split(key, 4)\n",
Expand Down
Loading

0 comments on commit 2d9d3ce

Please sign in to comment.