diff --git a/docs/notebooks/common_recipes.ipynb b/docs/notebooks/common_recipes.ipynb index cc9e8ee..4307e41 100644 --- a/docs/notebooks/common_recipes.ipynb +++ b/docs/notebooks/common_recipes.ipynb @@ -127,9 +127,9 @@ "class Tree(sk.TreeClass):\n", " def method_1(self, x: jax.Array) -> jax.Array:\n", " return x + jax.lax.stop_gradient(self.buffer)\n", - " .\n", - " .\n", - " .\n", + " .\n", + " .\n", + " .\n", " def method_n(self, x: jax.Array) -> jax.Array:\n", " return x + jax.lax.stop_gradient(self.buffer)\n", "```\n", @@ -220,12 +220,12 @@ } ], "source": [ - "import serket as sp\n", + "import serket as sk\n", "import jax\n", "\n", "\n", "def frozen_field(**kwargs):\n", - " return sp.field(on_getattr=[sk.unfreeze], on_setattr=[sk.freeze], **kwargs)\n", + " return sk.field(on_getattr=[sk.unfreeze], on_setattr=[sk.freeze], **kwargs)\n", "\n", "\n", "@sk.autoinit\n", @@ -301,7 +301,7 @@ } ], "source": [ - "import serket as sp\n", + "import serket as sk\n", "import jax.numpy as jnp\n", "\n", "\n", @@ -343,7 +343,7 @@ } ], "source": [ - "import serket as sp\n", + "import serket as sk\n", "import jax.numpy as jnp\n", "\n", "\n", @@ -396,7 +396,7 @@ ], "source": [ "import jax\n", - "import serket as sp\n", + "import serket as sk\n", "\n", "\n", "# you can use any function\n", @@ -466,7 +466,7 @@ } ], "source": [ - "import serket as sp\n", + "import serket as sk\n", "from typing import Any\n", "import jax\n", "import jax.numpy as jnp\n", @@ -618,7 +618,7 @@ } ], "source": [ - "import serket as sp\n", + "import serket as sk\n", "from typing import Any\n", "import jax\n", "import jax.numpy as jnp\n", @@ -694,7 +694,7 @@ ], "source": [ "from typing import Any\n", - "import serket as sp\n", + "import serket as sk\n", "import jax\n", "import optax\n", "import jax.numpy as jnp\n", @@ -788,7 +788,7 @@ } ], "source": [ - "import serket as sp\n", + "import serket as sk\n", "import jax\n", "\n", "\n", @@ -854,7 +854,7 @@ "import jax\n", "import jax.numpy as jnp\n", "import jax.random as jr\n", - "import serket as sp\n", + "import serket as sk\n", "import functools as ft\n", "from typing import Generic, TypeVar\n", "\n", @@ -1003,7 +1003,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## [12] Regaularization\n", + "## [12] Regularization\n", "\n", "The following code showcase how to use `at` functionality to select some leaves of a model based on boolean mask or/and name condition to apply some weight regualrization on them. For example using `.at[...]` functionality the following can be achieved concisely:" ] @@ -1270,77 +1270,131 @@ "cell_type": "code", "execution_count": 19, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "autojit\n", - "(Array([[5., 5., 5., 5., 5.],\n", - " [5., 5., 5., 5., 5.],\n", - " [5., 5., 5., 5., 5.],\n", - " [5., 5., 5., 5., 5.],\n", - " [5., 5., 5., 5., 5.]], dtype=float32), 'hello')\n", - "autovmap\n", - "(Array([5., 5., 5., 5., 5.], dtype=float32), 'hello')\n" - ] - } - ], + "outputs": [], "source": [ "import serket as sk\n", "import functools as ft\n", "import jax\n", + "import jax.random as jr\n", "import jax.numpy as jnp\n", + "from typing import Any\n", "\n", "\n", "def automask(jax_transform):\n", - " # takes a jax transformation and returns the same transformation\n", - " # but with the ability to apply it to arbitrary pytrees of non-jax types\n", + " \"\"\"Enable jax transformations to accept non-jax types.\"\"\"\n", "\n", " def out_transform(func, **transformation_kwargs):\n", " @ft.partial(jax_transform, **transformation_kwargs)\n", " def jax_boundary(*args, **kwargs):\n", " # unmask the inputs before pasing to the actual function\n", " args, kwargs = sk.tree_unmask((args, kwargs))\n", - " # mask the outputs after calling the actual function\n", - " # because outputs from `jax` transformation should return jax-types\n", + " # outputs should return jax types\n", " return sk.tree_mask(func(*args, **kwargs))\n", "\n", " @ft.wraps(func)\n", " def outer_wrapper(*args, **kwargs):\n", " # mask the inputs before the `jax` boundary\n", " args, kwargs = sk.tree_mask((args, kwargs))\n", - " return sk.tree_unmask(jax_boundary(*args, **kwargs))\n", + " # apply the jax transformation\n", + " output = jax_boundary(*args, **kwargs)\n", + " # unmask the outputs before returning\n", + " return sk.tree_unmask(output)\n", "\n", " return outer_wrapper\n", "\n", - " return out_transform\n", - "\n", - "\n", + " return out_transform" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`automask` with `jit`" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "`jit error`: Argument 'layer' of type is not a valid JAX type\n", + "Using automask:\n", + "forward_jit(params, x)=Array([[4.999546, 4.999546, 4.999546, 4.999546, 4.999546],\n", + " [4.999546, 4.999546, 4.999546, 4.999546, 4.999546],\n", + " [4.999546, 4.999546, 4.999546, 4.999546, 4.999546],\n", + " [4.999546, 4.999546, 4.999546, 4.999546, 4.999546],\n", + " [4.999546, 4.999546, 4.999546, 4.999546, 4.999546]], dtype=float32)\n" + ] + } + ], + "source": [ "x, y = jnp.ones([5, 5]), jnp.ones([5, 5])\n", "\n", - "\n", - "# test masked transformations\n", + "params = dict(w1=jnp.ones([5, 5]), w2=jnp.ones([5, 5]), name=\"layer\")\n", "\n", "\n", - "@automask(jax.jit)\n", - "def func(x: jax.Array, y: jax.Array, name: str):\n", - " # name is not a jax type, with normal jit this will throw an error\n", - " return x @ y, name\n", + "def forward(params: dict[str, Any], x: jax.Array) -> jax.Array:\n", + " return jnp.tanh(x @ params[\"w1\"]) @ params[\"w2\"]\n", "\n", "\n", - "print(\"autojit\")\n", - "print(func(x, y, \"hello\"))\n", - "\n", + "try:\n", + " forward_jit = jax.jit(forward)\n", + " print(forward_jit(params, x))\n", + "except TypeError as e:\n", + " print(\"`jit error`:\", e)\n", + " # now with `automask` the function can accept non-jax types (e.g. string)\n", + " forward_jit = automask(jax.jit)(forward)\n", + " print(\"Using automask:\")\n", + " print(f\"{forward_jit(params, x)=}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`automask` with `vmap`" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "`vmap error`: Output from batched function 'layer' with type is not a valid JAX type\n", + "Using automask:\n", + "{\n", + " name:layer, \n", + " w1:f32[4,5,5](μ=0.50, σ=0.28, ∈[0.02,1.00]), \n", + " w2:f32[4,5,5](μ=0.46, σ=0.27, ∈[0.01,0.99])\n", + "}\n" + ] + } + ], + "source": [ + "def make_params(key: jax.Array):\n", + " k1, k2 = jax.random.split(key.astype(jnp.uint32))\n", + " return dict(w1=jr.uniform(k1, (5, 5)), w2=jr.uniform(k2, (5, 5)), name=\"layer\")\n", "\n", - "@automask(jax.vmap)\n", - "def func(x: jax.Array, y: jax.Array, name: str):\n", - " # name is not a jax type, with normal vmap this will throw an error\n", - " return x @ y, name\n", "\n", + "keys = jr.split(jr.PRNGKey(0), 4).astype(jnp.float32)\n", "\n", - "print(\"autovmap\")\n", - "print(func(x, y, \"hello\"))" + "try:\n", + " params = jax.vmap(make_params)(keys)\n", + " print(params)\n", + "except TypeError as e:\n", + " print(\"`vmap error`:\", e)\n", + " # now with `automask` the function can accept non-jax types (e.g. string)\n", + " params = automask(jax.vmap)(make_params)(keys)\n", + " print(\"Using automask:\")\n", + " print(sk.tree_repr(params))" ] } ], diff --git a/docs/notebooks/model_surgery.ipynb b/docs/notebooks/model_surgery.ipynb index aa50b72..8d67aa8 100644 --- a/docs/notebooks/model_surgery.ipynb +++ b/docs/notebooks/model_surgery.ipynb @@ -11,95 +11,312 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "!pip install git+https://github.com/ASEM000/serket --quiet" ] }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import serket as sk\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import re" + ] + }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## `AtIndexer` basics\n", + "## `AtIndexer` basics" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`serket.AtIndexer` wraps any pytree (nested container) to manipulate its content in out-of-place fashion. This means that any change will be applied to a _new_ instance of the pytree.\n", "\n", - "`serket.AtIndexer` wraps any pytree to manipulate its content in out-of-place fashion. This means that any change will be applied on a _new_ instance of the pytree. The following example demonstrate this point:" + " The following example demonstrate this point:" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[1, [100, 3], 4]\n" + "pytree1=[1, [2, 3], 4], pytree2=[1, [2, 3], 4]\n", + "pytree1 is pytree2 = False\n" ] - }, - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ - "import serket as sk\n", - "pytree1 = [1, [2, 3], 4]\n", - "pytree2 = sk.AtIndexer(pytree1)[1][0].set(100) # equivalent to pytree[1][0] = 100\n", - "print(pytree2)\n", - "# [1, [100, 3], 4]\n", - "pytree1 is pytree2 # test out-of-place update" + "pytree1 = [1, [2, 3], 4] \n", + "indexer: sk.AtIndexer = sk.AtIndexer(pytree1)\n", + "pytree2 = indexer[...].get() # get the whole pytree using ...\n", + "print(f\"{pytree1=}, {pytree2=}\")\n", + "# even though pytree1 and pytree2 are the same, they are not the same object\n", + "# because pytree2 is a copy of pytree1\n", + "print(f\"pytree1 is pytree2 = {pytree1 is pytree2}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "`serket.AtIndexer` can also edit pytree entries by a boolean mask, meaning that given a mask of the same structure of the pytree, then nodes marked `True` will be edited, otherwise will not be touched. The following example set all negative entries to 0:" + "Note that each `[ ]` is selecting at a certain depth, meaning that `[a][b]` is selecting\n", + "`a` at depth=1 and `b` at depth=2." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Integer indexing\n", + "\n", + "`serket.AtIndexer` can edit pytrees by integer paths." ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{'a': 0, 'b': {'c': 2, 'd': 3}, 'e': 0}\n" + "pytree1=[1, [2, 3], 4], pytree2=[1, [100, 3], 4]\n" ] - }, - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ - "import serket as sk\n", - "import jax\n", + "pytree1 = [1, [2, 3], 4]\n", + "indexer: sk.AtIndexer = sk.AtIndexer(pytree1)\n", + "pytree2 = indexer[1][0].set(100) # equivalent to pytree1[1][0] = 100\n", "\n", + "print(f\"{pytree1=}, {pytree2=}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Named path indexing\n", + "`serket.AtIndexer` can edit pytrees by named paths." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "pytree1 = {\"a\": -1, \"b\": {\"c\": 2, \"d\": 3}, \"e\": -4, \"f\": {\"g\": 7, \"h\": 8}}\n", + "indexer: sk.AtIndexer = sk.AtIndexer(pytree1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```diff\n", + "Original pytree1\n", + "dict\n", + "├── ['a']=-1\n", + "├── ['b']:dict\n", + "│ ├── ['c']=2\n", + "│ └── ['d']=3\n", + "├── ['e']=-4\n", + "└── ['f']:dict\n", + " ├── ['g']=7\n", + " └── ['h']=8\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# exmaple 1: set the value of pytree1[\"b\"][\"c\"] to 200\n", + "pytree2 = indexer[\"b\"][\"c\"].set(200)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```diff\n", + "Set pytree1['b']['c'] to 200:\n", + " dict\n", + " ├── ['a']=-1\n", + " ├── ['b']:dict\n", + "+│ ├── ['c']=200\n", + " │ └── ['d']=3\n", + " ├── ['e']=-4\n", + " └── ['f']:dict\n", + " ├── ['g']=7\n", + " └── ['h']=8\n", + "\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# example 2: set the value of pytree1[\"b\"] to 100\n", + "pytree3 = indexer[\"b\"].set(100)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```diff\n", + "Set pytree1['b'] to 100:\n", + " dict\n", + " ├── ['a']=-1\n", + "+├── ['b']=100\n", + " ├── ['e']=-4\n", + " └── ['f']:dict\n", + " ├── ['g']=7\n", + " └── ['h']=8\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# example 3: set _all leaves_ of \"b\" subtree to 100\n", + "pytree4 = indexer[\"b\"][...].set(100)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```diff\n", + "Set _all leaves_ of pytree1['b'] to 100:\n", + " dict\n", + " ├── ['a']=-1\n", + " ├── ['b']:dict\n", + "+│ ├── ['c']=100\n", + "+│ └── ['d']=100\n", + " ├── ['e']=-4\n", + " └── ['f']:dict\n", + " ├── ['g']=7\n", + " └── ['h']=8\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# example 4: set _all leaves_ of pytree1[\"b\"] _and_ pytree1[\"f\"] to 100\n", + "pytree5 = indexer[\"b\", \"f\"][...].set(100)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```diff\n", + "Set _all leaves_ of pytree1['b'] and pytree1['f'] to 100:\n", + " dict\n", + " ├── ['a']=-1\n", + " ├── ['b']:dict\n", + "+│ ├── ['c']=100\n", + "+│ └── ['d']=100\n", + " ├── ['e']=-4\n", + " └── ['f']:dict\n", + "+ ├── ['g']=100\n", + "+ └── ['h']=100\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Masked indexing\n", + "`serket.AtIndexer` can also edit pytree entries by a boolean mask, meaning that given a mask of the same structure of the pytree, then nodes marked `True` will be edited, otherwise will not be touched. The following example set all negative entries to 0:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ "pytree1 = {\"a\": -1, \"b\": {\"c\": 2, \"d\": 3}, \"e\": -4}\n", "mask = jax.tree_map(lambda x: x < 0, pytree1)\n", - "pytree2 = sk.AtIndexer(pytree1)[mask].set(0)\n", - "print(pytree2)\n", - "# {'a': 0, 'b': {'c': 2, 'd': 3}, 'e': 0}\n", - "pytree1 is pytree2 # test out-of-place update\n", - "# False" + "indexer: sk.AtIndexer = sk.AtIndexer(pytree1)\n", + "pytree2 = indexer[mask].set(0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Original pytree\n", + "```diff\n", + "dict\n", + " ├── ['a']=-1\n", + " ├── ['b']:dict\n", + " │ ├── ['c']=2\n", + " │ └── ['d']=3\n", + " └── ['e']=-4\n", + "```\n", + "\n", + "Mask\n", + "\n", + "```diff\n", + "dict\n", + "+├── ['a']=True\n", + " ├── ['b']:dict\n", + " │ ├── ['c']=False\n", + " │ └── ['d']=False\n", + "+└── ['e']=True\n", + "```\n", + "\n", + "modified pytree\n", + "\n", + "```diff\n", + "dict\n", + "+├── ['a']=0\n", + " ├── ['b']:dict\n", + " │ ├── ['c']=2\n", + " │ └── ['d']=3\n", + "+└── ['e']=0\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Other features include `get`,`apply`,`scan`,`reduce`, and `pluck`. Check the documentation for more examples" ] }, { @@ -108,127 +325,136 @@ "source": [ "## `serket` layers surgery\n", "\n", - "Similarly, `serket` layers are pytrees as above. Howver, `AtIndexer` is embeded in `TreeClass` under `.at` property, this design enables powerful composition of both name/index based and boolean based updates. The next example demonstrates this point.\n" + "Similarly, `serket` layers are pytrees as above with `AtIndexer` embeded in `TreeClass` under `.at` property. This design enables powerful composition of both name/index based and boolean based updates. The next example demonstrates this point.\n" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "ConvNet\n", - "├── .conv1:Conv2D\n", - "│ ├── .in_features=3\n", - "│ ├── .out_features=10\n", - "│ ├── .kernel_size=(...)\n", - "│ ├── .strides=(...)\n", - "│ ├── .padding=same\n", - "│ ├── .dilation=(...)\n", - "│ ├── .weight_init=glorot_uniform\n", - "│ ├── .bias_init=zeros\n", - "│ ├── .groups=1\n", - "│ ├── .weight=f32[10,3,3,3](μ=-0.00, σ=0.11, ∈[-0.18,0.18])\n", - "│ └── .bias=f32[10,1,1](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", - "└── .conv2:Conv2D\n", - " ├── .in_features=10\n", - " ├── .out_features=1\n", - " ├── .kernel_size=(...)\n", - " ├── .strides=(...)\n", - " ├── .padding=same\n", - " ├── .dilation=(...)\n", - " ├── .weight_init=glorot_uniform\n", - " ├── .bias_init=zeros\n", - " ├── .groups=1\n", - " ├── .weight=f32[1,10,1,1](μ=-0.18, σ=0.29, ∈[-0.53,0.31])\n", - " └── .bias=f32[1,1,1](μ=0.00, σ=0.00, ∈[0.00,0.00])\n" + "net1=Net(\n", + " encoder={\n", + " weight:f32[5,3](μ=0.30, σ=0.90, ∈[-1.44,1.84]), \n", + " bias:f32[5](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", + " }, \n", + " decoder={\n", + " weight:f32[5,5](μ=-0.16, σ=0.75, ∈[-1.78,1.20]), \n", + " bias:f32[3](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", + " }\n", + ")\n" ] } ], "source": [ - "import serket as sk\n", - "import jax\n", - "import jax.numpy as jnp\n", - "\n", - "\n", "# basic convnet with two convolutional layers\n", - "class ConvNet(sk.TreeClass):\n", - " def __init__(self, indim, outdim, key):\n", + "class Net(sk.TreeClass):\n", + " def __init__(self, in_features: int, out_features: int, *, key: jax.Array):\n", " k1, k2 = jax.random.split(key)\n", - " self.conv1 = sk.nn.Conv2D(indim, outdim, 3, key=k1)\n", - " self.conv2 = sk.nn.Conv2D(outdim, 1, 1, key=k2)\n", + " W1 = jax.random.normal(k1, (out_features, in_features))\n", + " W2 = jax.random.normal(k2, (out_features, out_features))\n", "\n", - " def __call__(self, x):\n", - " x = self.conv1(x)\n", - " x = jax.nn.relu(x)\n", - " x = self.conv2(x)\n", - " return x\n", + " self.encoder = {\"weight\": W1, \"bias\": jnp.zeros((out_features,))}\n", + " self.decoder = {\"weight\": W2, \"bias\": jnp.zeros((in_features,))}\n", "\n", + " def __call__(self, x):\n", + " x = x @ self.encoder[\"weight\"] + self.encoder[\"bias\"]\n", + " x = x @ self.decoder[\"weight\"] + self.decoder[\"bias\"]\n", + " return\n", "\n", - "cnn1 = ConvNet(3, 10, jax.random.PRNGKey(0))\n", "\n", - "# note that `ConvNet` is composed of two branches\n", - "print(sk.tree_diagram(cnn1, depth=2))" + "net1 = Net(3, 5, key=jax.random.PRNGKey(0))\n", + "print(f\"{net1=}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Now, suppose we want to set the range of 'weight' in both layers to `[-0.2, 0.2]` by setting out-of-range values to zero. Combining the name-based indexing - i.e. `conv1.weight` and `conv2.weight` - with boolean masking - i.e. a mask that is true if `x<-0.2` or `x>0.2` - suffices to achieve this. The following example show how can achieve this by _composition_." + "```diff\n", + "Net\n", + "├── .encoder:dict\n", + "│ ├── ['bias']=f32[5](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", + "│ └── ['weight']=f32[5,3](μ=0.30, σ=0.90, ∈[-1.44,1.84])\n", + "└── .decoder:dict\n", + " ├── ['bias']=f32[3](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", + " └── ['weight']=f32[5,5](μ=-0.16, σ=0.75, ∈[-1.78,1.20])\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, suppose we want to set the range of `weight` in both layers `conv1` and `conv2` to `[-0.2, 0.2]` by clipping out of bound values." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 12, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ConvNet\n", - "├── .conv1:Conv2D\n", - "│ ├── .in_features=3\n", - "│ ├── .out_features=10\n", - "│ ├── .kernel_size=(...)\n", - "│ ├── .strides=(...)\n", - "│ ├── .padding=same\n", - "│ ├── .dilation=(...)\n", - "│ ├── .weight_init=glorot_uniform\n", - "│ ├── .bias_init=zeros\n", - "│ ├── .groups=1\n", - "│ ├── .weight=f32[10,3,3,3](μ=-0.00, σ=0.11, ∈[-0.18,0.18])\n", - "│ └── .bias=f32[10,1,1](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", - "└── .conv2:Conv2D\n", - " ├── .in_features=10\n", - " ├── .out_features=1\n", - " ├── .kernel_size=(...)\n", - " ├── .strides=(...)\n", - " ├── .padding=same\n", - " ├── .dilation=(...)\n", - " ├── .weight_init=glorot_uniform\n", - " ├── .bias_init=zeros\n", - " ├── .groups=1\n", - " ├── .weight=f32[1,10,1,1](μ=-0.02, σ=0.08, ∈[-0.17,0.14])\n", - " └── .bias=f32[1,1,1](μ=0.00, σ=0.00, ∈[0.00,0.00])\n" - ] - } - ], + "outputs": [], "source": [ - "def set_to_zero(x):\n", - " # set all values of x to zero if they are not in the range [-0.2, 0.2]\n", - " return jnp.where(x < -0.2, 0, jnp.where(x > 0.2, 0, x))\n", - "\n", - "\n", - "# note that ['conv1', 'conv2'] is basically selecting both 'conv1' and 'conv2'\n", - "cnn2 = cnn1.at[\"conv1\", \"conv2\"][\"weight\"].apply(set_to_zero)\n", - "\n", - "# note that weight of both 'conv1' and 'conv2' range is changed\n", - "print(sk.tree_diagram(cnn2, depth=2))" + "# example 1: clip the `weights` of `encoder` and `decoder` to [-0.2, 0.2]\n", + "net2 = net1.at[\"encoder\", \"decoder\"][\"weight\"].apply(lambda x: jnp.clip(x, -0.2, 0.2))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```diff\n", + " Net\n", + " ├── .encoder:dict\n", + " │ ├── ['bias']=f32[5](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", + "+│ └── ['weight']=f32[5,3](μ=0.04, σ=0.18, ∈[-0.20,0.20])\n", + " └── .decoder:dict\n", + " ├── ['bias']=f32[3](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", + "+ └── ['weight']=f32[5,5](μ=-0.02, σ=0.18, ∈[-0.20,0.20])\n", + "````" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# example 2: load pretrained weights for `encoder`\n", + "pretrained = {\"weight\": jnp.ones((5, 3))*100., \"bias\": jnp.ones((5,))*100.}\n", + "net3 = net1.at[\"encoder\"].set(pretrained)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```diff\n", + " net1=Net(\n", + " encoder={\n", + "- weight:f32[5,3](μ=0.30, σ=0.90, ∈[-1.44,1.84]), \n", + "- bias:f32[5](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", + " }, \n", + " decoder={\n", + " weight:f32[5,5](μ=-0.16, σ=0.75, ∈[-1.78,1.20]), \n", + " bias:f32[3](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", + " }\n", + " )\n", + " net3=Net(\n", + " encoder={\n", + "+ weight:f32[5,3](μ=100.00, σ=0.00, ∈[100.00,100.00]), \n", + "+ bias:f32[5](μ=100.00, σ=0.00, ∈[100.00,100.00])\n", + " }, \n", + " decoder={\n", + " weight:f32[5,5](μ=-0.16, σ=0.75, ∈[-1.78,1.20]), \n", + " bias:f32[3](μ=0.00, σ=0.00, ∈[0.00,0.00])\n", + " }\n", + " )\n", + "```" ] } ], diff --git a/docs/notebooks/train_unet.ipynb b/docs/notebooks/train_unet.ipynb index 4e23ba4..24d9b21 100644 --- a/docs/notebooks/train_unet.ipynb +++ b/docs/notebooks/train_unet.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Train Semantic segmenter with `UNet`\n", + "# Train `UNet` segmenter\n", "\n", "In this example, a `UNet` architecture is used to segment Oxford pets datasets.\n", "\n", @@ -24,9 +24,9 @@ "metadata": {}, "outputs": [], "source": [ - "# !pip install git+https://github.com/ASEM000/serket --quiet\n", - "# !pip install optax --quiet\n", - "# !pip install ml_collections --quiet" + "!pip install git+https://github.com/ASEM000/serket --quiet\n", + "!pip install optax --quiet\n", + "!pip install ml_collections --quiet" ] }, { @@ -1075,7 +1075,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.0" + "version": "3.1.0" } }, "nbformat": 4,