Skip to content

Commit

Permalink
Expose pooling functional forms (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 authored Dec 11, 2023
1 parent 4339b98 commit 899322e
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 97 deletions.
17 changes: 10 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,18 @@ pip install git+https://github.com/ASEM000/serket

## 📖 Description and motivation<a id="Description"></a>

- `serket` aims to be the most intuitive and easy-to-use machine learning library in `JAX`.
- `serket` aims to be the most intuitive and easy-to-use machine learning library in `jax`.
- `serket` is fully transparent to `jax` transformation (e.g. `vmap`,`grad`,`jit`,...).

### 🏃 Quick example<a id="QuickExample"></a>
## Documentation
- [Documentation](https://serket.readthedocs.io/)
- [Train MNIST, UNet, ConvLSTM, PINN, ...](https://serket.readthedocs.io/training_guides.html)
- [Model surgery, Parallelism, Mixed precision, ...](https://serket.readthedocs.io/core_guides.html)
- [Optimizers, Augmentation composition, ...](https://serket.readthedocs.io/other_guides.html)
- [Interoperability with keras, tensorflow, ...](https://serket.readthedocs.io/interoperability.html)

See [🧠 `serket` mental model](https://serket.readthedocs.io/en/latest/notebooks/mental_model.html) and for examples, see [Training MNIST](https://serket.readthedocs.io/en/latest/notebooks/train_mnist.html)
or [Training Bidirectional-LSTM](https://serket.readthedocs.io/en/latest/notebooks/train_bilstm.html)
or [Training PINN](https://serket.readthedocs.io/en/latest/notebooks/train_pinn_burgers.html#) or [Image augmentation pipelines](https://serket.readthedocs.io/en/latest/notebooks/augmentations.html)

## 🏃 Quick example<a id="QuickExample"></a>

```python
import jax, jax.numpy as jnp
Expand Down Expand Up @@ -86,8 +90,7 @@ net = sk.tree_unmask(net)
| Attention | - `MultiHeadAttention` |
| Convolution | - `{FFT,_}Conv{1D,2D,3D}` <br> - `{FFT,_}Conv{1D,2D,3D}Transpose` <br> - `Depthwise{FFT,_}Conv{1D,2D,3D}` <br> - `Separable{FFT,_}Conv{1D,2D,3D}` <br> - `Conv{1D,2D,3D}Local` <br> - `SpectralConv{1D,2D,3D}` |
| Dropout | - `Dropout`<br> - `Dropout{1D,2D,3D}` <br> - `RandomCutout{1D,2D,3D}` |
| Linear | - `Linear`, `GeneralLinear`, `Identity` |
| Densely connected | - `FNN` , <br> - `MLP` _compile time_ optimized |
| Linear | - `Linear`, `MLP`, `Identity` | |
| Normalization | - `{Layer,Instance,Group,Batch}Norm` |
| Pooling | - `{Avg,Max,LP}Pool{1D,2D,3D}` <br> - `Global{Avg,Max}Pool{1D,2D,3D}` <br> - `Adaptive{Avg,Max}Pool{1D,2D,3D}` |
| Reshaping | - `Upsample{1D,2D,3D}` <br> - `{Random,Center}Crop{1D,2D,3D}` ` |
Expand Down
8 changes: 7 additions & 1 deletion docs/API/pooling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,10 @@ Pooling
.. autoclass:: AdaptiveAvgPool3D
.. autoclass:: AdaptiveMaxPool1D
.. autoclass:: AdaptiveMaxPool2D
.. autoclass:: AdaptiveMaxPool3D
.. autoclass:: AdaptiveMaxPool3D

.. autofunction:: adaptive_avg_pool_nd
.. autofunction:: adaptive_max_pool_nd
.. autofunction:: avg_pool_nd
.. autofunction:: lp_pool_nd
.. autofunction:: max_pool_nd
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
autodoc_default_options = {
"member-order": "bysource",
"special-members": "__call__",
"exclude-members": "__repr__, __str__, __weakref__, at, spatial_ndim, conv_op, filter_op",
"exclude-members": "__repr__, __str__, __weakref__, at, spatial_ndim, conv_op, filter_op, attention_op",
"inherited-members": True,
}

Expand Down
132 changes: 68 additions & 64 deletions docs/notebooks/misc_recipes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,36 +75,33 @@
"import jax.random as jr\n",
"\n",
"\n",
"@sk.autoinit\n",
"class LazyLinear(sk.TreeClass):\n",
" out_features: int\n",
" def __init__(self, out_features: int):\n",
" self.out_features = out_features\n",
"\n",
" def param(self, name: str, value: Any):\n",
" # return the value if it exists, otherwise set it and return it\n",
" if name not in vars(self):\n",
" setattr(self, name, value)\n",
" return vars(self)[name]\n",
"\n",
" def __call__(self, x: jax.Array, *, key: jax.Array = jr.PRNGKey(0)):\n",
" weight = self.param(\"weight\", jnp.ones((x.shape[-1], self.out_features)))\n",
" def __call__(self, input: jax.Array) -> jax.Array:\n",
" weight = self.param(\"weight\", jnp.ones((self.out_features, input.shape[-1])))\n",
" bias = self.param(\"bias\", jnp.zeros((self.out_features,)))\n",
" return x @ weight + bias\n",
"\n",
" return input @ weight.T + bias\n",
"\n",
"x = jnp.ones([10, 1])\n",
"\n",
"lazy_linear = LazyLinear(out_features=1)\n",
"input = jnp.ones([10, 1])\n",
"\n",
"lazy_linear\n",
"print(f\"Layer before param is set:\\t{lazy_linear}\")\n",
"lazy = LazyLinear(out_features=1)\n",
"\n",
"print(f\"Layer before param is set:\\t{lazy}\")\n",
"\n",
"# first call will set the parameters\n",
"_, linear = lazy_linear.at[\"__call__\"](x, key=jr.PRNGKey(0))\n",
"_, material = lazy.at[\"__call__\"](input)\n",
"\n",
"print(f\"Layer after param is set:\\t{linear}\")\n",
"# subsequent calls will use the same parameters and not set them again\n",
"linear(x)"
"print(f\"Layer after param is set:\\t{material}\")\n",
"# subsequent calls will not set the parameters again\n",
"material(input)"
]
},
{
Expand Down Expand Up @@ -432,35 +429,12 @@
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TiedAutoEncoder(\n",
" encoder=Linear(\n",
" in_features=(#1), \n",
" out_features=#10, \n",
" in_axis=(#-1), \n",
" out_axis=#-1, \n",
" weight_init=#glorot_uniform, \n",
" bias_init=#zeros, \n",
" weight=f32[10,1](μ=-0.78, σ=1.11, ∈[-2.58,0.00]), \n",
" bias=f32[10](μ=-0.39, σ=0.55, ∈[-1.29,0.00])\n",
" ), \n",
" decoder=Linear(\n",
" in_features=(#10), \n",
" out_features=#1, \n",
" in_axis=(#-1), \n",
" out_axis=#-1, \n",
" weight_init=#glorot_uniform, \n",
" bias_init=#zeros, \n",
" weight=None, \n",
" bias=f32[1](μ=-2.40, σ=0.00, ∈[-2.40,-2.40])\n",
" )\n",
")"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
"name": "stdout",
"output_type": "stream",
"text": [
"10\n",
"200\n"
]
}
],
"source": [
Expand All @@ -469,31 +443,40 @@
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"\n",
"\n",
"class TiedAutoEncoder(sk.TreeClass):\n",
" def __init__(self, *, key: jax.Array):\n",
" k1, k2 = jr.split(key)\n",
" self.encoder = sk.nn.Linear(1, 10, key=k1)\n",
" # set the unused weight of decoder to `None` to avoid memory usage\n",
" self.decoder = sk.nn.Linear(10, 1, key=k2).at[\"weight\"].set(None)\n",
" k1, k2, k3, k4 = jr.split(key, 4)\n",
" self.enc1 = sk.nn.Linear(1, 10, key=k1)\n",
" self.enc2 = sk.nn.Linear(10, 20, key=k2)\n",
" self.dec2 = sk.nn.Linear(20, 10, key=k3)\n",
" self.dec1 = sk.nn.Linear(10, 1, key=k4)\n",
"\n",
" def _call(self, x):\n",
" def _tied_call(self, input: jax.Array) -> jax.Array:\n",
" # share/tie weights of encoder and decoder\n",
" # however this operation mutates the state\n",
" # so this method will only work with .at\n",
" # otherwise will throw `AttributeError`\n",
" self.decoder.weight = self.encoder.weight.T\n",
" out = self.decoder(jax.nn.relu(self.encoder(x)))\n",
" return out\n",
"\n",
" def __call__(self, x):\n",
" # make the mutating method `_call` work with .at\n",
" # since .at returns a tuple of the method value and a new instance\n",
" # of the class that has the mutated state (i.e. does not mutate in place)\n",
" # then we can define __call__ to return only the result of the method\n",
" # and ignore the new instance of the class\n",
" out, _ = self.at[\"_call\"](x)\n",
" return out\n",
" self.dec1.weight = self.enc1.weight.T\n",
" self.dec2.weight = self.enc2.weight.T\n",
" output = self.enc1(input)\n",
" output = self.enc2(output)\n",
" output = self.dec2(output)\n",
" output = self.dec1(output)\n",
" return output\n",
"\n",
" def tied_call(self, input: jax.Array) -> jax.Array:\n",
" # this method call `_tied_call` with .at to\n",
" # return the output without mutating the state\n",
" output, _ = self.at[\"_tied_call\"](input)\n",
" return output\n",
"\n",
" def non_tied_call(self, x):\n",
" # non-tied call\n",
" output = self.enc1(x)\n",
" output = self.enc2(output)\n",
" output = self.dec2(output)\n",
" output = self.dec1(output)\n",
" return output\n",
"\n",
"\n",
"tree = sk.tree_mask(TiedAutoEncoder(key=jr.PRNGKey(0)))\n",
Expand All @@ -503,15 +486,36 @@
"@jax.grad\n",
"def loss_func(net, x, y):\n",
" net = sk.tree_unmask(net)\n",
" return jnp.mean((jax.vmap(net)(x) - y) ** 2)\n",
" return jnp.mean((jax.vmap(net.tied_call)(x) - y) ** 2)\n",
"\n",
"\n",
"tree = sk.tree_mask(tree)\n",
"x = jnp.ones([10, 1]) + 0.0\n",
"y = jnp.ones([10, 1]) * 2.0\n",
"grads: TiedAutoEncoder = loss_func(tree, x, y)\n",
"\n",
"\n",
"# check that gradients are zero for tied weights (dec1.weight, dec2.weight)\n",
"assert jnp.count_nonzero(grads.dec1.weight) == 0\n",
"assert jnp.count_nonzero(grads.dec2.weight) == 0\n",
"\n",
"\n",
"# check for non-tied call\n",
"@jax.jit\n",
"@jax.grad\n",
"def loss_func(net, x, y):\n",
" net = sk.tree_unmask(net)\n",
" return jnp.mean((jax.vmap(net.non_tied_call)(x) - y) ** 2)\n",
"\n",
"\n",
"tree = sk.tree_mask(tree)\n",
"x = jnp.ones([10, 1]) + 0.0\n",
"y = jnp.ones([10, 1]) * 2.0\n",
"grads: TiedAutoEncoder = loss_func(tree, x, y)\n",
"\n",
"grads"
"# check non-zero gradients for the decoder weights\n",
"print(jnp.count_nonzero(grads.dec1.weight))\n",
"print(jnp.count_nonzero(grads.dec2.weight))"
]
}
],
Expand Down
6 changes: 3 additions & 3 deletions serket/_src/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@ class Linear(sk.TreeClass):
>>> import jax.random as jr
>>> key = jr.PRNGKey(0)
>>> input = jnp.ones((10, 5, 4))
>>> lazy_linear = sk.nn.Linear(None, 12, in_axis=(0, 2), key=key)
>>> _, material_linear = lazy_linear.at["__call__"](input)
>>> material_linear.in_features
>>> lazy = sk.nn.Linear(None, 12, in_axis=(0, 2), key=key)
>>> _, material = lazy.at["__call__"](input)
>>> material.in_features
(10, 4)
"""

Expand Down
Loading

0 comments on commit 899322e

Please sign in to comment.