diff --git a/sleap_roots/series.py b/sleap_roots/series.py index 5b98228..75a9eba 100644 --- a/sleap_roots/series.py +++ b/sleap_roots/series.py @@ -97,25 +97,34 @@ def load( # Attempt to load the predictions, with error handling try: if primary_name: - primary_path = ( - Path(series_name).with_suffix(f".{primary_name}.slp").as_posix() - ) + if Path(primary_name).as_posix().endswith(".slp"): + primary_path = primary_name + else: + primary_path = ( + Path(series_name).with_suffix(f".{primary_name}.slp").as_posix() + ) if Path(primary_path).exists(): primary_labels = sio.load_slp(primary_path) else: print(f"Primary prediction file not found: {primary_path}") if lateral_name: - lateral_path = ( - Path(series_name).with_suffix(f".{lateral_name}.slp").as_posix() - ) + if lateral_name.endswith(".slp"): + lateral_path = lateral_name + else: + lateral_path = ( + Path(series_name).with_suffix(f".{lateral_name}.slp").as_posix() + ) if Path(lateral_path).exists(): lateral_labels = sio.load_slp(lateral_path) else: print(f"Lateral prediction file not found: {lateral_path}") if crown_name: - crown_path = ( - Path(series_name).with_suffix(f".{crown_name}.slp").as_posix() - ) + if crown_name.endswith(".slp"): + crown_path = crown_name + else: + crown_path = ( + Path(series_name).with_suffix(f".{crown_name}.slp").as_posix() + ) if Path(crown_path).exists(): crown_labels = sio.load_slp(crown_path) else: @@ -385,7 +394,21 @@ def find_all_h5_series(data_folders: Union[str, List[str]]) -> List[str]: def find_all_slp_series(data_folders: Union[str, List[str]]) -> List[str]: - """Find all .slp series from a list of folders.""" + """Find all .slp series from a list of folders. + + Args: + data_folders: Path or list of paths to folders containing .slp series. + + Returns: + A list of unique series names derived from the filenames. + """ + if type(data_folders) != list: + data_folders = [data_folders] + + slp_series = [] + for data_folder in data_folders: + slp_series.extend([Path(p).as_posix() for p in Path(data_folder).glob("*.slp")]) + return list(set([Path(p).name.split(".")[0] for p in slp_series])) def imgfig(