Skip to content

Commit

Permalink
got rid of "postbin" from trainer configuration file (temporarily pla…
Browse files Browse the repository at this point in the history
…ced under classifier config yaml)
  • Loading branch information
keighrim committed Mar 15, 2024
1 parent b4651c4 commit 5925f02
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 109 deletions.
19 changes: 5 additions & 14 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif:
configs['model_file'] = default_model_storage / f'{parameters["modelName"]}.pt'
# model files from k-fold training have the fold number as three-digit suffix, trim it
configs['model_config_file'] = default_model_storage / f'{parameters["modelName"][:-4]}_config.yml'
# TODO (krim @ 2024-03-14): make this into a runtime parameter once
# https://github.com/clamsproject/clams-python/issues/197 is resolved
configs['postbin'] = configs['postbins'].get(parameters['modelName'], None)
t = time.perf_counter()
self.logger.info(f"Initiating classifier with {configs['model_file']}")
if self.logger.isEnabledFor(logging.DEBUG):
Expand Down Expand Up @@ -117,8 +120,8 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif:
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f"Processing took {time.perf_counter() - t} seconds")

new_view.new_contain(AnnotationTypes.TimePoint,
document=vd.id, timeUnit='milliseconds', labelset=self.classifier.postbin_labels)
new_view.new_contain(AnnotationTypes.TimePoint,
document=vd.id, timeUnit='milliseconds', labelset=self.stitcher.stitch_label)

for prediction in predictions:
timepoint_annotation = new_view.new_annotation(AnnotationTypes.TimePoint)
Expand All @@ -143,18 +146,6 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif:
[p.annotation.id for p in tf.representative_predictions()])
return mmif

@staticmethod
def _transform(classification: dict, bins: dict):
"""Take the raw classification and turn it into a classification of user
labels. Also includes modeling.negative_label."""
# TODO: this may not work when there is pre-binning
transformed = {}
for postlabel in bins['post'].keys():
score = sum([classification[lbl] for lbl in bins['post'][postlabel]])
transformed[postlabel] = score
transformed[negative_label] = 1 - sum(transformed.values())
return transformed


if __name__ == "__main__":

Expand Down
5 changes: 2 additions & 3 deletions modeling/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,15 @@ def __init__(self, **config):
self.config = config
self.model_config = yaml.safe_load(open(config["model_config_file"]))
self.prebin_labels = train.pretraining_binned_label(self.model_config)
self.postbin_labels = train.post_bin_label_names(self.model_config)
self.featurizer = data_loader.FeatureExtractor(
img_enc_name=self.model_config["img_enc_name"],
pos_enc_name=self.model_config.get("pos_enc_name", None),
pos_enc_dim=self.model_config.get("pos_enc_dim", 0),
max_input_length=self.model_config.get("max_input_length", 0),
pos_unit=self.model_config.get("pos_unit", 0))
label_count = len(FRAME_TYPES) + 1
if 'pre' in self.model_config['bins']:
label_count = len(self.model_config['bins']['pre'].keys()) + 1
if 'bins' in self.model_config:
label_count = len(self.model_config['pre'].keys()) + 1
self.classifier = train.get_net(
in_dim=self.featurizer.feature_vector_dim(),
n_labels=label_count,
Expand Down
75 changes: 74 additions & 1 deletion modeling/config/classifier.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,77 @@ minFrameCount: 2
staticFrames: [bars, slate, chyron]

# Set to False to turn off the stitcher
useStitcher: True
useStitcher: True

postbins:
20240126-180026.convnext_lg.kfold_000:
bars:
- B
slate:
- S
- S:H
- S:C
- S:D
- S:G
chyron:
- I
- N
- Y
credits:
- C
20240212-131937.convnext_tiny.kfold_000:
bars:
- "B"
slate:
- "S"
- "S:H"
- "S:C"
- "S:D"
- "S:G"
other_opening:
- "W"
- "L"
- "O"
- "M"
chyron:
- "I"
- "N"
- "Y"
credit:
- "C"
- "R"
other_text:
- "E"
- "K"
- "G"
- 'T'
- 'F'
20240212-132306.convnext_lg.kfold_000:
bars:
- "B"
slate:
- "S"
- "S:H"
- "S:C"
- "S:D"
- "S:G"
other_opening:
- "W"
- "L"
- "O"
- "M"
chyron:
- "I"
- "N"
- "Y"
credit:
- "C"
- "R"
other_text:
- "E"
- "K"
- "G"
- 'T'
- 'F'


16 changes: 0 additions & 16 deletions modeling/models/20240126-180026.convnext_lg.kfold_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,3 @@ block_guids_valid:
- cpb-aacip-512-4b2x34nt7g
- cpb-aacip-512-3n20c4tr34
- cpb-aacip-512-3f4kk9534t
bins:
post:
bars:
- B
slate:
- S
- S:H
- S:C
- S:D
- S:G
chyron:
- I
- N
- Y
credits:
- C
29 changes: 0 additions & 29 deletions modeling/models/20240212-131937.convnext_tiny.kfold_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,32 +39,3 @@ pos_enc_dim: 512
# cpb-aacip-259-4j09zf95 Duration: 01:33:59.57, start: 0.000000, bitrate: 852 kb/s
# 94 mins = 5640 secs = 5640000 ms
max_input_length: 5640000

bins:
post:
bars:
- "B"
slate:
- "S"
- "S:H"
- "S:C"
- "S:D"
- "S:G"
other_opening:
- "W"
- "L"
- "O"
- "M"
chyron:
- "I"
- "N"
- "Y"
credit:
- "C"
- "R"
other_text:
- "E"
- "K"
- "G"
- 'T'
- 'F'
29 changes: 0 additions & 29 deletions modeling/models/20240212-132306.convnext_lg.kfold_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,32 +39,3 @@ pos_enc_dim: 512
# cpb-aacip-259-4j09zf95 Duration: 01:33:59.57, start: 0.000000, bitrate: 852 kb/s
# 94 mins = 5640 secs = 5640000 ms
max_input_length: 5640000

bins:
post:
bars:
- "B"
slate:
- "S"
- "S:H"
- "S:C"
- "S:D"
- "S:G"
other_opening:
- "W"
- "L"
- "O"
- "M"
chyron:
- "I"
- "N"
- "Y"
credit:
- "C"
- "R"
other_text:
- "E"
- "K"
- "G"
- 'T'
- 'F'
16 changes: 7 additions & 9 deletions modeling/stitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@ def __init__(self, **config):
self.min_timeframe_score = config.get("minTimeframeScore")
self.min_frame_count = config.get("minFrameCount")
self.static_frames = self.config.get("staticFrames")
self.prebin_labels = train.pretraining_binned_label(self.model_config)
self.postbin_labels = train.post_bin_label_names(self.model_config)
self.use_postbinning = "post" in self.model_config["bins"]
self.model_label = train.pretraining_binned_label(self.model_config)
self.stitch_label = config.get("postbin")
self.debug = False

def __str__(self):
Expand All @@ -38,8 +37,8 @@ def __str__(self):

def create_timeframes(self, predictions: list) -> list:
if self.debug:
print('pre-bin labels', self.prebin_labels)
print('post-bin labels', self.postbin_labels)
print('TimePoint labels', self.model_label)
print('TimeFrame labels', list(self.stitch_label.keys()))
timeframes = self.collect_timeframes(predictions)
if self.debug:
print_timeframes('Collected frames', timeframes)
Expand All @@ -57,7 +56,7 @@ def create_timeframes(self, predictions: list) -> list:
def collect_timeframes(self, predictions: list) -> list:
"""Find sequences of frames for all labels where the score of each frame
is at least the mininum value as defined in self.min_frame_score."""
labels = self.postbin_labels if self.use_postbinning else self.prebin_labels
labels = self.stitch_label if self.stitch_label is not None else self.model_label
if self.debug:
print('>>> labels', labels)
timeframes = []
Expand Down Expand Up @@ -112,11 +111,10 @@ def is_included(self, frame, outlawed_timepoints: set) -> bool:
def _score_for_label(self, label: str, prediction):
"""Return the score for the label, this is somewhat more complicated when
postbinning is used."""
if not self.use_postbinning:
if self.stitch_label is None:
return prediction.score_for_label(label)
else:
postbins = self.model_config['bins']['post']
return prediction.score_for_labels(postbins[label])
return prediction.score_for_labels(self.stitch_label[label])


class TimeFrame:
Expand Down
8 changes: 0 additions & 8 deletions modeling/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,6 @@ def pretraining_binned_label(config):
return modeling.FRAME_TYPES + [modeling.negative_label]


def post_bin_label_names(config):
post_labels = list(config["bins"].get("post", {}).keys())
if post_labels:
return post_labels + [modeling.negative_label]
else:
return pretraining_binned_label(config)


def train_model(model, loss_fn, device, train_loader, configs):
since = time.perf_counter()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
Expand Down

0 comments on commit 5925f02

Please sign in to comment.