From 63da7b725d89c2c74adba6a3ad8713fbd0fff2c1 Mon Sep 17 00:00:00 2001 From: Vladimir Iglovikov Date: Fri, 1 Nov 2024 19:22:17 -0700 Subject: [PATCH] Fix to gray (#2050) * Empty-Commit * Cleanup * Cleanup --- .../domain_adaptation/functional.py | 4 ++-- .../domain_adaptation/transforms.py | 4 ++-- albumentations/augmentations/functional.py | 18 ++++++++++-------- .../augmentations/geometric/functional.py | 7 ++++--- .../augmentations/text/functional.py | 6 +++--- 5 files changed, 21 insertions(+), 18 deletions(-) diff --git a/albumentations/augmentations/domain_adaptation/functional.py b/albumentations/augmentations/domain_adaptation/functional.py index 24f741cca..a5c2e4a11 100644 --- a/albumentations/augmentations/domain_adaptation/functional.py +++ b/albumentations/augmentations/domain_adaptation/functional.py @@ -277,9 +277,9 @@ def fourier_domain_adaptation(img: np.ndarray, target_img: np.ndarray, beta: flo src_img = img.astype(np.float32) trg_img = target_img.astype(np.float32) - if len(src_img.shape) == MONO_CHANNEL_DIMENSIONS: + if src_img.ndim == MONO_CHANNEL_DIMENSIONS: src_img = np.expand_dims(src_img, axis=-1) - if len(trg_img.shape) == MONO_CHANNEL_DIMENSIONS: + if trg_img.ndim == MONO_CHANNEL_DIMENSIONS: trg_img = np.expand_dims(trg_img, axis=-1) num_channels = src_img.shape[-1] diff --git a/albumentations/augmentations/domain_adaptation/transforms.py b/albumentations/augmentations/domain_adaptation/transforms.py index a40ff880b..0de301cdc 100644 --- a/albumentations/augmentations/domain_adaptation/transforms.py +++ b/albumentations/augmentations/domain_adaptation/transforms.py @@ -511,8 +511,8 @@ def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, A template = fgeometric.resize(template, img.shape[:2], interpolation=cv2.INTER_AREA) if get_num_channels(template) == 1 and get_num_channels(img) > 1: - template = np.stack((template,) * get_num_channels(img), axis=-1) - + # Replicate single channel template across all channels to match input image + template = cv2.merge([template] * get_num_channels(img)) # in order to support grayscale image with dummy dim template = template.reshape(img.shape) diff --git a/albumentations/augmentations/functional.py b/albumentations/augmentations/functional.py index 464478a71..ee2ebf15f 100644 --- a/albumentations/augmentations/functional.py +++ b/albumentations/augmentations/functional.py @@ -148,7 +148,7 @@ def solarize(img: np.ndarray, threshold: int) -> np.ndarray: prev_shape = img.shape img = sz_lut(img, np.array(lut, dtype=dtype), inplace=False) - return np.expand_dims(img, -1) if len(prev_shape) != len(img.shape) else img + return np.expand_dims(img, -1) if len(prev_shape) != img.ndim else img cond = img >= threshold img[cond] = max_val - img[cond] @@ -1300,14 +1300,16 @@ def grayscale_to_multichannel(grayscale_image: np.ndarray, num_output_channels: num_output_channels (int, optional): Number of channels in the output image. Defaults to 3. Returns: - np.ndarray: Multi-channel image with shape (height, width, num_channels). - - Note: - If the input is already a multi-channel image with the desired number of channels, - it will be returned unchanged. + np.ndarray: Multi-channel image with shape (height, width, num_channels) """ - grayscale_image = grayscale_image.copy().squeeze() - return np.stack([grayscale_image] * num_output_channels, axis=-1) + # If output should be single channel, just squeeze and return + if num_output_channels == 1: + return grayscale_image + + # For multi-channel output, squeeze and stack + squeezed = np.squeeze(grayscale_image) + + return cv2.merge([squeezed] * num_output_channels) @preserve_channel_dim diff --git a/albumentations/augmentations/geometric/functional.py b/albumentations/augmentations/geometric/functional.py index 704c89a86..7ddcd2e87 100644 --- a/albumentations/augmentations/geometric/functional.py +++ b/albumentations/augmentations/geometric/functional.py @@ -1814,9 +1814,10 @@ def bbox_distort_image( bboxes = bboxes.copy() masks = masks_from_bboxes(bboxes, image_shape) - transformed_masks = np.stack( - [distort_image(mask, generated_mesh, cv2.INTER_NEAREST) for mask in masks], - ) + transformed_masks = cv2.merge([distort_image(mask, generated_mesh, cv2.INTER_NEAREST) for mask in masks]) + + if transformed_masks.ndim == NUM_MULTI_CHANNEL_DIMENSIONS: + transformed_masks = transformed_masks.transpose(2, 0, 1) # Normalize the returned bboxes bboxes[:, :4] = bboxes_from_masks(transformed_masks) diff --git a/albumentations/augmentations/text/functional.py b/albumentations/augmentations/text/functional.py index 47717cb8b..9d4d8c865 100644 --- a/albumentations/augmentations/text/functional.py +++ b/albumentations/augmentations/text/functional.py @@ -64,11 +64,11 @@ def convert_image_to_pil(image: np.ndarray) -> Image: except ImportError: raise ImportError("Pillow is not installed") from ImportError - if len(image.shape) == MONO_CHANNEL_DIMENSIONS: # (height, width) + if image.ndim == MONO_CHANNEL_DIMENSIONS: # (height, width) return Image.fromarray(image) - if len(image.shape) == NUM_MULTI_CHANNEL_DIMENSIONS and image.shape[2] == 1: # (height, width, 1) + if image.ndim == NUM_MULTI_CHANNEL_DIMENSIONS and image.shape[2] == 1: # (height, width, 1) return Image.fromarray(image[:, :, 0], mode="L") - if len(image.shape) == NUM_MULTI_CHANNEL_DIMENSIONS and image.shape[2] == NUM_RGB_CHANNELS: # (height, width, 3) + if image.ndim == NUM_MULTI_CHANNEL_DIMENSIONS and image.shape[2] == NUM_RGB_CHANNELS: # (height, width, 3) return Image.fromarray(image) raise TypeError(f"Unsupported image shape: {image.shape}")