From f4300f1d03697da2c63a70a1a17a9592b0ccd452 Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Thu, 7 Dec 2023 18:21:02 +0900 Subject: [PATCH] custom convolution --- docs/notebooks/custom_convolutions.ipynb | 318 +++++++++++++++++++++++ docs/other_guides.rst | 3 +- 2 files changed, 320 insertions(+), 1 deletion(-) create mode 100644 docs/notebooks/custom_convolutions.ipynb diff --git a/docs/notebooks/custom_convolutions.ipynb b/docs/notebooks/custom_convolutions.ipynb new file mode 100644 index 0000000..ee7eba2 --- /dev/null +++ b/docs/notebooks/custom_convolutions.ipynb @@ -0,0 +1,318 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 🎛 Custom convolutions\n", + "\n", + "In this notebook, overriding the convolution layers operation is demonstrated using [kernex](https://github.com/ASEM000/kernex/blob/main/README.md). By defining only the kernel operation, the layer can be used in the same way as the original layer and parameter initialization/shape checking is handled automatically." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Direct convolution\n", + "\n", + "This example demonstrates how to recreate the convolution operation using the `kernex` library. `kernex` offers function transformation similar to `jax.vmap`, that wraps a kernel operation (e.g. `lambda input,kernel: sum(input*kernel)`) and returns a function that works on array views." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install git+https://github.com/ASEM000/serket --quiet\n", + "!pip install kernex --quiet" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import kernex as kex # for stencil operations like convolutions\n", + "import serket as sk\n", + "import jax\n", + "import jax.random as jr\n", + "import jax.numpy as jnp\n", + "import numpy.testing as npt\n", + "\n", + "\n", + "def my_conv(\n", + " input: jax.Array,\n", + " weight: jax.Array,\n", + " bias: jax.Array | None,\n", + " strides: tuple[int, ...],\n", + " padding: tuple[tuple[int, int], ...],\n", + " dilation: tuple[int, ...],\n", + " groups: int,\n", + " mask: jax.Array | None,\n", + "):\n", + " # same function signature as serket.nn.conv_nd\n", + " del mask #\n", + " del dilation # for simplicity\n", + " del groups # for simplicity\n", + " _, in_features, *kernel_size = weight.shape\n", + "\n", + " @kex.kmap(\n", + " kernel_size=(in_features, *kernel_size),\n", + " strides=(1, *strides),\n", + " padding=((0, 0), *padding),\n", + " )\n", + " def conv_func(input, weight):\n", + " # define the kernel operation\n", + " return jnp.sum(input * weight)\n", + "\n", + " # vectorize over the out_features of the weight\n", + " out = jax.vmap(conv_func, in_axes=(None, 0))(input, weight)\n", + " # squeeze out the vmapped axis\n", + " out = jnp.squeeze(out, axis=1)\n", + " return out + bias if bias is not None else out\n", + "\n", + "\n", + "class CustomConv2D(sk.nn.Conv2D):\n", + " # override the conv_op\n", + " conv_op = my_conv\n", + "\n", + "\n", + "k1, k2 = jr.split(jr.PRNGKey(0), 2)\n", + "\n", + "basic_conv = sk.nn.Conv2D(\n", + " in_features=1,\n", + " out_features=2,\n", + " kernel_size=3,\n", + " bias_init=None,\n", + " key=k1,\n", + ")\n", + "\n", + "custom_conv = CustomConv2D(\n", + " in_features=1,\n", + " out_features=2,\n", + " kernel_size=3,\n", + " bias_init=None,\n", + " key=k1,\n", + ")\n", + "\n", + "# channel-first input\n", + "input = jr.uniform(k2, shape=(1, 10, 10))\n", + "\n", + "npt.assert_allclose(\n", + " basic_conv(input),\n", + " custom_conv(input),\n", + " atol=1e-6,\n", + ")\n", + "# lets check gradients\n", + "npt.assert_allclose(\n", + " jax.grad(lambda x: basic_conv(x).sum())(input),\n", + " jax.grad(lambda x: custom_conv(x).sum())(input),\n", + " atol=1e-6,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Depthwise convolution\n", + "\n", + "Similar to the above example, For recreating depthwise convolution, the only addition is to add vectorize the kernel operation over the channels dimension using `jax.vmap`" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import kernex as kex # for stencil operations like convolutions\n", + "import jax\n", + "import jax.random as jr\n", + "import jax.numpy as jnp\n", + "import numpy.testing as npt\n", + "\n", + "\n", + "def my_depthwise_conv(\n", + " input: jax.Array,\n", + " weight: jax.Array,\n", + " bias: jax.Array | None,\n", + " strides: tuple[int, ...],\n", + " padding: tuple[tuple[int, int], ...],\n", + " mask: jax.Array | None,\n", + "):\n", + " # same function signature as serket.nn.depthwise_conv_nd\n", + " del mask #\n", + " _, _, *kernel_size = weight.shape\n", + "\n", + " @jax.vmap # <- vectorize over the input channels\n", + " @kex.kmap(\n", + " kernel_size=tuple(kernel_size),\n", + " strides=strides,\n", + " padding=padding,\n", + " )\n", + " def conv_func(input, weight):\n", + " # define the kernel operation\n", + " return jnp.sum(input * weight)\n", + "\n", + " # vectorize over the output channels (filters)\n", + " out = jax.vmap(conv_func, in_axes=(None, 0))(input, weight)\n", + " out = jnp.squeeze(out, axis=1) # squeeze out the vmapped axis\n", + " return out + bias if bias is not None else out\n", + "\n", + "\n", + "class CustomDepthwiseConv2D(sk.nn.DepthwiseConv2D):\n", + " # override the conv_op\n", + " conv_op = my_depthwise_conv\n", + "\n", + "\n", + "k1, k2 = jr.split(jr.PRNGKey(0), 2)\n", + "\n", + "basic_conv = sk.nn.DepthwiseConv2D(\n", + " in_features=1,\n", + " depth_multiplier=2,\n", + " kernel_size=3,\n", + " bias_init=None,\n", + " key=k1,\n", + ")\n", + "\n", + "custom_conv = CustomDepthwiseConv2D(\n", + " in_features=1,\n", + " depth_multiplier=2,\n", + " kernel_size=3,\n", + " bias_init=None,\n", + " key=k1,\n", + ")\n", + "\n", + "# channel-first input\n", + "input = jr.uniform(k2, shape=(1, 10, 10))\n", + "\n", + "npt.assert_allclose(\n", + " basic_conv(input),\n", + " custom_conv(input),\n", + " atol=1e-6,\n", + ")\n", + "# lets check gradients\n", + "npt.assert_allclose(\n", + " jax.grad(lambda x: basic_conv(x).sum())(input),\n", + " jax.grad(lambda x: custom_conv(x).sum())(input),\n", + " atol=1e-6,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Positive kernel convolution\n", + "\n", + "In this example, a custom convolution operation is defined. As a toy examaple the operation will only multiply weight values\n", + "that are not zero." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(2, 10, 10)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import kernex as kex # for stencil operations like convolutions\n", + "import serket as sk\n", + "import jax\n", + "import jax.random as jr\n", + "import jax.numpy as jnp\n", + "import numpy.testing as npt\n", + "\n", + "\n", + "def my_custom_conv(\n", + " input: jax.Array,\n", + " weight: jax.Array,\n", + " bias: jax.Array | None,\n", + " strides: tuple[int, ...],\n", + " padding: tuple[tuple[int, int], ...],\n", + " dilation: tuple[int, ...],\n", + " groups: int,\n", + " mask: jax.Array | None,\n", + "):\n", + " # same function signature as serket.nn.conv_nd\n", + " del mask #\n", + " del dilation # for simplicity\n", + " del groups # for simplicity\n", + " _, in_features, *kernel_size = weight.shape\n", + "\n", + " @kex.kmap(\n", + " kernel_size=(in_features, *kernel_size),\n", + " strides=(1, *strides),\n", + " padding=((0, 0), *padding),\n", + " )\n", + " def conv_func(input, weight):\n", + " # define a custom kernel operation\n", + " # that only multiplies the input with the weight\n", + " # if the weight is positive\n", + " return jnp.sum(input * jnp.where(weight < 0, 0, weight))\n", + "\n", + " # vectorize over the out_features of the weight\n", + " out = jax.vmap(conv_func, in_axes=(None, 0))(input, weight)\n", + " # squeeze out the vmapped axis\n", + " out = jnp.squeeze(out, axis=1)\n", + " return out + bias if bias is not None else out\n", + "\n", + "\n", + "class CustomConv2D(sk.nn.Conv2D):\n", + " # override the conv_op\n", + " conv_op = my_custom_conv\n", + "\n", + "\n", + "k1, k2 = jr.split(jr.PRNGKey(0), 2)\n", + "\n", + "\n", + "custom_conv = CustomConv2D(\n", + " in_features=1,\n", + " out_features=2,\n", + " kernel_size=3,\n", + " bias_init=None,\n", + " key=k1,\n", + ")\n", + "\n", + "# channel-first input\n", + "input = jr.uniform(k2, shape=(1, 10, 10))\n", + "\n", + "basic_conv(input).shape" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "py311", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/other_guides.rst b/docs/other_guides.rst index a547a71..91d2ac4 100644 --- a/docs/other_guides.rst +++ b/docs/other_guides.rst @@ -10,4 +10,5 @@ notebooks/hyperparam notebooks/loss_landscape notebooks/augmentations - notebooks/deep_ensembles \ No newline at end of file + notebooks/deep_ensembles + notebooks/custom_convolutions \ No newline at end of file