From 529db7a094d8acebf794efbabdc4bcd37d43e8e4 Mon Sep 17 00:00:00 2001 From: Michael Terry Date: Thu, 6 Jun 2024 10:12:54 -0400 Subject: [PATCH] feat: add chart range output to `accuracy` When running the accuracy command, print a nicer header giving some basic info like: who is truth and who is merely annotator? How many charts are in their overlap and what are the ID ranges? This should give some peace of mind that the stats are covering what they should be covering. When passing in --save, we print the header to the console, but not the written file, which is still just a csv/json of the stats. Also: - If a project didn't specify labels, the docs say we use all found labels. And we were... but only when scoring the "all labels" line item. We didn't score each found label on its own line. Now that's been fixed and we do score each found label separately. - If an unknown annotator name is passed to the accuracy command, we now give a nice understandable error. --- .gitignore | 2 ++ README.md | 5 ++++- chart_review/cohort.py | 6 ++---- chart_review/commands/accuracy.py | 21 +++++++++++++++---- chart_review/commands/info.py | 32 ++++++----------------------- chart_review/console_utils.py | 34 +++++++++++++++++++++++++++++++ chart_review/simplify.py | 4 ++++ docs/accuracy.md | 5 ++++- tests/test_agree.py | 2 +- tests/test_cohort.py | 31 ++++++++++++++++++++++++++++ 10 files changed, 105 insertions(+), 37 deletions(-) create mode 100644 chart_review/console_utils.py diff --git a/.gitignore b/.gitignore index 4358842..7be7b27 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ /.idea/ +/dist/ +/site/ __pycache__/ diff --git a/README.md b/README.md index ba1e80a..62a92d6 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,10 @@ $ ls config.yaml labelstudio-export.json $ chart-review accuracy jill jane -accuracy-jill-jane: +Comparing 3 charts (1, 3–4) +Truth: jill +Annotator: jane + F1 Sens Spec PPV NPV Kappa TP FN TN FP Label 0.667 0.75 0.6 0.6 0.75 0.341 3 1 3 2 * 0.667 0.5 1.0 1.0 0.5 0.4 1 1 1 0 Cough diff --git a/chart_review/cohort.py b/chart_review/cohort.py index 4ccc322..aa92b88 100644 --- a/chart_review/cohort.py +++ b/chart_review/cohort.py @@ -93,14 +93,12 @@ def calc_label_freq(self, annotator) -> dict: def calc_term_label_confusion(self, annotator) -> dict: return term_freq.calc_term_label_confusion(self.calc_term_freq(annotator)) - def _select_labels(self, label_pick: str = None) -> Optional[Iterable[str]]: + def _select_labels(self, label_pick: str = None) -> Iterable[str]: if label_pick: guard_in(label_pick, self.class_labels) return [label_pick] - elif self.class_labels: - return self.class_labels else: - return None + return self.class_labels def confusion_matrix( self, truth: str, annotator: str, note_range: Iterable, label_pick: str = None diff --git a/chart_review/commands/accuracy.py b/chart_review/commands/accuracy.py index a5308cc..5d71dd7 100644 --- a/chart_review/commands/accuracy.py +++ b/chart_review/commands/accuracy.py @@ -5,7 +5,7 @@ import rich import rich.table -from chart_review import agree, cohort, common +from chart_review import agree, cohort, common, console_utils def accuracy(reader: cohort.CohortReader, truth: str, annotator: str, save: bool = False) -> None: @@ -19,6 +19,13 @@ def accuracy(reader: cohort.CohortReader, truth: str, annotator: str, save: bool :param annotator: the other annotator to compare against truth :param save: whether to write the results to disk vs just printing them """ + if truth not in reader.note_range: + print(f"Unrecognized annotator '{truth}'") + return + if annotator not in reader.note_range: + print(f"Unrecognized annotator '{annotator}'") + return + # Grab the intersection of ranges note_range = set(reader.note_range[truth]) note_range &= set(reader.note_range[annotator]) @@ -32,17 +39,23 @@ def accuracy(reader: cohort.CohortReader, truth: str, annotator: str, save: bool reader.confusion_matrix(truth, annotator, note_range, label) ) - result_name = f"accuracy-{truth}-{annotator}" + note_count = len(note_range) + chart_word = "chart" if note_count == 1 else "charts" + pretty_ranges = f" ({console_utils.pretty_note_range(note_range)})" if note_count > 0 else "" + print(f"Comparing {note_count} {chart_word}{pretty_ranges}") + print(f"Truth: {truth}") + print(f"Annotator: {annotator}") + print() + if save: # Write the results out to disk - output_stem = os.path.join(reader.project_dir, result_name) + output_stem = os.path.join(reader.project_dir, f"accuracy-{truth}-{annotator}") common.write_json(f"{output_stem}.json", table) print(f"Wrote {output_stem}.json") common.write_text(f"{output_stem}.csv", agree.csv_table(table, reader.class_labels)) print(f"Wrote {output_stem}.csv") else: # Print the results out to the console - print(f"{result_name}:") rich_table = rich.table.Table(*agree.csv_header(), "Label", box=None, pad_edge=False) rich_table.add_row(*agree.csv_row_score(table), "*") for label in sorted(reader.class_labels): diff --git a/chart_review/commands/info.py b/chart_review/commands/info.py index 47cc459..ef04e9f 100644 --- a/chart_review/commands/info.py +++ b/chart_review/commands/info.py @@ -4,7 +4,7 @@ import rich.box import rich.table -from chart_review import cohort +from chart_review import cohort, console_utils def info(reader: cohort.CohortReader) -> None: @@ -28,7 +28,11 @@ def info(reader: cohort.CohortReader) -> None: ) for annotator in sorted(reader.note_range): notes = reader.note_range[annotator] - chart_table.add_row(annotator, str(len(notes)), pretty_note_range(notes)) + chart_table.add_row( + annotator, + str(len(notes)), + console_utils.pretty_note_range(notes), + ) console.print(chart_table) console.print() @@ -38,27 +42,3 @@ def info(reader: cohort.CohortReader) -> None: console.print(", ".join(sorted(reader.class_labels, key=str.casefold))) else: console.print("None", style="italic", highlight=False) - - -def pretty_note_range(notes: set[int]) -> str: - ranges = [] - range_start = None - prev_note = None - - def end_range() -> None: - if prev_note is None: - return - if range_start == prev_note: - ranges.append(str(prev_note)) - else: - ranges.append(f"{range_start}–{prev_note}") # en dash - - for note in sorted(notes): - if prev_note is None or prev_note + 1 != note: - end_range() - range_start = note - prev_note = note - - end_range() - - return ", ".join(ranges) diff --git a/chart_review/console_utils.py b/chart_review/console_utils.py new file mode 100644 index 0000000..6ba6232 --- /dev/null +++ b/chart_review/console_utils.py @@ -0,0 +1,34 @@ +"""Helper methods for printing to the console.""" + + +def pretty_note_range(notes: set[int]) -> str: + """ + Returns a pretty, human-readable string for a set of notes. + + If no notes, this returns an empty string. + + Example: + pretty_note_range({1, 2, 3, 7, 9, 10}) + -> "1–3, 7, 9–10" + """ + ranges = [] + range_start = None + prev_note = None + + def end_range() -> None: + if prev_note is None: + return + if range_start == prev_note: + ranges.append(str(prev_note)) + else: + ranges.append(f"{range_start}–{prev_note}") # en dash + + for note in sorted(notes): + if prev_note is None or prev_note + 1 != note: + end_range() + range_start = note + prev_note = note + + end_range() + + return ", ".join(ranges) diff --git a/chart_review/simplify.py b/chart_review/simplify.py index 1321a3e..2fb91bb 100644 --- a/chart_review/simplify.py +++ b/chart_review/simplify.py @@ -18,6 +18,7 @@ def simplify_export( """ annotations = types.ProjectAnnotations() annotations.labels = proj_config.class_labels + grab_all_labels = not annotations.labels for entry in exported_json: note_id = int(entry.get("id")) @@ -38,6 +39,9 @@ def simplify_export( labels |= result_labels text_tags.append(types.LabeledText(result_text, result_labels)) + if grab_all_labels: + annotations.labels |= labels + # Store these mentions in the main annotations list, by author & note annotator = proj_config.annotators[completed_by] annotator_mentions = annotations.mentions.setdefault(annotator, types.Mentions()) diff --git a/docs/accuracy.md b/docs/accuracy.md index 33a3081..de96849 100644 --- a/docs/accuracy.md +++ b/docs/accuracy.md @@ -18,7 +18,10 @@ your accuracy scores will be printed to the console. ```shell $ chart-review accuracy jill jane -accuracy-jill-jane: +Comparing 3 charts (1, 3–4) +Truth: jill +Annotator: jane + F1 Sens Spec PPV NPV Kappa TP FN TN FP Label 0.667 0.75 0.6 0.6 0.75 0.341 3 1 3 2 * 0.667 0.5 1.0 1.0 0.5 0.4 1 1 1 0 Cough diff --git a/tests/test_agree.py b/tests/test_agree.py index 147db34..0a7782a 100644 --- a/tests/test_agree.py +++ b/tests/test_agree.py @@ -26,7 +26,7 @@ class TestAgreement(unittest.TestCase): ( "bob", "alice", - None, + {}, { "FN": [{1: "Headache"}, {2: "Cough"}], "FP": [{1: "Cough"}], diff --git a/tests/test_cohort.py b/tests/test_cohort.py index b946dcb..69ef360 100644 --- a/tests/test_cohort.py +++ b/tests/test_cohort.py @@ -16,6 +16,37 @@ def setUp(self): super().setUp() self.maxDiff = None + def test_no_specified_label(self): + """Verify that no label setup grabs all found labels from the export.""" + with tempfile.TemporaryDirectory() as tmpdir: + common.write_json( + f"{tmpdir}/config.json", + { + "annotators": {"bob": 1, "alice": 2}, + }, + ) + common.write_json( + f"{tmpdir}/labelstudio-export.json", + [ + { + "id": 1, + "annotations": [ + { + "completed_by": 1, + "result": [ + { + "value": {"labels": ["Label A", "Label B"]}, + } + ], + }, + ], + }, + ], + ) + reader = cohort.CohortReader(config.ProjectConfig(tmpdir)) + + self.assertEqual({"Label A", "Label B"}, reader.class_labels) + def test_ignored_ids(self): reader = cohort.CohortReader(config.ProjectConfig(f"{DATA_DIR}/ignore"))