Skip to content

Commit

Permalink
Update common_recipes.ipynb (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Dec 2, 2023
1 parent 7473148 commit 81b4aae
Showing 1 changed file with 98 additions and 42 deletions.
140 changes: 98 additions & 42 deletions docs/notebooks/common_recipes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,63 @@
"print(tree_where(tree > 2, tree + 100, 0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`bcmap` on pytrees with non-jaxtype\n",
"\n",
"In case the tree has some non-jaxtype leaves, The above will fail, but we can use `tree_mask`/`tree_unmask` to fix it"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"bcmap fail '>' not supported between instances of 'str' and 'int'\n",
"Tree(a=0.0, b=(0.0, 103.0), c=[104. 105. 106.], name=tree, func=<lambda>(x))\n"
]
}
],
"source": [
"# in case the tree has some non-jaxtype leaves\n",
"# the above will fail, but we can use `tree_mask`/`tree_unmask` to fix it\n",
"import serket as sk\n",
"import jax.numpy as jnp\n",
"from typing import Callable\n",
"\n",
"\n",
"@sk.leafwise # enable math operations on leaves\n",
"@sk.autoinit # generate __init__ from type annotations\n",
"class Tree(sk.TreeClass):\n",
" a: float = 1.0\n",
" b: tuple[float] = (2.0, 3.0)\n",
" c: jax.Array = jnp.array([4.0, 5.0, 6.0])\n",
" name: str = \"tree\" # non-jaxtype\n",
" func: Callable = lambda x: x # non-jaxtype\n",
"\n",
"\n",
"tree = Tree()\n",
"\n",
"try:\n",
" # make where work with arbitrary pytrees with non-jaxtype leaves\n",
" tree_where = sk.bcmap(jnp.where)\n",
" # for values > 2, add 100, else set to 0\n",
" print(tree_where(tree > 2, tree + 100, 0))\n",
"except TypeError as e:\n",
" print(\"bcmap fail\", e)\n",
" # now we can use `tree_mask`/`tree_unmask` to fix it\n",
" masked_tree = sk.tree_mask(tree) # mask non-jaxtype leaves\n",
" masked_tree = tree_where(masked_tree > 2, masked_tree + 100, 0)\n",
" unmasked_tree = sk.tree_unmask(masked_tree)\n",
" print(unmasked_tree)"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand All @@ -380,7 +437,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -450,7 +507,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand All @@ -473,31 +530,30 @@
"\n",
"\n",
"class ArrayValidator(sk.TreeClass):\n",
" \"\"\"Validate shape and dtype of input array.\n",
"\n",
" Args:\n",
" shape: Expected shape of array. available values are int, None, ...\n",
" use int for fixed size, None for any size, and ... for any number\n",
" of dimensions. for example (..., 1) allows any number of dimensions\n",
" with the last dimension being 1. (1, ..., 1) allows any number of\n",
" dimensions with the first and last dimensions being 1.\n",
" dtype: Expected dtype of array.\n",
"\n",
" Example:\n",
" >>> x = jnp.ones((5, 5))\n",
" >>> # any number of dimensions with last dim=5\n",
" >>> shape = (..., 5)\n",
" >>> dtype = jnp.float32\n",
" >>> validator = ArrayValidator(shape, dtype)\n",
" >>> validator(x) # no error\n",
"\n",
" >>> # must be 2 dimensions with first dim unconstrained and last dim=5\n",
" >>> shape = (None, 5)\n",
" >>> validator = ArrayValidator(shape, dtype)\n",
" >>> validator(x) # no error\n",
" \"\"\"\n",
" def __init__(self, shape, dtype):\n",
" \"\"\"Validate shape and dtype of input array.\n",
"\n",
" Args:\n",
" shape: Expected shape of array. available values are int, None, ...\n",
" use int for fixed size, None for any size, and ... for any number\n",
" of dimensions. for example (..., 1) allows any number of dimensions\n",
" with the last dimension being 1. (1, ..., 1) allows any number of\n",
" dimensions with the first and last dimensions being 1.\n",
" dtype: Expected dtype of array.\n",
"\n",
" Example:\n",
" >>> x = jnp.ones((5, 5))\n",
" >>> # any number of dimensions with last dim=5\n",
" >>> shape = (..., 5)\n",
" >>> dtype = jnp.float32\n",
" >>> validator = ArrayValidator(shape, dtype)\n",
" >>> validator(x) # no error\n",
"\n",
" >>> # must be 2 dimensions with first dim unconstrained and last dim=5\n",
" >>> shape = (None, 5)\n",
" >>> validator = ArrayValidator(shape, dtype)\n",
" >>> validator(x) # no error\n",
" \"\"\"\n",
"\n",
" if shape.count(...) > 1:\n",
" raise ValueError(\"Only one ellipsis allowed\")\n",
"\n",
Expand Down Expand Up @@ -586,7 +642,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand All @@ -612,7 +668,7 @@
" [1.]], dtype=float32)"
]
},
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -669,7 +725,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -758,7 +814,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"outputs": [
{
Expand All @@ -782,7 +838,7 @@
" )]"
]
},
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -838,7 +894,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -955,7 +1011,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 15,
"metadata": {},
"outputs": [
{
Expand All @@ -964,7 +1020,7 @@
"25"
]
},
"execution_count": 14,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -1019,7 +1075,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 16,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -1070,7 +1126,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"metadata": {},
"outputs": [
{
Expand All @@ -1096,7 +1152,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 18,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -1173,7 +1229,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 19,
"metadata": {},
"outputs": [
{
Expand All @@ -1199,7 +1255,7 @@
")"
]
},
"execution_count": 18,
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -1268,7 +1324,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -1314,7 +1370,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 21,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -1361,7 +1417,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 22,
"metadata": {},
"outputs": [
{
Expand Down

0 comments on commit 81b4aae

Please sign in to comment.