diff --git a/modeling/data_loader.py b/modeling/data_loader.py index 7f1d7f6..6fbbaff 100644 --- a/modeling/data_loader.py +++ b/modeling/data_loader.py @@ -1,15 +1,16 @@ import argparse import csv import json +import logging import os from collections import defaultdict from pathlib import Path -from typing import List, Union, Tuple, Dict, ClassVar, Optional -import logging +from typing import List, Union, Tuple, Dict, ClassVar import av import numpy as np import torch +from PIL import Image from tqdm import tqdm from modeling import backbones @@ -29,6 +30,7 @@ class AnnotatedImage: def __init__(self, filename: str, label: str, subtype_label: str, mod: bool = False): self.image = None + self.filename = filename self.guid, self.total_time, self.curr_time = self.split_name(filename) self.total_time = int(self.total_time) self.curr_time = int(self.curr_time) @@ -44,7 +46,12 @@ def split_name(filename:str) -> List[str]: :param filename: filename of the format **GUID_TOTAL_CURR** :return: a tuple containing all the significant metadata """ - guid, total, curr = filename.split("_") + + split_string = filename.split("_") + if len(split_string) == 3: + guid, total, curr = split_string + elif len(split_string) == 4: + guid, total, sought, curr = split_string curr = curr[:-4] return guid, total, curr @@ -155,47 +162,51 @@ def __init__(self, model_name: str): self.models = [FeatureExtractor(model_name)] else: raise ValueError("No valid model found") - print(f'using model(s): {[model.img_encoder.name for model in self.models]}') + logger.info(f'using model(s): {[model.img_encoder.name for model in self.models]}') - def process_video(self, - vid_path: Union[os.PathLike, str], - csv_path: Union[os.PathLike, str],) -> Tuple[Dict, Dict[str, np.ndarray]]: + def process_input(self, + input_path: Union[os.PathLike, str], + csv_path: Union[os.PathLike, str]): """ Extract the features for every annotated timepoint in a video. - :param vid_path: filename of the video - :param csv_path: filename of the csv containing timepoints - :return: A list of metadata dictionaries and associated feature matrix + :param input_path: filename of the input + :param csv_path: csv file containing timepoint-wise annotations """ + if Path(input_path).is_dir(): + logger.info(f'processing dictionary: {input_path}') + else: + logger.info(f'processing video: {input_path}') frame_metadata = {'frames': []} frame_vecs = defaultdict(list) - print(f'processing video: {vid_path}') - for i, frame in tqdm(enumerate(self.get_stills(vid_path, csv_path))): + for frame in tqdm(self.get_stills(input_path, csv_path)): + if 'guid' in frame_metadata and frame.guid != frame_metadata['guid']: + frame_mats = {k: np.vstack(v) for k, v in frame_vecs.items()} + yield frame_metadata, frame_mats + frame_metadata = {'frames': []} + frame_vecs = defaultdict(list) if 'guid' not in frame_metadata: frame_metadata['guid'] = frame.guid if 'duration' not in frame_metadata: frame_metadata['duration'] = frame.total_time - for extractor in self.models: frame_vecs[extractor.img_encoder.name].append(extractor.get_img_vector(frame.image, as_numpy=True)) - frame_dict = {k: v for k, v in frame.__dict__.items() if k != "image" and k != "guid" and k != "total_time"} - frame_dict['vec_idx'] = i + frame_dict = {k: v for k, v in frame.__dict__.items() if k != "image" and k != "guid" and k != "total_time" and k != "filename"} + frame_dict['vec_idx'] = len(frame_metadata['frames']) frame_metadata["frames"].append(frame_dict) - frame_mats = {k: np.vstack(v) for k, v in frame_vecs.items()} - return frame_metadata, frame_mats + yield frame_metadata, frame_mats - def get_stills(self, vid_path: Union[os.PathLike, str], + def get_stills(self, media_path: Union[os.PathLike, str], csv_path: Union[os.PathLike, str]) -> List[AnnotatedImage]: """ Extract stills at given timepoints from a video file - :param vid_path: the filename of the video - :param timepoints: a list of the video's annotated timepoints - :return: a list of Frame objects + :param media_path: the filename of the video + :param csv_path: path to the csv file containing timepoint-wise annotations + :return: a generator of image objects that contains raw Image array and metadata from the annotations """ - with open(csv_path, encoding='utf8') as f: reader = csv.reader(f) next(reader) @@ -204,51 +215,68 @@ def get_stills(self, vid_path: Union[os.PathLike, str], subtype_label=row[3], mod=row[4].lower() == 'true') for row in reader if row[1] == 'true'] # CSV rows with mod=True should be discarded (taken as "unseen") - # maybe we can throw away the video with the least (88) frames annotation from B2 to make 20/20 split on dense vs sparse annotation - - # this part is doing the same thing as the get_stills function in getstills.py - # (copied from https://github.com/WGBH-MLA/keystrokelabeler/blob/df4d2bc936fa3a73cdf3004803a0c35c290caf93/getstills.py#L36 ) - - container = av.open(vid_path) - video_stream = next((s for s in container.streams if s.type == 'video'), None) - if video_stream is None: - raise Exception("No video stream found in {}".format(vid_path)) - fps = video_stream.average_rate.numerator / video_stream.average_rate.denominator - cur_target_frame = 0 - fcount = 0 - for frame in container.decode(video=0): - if cur_target_frame == len(frame_list): - break - ftime = int(frame.time * 1000) - if ftime == frame_list[cur_target_frame].curr_time: - frame_list[cur_target_frame].image = frame.to_image() - yield frame_list[cur_target_frame] - cur_target_frame += 1 - fcount += 1 + # maybe we can throw away the video with the least (88) frames annotation from B2 + # to make 20/20 split on dense vs sparse annotation + + if Path(media_path).is_dir(): + # Process as directory of images + for frame in frame_list: + image_path = Path(media_path) / frame.filename + if image_path.exists(): + # see https://stackoverflow.com/a/30376272 + i = Image.open(image_path) + frame.image = i.copy() + yield frame + else: + logger.warning(f"Image file not found for annotation: {frame.filename}") + + else: + # this part is doing the same thing as the get_stills function in getstills.py + # (copied from https://github.com/WGBH-MLA/keystrokelabeler/blob/df4d2bc936fa3a73cdf3004803a0c35c290caf93/getstills.py#L36 ) + container = av.open(media_path) + video_stream = next((s for s in container.streams if s.type == 'video'), None) + if video_stream is None: + raise Exception("No video stream found in {}".format(media_path)) + fps = video_stream.average_rate.numerator / video_stream.average_rate.denominator + cur_target_frame = 0 + fcount = 0 + for frame in container.decode(video=0): + if cur_target_frame == len(frame_list): + break + ftime = int(fcount/fps * 1000) + if ftime == frame_list[cur_target_frame].curr_time: + frame_list[cur_target_frame].image = frame.to_image() + yield frame_list[cur_target_frame] + cur_target_frame += 1 + fcount += 1 def main(args): - in_file = args.input_file - metadata_file = args.csv_file - featurizer = TrainingDataPreprocessor(args.model_name) - print('extractor ready') - feat_metadata, feat_mats = featurizer.process_video(in_file, metadata_file) - print('extraction complete') - - if not os.path.exists(args.outdir): - os.makedirs(args.outdir, exist_ok=True) - with open(f"{args.outdir}/{feat_metadata['guid']}.json", 'w', encoding='utf8') as f: - json.dump(feat_metadata, f) - for name, vectors in feat_mats.items(): - np.save(f"{args.outdir}/{feat_metadata['guid']}.{name}", vectors) + in_file = args.input_data + metadata_file = args.annotation_csv + featurizer = TrainingDataPreprocessor(args.model) + logger.info('extractor ready') + + Path(args.outdir).mkdir(parents=True, exist_ok=True) + + for feat_metadata, feat_mats in featurizer.process_input(in_file, metadata_file): + logger.info(f'{feat_metadata["guid"]} extraction complete') + with open(f"{args.outdir}/{feat_metadata['guid']}.json", 'w', encoding='utf8') as f: + json.dump(feat_metadata, f) + for name, vectors in feat_mats.items(): + np.save(f"{args.outdir}/{feat_metadata['guid']}.{name}", vectors) + logger.info('all extraction complete') + + # featurizer.process_input(in_file, metadata_file, args.outdir) + logger.info('extraction complete') if __name__ == "__main__": parser = argparse.ArgumentParser(description="CLI for preprocessing a video file and its associated manual SWT " "annotations to pre-generate (CNN) image feature vectors with manual" "labels attached for later training.") - parser.add_argument("-i", "--input-video", - help="filepath for the video to be processed.", - required=True) + parser.add_argument("-i", "--input-data", + help="filepath for the video file or a directory of extracted images to be processed.", + required=True) parser.add_argument("-c", "--annotation-csv", help="filepath for the csv containing timepoints + labels.", required=True)