Skip to content

Commit

Permalink
custom convolution
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Dec 7, 2023
1 parent 6c7458c commit f4300f1
Show file tree
Hide file tree
Showing 2 changed files with 320 additions and 1 deletion.
318 changes: 318 additions & 0 deletions docs/notebooks/custom_convolutions.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 🎛 Custom convolutions\n",
"\n",
"In this notebook, overriding the convolution layers operation is demonstrated using [kernex](https://github.com/ASEM000/kernex/blob/main/README.md). By defining only the kernel operation, the layer can be used in the same way as the original layer and parameter initialization/shape checking is handled automatically."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Direct convolution\n",
"\n",
"This example demonstrates how to recreate the convolution operation using the `kernex` library. `kernex` offers function transformation similar to `jax.vmap`, that wraps a kernel operation (e.g. `lambda input,kernel: sum(input*kernel)`) and returns a function that works on array views."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"!pip install git+https://github.com/ASEM000/serket --quiet\n",
"!pip install kernex --quiet"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import kernex as kex # for stencil operations like convolutions\n",
"import serket as sk\n",
"import jax\n",
"import jax.random as jr\n",
"import jax.numpy as jnp\n",
"import numpy.testing as npt\n",
"\n",
"\n",
"def my_conv(\n",
" input: jax.Array,\n",
" weight: jax.Array,\n",
" bias: jax.Array | None,\n",
" strides: tuple[int, ...],\n",
" padding: tuple[tuple[int, int], ...],\n",
" dilation: tuple[int, ...],\n",
" groups: int,\n",
" mask: jax.Array | None,\n",
"):\n",
" # same function signature as serket.nn.conv_nd\n",
" del mask #\n",
" del dilation # for simplicity\n",
" del groups # for simplicity\n",
" _, in_features, *kernel_size = weight.shape\n",
"\n",
" @kex.kmap(\n",
" kernel_size=(in_features, *kernel_size),\n",
" strides=(1, *strides),\n",
" padding=((0, 0), *padding),\n",
" )\n",
" def conv_func(input, weight):\n",
" # define the kernel operation\n",
" return jnp.sum(input * weight)\n",
"\n",
" # vectorize over the out_features of the weight\n",
" out = jax.vmap(conv_func, in_axes=(None, 0))(input, weight)\n",
" # squeeze out the vmapped axis\n",
" out = jnp.squeeze(out, axis=1)\n",
" return out + bias if bias is not None else out\n",
"\n",
"\n",
"class CustomConv2D(sk.nn.Conv2D):\n",
" # override the conv_op\n",
" conv_op = my_conv\n",
"\n",
"\n",
"k1, k2 = jr.split(jr.PRNGKey(0), 2)\n",
"\n",
"basic_conv = sk.nn.Conv2D(\n",
" in_features=1,\n",
" out_features=2,\n",
" kernel_size=3,\n",
" bias_init=None,\n",
" key=k1,\n",
")\n",
"\n",
"custom_conv = CustomConv2D(\n",
" in_features=1,\n",
" out_features=2,\n",
" kernel_size=3,\n",
" bias_init=None,\n",
" key=k1,\n",
")\n",
"\n",
"# channel-first input\n",
"input = jr.uniform(k2, shape=(1, 10, 10))\n",
"\n",
"npt.assert_allclose(\n",
" basic_conv(input),\n",
" custom_conv(input),\n",
" atol=1e-6,\n",
")\n",
"# lets check gradients\n",
"npt.assert_allclose(\n",
" jax.grad(lambda x: basic_conv(x).sum())(input),\n",
" jax.grad(lambda x: custom_conv(x).sum())(input),\n",
" atol=1e-6,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Depthwise convolution\n",
"\n",
"Similar to the above example, For recreating depthwise convolution, the only addition is to add vectorize the kernel operation over the channels dimension using `jax.vmap`"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import kernex as kex # for stencil operations like convolutions\n",
"import jax\n",
"import jax.random as jr\n",
"import jax.numpy as jnp\n",
"import numpy.testing as npt\n",
"\n",
"\n",
"def my_depthwise_conv(\n",
" input: jax.Array,\n",
" weight: jax.Array,\n",
" bias: jax.Array | None,\n",
" strides: tuple[int, ...],\n",
" padding: tuple[tuple[int, int], ...],\n",
" mask: jax.Array | None,\n",
"):\n",
" # same function signature as serket.nn.depthwise_conv_nd\n",
" del mask #\n",
" _, _, *kernel_size = weight.shape\n",
"\n",
" @jax.vmap # <- vectorize over the input channels\n",
" @kex.kmap(\n",
" kernel_size=tuple(kernel_size),\n",
" strides=strides,\n",
" padding=padding,\n",
" )\n",
" def conv_func(input, weight):\n",
" # define the kernel operation\n",
" return jnp.sum(input * weight)\n",
"\n",
" # vectorize over the output channels (filters)\n",
" out = jax.vmap(conv_func, in_axes=(None, 0))(input, weight)\n",
" out = jnp.squeeze(out, axis=1) # squeeze out the vmapped axis\n",
" return out + bias if bias is not None else out\n",
"\n",
"\n",
"class CustomDepthwiseConv2D(sk.nn.DepthwiseConv2D):\n",
" # override the conv_op\n",
" conv_op = my_depthwise_conv\n",
"\n",
"\n",
"k1, k2 = jr.split(jr.PRNGKey(0), 2)\n",
"\n",
"basic_conv = sk.nn.DepthwiseConv2D(\n",
" in_features=1,\n",
" depth_multiplier=2,\n",
" kernel_size=3,\n",
" bias_init=None,\n",
" key=k1,\n",
")\n",
"\n",
"custom_conv = CustomDepthwiseConv2D(\n",
" in_features=1,\n",
" depth_multiplier=2,\n",
" kernel_size=3,\n",
" bias_init=None,\n",
" key=k1,\n",
")\n",
"\n",
"# channel-first input\n",
"input = jr.uniform(k2, shape=(1, 10, 10))\n",
"\n",
"npt.assert_allclose(\n",
" basic_conv(input),\n",
" custom_conv(input),\n",
" atol=1e-6,\n",
")\n",
"# lets check gradients\n",
"npt.assert_allclose(\n",
" jax.grad(lambda x: basic_conv(x).sum())(input),\n",
" jax.grad(lambda x: custom_conv(x).sum())(input),\n",
" atol=1e-6,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Positive kernel convolution\n",
"\n",
"In this example, a custom convolution operation is defined. As a toy examaple the operation will only multiply weight values\n",
"that are not zero."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(2, 10, 10)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import kernex as kex # for stencil operations like convolutions\n",
"import serket as sk\n",
"import jax\n",
"import jax.random as jr\n",
"import jax.numpy as jnp\n",
"import numpy.testing as npt\n",
"\n",
"\n",
"def my_custom_conv(\n",
" input: jax.Array,\n",
" weight: jax.Array,\n",
" bias: jax.Array | None,\n",
" strides: tuple[int, ...],\n",
" padding: tuple[tuple[int, int], ...],\n",
" dilation: tuple[int, ...],\n",
" groups: int,\n",
" mask: jax.Array | None,\n",
"):\n",
" # same function signature as serket.nn.conv_nd\n",
" del mask #\n",
" del dilation # for simplicity\n",
" del groups # for simplicity\n",
" _, in_features, *kernel_size = weight.shape\n",
"\n",
" @kex.kmap(\n",
" kernel_size=(in_features, *kernel_size),\n",
" strides=(1, *strides),\n",
" padding=((0, 0), *padding),\n",
" )\n",
" def conv_func(input, weight):\n",
" # define a custom kernel operation\n",
" # that only multiplies the input with the weight\n",
" # if the weight is positive\n",
" return jnp.sum(input * jnp.where(weight < 0, 0, weight))\n",
"\n",
" # vectorize over the out_features of the weight\n",
" out = jax.vmap(conv_func, in_axes=(None, 0))(input, weight)\n",
" # squeeze out the vmapped axis\n",
" out = jnp.squeeze(out, axis=1)\n",
" return out + bias if bias is not None else out\n",
"\n",
"\n",
"class CustomConv2D(sk.nn.Conv2D):\n",
" # override the conv_op\n",
" conv_op = my_custom_conv\n",
"\n",
"\n",
"k1, k2 = jr.split(jr.PRNGKey(0), 2)\n",
"\n",
"\n",
"custom_conv = CustomConv2D(\n",
" in_features=1,\n",
" out_features=2,\n",
" kernel_size=3,\n",
" bias_init=None,\n",
" key=k1,\n",
")\n",
"\n",
"# channel-first input\n",
"input = jr.uniform(k2, shape=(1, 10, 10))\n",
"\n",
"basic_conv(input).shape"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "py311",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
3 changes: 2 additions & 1 deletion docs/other_guides.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
notebooks/hyperparam
notebooks/loss_landscape
notebooks/augmentations
notebooks/deep_ensembles
notebooks/deep_ensembles
notebooks/custom_convolutions

0 comments on commit f4300f1

Please sign in to comment.