Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/liezl/add-multiview-datastructur…
Browse files Browse the repository at this point in the history
…es' into liezl/add-cameragroup-class
  • Loading branch information
roomrys committed Jan 20, 2025
2 parents 1f7be2f + 0edce96 commit 6f9f95f
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 17 deletions.
100 changes: 85 additions & 15 deletions sleap_io/model/camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,26 +271,26 @@ class RecordingSession:
_camera_by_video: dict[Video, Camera] = field(factory=dict)

def get_video(self, camera: Camera) -> Video | None:
"""Get `Video` associated with `Camera`.
"""Get `Video` associated with `camera`.
Args:
camera: Camera to get video
camera: `Camera` to get `Video`
Returns:
Video associated with camera or None if not found
`Video` associated with `camera` or None if not found
"""
return self._video_by_camera.get(camera, None)

def add_video(self, video: Video, camera: Camera):
"""Add `Video` to `RecordingSession` and mapping to `Camera`.
"""Add `video` to `RecordingSession` and mapping to `camera`.
Args:
video: `Video` object to add to `RecordingSession`.
camera: `Camera` object to associate with `Video`.
camera: `Camera` object to associate with `video`.
Raises:
ValueError: If `Camera` is not in associated `CameraGroup`.
ValueError: If `Video` is not a `Video` object.
ValueError: If `camera` is not in associated `CameraGroup`.
ValueError: If `video` is not a `Video` object.
"""
# Raise ValueError if camera is not in associated camera group
self.camera_group.cameras.index(camera)
Expand All @@ -308,13 +308,13 @@ def add_video(self, video: Video, camera: Camera):
self._camera_by_video[video] = camera

def remove_video(self, video: Video):
"""Remove `Video` from `RecordingSession` and mapping to `Camera`.
"""Remove `video` from `RecordingSession` and mapping to `Camera`.
Args:
video: `Video` object to remove from `RecordingSession`.
Raises:
ValueError: If `Video` is not in associated `RecordingSession`.
ValueError: If `video` is not in associated `RecordingSession`.
"""
# Remove video from camera mapping
camera = self._camera_by_video.pop(video)
Expand All @@ -331,7 +331,7 @@ class Camera:
matrix: Intrinsic camera matrix of size (3, 3) and type float64.
dist: Radial-tangential distortion coefficients [k_1, k_2, p_1, p_2, k_3] of
size (5,) and type float64.
size: Image size of camera in pixels of size (2,) and type int.
size: Image size (width, height) of camera in pixels of size (2,) and type int.
rvec: Rotation vector in unnormalized axis-angle representation of size (3,) and
type float64.
tvec: Translation vector of size (3,) and type float64.
Expand Down Expand Up @@ -396,7 +396,7 @@ def _validate_shape(self, attribute: attrs.Attribute, value):
if np.shape(value) != expected_shape:
raise ValueError(
f"{attribute.name} must be a {expected_type} of size {expected_shape}, "
f"but recieved shape: {np.shape(value)} and type: {type(value)} for "
f"but received shape: {np.shape(value)} and type: {type(value)} for "
f"value: {value}"
)

Expand Down Expand Up @@ -490,20 +490,33 @@ def project(self, points: np.ndarray) -> np.ndarray:
"""Project 3D points to 2D using camera matrix and distortion coefficients.
Args:
points: 3D points to project of shape (N, 3) or (N, 1, 3).
points: 3D points to project of shape (..., 3) where "..." is any number of
dimensions (including 0).
Returns:
Projected 2D points of shape (N, 1, 2).
Projected 2D points of shape (..., 2) where "..." is the same as the input
"..." dimensions.
"""
points = points.reshape(-1, 1, 3)
# Validate points in
points_shape = points.shape
try:
if points_shape[-1] != 3:
raise ValueError
points = points.reshape(-1, 1, 3)
except Exception as e:
raise ValueError(
"Expected points to be an array of 3D points of shape (..., 3) where "
"'...' is any number of non-zero dimensions, but received shape "
f"{points_shape}.\n\n{e}"
)
out, _ = cv2.projectPoints(
points,
self.rvec,
self.tvec,
self.matrix,
self.dist,
)
return out
return out.reshape(*points_shape[:-1], 2)

def get_video(self, session: RecordingSession) -> Video | None:
"""Get video associated with recording session.
Expand All @@ -516,6 +529,63 @@ def get_video(self, session: RecordingSession) -> Video | None:
"""
return session.get_video(camera=self)

def to_dict(self) -> dict:
"""Convert `Camera` to dictionary.
Returns:
Dictionary containing camera information with the following keys:
- name: Camera name.
- size: Image size (width, height) of camera in pixels of size (2,) and
type int.
- matrix: Intrinsic camera matrix of size (3, 3) and type float64.
- distortions: Radial-tangential distortion coefficients
[k_1, k_2, p_1, p_2, k_3] of size (5,) and type float64.
- rotation: Rotation vector in unnormalized axis-angle representation of
size (3,) and type float64.
- translation: Translation vector of size (3,) and type float64.
"""
camera_dict = {
"name": self.name,
"size": list(self.size),
"matrix": self.matrix.tolist(),
"distortions": self.dist.tolist(),
"rotation": self.rvec.tolist(),
"translation": self.tvec.tolist(),
}

return camera_dict

@classmethod
def from_dict(cls, camera_dict: dict) -> Camera:
"""Create `Camera` from dictionary.
Args:
camera_dict: Dictionary containing camera information with the following
keys:
- name: Camera name.
- size: Image size (width, height) of camera in pixels of size (2,) and
type int.
- matrix: Intrinsic camera matrix of size (3, 3) and type float64.
- distortions: Radial-tangential distortion coefficients
[k_1, k_2, p_1, p_2, k_3] of size (5,) and type float64.
- rotation: Rotation vector in unnormalized axis-angle representation of
size (3,) and type float64.
- translation: Translation vector of size (3,) and type float64.
Returns:
`Camera` object created from dictionary.
"""
camera = cls(
name=camera_dict["name"],
size=camera_dict["size"],
matrix=camera_dict["matrix"],
dist=camera_dict["distortions"],
rvec=camera_dict["rotation"],
tvec=camera_dict["translation"],
)

return camera

# TODO: Remove this when we implement triangulation without aniposelib
def __getattr__(self, name: str):
"""Get attribute by name.
Expand Down
40 changes: 38 additions & 2 deletions tests/model/test_camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,42 @@ def test_camera_extrinsic_matrix():
np.testing.assert_array_equal(camera.tvec, tvec)


def test_camera_from_dict_to_dict():
"""Test camera from_dict method."""

# Define camera dictionary
name = "back"
size = [1280, 1024]
matrix = [
[762.513822135494, 0.0, 639.5],
[0.0, 762.513822135494, 511.5],
[0.0, 0.0, 1.0],
]
distortions = [-0.2868458380166852, 0.0, 0.0, 0.0, 0.0]
rotation = [0.3571857188780474, 0.8879473292757126, 1.6832001677006176]
translation = [-555.4577842902744, -294.43494957092884, -190.82196458369515]
camera_dict = {
"name": name,
"size": size,
"matrix": matrix,
"distortions": distortions,
"rotation": rotation,
"translation": translation,
}

# Test camera from_dict
camera = Camera.from_dict(camera_dict)
assert camera.name == "back"
assert camera.size == tuple(size)
np.testing.assert_array_almost_equal(camera.matrix, np.array(matrix))
np.testing.assert_array_almost_equal(camera.dist, np.array(distortions))
np.testing.assert_array_almost_equal(camera.rvec, np.array(rotation))
np.testing.assert_array_almost_equal(camera.tvec, np.array(translation))

# Test camera to_dict
assert camera.to_dict() == camera_dict


def test_camera_undistort_points():
"""Test camera undistort points method."""
camera = Camera(
Expand Down Expand Up @@ -225,11 +261,11 @@ def test_camera_project():

points = np.random.rand(10, 3)
projected_points = camera.project(points)
assert projected_points.shape == (points.shape[0], 1, 2)
assert projected_points.shape == (*points.shape[:-1], 2)

points = np.random.rand(10, 1, 3)
projected_points = camera.project(points)
assert projected_points.shape == (points.shape[0], 1, 2)
assert projected_points.shape == (*points.shape[:-1], 2)


def test_camera_get_video():
Expand Down

0 comments on commit 6f9f95f

Please sign in to comment.