Skip to content

Commit

Permalink
docs edit (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Nov 26, 2023
1 parent 4ae5d6c commit 9644ecd
Show file tree
Hide file tree
Showing 3 changed files with 468 additions and 188 deletions.
160 changes: 107 additions & 53 deletions docs/notebooks/common_recipes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@
"class Tree(sk.TreeClass):\n",
" def method_1(self, x: jax.Array) -> jax.Array:\n",
" return x + jax.lax.stop_gradient(self.buffer)\n",
" .\n",
" .\n",
" .\n",
" .\n",
" .\n",
" .\n",
" def method_n(self, x: jax.Array) -> jax.Array:\n",
" return x + jax.lax.stop_gradient(self.buffer)\n",
"```\n",
Expand Down Expand Up @@ -220,12 +220,12 @@
}
],
"source": [
"import serket as sp\n",
"import serket as sk\n",
"import jax\n",
"\n",
"\n",
"def frozen_field(**kwargs):\n",
" return sp.field(on_getattr=[sk.unfreeze], on_setattr=[sk.freeze], **kwargs)\n",
" return sk.field(on_getattr=[sk.unfreeze], on_setattr=[sk.freeze], **kwargs)\n",
"\n",
"\n",
"@sk.autoinit\n",
Expand Down Expand Up @@ -301,7 +301,7 @@
}
],
"source": [
"import serket as sp\n",
"import serket as sk\n",
"import jax.numpy as jnp\n",
"\n",
"\n",
Expand Down Expand Up @@ -343,7 +343,7 @@
}
],
"source": [
"import serket as sp\n",
"import serket as sk\n",
"import jax.numpy as jnp\n",
"\n",
"\n",
Expand Down Expand Up @@ -396,7 +396,7 @@
],
"source": [
"import jax\n",
"import serket as sp\n",
"import serket as sk\n",
"\n",
"\n",
"# you can use any function\n",
Expand Down Expand Up @@ -466,7 +466,7 @@
}
],
"source": [
"import serket as sp\n",
"import serket as sk\n",
"from typing import Any\n",
"import jax\n",
"import jax.numpy as jnp\n",
Expand Down Expand Up @@ -618,7 +618,7 @@
}
],
"source": [
"import serket as sp\n",
"import serket as sk\n",
"from typing import Any\n",
"import jax\n",
"import jax.numpy as jnp\n",
Expand Down Expand Up @@ -694,7 +694,7 @@
],
"source": [
"from typing import Any\n",
"import serket as sp\n",
"import serket as sk\n",
"import jax\n",
"import optax\n",
"import jax.numpy as jnp\n",
Expand Down Expand Up @@ -788,7 +788,7 @@
}
],
"source": [
"import serket as sp\n",
"import serket as sk\n",
"import jax\n",
"\n",
"\n",
Expand Down Expand Up @@ -854,7 +854,7 @@
"import jax\n",
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"import serket as sp\n",
"import serket as sk\n",
"import functools as ft\n",
"from typing import Generic, TypeVar\n",
"\n",
Expand Down Expand Up @@ -1003,7 +1003,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## [12] Regaularization\n",
"## [12] Regularization\n",
"\n",
"The following code showcase how to use `at` functionality to select some leaves of a model based on boolean mask or/and name condition to apply some weight regualrization on them. For example using `.at[...]` functionality the following can be achieved concisely:"
]
Expand Down Expand Up @@ -1270,77 +1270,131 @@
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"autojit\n",
"(Array([[5., 5., 5., 5., 5.],\n",
" [5., 5., 5., 5., 5.],\n",
" [5., 5., 5., 5., 5.],\n",
" [5., 5., 5., 5., 5.],\n",
" [5., 5., 5., 5., 5.]], dtype=float32), 'hello')\n",
"autovmap\n",
"(Array([5., 5., 5., 5., 5.], dtype=float32), 'hello')\n"
]
}
],
"outputs": [],
"source": [
"import serket as sk\n",
"import functools as ft\n",
"import jax\n",
"import jax.random as jr\n",
"import jax.numpy as jnp\n",
"from typing import Any\n",
"\n",
"\n",
"def automask(jax_transform):\n",
" # takes a jax transformation and returns the same transformation\n",
" # but with the ability to apply it to arbitrary pytrees of non-jax types\n",
" \"\"\"Enable jax transformations to accept non-jax types.\"\"\"\n",
"\n",
" def out_transform(func, **transformation_kwargs):\n",
" @ft.partial(jax_transform, **transformation_kwargs)\n",
" def jax_boundary(*args, **kwargs):\n",
" # unmask the inputs before pasing to the actual function\n",
" args, kwargs = sk.tree_unmask((args, kwargs))\n",
" # mask the outputs after calling the actual function\n",
" # because outputs from `jax` transformation should return jax-types\n",
" # outputs should return jax types\n",
" return sk.tree_mask(func(*args, **kwargs))\n",
"\n",
" @ft.wraps(func)\n",
" def outer_wrapper(*args, **kwargs):\n",
" # mask the inputs before the `jax` boundary\n",
" args, kwargs = sk.tree_mask((args, kwargs))\n",
" return sk.tree_unmask(jax_boundary(*args, **kwargs))\n",
" # apply the jax transformation\n",
" output = jax_boundary(*args, **kwargs)\n",
" # unmask the outputs before returning\n",
" return sk.tree_unmask(output)\n",
"\n",
" return outer_wrapper\n",
"\n",
" return out_transform\n",
"\n",
"\n",
" return out_transform"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`automask` with `jit`"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"`jit error`: Argument 'layer' of type <class 'str'> is not a valid JAX type\n",
"Using automask:\n",
"forward_jit(params, x)=Array([[4.999546, 4.999546, 4.999546, 4.999546, 4.999546],\n",
" [4.999546, 4.999546, 4.999546, 4.999546, 4.999546],\n",
" [4.999546, 4.999546, 4.999546, 4.999546, 4.999546],\n",
" [4.999546, 4.999546, 4.999546, 4.999546, 4.999546],\n",
" [4.999546, 4.999546, 4.999546, 4.999546, 4.999546]], dtype=float32)\n"
]
}
],
"source": [
"x, y = jnp.ones([5, 5]), jnp.ones([5, 5])\n",
"\n",
"\n",
"# test masked transformations\n",
"params = dict(w1=jnp.ones([5, 5]), w2=jnp.ones([5, 5]), name=\"layer\")\n",
"\n",
"\n",
"@automask(jax.jit)\n",
"def func(x: jax.Array, y: jax.Array, name: str):\n",
" # name is not a jax type, with normal jit this will throw an error\n",
" return x @ y, name\n",
"def forward(params: dict[str, Any], x: jax.Array) -> jax.Array:\n",
" return jnp.tanh(x @ params[\"w1\"]) @ params[\"w2\"]\n",
"\n",
"\n",
"print(\"autojit\")\n",
"print(func(x, y, \"hello\"))\n",
"\n",
"try:\n",
" forward_jit = jax.jit(forward)\n",
" print(forward_jit(params, x))\n",
"except TypeError as e:\n",
" print(\"`jit error`:\", e)\n",
" # now with `automask` the function can accept non-jax types (e.g. string)\n",
" forward_jit = automask(jax.jit)(forward)\n",
" print(\"Using automask:\")\n",
" print(f\"{forward_jit(params, x)=}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`automask` with `vmap`"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"`vmap error`: Output from batched function 'layer' with type <class 'str'> is not a valid JAX type\n",
"Using automask:\n",
"{\n",
" name:layer, \n",
" w1:f32[4,5,5](μ=0.50, σ=0.28, ∈[0.02,1.00]), \n",
" w2:f32[4,5,5](μ=0.46, σ=0.27, ∈[0.01,0.99])\n",
"}\n"
]
}
],
"source": [
"def make_params(key: jax.Array):\n",
" k1, k2 = jax.random.split(key.astype(jnp.uint32))\n",
" return dict(w1=jr.uniform(k1, (5, 5)), w2=jr.uniform(k2, (5, 5)), name=\"layer\")\n",
"\n",
"@automask(jax.vmap)\n",
"def func(x: jax.Array, y: jax.Array, name: str):\n",
" # name is not a jax type, with normal vmap this will throw an error\n",
" return x @ y, name\n",
"\n",
"keys = jr.split(jr.PRNGKey(0), 4).astype(jnp.float32)\n",
"\n",
"print(\"autovmap\")\n",
"print(func(x, y, \"hello\"))"
"try:\n",
" params = jax.vmap(make_params)(keys)\n",
" print(params)\n",
"except TypeError as e:\n",
" print(\"`vmap error`:\", e)\n",
" # now with `automask` the function can accept non-jax types (e.g. string)\n",
" params = automask(jax.vmap)(make_params)(keys)\n",
" print(\"Using automask:\")\n",
" print(sk.tree_repr(params))"
]
}
],
Expand Down
Loading

0 comments on commit 9644ecd

Please sign in to comment.