diff --git a/docs/notebooks/misc_recipes.ipynb b/docs/notebooks/misc_recipes.ipynb index a58a6d5..7f5856b 100644 --- a/docs/notebooks/misc_recipes.ipynb +++ b/docs/notebooks/misc_recipes.ipynb @@ -14,7 +14,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install git+https://github.com/ASEM000/serket --quiet" + "# !pip install git+https://github.com/ASEM000/serket --quiet" ] }, { @@ -36,7 +36,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -62,7 +62,7 @@ " [1.]], dtype=float32)" ] }, - "execution_count": 1, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -72,8 +72,6 @@ "from typing import Any\n", "import jax\n", "import jax.numpy as jnp\n", - "import jax.random as jr\n", - "\n", "\n", "class LazyLinear(sk.TreeClass):\n", " def __init__(self, out_features: int):\n", @@ -110,89 +108,111 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## [2] Intermediates handling.\n", + "## [2] Intermediates handling." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Capture intermediate values. \n", "\n", - "This example shows how to capture specific intermediate values within each function call in this example." + "In this example, we will capture the intermediate values in a method by simply returning them as part of the output." ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Intermediate values:\t\n", - " (Array([[0. ],\n", - " [0.5],\n", - " [1. ],\n", - " [1.5],\n", - " [2. ]], dtype=float32), Array([[-0.09999937],\n", - " [ 0.40000063],\n", - " [ 0.90000063],\n", - " [ 1.4000006 ],\n", - " [ 1.9000006 ]], dtype=float32))\n", - "\n", - "Final tree:\t\n", - " Tree(a=0.801189)\n" - ] + "data": { + "text/plain": [ + "{'b': 2.0, 'c': 4.0}" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "from typing import Any\n", "import serket as sk\n", "import jax\n", - "import optax\n", - "import jax.numpy as jnp\n", "\n", "\n", - "@sk.autoinit\n", - "class Tree(sk.TreeClass):\n", - " a: float = 1.0\n", + "class Foo(sk.TreeClass):\n", + " def __init__(self):\n", + " self.a = 1.0\n", "\n", - " def __call__(self, x: jax.Array, intermediate: tuple[Any, ...]):\n", - " x = x + self.a\n", - " # store intermediate variables\n", - " return x, intermediate + (x,)\n", + " def __call__(self, x):\n", + " capture = {}\n", + " b = self.a + x\n", + " capture[\"b\"] = b\n", + " c = 2 * b\n", + " capture[\"c\"] = c\n", + " e = 4 * c\n", + " return e, capture\n", "\n", "\n", - "def loss_func(tree: Tree, x: jax.Array, y: jax.Array, intermediate: tuple[Any, ...]):\n", - " ypred, intermediate = tree(x, intermediate)\n", - " loss = jnp.mean((ypred - y) ** 2)\n", - " return loss, intermediate\n", + "foo = Foo()\n", "\n", + "_, inter_values = foo(1.0)\n", + "inter_values" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Capture intermediate gradients\n", "\n", - "@jax.jit\n", - "def train_step(\n", - " tree: Tree,\n", - " optim_state: optax.OptState,\n", - " x: jax.Array,\n", - " y: jax.Array,\n", - " intermediate: tuple[Any, ...],\n", - "):\n", - " grads, intermediate = jax.grad(loss_func, has_aux=True)(tree, x, y, intermediate)\n", - " updates, optim_state = optim.update(grads, optim_state)\n", - " tree = optax.apply_updates(tree, updates)\n", - " return tree, optim_state, intermediate\n", + "In this example, we will capture the intermediate gradients in a method by 1) perturbing the desired value and 2) using `argnum` in `jax.grad` to compute the intermediate gradients." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'b': Array(8., dtype=float32, weak_type=True),\n", + " 'c': Array(4., dtype=float32, weak_type=True)}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import serket as sk\n", + "import jax\n", "\n", "\n", - "tree = Tree()\n", - "optim = optax.adam(1e-1)\n", - "optim_state = optim.init(tree)\n", + "class Foo(sk.TreeClass):\n", + " def __init__(self):\n", + " self.a = 1.0\n", "\n", - "x = jnp.linspace(-1, 1, 5)[:, None]\n", - "y = x**2\n", + " def __call__(self, x, perturb):\n", + " # pass in the perturbations as a pytree\n", + " b = self.a + x + perturb[\"b\"]\n", + " c = 2 * b + perturb[\"c\"]\n", + " e = 4 * c\n", + " return e\n", "\n", - "intermediate = ()\n", "\n", - "for i in range(2):\n", - " tree, optim_state, intermediate = train_step(tree, optim_state, x, y, intermediate)\n", + "foo = Foo()\n", "\n", + "# de/dc = 4\n", + "# de/db = de/dc * dc/db = 4 * 2 = 8\n", "\n", - "print(\"Intermediate values:\\t\\n\", intermediate)\n", - "print(\"\\nFinal tree:\\t\\n\", tree)" + "# take gradient with respect to the perturbations pytree\n", + "# by setting `argnums=1` in `jax.grad`\n", + "inter_grads = jax.grad(foo, argnums=1)(1.0, dict(b=0.0, c=0.0))\n", + "inter_grads" ] }, { @@ -206,7 +226,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -215,7 +235,7 @@ "25" ] }, - "execution_count": 3, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -269,7 +289,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -320,7 +340,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -348,7 +368,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -420,23 +440,16 @@ "source": [ "## [5] Sharing/Tie Weights\n", "\n", - "In this example a simple `AutoEncoder` with shared `weight` between the encode/decoder is demonstrated." + "In this example a simple `AutoEncoder` with shared `weight` between the encode/decoder is demonstrated. \n", + "\n", + "The key idea here, is that sharing weight takes effect only within methods and does not extend beyond that scope. The limited scope design is to comply with `jax` functional requirements." ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "10\n", - "200\n" - ] - } - ], + "outputs": [], "source": [ "import serket as sk\n", "import jax\n", @@ -444,7 +457,7 @@ "import jax.random as jr\n", "\n", "\n", - "class TiedAutoEncoder(sk.TreeClass):\n", + "class AutoEncoder(sk.TreeClass):\n", " def __init__(self, *, key: jax.Array):\n", " k1, k2, k3, k4 = jr.split(key, 4)\n", " self.enc1 = sk.nn.Linear(1, 10, key=k1)\n", @@ -472,7 +485,6 @@ " return output\n", "\n", " def non_tied_call(self, x):\n", - " # non-tied call\n", " output = self.enc1(x)\n", " output = self.enc2(output)\n", " output = self.dec2(output)\n", @@ -480,12 +492,44 @@ " return output\n", "\n", "\n", - "tree = sk.tree_mask(TiedAutoEncoder(key=jr.PRNGKey(0)))\n", - "\n", - "\n", + "tree = sk.tree_mask(AutoEncoder(key=jr.PRNGKey(0)))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Linear(\n", + " in_features=(#10), \n", + " out_features=#1, \n", + " in_axis=(#-1), \n", + " out_axis=#-1, \n", + " weight_init=#glorot_uniform, \n", + " bias_init=#zeros, \n", + " weight=f32[1,10](μ=0.00, σ=0.00, ∈[0.00,0.00]), \n", + " bias=f32[1](μ=0.35, σ=0.00, ∈[0.35,0.35])\n", + ") Linear(\n", + " in_features=(#20), \n", + " out_features=#10, \n", + " in_axis=(#-1), \n", + " out_axis=#-1, \n", + " weight_init=#glorot_uniform, \n", + " bias_init=#zeros, \n", + " weight=f32[10,20](μ=0.00, σ=0.00, ∈[0.00,0.00]), \n", + " bias=f32[10](μ=0.11, σ=0.09, ∈[-0.03,0.24])\n", + ")\n" + ] + } + ], + "source": [ "@jax.jit\n", "@jax.grad\n", - "def loss_func(net, x, y):\n", + "def tied_loss_func(net, x, y):\n", " net = sk.tree_unmask(net)\n", " return jnp.mean((jax.vmap(net.tied_call)(x) - y) ** 2)\n", "\n", @@ -493,18 +537,47 @@ "tree = sk.tree_mask(tree)\n", "x = jnp.ones([10, 1]) + 0.0\n", "y = jnp.ones([10, 1]) * 2.0\n", - "grads: TiedAutoEncoder = loss_func(tree, x, y)\n", - "\n", - "\n", - "# check that gradients are zero for tied weights (dec1.weight, dec2.weight)\n", - "assert jnp.count_nonzero(grads.dec1.weight) == 0\n", - "assert jnp.count_nonzero(grads.dec2.weight) == 0\n", - "\n", - "\n", + "grads: AutoEncoder = tied_loss_func(tree, x, y)\n", + "# note that the shared weights have 0 gradient\n", + "print(repr(grads.dec1), repr(grads.dec2))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Linear(\n", + " in_features=(#10), \n", + " out_features=#1, \n", + " in_axis=(#-1), \n", + " out_axis=#-1, \n", + " weight_init=#glorot_uniform, \n", + " bias_init=#zeros, \n", + " weight=f32[1,10](μ=-0.12, σ=1.01, ∈[-1.18,1.63]), \n", + " bias=f32[1](μ=-2.74, σ=0.00, ∈[-2.74,-2.74])\n", + ") Linear(\n", + " in_features=(#20), \n", + " out_features=#10, \n", + " in_axis=(#-1), \n", + " out_axis=#-1, \n", + " weight_init=#glorot_uniform, \n", + " bias_init=#zeros, \n", + " weight=f32[10,20](μ=-0.00, σ=0.35, ∈[-1.59,1.02]), \n", + " bias=f32[10](μ=-0.88, σ=0.57, ∈[-1.65,0.07])\n", + ")\n" + ] + } + ], + "source": [ "# check for non-tied call\n", "@jax.jit\n", "@jax.grad\n", - "def loss_func(net, x, y):\n", + "def tied_loss_func(net, x, y):\n", " net = sk.tree_unmask(net)\n", " return jnp.mean((jax.vmap(net.non_tied_call)(x) - y) ** 2)\n", "\n", @@ -512,11 +585,10 @@ "tree = sk.tree_mask(tree)\n", "x = jnp.ones([10, 1]) + 0.0\n", "y = jnp.ones([10, 1]) * 2.0\n", - "grads: TiedAutoEncoder = loss_func(tree, x, y)\n", + "grads: AutoEncoder = tied_loss_func(tree, x, y)\n", "\n", - "# check non-zero gradients for the decoder weights\n", - "print(jnp.count_nonzero(grads.dec1.weight))\n", - "print(jnp.count_nonzero(grads.dec2.weight))" + "# note that the shared weights have non-zero gradients\n", + "print(repr(grads.dec1), repr(grads.dec2))" ] } ],