Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable image-based data preprocessing in data_loader.py #115

Merged
merged 8 commits into from
Sep 2, 2024
146 changes: 87 additions & 59 deletions modeling/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import argparse
import csv
import json
import logging
import os
from collections import defaultdict
from pathlib import Path
from typing import List, Union, Tuple, Dict, ClassVar, Optional
import logging
from typing import List, Union, Tuple, Dict, ClassVar

import av
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm

from modeling import backbones
Expand All @@ -29,6 +30,7 @@ class AnnotatedImage:

def __init__(self, filename: str, label: str, subtype_label: str, mod: bool = False):
self.image = None
self.filename = filename
self.guid, self.total_time, self.curr_time = self.split_name(filename)
self.total_time = int(self.total_time)
self.curr_time = int(self.curr_time)
Expand All @@ -44,7 +46,12 @@ def split_name(filename:str) -> List[str]:
:param filename: filename of the format **GUID_TOTAL_CURR**
:return: a tuple containing all the significant metadata
"""
guid, total, curr = filename.split("_")

split_string = filename.split("_")
if len(split_string) == 3:
guid, total, curr = split_string
elif len(split_string) == 4:
guid, total, sought, curr = split_string
curr = curr[:-4]
return guid, total, curr

Expand Down Expand Up @@ -155,47 +162,51 @@ def __init__(self, model_name: str):
self.models = [FeatureExtractor(model_name)]
else:
raise ValueError("No valid model found")
print(f'using model(s): {[model.img_encoder.name for model in self.models]}')
logger.info(f'using model(s): {[model.img_encoder.name for model in self.models]}')

def process_video(self,
vid_path: Union[os.PathLike, str],
csv_path: Union[os.PathLike, str],) -> Tuple[Dict, Dict[str, np.ndarray]]:
def process_input(self,
input_path: Union[os.PathLike, str],
csv_path: Union[os.PathLike, str]):
"""
Extract the features for every annotated timepoint in a video.

:param vid_path: filename of the video
:param csv_path: filename of the csv containing timepoints
:return: A list of metadata dictionaries and associated feature matrix
:param input_path: filename of the input
:param csv_path: csv file containing timepoint-wise annotations
"""
if Path(input_path).is_dir():
logger.info(f'processing dictionary: {input_path}')
else:
logger.info(f'processing video: {input_path}')

frame_metadata = {'frames': []}
frame_vecs = defaultdict(list)
print(f'processing video: {vid_path}')
for i, frame in tqdm(enumerate(self.get_stills(vid_path, csv_path))):
for frame in tqdm(self.get_stills(input_path, csv_path)):
if 'guid' in frame_metadata and frame.guid != frame_metadata['guid']:
frame_mats = {k: np.vstack(v) for k, v in frame_vecs.items()}
yield frame_metadata, frame_mats
frame_metadata = {'frames': []}
frame_vecs = defaultdict(list)
if 'guid' not in frame_metadata:
frame_metadata['guid'] = frame.guid
if 'duration' not in frame_metadata:
frame_metadata['duration'] = frame.total_time

for extractor in self.models:
frame_vecs[extractor.img_encoder.name].append(extractor.get_img_vector(frame.image, as_numpy=True))
frame_dict = {k: v for k, v in frame.__dict__.items() if k != "image" and k != "guid" and k != "total_time"}
frame_dict['vec_idx'] = i
frame_dict = {k: v for k, v in frame.__dict__.items() if k != "image" and k != "guid" and k != "total_time" and k != "filename"}
frame_dict['vec_idx'] = len(frame_metadata['frames'])
frame_metadata["frames"].append(frame_dict)

frame_mats = {k: np.vstack(v) for k, v in frame_vecs.items()}
return frame_metadata, frame_mats
yield frame_metadata, frame_mats

def get_stills(self, vid_path: Union[os.PathLike, str],
def get_stills(self, media_path: Union[os.PathLike, str],
csv_path: Union[os.PathLike, str]) -> List[AnnotatedImage]:
"""
Extract stills at given timepoints from a video file

:param vid_path: the filename of the video
:param timepoints: a list of the video's annotated timepoints
:return: a list of Frame objects
:param media_path: the filename of the video
:param csv_path: path to the csv file containing timepoint-wise annotations
:return: a generator of image objects that contains raw Image array and metadata from the annotations
"""

with open(csv_path, encoding='utf8') as f:
reader = csv.reader(f)
next(reader)
Expand All @@ -204,51 +215,68 @@ def get_stills(self, vid_path: Union[os.PathLike, str],
subtype_label=row[3],
mod=row[4].lower() == 'true') for row in reader if row[1] == 'true']
# CSV rows with mod=True should be discarded (taken as "unseen")
# maybe we can throw away the video with the least (88) frames annotation from B2 to make 20/20 split on dense vs sparse annotation

# this part is doing the same thing as the get_stills function in getstills.py
# (copied from https://github.com/WGBH-MLA/keystrokelabeler/blob/df4d2bc936fa3a73cdf3004803a0c35c290caf93/getstills.py#L36 )

container = av.open(vid_path)
video_stream = next((s for s in container.streams if s.type == 'video'), None)
if video_stream is None:
raise Exception("No video stream found in {}".format(vid_path))
fps = video_stream.average_rate.numerator / video_stream.average_rate.denominator
cur_target_frame = 0
fcount = 0
for frame in container.decode(video=0):
if cur_target_frame == len(frame_list):
break
ftime = int(frame.time * 1000)
if ftime == frame_list[cur_target_frame].curr_time:
frame_list[cur_target_frame].image = frame.to_image()
yield frame_list[cur_target_frame]
cur_target_frame += 1
fcount += 1
# maybe we can throw away the video with the least (88) frames annotation from B2
# to make 20/20 split on dense vs sparse annotation

if Path(media_path).is_dir():
# Process as directory of images
for frame in frame_list:
image_path = Path(media_path) / frame.filename
if image_path.exists():
# see https://stackoverflow.com/a/30376272
i = Image.open(image_path)
frame.image = i.copy()
yield frame
else:
logger.warning(f"Image file not found for annotation: {frame.filename}")

else:
# this part is doing the same thing as the get_stills function in getstills.py
# (copied from https://github.com/WGBH-MLA/keystrokelabeler/blob/df4d2bc936fa3a73cdf3004803a0c35c290caf93/getstills.py#L36 )
container = av.open(media_path)
video_stream = next((s for s in container.streams if s.type == 'video'), None)
if video_stream is None:
raise Exception("No video stream found in {}".format(media_path))
fps = video_stream.average_rate.numerator / video_stream.average_rate.denominator
cur_target_frame = 0
fcount = 0
for frame in container.decode(video=0):
if cur_target_frame == len(frame_list):
break
ftime = int(fcount/fps * 1000)
if ftime == frame_list[cur_target_frame].curr_time:
frame_list[cur_target_frame].image = frame.to_image()
yield frame_list[cur_target_frame]
cur_target_frame += 1
fcount += 1


def main(args):
in_file = args.input_file
metadata_file = args.csv_file
featurizer = TrainingDataPreprocessor(args.model_name)
print('extractor ready')
feat_metadata, feat_mats = featurizer.process_video(in_file, metadata_file)
print('extraction complete')

if not os.path.exists(args.outdir):
os.makedirs(args.outdir, exist_ok=True)
with open(f"{args.outdir}/{feat_metadata['guid']}.json", 'w', encoding='utf8') as f:
json.dump(feat_metadata, f)
for name, vectors in feat_mats.items():
np.save(f"{args.outdir}/{feat_metadata['guid']}.{name}", vectors)
in_file = args.input_data
metadata_file = args.annotation_csv
featurizer = TrainingDataPreprocessor(args.model)
logger.info('extractor ready')

Path(args.outdir).mkdir(parents=True, exist_ok=True)

for feat_metadata, feat_mats in featurizer.process_input(in_file, metadata_file):
logger.info(f'{feat_metadata["guid"]} extraction complete')
with open(f"{args.outdir}/{feat_metadata['guid']}.json", 'w', encoding='utf8') as f:
json.dump(feat_metadata, f)
for name, vectors in feat_mats.items():
np.save(f"{args.outdir}/{feat_metadata['guid']}.{name}", vectors)
logger.info('all extraction complete')

# featurizer.process_input(in_file, metadata_file, args.outdir)
logger.info('extraction complete')

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="CLI for preprocessing a video file and its associated manual SWT "
"annotations to pre-generate (CNN) image feature vectors with manual"
"labels attached for later training.")
parser.add_argument("-i", "--input-video",
help="filepath for the video to be processed.",
required=True)
parser.add_argument("-i", "--input-data",
help="filepath for the video file or a directory of extracted images to be processed.",
required=True)
parser.add_argument("-c", "--annotation-csv",
help="filepath for the csv containing timepoints + labels.",
required=True)
Expand Down
Loading