diff --git a/docs/notebooks/misc_recipes.ipynb b/docs/notebooks/misc_recipes.ipynb index 58a2744..a58a6d5 100644 --- a/docs/notebooks/misc_recipes.ipynb +++ b/docs/notebooks/misc_recipes.ipynb @@ -31,12 +31,12 @@ "metadata": {}, "source": [ "## [1] Lazy layers.\n", - "In this example, a `Linear` layer with a weight parameter based on the shape of the input will be created. Since this requires parameter creation (i.e., `weight`) after instance initialization, we will use `.at` to create a new instance with the added parameter." + "In this example, a `Linear` layer with a weight parameter based on the shape of the input will be created. Since this requires parameter creation (i.e., `weight`) after instance initialization, we will use `value_and_tree` to create a new instance with the added parameter." ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -62,7 +62,7 @@ " [1.]], dtype=float32)" ] }, - "execution_count": 2, + "execution_count": 1, "metadata": {}, "output_type": "execute_result" } @@ -97,7 +97,8 @@ "\n", "print(f\"Layer before param is set:\\t{lazy}\")\n", "\n", - "_, material = lazy.at[\"__call__\"](input)\n", + "# `value_and_tree` executes any mutating method in a functional way\n", + "_, material = sk.value_and_tree(lambda layer: layer(input))(lazy)\n", "\n", "print(f\"Layer after param is set:\\t{material}\")\n", "# subsequent calls will not set the parameters again\n", @@ -116,7 +117,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -205,7 +206,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -214,7 +215,7 @@ "25" ] }, - "execution_count": 1, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -268,7 +269,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -319,7 +320,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -347,7 +348,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -424,7 +425,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -454,7 +455,7 @@ " def _tied_call(self, input: jax.Array) -> jax.Array:\n", " # share/tie weights of encoder and decoder\n", " # however this operation mutates the state\n", - " # so this method will only work with .at\n", + " # so this method will only work with `value_and_tree`\n", " # otherwise will throw `AttributeError`\n", " self.dec1.weight = self.enc1.weight.T\n", " self.dec2.weight = self.enc2.weight.T\n", @@ -465,9 +466,9 @@ " return output\n", "\n", " def tied_call(self, input: jax.Array) -> jax.Array:\n", - " # this method call `_tied_call` with .at to\n", - " # return the output without mutating the state\n", - " output, _ = self.at[\"_tied_call\"](input)\n", + " # this method call `_tied_call` with value_and_tree\n", + " # return the output without mutating the state of the network\n", + " output, _ = sk.value_and_tree(lambda net: net._tied_call(input))(self)\n", " return output\n", "\n", " def non_tied_call(self, x):\n",