Skip to content

Commit

Permalink
Merge pull request #88 from clamsproject/87-fix-neg-score
Browse files Browse the repository at this point in the history
TimePoint outputs now include softmax for `NEG` label
  • Loading branch information
keighrim authored Mar 15, 2024
2 parents 06a0e81 + 8b7f578 commit 1ed0d2c
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
3 changes: 1 addition & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,8 @@ def _annotate(self, mmif: Union[str, dict, Mmif], **parameters) -> Mmif:
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f"Processing took {time.perf_counter() - t} seconds")

tp_labelset = FRAME_TYPES
new_view.new_contain(AnnotationTypes.TimePoint,
document=vd.id, timeUnit='milliseconds', labelset=tp_labelset)
document=vd.id, timeUnit='milliseconds', labelset=self.classifier.postbin_labels)

for prediction in predictions:
timepoint_annotation = new_view.new_annotation(AnnotationTypes.TimePoint)
Expand Down
2 changes: 1 addition & 1 deletion modeling/stitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def collect_timeframes(self, predictions: list) -> list:
if self.debug:
print('>>> labels', labels)
timeframes = []
open_frames = { label: TimeFrame(label, self) for label in labels}
open_frames = {label: TimeFrame(label, self) for label in labels}
for prediction in predictions:
if self.debug:
print(prediction)
Expand Down
5 changes: 3 additions & 2 deletions modeling/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,17 +265,19 @@ def pre_bin_label_names(config, raw_labels=None):
if 'pre' in config["bins"]:
return list(config["bins"]["pre"].keys()) + [modeling.negative_label]
elif raw_labels is not None:
return raw_labels
return raw_labels + [modeling.negative_label]
else:
return []


def post_bin_label_names(config):
post_labels = list(config["bins"].get("post", {}).keys())
if post_labels:
return post_labels + [modeling.negative_label]
else:
return pre_bin_label_names(config)


def get_final_label_names(config):
if config and "post" in config["bins"]:
return post_bin_label_names(config)
Expand Down Expand Up @@ -346,7 +348,6 @@ def train_model(model, loss_fn, device, train_loader, valid_loader, configs, n_l

def export_train_result(out: IO, predictions: Tensor, labels: Tensor, labelset: List[str], img_enc_name: str):
"""Exports the data into a human-readable format.
@return: class-based accuracy metrics for each label, organized into a csv.
"""

label_metrics = defaultdict(dict)
Expand Down

0 comments on commit 1ed0d2c

Please sign in to comment.