From eb966a27035986fb6170bc6c669424c15ab8b07c Mon Sep 17 00:00:00 2001 From: Keigh Rim Date: Thu, 14 Mar 2024 23:07:57 -0400 Subject: [PATCH 1/4] moved some developer scripts to `scripts` directory, minor cleanups --- app.py | 6 ------ modeling/stitch.py | 2 -- modeling/evaluate.py => scripts/dev.evaluate.py | 0 modeling/visualize.py => scripts/dev.visualize.py | 0 modeling/show-results.py => scripts/see_results.py | 0 5 files changed, 8 deletions(-) rename modeling/evaluate.py => scripts/dev.evaluate.py (100%) rename modeling/visualize.py => scripts/dev.visualize.py (100%) rename modeling/show-results.py => scripts/see_results.py (100%) diff --git a/app.py b/app.py index bddc739..6172e93 100644 --- a/app.py +++ b/app.py @@ -133,12 +133,6 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif: if not configs.get('useStitcher'): return mmif - labelset = self.classifier.postbin_labels - bins = self.classifier.model_config['bins'] - new_view.new_contain( - AnnotationTypes.TimePoint, - document=vd.id, timeUnit='milliseconds', labelset=labelset) - timeframes = self.stitcher.create_timeframes(predictions) for tf in timeframes: timeframe_annotation = new_view.new_annotation(AnnotationTypes.TimeFrame) diff --git a/modeling/stitch.py b/modeling/stitch.py index c0f074c..808554b 100644 --- a/modeling/stitch.py +++ b/modeling/stitch.py @@ -56,8 +56,6 @@ def collect_timeframes(self, predictions: list) -> list: """Find sequences of frames for all labels where the score of each frame is at least the mininum value as defined in self.min_frame_score.""" labels = self.postbin_labels if self.use_postbinning else self.prebin_labels - if self.use_postbinning: - postbins = self.model_config['bins']['post'] if self.debug: print('>>> labels', labels) timeframes = [] diff --git a/modeling/evaluate.py b/scripts/dev.evaluate.py similarity index 100% rename from modeling/evaluate.py rename to scripts/dev.evaluate.py diff --git a/modeling/visualize.py b/scripts/dev.visualize.py similarity index 100% rename from modeling/visualize.py rename to scripts/dev.visualize.py diff --git a/modeling/show-results.py b/scripts/see_results.py similarity index 100% rename from modeling/show-results.py rename to scripts/see_results.py From b4651c477653c7e4038eb8492e1aa9aaa561fa41 Mon Sep 17 00:00:00 2001 From: Keigh Rim Date: Thu, 14 Mar 2024 23:52:01 -0400 Subject: [PATCH 2/4] removed "postbin" from training code, evaluation is separated and done w/o/ postbins --- modeling/classify.py | 2 +- modeling/config/trainer.yml | 17 ++--- modeling/evaluate.py | 60 +++++++++++++++ modeling/stitch.py | 6 +- modeling/train.py | 142 ++++++------------------------------ 5 files changed, 97 insertions(+), 130 deletions(-) create mode 100644 modeling/evaluate.py diff --git a/modeling/classify.py b/modeling/classify.py index 6b2b5a6..dceb494 100644 --- a/modeling/classify.py +++ b/modeling/classify.py @@ -43,7 +43,7 @@ class Classifier: def __init__(self, **config): self.config = config self.model_config = yaml.safe_load(open(config["model_config_file"])) - self.prebin_labels = train.pre_bin_label_names(self.model_config, FRAME_TYPES) + self.prebin_labels = train.pretraining_binned_label(self.model_config) self.postbin_labels = train.post_bin_label_names(self.model_config) self.featurizer = data_loader.FeatureExtractor( img_enc_name=self.model_config["img_enc_name"], diff --git a/modeling/config/trainer.yml b/modeling/config/trainer.yml index 4e8953e..bca6c97 100644 --- a/modeling/config/trainer.yml +++ b/modeling/config/trainer.yml @@ -41,12 +41,11 @@ pos_enc_dim: 512 max_input_length: 5640000 bins: - pre: - slate: - - "S" - chyron: - - "I" - - "N" - - "Y" - credit: - - "C" + slate: + - "S" + chyron: + - "I" + - "N" + - "Y" + credit: + - "C" diff --git a/modeling/evaluate.py b/modeling/evaluate.py new file mode 100644 index 0000000..f88c92a --- /dev/null +++ b/modeling/evaluate.py @@ -0,0 +1,60 @@ +import csv +import logging +import sys +from collections import defaultdict +from pathlib import Path +from typing import IO, List + +import torch +from torch import Tensor +from torchmetrics import functional as metrics +from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score + + +def evaluate(model, valid_loader, labelset, export_fname=None): + model.eval() + # valid_loader is currently expected to be a single batch + vfeats, vlabels = next(iter(valid_loader)) + outputs = model(vfeats) + _, preds = torch.max(outputs, 1) + p = metrics.precision(preds, vlabels, 'multiclass', num_classes=len(labelset), average='macro') + r = metrics.recall(preds, vlabels, 'multiclass', num_classes=len(labelset), average='macro') + f = metrics.f1_score(preds, vlabels, 'multiclass', num_classes=len(labelset), average='macro') + # m = metrics.confusion_matrix(preds, vlabels, 'multiclass', num_classes=len(labelset)) + + if not export_fname: + export_f = sys.stdout + else: + path = Path(export_fname) + path.parent.mkdir(parents=True, exist_ok=True) + export_f = open(path, 'w', encoding='utf8') + export_train_result(out=export_f, preds=preds, golds=vlabels, + labelset=labelset, img_enc_name=valid_loader.dataset.img_enc_name) + logging.info(f"Exported to {export_f.name}") + return p, r, f + + +def export_train_result(out: IO, preds: Tensor, golds: Tensor, labelset: List[str], img_enc_name: str): + """Exports the data into a human-readable format. + """ + + label_metrics = defaultdict(dict) + + for i, label in enumerate(labelset): + pred_labels = torch.where(preds == i, 1, 0) + true_labels = torch.where(golds == i, 1, 0) + binary_acc = BinaryAccuracy() + binary_prec = BinaryPrecision() + binary_recall = BinaryRecall() + binary_f1 = BinaryF1Score() + label_metrics[label] = {"Model_Name": img_enc_name, + "Label": label, + "Accuracy": binary_acc(pred_labels, true_labels).item(), + "Precision": binary_prec(pred_labels, true_labels).item(), + "Recall": binary_recall(pred_labels, true_labels).item(), + "F1-Score": binary_f1(pred_labels, true_labels).item()} + + writer = csv.DictWriter(out, fieldnames=["Model_Name", "Label", "Accuracy", "Precision", "Recall", "F1-Score"]) + writer.writeheader() + for label, metrics in label_metrics.items(): + writer.writerow(metrics) diff --git a/modeling/stitch.py b/modeling/stitch.py index 808554b..913769c 100644 --- a/modeling/stitch.py +++ b/modeling/stitch.py @@ -10,8 +10,10 @@ import operator + import yaml -from modeling import train, negative_label, FRAME_TYPES + +from modeling import train, negative_label class Stitcher: @@ -24,7 +26,7 @@ def __init__(self, **config): self.min_timeframe_score = config.get("minTimeframeScore") self.min_frame_count = config.get("minFrameCount") self.static_frames = self.config.get("staticFrames") - self.prebin_labels = train.pre_bin_label_names(self.model_config, FRAME_TYPES) + self.prebin_labels = train.pretraining_binned_label(self.model_config) self.postbin_labels = train.post_bin_label_names(self.model_config) self.use_postbinning = "post" in self.model_config["bins"] self.debug = False diff --git a/modeling/train.py b/modeling/train.py index eff0d40..5ff6f75 100644 --- a/modeling/train.py +++ b/modeling/train.py @@ -1,28 +1,22 @@ import argparse -import csv import json import logging import platform import shutil -import sys import time -from collections import defaultdict from pathlib import Path -from typing import List, IO import copy import numpy as np import torch import torch.nn as nn import yaml -from torch import Tensor from torch.utils.data import Dataset, DataLoader -from torchmetrics import functional as metrics -from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score from tqdm import tqdm import modeling from modeling import data_loader, FRAME_TYPES +from modeling.evaluate import evaluate logging.basicConfig( level=logging.WARNING, @@ -59,37 +53,13 @@ def get_guids(data_dir): return sorted(guids) -def pre_bin(label, specs): - if specs is None or "pre" not in specs["bins"]: +def pretraining_bin(label, specs): + if specs is None or "bins" not in specs: return int_encode(label) - for i, bin in enumerate(specs["bins"]["pre"].values()): - if label and label in bin: + for i, ptbin in enumerate(specs["bins"].values()): + if label and label in ptbin: return i - return len(specs["bins"]["pre"].keys()) - - -def post_bin(label, specs): - if specs is None: - return int_encode(label) - # If no post binning method, just return the label - if "post" not in specs["bins"]: - return label - # If there was no pre-binning, use default int encoding - if type(label) != str and "pre" not in specs["bins"]: - if label >= len(FRAME_TYPES): - return len(FRAME_TYPES) - label_name = FRAME_TYPES[label] - # Otherwise, get label name from pre-binning - else: - pre_bins = specs["bins"]["pre"].keys() - if label >= len(pre_bins): - return len(pre_bins) - label_name = list(pre_bins)[label] - - for i, post_bin in enumerate(specs["bins"]["post"].values()): - if label_name in post_bin: - return i - return len(specs["bins"]["post"].keys()) + return len(specs["bins"].keys()) def load_config(config): @@ -132,7 +102,7 @@ def get_net(in_dim, n_labels, num_layers, dropout=0.0): def prepare_datasets(indir, train_guids, validation_guids, configs): """ Given a directory of pre-computed dense feature vectors, - prepare the training and validation datasets. The preparation incluses + prepare the training and validation datasets. The preparation includes 1. positional encodings are applied. 2. 'gold' labels are attached to each vector. 3. split of vectors into training and validation sets (at video-level, meaning all frames from a video are either in training or validation set). @@ -142,8 +112,8 @@ def prepare_datasets(indir, train_guids, validation_guids, configs): train_labels = [] valid_vectors = [] valid_labels = [] - if configs and 'bins' in configs and 'pre' in configs['bins']: - pre_bin_size = len(configs['bins']['pre'].keys()) + 1 + if configs and 'bins' in configs: + pre_bin_size = len(configs['bins'].keys()) + 1 else: pre_bin_size = len(FRAME_TYPES) + 1 train_vimg = valid_vimg = 0 @@ -163,7 +133,7 @@ def prepare_datasets(indir, train_guids, validation_guids, configs): total_video_len = labels['duration'] for i, vec in enumerate(feature_vecs): if not labels['frames'][i]['mod']: # "transitional" frames - pre_binned_label = pre_bin(labels['frames'][i]['label'], configs) + pre_binned_label = pretraining_bin(labels['frames'][i]['label'], configs) vector = torch.from_numpy(vec) position = labels['frames'][i]['curr_time'] vector = extractor.encode_position(position, total_video_len, vector) @@ -213,11 +183,11 @@ def k_fold_train(indir, outdir, config_file, configs, train_id=time.strftime("%Y logger.info(f'Split {i}: training on {len(train_guids)} videos, validating on {validation_guids}') export_csv_file = f"{outdir}/{train_id}.kfold_{i:03d}.csv" export_model_file = f"{outdir}/{train_id}.kfold_{i:03d}.pt" - model, p, r, f = train_model( + model = train_model( get_net(train.feat_dim, labelset_size, configs['num_layers'], configs['dropouts']), - loss, device, train_loader, valid_loader, configs, labelset_size, - export_fname=export_csv_file) + loss, device, train_loader, configs) torch.save(model.state_dict(), export_model_file) + p, r, f = evaluate(model, valid_loader, pretraining_binned_label(config), export_fname=export_csv_file) val_set_spec.append(validation_guids) p_scores.append(p) r_scores.append(r) @@ -230,9 +200,6 @@ def k_fold_train(indir, outdir, config_file, configs, train_id=time.strftime("%Y def export_kfold_config(config_file: str, configs: dict, outfile: str):#, train_id: str): - #config_path = Path(f"{outdir}", f"{train_id}.kfold_config.yml") - #print('>>>', config_path) - #config_path.parent.mkdir(parents=True, exist_ok=True) if config_file is None: configs_copy = copy.deepcopy(configs) with open(outfile, 'w') as fh: @@ -261,13 +228,10 @@ def export_kfold_results(trial_specs, p_scores, r_scores, f_scores, p_results, * out.write(f'\trecall = {sum(r_scores) / len(r_scores)}\n') -def pre_bin_label_names(config, raw_labels=None): - if 'pre' in config["bins"]: - return list(config["bins"]["pre"].keys()) + [modeling.negative_label] - elif raw_labels is not None: - return raw_labels + [modeling.negative_label] - else: - return [] +def pretraining_binned_label(config): + if 'bins' in config: + return list(config["bins"].keys()) + [modeling.negative_label] + return modeling.FRAME_TYPES + [modeling.negative_label] def post_bin_label_names(config): @@ -275,20 +239,11 @@ def post_bin_label_names(config): if post_labels: return post_labels + [modeling.negative_label] else: - return pre_bin_label_names(config) - + return pretraining_binned_label(config) -def get_final_label_names(config): - if config and "post" in config["bins"]: - return post_bin_label_names(config) - elif config and "pre" in config["bins"]: - return pre_bin_label_names(config) - else: - return FRAME_TYPES + [modeling.negative_label] - -def train_model(model, loss_fn, device, train_loader, valid_loader, configs, n_labels, export_fname=None): - since = time.time() +def train_model(model, loss_fn, device, train_loader, configs): + since = time.perf_counter() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) for num_epoch in tqdm(range(configs['num_epochs'])): @@ -314,62 +269,13 @@ def train_model(model, loss_fn, device, train_loader, valid_loader, configs, n_l logger.debug(f'Loss: {loss.sum().item():.4f}') epoch_loss = running_loss / len(train_loader) - - model.eval() - for vfeats, vlabels in valid_loader: - outputs = model(vfeats) - _, preds = torch.max(outputs, 1) - # post-binning - preds = torch.from_numpy(np.vectorize(post_bin)(preds, configs)) - vlabels = torch.from_numpy(np.vectorize(post_bin)(vlabels, configs)) - p = metrics.precision(preds, vlabels, 'multiclass', num_classes=n_labels, average='macro') - r = metrics.recall(preds, vlabels, 'multiclass', num_classes=n_labels, average='macro') - f = metrics.f1_score(preds, vlabels, 'multiclass', num_classes=n_labels, average='macro') - # m = metrics.confusion_matrix(preds, vlabels, 'multiclass', num_classes=n_labels) - - final_classes = get_final_label_names(configs) logger.debug(f'Loss: {epoch_loss:.4f} after {num_epoch+1} epochs') - time_elapsed = time.time() - since + time_elapsed = time.perf_counter() - since logger.info(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s') - - if not export_fname: - export_f = sys.stdout - else: - path = Path(export_fname) - path.parent.mkdir(parents=True, exist_ok=True) - export_f = open(path, 'w', encoding='utf8') - export_train_result(out=export_f, predictions=preds, labels=vlabels, - labelset=final_classes, img_enc_name=train_loader.dataset.img_enc_name) - logger.info(f"Exported to {export_f.name}") - - return model, p, r, f - - -def export_train_result(out: IO, predictions: Tensor, labels: Tensor, labelset: List[str], img_enc_name: str): - """Exports the data into a human-readable format. - """ - - label_metrics = defaultdict(dict) - - for i, label in enumerate(labelset): - pred_labels = torch.where(predictions == i, 1, 0) - true_labels = torch.where(labels == i, 1, 0) - binary_acc = BinaryAccuracy() - binary_prec = BinaryPrecision() - binary_recall = BinaryRecall() - binary_f1 = BinaryF1Score() - label_metrics[label] = {"Model_Name": img_enc_name, - "Label": label, - "Accuracy": binary_acc(pred_labels, true_labels).item(), - "Precision": binary_prec(pred_labels, true_labels).item(), - "Recall": binary_recall(pred_labels, true_labels).item(), - "F1-Score": binary_f1(pred_labels, true_labels).item()} - - writer = csv.DictWriter(out, fieldnames=["Model_Name", "Label", "Accuracy", "Precision", "Recall", "F1-Score"]) - writer.writeheader() - for label, metrics in label_metrics.items(): - writer.writerow(metrics) + + model.eval() + return model if __name__ == "__main__": From 5925f029c6f13446a78e7144ae30f146354186a2 Mon Sep 17 00:00:00 2001 From: Keigh Rim Date: Thu, 14 Mar 2024 23:44:11 -0400 Subject: [PATCH 3/4] got rid of "postbin" from trainer configuration file (temporarily placed under classifier config yaml) --- app.py | 19 ++--- modeling/classify.py | 5 +- modeling/config/classifier.yml | 75 ++++++++++++++++++- ...240126-180026.convnext_lg.kfold_config.yml | 16 ---- ...0212-131937.convnext_tiny.kfold_config.yml | 29 ------- ...240212-132306.convnext_lg.kfold_config.yml | 29 ------- modeling/stitch.py | 16 ++-- modeling/train.py | 8 -- 8 files changed, 88 insertions(+), 109 deletions(-) diff --git a/app.py b/app.py index 6172e93..6af470e 100644 --- a/app.py +++ b/app.py @@ -76,6 +76,9 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif: configs['model_file'] = default_model_storage / f'{parameters["modelName"]}.pt' # model files from k-fold training have the fold number as three-digit suffix, trim it configs['model_config_file'] = default_model_storage / f'{parameters["modelName"][:-4]}_config.yml' + # TODO (krim @ 2024-03-14): make this into a runtime parameter once + # https://github.com/clamsproject/clams-python/issues/197 is resolved + configs['postbin'] = configs['postbins'].get(parameters['modelName'], None) t = time.perf_counter() self.logger.info(f"Initiating classifier with {configs['model_file']}") if self.logger.isEnabledFor(logging.DEBUG): @@ -117,8 +120,8 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif: if self.logger.isEnabledFor(logging.DEBUG): self.logger.debug(f"Processing took {time.perf_counter() - t} seconds") - new_view.new_contain(AnnotationTypes.TimePoint, - document=vd.id, timeUnit='milliseconds', labelset=self.classifier.postbin_labels) + new_view.new_contain(AnnotationTypes.TimePoint, + document=vd.id, timeUnit='milliseconds', labelset=self.stitcher.stitch_label) for prediction in predictions: timepoint_annotation = new_view.new_annotation(AnnotationTypes.TimePoint) @@ -143,18 +146,6 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif: [p.annotation.id for p in tf.representative_predictions()]) return mmif - @staticmethod - def _transform(classification: dict, bins: dict): - """Take the raw classification and turn it into a classification of user - labels. Also includes modeling.negative_label.""" - # TODO: this may not work when there is pre-binning - transformed = {} - for postlabel in bins['post'].keys(): - score = sum([classification[lbl] for lbl in bins['post'][postlabel]]) - transformed[postlabel] = score - transformed[negative_label] = 1 - sum(transformed.values()) - return transformed - if __name__ == "__main__": diff --git a/modeling/classify.py b/modeling/classify.py index dceb494..467040d 100644 --- a/modeling/classify.py +++ b/modeling/classify.py @@ -44,7 +44,6 @@ def __init__(self, **config): self.config = config self.model_config = yaml.safe_load(open(config["model_config_file"])) self.prebin_labels = train.pretraining_binned_label(self.model_config) - self.postbin_labels = train.post_bin_label_names(self.model_config) self.featurizer = data_loader.FeatureExtractor( img_enc_name=self.model_config["img_enc_name"], pos_enc_name=self.model_config.get("pos_enc_name", None), @@ -52,8 +51,8 @@ def __init__(self, **config): max_input_length=self.model_config.get("max_input_length", 0), pos_unit=self.model_config.get("pos_unit", 0)) label_count = len(FRAME_TYPES) + 1 - if 'pre' in self.model_config['bins']: - label_count = len(self.model_config['bins']['pre'].keys()) + 1 + if 'bins' in self.model_config: + label_count = len(self.model_config['pre'].keys()) + 1 self.classifier = train.get_net( in_dim=self.featurizer.feature_vector_dim(), n_labels=label_count, diff --git a/modeling/config/classifier.yml b/modeling/config/classifier.yml index 7e9e97c..2575e08 100644 --- a/modeling/config/classifier.yml +++ b/modeling/config/classifier.yml @@ -18,4 +18,77 @@ minFrameCount: 2 staticFrames: [bars, slate, chyron] # Set to False to turn off the stitcher -useStitcher: True \ No newline at end of file +useStitcher: True + +postbins: + 20240126-180026.convnext_lg.kfold_000: + bars: + - B + slate: + - S + - S:H + - S:C + - S:D + - S:G + chyron: + - I + - N + - Y + credits: + - C + 20240212-131937.convnext_tiny.kfold_000: + bars: + - "B" + slate: + - "S" + - "S:H" + - "S:C" + - "S:D" + - "S:G" + other_opening: + - "W" + - "L" + - "O" + - "M" + chyron: + - "I" + - "N" + - "Y" + credit: + - "C" + - "R" + other_text: + - "E" + - "K" + - "G" + - 'T' + - 'F' + 20240212-132306.convnext_lg.kfold_000: + bars: + - "B" + slate: + - "S" + - "S:H" + - "S:C" + - "S:D" + - "S:G" + other_opening: + - "W" + - "L" + - "O" + - "M" + chyron: + - "I" + - "N" + - "Y" + credit: + - "C" + - "R" + other_text: + - "E" + - "K" + - "G" + - 'T' + - 'F' + + \ No newline at end of file diff --git a/modeling/models/20240126-180026.convnext_lg.kfold_config.yml b/modeling/models/20240126-180026.convnext_lg.kfold_config.yml index 4f6b1cc..9b6fd19 100644 --- a/modeling/models/20240126-180026.convnext_lg.kfold_config.yml +++ b/modeling/models/20240126-180026.convnext_lg.kfold_config.yml @@ -31,19 +31,3 @@ block_guids_valid: - cpb-aacip-512-4b2x34nt7g - cpb-aacip-512-3n20c4tr34 - cpb-aacip-512-3f4kk9534t -bins: - post: - bars: - - B - slate: - - S - - S:H - - S:C - - S:D - - S:G - chyron: - - I - - N - - Y - credits: - - C diff --git a/modeling/models/20240212-131937.convnext_tiny.kfold_config.yml b/modeling/models/20240212-131937.convnext_tiny.kfold_config.yml index 892fff0..b3abb1d 100644 --- a/modeling/models/20240212-131937.convnext_tiny.kfold_config.yml +++ b/modeling/models/20240212-131937.convnext_tiny.kfold_config.yml @@ -39,32 +39,3 @@ pos_enc_dim: 512 # cpb-aacip-259-4j09zf95 Duration: 01:33:59.57, start: 0.000000, bitrate: 852 kb/s # 94 mins = 5640 secs = 5640000 ms max_input_length: 5640000 - -bins: - post: - bars: - - "B" - slate: - - "S" - - "S:H" - - "S:C" - - "S:D" - - "S:G" - other_opening: - - "W" - - "L" - - "O" - - "M" - chyron: - - "I" - - "N" - - "Y" - credit: - - "C" - - "R" - other_text: - - "E" - - "K" - - "G" - - 'T' - - 'F' diff --git a/modeling/models/20240212-132306.convnext_lg.kfold_config.yml b/modeling/models/20240212-132306.convnext_lg.kfold_config.yml index 40724e6..bb46562 100644 --- a/modeling/models/20240212-132306.convnext_lg.kfold_config.yml +++ b/modeling/models/20240212-132306.convnext_lg.kfold_config.yml @@ -39,32 +39,3 @@ pos_enc_dim: 512 # cpb-aacip-259-4j09zf95 Duration: 01:33:59.57, start: 0.000000, bitrate: 852 kb/s # 94 mins = 5640 secs = 5640000 ms max_input_length: 5640000 - -bins: - post: - bars: - - "B" - slate: - - "S" - - "S:H" - - "S:C" - - "S:D" - - "S:G" - other_opening: - - "W" - - "L" - - "O" - - "M" - chyron: - - "I" - - "N" - - "Y" - credit: - - "C" - - "R" - other_text: - - "E" - - "K" - - "G" - - 'T' - - 'F' diff --git a/modeling/stitch.py b/modeling/stitch.py index 913769c..2c3e27c 100644 --- a/modeling/stitch.py +++ b/modeling/stitch.py @@ -26,9 +26,8 @@ def __init__(self, **config): self.min_timeframe_score = config.get("minTimeframeScore") self.min_frame_count = config.get("minFrameCount") self.static_frames = self.config.get("staticFrames") - self.prebin_labels = train.pretraining_binned_label(self.model_config) - self.postbin_labels = train.post_bin_label_names(self.model_config) - self.use_postbinning = "post" in self.model_config["bins"] + self.model_label = train.pretraining_binned_label(self.model_config) + self.stitch_label = config.get("postbin") self.debug = False def __str__(self): @@ -38,8 +37,8 @@ def __str__(self): def create_timeframes(self, predictions: list) -> list: if self.debug: - print('pre-bin labels', self.prebin_labels) - print('post-bin labels', self.postbin_labels) + print('TimePoint labels', self.model_label) + print('TimeFrame labels', list(self.stitch_label.keys())) timeframes = self.collect_timeframes(predictions) if self.debug: print_timeframes('Collected frames', timeframes) @@ -57,7 +56,7 @@ def create_timeframes(self, predictions: list) -> list: def collect_timeframes(self, predictions: list) -> list: """Find sequences of frames for all labels where the score of each frame is at least the mininum value as defined in self.min_frame_score.""" - labels = self.postbin_labels if self.use_postbinning else self.prebin_labels + labels = self.stitch_label if self.stitch_label is not None else self.model_label if self.debug: print('>>> labels', labels) timeframes = [] @@ -112,11 +111,10 @@ def is_included(self, frame, outlawed_timepoints: set) -> bool: def _score_for_label(self, label: str, prediction): """Return the score for the label, this is somewhat more complicated when postbinning is used.""" - if not self.use_postbinning: + if self.stitch_label is None: return prediction.score_for_label(label) else: - postbins = self.model_config['bins']['post'] - return prediction.score_for_labels(postbins[label]) + return prediction.score_for_labels(self.stitch_label[label]) class TimeFrame: diff --git a/modeling/train.py b/modeling/train.py index 5ff6f75..06d6b7e 100644 --- a/modeling/train.py +++ b/modeling/train.py @@ -234,14 +234,6 @@ def pretraining_binned_label(config): return modeling.FRAME_TYPES + [modeling.negative_label] -def post_bin_label_names(config): - post_labels = list(config["bins"].get("post", {}).keys()) - if post_labels: - return post_labels + [modeling.negative_label] - else: - return pretraining_binned_label(config) - - def train_model(model, loss_fn, device, train_loader, configs): since = time.perf_counter() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) From 667bbbea2bb22f4bb0869dcb392deaca141f0680 Mon Sep 17 00:00:00 2001 From: Keigh Rim Date: Sat, 16 Mar 2024 17:16:51 -0400 Subject: [PATCH 4/4] fixes for wrong property names and values --- app.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/app.py b/app.py index 6af470e..49fa0f1 100644 --- a/app.py +++ b/app.py @@ -121,7 +121,7 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif: self.logger.debug(f"Processing took {time.perf_counter() - t} seconds") new_view.new_contain(AnnotationTypes.TimePoint, - document=vd.id, timeUnit='milliseconds', labelset=self.stitcher.stitch_label) + document=vd.id, timeUnit='milliseconds', labelset=FRAME_TYPES + [negative_label]) for prediction in predictions: timepoint_annotation = new_view.new_annotation(AnnotationTypes.TimePoint) @@ -136,6 +136,8 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif: if not configs.get('useStitcher'): return mmif + new_view.new_contain(AnnotationTypes.TimeFrame, + document=vd.id, timeUnit='milliseconds', labelset=list(self.stitcher.stitch_label.keys())) timeframes = self.stitcher.create_timeframes(predictions) for tf in timeframes: timeframe_annotation = new_view.new_annotation(AnnotationTypes.TimeFrame)