diff --git a/modeling/backbones.py b/modeling/backbones.py index 3337770..6180fc6 100644 --- a/modeling/backbones.py +++ b/modeling/backbones.py @@ -36,6 +36,7 @@ # Base Class class ExtractorModel: name: str + dim: int model: torch.nn.Module preprocess: Callable @@ -47,6 +48,7 @@ class ExtractorModel: # ConvNext Models class ConvnextBaseExtractor(ExtractorModel): name = "convnext_base" + dim = 1024 def __init__(self): self.model = convnext_base(weights=ConvNeXt_Base_Weights.IMAGENET1K_V1) @@ -56,6 +58,7 @@ def __init__(self): class ConvnextTinyExtractor(ExtractorModel): name = "convnext_tiny" + dim = 768 def __init__(self): self.model = convnext_tiny(weights=ConvNeXt_Tiny_Weights.IMAGENET1K_V1) @@ -65,6 +68,7 @@ def __init__(self): class ConvnextSmallExtractor(ExtractorModel): name = "convnext_small" + dim = 768 def __init__(self): self.model = convnext_small(weights=ConvNeXt_Small_Weights.IMAGENET1K_V1) @@ -74,6 +78,7 @@ def __init__(self): class ConvnextLargeExtractor(ExtractorModel): name = "convnext_lg" + dim = 1536 def __init__(self): self.model = convnext_large(weights=ConvNeXt_Large_Weights.IMAGENET1K_V1) @@ -85,6 +90,7 @@ def __init__(self): # DenseNet Models class Densenet121Extractor(ExtractorModel): name = "densenet121" + dim = 1024 def __init__(self): self.model = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1) @@ -92,8 +98,9 @@ def __init__(self): self.preprocess = DenseNet121_Weights.IMAGENET1K_V1.transforms() -class Densenet161Extractor(): +class Densenet161Extractor(ExtractorModel): name = "densenet161" + dim = 2208 def __init__(self): self.model = densenet161(weights=DenseNet161_Weights.IMAGENET1K_V1) @@ -101,8 +108,9 @@ def __init__(self): self.preprocess = DenseNet161_Weights.IMAGENET1K_V1.transforms() -class Densenet169Extractor(): +class Densenet169Extractor(ExtractorModel): name = "densenet169" + dim = 1664 def __init__(self): self.model = densenet169(weights=DenseNet169_Weights.IMAGENET1K_V1) @@ -110,8 +118,9 @@ def __init__(self): self.preprocess = DenseNet169_Weights.IMAGENET1K_V1.transforms() -class Densenet201Extractor(): +class Densenet201Extractor(ExtractorModel): name = "densenet201" + dim = 1920 def __init__(self): self.model = densenet201(weights=DenseNet201_Weights.IMAGENET1K_V1) @@ -123,6 +132,7 @@ def __init__(self): # EfficientNet Models class EfficientnetSmallExtractor(ExtractorModel): name = "efficientnet_small" + dim = 1280 def __init__(self): self.model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1) @@ -132,6 +142,7 @@ def __init__(self): class EfficientnetMediumExtractor(ExtractorModel): name = "efficientnet_med" + dim = 1280 def __init__(self): self.model = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.IMAGENET1K_V1) @@ -141,6 +152,7 @@ def __init__(self): class EfficientnetLargeExtractor(ExtractorModel): name = "efficientnet_large" + dim = 1280 def __init__(self): self.model = efficientnet_v2_l(weights=EfficientNet_V2_L_Weights.IMAGENET1K_V1) @@ -164,6 +176,7 @@ def __init__(self): class Resnet18Extractor(ExtractorModel): name = "resnet18" + dim = 512 def __init__(self): self.model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) @@ -173,6 +186,7 @@ def __init__(self): class Resnet50Extractor(ExtractorModel): name = "resnet50" + dim = 2048 def __init__(self): self.model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) @@ -182,6 +196,7 @@ def __init__(self): class Resnet101Extractor(ExtractorModel): name = "resnet101" + dim = 2048 def __init__(self): self.model = resnet101(weights=ResNet101_Weights.IMAGENET1K_V1) @@ -191,6 +206,7 @@ def __init__(self): class Resnet152Extractor(ExtractorModel): name = "resnet152" + dim = 2048 def __init__(self): self.model = resnet152(weights=ResNet152_Weights.IMAGENET1K_V1) @@ -202,6 +218,7 @@ def __init__(self): # VGG Models class Vgg16Extractor(ExtractorModel): name = "vgg16" + dim = 4096 def __init__(self): self.model = vgg16(weights=VGG16_Weights.IMAGENET1K_V1) @@ -211,6 +228,7 @@ def __init__(self): class BN_Vgg16Extractor(ExtractorModel): name = "bn_vgg16" + dim = 4096 def __init__(self): self.model = vgg16_bn(weights=VGG16_BN_Weights.IMAGENET1K_V1) @@ -220,6 +238,7 @@ def __init__(self): class Vgg19Extractor(ExtractorModel): name = "vgg19" + dim = 4096 def __init__(self): self.model = vgg19(weights=VGG19_Weights.IMAGENET1K_V1) @@ -229,6 +248,7 @@ def __init__(self): class BN_VGG19Extractor(ExtractorModel): name = "bn_vgg19" + dim = 4096 def __init__(self): self.model = vgg19_bn(weights=VGG19_BN_Weights.IMAGENET1K_V1) @@ -242,6 +262,10 @@ def __init__(self): model.name: model for model in sys.modules[__name__].ExtractorModel.__subclasses__() if model.name != 'inceptionv3'} +model_dim_map = { + model.name: model.dim for model + in sys.modules[__name__].ExtractorModel.__subclasses__() if model.name != 'inceptionv3'} + if __name__ == "__main__": import numpy as np dummy_guid = 'cpb-aacip-fe9efa663c6' diff --git a/modeling/classify.py b/modeling/classify.py deleted file mode 100644 index d50ed46..0000000 --- a/modeling/classify.py +++ /dev/null @@ -1,282 +0,0 @@ -import os -import sys -import json -from operator import itemgetter - -import torch -import numpy as np -import cv2 -from PIL import Image - -from mmif.utils import video_document_helper as vdh - -import train -import backbones - - -# For now just using the model from the first fold of a test (and also assuming -# that it has the labels as listed in config/default.yml). -BEST_MODEL = 'results-test/20231026-135541.kfold_000.pt' -BEST_MODEL = 'results-test/20231026-164841.kfold_000.pt' - -# Assuming for now that the model's feature extractor uses the VGG16 model. -MODEL_TYPE = 'vgg16' - -# Mappings from prediction indices to label name. Another temporary assumption, -# it should be read from a config file or an input parameter. -LABEL_MAPPINGS = {0: 'slate', 1: 'chyron', 2: 'credit', 3: 'other'} - -# Milliseconds between frames. -STEP_SIZE = 1000 - -# Minimum average score for a timeframe. We require at least one frame score -# higher than 1. -MINIMUM_SCORE = 1.01 - -# For debugging, set to True if you want to save the frames that were extracted. -SAFE_FRAMES = False - -# Set to True if you want the script to be more verbose. -DRIBBLE = False - -# Defining the bins for the labels. -SCORE_MAPPING = ((0.01, 0), (0.25, 1), (0.50, 2), (0.75, 3), (1.01, 4)) - - -# Getting the non-other labels. -LABELS = {label for label in sorted(LABEL_MAPPINGS.values()) if label != 'other'} - -# Loading the model and featurizer. -model = train.get_net(4096, 4, 3, 0.2) -model.load_state_dict(torch.load(BEST_MODEL)) -featurizer = backbones.model_map[MODEL_TYPE]() - - -def process_video(mp4_file: str, step: int = 1000): - """Loops over the frames in a video and for each frame extract the features - and apply the model. Returns a list of predictions, where each prediction is - an instance of numpy.ndarray.""" - print(f'Processing {mp4_file}...') - all_predictions = [] - for n, image in get_frames(mp4_file, step): - img = Image.fromarray(image[:,:,::-1]) - features = extract_features(img, featurizer) - prediction = model(features) - prediction = Prediction(n, prediction) - if DRIBBLE: - print(f'{n:07d}', prediction) - all_predictions.append(prediction) - if SAFE_FRAMES: - cv2.imwrite(f"frames/frame-{n:06d}.jpg", image) - return(all_predictions) - - -def get_frames(mp4_file: str, step: int = 1000): - """Generator to get frames from an mp4 file. The step parameter defines the number - of milliseconds between the frames.""" - vidcap = cv2.VideoCapture(mp4_file) - for n in range(0, sys.maxsize, step): - vidcap.set(cv2.CAP_PROP_POS_MSEC, n) - success, image = vidcap.read() - if not success: - break - yield n, image - - -def extract_features(frame_vec: np.ndarray, model: torch.nn.Sequential) -> torch.Tensor: - """Extract the features of a single frame. Based on, but not identical to, the - process_frame() method of the FeatureExtractor class in data_ingestion.py.""" - frame_vec = model.preprocess(frame_vec) - frame_vec = frame_vec.unsqueeze(0) - if torch.cuda.is_available(): - if DRIBBLE: - print('CUDA is available') - frame_vec = frame_vec.to('cuda') - model.model.to('cuda') - with torch.no_grad(): - feature_vec = model.model(frame_vec) - return feature_vec.cpu() - - -def softmax(x): - return(np.exp(x - np.max(x)) / np.exp(x - np.max(x)).sum()) - - -def save_predictions(predictions: list, filename: str): - json_obj = [] - for prediction in predictions: - json_obj.append(prediction.as_json()) - with open(filename, 'w') as fh: - json.dump(json_obj, fh) - print(f'Saved output to {filename}') - - -def load_predictions(filename: str) -> list: - with open(filename) as fh: - predictions = json.load(fh) - return predictions - - -def enrich_predictions(predictions: list): - """For each prediction, add a nominal score for each label. The scores go from - 0 through 4. For example if the raw probability score for the slate is in the - 0.5-0.75 range than ('slate', 3) will be added.""" - for prediction in predictions: - binned_scores = compute_labels(prediction[1]) - prediction[1].append(binned_scores) - - -def print_predictions(predictions): - print('\n slate chyron creds other') - for prediction in predictions: - milliseconds = prediction[0] - p1, p2, p3, p4 = prediction[1][:4] - binned_scores = prediction[1][4] - labels = ' '.join([f'{label}-{score}' for label, score in binned_scores]) - print(f'{milliseconds:6} {p1:.4f} {p2:.4f} {p3:.4f} {p4:.4f} {labels}') - print(f'\nTOTAL PREDICTIONS: {len(predictions)}\n') - - -def compute_labels(scores: list): - return ( - ('slate', scale(scores[0])), - ('chyron', scale(scores[1])), - ('credit', scale(scores[2]))) - - -def scale(score): - """Put the score on a scale from 0 through 4, where 0 means the score is less - than 0.01 and 1 though 4 are quartiles for score bins 0.01-0.25, 0.25-0.50, - 0.50-0.75 and 0.75-1.00.""" - for score_in, score_out in SCORE_MAPPING: - if score < score_in: - return score_out - - -def collect_timeframes(predictions: list) -> dict: - """Find sequences of frames for all labels where the score is not 0.""" - timeframes = { label: [] for label in LABELS} - open_frames = { label: [] for label in LABELS} - for prediction in predictions: - timepoint = prediction[0] - bins = prediction[1][4] - for label, score in bins: - if score == 0: - if open_frames[label]: - timeframes[label].append(open_frames[label]) - open_frames[label] = [] - elif score >= 1: - open_frames[label].append((timepoint, score, label)) - for label, score in bins: - if open_frames[label]: - timeframes[label].append(open_frames[label]) - return timeframes - -def compress_timeframes(timeframes: dict): - """Compresses all timeframes from [(t_1, score_1), ... (t_n, score_n)] into the - shorter representation (t_1, t_n, average_score).""" - for label in LABELS: - frames = timeframes[label] - for i in range(len(frames)): - start = frames[i][0][0] - end = frames[i][-1][0] - score = sum([e[1] for e in frames[i]]) / len(frames[i]) - frames[i] = (start, end, score) - -def filter_timeframes(timeframes: dict): - """Filter out all timeframes with an average score below the threshold defined - in MINIMUM_SCORE.""" - for label in LABELS: - timeframes[label] = [tf for tf in timeframes[label] if tf[2] > MINIMUM_SCORE] - -def remove_overlapping_timeframes(timeframes: dict) -> list: - all_frames = [] - for label in timeframes: - for frame in timeframes[label]: - all_frames.append(frame + (label,)) - all_frames = list(sorted(all_frames, key=itemgetter(2), reverse=True)) - outlawed_timepoints = set() - final_frames = [] - for frame in all_frames: - if is_included(frame, outlawed_timepoints): - continue - final_frames.append(frame) - start, end, _, _ = frame - for p in range(start, end + STEP_SIZE, STEP_SIZE): - outlawed_timepoints.add(p) - return all_frames - -def is_included(frame, outlawed_timepoints): - start, end, _, _ = frame - for i in range(start, end + STEP_SIZE, STEP_SIZE): - if i in outlawed_timepoints: - return True - return False - - -def experiment(): - """This is an older experiment. It was the first one that I could get to work - and it was fully based on the code in data_ingestion.py""" - outdir = 'vectorized2' - featurizer = FeatureExtractor('vgg16') - in_file = 'data/cpb-aacip-690722078b2-shrunk.mp4' - #in_file = 'data/cpb-aacip-690722078b2.mp4' - metadata_file = 'data/cpb-aacip-690722078b2.csv' - feat_metadata, feat_mats = featurizer.process_video(in_file, metadata_file) - print('extraction complete') - if not os.path.exists(outdir): - os.makedirs(outdir, exist_ok=True) - for name, vectors in feat_mats.items(): - with open(f"{outdir}/{feat_metadata['guid']}.json", 'w', encoding='utf8') as f: - json.dump(feat_metadata, f) - np.save(f"{outdir}/{feat_metadata['guid']}.{name}", vectors) - outputs = model(torch.from_numpy(vectors)) - print(outputs) - - -class Prediction: - - """Class to store a prediction from the model. It is meant to simplify the rest - of the code a bit and manage some of the intricacies of the data structures that - are involved. One thing it does is to run softmax over the scores in the tensor. - - timepoint - the location of the frame, in milliseconds - tensor - the tensor that results from running the model on the features - data - the tensor simplified into a simple list with scores for each label - - """ - - def __init__(self, timepoint: int, prediction: torch.Tensor): - self.timepoint = timepoint - self.tensor = prediction - self.data = softmax(self.tensor.detach().numpy())[0].tolist() - - def __str__(self): - return f'' - - def as_json(self): - return [self.timepoint, self.data] - - -if __name__ == '__main__': - - create_frame_predictions = False - create_timeframes = False - - if create_frame_predictions: - predictions = process_video('data/cpb-aacip-690722078b2-shrunk.mp4', step=STEP_SIZE) - save_predictions(predictions, 'predictions.json') - - if create_timeframes: - predictions = load_predictions('predictions.json') - enrich_predictions(predictions) - #print_predictions(predictions) - timeframes = collect_timeframes(predictions) - compress_timeframes(timeframes) - filter_timeframes(timeframes) - #for label in timeframes: - # print(label, timeframes[label]) - timeframes = remove_overlapping_timeframes(timeframes) - print(timeframes) - - diff --git a/modeling/data_ingestion.py b/modeling/data_ingestion.py index 247f9ba..3a073bc 100644 --- a/modeling/data_ingestion.py +++ b/modeling/data_ingestion.py @@ -22,7 +22,8 @@ import json import os from collections import defaultdict -from typing import List, Union, Tuple, Dict +from pathlib import Path +from typing import List, Union, Tuple, Dict, ClassVar import av import numpy as np @@ -53,22 +54,101 @@ def split_name(filename:str) -> List[str]: return guid, total, curr -class FeatureExtractor: - """Convert an annotated video set into a machine-readable format - uses as a backbone to featurize the annotated still images - into 4096-dim vectors. +class FeatureExtractor(object): + + dense_encoder: backbones.ExtractorModel + pos_encoder: str + max_input_length: int + pos_dim: int + sinusoidal_embeddings: ClassVar[Dict[Tuple[int, int], torch.Tensor]] = {} + + def __init__(self, dense_encoder_name: str, + positional_encoder: str = None, + positional_embedding_dim: int = 512, + max_input_length: int = 5640000, # 94 min = the longest video in the first round of annotation + positional_unit: int = 60000): + """ + Initializes the FeatureExtractor object. + + @param: model_name = a name of backbone model to use for dense (i.e., CNN) vector extraction + @param: positional_encoder = type of positional encoder to use, one of 'fractional', sinusoidal-add', 'sinusoidal-concat', when not given use no positional encoding + @param: positional_embedding_dim = dimension of positional embedding, only relevant to 'sinusoidal-add' scheme, when not given use 512 + @param: max_input_length = maximum length of input video in milliseconds, used for padding positional encoding + @param: positional_unit = unit of positional encoding in milliseconds (e.g., 60000 for minutes, 1000 for seconds) + """ + if dense_encoder_name is None: + raise ValueError("A dense vector model must be specified") + else: + self.dense_encoder: backbones.ExtractorModel = backbones.model_map[dense_encoder_name]() + self.pos_encoder = positional_encoder + self.pos_dim = positional_unit + if positional_encoder in ['sinusoidal-add', 'sinusoidal-concat']: + position_dim = int(max_input_length / positional_unit) + if position_dim % 2 == 1: + position_dim += 1 + if positional_encoder == 'sinusoidal-concat': + self.pos_vec_lookup = self.get_sinusoidal_embeddings(position_dim, positional_embedding_dim) + elif positional_encoder == 'sinusoidal-add': + self.pos_vec_lookup = self.get_sinusoidal_embeddings(position_dim, self.dense_encoder.dim) + + def get_sinusoidal_embeddings(self, n_pos, dim): + if (n_pos, dim) in self.__class__.sinusoidal_embeddings: + return self.__class__.sinusoidal_embeddings[(n_pos, dim)] + matrix = torch.zeros(n_pos, dim) + position_enc = np.array( + [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) + matrix[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) + matrix[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) + self.__class__.sinusoidal_embeddings[(n_pos, dim)] = matrix + return matrix + + def get_dense_vector(self, img_vec): + img_vec = self.dense_encoder.preprocess(img_vec) + img_vec = img_vec.unsqueeze(0) + if torch.cuda.is_available(): + img_vec = img_vec.to('cuda') + self.dense_encoder.model.to('cuda') + with torch.no_grad(): + feature_vec = self.dense_encoder.model(img_vec) + return feature_vec.cpu().numpy() + + def encode_position(self, cur_time, tot_time, dense_vec): + pos = cur_time / tot_time + if self.pos_encoder is None: + return dense_vec + elif self.pos_encoder == 'fractional': + return torch.concat((dense_vec, torch.tensor([pos]))) + elif self.pos_encoder == 'sinusoidal-add': + return torch.add(dense_vec, self.pos_vec_lookup[round(pos)]) + elif self.pos_encoder == 'sinusoidal-concat': + return torch.concat((dense_vec, self.pos_vec_lookup[round(pos)])) + + def feature_vector_dim(self): + if self.pos_encoder == 'sinusoidal-add' or self.pos_encoder is None: + return self.dense_encoder.dim + elif self.pos_encoder == 'sinusoidal-concat': + return self.dense_encoder.dim + self.pos_dim + elif self.pos_encoder == 'fractional': + return self.dense_encoder.dim + 1 + + def get_full_feature_vectors(self, img_vec, cur_time, tot_time): + dense_vecs = self.get_dense_vector(img_vec) + return self.encode_position(cur_time, tot_time, dense_vecs) + + +class TrainingDataPreprocessor(object): """ - models: List[backbones.ExtractorModel] - - def __init__(self, model_name: str = None): + Refactor of an early feature extraction code, where we only used CNN vectors + """ + def __init__(self, model_name: str): if model_name is None: - self.models = [model() for model in backbones.model_map.values()] + self.models = [FeatureExtractor(model_name) for model_name in backbones.model_map.keys()] else: if model_name in backbones.model_map: - self.models = [backbones.model_map[model_name]()] + self.models = [FeatureExtractor(model_name)] else: raise ValueError("No valid model found") - print(f'using model(s): {[model.name for model in self.models]}') + print(f'using model(s): {[model.dense_encoder.name for model in self.models]}') def process_video(self, vid_path: Union[os.PathLike, str], @@ -88,8 +168,8 @@ def process_video(self, if 'duration' not in frame_metadata: frame_metadata['duration'] = frame.total_time - for model in self.models: - frame_vecs[model.name].append(self.process_frame(frame.image, model)) + for extractor in self.models: + frame_vecs[extractor.dense_encoder.name].append(extractor.get_dense_vector(frame.image)) frame_dict = {k: v for k, v in frame.__dict__.items() if k != "image" and k != "guid" and k != "total_time"} frame_dict['vec_idx'] = i frame_metadata["frames"].append(frame_dict) @@ -97,20 +177,6 @@ def process_video(self, frame_mats = {k: np.vstack(v) for k, v in frame_vecs.items()} return frame_metadata, frame_mats - def process_frame(self, frame_vec: np.ndarray, model) -> np.ndarray: - """Extract the features of a single frame. - - @param: frame = a frame as a numpy array - @returns: a numpy array representing the frame as features""" - frame_vec = model.preprocess(frame_vec) - frame_vec = frame_vec.unsqueeze(0) - if torch.cuda.is_available(): - frame_vec = frame_vec.to('cuda') - model.model.to('cuda') - with torch.no_grad(): - feature_vec = model.model(frame_vec) - return feature_vec.cpu().numpy() - def get_stills(self, vid_path: Union[os.PathLike, str], csv_path: Union[os.PathLike, str]) -> List[AnnotatedImage]: """Extract stills at given timepoints from a video file @@ -126,8 +192,7 @@ def get_stills(self, vid_path: Union[os.PathLike, str], label=row[2], subtype_label=row[3], mod=row[4].lower() == 'true') for row in reader if row[1] == 'true'] - # mod=True should discard (taken as "unseen") - # performace jump from additional batch (from 2nd) + # CSV rows with mod=True should be discarded (taken as "unseen") # maybe we can throw away the video with the least (88) frames annotation from B2 to make 20/20 split on dense vs sparse annotation # this part is doing the same thing as the get_stills function in getstills.py @@ -141,8 +206,6 @@ def get_stills(self, vid_path: Union[os.PathLike, str], cur_target_frame = 0 fcount = 0 for frame in container.decode(video=0): - # if fcount % 10000 == 0: - # print(f'processing frame {fcount}') if cur_target_frame == len(frame_list): break ftime = int(frame.time * 1000) @@ -153,16 +216,10 @@ def get_stills(self, vid_path: Union[os.PathLike, str], fcount += 1 -def get_framenum(frame: AnnotatedImage, fps: float) -> int: - """Returns the frame number of the given FrameOfInterest - (converts from ms to frame#)""" - return int(int(frame.curr_time)/1000 * fps) - - def main(args): in_file = args.input_file metadata_file = args.csv_file - featurizer = FeatureExtractor(args.model_name) + featurizer = TrainingDataPreprocessor(args.model_name) print('extractor ready') feat_metadata, feat_mats = featurizer.process_video(in_file, metadata_file) print('extraction complete') @@ -189,6 +246,6 @@ def main(args): default=None) parser.add_argument("-o", "--outdir", help="directory to save output files", - default="vectorized") + default=Path(__file__).parent / "vectorized") clargs = parser.parse_args() main(clargs) diff --git a/modeling/train.py b/modeling/train.py index 15779c3..37219f4 100644 --- a/modeling/train.py +++ b/modeling/train.py @@ -6,7 +6,6 @@ import sys import time from collections import defaultdict -from functools import lru_cache from pathlib import Path from typing import List, IO @@ -20,6 +19,8 @@ from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score from tqdm import tqdm +from modeling import data_ingestion + logging.basicConfig( level=logging.WARNING, format="%(asctime)s %(name)s %(levelname)-8s %(thread)d %(message)s", @@ -27,30 +28,12 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -ori_feat_dims = { - "convnext_base": 1024, - "convnext_tiny": 768, - "convnext_small": 768, - "convnext_lg": 1536, - "densenet121": 1024, - "efficientnet_small": 1280, - "efficientnet_med": 1280, - "efficientnet_large": 1280, - "resnet18": 512, - "resnet50": 2048, - "resnet101": 2048, - "resnet152": 2048, - "vgg16": 4096, - "bn_vgg16": 4096, - "vgg19": 4096, - "bn_vgg19": 4096, -} feat_dims = {} # full typology from https://github.com/clamsproject/app-swt-detection/issues/1 FRAME_TYPES = ["B", "S", "S:H", "S:C", "S:D", "S:B", "S:G", "W", "L", "O", "M", "I", "N", "E", "P", "Y", "K", "G", "T", "F", "C", "R"] -RESULTS_DIR = f"results-{platform.node().split('.')[0]}" +RESULTS_DIR = Path(__file__).parent / f"results-{platform.node().split('.')[0]}" class SWTDataset(Dataset): @@ -70,28 +53,6 @@ def has_data(self): return 0 < len(self.vectors) == len(self.labels) -def adjust_dims(configs): - additional_dim = 0 - if configs and 'positional_encoding' in configs: - if configs['positional_encoding'] == 'fractional': - additional_dim = 1 - elif configs['positional_encoding'] == 'sinusoidal-concat': - if 'embedding_size' in configs: - additional_dim = configs['embedding_size'] - global feat_dims - feat_dims = {backbone: dim + additional_dim for backbone, dim in ori_feat_dims.items()} - return - - -@lru_cache -def create_sinusoidal_embeddings(n_pos, dim): - matrix = torch.zeros(n_pos, dim) - position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) - matrix[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) - matrix[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) - return matrix - - def get_guids(data_dir): guids = [] for j in Path(data_dir).glob('*.json'): @@ -175,57 +136,37 @@ def split_dataset(indir, train_guids, validation_guids, configs): train_labels = [] valid_vectors = [] valid_labels = [] - if configs and 'unit_multiplier' in configs: - unit = configs['unit_multiplier'] - else: - unit = 3600000 if configs and 'bins' in configs and 'pre' in configs['bins']: pre_bin_size = len(configs['bins']['pre'].keys()) + 1 else: pre_bin_size = len(FRAME_TYPES) + 1 train_vnum = train_vimg = valid_vnum = valid_vimg = 0 - if configs and 'positional_encoding' in configs and configs['positional_encoding'] in ['sinusoidal-add', 'sinusoidal-concat']: + logger.warn(configs['positional_encoding']) + + extractor = data_ingestion.FeatureExtractor( + dense_encoder_name=configs['backbone_name'], + positional_encoder=configs['positional_encoding'], + positional_unit=configs['unit_multiplier'] if configs and 'unit_multiplier' in configs else 3600000, + positional_embedding_dim=configs['embedding_size'] if 'embedding_size' in configs else 512, # for now, hard-coding the longest video length in the annotated dataset # $ for m in /llc_data/clams/swt-gbh/**/*.mp4; do printf "%s %s\n" "$(basename $m .mp4)" "$(ffmpeg -i $m 2>&1 | grep Duration: )"; done | sort -k 3 -r | head -n 1 # cpb-aacip-259-4j09zf95 Duration: 01:33:59.57, start: 0.000000, bitrate: 852 kb/s - # 94 miins = 5640 secs = 5640000 ms - logger.warning('POSITIONAL ENCODING IS EXPERIMENTAL') - max_len = int(5640000 / unit) - if max_len % 2 == 1: - max_len += 1 - if configs['positional_encoding'] == 'sinusoidal-add': - embedding_dim = feat_dims[configs['backbone_name']] - elif 'embedding_size' in configs: - embedding_dim = configs['embedding_size'] - else: - embedding_dim = 512 - logger.info(f'creating positional encoding: {max_len} x {embedding_dim}') - positional_encoding = create_sinusoidal_embeddings(max_len, embedding_dim) + # 94 mins = 5640 secs = 5640000 ms + max_input_length=5640000 + ) + for j in Path(indir).glob('*.json'): guid = j.with_suffix("").name feature_vecs = np.load(Path(indir) / f"{guid}.{configs['backbone_name']}.npy") labels = json.load(open(Path(indir) / f"{guid}.json")) - # posenced_vecs = [] - # if configs and 'positional_encoding' in configs: - # if configs['positional_encoding'] == 'fractional': + total_video_len = labels['duration'] for i, vec in enumerate(feature_vecs): if not labels['frames'][i]['mod']: # "transitional" frames valid_vimg += 1 pre_binned_label = pre_bin(labels['frames'][i]['label'], configs) vector = torch.from_numpy(vec) position = labels['frames'][i]['curr_time'] - if configs and 'positional_encoding' in configs: - if configs['positional_encoding'] == 'fractional': - total = labels['duration'] - fraction = position / total - vector = torch.concat((vector, torch.tensor([fraction]))) - elif configs['positional_encoding'] in ['sinusoidal-add', 'sinusoidal-concat']: - position = round(position/unit) - embedding = positional_encoding[position] - if configs['positional_encoding'] == 'sinusoidal-add': - vector = torch.add(vector, embedding) - else: - vector = torch.concat((vector, embedding)) + vector = extractor.encode_position(position, total_video_len, vector) if guid in validation_guids: valid_vnum += 1 valid_vectors.append(vector) @@ -262,6 +203,7 @@ def k_fold_train(indir, configs, train_id=time.strftime("%Y%m%d-%H%M%S")): logger.debug(f'train set: {train_guids}') logger.debug(f'dev set: {validation_guids}') train, valid, labelset_size = split_dataset(indir, train_guids, validation_guids, configs) + # `train` and `valid` vectors DO contain positional encoding after `split_dataset` if not train.has_data() or not valid.has_data(): logger.info(f"Skipping fold {i} due to lack of data") continue @@ -288,10 +230,10 @@ def k_fold_train(indir, configs, train_id=time.strftime("%Y%m%d-%H%M%S")): else: export_f = sys.stdout export_kfold_results(val_set_spec, p_scores, r_scores, f_scores, out=export_f, **configs) - export_config(configs, train_id) + export_config(configs, train_id, train.feat_dim) -def export_config(configs: dict, train_id: str): +def export_config(configs: dict, train_id: str, feat_dim): backbone = configs["backbone_name"] config_path = Path(f"{RESULTS_DIR}", f"{backbone}.{train_id}.kfold_config.yml") config_path.parent.mkdir(parents=True, exist_ok=True) @@ -301,7 +243,7 @@ def export_config(configs: dict, train_id: str): fh.write(f'labels: {get_valid_labels(configs)}\n\n') # TODO: keeping this for now because some other downstream code depends # on it, but remove this after the backbone refactoring is merged in - fh.write(f'in_dim: {feat_dims[backbone]}\n\n') + fh.write(f'in_dim: {feat_dim}\n\n') def export_kfold_results(trial_specs, p_scores, r_scores, f_scores, out=sys.stdout, **train_spec): @@ -433,10 +375,8 @@ def export_train_result(out: IO, predictions: Tensor, labels: Tensor, labelset: args = parser.parse_args() if args.config: - adjust_dims(args.config) k_fold_train(indir=args.indir, configs=args.config, train_id=time.strftime("%Y%m%d-%H%M%S")) else: import gridsearch for config in gridsearch.configs: - adjust_dims(config) k_fold_train(indir=args.indir, configs=config, train_id=time.strftime("%Y%m%d-%H%M%S"))