Skip to content

Commit

Permalink
removed "postbin" from training code, evaluation is separated and don…
Browse files Browse the repository at this point in the history
…e w/o/ postbins
  • Loading branch information
keighrim committed Mar 15, 2024
1 parent eb966a2 commit b4651c4
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 130 deletions.
2 changes: 1 addition & 1 deletion modeling/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Classifier:
def __init__(self, **config):
self.config = config
self.model_config = yaml.safe_load(open(config["model_config_file"]))
self.prebin_labels = train.pre_bin_label_names(self.model_config, FRAME_TYPES)
self.prebin_labels = train.pretraining_binned_label(self.model_config)
self.postbin_labels = train.post_bin_label_names(self.model_config)
self.featurizer = data_loader.FeatureExtractor(
img_enc_name=self.model_config["img_enc_name"],
Expand Down
17 changes: 8 additions & 9 deletions modeling/config/trainer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,11 @@ pos_enc_dim: 512
max_input_length: 5640000

bins:
pre:
slate:
- "S"
chyron:
- "I"
- "N"
- "Y"
credit:
- "C"
slate:
- "S"
chyron:
- "I"
- "N"
- "Y"
credit:
- "C"
60 changes: 60 additions & 0 deletions modeling/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import csv
import logging
import sys
from collections import defaultdict
from pathlib import Path
from typing import IO, List

import torch
from torch import Tensor
from torchmetrics import functional as metrics
from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score


def evaluate(model, valid_loader, labelset, export_fname=None):
model.eval()
# valid_loader is currently expected to be a single batch
vfeats, vlabels = next(iter(valid_loader))
outputs = model(vfeats)
_, preds = torch.max(outputs, 1)
p = metrics.precision(preds, vlabels, 'multiclass', num_classes=len(labelset), average='macro')
r = metrics.recall(preds, vlabels, 'multiclass', num_classes=len(labelset), average='macro')
f = metrics.f1_score(preds, vlabels, 'multiclass', num_classes=len(labelset), average='macro')
# m = metrics.confusion_matrix(preds, vlabels, 'multiclass', num_classes=len(labelset))

if not export_fname:
export_f = sys.stdout
else:
path = Path(export_fname)
path.parent.mkdir(parents=True, exist_ok=True)
export_f = open(path, 'w', encoding='utf8')
export_train_result(out=export_f, preds=preds, golds=vlabels,
labelset=labelset, img_enc_name=valid_loader.dataset.img_enc_name)
logging.info(f"Exported to {export_f.name}")
return p, r, f


def export_train_result(out: IO, preds: Tensor, golds: Tensor, labelset: List[str], img_enc_name: str):
"""Exports the data into a human-readable format.
"""

label_metrics = defaultdict(dict)

for i, label in enumerate(labelset):
pred_labels = torch.where(preds == i, 1, 0)
true_labels = torch.where(golds == i, 1, 0)
binary_acc = BinaryAccuracy()
binary_prec = BinaryPrecision()
binary_recall = BinaryRecall()
binary_f1 = BinaryF1Score()
label_metrics[label] = {"Model_Name": img_enc_name,
"Label": label,
"Accuracy": binary_acc(pred_labels, true_labels).item(),
"Precision": binary_prec(pred_labels, true_labels).item(),
"Recall": binary_recall(pred_labels, true_labels).item(),
"F1-Score": binary_f1(pred_labels, true_labels).item()}

writer = csv.DictWriter(out, fieldnames=["Model_Name", "Label", "Accuracy", "Precision", "Recall", "F1-Score"])
writer.writeheader()
for label, metrics in label_metrics.items():
writer.writerow(metrics)
6 changes: 4 additions & 2 deletions modeling/stitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@


import operator

import yaml
from modeling import train, negative_label, FRAME_TYPES

from modeling import train, negative_label


class Stitcher:
Expand All @@ -24,7 +26,7 @@ def __init__(self, **config):
self.min_timeframe_score = config.get("minTimeframeScore")
self.min_frame_count = config.get("minFrameCount")
self.static_frames = self.config.get("staticFrames")
self.prebin_labels = train.pre_bin_label_names(self.model_config, FRAME_TYPES)
self.prebin_labels = train.pretraining_binned_label(self.model_config)
self.postbin_labels = train.post_bin_label_names(self.model_config)
self.use_postbinning = "post" in self.model_config["bins"]
self.debug = False
Expand Down
142 changes: 24 additions & 118 deletions modeling/train.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,22 @@
import argparse
import csv
import json
import logging
import platform
import shutil
import sys
import time
from collections import defaultdict
from pathlib import Path
from typing import List, IO
import copy

import numpy as np
import torch
import torch.nn as nn
import yaml
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torchmetrics import functional as metrics
from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score
from tqdm import tqdm

import modeling
from modeling import data_loader, FRAME_TYPES
from modeling.evaluate import evaluate

logging.basicConfig(
level=logging.WARNING,
Expand Down Expand Up @@ -59,37 +53,13 @@ def get_guids(data_dir):
return sorted(guids)


def pre_bin(label, specs):
if specs is None or "pre" not in specs["bins"]:
def pretraining_bin(label, specs):
if specs is None or "bins" not in specs:
return int_encode(label)
for i, bin in enumerate(specs["bins"]["pre"].values()):
if label and label in bin:
for i, ptbin in enumerate(specs["bins"].values()):
if label and label in ptbin:
return i
return len(specs["bins"]["pre"].keys())


def post_bin(label, specs):
if specs is None:
return int_encode(label)
# If no post binning method, just return the label
if "post" not in specs["bins"]:
return label
# If there was no pre-binning, use default int encoding
if type(label) != str and "pre" not in specs["bins"]:
if label >= len(FRAME_TYPES):
return len(FRAME_TYPES)
label_name = FRAME_TYPES[label]
# Otherwise, get label name from pre-binning
else:
pre_bins = specs["bins"]["pre"].keys()
if label >= len(pre_bins):
return len(pre_bins)
label_name = list(pre_bins)[label]

for i, post_bin in enumerate(specs["bins"]["post"].values()):
if label_name in post_bin:
return i
return len(specs["bins"]["post"].keys())
return len(specs["bins"].keys())


def load_config(config):
Expand Down Expand Up @@ -132,7 +102,7 @@ def get_net(in_dim, n_labels, num_layers, dropout=0.0):
def prepare_datasets(indir, train_guids, validation_guids, configs):
"""
Given a directory of pre-computed dense feature vectors,
prepare the training and validation datasets. The preparation incluses
prepare the training and validation datasets. The preparation includes
1. positional encodings are applied.
2. 'gold' labels are attached to each vector.
3. split of vectors into training and validation sets (at video-level, meaning all frames from a video are either in training or validation set).
Expand All @@ -142,8 +112,8 @@ def prepare_datasets(indir, train_guids, validation_guids, configs):
train_labels = []
valid_vectors = []
valid_labels = []
if configs and 'bins' in configs and 'pre' in configs['bins']:
pre_bin_size = len(configs['bins']['pre'].keys()) + 1
if configs and 'bins' in configs:
pre_bin_size = len(configs['bins'].keys()) + 1
else:
pre_bin_size = len(FRAME_TYPES) + 1
train_vimg = valid_vimg = 0
Expand All @@ -163,7 +133,7 @@ def prepare_datasets(indir, train_guids, validation_guids, configs):
total_video_len = labels['duration']
for i, vec in enumerate(feature_vecs):
if not labels['frames'][i]['mod']: # "transitional" frames
pre_binned_label = pre_bin(labels['frames'][i]['label'], configs)
pre_binned_label = pretraining_bin(labels['frames'][i]['label'], configs)
vector = torch.from_numpy(vec)
position = labels['frames'][i]['curr_time']
vector = extractor.encode_position(position, total_video_len, vector)
Expand Down Expand Up @@ -213,11 +183,11 @@ def k_fold_train(indir, outdir, config_file, configs, train_id=time.strftime("%Y
logger.info(f'Split {i}: training on {len(train_guids)} videos, validating on {validation_guids}')
export_csv_file = f"{outdir}/{train_id}.kfold_{i:03d}.csv"
export_model_file = f"{outdir}/{train_id}.kfold_{i:03d}.pt"
model, p, r, f = train_model(
model = train_model(
get_net(train.feat_dim, labelset_size, configs['num_layers'], configs['dropouts']),
loss, device, train_loader, valid_loader, configs, labelset_size,
export_fname=export_csv_file)
loss, device, train_loader, configs)
torch.save(model.state_dict(), export_model_file)
p, r, f = evaluate(model, valid_loader, pretraining_binned_label(config), export_fname=export_csv_file)
val_set_spec.append(validation_guids)
p_scores.append(p)
r_scores.append(r)
Expand All @@ -230,9 +200,6 @@ def k_fold_train(indir, outdir, config_file, configs, train_id=time.strftime("%Y


def export_kfold_config(config_file: str, configs: dict, outfile: str):#, train_id: str):
#config_path = Path(f"{outdir}", f"{train_id}.kfold_config.yml")
#print('>>>', config_path)
#config_path.parent.mkdir(parents=True, exist_ok=True)
if config_file is None:
configs_copy = copy.deepcopy(configs)
with open(outfile, 'w') as fh:
Expand Down Expand Up @@ -261,34 +228,22 @@ def export_kfold_results(trial_specs, p_scores, r_scores, f_scores, p_results, *
out.write(f'\trecall = {sum(r_scores) / len(r_scores)}\n')


def pre_bin_label_names(config, raw_labels=None):
if 'pre' in config["bins"]:
return list(config["bins"]["pre"].keys()) + [modeling.negative_label]
elif raw_labels is not None:
return raw_labels + [modeling.negative_label]
else:
return []
def pretraining_binned_label(config):
if 'bins' in config:
return list(config["bins"].keys()) + [modeling.negative_label]
return modeling.FRAME_TYPES + [modeling.negative_label]


def post_bin_label_names(config):
post_labels = list(config["bins"].get("post", {}).keys())
if post_labels:
return post_labels + [modeling.negative_label]
else:
return pre_bin_label_names(config)

return pretraining_binned_label(config)

def get_final_label_names(config):
if config and "post" in config["bins"]:
return post_bin_label_names(config)
elif config and "pre" in config["bins"]:
return pre_bin_label_names(config)
else:
return FRAME_TYPES + [modeling.negative_label]


def train_model(model, loss_fn, device, train_loader, valid_loader, configs, n_labels, export_fname=None):
since = time.time()
def train_model(model, loss_fn, device, train_loader, configs):
since = time.perf_counter()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for num_epoch in tqdm(range(configs['num_epochs'])):
Expand All @@ -314,62 +269,13 @@ def train_model(model, loss_fn, device, train_loader, valid_loader, configs, n_l
logger.debug(f'Loss: {loss.sum().item():.4f}')

epoch_loss = running_loss / len(train_loader)

model.eval()
for vfeats, vlabels in valid_loader:
outputs = model(vfeats)
_, preds = torch.max(outputs, 1)
# post-binning
preds = torch.from_numpy(np.vectorize(post_bin)(preds, configs))
vlabels = torch.from_numpy(np.vectorize(post_bin)(vlabels, configs))
p = metrics.precision(preds, vlabels, 'multiclass', num_classes=n_labels, average='macro')
r = metrics.recall(preds, vlabels, 'multiclass', num_classes=n_labels, average='macro')
f = metrics.f1_score(preds, vlabels, 'multiclass', num_classes=n_labels, average='macro')
# m = metrics.confusion_matrix(preds, vlabels, 'multiclass', num_classes=n_labels)

final_classes = get_final_label_names(configs)

logger.debug(f'Loss: {epoch_loss:.4f} after {num_epoch+1} epochs')
time_elapsed = time.time() - since
time_elapsed = time.perf_counter() - since
logger.info(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')

if not export_fname:
export_f = sys.stdout
else:
path = Path(export_fname)
path.parent.mkdir(parents=True, exist_ok=True)
export_f = open(path, 'w', encoding='utf8')
export_train_result(out=export_f, predictions=preds, labels=vlabels,
labelset=final_classes, img_enc_name=train_loader.dataset.img_enc_name)
logger.info(f"Exported to {export_f.name}")

return model, p, r, f


def export_train_result(out: IO, predictions: Tensor, labels: Tensor, labelset: List[str], img_enc_name: str):
"""Exports the data into a human-readable format.
"""

label_metrics = defaultdict(dict)

for i, label in enumerate(labelset):
pred_labels = torch.where(predictions == i, 1, 0)
true_labels = torch.where(labels == i, 1, 0)
binary_acc = BinaryAccuracy()
binary_prec = BinaryPrecision()
binary_recall = BinaryRecall()
binary_f1 = BinaryF1Score()
label_metrics[label] = {"Model_Name": img_enc_name,
"Label": label,
"Accuracy": binary_acc(pred_labels, true_labels).item(),
"Precision": binary_prec(pred_labels, true_labels).item(),
"Recall": binary_recall(pred_labels, true_labels).item(),
"F1-Score": binary_f1(pred_labels, true_labels).item()}

writer = csv.DictWriter(out, fieldnames=["Model_Name", "Label", "Accuracy", "Precision", "Recall", "F1-Score"])
writer.writeheader()
for label, metrics in label_metrics.items():
writer.writerow(metrics)

model.eval()
return model


if __name__ == "__main__":
Expand Down

0 comments on commit b4651c4

Please sign in to comment.