Skip to content

Commit

Permalink
Merge pull request #89 from clamsproject/58-decouple-train-eval
Browse files Browse the repository at this point in the history
decoupling "postbin" and evaluation from training code
  • Loading branch information
keighrim authored Mar 17, 2024
2 parents 1ed0d2c + 667bbbe commit 68059f8
Show file tree
Hide file tree
Showing 13 changed files with 376 additions and 436 deletions.
27 changes: 7 additions & 20 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif:
configs['model_file'] = default_model_storage / f'{parameters["modelName"]}.pt'
# model files from k-fold training have the fold number as three-digit suffix, trim it
configs['model_config_file'] = default_model_storage / f'{parameters["modelName"][:-4]}_config.yml'
# TODO (krim @ 2024-03-14): make this into a runtime parameter once
# https://github.com/clamsproject/clams-python/issues/197 is resolved
configs['postbin'] = configs['postbins'].get(parameters['modelName'], None)
t = time.perf_counter()
self.logger.info(f"Initiating classifier with {configs['model_file']}")
if self.logger.isEnabledFor(logging.DEBUG):
Expand Down Expand Up @@ -117,8 +120,8 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif:
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f"Processing took {time.perf_counter() - t} seconds")

new_view.new_contain(AnnotationTypes.TimePoint,
document=vd.id, timeUnit='milliseconds', labelset=self.classifier.postbin_labels)
new_view.new_contain(AnnotationTypes.TimePoint,
document=vd.id, timeUnit='milliseconds', labelset=FRAME_TYPES + [negative_label])

for prediction in predictions:
timepoint_annotation = new_view.new_annotation(AnnotationTypes.TimePoint)
Expand All @@ -133,12 +136,8 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif:
if not configs.get('useStitcher'):
return mmif

labelset = self.classifier.postbin_labels
bins = self.classifier.model_config['bins']
new_view.new_contain(
AnnotationTypes.TimePoint,
document=vd.id, timeUnit='milliseconds', labelset=labelset)

new_view.new_contain(AnnotationTypes.TimeFrame,
document=vd.id, timeUnit='milliseconds', labelset=list(self.stitcher.stitch_label.keys()))
timeframes = self.stitcher.create_timeframes(predictions)
for tf in timeframes:
timeframe_annotation = new_view.new_annotation(AnnotationTypes.TimeFrame)
Expand All @@ -149,18 +148,6 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif:
[p.annotation.id for p in tf.representative_predictions()])
return mmif

@staticmethod
def _transform(classification: dict, bins: dict):
"""Take the raw classification and turn it into a classification of user
labels. Also includes modeling.negative_label."""
# TODO: this may not work when there is pre-binning
transformed = {}
for postlabel in bins['post'].keys():
score = sum([classification[lbl] for lbl in bins['post'][postlabel]])
transformed[postlabel] = score
transformed[negative_label] = 1 - sum(transformed.values())
return transformed


if __name__ == "__main__":

Expand Down
7 changes: 3 additions & 4 deletions modeling/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,16 @@ class Classifier:
def __init__(self, **config):
self.config = config
self.model_config = yaml.safe_load(open(config["model_config_file"]))
self.prebin_labels = train.pre_bin_label_names(self.model_config, FRAME_TYPES)
self.postbin_labels = train.post_bin_label_names(self.model_config)
self.prebin_labels = train.pretraining_binned_label(self.model_config)
self.featurizer = data_loader.FeatureExtractor(
img_enc_name=self.model_config["img_enc_name"],
pos_enc_name=self.model_config.get("pos_enc_name", None),
pos_enc_dim=self.model_config.get("pos_enc_dim", 0),
max_input_length=self.model_config.get("max_input_length", 0),
pos_unit=self.model_config.get("pos_unit", 0))
label_count = len(FRAME_TYPES) + 1
if 'pre' in self.model_config['bins']:
label_count = len(self.model_config['bins']['pre'].keys()) + 1
if 'bins' in self.model_config:
label_count = len(self.model_config['pre'].keys()) + 1
self.classifier = train.get_net(
in_dim=self.featurizer.feature_vector_dim(),
n_labels=label_count,
Expand Down
75 changes: 74 additions & 1 deletion modeling/config/classifier.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,77 @@ minFrameCount: 2
staticFrames: [bars, slate, chyron]

# Set to False to turn off the stitcher
useStitcher: True
useStitcher: True

postbins:
20240126-180026.convnext_lg.kfold_000:
bars:
- B
slate:
- S
- S:H
- S:C
- S:D
- S:G
chyron:
- I
- N
- Y
credits:
- C
20240212-131937.convnext_tiny.kfold_000:
bars:
- "B"
slate:
- "S"
- "S:H"
- "S:C"
- "S:D"
- "S:G"
other_opening:
- "W"
- "L"
- "O"
- "M"
chyron:
- "I"
- "N"
- "Y"
credit:
- "C"
- "R"
other_text:
- "E"
- "K"
- "G"
- 'T'
- 'F'
20240212-132306.convnext_lg.kfold_000:
bars:
- "B"
slate:
- "S"
- "S:H"
- "S:C"
- "S:D"
- "S:G"
other_opening:
- "W"
- "L"
- "O"
- "M"
chyron:
- "I"
- "N"
- "Y"
credit:
- "C"
- "R"
other_text:
- "E"
- "K"
- "G"
- 'T'
- 'F'


17 changes: 8 additions & 9 deletions modeling/config/trainer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,11 @@ pos_enc_dim: 512
max_input_length: 5640000

bins:
pre:
slate:
- "S"
chyron:
- "I"
- "N"
- "Y"
credit:
- "C"
slate:
- "S"
chyron:
- "I"
- "N"
- "Y"
credit:
- "C"
Loading

0 comments on commit 68059f8

Please sign in to comment.