From baecd10a196e8cd2cd49133b5b39b3d946c485cc Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Fri, 8 Sep 2023 19:02:46 +0900 Subject: [PATCH] docs --- docs/notebooks/regularization.ipynb | 13 ++-- docs/notebooks/train_bilstm.ipynb | 4 +- serket/image/augment.py | 7 +++ serket/image/geometric.py | 94 ++++++++++++++++++++++++++++- 4 files changed, 106 insertions(+), 12 deletions(-) diff --git a/docs/notebooks/regularization.ipynb b/docs/notebooks/regularization.ipynb index ed99ba8..952598b 100644 --- a/docs/notebooks/regularization.ipynb +++ b/docs/notebooks/regularization.ipynb @@ -18,7 +18,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -43,10 +43,8 @@ "\n", "def negative_entries_l2_loss(net: Net):\n", " return (\n", - " # select all first level branches e.g `Net.weight` and `Net.bias`\n", - " net.at[...]\n", " # select all positive array entries\n", - " .at[jax.tree_map(lambda x: x > 0, net)]\n", + " net.at[jax.tree_map(lambda x: x > 0, net)]\n", " # set them to zero to exclude their loss\n", " .set(0)\n", " # select all branches at first level e.g `Net.weight` and `Net.bias`\n", @@ -69,14 +67,14 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "131.01498\n" + "86.17129\n" ] } ], @@ -107,9 +105,8 @@ " return (\n", " # select desired branches (linear1, linear2 in this example)\n", " # we can also select branches by a valid regex expression\n", - " net.at[re.compile(\"linear[1,2]\")]\n", " # select weight leaf\n", - " .at[\"weight\"]\n", + " net.at[re.compile(\"linear[1,2]\")][\"weight\"]\n", " # apply l1 loss\n", " .reduce(lambda x, y: x + jnp.sum(jnp.abs(y)), initializer=0)\n", " )\n", diff --git a/docs/notebooks/train_bilstm.ipynb b/docs/notebooks/train_bilstm.ipynb index e86e8be..569d128 100644 --- a/docs/notebooks/train_bilstm.ipynb +++ b/docs/notebooks/train_bilstm.ipynb @@ -156,13 +156,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 100/100\tBatch: 100/100\tBatch loss: 1.485198e-03\tTime: 0.021\r" + "Epoch: 100/100\tBatch: 100/100\tBatch loss: 1.485198e-03\tTime: 0.018\r" ] }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 5, diff --git a/serket/image/augment.py b/serket/image/augment.py index d02c897..47b30ca 100644 --- a/serket/image/augment.py +++ b/serket/image/augment.py @@ -130,6 +130,13 @@ def spatial_ndim(self) -> int: class RandomContrast2D(sk.TreeClass): """Randomly adjusts the contrast of an 1D input by scaling the pixel values by a factor. + Args: + contrast_range: contrast range to adust the contrast by. Defaults to (0.5, 1). + + Note: + - Use :func:`tree_eval` to replace this layer with :class:`Identity` during + evaluation. + Reference: - https://www.tensorflow.org/api_docs/python/tf/image/adjust_contrast - https://github.com/deepmind/dm_pix/blob/master/dm_pix/_src/augment.py diff --git a/serket/image/geometric.py b/serket/image/geometric.py index c7bf5ce..cfae1bb 100644 --- a/serket/image/geometric.py +++ b/serket/image/geometric.py @@ -130,6 +130,21 @@ class RandomRotate2D(sk.TreeClass): Args: angle_range: a tuple of min angle and max angle to randdomly choose from. + Note: + - Use :func:`tree_eval` to replace this layer with :class:`Identity` during + evaluation. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> x = jnp.arange(1, 17).reshape(1, 4, 4) + >>> layer = sk.image.RandomRotate2D((10, 30)) + >>> eval_layer = sk.tree_eval(layer) + >>> print(eval_layer(x)) + [[[ 1 2 3 4] + [ 5 6 7 8] + [ 9 10 11 12] + [13 14 15 16]]] + Example: >>> import serket as sk >>> import jax @@ -205,12 +220,27 @@ class RandomHorizontalShear2D(sk.TreeClass): Args: angle_range: a tuple of min angle and max angle to randdomly choose from. + Note: + - Use :func:`tree_eval` to replace this layer with :class:`Identity` during + evaluation. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> x = jnp.arange(1, 17).reshape(1, 4, 4) + >>> layer = sk.image.RandomHorizontalShear2D((45, 45)) + >>> eval_layer = sk.tree_eval(layer) + >>> print(eval_layer(x)) + [[[ 1 2 3 4] + [ 5 6 7 8] + [ 9 10 11 12] + [13 14 15 16]]] + Example: >>> import serket as sk >>> import jax >>> import jax.numpy as jnp >>> x = jnp.arange(1, 26).reshape(1, 5, 5) - >>> print(sk.image.RandomHorizontalShear2D((45,45))(x)) + >>> print(sk.image.RandomHorizontalShear2D((45, 45))(x)) [[[ 0 0 1 2 3] [ 0 6 7 8 9] [11 12 13 14 15] @@ -279,12 +309,27 @@ class RandomVerticalShear2D(sk.TreeClass): Args: angle_range: a tuple of min angle and max angle to randdomly choose from. + Note: + - Use :func:`tree_eval` to replace this layer with :class:`Identity` during + evaluation. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> x = jnp.arange(1, 17).reshape(1, 4, 4) + >>> layer = sk.image.RandomVerticalShear2D((45, 45)) + >>> eval_layer = sk.tree_eval(layer) + >>> print(eval_layer(x)) + [[[ 1 2 3 4] + [ 5 6 7 8] + [ 9 10 11 12] + [13 14 15 16]]] + Example: >>> import serket as sk >>> import jax >>> import jax.numpy as jnp >>> x = jnp.arange(1, 26).reshape(1, 5, 5) - >>> print(sk.image.RandomVerticalShear2D((45,45))(x)) + >>> print(sk.image.RandomVerticalShear2D((45, 45))(x)) [[[ 0 0 3 9 15] [ 0 2 8 14 20] [ 1 7 13 19 25] @@ -405,6 +450,21 @@ class RandomPerspective2D(sk.TreeClass): lead to higher degree of perspective transform. default to 1.0. 0.0 means no perspective transform. + Note: + - Use :func:`tree_eval` to replace this layer with :class:`Identity` during + evaluation. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> x = jnp.arange(1, 17).reshape(1, 4, 4) + >>> layer = sk.image.RandomPerspective2D(100) + >>> eval_layer = sk.tree_eval(layer) + >>> print(eval_layer(x)) + [[[ 1 2 3 4] + [ 5 6 7 8] + [ 9 10 11 12] + [13 14 15 16]]] + Example: >>> import serket as sk >>> import jax.numpy as jnp @@ -640,6 +700,21 @@ def spatial_ndim(self) -> int: class RandomHorizontalTranslate2D(sk.TreeClass): """Translate an image horizontally by a random pixel value. + Note: + - Use :func:`tree_eval` to replace this layer with :class:`Identity` during + evaluation. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> x = jnp.arange(1, 17).reshape(1, 4, 4) + >>> layer = sk.image.RandomHorizontalTranslate2D() + >>> eval_layer = sk.tree_eval(layer) + >>> print(eval_layer(x)) + [[[ 1 2 3 4] + [ 5 6 7 8] + [ 9 10 11 12] + [13 14 15 16]]] + Example: >>> import serket as sk >>> import jax.numpy as jnp @@ -668,6 +743,21 @@ def spatial_ndim(self) -> int: class RandomVerticalTranslate2D(sk.TreeClass): """Translate an image vertically by a random pixel value. + Note: + - Use :func:`tree_eval` to replace this layer with :class:`Identity` during + evaluation. + + >>> import serket as sk + >>> import jax.numpy as jnp + >>> x = jnp.arange(1, 17).reshape(1, 4, 4) + >>> layer = sk.image.RandomVerticalTranslate2D() + >>> eval_layer = sk.tree_eval(layer) + >>> print(eval_layer(x)) + [[[ 1 2 3 4] + [ 5 6 7 8] + [ 9 10 11 12] + [13 14 15 16]]] + Example: >>> import serket as sk >>> import jax.numpy as jnp