diff --git a/batch_run/dispatcher.py b/batch_run/dispatcher.py index bce7ff9..7c4570b 100644 --- a/batch_run/dispatcher.py +++ b/batch_run/dispatcher.py @@ -35,7 +35,7 @@ "overwrite_config": True, "update_project_paths": True, "initialize_visualization": False, - "use_GPU": False, + "use_GPU": True, "random_seed": 0, "verbose": 2, }, @@ -84,16 +84,16 @@ }, }, "PointTracker": { - "contiguous": False, + "contiguous": True, "params_optical_flow": { "method": "lucas_kanade", - "mesh_rigidity": 0.025, - "mesh_n_neighbors": 8, + "mesh_rigidity": 0.045, + "mesh_n_neighbors": 12, "relaxation": 0.0015, "kwargs_method": { "winSize": [ - 20, - 20, + 25, + 25, ], "maxLevel": 2, "criteria": [ @@ -110,19 +110,19 @@ }, "params_outlier_handling": { "threshold_displacement": 150, - "framesHalted_before": 10, - "framesHalted_after": 10, + "framesHalted_before": 40, + "framesHalted_after": 40, }, "verbose": 2, }, "VQT_Analyzer": { "params_VQT": { - "Q_lowF": 4, - "Q_highF": 10, + "Q_lowF": 3, + "Q_highF": 5, "F_min": 1.0, "F_max": 60, "n_freq_bins": 36, - "win_size": 501, + "win_size": 701, "symmetry": 'left', "taper_asymmetric": True, "plot_pref": False, @@ -163,7 +163,7 @@ "fit": { "method": "CP_NN_HALS", "params_method": { - "rank": 12, + "rank": 10, "n_iter_max": 200, "init": "random", "svd": "truncated_svd", @@ -241,12 +241,12 @@ #SBATCH --job-name={name_slurm} #SBATCH --output={path} #SBATCH --constraint=intel -#SBATCH --gres=gpu:1,vram:23G -#SBATCH --partition=gpu_requeue -#SBATCH -c 4 +#SBATCH --partition=gpu_quad +#SBATCH --gres=gpu:1,vram:31G +#SBATCH -c 8 #SBATCH -n 1 -#SBATCH --mem=36GB -#SBATCH --time=0-0:30:00 +#SBATCH --mem=48GB +#SBATCH --time=0-7:30:00 unset XDG_RUNTIME_DIR diff --git a/face_rhythm/__init__.py b/face_rhythm/__init__.py index aa60cf1..e0406b1 100644 --- a/face_rhythm/__init__.py +++ b/face_rhythm/__init__.py @@ -22,8 +22,8 @@ ## Prepare cv2.imshow -import pkg_resources -installed_packages = {pkg.key for pkg in pkg_resources.working_set} +import importlib.metadata +installed_packages = {dist.metadata['Name'] for dist in importlib.metadata.distributions()} has_cv2_headless = 'opencv-contrib-python-headless' in installed_packages has_cv2_normal = 'opencv-contrib-python' in installed_packages if has_cv2_normal and not has_cv2_headless: @@ -66,4 +66,4 @@ def prepare_cv2_imshow(): exec('from . import ' + pkg) -__version__ = '0.2.4' \ No newline at end of file +__version__ = '0.2.5' \ No newline at end of file diff --git a/face_rhythm/spectral_analysis.py b/face_rhythm/spectral_analysis.py index cf09457..a47fda7 100644 --- a/face_rhythm/spectral_analysis.py +++ b/face_rhythm/spectral_analysis.py @@ -14,6 +14,20 @@ class VQT_Analyzer(FR_Module): displacement traces. The spectrograms are generated using the Variable Q-Transform (VQT) algorithm. RH 2022 + + Args: + params_VQT (dict): + A dictionary of parameters to pass to the VQT class (Variable + Q-Transform) in helpers.py. + normalization_factor (float): + A float between 0 and 1 to normalize the spectrograms. 0 means no + normalization. 1 means every time point has the same power. + spectrogram_exponent (float): + A float to raise the spectrogram to before normalizing. + one_over_f_exponent (float): + A float to raise the frequency axis to before doing 1/f correction. + verbose (int): + An integer to control the verbosity of the class. """ def __init__( self, diff --git a/face_rhythm/visualization.py b/face_rhythm/visualization.py index 1004d6d..e63094d 100644 --- a/face_rhythm/visualization.py +++ b/face_rhythm/visualization.py @@ -1,6 +1,7 @@ from typing import Union from pathlib import Path import copy +import gc import numpy as np import cv2 @@ -120,7 +121,7 @@ def __init__( If list: Each element is a thickness for a different text. Length of list must match the length of text. """ - ## Stor arguments + ## Store arguments self.point_sizes = point_sizes if point_sizes is not None else None self.points_colors = points_colors if points_colors is not None else None self.alpha = alpha if alpha is not None else None @@ -415,8 +416,11 @@ def visualize_image_with_points( def close(self): if self.video_writer is not None: - self.video_writer.release() cv2.destroyWindow(self.handle_cv2Imshow) + try: + self.video_writer.release() + except: + pass def __call__(self, *args, **kwds): """ @@ -426,11 +430,7 @@ def __call__(self, *args, **kwds): self.visualize_image_with_points(*args, **kwds) def __del__(self): self.close() - def __exit__(self): - self.close() - def __enter__(self): - return self - + def __repr__(self): return f'FrameVisualizer(handle_cv2Imshow={self.handle_cv2Imshow}, display={self.display}, video_writer={self.video_writer}, path_video={self.path_save}, frame_rate={self.frame_rate}, frame_height_width={self.frame_height_width})' @@ -477,15 +477,26 @@ def play_video_with_points( ### Set buffered video reader to first frame bufferedVideoReader.set_iterator_frame_idx(int(idx_frames[0])) ### Iterate through frames - for idx_frame in tqdm(idx_frames): - frame = bufferedVideoReader[idx_frame][0] - frame = frame.numpy() if isinstance(frame, torch.Tensor) else frame - p = points_int[idx_frame] if points_int is not None else None - frameVisualizer.visualize_image_with_points( - image=frame, - points=[p], - ) - frameVisualizer.close() + ### Use context manager to close frameVisualizer + class CM: + def __init__(self, frameVisualizer): + self.frameVisualizer = frameVisualizer + def __enter__(self): + return self.frameVisualizer + def __exit__(self, exc_type, exc_value, traceback): + self.frameVisualizer.close() + cv2.destroyWindow(self.frameVisualizer.handle_cv2Imshow) + gc.collect() + with CM(frameVisualizer) as f: + for idx_frame in tqdm(idx_frames): + frame = bufferedVideoReader[idx_frame][0] + frame = frame.numpy() if isinstance(frame, torch.Tensor) else frame + p = points_int[idx_frame] if points_int is not None else None + f.visualize_image_with_points( + image=frame, + points=[p], + ) + f.close() # def display_toggle_image_stack(images, clim=None, **kwargs): diff --git a/requirements.txt b/requirements.txt index 0fe1996..05696e8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ notebook<7 tensorly==0.8.1 opencv_contrib_python==4.9.0.80 matplotlib -scikit_learn==1.4.1.post1 +scikit_learn==1.4.2 scikit_image pyyaml tqdm