Skip to content

Commit

Permalink
Using post label classification and sorting time frames (issues #72 and
Browse files Browse the repository at this point in the history
  • Loading branch information
marcverhagen committed Feb 22, 2024
1 parent 0ea2925 commit 7838415
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
30 changes: 24 additions & 6 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from mmif import Mmif, View, AnnotationTypes, DocumentTypes
from mmif.utils import video_document_helper as vdh

from modeling import classify, stitch
from modeling import classify, stitch, negative_label

logging.basicConfig(filename='swt.log', level=logging.DEBUG)

Expand Down Expand Up @@ -54,23 +54,30 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif:
predictions = self.classifier.process_video(vcap)
timeframes = self.stitcher.create_timeframes(predictions)

labelset = self.classifier.postbin_labels
bins = self.classifier.model_config['bins']

new_view.new_contain(
AnnotationTypes.TimeFrame, document=vd.id, timeUnit='milliseconds')
AnnotationTypes.TimeFrame,
document=vd.id, timeUnit='milliseconds', labelset=labelset)
new_view.new_contain(
AnnotationTypes.TimePoint, document=vd.id, timeUnit='milliseconds')
AnnotationTypes.TimePoint,
document=vd.id, timeUnit='milliseconds', labelset=labelset)

for tf in timeframes:
timeframe_annotation = new_view.new_annotation(AnnotationTypes.TimeFrame)
timeframe_annotation.add_property("frameType", tf.label),
timeframe_annotation.add_property("score", tf.score)
timeframe_annotation.add_property("scores", tf.scores)
timeframe_annotation.add_property("label", tf.label),
timeframe_annotation.add_property('classification', {tf.label: tf.score})
#timeframe_annotation.add_property("score", tf.score)
#timeframe_annotation.add_property("scores", tf.scores)
timepoint_annotations = []
for prediction in tf.targets:
timepoint_annotation = new_view.new_annotation(AnnotationTypes.TimePoint)
prediction.annotation = timepoint_annotation
scores = [prediction.score_for_label(lbl) for lbl in prediction.labels]
label = self._label_with_highest_score(prediction.labels, scores)
classification = {l:s for l, s in zip(prediction.labels, scores)}
classification = self._transform(classification, bins)
timepoint_annotation.add_property('timePoint', prediction.timepoint)
timepoint_annotation.add_property('label', label)
timepoint_annotation.add_property('classification', classification)
Expand Down Expand Up @@ -108,6 +115,17 @@ def _label_with_highest_score(labels: list, scores: list) -> str:
sorted_scores = list(sorted(zip(scores, labels), reverse=True))
return sorted_scores[0][1]

@staticmethod
def _transform(classification: dict, bins: dict):
"""Take the raw classification and turn it into a classification of user
labels. Also includes modeling.negative_label."""
# TODO: this may not work when there is pre-binning
transformed = {}
for postlabel in bins['post'].keys():
score = sum([classification[lbl] for lbl in bins['post'][postlabel]])
transformed[postlabel] = score
transformed[negative_label] = 1 - sum(transformed.values())
return transformed


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions modeling/stitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def create_timeframes(self, predictions: list) -> list:
timeframes = self.remove_overlapping_timeframes(timeframes)
for tf in timeframes:
tf.set_representatives()
timeframes = list(sorted(timeframes, key=(lambda tf: tf.start)))
if self.debug:
print_timeframes('Final frames', timeframes)
return timeframes
Expand Down

0 comments on commit 7838415

Please sign in to comment.