From 81b4aaee3bb51282f595c61c078e0ba4a6d11101 Mon Sep 17 00:00:00 2001 From: Mahmoud Asem <48389287+ASEM000@users.noreply.github.com> Date: Sat, 2 Dec 2023 15:33:34 +0900 Subject: [PATCH] Update common_recipes.ipynb (#85) --- docs/notebooks/common_recipes.ipynb | 140 +++++++++++++++++++--------- 1 file changed, 98 insertions(+), 42 deletions(-) diff --git a/docs/notebooks/common_recipes.ipynb b/docs/notebooks/common_recipes.ipynb index 4307e41..57db088 100644 --- a/docs/notebooks/common_recipes.ipynb +++ b/docs/notebooks/common_recipes.ipynb @@ -363,6 +363,63 @@ "print(tree_where(tree > 2, tree + 100, 0))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`bcmap` on pytrees with non-jaxtype\n", + "\n", + "In case the tree has some non-jaxtype leaves, The above will fail, but we can use `tree_mask`/`tree_unmask` to fix it" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "bcmap fail '>' not supported between instances of 'str' and 'int'\n", + "Tree(a=0.0, b=(0.0, 103.0), c=[104. 105. 106.], name=tree, func=(x))\n" + ] + } + ], + "source": [ + "# in case the tree has some non-jaxtype leaves\n", + "# the above will fail, but we can use `tree_mask`/`tree_unmask` to fix it\n", + "import serket as sk\n", + "import jax.numpy as jnp\n", + "from typing import Callable\n", + "\n", + "\n", + "@sk.leafwise # enable math operations on leaves\n", + "@sk.autoinit # generate __init__ from type annotations\n", + "class Tree(sk.TreeClass):\n", + " a: float = 1.0\n", + " b: tuple[float] = (2.0, 3.0)\n", + " c: jax.Array = jnp.array([4.0, 5.0, 6.0])\n", + " name: str = \"tree\" # non-jaxtype\n", + " func: Callable = lambda x: x # non-jaxtype\n", + "\n", + "\n", + "tree = Tree()\n", + "\n", + "try:\n", + " # make where work with arbitrary pytrees with non-jaxtype leaves\n", + " tree_where = sk.bcmap(jnp.where)\n", + " # for values > 2, add 100, else set to 0\n", + " print(tree_where(tree > 2, tree + 100, 0))\n", + "except TypeError as e:\n", + " print(\"bcmap fail\", e)\n", + " # now we can use `tree_mask`/`tree_unmask` to fix it\n", + " masked_tree = sk.tree_mask(tree) # mask non-jaxtype leaves\n", + " masked_tree = tree_where(masked_tree > 2, masked_tree + 100, 0)\n", + " unmasked_tree = sk.tree_unmask(masked_tree)\n", + " print(unmasked_tree)" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -380,7 +437,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -450,7 +507,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -473,31 +530,30 @@ "\n", "\n", "class ArrayValidator(sk.TreeClass):\n", + " \"\"\"Validate shape and dtype of input array.\n", + "\n", + " Args:\n", + " shape: Expected shape of array. available values are int, None, ...\n", + " use int for fixed size, None for any size, and ... for any number\n", + " of dimensions. for example (..., 1) allows any number of dimensions\n", + " with the last dimension being 1. (1, ..., 1) allows any number of\n", + " dimensions with the first and last dimensions being 1.\n", + " dtype: Expected dtype of array.\n", + "\n", + " Example:\n", + " >>> x = jnp.ones((5, 5))\n", + " >>> # any number of dimensions with last dim=5\n", + " >>> shape = (..., 5)\n", + " >>> dtype = jnp.float32\n", + " >>> validator = ArrayValidator(shape, dtype)\n", + " >>> validator(x) # no error\n", + "\n", + " >>> # must be 2 dimensions with first dim unconstrained and last dim=5\n", + " >>> shape = (None, 5)\n", + " >>> validator = ArrayValidator(shape, dtype)\n", + " >>> validator(x) # no error\n", + " \"\"\"\n", " def __init__(self, shape, dtype):\n", - " \"\"\"Validate shape and dtype of input array.\n", - "\n", - " Args:\n", - " shape: Expected shape of array. available values are int, None, ...\n", - " use int for fixed size, None for any size, and ... for any number\n", - " of dimensions. for example (..., 1) allows any number of dimensions\n", - " with the last dimension being 1. (1, ..., 1) allows any number of\n", - " dimensions with the first and last dimensions being 1.\n", - " dtype: Expected dtype of array.\n", - "\n", - " Example:\n", - " >>> x = jnp.ones((5, 5))\n", - " >>> # any number of dimensions with last dim=5\n", - " >>> shape = (..., 5)\n", - " >>> dtype = jnp.float32\n", - " >>> validator = ArrayValidator(shape, dtype)\n", - " >>> validator(x) # no error\n", - "\n", - " >>> # must be 2 dimensions with first dim unconstrained and last dim=5\n", - " >>> shape = (None, 5)\n", - " >>> validator = ArrayValidator(shape, dtype)\n", - " >>> validator(x) # no error\n", - " \"\"\"\n", - "\n", " if shape.count(...) > 1:\n", " raise ValueError(\"Only one ellipsis allowed\")\n", "\n", @@ -586,7 +642,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -612,7 +668,7 @@ " [1.]], dtype=float32)" ] }, - "execution_count": 10, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -669,7 +725,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -758,7 +814,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -782,7 +838,7 @@ " )]" ] }, - "execution_count": 12, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -838,7 +894,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -955,7 +1011,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -964,7 +1020,7 @@ "25" ] }, - "execution_count": 14, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -1019,7 +1075,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -1070,7 +1126,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -1096,7 +1152,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -1173,7 +1229,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -1199,7 +1255,7 @@ ")" ] }, - "execution_count": 18, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -1268,7 +1324,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -1314,7 +1370,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -1361,7 +1417,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 22, "metadata": {}, "outputs": [ {