diff --git a/app.py b/app.py index 6172e93..6af470e 100644 --- a/app.py +++ b/app.py @@ -76,6 +76,9 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif: configs['model_file'] = default_model_storage / f'{parameters["modelName"]}.pt' # model files from k-fold training have the fold number as three-digit suffix, trim it configs['model_config_file'] = default_model_storage / f'{parameters["modelName"][:-4]}_config.yml' + # TODO (krim @ 2024-03-14): make this into a runtime parameter once + # https://github.com/clamsproject/clams-python/issues/197 is resolved + configs['postbin'] = configs['postbins'].get(parameters['modelName'], None) t = time.perf_counter() self.logger.info(f"Initiating classifier with {configs['model_file']}") if self.logger.isEnabledFor(logging.DEBUG): @@ -117,8 +120,8 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif: if self.logger.isEnabledFor(logging.DEBUG): self.logger.debug(f"Processing took {time.perf_counter() - t} seconds") - new_view.new_contain(AnnotationTypes.TimePoint, - document=vd.id, timeUnit='milliseconds', labelset=self.classifier.postbin_labels) + new_view.new_contain(AnnotationTypes.TimePoint, + document=vd.id, timeUnit='milliseconds', labelset=self.stitcher.stitch_label) for prediction in predictions: timepoint_annotation = new_view.new_annotation(AnnotationTypes.TimePoint) @@ -143,18 +146,6 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif: [p.annotation.id for p in tf.representative_predictions()]) return mmif - @staticmethod - def _transform(classification: dict, bins: dict): - """Take the raw classification and turn it into a classification of user - labels. Also includes modeling.negative_label.""" - # TODO: this may not work when there is pre-binning - transformed = {} - for postlabel in bins['post'].keys(): - score = sum([classification[lbl] for lbl in bins['post'][postlabel]]) - transformed[postlabel] = score - transformed[negative_label] = 1 - sum(transformed.values()) - return transformed - if __name__ == "__main__": diff --git a/modeling/classify.py b/modeling/classify.py index dceb494..467040d 100644 --- a/modeling/classify.py +++ b/modeling/classify.py @@ -44,7 +44,6 @@ def __init__(self, **config): self.config = config self.model_config = yaml.safe_load(open(config["model_config_file"])) self.prebin_labels = train.pretraining_binned_label(self.model_config) - self.postbin_labels = train.post_bin_label_names(self.model_config) self.featurizer = data_loader.FeatureExtractor( img_enc_name=self.model_config["img_enc_name"], pos_enc_name=self.model_config.get("pos_enc_name", None), @@ -52,8 +51,8 @@ def __init__(self, **config): max_input_length=self.model_config.get("max_input_length", 0), pos_unit=self.model_config.get("pos_unit", 0)) label_count = len(FRAME_TYPES) + 1 - if 'pre' in self.model_config['bins']: - label_count = len(self.model_config['bins']['pre'].keys()) + 1 + if 'bins' in self.model_config: + label_count = len(self.model_config['pre'].keys()) + 1 self.classifier = train.get_net( in_dim=self.featurizer.feature_vector_dim(), n_labels=label_count, diff --git a/modeling/config/classifier.yml b/modeling/config/classifier.yml index 7e9e97c..2575e08 100644 --- a/modeling/config/classifier.yml +++ b/modeling/config/classifier.yml @@ -18,4 +18,77 @@ minFrameCount: 2 staticFrames: [bars, slate, chyron] # Set to False to turn off the stitcher -useStitcher: True \ No newline at end of file +useStitcher: True + +postbins: + 20240126-180026.convnext_lg.kfold_000: + bars: + - B + slate: + - S + - S:H + - S:C + - S:D + - S:G + chyron: + - I + - N + - Y + credits: + - C + 20240212-131937.convnext_tiny.kfold_000: + bars: + - "B" + slate: + - "S" + - "S:H" + - "S:C" + - "S:D" + - "S:G" + other_opening: + - "W" + - "L" + - "O" + - "M" + chyron: + - "I" + - "N" + - "Y" + credit: + - "C" + - "R" + other_text: + - "E" + - "K" + - "G" + - 'T' + - 'F' + 20240212-132306.convnext_lg.kfold_000: + bars: + - "B" + slate: + - "S" + - "S:H" + - "S:C" + - "S:D" + - "S:G" + other_opening: + - "W" + - "L" + - "O" + - "M" + chyron: + - "I" + - "N" + - "Y" + credit: + - "C" + - "R" + other_text: + - "E" + - "K" + - "G" + - 'T' + - 'F' + + \ No newline at end of file diff --git a/modeling/models/20240126-180026.convnext_lg.kfold_config.yml b/modeling/models/20240126-180026.convnext_lg.kfold_config.yml index 4f6b1cc..9b6fd19 100644 --- a/modeling/models/20240126-180026.convnext_lg.kfold_config.yml +++ b/modeling/models/20240126-180026.convnext_lg.kfold_config.yml @@ -31,19 +31,3 @@ block_guids_valid: - cpb-aacip-512-4b2x34nt7g - cpb-aacip-512-3n20c4tr34 - cpb-aacip-512-3f4kk9534t -bins: - post: - bars: - - B - slate: - - S - - S:H - - S:C - - S:D - - S:G - chyron: - - I - - N - - Y - credits: - - C diff --git a/modeling/models/20240212-131937.convnext_tiny.kfold_config.yml b/modeling/models/20240212-131937.convnext_tiny.kfold_config.yml index 892fff0..b3abb1d 100644 --- a/modeling/models/20240212-131937.convnext_tiny.kfold_config.yml +++ b/modeling/models/20240212-131937.convnext_tiny.kfold_config.yml @@ -39,32 +39,3 @@ pos_enc_dim: 512 # cpb-aacip-259-4j09zf95 Duration: 01:33:59.57, start: 0.000000, bitrate: 852 kb/s # 94 mins = 5640 secs = 5640000 ms max_input_length: 5640000 - -bins: - post: - bars: - - "B" - slate: - - "S" - - "S:H" - - "S:C" - - "S:D" - - "S:G" - other_opening: - - "W" - - "L" - - "O" - - "M" - chyron: - - "I" - - "N" - - "Y" - credit: - - "C" - - "R" - other_text: - - "E" - - "K" - - "G" - - 'T' - - 'F' diff --git a/modeling/models/20240212-132306.convnext_lg.kfold_config.yml b/modeling/models/20240212-132306.convnext_lg.kfold_config.yml index 40724e6..bb46562 100644 --- a/modeling/models/20240212-132306.convnext_lg.kfold_config.yml +++ b/modeling/models/20240212-132306.convnext_lg.kfold_config.yml @@ -39,32 +39,3 @@ pos_enc_dim: 512 # cpb-aacip-259-4j09zf95 Duration: 01:33:59.57, start: 0.000000, bitrate: 852 kb/s # 94 mins = 5640 secs = 5640000 ms max_input_length: 5640000 - -bins: - post: - bars: - - "B" - slate: - - "S" - - "S:H" - - "S:C" - - "S:D" - - "S:G" - other_opening: - - "W" - - "L" - - "O" - - "M" - chyron: - - "I" - - "N" - - "Y" - credit: - - "C" - - "R" - other_text: - - "E" - - "K" - - "G" - - 'T' - - 'F' diff --git a/modeling/stitch.py b/modeling/stitch.py index 913769c..2c3e27c 100644 --- a/modeling/stitch.py +++ b/modeling/stitch.py @@ -26,9 +26,8 @@ def __init__(self, **config): self.min_timeframe_score = config.get("minTimeframeScore") self.min_frame_count = config.get("minFrameCount") self.static_frames = self.config.get("staticFrames") - self.prebin_labels = train.pretraining_binned_label(self.model_config) - self.postbin_labels = train.post_bin_label_names(self.model_config) - self.use_postbinning = "post" in self.model_config["bins"] + self.model_label = train.pretraining_binned_label(self.model_config) + self.stitch_label = config.get("postbin") self.debug = False def __str__(self): @@ -38,8 +37,8 @@ def __str__(self): def create_timeframes(self, predictions: list) -> list: if self.debug: - print('pre-bin labels', self.prebin_labels) - print('post-bin labels', self.postbin_labels) + print('TimePoint labels', self.model_label) + print('TimeFrame labels', list(self.stitch_label.keys())) timeframes = self.collect_timeframes(predictions) if self.debug: print_timeframes('Collected frames', timeframes) @@ -57,7 +56,7 @@ def create_timeframes(self, predictions: list) -> list: def collect_timeframes(self, predictions: list) -> list: """Find sequences of frames for all labels where the score of each frame is at least the mininum value as defined in self.min_frame_score.""" - labels = self.postbin_labels if self.use_postbinning else self.prebin_labels + labels = self.stitch_label if self.stitch_label is not None else self.model_label if self.debug: print('>>> labels', labels) timeframes = [] @@ -112,11 +111,10 @@ def is_included(self, frame, outlawed_timepoints: set) -> bool: def _score_for_label(self, label: str, prediction): """Return the score for the label, this is somewhat more complicated when postbinning is used.""" - if not self.use_postbinning: + if self.stitch_label is None: return prediction.score_for_label(label) else: - postbins = self.model_config['bins']['post'] - return prediction.score_for_labels(postbins[label]) + return prediction.score_for_labels(self.stitch_label[label]) class TimeFrame: diff --git a/modeling/train.py b/modeling/train.py index 5ff6f75..06d6b7e 100644 --- a/modeling/train.py +++ b/modeling/train.py @@ -234,14 +234,6 @@ def pretraining_binned_label(config): return modeling.FRAME_TYPES + [modeling.negative_label] -def post_bin_label_names(config): - post_labels = list(config["bins"].get("post", {}).keys()) - if post_labels: - return post_labels + [modeling.negative_label] - else: - return pretraining_binned_label(config) - - def train_model(model, loss_fn, device, train_loader, configs): since = time.perf_counter() optimizer = torch.optim.Adam(model.parameters(), lr=0.001)