diff --git a/sleap_roots/__init__.py b/sleap_roots/__init__.py index ea56b38..7b0b898 100644 --- a/sleap_roots/__init__.py +++ b/sleap_roots/__init__.py @@ -19,7 +19,13 @@ OlderMonocotPipeline, MultipleDicotPipeline, ) -from sleap_roots.series import Series, find_all_series +from sleap_roots.series import ( + Series, + find_all_h5_paths, + find_all_slp_paths, + load_series_from_h5s, + load_series_from_slps, +) # Define package version. # This is read dynamically by setuptools in pyproject.toml to determine the release version. diff --git a/sleap_roots/series.py b/sleap_roots/series.py index a3e97d3..43f4db2 100644 --- a/sleap_roots/series.py +++ b/sleap_roots/series.py @@ -130,19 +130,22 @@ def load( h5_path = Path(h5_path) # Check if the file exists if h5_path.exists(): + # Make the h5_path POSIX-compliant + h5_path = h5_path.as_posix() # Load the video - video = sio.Video.from_filename(h5_path.as_posix()) + video = sio.Video.from_filename(h5_path) + # Replace the filename in the labels with the h5_path + for labels in [primary_labels, lateral_labels, crown_labels]: + if labels is not None: + labels.video.replace_filename(h5_path) else: - print(f"Video file not found: {h5_path.as_posix()}") + print(f"Video file not found: {h5_path}") except Exception as e: print(f"Error loading video file {h5_path}: {e}") - # Replace the filename in the labels with the h5_path if it is provided. - if h5_path: - for labels in [primary_labels, lateral_labels, crown_labels]: - if labels is not None: - if not labels.video.exists(): - labels.video.replace_filename(h5_path) + # Make the csv path POSIX-compliant + if csv_path: + csv_path = Path(csv_path).as_posix() return cls( series_name=series_name, @@ -204,7 +207,14 @@ def qc_fail(self) -> Union[int, float]: def __len__(self) -> int: """Length of the series (number of images).""" - return len(self.video) + if self.video is None: + for labels in [self.primary_labels, self.lateral_labels, self.crown_labels]: + if labels is not None: + return len(labels) + else: + return 0 + else: + return len(self.video) def __getitem__(self, idx: int) -> Dict[str, Optional[sio.LabeledFrame]]: """Return labeled frames for primary and/or lateral and/or crown predictions.""" @@ -262,6 +272,10 @@ def plot(self, frame_idx: int, scale: float = 1.0, **kwargs): scale: Relative size of the visualized image. Useful for plotting smaller images within notebooks. """ + # Check if the video is available + if self.video is None: + raise ValueError("Video is not available. Specify the h5_path to load it.") + # Retrieve all available frames frames = self.get_frame(frame_idx) @@ -391,6 +405,115 @@ def find_all_h5_paths(data_folders: Union[str, List[str]]) -> List[str]: return h5_paths +def find_all_slp_paths(data_folders: Union[str, List[str]]) -> List[str]: + """Find all .slp paths from a list of folders. + + Args: + data_folders: Path or list of paths to folders containing .slp paths. + + Returns: + A list of filenames to .slp paths. + """ + if type(data_folders) != list: + data_folders = [data_folders] + + slp_paths = [] + for data_folder in data_folders: + slp_paths.extend([Path(p).as_posix() for p in Path(data_folder).glob("*.slp")]) + return slp_paths + + +def load_series_from_h5s( + h5_paths: List[str], model_id: str, csv_path: Optional[str] = None +) -> List[Series]: + """Load a list of Series from a list of .h5 paths. + + To load the `Series`, the files must be named with the following convention: + h5_path: '/path/to/scan/series_name.h5' + primary_path: '/path/to/scan/series_name.model{model_id}.rootprimary.slp' + lateral_path: '/path/to/scan/series_name.model{model_id}.rootlateral.slp' + crown_path: '/path/to/scan/series_name.model{model_id}.rootcrown.slp' + + Our pipeline outputs prediction files with this format: + //scan{scan_id}.model{model_id}.root{model_type}.slp + + Args: + h5_paths: List of paths to .h5 files. + csv_path: Optional path to the CSV file containing the expected plant count. + + Returns: + A list of Series loaded with the specified .h5 files. + """ + series_list = [] + for h5_path in h5_paths: + # Extract the series name from the h5 path + series_name = Path(h5_path).name.split(".")[0] + # Generate the paths for the primary, lateral, and crown predictions + primary_path = h5_path.replace(".h5", f".model{model_id}.rootprimary.slp") + lateral_path = h5_path.replace(".h5", f".model{model_id}.rootlateral.slp") + crown_path = h5_path.replace(".h5", f".model{model_id}.rootcrown.slp") + # Load the Series + series = Series.load( + series_name, + h5_path=h5_path, + primary_path=primary_path, + lateral_path=lateral_path, + crown_path=crown_path, + csv_path=csv_path, + ) + series_list.append(series) + return series_list + + +def load_series_from_slps( + slp_paths: List[str], h5s: bool, csv_path: Optional[str] = None +) -> List[Series]: + """Load a list of Series from a list of .slp paths. + + To load the `Series`, the files must be named with the following convention: + h5_path: '/path/to/scan/series_name.h5' + primary_path: '/path/to/scan/series_name.model{model_id}.rootprimary.slp' + lateral_path: '/path/to/scan/series_name.model{model_id}.rootlateral.slp' + crown_path: '/path/to/scan/series_name.model{model_id}.rootcrown.slp' + Note that everything is expected to be in the same folder. + + Our pipeline outputs prediction files with this format: + //scan{scan_id}.model{model_id}.root{model_type}.slp + + + Args: + slp_paths: List of paths to .slp files. + h5s: Boolean flag to indicate if the .h5 files are available. + csv_path: Optional path to the CSV file containing the expected plant count. + """ + series_list = [] + series_names = list(set([Path(p).name.split(".")[0] for p in slp_paths])) + for series_name in series_names: + # Generate the paths for the primary, lateral, and crown predictions + primary_path = [p for p in slp_paths if series_name in p and "primary" in p] + lateral_path = [p for p in slp_paths if series_name in p and "lateral" in p] + crown_path = [p for p in slp_paths if series_name in p and "crown" in p] + # Check if the .h5 files are available + if h5s: + # Get directory of the h5s + h5_dir = Path(slp_paths[0]).parent + # Generate the path to the .h5 file + h5_path = h5_dir / f"{series_name}.h5" + else: + h5_path = None + # Load the Series + series = Series.load( + series_name, + primary_path=primary_path[0] if primary_path else None, + lateral_path=lateral_path[0] if lateral_path else None, + crown_path=crown_path[0] if crown_path else None, + h5_path=h5_path, + csv_path=csv_path, + ) + series_list.append(series) + return series_list + + def imgfig( size: Union[float, Tuple] = 6, dpi: int = 72, scale: float = 1.0 ) -> matplotlib.figure.Figure: diff --git a/tests/test_series.py b/tests/test_series.py index 6c861b2..01064a1 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1,7 +1,13 @@ import sleap_io as sio import numpy as np import pytest -from sleap_roots.series import Series, find_all_series +from sleap_roots.series import ( + Series, + find_all_slp_paths, + find_all_h5_paths, + load_series_from_h5s, + load_series_from_slps, +) from pathlib import Path from typing import Literal from contextlib import redirect_stdout @@ -11,40 +17,45 @@ @pytest.fixture def series_instance(): # Create a Series instance with dummy data - return Series(h5_path="dummy.h5") + return Series( + series_name="dummy", + h5_path="dummy.h5", + primary_path="dummy.model1.rootprimary.slp", + lateral_path="dummy.model1.rootlateral.slp", + ) @pytest.fixture def dummy_video_path(tmp_path): - video_path = tmp_path / "dummy_video.mp4" + video_path = tmp_path / "dummy.h5" video_path.write_text("This is a dummy video file.") return str(video_path) @pytest.fixture(params=["primary", "lateral", "crown"]) def label_type(request): - """Yields label types for tests, one by one.""" + """Yields root types for tests, one by one.""" return request.param @pytest.fixture def dummy_labels_path(tmp_path, label_type): - labels_path = tmp_path / f"dummy.{label_type}.predictions.slp" + labels_path = tmp_path / f"dummy.model1.root{label_type}.slp" # Simulate the structure of a SLEAP labels file. labels_path.write_text("Dummy SLEAP labels content.") return str(labels_path) @pytest.fixture -def dummy_series(dummy_video_path, dummy_labels_path): - # Assuming dummy_labels_path names are formatted as "{label_type}.predictions.slp" - # Extract the label type (primary, lateral, crown) from the filename - label_type = Path(dummy_labels_path).stem.split(".")[1] +def dummy_series(dummy_video_path, label_type, dummy_labels_path): + # Assuming dummy_labels_path names are formatted as + # "dummy.model1.root{label_type}.slp" # Construct the keyword argument for Series.load() kwargs = { + "series_name": "dummy", "h5_path": dummy_video_path, - f"{label_type}_name": dummy_labels_path, + f"{label_type}_path": dummy_labels_path, } return Series.load(**kwargs) @@ -66,7 +77,11 @@ def test_primary_prediction_not_found(tmp_path): # Create a dummy Series instance with a non-existent primary prediction file output = io.StringIO() with redirect_stdout(output): - Series.load(h5_path=dummy_video_path, primary_name="nonexistent") + Series.load( + series_name="dummy_video", + h5_path=dummy_video_path, + primary_path="dummy_video.model1.rootprimary.slp", + ) # format file path string for assert statement new_file_path = Path(dummy_video_path).with_suffix("").as_posix() @@ -74,7 +89,7 @@ def test_primary_prediction_not_found(tmp_path): assert ( output.getvalue() - == f"Primary prediction file not found: {new_file_path}.nonexistent.predictions.slp\n" + == f"Primary prediction file not found: dummy_video.model1.rootprimary.slp\n" ) @@ -85,14 +100,18 @@ def test_lateral_prediction_not_found(tmp_path): # Create a dummy Series instance with a non-existent primary prediction file output = io.StringIO() with redirect_stdout(output): - Series.load(h5_path=dummy_video_path, lateral_name="nonexistent") + Series.load( + series_name="dummy_video", + h5_path=dummy_video_path, + lateral_path="dummy_video.model1.rootlateral.slp", + ) # format file path string for assert statement new_file_path = Path(dummy_video_path).with_suffix("").as_posix() assert ( output.getvalue() - == f"Lateral prediction file not found: {new_file_path}.nonexistent.predictions.slp\n" + == f"Lateral prediction file not found: dummy_video.model1.rootlateral.slp\n" ) @@ -103,14 +122,18 @@ def test_crown_prediction_not_found(tmp_path): # Create a dummy Series instance with a non-existent primary prediction file output = io.StringIO() with redirect_stdout(output): - Series.load(h5_path=dummy_video_path, crown_name="nonexistent") + Series.load( + series_name="dummy_video", + h5_path=dummy_video_path, + crown_path="dummy_video.model1.rootcrown.slp", + ) # format file path string for assert statement new_file_path = Path(dummy_video_path).with_suffix("").as_posix() assert ( output.getvalue() - == f"Crown prediction file not found: {new_file_path}.nonexistent.predictions.slp\n" + == f"Crown prediction file not found: dummy_video.model1.rootcrown.slp\n" ) @@ -120,18 +143,10 @@ def test_video_loading_error(tmp_path): output = io.StringIO() with redirect_stdout(output): - Series.load(h5_path=invalid_video_path) + Series.load(series_name="invalid_video", h5_path=invalid_video_path) # Check if the correct error message is output - assert ( - output.getvalue() - == f"Error loading video file {invalid_video_path}: File not found\n" - ) - - -def test_series_name(dummy_series): - expected_name = "dummy_video" # Based on the dummy_video_path fixture - assert dummy_series.series_name == expected_name + assert output.getvalue() == f"Video file not found: {invalid_video_path}\n" def test_get_frame(dummy_series): @@ -143,11 +158,6 @@ def test_get_frame(dummy_series): assert "crown" in frames -def test_series_name_property(): - series = Series(h5_path="mock_path/file_name.h5") - assert series.series_name == "file_name" - - def test_series_name(series_instance): assert series_instance.series_name == "dummy" @@ -172,25 +182,43 @@ def test_qc_cylinder(series_instance, csv_path): assert series_instance.qc_fail == 0 -def test_len(): - series = Series(video=["frame1", "frame2"]) +def test_len_video(): + series = Series(series_name="test_video", video=["frame1", "frame2"]) assert len(series) == 2 -def test_series_load_canola(canola_h5: Literal["tests/data/canola_7do/919QDUH.h5"]): - series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") +def test_len_no_video(): + series = Series(series_name="test_no_video", video=None) + assert len(series) == 0 + + +def test_series_load_canola( + canola_h5: Literal["tests/data/canola_7do/919QDUH.h5"], + canola_primary_slp: Literal[ + "tests/data/canola_7do/919QDUH.primary.predictions.slp" + ], +): + series = Series.load( + series_name="919QDUH", h5_path=canola_h5, primary_path=canola_primary_slp + ) assert len(series) == 72 -def test_find_all_series_canola(canola_folder: Literal["tests/data/canola_7do"]): - all_series_files = find_all_series(canola_folder) - assert len(all_series_files) == 1 +def test_find_all_series_from_h5s_canola( + canola_folder: Literal["tests/data/canola_7do"], +): + h5_paths = find_all_h5_paths(canola_folder) + all_series = load_series_from_h5s(h5_paths) + assert len(all_series) == 1 def test_load_rice_10do( rice_main_10do_h5: Literal["tests/data/rice_10do/0K9E8BI.h5"], + rice_main_10do_slp: Literal["tests/data/rice_10do/0K9E8BI.crown.predictions.slp"], ): - series = Series.load(rice_main_10do_h5, crown_name="crown") + series = Series.load( + series_name="0K9E8BI", h5_path=rice_main_10do_h5, crown_path=rice_main_10do_slp + ) expected_video = sio.Video.from_filename(rice_main_10do_h5) assert len(series) == 72 @@ -211,13 +239,37 @@ def test_get_frame_rice_10do( expected_labeled_frame = expected_labels[0] # Load the series - series = Series.load(rice_main_10do_h5, crown_name="crown") + series = Series.load( + series_name="0K9E8BI", h5_path=rice_main_10do_h5, crown_path=rice_main_10do_slp + ) # Retrieve all available frames frames = series.get_frame(frame_idx) # Get the crown labeled frame crown_lf = frames.get("crown") + + assert crown_lf == expected_labeled_frame + assert series.series_name == "0K9E8BI" + + +def test_get_frame_rice_10do_no_video( + rice_main_10do_slp: Literal["tests/data/rice_10do/0K9E8BI.crown.predictions.slp"], +): + # Set the frame index to 0 + frame_idx = 0 + + # Load the expected Labels object for comparison + expected_labels = sio.load_slp(rice_main_10do_slp) + # Get the first labeled frame + expected_labeled_frame = expected_labels[frame_idx] + + # Load the series + series = Series.load(series_name="0K9E8BI", crown_path=rice_main_10do_slp) + # Retrieve all available frames + frames = series.get_frame(frame_idx) + # Get the crown labeled frame + crown_lf = frames.get("crown") + assert crown_lf == expected_labeled_frame - # Check the series name property assert series.series_name == "0K9E8BI" diff --git a/tests/test_trait_pipelines.py b/tests/test_trait_pipelines.py index c16e814..41b9c18 100644 --- a/tests/test_trait_pipelines.py +++ b/tests/test_trait_pipelines.py @@ -6,7 +6,13 @@ OlderMonocotPipeline, MultipleDicotPipeline, ) -from sleap_roots.series import Series, find_all_series +from sleap_roots.series import ( + Series, + find_all_h5_paths, + find_all_slp_paths, + load_series_from_h5s, + load_series_from_slps, +) def test_dicot_pipeline(canola_h5, soy_h5):