Skip to content

Commit

Permalink
some refactoring of classifier class ...
Browse files Browse the repository at this point in the history
- app,py no longer handles model file suffixing
- removed unused stand-alone classifier CLI
    - all point-wise classification now in MMIF
    - cli.py can fully replace CLI in classifier.py
- renamed trainer config file to make it clear that's just an example
  • Loading branch information
keighrim committed Jun 27, 2024
1 parent 2507442 commit 1432c51
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 268 deletions.
13 changes: 3 additions & 10 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif:
# parameters here is a "refined" dict, so hopefully its values are properly
# validated and casted at this point.
self.configs = parameters
self._configure_model()
self._configure_postbin()
for k, v in self.configs.items():
self.logger.debug(f"Final Configuration: {k} :: {v}")
Expand Down Expand Up @@ -73,11 +72,6 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif:

return mmif

def _configure_model(self):
model_name = self.configs["modelName"]
self.configs['model_file'] = default_model_storage / f'{model_name}.pt'
self.configs['model_config_file'] = default_model_storage / f'{model_name}.yml'

def _configure_postbin(self):
"""
Set the postbin property of the the configs configuration dictionary, using the
Expand Down Expand Up @@ -122,10 +116,9 @@ def _extract_images(self, video):

def _classify(self, extracted: list, positions: list, total_ms: int):
t = time.perf_counter()
self.logger.info(f"Initiating classifier with {self.configs['model_file']}")
if self.logger.isEnabledFor(logging.DEBUG):
self.configs['logger_name'] = self.logger.name
classifier = classify.Classifier(**self.configs)
self.logger.info(f"Initiating classifier with {self.configs['modelName']}")
classifier = classify.Classifier(default_model_storage / self.configs['modelName'],
self.logger.name if self.logger.isEnabledFor(logging.DEBUG) else None)
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f"Classifier initiation took {time.perf_counter() - t:.2f} seconds")
predictions = classifier.classify_images(extracted, positions, total_ms)
Expand Down
227 changes: 17 additions & 210 deletions modeling/classify.py
Original file line number Diff line number Diff line change
@@ -1,126 +1,36 @@
"""Classifier module.
Used by app.py in the parent directory.
See app.py for hints on how to uses this, the main workhorse method is process_video(),
which takes a video and returns a list of predictions from the image classification model.
For debugging you can run this a standalone script from the parent directory:
$ python -m modeling.classify \
--config modeling/config/classifier.yml \
--input MP4_FILE \
--debug
The above will also use the stitcher in stitch.py.
For help on parameters use:
$ python -m modeling.classify -h
The requirements are the same as the requirements for ../app.py.
"""

import time
import argparse
import json
import logging
import os
import sys
import time
from typing import List

import cv2
import torch
import yaml
from PIL import Image

from modeling import train, data_loader, stitch, FRAME_TYPES
from modeling import train, data_loader, FRAME_TYPES


class Classifier:

def __init__(self, **config):
self.config = config
self.model_config = yaml.safe_load(open(config["model_config_file"]))
self.prebin_labels = train.pretraining_binned_label(self.model_config)
def __init__(self, model_stem, logger_name=None):
model_config = yaml.safe_load(open(f"{model_stem}.yml"))
self.training_labels = train.pretraining_binned_label(model_config)
self.featurizer = data_loader.FeatureExtractor(
img_enc_name=self.model_config["img_enc_name"],
pos_enc_name=self.model_config.get("pos_enc_name", None),
pos_enc_dim=self.model_config.get("pos_enc_dim", 0),
max_input_length=self.model_config.get("max_input_length", 0),
pos_unit=self.model_config.get("pos_unit", 0))
img_enc_name=model_config["img_enc_name"],
pos_enc_name=model_config.get("pos_enc_name", None),
pos_enc_dim=model_config.get("pos_enc_dim", 0),
max_input_length=model_config.get("max_input_length", 0),
pos_unit=model_config.get("pos_unit", 0))
label_count = len(FRAME_TYPES) + 1
if 'bins' in self.model_config:
label_count = len(self.model_config['pre'].keys()) + 1
if 'bins' in model_config:
label_count = len(model_config['bins'].keys()) + 1
self.classifier = train.get_net(
in_dim=self.featurizer.feature_vector_dim(),
n_labels=label_count,
num_layers=self.model_config["num_layers"],
dropout=self.model_config["dropouts"])
self.classifier.load_state_dict(torch.load(config["model_file"]))
self.sample_rate = self.config.get("sampleRate")
self.start_at = config.get("startAt", 0)
self.stop_at = config.get('stopAt', sys.maxsize)
num_layers=model_config["num_layers"],
dropout=model_config["dropouts"])
self.classifier.load_state_dict(torch.load(f'{model_stem}.pt'))
self.debug = False
self.logger = logging.getLogger(config.get('logger_name', self.__class__.__name__))

def __str__(self):
return (f"<Classifier "
+ f'img_enc_name="{self.model_config["img_enc_name"]}" '
+ f'pos_enc_name="{self.model_config["pos_enc_name"]}" '
+ f'sample_rate={self.sample_rate}>')

def process_video(self, vidcap: cv2.VideoCapture) -> list:
"""
Image classification for a video without MMIF SDK helpers, for standalone mode without MMIF/CLAMS involved
Loops over the frames in a video and for each frame extracts the features
and applies the classifier. Returns a list of predictions, where each prediction
is an instance of numpy.ndarray.
"""
featurizing_time = 0
classifier_time = 0
extract_time = 0
seek_time = 0
if self.debug:
print(f'Labels: {self.prebin_labels}')
predictions = []
fps = round(vidcap.get(cv2.CAP_PROP_FPS), 2)
fc = vidcap.get(cv2.CAP_PROP_FRAME_COUNT)
dur = round(fc / fps, 3) * 1000
for ms in range(0, sys.maxsize, self.sample_rate):
if ms < self.start_at:
continue
if ms > self.stop_at:
break
t = time.perf_counter()
vidcap.set(cv2.CAP_PROP_POS_MSEC, ms)
if self.logger.isEnabledFor(logging.DEBUG):
seek_time += time.perf_counter() - t
t = time.perf_counter()
success, image = vidcap.read()
if self.logger.isEnabledFor(logging.DEBUG):
extract_time += time.perf_counter() - t
if not success:
break
img = Image.fromarray(image[:, :, ::-1])
t = time.perf_counter()
features = self.featurizer.get_full_feature_vectors(img, ms, dur)
if self.logger.isEnabledFor(logging.DEBUG):
featurizing_time += time.perf_counter() - t
t = time.perf_counter()
prediction = self.classifier(features).detach()
prediction = Prediction(ms, self.prebin_labels, prediction)
if self.logger.isEnabledFor(logging.DEBUG):
classifier_time += time.perf_counter() - t
if self.debug:
print(prediction)
predictions.append(prediction)
self.logger.debug(f'Featurizing time: {featurizing_time:.2f} seconds\n')
self.logger.debug(f'Classifier time: {classifier_time:.2f} seconds\n')
self.logger.debug(f'Extract time: {extract_time:.2f} seconds\n')
self.logger.debug(f'Seeking time: {seek_time:.2f} seconds\n')
return predictions
self.logger = logging.getLogger(logger_name if logger_name else self.__class__.__name__)

def classify_images(self, images: List[Image.Image], positions: List[int], final_pos: int) -> list:
"""
Expand All @@ -137,50 +47,14 @@ def classify_images(self, images: List[Image.Image], positions: List[int], final
featurizing_time += time.perf_counter() - t
t = time.perf_counter()
prediction = self.classifier(features).detach()
prediction = Prediction(pos, self.prebin_labels, prediction)
prediction = Prediction(pos, self.training_labels, prediction)
if self.logger.isEnabledFor(logging.DEBUG):
classifier_time += time.perf_counter() - t
predictions.append(prediction)
self.logger.debug(f'Featurizing time: {featurizing_time:.2f} seconds\n')
self.logger.debug(f'Classifier time: {classifier_time:.2f} seconds\n')
return predictions

def pp(self):
# debugging method
print(f"Classifier {self.model_file}")
print(f" sample_rate = {self.sample_rate}")
print(f" min_frame_score = {self.min_timeframe_score}")
print(f" min_frame_count = {self.min_frame_count}")


def save_predictions(predictions: list, filename: str):
json_obj = []
for prediction in predictions:
json_obj.append(prediction.as_json())
with open(filename, 'w') as fh:
json.dump(json_obj, fh)
print(f'Saved predictions to {filename}')


def load_predictions(filename: str, labels: list) -> list:
predictions = []
with open(filename) as fh:
for (n, tensor, data) in json.load(fh):
p = Prediction(n, labels, torch.Tensor(tensor), data=data)
predictions.append(p)
return predictions


def print_predictions(predictions, filename=None):
# Debugging method
fh = sys.stdout if filename is None else open(filename, 'w')
fh.write('\n slate chyron creds other\n')
for prediction in predictions:
milliseconds = prediction.timepoint
p1, p2, p3, p4 = prediction.data[:4]
fh.write(f'{milliseconds:7} {p1:.4f} {p2:.4f} {p3:.4f} {p4:.4f}\n')
fh.write(f'\nTOTAL PREDICTIONS: {len(predictions)}\n')


class Prediction:

Expand Down Expand Up @@ -235,70 +109,3 @@ def score_for_labels(self, labels: list):

def as_json(self):
return [self.timepoint, self.tensor.detach().numpy().tolist(), self.data]


def parse_args():
parser = argparse.ArgumentParser()
default_config = 'modeling/config/classifier.yml'
conf_help = "the YAML config file"
pred1_help = "use saved predictions"
pred2_help = "save predictions"
parser.add_argument("--input", help="input video file")
parser.add_argument("--config", default=default_config, help=conf_help)
parser.add_argument("--start", default=0, help="start N milliseconds into the video")
parser.add_argument("--stop", default=None, help="stop N milliseconds into the video")
parser.add_argument("--use-predictions", action='store_true', help=pred1_help)
parser.add_argument("--save-predictions", action='store_true', help=pred2_help)
parser.add_argument("--debug", action='store_true', help="turn on debugging")
return parser.parse_args()


def add_parameters(args: argparse.Namespace, classifier: Classifier, stitcher: stitch.Stitcher):
"""Add arguments to the classifier and the stitcher."""
if args.debug:
classifier.debug = True
stitcher.debug = True
if args.start:
classifier.start_at = int(args.start)
if args.stop:
classifier.stop_at = int(args.stop)


def open_mp4_file(mp4_file, verbose=False):
if verbose:
print(f'Processing {args.input}...')
mp4_vidcap = cv2.VideoCapture(mp4_file)
if not mp4_vidcap.isOpened():
raise IOError(f'Could not open {mp4_file}')
return mp4_vidcap


if __name__ == '__main__':

args = parse_args()
configs = yaml.safe_load(open(args.config))
classifier = Classifier(**configs)
stitcher = stitch.Stitcher(**configs)
add_parameters(args, classifier, stitcher)

if args.debug:
print(classifier)
print(stitcher)
mp4_vidcap = open_mp4_file(args.input, args.debug)
if not mp4_vidcap.isOpened():
raise IOError(f'Could not open {args.input}')

input_basename, extension = os.path.splitext(args.input)
predictions_file = f'{input_basename}.json'
if args.use_predictions:
predictions = load_predictions(predictions_file, classifier.prebin_labels)
else:
predictions = classifier.process_video(mp4_vidcap)
if args.save_predictions:
save_predictions(predictions, predictions_file)

timeframes = stitcher.create_timeframes(predictions)

if not args.debug:
for timeframe in timeframes:
print(timeframe)
48 changes: 0 additions & 48 deletions modeling/config/README.md

This file was deleted.

File renamed without changes.

0 comments on commit 1432c51

Please sign in to comment.