diff --git a/configs/i3d.yml b/configs/i3d.yml index d273036..abfc026 100644 --- a/configs/i3d.yml +++ b/configs/i3d.yml @@ -6,9 +6,13 @@ streams: null # Streams to use for feature extraction (e.g. 'rgb' or 'flow'). Bo flow_type: 'pwc' # Flow to use in I3D. 'pwc' (PWCNet) is faster while 'raft' (RAFT) is more accurate. extraction_fps: null # For original video fps, leave as "null" (None) +# Feature Augumentation +augment: null # Augmentation to use for feature extraction. Can be ['ten_crop', 'five_crop'] + # Extraction Parameters device: 'cuda:0' # device as in `torch`, can be 'cpu' on_extraction: 'print' # what to do once the features are extracted. Can be ['print', 'save_numpy', 'save_pickle'] +save_option: null # choose what to save. If you only want rgb features, choose 'rgb_only' else null output_path: './output' # where to store results if saved tmp_path: './tmp' # folder to store the temporary files used for extraction (frames or aud files) keep_tmp_files: false # to keep temp files after feature extraction. diff --git a/docs/models/i3d.md b/docs/models/i3d.md index cd3474d..dc1926a 100644 --- a/docs/models/i3d.md +++ b/docs/models/i3d.md @@ -44,6 +44,7 @@ You may test it yourself by providing `--show_pred` flag. | `video_paths` | `null` | A list of videos for feature extraction. E.g. `"[./sample/v_ZNVhz7ctTq0.mp4, ./sample/v_GGSY1Qvo990.mp4]"` or just one path `"./sample/v_GGSY1Qvo990.mp4"`. | | `file_with_video_paths` | `null` | A path to a text file with video paths (one path per line). Hint: given a folder `./dataset` with `.mp4` files one could use: `find ./dataset -name "*mp4" > ./video_paths.txt`. | | `on_extraction` | `print` | If `print`, the features are printed to the terminal. If `save_numpy` or `save_pickle`, the features are saved to either `.npy` file or `.pkl`. | +| `save_option` | `null` | If `rgb_only`, only the rgb features will be saved. | | `output_path` | `"./output"` | A path to a folder for storing the extracted features (if `on_extraction` is either `save_numpy` or `save_pickle`). | | `keep_tmp_files` | `false` | If `true`, the reencoded videos will be kept in `tmp_path`. | | `tmp_path` | `"./tmp"` | A path to a folder for storing temporal files (e.g. reencoded videos). | diff --git a/models/_base/base_extractor.py b/models/_base/base_extractor.py index a8d7972..b8dcf83 100644 --- a/models/_base/base_extractor.py +++ b/models/_base/base_extractor.py @@ -18,9 +18,11 @@ def __init__(self, output_path: str, keep_tmp_files: bool, device: str, + save_option=None, ) -> None: self.feature_type = feature_type self.on_extraction = on_extraction + self.save_option = save_option self.tmp_path = tmp_path self.output_path = output_path self.keep_tmp_files = keep_tmp_files @@ -76,6 +78,11 @@ def action_on_extraction( return for key, value in feats_dict.items(): + if self.save_option == 'rgb_only': + if key != 'rgb': + continue + else: + key = None if self.on_extraction == 'print': print(key) print(value) @@ -84,11 +91,18 @@ def action_on_extraction( elif self.on_extraction in ['save_numpy', 'save_pickle']: # make dir if doesn't exist os.makedirs(self.output_path, exist_ok=True) - fpath = make_path(self.output_path, video_path, key, action2ext[self.on_extraction]) - if key != 'fps' and len(value) == 0: - print(f'Warning: the value is empty for {key} @ {fpath}') # save the info behind the each key - action2savefn[self.on_extraction](fpath, value) + if len(value.shape) < 3: + fpath = make_path(self.output_path, video_path, key, action2ext[self.on_extraction]) + if key != 'fps' and len(value) == 0: + print(f'Warning: the value is empty for {key} @ {fpath}') + action2savefn[self.on_extraction](fpath, value) + else: + for i in range(value.shape[0]): + fpath = make_path(self.output_path, video_path, key, action2ext[self.on_extraction], i) + if key != 'fps' and len(value) == 0: + print(f'Warning: the value is empty for {key} @ {fpath}') + action2savefn[self.on_extraction](fpath, value[i, :]) else: raise NotImplementedError(f'on_extraction: {self.on_extraction} is not implemented') diff --git a/models/i3d/extract_i3d.py b/models/i3d/extract_i3d.py index fb40a80..081cb00 100644 --- a/models/i3d/extract_i3d.py +++ b/models/i3d/extract_i3d.py @@ -24,6 +24,7 @@ def __init__(self, args) -> None: super().__init__( feature_type=args.feature_type, on_extraction=args.on_extraction, + save_option=args.save_option, tmp_path=args.tmp_path, output_path=args.output_path, keep_tmp_files=args.keep_tmp_files, @@ -38,15 +39,24 @@ def __init__(self, args) -> None: self.extraction_fps = args.extraction_fps self.step_size = 64 if args.step_size is None else args.step_size self.stack_size = 64 if args.stack_size is None else args.stack_size + self.aug_type = args.augment self.resize_transforms = torchvision.transforms.Compose([ torchvision.transforms.ToPILImage(), ResizeImproved(self.min_side_size), PILToTensor(), ToFloat(), ]) + if self.aug_type is None: + aug_transform = TensorCenterCrop(self.central_crop_size) + elif self.aug_type == 'five_crop': + aug_transform = torchvision.transforms.FiveCrop(self.central_crop_size) + self.num_crop = 5 + elif self.aug_type == 'ten_crop': + aug_transform = torchvision.transforms.TenCrop(self.central_crop_size) + self.num_crop = 10 self.i3d_transforms = { 'rgb': torchvision.transforms.Compose([ - TensorCenterCrop(self.central_crop_size), + aug_transform, ScaleTo1_1(), PermuteAndUnsqueeze() ]), @@ -82,8 +92,12 @@ def extract(self, video_path: str) -> Dict[str, np.ndarray]: # timestamp when the last frame in the stack begins (when the old frame of the last pair ends) timestamps_ms = [] rgb_stack = [] - feats_dict = {stream: [] for stream in self.streams} - + + if self.aug_type is not None: + feats_dict = {stream: [[] for _ in range(self.num_crop)] for stream in self.streams} + else: + feats_dict = {stream: [] for stream in self.streams} + # sometimes when the target fps is 1 or 2, the first frame of the reencoded video is missing # and cap.read returns None but the rest of the frames are ok. timestep is 0.0 for the 2nd frame in # this case @@ -113,7 +127,11 @@ def extract(self, video_path: str) -> Dict[str, np.ndarray]: if len(rgb_stack) - 1 == self.stack_size: batch_feats_dict = self.run_on_a_stack(rgb_stack, stack_counter, padder) for stream in self.streams: - feats_dict[stream].extend(batch_feats_dict[stream].tolist()) + if isinstance(batch_feats_dict[stream], tuple): + for i in range(len(batch_feats_dict[stream])): + feats_dict[stream][i].extend(batch_feats_dict[stream][i].tolist()) + else: + feats_dict[stream].extend(batch_feats_dict[stream].tolist()) # leaving the elements if step_size < stack_size so they will not be loaded again # if step_size == stack_size one element is left because the flow between the last element # in the prev list and the first element in the current list @@ -161,8 +179,11 @@ def run_on_a_stack(self, rgb_stack, stack_counter, padder=None) -> Dict[str, tor raise NotImplementedError # apply transforms depending on the stream (flow or rgb) stream_slice = self.i3d_transforms[stream](stream_slice) - # extract features for a stream - batch_feats_dict[stream] = models[stream](stream_slice, features=True) # (B, 1024) + if isinstance(stream_slice, tuple): + # extract features for a stream + batch_feats_dict[stream] = tuple([models[stream](stream_crop, features=True) for stream_crop in stream_slice]) + else: + batch_feats_dict[stream] = models[stream](stream_slice, features=True) # (B, 1024) # add features to the output dict self.maybe_show_pred(stream_slice, self.name2module['model'][stream], stack_counter) diff --git a/models/transforms.py b/models/transforms.py index 943dd21..b1ea297 100644 --- a/models/transforms.py +++ b/models/transforms.py @@ -145,13 +145,17 @@ def __call__(self, tensor: torch.FloatTensor) -> torch.FloatTensor: class ScaleTo1_1(object): - def __call__(self, tensor: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, tensor): + if isinstance(tensor, tuple): + return tuple([(2 * t / 255) - 1 for t in tensor]) return (2 * tensor / 255) - 1 class PermuteAndUnsqueeze(object): - def __call__(self, tensor: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, tensor): + if isinstance(tensor, tuple): + return tuple([t.permute(1, 0, 2, 3).unsqueeze(0) for t in tensor]) return tensor.permute(1, 0, 2, 3).unsqueeze(0) diff --git a/utils/utils.py b/utils/utils.py index 123778c..291340f 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -50,9 +50,18 @@ def show_predictions_on_dataset(logits: torch.FloatTensor, dataset: Union[str, L print(f'{logit:8.3f} | {smax:.3f} | {cls}') print() -def make_path(output_root, video_path, output_key, ext): +def make_path(output_root, video_path, output_key, ext, idx=None): # extract file name and change the extention - fname = f'{Path(video_path).stem}_{output_key}{ext}' + if idx is not None: + if output_key is not None: + fname = f'{Path(video_path).stem}_{output_key}_{idx}{ext}' + else: + fname = f'{Path(video_path).stem}_{idx}{ext}' + else: + if output_key is not None: + fname = f'{Path(video_path).stem}_{output_key}{ext}' + else: + fname = f'{Path(video_path).stem}_{idx}{ext}' # construct the paths to save the features return os.path.join(output_root, fname) @@ -131,8 +140,7 @@ def form_list_from_user_input( to_shuffle: bool = True, ) -> list: '''User specifies either list of videos in the cmd or a path to a file with video paths. This function - transforms the user input into a list of paths. Files are expected to be formatted with a single - video-path in each line. + transforms the user input into a list of paths. Args: video_paths (Union[str, ListConfig, None], optional): a list of video paths. Defaults to None.