Skip to content

Commit

Permalink
More type hints (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
benlansdell committed Feb 2, 2024
1 parent 3812c08 commit 4a03342
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 32 deletions.
12 changes: 6 additions & 6 deletions ethome/features/generic_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np


def _diff_within_group(df, sort_key, diff_col, **kwargs):
def _diff_within_group(df, sort_key: str, diff_col:str, **kwargs):
return df.groupby(sort_key)[diff_col].transform(lambda x: x.diff(**kwargs).bfill())


Expand Down Expand Up @@ -45,7 +45,7 @@ def compute_centerofmass_interanimal_distances(


def compute_centerofmass_interanimal_speed(
df: pd.DataFrame, raw_col_names: list, n_shifts=5, **kwargs
df: pd.DataFrame, raw_col_names: list, n_shifts:int=5, **kwargs
) -> pd.DataFrame:
"""Speeds between all animals' centroids"""
animal_setup = df.pose.animal_setup
Expand Down Expand Up @@ -123,7 +123,7 @@ def compute_centerofmass(


def compute_centerofmass_velocity(
df: pd.DataFrame, raw_col_names: list, n_shifts=5, bodyparts: list = [], **kwargs
df: pd.DataFrame, raw_col_names: list, n_shifts:int=5, bodyparts: list = [], **kwargs
) -> pd.DataFrame:
"""Velocity of all animals' centroids"""
animal_setup = df.pose.animal_setup
Expand Down Expand Up @@ -161,7 +161,7 @@ def compute_centerofmass_velocity(


def compute_part_velocity(
df: pd.DataFrame, raw_col_names: list, n_shifts=5, bodyparts: list = [], **kwargs
df: pd.DataFrame, raw_col_names: list, n_shifts:int=5, bodyparts: list = [], **kwargs
) -> pd.DataFrame:
"""Velocity of all animals' bodyparts"""
animal_setup = df.pose.animal_setup
Expand Down Expand Up @@ -198,7 +198,7 @@ def compute_part_velocity(


def compute_part_speed(
df: pd.DataFrame, raw_col_names: list, n_shifts=5, bodyparts: list = [], **kwargs
df: pd.DataFrame, raw_col_names: list, n_shifts:int=5, bodyparts: list = [], **kwargs
) -> pd.DataFrame:
"""Speed of all animals' bodyparts"""
animal_setup = df.pose.animal_setup
Expand Down Expand Up @@ -235,7 +235,7 @@ def compute_part_speed(


def compute_speed_features(
df: pd.DataFrame, raw_col_names: list, n_shifts=5, **kwargs
df: pd.DataFrame, raw_col_names: list, n_shifts:int=5, **kwargs
) -> pd.DataFrame:
"""Speeds between all body parts pairs (within and between animals)"""
animal_setup = df.pose.animal_setup
Expand Down
37 changes: 19 additions & 18 deletions ethome/features/mars_features.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd
import numpy as np

from typing import Callable
from ethome.io import XY_IDS

from itertools import product
Expand All @@ -12,9 +13,9 @@


# The decorator maker, so we can provide arguments
def augment_features(window_size=5, n_shifts=3, mode="shift"):
def augment_features(window_size:int=5, n_shifts:int=3, mode:str="shift"):
# The decorator
def decorator(feature_function):
def decorator(feature_function:Callable):
# What is called instead of the actual function, assumes the feature making
# function returns the names of the columns just made
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -85,7 +86,7 @@ def wrapper(*args, **kwargs):
from pandas.api.types import is_numeric_dtype


def boiler_plate(features_df):
def boiler_plate(features_df:pd.DataFrame):
reversemap = None

to_drop = ["Unnamed: 0"]
Expand All @@ -101,7 +102,7 @@ def boiler_plate(features_df):

@augment_features()
def _compute_centroid(
df, name, animal_setup, body_parts=None, n_shifts=3, mode="shift"
df:pd.DataFrame, name:str, animal_setup:dict, body_parts:list=None, n_shifts:int=3, mode:str="shift"
):
bodypart_ids = animal_setup["bodypart_ids"]
mouse_ids = animal_setup["mouse_ids"]
Expand All @@ -121,7 +122,7 @@ def _compute_centroid(

@augment_features()
def _compute_abs_angle(
df, name, animal_setup, bps, centroid=True, n_shifts=3, mode="shift"
df:pd.DataFrame, name:str, animal_setup:dict, bps:list, centroid:bool=True, n_shifts:int=3, mode:bool="shift"
):
mouse_ids = animal_setup["mouse_ids"]
df = df.copy()
Expand All @@ -142,7 +143,7 @@ def _compute_abs_angle(

@augment_features()
def _compute_rel_angle(
df, name, animal_setup, bps, centroid=False, n_shifts=3, mode="shift"
df:pd.DataFrame, name:str, animal_setup:dict, bps:list, centroid:bool=False, n_shifts:int=3, mode:str="shift"
):
mouse_ids = animal_setup["mouse_ids"]
df = df.copy()
Expand Down Expand Up @@ -172,7 +173,7 @@ def _compute_rel_angle(


@augment_features()
def _compute_ellipsoid(df, animal_setup, n_shifts=3, mode="shift"):
def _compute_ellipsoid(df:pd.DataFrame, animal_setup:dict, n_shifts:int=3, mode:str="shift"):
bodypart_ids = animal_setup["bodypart_ids"]
mouse_ids = animal_setup["mouse_ids"]
colnames = animal_setup["colnames"]
Expand Down Expand Up @@ -213,7 +214,7 @@ def _compute_ellipsoid(df, animal_setup, n_shifts=3, mode="shift"):


# Recall framerate is 30 fps
def _compute_kinematics(df, names, animal_setup, window_size=5, n_shifts=3):
def _compute_kinematics(df:pd.DataFrame, names:list, animal_setup:dict, window_size:int=5, n_shifts:int=3):
bodypart_ids = animal_setup["bodypart_ids"]
mouse_ids = animal_setup["mouse_ids"]
colnames = animal_setup["colnames"]
Expand All @@ -235,7 +236,7 @@ def _compute_kinematics(df, names, animal_setup, window_size=5, n_shifts=3):

@augment_features()
def _compute_relative_body_motions(
df, animal_setup, window_size=3, n_shifts=3, mode="shift"
df:pd.DataFrame, animal_setup:dict, window_size:int=3, n_shifts:int=3, mode:str="shift"
):
bodypart_ids = animal_setup["bodypart_ids"]
mouse_ids = animal_setup["mouse_ids"]
Expand Down Expand Up @@ -268,7 +269,7 @@ def _compute_relative_body_motions(


@augment_features()
def _compute_relative_body_angles(df, animal_setup, n_shifts=3, mode="shift"):
def _compute_relative_body_angles(df:pd.DataFrame, animal_setup:dict, n_shifts:int=3, mode:str="shift"):
bodypart_ids = animal_setup["bodypart_ids"]
mouse_ids = animal_setup["mouse_ids"]
colnames = animal_setup["colnames"]
Expand Down Expand Up @@ -313,7 +314,7 @@ def _compute_relative_body_angles(df, animal_setup, n_shifts=3, mode="shift"):


@augment_features()
def _compute_iou(df, animal_setup, n_shifts=3, mode="shift"):
def _compute_iou(df:pd.DataFrame, animal_setup:dict, n_shifts:int=3, mode:str="shift"):
bodypart_ids = animal_setup["bodypart_ids"]
mouse_ids = animal_setup["mouse_ids"]
colnames = animal_setup["colnames"]
Expand Down Expand Up @@ -355,7 +356,7 @@ def _compute_iou(df, animal_setup, n_shifts=3, mode="shift"):
# Which can change from video to video, train to test, etc. So perhaps not useful
@augment_features()
def _compute_cage_distances(
features_df, animal_setup, n_shifts=3, mode="shift"
features_df:pd.DataFrame, animal_setup:dict, n_shifts:int=3, mode:bool="shift"
): # pragma: no cover
bodypart_ids = animal_setup["bodypart_ids"]
mouse_ids = animal_setup["mouse_ids"]
Expand Down Expand Up @@ -388,7 +389,7 @@ def _compute_cage_distances(
return features_df


def make_features_distances(df, animal_setup):
def make_features_distances(df:pd.DataFrame, animal_setup:dict):
bodypart_ids = animal_setup["bodypart_ids"]
mouse_ids = animal_setup["mouse_ids"]
colnames = animal_setup["colnames"]
Expand Down Expand Up @@ -430,7 +431,7 @@ def make_features_distances(df, animal_setup):
return features_df


def make_features_mars(df, animal_setup, n_shifts=3, mode="shift"):
def make_features_mars(df:pd.DataFrame, animal_setup:dict, n_shifts:int=3, mode:str="shift"):
features_df = df.copy()

#######################
Expand Down Expand Up @@ -541,11 +542,11 @@ def make_features_mars(df, animal_setup, n_shifts=3, mode="shift"):
return features_df


def make_features_mars_distr(df, animal_setup):
def make_features_mars_distr(df:pd.DataFrame, animal_setup:dict):
return make_features_mars(df, animal_setup, n_shifts=3, mode="distr")


def make_features_mars_reduced(df, animal_setup, n_shifts=2, mode="diff"):
def make_features_mars_reduced(df:pd.DataFrame, animal_setup:dict, n_shifts:int=2, mode:str="diff"):
features_df = df.copy()

#######################
Expand Down Expand Up @@ -641,7 +642,7 @@ def make_features_mars_reduced(df, animal_setup, n_shifts=2, mode="diff"):
return features_df


def make_features_velocities(df, animal_setup, n_shifts=5): # pragma: no cover
def make_features_velocities(df:pd.DataFrame, animal_setup:dict, n_shifts:int=5): # pragma: no cover
bodypart_ids = animal_setup["bodypart_ids"]
mouse_ids = animal_setup["mouse_ids"]
colnames = animal_setup["colnames"]
Expand Down Expand Up @@ -699,7 +700,7 @@ def make_features_velocities(df, animal_setup, n_shifts=5): # pragma: no cover
return features_df


def make_features_social(df, animal_setup, n_shifts=3, mode="shift"):
def make_features_social(df:pd.DataFrame, animal_setup:dict, n_shifts:int=3, mode:str="shift"):
features_df = df.copy()
colnames = animal_setup["colnames"]

Expand Down
17 changes: 9 additions & 8 deletions ethome/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pandas as pd
import pickle
import numpy as np
from typing import Sequence
from itertools import product
from joblib import dump, load
import os
Expand All @@ -15,22 +16,22 @@
XYLIKELIHOOD_IDS = ["x", "y", "likelihood"]


def uniquifier(seq):
def uniquifier(seq:Sequence):
"""Return a sequence (e.g. list) with unique elements only, but maintaining original list order"""
seen = set()
seen_add = seen.add
return [x for x in seq if not (x in seen or seen_add(x))]


def _list_replace(ls, renamer):
def _list_replace(ls:list, renamer:dict):
"""Replace elements in a list according to provided dictionary"""
for i, word in enumerate(ls):
if word in renamer.keys():
ls[i] = renamer[word]
return ls


def save_sklearn_model(model, fn_out): # pragma: no cover
def save_sklearn_model(model, fn_out:str): # pragma: no cover
"""Save sklearn model to file
Args:
Expand All @@ -40,7 +41,7 @@ def save_sklearn_model(model, fn_out): # pragma: no cover
dump(model, fn_out)


def load_sklearn_model(fn_in): # pragma: no cover
def load_sklearn_model(fn_in: str): # pragma: no cover
"""Load sklearn model from file
Args:
Expand Down Expand Up @@ -152,7 +153,7 @@ def read_sleap_tracks(
# POSSIBILITY OF SUCH DAMAGE.


def _load_sleap_data(path, multi_animal=False):
def _load_sleap_data(path:str, multi_animal:bool=False):
"""loads sleap data h5 format from file path and returns it as pd.DataFrame
As sleap body parts (nodes) are not ordered in a particular way, we sort them alphabetically.
As sleap tracks do not export a score/likelihood but cut off automatically (nan), we are simulating a likelihood
Expand Down Expand Up @@ -408,7 +409,7 @@ def _read_DLC_tracks(
# SOFTWARE.


def _convert_nwb_to_h5_all(nwbfile):
def _convert_nwb_to_h5_all(nwbfile:str):
"""
Convert a NWB data file back to DeepLabCut's h5 data format.
Expand Down Expand Up @@ -546,7 +547,7 @@ def load_data(fn: str): # pragma: no cover


# Only used to making a test dataframe for testing and dev purposes
def _make_sample_dataframe(fn_out="sample_dataframe.pkl"): # pragma: no cover
def _make_sample_dataframe(fn_out:str="sample_dataframe.pkl"): # pragma: no cover
from ethome import create_dataset, create_metadata

cur_dir = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -662,7 +663,7 @@ def read_boris_annotation(
return ground_truth


def create_behavior_labels(boris_files):
def create_behavior_labels(boris_files:list):
"""Create behavior labels from BORIS exported csv files.
Args:
Expand Down

0 comments on commit 4a03342

Please sign in to comment.