Skip to content

Commit

Permalink
false pos and neg items now correctly updated
Browse files Browse the repository at this point in the history
  • Loading branch information
RichJackson committed Oct 24, 2024
1 parent e26487c commit 64c9ffb
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions kazu/training/train_multilabel_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,15 +398,17 @@ def calculate_metrics(
all_results[f"{clazz}_precision"] = result.precision
all_results[f"{clazz}_recall"] = result.recall
all_results[f"{clazz}_support"] = support
all_results["false_positives"] = {}
all_results["false_negatives"] = {}

false_positives: defaultdict[str, dict[str, int]] = defaultdict(dict)
for match, count in result.fp_info:
false_positives[clazz][match] = count
all_results["false_positives"] = dict(false_positives)
all_results["false_positives"].update(dict(false_positives))
false_negatives: defaultdict[str, dict[str, int]] = defaultdict(dict)
for match, count in result.fn_info:
false_negatives[clazz][match] = count
all_results["false_negatives"] = dict(false_negatives)
all_results["false_negatives"].update(dict(false_negatives))
label_set = set(label_list)
label_set.remove(ENTITY_OUTSIDE_SYMBOL)
if len(ner_results) != len(label_set):
Expand Down

0 comments on commit 64c9ffb

Please sign in to comment.