Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Labels QOL enhancements #81

Merged
merged 16 commits into from
Apr 14, 2024
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ jobs:

# Tests with pytest
tests:
timeout-minutes: 15
strategy:
fail-fast: false
matrix:
Expand Down
1 change: 1 addition & 0 deletions sleap_io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@
save_jabs,
load_video,
load_file,
save_file,
)
Comment on lines 26 to 30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝 NOTE
This review was outside the diff hunks and was mapped to the diff hunk with the greatest overlap. Original lines [7-28]

Consider addressing the unused imports.

Many imports in this file are flagged as unused. If these are meant to be available for external use, add them to the __all__ list to explicitly export them. Otherwise, consider removing them to clean up the codebase. Here's an example of how you might modify the code:

- from sleap_io.model.skeleton import Node, Edge, Skeleton, Symmetry
- from sleap_io.model.video import Video
- from sleap_io.model.instance import (
-     Point,
-     PredictedPoint,
-     Track,
-     Instance,
-     PredictedInstance,
- )
- from sleap_io.model.labeled_frame import LabeledFrame
- from sleap_io.model.labels import Labels
- from sleap_io.io.main import (
-     load_slp,
-     save_slp,
-     load_nwb,
-     save_nwb,
-     load_labelstudio,
-     save_labelstudio,
-     load_jabs,
-     save_jabs,
-     load_video,
-     load_file,
-     save_file,
- )
+ __all__ = ['Node', 'Edge', 'Skeleton', 'Symmetry', 'Video', 'Point', 'PredictedPoint', 'Track', 'Instance', 'PredictedInstance', 'LabeledFrame', 'Labels', 'load_slp', 'save_slp', 'load_nwb', 'save_nwb', 'load_labelstudio', 'save_labelstudio', 'load_jabs', 'save_jabs', 'load_video', 'load_file', 'save_file']

8 changes: 8 additions & 0 deletions sleap_io/model/labeled_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,11 @@ def unused_predictions(self) -> list[Instance]:
]

return unused_predictions

def remove_predictions(self):
"""Remove all `PredictedInstance` objects from the frame."""
self.instances = [inst for inst in self.instances if type(inst) == Instance]

def remove_empty_instances(self):
"""Remove all instances with no visible points."""
self.instances = [inst for inst in self.instances if not inst.is_empty]
117 changes: 117 additions & 0 deletions sleap_io/model/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,27 @@ def __getitem__(self, key: int) -> list[LabeledFrame] | LabeledFrame:
"""Return one or more labeled frames based on indexing criteria."""
if type(key) == int:
return self.labeled_frames[key]
elif type(key) == slice:
return [self.labeled_frames[i] for i in range(*key.indices(len(self)))]
elif type(key) == list:
return [self.labeled_frames[i] for i in key]
elif isinstance(key, np.ndarray):
return [self.labeled_frames[i] for i in key.tolist()]
elif type(key) == tuple and len(key) == 2:
video, frame_idx = key
res = self.find(video, frame_idx)
if len(res) == 1:
return res[0]
elif len(res) == 0:
raise IndexError(
f"No labeled frames found for video {video} and "
f"frame index {frame_idx}."
)
elif type(key) == Video:
res = self.find(key)
if len(res) == 0:
raise IndexError(f"No labeled frames found for video {key}.")
return res
Comment on lines +66 to +86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using isinstance() for type checking instead of direct type comparison. This is more Pythonic and supports inheritance.

- elif type(key) == slice:
+ elif isinstance(key, slice):
- elif type(key) == list:
+ elif isinstance(key, list):
- elif type(key) == tuple and len(key) == 2:
+ elif isinstance(key, tuple) and len(key) == 2:
- elif type(key) == Video:
+ elif isinstance(key, Video):

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
elif type(key) == slice:
return [self.labeled_frames[i] for i in range(*key.indices(len(self)))]
elif type(key) == list:
return [self.labeled_frames[i] for i in key]
elif isinstance(key, np.ndarray):
return [self.labeled_frames[i] for i in key.tolist()]
elif type(key) == tuple and len(key) == 2:
video, frame_idx = key
res = self.find(video, frame_idx)
if len(res) == 1:
return res[0]
elif len(res) == 0:
raise IndexError(
f"No labeled frames found for video {video} and "
f"frame index {frame_idx}."
)
elif type(key) == Video:
res = self.find(key)
if len(res) == 0:
raise IndexError(f"No labeled frames found for video {key}.")
return res
elif isinstance(key, slice):
return [self.labeled_frames[i] for i in range(*key.indices(len(self)))]
elif isinstance(key, list):
return [self.labeled_frames[i] for i in key]
elif isinstance(key, np.ndarray):
return [self.labeled_frames[i] for i in key.tolist()]
elif isinstance(key, tuple) and len(key) == 2:
video, frame_idx = key
res = self.find(video, frame_idx)
if len(res) == 1:
return res[0]
elif len(res) == 0:
raise IndexError(
f"No labeled frames found for video {video} and "
f"frame index {frame_idx}."
)
elif isinstance(key, Video):
res = self.find(key)
if len(res) == 0:
raise IndexError(f"No labeled frames found for video {key}.")
return res

else:
raise IndexError(f"Invalid indexing argument for labels: {key}")

Expand Down Expand Up @@ -248,3 +269,99 @@ def find(
results.append(LabeledFrame(video=video, frame_idx=frame_ind))

return results

def save(self, filename: str, format: Optional[str] = None, **kwargs):
"""Save labels to file in specified format.

Args:
filename: Path to save labels to.
format: The format to save the labels in. If `None`, the format will be
inferred from the file extension. Available formats are "slp", "nwb",
"labelstudio", and "jabs".
"""
from sleap_io import save_file

save_file(self, filename, format=format, **kwargs)

def clean(
self,
frames: bool = True,
empty_instances: bool = False,
skeletons: bool = True,
tracks: bool = True,
videos: bool = False,
):
"""Remove empty frames, unused skeletons, tracks and videos.

Args:
frames: If `True` (the default), remove empty frames.
empty_instances: If `True` (NOT default), remove instances that have no
visible points.
skeletons: If `True` (the default), remove unused skeletons.
tracks: If `True` (the default), remove unused tracks.
videos: If `True` (NOT default), remove videos that have no labeled frames.
"""
used_skeletons = []
used_tracks = []
used_videos = []
kept_frames = []
for lf in self.labeled_frames:

if empty_instances:
lf.remove_empty_instances()

if frames and len(lf) == 0:
continue

if videos and lf.video not in used_videos:
used_videos.append(lf.video)

if skeletons or tracks:
for inst in lf:
if skeletons and inst.skeleton not in used_skeletons:
used_skeletons.append(inst.skeleton)
if (
tracks
and inst.track is not None
and inst.track not in used_tracks
):
used_tracks.append(inst.track)

if frames:
kept_frames.append(lf)

if videos:
self.videos = [video for video in self.videos if video in used_videos]

if skeletons:
self.skeletons = [
skeleton for skeleton in self.skeletons if skeleton in used_skeletons
]

if tracks:
self.tracks = [track for track in self.tracks if track in used_tracks]

if frames:
self.labeled_frames = kept_frames

def remove_predictions(self, clean: bool = True):
"""Remove all predicted instances from the labels.

Args:
clean: If `True` (the default), also remove any empty frames and unused
tracks and skeletons. It does NOT remove videos that have no labeled
frames or instances with no visible points.

See also: `Labels.clean`
"""
for lf in self.labeled_frames:
lf.remove_predictions()

if clean:
self.clean(
frames=True,
empty_instances=False,
skeletons=True,
tracks=True,
videos=False,
)
49 changes: 49 additions & 0 deletions tests/model/test_labeled_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from numpy.testing import assert_equal
from sleap_io import Video, Skeleton, Instance, PredictedInstance
from sleap_io.model.labeled_frame import LabeledFrame
import numpy as np


def test_labeled_frame():
Expand All @@ -26,3 +27,51 @@ def test_labeled_frame():

# Test LabeledFrame.__getitem__ method
assert lf[0] == inst


def test_remove_predictions():
"""Test removing predictions from `LabeledFrame`."""
inst = Instance([[0, 1], [2, 3]], skeleton=Skeleton(["A", "B"]))
lf = LabeledFrame(
video=Video(filename="test"),
frame_idx=0,
instances=[
inst,
PredictedInstance([[4, 5], [6, 7]], skeleton=Skeleton(["A", "B"])),
],
)

assert len(lf) == 2
assert len(lf.predicted_instances) == 1

# Remove predictions
lf.remove_predictions()

assert len(lf) == 1
assert len(lf.predicted_instances) == 0
assert type(lf[0]) == Instance
assert_equal(lf.numpy(), [[[0, 1], [2, 3]]])


def test_remove_empty_instances():
"""Test removing empty instances from `LabeledFrame`."""
inst = Instance([[0, 1], [2, 3]], skeleton=Skeleton(["A", "B"]))
lf = LabeledFrame(
video=Video(filename="test"),
frame_idx=0,
instances=[
inst,
Instance(
[[np.nan, np.nan], [np.nan, np.nan]], skeleton=Skeleton(["A", "B"])
),
],
)

assert len(lf) == 2

# Remove empty instances
lf.remove_empty_instances()

assert len(lf) == 1
assert type(lf[0]) == Instance
assert_equal(lf.numpy(), [[[0, 1], [2, 3]]])
141 changes: 141 additions & 0 deletions tests/model/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
Instance,
PredictedInstance,
LabeledFrame,
Track,
load_slp,
load_video,
)
from sleap_io.model.labels import Labels
import numpy as np


def test_labels():
Expand Down Expand Up @@ -117,3 +120,141 @@ def test_labels_skeleton():
labels.skeletons.append(Skeleton(["B"]))
with pytest.raises(ValueError):
labels.skeleton


def test_labels_getitem(slp_typical):
labels = load_slp(slp_typical)
labels.labeled_frames.append(LabeledFrame(video=labels.video, frame_idx=1))
assert len(labels) == 2
assert labels[0].frame_idx == 0
assert len(labels[:2]) == 2
assert len(labels[[0, 1]]) == 2
assert len(labels[np.array([0, 1])]) == 2
assert labels[(labels.video, 0)].frame_idx == 0

with pytest.raises(IndexError):
labels[(labels.video, 2000)]

assert len(labels[labels.video]) == 2

with pytest.raises(IndexError):
labels[Video(filename="test")]

with pytest.raises(IndexError):
labels[None]


def test_labels_save(tmp_path, slp_typical):
labels = load_slp(slp_typical)
labels.save(tmp_path / "test.slp")
assert (tmp_path / "test.slp").exists()


def test_labels_clean_unchanged(slp_real_data):
labels = load_slp(slp_real_data)
assert len(labels) == 10
assert labels[0].frame_idx == 0
assert len(labels[0]) == 2
assert labels[1].frame_idx == 990
assert len(labels[1]) == 2
assert len(labels.skeletons) == 1
assert len(labels.videos) == 1
assert len(labels.tracks) == 0
labels.clean(
frames=True, empty_instances=True, skeletons=True, tracks=True, videos=True
)
assert len(labels) == 10
assert labels[0].frame_idx == 0
assert len(labels[0]) == 2
assert labels[1].frame_idx == 990
assert len(labels[1]) == 2
assert len(labels.skeletons) == 1
assert len(labels.videos) == 1
assert len(labels.tracks) == 0


def test_labels_clean_frames(slp_real_data):
labels = load_slp(slp_real_data)
assert labels[0].frame_idx == 0
assert len(labels[0]) == 2
labels[0].instances = []
labels.clean(
frames=True, empty_instances=False, skeletons=False, tracks=False, videos=False
)
assert len(labels) == 9
assert labels[0].frame_idx == 990
assert len(labels[0]) == 2


def test_labels_clean_empty_instances(slp_real_data):
labels = load_slp(slp_real_data)
assert labels[0].frame_idx == 0
assert len(labels[0]) == 2
labels[0].instances = [
Instance.from_numpy(
np.full((len(labels.skeleton), 2), np.nan), skeleton=labels.skeleton
)
]
labels.clean(
frames=False, empty_instances=True, skeletons=False, tracks=False, videos=False
)
assert len(labels) == 10
assert labels[0].frame_idx == 0
assert len(labels[0]) == 0

labels.clean(
frames=True, empty_instances=True, skeletons=False, tracks=False, videos=False
)
assert len(labels) == 9


def test_labels_clean_skeletons(slp_real_data):
labels = load_slp(slp_real_data)
labels.skeletons.append(Skeleton(["A", "B"]))
assert len(labels.skeletons) == 2
labels.clean(
frames=False, empty_instances=False, skeletons=True, tracks=False, videos=False
)
assert len(labels) == 10
assert len(labels.skeletons) == 1


def test_labels_clean_tracks(slp_real_data):
labels = load_slp(slp_real_data)
labels.tracks.append(Track(name="test1"))
labels.tracks.append(Track(name="test2"))
assert len(labels.tracks) == 2
labels[0].instances[0].track = labels.tracks[1]
labels.clean(
frames=False, empty_instances=False, skeletons=False, tracks=True, videos=False
)
assert len(labels) == 10
assert len(labels.tracks) == 1
assert labels[0].instances[0].track == labels.tracks[0]
assert labels.tracks[0].name == "test2"


def test_labels_clean_videos(slp_real_data):
labels = load_slp(slp_real_data)
labels.videos.append(Video(filename="test2"))
assert len(labels.videos) == 2
labels.clean(
frames=False, empty_instances=False, skeletons=False, tracks=False, videos=True
)
assert len(labels) == 10
assert len(labels.videos) == 1
assert labels.video.filename == "tests/data/videos/centered_pair_low_quality.mp4"


def test_labels_remove_predictions(slp_real_data):
labels = load_slp(slp_real_data)
assert len(labels) == 10
assert sum([len(lf.predicted_instances) for lf in labels]) == 12
labels.remove_predictions(clean=False)
assert len(labels) == 10
assert sum([len(lf.predicted_instances) for lf in labels]) == 0

labels = load_slp(slp_real_data)
labels.remove_predictions(clean=True)
assert len(labels) == 5
assert sum([len(lf.predicted_instances) for lf in labels]) == 0
Loading