diff --git a/docs/API/color.rst b/docs/API/color.rst index 17c848f..8a1f48c 100644 --- a/docs/API/color.rst +++ b/docs/API/color.rst @@ -4,6 +4,10 @@ Color API .. autoclass:: RGBToGrayscale2D .. autoclass:: GrayscaleToRGB2D - .. autoclass:: RGBToHSV2D -.. autoclass:: HSVToRGB2D \ No newline at end of file +.. autoclass:: HSVToRGB2D + +.. autofunction:: rgb_to_grayscale +.. autofunction:: grayscale_to_rgb +.. autofunction:: rgb_to_hsv +.. autofunction:: hsv_to_rgb \ No newline at end of file diff --git a/serket/_src/image/augment.py b/serket/_src/image/augment.py index a2a6b67..62c833f 100644 --- a/serket/_src/image/augment.py +++ b/serket/_src/image/augment.py @@ -23,7 +23,7 @@ import serket as sk from serket._src.custom_transform import tree_eval -from serket._src.image.color import hsv_to_rgb_3d, rgb_to_hsv_3d +from serket._src.image.color import hsv_to_rgb, rgb_to_hsv from serket._src.nn.linear import Identity from serket._src.utils import ( CHWArray, @@ -188,18 +188,18 @@ def adjust_log_2d(image: HWArray, gain: float = 1, inv: bool = False) -> HWArray def adjust_hue_3d(image: CHWArray, factor: float) -> CHWArray: - h, s, v = rgb_to_hsv_3d(image) + h, s, v = rgb_to_hsv(image) divisor = 2 * jnp.pi h = jnp.fmod(h + factor, divisor) out = jnp.stack([h, s, v], axis=0) - return hsv_to_rgb_3d(out) + return hsv_to_rgb(out) def adust_saturation_3d(image: CHWArray, factor: float) -> CHWArray: - h, s, v = rgb_to_hsv_3d(image) + h, s, v = rgb_to_hsv(image) s = jnp.clip(s * factor, 0.0, 1.0) out = jnp.stack([h, s, v], axis=0) - return hsv_to_rgb_3d(out) + return hsv_to_rgb(out) def random_hue_3d( diff --git a/serket/_src/image/color.py b/serket/_src/image/color.py index 7cec4f5..9241682 100644 --- a/serket/_src/image/color.py +++ b/serket/_src/image/color.py @@ -78,8 +78,15 @@ def __call__(self, image: CHWArray) -> CHWArray: spatial_ndim: int = 2 -def rgb_to_hsv_3d(image: CHWArray) -> CHWArray: - """Convert an RGB image to HSV.""" +def rgb_to_hsv(image: CHWArray) -> CHWArray: + """Convert an RGB image to HSV. + + Args: + image: RGB image in channel-first format with range [0, 1]. + + Returns: + HSV image in channel-first format with range [0, 2pi] for hue, [0, 1] for saturation and value. + """ # https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html#rgb_to_hsv eps = jnp.finfo(image.dtype).eps @@ -106,7 +113,7 @@ def rgb_to_hsv_3d(image: CHWArray) -> CHWArray: return jnp.concatenate((h, s, v), axis=0) -def hsv_to_rgb_3d(image: CHWArray) -> CHWArray: +def hsv_to_rgb(image: CHWArray) -> CHWArray: """Convert an image from HSV to RGB.""" # https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html#rgb_to_hsv c, _, _ = image.shape @@ -168,7 +175,7 @@ class RGBToHSV2D(sk.TreeClass): @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: - return rgb_to_hsv_3d(image) + return rgb_to_hsv(image) spatial_ndim: int = 2 @@ -191,6 +198,6 @@ class HSVToRGB2D(sk.TreeClass): @ft.partial(validate_spatial_ndim, argnum=0) def __call__(self, image: CHWArray) -> CHWArray: - return hsv_to_rgb_3d(image) + return hsv_to_rgb(image) spatial_ndim: int = 2 diff --git a/serket/image/__init__.py b/serket/image/__init__.py index 7b3a5a0..9ece8df 100644 --- a/serket/image/__init__.py +++ b/serket/image/__init__.py @@ -35,6 +35,10 @@ HSVToRGB2D, RGBToGrayscale2D, RGBToHSV2D, + grayscale_to_rgb, + hsv_to_rgb, + rgb_to_grayscale, + rgb_to_hsv, ) from serket._src.image.filter import ( AvgBlur2D, @@ -189,4 +193,8 @@ "HSVToRGB2D", "RGBToGrayscale2D", "RGBToHSV2D", + "grayscale_to_rgb", + "hsv_to_rgb", + "rgb_to_grayscale", + "rgb_to_hsv", ]