From 8b7f578098f01a1612bb1fce69c800ca646b159a Mon Sep 17 00:00:00 2001 From: Keigh Rim Date: Thu, 14 Mar 2024 22:13:57 -0400 Subject: [PATCH] TimePoint outputs now include softmax for `NEG` label --- app.py | 3 +-- modeling/stitch.py | 2 +- modeling/train.py | 5 +++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/app.py b/app.py index 6820e87..bddc739 100644 --- a/app.py +++ b/app.py @@ -117,9 +117,8 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif: if self.logger.isEnabledFor(logging.DEBUG): self.logger.debug(f"Processing took {time.perf_counter() - t} seconds") - tp_labelset = FRAME_TYPES new_view.new_contain(AnnotationTypes.TimePoint, - document=vd.id, timeUnit='milliseconds', labelset=tp_labelset) + document=vd.id, timeUnit='milliseconds', labelset=self.classifier.postbin_labels) for prediction in predictions: timepoint_annotation = new_view.new_annotation(AnnotationTypes.TimePoint) diff --git a/modeling/stitch.py b/modeling/stitch.py index a94ce7b..c0f074c 100644 --- a/modeling/stitch.py +++ b/modeling/stitch.py @@ -61,7 +61,7 @@ def collect_timeframes(self, predictions: list) -> list: if self.debug: print('>>> labels', labels) timeframes = [] - open_frames = { label: TimeFrame(label, self) for label in labels} + open_frames = {label: TimeFrame(label, self) for label in labels} for prediction in predictions: if self.debug: print(prediction) diff --git a/modeling/train.py b/modeling/train.py index 8dba721..eff0d40 100644 --- a/modeling/train.py +++ b/modeling/train.py @@ -265,10 +265,11 @@ def pre_bin_label_names(config, raw_labels=None): if 'pre' in config["bins"]: return list(config["bins"]["pre"].keys()) + [modeling.negative_label] elif raw_labels is not None: - return raw_labels + return raw_labels + [modeling.negative_label] else: return [] + def post_bin_label_names(config): post_labels = list(config["bins"].get("post", {}).keys()) if post_labels: @@ -276,6 +277,7 @@ def post_bin_label_names(config): else: return pre_bin_label_names(config) + def get_final_label_names(config): if config and "post" in config["bins"]: return post_bin_label_names(config) @@ -346,7 +348,6 @@ def train_model(model, loss_fn, device, train_loader, valid_loader, configs, n_l def export_train_result(out: IO, predictions: Tensor, labels: Tensor, labelset: List[str], img_enc_name: str): """Exports the data into a human-readable format. - @return: class-based accuracy metrics for each label, organized into a csv. """ label_metrics = defaultdict(dict)