diff --git a/modeling/classify.py b/modeling/classify.py index e4699b2..8e2c752 100644 --- a/modeling/classify.py +++ b/modeling/classify.py @@ -7,31 +7,31 @@ import cv2 import numpy as np import torch -import yaml from PIL import Image -from modeling import backbones -from modeling.train import get_net +from modeling import data_loader, train class Classifier: - def __init__(self, configs): - config = yaml.safe_load(configs) - self.step_size = config["step_size"] - self.minimum_score = config["minimum_score"] - self.score_mapping = config["score_mapping"] - if "safe_frames" in config: - self.safe_frames = config["safe_frames"] - else: - self.safe_frames = False - if "dribble" in config: - self.dribble = config["dribble"] - else: - self.dribble = False - model_type, self.label_mappings, self.model = read_model_config(config["model_config"]) - self.model.load_state_dict(torch.load(config["model"])) - self.featurizer = backbones.model_map[model_type]() + def __init__(self, **config): + self.classifier = train.get_net( + in_dim=config["feature_dim"], + n_labels=len(config['prebin']) if 'prebin' in config else len(config["labels"]), + num_layers=config["num_layers"], + dropout=config["dropouts"], + ) + self.classifier.load_state_dict(torch.load(config["model_file"])) + self.featurizer = data_loader.FeatureExtractor( + img_enc_name=config["img_enc_name"], + pos_enc_name=config.get("pos_enc_name", None), + pos_enc_dim=config.get("pos_enc_dim", None), + max_input_length=config.get("max_input_length", None), + position_unit=config.get("position_unit", None), + ) + self.sample_rate = config["sample_rate"] + self.labels = config["labels"] + self.postbin = config.get("postbin", None) def process_video(self, mp4_file: str): """Loops over the frames in a video and for each frame extract the features @@ -40,42 +40,25 @@ def process_video(self, mp4_file: str): print(f'Processing {mp4_file}...') logging.info(f'processing {mp4_file}...') all_predictions = [] - for n, image in get_frames(mp4_file, self.step_size): + vidcap = cv2.VideoCapture(mp4_file) + fps = round(vidcap.get(cv2.CAP_PROP_FPS), 2) + fc = vidcap.get(cv2.CAP_PROP_FRAME_COUNT) + dur = round(fc / fps, 3) * 1000 + for ms in range(0, sys.maxsize, self.sample_rate): + vidcap.set(cv2.CAP_PROP_POS_MSEC, ms) + success, image = vidcap.read() + if not success: + break img = Image.fromarray(image[:,:,::-1]) - features = self.extract_features(img, self.featurizer) - prediction = self.model(features) - prediction = Prediction(n, prediction) - if self.dribble: - print(f'{n:07d}', prediction) - all_predictions.append(prediction) - if self.safe_frames: - cv2.imwrite(f"frames/frame-{n:06d}.jpg", image) - logging.info(f'number of predictions = {len(all_predictions)}') - return(all_predictions) + features = self.featurizer.get_full_feature_vectors(img, ms, dur) + softmax = torch.nn.Softmax() + output = self.classifier(features).detach() + prediction = softmax(output) + top1 = torch.argmax(prediction).item() + label = self.labels[self.postbin[top1]] if self.postbin else self.labels[top1] + all_predictions.append((ms, label, prediction[top1])) + return all_predictions - def extract_features(self, frame_vec: np.ndarray, model: torch.nn.Sequential) -> torch.Tensor: - """Extract the features of a single frame. Based on, but not identical to, the - process_frame() method of the FeatureExtractor class in data_ingestion.py.""" - frame_vec = model.preprocess(frame_vec) - frame_vec = frame_vec.unsqueeze(0) - if torch.cuda.is_available(): - if self.dribble: - print('CUDA is available') - frame_vec = frame_vec.to('cuda') - model.model.to('cuda') - with torch.no_grad(): - feature_vec = model.model(frame_vec) - return feature_vec.cpu() - - def save_predictions(self, predictions: list, filename: str): - json_obj = [] - for prediction in predictions: - json_obj.append(prediction.as_json()) - with open(filename, 'w') as fh: - json.dump(json_obj, fh) - if self.dribble: - print(f'Saved predictions to {filename}') - def compute_labels(self, scores: list): return ( ('slate', self.scale(scores[0])), @@ -108,8 +91,8 @@ def extract_timeframes(self, predictions): def collect_timeframes(self, predictions: list) -> dict: """Find sequences of frames for all labels where the score is not 0.""" - timeframes = { label: [] for label in self.labels} - open_frames = { label: [] for label in self.labels} + timeframes = {label: [] for label in self.labels} + open_frames = {label: [] for label in self.labels} for prediction in predictions: timepoint = prediction.timepoint bins = prediction.data[-1] @@ -156,44 +139,18 @@ def remove_overlapping_timeframes(self, timeframes: dict) -> list: continue final_frames.append(frame) start, end, _, _ = frame - for p in range(start, end + self.step_size, self.step_size): + for p in range(start, end + self.sample_rate, self.sample_rate): outlawed_timepoints.add(p) return all_frames def is_included(self, frame, outlawed_timepoints): start, end, _, _ = frame - for i in range(start, end + self.step_size, self.step_size): + for i in range(start, end + self.sample_rate, self.sample_rate): if i in outlawed_timepoints: return True return False -def read_model_config(configs): - with open(configs) as f: - config = yaml.safe_load(configs) - labels = config["labels"] - in_dim = config["in_dim"] - n_labels = len(labels) - num_layers = config["num_layers"] - dropout = config["dropout"] - model = get_net(in_dim, n_labels, num_layers, dropout) - model_type = config["model_type"] - label_mappings = {i: label for i, label in enumerate(labels)} - return model_type, label_mappings, model - - -def get_frames(mp4_file: str, step: int = 1000): - """Generator to get frames from an mp4 file. The step parameter defines the number - of milliseconds between the frames.""" - vidcap = cv2.VideoCapture(mp4_file) - for n in range(0, sys.maxsize, step): - vidcap.set(cv2.CAP_PROP_POS_MSEC, n) - success, image = vidcap.read() - if not success: - break - yield n, image - - def softmax(x): return np.exp(x - np.max(x)) / np.exp(x - np.max(x)).sum() @@ -243,7 +200,7 @@ def as_json(self): if __name__ == '__main__': # purely for debugging purposes parser = argparse.ArgumentParser() - parser.add_argument("-c", "--config", help="The YAML config file") + parser.add_argument("-c", "--config", help="The YAML config file", default='modeling/config/classifier.yml') args = parser.parse_args() classifier = Classifier(args.config)