Skip to content

Commit

Permalink
feat: add chart range output to accuracy
Browse files Browse the repository at this point in the history
When running the accuracy command, print a nicer header giving some
basic info like: who is truth and who is merely annotator? How many
charts are in their overlap and what are the ID ranges?

This should give some peace of mind that the stats are covering what
they should be covering.

When passing in --save, we print the header to the console, but not
the written file, which is still just a csv/json of the stats.

Also:
- If a project didn't specify labels, the docs say we use all found
  labels. And we were... but only when scoring the "all labels" line
  item. We didn't score each found label on its own line. Now that's
  been fixed and we do score each found label separately.
- If an unknown annotator name is passed to the accuracy command, we
  now give a nice understandable error.
  • Loading branch information
mikix committed Jun 6, 2024
1 parent 41fcc7d commit 529db7a
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 37 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
/.idea/
/dist/
/site/
__pycache__/
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ $ ls
config.yaml labelstudio-export.json

$ chart-review accuracy jill jane
accuracy-jill-jane:
Comparing 3 charts (1, 3–4)
Truth: jill
Annotator: jane

F1 Sens Spec PPV NPV Kappa TP FN TN FP Label
0.667 0.75 0.6 0.6 0.75 0.341 3 1 3 2 *
0.667 0.5 1.0 1.0 0.5 0.4 1 1 1 0 Cough
Expand Down
6 changes: 2 additions & 4 deletions chart_review/cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,12 @@ def calc_label_freq(self, annotator) -> dict:
def calc_term_label_confusion(self, annotator) -> dict:
return term_freq.calc_term_label_confusion(self.calc_term_freq(annotator))

def _select_labels(self, label_pick: str = None) -> Optional[Iterable[str]]:
def _select_labels(self, label_pick: str = None) -> Iterable[str]:
if label_pick:
guard_in(label_pick, self.class_labels)
return [label_pick]
elif self.class_labels:
return self.class_labels
else:
return None
return self.class_labels

def confusion_matrix(
self, truth: str, annotator: str, note_range: Iterable, label_pick: str = None
Expand Down
21 changes: 17 additions & 4 deletions chart_review/commands/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import rich
import rich.table

from chart_review import agree, cohort, common
from chart_review import agree, cohort, common, console_utils


def accuracy(reader: cohort.CohortReader, truth: str, annotator: str, save: bool = False) -> None:
Expand All @@ -19,6 +19,13 @@ def accuracy(reader: cohort.CohortReader, truth: str, annotator: str, save: bool
:param annotator: the other annotator to compare against truth
:param save: whether to write the results to disk vs just printing them
"""
if truth not in reader.note_range:
print(f"Unrecognized annotator '{truth}'")
return
if annotator not in reader.note_range:
print(f"Unrecognized annotator '{annotator}'")
return

# Grab the intersection of ranges
note_range = set(reader.note_range[truth])
note_range &= set(reader.note_range[annotator])
Expand All @@ -32,17 +39,23 @@ def accuracy(reader: cohort.CohortReader, truth: str, annotator: str, save: bool
reader.confusion_matrix(truth, annotator, note_range, label)
)

result_name = f"accuracy-{truth}-{annotator}"
note_count = len(note_range)
chart_word = "chart" if note_count == 1 else "charts"
pretty_ranges = f" ({console_utils.pretty_note_range(note_range)})" if note_count > 0 else ""
print(f"Comparing {note_count} {chart_word}{pretty_ranges}")
print(f"Truth: {truth}")
print(f"Annotator: {annotator}")
print()

if save:
# Write the results out to disk
output_stem = os.path.join(reader.project_dir, result_name)
output_stem = os.path.join(reader.project_dir, f"accuracy-{truth}-{annotator}")
common.write_json(f"{output_stem}.json", table)
print(f"Wrote {output_stem}.json")
common.write_text(f"{output_stem}.csv", agree.csv_table(table, reader.class_labels))
print(f"Wrote {output_stem}.csv")
else:
# Print the results out to the console
print(f"{result_name}:")
rich_table = rich.table.Table(*agree.csv_header(), "Label", box=None, pad_edge=False)
rich_table.add_row(*agree.csv_row_score(table), "*")
for label in sorted(reader.class_labels):
Expand Down
32 changes: 6 additions & 26 deletions chart_review/commands/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import rich.box
import rich.table

from chart_review import cohort
from chart_review import cohort, console_utils


def info(reader: cohort.CohortReader) -> None:
Expand All @@ -28,7 +28,11 @@ def info(reader: cohort.CohortReader) -> None:
)
for annotator in sorted(reader.note_range):
notes = reader.note_range[annotator]
chart_table.add_row(annotator, str(len(notes)), pretty_note_range(notes))
chart_table.add_row(
annotator,
str(len(notes)),
console_utils.pretty_note_range(notes),
)
console.print(chart_table)
console.print()

Expand All @@ -38,27 +42,3 @@ def info(reader: cohort.CohortReader) -> None:
console.print(", ".join(sorted(reader.class_labels, key=str.casefold)))
else:
console.print("None", style="italic", highlight=False)


def pretty_note_range(notes: set[int]) -> str:
ranges = []
range_start = None
prev_note = None

def end_range() -> None:
if prev_note is None:
return
if range_start == prev_note:
ranges.append(str(prev_note))
else:
ranges.append(f"{range_start}{prev_note}") # en dash

for note in sorted(notes):
if prev_note is None or prev_note + 1 != note:
end_range()
range_start = note
prev_note = note

end_range()

return ", ".join(ranges)
34 changes: 34 additions & 0 deletions chart_review/console_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Helper methods for printing to the console."""


def pretty_note_range(notes: set[int]) -> str:
"""
Returns a pretty, human-readable string for a set of notes.
If no notes, this returns an empty string.
Example:
pretty_note_range({1, 2, 3, 7, 9, 10})
-> "1–3, 7, 9–10"
"""
ranges = []
range_start = None
prev_note = None

def end_range() -> None:
if prev_note is None:
return
if range_start == prev_note:
ranges.append(str(prev_note))
else:
ranges.append(f"{range_start}{prev_note}") # en dash

for note in sorted(notes):
if prev_note is None or prev_note + 1 != note:
end_range()
range_start = note
prev_note = note

end_range()

return ", ".join(ranges)
4 changes: 4 additions & 0 deletions chart_review/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def simplify_export(
"""
annotations = types.ProjectAnnotations()
annotations.labels = proj_config.class_labels
grab_all_labels = not annotations.labels

for entry in exported_json:
note_id = int(entry.get("id"))
Expand All @@ -38,6 +39,9 @@ def simplify_export(
labels |= result_labels
text_tags.append(types.LabeledText(result_text, result_labels))

if grab_all_labels:
annotations.labels |= labels

# Store these mentions in the main annotations list, by author & note
annotator = proj_config.annotators[completed_by]
annotator_mentions = annotations.mentions.setdefault(annotator, types.Mentions())
Expand Down
5 changes: 4 additions & 1 deletion docs/accuracy.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ your accuracy scores will be printed to the console.

```shell
$ chart-review accuracy jill jane
accuracy-jill-jane:
Comparing 3 charts (1, 3–4)
Truth: jill
Annotator: jane

F1 Sens Spec PPV NPV Kappa TP FN TN FP Label
0.667 0.75 0.6 0.6 0.75 0.341 3 1 3 2 *
0.667 0.5 1.0 1.0 0.5 0.4 1 1 1 0 Cough
Expand Down
2 changes: 1 addition & 1 deletion tests/test_agree.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TestAgreement(unittest.TestCase):
(
"bob",
"alice",
None,
{},
{
"FN": [{1: "Headache"}, {2: "Cough"}],
"FP": [{1: "Cough"}],
Expand Down
31 changes: 31 additions & 0 deletions tests/test_cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,37 @@ def setUp(self):
super().setUp()
self.maxDiff = None

def test_no_specified_label(self):
"""Verify that no label setup grabs all found labels from the export."""
with tempfile.TemporaryDirectory() as tmpdir:
common.write_json(
f"{tmpdir}/config.json",
{
"annotators": {"bob": 1, "alice": 2},
},
)
common.write_json(
f"{tmpdir}/labelstudio-export.json",
[
{
"id": 1,
"annotations": [
{
"completed_by": 1,
"result": [
{
"value": {"labels": ["Label A", "Label B"]},
}
],
},
],
},
],
)
reader = cohort.CohortReader(config.ProjectConfig(tmpdir))

self.assertEqual({"Label A", "Label B"}, reader.class_labels)

def test_ignored_ids(self):
reader = cohort.CohortReader(config.ProjectConfig(f"{DATA_DIR}/ignore"))

Expand Down

0 comments on commit 529db7a

Please sign in to comment.