From 8978f1c148a4348e1203720f4b5dc7c43c3ffcda Mon Sep 17 00:00:00 2001 From: Michael Terry Date: Fri, 9 Aug 2024 12:34:32 -0400 Subject: [PATCH] Add two new covid_symptom GPT tasks - covid_symptom__nlp_results_gpt35 - covid_symptom__nlp_results_gpt4 Both call out to Azure ChatGPT and detects COVID symptoms in clinical notes, with the same cohort selection as the other covid_symptom tasks (ED notes). You will need to set the following two environment variables to use them: - AZURE_OPENAI_API_KEY - AZURE_OPENAI_ENDPOINT These tasks will also need to be added to your AWS config (see the new template) and the Crawler run to detect their schemas. The Library covid study has not yet been updated to consume the tables yet, but will be. --- .pylintrc | 437 ------------------ compose.yaml | 8 +- .../etl/studies/covid_symptom/__init__.py | 7 +- .../etl/studies/covid_symptom/covid_ctakes.py | 15 +- .../etl/studies/covid_symptom/covid_tasks.py | 225 ++++++++- cumulus_etl/etl/studies/hftest/hf_tasks.py | 4 +- cumulus_etl/etl/tasks/task_factory.py | 2 + cumulus_etl/nlp/__init__.py | 2 +- cumulus_etl/nlp/utils.py | 37 +- docs/setup/cumulus-aws-template.yaml | 2 + docs/studies/covid-symptom.md | 53 ++- pyproject.toml | 1 + tests/covid_symptom/test_covid_gpt.py | 253 ++++++++++ tests/covid_symptom/test_covid_results.py | 82 +++- tests/etl/base.py | 8 +- tests/hftest/test_hftask.py | 4 +- tests/i2b2_mock_data.py | 6 +- 17 files changed, 658 insertions(+), 488 deletions(-) delete mode 100644 .pylintrc create mode 100644 tests/covid_symptom/test_covid_gpt.py diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index 98c6c6bf..00000000 --- a/.pylintrc +++ /dev/null @@ -1,437 +0,0 @@ -# Below is a copy of Google's pylintrc, with the following modifications: -# - indent-string changed to 4 spaces (the shipped version of this config -# file has 2 because that's what Google uses internally, despite the -# public description of their style guide using 4) -# - max-line-length changed to 120 because 80 was driving us crazy -# - wrong-import-order re-enabled, because MT likes order among chaos -# -# BELOW THIS LINE IS A COPY OF https://google.github.io/styleguide/pylintrc - -# This Pylint rcfile contains a best-effort configuration to uphold the -# best-practices and style described in the Google Python style guide: -# https://google.github.io/styleguide/pyguide.html -# -# Its canonical open-source location is: -# https://google.github.io/styleguide/pylintrc - -[MASTER] - -# Files or directories to be skipped. They should be base names, not paths. -ignore=third_party - -# Files or directories matching the regex patterns are skipped. The regex -# matches against base names, not paths. -ignore-patterns= - -# Pickle collected data for later comparisons. -persistent=no - -# List of plugins (as comma separated values of python modules names) to load, -# usually to register additional checkers. -load-plugins= - -# Use multiple processes to speed up Pylint. -jobs=4 - -# Allow loading of arbitrary C extensions. Extensions are imported into the -# active Python interpreter and may run arbitrary code. -unsafe-load-any-extension=no - - -[MESSAGES CONTROL] - -# Only show warnings with the listed confidence levels. Leave empty to show -# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED -confidence= - -# Enable the message, report, category or checker with the given id(s). You can -# either give multiple identifier separated by comma (,) or put this option -# multiple time (only on the command line, not in the configuration file where -# it should appear only once). See also the "--disable" option for examples. -#enable= - -# Disable the message, report, category or checker with the given id(s). You -# can either give multiple identifiers separated by comma (,) or put this -# option multiple times (only on the command line, not in the configuration -# file where it should appear only once).You can also use "--disable=all" to -# disable everything first and then reenable specific checks. For example, if -# you want to run only the similarities checker, you can use "--disable=all -# --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use"--disable=all --enable=classes -# --disable=W" -disable=abstract-method, - apply-builtin, - arguments-differ, - attribute-defined-outside-init, - backtick, - bad-option-value, - basestring-builtin, - buffer-builtin, - c-extension-no-member, - consider-using-enumerate, - cmp-builtin, - cmp-method, - coerce-builtin, - coerce-method, - delslice-method, - div-method, - duplicate-code, - eq-without-hash, - execfile-builtin, - file-builtin, - filter-builtin-not-iterating, - fixme, - getslice-method, - global-statement, - hex-method, - idiv-method, - implicit-str-concat, - import-error, - import-self, - import-star-module-level, - inconsistent-return-statements, - input-builtin, - intern-builtin, - invalid-str-codec, - locally-disabled, - long-builtin, - long-suffix, - map-builtin-not-iterating, - misplaced-comparison-constant, - missing-function-docstring, - metaclass-assignment, - next-method-called, - next-method-defined, - no-absolute-import, - no-else-break, - no-else-continue, - no-else-raise, - no-else-return, - no-init, # added - no-member, - no-name-in-module, - no-self-use, - nonzero-method, - oct-method, - old-division, - old-ne-operator, - old-octal-literal, - old-raise-syntax, - parameter-unpacking, - print-statement, - raising-string, - range-builtin-not-iterating, - raw_input-builtin, - rdiv-method, - reduce-builtin, - relative-import, - reload-builtin, - round-builtin, - setslice-method, - signature-differs, - standarderror-builtin, - suppressed-message, - sys-max-int, - too-few-public-methods, - too-many-ancestors, - too-many-arguments, - too-many-boolean-expressions, - too-many-branches, - too-many-instance-attributes, - too-many-locals, - too-many-nested-blocks, - too-many-public-methods, - too-many-return-statements, - too-many-statements, - trailing-newlines, - unichr-builtin, - unicode-builtin, - unnecessary-pass, - unpacking-in-except, - useless-else-on-loop, - useless-object-inheritance, - useless-suppression, - using-cmp-argument, - xrange-builtin, - zip-builtin-not-iterating, - - -[REPORTS] - -# Set the output format. Available formats are text, parseable, colorized, msvs -# (visual studio) and html. You can also give a reporter class, eg -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Tells whether to display a full report or only the messages -reports=no - -# Python expression which should return a note less than 10 (10 is the highest -# note). You have access to the variables errors warning, statement which -# respectively contain the number of errors / warnings messages and the total -# number of statements analyzed. This is used by the global evaluation report -# (RP0004). -evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) - -# Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details -#msg-template= - - -[BASIC] - -# Good variable names which should always be accepted, separated by a comma -good-names=main,_ - -# Bad variable names which should always be refused, separated by a comma -bad-names= - -# Colon-delimited sets of names that determine each other's naming style when -# the name regexes allow several styles. -name-group= - -# Include a hint for the correct naming format with invalid-name -include-naming-hint=no - -# List of decorators that produce properties, such as abc.abstractproperty. Add -# to this list to register other decorators that produce valid properties. -property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl - -# Regular expression matching correct function names -function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ - -# Regular expression matching correct variable names -variable-rgx=^[a-z][a-z0-9_]*$ - -# Regular expression matching correct constant names -const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ - -# Regular expression matching correct attribute names -attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ - -# Regular expression matching correct argument names -argument-rgx=^[a-z][a-z0-9_]*$ - -# Regular expression matching correct class attribute names -class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ - -# Regular expression matching correct inline iteration names -inlinevar-rgx=^[a-z][a-z0-9_]*$ - -# Regular expression matching correct class names -class-rgx=^_?[A-Z][a-zA-Z0-9]*$ - -# Regular expression matching correct module names -module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ - -# Regular expression matching correct method names -method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ - -# Regular expression which should only match function or class names that do -# not require a docstring. -no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ - -# Minimum line length for functions/classes that require docstrings, shorter -# ones are exempt. -docstring-min-length=10 - - -[TYPECHECK] - -# List of decorators that produce context managers, such as -# contextlib.contextmanager. Add to this list to register other decorators that -# produce valid context managers. -contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager - -# Tells whether missing members accessed in mixin class should be ignored. A -# mixin class is detected if its name ends with "mixin" (case insensitive). -ignore-mixin-members=yes - -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis. It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules= - -# List of class names for which member attributes should not be checked (useful -# for classes with dynamically set attributes). This supports the use of -# qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local - -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E1101 when accessed. Python regular -# expressions are accepted. -generated-members= - - -[FORMAT] - -# Maximum number of characters on a single line. -max-line-length=120 - -# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt -# lines made too long by directives to pytype. - -# Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=(?x)( - ^\s*(\#\ )??$| - ^\s*(from\s+\S+\s+)?import\s+.+$) - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=yes - -# Maximum number of lines in a module -max-module-lines=99999 - -# String used as indentation unit. The internal Google style guide mandates 2 -# spaces. Google's externaly-published style guide says 4, consistent with -# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google -# projects (like TensorFlow). -indent-string=' ' - -# Number of spaces of indent required inside a hanging or continued line. -indent-after-paren=4 - -# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. -expected-line-ending-format= - - -[MISCELLANEOUS] - -# List of note tags to take in consideration, separated by a comma. -notes=TODO - - -[STRING] - -# This flag controls whether inconsistent-quotes generates a warning when the -# character used as a quote delimiter is used inconsistently within a module. -check-quote-consistency=yes - - -[VARIABLES] - -# Tells whether we should check for unused import in __init__ files. -init-import=no - -# A regular expression matching the name of dummy variables (i.e. expectedly -# not used). -dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) - -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid to define new builtins when possible. -additional-builtins= - -# List of strings which can identify a callback function by name. A callback -# name must start or end with one of those strings. -callbacks=cb_,_cb - -# List of qualified module names which can have objects that can redefine -# builtins. -redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools - - -[LOGGING] - -# Logging modules to check that the string format arguments are in logging -# function parameter format -logging-modules=logging,absl.logging,tensorflow.io.logging - - -[SIMILARITIES] - -# Minimum lines number of a similarity. -min-similarity-lines=4 - -# Ignore comments when computing similarities. -ignore-comments=yes - -# Ignore docstrings when computing similarities. -ignore-docstrings=yes - -# Ignore imports when computing similarities. -ignore-imports=no - - -[SPELLING] - -# Spelling dictionary name. Available dictionaries: none. To make it working -# install python-enchant package. -spelling-dict= - -# List of comma separated words that should not be checked. -spelling-ignore-words= - -# A path to a file that contains private dictionary; one word per line. -spelling-private-dict-file= - -# Tells whether to store unknown words to indicated private dictionary in -# --spelling-private-dict-file option instead of raising a message. -spelling-store-unknown-words=no - - -[IMPORTS] - -# Deprecated modules which should not be used, separated by a comma -deprecated-modules=regsub, - TERMIOS, - Bastion, - rexec, - sets - -# Create a graph of every (i.e. internal and external) dependencies in the -# given file (report RP0402 must not be disabled) -import-graph= - -# Create a graph of external dependencies in the given file (report RP0402 must -# not be disabled) -ext-import-graph= - -# Create a graph of internal dependencies in the given file (report RP0402 must -# not be disabled) -int-import-graph= - -# Force import order to recognize a module as part of the standard -# compatibility libraries. -known-standard-library= - -# Force import order to recognize a module as part of a third party library. -known-third-party=enchant, absl - -# Analyse import fallback blocks. This can be used to support both Python 2 and -# 3 compatible code, which means that the block might have code that exists -# only in one or another interpreter, leading to false positives when analysed. -analyse-fallback-blocks=no - - -[CLASSES] - -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__, - __new__, - setUp - -# List of member names, which should be excluded from the protected access -# warning. -exclude-protected=_asdict, - _fields, - _replace, - _source, - _make - -# List of valid names for the first argument in a class method. -valid-classmethod-first-arg=cls, - class_ - -# List of valid names for the first argument in a metaclass class method. -valid-metaclass-classmethod-first-arg=mcs - - -[EXCEPTIONS] - -# Exceptions that will emit a warning when being caught. Defaults to -# "Exception" -overgeneral-exceptions=builtins.StandardError, - builtins.Exception, - builtins.BaseException diff --git a/compose.yaml b/compose.yaml index 65d7e461..55c66892 100644 --- a/compose.yaml +++ b/compose.yaml @@ -26,11 +26,15 @@ services: context: . target: cumulus-etl environment: + # Environment variobles to pull in from the host - AWS_ACCESS_KEY_ID + - AWS_DEFAULT_PROFILE + - AWS_PROFILE - AWS_SECRET_ACCESS_KEY - AWS_SESSION_TOKEN - - AWS_PROFILE - - AWS_DEFAULT_PROFILE + - AZURE_OPENAI_API_KEY + - AZURE_OPENAI_ENDPOINT + # Internal environment variobles - CUMULUS_HUGGING_FACE_URL=http://llama2:8086/ - URL_CTAKES_REST=http://ctakes-covid:8080/ctakes-web-rest/service/analyze - URL_CNLP_NEGATION=http://cnlpt-negation:8000/negation/process diff --git a/cumulus_etl/etl/studies/covid_symptom/__init__.py b/cumulus_etl/etl/studies/covid_symptom/__init__.py index ec307d2c..bc342a64 100644 --- a/cumulus_etl/etl/studies/covid_symptom/__init__.py +++ b/cumulus_etl/etl/studies/covid_symptom/__init__.py @@ -1,3 +1,8 @@ """The covid_symptom study""" -from .covid_tasks import CovidSymptomNlpResultsTask, CovidSymptomNlpResultsTermExistsTask +from .covid_tasks import ( + CovidSymptomNlpResultsGpt4Task, + CovidSymptomNlpResultsGpt35Task, + CovidSymptomNlpResultsTask, + CovidSymptomNlpResultsTermExistsTask, +) diff --git a/cumulus_etl/etl/studies/covid_symptom/covid_ctakes.py b/cumulus_etl/etl/studies/covid_symptom/covid_ctakes.py index da5e110e..abd68a24 100644 --- a/cumulus_etl/etl/studies/covid_symptom/covid_ctakes.py +++ b/cumulus_etl/etl/studies/covid_symptom/covid_ctakes.py @@ -6,7 +6,7 @@ import httpx from ctakesclient.transformer import TransformerModel -from cumulus_etl import common, fhir, nlp, store +from cumulus_etl import common, nlp, store async def covid_symptoms_extract( @@ -31,14 +31,11 @@ async def covid_symptoms_extract( :param cnlp_http_client: HTTPX client to use for the cNLP transformer server :return: list of NLP results encoded as FHIR observations """ - docref_id = docref["id"] - _, subject_id = fhir.unref_resource(docref["subject"]) - - encounters = docref.get("context", {}).get("encounter", []) - if not encounters: - logging.warning("No encounters for docref %s", docref_id) + try: + docref_id, encounter_id, subject_id = nlp.get_docref_info(docref) + except KeyError as exc: + logging.warning(exc) return None - _, encounter_id = fhir.unref_resource(encounters[0]) # cTAKES cache namespace history (and thus, cache invalidation history): # v1: original cTAKES processing @@ -54,7 +51,7 @@ async def covid_symptoms_extract( case TransformerModel.TERM_EXISTS: cnlp_namespace = f"{ctakes_namespace}-cnlp_term_exists_v1" case _: - logging.warning("Unknown polarity method: %s", polarity_model.value) + logging.warning("Unknown polarity method: %s", polarity_model) return None timestamp = common.datetime_now().isoformat() diff --git a/cumulus_etl/etl/studies/covid_symptom/covid_tasks.py b/cumulus_etl/etl/studies/covid_symptom/covid_tasks.py index 166b3de6..bbe2af54 100644 --- a/cumulus_etl/etl/studies/covid_symptom/covid_tasks.py +++ b/cumulus_etl/etl/studies/covid_symptom/covid_tasks.py @@ -1,14 +1,19 @@ """Define tasks for the covid_symptom study""" import itertools +import json +import logging +import os from typing import ClassVar import ctakesclient +import openai import pyarrow import rich.progress from ctakesclient.transformer import TransformerModel +from openai.types import chat -from cumulus_etl import nlp, store +from cumulus_etl import common, nlp, store from cumulus_etl.etl import tasks from cumulus_etl.etl.studies.covid_symptom import covid_ctakes @@ -74,9 +79,11 @@ def is_ed_docref(docref): return any(is_ed_coding(x) for x in codings) -class BaseCovidSymptomNlpResultsTask(tasks.BaseNlpTask): +class BaseCovidCtakesTask(tasks.BaseNlpTask): """Covid Symptom study task, to generate symptom lists from ED notes using cTAKES + a polarity check""" + tags: ClassVar = {"covid_symptom", "gpu"} + # Subclasses: set name, tags, and polarity_model yourself polarity_model = None @@ -117,7 +124,7 @@ async def prepare_task(self) -> bool: bsv_path = ctakesclient.filesystem.covid_symptoms_path() success = nlp.restart_ctakes_with_bsv(self.task_config.ctakes_overrides, bsv_path) if not success: - print(f"Skipping {self.name}.") + print(" Skipping.") self.summaries[0].had_errors = True return success @@ -197,11 +204,10 @@ def get_schema(cls, resource_type: str | None, rows: list[dict]) -> pyarrow.Sche ) -class CovidSymptomNlpResultsTask(BaseCovidSymptomNlpResultsTask): +class CovidSymptomNlpResultsTask(BaseCovidCtakesTask): """Covid Symptom study task, to generate symptom lists from ED notes using cTAKES and cnlpt negation""" name: ClassVar = "covid_symptom__nlp_results" - tags: ClassVar = {"covid_symptom", "gpu"} polarity_model: ClassVar = TransformerModel.NEGATION @classmethod @@ -210,17 +216,216 @@ async def init_check(cls) -> None: nlp.check_negation_cnlpt() -class CovidSymptomNlpResultsTermExistsTask(BaseCovidSymptomNlpResultsTask): +class CovidSymptomNlpResultsTermExistsTask(BaseCovidCtakesTask): """Covid Symptom study task, to generate symptom lists from ED notes using cTAKES and cnlpt termexists""" name: ClassVar = "covid_symptom__nlp_results_term_exists" polarity_model: ClassVar = TransformerModel.TERM_EXISTS - # Explicitly don't use any tags because this is really a "hidden" task that is mostly for comparing - # polarity model performance more than running a study. So we don't want it to be accidentally run. - tags: ClassVar = {} - @classmethod async def init_check(cls) -> None: nlp.check_ctakes() nlp.check_term_exists_cnlpt() + + +class BaseCovidGptTask(tasks.BaseNlpTask): + """Covid Symptom study task, using GPT""" + + tags: ClassVar = {"covid_symptom", "cpu"} + outputs: ClassVar = [tasks.OutputTable(resource_type=None)] + + # Overridden by child classes + model_id: ClassVar = None + + async def prepare_task(self) -> bool: + api_key = os.environ.get("AZURE_OPENAI_API_KEY") + endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT") + if not api_key or not endpoint: + if not api_key: + print(" The AZURE_OPENAI_API_KEY environment variable is not set.") + if not endpoint: + print(" The AZURE_OPENAI_ENDPOINT environment variable is not set.") + print(" Skipping.") + self.summaries[0].had_errors = True + return False + return True + + async def read_entries(self, *, progress: rich.progress.Progress = None) -> tasks.EntryIterator: + """Passes clinical notes through NLP and returns any symptoms found""" + async for orig_docref, docref, clinical_note in self.read_notes( + progress=progress, doc_check=is_ed_docref + ): + try: + docref_id, encounter_id, subject_id = nlp.get_docref_info(docref) + except KeyError as exc: + logging.warning(exc) + self.add_error(orig_docref) + continue + + client = openai.AsyncAzureOpenAI(api_version="2024-06-01") + try: + response = await nlp.cache_wrapper( + self.task_config.dir_phi, + f"{self.name}_v{self.task_version}", + clinical_note, + lambda x: chat.ChatCompletion.model_validate_json(x), + lambda x: x.model_dump_json( + indent=None, round_trip=True, exclude_unset=True, by_alias=True + ), + client.chat.completions.create, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": self.get_prompt(clinical_note)}, + ], + model=self.model_id, + seed=12345, # arbitrary, only specified to improve reproducibility + response_format={"type": "json_object"}, + ) + except openai.APIError as exc: + logging.warning(f"Could not connect to GPT for DocRef {docref['id']}: {exc}") + self.add_error(orig_docref) + continue + + if response.choices[0].finish_reason != "stop": + logging.warning( + f"GPT response didn't complete for DocRef {docref['id']}: " + f"{response.choices[0].finish_reason}" + ) + self.add_error(orig_docref) + continue + + try: + symptoms = json.loads(response.choices[0].message.content) + except json.JSONDecodeError as exc: + logging.warning(f"Could not parse GPT results for DocRef {docref['id']}: {exc}") + self.add_error(orig_docref) + continue + + yield { + "id": docref_id, # keep one results entry per docref + "docref_id": docref_id, + "encounter_id": encounter_id, + "subject_id": subject_id, + "generated_on": common.datetime_now().isoformat(), + "task_version": self.task_version, + "system_fingerprint": response.system_fingerprint, + "symptoms": { + "congestion_or_runny_nose": bool(symptoms.get("Congestion or runny nose")), + "cough": bool(symptoms.get("Cough")), + "diarrhea": bool(symptoms.get("Diarrhea")), + "dyspnea": bool(symptoms.get("Dyspnea")), + "fatigue": bool(symptoms.get("Fatigue")), + "fever_or_chills": bool(symptoms.get("Fever or chills")), + "headache": bool(symptoms.get("Headache")), + "loss_of_taste_or_smell": bool(symptoms.get("Loss of taste or smell")), + "muscle_or_body_aches": bool(symptoms.get("Muscle or body aches")), + "nausea_or_vomiting": bool(symptoms.get("Nausea or vomiting")), + "sore_throat": bool(symptoms.get("Sore throat")), + }, + } + + @staticmethod + def get_prompt(clinical_note: str) -> str: + instructions = ( + "You are a helpful assistant identifying symptoms from emergency " + "department notes that could relate to infectious respiratory diseases.\n" + "Output positively documented symptoms, looking out specifically for the " + "following: Congestion or runny nose, Cough, Diarrhea, Dyspnea, Fatigue, " + "Fever or chills, Headache, Loss of taste or smell, Muscle or body aches, " + "Nausea or vomiting, Sore throat.\nSymptoms only need to be positively " + "mentioned once to be included.\nDo not mention symptoms that are not " + "present in the note.\n\nFollow these rules:\nRule (1): Symptoms must be " + "positively documented and relevant to the presenting illness or reason " + "for visit.\nRule (2): Medical section headings must be specific to the " + "present emergency department encounter.\nInclude positive symptoms from " + 'these medical section headings: "Chief Complaint", "History of ' + 'Present Illness", "HPI", "Review of Systems", "Physical Exam", ' + '"Vital Signs", "Assessment and Plan", "Medical Decision Making".\n' + "Rule (3): Positive symptom mentions must be a definite medical synonym.\n" + 'Include positive mentions of: "anosmia", "loss of taste", "loss of ' + 'smell", "rhinorrhea", "congestion", "discharge", "nose is ' + 'dripping", "runny nose", "stuffy nose", "cough", "tussive or ' + 'post-tussive", "cough is unproductive", "productive cough", "dry ' + 'cough", "wet cough", "producing sputum", "diarrhea", "watery ' + 'stool", "fatigue", "tired", "exhausted", "weary", "malaise", ' + '"feeling generally unwell", "fever", "pyrexia", "chills", ' + '"temperature greater than or equal 100.4 Fahrenheit or 38 celsius", ' + '"Temperature >= 100.4F", "Temperature >= 38C", "headache", "HA", ' + '"migraine", "cephalgia", "head pain", "muscle or body aches", ' + '"muscle aches", "generalized aches and pains", "body aches", ' + '"myalgias", "myoneuralgia", "soreness", "generalized aches and ' + 'pains", "nausea or vomiting", "Nausea", "vomiting", "emesis", ' + '"throwing up", "queasy", "regurgitated", "shortness of breath", ' + '"difficulty breathing", "SOB", "Dyspnea", "breathing is short", ' + '"increased breathing", "labored breathing", "distressed ' + 'breathing", "sore throat", "throat pain", "pharyngeal pain", ' + '"pharyngitis", "odynophagia".\nYour reply must be parsable as JSON.\n' + 'Format your response using only the following JSON schema: {"Congestion ' + 'or runny nose": boolean, "Cough": boolean, "Diarrhea": boolean, ' + '"Dyspnea": boolean, "Fatigue": boolean, "Fever or chills": ' + 'boolean, "Headache": boolean, "Loss of taste or smell": boolean, ' + '"Muscle or body aches": boolean, "Nausea or vomiting": boolean, ' + '"Sore throat": boolean}. Each JSON key should correspond to a symptom, ' + "and each value should be true if that symptom is indicated in the " + "clinical note; false otherwise.\nNever explain yourself, and only reply " + "with JSON." + ) + return f"### Instructions ###\n{instructions}\n### Text ###\n{clinical_note}" + + @classmethod + def get_schema(cls, resource_type: str | None, rows: list[dict]) -> pyarrow.Schema: + return pyarrow.schema( + [ + pyarrow.field("id", pyarrow.string()), + pyarrow.field("docref_id", pyarrow.string()), + pyarrow.field("encounter_id", pyarrow.string()), + pyarrow.field("subject_id", pyarrow.string()), + pyarrow.field("generated_on", pyarrow.string()), + pyarrow.field("task_version", pyarrow.int32()), + pyarrow.field("system_fingerprint", pyarrow.string()), + pyarrow.field( + "symptoms", + pyarrow.struct( + [ + pyarrow.field("congestion_or_runny_nose", pyarrow.bool_()), + pyarrow.field("cough", pyarrow.bool_()), + pyarrow.field("diarrhea", pyarrow.bool_()), + pyarrow.field("dyspnea", pyarrow.bool_()), + pyarrow.field("fatigue", pyarrow.bool_()), + pyarrow.field("fever_or_chills", pyarrow.bool_()), + pyarrow.field("headache", pyarrow.bool_()), + pyarrow.field("loss_of_taste_or_smell", pyarrow.bool_()), + pyarrow.field("muscle_or_body_aches", pyarrow.bool_()), + pyarrow.field("nausea_or_vomiting", pyarrow.bool_()), + pyarrow.field("sore_throat", pyarrow.bool_()), + ], + ), + ), + ] + ) + + +class CovidSymptomNlpResultsGpt35Task(BaseCovidGptTask): + """Covid Symptom study task, using GPT3.5""" + + name: ClassVar = "covid_symptom__nlp_results_gpt35" + model_id: ClassVar = "gpt-35-turbo-0125" + + task_version: ClassVar = 1 + # Task Version History: + # ** 1 (2024-08): Initial version ** + # model: gpt-35-turbo-0125 + # seed: 12345 + + +class CovidSymptomNlpResultsGpt4Task(BaseCovidGptTask): + """Covid Symptom study task, using GPT4""" + + name: ClassVar = "covid_symptom__nlp_results_gpt4" + model_id: ClassVar = "gpt-4" + + task_version: ClassVar = 1 + # Task Version History: + # ** 1 (2024-08): Initial version ** + # model: gpt-4 + # seed: 12345 diff --git a/cumulus_etl/etl/studies/hftest/hf_tasks.py b/cumulus_etl/etl/studies/hftest/hf_tasks.py index 2b67d49a..9a8b685a 100644 --- a/cumulus_etl/etl/studies/hftest/hf_tasks.py +++ b/cumulus_etl/etl/studies/hftest/hf_tasks.py @@ -63,7 +63,9 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> task summary = await nlp.cache_wrapper( self.task_config.dir_phi, f"{self.name}_v{self.task_version}", - user_prompt, + clinical_note, + lambda x: x, # from file: just store the string + lambda x: x, # to file: just read it back nlp.llama2_prompt, system_prompt, user_prompt, diff --git a/cumulus_etl/etl/tasks/task_factory.py b/cumulus_etl/etl/tasks/task_factory.py index cd7ec361..36857475 100644 --- a/cumulus_etl/etl/tasks/task_factory.py +++ b/cumulus_etl/etl/tasks/task_factory.py @@ -22,6 +22,8 @@ def get_all_tasks() -> list[type[AnyTask]]: # Note: tasks will be run in the order listed here. return [ *get_default_tasks(), + covid_symptom.CovidSymptomNlpResultsGpt35Task, + covid_symptom.CovidSymptomNlpResultsGpt4Task, covid_symptom.CovidSymptomNlpResultsTask, covid_symptom.CovidSymptomNlpResultsTermExistsTask, hftest.HuggingFaceTestTask, diff --git a/cumulus_etl/nlp/__init__.py b/cumulus_etl/nlp/__init__.py index 65d80676..c73d9bdf 100644 --- a/cumulus_etl/nlp/__init__.py +++ b/cumulus_etl/nlp/__init__.py @@ -2,7 +2,7 @@ from .extract import TransformerModel, ctakes_extract, ctakes_httpx_client, list_polarity from .huggingface import hf_info, hf_prompt, llama2_prompt -from .utils import cache_wrapper, is_docref_valid +from .utils import cache_wrapper, get_docref_info, is_docref_valid from .watcher import ( check_ctakes, check_negation_cnlpt, diff --git a/cumulus_etl/nlp/utils.py b/cumulus_etl/nlp/utils.py index ef1b10e0..deb4573c 100644 --- a/cumulus_etl/nlp/utils.py +++ b/cumulus_etl/nlp/utils.py @@ -3,8 +3,11 @@ import hashlib import os from collections.abc import Callable +from typing import TypeVar -from cumulus_etl import common, store +from cumulus_etl import common, fhir, store + +Obj = TypeVar("Obj") def is_docref_valid(docref: dict) -> bool: @@ -15,22 +18,44 @@ def is_docref_valid(docref: dict) -> bool: return good_status and good_doc_status +def get_docref_info(docref: dict) -> (str, str, str): + """ + Returns docref_id, encounter_id, subject_id for the given DocRef. + + Raises KeyError if any of them aren't present. + """ + docref_id = docref["id"] + encounters = docref.get("context", {}).get("encounter", []) + if not encounters: + raise KeyError(f"No encounters for docref {docref_id}") + _, encounter_id = fhir.unref_resource(encounters[0]) + _, subject_id = fhir.unref_resource(docref["subject"]) + return docref_id, encounter_id, subject_id + + async def cache_wrapper( - cache_dir: str, namespace: str, content: str, method: Callable, *args, **kwargs -) -> str: + cache_dir: str, + namespace: str, + content: str, + from_file: Callable[[str], Obj], + to_file: Callable[[Obj], str], + method: Callable, + *args, + **kwargs, +) -> Obj: """Looks up an NLP result in the cache first, falling back to actually calling NLP.""" # First, what is our target path for a possible cache file cache_dir = store.Root(cache_dir, create=True) checksum = hashlib.sha256(content.encode("utf8")).hexdigest() - path = f"ctakes-cache/{namespace}/{checksum[0:4]}/sha256-{checksum}.json" # "ctakes-cache" is historical + path = f"nlp-cache/{namespace}/{checksum[0:4]}/sha256-{checksum}.cache" cache_filename = cache_dir.joinpath(path) # And try to read that file, falling back to calling the given method if a cache is not available try: - result = common.read_text(cache_filename) + result = from_file(common.read_text(cache_filename)) except (FileNotFoundError, PermissionError): result = await method(*args, **kwargs) cache_dir.makedirs(os.path.dirname(cache_filename)) - common.write_text(cache_filename, result) + common.write_text(cache_filename, to_file(result)) return result diff --git a/docs/setup/cumulus-aws-template.yaml b/docs/setup/cumulus-aws-template.yaml index 401cbfc0..65fa0d8a 100644 --- a/docs/setup/cumulus-aws-template.yaml +++ b/docs/setup/cumulus-aws-template.yaml @@ -197,6 +197,8 @@ Resources: - !Sub "s3://${S3Bucket}/${EtlSubdir}/procedure" - !Sub "s3://${S3Bucket}/${EtlSubdir}/servicerequest" - !Sub "s3://${S3Bucket}/${EtlSubdir}/covid_symptom__nlp_results" + - !Sub "s3://${S3Bucket}/${EtlSubdir}/covid_symptom__nlp_results_gpt35" + - !Sub "s3://${S3Bucket}/${EtlSubdir}/covid_symptom__nlp_results_gpt4" - !Sub "s3://${S3Bucket}/${EtlSubdir}/covid_symptom__nlp_results_term_exists" - !Sub "s3://${S3Bucket}/${EtlSubdir}/etl__completion" - !Sub "s3://${S3Bucket}/${EtlSubdir}/etl__completion_encounters" diff --git a/docs/studies/covid-symptom.md b/docs/studies/covid-symptom.md index 0404339c..49d2d66d 100644 --- a/docs/studies/covid-symptom.md +++ b/docs/studies/covid-symptom.md @@ -10,13 +10,21 @@ nav_order: 1 # The Covid Symptom Study This study uses NLP to identify symptoms of COVID-19 in clinical notes. -Specifically, it uses [cTAKES](https://ctakes.apache.org/) and -[cNLP transformers](https://github.com/Machine-Learning-for-Medical-Language/cnlp_transformers) -to identify clinical terms. -## Preparation +It allows for running different NLP strategies and comparing them: +1. [cTAKES](https://ctakes.apache.org/) and the `negation` +[cNLP transformer](https://github.com/Machine-Learning-for-Medical-Language/cnlp_transformers) +2. cTAKES and the `termexists` cNLP transformer +3. [ChatGPT](https://openai.com/chatgpt/) 3.5 +4. ChatGPT 4 -Because it uses external services like cTAKES, you will want to make sure those are ready. +Each can be run separately, and may require different preparation. +Read more below about the main approaches (cTAKES and ChatGPT). + +## cTAKES Preparation + +Because cTAKES and cNLP transformers are both services separate from the ETL, +you will want to make sure they are ready. From your git clone of the `cumulus-etl` repo, you can run the following to run those services: ```shell export UMLS_API_KEY=your-umls-api-key # don't forget to set this - cTAKES needs it @@ -24,7 +32,7 @@ docker compose --profile covid-symptom-gpu up -d ``` You'll notice the `-gpu` suffix there. -Running NLP is much, much faster with access to a GPU, +Running the transformers is much, much faster with access to a GPU, so we strongly recommend you run this on GPU-enabled hardware. And since we _are_ running the GPU profile, when you do run the ETL, @@ -36,13 +44,40 @@ docker compose run cumulus-etl-gpu … But if you can't use a GPU or you just want to test things out, you can use `--profile covid-symptom` above and the normal `cumulus-etl` run line to use the CPU. -## Task +## ChatGPT Preparation + +1. Make sure you have an Azure ChatGPT account set up. +2. Set the following environment variables: + - `AZURE_OPENAI_API_KEY` + - `AZURE_OPENAI_ENDPOINT` -There is one main task, run with `--task covid_symptom__nlp_results`. +## Running the Tasks -This will need access to clinical notes, +To run any of these individual tasks, use the following names: + +- cTAKES + negation: `covid_symptom__nlp_results` +- cTAKES + termexists: `covid_symptom__nlp_results_term_exists` +- ChatGPT 3.5: `covid_symptom__nlp_results_gpt35` +- ChatGPT 4: `covid_symptom__nlp_results_gpt4` + +For example, your Cumulus ETL command might look like: +```sh +cumulus-etl … --task=covid_symptom__nlp_results +``` + +### Clinical Notes + +All these tasks will need access to clinical notes, which are pulled fresh from your EHR (since the ETL doesn't store clinical notes). This means you will likely have to provide some other FHIR authentication arguments like `--smart-client-id` and `--fhir-url`. See `--help` for more authentication options. + +## Evaluating the Results + +See the [Cumulus Library Covid study repository](https://github.com/smart-on-fhir/cumulus-library-covid) +for more information about processing the raw NLP results that the ETL generates. + +Those instructions will help you set up Label Studio so that you can compare the +different NLP strategies against human reviewers. diff --git a/pyproject.toml b/pyproject.toml index 3bb9d65e..5ed98e82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "inscriptis < 3", "jwcrypto < 2", "label-studio-sdk < 2", + "openai < 2", "oracledb < 3", "philter-lite < 1", "pyarrow < 18", diff --git a/tests/covid_symptom/test_covid_gpt.py b/tests/covid_symptom/test_covid_gpt.py new file mode 100644 index 00000000..ed58531e --- /dev/null +++ b/tests/covid_symptom/test_covid_gpt.py @@ -0,0 +1,253 @@ +"""Tests for GPT covid symptom tasks""" + +import filecmp +import hashlib +import json +import os +from unittest import mock + +import ddt +import openai +from openai.types import chat + +from cumulus_etl.etl.studies import covid_symptom +from tests import i2b2_mock_data +from tests.etl import TaskTestCase + + +@ddt.ddt +@mock.patch.dict( + os.environ, {"AZURE_OPENAI_API_KEY": "test-key", "AZURE_OPENAI_ENDPOINT": "test-endpoint"} +) +class TestCovidSymptomGptResultsTask(TaskTestCase): + """Test case for CovidSymptomNlpResultsGpt*Task""" + + def setUp(self): + super().setUp() + self.mock_client_factory = self.patch("openai.AsyncAzureOpenAI") + self.mock_client = mock.AsyncMock() + self.mock_client_factory.return_value = self.mock_client + self.mock_create = self.mock_client.chat.completions.create + self.responses = [] + + def prep_docs(self, docref: dict | None = None): + """Create two docs for input""" + docref = docref or i2b2_mock_data.documentreference("foo") + self.make_json("DocumentReference", "1", **docref) + self.make_json("DocumentReference", "2", **i2b2_mock_data.documentreference("bar")) + + def mock_response( + self, found: bool = False, finish_reason: str = "stop", content: str | None = None + ) -> None: + symptoms = { + "Congestion or runny nose": found, + "Cough": found, + "Diarrhea": found, + "Dyspnea": found, + "Fatigue": found, + "Fever or chills": found, + "Headache": found, + "Loss of taste or smell": found, + "Muscle or body aches": found, + "Nausea or vomiting": found, + "Sore throat": found, + } + if content is None: + content = json.dumps(symptoms) + + self.responses.append( + chat.ChatCompletion( + id="test-id", + choices=[ + { + "finish_reason": finish_reason, + "index": 0, + "message": {"content": content, "role": "assistant"}, + } + ], + created=1723143708, + model="test-model", + object="chat.completion", + system_fingerprint="test-fp", + ), + ) + self.mock_create.side_effect = self.responses + + async def assert_failed_doc(self, msg: str): + task = covid_symptom.CovidSymptomNlpResultsGpt35Task(self.job_config, self.scrubber) + with self.assertLogs(level="WARN") as cm: + await task.run() + + # Confirm we printed a warning + self.assertEqual(len(cm.output), 1) + self.assertRegex(cm.output[0], msg) + + # Confirm we flagged and recorded the error + self.assertTrue(task.summaries[0].had_errors) + self.assertTrue( + filecmp.cmp( + f"{self.input_dir}/1.ndjson", f"{self.errors_dir}/{task.name}/nlp-errors.ndjson" + ) + ) + + # Confirm that we did write the second docref out - that we continued instead of exiting. + self.assertEqual(self.format.write_records.call_count, 1) + batch = self.format.write_records.call_args[0][0] + self.assertEqual(len(batch.rows), 1) + self.assertEqual(batch.rows[0]["id"], self.codebook.db.resource_hash("2")) + + @ddt.data( + # env vars to set, success + (["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT"], True), + (["AZURE_OPENAI_API_KEY"], False), + (["AZURE_OPENAI_ENDPOINT"], False), + ) + @ddt.unpack + async def test_requires_env(self, names, success): + task = covid_symptom.CovidSymptomNlpResultsGpt35Task(self.job_config, self.scrubber) + env = {name: "content" for name in names} + with mock.patch.dict(os.environ, env, clear=True): + self.assertEqual(await task.prepare_task(), success) + self.assertEqual(task.summaries[0].had_errors, not success) + + async def test_gpt4_changes(self): + """ + Verify anything that is GPT4 specific. + + The rest of the tests work with gpt 3.5. + """ + self.make_json("DocumentReference", "doc", **i2b2_mock_data.documentreference()) + self.mock_response() + + task = covid_symptom.CovidSymptomNlpResultsGpt4Task(self.job_config, self.scrubber) + await task.run() + + # Only the model (and sometimes the task version) are unique to the gpt4 task + self.assertEqual(self.mock_create.call_args_list[0][1]["model"], "gpt-4") + batch = self.format.write_records.call_args[0][0] + self.assertEqual(batch.rows[0]["task_version"], 1) + + async def test_happy_path(self): + self.prep_docs() + self.mock_response(found=True) + self.mock_response(found=False) + + task = covid_symptom.CovidSymptomNlpResultsGpt35Task(self.job_config, self.scrubber) + await task.run() + + # Confirm we send the correct params to the server + prompt = covid_symptom.CovidSymptomNlpResultsGpt35Task.get_prompt("foo") + self.assertEqual( + hashlib.md5(prompt.encode("utf8")).hexdigest(), + "8c5025f98a6cfc94dfedbf10a82396ae", # catch unexpected prompt changes + ) + self.assertEqual(self.mock_create.call_count, 2) + self.assertEqual( + { + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + "model": "gpt-35-turbo-0125", + "seed": 12345, + "response_format": {"type": "json_object"}, + }, + self.mock_create.call_args_list[0][1], + ) + + # Confirm we formatted the output correctly + self.assertEqual(self.format.write_records.call_count, 1) + batch = self.format.write_records.call_args[0][0] + self.assertEqual( + [ + { + "id": self.codebook.db.resource_hash("1"), + "docref_id": self.codebook.db.resource_hash("1"), + "encounter_id": self.codebook.db.resource_hash("67890"), + "subject_id": self.codebook.db.resource_hash("12345"), + "generated_on": "2021-09-14T21:23:45+00:00", + "task_version": 1, + "system_fingerprint": "test-fp", + "symptoms": { + "congestion_or_runny_nose": True, + "cough": True, + "diarrhea": True, + "dyspnea": True, + "fatigue": True, + "fever_or_chills": True, + "headache": True, + "loss_of_taste_or_smell": True, + "muscle_or_body_aches": True, + "nausea_or_vomiting": True, + "sore_throat": True, + }, + }, + { + "id": self.codebook.db.resource_hash("2"), + "docref_id": self.codebook.db.resource_hash("2"), + "encounter_id": self.codebook.db.resource_hash("67890"), + "subject_id": self.codebook.db.resource_hash("12345"), + "generated_on": "2021-09-14T21:23:45+00:00", + "task_version": 1, + "system_fingerprint": "test-fp", + "symptoms": { + "congestion_or_runny_nose": False, + "cough": False, + "diarrhea": False, + "dyspnea": False, + "fatigue": False, + "fever_or_chills": False, + "headache": False, + "loss_of_taste_or_smell": False, + "muscle_or_body_aches": False, + "nausea_or_vomiting": False, + "sore_throat": False, + }, + }, + ], + batch.rows, + ) + self.assertEqual(batch.groups, set()) + + async def test_caching(self): + # Provide the fist docref using the same content as the second. + # So the second will be cached. + self.prep_docs(i2b2_mock_data.documentreference("bar")) + self.mock_response(found=True) + + task = covid_symptom.CovidSymptomNlpResultsGpt35Task(self.job_config, self.scrubber) + await task.run() + + # Confirm we only asked the server once + self.assertEqual(self.mock_create.call_count, 1) + + # Confirm we round-tripped the data correctly + self.assertEqual(self.format.write_records.call_count, 1) + batch = self.format.write_records.call_args[0][0] + self.assertEqual(len(batch.rows), 2) + self.assertEqual(batch.rows[0]["symptoms"], batch.rows[1]["symptoms"]) + + async def test_no_encounter_error(self): + docref = i2b2_mock_data.documentreference("foo") + del docref["context"] + self.prep_docs(docref) + self.mock_response() + await self.assert_failed_doc("No encounters for docref") + + async def test_network_error(self): + self.prep_docs() + self.responses.append(openai.APIError("oops", mock.MagicMock(), body=None)) + self.mock_response() + await self.assert_failed_doc("Could not connect to GPT for DocRef .*: oops") + + async def test_incomplete_response_error(self): + self.prep_docs() + self.mock_response(finish_reason="length") + self.mock_response() + await self.assert_failed_doc("GPT response didn't complete for DocRef .*: length") + + async def test_bad_json_error(self): + self.prep_docs() + self.mock_response(content='{"hello"') + self.mock_response() + await self.assert_failed_doc("Could not parse GPT results for DocRef .*: Expecting ':'") diff --git a/tests/covid_symptom/test_covid_results.py b/tests/covid_symptom/test_covid_results.py index c2bfe8cb..3d81712b 100644 --- a/tests/covid_symptom/test_covid_results.py +++ b/tests/covid_symptom/test_covid_results.py @@ -1,6 +1,7 @@ """Tests for etl/studies/covid_symptom/covid_tasks.py""" import os +from unittest import mock import cumulus_fhir_support import ddt @@ -20,6 +21,82 @@ def setUp(self): super().setUp() self.job_config.ctakes_overrides = self.ctakes_overrides.name + async def test_prepare_failure(self): + """Verify that if ctakes can't be restarted, we skip""" + task = covid_symptom.CovidSymptomNlpResultsTask(self.job_config, self.scrubber) + with mock.patch("cumulus_etl.nlp.restart_ctakes_with_bsv", return_value=False): + self.assertFalse(await task.prepare_task()) + self.assertEqual(task.summaries[0].had_errors, True) + + @mock.patch("cumulus_etl.nlp.check_ctakes") + @mock.patch("cumulus_etl.nlp.check_negation_cnlpt") + async def test_negation_init_check(self, mock_cnlpt, mock_ctakes): + await covid_symptom.CovidSymptomNlpResultsTask.init_check() + self.assertEqual(mock_ctakes.call_count, 1) + self.assertEqual(mock_cnlpt.call_count, 1) + + @mock.patch("cumulus_etl.nlp.check_ctakes") + @mock.patch("cumulus_etl.nlp.check_term_exists_cnlpt") + async def test_term_exists_init_check(self, mock_cnlpt, mock_ctakes): + await covid_symptom.CovidSymptomNlpResultsTermExistsTask.init_check() + self.assertEqual(mock_ctakes.call_count, 1) + self.assertEqual(mock_cnlpt.call_count, 1) + + @mock.patch("cumulus_etl.nlp.ctakes_extract", side_effect=ValueError("oops")) + async def test_ctakes_error(self, mock_extract): + """Verify we skip docrefs when a cTAKES error happens""" + self.make_json("DocumentReference", "doc", **i2b2_mock_data.documentreference()) + + task = covid_symptom.CovidSymptomNlpResultsTask(self.job_config, self.scrubber) + with self.assertLogs(level="WARN") as cm: + await task.run() + + self.assertEqual(len(cm.output), 1) + self.assertRegex( + cm.output[0], r"Could not extract symptoms for docref .* \(ValueError\): oops" + ) + + # Confirm that we skipped the doc + self.assertEqual(self.format.write_records.call_count, 1) + batch = self.format.write_records.call_args[0][0] + self.assertEqual(len(batch.rows), 0) + + @mock.patch("cumulus_etl.nlp.list_polarity", side_effect=ValueError("oops")) + async def test_cnlpt_error(self, mock_extract): + """Verify we skip docrefs when a cNLPT error happens""" + self.make_json("DocumentReference", "doc", **i2b2_mock_data.documentreference()) + + task = covid_symptom.CovidSymptomNlpResultsTask(self.job_config, self.scrubber) + with self.assertLogs(level="WARN") as cm: + await task.run() + + self.assertEqual(len(cm.output), 1) + self.assertRegex( + cm.output[0], r"Could not check polarity for docref .* \(ValueError\): oops" + ) + + # Confirm that we skipped the doc + self.assertEqual(self.format.write_records.call_count, 1) + batch = self.format.write_records.call_args[0][0] + self.assertEqual(len(batch.rows), 0) + + async def test_extract_polarity_type_error(self): + """Verify we bail if a bogus polarity model was given""" + self.make_json("DocumentReference", "doc", **i2b2_mock_data.documentreference()) + + task = covid_symptom.CovidSymptomNlpResultsTask(self.job_config, self.scrubber) + task.polarity_model = None + with self.assertLogs(level="WARN") as cm: + await task.run() + + self.assertEqual(len(cm.output), 1) + self.assertRegex(cm.output[0], "Unknown polarity method: None") + + # Confirm that we skipped the doc + self.assertEqual(self.format.write_records.call_count, 1) + batch = self.format.write_records.call_args[0][0] + self.assertEqual(len(batch.rows), 0) + async def test_unknown_modifier_extensions_skipped_for_nlp_symptoms(self): """Verify we ignore unknown modifier extensions during a custom read task (nlp symptoms)""" docref0 = i2b2_mock_data.documentreference() @@ -220,8 +297,7 @@ async def test_group_values_noted(self): async def test_zero_symptoms(self): """Verify that we write out a marker for DocRefs we did examine, even if no symptoms appeared""" docref = i2b2_mock_data.documentreference() - docref_no_text = i2b2_mock_data.documentreference() - docref_no_text["content"][0]["attachment"]["data"] = "" + docref_no_text = i2b2_mock_data.documentreference("") self.make_json("DocumentReference", "zero-symptoms", **docref_no_text) self.make_json("DocumentReference", "not-examined", **docref, docStatus="preliminary") @@ -246,7 +322,7 @@ class TestCovidSymptomEtl(BaseEtlSimple): DATA_ROOT = "covid" async def test_basic_run(self): - await self.run_etl(tags=["covid_symptom"]) + await self.run_etl(tasks=["covid_symptom__nlp_results"]) self.assert_output_equal() async def test_term_exists_task(self): diff --git a/tests/etl/base.py b/tests/etl/base.py index bf5a4fa3..44c61d9f 100644 --- a/tests/etl/base.py +++ b/tests/etl/base.py @@ -170,8 +170,6 @@ def make_formatter(dbname: str, **kwargs): def make_json(self, resource_type, resource_id, **kwargs): self.json_file_count += 1 - filename = f"{self.json_file_count}.ndjson" - common.write_json( - os.path.join(self.input_dir, filename), - {"resourceType": resource_type, **kwargs, "id": resource_id}, - ) + filename = f"{self.input_dir}/{self.json_file_count}.ndjson" + with common.NdjsonWriter(filename) as writer: + writer.write({"resourceType": resource_type, **kwargs, "id": resource_id}) diff --git a/tests/hftest/test_hftask.py b/tests/hftest/test_hftask.py index ddaa74b8..7daa4032 100644 --- a/tests/hftest/test_hftask.py +++ b/tests/hftest/test_hftask.py @@ -96,8 +96,8 @@ async def test_caching(self, respx_mock): await hftest.HuggingFaceTestTask(self.job_config, self.scrubber).run() self.assertEqual(1, route.call_count) - cache_dir = f"{self.phi_dir}/ctakes-cache/hftest__summary_v0/06ee/" - cache_file = f"{cache_dir}/sha256-06ee538c626fbf4bdcec2199b7225c8034f26e2b46a7b5cb7ab385c8e8c00efa.json" + cache_dir = f"{self.phi_dir}/nlp-cache/hftest__summary_v0/06ee/" + cache_file = f"{cache_dir}/sha256-06ee538c626fbf4bdcec2199b7225c8034f26e2b46a7b5cb7ab385c8e8c00efa.cache" self.assertEqual("Patient has a fever.", common.read_text(cache_file)) await hftest.HuggingFaceTestTask(self.job_config, self.scrubber).run() diff --git a/tests/i2b2_mock_data.py b/tests/i2b2_mock_data.py index 0d38c1d4..d4448959 100644 --- a/tests/i2b2_mock_data.py +++ b/tests/i2b2_mock_data.py @@ -71,8 +71,10 @@ def documentreference_dim() -> transform.ObservationFact: ) -def documentreference() -> dict: - return transform.to_fhir_documentreference(documentreference_dim()) +def documentreference(text: str = DOCREF_TEXT) -> dict: + dim = documentreference_dim() + dim.observation_blob = text + return transform.to_fhir_documentreference(dim) def observation_dim() -> transform.ObservationFact: