Skip to content

Commit

Permalink
Cleaned docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus committed Oct 20, 2024
1 parent 7e129aa commit ae60dca
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 40 deletions.
4 changes: 2 additions & 2 deletions albumentations/augmentations/blur/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def __init__(
always_apply: bool | None = None,
p: float = 0.5,
):
super().__init__(p, always_apply)
super().__init__(p=p, always_apply=always_apply)
self.radius = cast(tuple[int, int], radius)
self.alias_blur = cast(tuple[float, float], alias_blur)

Expand Down Expand Up @@ -803,7 +803,7 @@ def __init__(
always_apply: bool | None = None,
p: float = 0.5,
):
super().__init__(p, always_apply)
super().__init__(p=p, always_apply=always_apply)
self.max_factor = cast(tuple[float, float], max_factor)
self.step_factor = cast(tuple[float, float], step_factor)

Expand Down
14 changes: 14 additions & 0 deletions albumentations/augmentations/crops/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,9 @@ class RandomSizedCrop(_BaseRandomSizedCrop):
interpolation (OpenCV flag): Flag that is used to specify the interpolation algorithm. Should be one of:
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
Default: cv2.INTER_LINEAR.
mask_interpolation (OpenCV flag): Flag that is used to specify the interpolation algorithm for mask.
Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
Default: cv2.INTER_NEAREST.
p (float): Probability of applying the transform. Default: 1.0
Targets:
Expand Down Expand Up @@ -649,6 +652,9 @@ class RandomResizedCrop(_BaseRandomSizedCrop):
interpolation (OpenCV flag): Flag that is used to specify the interpolation algorithm. Should be one of:
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
Default: cv2.INTER_LINEAR
mask_interpolation (OpenCV flag): Flag that is used to specify the interpolation algorithm for mask.
Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
Default: cv2.INTER_NEAREST
p (float): Probability of applying the transform. Default: 1.0
Targets:
Expand Down Expand Up @@ -1020,6 +1026,9 @@ class RandomSizedBBoxSafeCrop(BBoxSafeRandomCrop):
interpolation (OpenCV flag): Flag that is used to specify the interpolation algorithm. Should be one of:
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.AREA, cv2.INTER_LANCZOS4.
Default: cv2.INTER_LINEAR.
mask_interpolation (OpenCV flag): Flag that is used to specify the interpolation algorithm for mask.
Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.AREA, cv2.INTER_LANCZOS4.
Default: cv2.INTER_NEAREST.
p (float): Probability of applying the transform. Default: 1.0.
Targets:
Expand Down Expand Up @@ -1175,6 +1184,11 @@ class CropAndPad(DualTransform):
OpenCV interpolation flag used for resizing if keep_size is True.
Default: cv2.INTER_LINEAR.
mask_interpolation (int):
OpenCV interpolation flag used for resizing if keep_size is True.
Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
Default: cv2.INTER_NEAREST.
p (float):
Probability of applying the transform. Default: 1.0.
Expand Down
35 changes: 28 additions & 7 deletions albumentations/augmentations/geometric/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ class BaseDistortion(DualTransform):
cv2.BORDER_CONSTANT. Default: None
mask_value (ColorType | None): Padding value for mask if
border_mode is cv2.BORDER_CONSTANT. Default: None
mask_interpolation (int): Flag that is used to specify the interpolation algorithm for mask.
Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
Default: cv2.INTER_NEAREST.
p (float): Probability of applying the transform. Default: 0.5
Targets:
Expand Down Expand Up @@ -182,6 +185,9 @@ class ElasticTransform(BaseDistortion):
less accurate for large sigma values. Default: False
same_dxdy (bool): Whether to use the same random displacement field for both x and y
directions. Can speed up the transform at the cost of less diverse distortions. Default: False
mask_interpolation (int): Flag that is used to specify the interpolation algorithm for mask.
Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
Default: cv2.INTER_NEAREST.
p (float): Probability of applying the transform. Default: 0.5
Targets:
Expand Down Expand Up @@ -227,8 +233,8 @@ def __init__(
mask_value: ScalarType | list[ScalarType] | None = None,
always_apply: bool | None = None,
approximate: bool = False,
mask_interpolation: int = cv2.INTER_NEAREST,
same_dxdy: bool = False,
mask_interpolation: int = cv2.INTER_NEAREST,
p: float = 0.5,
):
super().__init__(
Expand Down Expand Up @@ -291,6 +297,11 @@ class Perspective(DualTransform):
to True. If False, parts of the transformed image may be outside of the image plane.
This setting should not be set to True when using large scale values as it could lead to very large images.
Default: False.
interpolation (int): Interpolation method to be used for image transformation. Should be one
of the OpenCV interpolation types. Default: cv2.INTER_LINEAR
mask_interpolation (int): Flag that is used to specify the interpolation algorithm for mask.
Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
Default: cv2.INTER_NEAREST.
p (float): Probability of applying the transform. Default: 0.5.
Targets:
Expand Down Expand Up @@ -850,6 +861,9 @@ class ShiftScaleRotate(Affine):
in the range [-, 1]. Default: None.
rotate_method (str): rotation method used for the bounding boxes. Should be one of "largest_box" or "ellipse".
Default: "largest_box"
mask_interpolation (OpenCV flag): Flag that is used to specify the interpolation algorithm for mask.
Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
Default: cv2.INTER_NEAREST.
p (float): probability of applying the transform. Default: 0.5.
Targets:
Expand All @@ -873,6 +887,7 @@ class InitSchema(BaseTransformInitSchema):
shift_limit_x: ScaleFloatType | None = Field(default=None)
shift_limit_y: ScaleFloatType | None = Field(default=None)
rotate_method: Literal["largest_box", "ellipse"] = "largest_box"
mask_interpolation: InterpolationType

@model_validator(mode="after")
def check_shift_limit(self) -> Self:
Expand Down Expand Up @@ -903,6 +918,7 @@ def __init__(
shift_limit_x: ScaleFloatType | None = None,
shift_limit_y: ScaleFloatType | None = None,
rotate_method: Literal["largest_box", "ellipse"] = "largest_box",
mask_interpolation: InterpolationType = cv2.INTER_NEAREST,
always_apply: bool | None = None,
p: float = 0.5,
):
Expand All @@ -912,7 +928,7 @@ def __init__(
rotate=rotate_limit,
shear=(0, 0),
interpolation=interpolation,
mask_interpolation=cv2.INTER_NEAREST,
mask_interpolation=mask_interpolation,
cval=value,
cval_mask=mask_value,
mode=border_mode,
Expand Down Expand Up @@ -946,6 +962,7 @@ def get_transform_init_args(self) -> dict[str, Any]:
"value": self.value,
"mask_value": self.mask_value,
"rotate_method": self.rotate_method,
"mask_interpolation": self.mask_interpolation,
}


Expand Down Expand Up @@ -978,8 +995,9 @@ class PiecewiseAffine(DualTransform):
- 3: Bi-cubic
- 4: Bi-quartic
- 5: Bi-quintic
mask_interpolation (int): The order of interpolation for masks. Similar to 'interpolation' but for masks.
Default: 0 (Nearest-neighbor).
mask_interpolation (OpenCV flag): Flag that is used to specify the interpolation algorithm for mask.
Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
Default: cv2.INTER_NEAREST.
cval (number): The constant value to use when filling in newly created pixels.
Default: 0.
cval_mask (number): The constant value to use when filling in newly created pixels in masks.
Expand Down Expand Up @@ -1673,8 +1691,9 @@ class OpticalDistortion(BaseDistortion):
is cv2.BORDER_CONSTANT. Default: None.
mask_value (int, float, list of int, list of float): Padding value for mask
if border_mode is cv2.BORDER_CONSTANT. Default: None.
always_apply (bool): If True, the transform will be always applied.
Default: None.
mask_interpolation (OpenCV flag): Flag that is used to specify the interpolation algorithm for mask.
Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
Default: cv2.INTER_NEAREST.
p (float): Probability of applying the transform. Default: 0.5.
Targets:
Expand Down Expand Up @@ -1779,7 +1798,9 @@ class GridDistortion(BaseDistortion):
normalized (bool): If True, ensures that the distortion does not move pixels
outside the image boundaries. This can result in less extreme distortions
but guarantees that no information is lost. Default: True.
always_apply (bool): If True, the transform will be always applied. Default: None.
mask_interpolation (OpenCV flag): Flag that is used to specify the interpolation algorithm for mask.
Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
Default: cv2.INTER_NEAREST.
p (float): Probability of applying the transform. Default: 0.5.
Targets:
Expand Down
109 changes: 78 additions & 31 deletions albumentations/core/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class BaseCompose(Serializable):
check_each_transform: tuple[DataProcessor, ...] | None = None
main_compose: bool = True

def __init__(self, transforms: TransformsSeqType, p: float):
def __init__(self, transforms: TransformsSeqType, p: float, mask_interpolation: int | None = None):
if isinstance(transforms, (BaseCompose, BasicTransform)):
warnings.warn(
"transforms is single transform, but a sequence is expected! Transform will be wrapped into list.",
Expand All @@ -118,6 +118,19 @@ def __init__(self, transforms: TransformsSeqType, p: float):
self._available_keys: set[str] = set()
self.processors: dict[str, BboxProcessor | KeypointsProcessor] = {}
self._set_keys()
self.set_mask_interpolation(mask_interpolation)

def set_mask_interpolation(self, mask_interpolation: int | None) -> None:
self.mask_interpolation = mask_interpolation
self._set_mask_interpolation_recursive(self.transforms)

def _set_mask_interpolation_recursive(self, transforms: TransformsSeqType) -> None:
for transform in transforms:
if isinstance(transform, BasicTransform):
if hasattr(transform, "mask_interpolation") and self.mask_interpolation is not None:
transform.mask_interpolation = self.mask_interpolation
elif isinstance(transform, BaseCompose):
transform.set_mask_interpolation(self.mask_interpolation)

def __iter__(self) -> Iterator[TransformType]:
return iter(self.transforms)
Expand Down Expand Up @@ -228,20 +241,43 @@ def check_data_post_transform(self, data: Any) -> dict[str, Any]:


class Compose(BaseCompose, HubMixin):
"""Compose transforms and handle all transformations regarding bounding boxes
"""Compose multiple transforms together and apply them sequentially to input data.
This class allows you to chain multiple image augmentation transforms and apply them
in a specified order. It also handles bounding box and keypoint transformations if
the appropriate parameters are provided.
Args:
transforms (list): list of transformations to compose.
bbox_params (BboxParams): Parameters for bounding boxes transforms
keypoint_params (KeypointParams): Parameters for keypoints transforms
additional_targets (dict): Dict with keys - new target name, values - old target name. ex: {'image2': 'image'}
p (float): probability of applying all list of transforms. Default: 1.0.
is_check_shapes (bool): If True shapes consistency of images/mask/masks would be checked on each call. If you
would like to disable this check - pass False (do it only if you are sure in your data consistency).
strict (bool): If True, unknown keys will raise an error. If False, unknown keys will be ignored. Default: True.
return_params (bool): if True returns params of each applied transform
save_key (str): key to save applied params, default is 'applied_params'
transforms (List[Union[BasicTransform, BaseCompose]]): A list of transforms to apply.
bbox_params (Union[dict, BboxParams, None]): Parameters for bounding box transforms.
Can be a dict of params or a BboxParams object. Default is None.
keypoint_params (Union[dict, KeypointParams, None]): Parameters for keypoint transforms.
Can be a dict of params or a KeypointParams object. Default is None.
additional_targets (Dict[str, str], optional): A dictionary mapping additional target names
to their types. For example, {'image2': 'image'}. Default is None.
p (float): Probability of applying all transforms. Should be in range [0, 1]. Default is 1.0.
is_check_shapes (bool): If True, checks consistency of shapes for image/mask/masks on each call.
Disable only if you are sure about your data consistency. Default is True.
strict (bool): If True, raises an error on unknown input keys. If False, ignores them. Default is True.
return_params (bool): If True, returns parameters of applied transforms. Default is False.
save_key (str): Key to save applied params if return_params is True. Default is 'applied_params'.
mask_interpolation (int, optional): Interpolation method for mask transforms. When defined,
it overrides the interpolation method specified in individual transforms. Default is None.
Example:
>>> import albumentations as A
>>> transform = A.Compose([
... A.RandomCrop(width=256, height=256),
... A.HorizontalFlip(p=0.5),
... A.RandomBrightnessContrast(p=0.2),
... ])
>>> transformed = transform(image=image)
Note:
- The class checks the validity of input data and shapes if is_check_args and is_check_shapes are True.
- When bbox_params or keypoint_params are provided, it sets up the corresponding processors.
- The transform can handle additional targets specified in the additional_targets dictionary.
- If return_params is True, it will return the parameters of applied transforms in the output.
"""

def __init__(
Expand All @@ -257,7 +293,7 @@ def __init__(
save_key: str = "applied_params",
mask_interpolation: int | None = None, # Add this parameter
):
super().__init__(transforms=transforms, p=p)
super().__init__(transforms=transforms, p=p, mask_interpolation=mask_interpolation)

if bbox_params:
if isinstance(bbox_params, dict):
Expand Down Expand Up @@ -305,17 +341,6 @@ def __init__(

self._set_processors_for_transforms(self.transforms)

self.mask_interpolation = mask_interpolation
self._set_mask_interpolation(self.transforms)

def _set_mask_interpolation(self, transforms: TransformsSeqType) -> None:
for transform in transforms:
if isinstance(transform, BasicTransform):
if hasattr(transform, "mask_interpolation") and self.mask_interpolation is not None:
transform.mask_interpolation = self.mask_interpolation
elif isinstance(transform, BaseCompose):
self._set_mask_interpolation(transform.transforms)

def _set_processors_for_transforms(self, transforms: TransformsSeqType) -> None:
for transform in transforms:
if isinstance(transform, BasicTransform):
Expand Down Expand Up @@ -518,18 +543,40 @@ def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[s


class SomeOf(BaseCompose):
"""Select N transforms to apply. Selected transforms will be called with `force_apply=True`.
Transforms probabilities will be normalized to one 1, so in this case transforms probabilities works as weights.
"""Apply a random subset of transforms from the given list.
This class selects a specified number of transforms from the provided list
and applies them to the input data. The selection can be done with or without
replacement, allowing for the same transform to be potentially applied multiple times.
Args:
transforms (list): list of transformations to compose.
n (int): number of transforms to apply.
replace (bool): Whether the sampled transforms are with or without replacement. Default: True.
p (float): probability of applying selected transform. Default: 1.
transforms (List[Union[BasicTransform, BaseCompose]]): A list of transforms to choose from.
n (int): The number of transforms to apply. If greater than the number of
transforms and replace=False, it will be set to the number of transforms.
replace (bool): Whether to sample transforms with replacement. Default is True.
p (float): Probability of applying the selected transforms. Should be in the range [0, 1].
Default is 1.0.
mask_interpolation (int, optional): Interpolation method for mask transforms.
When defined, it overrides the interpolation method
specified in individual transforms. Default is None.
Note:
- If `n` is greater than the number of transforms and `replace` is False,
`n` will be set to the number of transforms with a warning.
- The probabilities of individual transforms are used as weights for sampling.
- When `replace` is True, the same transform can be selected multiple times.
Example:
>>> import albumentations as A
>>> transform = A.SomeOf([
... A.HorizontalFlip(p=1),
... A.VerticalFlip(p=1),
... A.RandomBrightnessContrast(p=1),
... ], n=2, replace=False, p=0.5)
>>> # This will apply 2 out of the 3 transforms with 50% probability
"""

def __init__(self, transforms: TransformsSeqType, n: int, replace: bool = True, p: float = 1):
def __init__(self, transforms: TransformsSeqType, n: int = 1, replace: bool = False, p: float = 1):
super().__init__(transforms, p)
self.n = n
if not replace and n > len(self.transforms):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,3 +1411,21 @@ def test_mask_interpolation(augmentation_cls, params, interpolation):
assert transformed["mask"].flags["C_CONTIGUOUS"]

np.testing.assert_array_equal(transformed["mask"], transformed["image"])



@pytest.mark.parametrize("interpolation", [cv2.INTER_NEAREST,
cv2.INTER_LINEAR,
cv2.INTER_CUBIC,
cv2.INTER_AREA
])
@pytest.mark.parametrize("compose", [A.Compose, A.OneOf, A.Sequential, A.SomeOf])
def test_mask_interpolation_someof(interpolation, compose):
transform = A.Compose([compose([A.Affine(p=1), A.RandomSizedCrop(min_max_height=(4, 8), size= (113, 103), p=1)], p=1)], mask_interpolation=interpolation)

image = SQUARE_UINT8_IMAGE
mask = image.copy()

transformed = transform(image=image, mask=mask)

assert transformed["mask"].flags["C_CONTIGUOUS"]

0 comments on commit ae60dca

Please sign in to comment.