Skip to content

Commit

Permalink
Improve docstrings with Tuple returns
Browse files Browse the repository at this point in the history
  • Loading branch information
prouast committed Oct 25, 2024
1 parent cc49fe1 commit 7348fe9
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 49 deletions.
10 changes: 8 additions & 2 deletions vitallens/methods/chrom.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,22 @@
from vitallens.signal import detrend_lambda_for_hr_response

class CHROMRPPGMethod(SimpleRPPGMethod):
"""The CHROM algorithm by De Haan and Jeanne (2013)"""
def __init__(
self,
config: dict
):
"""Initialize the `CHROMRPPGMethod`
Args:
config: The configuration dict
"""
super(CHROMRPPGMethod, self).__init__(config=config)
def algorithm(
self,
rgb: np.ndarray,
fps: float
):
) -> np.ndarray:
"""Use CHROM algorithm to estimate pulse from rgb signal.
Args:
Expand Down Expand Up @@ -77,7 +83,7 @@ def pulse_filter(
self,
sig: np.ndarray,
fps: float
):
) -> np.ndarray:
"""Apply filters to the estimated pulse signal.
Args:
Expand Down
6 changes: 6 additions & 0 deletions vitallens/methods/g.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,16 @@
from vitallens.signal import moving_average_size_for_hr_response

class GRPPGMethod(SimpleRPPGMethod):
"""The G algorithm by Verkruysse (2008)"""
def __init__(
self,
config: dict
):
"""Initialize the `GRPPGMethod`
Args:
config: The configuration dict
"""
super(GRPPGMethod, self).__init__(config=config)
def algorithm(
self,
Expand Down
6 changes: 6 additions & 0 deletions vitallens/methods/pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,16 @@
from vitallens.signal import moving_average_size_for_hr_response

class POSRPPGMethod(SimpleRPPGMethod):
"""The POS algorithm by Wang et al. (2017)"""
def __init__(
self,
config: dict
):
"""Initialize the `POSRPPGMethod`
Args:
config: The configuration dict
"""
super(POSRPPGMethod, self).__init__(config=config)
def algorithm(
self,
Expand Down
12 changes: 10 additions & 2 deletions vitallens/methods/rppg_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,21 @@
# SOFTWARE.

import abc
import numpy as np

class RPPGMethod(metaclass=abc.ABCMeta):
def __init__(self, config):
"""Abstract superclass for rPPG methods"""
def __init__(self, config: dict):
"""Initialize the `RPPGMethod`
Args:
config: The configuration dict
"""
self.fps_target = config['fps_target']
self.est_window_length = config['est_window_length']
self.est_window_overlap = config['est_window_overlap']
self.est_window_flexible = self.est_window_length == 0
@abc.abstractmethod
def __call__(self, video, fps, mode):
def __call__(self, frames, faces, fps, override_fps_target, override_global_parse):
"""Run inference. Abstract method to be implemented in subclasses."""
pass
19 changes: 14 additions & 5 deletions vitallens/methods/simple_rppg_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,16 @@
from vitallens.utils import parse_video_inputs, merge_faces

class SimpleRPPGMethod(RPPGMethod):
"""A simple rPPG method using a handcrafted algorithm based on RGB signal trace"""
def __init__(
self,
config: dict
):
"""Initialize the `SimpleRPPGMethod`
Args:
config: The configuration dict
"""
super(SimpleRPPGMethod, self).__init__(config=config)
self.model = config['model']
self.roi_method = config['roi_method']
Expand All @@ -45,12 +51,14 @@ def algorithm(
rgb: np.ndarray,
fps: float
):
"""The algorithm. Abstract method to be implemented by subclasses."""
pass
@abc.abstractmethod
def pulse_filter(self,
sig: np.ndarray,
fps: float
) -> np.ndarray:
"""The post-processing filter to be applied to estimated pulse signal. Abstract method to be implemented by subclasses."""
pass
def __call__(
self,
Expand All @@ -70,11 +78,12 @@ def __call__(
override_fps_target: Override the method's default inference fps (optional).
override_global_parse: Has no effect here.
Returns:
data: A dictionary with the values of the estimated vital signs.
unit: A dictionary with the units of the estimated vital signs.
conf: A dictionary with the confidences of the estimated vital signs.
note: A dictionary with notes on the estimated vital signs.
live: Dummy live confidence estimation (set to always 1). Shape (n_frames,)
Tuple of
- data: A dictionary with the values of the estimated vital signs.
- unit: A dictionary with the units of the estimated vital signs.
- conf: A dictionary with the confidences of the estimated vital signs.
- note: A dictionary with notes on the estimated vital signs.
- live: Dummy live confidence estimation (set to always 1). Shape (n_frames,)
"""
# Compute temporal union of ROIs
u_roi = merge_faces(faces)
Expand Down
35 changes: 25 additions & 10 deletions vitallens/methods/vitallens.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,18 @@
from vitallens.utils import probe_video_inputs, parse_video_inputs, check_faces_in_roi

class VitalLensRPPGMethod(RPPGMethod):
"""RPPG method using the VitalLens API for inference"""
def __init__(
self,
config: dict,
api_key: str
):
"""Initialize the `VitalLensRPPGMethod`
Args:
config: The configuration dict
api_key: The API key
"""
super(VitalLensRPPGMethod, self).__init__(config=config)
self.api_key = api_key
self.input_size = config['input_size']
Expand All @@ -71,11 +78,12 @@ def __call__(
override_global_parse: If True, always use global parse. If False, don't use global parse.
If None, choose based on video.
Returns:
out_data: The estimated data/value for each signal.
out_unit: The estimation unit for each signal.
out_conf: The estimation confidence for each signal.
out_note: An explanatory note for each signal.
live: The face live confidence. Shape (1, n_frames)
Tuple of
- out_data: The estimated data/value for each signal.
- out_unit: The estimation unit for each signal.
- out_conf: The estimation confidence for each signal.
- out_note: An explanatory note for each signal.
- live: The face live confidence. Shape (1, n_frames)
"""
inputs_shape, fps, video_issues = probe_video_inputs(video=frames, fps=fps)
video_fits_in_memory = enough_memory_for_ndarray(
Expand Down Expand Up @@ -189,10 +197,11 @@ def process_api_batch(
fps: The frame rate of the input video. Required if type(video) == np.ndarray
global_parse: Flag that indicates whether video has already been parsed.
Returns:
sig: Estimated signals. Shape (n_sig, n_frames)
conf: Estimation confidences. Shape (n_sig, n_frames)
live: Liveness estimation. Shape (n_frames,)
idxs: Indices in inputs that were processed. Shape (n_frames)
Tuple of
- sig: Estimated signals. Shape (n_sig, n_frames)
- conf: Estimation confidences. Shape (n_sig, n_frames)
- live: Liveness estimation. Shape (n_frames,)
- idxs: Indices in inputs that were processed. Shape (n_frames)
"""
logging.debug("Batch {}/{}...".format(batch, n_batches))
# Trim face detections to batch if necessary
Expand Down Expand Up @@ -256,7 +265,13 @@ def process_api_batch(
live_ds = np.asarray(response_body["face"]["confidence"])
idxs = np.asarray(idxs)
return sig_ds, conf_ds, live_ds, idxs
def postprocess(self, sig, fps, type='ppg', filter=True):
def postprocess(
self,
sig: np.ndarray,
fps: float,
type: str = 'ppg',
filter: bool = True
) -> np.ndarray:
"""Apply filters to the estimated signal.
Args:
sig: The estimated signal. Shape (n_frames,)
Expand Down
44 changes: 36 additions & 8 deletions vitallens/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,50 @@

def moving_average_size_for_hr_response(
f_s: Union[float, int]
):
) -> int:
"""Get the moving average window size for a signal with HR information sampled at a given frequency
Args:
f_s: The sampling frequency
Returns:
The moving average size in number of signal vals
"""
return moving_average_size_for_response(f_s, CALC_HR_MAX / SECONDS_PER_MINUTE)

def moving_average_size_for_rr_response(
f_s: Union[float, int]
):
) -> int:
"""Get the moving average window size for a signal with RR information sampled at a given frequency
Args:
f_s: The sampling frequency
Returns:
The moving average size in number of signal vals
"""
return moving_average_size_for_response(f_s, CALC_RR_MAX / SECONDS_PER_MINUTE)

def detrend_lambda_for_hr_response(
f_s: Union[float, int]
):
) -> int:
"""Get the detrending lambda parameter for a signal with HR information sampled at a given frequency
Args:
f_s: The sampling frequency
Returns:
The lambda parameter
"""
return int(0.1614*np.power(f_s, 1.9804))

def detrend_lambda_for_rr_response(
f_s: Union[float, int]
):
) -> int:
"""Get the detrending lambda parameter for a signal with RR information sampled at a given frequency
Args:
f_s: The sampling frequency
Returns:
The lambda parameter
"""
return int(4.4248*np.power(f_s, 2.1253))

def windowed_mean(
Expand Down Expand Up @@ -126,16 +154,16 @@ def windowed_freq(
def reassemble_from_windows(
x: np.ndarray,
idxs: np.ndarray
) -> np.ndarray:
) -> Tuple[np.ndarray, np.ndarray]:
"""Reassemble windowed data using corresponding idxs.
Args:
x: Data generated using a windowing operation. Shape (n_windows, n, window_size)
idxs: Indices of x in the original 1-d array. Shape (n_windows, window_size)
Returns:
out: Reassembled data. Shape (n, n_idxs)
idxs: Reassembled idxs. Shape (n_idxs)
Tuple of
- out: Reassembled data. Shape (n, n_idxs)
- idxs: Reassembled idxs. Shape (n_idxs,)
"""
x = np.asarray(x)
idxs = np.asarray(idxs)
Expand Down
28 changes: 17 additions & 11 deletions vitallens/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def nms(
iou_threshold: Threshold wrt iou for amount of box overlap. Scalar.
score_threshold: Threshold wrt score for removing boxes. Scalar.
Returns:
idxs: The selected indices padded with zero. Shape (n_batch, max_output_size)
Tuple of
- idxs: The selected indices padded with zero. Shape (n_batch, max_output_size)
- num_valid: Number of valid elements per batch. Shape (n_batch,)
"""
n_batch = boxes.shape[0]
# Split up box coordinates
Expand Down Expand Up @@ -108,8 +110,9 @@ def enforce_temporal_consistency(
info: Detection info: idx, scanned, scan_found_face, confidence. Shape (n_frames, n_faces, 4)
n_frames: Number of frames in the original input.
Returns:
boxes: Processed boxes in point form [0, 1], shape (n_frames, n_faces, 4)
info: Processed info: idx, scanned, scan_found_face, confidence. Shape (n_frames, n_faces, 4)
Tuple of
- boxes: Processed boxes in point form [0, 1], shape (n_frames, n_faces, 4)
- info: Processed info: idx, scanned, scan_found_face, confidence. Shape (n_frames, n_faces, 4)
"""
# Make sure that enough frames are present
if n_frames == 1:
Expand Down Expand Up @@ -162,8 +165,9 @@ def interpolate_unscanned_frames(
info: Detection info: idx, scanned, scan_found_face, interp_valid, confidence. Shape (n_frames, n_faces, 5)
n_frames: Number of frames in the original input.
Returns:
boxes: Processed boxes in point form [0, 1], shape (orig_n_frames, n_faces, 4)
info: Processed info: idx, scanned, scan_found_face, confidence. Shape (orig_n_frames, n_faces, 4)
Tuple of
- boxes: Processed boxes in point form [0, 1], shape (orig_n_frames, n_faces, 4)
- info: Processed info: idx, scanned, scan_found_face, confidence. Shape (orig_n_frames, n_faces, 4)
"""
_, n_faces, _ = info.shape
# Add rows corresponding to unscanned frames
Expand Down Expand Up @@ -220,8 +224,9 @@ def __call__(
inputs_shape: The shape of the input video as (n_frames, h, w, 3)
fps: Sampling frequency of the input video.
Returns:
boxes: Detected face boxes in relative flat point form (n_frames, n_faces, 4)
info: Tuple (idx, scanned, scan_found_face, interp_valid, confidence) (n_frames, n_faces, 5)
Tuple of
- boxes: Detected face boxes in relative flat point form (n_frames, n_faces, 4)
- info: Tuple (idx, scanned, scan_found_face, interp_valid, confidence) (n_frames, n_faces, 5)
"""
# Determine number of batches
n_frames = inputs_shape[0]
Expand Down Expand Up @@ -275,7 +280,7 @@ def scan_batch(
start: int,
end: int,
fps: float = None,
) -> Tuple[np.ndarray, np.ndarray]:
) -> Tuple[np.ndarray, np.ndarray, list]:
"""Parse video and run inference for one batch.
Args:
Expand All @@ -287,9 +292,10 @@ def scan_batch(
end: The index of the last frame of the video to analyze in this batch.
fps: Sampling frequency of the input video. Required if type(video) == np.ndarray.
Returns:
boxes: Scanned boxes in flat point form (n_frames, n_boxes, 4)
classes: Detection scores for boxes (n_frames, n_boxes, 2)
idxs: Indices of the scanned frames from the original video
Tuple of
- boxes: Scanned boxes in flat point form (n_frames, n_boxes, 4)
- classes: Detection scores for boxes (n_frames, n_boxes, 2)
- idxs: Indices of the scanned frames from the original video
"""
logging.debug("Batch {}/{}...".format(batch, n_batches))
# Parse the inputs
Expand Down
Loading

0 comments on commit 7348fe9

Please sign in to comment.