Skip to content

Commit

Permalink
Merge pull request #35 from smart-on-fhir/mikix/accuracy-ranges
Browse files Browse the repository at this point in the history
feat: add chart range output to `accuracy`
  • Loading branch information
mikix authored Jun 7, 2024
2 parents 41fcc7d + 529db7a commit f31f098
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 37 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
/.idea/
/dist/
/site/
__pycache__/
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ $ ls
config.yaml labelstudio-export.json

$ chart-review accuracy jill jane
accuracy-jill-jane:
Comparing 3 charts (1, 3–4)
Truth: jill
Annotator: jane

F1 Sens Spec PPV NPV Kappa TP FN TN FP Label
0.667 0.75 0.6 0.6 0.75 0.341 3 1 3 2 *
0.667 0.5 1.0 1.0 0.5 0.4 1 1 1 0 Cough
Expand Down
6 changes: 2 additions & 4 deletions chart_review/cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,12 @@ def calc_label_freq(self, annotator) -> dict:
def calc_term_label_confusion(self, annotator) -> dict:
return term_freq.calc_term_label_confusion(self.calc_term_freq(annotator))

def _select_labels(self, label_pick: str = None) -> Optional[Iterable[str]]:
def _select_labels(self, label_pick: str = None) -> Iterable[str]:
if label_pick:
guard_in(label_pick, self.class_labels)
return [label_pick]
elif self.class_labels:
return self.class_labels
else:
return None
return self.class_labels

def confusion_matrix(
self, truth: str, annotator: str, note_range: Iterable, label_pick: str = None
Expand Down
21 changes: 17 additions & 4 deletions chart_review/commands/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import rich
import rich.table

from chart_review import agree, cohort, common
from chart_review import agree, cohort, common, console_utils


def accuracy(reader: cohort.CohortReader, truth: str, annotator: str, save: bool = False) -> None:
Expand All @@ -19,6 +19,13 @@ def accuracy(reader: cohort.CohortReader, truth: str, annotator: str, save: bool
:param annotator: the other annotator to compare against truth
:param save: whether to write the results to disk vs just printing them
"""
if truth not in reader.note_range:
print(f"Unrecognized annotator '{truth}'")
return
if annotator not in reader.note_range:
print(f"Unrecognized annotator '{annotator}'")
return

# Grab the intersection of ranges
note_range = set(reader.note_range[truth])
note_range &= set(reader.note_range[annotator])
Expand All @@ -32,17 +39,23 @@ def accuracy(reader: cohort.CohortReader, truth: str, annotator: str, save: bool
reader.confusion_matrix(truth, annotator, note_range, label)
)

result_name = f"accuracy-{truth}-{annotator}"
note_count = len(note_range)
chart_word = "chart" if note_count == 1 else "charts"
pretty_ranges = f" ({console_utils.pretty_note_range(note_range)})" if note_count > 0 else ""
print(f"Comparing {note_count} {chart_word}{pretty_ranges}")
print(f"Truth: {truth}")
print(f"Annotator: {annotator}")
print()

if save:
# Write the results out to disk
output_stem = os.path.join(reader.project_dir, result_name)
output_stem = os.path.join(reader.project_dir, f"accuracy-{truth}-{annotator}")
common.write_json(f"{output_stem}.json", table)
print(f"Wrote {output_stem}.json")
common.write_text(f"{output_stem}.csv", agree.csv_table(table, reader.class_labels))
print(f"Wrote {output_stem}.csv")
else:
# Print the results out to the console
print(f"{result_name}:")
rich_table = rich.table.Table(*agree.csv_header(), "Label", box=None, pad_edge=False)
rich_table.add_row(*agree.csv_row_score(table), "*")
for label in sorted(reader.class_labels):
Expand Down
32 changes: 6 additions & 26 deletions chart_review/commands/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import rich.box
import rich.table

from chart_review import cohort
from chart_review import cohort, console_utils


def info(reader: cohort.CohortReader) -> None:
Expand All @@ -28,7 +28,11 @@ def info(reader: cohort.CohortReader) -> None:
)
for annotator in sorted(reader.note_range):
notes = reader.note_range[annotator]
chart_table.add_row(annotator, str(len(notes)), pretty_note_range(notes))
chart_table.add_row(
annotator,
str(len(notes)),
console_utils.pretty_note_range(notes),
)
console.print(chart_table)
console.print()

Expand All @@ -38,27 +42,3 @@ def info(reader: cohort.CohortReader) -> None:
console.print(", ".join(sorted(reader.class_labels, key=str.casefold)))
else:
console.print("None", style="italic", highlight=False)


def pretty_note_range(notes: set[int]) -> str:
ranges = []
range_start = None
prev_note = None

def end_range() -> None:
if prev_note is None:
return
if range_start == prev_note:
ranges.append(str(prev_note))
else:
ranges.append(f"{range_start}{prev_note}") # en dash

for note in sorted(notes):
if prev_note is None or prev_note + 1 != note:
end_range()
range_start = note
prev_note = note

end_range()

return ", ".join(ranges)
34 changes: 34 additions & 0 deletions chart_review/console_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Helper methods for printing to the console."""


def pretty_note_range(notes: set[int]) -> str:
"""
Returns a pretty, human-readable string for a set of notes.
If no notes, this returns an empty string.
Example:
pretty_note_range({1, 2, 3, 7, 9, 10})
-> "1–3, 7, 9–10"
"""
ranges = []
range_start = None
prev_note = None

def end_range() -> None:
if prev_note is None:
return
if range_start == prev_note:
ranges.append(str(prev_note))
else:
ranges.append(f"{range_start}{prev_note}") # en dash

for note in sorted(notes):
if prev_note is None or prev_note + 1 != note:
end_range()
range_start = note
prev_note = note

end_range()

return ", ".join(ranges)
4 changes: 4 additions & 0 deletions chart_review/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def simplify_export(
"""
annotations = types.ProjectAnnotations()
annotations.labels = proj_config.class_labels
grab_all_labels = not annotations.labels

for entry in exported_json:
note_id = int(entry.get("id"))
Expand All @@ -38,6 +39,9 @@ def simplify_export(
labels |= result_labels
text_tags.append(types.LabeledText(result_text, result_labels))

if grab_all_labels:
annotations.labels |= labels

# Store these mentions in the main annotations list, by author & note
annotator = proj_config.annotators[completed_by]
annotator_mentions = annotations.mentions.setdefault(annotator, types.Mentions())
Expand Down
5 changes: 4 additions & 1 deletion docs/accuracy.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ your accuracy scores will be printed to the console.

```shell
$ chart-review accuracy jill jane
accuracy-jill-jane:
Comparing 3 charts (1, 3–4)
Truth: jill
Annotator: jane

F1 Sens Spec PPV NPV Kappa TP FN TN FP Label
0.667 0.75 0.6 0.6 0.75 0.341 3 1 3 2 *
0.667 0.5 1.0 1.0 0.5 0.4 1 1 1 0 Cough
Expand Down
2 changes: 1 addition & 1 deletion tests/test_agree.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TestAgreement(unittest.TestCase):
(
"bob",
"alice",
None,
{},
{
"FN": [{1: "Headache"}, {2: "Cough"}],
"FP": [{1: "Cough"}],
Expand Down
31 changes: 31 additions & 0 deletions tests/test_cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,37 @@ def setUp(self):
super().setUp()
self.maxDiff = None

def test_no_specified_label(self):
"""Verify that no label setup grabs all found labels from the export."""
with tempfile.TemporaryDirectory() as tmpdir:
common.write_json(
f"{tmpdir}/config.json",
{
"annotators": {"bob": 1, "alice": 2},
},
)
common.write_json(
f"{tmpdir}/labelstudio-export.json",
[
{
"id": 1,
"annotations": [
{
"completed_by": 1,
"result": [
{
"value": {"labels": ["Label A", "Label B"]},
}
],
},
],
},
],
)
reader = cohort.CohortReader(config.ProjectConfig(tmpdir))

self.assertEqual({"Label A", "Label B"}, reader.class_labels)

def test_ignored_ids(self):
reader = cohort.CohortReader(config.ProjectConfig(f"{DATA_DIR}/ignore"))

Expand Down

0 comments on commit f31f098

Please sign in to comment.