Skip to content

Commit

Permalink
removed stitcher's dependency to model_config, "postbin" now complete…
Browse files Browse the repository at this point in the history
…ly decoupled from model_config
  • Loading branch information
keighrim committed Jun 27, 2024
1 parent 02a7ac5 commit dec63bd
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions modeling/stitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
which takes a list of predictions from the classifier and creates TimeFrames.
"""


import itertools
import operator

import yaml
Expand All @@ -20,12 +19,10 @@ class Stitcher:

def __init__(self, **config):
self.config = config
self.model_config = yaml.safe_load(open(config["model_config_file"]))
self.sample_rate = config.get("sampleRate")
self.min_frame_score = config.get("minFrameScore")
self.min_timeframe_score = config.get("minTimeframeScore")
self.min_frame_count = config.get("minFrameCount")
self.model_labels = train.pretraining_binned_label(self.model_config)
self.stitch_labels = config.get("postbin")
self.allow_overlap = config.get('allowOverlap')
self.debug = False
Expand All @@ -37,7 +34,7 @@ def __str__(self):

def create_timeframes(self, predictions: list) -> list:
if self.debug:
print('>>> TimePoint labels:', ' '.join(self.model_labels))
print('>>> TimePoint labels:', ' '.join(list(itertools.chain(self.stitch_labels.values()))))
print('>>> TimeFrame labels:', ' '.join(list(self.stitch_labels.keys())))
timeframes = self.collect_timeframes(predictions)
if self.debug:
Expand All @@ -57,13 +54,12 @@ def create_timeframes(self, predictions: list) -> list:
def collect_timeframes(self, predictions: list) -> list:
"""Find sequences of frames for all labels where the score of each frame
is at least the mininum value as defined in self.min_frame_score."""
labels = self.stitch_labels if self.stitch_labels is not None else self.model_labels
if self.debug:
print('>>> labels', labels)
print('>>> labels', self.stitch_labels)
timeframes = []
open_frames = {label: TimeFrame(label, self) for label in labels}
open_frames = {label: TimeFrame(label, self) for label in self.stitch_labels}
for prediction in predictions:
for label in [label for label in labels if label != negative_label]:
for label in [label for label in self.stitch_labels if label != negative_label]:
score = self._score_for_label(label, prediction)
if score < self.min_frame_score:
# the second part checks whether there is something in the timeframe
Expand All @@ -72,7 +68,7 @@ def collect_timeframes(self, predictions: list) -> list:
open_frames[label] = TimeFrame(label, self)
else:
open_frames[label].add_prediction(prediction, score)
for label in labels:
for label in self.stitch_labels:
if open_frames[label]:
timeframes.append(open_frames[label])
for tf in timeframes:
Expand Down

0 comments on commit dec63bd

Please sign in to comment.