Skip to content

Commit

Permalink
Merge pull request #45 from smart-on-fhir/mikix/csv-arg
Browse files Browse the repository at this point in the history
feat: add --csv flag to all existing commands
  • Loading branch information
mikix authored Jun 24, 2024
2 parents 60f38f4 + 3206698 commit 032f08e
Show file tree
Hide file tree
Showing 13 changed files with 406 additions and 167 deletions.
40 changes: 40 additions & 0 deletions chart_review/cli_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
"""Helper methods for CLI parsing."""

import argparse
import csv
import sys

import rich.box
import rich.table

from chart_review import cohort, config

Expand All @@ -26,6 +31,41 @@ def add_project_args(parser: argparse.ArgumentParser, is_global: bool = False) -
)


def add_output_args(parser: argparse.ArgumentParser):
"""Returns an exclusive option group if you want to add custom output arguments"""
group = parser.add_argument_group("output")
exclusive = group.add_mutually_exclusive_group()
exclusive.add_argument("--csv", action="store_true", help="print results in CSV format")
return exclusive


def get_cohort_reader(args: argparse.Namespace) -> cohort.CohortReader:
proj_config = config.ProjectConfig(project_dir=args.project_dir, config_path=args.config)
return cohort.CohortReader(proj_config)


def create_table(*headers) -> rich.table.Table:
"""
Creates a table with standard chart-review formatting.
You can use your own table formatting if you have particular needs,
but this should be your default table creator.
"""
table = rich.table.Table(box=rich.box.ROUNDED)
for header in headers:
table.add_column(header, overflow="fold")
return table


def print_table_as_csv(table: rich.table.Table) -> None:
"""Prints a Rich table as a CSV to stdout"""
writer = csv.writer(sys.stdout)

# First the headers
headers = [str(col.header).lower().replace(" ", "_") for col in table.columns]
writer.writerow(headers)

# And then each row
cells_by_row = zip(*[col.cells for col in table.columns])
for row in cells_by_row:
writer.writerow(row)
112 changes: 55 additions & 57 deletions chart_review/commands/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,29 @@
import rich.table
import rich.text

from chart_review import agree, cli_utils, cohort, common, config, console_utils
from chart_review import agree, cli_utils, common, console_utils


def accuracy(
reader: cohort.CohortReader,
truth: str,
annotator: str,
save: bool = False,
verbose: bool = False,
) -> None:
def make_subparser(parser: argparse.ArgumentParser) -> None:
cli_utils.add_project_args(parser)
output_group = cli_utils.add_output_args(parser)
output_group.add_argument("--save", action="store_true", help=argparse.SUPPRESS)
parser.add_argument("--verbose", action="store_true", help="show each chart’s labels")
parser.add_argument("truth_annotator")
parser.add_argument("annotator")
parser.set_defaults(func=print_accuracy)


def print_accuracy(args: argparse.Namespace) -> None:
"""
High-level accuracy calculation between two annotators.
The results will be written to the project directory.
:param reader: the cohort configuration
:param truth: the truth annotator
:param annotator: the other annotator to compare against truth
:param save: whether to write the results to disk vs just printing them
:param verbose: whether to print per-chart/per-label classifications
"""
reader = cli_utils.get_cohort_reader(args)
truth = args.truth_annotator
annotator = args.annotator

if truth not in reader.note_range:
print(f"Unrecognized annotator '{truth}'")
return
Expand All @@ -48,64 +50,60 @@ def accuracy(
matrices[label] = reader.confusion_matrix(truth, annotator, note_range, label)

# Now score them
scores = agree.score_matrix(matrices[None])
scores = {None: agree.score_matrix(matrices[None])}
for label in labels:
scores[label] = agree.score_matrix(matrices[label])

console = rich.get_console()

note_count = len(note_range)
chart_word = "chart" if note_count == 1 else "charts"
pretty_ranges = f" ({console_utils.pretty_note_range(note_range)})" if note_count > 0 else ""
console.print(f"Comparing {note_count} {chart_word}{pretty_ranges}")
console.print(f"Truth: {truth}")
console.print(f"Annotator: {annotator}")

console.print()
if save:
# Write the results out to disk
output_stem = os.path.join(reader.project_dir, f"accuracy-{truth}-{annotator}")
common.write_json(f"{output_stem}.json", scores)
console.print(f"Wrote {output_stem}.json")
common.write_text(f"{output_stem}.csv", agree.csv_table(scores, reader.class_labels))
console.print(f"Wrote {output_stem}.csv")
else:
# Print the results out to the console
rich_table = rich.table.Table(*agree.csv_header(), "Label", box=None, pad_edge=False)
rich_table.add_row(*agree.csv_row_score(scores), "*")
for label in labels:
rich_table.add_row(*agree.csv_row_score(scores[label]), label)
console.print(rich_table)

if verbose:
if args.verbose:
# Print a table of each chart/label combo - useful for reviewing where an annotator
# went wrong.
verbose_table = rich.table.Table(
"Chart ID", "Label", "Classification", box=rich.box.ROUNDED
)
table = cli_utils.create_table("Chart ID", "Label", "Classification")
for note_id in sorted(note_range):
verbose_table.add_section()
table.add_section()
for label in labels:
for classification in ["TN", "TP", "FN", "FP"]:
if {note_id: label} in matrices[label][classification]:
style = "bold" if classification[0] == "F" else None # highlight errors
class_text = rich.text.Text(classification, style=style)
verbose_table.add_row(str(note_id), label, class_text)
table.add_row(str(note_id), label, class_text)
break
console.print()
console.print(verbose_table)
else:
# Normal F1/Kappa scores
table = rich.table.Table(*agree.csv_header(), "Label", box=None, pad_edge=False)
table.add_row(*agree.csv_row_score(scores[None]), "*")
for label in labels:
table.add_row(*agree.csv_row_score(scores[label]), label)

if args.csv:
cli_utils.print_table_as_csv(table)
return

def make_subparser(parser: argparse.ArgumentParser) -> None:
cli_utils.add_project_args(parser)
parser.add_argument("--save", action="store_true", help="Write stats to CSV & JSON files")
parser.add_argument("--verbose", action="store_true", help="Explain each chart’s labels")
parser.add_argument("truth_annotator")
parser.add_argument("annotator")
parser.set_defaults(func=run_accuracy)
# OK we aren't printing a CSV file to stdout, so we can include a bit more explanation
# as a little header to the real results.
note_count = len(note_range)
chart_word = "chart" if note_count == 1 else "charts"
pretty_ranges = f" ({console_utils.pretty_note_range(note_range)})" if note_count > 0 else ""
console.print(f"Comparing {note_count} {chart_word}{pretty_ranges}")
console.print(f"Truth: {truth}")
console.print(f"Annotator: {annotator}")
console.print()

if args.save: # deprecated/hidden since 2.0, but still supported for now
output_stem = os.path.join(reader.project_dir, f"accuracy-{truth}-{annotator}")

def run_accuracy(args: argparse.Namespace) -> None:
proj_config = config.ProjectConfig(args.project_dir, config_path=args.config)
reader = cohort.CohortReader(proj_config)
accuracy(reader, args.truth_annotator, args.annotator, save=args.save, verbose=args.verbose)
# JSON: Historically, this has been formatted with the global label results intermixed
# with the specific label names, so reproduce that historical formatting here.
# Note: this could bite us if the user ever has a label like "Kappa", which is why the
# above code avoids intermixing, but we'll keep this as-is for now.
scores.update(scores[None])
del scores[None]
common.write_json(f"{output_stem}.json", scores)
console.print(f"Wrote {output_stem}.json")

# CSV: we should really use a .tsv suffix here, but keeping .csv for historical reasons
common.write_text(f"{output_stem}.csv", agree.csv_table(scores, reader.class_labels))
console.print(f"Wrote {output_stem}.csv")
else:
console.print(table)
22 changes: 12 additions & 10 deletions chart_review/commands/ids.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import argparse
import csv
import sys

import rich.table

from chart_review import cli_utils


def make_subparser(parser: argparse.ArgumentParser) -> None:
cli_utils.add_project_args(parser)
cli_utils.add_output_args(parser)
parser.set_defaults(func=print_ids)


Expand All @@ -20,10 +21,8 @@ def print_ids(args: argparse.Namespace) -> None:
"""
reader = cli_utils.get_cohort_reader(args)

writer = csv.writer(sys.stdout)
writer.writerow(["chart_id", "original_fhir_id", "anonymized_fhir_id"])
table = cli_utils.create_table("Chart ID", "Original FHIR ID", "Anonymized FHIR ID")

# IDS
for chart in reader.ls_export:
chart_id = str(chart["id"])
chart_data = chart.get("data", {})
Expand All @@ -33,18 +32,21 @@ def print_ids(args: argparse.Namespace) -> None:
orig_id = f"Encounter/{chart_data['enc_id']}" if "enc_id" in chart_data else ""
anon_id = f"Encounter/{chart_data['anon_id']}" if "anon_id" in chart_data else ""
if orig_id or anon_id:
writer.writerow([chart_id, orig_id, anon_id])
table.add_row(chart_id, orig_id, anon_id)
printed = True

# Now each DocRef ID
for orig_id, anon_id in chart_data.get("docref_mappings", {}).items():
writer.writerow(
[chart_id, f"DocumentReference/{orig_id}", f"DocumentReference/{anon_id}"]
)
table.add_row(chart_id, f"DocumentReference/{orig_id}", f"DocumentReference/{anon_id}")
printed = True

if not printed:
# Guarantee that every Chart ID shows up at least once - so it's clearer that the
# chart ID is included in the Label Studio export but that it does not have any
# IDs mapped to it.
writer.writerow([chart_id, None, None])
table.add_row(chart_id, None, None)

if args.csv:
cli_utils.print_table_as_csv(table)
else:
rich.get_console().print(table)
15 changes: 7 additions & 8 deletions chart_review/commands/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

def make_subparser(parser: argparse.ArgumentParser) -> None:
cli_utils.add_project_args(parser)
cli_utils.add_output_args(parser)
parser.set_defaults(func=print_labels)


Expand All @@ -28,12 +29,7 @@ def print_labels(args: argparse.Namespace) -> None:
label_notes[annotator][name] = note_ids
any_annotator_note_sets.setdefault(name, types.NoteSet()).update(note_ids)

label_table = rich.table.Table(
"Annotator",
"Chart Count",
"Label",
box=rich.box.ROUNDED,
)
label_table = cli_utils.create_table("Annotator", "Chart Count", "Label")

# First add summary entries, for counts across the union of all annotators
for name in label_names:
Expand All @@ -47,5 +43,8 @@ def print_labels(args: argparse.Namespace) -> None:
count = str(len(note_set))
label_table.add_row(annotator, count, name)

rich.get_console().print(label_table)
console_utils.print_ignored_charts(reader)
if args.csv:
cli_utils.print_table_as_csv(label_table)
else:
rich.get_console().print(label_table)
console_utils.print_ignored_charts(reader)
47 changes: 33 additions & 14 deletions docs/accuracy.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,7 @@ F1 Sens Spec PPV NPV Kappa TP FN TN FP Label

## Options

### `--save`

Use this to write a JSON and CSV file to the project directory,
rather than printing to the console.
Useful for passing results around in a machine-parsable format.

### `--verbose`
### --verbose

Use this to also print out a table of per-chart/per-label classifications.
This is helpful for investigating where specifically the two annotators agreed or not.
Expand All @@ -50,12 +44,6 @@ Comparing 3 charts (1, 3–4)
Truth: jill
Annotator: jane

F1 Sens Spec PPV NPV Kappa TP FN TN FP Label
0.667 0.75 0.6 0.6 0.75 0.341 3 1 3 2 *
0.667 0.5 1.0 1.0 0.5 0.4 1 1 1 0 Cough
1.0 1.0 1.0 1.0 1.0 1.0 2 0 1 0 Fatigue
0 0 0 0 0 0 0 0 1 2 Headache

╭──────────┬──────────┬────────────────╮
│ Chart ID │ Label │ Classification │
├──────────┼──────────┼────────────────┤
Expand All @@ -71,4 +59,35 @@ F1 Sens Spec PPV NPV Kappa TP FN TN FP Label
│ 4 │ Fatigue │ TP │
│ 4 │ Headache │ FP │
╰──────────┴──────────┴────────────────╯
```
```

### --csv

Print the accuracy chart in a machine-parseable CSV format.

Can be used with both the default or verbose modes.

#### Examples

```shell
$ chart-review accuracy jill jane --csv
f1,sens,spec,ppv,npv,kappa,tp,fn,tn,fp,label
0.667,0.75,0.6,0.6,0.75,0.341,3,1,3,2,*
0.667,0.5,1.0,1.0,0.5,0.4,1,1,1,0,Cough
1.0,1.0,1.0,1.0,1.0,1.0,2,0,1,0,Fatigue
0,0,0,0,0,0,0,0,1,2,Headache
```

```shell
$ chart-review accuracy jill jane --verbose --csv
chart_id,label,classification
1,Cough,TP
1,Fatigue,TP
1,Headache,FP
3,Cough,TN
3,Fatigue,TN
3,Headache,TN
4,Cough,FN
4,Fatigue,TP
4,Headache,FP
```
Loading

0 comments on commit 032f08e

Please sign in to comment.