Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Sep 8, 2023
1 parent 8422fa1 commit baecd10
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 12 deletions.
13 changes: 5 additions & 8 deletions docs/notebooks/regularization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 1,
"metadata": {},
"outputs": [
{
Expand All @@ -43,10 +43,8 @@
"\n",
"def negative_entries_l2_loss(net: Net):\n",
" return (\n",
" # select all first level branches e.g `Net.weight` and `Net.bias`\n",
" net.at[...]\n",
" # select all positive array entries\n",
" .at[jax.tree_map(lambda x: x > 0, net)]\n",
" net.at[jax.tree_map(lambda x: x > 0, net)]\n",
" # set them to zero to exclude their loss\n",
" .set(0)\n",
" # select all branches at first level e.g `Net.weight` and `Net.bias`\n",
Expand All @@ -69,14 +67,14 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"131.01498\n"
"86.17129\n"
]
}
],
Expand Down Expand Up @@ -107,9 +105,8 @@
" return (\n",
" # select desired branches (linear1, linear2 in this example)\n",
" # we can also select branches by a valid regex expression\n",
" net.at[re.compile(\"linear[1,2]\")]\n",
" # select weight leaf\n",
" .at[\"weight\"]\n",
" net.at[re.compile(\"linear[1,2]\")][\"weight\"]\n",
" # apply l1 loss\n",
" .reduce(lambda x, y: x + jnp.sum(jnp.abs(y)), initializer=0)\n",
" )\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/train_bilstm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 100/100\tBatch: 100/100\tBatch loss: 1.485198e-03\tTime: 0.021\r"
"Epoch: 100/100\tBatch: 100/100\tBatch loss: 1.485198e-03\tTime: 0.018\r"
]
},
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x1660b5480>"
"<matplotlib.legend.Legend at 0x16d857220>"
]
},
"execution_count": 5,
Expand Down
7 changes: 7 additions & 0 deletions serket/image/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,13 @@ def spatial_ndim(self) -> int:
class RandomContrast2D(sk.TreeClass):
"""Randomly adjusts the contrast of an 1D input by scaling the pixel values by a factor.
Args:
contrast_range: contrast range to adust the contrast by. Defaults to (0.5, 1).
Note:
- Use :func:`tree_eval` to replace this layer with :class:`Identity` during
evaluation.
Reference:
- https://www.tensorflow.org/api_docs/python/tf/image/adjust_contrast
- https://github.com/deepmind/dm_pix/blob/master/dm_pix/_src/augment.py
Expand Down
94 changes: 92 additions & 2 deletions serket/image/geometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,21 @@ class RandomRotate2D(sk.TreeClass):
Args:
angle_range: a tuple of min angle and max angle to randdomly choose from.
Note:
- Use :func:`tree_eval` to replace this layer with :class:`Identity` during
evaluation.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 17).reshape(1, 4, 4)
>>> layer = sk.image.RandomRotate2D((10, 30))
>>> eval_layer = sk.tree_eval(layer)
>>> print(eval_layer(x))
[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]]
Example:
>>> import serket as sk
>>> import jax
Expand Down Expand Up @@ -205,12 +220,27 @@ class RandomHorizontalShear2D(sk.TreeClass):
Args:
angle_range: a tuple of min angle and max angle to randdomly choose from.
Note:
- Use :func:`tree_eval` to replace this layer with :class:`Identity` during
evaluation.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 17).reshape(1, 4, 4)
>>> layer = sk.image.RandomHorizontalShear2D((45, 45))
>>> eval_layer = sk.tree_eval(layer)
>>> print(eval_layer(x))
[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]]
Example:
>>> import serket as sk
>>> import jax
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 26).reshape(1, 5, 5)
>>> print(sk.image.RandomHorizontalShear2D((45,45))(x))
>>> print(sk.image.RandomHorizontalShear2D((45, 45))(x))
[[[ 0 0 1 2 3]
[ 0 6 7 8 9]
[11 12 13 14 15]
Expand Down Expand Up @@ -279,12 +309,27 @@ class RandomVerticalShear2D(sk.TreeClass):
Args:
angle_range: a tuple of min angle and max angle to randdomly choose from.
Note:
- Use :func:`tree_eval` to replace this layer with :class:`Identity` during
evaluation.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 17).reshape(1, 4, 4)
>>> layer = sk.image.RandomVerticalShear2D((45, 45))
>>> eval_layer = sk.tree_eval(layer)
>>> print(eval_layer(x))
[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]]
Example:
>>> import serket as sk
>>> import jax
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 26).reshape(1, 5, 5)
>>> print(sk.image.RandomVerticalShear2D((45,45))(x))
>>> print(sk.image.RandomVerticalShear2D((45, 45))(x))
[[[ 0 0 3 9 15]
[ 0 2 8 14 20]
[ 1 7 13 19 25]
Expand Down Expand Up @@ -405,6 +450,21 @@ class RandomPerspective2D(sk.TreeClass):
lead to higher degree of perspective transform. default to 1.0. 0.0
means no perspective transform.
Note:
- Use :func:`tree_eval` to replace this layer with :class:`Identity` during
evaluation.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 17).reshape(1, 4, 4)
>>> layer = sk.image.RandomPerspective2D(100)
>>> eval_layer = sk.tree_eval(layer)
>>> print(eval_layer(x))
[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]]
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
Expand Down Expand Up @@ -640,6 +700,21 @@ def spatial_ndim(self) -> int:
class RandomHorizontalTranslate2D(sk.TreeClass):
"""Translate an image horizontally by a random pixel value.
Note:
- Use :func:`tree_eval` to replace this layer with :class:`Identity` during
evaluation.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 17).reshape(1, 4, 4)
>>> layer = sk.image.RandomHorizontalTranslate2D()
>>> eval_layer = sk.tree_eval(layer)
>>> print(eval_layer(x))
[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]]
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
Expand Down Expand Up @@ -668,6 +743,21 @@ def spatial_ndim(self) -> int:
class RandomVerticalTranslate2D(sk.TreeClass):
"""Translate an image vertically by a random pixel value.
Note:
- Use :func:`tree_eval` to replace this layer with :class:`Identity` during
evaluation.
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 17).reshape(1, 4, 4)
>>> layer = sk.image.RandomVerticalTranslate2D()
>>> eval_layer = sk.tree_eval(layer)
>>> print(eval_layer(x))
[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]]
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
Expand Down

0 comments on commit baecd10

Please sign in to comment.