From 166e581db7032a2879967f37827fead168b9e36e Mon Sep 17 00:00:00 2001 From: Michael Terry Date: Mon, 13 Feb 2023 12:20:45 -0500 Subject: [PATCH] feat: support DocumentReference URL attachments Previously, we only supported DocumentReferences with inlined notes. Now, we will properly download URL attachments. Also: - Expands the mimetypes we look for from just text/plain to text/plain, text/*, and application/xml in that order. - Adds --fhir-url to specify the FHIR server when you are using an externally downloaded folder - Renames and moves the BackendServiceServer in the ndjson loader to FhirClient in toplevel code. - Moves some credential argument handling out of the Ndjson loader into etl.py code. - Eased credential requirement checking, so that you don't even need credentials, as long as the server doesn't complain (e.g. we can even run against Cerner's public sandbox that doesn't need auth) - Made tasks async. - Bumps FhirClient's timeout from 5 seconds to 5 minutes, for safety Note: - This implementation is a little naive. It just downloads each URL as it sees them, with no caching. If we grow another NLP task, we'll to be more clever. And even without that, we could maybe be smarter about looking for a cached NLP result first. --- cumulus/config.py | 4 +- cumulus/ctakes.py | 96 +++++- cumulus/deid/ms-config.json | 1 + cumulus/deid/scrubber.py | 4 +- cumulus/errors.py | 1 + cumulus/etl.py | 123 ++++--- .../backend_service.py => fhir_client.py} | 126 ++++--- cumulus/loaders/fhir/bulk_export.py | 12 +- cumulus/loaders/fhir/fhir_ndjson.py | 33 +- cumulus/tasks.py | 48 ++- docs/howtos/epic-tips.md | 14 + docs/howtos/run-cumulus-etl.md | 1 + tests/test_bulk_export.py | 307 +++++++++++------- tests/test_tasks.py | 106 ++++-- 14 files changed, 585 insertions(+), 291 deletions(-) rename cumulus/{loaders/fhir/backend_service.py => fhir_client.py} (77%) diff --git a/cumulus/config.py b/cumulus/config.py index a83a6da4..f6713847 100644 --- a/cumulus/config.py +++ b/cumulus/config.py @@ -5,7 +5,7 @@ from socket import gethostname from typing import List -from cumulus import common, formats, store +from cumulus import common, fhir_client, formats, store class JobConfig: @@ -25,6 +25,7 @@ def __init__( dir_phi: str, input_format: str, output_format: str, + client: fhir_client.FhirClient, timestamp: datetime.datetime = None, comment: str = None, batch_size: int = 1, # this default is never really used - overridden by command line args @@ -36,6 +37,7 @@ def __init__( self.dir_phi = dir_phi self._input_format = input_format self._output_format = output_format + self.client = client self.timestamp = common.timestamp_filename(timestamp) self.hostname = gethostname() self.comment = comment or "" diff --git a/cumulus/ctakes.py b/cumulus/ctakes.py index 7a85d5dc..ffc967d4 100644 --- a/cumulus/ctakes.py +++ b/cumulus/ctakes.py @@ -5,17 +5,18 @@ import hashlib import logging import os -from typing import List +from typing import List, Optional import ctakesclient -from cumulus import common, fhir_common, store +from cumulus import common, fhir_client, fhir_common, store -def covid_symptoms_extract(cache: store.Root, docref: dict) -> List[dict]: +async def covid_symptoms_extract(client: fhir_client.FhirClient, cache: store.Root, docref: dict) -> List[dict]: """ Extract a list of Observations from NLP-detected symptoms in physician notes + :param client: a client ready to talk to a FHIR server :param cache: Where to cache NLP results :param docref: Physician Note :return: list of NLP results encoded as FHIR observations @@ -25,20 +26,14 @@ def covid_symptoms_extract(cache: store.Root, docref: dict) -> List[dict]: encounters = docref.get("context", {}).get("encounter", []) if not encounters: - logging.warning("No valid encounters for symptoms") # ideally would print identifier, but it's PHI... + logging.warning("No encounters for docref %s", docref_id) return [] _, encounter_id = fhir_common.unref_resource(encounters[0]) # Find the physician note among the attachments - for content in docref["content"]: - if "contentType" in content["attachment"] and "data" in content["attachment"]: - mimetype, params = cgi.parse_header(content["attachment"]["contentType"]) - if mimetype == "text/plain": # just grab first text we find - charset = params.get("charset", "utf8") - physician_note = base64.standard_b64decode(content["attachment"]["data"]).decode(charset) - break - else: - logging.warning("No text/plain content in docref %s", docref_id) + physician_note = await get_docref_note(client, [content["attachment"] for content in docref["content"]]) + if physician_note is None: + logging.warning("No text content in docref %s", docref_id) return [] # Strip this "line feed" character that often shows up in notes and is confusing for cNLP. @@ -97,6 +92,81 @@ def is_covid_match(m: ctakesclient.typesystem.MatchText): return positive_matches +def parse_content_type(content_type: str) -> (str, str): + """Returns (mimetype, encoding)""" + # TODO: switch to message.Message parsing, since cgi is deprecated + mimetype, params = cgi.parse_header(content_type) + return mimetype, params.get("charset", "utf8") + + +def mimetype_priority(mimetype: str) -> int: + """ + Returns priority of mimetypes for docref notes. + + 0 means "ignore" + Higher numbers are higher priority + """ + if mimetype == "text/plain": + return 3 + elif mimetype.startswith("text/"): + return 2 + elif mimetype in ("application/xml", "application/xhtml+xml"): + return 1 + return 0 + + +async def get_docref_note(client: fhir_client.FhirClient, attachments: List[dict]) -> Optional[str]: + # Find the best attachment to use, based on mimetype. + # We prefer basic text documents, to avoid confusing cTAKES with extra formatting (like ). + best_attachment_index = -1 + best_attachment_priority = 0 + for index, attachment in enumerate(attachments): + if "contentType" in attachment: + mimetype, _ = parse_content_type(attachment["contentType"]) + priority = mimetype_priority(mimetype) + if priority > best_attachment_priority: + best_attachment_priority = priority + best_attachment_index = index + + if best_attachment_index >= 0: + return await get_docref_note_from_attachment(client, attachments[best_attachment_index]) + + # We didn't find _any_ of our target text content types. + # A content type isn't required by the spec with external URLs... so it's possible an unmarked link could be good. + # But let's optimistically enforce the need for a content type ourselves by bailing here. + # If we find a real-world need to be more permissive, we can change this later. + # But note that if we do, we'll need to handle downloading Binary FHIR objects, in addition to arbitrary URLs. + return None + + +async def get_docref_note_from_attachment(client: fhir_client.FhirClient, attachment: dict) -> Optional[str]: + """ + Decodes or downloads a note from an attachment. + + Note that it is assumed a contentType is provided. + + :returns: the attachment's note text + """ + mimetype, charset = parse_content_type(attachment["contentType"]) + + if "data" in attachment: + return base64.standard_b64decode(attachment["data"]).decode(charset) + + # TODO: At some point we should centralize the downloading of attachments -- once we have multiple NLP tasks, + # we may not want to re-download the overlapping notes. When we do that, it should not be part of our bulk + # exporter, since we may be given already-exported ndjson. + # + # TODO: There are future optimizations to try to use our ctakes cache to avoid downloading in the first place: + # - use attachment["hash"] if available (algorithm mismatch though... maybe we should switch to sha1...) + # - send a HEAD request with "Want-Digest: sha-256" but Cerner at least does not support that + if "url" in attachment: + # We need to pass Accept to get the raw data, not a Binary object. See https://www.hl7.org/fhir/binary.html + response = await client.request("GET", attachment["url"], headers={"Accept": mimetype}) + return response.text + + return None + + def extract(cache: store.Root, namespace: str, sentence: str) -> ctakesclient.typesystem.CtakesJSON: """ This is a version of ctakesclient.client.extract() that also uses a cache diff --git a/cumulus/deid/ms-config.json b/cumulus/deid/ms-config.json index b0069ef5..f67967bd 100644 --- a/cumulus/deid/ms-config.json +++ b/cumulus/deid/ms-config.json @@ -5,6 +5,7 @@ "fhirPathRules": [ {"path": "nodesByName('modifierExtension')", "method": "keep", "comment": "Cumulus: keep these so we can ignore resources with modifiers we don't understand"}, {"path": "DocumentReference.nodesByType('Attachment').data", "method": "keep", "comment": "Cumulus: needed to run NLP on physician notes"}, + {"path": "DocumentReference.nodesByType('Attachment').url", "method": "keep", "comment": "Cumulus: needed to run NLP on physician notes"}, {"path": "Patient.extension.where(url='http://hl7.org/fhir/us/core/StructureDefinition/us-core-birthsex')", "method": "keep", "comment": "Cumulus: useful for studies"}, {"path": "Patient.extension.where(url='http://hl7.org/fhir/us/core/StructureDefinition/us-core-ethnicity')", "method": "keep", "comment": "Cumulus: useful for studies"}, {"path": "Patient.extension.where(url='http://hl7.org/fhir/us/core/StructureDefinition/us-core-genderIdentity')", "method": "keep", "comment": "Cumulus: useful for studies"}, diff --git a/cumulus/deid/scrubber.py b/cumulus/deid/scrubber.py index 979fd9f2..1e831b9e 100644 --- a/cumulus/deid/scrubber.py +++ b/cumulus/deid/scrubber.py @@ -140,8 +140,8 @@ def _check_ids(self, node_path: str, node: dict, key: str, value: Any) -> None: @staticmethod def _check_attachments(node_path: str, node: dict, key: str) -> None: """Strip any attachment data""" - if node_path == "root.content.attachment" and key == "data": - del node["data"] + if node_path == "root.content.attachment" and key in {"data", "url"}: + del node[key] @staticmethod def _check_security(node_path: str, node: dict, key: str, value: Any) -> None: diff --git a/cumulus/errors.py b/cumulus/errors.py index ec5459cf..14514013 100644 --- a/cumulus/errors.py +++ b/cumulus/errors.py @@ -16,3 +16,4 @@ TASK_SET_EMPTY = 21 ARGS_CONFLICT = 22 ARGS_INVALID = 23 +FHIR_URL_MISSING = 24 diff --git a/cumulus/etl.py b/cumulus/etl.py index a48696c0..426d229e 100644 --- a/cumulus/etl.py +++ b/cumulus/etl.py @@ -11,12 +11,12 @@ import sys import tempfile import time -from typing import List, Type +from typing import Iterable, List, Type from urllib.parse import urlparse import ctakesclient -from cumulus import common, context, deid, errors, loaders, store, tasks +from cumulus import common, context, deid, errors, fhir_client, loaders, store, tasks from cumulus.config import JobConfig, JobSummary @@ -27,9 +27,7 @@ ############################################################################### -async def load_and_deidentify( - loader: loaders.Loader, selected_tasks: List[Type[tasks.EtlTask]] -) -> tempfile.TemporaryDirectory: +async def load_and_deidentify(loader: loaders.Loader, resources: Iterable[str]) -> tempfile.TemporaryDirectory: """ Loads the input directory and does a first-pass de-identification @@ -37,17 +35,14 @@ async def load_and_deidentify( :returns: a temporary directory holding the de-identified files in FHIR ndjson format """ - # Grab a list of all required resource types for the tasks we are running - required_resources = set(t.resource for t in selected_tasks) - # First step is loading all the data into a local ndjson format - loaded_dir = await loader.load_all(list(required_resources)) + loaded_dir = await loader.load_all(list(resources)) # Second step is de-identifying that data (at a bulk level) return await deid.Scrubber.scrub_bulk_data(loaded_dir.name) -def etl_job(config: JobConfig, selected_tasks: List[Type[tasks.EtlTask]]) -> List[JobSummary]: +async def etl_job(config: JobConfig, selected_tasks: List[Type[tasks.EtlTask]]) -> List[JobSummary]: """ :param config: job config :param selected_tasks: the tasks to run @@ -58,7 +53,7 @@ def etl_job(config: JobConfig, selected_tasks: List[Type[tasks.EtlTask]]) -> Lis scrubber = deid.Scrubber(config.dir_phi) for task_class in selected_tasks: task = task_class(config, scrubber) - summary = task.run() + summary = await task.run() summary_list.append(summary) path = os.path.join(config.dir_job_config(), f"{summary.label}.json") @@ -195,6 +190,9 @@ def make_parser() -> argparse.ArgumentParser: metavar="PATH", help="Bearer token for custom bearer authentication", ) + export.add_argument( + "--fhir-url", metavar="URL", help="FHIR server base URL, only needed if you exported separately" + ) export.add_argument("--since", help="Start date for export from the FHIR server") export.add_argument("--until", help="End date for export from the FHIR server") @@ -213,6 +211,39 @@ def make_parser() -> argparse.ArgumentParser: return parser +def create_fhir_client(args, root_input, resources): + client_base_url = args.fhir_url + if root_input.protocol in {"http", "https"}: + if args.fhir_url and not root_input.path.startswith(args.fhir_url): + print( + "You provided both an input FHIR server and a different --fhir-url. Try dropping --fhir-url.", + file=sys.stderr, + ) + raise SystemExit(errors.ARGS_CONFLICT) + client_base_url = root_input.path + + try: + try: + # Try to load client ID from file first (some servers use crazy long ones, like SMART's bulk-data-server) + smart_client_id = common.read_text(args.smart_client_id).strip() if args.smart_client_id else None + except FileNotFoundError: + smart_client_id = args.smart_client_id + + smart_jwks = common.read_json(args.smart_jwks) if args.smart_jwks else None + bearer_token = common.read_text(args.bearer_token).strip() if args.bearer_token else None + except OSError as exc: + print(exc, file=sys.stderr) + raise SystemExit(errors.ARGS_INVALID) from exc + + return fhir_client.FhirClient( + client_base_url, + resources, + client_id=smart_client_id, + jwks=smart_jwks, + bearer_token=bearer_token, + ) + + async def main(args: List[str]): parser = make_parser() args = parser.parse_args(args) @@ -233,45 +264,49 @@ async def main(args: List[str]): job_context = context.JobContext(root_phi.joinpath("context.json")) job_datetime = common.datetime_now() # grab timestamp before we do anything - if args.input_format == "i2b2": - config_loader = loaders.I2b2Loader(root_input, args.batch_size) - else: - config_loader = loaders.FhirNdjsonLoader( - root_input, - client_id=args.smart_client_id, - jwks=args.smart_jwks, - bearer_token=args.bearer_token, - since=args.since, - until=args.until, - ) - # Check which tasks are being run, allowing comma-separated values task_names = args.task and set(itertools.chain.from_iterable(t.split(",") for t in args.task)) task_filters = args.task_filter and list(itertools.chain.from_iterable(t.split(",") for t in args.task_filter)) selected_tasks = tasks.EtlTask.get_selected_tasks(task_names, task_filters) - # Pull down resources and run the MS tool on them - deid_dir = await load_and_deidentify(config_loader, selected_tasks) - - # Prepare config for jobs - config = JobConfig( - args.dir_input, - deid_dir.name, - args.dir_output, - args.dir_phi, - args.input_format, - args.output_format, - comment=args.comment, - batch_size=args.batch_size, - timestamp=job_datetime, - tasks=[t.name for t in selected_tasks], - ) - common.write_json(config.path_config(), config.as_json(), indent=4) - common.print_header("Configuration:") - print(json.dumps(config.as_json(), indent=4)) + # Grab a list of all required resource types for the tasks we are running + required_resources = set(t.resource for t in selected_tasks) + + # Create a client to talk to a FHIR server. + # This is useful even if we aren't doing a bulk export, because some resources like DocumentReference can still + # reference external resources on the server (like the document text). + # If we don't need this client (e.g. we're using local data and don't download any attachments), this is a no-op. + client = create_fhir_client(args, root_input, required_resources) + + async with client: + if args.input_format == "i2b2": + config_loader = loaders.I2b2Loader(root_input, args.batch_size) + else: + config_loader = loaders.FhirNdjsonLoader(root_input, client, since=args.since, until=args.until) + + # Pull down resources and run the MS tool on them + deid_dir = await load_and_deidentify(config_loader, required_resources) + + # Prepare config for jobs + config = JobConfig( + args.dir_input, + deid_dir.name, + args.dir_output, + args.dir_phi, + args.input_format, + args.output_format, + client, + comment=args.comment, + batch_size=args.batch_size, + timestamp=job_datetime, + tasks=[t.name for t in selected_tasks], + ) + common.write_json(config.path_config(), config.as_json(), indent=4) + common.print_header("Configuration:") + print(json.dumps(config.as_json(), indent=4)) - # Finally, actually run the meat of the pipeline! (Filtered down to requested tasks) - summaries = etl_job(config, selected_tasks) + # Finally, actually run the meat of the pipeline! (Filtered down to requested tasks) + summaries = await etl_job(config, selected_tasks) # Print results to the console common.print_header("Results:") diff --git a/cumulus/loaders/fhir/backend_service.py b/cumulus/fhir_client.py similarity index 77% rename from cumulus/loaders/fhir/backend_service.py rename to cumulus/fhir_client.py index 3ce38d27..0e96a399 100644 --- a/cumulus/loaders/fhir/backend_service.py +++ b/cumulus/fhir_client.py @@ -1,16 +1,15 @@ -"""Support for SMART App Launch Backend Services""" +"""HTTP client that talk to a FHIR server""" -import abc import re import sys import time import urllib.parse import uuid from json import JSONDecodeError -from typing import List, Optional +from typing import Iterable, Optional +import fhirclient.client import httpx -from fhirclient.client import FHIRClient from jwcrypto import jwk, jwt from cumulus import errors @@ -20,22 +19,42 @@ class FatalError(Exception): """An unrecoverable error""" -class Auth(abc.ABC): - """Abstracted authentication for a FHIR server""" +def _urljoin(base: str, path: str) -> str: + """Basically just urllib.parse.urljoin, but with some extra error checking""" + path_is_absolute = bool(urllib.parse.urlparse(path).netloc) + if path_is_absolute: + return path - @abc.abstractmethod - async def authorize(self, session: httpx.AsyncClient) -> None: + if not base: + print("You must provide a base FHIR server URL with --fhir-url", file=sys.stderr) + raise SystemExit(errors.FHIR_URL_MISSING) + return urllib.parse.urljoin(base, path) + + +class Auth: + """Abstracted authentication for a FHIR server. By default, does nothing.""" + + async def authorize(self, session: httpx.AsyncClient, reauthorize=False) -> None: """Authorize (or re-authorize) against the server""" + del session + + if reauthorize: + # Abort because we clearly need authentication tokens, but have not been given any parameters for them. + print( + "You must provide some authentication parameters (like --smart-client-id) to connect to a server.", + file=sys.stderr, + ) + raise SystemExit(errors.SMART_CREDENTIALS_MISSING) - @abc.abstractmethod def sign_headers(self, headers: Optional[dict]) -> dict: """Add signature token to request headers""" + return headers class JwksAuth(Auth): """Authentication with a JWK Set (typical backend service profile)""" - def __init__(self, server_root: str, client_id: str, jwks: dict, resources: List[str]): + def __init__(self, server_root: str, client_id: str, jwks: dict, resources: Iterable[str]): super().__init__() self._server_root = server_root self._client_id = client_id @@ -44,7 +63,7 @@ def __init__(self, server_root: str, client_id: str, jwks: dict, resources: List self._server = None self._token_endpoint = None - async def authorize(self, session: httpx.AsyncClient) -> None: + async def authorize(self, session: httpx.AsyncClient, reauthorize=False) -> None: """ Authenticates against a SMART FHIR server using the Backend Services profile. @@ -53,13 +72,13 @@ async def authorize(self, session: httpx.AsyncClient) -> None: # Have we authorized before? if self._token_endpoint is None: self._token_endpoint = await self._get_token_endpoint(session) - if self._server is not None and self._server.reauthorize(): + if reauthorize and self._server.reauthorize(): return # Else we must not have been issued a refresh token, let's just authorize from scratch below signed_jwt = self._make_signed_jwt() scope = " ".join([f"system/{resource}.read" for resource in self._resources]) - client = FHIRClient( + client = fhirclient.client.FHIRClient( settings={ "api_base": self._server_root, "app_id": self._client_id, @@ -100,7 +119,7 @@ async def _get_token_endpoint(self, session: httpx.AsyncClient) -> str: :returns: URL for the server's oauth2 token endpoint """ response = await session.get( - urllib.parse.urljoin(self._server_root, ".well-known/smart-configuration"), + _urljoin(self._server_root, ".well-known/smart-configuration"), headers={ "Accept": "application/json", }, @@ -161,7 +180,7 @@ def __init__(self, bearer_token: str): super().__init__() self._bearer_token = bearer_token - async def authorize(self, session: httpx.AsyncClient) -> None: + async def authorize(self, session: httpx.AsyncClient, reauthorize=False) -> None: pass def sign_headers(self, headers: Optional[dict]) -> dict: @@ -170,9 +189,11 @@ def sign_headers(self, headers: Optional[dict]) -> dict: return headers -class BackendServiceServer: +class FhirClient: """ - Manages authentication and requests for a server that supports the Backend Service SMART profile. + Manages authentication and requests for a FHIR server. + + Supports a few different auth methods, but most notably the Backend Service SMART profile. Use this as a context manager (like you would an httpx.AsyncClient instance). @@ -180,7 +201,12 @@ class BackendServiceServer: """ def __init__( - self, url: str, resources: List[str], client_id: str = None, jwks: dict = None, bearer_token: str = None + self, + url: Optional[str], + resources: Iterable[str], + client_id: str = None, + jwks: dict = None, + bearer_token: str = None, ): """ Initialize and authorize a BackendServiceServer context manager. @@ -191,20 +217,23 @@ def __init__( :param jwks: content of a JWK Set file, containing the private key for the registered public key :param bearer_token: a bearer token, containing the secret key to sign https requests (instead of JWKS) """ + # Allow url to be None in the case we are a fully local ETL run, and this class is basically a no-op self._base_url = url # all requests are relative to this URL - if not self._base_url.endswith("/"): + if self._base_url and not self._base_url.endswith("/"): self._base_url += "/" # The base URL may not be the server root (like it may be a Group export URL). Let's find the root. self._server_root = self._base_url - self._server_root = re.sub(r"/Patient/$", "/", self._server_root) - self._server_root = re.sub(r"/Group/[^/]+/$", "/", self._server_root) + if self._server_root: + self._server_root = re.sub(r"/Patient/$", "/", self._server_root) + self._server_root = re.sub(r"/Group/[^/]+/$", "/", self._server_root) + self._auth = self._make_auth(resources, client_id, jwks, bearer_token) self._session: Optional[httpx.AsyncClient] = None async def __aenter__(self): # Limit the number of connections open at once, because EHRs tend to be very busy. limits = httpx.Limits(max_connections=5) - self._session = httpx.AsyncClient(limits=limits) + self._session = httpx.AsyncClient(limits=limits, timeout=300) # five minutes to be generous await self._auth.authorize(self._session) return self @@ -212,32 +241,12 @@ async def __aexit__(self, exc_type, exc_value, traceback): if self._session: await self._session.aclose() - def _make_auth(self, resources: List[str], client_id: str, jwks: dict, bearer_token: str) -> Auth: - """Determine which auth method to use based on user provided arguments""" - - if bearer_token and (client_id or jwks): - print("--bearer-token cannot be used with --smart-client-id or --smart-jwks", file=sys.stderr) - raise SystemExit(errors.ARGS_CONFLICT) - - if bearer_token: - return BearerAuth(bearer_token) - - # Confirm that all required SMART arguments were provided - error_list = [] - if not client_id: - error_list.append("You must provide a client ID with --smart-client-id to connect to a SMART FHIR server.") - if jwks is None: - error_list.append("You must provide a JWKS file with --smart-jwks to connect to a SMART FHIR server.") - if error_list: - print("\n".join(error_list), file=sys.stderr) - raise SystemExit(errors.SMART_CREDENTIALS_MISSING) - - return JwksAuth(self._server_root, client_id, jwks, resources) - async def request(self, method: str, path: str, headers: dict = None, stream: bool = False) -> httpx.Response: """ Issues an HTTP request. + The default Accept type is application/fhir+json, but can be overridden by a provided header. + This is a lightly modified version of FHIRServer._get(), but additionally supports streaming and reauthorization. @@ -249,7 +258,7 @@ async def request(self, method: str, path: str, headers: dict = None, stream: bo :param stream: whether to stream content in or load it all into memory at once :returns: The response object """ - url = urllib.parse.urljoin(self._base_url, path) + url = _urljoin(self._base_url, path) final_headers = { "Accept": "application/fhir+json", @@ -262,7 +271,7 @@ async def request(self, method: str, path: str, headers: dict = None, stream: bo # Check if our access token expired and thus needs to be refreshed if response.status_code == 401: - await self._auth.authorize(self._session) + await self._auth.authorize(self._session, reauthorize=True) if stream: await response.aclose() response = await self._request_with_signed_headers(method, url, final_headers, stream=stream) @@ -300,6 +309,28 @@ async def request(self, method: str, path: str, headers: dict = None, stream: bo # ################################################################################################################### + def _make_auth(self, resources: Iterable[str], client_id: str, jwks: dict, bearer_token: str) -> Auth: + """Determine which auth method to use based on user provided arguments""" + valid_jwks = jwks is not None + + if bearer_token and (client_id or valid_jwks): + print("--bearer-token cannot be used with --smart-client-id or --smart-jwks", file=sys.stderr) + raise SystemExit(errors.ARGS_CONFLICT) + + if bearer_token: + return BearerAuth(bearer_token) + + if client_id and valid_jwks: + return JwksAuth(self._server_root, client_id, jwks, resources) + elif client_id or valid_jwks: + print( + "You must provide both --smart-client-id and --smart-jwks to connect to a SMART FHIR server.", + file=sys.stderr, + ) + raise SystemExit(errors.SMART_CREDENTIALS_MISSING) + + return Auth() + async def _request_with_signed_headers( self, method: str, url: str, headers: dict = None, **kwargs ) -> httpx.Response: @@ -311,6 +342,9 @@ async def _request_with_signed_headers( :param headers: header dictionary :returns: The response object """ + if not self._session: + raise RuntimeError("FhirClient must be used as a context manager") + headers = self._auth.sign_headers(headers) request = self._session.build_request(method, url, headers=headers) # Follow redirects by default -- some EHRs definitely use them for bulk download files, diff --git a/cumulus/loaders/fhir/bulk_export.py b/cumulus/loaders/fhir/bulk_export.py index 44d4828b..11533f55 100644 --- a/cumulus/loaders/fhir/bulk_export.py +++ b/cumulus/loaders/fhir/bulk_export.py @@ -9,7 +9,7 @@ import httpx from cumulus import common -from cumulus.loaders.fhir.backend_service import BackendServiceServer, FatalError +from cumulus.fhir_client import FatalError, FhirClient class BulkExporter: @@ -24,19 +24,19 @@ class BulkExporter: _TIMEOUT_THRESHOLD = 60 * 60 * 24 # a day, which is probably an overly generous timeout def __init__( - self, server: BackendServiceServer, resources: List[str], destination: str, since: str = None, until: str = None + self, client: FhirClient, resources: List[str], destination: str, since: str = None, until: str = None ): """ Initialize a bulk exporter (but does not start an export). - :param server: a server instance ready to make requests + :param client: a client ready to make requests :param resources: a list of resource names to export :param destination: a local folder to store all the files :param since: start date for export :param until: end date for export """ super().__init__() - self._server = server + self._client = client self._resources = resources self._destination = destination self._total_wait_time = 0 # in seconds, across all our requests @@ -128,7 +128,7 @@ async def _request_with_delay( :returns: the HTTP response """ while self._total_wait_time < self._TIMEOUT_THRESHOLD: - response = await self._server.request(method, path, headers=headers) + response = await self._client.request(method, path, headers=headers) if response.status_code == target_status_code: return response @@ -203,7 +203,7 @@ async def _download_ndjson_file(self, url: str, filename: str) -> None: :param url: URL location of file to download :param filename: local path to write data to """ - response = await self._server.request("GET", url, headers={"Accept": "application/fhir+ndjson"}, stream=True) + response = await self._client.request("GET", url, headers={"Accept": "application/fhir+ndjson"}, stream=True) try: with open(filename, "w", encoding="utf8") as file: async for block in response.aiter_text(): diff --git a/cumulus/loaders/fhir/fhir_ndjson.py b/cumulus/loaders/fhir/fhir_ndjson.py index dbcee5d1..f23325d7 100644 --- a/cumulus/loaders/fhir/fhir_ndjson.py +++ b/cumulus/loaders/fhir/fhir_ndjson.py @@ -6,8 +6,8 @@ from typing import List from cumulus import common, errors, store +from cumulus.fhir_client import FatalError, FhirClient from cumulus.loaders import base -from cumulus.loaders.fhir.backend_service import BackendServiceServer, FatalError from cumulus.loaders.fhir.bulk_export import BulkExporter @@ -22,34 +22,18 @@ class FhirNdjsonLoader(base.Loader): def __init__( self, root: store.Root, - client_id: str = None, - jwks: str = None, - bearer_token: str = None, + client: FhirClient, since: str = None, until: str = None, ): """ :param root: location to load ndjson from - :param client_id: client ID for a SMART server - :param jwks: path to a JWKS file for a SMART server - :param bearer_token: path to a file with a bearer token for a FHIR server + :param client: client ready to talk to a FHIR server :param since: export start date for a FHIR server :param until: export end date for a FHIR server """ super().__init__(root) - - try: - try: - self.client_id = common.read_text(client_id).strip() if client_id else None - except FileNotFoundError: - self.client_id = client_id - - self.jwks = common.read_json(jwks) if jwks else None - self.bearer_token = common.read_text(bearer_token).strip() if bearer_token else None - except OSError as exc: - print(exc, file=sys.stderr) - raise SystemExit(errors.ARGS_INVALID) from exc - + self.client = client self.since = since self.until = until @@ -58,7 +42,7 @@ async def load_all(self, resources: List[str]) -> tempfile.TemporaryDirectory: if self.root.protocol in ["http", "https"]: return await self._load_from_bulk_export(resources) - if self.client_id or self.jwks or self.bearer_token or self.since or self.until: + if self.since or self.until: print("You provided FHIR bulk export parameters but did not provide a FHIR server", file=sys.stderr) raise SystemExit(errors.ARGS_CONFLICT) @@ -84,11 +68,8 @@ async def _load_from_bulk_export(self, resources: List[str]) -> tempfile.Tempora tmpdir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with try: - async with BackendServiceServer( - self.root.path, resources, client_id=self.client_id, jwks=self.jwks, bearer_token=self.bearer_token - ) as server: - bulk_exporter = BulkExporter(server, resources, tmpdir.name, self.since, self.until) - await bulk_exporter.export() + bulk_exporter = BulkExporter(self.client, resources, tmpdir.name, self.since, self.until) + await bulk_exporter.export() except FatalError as exc: print(str(exc), file=sys.stderr) raise SystemExit(errors.BULK_EXPORT_FAILED) from exc diff --git a/cumulus/tasks.py b/cumulus/tasks.py index 5fe7544e..876bcf6a 100644 --- a/cumulus/tasks.py +++ b/cumulus/tasks.py @@ -5,7 +5,7 @@ import os import re import sys -from typing import Iterable, Iterator, List, Set, Type, TypeVar, Union +from typing import AsyncIterable, AsyncIterator, Iterable, Iterator, List, Set, Type, TypeVar, Union import pandas @@ -15,7 +15,7 @@ AnyTask = TypeVar("AnyTask", bound="EtlTask") -def _batch_slice(iterable: Iterable[Union[List[T], T]], n: int) -> Iterator[T]: +async def _batch_slice(iterable: AsyncIterable[Union[List[T], T]], n: int) -> AsyncIterator[T]: """ Returns the first n elements of iterable, flattening lists, but including an entire list if we would end in middle. @@ -24,9 +24,10 @@ def _batch_slice(iterable: Iterable[Union[List[T], T]], n: int) -> Iterator[T]: Note that this will only flatten elements that are actual Python lists (isinstance list is True) """ count = 0 - for item in iterable: + async for item in iterable: if isinstance(item, list): - yield from item + for x in item: + yield x count += len(item) else: yield item @@ -36,7 +37,14 @@ def _batch_slice(iterable: Iterable[Union[List[T], T]], n: int) -> Iterator[T]: return -def _batch_iterate(iterable: Iterable[Union[List[T], T]], size: int) -> Iterator[Iterator[T]]: +async def _async_chain(first: T, rest: AsyncIterator[T]) -> AsyncIterator[T]: + """An asynchronous version of itertools.chain([first], rest)""" + yield first + async for x in rest: + yield x + + +async def _batch_iterate(iterable: AsyncIterable[Union[List[T], T]], size: int) -> AsyncIterator[AsyncIterator[T]]: """ Yields sub-iterators, each roughly {size} elements from iterable. @@ -59,14 +67,17 @@ def _batch_iterate(iterable: Iterable[Union[List[T], T]], size: int) -> Iterator if size < 1: raise ValueError("Must iterate by at least a batch of 1") - true_iterable = iter(iterable) # in case it's actually a list (we want to iterate only once through) + # aiter() and anext() were added in python 3.10 + # pylint: disable=unnecessary-dunder-call + + true_iterable = iterable.__aiter__() # get a real once-through iterable (we want to iterate only once) while True: iter_slice = _batch_slice(true_iterable, size) try: - peek = next(iter_slice) - except StopIteration: + peek = await iter_slice.__anext__() + except StopAsyncIteration: return # we're done! - yield itertools.chain([peek], iter_slice) + yield _async_chain(peek, iter_slice) class EtlTask: @@ -182,7 +193,7 @@ def __init__(self, task_config: config.JobConfig, scrubber: deid.Scrubber): self.task_config = task_config self.scrubber = scrubber - def run(self) -> config.JobSummary: + async def run(self) -> config.JobSummary: """ Executes a single task and returns the summary. @@ -198,9 +209,10 @@ def run(self) -> config.JobSummary: # At this point we have a giant iterable of de-identified FHIR objects, ready to be written out. # We want to batch them up, to allow resuming from interruptions more easily. - for index, batch in enumerate(_batch_iterate(entries, self.task_config.batch_size)): + index = 0 + async for batch in _batch_iterate(entries, self.task_config.batch_size): # Stuff de-identified FHIR json into one big pandas DataFrame - dataframe = pandas.DataFrame(batch) + dataframe = pandas.DataFrame([row async for row in batch]) # Checkpoint scrubber data before writing to the store, because if we get interrupted, it's safer to have an # updated codebook with no data than data with an inaccurate codebook. @@ -210,6 +222,7 @@ def run(self) -> config.JobSummary: formatter.write_records(dataframe, index) print(f" {summary.success:,} processed for {self.name}") + index += 1 # All data is written, now do any final cleanup the formatter wants formatter.finalize() @@ -238,7 +251,7 @@ def read_ndjson(self) -> Iterator[dict]: for line in f: yield json.loads(line) - def read_entries(self) -> Iterator[Union[List[dict], dict]]: + async def read_entries(self) -> AsyncIterator[Union[List[dict], dict]]: """ Reads input entries for the job. @@ -248,7 +261,8 @@ def read_entries(self) -> Iterator[Union[List[dict], dict]]: the elements of which will be guaranteed to all be in the same output batch. See comments for EtlTask.group_field for why you might do this. """ - return filter(self.scrubber.scrub_resource, self.read_ndjson()) + for x in filter(self.scrubber.scrub_resource, self.read_ndjson()): + yield x ########################################################################################## @@ -332,7 +346,7 @@ def is_ed_coding(cls, coding): """Returns true if this is a coding for an emergency department note""" return coding.get("code") in cls.ED_CODES.get(coding.get("system"), {}) - def read_entries(self) -> Iterator[Union[List[dict], dict]]: + async def read_entries(self) -> AsyncIterator[Union[List[dict], dict]]: """Passes physician notes through NLP and returns any symptoms found""" phi_root = store.Root(self.task_config.dir_phi, create=True) @@ -340,7 +354,7 @@ def read_entries(self) -> Iterator[Union[List[dict], dict]]: # Check that the note is one of our special allow-listed types (we do this here rather than on the output # side to save needing to run everything through NLP). # We check both type and category for safety -- we aren't sure yet how EHRs are using these fields. - codings = docref.get("category", {}).get("coding", []) + codings = list(itertools.chain.from_iterable([cat.get("coding", []) for cat in docref.get("category", [])])) codings += docref.get("type", {}).get("coding", []) is_er_note = any(self.is_ed_coding(x) for x in codings) if not is_er_note: @@ -352,4 +366,4 @@ def read_entries(self) -> Iterator[Union[List[dict], dict]]: # Yield the whole set of symptoms at once, to allow for more easily replacing previous a set of symptoms. # This way we don't need to worry about symptoms from the same note crossing batch boundaries. # The Format class will replace all existing symptoms from this note at once (because we set group_field). - yield ctakes.covid_symptoms_extract(phi_root, docref) + yield await ctakes.covid_symptoms_extract(self.task_config.client, phi_root, docref) diff --git a/docs/howtos/epic-tips.md b/docs/howtos/epic-tips.md index ef16b222..7a63fc15 100644 --- a/docs/howtos/epic-tips.md +++ b/docs/howtos/epic-tips.md @@ -2,6 +2,20 @@ # Epic Tips & Tricks +## Frequent Bulk Exporting + +You may encounter this error: +`Error processing Bulk Data Kickoff request: Request not allowed: The Client requested this Group too recently.`. + +If so, you will want to update the `FHIR_BULK_CLIENT_REQUEST_WINDOW_TBL` to a longer time. +The default is 24 hours. + +## Long IDs + +In rare cases, Epic's bulk FHIR export can generate IDs that are longer than the mandated 64-character limit. +Cumulus ETL itself will not mind this, but if you use another bulk export client, you may find that it complains. +If so, you have to reach out to Epic to cap it at 64 characters. + ## Batch Updates Epic has not yet (as of early 2023) implemented the `_since` or `_typeFilter` parameters for bulk exports. diff --git a/docs/howtos/run-cumulus-etl.md b/docs/howtos/run-cumulus-etl.md index ab4ec353..ab8444de 100644 --- a/docs/howtos/run-cumulus-etl.md +++ b/docs/howtos/run-cumulus-etl.md @@ -328,6 +328,7 @@ The [SMART Bulk Data Client](https://github.com/smart-on-fhir/bulk-data-client) options than Cumulus ETL's built-in exporter offers. If you use this tool, pass Cumulus ETL the folder that holds the downloaded data as the input path. +And you may need to specify `--fhir-url=` so that external document notes can be downloaded. ## EHR-Specific Advice diff --git a/tests/test_bulk_export.py b/tests/test_bulk_export.py index f35ae175..b487f9ac 100644 --- a/tests/test_bulk_export.py +++ b/tests/test_bulk_export.py @@ -2,7 +2,6 @@ import contextlib import io -import os import tempfile import time import unittest @@ -17,7 +16,7 @@ from jwcrypto import jwk, jwt from cumulus import common, errors, etl, loaders, store -from cumulus.loaders.fhir.backend_service import BackendServiceServer, FatalError +from cumulus.fhir_client import FatalError, FhirClient from cumulus.loaders.fhir.bulk_export import BulkExporter @@ -41,40 +40,69 @@ def make_response(status_code=200, json=None, text=None, reason=None, headers=No ) -class TestBulkLoader(unittest.IsolatedAsyncioTestCase): +class TestBulkEtl(unittest.IsolatedAsyncioTestCase): """ Test case for bulk export support in the etl pipeline and ndjson loader. - i.e. tests for fhir_ndjson.py + i.e. tests for etl.py & fhir_ndjson.py This does no actual bulk loading. """ def setUp(self): super().setUp() - self.root = store.Root("http://localhost:9999") - self.jwks_file = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with self.jwks_path = self.jwks_file.name self.jwks_file.write(b'{"fake":"jwks"}') self.jwks_file.flush() - # Mock out the backend service and bulk export code by default. We don't care about actually doing any + # Mock out the bulk export code by default. We don't care about actually doing any # bulk work in this test case, just confirming the flow. - - server_patcher = mock.patch("cumulus.loaders.fhir.fhir_ndjson.BackendServiceServer") - self.addCleanup(server_patcher.stop) - self.mock_server = server_patcher.start() - - exporter_patcher = mock.patch("cumulus.loaders.fhir.fhir_ndjson.BulkExporter") + exporter_patcher = mock.patch("cumulus.loaders.fhir.fhir_ndjson.BulkExporter", spec=BulkExporter) self.addCleanup(exporter_patcher.stop) - self.mock_exporter = exporter_patcher.start() + self.mock_exporter_class = exporter_patcher.start() + self.mock_exporter = mock.AsyncMock() + self.mock_exporter_class.return_value = self.mock_exporter + @mock.patch("cumulus.etl.fhir_client.FhirClient") @mock.patch("cumulus.etl.loaders.FhirNdjsonLoader") - async def test_etl_passes_args(self, mock_loader): + async def test_etl_passes_args(self, mock_loader, mock_client): """Verify that we are passed the client ID and JWKS from the command line""" mock_loader.side_effect = ValueError # just to stop the etl pipeline once we get this far + with tempfile.NamedTemporaryFile(buffering=0) as bt_file: + bt_file.write(b"bt") + + with self.assertRaises(ValueError): + await etl.main( + [ + "http://localhost:9999", + "/tmp/output", + "/tmp/phi", + "--skip-init-checks", + "--input-format=ndjson", + "--smart-client-id=x", + f"--smart-jwks={self.jwks_path}", + f"--bearer-token={bt_file.name}", + "--since=2018", + "--until=2020", + ] + ) + + self.assertEqual(1, mock_client.call_count) + self.assertEqual("x", mock_client.call_args[1]["client_id"]) + self.assertEqual({"fake": "jwks"}, mock_client.call_args[1]["jwks"]) + self.assertEqual("bt", mock_client.call_args[1]["bearer_token"]) + self.assertEqual(1, mock_loader.call_count) + self.assertEqual("2018", mock_loader.call_args[1]["since"]) + self.assertEqual("2020", mock_loader.call_args[1]["until"]) + + @mock.patch("cumulus.etl.fhir_client.FhirClient") + async def test_reads_client_id_from_file(self, mock_client): + """Verify that we try to read a client ID from a file.""" + mock_client.side_effect = ValueError # just to stop the etl pipeline once we get this far + + # First, confirm string is used directly if file doesn't exist with self.assertRaises(ValueError): await etl.main( [ @@ -82,92 +110,127 @@ async def test_etl_passes_args(self, mock_loader): "/tmp/output", "/tmp/phi", "--skip-init-checks", - "--input-format=ndjson", - "--smart-client-id=x", - "--smart-jwks=y", - "--bearer-token=bt", - "--since=2018", - "--until=2020", + "--smart-client-id=/direct-string", ] ) - - self.assertEqual(1, mock_loader.call_count) - self.assertEqual("x", mock_loader.call_args[1]["client_id"]) - self.assertEqual("y", mock_loader.call_args[1]["jwks"]) - self.assertEqual("bt", mock_loader.call_args[1]["bearer_token"]) - self.assertEqual("2018", mock_loader.call_args[1]["since"]) - self.assertEqual("2020", mock_loader.call_args[1]["until"]) - - def test_reads_client_id_from_file(self): - """Verify that we require both a client ID and a JWK Set.""" - # First, confirm string is used directly if file doesn't exist - loader = loaders.FhirNdjsonLoader(self.root, client_id="/direct-string") - self.assertEqual("/direct-string", loader.client_id) + self.assertEqual("/direct-string", mock_client.call_args[1]["client_id"]) # Now read from a file that exists - with tempfile.NamedTemporaryFile() as file: + with tempfile.NamedTemporaryFile(buffering=0) as file: file.write(b"\ninside-file\n") - file.flush() - loader = loaders.FhirNdjsonLoader(self.root, client_id=file.name) - self.assertEqual("inside-file", loader.client_id) + with self.assertRaises(ValueError): + await etl.main( + [ + "http://localhost:9999", + "/tmp/output", + "/tmp/phi", + "--skip-init-checks", + f"--smart-client-id={file.name}", + ] + ) + self.assertEqual("inside-file", mock_client.call_args[1]["client_id"]) - def test_reads_bearer_token(self): + @mock.patch("cumulus.etl.fhir_client.FhirClient") + async def test_reads_bearer_token(self, mock_client): """Verify that we read the bearer token file""" - with tempfile.NamedTemporaryFile() as file: + mock_client.side_effect = ValueError # just to stop the etl pipeline once we get this far + + with tempfile.NamedTemporaryFile(buffering=0) as file: file.write(b"\ninside-file\n") - file.flush() - loader = loaders.FhirNdjsonLoader(self.root, bearer_token=file.name) - self.assertEqual("inside-file", loader.bearer_token) + with self.assertRaises(ValueError): + await etl.main( + [ + "http://localhost:9999", + "/tmp/output", + "/tmp/phi", + "--skip-init-checks", + f"--bearer-token={file.name}", + ] + ) + self.assertEqual("inside-file", mock_client.call_args[1]["bearer_token"]) - async def test_export_flow(self): - """ - Verify that we make all the right calls into the bulk export helper classes. + @mock.patch("cumulus.etl.fhir_client.FhirClient") + async def test_fhir_url(self, mock_client): + """Verify that we handle the user provided --fhir-client correctly""" + mock_client.side_effect = ValueError # just to stop the etl pipeline once we get this far - This is a little lower-level than I would normally test, but the benefit of ensuring this flow here is that - the other test cases can focus on just the helper classes and trust that the flow works, without us needing to - do the full flow each time. - """ - mock_server_instance = mock.AsyncMock() - self.mock_server.return_value = mock_server_instance - mock_exporter_instance = mock.AsyncMock() - self.mock_exporter.return_value = mock_exporter_instance + # Confirm that we don't allow conflicting URLs + with self.assertRaises(SystemExit): + await etl.main( + [ + "http://localhost:9999", + "/tmp/output", + "/tmp/phi", + "--skip-init-checks", + "--fhir-url=https://example.com/hello", + ] + ) - loader = loaders.FhirNdjsonLoader(self.root, client_id="foo", jwks=self.jwks_path) - await loader.load_all(["Condition", "Encounter"]) + # But a subset --fhir-url is fine + with self.assertRaises(ValueError): + await etl.main( + [ + "https://example.com/hello/Group/1234", + "/tmp/output", + "/tmp/phi", + "--skip-init-checks", + "--fhir-url=https://example.com/hello", + ] + ) + self.assertEqual("https://example.com/hello/Group/1234", mock_client.call_args[0][0]) - expected_resources = [ - "Condition", - "Encounter", - ] + # Now do a normal use of --fhir-url + mock_client.side_effect = ValueError # just to stop the etl pipeline once we get this far + with self.assertRaises(ValueError): + await etl.main( + [ + "/tmp/input", + "/tmp/output", + "/tmp/phi", + "--skip-init-checks", + "--fhir-url=https://example.com/hello", + ] + ) + self.assertEqual("https://example.com/hello", mock_client.call_args[0][0]) - self.assertEqual( - [ - mock.call( - self.root.path, expected_resources, client_id="foo", jwks={"fake": "jwks"}, bearer_token=None - ), - ], - self.mock_server.call_args_list, - ) + @mock.patch("cumulus.etl.fhir_client.FhirClient") + async def test_export_flow(self, mock_client): + """ + Verify that we make the right calls down as far as the bulk export helper classes, with the right resources. + """ + self.mock_exporter.export.side_effect = ValueError # stop us when we get this far, but also confirm we call it - self.assertEqual(1, self.mock_exporter.call_count) - self.assertEqual(expected_resources, self.mock_exporter.call_args[0][1]) + with self.assertRaises(ValueError): + await etl.main( + [ + "http://localhost:9999", + "/tmp/output", + "/tmp/phi", + "--skip-init-checks", + "--task=condition,encounter", + ] + ) - self.assertEqual(1, mock_exporter_instance.export.call_count) + expected_resources = {"Condition", "Encounter"} + self.assertEqual(1, mock_client.call_count) + self.assertEqual(expected_resources, mock_client.call_args[0][1]) + self.assertEqual(1, self.mock_exporter_class.call_count) + self.assertEqual(expected_resources, set(self.mock_exporter_class.call_args[0][1])) async def test_fatal_errors_are_fatal(self): """Verify that when a FatalError is raised, we do really quit""" - self.mock_server.side_effect = FatalError + self.mock_exporter.export.side_effect = FatalError with self.assertRaises(SystemExit) as cm: - await loaders.FhirNdjsonLoader(self.root, client_id="foo", jwks=self.jwks_path).load_all(["Patient"]) + await loaders.FhirNdjsonLoader(store.Root("http://localhost:9999"), mock.AsyncMock()).load_all(["Patient"]) - self.assertEqual(1, self.mock_server.call_count) + self.assertEqual(1, self.mock_exporter.export.call_count) self.assertEqual(errors.BULK_EXPORT_FAILED, cm.exception.code) @ddt.ddt @freezegun.freeze_time("Sep 15th, 2021 1:23:45") -@mock.patch("cumulus.loaders.fhir.backend_service.uuid.uuid4", new=lambda: "1234") +@mock.patch("cumulus.fhir_client.uuid.uuid4", new=lambda: "1234") class TestBulkServer(unittest.IsolatedAsyncioTestCase): """ Test case for bulk export server oauth2 / request support. @@ -226,7 +289,7 @@ def setUp(self): # Set up mocks for fhirclient (we don't need to test its oauth code by mocking server responses there) self.mock_client = mock.MagicMock() # FHIRClient instance self.mock_server = self.mock_client.server # FHIRServer instance - client_patcher = mock.patch("cumulus.loaders.fhir.backend_service.FHIRClient") + client_patcher = mock.patch("cumulus.fhir_client.fhirclient.client.FHIRClient") self.addCleanup(client_patcher.stop) self.mock_client_class = client_patcher.start() # FHIRClient class self.mock_client_class.return_value = self.mock_client @@ -239,30 +302,41 @@ def mock_session(server, *args, **kwargs): async def test_required_arguments(self): """Verify that we require both a client ID and a JWK Set""" - # No SMART args at all - with self.assertRaises(SystemExit): - async with BackendServiceServer(self.server_url, []): - pass + # Deny any actual requests during this test + self.respx_mock.get(f"{self.server_url}/test").respond(status_code=401) + + # Simple helper to open and make a call on client. + async def use_client(request=False, code=None, url=self.server_url, **kwargs): + try: + async with FhirClient(url, [], **kwargs) as client: + if request: + await client.request("GET", "test") + except SystemExit as exc: + if code is None: + raise + self.assertEqual(code, exc.code) + + # No SMART args at all doesn't cause any problem, if we don't make calls + await use_client() + + # No SMART args at all will raise though if we do make a call + await use_client(code=errors.SMART_CREDENTIALS_MISSING, request=True) + + # No client ID + await use_client(code=errors.FHIR_URL_MISSING, request=True, url=None) # No JWKS - with self.assertRaises(SystemExit): - async with BackendServiceServer(self.server_url, [], client_id="foo"): - pass + await use_client(code=errors.SMART_CREDENTIALS_MISSING, client_id="foo") # No client ID - with self.assertRaises(SystemExit): - async with BackendServiceServer(self.server_url, [], jwks=self.jwks): - pass + await use_client(code=errors.SMART_CREDENTIALS_MISSING, jwks=self.jwks) # Works fine if both given - async with BackendServiceServer(self.server_url, [], client_id="foo", jwks=self.jwks): - pass + await use_client(client_id="foo", jwks=self.jwks) async def test_auth_initial_authorize(self): """Verify that we authorize correctly upon class initialization""" - async with BackendServiceServer( - self.server_url, ["Condition", "Patient"], client_id=self.client_id, jwks=self.jwks - ): + async with FhirClient(self.server_url, ["Condition", "Patient"], client_id=self.client_id, jwks=self.jwks): pass # Check initialization of FHIRClient @@ -292,7 +366,7 @@ async def test_auth_with_bearer_token(self): headers={"Authorization": "Bearer fob"}, ) - async with BackendServiceServer(self.server_url, ["Condition", "Patient"], bearer_token="fob") as server: + async with FhirClient(self.server_url, ["Condition", "Patient"], bearer_token="fob") as server: await server.request("GET", "foo") async def test_get_with_new_header(self): @@ -300,7 +374,7 @@ async def test_get_with_new_header(self): # This is mostly confirming that we call mocks correctly, but that's important since we're mocking out all # of fhirclient. Since we do that, we need to confirm we're driving it well. - async with BackendServiceServer(self.server_url, [], client_id=self.client_id, jwks=self.jwks) as server: + async with FhirClient(self.server_url, [], client_id=self.client_id, jwks=self.jwks) as server: with self.mock_session(server) as mock_session: # With new header and stream await server.request("GET", "foo", headers={"Test": "Value"}, stream=True) @@ -330,7 +404,7 @@ async def test_get_with_new_header(self): async def test_get_with_overriden_header(self): """Verify that we issue a GET correctly for the happy path""" - async with BackendServiceServer(self.server_url, [], client_id=self.client_id, jwks=self.jwks) as server: + async with FhirClient(self.server_url, [], client_id=self.client_id, jwks=self.jwks) as server: with self.mock_session(server) as mock_session: # With overriding a header and default stream (False) await server.request("GET", "bar", headers={"Accept": "text/plain"}) @@ -366,7 +440,7 @@ async def test_get_with_overriden_header(self): ) async def test_jwks_without_suitable_key(self, bad_jwks): with self.assertRaisesRegex(FatalError, "No private ES384 or RS384 key found"): - async with BackendServiceServer(self.server_url, [], client_id=self.client_id, jwks=bad_jwks): + async with FhirClient(self.server_url, [], client_id=self.client_id, jwks=bad_jwks): pass @ddt.data( @@ -394,7 +468,7 @@ async def test_bad_smart_config(self, bad_config_override): ) with self.assertRaisesRegex(FatalError, "does not support the client-confidential-asymmetric protocol"): - async with BackendServiceServer(self.server_url, [], client_id=self.client_id, jwks=self.jwks): + async with FhirClient(self.server_url, [], client_id=self.client_id, jwks=self.jwks): pass async def test_authorize_error_with_response(self): @@ -404,19 +478,19 @@ async def test_authorize_error_with_response(self): error.response.json.return_value = {"error_description": "Ouch!"} self.mock_client.authorize.side_effect = error with self.assertRaisesRegex(FatalError, "Could not authenticate with the FHIR server: Ouch!"): - async with BackendServiceServer(self.server_url, [], client_id=self.client_id, jwks=self.jwks): + async with FhirClient(self.server_url, [], client_id=self.client_id, jwks=self.jwks): pass async def test_authorize_error_without_response(self): """Verify that we translate authorize non-response errors into FatalErrors.""" self.mock_client.authorize.side_effect = Exception("no memory") with self.assertRaisesRegex(FatalError, "Could not authenticate with the FHIR server: no memory"): - async with BackendServiceServer(self.server_url, [], client_id=self.client_id, jwks=self.jwks): + async with FhirClient(self.server_url, [], client_id=self.client_id, jwks=self.jwks): pass async def test_get_error_401(self): """Verify that an expired token is refreshed.""" - async with BackendServiceServer(self.server_url, [], client_id=self.client_id, jwks=self.jwks) as server: + async with FhirClient(self.server_url, [], client_id=self.client_id, jwks=self.jwks) as server: # Check that we correctly tried to re-authenticate with self.mock_session(server) as mock_session: mock_session.send.side_effect = [make_response(status_code=401), make_response()] @@ -431,7 +505,7 @@ async def test_get_error_401(self): async def test_get_error_429(self): """Verify that 429 errors are passed through and not treated as exceptions.""" - async with BackendServiceServer(self.server_url, [], client_id=self.client_id, jwks=self.jwks) as server: + async with FhirClient(self.server_url, [], client_id=self.client_id, jwks=self.jwks) as server: # Confirm 429 passes with self.mock_session(server, status_code=429): response = await server.request("GET", "foo") @@ -450,7 +524,7 @@ async def test_get_error_429(self): ) async def test_get_error_other(self, response_args): """Verify that other http errors are FatalErrors.""" - async with BackendServiceServer(self.server_url, [], client_id=self.client_id, jwks=self.jwks) as server: + async with FhirClient(self.server_url, [], client_id=self.client_id, jwks=self.jwks) as server: with self.mock_session(server, status_code=500, **response_args): with self.assertRaisesRegex(FatalError, "testmsg"): await server.request("GET", "foo") @@ -672,6 +746,7 @@ async def test_delete_if_interrupted(self): ) +@mock.patch("cumulus.deid.codebook.secrets.token_hex", new=lambda x: "1234") # just to not waste entropy class TestBulkExportEndToEnd(unittest.IsolatedAsyncioTestCase): """ Test case for doing an entire bulk export loop, without mocking python code. @@ -803,16 +878,28 @@ def set_up_requests(self, respx_mock): status_code=202, ) + @mock.patch("cumulus.deid.codebook.secrets.token_hex", new=lambda x: "1234") @responses.mock.activate(assert_all_requests_are_fired=True) async def test_successful_bulk_export(self): """Verify a happy path bulk export, from toe to tip""" - loader = loaders.FhirNdjsonLoader(self.root, client_id=self.client_id, jwks=self.jwks_path) - - with respx.mock(assert_all_called=True) as respx_mock: - self.set_up_requests(respx_mock) - tmpdir = await loader.load_all(["Patient"]) + with tempfile.TemporaryDirectory() as tmpdir: + with respx.mock(assert_all_called=True) as respx_mock: + self.set_up_requests(respx_mock) + + await etl.main( + [ + self.root.path, + f"{tmpdir}/output", + f"{tmpdir}/phi", + "--skip-init-checks", + "--output-format=ndjson", + "--task=patient", + f"--smart-client-id={self.client_id}", + f"--smart-jwks={self.jwks_path}", + ] + ) - self.assertEqual( - {"id": "testPatient1", "resourceType": "Patient"}, - common.read_json(os.path.join(tmpdir.name, "Patient.000.ndjson")), - ) + self.assertEqual( + {"id": "4342abf315cf6f243e11f4d460303e36c6c3663a25c91cc6b1a8002476c850dd", "resourceType": "Patient"}, + common.read_json(f"{tmpdir}/output/patient/patient.000.ndjson"), + ) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index facc0da8..7037421e 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -4,23 +4,26 @@ import shutil import tempfile import unittest +from typing import AsyncIterator, List from unittest import mock import ddt +import respx -from cumulus import common, config, deid, errors, tasks +from cumulus import common, config, deid, errors, fhir_client, tasks from tests.ctakesmock import CtakesMixin from tests import i2b2_mock_data @ddt.ddt -class TestTasks(CtakesMixin, unittest.TestCase): +class TestTasks(CtakesMixin, unittest.IsolatedAsyncioTestCase): """Test case for task methods""" def setUp(self) -> None: super().setUp() + client = fhir_client.FhirClient("http://localhost/", []) script_dir = os.path.dirname(__file__) data_dir = os.path.join(script_dir, "data/simple") self.tmpdir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with @@ -30,7 +33,7 @@ def setUp(self) -> None: os.makedirs(self.phi_dir) self.job_config = config.JobConfig( - self.input_dir, self.input_dir, self.tmpdir.name, self.phi_dir, "ndjson", "ndjson", batch_size=5 + self.input_dir, self.input_dir, self.tmpdir.name, self.phi_dir, "ndjson", "ndjson", client, batch_size=5 ) self.format = mock.MagicMock() @@ -51,60 +54,80 @@ def make_json(self, filename, resource_id, **kwargs): os.path.join(self.input_dir, f"{filename}.ndjson"), {"resourceType": "Test", **kwargs, "id": resource_id} ) - def test_batch_iterate(self): + async def test_batch_iterate(self): """Check a bunch of edge cases for the _batch_iterate helper""" # pylint: disable=protected-access - self.assertEqual([], [list(x) for x in tasks._batch_iterate([], 2)]) + # Tiny little convenience method to be turn sync lists into async iterators. + async def async_iter(values: List) -> AsyncIterator: + for x in values: + yield x - self.assertEqual( + # Handles converting all the async code into synchronous lists for ease of testing + async def assert_batches_equal(expected: List, values: List, batch_size: int) -> None: + collected = [] + async for batch in tasks._batch_iterate(async_iter(values), batch_size): + batch_list = [] + async for item in batch: + batch_list.append(item) + collected.append(batch_list) + self.assertEqual(expected, collected) + + await assert_batches_equal([], [], 2) + + await assert_batches_equal( [ [1, 2.1, 2.2], [3, 4], ], - [list(x) for x in tasks._batch_iterate([1, [2.1, 2.2], 3, 4], 2)], + [1, [2.1, 2.2], 3, 4], + 2, ) - self.assertEqual( + await assert_batches_equal( [ [1.1, 1.2], [2.1, 2.2], [3, 4], ], - [list(x) for x in tasks._batch_iterate([[1.1, 1.2], [2.1, 2.2], 3, 4], 2)], + [[1.1, 1.2], [2.1, 2.2], 3, 4], + 2, ) - self.assertEqual( + await assert_batches_equal( [ [1, 2.1, 2.2], [3, 4], [5], ], - [list(x) for x in tasks._batch_iterate([1, [2.1, 2.2], 3, 4, 5], 2)], + [1, [2.1, 2.2], 3, 4, 5], + 2, ) - self.assertEqual( + await assert_batches_equal( [ [1, 2.1, 2.2], [3, 4], ], - [list(x) for x in tasks._batch_iterate([1, [2.1, 2.2], 3, 4], 3)], + [1, [2.1, 2.2], 3, 4], + 3, ) - self.assertEqual( + await assert_batches_equal( [ [1], [2.1, 2.2], [3], ], - [list(x) for x in tasks._batch_iterate([1, [2.1, 2.2], 3], 1)], + [1, [2.1, 2.2], 3], + 1, ) with self.assertRaises(ValueError): - list(tasks._batch_iterate([1, 2, 3], 0)) + await assert_batches_equal([], [1, 2, 3], 0) with self.assertRaises(ValueError): - list(tasks._batch_iterate([1, 2, 3], -1)) + await assert_batches_equal([], [1, 2, 3], -1) def test_read_ndjson(self): """Verify we recognize all expected ndjson filename formats""" @@ -122,19 +145,19 @@ def test_read_ndjson(self): resources = tasks.EncounterTask(self.job_config, self.scrubber).read_ndjson() self.assertEqual([], list(resources)) - def test_unknown_modifier_extensions_skipped_for_patients(self): + async def test_unknown_modifier_extensions_skipped_for_patients(self): """Verify we ignore unknown modifier extensions during a normal task (like patients)""" self.make_json("Patient.0", "0") self.make_json("Patient.1", "1", modifierExtension=[{"url": "unrecognized"}]) - tasks.PatientTask(self.job_config, self.scrubber).run() + await tasks.PatientTask(self.job_config, self.scrubber).run() # Confirm that only patient 0 got stored self.assertEqual(1, self.format.write_records.call_count) df = self.format.write_records.call_args[0][0] self.assertEqual([self.codebook.db.patient("0")], list(df.id)) - def test_unknown_modifier_extensions_skipped_for_nlp_symptoms(self): + async def test_unknown_modifier_extensions_skipped_for_nlp_symptoms(self): """Verify we ignore unknown modifier extensions during a custom read task (nlp symptoms)""" docref0 = i2b2_mock_data.documentreference() docref0["subject"]["reference"] = "Patient/1234" @@ -144,7 +167,7 @@ def test_unknown_modifier_extensions_skipped_for_nlp_symptoms(self): docref1["modifierExtension"] = [{"url": "unrecognized"}] self.make_json("DocumentReference.1", "1", **docref1) - tasks.CovidSymptomNlpResultsTask(self.job_config, self.scrubber).run() + await tasks.CovidSymptomNlpResultsTask(self.job_config, self.scrubber).run() # Confirm that only symptoms from docref 0 got stored self.assertEqual(1, self.format.write_records.call_count) @@ -165,25 +188,25 @@ def test_unknown_modifier_extensions_skipped_for_nlp_symptoms(self): ([{"system": "nope", "code": "nope"}, {"system": "http://loinc.org", "code": "57053-1"}], True), ) @ddt.unpack - def test_ed_note_filtering_for_nlp(self, codings, expected): + async def test_ed_note_filtering_for_nlp(self, codings, expected): """Verify we filter out any non-emergency-department note""" # Use one doc with category set, and one with type set. Either should work. docref0 = i2b2_mock_data.documentreference() - docref0["category"] = {"coding": codings} + docref0["category"] = [{"coding": codings}] del docref0["type"] self.make_json("DocumentReference.0", "0", **docref0) docref1 = i2b2_mock_data.documentreference() docref1["type"] = {"coding": codings} self.make_json("DocumentReference.1", "1", **docref1) - tasks.CovidSymptomNlpResultsTask(self.job_config, self.scrubber).run() + await tasks.CovidSymptomNlpResultsTask(self.job_config, self.scrubber).run() self.assertEqual(1 if expected else 0, self.format.write_records.call_count) if expected: df = self.format.write_records.call_args[0][0] self.assertEqual(4, len(df)) - def test_non_ed_visit_is_skipped_for_covid_symptoms(self): + async def test_non_ed_visit_is_skipped_for_covid_symptoms(self): """Verify we ignore non ED visits for the covid symptoms NLP""" docref0 = i2b2_mock_data.documentreference() docref0["type"]["coding"][0]["code"] = "NOTE:nope" # pylint: disable=unsubscriptable-object @@ -192,7 +215,7 @@ def test_non_ed_visit_is_skipped_for_covid_symptoms(self): docref1["type"]["coding"][0]["code"] = "NOTE:149798455" # pylint: disable=unsubscriptable-object self.make_json("DocumentReference.1", "present", **docref1) - tasks.CovidSymptomNlpResultsTask(self.job_config, self.scrubber).run() + await tasks.CovidSymptomNlpResultsTask(self.job_config, self.scrubber).run() # Confirm that only symptoms from docref 'present' got stored self.assertEqual(1, self.format.write_records.call_count) @@ -215,3 +238,34 @@ def test_filtered_but_named_task(self): with self.assertRaises(SystemExit) as cm: tasks.EtlTask.get_selected_tasks(names=["condition"], filter_tags=["gpu"]) self.assertEqual(errors.TASK_FILTERED_OUT, cm.exception.code) + + @ddt.data( + # list of (URL, contentType), expected text + ([("http://localhost/file-cough", "text/plain")], "cough"), # handles absolute URL + ([("file-cough", "text/html")], "cough"), # handles text/* + ([("file-cough", "application/xml")], "cough"), # handles app/xml + ([("file-cough", "text/html"), ("file-fever", "text/plain")], "fever"), # prefers text/plain to text/* + ([("file-cough", "application/xml"), ("file-fever", "text/blarg")], "fever"), # prefers text/* to app/xml + ([("file-cough", "nope/nope")], None), # ignores unsupported mimetypes + ) + @ddt.unpack + @respx.mock + async def test_note_urls_downloaded(self, attachments, expected_text): + """Verify that we download any attachments with URLs""" + # We return three words due to how our cTAKES mock works. It wants 3 words -- fever word is in middle. + respx.get("http://localhost/file-cough").respond(text="has cough bad") + respx.get("http://localhost/file-fever").respond(text="has fever bad") + + docref0 = i2b2_mock_data.documentreference() + docref0["content"] = [{"attachment": {"url": a[0], "contentType": a[1]}} for a in attachments] + self.make_json("DocumentReference.0", "doc0", **docref0) + + async with self.job_config.client: + await tasks.CovidSymptomNlpResultsTask(self.job_config, self.scrubber).run() + + if expected_text: + self.assertEqual(1, self.format.write_records.call_count) + df = self.format.write_records.call_args[0][0] + self.assertEqual(expected_text, df.iloc[0].match["text"]) + else: + self.assertEqual(0, self.format.write_records.call_count)