Skip to content

Commit

Permalink
Either parse video globally or not depending on circumstances; Add ro…
Browse files Browse the repository at this point in the history
…bustness for videos with issues which result in unexpected number of frames; Allow overwrite of global parse setting; Return more informative errors
  • Loading branch information
prouast committed Jul 23, 2024
1 parent 3733269 commit c3d9608
Show file tree
Hide file tree
Showing 10 changed files with 111 additions and 65 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
"importlib_resources",
"numpy",
"onnxruntime",
"prpy[ffmpeg,numpy_min]>=0.2.8",
"prpy[ffmpeg,numpy_min]==0.2.10",
"python-dotenv",
"pyyaml",
"requests",
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_video_fps():

@pytest.fixture(scope='session')
def test_video_shape():
_, n, w, h, _, _, _ = probe_video(TEST_VIDEO_PATH)
_, n, w, h, *_ = probe_video(TEST_VIDEO_PATH)
return (n, h, w, 3)

@pytest.fixture(scope='session')
Expand Down
1 change: 0 additions & 1 deletion tests/test_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# SOFTWARE.

import numpy as np
from prpy.ffmpeg.probe import probe_video
import pytest

import sys
Expand Down
5 changes: 3 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ def test_load_config(method):
def test_probe_video_inputs(request, file):
if file:
test_video_path = request.getfixturevalue('test_video_path')
video_shape, fps = probe_video_inputs(test_video_path)
video_shape, fps, i = probe_video_inputs(test_video_path)
else:
test_video_ndarray = request.getfixturevalue('test_video_ndarray')
test_video_fps = request.getfixturevalue('test_video_fps')
video_shape, fps = probe_video_inputs(test_video_ndarray, fps=test_video_fps)
video_shape, fps, i = probe_video_inputs(test_video_ndarray, fps=test_video_fps)
assert video_shape == (360, 480, 768, 3)
assert fps == 30
assert i == False

def test_probe_video_inputs_no_file():
with pytest.raises(Exception):
Expand Down
11 changes: 7 additions & 4 deletions tests/test_vitallens.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ def create_mock_api_response(
@pytest.mark.parametrize("file", [True, False])
@pytest.mark.parametrize("override_fps_target", [None, 15, 10])
@pytest.mark.parametrize("long", [False, True])
@pytest.mark.parametrize("override_global_parse", [False, True])
@patch('requests.post', side_effect=create_mock_api_response)
def test_VitalLensRPPGMethod_mock(mock_post, request, file, override_fps_target, long):
def test_VitalLensRPPGMethod_mock(mock_post, request, file, long, override_fps_target, override_global_parse):
if long and file:
pytest.skip("Skip because parameter combination does not work")
config = load_config("vitallens.yaml")
Expand All @@ -108,15 +109,17 @@ def test_VitalLensRPPGMethod_mock(mock_post, request, file, override_fps_target,
if file:
data, unit, conf, note, live = method(
frames=test_video_path, faces=test_video_faces,
override_fps_target=override_fps_target)
else:
override_fps_target=override_fps_target,
override_global_parse=override_global_parse)
else:
if long:
n_repeats = (API_MAX_FRAMES * 3) // test_video_ndarray.shape[0] + 1
test_video_ndarray = np.repeat(test_video_ndarray, repeats=n_repeats, axis=0)
test_video_faces = np.repeat(test_video_faces, repeats=n_repeats, axis=0)
data, unit, conf, note, live = method(
frames=test_video_ndarray, faces=test_video_faces,
fps=test_video_fps, override_fps_target=override_fps_target)
fps=test_video_fps, override_fps_target=override_fps_target,
override_global_parse=override_global_parse)
assert all(key in data for key in method.signals)
assert all(key in unit for key in method.signals)
assert all(key in conf for key in method.signals)
Expand Down
9 changes: 7 additions & 2 deletions vitallens/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __call__(
faces: Union[np.ndarray, list] = None,
fps: float = None,
override_fps_target: float = None,
override_global_parse: bool = None,
export_filename: str = None
) -> list:
"""Run rPPG inference.
Expand All @@ -124,6 +125,8 @@ def __call__(
fps: Sampling frequency of the input video. Required if type(video) == np.ndarray.
override_fps_target: Target fps at which rPPG inference should be run (optional).
If not provided, will use default of the selected method.
override_global_parse: If True, always use global parse. If False, don't use global parse.
If None, choose based on video.
export_filename: Filename for json export if applicable.
Returns:
result: Analysis results as a list of faces in the following format:
Expand Down Expand Up @@ -169,7 +172,7 @@ def __call__(
]
"""
# Probe inputs
inputs_shape, fps = probe_video_inputs(video=video, fps=fps)
inputs_shape, fps, _ = probe_video_inputs(video=video, fps=fps)
# TODO: Optimize performance of simple rPPG methods for long videos
# Warning if using long video
target_fps = override_fps_target if override_fps_target is not None else self.rppg.fps_target
Expand All @@ -194,7 +197,9 @@ def __call__(
for face in faces:
# Run selected rPPG method
data, unit, conf, note, live = self.rppg(
frames=video, faces=face, fps=fps, override_fps_target=override_fps_target)
frames=video, faces=face, fps=fps,
override_fps_target=override_fps_target,
override_global_parse=override_global_parse)
# Parse face results
face_result = {'face': {
'coordinates': face,
Expand Down
3 changes: 3 additions & 0 deletions vitallens/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,8 @@
if 'API_URL' in os.environ:
API_URL = os.getenv('API_URL')

# Video error message
VIDEO_PARSE_ERROR = "Unable to parse input video. There may be an issue with the video file."

# Disclaimer message
DISCLAIMER = "The provided values are estimates and should be interpreted according to the provided confidence levels ranging from 0 to 1. The VitalLens API is not a medical device and its estimates are not intended for any medical purposes."
4 changes: 3 additions & 1 deletion vitallens/methods/simple_rppg_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def __call__(
frames: Union[np.ndarray, str],
faces: np.ndarray,
fps: float,
override_fps_target: float = None
override_fps_target: float = None,
override_global_parse: float = None,
) -> Tuple[dict, dict, dict, dict, np.ndarray]:
"""Estimate pulse signal from video frames using the subclass algorithm.
Expand All @@ -66,6 +67,7 @@ def __call__(
faces: The face detection boxes as np.int64. Shape (n_frames, 4) in form (x0, y0, x1, y1)
fps: The rate at which video was sampled.
override_fps_target: Override the method's default inference fps (optional).
override_global_parse: Has no effect here.
Returns:
data: A dictionary with the values of the estimated vital signs.
unit: A dictionary with the units of the estimated vital signs.
Expand Down
107 changes: 68 additions & 39 deletions vitallens/methods/vitallens.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from prpy.numpy.face import get_roi_from_det
from prpy.numpy.signal import detrend, moving_average, standardize
from prpy.numpy.signal import interpolate_cubic_spline, estimate_freq
from prpy.numpy.utils import enough_memory_for_ndarray
import json
import logging
import requests
Expand All @@ -38,7 +39,7 @@
from vitallens.signal import detrend_lambda_for_hr_response, detrend_lambda_for_rr_response
from vitallens.signal import moving_average_size_for_hr_response, moving_average_size_for_rr_response
from vitallens.signal import reassemble_from_windows
from vitallens.utils import probe_video_inputs, parse_video_inputs
from vitallens.utils import probe_video_inputs, parse_video_inputs, check_faces_in_roi

class VitalLensRPPGMethod(RPPGMethod):
def __init__(
Expand All @@ -56,7 +57,8 @@ def __call__(
frames: Union[np.ndarray, str],
faces: np.ndarray,
fps: float = None,
override_fps_target: float = None
override_fps_target: float = None,
override_global_parse: bool = None
) -> Tuple[dict, dict, dict, dict, np.ndarray]:
"""Estimate vitals from video frames using the VitalLens API.
Expand All @@ -66,44 +68,54 @@ def __call__(
faces: The face detection boxes as np.int64. Shape (n_frames, 4) in form (x0, y0, x1, y1)
fps: The rate at which video was sampled.
override_fps_target: Override the method's default inference fps (optional).
override_global_parse: If True, always use global parse. If False, don't use global parse.
If None, choose based on video.
Returns:
out_data: The estimated data/value for each signal.
out_unit: The estimation unit for each signal.
out_conf: The estimation confidence for each signal.
out_note: An explanatory note for each signal.
live: The face live confidence. Shape (1, n_frames)
"""
inputs_shape, fps = probe_video_inputs(video=frames, fps=fps)
inputs_shape, fps, video_issues = probe_video_inputs(video=frames, fps=fps)
video_fits_in_memory = enough_memory_for_ndarray(
shape=(inputs_shape[0], self.input_size, self.input_size, 3), dtype=np.uint8)
# Check the number of frames to be processed
inputs_n = inputs_shape[0]
fps_target = override_fps_target if override_fps_target is not None else self.fps_target
expected_ds_factor = round(fps / fps_target)
expected_ds_n = math.ceil(inputs_n / expected_ds_factor)
if expected_ds_n <= API_MAX_FRAMES:
# API supports up to MAX_FRAMES at once - process all frames
sig_ds, conf_ds, live_ds, idxs = self.process_api_batch(
batch=1, n_batches=1, inputs=frames, inputs_shape=inputs_shape,
faces=faces, fps_target=fps_target, fps=fps)
else:
# Longer videos are split up with small overlaps
n_splits = math.ceil((expected_ds_n - API_MAX_FRAMES) / (API_MAX_FRAMES - API_OVERLAP)) + 1
split_len = math.ceil((inputs_n + (n_splits-1) * API_OVERLAP * expected_ds_factor) / n_splits)
# start_idxs = [i for i in range(0, expected_ds_len - n_splits * API_OVERLAP, split_len - API_OVERLAP)]
start_idxs = [i * (split_len - API_OVERLAP * expected_ds_factor) for i in range(n_splits)]
end_idxs = [min(start + split_len, inputs_n) for start in start_idxs]
start_idxs = [max(0, end - split_len) for end in end_idxs]
logging.info("Running inference for {} frames using {} requests...".format(expected_ds_n, n_splits))
# Process the splits in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
results = list(executor.map(lambda i: self.process_api_batch(
batch=i, n_batches=n_splits, inputs=frames, inputs_shape=inputs_shape,
faces=faces, fps_target=fps_target, start=start_idxs[i], end=end_idxs[i],
fps=fps), range(n_splits)))
# Aggregate the results
sig_results, conf_results, live_results, idxs_results = zip(*results)
sig_ds, idxs = reassemble_from_windows(x=sig_results, idxs=idxs_results)
conf_ds, _ = reassemble_from_windows(x=conf_results, idxs=idxs_results)
live_ds = reassemble_from_windows(x=np.asarray(live_results)[:,np.newaxis], idxs=idxs_results)[0][0]
# Check if we can parse the video globally
global_face = faces[np.argmin(np.linalg.norm(faces - np.median(faces, axis=0), axis=1))]
global_roi = get_roi_from_det(
global_face, roi_method=self.roi_method, clip_dims=(inputs_shape[2], inputs_shape[1]))
global_faces_in_roi = check_faces_in_roi(faces=faces, roi=global_roi)
global_parse = isinstance(frames, str) and video_fits_in_memory and (video_issues or global_faces_in_roi)
if override_global_parse is not None: global_parse = override_global_parse
if global_parse:
# Parse entire video for inference globally
frames, _, _, _, idxs = parse_video_inputs(
video=frames, fps=fps, target_size=self.input_size, roi=global_roi, target_fps=fps_target,
library='prpy', scale_algorithm='bilinear', dim_deltas=(API_OVERLAP, 0, 0))
# Longer videos are split up with small overlaps
n_splits = 1 if expected_ds_n <= API_MAX_FRAMES else math.ceil((expected_ds_n - API_MAX_FRAMES) / (API_MAX_FRAMES - API_OVERLAP)) + 1
split_len = expected_ds_n if n_splits == 1 else math.ceil((inputs_n + (n_splits-1) * API_OVERLAP * expected_ds_factor) / n_splits)
start_idxs = [i * (split_len - API_OVERLAP * expected_ds_factor) for i in range(n_splits)]
end_idxs = [min(start + split_len, inputs_n) for start in start_idxs]
start_idxs = [max(0, end - split_len) for end in end_idxs]
logging.info("Running inference for {} frames using {} request(s)...".format(expected_ds_n, n_splits))
# Process the splits in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
results = list(executor.map(lambda i: self.process_api_batch(
batch=i, n_batches=n_splits, inputs=frames, inputs_shape=inputs_shape,
faces=faces, fps_target=fps_target, fps=fps, global_parse=global_parse,
start=None if n_splits == 1 else start_idxs[i],
end=None if n_splits == 1 else end_idxs[i]), range(n_splits)))
# Aggregate the results
sig_results, conf_results, live_results, idxs_results = zip(*results)
sig_ds, idxs = reassemble_from_windows(x=sig_results, idxs=idxs_results)
conf_ds, _ = reassemble_from_windows(x=conf_results, idxs=idxs_results)
live_ds = reassemble_from_windows(x=np.asarray(live_results)[:,np.newaxis], idxs=idxs_results)[0][0]
# Interpolate to original sampling rate (n_frames,)
sig = interpolate_cubic_spline(
x=idxs, y=sig_ds, xs=np.arange(inputs_n), axis=1)
Expand Down Expand Up @@ -159,7 +171,8 @@ def process_api_batch(
fps_target: float,
start: int = None,
end: int = None,
fps: float = None
fps: float = None,
global_parse: bool = False
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Process a batch of frames with the VitalLens API.
Expand All @@ -168,12 +181,13 @@ def process_api_batch(
n_batches: The total number of batches.
inputs: The video to analyze. Either a np.ndarray of shape (n_frames, h, w, 3)
with a sequence of frames in unscaled uint8 RGB format, or a path to a video file.
inputs_shape: The shape of the inputs.
inputs_shape: The original shape of the inputs.
faces: The face detection boxes as np.int64. Shape (n_frames, 4) in form (x0, y0, x1, y1)
fps_target: The target frame rate at which to run inference.
start: The index of first frame of the video to analyze in this batch.
end: The index of the last frame of the video to analyze in this batch.
fps: The frame rate of the input video. Required if type(video) == np.ndarray
global_parse: Flag that indicates whether video has already been parsed.
Returns:
sig: Estimated signals. Shape (n_sig, n_frames)
conf: Estimation confidences. Shape (n_sig, n_frames)
Expand All @@ -188,16 +202,31 @@ def process_api_batch(
face = faces[np.argmin(np.linalg.norm(faces - np.median(faces, axis=0), axis=1))]
roi = get_roi_from_det(
face, roi_method=self.roi_method, clip_dims=(inputs_shape[2], inputs_shape[1]))
if np.any(np.logical_or(
(faces[:,2] - faces[:,0]) * 0.5 < np.maximum(0, faces[:,0] - roi[0]) + np.maximum(0, faces[:,2] - roi[2]),
(faces[:,3] - faces[:,1]) * 0.5 < np.maximum(0, faces[:,1] - roi[1]) + np.maximum(0, faces[:,3] - roi[3]))):
if not check_faces_in_roi(faces=faces, roi=roi):
logging.warning("Large face movement detected")
# Parse the inputs
frames_ds, fps, inputs_shape, _, idxs = parse_video_inputs(
video=inputs, fps=fps, target_size=self.input_size, roi=roi, target_fps=fps_target,
trim=(start, end) if start is not None and end is not None else None,
library='prpy', scale_algorithm='bilinear')
assert frames_ds.shape[0] <= API_MAX_FRAMES
if global_parse:
# Inputs have already been parsed globally.
assert isinstance(inputs, np.ndarray)
frames_ds = inputs
ds_factor = math.ceil(inputs_shape[0] / frames_ds.shape[0])
# Trim frames to batch if necessary
if start is not None and end is not None:
start_ds = start // ds_factor
end_ds = math.ceil((end-start)/ds_factor) + start_ds
frames_ds = frames_ds[start_ds:end_ds]
idxs = list(range(start, end, ds_factor))
else:
idxs = list(range(0, inputs_shape[0], ds_factor))
else:
# Inputs have not been parsed globally. Parse the inputs
frames_ds, _, _, ds_factor, idxs = parse_video_inputs(
video=inputs, fps=fps, target_size=self.input_size, roi=roi, target_fps=fps_target,
trim=(start, end) if start is not None and end is not None else None,
library='prpy', scale_algorithm='bilinear', dim_deltas=(API_OVERLAP, 0, 0))
# Make sure we have the correct number of frames
expected_n = math.ceil(((end-start) if start is not None and end is not None else inputs_shape[0]) / ds_factor)
if frames_ds.shape[0] != expected_n or len(idxs) != expected_n:
raise ValueError("Unexpected number of frames returned. Try to set `override_global_parse` to `True` or `False`.")
# Prepare API header and payload
headers = {"x-api-key": self.api_key}
payload = {"video": base64.b64encode(frames_ds.tobytes()).decode('utf-8')}
Expand Down
Loading

0 comments on commit c3d9608

Please sign in to comment.