diff --git a/.gitignore b/.gitignore index 4358842..7be7b27 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ /.idea/ +/dist/ +/site/ __pycache__/ diff --git a/README.md b/README.md index ba1e80a..62a92d6 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,10 @@ $ ls config.yaml labelstudio-export.json $ chart-review accuracy jill jane -accuracy-jill-jane: +Comparing 3 charts (1, 3–4) +Truth: jill +Annotator: jane + F1 Sens Spec PPV NPV Kappa TP FN TN FP Label 0.667 0.75 0.6 0.6 0.75 0.341 3 1 3 2 * 0.667 0.5 1.0 1.0 0.5 0.4 1 1 1 0 Cough diff --git a/chart_review/cohort.py b/chart_review/cohort.py index 4ccc322..aa92b88 100644 --- a/chart_review/cohort.py +++ b/chart_review/cohort.py @@ -93,14 +93,12 @@ def calc_label_freq(self, annotator) -> dict: def calc_term_label_confusion(self, annotator) -> dict: return term_freq.calc_term_label_confusion(self.calc_term_freq(annotator)) - def _select_labels(self, label_pick: str = None) -> Optional[Iterable[str]]: + def _select_labels(self, label_pick: str = None) -> Iterable[str]: if label_pick: guard_in(label_pick, self.class_labels) return [label_pick] - elif self.class_labels: - return self.class_labels else: - return None + return self.class_labels def confusion_matrix( self, truth: str, annotator: str, note_range: Iterable, label_pick: str = None diff --git a/chart_review/commands/accuracy.py b/chart_review/commands/accuracy.py index a5308cc..5d71dd7 100644 --- a/chart_review/commands/accuracy.py +++ b/chart_review/commands/accuracy.py @@ -5,7 +5,7 @@ import rich import rich.table -from chart_review import agree, cohort, common +from chart_review import agree, cohort, common, console_utils def accuracy(reader: cohort.CohortReader, truth: str, annotator: str, save: bool = False) -> None: @@ -19,6 +19,13 @@ def accuracy(reader: cohort.CohortReader, truth: str, annotator: str, save: bool :param annotator: the other annotator to compare against truth :param save: whether to write the results to disk vs just printing them """ + if truth not in reader.note_range: + print(f"Unrecognized annotator '{truth}'") + return + if annotator not in reader.note_range: + print(f"Unrecognized annotator '{annotator}'") + return + # Grab the intersection of ranges note_range = set(reader.note_range[truth]) note_range &= set(reader.note_range[annotator]) @@ -32,17 +39,23 @@ def accuracy(reader: cohort.CohortReader, truth: str, annotator: str, save: bool reader.confusion_matrix(truth, annotator, note_range, label) ) - result_name = f"accuracy-{truth}-{annotator}" + note_count = len(note_range) + chart_word = "chart" if note_count == 1 else "charts" + pretty_ranges = f" ({console_utils.pretty_note_range(note_range)})" if note_count > 0 else "" + print(f"Comparing {note_count} {chart_word}{pretty_ranges}") + print(f"Truth: {truth}") + print(f"Annotator: {annotator}") + print() + if save: # Write the results out to disk - output_stem = os.path.join(reader.project_dir, result_name) + output_stem = os.path.join(reader.project_dir, f"accuracy-{truth}-{annotator}") common.write_json(f"{output_stem}.json", table) print(f"Wrote {output_stem}.json") common.write_text(f"{output_stem}.csv", agree.csv_table(table, reader.class_labels)) print(f"Wrote {output_stem}.csv") else: # Print the results out to the console - print(f"{result_name}:") rich_table = rich.table.Table(*agree.csv_header(), "Label", box=None, pad_edge=False) rich_table.add_row(*agree.csv_row_score(table), "*") for label in sorted(reader.class_labels): diff --git a/chart_review/commands/info.py b/chart_review/commands/info.py index 47cc459..ef04e9f 100644 --- a/chart_review/commands/info.py +++ b/chart_review/commands/info.py @@ -4,7 +4,7 @@ import rich.box import rich.table -from chart_review import cohort +from chart_review import cohort, console_utils def info(reader: cohort.CohortReader) -> None: @@ -28,7 +28,11 @@ def info(reader: cohort.CohortReader) -> None: ) for annotator in sorted(reader.note_range): notes = reader.note_range[annotator] - chart_table.add_row(annotator, str(len(notes)), pretty_note_range(notes)) + chart_table.add_row( + annotator, + str(len(notes)), + console_utils.pretty_note_range(notes), + ) console.print(chart_table) console.print() @@ -38,27 +42,3 @@ def info(reader: cohort.CohortReader) -> None: console.print(", ".join(sorted(reader.class_labels, key=str.casefold))) else: console.print("None", style="italic", highlight=False) - - -def pretty_note_range(notes: set[int]) -> str: - ranges = [] - range_start = None - prev_note = None - - def end_range() -> None: - if prev_note is None: - return - if range_start == prev_note: - ranges.append(str(prev_note)) - else: - ranges.append(f"{range_start}–{prev_note}") # en dash - - for note in sorted(notes): - if prev_note is None or prev_note + 1 != note: - end_range() - range_start = note - prev_note = note - - end_range() - - return ", ".join(ranges) diff --git a/chart_review/console_utils.py b/chart_review/console_utils.py new file mode 100644 index 0000000..6ba6232 --- /dev/null +++ b/chart_review/console_utils.py @@ -0,0 +1,34 @@ +"""Helper methods for printing to the console.""" + + +def pretty_note_range(notes: set[int]) -> str: + """ + Returns a pretty, human-readable string for a set of notes. + + If no notes, this returns an empty string. + + Example: + pretty_note_range({1, 2, 3, 7, 9, 10}) + -> "1–3, 7, 9–10" + """ + ranges = [] + range_start = None + prev_note = None + + def end_range() -> None: + if prev_note is None: + return + if range_start == prev_note: + ranges.append(str(prev_note)) + else: + ranges.append(f"{range_start}–{prev_note}") # en dash + + for note in sorted(notes): + if prev_note is None or prev_note + 1 != note: + end_range() + range_start = note + prev_note = note + + end_range() + + return ", ".join(ranges) diff --git a/chart_review/simplify.py b/chart_review/simplify.py index 1321a3e..2fb91bb 100644 --- a/chart_review/simplify.py +++ b/chart_review/simplify.py @@ -18,6 +18,7 @@ def simplify_export( """ annotations = types.ProjectAnnotations() annotations.labels = proj_config.class_labels + grab_all_labels = not annotations.labels for entry in exported_json: note_id = int(entry.get("id")) @@ -38,6 +39,9 @@ def simplify_export( labels |= result_labels text_tags.append(types.LabeledText(result_text, result_labels)) + if grab_all_labels: + annotations.labels |= labels + # Store these mentions in the main annotations list, by author & note annotator = proj_config.annotators[completed_by] annotator_mentions = annotations.mentions.setdefault(annotator, types.Mentions()) diff --git a/docs/accuracy.md b/docs/accuracy.md index 33a3081..de96849 100644 --- a/docs/accuracy.md +++ b/docs/accuracy.md @@ -18,7 +18,10 @@ your accuracy scores will be printed to the console. ```shell $ chart-review accuracy jill jane -accuracy-jill-jane: +Comparing 3 charts (1, 3–4) +Truth: jill +Annotator: jane + F1 Sens Spec PPV NPV Kappa TP FN TN FP Label 0.667 0.75 0.6 0.6 0.75 0.341 3 1 3 2 * 0.667 0.5 1.0 1.0 0.5 0.4 1 1 1 0 Cough diff --git a/tests/test_agree.py b/tests/test_agree.py index 147db34..0a7782a 100644 --- a/tests/test_agree.py +++ b/tests/test_agree.py @@ -26,7 +26,7 @@ class TestAgreement(unittest.TestCase): ( "bob", "alice", - None, + {}, { "FN": [{1: "Headache"}, {2: "Cough"}], "FP": [{1: "Cough"}], diff --git a/tests/test_cohort.py b/tests/test_cohort.py index b946dcb..69ef360 100644 --- a/tests/test_cohort.py +++ b/tests/test_cohort.py @@ -16,6 +16,37 @@ def setUp(self): super().setUp() self.maxDiff = None + def test_no_specified_label(self): + """Verify that no label setup grabs all found labels from the export.""" + with tempfile.TemporaryDirectory() as tmpdir: + common.write_json( + f"{tmpdir}/config.json", + { + "annotators": {"bob": 1, "alice": 2}, + }, + ) + common.write_json( + f"{tmpdir}/labelstudio-export.json", + [ + { + "id": 1, + "annotations": [ + { + "completed_by": 1, + "result": [ + { + "value": {"labels": ["Label A", "Label B"]}, + } + ], + }, + ], + }, + ], + ) + reader = cohort.CohortReader(config.ProjectConfig(tmpdir)) + + self.assertEqual({"Label A", "Label B"}, reader.class_labels) + def test_ignored_ids(self): reader = cohort.CohortReader(config.ProjectConfig(f"{DATA_DIR}/ignore"))