diff --git a/modeling/stitch.py b/modeling/stitch.py index 5d86a2e..7afe411 100644 --- a/modeling/stitch.py +++ b/modeling/stitch.py @@ -7,8 +7,7 @@ which takes a list of predictions from the classifier and creates TimeFrames. """ - - +import itertools import operator import yaml @@ -20,12 +19,10 @@ class Stitcher: def __init__(self, **config): self.config = config - self.model_config = yaml.safe_load(open(config["model_config_file"])) self.sample_rate = config.get("sampleRate") self.min_frame_score = config.get("minFrameScore") self.min_timeframe_score = config.get("minTimeframeScore") self.min_frame_count = config.get("minFrameCount") - self.model_labels = train.pretraining_binned_label(self.model_config) self.stitch_labels = config.get("postbin") self.allow_overlap = config.get('allowOverlap') self.debug = False @@ -37,7 +34,7 @@ def __str__(self): def create_timeframes(self, predictions: list) -> list: if self.debug: - print('>>> TimePoint labels:', ' '.join(self.model_labels)) + print('>>> TimePoint labels:', ' '.join(list(itertools.chain(self.stitch_labels.values())))) print('>>> TimeFrame labels:', ' '.join(list(self.stitch_labels.keys()))) timeframes = self.collect_timeframes(predictions) if self.debug: @@ -57,13 +54,12 @@ def create_timeframes(self, predictions: list) -> list: def collect_timeframes(self, predictions: list) -> list: """Find sequences of frames for all labels where the score of each frame is at least the mininum value as defined in self.min_frame_score.""" - labels = self.stitch_labels if self.stitch_labels is not None else self.model_labels if self.debug: - print('>>> labels', labels) + print('>>> labels', self.stitch_labels) timeframes = [] - open_frames = {label: TimeFrame(label, self) for label in labels} + open_frames = {label: TimeFrame(label, self) for label in self.stitch_labels} for prediction in predictions: - for label in [label for label in labels if label != negative_label]: + for label in [label for label in self.stitch_labels if label != negative_label]: score = self._score_for_label(label, prediction) if score < self.min_frame_score: # the second part checks whether there is something in the timeframe @@ -72,7 +68,7 @@ def collect_timeframes(self, predictions: list) -> list: open_frames[label] = TimeFrame(label, self) else: open_frames[label].add_prediction(prediction, score) - for label in labels: + for label in self.stitch_labels: if open_frames[label]: timeframes.append(open_frames[label]) for tf in timeframes: