Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add new frequencies command to show term frequencies #48

Merged
merged 1 commit into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 0 additions & 51 deletions chart_review/agree.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just removing some unused code that I noticed.

Original file line number Diff line number Diff line change
Expand Up @@ -151,38 +151,6 @@ def score_matrix(matrix: dict, sig_digits=3) -> dict:
}


def avg_scores(first: dict, second: dict, sig_digits=3) -> dict:
merged = {}
for header in csv_header():
added = first[header] + second[header]
if header in ["TP", "FP", "FN", "TN"]:
merged[header] = added
else:
merged[header] = round(added / 2, sig_digits)
return merged


def score_reviewer(
annotations: types.ProjectAnnotations,
truth: str,
annotator: str,
note_range: Collection[int],
labels: Iterable[str] = None,
) -> dict:
"""
Score reliability of an annotator against a truth annotator.

:param annotations: prepared map of annotators and mentions
:param truth: annotator to use as the ground truth
:param annotator: another annotator to compare with truth
:param note_range: collection of LabelStudio document ID
:param labels: (optional) set of labels to score
:return: dict, keys f1, precision, recall and vals= %score
"""
truth_matrix = confusion_matrix(annotations, truth, annotator, note_range, labels=labels)
return score_matrix(truth_matrix)


def csv_table(score: dict, class_labels: types.LabelSet):
table = list()
table.append(csv_header(False, True))
Expand Down Expand Up @@ -229,22 +197,3 @@ def csv_row_score(

row.append(pick_label if pick_label else "*")
return "\t".join(row)


def true_prevalence(prevalence_apparent: float, sensitivity: float, specificity: float):
"""
See paper: "The apparent prevalence, the true prevalence"
https://www.ncbi.nlm.nih.gov/pmc/articles/PMC9195606

Using Eq. 4. it can be calculated:
True prevalence = (Apparent prevalence + Sp - 1)/(Se + Sp - 1)

:param prevalence_apparent: estimated prevalence, concretely:
the %NLP labled positives / cohort

:param: sensitivity: of the class label (where prevalence was measured)
:param: specificity: of the class label (where prevalence was measured)

:return: float adjusted prevalence
"""
return round((prevalence_apparent + specificity - 1) / (sensitivity + specificity - 1), 5)
5 changes: 4 additions & 1 deletion chart_review/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse
import sys

from chart_review.commands import accuracy, default, ids, labels, mentions
from chart_review.commands import accuracy, default, frequency, ids, labels, mentions


def define_parser() -> argparse.ArgumentParser:
Expand All @@ -13,6 +13,9 @@ def define_parser() -> argparse.ArgumentParser:

subparsers = parser.add_subparsers()
accuracy.make_subparser(subparsers.add_parser("accuracy", help="calculate F1 and Kappa scores"))
frequency.make_subparser(
subparsers.add_parser("frequency", help="show counts of each text mention")
)
ids.make_subparser(subparsers.add_parser("ids", help="map Label Studio IDs to FHIR IDs"))
labels.make_subparser(subparsers.add_parser("labels", help="show label usage by annotator"))
mentions.make_subparser(subparsers.add_parser("mentions", help="show each mention of a label"))
Expand Down
57 changes: 2 additions & 55 deletions chart_review/cohort.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Iterable

from chart_review.common import guard_str, guard_iter, guard_in
from chart_review import agree, common, config, errors, external, term_freq, simplify, types
from chart_review.common import guard_iter, guard_in
from chart_review import agree, common, config, errors, external, simplify, types


class CohortReader:
Expand Down Expand Up @@ -84,25 +84,6 @@ def _collect_note_ranges(
def class_labels(self):
return self.annotations.labels

def calc_term_freq(self, annotator) -> dict:
"""
Calculate Term Frequency of highlighted mentions.
:param annotator: an annotator name
:return: dict key=TERM val= {label, list of chart_id}
"""
return term_freq.calc_term_freq(self.annotations, guard_str(annotator))

def calc_label_freq(self, annotator) -> dict:
"""
Calculate Term Frequency of highlighted mentions.
:param annotator: an annotator name
:return: dict key=TERM val= {label, list of chart_id}
"""
return term_freq.calc_label_freq(self.calc_term_freq(annotator))

def calc_term_label_confusion(self, annotator) -> dict:
return term_freq.calc_term_label_confusion(self.calc_term_freq(annotator))

def _select_labels(self, label_pick: str = None) -> Iterable[str]:
if label_pick:
guard_in(label_pick, self.class_labels)
Expand Down Expand Up @@ -131,37 +112,3 @@ def confusion_matrix(
note_range,
labels=labels,
)

def score_reviewer(self, truth: str, annotator: str, note_range, label_pick: str = None):
"""
Score reliability of rater at the level of all symptom *PREVALENCE*
:param truth: annotator to use as the ground truth
:param annotator: another annotator to compare with truth
:param note_range: default= all in corpus
:param label_pick: (optional) of the CLASS_LABEL to score separately
:return: dict, keys f1, precision, recall and vals= %score
"""
labels = self._select_labels(label_pick)
note_range = set(guard_iter(note_range))
return agree.score_reviewer(self.annotations, truth, annotator, note_range, labels=labels)

def score_reviewer_table_csv(self, truth: str, annotator: str, note_range) -> str:
table = list()
table.append(agree.csv_header(False, True))

score = self.score_reviewer(truth, annotator, note_range)
table.append(agree.csv_row_score(score, as_string=True))

for label in self.class_labels:
score = self.score_reviewer(truth, annotator, note_range, label)
table.append(agree.csv_row_score(score, label, as_string=True))

return "\n".join(table) + "\n"

def score_reviewer_table_dict(self, truth, annotator, note_range) -> dict:
table = self.score_reviewer(truth, annotator, note_range)

for label in self.class_labels:
table[label] = self.score_reviewer(truth, annotator, note_range, label)

return table
2 changes: 1 addition & 1 deletion chart_review/commands/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def print_info(args: argparse.Namespace) -> None:
notes = reader.note_range[annotator]
chart_table.add_row(
annotator,
str(len(notes)),
f"{len(notes):,}",
console_utils.pretty_note_range(notes),
)

Expand Down
77 changes: 77 additions & 0 deletions chart_review/commands/frequency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import argparse

import rich
import rich.box
import rich.table
import rich.text

from chart_review import cli_utils, console_utils, types


def make_subparser(parser: argparse.ArgumentParser) -> None:
cli_utils.add_project_args(parser)
cli_utils.add_output_args(parser)
parser.set_defaults(func=print_frequency)


def print_frequency(args: argparse.Namespace) -> None:
"""
Print counts of each text mention.
"""
reader = cli_utils.get_cohort_reader(args)

frequencies = {} # annotator -> label -> text -> count
all_annotator_frequencies = {} # label -> text -> count
text_labels = {} # text -> labelset (to flag term confusion)
for annotator in reader.annotations.original_text_mentions:
annotator_mentions = reader.annotations.original_text_mentions[annotator]
for labeled_texts in annotator_mentions.values():
for labeled_text in labeled_texts:
text = (labeled_text.text or "").strip().casefold()
for label in labeled_text.labels:
if label in reader.annotations.labels:
# Count the mention for this annotator
label_to_text = frequencies.setdefault(annotator, {})
text_to_count = label_to_text.setdefault(label, {})
text_to_count[text] = text_to_count.get(text, 0) + 1

# Count the mention for our running all-annotators total
all_text_to_count = all_annotator_frequencies.setdefault(label, {})
all_text_to_count[text] = all_text_to_count.get(text, 0) + 1

# And finally, add it to our running term-confusion tracker
text_labels.setdefault(text, types.LabelSet()).add(label)

# Now group up the data into a formatted table
table = cli_utils.create_table("Annotator", "Label", "Mention", "Count")
has_term_confusion = False # whether multiple labels are used for the same text

# Helper method to add all the info for a single annotator to our table
def add_annotator_to_table(name, label_to_text: dict) -> None:
nonlocal has_term_confusion

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

first time i've seen this keyword - i get the approach, i think it makes sense in this use case, but it makes me itchy.

maybe i just need time to come around on it, like our friend the walrus.

table.add_section()
for label in sorted(label_to_text, key=str.casefold):
text_to_count = label_to_text[label]
for text, count in sorted(
text_to_count.items(), key=lambda t: (t[1], t[0]), reverse=True
):
is_confused = not args.csv and text and len(text_labels[text]) > 1
if is_confused:
text = rich.text.Text(text + "*", style="bold")
has_term_confusion = True
table.add_row(name, label, text, f"{count:,}")

# Add each annotator
add_annotator_to_table(rich.text.Text("All", style="italic"), all_annotator_frequencies)
for annotator in sorted(frequencies, key=str.casefold):
add_annotator_to_table(annotator, frequencies[annotator])

if args.csv:
cli_utils.print_table_as_csv(table)
else:
rich.get_console().print(table)
console_utils.print_ignored_charts(reader)
if has_term_confusion:
rich.get_console().print(
f" * This text has multiple associated labels.", style="italic"
)
2 changes: 1 addition & 1 deletion chart_review/commands/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def print_labels(args: argparse.Namespace) -> None:

# First add summary entries, for counts across the union of all annotators
for name in label_names:
count = str(len(any_annotator_note_sets.get(name, {})))
count = f"{len(any_annotator_note_sets.get(name, {})):,}"
label_table.add_row(rich.text.Text("Any", style="italic"), name, count)

# Now do each annotator as their own little boxed section
Expand Down
19 changes: 13 additions & 6 deletions chart_review/commands/mentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import rich.table
import rich.text

from chart_review import cli_utils, console_utils, types
from chart_review import cli_utils, console_utils


def make_subparser(parser: argparse.ArgumentParser) -> None:
Expand All @@ -24,12 +24,19 @@ def print_mentions(args: argparse.Namespace) -> None:

for annotator in sorted(reader.annotations.original_text_mentions, key=str.casefold):
table.add_section()
mentions = reader.annotations.original_text_mentions[annotator]
for note_id, labeled_texts in mentions.items():
for label_text in labeled_texts:
for label in sorted(label_text.labels, key=str.casefold):
annotator_mentions = reader.annotations.original_text_mentions[annotator]
for note_id, labeled_texts in annotator_mentions.items():
# Gather all combos of text/label (i.e. all mentions) in this note
note_mentions = set()
for labeled_text in labeled_texts:
text = labeled_text.text and labeled_text.text.casefold()
for label in labeled_text.labels:
if label in reader.annotations.labels:
table.add_row(annotator, str(note_id), label_text.text, label)
note_mentions.add((text, label))

# Now add each mention to the table
for note_mention in sorted(note_mentions, key=lambda m: (m[0], m[1].casefold())):
table.add_row(annotator, str(note_id), note_mention[0], note_mention[1])

if args.csv:
cli_utils.print_table_as_csv(table)
Expand Down
11 changes: 0 additions & 11 deletions chart_review/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,6 @@ def print_line(heading=None) -> None:
###############################################################################
# Helper Functions: enum type smoothing
###############################################################################
def guard_str(object) -> str:
if isinstance(object, Enum):
return str(object.name)
elif isinstance(object, EnumMeta):
return str(object.name)
elif isinstance(object, str):
return object
else:
raise Exception(f"expected str|Enum but got {type(object)}")


def guard_iter(object) -> Iterable:
if isinstance(object, Enum):
return guard_iter(object.value)
Expand Down
Loading
Loading