Skip to content

Commit

Permalink
updated classifier to use new config keys and pos_enc configs, also ...
Browse files Browse the repository at this point in the history
* preparing to move smoothing code into app.py
  • Loading branch information
keighrim committed Nov 20, 2023
1 parent 81e0db2 commit 2b874ca
Showing 1 changed file with 41 additions and 84 deletions.
125 changes: 41 additions & 84 deletions modeling/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,31 @@
import cv2
import numpy as np
import torch
import yaml
from PIL import Image

from modeling import backbones
from modeling.train import get_net
from modeling import data_loader, train


class Classifier:

def __init__(self, configs):
config = yaml.safe_load(configs)
self.step_size = config["step_size"]
self.minimum_score = config["minimum_score"]
self.score_mapping = config["score_mapping"]
if "safe_frames" in config:
self.safe_frames = config["safe_frames"]
else:
self.safe_frames = False
if "dribble" in config:
self.dribble = config["dribble"]
else:
self.dribble = False
model_type, self.label_mappings, self.model = read_model_config(config["model_config"])
self.model.load_state_dict(torch.load(config["model"]))
self.featurizer = backbones.model_map[model_type]()
def __init__(self, **config):
self.classifier = train.get_net(
in_dim=config["feature_dim"],
n_labels=len(config['prebin']) if 'prebin' in config else len(config["labels"]),
num_layers=config["num_layers"],
dropout=config["dropouts"],
)
self.classifier.load_state_dict(torch.load(config["model_file"]))
self.featurizer = data_loader.FeatureExtractor(
img_enc_name=config["img_enc_name"],
pos_enc_name=config.get("pos_enc_name", None),
pos_enc_dim=config.get("pos_enc_dim", None),
max_input_length=config.get("max_input_length", None),
position_unit=config.get("position_unit", None),
)
self.sample_rate = config["sample_rate"]
self.labels = config["labels"]
self.postbin = config.get("postbin", None)

def process_video(self, mp4_file: str):
"""Loops over the frames in a video and for each frame extract the features
Expand All @@ -40,42 +40,25 @@ def process_video(self, mp4_file: str):
print(f'Processing {mp4_file}...')
logging.info(f'processing {mp4_file}...')
all_predictions = []
for n, image in get_frames(mp4_file, self.step_size):
vidcap = cv2.VideoCapture(mp4_file)
fps = round(vidcap.get(cv2.CAP_PROP_FPS), 2)
fc = vidcap.get(cv2.CAP_PROP_FRAME_COUNT)
dur = round(fc / fps, 3) * 1000
for ms in range(0, sys.maxsize, self.sample_rate):
vidcap.set(cv2.CAP_PROP_POS_MSEC, ms)
success, image = vidcap.read()
if not success:
break
img = Image.fromarray(image[:,:,::-1])
features = self.extract_features(img, self.featurizer)
prediction = self.model(features)
prediction = Prediction(n, prediction)
if self.dribble:
print(f'{n:07d}', prediction)
all_predictions.append(prediction)
if self.safe_frames:
cv2.imwrite(f"frames/frame-{n:06d}.jpg", image)
logging.info(f'number of predictions = {len(all_predictions)}')
return(all_predictions)
features = self.featurizer.get_full_feature_vectors(img, ms, dur)
softmax = torch.nn.Softmax()
output = self.classifier(features).detach()
prediction = softmax(output)
top1 = torch.argmax(prediction).item()
label = self.labels[self.postbin[top1]] if self.postbin else self.labels[top1]
all_predictions.append((ms, label, prediction[top1]))
return all_predictions

def extract_features(self, frame_vec: np.ndarray, model: torch.nn.Sequential) -> torch.Tensor:
"""Extract the features of a single frame. Based on, but not identical to, the
process_frame() method of the FeatureExtractor class in data_ingestion.py."""
frame_vec = model.preprocess(frame_vec)
frame_vec = frame_vec.unsqueeze(0)
if torch.cuda.is_available():
if self.dribble:
print('CUDA is available')
frame_vec = frame_vec.to('cuda')
model.model.to('cuda')
with torch.no_grad():
feature_vec = model.model(frame_vec)
return feature_vec.cpu()

def save_predictions(self, predictions: list, filename: str):
json_obj = []
for prediction in predictions:
json_obj.append(prediction.as_json())
with open(filename, 'w') as fh:
json.dump(json_obj, fh)
if self.dribble:
print(f'Saved predictions to {filename}')

def compute_labels(self, scores: list):
return (
('slate', self.scale(scores[0])),
Expand Down Expand Up @@ -108,8 +91,8 @@ def extract_timeframes(self, predictions):

def collect_timeframes(self, predictions: list) -> dict:
"""Find sequences of frames for all labels where the score is not 0."""
timeframes = { label: [] for label in self.labels}
open_frames = { label: [] for label in self.labels}
timeframes = {label: [] for label in self.labels}
open_frames = {label: [] for label in self.labels}
for prediction in predictions:
timepoint = prediction.timepoint
bins = prediction.data[-1]
Expand Down Expand Up @@ -156,44 +139,18 @@ def remove_overlapping_timeframes(self, timeframes: dict) -> list:
continue
final_frames.append(frame)
start, end, _, _ = frame
for p in range(start, end + self.step_size, self.step_size):
for p in range(start, end + self.sample_rate, self.sample_rate):
outlawed_timepoints.add(p)
return all_frames

def is_included(self, frame, outlawed_timepoints):
start, end, _, _ = frame
for i in range(start, end + self.step_size, self.step_size):
for i in range(start, end + self.sample_rate, self.sample_rate):
if i in outlawed_timepoints:
return True
return False


def read_model_config(configs):
with open(configs) as f:
config = yaml.safe_load(configs)
labels = config["labels"]
in_dim = config["in_dim"]
n_labels = len(labels)
num_layers = config["num_layers"]
dropout = config["dropout"]
model = get_net(in_dim, n_labels, num_layers, dropout)
model_type = config["model_type"]
label_mappings = {i: label for i, label in enumerate(labels)}
return model_type, label_mappings, model


def get_frames(mp4_file: str, step: int = 1000):
"""Generator to get frames from an mp4 file. The step parameter defines the number
of milliseconds between the frames."""
vidcap = cv2.VideoCapture(mp4_file)
for n in range(0, sys.maxsize, step):
vidcap.set(cv2.CAP_PROP_POS_MSEC, n)
success, image = vidcap.read()
if not success:
break
yield n, image


def softmax(x):
return np.exp(x - np.max(x)) / np.exp(x - np.max(x)).sum()

Expand Down Expand Up @@ -243,7 +200,7 @@ def as_json(self):
if __name__ == '__main__':
# purely for debugging purposes
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", help="The YAML config file")
parser.add_argument("-c", "--config", help="The YAML config file", default='modeling/config/classifier.yml')
args = parser.parse_args()

classifier = Classifier(args.config)
Expand Down

0 comments on commit 2b874ca

Please sign in to comment.