Skip to content

Commit

Permalink
add option to load Series from h5s or slps
Browse files Browse the repository at this point in the history
  • Loading branch information
eberrigan committed May 11, 2024
1 parent f7d66d5 commit 21884dc
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 52 deletions.
8 changes: 7 additions & 1 deletion sleap_roots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
OlderMonocotPipeline,
MultipleDicotPipeline,
)
from sleap_roots.series import Series, find_all_series
from sleap_roots.series import (
Series,
find_all_h5_paths,
find_all_slp_paths,
load_series_from_h5s,
load_series_from_slps,
)

# Define package version.
# This is read dynamically by setuptools in pyproject.toml to determine the release version.
Expand Down
141 changes: 132 additions & 9 deletions sleap_roots/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,19 +130,22 @@ def load(
h5_path = Path(h5_path)
# Check if the file exists
if h5_path.exists():
# Make the h5_path POSIX-compliant
h5_path = h5_path.as_posix()
# Load the video
video = sio.Video.from_filename(h5_path.as_posix())
video = sio.Video.from_filename(h5_path)
# Replace the filename in the labels with the h5_path
for labels in [primary_labels, lateral_labels, crown_labels]:
if labels is not None:
labels.video.replace_filename(h5_path)
else:
print(f"Video file not found: {h5_path.as_posix()}")
print(f"Video file not found: {h5_path}")
except Exception as e:
print(f"Error loading video file {h5_path}: {e}")

# Replace the filename in the labels with the h5_path if it is provided.
if h5_path:
for labels in [primary_labels, lateral_labels, crown_labels]:
if labels is not None:
if not labels.video.exists():
labels.video.replace_filename(h5_path)
# Make the csv path POSIX-compliant
if csv_path:
csv_path = Path(csv_path).as_posix()

return cls(
series_name=series_name,
Expand Down Expand Up @@ -204,7 +207,14 @@ def qc_fail(self) -> Union[int, float]:

def __len__(self) -> int:
"""Length of the series (number of images)."""
return len(self.video)
if self.video is None:
for labels in [self.primary_labels, self.lateral_labels, self.crown_labels]:
if labels is not None:
return len(labels)
else:
return 0
else:
return len(self.video)

def __getitem__(self, idx: int) -> Dict[str, Optional[sio.LabeledFrame]]:
"""Return labeled frames for primary and/or lateral and/or crown predictions."""
Expand Down Expand Up @@ -262,6 +272,10 @@ def plot(self, frame_idx: int, scale: float = 1.0, **kwargs):
scale: Relative size of the visualized image. Useful for plotting smaller
images within notebooks.
"""
# Check if the video is available
if self.video is None:
raise ValueError("Video is not available. Specify the h5_path to load it.")

# Retrieve all available frames
frames = self.get_frame(frame_idx)

Expand Down Expand Up @@ -391,6 +405,115 @@ def find_all_h5_paths(data_folders: Union[str, List[str]]) -> List[str]:
return h5_paths


def find_all_slp_paths(data_folders: Union[str, List[str]]) -> List[str]:
"""Find all .slp paths from a list of folders.
Args:
data_folders: Path or list of paths to folders containing .slp paths.
Returns:
A list of filenames to .slp paths.
"""
if type(data_folders) != list:
data_folders = [data_folders]

slp_paths = []
for data_folder in data_folders:
slp_paths.extend([Path(p).as_posix() for p in Path(data_folder).glob("*.slp")])
return slp_paths


def load_series_from_h5s(
h5_paths: List[str], model_id: str, csv_path: Optional[str] = None
) -> List[Series]:
"""Load a list of Series from a list of .h5 paths.
To load the `Series`, the files must be named with the following convention:
h5_path: '/path/to/scan/series_name.h5'
primary_path: '/path/to/scan/series_name.model{model_id}.rootprimary.slp'
lateral_path: '/path/to/scan/series_name.model{model_id}.rootlateral.slp'
crown_path: '/path/to/scan/series_name.model{model_id}.rootcrown.slp'
Our pipeline outputs prediction files with this format:
/<output_folder>/scan{scan_id}.model{model_id}.root{model_type}.slp
Args:
h5_paths: List of paths to .h5 files.
csv_path: Optional path to the CSV file containing the expected plant count.
Returns:
A list of Series loaded with the specified .h5 files.
"""
series_list = []
for h5_path in h5_paths:
# Extract the series name from the h5 path
series_name = Path(h5_path).name.split(".")[0]
# Generate the paths for the primary, lateral, and crown predictions
primary_path = h5_path.replace(".h5", f".model{model_id}.rootprimary.slp")
lateral_path = h5_path.replace(".h5", f".model{model_id}.rootlateral.slp")
crown_path = h5_path.replace(".h5", f".model{model_id}.rootcrown.slp")
# Load the Series
series = Series.load(
series_name,
h5_path=h5_path,
primary_path=primary_path,
lateral_path=lateral_path,
crown_path=crown_path,
csv_path=csv_path,
)
series_list.append(series)
return series_list


def load_series_from_slps(
slp_paths: List[str], h5s: bool, csv_path: Optional[str] = None
) -> List[Series]:
"""Load a list of Series from a list of .slp paths.
To load the `Series`, the files must be named with the following convention:
h5_path: '/path/to/scan/series_name.h5'
primary_path: '/path/to/scan/series_name.model{model_id}.rootprimary.slp'
lateral_path: '/path/to/scan/series_name.model{model_id}.rootlateral.slp'
crown_path: '/path/to/scan/series_name.model{model_id}.rootcrown.slp'
Note that everything is expected to be in the same folder.
Our pipeline outputs prediction files with this format:
/<output_folder>/scan{scan_id}.model{model_id}.root{model_type}.slp
Args:
slp_paths: List of paths to .slp files.
h5s: Boolean flag to indicate if the .h5 files are available.
csv_path: Optional path to the CSV file containing the expected plant count.
"""
series_list = []
series_names = list(set([Path(p).name.split(".")[0] for p in slp_paths]))
for series_name in series_names:
# Generate the paths for the primary, lateral, and crown predictions
primary_path = [p for p in slp_paths if series_name in p and "primary" in p]
lateral_path = [p for p in slp_paths if series_name in p and "lateral" in p]
crown_path = [p for p in slp_paths if series_name in p and "crown" in p]
# Check if the .h5 files are available
if h5s:
# Get directory of the h5s
h5_dir = Path(slp_paths[0]).parent
# Generate the path to the .h5 file
h5_path = h5_dir / f"{series_name}.h5"
else:
h5_path = None
# Load the Series
series = Series.load(
series_name,
primary_path=primary_path[0] if primary_path else None,
lateral_path=lateral_path[0] if lateral_path else None,
crown_path=crown_path[0] if crown_path else None,
h5_path=h5_path,
csv_path=csv_path,
)
series_list.append(series)
return series_list


def imgfig(
size: Union[float, Tuple] = 6, dpi: int = 72, scale: float = 1.0
) -> matplotlib.figure.Figure:
Expand Down
Loading

0 comments on commit 21884dc

Please sign in to comment.