Skip to content

Commit

Permalink
add Solarize2D
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Aug 8, 2023
1 parent 10b64c3 commit 4b9188f
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 1 deletion.
11 changes: 11 additions & 0 deletions CHANEGLOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@
return SimpleRNNState(jnp.zeros([cell.hidden_features]))
```

- `MultiHeadAttention`
- `BatchNorm`
- `RandomHorizontalShear2D`
- `RandomPerspective2D`
- `RandomRotate2D`
- `RandomVerticalShear2D`
- `Rotate2D`
- `VerticalShear2D`
- `Pixelate2D`
- `Solarize2D`

### Deprecations

- `Bilinear` is deprecated, use `Multilinear((in1_features, in2_features), out_features)`
2 changes: 2 additions & 0 deletions serket/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
RandomRotate2D,
RandomVerticalShear2D,
Rotate2D,
Solarize2D,
VerticalShear2D,
)
from .linear import FNN, MLP, Embedding, GeneralLinear, Identity, Linear, Multilinear
Expand Down Expand Up @@ -279,6 +280,7 @@
"RandomRotate2D",
"RandomVerticalShear2D",
"Rotate2D",
"Solarize2D",
"VerticalShear2D",
# pooling
"AdaptiveAvgPool1D",
Expand Down
47 changes: 47 additions & 0 deletions serket/nn/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,53 @@ def spatial_ndim(self) -> int:
return 2


def solarize(
image: jax.Array,
threshold: float | int,
max_val: float | int,
) -> jax.Array:
"""Inverts all values above a given threshold."""
return jnp.where(image < threshold, image, max_val - image)


@sk.autoinit
class Solarize2D(sk.TreeClass):
"""Inverts all values above a given threshold.
Args:
threshold: The threshold value above which to invert.
max_val: The maximum value of the image. e.g. 255 for uint8 images.
1.0 for float images. default: 1.0
Example:
>>> import serket as sk
>>> import jax.numpy as jnp
>>> x = jnp.arange(1, 26).reshape(1, 5, 5)
>>> layer = sk.nn.Solarize2D(threshold=10, max_val=25)
>>> print(layer(x))
[[[ 1 2 3 4 5]
[ 6 7 8 9 15]
[14 13 12 11 10]
[ 9 8 7 6 5]
[ 4 3 2 1 0]]]
Reference:
- https://github.com/tensorflow/models/blob/v2.13.1/official/vision/ops/augment.py#L804-L809
"""

threshold: float
max_val: float = 1.0

@ft.partial(validate_spatial_ndim, attribute_name="spatial_ndim")
def __call__(self, x: jax.Array, **k) -> jax.Array:
threshold, max_val = jax.lax.stop_gradient((self.threshold, self.max_val))
return solarize(x, threshold, max_val)

@property
def spatial_ndim(self) -> int:
return 2


@tree_eval.def_eval(RandomContrast2D)
@tree_eval.def_eval(RandomRotate2D)
@tree_eval.def_eval(RandomHorizontalShear2D)
Expand Down
21 changes: 20 additions & 1 deletion tests/test_image_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy.testing as npt
import pytest

from serket.nn import AvgBlur2D, FFTFilter2D, Filter2D, GaussianBlur2D
from serket.nn import AvgBlur2D, FFTFilter2D, Filter2D, GaussianBlur2D, Solarize2D


def test_AvgBlur2D():
Expand Down Expand Up @@ -99,3 +99,22 @@ def test_filter2d():
layer2 = FFTFilter2D(in_features=1, kernel=jnp.ones([3, 3]) / 9.0)

npt.assert_allclose(layer(x), layer2(x), atol=1e-4)


def test_solarize2d():
x = jnp.arange(1, 26).reshape(1, 5, 5)
layer = Solarize2D(threshold=10, max_val=25)
npt.assert_allclose(
layer(x),
jnp.array(
[
[
[1, 2, 3, 4, 5],
[6, 7, 8, 9, 15],
[14, 13, 12, 11, 10],
[9, 8, 7, 6, 5],
[4, 3, 2, 1, 0],
]
]
),
)

0 comments on commit 4b9188f

Please sign in to comment.