From 4b9188fa04523e4c1a1be002eab9f4fd35378a28 Mon Sep 17 00:00:00 2001 From: ASEM000 Date: Tue, 8 Aug 2023 19:18:09 +0300 Subject: [PATCH] add `Solarize2D` --- CHANEGLOG.md | 11 +++++++++ serket/nn/__init__.py | 2 ++ serket/nn/image.py | 47 ++++++++++++++++++++++++++++++++++++++ tests/test_image_filter.py | 21 ++++++++++++++++- 4 files changed, 80 insertions(+), 1 deletion(-) diff --git a/CHANEGLOG.md b/CHANEGLOG.md index d1decf6..e07c7d7 100644 --- a/CHANEGLOG.md +++ b/CHANEGLOG.md @@ -37,6 +37,17 @@ return SimpleRNNState(jnp.zeros([cell.hidden_features])) ``` +- `MultiHeadAttention` +- `BatchNorm` +- `RandomHorizontalShear2D` +- `RandomPerspective2D` +- `RandomRotate2D` +- `RandomVerticalShear2D` +- `Rotate2D` +- `VerticalShear2D` +- `Pixelate2D` +- `Solarize2D` + ### Deprecations - `Bilinear` is deprecated, use `Multilinear((in1_features, in2_features), out_features)` diff --git a/serket/nn/__init__.py b/serket/nn/__init__.py index 8743658..695d578 100644 --- a/serket/nn/__init__.py +++ b/serket/nn/__init__.py @@ -102,6 +102,7 @@ RandomRotate2D, RandomVerticalShear2D, Rotate2D, + Solarize2D, VerticalShear2D, ) from .linear import FNN, MLP, Embedding, GeneralLinear, Identity, Linear, Multilinear @@ -279,6 +280,7 @@ "RandomRotate2D", "RandomVerticalShear2D", "Rotate2D", + "Solarize2D", "VerticalShear2D", # pooling "AdaptiveAvgPool1D", diff --git a/serket/nn/image.py b/serket/nn/image.py index 925e2ec..30165fb 100644 --- a/serket/nn/image.py +++ b/serket/nn/image.py @@ -908,6 +908,53 @@ def spatial_ndim(self) -> int: return 2 +def solarize( + image: jax.Array, + threshold: float | int, + max_val: float | int, +) -> jax.Array: + """Inverts all values above a given threshold.""" + return jnp.where(image < threshold, image, max_val - image) + + +@sk.autoinit +class Solarize2D(sk.TreeClass): + """Inverts all values above a given threshold. + + Args: + threshold: The threshold value above which to invert. + max_val: The maximum value of the image. e.g. 255 for uint8 images. + 1.0 for float images. default: 1.0 + + Example: + >>> import serket as sk + >>> import jax.numpy as jnp + >>> x = jnp.arange(1, 26).reshape(1, 5, 5) + >>> layer = sk.nn.Solarize2D(threshold=10, max_val=25) + >>> print(layer(x)) + [[[ 1 2 3 4 5] + [ 6 7 8 9 15] + [14 13 12 11 10] + [ 9 8 7 6 5] + [ 4 3 2 1 0]]] + + Reference: + - https://github.com/tensorflow/models/blob/v2.13.1/official/vision/ops/augment.py#L804-L809 + """ + + threshold: float + max_val: float = 1.0 + + @ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim") + def __call__(self, x: jax.Array, **k) -> jax.Array: + threshold, max_val = jax.lax.stop_gradient((self.threshold, self.max_val)) + return solarize(x, threshold, max_val) + + @property + def spatial_ndim(self) -> int: + return 2 + + @tree_eval.def_eval(RandomContrast2D) @tree_eval.def_eval(RandomRotate2D) @tree_eval.def_eval(RandomHorizontalShear2D) diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py index ba445e4..b5fa449 100644 --- a/tests/test_image_filter.py +++ b/tests/test_image_filter.py @@ -18,7 +18,7 @@ import numpy.testing as npt import pytest -from serket.nn import AvgBlur2D, FFTFilter2D, Filter2D, GaussianBlur2D +from serket.nn import AvgBlur2D, FFTFilter2D, Filter2D, GaussianBlur2D, Solarize2D def test_AvgBlur2D(): @@ -99,3 +99,22 @@ def test_filter2d(): layer2 = FFTFilter2D(in_features=1, kernel=jnp.ones([3, 3]) / 9.0) npt.assert_allclose(layer(x), layer2(x), atol=1e-4) + + +def test_solarize2d(): + x = jnp.arange(1, 26).reshape(1, 5, 5) + layer = Solarize2D(threshold=10, max_val=25) + npt.assert_allclose( + layer(x), + jnp.array( + [ + [ + [1, 2, 3, 4, 5], + [6, 7, 8, 9, 15], + [14, 13, 12, 11, 10], + [9, 8, 7, 6, 5], + [4, 3, 2, 1, 0], + ] + ] + ), + )