Skip to content

Commit

Permalink
Update misc_recipes.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Dec 17, 2023
1 parent 2d9d3ce commit 1b8c8bf
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions docs/notebooks/misc_recipes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
"metadata": {},
"source": [
"## [1] Lazy layers.\n",
"In this example, a `Linear` layer with a weight parameter based on the shape of the input will be created. Since this requires parameter creation (i.e., `weight`) after instance initialization, we will use `.at` to create a new instance with the added parameter."
"In this example, a `Linear` layer with a weight parameter based on the shape of the input will be created. Since this requires parameter creation (i.e., `weight`) after instance initialization, we will use `value_and_tree` to create a new instance with the added parameter."
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [
{
Expand All @@ -62,7 +62,7 @@
" [1.]], dtype=float32)"
]
},
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -97,7 +97,8 @@
"\n",
"print(f\"Layer before param is set:\\t{lazy}\")\n",
"\n",
"_, material = lazy.at[\"__call__\"](input)\n",
"# `value_and_tree` executes any mutating method in a functional way\n",
"_, material = sk.value_and_tree(lambda layer: layer(input))(lazy)\n",
"\n",
"print(f\"Layer after param is set:\\t{material}\")\n",
"# subsequent calls will not set the parameters again\n",
Expand All @@ -116,7 +117,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -205,7 +206,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand All @@ -214,7 +215,7 @@
"25"
]
},
"execution_count": 1,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -268,7 +269,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -319,7 +320,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -347,7 +348,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -424,7 +425,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -454,7 +455,7 @@
" def _tied_call(self, input: jax.Array) -> jax.Array:\n",
" # share/tie weights of encoder and decoder\n",
" # however this operation mutates the state\n",
" # so this method will only work with .at\n",
" # so this method will only work with `value_and_tree`\n",
" # otherwise will throw `AttributeError`\n",
" self.dec1.weight = self.enc1.weight.T\n",
" self.dec2.weight = self.enc2.weight.T\n",
Expand All @@ -465,9 +466,9 @@
" return output\n",
"\n",
" def tied_call(self, input: jax.Array) -> jax.Array:\n",
" # this method call `_tied_call` with .at to\n",
" # return the output without mutating the state\n",
" output, _ = self.at[\"_tied_call\"](input)\n",
" # this method call `_tied_call` with value_and_tree\n",
" # return the output without mutating the state of the network\n",
" output, _ = sk.value_and_tree(lambda net: net._tied_call(input))(self)\n",
" return output\n",
"\n",
" def non_tied_call(self, x):\n",
Expand Down

0 comments on commit 1b8c8bf

Please sign in to comment.