diff --git a/chart_review/cli.py b/chart_review/cli.py index 0b98816..ccfe8b0 100644 --- a/chart_review/cli.py +++ b/chart_review/cli.py @@ -3,7 +3,7 @@ import argparse import sys -from chart_review import cohort +from chart_review import cohort, config from chart_review.commands.accuracy import accuracy @@ -18,8 +18,14 @@ def add_project_args(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--project-dir", default=".", - help="Directory holding project files, " - "like config.yaml and labelstudio-export.json (default: current dir)", + metavar="DIR", + help=( + "Directory holding project files, " + "like labelstudio-export.json (default: current dir)" + ), + ) + parser.add_argument( + "--config", "-c", metavar="PATH", help="Config file (default: [project-dir]/config.yaml)" ) @@ -49,7 +55,8 @@ def add_accuracy_subparser(subparsers) -> None: def run_accuracy(args: argparse.Namespace) -> None: - reader = cohort.CohortReader(args.project_dir) + proj_config = config.ProjectConfig(args.project_dir, config_path=args.config) + reader = cohort.CohortReader(proj_config) accuracy(reader, args.truth_annotator, args.annotator) diff --git a/chart_review/cohort.py b/chart_review/cohort.py index cf58330..adfd637 100644 --- a/chart_review/cohort.py +++ b/chart_review/cohort.py @@ -11,36 +11,24 @@ class CohortReader: - def __init__(self, project_dir: str): + def __init__(self, proj_config: config.ProjectConfig): """ - :param project_dir: str like /opt/labelstudio/study_name + :param proj_config: parsed project configuration """ - self.project_dir = project_dir - self.config = config.ProjectConfig(project_dir) - self.labelstudio_json = self.path( - "labelstudio-export.json" - ) # TODO: refactor labelstudio.py + self.config = proj_config + self.project_dir = self.config.project_dir + self.labelstudio_json = self.config.path("labelstudio-export.json") self.annotator = self.config.annotators self.note_range = self.config.note_ranges self.class_labels = self.config.class_labels - self.annotations = None + self.annotations = simplify.simplify_full(self.labelstudio_json, self.annotator) saved = common.read_json(self.labelstudio_json) - if isinstance(saved, list): - self.annotations = simplify.simplify_full(self.labelstudio_json, self.annotator) - else: - # TODO: int keys cant be saved in JSON, compatability hack use instead LabelStudio.py - compat = dict() - compat["files"] = saved["files"] - compat["annotations"] = dict() - for k in saved["annotations"].keys(): - compat["annotations"][int(k)] = saved["annotations"][k] - self.annotations = compat # Load external annotations (i.e. from NLP tags or ICD10 codes) for name, value in self.config.external_annotations.items(): self.annotations = external.merge_external( - self.annotations, saved, project_dir, name, value + self.annotations, saved, self.project_dir, name, value ) # Detect note ranges if they were not defined in the project config @@ -67,9 +55,6 @@ def __init__(self, project_dir: str): continue self.ignored_notes.add(ls_id) - def path(self, filename): - return os.path.join(self.project_dir, filename) - def calc_term_freq(self, annotator) -> dict: """ Calculate Term Frequency of highlighted mentions. diff --git a/chart_review/config.py b/chart_review/config.py index 3a8cfcd..5eafc97 100644 --- a/chart_review/config.py +++ b/chart_review/config.py @@ -2,7 +2,7 @@ import os import re import sys -from typing import Iterable, Union +from typing import Iterable, Optional, Union import yaml @@ -13,22 +13,12 @@ class ProjectConfig: _NUMBER_REGEX = re.compile(r"\d+") _RANGE_REGEX = re.compile(r"\d+-\d+") - def __init__(self, project_dir: str): + def __init__(self, project_dir: str, config_path: Optional[str] = None): """ :param project_dir: str like /opt/labelstudio/study_name """ - self._data = None - - for filename in ("config.yaml", "config.json"): - try: - path = os.path.join(project_dir, filename) - with open(path, "r", encoding="utf8") as f: - self._data = yaml.safe_load(f) - except FileNotFoundError: - continue - - if self._data is None: - raise FileNotFoundError(f"No config.yaml or config.json file found in {project_dir}") + self.project_dir = project_dir + self._data = self._load_config(config_path) # ** Annotators ** # Internally, we're often dealing with numeric ID as the primary annotator identifier, @@ -57,6 +47,26 @@ def __init__(self, project_dir: str): value = {value} self.implied_labels[key] = set(value) + def path(self, filename: str) -> str: + return os.path.join(self.project_dir, filename) + + @staticmethod + def _read_yaml(path) -> dict: + with open(path, encoding="utf8") as f: + return yaml.safe_load(f) + + def _load_config(self, config_path: Optional[str]) -> dict: + if config_path is None: + # Support config.json in case folks prefer that + try: + return self._read_yaml(self.path("config.json")) + except FileNotFoundError: + return self._read_yaml(self.path("config.yaml")) + + # Don't resolve config_path relative to the project dir, because + # this will have come from the command line and will resolve relative to `pwd`. + return self._read_yaml(config_path) + def _parse_note_range(self, value: Union[str, int, list[Union[str, int]]]) -> Iterable[int]: if isinstance(value, list): return list(itertools.chain.from_iterable(self._parse_note_range(v) for v in value)) diff --git a/tests/test_cli.py b/tests/test_cli.py index 862bb45..65f80b5 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -94,3 +94,18 @@ def test_ignored_ids(self): self.assertEqual(0, accuracy_json["FN"]) self.assertEqual(2, accuracy_json["TN"]) self.assertEqual(0, accuracy_json["FP"]) + + def test_custom_config(self): + with tempfile.TemporaryDirectory() as tmpdir: + shutil.copy(f"{DATA_DIR}/cold/labelstudio-export.json", tmpdir) + cli.main_cli( + [ + "accuracy", + "--project-dir", + tmpdir, + "-c", + f"{DATA_DIR}/cold/config.yaml", + "jane", + "john", + ] + ) # just confirm it doesn't error out diff --git a/tests/test_external.py b/tests/test_external.py index 83492d3..7e5c08f 100644 --- a/tests/test_external.py +++ b/tests/test_external.py @@ -5,7 +5,7 @@ import tempfile import unittest -from chart_review import cohort +from chart_review import cohort, config DATA_DIR = os.path.join(os.path.dirname(__file__), "data") @@ -20,7 +20,7 @@ def setUp(self): def test_basic_read(self): with tempfile.TemporaryDirectory() as tmpdir: shutil.copytree(f"{DATA_DIR}/external", tmpdir, dirs_exist_ok=True) - reader = cohort.CohortReader(tmpdir) + reader = cohort.CohortReader(config.ProjectConfig(tmpdir)) self.assertEqual( {