diff --git a/examples/live.py b/examples/live.py index 2b2635b..5b014f0 100644 --- a/examples/live.py +++ b/examples/live.py @@ -2,6 +2,7 @@ import concurrent.futures import cv2 import numpy as np +from prpy.constants import SECONDS_PER_MINUTE from prpy.numpy.face import get_upper_body_roi_from_det from prpy.numpy.signal import estimate_freq import sys @@ -13,6 +14,7 @@ from vitallens import VitalLens, Mode, Method from vitallens.buffer import SignalBuffer, MultiSignalBuffer from vitallens.constants import API_MIN_FRAMES +from vitallens.constants import CALC_HR_MIN, CALC_HR_MAX, CALC_RR_MIN, CALC_RR_MAX def draw_roi(frame, roi): roi = np.asarray(roi).astype(np.int32) @@ -49,9 +51,10 @@ def draw_fps(frame, fps, text, draw_area_bl_x, draw_area_bl_y): cv2.putText(frame, text='{}: {:.1f}'.format(text, fps), org=(draw_area_bl_x, draw_area_bl_y), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.6, color=(0,255,0), thickness=1) -def draw_vital(frame, sig, text, sig_name, fps, mult, color, draw_area_bl_x, draw_area_bl_y): +def draw_vital(frame, sig, text, sig_name, fps, color, draw_area_bl_x, draw_area_bl_y): if sig_name in sig: - val = estimate_freq(x=sig[sig_name], f_s=fps, f_res=0.0167, method='periodogram') * mult + f_range = (CALC_HR_MIN/SECONDS_PER_MINUTE, CALC_HR_MAX/SECONDS_PER_MINUTE) if 'heart' in sig_name else (CALC_RR_MIN/SECONDS_PER_MINUTE, CALC_RR_MAX/SECONDS_PER_MINUTE) + val = estimate_freq(x=sig[sig_name], f_s=fps, f_res=0.1/SECONDS_PER_MINUTE, f_range=f_range, method='periodogram') * SECONDS_PER_MINUTE cv2.putText(frame, text='{}: {:.1f}'.format(text, val), org=(draw_area_bl_x, draw_area_bl_y), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.6, color=color, thickness=1) @@ -131,7 +134,7 @@ def run(args): # Start next prediction if len(frame_buffer) >= (API_MIN_FRAMES if args.method == Method.VITALLENS else 1): n_frames = len(frame_buffer) - future = executor.submit(vl, frame_buffer.copy(), fps) + executor.submit(vl, frame_buffer.copy(), fps) frame_buffer.clear() # Sample frames if i % ds_factor == 0: @@ -149,8 +152,8 @@ def run(args): draw_area_tl_x=roi[2]+20, draw_area_tl_y=int(roi[1]+(roi[3]-roi[1])/2.0), color=(255, 0, 0)) draw_fps(frame, fps=fps, text="fps", draw_area_bl_x=roi[0], draw_area_bl_y=roi[3]+20) draw_fps(frame, fps=p_fps, text="p_fps", draw_area_bl_x=int(roi[0]+0.4*(roi[2]-roi[0])), draw_area_bl_y=roi[3]+20) - draw_vital(frame, sig=signals, text="hr [bpm]", sig_name='ppg_waveform_sig', fps=fps, mult=60., color=(0,0,255), draw_area_bl_x=roi[2]+20, draw_area_bl_y=int(roi[1]+(roi[3]-roi[1])/2.0)) - draw_vital(frame, sig=signals, text="rr [rpm]", sig_name='respiratory_waveform_sig', fps=fps, mult=60., color=(255,0,0), draw_area_bl_x=roi[2]+20, draw_area_bl_y=roi[3]) + draw_vital(frame, sig=signals, text="hr [bpm]", sig_name='ppg_waveform_sig', fps=fps, color=(0,0,255), draw_area_bl_x=roi[2]+20, draw_area_bl_y=int(roi[1]+(roi[3]-roi[1])/2.0)) + draw_vital(frame, sig=signals, text="rr [rpm]", sig_name='respiratory_waveform_sig', fps=fps, color=(255,0,0), draw_area_bl_x=roi[2]+20, draw_area_bl_y=roi[3]) cv2.imshow('Live', frame) c = cv2.waitKey(1) if c == 27: diff --git a/tests/test_vitallens.py b/tests/test_vitallens.py index 13d6f82..2b5531a 100644 --- a/tests/test_vitallens.py +++ b/tests/test_vitallens.py @@ -135,7 +135,8 @@ def test_VitalLensRPPGMethod_mock(mock_post, request, file, long, override_fps_t assert live.shape == (test_video_ndarray.shape[0],) @pytest.mark.parametrize("process_signals", [True, False]) -def test_VitalLens_API_valid_response(request, process_signals): +@pytest.mark.parametrize("n_frames", [16, 250]) +def test_VitalLens_API_valid_response(request, process_signals, n_frames): config = load_config("vitallens.yaml") api_key = request.getfixturevalue('test_dev_api_key') test_video_ndarray = request.getfixturevalue('test_video_ndarray') @@ -145,7 +146,7 @@ def test_VitalLens_API_valid_response(request, process_signals): inputs=test_video_ndarray, fps=test_video_fps, target_size=config['input_size'], roi=test_video_faces[0].tolist(), library='prpy', scale_algorithm='bilinear') headers = {"x-api-key": api_key} - payload = {"video": base64.b64encode(frames[:16].tobytes()).decode('utf-8')} + payload = {"video": base64.b64encode(frames[:n_frames].tobytes()).decode('utf-8')} if process_signals: payload['fps'] = str(30) response = requests.post(API_URL, headers=headers, json=payload) response_body = json.loads(response.text) @@ -157,13 +158,14 @@ def test_VitalLens_API_valid_response(request, process_signals): ppg_waveform_conf = np.asarray(response_body["vital_signs"]["ppg_waveform"]["confidence"]) resp_waveform_data = np.asarray(response_body["vital_signs"]["respiratory_waveform"]["data"]) resp_waveform_conf = np.asarray(response_body["vital_signs"]["respiratory_waveform"]["confidence"]) - assert ppg_waveform_data.shape == (16,) - assert ppg_waveform_conf.shape == (16,) - assert resp_waveform_data.shape == (16,) - assert resp_waveform_conf.shape == (16,) - assert all((key in vital_signs) if process_signals else (key not in vital_signs) for key in ["heart_rate", "respiratory_rate"]) + assert ppg_waveform_data.shape == (n_frames,) + assert ppg_waveform_conf.shape == (n_frames,) + assert resp_waveform_data.shape == (n_frames,) + assert resp_waveform_conf.shape == (n_frames,) + t = n_frames/test_video_fps + assert all((key in vital_signs) if (process_signals and t > 8.) else (key not in vital_signs) for key in ["heart_rate", "respiratory_rate"]) live = np.asarray(response_body["face"]["confidence"]) - assert live.shape == (16,) + assert live.shape == (n_frames,) state = np.asarray(response_body["state"]["data"]) assert state.shape == (2, 128) diff --git a/vitallens/client.py b/vitallens/client.py index 0e1eb1c..94790e0 100644 --- a/vitallens/client.py +++ b/vitallens/client.py @@ -173,8 +173,8 @@ def __call__( # Probe inputs if self.mode == Mode.BURST and not isinstance(video, np.ndarray): raise ValueError("Must provide `np.ndarray` inputs for burst mode.") - if self.mode == Mode.BURST and video.shape[0] > API_MAX_FRAMES: - raise ValueError(f"Maximum number of frames in burst mode is {API_MAX_FRAMES}, but received {video.shape[0]}.") + if self.mode == Mode.BURST and video.shape[0] > (API_MAX_FRAMES - self.rppg.n_inputs + 1): + raise ValueError(f"Maximum number of frames in burst mode is {API_MAX_FRAMES - self.rppg.n_inputs + 1}, but received {video.shape[0]}.") inputs_shape, fps, _ = probe_image_inputs(video, fps=fps, allow_image=False) # TODO: Optimize performance of simple rPPG methods for long videos # Warning if using long video diff --git a/vitallens/configs/vitallens.yaml b/vitallens/configs/vitallens.yaml index 88c6610..06fdd18 100644 --- a/vitallens/configs/vitallens.yaml +++ b/vitallens/configs/vitallens.yaml @@ -27,6 +27,8 @@ model: 'vitallens' # Size of the input input_size: 40 +# Number of inputs +n_inputs: 4 # List estimated signals signals: ['heart_rate', 'respiratory_rate', 'ppg_waveform', 'respiratory_waveform'] diff --git a/vitallens/methods/rppg_method.py b/vitallens/methods/rppg_method.py index b6566e2..4d85743 100644 --- a/vitallens/methods/rppg_method.py +++ b/vitallens/methods/rppg_method.py @@ -37,6 +37,7 @@ def __init__( """ self.fps_target = config['fps_target'] self.op_mode = mode + self.n_inputs = 1 self.est_window_length = config['est_window_length'] self.est_window_overlap = config['est_window_overlap'] self.est_window_flexible = self.est_window_length == 0 diff --git a/vitallens/methods/vitallens.py b/vitallens/methods/vitallens.py index 888663b..0e0917c 100644 --- a/vitallens/methods/vitallens.py +++ b/vitallens/methods/vitallens.py @@ -60,10 +60,12 @@ def __init__( self.api_key = api_key self.model = config['model'] self.input_size = config['input_size'] + self.n_inputs = config['n_inputs'] self.roi_method = config['roi_method'] self.signals = config['signals'] if mode == Mode.BURST: self.state = None + self.input_buffer = None def __call__( self, inputs: Union[np.ndarray, str], @@ -140,8 +142,9 @@ def __call__( x=idxs, y=conf_ds, xs=np.arange(inputs_n), axis=1) live = interpolate_cubic_spline( x=idxs, y=live_ds, xs=np.arange(inputs_n), axis=0) - # Filter (2, n_frames) - sig = np.asarray([self.postprocess(p, fps, type=name) for p, name in zip(sig, ['ppg', 'resp'])]) + # Filter only in batch mode (2, n_frames) + if self.op_mode == Mode.BATCH: + sig = np.asarray([self.postprocess(p, fps, type=name) for p, name in zip(sig, ['ppg', 'resp'])]) # Assemble and return the results return assemble_results(sig=sig, conf=conf, @@ -209,6 +212,16 @@ def process_api_batch( else: idxs = list(range(0, inputs_shape[0], ds_factor)) else: + # Buffer inputs for burst mode + if self.op_mode == Mode.BURST: + # Inputs in burst mode are always np.ndarray + if self.state is not None: + # State has been initialized + assert self.input_buffer is not None + if inputs.shape[1:] != self.input_buffer.shape[1:]: + raise ValueError("In burst mode, input dimensions must be consistent.") + inputs = np.concatenate([self.input_buffer, inputs], axis=0) + self.input_buffer = inputs[-(self.n_inputs-1):] # Inputs have not been parsed globally. Parse the inputs frames_ds, _, _, ds_factor, idxs = parse_image_inputs( inputs=inputs, fps=fps, roi=roi, target_size=self.input_size, target_fps=fps_target, @@ -216,14 +229,21 @@ def process_api_batch( trim=(start, end) if start is not None and end is not None else None, allow_image=False, videodims=True) # Make sure we have the correct number of frames + idxs = np.asarray(idxs) expected_n = math.ceil(((end-start) if start is not None and end is not None else inputs_shape[0]) / ds_factor) - if frames_ds.shape[0] != expected_n or len(idxs) != expected_n: + if (self.op_mode == Mode.BURST and self.state is not None): expected_n += (self.n_inputs - 1) + if frames_ds.shape[0] != expected_n or idxs.shape[0] != expected_n: raise ValueError("Unexpected number of frames returned. Try to set `override_global_parse` to `True` or `False`.") # Prepare API header and payload headers = {"x-api-key": self.api_key} payload = {"video": base64.b64encode(frames_ds.tobytes()).decode('utf-8')} - if self.op_mode == Mode.BURST and self.state is not None: - payload["state"] = base64.b64encode(self.state.astype(np.float32).tobytes()).decode('utf-8') + if self.op_mode == Mode.BURST: + if self.state is not None: + # State and frame buffer have been initialized + assert self.input_buffer is not None + payload["state"] = base64.b64encode(self.state.astype(np.float32).tobytes()).decode('utf-8') + # Adjust idxs + idxs = idxs[3:] - 3 # Ask API to process video response = requests.post(API_URL, headers=headers, json=payload) response_body = json.loads(response.text) @@ -250,7 +270,6 @@ def process_api_batch( live_ds = np.asarray(response_body["face"]["confidence"]) if self.op_mode == Mode.BURST: self.state = np.asarray(response_body["state"]["data"], dtype=np.float32) - idxs = np.asarray(idxs) return sig_ds, conf_ds, live_ds, idxs def postprocess( self, @@ -293,3 +312,4 @@ def reset(self): """Reset""" if self.op_mode == Mode.BURST: self.state = None + self.input_buffer = None