Skip to content

Commit

Permalink
Merge pull request #17 from smart-on-fhir/mikix/custom-config
Browse files Browse the repository at this point in the history
feat: support secondary config files with --config
  • Loading branch information
mikix authored Jan 16, 2024
2 parents f21086e + 64169e8 commit 6ece63c
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 42 deletions.
15 changes: 11 additions & 4 deletions chart_review/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse
import sys

from chart_review import cohort
from chart_review import cohort, config
from chart_review.commands.accuracy import accuracy


Expand All @@ -18,8 +18,14 @@ def add_project_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--project-dir",
default=".",
help="Directory holding project files, "
"like config.yaml and labelstudio-export.json (default: current dir)",
metavar="DIR",
help=(
"Directory holding project files, "
"like labelstudio-export.json (default: current dir)"
),
)
parser.add_argument(
"--config", "-c", metavar="PATH", help="Config file (default: [project-dir]/config.yaml)"
)


Expand Down Expand Up @@ -49,7 +55,8 @@ def add_accuracy_subparser(subparsers) -> None:


def run_accuracy(args: argparse.Namespace) -> None:
reader = cohort.CohortReader(args.project_dir)
proj_config = config.ProjectConfig(args.project_dir, config_path=args.config)
reader = cohort.CohortReader(proj_config)
accuracy(reader, args.truth_annotator, args.annotator)


Expand Down
29 changes: 7 additions & 22 deletions chart_review/cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,24 @@


class CohortReader:
def __init__(self, project_dir: str):
def __init__(self, proj_config: config.ProjectConfig):
"""
:param project_dir: str like /opt/labelstudio/study_name
:param proj_config: parsed project configuration
"""
self.project_dir = project_dir
self.config = config.ProjectConfig(project_dir)
self.labelstudio_json = self.path(
"labelstudio-export.json"
) # TODO: refactor labelstudio.py
self.config = proj_config
self.project_dir = self.config.project_dir
self.labelstudio_json = self.config.path("labelstudio-export.json")
self.annotator = self.config.annotators
self.note_range = self.config.note_ranges
self.class_labels = self.config.class_labels
self.annotations = None
self.annotations = simplify.simplify_full(self.labelstudio_json, self.annotator)

saved = common.read_json(self.labelstudio_json)
if isinstance(saved, list):
self.annotations = simplify.simplify_full(self.labelstudio_json, self.annotator)
else:
# TODO: int keys cant be saved in JSON, compatability hack use instead LabelStudio.py
compat = dict()
compat["files"] = saved["files"]
compat["annotations"] = dict()
for k in saved["annotations"].keys():
compat["annotations"][int(k)] = saved["annotations"][k]
self.annotations = compat

# Load external annotations (i.e. from NLP tags or ICD10 codes)
for name, value in self.config.external_annotations.items():
self.annotations = external.merge_external(
self.annotations, saved, project_dir, name, value
self.annotations, saved, self.project_dir, name, value
)

# Detect note ranges if they were not defined in the project config
Expand All @@ -67,9 +55,6 @@ def __init__(self, project_dir: str):
continue
self.ignored_notes.add(ls_id)

def path(self, filename):
return os.path.join(self.project_dir, filename)

def calc_term_freq(self, annotator) -> dict:
"""
Calculate Term Frequency of highlighted mentions.
Expand Down
38 changes: 24 additions & 14 deletions chart_review/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import re
import sys
from typing import Iterable, Union
from typing import Iterable, Optional, Union

import yaml

Expand All @@ -13,22 +13,12 @@ class ProjectConfig:
_NUMBER_REGEX = re.compile(r"\d+")
_RANGE_REGEX = re.compile(r"\d+-\d+")

def __init__(self, project_dir: str):
def __init__(self, project_dir: str, config_path: Optional[str] = None):
"""
:param project_dir: str like /opt/labelstudio/study_name
"""
self._data = None

for filename in ("config.yaml", "config.json"):
try:
path = os.path.join(project_dir, filename)
with open(path, "r", encoding="utf8") as f:
self._data = yaml.safe_load(f)
except FileNotFoundError:
continue

if self._data is None:
raise FileNotFoundError(f"No config.yaml or config.json file found in {project_dir}")
self.project_dir = project_dir
self._data = self._load_config(config_path)

# ** Annotators **
# Internally, we're often dealing with numeric ID as the primary annotator identifier,
Expand Down Expand Up @@ -57,6 +47,26 @@ def __init__(self, project_dir: str):
value = {value}
self.implied_labels[key] = set(value)

def path(self, filename: str) -> str:
return os.path.join(self.project_dir, filename)

@staticmethod
def _read_yaml(path) -> dict:
with open(path, encoding="utf8") as f:
return yaml.safe_load(f)

def _load_config(self, config_path: Optional[str]) -> dict:
if config_path is None:
# Support config.json in case folks prefer that
try:
return self._read_yaml(self.path("config.json"))
except FileNotFoundError:
return self._read_yaml(self.path("config.yaml"))

# Don't resolve config_path relative to the project dir, because
# this will have come from the command line and will resolve relative to `pwd`.
return self._read_yaml(config_path)

def _parse_note_range(self, value: Union[str, int, list[Union[str, int]]]) -> Iterable[int]:
if isinstance(value, list):
return list(itertools.chain.from_iterable(self._parse_note_range(v) for v in value))
Expand Down
15 changes: 15 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,18 @@ def test_ignored_ids(self):
self.assertEqual(0, accuracy_json["FN"])
self.assertEqual(2, accuracy_json["TN"])
self.assertEqual(0, accuracy_json["FP"])

def test_custom_config(self):
with tempfile.TemporaryDirectory() as tmpdir:
shutil.copy(f"{DATA_DIR}/cold/labelstudio-export.json", tmpdir)
cli.main_cli(
[
"accuracy",
"--project-dir",
tmpdir,
"-c",
f"{DATA_DIR}/cold/config.yaml",
"jane",
"john",
]
) # just confirm it doesn't error out
4 changes: 2 additions & 2 deletions tests/test_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tempfile
import unittest

from chart_review import cohort
from chart_review import cohort, config

DATA_DIR = os.path.join(os.path.dirname(__file__), "data")

Expand All @@ -20,7 +20,7 @@ def setUp(self):
def test_basic_read(self):
with tempfile.TemporaryDirectory() as tmpdir:
shutil.copytree(f"{DATA_DIR}/external", tmpdir, dirs_exist_ok=True)
reader = cohort.CohortReader(tmpdir)
reader = cohort.CohortReader(config.ProjectConfig(tmpdir))

self.assertEqual(
{
Expand Down

0 comments on commit 6ece63c

Please sign in to comment.