Skip to content

Commit

Permalink
Update misc_recipes.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Feb 22, 2024
1 parent 846a811 commit 3d5a75f
Showing 1 changed file with 86 additions and 117 deletions.
203 changes: 86 additions & 117 deletions docs/notebooks/misc_recipes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install git+https://github.com/ASEM000/serket --quiet"
"# !pip install git+https://github.com/ASEM000/serket --quiet"
]
},
{
Expand Down Expand Up @@ -72,8 +72,6 @@
"from typing import Any\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"\n",
"\n",
"class LazyLinear(sk.TreeClass):\n",
" def __init__(self, out_features: int):\n",
Expand Down Expand Up @@ -110,9 +108,16 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## [2] Intermediates handling.\n",
"## [2] Intermediates handling."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Capture intermediate values. \n",
"\n",
"This example shows how to capture specific intermediate values within each function call in this example."
"In this example, we will capture the intermediate values in a method by simply returning them as part of the output."
]
},
{
Expand All @@ -121,78 +126,93 @@
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Intermediate values:\t\n",
" (Array([[0. ],\n",
" [0.5],\n",
" [1. ],\n",
" [1.5],\n",
" [2. ]], dtype=float32), Array([[-0.09999937],\n",
" [ 0.40000063],\n",
" [ 0.90000063],\n",
" [ 1.4000006 ],\n",
" [ 1.9000006 ]], dtype=float32))\n",
"\n",
"Final tree:\t\n",
" Tree(a=0.801189)\n"
]
"data": {
"text/plain": [
"{'b': 2.0, 'c': 4.0}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from typing import Any\n",
"import serket as sk\n",
"import jax\n",
"import optax\n",
"import jax.numpy as jnp\n",
"\n",
"\n",
"@sk.autoinit\n",
"class Tree(sk.TreeClass):\n",
" a: float = 1.0\n",
"class Foo(sk.TreeClass):\n",
" def __init__(self):\n",
" self.a = 1.0\n",
"\n",
" def __call__(self, x: jax.Array, intermediate: tuple[Any, ...]):\n",
" x = x + self.a\n",
" # store intermediate variables\n",
" return x, intermediate + (x,)\n",
" def __call__(self, x):\n",
" capture = {}\n",
" b = self.a + x\n",
" capture[\"b\"] = b\n",
" c = 2 * b\n",
" capture[\"c\"] = c\n",
" e = 4 * c\n",
" return e, capture\n",
"\n",
"\n",
"def loss_func(tree: Tree, x: jax.Array, y: jax.Array, intermediate: tuple[Any, ...]):\n",
" ypred, intermediate = tree(x, intermediate)\n",
" loss = jnp.mean((ypred - y) ** 2)\n",
" return loss, intermediate\n",
"foo = Foo()\n",
"\n",
"_, inter_values = foo(1.0)\n",
"inter_values"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Capture intermediate gradients\n",
"\n",
"@jax.jit\n",
"def train_step(\n",
" tree: Tree,\n",
" optim_state: optax.OptState,\n",
" x: jax.Array,\n",
" y: jax.Array,\n",
" intermediate: tuple[Any, ...],\n",
"):\n",
" grads, intermediate = jax.grad(loss_func, has_aux=True)(tree, x, y, intermediate)\n",
" updates, optim_state = optim.update(grads, optim_state)\n",
" tree = optax.apply_updates(tree, updates)\n",
" return tree, optim_state, intermediate\n",
"In this example, we will capture the intermediate gradients in a method by 1) perturbing the desired value and 2) using `argnum` in `jax.grad` to compute the intermediate gradients."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'b': Array(8., dtype=float32, weak_type=True),\n",
" 'c': Array(4., dtype=float32, weak_type=True)}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import serket as sk\n",
"import jax\n",
"\n",
"\n",
"tree = Tree()\n",
"optim = optax.adam(1e-1)\n",
"optim_state = optim.init(tree)\n",
"class Foo(sk.TreeClass):\n",
" def __init__(self):\n",
" self.a = 1.0\n",
"\n",
"x = jnp.linspace(-1, 1, 5)[:, None]\n",
"y = x**2\n",
" def __call__(self, x, perturb):\n",
" # pass in the perturbations as a pytree\n",
" b = self.a + x + perturb[\"b\"]\n",
" c = 2 * b + perturb[\"c\"]\n",
" e = 4 * c\n",
" return e\n",
"\n",
"intermediate = ()\n",
"\n",
"for i in range(2):\n",
" tree, optim_state, intermediate = train_step(tree, optim_state, x, y, intermediate)\n",
"foo = Foo()\n",
"\n",
"# de/dc = 4\n",
"# de/db = de/dc * dc/db = 4 * 2 = 8\n",
"\n",
"print(\"Intermediate values:\\t\\n\", intermediate)\n",
"print(\"\\nFinal tree:\\t\\n\", tree)"
"# take gradient with respect to the perturbations pytree\n",
"# by setting `argnums=1` in `jax.grad`\n",
"inter_grads = jax.grad(foo, argnums=1)(1.0, dict(b=0.0, c=0.0))\n",
"inter_grads"
]
},
{
Expand All @@ -206,7 +226,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -215,7 +235,7 @@
"25"
]
},
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -269,7 +289,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -320,7 +340,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -348,7 +368,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -427,7 +447,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -477,7 +497,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -524,7 +544,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -570,57 +590,6 @@
"# note that the shared weights have non-zero gradients\n",
"print(repr(grads.dec1), repr(grads.dec2))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## [6] Extract intermediate gradients"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'b': Array(8., dtype=float32, weak_type=True),\n",
" 'c': Array(4., dtype=float32, weak_type=True)}"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import serket as sk\n",
"import jax\n",
"\n",
"\n",
"class Foo(sk.TreeClass):\n",
" def __init__(self):\n",
" self.a = 1.0\n",
"\n",
" def __call__(self, x, perturb):\n",
" b = self.a + x + perturb[\"b\"]\n",
" c = 2 * b + perturb[\"c\"]\n",
" e = 4 * c\n",
" return e\n",
"\n",
"\n",
"foo = Foo()\n",
"\n",
"# de/dc = 4\n",
"# de/db = de/dc * dc/db = 4 * 2 = 8\n",
"\n",
"# take gradient with respect to the perturbations pytree\n",
"# by setting `argnums=1` in `jax.grad`\n",
"inter_grads = jax.grad(foo, argnums=1)(1.0, dict(b=0.0, c=0.0))\n",
"inter_grads"
]
}
],
"metadata": {
Expand Down

0 comments on commit 3d5a75f

Please sign in to comment.