Skip to content

Commit

Permalink
Merge pull request #19 from smart-on-fhir/mikix/print-by-default
Browse files Browse the repository at this point in the history
feat!(accuracy): print results by default
  • Loading branch information
mikix authored Jan 18, 2024
2 parents 6d943e2 + d95d2a5 commit c55f057
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 20 deletions.
16 changes: 12 additions & 4 deletions chart_review/agree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Collection, Iterable
from typing import Union

from chart_review import simplify, types
from chart_review import types


def confusion_matrix(
Expand Down Expand Up @@ -157,10 +158,10 @@ def score_reviewer(
def csv_table(score: dict, class_labels: Iterable):
table = list()
table.append(csv_header(False, True))
table.append(csv_row_score(score))
table.append(csv_row_score(score, as_string=True))

for label in sorted(class_labels):
table.append(csv_row_score(score[label], label))
table.append(csv_row_score(score[label], label, as_string=True))
return "\n".join(table) + "\n"


Expand All @@ -181,16 +182,23 @@ def csv_header(pick_label=False, as_string=False):
return "\t".join(header)


def csv_row_score(score: dict, pick_label=None) -> str:
def csv_row_score(
score: dict, pick_label: str = None, as_string: bool = False
) -> Union[str, list[str]]:
"""
Table Row entry
F1, PPV (precision), Recall (sensitivity), True Pos, False Pos, False Neg
:param score: dict result from F1 scoring
:param pick_label: default= None means '*' all classes
:param as_string: whether to return a list of string scores or one single string
:return: str representation of the score
"""
row = [score[header] for header in csv_header()]
row = [str(value) for value in row]

if not as_string:
return row

row.append(pick_label if pick_label else "*")
return "\t".join(row)

Expand Down
3 changes: 2 additions & 1 deletion chart_review/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def define_parser() -> argparse.ArgumentParser:
def add_accuracy_subparser(subparsers) -> None:
parser = subparsers.add_parser("accuracy")
add_project_args(parser)
parser.add_argument("--save", action="store_true", default=False)
parser.add_argument("truth_annotator")
parser.add_argument("annotator")
parser.set_defaults(func=run_accuracy)
Expand All @@ -57,7 +58,7 @@ def add_accuracy_subparser(subparsers) -> None:
def run_accuracy(args: argparse.Namespace) -> None:
proj_config = config.ProjectConfig(args.project_dir, config_path=args.config)
reader = cohort.CohortReader(proj_config)
accuracy(reader, args.truth_annotator, args.annotator)
accuracy(reader, args.truth_annotator, args.annotator, save=args.save)


###############################################################################
Expand Down
4 changes: 2 additions & 2 deletions chart_review/cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ def score_reviewer_table_csv(self, truth: str, annotator: str, note_range) -> st
table.append(agree.csv_header(False, True))

score = self.score_reviewer(truth, annotator, note_range)
table.append(agree.csv_row_score(score))
table.append(agree.csv_row_score(score, as_string=True))

for label in self.class_labels:
score = self.score_reviewer(truth, annotator, note_range, label)
table.append(agree.csv_row_score(score, label))
table.append(agree.csv_row_score(score, label, as_string=True))

return "\n".join(table) + "\n"

Expand Down
32 changes: 23 additions & 9 deletions chart_review/commands/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

import os

import rich
import rich.table

from chart_review import agree, cohort, common


def accuracy(reader: cohort.CohortReader, truth: str, annotator: str) -> None:
def accuracy(reader: cohort.CohortReader, truth: str, annotator: str, save: bool = False) -> None:
"""
High-level accuracy calculation between two annotators.
Expand All @@ -14,10 +17,11 @@ def accuracy(reader: cohort.CohortReader, truth: str, annotator: str) -> None:
:param reader: the cohort configuration
:param truth: the truth annotator
:param annotator: the other annotator to compare against truth
:param save: whether to write the results to disk vs just printing them
"""
# Grab the intersection of ranges
note_range = set(reader.config.note_ranges[truth])
note_range &= set(reader.config.note_ranges[annotator])
note_range = set(reader.note_range[truth])
note_range &= set(reader.note_range[annotator])

# All labels first
table = agree.score_matrix(reader.confusion_matrix(truth, annotator, note_range))
Expand All @@ -28,9 +32,19 @@ def accuracy(reader: cohort.CohortReader, truth: str, annotator: str) -> None:
reader.confusion_matrix(truth, annotator, note_range, label)
)

# And write out the results
output_stem = os.path.join(reader.project_dir, f"accuracy-{truth}-{annotator}")
common.write_json(f"{output_stem}.json", table)
print(f"Wrote {output_stem}.json")
common.write_text(f"{output_stem}.csv", agree.csv_table(table, reader.class_labels))
print(f"Wrote {output_stem}.csv")
result_name = f"accuracy-{truth}-{annotator}"
if save:
# Write the results out to disk
output_stem = os.path.join(reader.project_dir, result_name)
common.write_json(f"{output_stem}.json", table)
print(f"Wrote {output_stem}.json")
common.write_text(f"{output_stem}.csv", agree.csv_table(table, reader.class_labels))
print(f"Wrote {output_stem}.csv")
else:
# Print the results out to the console
print(f"{result_name}:")
rich_table = rich.table.Table(*agree.csv_header(), "Label", box=None, pad_edge=False)
rich_table.add_row(*agree.csv_row_score(table), "*")
for label in sorted(reader.class_labels):
rich_table.add_row(*agree.csv_row_score(table[label]), label)
rich.get_console().print(rich_table)
4 changes: 2 additions & 2 deletions chart_review/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Utility methods"""
from enum import Enum, EnumMeta
from typing import Optional
from typing import Optional, Union
from collections.abc import Iterable
import logging
import json
Expand All @@ -10,7 +10,7 @@
###############################################################################


def read_json(path: str) -> dict | list[dict]:
def read_json(path: str) -> Union[dict, list[dict]]:
"""
Reads json from a file
:param path: filesystem path
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ requires-python = ">= 3.9"
dependencies = [
"ctakesclient",
"pyyaml >= 6",
"rich",
]
description = "Medical Record Chart Review Calculator"
readme = "README.md"
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def setUp(self):
def test_accuracy(self):
with tempfile.TemporaryDirectory() as tmpdir:
shutil.copytree(f"{DATA_DIR}/cold", tmpdir, dirs_exist_ok=True)
cli.main_cli(["accuracy", "--project-dir", tmpdir, "jill", "jane"])
cli.main_cli(["accuracy", "--project-dir", tmpdir, "--save", "jill", "jane"])

accuracy_json = common.read_json(f"{tmpdir}/accuracy-jill-jane.json")
self.assertEqual(
Expand Down Expand Up @@ -85,7 +85,7 @@ def test_accuracy(self):
def test_ignored_ids(self):
with tempfile.TemporaryDirectory() as tmpdir:
shutil.copytree(f"{DATA_DIR}/ignore", tmpdir, dirs_exist_ok=True)
cli.main_cli(["accuracy", "--project-dir", tmpdir, "allison", "adam"])
cli.main_cli(["accuracy", "--project-dir", tmpdir, "--save", "allison", "adam"])

# Only two of the five notes should be considered, and we should have full agreement.
accuracy_json = common.read_json(f"{tmpdir}/accuracy-allison-adam.json")
Expand Down

0 comments on commit c55f057

Please sign in to comment.