Skip to content

Commit

Permalink
Add type hints to video.py (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
benlansdell committed Feb 2, 2024
1 parent 4a03342 commit d4cd002
Showing 1 changed file with 29 additions and 28 deletions.
57 changes: 29 additions & 28 deletions ethome/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings
import types

from typing import Any
from glob import glob
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_predict, LeaveOneGroupOut, PredefinedSplit
Expand Down Expand Up @@ -49,12 +50,12 @@
}


def _add_item_to_dict(tracking_files, metadata, k, item):
def _add_item_to_dict(tracking_files: list, metadata:dict, k:Any, item:Any):
for fn in tracking_files:
metadata[fn][k] = item


def _add_items_to_dict(tracking_files, metadata, k, items):
def _add_items_to_dict(tracking_files: list, metadata:dict, k:Any, items:Any):
for fn, item in zip(tracking_files, items):
metadata[fn][k] = item

Expand Down Expand Up @@ -95,7 +96,7 @@ def create_metadata(tracking_files: list, **kwargs) -> dict:
return metadata


def _convert_units(df):
def _convert_units(df:pd.DataFrame):
# if 'frame_width', 'resolution' and 'frame_width_units' are provided, then we convert tracks to these units.
if len(df.metadata.details) == 0:
return
Expand All @@ -110,7 +111,7 @@ def _convert_units(df):
df.drop(columns="scale_factor", inplace=True)


def _validate_metadata(metadata, req_cols):
def _validate_metadata(metadata:dict, req_cols:list):
has_all_dim_cols_count = 0

should_rescale = None
Expand Down Expand Up @@ -215,7 +216,7 @@ def _validate_metadata(metadata, req_cols):

@pd.api.extensions.register_dataframe_accessor("metadata")
class EthologyMetadataAccessor(object):
def __init__(self, pandas_obj):
def __init__(self, pandas_obj:pd.DataFrame):
self._obj = pandas_obj
if "metadata__details" not in self._obj.attrs:
self._obj.attrs["metadata__details"] = {}
Expand All @@ -231,11 +232,11 @@ def label_key(self):
return self._obj.attrs["metadata__label_key"]

@details.setter
def details(self, val):
def details(self, val:Any):
self._obj.attrs["metadata__details"] = val

@label_key.setter
def label_key(self, val):
def label_key(self, val:Any):
self._obj.attrs["metadata__label_key"] = val

@property
Expand All @@ -253,7 +254,7 @@ def reverse_label_key(self):

@pd.api.extensions.register_dataframe_accessor("features")
class EthologyFeaturesAccessor(object):
def __init__(self, pandas_obj):
def __init__(self, pandas_obj:pd.DataFrame):
self._obj = pandas_obj
if "features__active" not in self._obj.attrs:
self._obj.attrs["features__active"] = None
Expand All @@ -263,7 +264,7 @@ def active(self):
return self._obj.attrs["features__active"]

@active.setter
def active(self, val):
def active(self, val:Any):
self._obj.attrs["features__active"] = val

def activate(self, name: str) -> list:
Expand Down Expand Up @@ -315,7 +316,7 @@ def deactivate(self, name: str) -> list:
# Set features by individual or by group names
def add(
self,
feature_maker,
feature_maker: Any,
featureset_name: str = None,
add_to_features=True,
required_columns=[],
Expand Down Expand Up @@ -395,7 +396,7 @@ def deactivate_cols(self, col_names: list) -> list:

@pd.api.extensions.register_dataframe_accessor("pose")
class EthologyPoseAccessor(object):
def __init__(self, pandas_obj):
def __init__(self, pandas_obj:pd.DataFrame):
self._obj = pandas_obj

if "pose__body_parts" not in self._obj.attrs:
Expand All @@ -415,37 +416,37 @@ def body_parts(self):
return self._obj.attrs["pose__body_parts"]

@body_parts.setter
def body_parts(self, val):
def body_parts(self, val:Any):
self._obj.attrs["pose__body_parts"] = val

@property
def animals(self):
return self._obj.attrs["pose__animals"]

@animals.setter
def animals(self, val):
def animals(self, val:Any):
self._obj.attrs["pose__animals"] = val

@property
def animal_setup(self):
return self._obj.attrs["pose__animal_setup"]

@animal_setup.setter
def animal_setup(self, val):
def animal_setup(self, val:Any):
self._obj.attrs["pose__animal_setup"] = val

@property
def raw_track_columns(self):
return self._obj.attrs["pose__raw_track_columns"]

@raw_track_columns.setter
def raw_track_columns(self, val):
def raw_track_columns(self, val:Any):
self._obj.attrs["pose__raw_track_columns"] = val


@pd.api.extensions.register_dataframe_accessor("ml")
class EthologyMLAccessor(object):
def __init__(self, pandas_obj):
def __init__(self, pandas_obj:pd.DataFrame):
self._obj = pandas_obj

if "ml__label_cols" not in self._obj.attrs:
Expand All @@ -459,15 +460,15 @@ def label_cols(self):
return self._obj.attrs["ml__label_cols"]

@label_cols.setter
def label_cols(self, val):
def label_cols(self, val:Any):
self._obj.attrs["ml__label_cols"] = val

@property
def fold_cols(self):
return self._obj.attrs["ml__fold_cols"]

@fold_cols.setter
def fold_cols(self, val):
def fold_cols(self, val:Any):
self._obj.attrs["ml__fold_cols"] = val

@property
Expand Down Expand Up @@ -514,7 +515,7 @@ def _make_predefined_split(self, folds): # pragma: no cover

@pd.api.extensions.register_dataframe_accessor("io")
class EthologyIOAccessor(object):
def __init__(self, pandas_obj):
def __init__(self, pandas_obj:pd.DataFrame):
self._obj = pandas_obj

def save(self, fn_out: str) -> None:
Expand All @@ -531,7 +532,7 @@ def save(self, fn_out: str) -> None:
# file.write(pickle.dumps(df.__dict__, protocol = 4))
file.write(dill.dumps(df.__dict__, protocol=4))

def to_dlc_csv(self, base_dir: str, save_h5_too=False) -> None:
def to_dlc_csv(self, base_dir: str, save_h5_too:bool=False) -> None:
"""Save ExperimentDataFrame tracking files to DLC csv format.
Only save tracking data, not other computed features.
Expand Down Expand Up @@ -588,7 +589,7 @@ def load(fn_in: str) -> pd.DataFrame:
return load_experiment(fn_in)

def save_movie(
self, label_columns, path_out: str, video_filenames=None
self, label_columns:list, path_out: str, video_filenames:list=None
) -> None: # pragma: no cover
"""Given columns indicating behavior predictions or whatever else, make a video
with these predictions overlaid.
Expand Down Expand Up @@ -655,7 +656,7 @@ def save_movie(
os.system(cmd)


def _create_from_dict(metadata, part_renamer, animal_renamer):
def _create_from_dict(metadata:dict, part_renamer:dict, animal_renamer:dict):
df = pd.DataFrame()
# req_cols = ['fps']
# Drop requirement this is provided. Just omit addition of time column, if fps omitted
Expand All @@ -678,7 +679,7 @@ def _create_from_dict(metadata, part_renamer, animal_renamer):
return df


def _create_from_list(input, part_renamer, animal_renamer, **kwargs):
def _create_from_list(input:list, part_renamer:dict, animal_renamer:dict, **kwargs):
if len(input) == 0:
return pd.DataFrame()
supported_exts = [".csv", ".h5", ".nwb", ".hdf5"]
Expand Down Expand Up @@ -740,7 +741,7 @@ def create_dataset(
return df


def _load_nwb(nwb_files, part_renamer, animal_renamer, set_as_label=True):
def _load_nwb(nwb_files:list, part_renamer:dict, animal_renamer:dict, set_as_label:bool=True):
metadata = {}
dfs = []
col_names_old = None
Expand Down Expand Up @@ -786,7 +787,7 @@ def _load_nwb(nwb_files, part_renamer, animal_renamer, set_as_label=True):
return df


def _load_tracks(df, part_renamer, animal_renamer, rescale=False):
def _load_tracks(df:pd.DataFrame, part_renamer:dict, animal_renamer:dict, rescale:bool=False):
"""Add tracks to DataFrame"""
dfs = []
col_names_old = None
Expand Down Expand Up @@ -872,12 +873,12 @@ def _load_tracks(df, part_renamer, animal_renamer, rescale=False):
return col_names


def _load_labels(df, col_name="label", set_as_label=False):
def _load_labels(df:pd.DataFrame, col_name:str="label", set_as_label:bool=False):
# For the moment only BORIS support
return _load_labels_boris(df, col_name, set_as_label)


def _load_labels_boris(df, prefix="label", set_as_label=False):
def _load_labels_boris(df:pd.DataFrame, prefix:str="label", set_as_label:bol=False):
"""Add behavior label data to DataFrame"""

label_cols = []
Expand Down Expand Up @@ -944,7 +945,7 @@ def get_sample_openfield_data():
return create_dataset(metadata)


def _make_dense_values_into_pairs(predictions, rate): # pragma: no cover
def _make_dense_values_into_pairs(predictions:list, rate:int): # pragma: no cover
# Put into start/stop pairs
pairs = []
in_pair = False
Expand Down

0 comments on commit d4cd002

Please sign in to comment.