Skip to content

Commit

Permalink
Fix to gray (#2050)
Browse files Browse the repository at this point in the history
* Empty-Commit

* Cleanup

* Cleanup
  • Loading branch information
ternaus authored Nov 2, 2024
1 parent 0dede8c commit 63da7b7
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 18 deletions.
4 changes: 2 additions & 2 deletions albumentations/augmentations/domain_adaptation/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,9 @@ def fourier_domain_adaptation(img: np.ndarray, target_img: np.ndarray, beta: flo
src_img = img.astype(np.float32)
trg_img = target_img.astype(np.float32)

if len(src_img.shape) == MONO_CHANNEL_DIMENSIONS:
if src_img.ndim == MONO_CHANNEL_DIMENSIONS:
src_img = np.expand_dims(src_img, axis=-1)
if len(trg_img.shape) == MONO_CHANNEL_DIMENSIONS:
if trg_img.ndim == MONO_CHANNEL_DIMENSIONS:
trg_img = np.expand_dims(trg_img, axis=-1)

num_channels = src_img.shape[-1]
Expand Down
4 changes: 2 additions & 2 deletions albumentations/augmentations/domain_adaptation/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,8 +511,8 @@ def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, A
template = fgeometric.resize(template, img.shape[:2], interpolation=cv2.INTER_AREA)

if get_num_channels(template) == 1 and get_num_channels(img) > 1:
template = np.stack((template,) * get_num_channels(img), axis=-1)

# Replicate single channel template across all channels to match input image
template = cv2.merge([template] * get_num_channels(img))
# in order to support grayscale image with dummy dim
template = template.reshape(img.shape)

Expand Down
18 changes: 10 additions & 8 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def solarize(img: np.ndarray, threshold: int) -> np.ndarray:
prev_shape = img.shape
img = sz_lut(img, np.array(lut, dtype=dtype), inplace=False)

return np.expand_dims(img, -1) if len(prev_shape) != len(img.shape) else img
return np.expand_dims(img, -1) if len(prev_shape) != img.ndim else img

cond = img >= threshold
img[cond] = max_val - img[cond]
Expand Down Expand Up @@ -1300,14 +1300,16 @@ def grayscale_to_multichannel(grayscale_image: np.ndarray, num_output_channels:
num_output_channels (int, optional): Number of channels in the output image. Defaults to 3.
Returns:
np.ndarray: Multi-channel image with shape (height, width, num_channels).
Note:
If the input is already a multi-channel image with the desired number of channels,
it will be returned unchanged.
np.ndarray: Multi-channel image with shape (height, width, num_channels)
"""
grayscale_image = grayscale_image.copy().squeeze()
return np.stack([grayscale_image] * num_output_channels, axis=-1)
# If output should be single channel, just squeeze and return
if num_output_channels == 1:
return grayscale_image

# For multi-channel output, squeeze and stack
squeezed = np.squeeze(grayscale_image)

return cv2.merge([squeezed] * num_output_channels)


@preserve_channel_dim
Expand Down
7 changes: 4 additions & 3 deletions albumentations/augmentations/geometric/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,9 +1814,10 @@ def bbox_distort_image(
bboxes = bboxes.copy()
masks = masks_from_bboxes(bboxes, image_shape)

transformed_masks = np.stack(
[distort_image(mask, generated_mesh, cv2.INTER_NEAREST) for mask in masks],
)
transformed_masks = cv2.merge([distort_image(mask, generated_mesh, cv2.INTER_NEAREST) for mask in masks])

if transformed_masks.ndim == NUM_MULTI_CHANNEL_DIMENSIONS:
transformed_masks = transformed_masks.transpose(2, 0, 1)

# Normalize the returned bboxes
bboxes[:, :4] = bboxes_from_masks(transformed_masks)
Expand Down
6 changes: 3 additions & 3 deletions albumentations/augmentations/text/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def convert_image_to_pil(image: np.ndarray) -> Image:
except ImportError:
raise ImportError("Pillow is not installed") from ImportError

if len(image.shape) == MONO_CHANNEL_DIMENSIONS: # (height, width)
if image.ndim == MONO_CHANNEL_DIMENSIONS: # (height, width)
return Image.fromarray(image)
if len(image.shape) == NUM_MULTI_CHANNEL_DIMENSIONS and image.shape[2] == 1: # (height, width, 1)
if image.ndim == NUM_MULTI_CHANNEL_DIMENSIONS and image.shape[2] == 1: # (height, width, 1)
return Image.fromarray(image[:, :, 0], mode="L")
if len(image.shape) == NUM_MULTI_CHANNEL_DIMENSIONS and image.shape[2] == NUM_RGB_CHANNELS: # (height, width, 3)
if image.ndim == NUM_MULTI_CHANNEL_DIMENSIONS and image.shape[2] == NUM_RGB_CHANNELS: # (height, width, 3)
return Image.fromarray(image)

raise TypeError(f"Unsupported image shape: {image.shape}")
Expand Down

0 comments on commit 63da7b7

Please sign in to comment.