From d4cd0028e954367c6b42b6c5c31f495d5dcee294 Mon Sep 17 00:00:00 2001 From: Ben Lansdell Date: Fri, 2 Feb 2024 14:41:01 -0700 Subject: [PATCH] Add type hints to video.py (#12) --- ethome/video.py | 57 +++++++++++++++++++++++++------------------------ 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/ethome/video.py b/ethome/video.py index 18b6512..3e0780e 100644 --- a/ethome/video.py +++ b/ethome/video.py @@ -8,6 +8,7 @@ import warnings import types +from typing import Any from glob import glob from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import cross_val_predict, LeaveOneGroupOut, PredefinedSplit @@ -49,12 +50,12 @@ } -def _add_item_to_dict(tracking_files, metadata, k, item): +def _add_item_to_dict(tracking_files: list, metadata:dict, k:Any, item:Any): for fn in tracking_files: metadata[fn][k] = item -def _add_items_to_dict(tracking_files, metadata, k, items): +def _add_items_to_dict(tracking_files: list, metadata:dict, k:Any, items:Any): for fn, item in zip(tracking_files, items): metadata[fn][k] = item @@ -95,7 +96,7 @@ def create_metadata(tracking_files: list, **kwargs) -> dict: return metadata -def _convert_units(df): +def _convert_units(df:pd.DataFrame): # if 'frame_width', 'resolution' and 'frame_width_units' are provided, then we convert tracks to these units. if len(df.metadata.details) == 0: return @@ -110,7 +111,7 @@ def _convert_units(df): df.drop(columns="scale_factor", inplace=True) -def _validate_metadata(metadata, req_cols): +def _validate_metadata(metadata:dict, req_cols:list): has_all_dim_cols_count = 0 should_rescale = None @@ -215,7 +216,7 @@ def _validate_metadata(metadata, req_cols): @pd.api.extensions.register_dataframe_accessor("metadata") class EthologyMetadataAccessor(object): - def __init__(self, pandas_obj): + def __init__(self, pandas_obj:pd.DataFrame): self._obj = pandas_obj if "metadata__details" not in self._obj.attrs: self._obj.attrs["metadata__details"] = {} @@ -231,11 +232,11 @@ def label_key(self): return self._obj.attrs["metadata__label_key"] @details.setter - def details(self, val): + def details(self, val:Any): self._obj.attrs["metadata__details"] = val @label_key.setter - def label_key(self, val): + def label_key(self, val:Any): self._obj.attrs["metadata__label_key"] = val @property @@ -253,7 +254,7 @@ def reverse_label_key(self): @pd.api.extensions.register_dataframe_accessor("features") class EthologyFeaturesAccessor(object): - def __init__(self, pandas_obj): + def __init__(self, pandas_obj:pd.DataFrame): self._obj = pandas_obj if "features__active" not in self._obj.attrs: self._obj.attrs["features__active"] = None @@ -263,7 +264,7 @@ def active(self): return self._obj.attrs["features__active"] @active.setter - def active(self, val): + def active(self, val:Any): self._obj.attrs["features__active"] = val def activate(self, name: str) -> list: @@ -315,7 +316,7 @@ def deactivate(self, name: str) -> list: # Set features by individual or by group names def add( self, - feature_maker, + feature_maker: Any, featureset_name: str = None, add_to_features=True, required_columns=[], @@ -395,7 +396,7 @@ def deactivate_cols(self, col_names: list) -> list: @pd.api.extensions.register_dataframe_accessor("pose") class EthologyPoseAccessor(object): - def __init__(self, pandas_obj): + def __init__(self, pandas_obj:pd.DataFrame): self._obj = pandas_obj if "pose__body_parts" not in self._obj.attrs: @@ -415,7 +416,7 @@ def body_parts(self): return self._obj.attrs["pose__body_parts"] @body_parts.setter - def body_parts(self, val): + def body_parts(self, val:Any): self._obj.attrs["pose__body_parts"] = val @property @@ -423,7 +424,7 @@ def animals(self): return self._obj.attrs["pose__animals"] @animals.setter - def animals(self, val): + def animals(self, val:Any): self._obj.attrs["pose__animals"] = val @property @@ -431,7 +432,7 @@ def animal_setup(self): return self._obj.attrs["pose__animal_setup"] @animal_setup.setter - def animal_setup(self, val): + def animal_setup(self, val:Any): self._obj.attrs["pose__animal_setup"] = val @property @@ -439,13 +440,13 @@ def raw_track_columns(self): return self._obj.attrs["pose__raw_track_columns"] @raw_track_columns.setter - def raw_track_columns(self, val): + def raw_track_columns(self, val:Any): self._obj.attrs["pose__raw_track_columns"] = val @pd.api.extensions.register_dataframe_accessor("ml") class EthologyMLAccessor(object): - def __init__(self, pandas_obj): + def __init__(self, pandas_obj:pd.DataFrame): self._obj = pandas_obj if "ml__label_cols" not in self._obj.attrs: @@ -459,7 +460,7 @@ def label_cols(self): return self._obj.attrs["ml__label_cols"] @label_cols.setter - def label_cols(self, val): + def label_cols(self, val:Any): self._obj.attrs["ml__label_cols"] = val @property @@ -467,7 +468,7 @@ def fold_cols(self): return self._obj.attrs["ml__fold_cols"] @fold_cols.setter - def fold_cols(self, val): + def fold_cols(self, val:Any): self._obj.attrs["ml__fold_cols"] = val @property @@ -514,7 +515,7 @@ def _make_predefined_split(self, folds): # pragma: no cover @pd.api.extensions.register_dataframe_accessor("io") class EthologyIOAccessor(object): - def __init__(self, pandas_obj): + def __init__(self, pandas_obj:pd.DataFrame): self._obj = pandas_obj def save(self, fn_out: str) -> None: @@ -531,7 +532,7 @@ def save(self, fn_out: str) -> None: # file.write(pickle.dumps(df.__dict__, protocol = 4)) file.write(dill.dumps(df.__dict__, protocol=4)) - def to_dlc_csv(self, base_dir: str, save_h5_too=False) -> None: + def to_dlc_csv(self, base_dir: str, save_h5_too:bool=False) -> None: """Save ExperimentDataFrame tracking files to DLC csv format. Only save tracking data, not other computed features. @@ -588,7 +589,7 @@ def load(fn_in: str) -> pd.DataFrame: return load_experiment(fn_in) def save_movie( - self, label_columns, path_out: str, video_filenames=None + self, label_columns:list, path_out: str, video_filenames:list=None ) -> None: # pragma: no cover """Given columns indicating behavior predictions or whatever else, make a video with these predictions overlaid. @@ -655,7 +656,7 @@ def save_movie( os.system(cmd) -def _create_from_dict(metadata, part_renamer, animal_renamer): +def _create_from_dict(metadata:dict, part_renamer:dict, animal_renamer:dict): df = pd.DataFrame() # req_cols = ['fps'] # Drop requirement this is provided. Just omit addition of time column, if fps omitted @@ -678,7 +679,7 @@ def _create_from_dict(metadata, part_renamer, animal_renamer): return df -def _create_from_list(input, part_renamer, animal_renamer, **kwargs): +def _create_from_list(input:list, part_renamer:dict, animal_renamer:dict, **kwargs): if len(input) == 0: return pd.DataFrame() supported_exts = [".csv", ".h5", ".nwb", ".hdf5"] @@ -740,7 +741,7 @@ def create_dataset( return df -def _load_nwb(nwb_files, part_renamer, animal_renamer, set_as_label=True): +def _load_nwb(nwb_files:list, part_renamer:dict, animal_renamer:dict, set_as_label:bool=True): metadata = {} dfs = [] col_names_old = None @@ -786,7 +787,7 @@ def _load_nwb(nwb_files, part_renamer, animal_renamer, set_as_label=True): return df -def _load_tracks(df, part_renamer, animal_renamer, rescale=False): +def _load_tracks(df:pd.DataFrame, part_renamer:dict, animal_renamer:dict, rescale:bool=False): """Add tracks to DataFrame""" dfs = [] col_names_old = None @@ -872,12 +873,12 @@ def _load_tracks(df, part_renamer, animal_renamer, rescale=False): return col_names -def _load_labels(df, col_name="label", set_as_label=False): +def _load_labels(df:pd.DataFrame, col_name:str="label", set_as_label:bool=False): # For the moment only BORIS support return _load_labels_boris(df, col_name, set_as_label) -def _load_labels_boris(df, prefix="label", set_as_label=False): +def _load_labels_boris(df:pd.DataFrame, prefix:str="label", set_as_label:bol=False): """Add behavior label data to DataFrame""" label_cols = [] @@ -944,7 +945,7 @@ def get_sample_openfield_data(): return create_dataset(metadata) -def _make_dense_values_into_pairs(predictions, rate): # pragma: no cover +def _make_dense_values_into_pairs(predictions:list, rate:int): # pragma: no cover # Put into start/stop pairs pairs = [] in_pair = False