Skip to content

Commit

Permalink
Rely on probe_image_inputs and parse_image_inputs in prpy
Browse files Browse the repository at this point in the history
  • Loading branch information
prouast committed Nov 12, 2024
1 parent 633e870 commit 4088244
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 191 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
"importlib_resources",
"numpy",
"onnxruntime",
"prpy[ffmpeg,numpy_min]>=0.2.12",
"prpy[ffmpeg,numpy_min]>=0.2.14",
"python-dotenv",
"pyyaml",
"requests",
Expand Down
38 changes: 17 additions & 21 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
# SOFTWARE.

import numpy as np
from prpy.numpy.image import parse_image_inputs, probe_image_inputs
import pytest

import sys
sys.path.append('../vitallens-python')

from vitallens.client import Method
from vitallens.utils import load_config, probe_video_inputs, parse_video_inputs
from vitallens.utils import load_config
from vitallens.utils import merge_faces, check_faces, check_faces_in_roi

@pytest.mark.parametrize("method", [m for m in Method])
Expand All @@ -37,52 +38,47 @@ def test_load_config(method):
def test_probe_video_inputs(request, file):
if file:
test_video_path = request.getfixturevalue('test_video_path')
video_shape, fps, i = probe_video_inputs(test_video_path)
video_shape, fps, i = probe_image_inputs(test_video_path)
else:
test_video_ndarray = request.getfixturevalue('test_video_ndarray')
test_video_fps = request.getfixturevalue('test_video_fps')
video_shape, fps, i = probe_video_inputs(test_video_ndarray, fps=test_video_fps)
video_shape, fps, i = probe_image_inputs(test_video_ndarray, fps=test_video_fps)
assert video_shape == (360, 480, 768, 3)
assert fps == 30
assert i == False

def test_probe_video_inputs_no_file():
with pytest.raises(Exception):
_ = probe_video_inputs("does_not_exist", fps="fps")
_ = probe_image_inputs("does_not_exist", fps="fps")

def test_probe_video_inputs_wrong_fps(request):
with pytest.raises(Exception):
test_video_path = request.getfixturevalue('test_video_path')
_ = probe_video_inputs(test_video_path, fps="fps")
_ = probe_image_inputs(test_video_path, fps="fps")

def test_probe_video_inputs_no_fps(request):
test_video_ndarray = request.getfixturevalue('test_video_ndarray')
with pytest.raises(Exception):
_ = probe_video_inputs(test_video_ndarray)
_ = probe_image_inputs(test_video_ndarray)

def test_probe_video_inputs_wrong_dtype(request):
test_video_ndarray = request.getfixturevalue('test_video_ndarray')
with pytest.raises(Exception):
_ = probe_video_inputs(test_video_ndarray.astype(np.float32), fps=30.)
_ = probe_image_inputs(test_video_ndarray.astype(np.float32), fps=30.)

def test_probe_video_inputs_wrong_shape_1(request):
test_video_ndarray = request.getfixturevalue('test_video_ndarray')
with pytest.raises(Exception):
_ = probe_video_inputs(test_video_ndarray[np.newaxis], fps=30.)
_ = probe_image_inputs(test_video_ndarray[np.newaxis], fps=30.)

def test_probe_video_inputs_wrong_shape_2(request):
test_video_ndarray = request.getfixturevalue('test_video_ndarray')
with pytest.raises(Exception):
_ = probe_video_inputs(test_video_ndarray[...,0:1], fps=30.)

def test_probe_video_inputs_wrong_shape_3(request):
test_video_ndarray = request.getfixturevalue('test_video_ndarray')
with pytest.raises(Exception):
_ = probe_video_inputs(test_video_ndarray[:10], fps=30.)
_ = probe_image_inputs(test_video_ndarray[...,0:1], fps=30.)

def test_probe_video_inputs_wrong_type():
with pytest.raises(Exception):
_ = probe_video_inputs(12345, fps=30.)
_ = probe_image_inputs(12345, fps=30.)

@pytest.mark.parametrize("file", [True, False])
@pytest.mark.parametrize("roi", [None, (200, 0, 500, 350)])
Expand All @@ -91,13 +87,13 @@ def test_probe_video_inputs_wrong_type():
def test_parse_video_inputs(request, file, roi, target_size, target_fps):
if file:
test_video_path = request.getfixturevalue('test_video_path')
parsed, fps_in, video_shape_in, ds_factor, idxs = parse_video_inputs(
test_video_path, roi=roi, target_size=target_size, target_fps=target_fps)
parsed, fps_in, video_shape_in, ds_factor, idxs = parse_image_inputs(
inputs=test_video_path, roi=roi, target_size=target_size, target_fps=target_fps)
else:
test_video_ndarray = request.getfixturevalue('test_video_ndarray')
test_video_fps = request.getfixturevalue('test_video_fps')
parsed, fps_in, video_shape_in, ds_factor, idxs = parse_video_inputs(
test_video_ndarray, fps=test_video_fps, roi=roi, target_size=target_size,
parsed, fps_in, video_shape_in, ds_factor, idxs = parse_image_inputs(
inputs=test_video_ndarray, fps=test_video_fps, roi=roi, target_size=target_size,
target_fps=target_fps)
assert parsed.shape == (360 if target_fps is None else 360 // 2,
200 if target_size is not None else (350 if roi is not None else 480),
Expand All @@ -110,11 +106,11 @@ def test_parse_video_inputs(request, file, roi, target_size, target_fps):

def test_parse_video_inputs_no_file():
with pytest.raises(Exception):
_ = parse_video_inputs("does_not_exist")
_ = parse_image_inputs("does_not_exist")

def test_parse_video_inputs_wrong_type():
with pytest.raises(Exception):
_ = parse_video_inputs(12345, fps=30.)
_ = parse_image_inputs(12345, fps=30.)

def test_merge_faces():
np.testing.assert_equal(
Expand Down
11 changes: 6 additions & 5 deletions tests/test_vitallens.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import base64
import json
import numpy as np
from prpy.numpy.image import parse_image_inputs
import pytest
import requests
from unittest.mock import Mock, patch
Expand All @@ -30,7 +31,7 @@

from vitallens.constants import API_MAX_FRAMES, API_MIN_FRAMES, API_URL
from vitallens.methods.vitallens import VitalLensRPPGMethod
from vitallens.utils import load_config, parse_video_inputs
from vitallens.utils import load_config

def create_mock_response(
status_code: int,
Expand Down Expand Up @@ -137,8 +138,8 @@ def test_VitalLens_API_valid_response(request, process_signals):
test_video_ndarray = request.getfixturevalue('test_video_ndarray')
test_video_fps = request.getfixturevalue('test_video_fps')
test_video_faces = request.getfixturevalue('test_video_faces')
frames, *_ = parse_video_inputs(
video=test_video_ndarray, fps=test_video_fps, target_size=config['input_size'],
frames, *_ = parse_image_inputs(
inputs=test_video_ndarray, fps=test_video_fps, target_size=config['input_size'],
roi=test_video_faces[0].tolist(), library='prpy', scale_algorithm='bilinear')
headers = {"x-api-key": api_key}
payload = {"video": base64.b64encode(frames[:16].tobytes()).decode('utf-8')}
Expand Down Expand Up @@ -168,8 +169,8 @@ def test_VitalLens_API_wrong_api_key(request):
test_video_ndarray = request.getfixturevalue('test_video_ndarray')
test_video_fps = request.getfixturevalue('test_video_fps')
test_video_faces = request.getfixturevalue('test_video_faces')
frames, *_ = parse_video_inputs(
video=test_video_ndarray, fps=test_video_fps, target_size=config['input_size'],
frames, *_ = parse_image_inputs(
inputs=test_video_ndarray, fps=test_video_fps, target_size=config['input_size'],
roi=test_video_faces[0].tolist(), library='prpy', scale_algorithm='bilinear')
headers = {"x-api-key": "WRONG_API_KEY"}
payload = {"video": base64.b64encode(frames[:16].tobytes()).decode('utf-8')}
Expand Down
13 changes: 7 additions & 6 deletions vitallens/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import numpy as np
import os
from prpy.constants import SECONDS_PER_MINUTE
from prpy.numpy.image import probe_image_inputs
from typing import Union

from vitallens.constants import DISCLAIMER
Expand All @@ -36,7 +37,7 @@
from vitallens.methods.vitallens import VitalLensRPPGMethod
from vitallens.signal import windowed_freq, windowed_mean
from vitallens.ssd import FaceDetector
from vitallens.utils import load_config, probe_video_inputs, check_faces, convert_ndarray_to_list
from vitallens.utils import load_config, check_faces, convert_ndarray_to_list

class Method(IntEnum):
VITALLENS = 1
Expand Down Expand Up @@ -118,10 +119,10 @@ def __call__(
video file. Note that aggressive video encoding destroys the rPPG signal.
faces: Face boxes in flat point form, containing [x0, y0, x1, y1] coords.
Ignored unless detect_faces=False. Pass a list or np.ndarray of
shape (n_faces, n_frames, 4) for multiple faces detected on multiple frames,
shape (n_frames, 4) for single face detected on mulitple frames, or
shape (4,) for a single face detected globally, or
`None` to assume all frames already cropped to the same single face detection.
- shape (n_faces, n_frames, 4) for multiple faces detected on multiple frames,
- shape (n_frames, 4) for single face detected on mulitple frames, or
- shape (4,) for a single face detected globally, or
- `None` to assume all frames already cropped to the same single face detection.
fps: Sampling frequency of the input video. Required if type(video) == np.ndarray.
override_fps_target: Target fps at which rPPG inference should be run (optional).
If not provided, will use default of the selected method.
Expand Down Expand Up @@ -172,7 +173,7 @@ def __call__(
]
"""
# Probe inputs
inputs_shape, fps, _ = probe_video_inputs(video=video, fps=fps)
inputs_shape, fps, _ = probe_image_inputs(video, fps=fps, allow_image=False)
# TODO: Optimize performance of simple rPPG methods for long videos
# Warning if using long video
target_fps = override_fps_target if override_fps_target is not None else self.rppg.fps_target
Expand Down
1 change: 0 additions & 1 deletion vitallens/methods/rppg_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# SOFTWARE.

import abc
import numpy as np

class RPPGMethod(metaclass=abc.ABCMeta):
"""Abstract superclass for rPPG methods"""
Expand Down
12 changes: 7 additions & 5 deletions vitallens/methods/simple_rppg_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
import numpy as np
from prpy.constants import SECONDS_PER_MINUTE
from prpy.numpy.face import get_roi_from_det
from prpy.numpy.image import reduce_roi
from prpy.numpy.image import reduce_roi, parse_image_inputs
from prpy.numpy.signal import interpolate_cubic_spline, estimate_freq
from typing import Union, Tuple

from vitallens.constants import CALC_HR_MIN, CALC_HR_MAX
from vitallens.methods.rppg_method import RPPGMethod
from vitallens.utils import parse_video_inputs, merge_faces
from vitallens.utils import merge_faces

class SimpleRPPGMethod(RPPGMethod):
"""A simple rPPG method using a handcrafted algorithm based on RGB signal trace"""
Expand Down Expand Up @@ -89,9 +89,11 @@ def __call__(
u_roi = merge_faces(faces)
faces = faces - [u_roi[0], u_roi[1], u_roi[0], u_roi[1]]
# Parse the inputs
frames_ds, fps, inputs_shape, ds_factor, _ = parse_video_inputs(
video=frames, fps=fps, target_size=None, roi=u_roi,
target_fps=override_fps_target if override_fps_target is not None else self.fps_target)
frames_ds, fps, inputs_shape, ds_factor, _ = parse_image_inputs(
inputs=frames, fps=fps, roi=u_roi, target_size=None,
target_fps=override_fps_target if override_fps_target is not None else self.fps_target,
preserve_aspect_ratio=False, library='prpy', scale_algorithm='bilinear',
trim=None, allow_image=False, videodims=True)
assert inputs_shape[0] == faces.shape[0], "Need same number of frames as face detections"
faces_ds = faces[0::ds_factor]
assert frames_ds.shape[0] == faces_ds.shape[0], "Need same number of frames as face detections"
Expand Down
26 changes: 15 additions & 11 deletions vitallens/methods/vitallens.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as np
from prpy.constants import SECONDS_PER_MINUTE
from prpy.numpy.face import get_roi_from_det
from prpy.numpy.image import probe_image_inputs, parse_image_inputs
from prpy.numpy.signal import detrend, moving_average, standardize
from prpy.numpy.signal import interpolate_cubic_spline, estimate_freq
from prpy.numpy.utils import enough_memory_for_ndarray
Expand All @@ -39,7 +40,7 @@
from vitallens.signal import detrend_lambda_for_hr_response, detrend_lambda_for_rr_response
from vitallens.signal import moving_average_size_for_hr_response, moving_average_size_for_rr_response
from vitallens.signal import reassemble_from_windows
from vitallens.utils import probe_video_inputs, parse_video_inputs, check_faces_in_roi
from vitallens.utils import check_faces_in_roi

class VitalLensRPPGMethod(RPPGMethod):
"""RPPG method using the VitalLens API for inference"""
Expand Down Expand Up @@ -83,17 +84,18 @@ def __call__(
- out_unit: The estimation unit for each signal.
- out_conf: The estimation confidence for each signal.
- out_note: An explanatory note for each signal.
- live: The face live confidence. Shape (1, n_frames)
- live: The face live confidence. Shape (n_frames,)
"""
inputs_shape, fps, video_issues = probe_video_inputs(video=frames, fps=fps)
video_fits_in_memory = enough_memory_for_ndarray(
shape=(inputs_shape[0], self.input_size, self.input_size, 3), dtype=np.uint8)
inputs_shape, fps, video_issues = probe_image_inputs(frames, fps=fps)
# Check the number of frames to be processed
inputs_n = inputs_shape[0]
fps_target = override_fps_target if override_fps_target is not None else self.fps_target
expected_ds_factor = round(fps / fps_target)
expected_ds_n = math.ceil(inputs_n / expected_ds_factor)
# Check if we can parse the video globally
video_fits_in_memory = enough_memory_for_ndarray(
shape=(expected_ds_n, self.input_size, self.input_size, 3), dtype=np.uint8,
max_fraction_of_available_memory_to_use=0.1)
global_face = faces[np.argmin(np.linalg.norm(faces - np.median(faces, axis=0), axis=1))]
global_roi = get_roi_from_det(
global_face, roi_method=self.roi_method, clip_dims=(inputs_shape[2], inputs_shape[1]))
Expand All @@ -102,9 +104,10 @@ def __call__(
if override_global_parse is not None: global_parse = override_global_parse
if global_parse:
# Parse entire video for inference globally
frames, _, _, _, idxs = parse_video_inputs(
video=frames, fps=fps, target_size=self.input_size, roi=global_roi, target_fps=fps_target,
library='prpy', scale_algorithm='bilinear', dim_deltas=(API_OVERLAP, 0, 0))
frames, _, _, _, idxs = parse_image_inputs(
inputs=frames, fps=fps, roi=global_roi, target_size=self.input_size, target_fps=fps_target,
preserve_aspect_ratio=False, library='prpy', scale_algorithm='bilinear',
trim=None, allow_image=False, videodims=True)
# Longer videos are split up with small overlaps
n_splits = 1 if expected_ds_n <= API_MAX_FRAMES else math.ceil((expected_ds_n - API_MAX_FRAMES) / (API_MAX_FRAMES - API_OVERLAP)) + 1
split_len = expected_ds_n if n_splits == 1 else math.ceil((inputs_n + (n_splits-1) * API_OVERLAP * expected_ds_factor) / n_splits)
Expand Down Expand Up @@ -228,10 +231,11 @@ def process_api_batch(
idxs = list(range(0, inputs_shape[0], ds_factor))
else:
# Inputs have not been parsed globally. Parse the inputs
frames_ds, _, _, ds_factor, idxs = parse_video_inputs(
video=inputs, fps=fps, target_size=self.input_size, roi=roi, target_fps=fps_target,
frames_ds, _, _, ds_factor, idxs = parse_image_inputs(
inputs=inputs, fps=fps, roi=roi, target_size=self.input_size, target_fps=fps_target,
preserve_aspect_ratio=False, library='prpy', scale_algorithm='bilinear',
trim=(start, end) if start is not None and end is not None else None,
library='prpy', scale_algorithm='bilinear', dim_deltas=(API_OVERLAP, 0, 0))
allow_image=False, videodims=True)
# Make sure we have the correct number of frames
expected_n = math.ceil(((end-start) if start is not None and end is not None else inputs_shape[0]) / ds_factor)
if frames_ds.shape[0] != expected_n or len(idxs) != expected_n:
Expand Down
10 changes: 5 additions & 5 deletions vitallens/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import math
import numpy as np
import os
from prpy.numpy.image import parse_image_inputs
from prpy.numpy.signal import interpolate_vals
import sys
from typing import Tuple
Expand All @@ -32,8 +33,6 @@
else:
from importlib_resources import files

from vitallens.utils import parse_video_inputs

INPUT_SIZE = (240, 320)
MAX_SCAN_FRAMES = 60

Expand Down Expand Up @@ -299,9 +298,10 @@ def scan_batch(
"""
logging.debug("Batch {}/{}...".format(batch, n_batches))
# Parse the inputs
inputs, fps, _, _, idxs = parse_video_inputs(
video=inputs, fps=fps, target_size=INPUT_SIZE, target_fps=self.fs,
library='prpy', scale_algorithm='bilinear', trim=(start, end))
inputs, fps, _, _, idxs = parse_image_inputs(
inputs=inputs, fps=fps, roi=None, target_size=INPUT_SIZE, target_fps=self.fs,
preserve_aspect_ratio=False, library='prpy', scale_algorithm='bilinear',
trim=(start, end), allow_image=False, videodims=True)
# Forward pass
onnx_inputs = {"args_0": (inputs.astype(np.float32) - 127.0) / 128.0}
onnx_outputs = self.model.run(None, onnx_inputs)[0]
Expand Down
Loading

0 comments on commit 4088244

Please sign in to comment.