diff --git a/chart_review/agree.py b/chart_review/agree.py index 865a3d7..590a032 100644 --- a/chart_review/agree.py +++ b/chart_review/agree.py @@ -1,6 +1,7 @@ from collections.abc import Collection, Iterable +from typing import Union -from chart_review import simplify, types +from chart_review import types def confusion_matrix( @@ -157,10 +158,10 @@ def score_reviewer( def csv_table(score: dict, class_labels: Iterable): table = list() table.append(csv_header(False, True)) - table.append(csv_row_score(score)) + table.append(csv_row_score(score, as_string=True)) for label in sorted(class_labels): - table.append(csv_row_score(score[label], label)) + table.append(csv_row_score(score[label], label, as_string=True)) return "\n".join(table) + "\n" @@ -181,16 +182,23 @@ def csv_header(pick_label=False, as_string=False): return "\t".join(header) -def csv_row_score(score: dict, pick_label=None) -> str: +def csv_row_score( + score: dict, pick_label: str = None, as_string: bool = False +) -> Union[str, list[str]]: """ Table Row entry F1, PPV (precision), Recall (sensitivity), True Pos, False Pos, False Neg :param score: dict result from F1 scoring :param pick_label: default= None means '*' all classes + :param as_string: whether to return a list of string scores or one single string :return: str representation of the score """ row = [score[header] for header in csv_header()] row = [str(value) for value in row] + + if not as_string: + return row + row.append(pick_label if pick_label else "*") return "\t".join(row) diff --git a/chart_review/cli.py b/chart_review/cli.py index ccfe8b0..1231bcb 100644 --- a/chart_review/cli.py +++ b/chart_review/cli.py @@ -49,6 +49,7 @@ def define_parser() -> argparse.ArgumentParser: def add_accuracy_subparser(subparsers) -> None: parser = subparsers.add_parser("accuracy") add_project_args(parser) + parser.add_argument("--save", action="store_true", default=False) parser.add_argument("truth_annotator") parser.add_argument("annotator") parser.set_defaults(func=run_accuracy) @@ -57,7 +58,7 @@ def add_accuracy_subparser(subparsers) -> None: def run_accuracy(args: argparse.Namespace) -> None: proj_config = config.ProjectConfig(args.project_dir, config_path=args.config) reader = cohort.CohortReader(proj_config) - accuracy(reader, args.truth_annotator, args.annotator) + accuracy(reader, args.truth_annotator, args.annotator, save=args.save) ############################################################################### diff --git a/chart_review/cohort.py b/chart_review/cohort.py index 896834e..e865d93 100644 --- a/chart_review/cohort.py +++ b/chart_review/cohort.py @@ -130,11 +130,11 @@ def score_reviewer_table_csv(self, truth: str, annotator: str, note_range) -> st table.append(agree.csv_header(False, True)) score = self.score_reviewer(truth, annotator, note_range) - table.append(agree.csv_row_score(score)) + table.append(agree.csv_row_score(score, as_string=True)) for label in self.class_labels: score = self.score_reviewer(truth, annotator, note_range, label) - table.append(agree.csv_row_score(score, label)) + table.append(agree.csv_row_score(score, label, as_string=True)) return "\n".join(table) + "\n" diff --git a/chart_review/commands/accuracy.py b/chart_review/commands/accuracy.py index 7c3f324..a5308cc 100644 --- a/chart_review/commands/accuracy.py +++ b/chart_review/commands/accuracy.py @@ -2,10 +2,13 @@ import os +import rich +import rich.table + from chart_review import agree, cohort, common -def accuracy(reader: cohort.CohortReader, truth: str, annotator: str) -> None: +def accuracy(reader: cohort.CohortReader, truth: str, annotator: str, save: bool = False) -> None: """ High-level accuracy calculation between two annotators. @@ -14,10 +17,11 @@ def accuracy(reader: cohort.CohortReader, truth: str, annotator: str) -> None: :param reader: the cohort configuration :param truth: the truth annotator :param annotator: the other annotator to compare against truth + :param save: whether to write the results to disk vs just printing them """ # Grab the intersection of ranges - note_range = set(reader.config.note_ranges[truth]) - note_range &= set(reader.config.note_ranges[annotator]) + note_range = set(reader.note_range[truth]) + note_range &= set(reader.note_range[annotator]) # All labels first table = agree.score_matrix(reader.confusion_matrix(truth, annotator, note_range)) @@ -28,9 +32,19 @@ def accuracy(reader: cohort.CohortReader, truth: str, annotator: str) -> None: reader.confusion_matrix(truth, annotator, note_range, label) ) - # And write out the results - output_stem = os.path.join(reader.project_dir, f"accuracy-{truth}-{annotator}") - common.write_json(f"{output_stem}.json", table) - print(f"Wrote {output_stem}.json") - common.write_text(f"{output_stem}.csv", agree.csv_table(table, reader.class_labels)) - print(f"Wrote {output_stem}.csv") + result_name = f"accuracy-{truth}-{annotator}" + if save: + # Write the results out to disk + output_stem = os.path.join(reader.project_dir, result_name) + common.write_json(f"{output_stem}.json", table) + print(f"Wrote {output_stem}.json") + common.write_text(f"{output_stem}.csv", agree.csv_table(table, reader.class_labels)) + print(f"Wrote {output_stem}.csv") + else: + # Print the results out to the console + print(f"{result_name}:") + rich_table = rich.table.Table(*agree.csv_header(), "Label", box=None, pad_edge=False) + rich_table.add_row(*agree.csv_row_score(table), "*") + for label in sorted(reader.class_labels): + rich_table.add_row(*agree.csv_row_score(table[label]), label) + rich.get_console().print(rich_table) diff --git a/chart_review/common.py b/chart_review/common.py index 070ac39..17a3a30 100644 --- a/chart_review/common.py +++ b/chart_review/common.py @@ -1,6 +1,6 @@ """Utility methods""" from enum import Enum, EnumMeta -from typing import Optional +from typing import Optional, Union from collections.abc import Iterable import logging import json @@ -10,7 +10,7 @@ ############################################################################### -def read_json(path: str) -> dict | list[dict]: +def read_json(path: str) -> Union[dict, list[dict]]: """ Reads json from a file :param path: filesystem path diff --git a/pyproject.toml b/pyproject.toml index a86f0a6..5941d0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ requires-python = ">= 3.9" dependencies = [ "ctakesclient", "pyyaml >= 6", + "rich", ] description = "Medical Record Chart Review Calculator" readme = "README.md" diff --git a/tests/test_cli.py b/tests/test_cli.py index 65f80b5..24baa4a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -20,7 +20,7 @@ def setUp(self): def test_accuracy(self): with tempfile.TemporaryDirectory() as tmpdir: shutil.copytree(f"{DATA_DIR}/cold", tmpdir, dirs_exist_ok=True) - cli.main_cli(["accuracy", "--project-dir", tmpdir, "jill", "jane"]) + cli.main_cli(["accuracy", "--project-dir", tmpdir, "--save", "jill", "jane"]) accuracy_json = common.read_json(f"{tmpdir}/accuracy-jill-jane.json") self.assertEqual( @@ -85,7 +85,7 @@ def test_accuracy(self): def test_ignored_ids(self): with tempfile.TemporaryDirectory() as tmpdir: shutil.copytree(f"{DATA_DIR}/ignore", tmpdir, dirs_exist_ok=True) - cli.main_cli(["accuracy", "--project-dir", tmpdir, "allison", "adam"]) + cli.main_cli(["accuracy", "--project-dir", tmpdir, "--save", "allison", "adam"]) # Only two of the five notes should be considered, and we should have full agreement. accuracy_json = common.read_json(f"{tmpdir}/accuracy-allison-adam.json")