Skip to content

Commit

Permalink
org
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Sep 17, 2023
1 parent fb71369 commit 7e5f1a7
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 130 deletions.
1 change: 0 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ Install from github::

train_examples
notebooks/evaluation
notebooks/lazy_initialization
notebooks/mixed_precision
notebooks/checkpointing
notebooks/regularization
Expand Down
64 changes: 62 additions & 2 deletions docs/notebooks/layers_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"**Handling weight initalization**\n",
"### Handling weight initalization\n",
"\n",
"Layers that contain `weight_init` or `bias_init` can accept:\n",
"\n",
Expand Down Expand Up @@ -92,7 +92,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"**Handling activation functions**\n",
"### Handling activation functions\n",
"\n",
"Layers that contain `act_func` accepts:\n",
"\n",
Expand Down Expand Up @@ -203,6 +203,66 @@
"linear\n",
"# note the dtype is f16(float16) in the repr output"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Lazy shape inference\n",
"\n",
"Lazy initialization is useful in scenarios where the dimensions of certain input features are not known in advance. For instance, when the number of neurons required for a flattened image input is uncertain, or the shape of the output from a flattened convolutional layer is not straightforward to calculate. In such cases, lazy initialization defers layers materialization until the first input.\n",
"\n",
"In `serket`, simply replace `in_features` with `None` to indicate that this layer is lazy. then materialzie the layer by functionally calling the layer. Recall that functional call - via `.at[method_name](*args, **kwargs)` _always_ returns a tuple of method output and a _new_ instance."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Marking the layer lazy**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import serket as sk\n",
"\n",
"# 5 images from MNIST\n",
"x = jax.numpy.ones([5, 1, 28, 28])\n",
"\n",
"layer = sk.nn.Sequential(\n",
" jax.numpy.ravel,\n",
" # lazy in_features inference pass `None`\n",
" sk.nn.Linear(None, 10),\n",
" jax.nn.relu,\n",
" sk.nn.Linear(10, 10),\n",
" jax.nn.softmax,\n",
")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Materialization by functional call**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# materialize the layer with single image\n",
"_, layer = layer.at[\"__call__\"](x[0])\n",
"# apply on batch\n",
"y = jax.vmap(layer)(x)\n",
"y.shape"
]
}
],
"metadata": {
Expand Down
127 changes: 0 additions & 127 deletions docs/notebooks/lazy_initialization.ipynb

This file was deleted.

0 comments on commit 7e5f1a7

Please sign in to comment.