Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rotate and project #2159

Merged
merged 2 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ Spatial-level transforms will simultaneously change both an input image as well
| [RandomVerticalFlip](https://explore.albumentations.ai/transform/RandomVerticalFlip) | ✓ | ✓ | ✓ | ✓ |
| [Resize](https://explore.albumentations.ai/transform/Resize) | ✓ | ✓ | ✓ | ✓ |
| [Rotate](https://explore.albumentations.ai/transform/Rotate) | ✓ | ✓ | ✓ | ✓ |
| [RotateAndProject](https://explore.albumentations.ai/transform/RotateAndProject) | ✓ | ✓ | ✓ | ✓ |
| [SafeRotate](https://explore.albumentations.ai/transform/SafeRotate) | ✓ | ✓ | ✓ | ✓ |
| [ShiftScaleRotate](https://explore.albumentations.ai/transform/ShiftScaleRotate) | ✓ | ✓ | ✓ | ✓ |
| [SmallestMaxSize](https://explore.albumentations.ai/transform/SmallestMaxSize) | ✓ | ✓ | ✓ | ✓ |
Expand Down
76 changes: 74 additions & 2 deletions albumentations/augmentations/geometric/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2265,7 +2265,7 @@ def center(image_shape: tuple[int, int]) -> tuple[float, float]:
image_shape (tuple[int, int]): The shape of the image.

Returns:
tuple[float, float]: The center coordinates.
tuple[float, float]: center_x, center_y
"""
height, width = image_shape[:2]
return width / 2 - 0.5, height / 2 - 0.5
Expand All @@ -2278,7 +2278,7 @@ def center_bbox(image_shape: tuple[int, int]) -> tuple[float, float]:
image_shape (tuple[int, int]): The shape of the image.

Returns:
tuple[float, float]: The center coordinates.
tuple[float, float]: center_x, center_y
"""
height, width = image_shape[:2]
return width / 2, height / 2
Expand Down Expand Up @@ -3128,3 +3128,75 @@ def get_fisheye_distortion_maps(
map_y = cy + r_dist * np.sin(theta)

return map_x, map_y


def get_projection_matrix(
image_shape: tuple[int, int],
x_angle: float,
y_angle: float,
z_angle: float,
focal_length: float,
center_xy: tuple[float, float],
) -> np.ndarray:
"""Get projection matrix for perspective transform.

Args:
image_shape: Height and width of the image
x_angle: Rotation angle around X axis in radians
y_angle: Rotation angle around Y axis in radians
z_angle: Rotation angle around Z axis in radians
focal_length: Focal length of the virtual camera
center_xy: Center point (x,y) of the transform

Returns:
3x3 projection matrix
"""
height, width = image_shape
center_x, center_y = center_xy

# Create translation matrices
to_origin = np.array([[1.0, 0.0, -center_x], [0.0, 1.0, -center_y], [0.0, 0.0, 1.0]], dtype=np.float64)

from_origin = np.array([[1.0, 0.0, center_x], [0.0, 1.0, center_y], [0.0, 0.0, 1.0]], dtype=np.float64)

# Create focal length matrix
focal = np.array([[focal_length, 0.0, 0.0], [0.0, focal_length, 0.0], [0.0, 0.0, 1.0]], dtype=np.float64)

# Get rotation matrix
rotation = get_rotation_matrix_3d(x_angle, y_angle, z_angle)

# Compose final matrix: from_origin @ rotation @ focal @ to_origin
matrix = from_origin @ rotation @ focal @ to_origin

# Return inverse matrix for warpPerspective
return np.linalg.inv(matrix).astype(np.float32)
ternaus marked this conversation as resolved.
Show resolved Hide resolved


def get_rotation_matrix_3d(x_angle: float, y_angle: float, z_angle: float) -> np.ndarray:
"""Get 3D rotation matrix.

Args:
x_angle: Rotation angle around X axis in radians
y_angle: Rotation angle around Y axis in radians
z_angle: Rotation angle around Z axis in radians

Returns:
3x3 rotation matrix
"""
# Create rotation matrices
cos_x, sin_x = np.cos(x_angle), np.sin(x_angle)
cos_y, sin_y = np.cos(y_angle), np.sin(y_angle)
cos_z, sin_z = np.cos(z_angle), np.sin(z_angle)

# X rotation
rx = np.array([[1.0, 0.0, 0.0], [0.0, cos_x, -sin_x], [0.0, sin_x, cos_x]], dtype=np.float64)

# Y rotation
ry = np.array([[cos_y, 0.0, sin_y], [0.0, 1.0, 0.0], [-sin_y, 0.0, cos_y]], dtype=np.float64)

# Z rotation
rz = np.array([[cos_z, -sin_z, 0.0], [sin_z, cos_z, 0.0], [0.0, 0.0, 1.0]], dtype=np.float64)

# Combine rotations: Y * X * Z
# This order matches the expected test results
return rx @ ry @ rz
181 changes: 177 additions & 4 deletions albumentations/augmentations/geometric/rotate.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from __future__ import annotations

import math
from typing import Any, cast
from typing import Annotated, Any, cast

import cv2
import numpy as np
from pydantic import AfterValidator
from typing_extensions import Literal

from albumentations.augmentations.crops import functional as fcrops
from albumentations.augmentations.geometric.transforms import Affine
from albumentations.core.pydantic import BorderModeType, InterpolationType, SymmetricRangeType
from albumentations.augmentations.geometric.transforms import Affine, Perspective
from albumentations.core.pydantic import BorderModeType, InterpolationType, SymmetricRangeType, nondecreasing
from albumentations.core.transforms_interface import BaseTransformInitSchema, DualTransform
from albumentations.core.types import (
ColorType,
Expand All @@ -19,7 +20,7 @@

from . import functional as fgeometric

__all__ = ["Rotate", "RandomRotate90", "SafeRotate"]
__all__ = ["Rotate", "RandomRotate90", "SafeRotate", "RotateAndProject"]

SMALL_NUMBER = 1e-10

Expand Down Expand Up @@ -484,3 +485,175 @@ def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, A
"bbox_matrix": bbox_matrix,
"output_shape": image_shape,
}


class RotateAndProject(Perspective):
"""Applies 3D rotation to an image and projects it back to 2D plane using perspective projection.

This transform simulates viewing a 2D image from different 3D viewpoints by:
1. Rotating the image around three axes (X, Y, Z) in 3D space
2. Applying perspective projection to map the rotated image back to 2D
3. Handling different center calculations for images/keypoints and bounding boxes

The transform preserves aspect ratios and handles all target types (images, masks,
keypoints, and bounding boxes) consistently.

Args:
x_angle_range (tuple[float, float]): Range for rotation around x-axis in degrees.
Positive angles rotate the top edge away from viewer.
Default: (-15, 15)
y_angle_range (tuple[float, float]): Range for rotation around y-axis in degrees.
Positive angles rotate the right edge away from viewer.
Default: (-15, 15)
z_angle_range (tuple[float, float]): Range for rotation around z-axis in degrees.
Positive angles rotate clockwise in image plane.
Default: (-15, 15)
focal_range (tuple[float, float]): Range for focal length of perspective projection.
Controls the strength of perspective effect:
- Values < 1.0: Strong perspective (wide-angle lens effect)
- Value = 1.0: Normal perspective
- Values > 1.0: Weak perspective (telephoto lens effect)
Default: (0.5, 1.5)
border_mode (OpenCV flag): Padding mode for borders after rotation.
Should be one of:
- cv2.BORDER_CONSTANT: pads with constant value
- cv2.BORDER_REFLECT: reflects border pixels
- cv2.BORDER_REFLECT_101: reflects border pixels without duplicating edge pixels
- cv2.BORDER_REPLICATE: replicates border pixels
Default: cv2.BORDER_CONSTANT
pad_val (int, float, list): Padding value if border_mode is cv2.BORDER_CONSTANT.
Default: 0
mask_pad_val (int, float, list): Padding value for masks if border_mode is cv2.BORDER_CONSTANT.
Default: 0
interpolation (OpenCV flag): Interpolation method for image transformation.
Should be one of:
- cv2.INTER_NEAREST: nearest-neighbor interpolation
- cv2.INTER_LINEAR: bilinear interpolation
- cv2.INTER_CUBIC: bicubic interpolation
Default: cv2.INTER_LINEAR
mask_interpolation (OpenCV flag): Interpolation method for mask transformation.
Default: cv2.INTER_NEAREST
p (float): Probability of applying the transform.
Default: 0.5

Targets:
image, mask, keypoints, bboxes

Image types:
uint8, float32

Note:
- The transform maintains original image size
- Uses different center calculations for images/keypoints (width-1)/2 vs bboxes width/2
- Handles all coordinate transformations in homogeneous coordinates
- Applies proper perspective transformation to bounding boxes by transforming corners

Example:
>>> import albumentations as A
>>> transform = A.RotateAndProject(
... x_angle_range=(-30, 30),
... y_angle_range=(-30, 30),
... z_angle_range=(-15, 15),
... focal_range=(0.7, 1.3),
... p=1.0
... )
>>> result = transform(image=image, bboxes=bboxes, keypoints=keypoints)
"""

class InitSchema(BaseTransformInitSchema):
x_angle_range: Annotated[tuple[float, float], AfterValidator(nondecreasing)]
y_angle_range: Annotated[tuple[float, float], AfterValidator(nondecreasing)]
z_angle_range: Annotated[tuple[float, float], AfterValidator(nondecreasing)]
focal_range: Annotated[tuple[float, float], AfterValidator(nondecreasing)]
mask_interpolation: InterpolationType
interpolation: InterpolationType
pad_mode: int
pad_val: ColorType
mask_pad_val: ColorType

def __init__(
self,
x_angle_range: tuple[float, float] = (-15, 15),
y_angle_range: tuple[float, float] = (-15, 15),
z_angle_range: tuple[float, float] = (-15, 15),
focal_range: tuple[float, float] = (0.5, 1.5),
pad_mode: int = cv2.BORDER_CONSTANT,
pad_val: ColorType = 0,
mask_pad_val: ColorType = 0,
interpolation: int = cv2.INTER_LINEAR,
mask_interpolation: int = cv2.INTER_NEAREST,
p: float = 0.5,
always_apply: bool | None = None,
):
super().__init__(
scale=(0, 0), # Unused but required by parent
keep_size=True,
pad_mode=pad_mode,
pad_val=pad_val,
mask_pad_val=mask_pad_val,
interpolation=interpolation,
mask_interpolation=mask_interpolation,
p=p,
)
self.x_angle_range = x_angle_range
self.y_angle_range = y_angle_range
self.z_angle_range = z_angle_range
self.focal_range = focal_range
self.pad_val = pad_val
self.mask_pad_val = mask_pad_val
self.interpolation = interpolation
self.mask_interpolation = mask_interpolation

def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
image_shape = params["shape"][:2]

height, width = image_shape
# Sample parameters
x_angle = np.deg2rad(self.py_random.uniform(*self.x_angle_range))
y_angle = np.deg2rad(self.py_random.uniform(*self.y_angle_range))
z_angle = np.deg2rad(self.py_random.uniform(*self.z_angle_range))
focal_length = self.py_random.uniform(*self.focal_range)

# Get projection matrix
matrix = fgeometric.get_projection_matrix(
image_shape,
x_angle,
y_angle,
z_angle,
focal_length,
fgeometric.center(image_shape),
)

matrix_bbox = fgeometric.get_projection_matrix(
image_shape,
x_angle,
y_angle,
z_angle,
focal_length,
fgeometric.center_bbox(image_shape),
)

return {"matrix": matrix, "max_height": height, "max_width": width, "matrix_bbox": matrix_bbox}

def get_transform_init_args_names(self) -> tuple[str, ...]:
return (
"x_angle_range",
"y_angle_range",
"z_angle_range",
"focal_range",
"pad_mode",
"pad_val",
"mask_pad_val",
"interpolation",
"mask_interpolation",
)

def apply_to_bboxes(
self,
bboxes: np.ndarray,
matrix_bbox: np.ndarray,
max_height: int,
max_width: int,
**params: Any,
) -> np.ndarray:
return fgeometric.perspective_bboxes(bboxes, params["shape"], matrix_bbox, max_width, max_height, True)
8 changes: 4 additions & 4 deletions albumentations/augmentations/geometric/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def apply_to_mask(
matrix,
max_width,
max_height,
self.pad_val,
self.mask_pad_val,
self.pad_mode,
self.keep_size,
self.mask_interpolation,
Expand All @@ -428,15 +428,15 @@ def apply_to_mask(
def apply_to_bboxes(
self,
bboxes: np.ndarray,
matrix: np.ndarray,
matrix_bbox: np.ndarray,
max_height: int,
max_width: int,
**params: Any,
) -> np.ndarray:
return fgeometric.perspective_bboxes(
bboxes,
params["shape"],
matrix,
matrix_bbox,
max_width,
max_height,
self.keep_size,
Expand Down Expand Up @@ -472,7 +472,7 @@ def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, A
if self.fit_output:
matrix, max_width, max_height = fgeometric.expand_transform(matrix, image_shape)

return {"matrix": matrix, "max_height": max_height, "max_width": max_width}
return {"matrix": matrix, "max_height": max_height, "max_width": max_width, "matrix_bbox": matrix}

def get_transform_init_args_names(self) -> tuple[str, ...]:
return (
Expand Down
1 change: 1 addition & 0 deletions tests/aug_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,4 +413,5 @@
[A.Illumination, {}],
[A.ThinPlateSpline, {}],
[A.AutoContrast, {}],
[A.RotateAndProject, {}],
]
Loading