Skip to content

Commit

Permalink
expose color functional api
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Dec 27, 2023
1 parent 23f84f1 commit 15c6ab0
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 12 deletions.
8 changes: 6 additions & 2 deletions docs/API/color.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ Color API

.. autoclass:: RGBToGrayscale2D
.. autoclass:: GrayscaleToRGB2D

.. autoclass:: RGBToHSV2D
.. autoclass:: HSVToRGB2D
.. autoclass:: HSVToRGB2D

.. autofunction:: rgb_to_grayscale
.. autofunction:: grayscale_to_rgb
.. autofunction:: rgb_to_hsv
.. autofunction:: hsv_to_rgb
10 changes: 5 additions & 5 deletions serket/_src/image/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import serket as sk
from serket._src.custom_transform import tree_eval
from serket._src.image.color import hsv_to_rgb_3d, rgb_to_hsv_3d
from serket._src.image.color import hsv_to_rgb, rgb_to_hsv
from serket._src.nn.linear import Identity
from serket._src.utils import (
CHWArray,
Expand Down Expand Up @@ -188,18 +188,18 @@ def adjust_log_2d(image: HWArray, gain: float = 1, inv: bool = False) -> HWArray


def adjust_hue_3d(image: CHWArray, factor: float) -> CHWArray:
h, s, v = rgb_to_hsv_3d(image)
h, s, v = rgb_to_hsv(image)
divisor = 2 * jnp.pi
h = jnp.fmod(h + factor, divisor)
out = jnp.stack([h, s, v], axis=0)
return hsv_to_rgb_3d(out)
return hsv_to_rgb(out)


def adust_saturation_3d(image: CHWArray, factor: float) -> CHWArray:
h, s, v = rgb_to_hsv_3d(image)
h, s, v = rgb_to_hsv(image)
s = jnp.clip(s * factor, 0.0, 1.0)
out = jnp.stack([h, s, v], axis=0)
return hsv_to_rgb_3d(out)
return hsv_to_rgb(out)


def random_hue_3d(
Expand Down
17 changes: 12 additions & 5 deletions serket/_src/image/color.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,15 @@ def __call__(self, image: CHWArray) -> CHWArray:
spatial_ndim: int = 2


def rgb_to_hsv_3d(image: CHWArray) -> CHWArray:
"""Convert an RGB image to HSV."""
def rgb_to_hsv(image: CHWArray) -> CHWArray:
"""Convert an RGB image to HSV.
Args:
image: RGB image in channel-first format with range [0, 1].
Returns:
HSV image in channel-first format with range [0, 2pi] for hue, [0, 1] for saturation and value.
"""
# https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html#rgb_to_hsv

eps = jnp.finfo(image.dtype).eps
Expand All @@ -106,7 +113,7 @@ def rgb_to_hsv_3d(image: CHWArray) -> CHWArray:
return jnp.concatenate((h, s, v), axis=0)


def hsv_to_rgb_3d(image: CHWArray) -> CHWArray:
def hsv_to_rgb(image: CHWArray) -> CHWArray:
"""Convert an image from HSV to RGB."""
# https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html#rgb_to_hsv
c, _, _ = image.shape
Expand Down Expand Up @@ -168,7 +175,7 @@ class RGBToHSV2D(sk.TreeClass):

@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
return rgb_to_hsv_3d(image)
return rgb_to_hsv(image)

spatial_ndim: int = 2

Expand All @@ -191,6 +198,6 @@ class HSVToRGB2D(sk.TreeClass):

@ft.partial(validate_spatial_ndim, argnum=0)
def __call__(self, image: CHWArray) -> CHWArray:
return hsv_to_rgb_3d(image)
return hsv_to_rgb(image)

spatial_ndim: int = 2
8 changes: 8 additions & 0 deletions serket/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
HSVToRGB2D,
RGBToGrayscale2D,
RGBToHSV2D,
grayscale_to_rgb,
hsv_to_rgb,
rgb_to_grayscale,
rgb_to_hsv,
)
from serket._src.image.filter import (
AvgBlur2D,
Expand Down Expand Up @@ -189,4 +193,8 @@
"HSVToRGB2D",
"RGBToGrayscale2D",
"RGBToHSV2D",
"grayscale_to_rgb",
"hsv_to_rgb",
"rgb_to_grayscale",
"rgb_to_hsv",
]

0 comments on commit 15c6ab0

Please sign in to comment.