diff --git a/cumulus/config.py b/cumulus/config.py index a83a6da4..f6713847 100644 --- a/cumulus/config.py +++ b/cumulus/config.py @@ -5,7 +5,7 @@ from socket import gethostname from typing import List -from cumulus import common, formats, store +from cumulus import common, fhir_client, formats, store class JobConfig: @@ -25,6 +25,7 @@ def __init__( dir_phi: str, input_format: str, output_format: str, + client: fhir_client.FhirClient, timestamp: datetime.datetime = None, comment: str = None, batch_size: int = 1, # this default is never really used - overridden by command line args @@ -36,6 +37,7 @@ def __init__( self.dir_phi = dir_phi self._input_format = input_format self._output_format = output_format + self.client = client self.timestamp = common.timestamp_filename(timestamp) self.hostname = gethostname() self.comment = comment or "" diff --git a/cumulus/ctakes.py b/cumulus/ctakes.py index 7a85d5dc..ffc967d4 100644 --- a/cumulus/ctakes.py +++ b/cumulus/ctakes.py @@ -5,17 +5,18 @@ import hashlib import logging import os -from typing import List +from typing import List, Optional import ctakesclient -from cumulus import common, fhir_common, store +from cumulus import common, fhir_client, fhir_common, store -def covid_symptoms_extract(cache: store.Root, docref: dict) -> List[dict]: +async def covid_symptoms_extract(client: fhir_client.FhirClient, cache: store.Root, docref: dict) -> List[dict]: """ Extract a list of Observations from NLP-detected symptoms in physician notes + :param client: a client ready to talk to a FHIR server :param cache: Where to cache NLP results :param docref: Physician Note :return: list of NLP results encoded as FHIR observations @@ -25,20 +26,14 @@ def covid_symptoms_extract(cache: store.Root, docref: dict) -> List[dict]: encounters = docref.get("context", {}).get("encounter", []) if not encounters: - logging.warning("No valid encounters for symptoms") # ideally would print identifier, but it's PHI... + logging.warning("No encounters for docref %s", docref_id) return [] _, encounter_id = fhir_common.unref_resource(encounters[0]) # Find the physician note among the attachments - for content in docref["content"]: - if "contentType" in content["attachment"] and "data" in content["attachment"]: - mimetype, params = cgi.parse_header(content["attachment"]["contentType"]) - if mimetype == "text/plain": # just grab first text we find - charset = params.get("charset", "utf8") - physician_note = base64.standard_b64decode(content["attachment"]["data"]).decode(charset) - break - else: - logging.warning("No text/plain content in docref %s", docref_id) + physician_note = await get_docref_note(client, [content["attachment"] for content in docref["content"]]) + if physician_note is None: + logging.warning("No text content in docref %s", docref_id) return [] # Strip this "line feed" character that often shows up in notes and is confusing for cNLP. @@ -97,6 +92,81 @@ def is_covid_match(m: ctakesclient.typesystem.MatchText): return positive_matches +def parse_content_type(content_type: str) -> (str, str): + """Returns (mimetype, encoding)""" + # TODO: switch to message.Message parsing, since cgi is deprecated + mimetype, params = cgi.parse_header(content_type) + return mimetype, params.get("charset", "utf8") + + +def mimetype_priority(mimetype: str) -> int: + """ + Returns priority of mimetypes for docref notes. + + 0 means "ignore" + Higher numbers are higher priority + """ + if mimetype == "text/plain": + return 3 + elif mimetype.startswith("text/"): + return 2 + elif mimetype in ("application/xml", "application/xhtml+xml"): + return 1 + return 0 + + +async def get_docref_note(client: fhir_client.FhirClient, attachments: List[dict]) -> Optional[str]: + # Find the best attachment to use, based on mimetype. + # We prefer basic text documents, to avoid confusing cTAKES with extra formatting (like ). + best_attachment_index = -1 + best_attachment_priority = 0 + for index, attachment in enumerate(attachments): + if "contentType" in attachment: + mimetype, _ = parse_content_type(attachment["contentType"]) + priority = mimetype_priority(mimetype) + if priority > best_attachment_priority: + best_attachment_priority = priority + best_attachment_index = index + + if best_attachment_index >= 0: + return await get_docref_note_from_attachment(client, attachments[best_attachment_index]) + + # We didn't find _any_ of our target text content types. + # A content type isn't required by the spec with external URLs... so it's possible an unmarked link could be good. + # But let's optimistically enforce the need for a content type ourselves by bailing here. + # If we find a real-world need to be more permissive, we can change this later. + # But note that if we do, we'll need to handle downloading Binary FHIR objects, in addition to arbitrary URLs. + return None + + +async def get_docref_note_from_attachment(client: fhir_client.FhirClient, attachment: dict) -> Optional[str]: + """ + Decodes or downloads a note from an attachment. + + Note that it is assumed a contentType is provided. + + :returns: the attachment's note text + """ + mimetype, charset = parse_content_type(attachment["contentType"]) + + if "data" in attachment: + return base64.standard_b64decode(attachment["data"]).decode(charset) + + # TODO: At some point we should centralize the downloading of attachments -- once we have multiple NLP tasks, + # we may not want to re-download the overlapping notes. When we do that, it should not be part of our bulk + # exporter, since we may be given already-exported ndjson. + # + # TODO: There are future optimizations to try to use our ctakes cache to avoid downloading in the first place: + # - use attachment["hash"] if available (algorithm mismatch though... maybe we should switch to sha1...) + # - send a HEAD request with "Want-Digest: sha-256" but Cerner at least does not support that + if "url" in attachment: + # We need to pass Accept to get the raw data, not a Binary object. See https://www.hl7.org/fhir/binary.html + response = await client.request("GET", attachment["url"], headers={"Accept": mimetype}) + return response.text + + return None + + def extract(cache: store.Root, namespace: str, sentence: str) -> ctakesclient.typesystem.CtakesJSON: """ This is a version of ctakesclient.client.extract() that also uses a cache diff --git a/cumulus/deid/ms-config.json b/cumulus/deid/ms-config.json index b0069ef5..f67967bd 100644 --- a/cumulus/deid/ms-config.json +++ b/cumulus/deid/ms-config.json @@ -5,6 +5,7 @@ "fhirPathRules": [ {"path": "nodesByName('modifierExtension')", "method": "keep", "comment": "Cumulus: keep these so we can ignore resources with modifiers we don't understand"}, {"path": "DocumentReference.nodesByType('Attachment').data", "method": "keep", "comment": "Cumulus: needed to run NLP on physician notes"}, + {"path": "DocumentReference.nodesByType('Attachment').url", "method": "keep", "comment": "Cumulus: needed to run NLP on physician notes"}, {"path": "Patient.extension.where(url='http://hl7.org/fhir/us/core/StructureDefinition/us-core-birthsex')", "method": "keep", "comment": "Cumulus: useful for studies"}, {"path": "Patient.extension.where(url='http://hl7.org/fhir/us/core/StructureDefinition/us-core-ethnicity')", "method": "keep", "comment": "Cumulus: useful for studies"}, {"path": "Patient.extension.where(url='http://hl7.org/fhir/us/core/StructureDefinition/us-core-genderIdentity')", "method": "keep", "comment": "Cumulus: useful for studies"}, diff --git a/cumulus/deid/scrubber.py b/cumulus/deid/scrubber.py index 979fd9f2..1e831b9e 100644 --- a/cumulus/deid/scrubber.py +++ b/cumulus/deid/scrubber.py @@ -140,8 +140,8 @@ def _check_ids(self, node_path: str, node: dict, key: str, value: Any) -> None: @staticmethod def _check_attachments(node_path: str, node: dict, key: str) -> None: """Strip any attachment data""" - if node_path == "root.content.attachment" and key == "data": - del node["data"] + if node_path == "root.content.attachment" and key in {"data", "url"}: + del node[key] @staticmethod def _check_security(node_path: str, node: dict, key: str, value: Any) -> None: diff --git a/cumulus/errors.py b/cumulus/errors.py index ec5459cf..14514013 100644 --- a/cumulus/errors.py +++ b/cumulus/errors.py @@ -16,3 +16,4 @@ TASK_SET_EMPTY = 21 ARGS_CONFLICT = 22 ARGS_INVALID = 23 +FHIR_URL_MISSING = 24 diff --git a/cumulus/etl.py b/cumulus/etl.py index a48696c0..426d229e 100644 --- a/cumulus/etl.py +++ b/cumulus/etl.py @@ -11,12 +11,12 @@ import sys import tempfile import time -from typing import List, Type +from typing import Iterable, List, Type from urllib.parse import urlparse import ctakesclient -from cumulus import common, context, deid, errors, loaders, store, tasks +from cumulus import common, context, deid, errors, fhir_client, loaders, store, tasks from cumulus.config import JobConfig, JobSummary @@ -27,9 +27,7 @@ ############################################################################### -async def load_and_deidentify( - loader: loaders.Loader, selected_tasks: List[Type[tasks.EtlTask]] -) -> tempfile.TemporaryDirectory: +async def load_and_deidentify(loader: loaders.Loader, resources: Iterable[str]) -> tempfile.TemporaryDirectory: """ Loads the input directory and does a first-pass de-identification @@ -37,17 +35,14 @@ async def load_and_deidentify( :returns: a temporary directory holding the de-identified files in FHIR ndjson format """ - # Grab a list of all required resource types for the tasks we are running - required_resources = set(t.resource for t in selected_tasks) - # First step is loading all the data into a local ndjson format - loaded_dir = await loader.load_all(list(required_resources)) + loaded_dir = await loader.load_all(list(resources)) # Second step is de-identifying that data (at a bulk level) return await deid.Scrubber.scrub_bulk_data(loaded_dir.name) -def etl_job(config: JobConfig, selected_tasks: List[Type[tasks.EtlTask]]) -> List[JobSummary]: +async def etl_job(config: JobConfig, selected_tasks: List[Type[tasks.EtlTask]]) -> List[JobSummary]: """ :param config: job config :param selected_tasks: the tasks to run @@ -58,7 +53,7 @@ def etl_job(config: JobConfig, selected_tasks: List[Type[tasks.EtlTask]]) -> Lis scrubber = deid.Scrubber(config.dir_phi) for task_class in selected_tasks: task = task_class(config, scrubber) - summary = task.run() + summary = await task.run() summary_list.append(summary) path = os.path.join(config.dir_job_config(), f"{summary.label}.json") @@ -195,6 +190,9 @@ def make_parser() -> argparse.ArgumentParser: metavar="PATH", help="Bearer token for custom bearer authentication", ) + export.add_argument( + "--fhir-url", metavar="URL", help="FHIR server base URL, only needed if you exported separately" + ) export.add_argument("--since", help="Start date for export from the FHIR server") export.add_argument("--until", help="End date for export from the FHIR server") @@ -213,6 +211,39 @@ def make_parser() -> argparse.ArgumentParser: return parser +def create_fhir_client(args, root_input, resources): + client_base_url = args.fhir_url + if root_input.protocol in {"http", "https"}: + if args.fhir_url and not root_input.path.startswith(args.fhir_url): + print( + "You provided both an input FHIR server and a different --fhir-url. Try dropping --fhir-url.", + file=sys.stderr, + ) + raise SystemExit(errors.ARGS_CONFLICT) + client_base_url = root_input.path + + try: + try: + # Try to load client ID from file first (some servers use crazy long ones, like SMART's bulk-data-server) + smart_client_id = common.read_text(args.smart_client_id).strip() if args.smart_client_id else None + except FileNotFoundError: + smart_client_id = args.smart_client_id + + smart_jwks = common.read_json(args.smart_jwks) if args.smart_jwks else None + bearer_token = common.read_text(args.bearer_token).strip() if args.bearer_token else None + except OSError as exc: + print(exc, file=sys.stderr) + raise SystemExit(errors.ARGS_INVALID) from exc + + return fhir_client.FhirClient( + client_base_url, + resources, + client_id=smart_client_id, + jwks=smart_jwks, + bearer_token=bearer_token, + ) + + async def main(args: List[str]): parser = make_parser() args = parser.parse_args(args) @@ -233,45 +264,49 @@ async def main(args: List[str]): job_context = context.JobContext(root_phi.joinpath("context.json")) job_datetime = common.datetime_now() # grab timestamp before we do anything - if args.input_format == "i2b2": - config_loader = loaders.I2b2Loader(root_input, args.batch_size) - else: - config_loader = loaders.FhirNdjsonLoader( - root_input, - client_id=args.smart_client_id, - jwks=args.smart_jwks, - bearer_token=args.bearer_token, - since=args.since, - until=args.until, - ) - # Check which tasks are being run, allowing comma-separated values task_names = args.task and set(itertools.chain.from_iterable(t.split(",") for t in args.task)) task_filters = args.task_filter and list(itertools.chain.from_iterable(t.split(",") for t in args.task_filter)) selected_tasks = tasks.EtlTask.get_selected_tasks(task_names, task_filters) - # Pull down resources and run the MS tool on them - deid_dir = await load_and_deidentify(config_loader, selected_tasks) - - # Prepare config for jobs - config = JobConfig( - args.dir_input, - deid_dir.name, - args.dir_output, - args.dir_phi, - args.input_format, - args.output_format, - comment=args.comment, - batch_size=args.batch_size, - timestamp=job_datetime, - tasks=[t.name for t in selected_tasks], - ) - common.write_json(config.path_config(), config.as_json(), indent=4) - common.print_header("Configuration:") - print(json.dumps(config.as_json(), indent=4)) + # Grab a list of all required resource types for the tasks we are running + required_resources = set(t.resource for t in selected_tasks) + + # Create a client to talk to a FHIR server. + # This is useful even if we aren't doing a bulk export, because some resources like DocumentReference can still + # reference external resources on the server (like the document text). + # If we don't need this client (e.g. we're using local data and don't download any attachments), this is a no-op. + client = create_fhir_client(args, root_input, required_resources) + + async with client: + if args.input_format == "i2b2": + config_loader = loaders.I2b2Loader(root_input, args.batch_size) + else: + config_loader = loaders.FhirNdjsonLoader(root_input, client, since=args.since, until=args.until) + + # Pull down resources and run the MS tool on them + deid_dir = await load_and_deidentify(config_loader, required_resources) + + # Prepare config for jobs + config = JobConfig( + args.dir_input, + deid_dir.name, + args.dir_output, + args.dir_phi, + args.input_format, + args.output_format, + client, + comment=args.comment, + batch_size=args.batch_size, + timestamp=job_datetime, + tasks=[t.name for t in selected_tasks], + ) + common.write_json(config.path_config(), config.as_json(), indent=4) + common.print_header("Configuration:") + print(json.dumps(config.as_json(), indent=4)) - # Finally, actually run the meat of the pipeline! (Filtered down to requested tasks) - summaries = etl_job(config, selected_tasks) + # Finally, actually run the meat of the pipeline! (Filtered down to requested tasks) + summaries = await etl_job(config, selected_tasks) # Print results to the console common.print_header("Results:") diff --git a/cumulus/loaders/fhir/backend_service.py b/cumulus/fhir_client.py similarity index 77% rename from cumulus/loaders/fhir/backend_service.py rename to cumulus/fhir_client.py index 3ce38d27..0e96a399 100644 --- a/cumulus/loaders/fhir/backend_service.py +++ b/cumulus/fhir_client.py @@ -1,16 +1,15 @@ -"""Support for SMART App Launch Backend Services""" +"""HTTP client that talk to a FHIR server""" -import abc import re import sys import time import urllib.parse import uuid from json import JSONDecodeError -from typing import List, Optional +from typing import Iterable, Optional +import fhirclient.client import httpx -from fhirclient.client import FHIRClient from jwcrypto import jwk, jwt from cumulus import errors @@ -20,22 +19,42 @@ class FatalError(Exception): """An unrecoverable error""" -class Auth(abc.ABC): - """Abstracted authentication for a FHIR server""" +def _urljoin(base: str, path: str) -> str: + """Basically just urllib.parse.urljoin, but with some extra error checking""" + path_is_absolute = bool(urllib.parse.urlparse(path).netloc) + if path_is_absolute: + return path - @abc.abstractmethod - async def authorize(self, session: httpx.AsyncClient) -> None: + if not base: + print("You must provide a base FHIR server URL with --fhir-url", file=sys.stderr) + raise SystemExit(errors.FHIR_URL_MISSING) + return urllib.parse.urljoin(base, path) + + +class Auth: + """Abstracted authentication for a FHIR server. By default, does nothing.""" + + async def authorize(self, session: httpx.AsyncClient, reauthorize=False) -> None: """Authorize (or re-authorize) against the server""" + del session + + if reauthorize: + # Abort because we clearly need authentication tokens, but have not been given any parameters for them. + print( + "You must provide some authentication parameters (like --smart-client-id) to connect to a server.", + file=sys.stderr, + ) + raise SystemExit(errors.SMART_CREDENTIALS_MISSING) - @abc.abstractmethod def sign_headers(self, headers: Optional[dict]) -> dict: """Add signature token to request headers""" + return headers class JwksAuth(Auth): """Authentication with a JWK Set (typical backend service profile)""" - def __init__(self, server_root: str, client_id: str, jwks: dict, resources: List[str]): + def __init__(self, server_root: str, client_id: str, jwks: dict, resources: Iterable[str]): super().__init__() self._server_root = server_root self._client_id = client_id @@ -44,7 +63,7 @@ def __init__(self, server_root: str, client_id: str, jwks: dict, resources: List self._server = None self._token_endpoint = None - async def authorize(self, session: httpx.AsyncClient) -> None: + async def authorize(self, session: httpx.AsyncClient, reauthorize=False) -> None: """ Authenticates against a SMART FHIR server using the Backend Services profile. @@ -53,13 +72,13 @@ async def authorize(self, session: httpx.AsyncClient) -> None: # Have we authorized before? if self._token_endpoint is None: self._token_endpoint = await self._get_token_endpoint(session) - if self._server is not None and self._server.reauthorize(): + if reauthorize and self._server.reauthorize(): return # Else we must not have been issued a refresh token, let's just authorize from scratch below signed_jwt = self._make_signed_jwt() scope = " ".join([f"system/{resource}.read" for resource in self._resources]) - client = FHIRClient( + client = fhirclient.client.FHIRClient( settings={ "api_base": self._server_root, "app_id": self._client_id, @@ -100,7 +119,7 @@ async def _get_token_endpoint(self, session: httpx.AsyncClient) -> str: :returns: URL for the server's oauth2 token endpoint """ response = await session.get( - urllib.parse.urljoin(self._server_root, ".well-known/smart-configuration"), + _urljoin(self._server_root, ".well-known/smart-configuration"), headers={ "Accept": "application/json", }, @@ -161,7 +180,7 @@ def __init__(self, bearer_token: str): super().__init__() self._bearer_token = bearer_token - async def authorize(self, session: httpx.AsyncClient) -> None: + async def authorize(self, session: httpx.AsyncClient, reauthorize=False) -> None: pass def sign_headers(self, headers: Optional[dict]) -> dict: @@ -170,9 +189,11 @@ def sign_headers(self, headers: Optional[dict]) -> dict: return headers -class BackendServiceServer: +class FhirClient: """ - Manages authentication and requests for a server that supports the Backend Service SMART profile. + Manages authentication and requests for a FHIR server. + + Supports a few different auth methods, but most notably the Backend Service SMART profile. Use this as a context manager (like you would an httpx.AsyncClient instance). @@ -180,7 +201,12 @@ class BackendServiceServer: """ def __init__( - self, url: str, resources: List[str], client_id: str = None, jwks: dict = None, bearer_token: str = None + self, + url: Optional[str], + resources: Iterable[str], + client_id: str = None, + jwks: dict = None, + bearer_token: str = None, ): """ Initialize and authorize a BackendServiceServer context manager. @@ -191,20 +217,23 @@ def __init__( :param jwks: content of a JWK Set file, containing the private key for the registered public key :param bearer_token: a bearer token, containing the secret key to sign https requests (instead of JWKS) """ + # Allow url to be None in the case we are a fully local ETL run, and this class is basically a no-op self._base_url = url # all requests are relative to this URL - if not self._base_url.endswith("/"): + if self._base_url and not self._base_url.endswith("/"): self._base_url += "/" # The base URL may not be the server root (like it may be a Group export URL). Let's find the root. self._server_root = self._base_url - self._server_root = re.sub(r"/Patient/$", "/", self._server_root) - self._server_root = re.sub(r"/Group/[^/]+/$", "/", self._server_root) + if self._server_root: + self._server_root = re.sub(r"/Patient/$", "/", self._server_root) + self._server_root = re.sub(r"/Group/[^/]+/$", "/", self._server_root) + self._auth = self._make_auth(resources, client_id, jwks, bearer_token) self._session: Optional[httpx.AsyncClient] = None async def __aenter__(self): # Limit the number of connections open at once, because EHRs tend to be very busy. limits = httpx.Limits(max_connections=5) - self._session = httpx.AsyncClient(limits=limits) + self._session = httpx.AsyncClient(limits=limits, timeout=300) # five minutes to be generous await self._auth.authorize(self._session) return self @@ -212,32 +241,12 @@ async def __aexit__(self, exc_type, exc_value, traceback): if self._session: await self._session.aclose() - def _make_auth(self, resources: List[str], client_id: str, jwks: dict, bearer_token: str) -> Auth: - """Determine which auth method to use based on user provided arguments""" - - if bearer_token and (client_id or jwks): - print("--bearer-token cannot be used with --smart-client-id or --smart-jwks", file=sys.stderr) - raise SystemExit(errors.ARGS_CONFLICT) - - if bearer_token: - return BearerAuth(bearer_token) - - # Confirm that all required SMART arguments were provided - error_list = [] - if not client_id: - error_list.append("You must provide a client ID with --smart-client-id to connect to a SMART FHIR server.") - if jwks is None: - error_list.append("You must provide a JWKS file with --smart-jwks to connect to a SMART FHIR server.") - if error_list: - print("\n".join(error_list), file=sys.stderr) - raise SystemExit(errors.SMART_CREDENTIALS_MISSING) - - return JwksAuth(self._server_root, client_id, jwks, resources) - async def request(self, method: str, path: str, headers: dict = None, stream: bool = False) -> httpx.Response: """ Issues an HTTP request. + The default Accept type is application/fhir+json, but can be overridden by a provided header. + This is a lightly modified version of FHIRServer._get(), but additionally supports streaming and reauthorization. @@ -249,7 +258,7 @@ async def request(self, method: str, path: str, headers: dict = None, stream: bo :param stream: whether to stream content in or load it all into memory at once :returns: The response object """ - url = urllib.parse.urljoin(self._base_url, path) + url = _urljoin(self._base_url, path) final_headers = { "Accept": "application/fhir+json", @@ -262,7 +271,7 @@ async def request(self, method: str, path: str, headers: dict = None, stream: bo # Check if our access token expired and thus needs to be refreshed if response.status_code == 401: - await self._auth.authorize(self._session) + await self._auth.authorize(self._session, reauthorize=True) if stream: await response.aclose() response = await self._request_with_signed_headers(method, url, final_headers, stream=stream) @@ -300,6 +309,28 @@ async def request(self, method: str, path: str, headers: dict = None, stream: bo # ################################################################################################################### + def _make_auth(self, resources: Iterable[str], client_id: str, jwks: dict, bearer_token: str) -> Auth: + """Determine which auth method to use based on user provided arguments""" + valid_jwks = jwks is not None + + if bearer_token and (client_id or valid_jwks): + print("--bearer-token cannot be used with --smart-client-id or --smart-jwks", file=sys.stderr) + raise SystemExit(errors.ARGS_CONFLICT) + + if bearer_token: + return BearerAuth(bearer_token) + + if client_id and valid_jwks: + return JwksAuth(self._server_root, client_id, jwks, resources) + elif client_id or valid_jwks: + print( + "You must provide both --smart-client-id and --smart-jwks to connect to a SMART FHIR server.", + file=sys.stderr, + ) + raise SystemExit(errors.SMART_CREDENTIALS_MISSING) + + return Auth() + async def _request_with_signed_headers( self, method: str, url: str, headers: dict = None, **kwargs ) -> httpx.Response: @@ -311,6 +342,9 @@ async def _request_with_signed_headers( :param headers: header dictionary :returns: The response object """ + if not self._session: + raise RuntimeError("FhirClient must be used as a context manager") + headers = self._auth.sign_headers(headers) request = self._session.build_request(method, url, headers=headers) # Follow redirects by default -- some EHRs definitely use them for bulk download files, diff --git a/cumulus/loaders/fhir/bulk_export.py b/cumulus/loaders/fhir/bulk_export.py index 44d4828b..11533f55 100644 --- a/cumulus/loaders/fhir/bulk_export.py +++ b/cumulus/loaders/fhir/bulk_export.py @@ -9,7 +9,7 @@ import httpx from cumulus import common -from cumulus.loaders.fhir.backend_service import BackendServiceServer, FatalError +from cumulus.fhir_client import FatalError, FhirClient class BulkExporter: @@ -24,19 +24,19 @@ class BulkExporter: _TIMEOUT_THRESHOLD = 60 * 60 * 24 # a day, which is probably an overly generous timeout def __init__( - self, server: BackendServiceServer, resources: List[str], destination: str, since: str = None, until: str = None + self, client: FhirClient, resources: List[str], destination: str, since: str = None, until: str = None ): """ Initialize a bulk exporter (but does not start an export). - :param server: a server instance ready to make requests + :param client: a client ready to make requests :param resources: a list of resource names to export :param destination: a local folder to store all the files :param since: start date for export :param until: end date for export """ super().__init__() - self._server = server + self._client = client self._resources = resources self._destination = destination self._total_wait_time = 0 # in seconds, across all our requests @@ -128,7 +128,7 @@ async def _request_with_delay( :returns: the HTTP response """ while self._total_wait_time < self._TIMEOUT_THRESHOLD: - response = await self._server.request(method, path, headers=headers) + response = await self._client.request(method, path, headers=headers) if response.status_code == target_status_code: return response @@ -203,7 +203,7 @@ async def _download_ndjson_file(self, url: str, filename: str) -> None: :param url: URL location of file to download :param filename: local path to write data to """ - response = await self._server.request("GET", url, headers={"Accept": "application/fhir+ndjson"}, stream=True) + response = await self._client.request("GET", url, headers={"Accept": "application/fhir+ndjson"}, stream=True) try: with open(filename, "w", encoding="utf8") as file: async for block in response.aiter_text(): diff --git a/cumulus/loaders/fhir/fhir_ndjson.py b/cumulus/loaders/fhir/fhir_ndjson.py index dbcee5d1..f23325d7 100644 --- a/cumulus/loaders/fhir/fhir_ndjson.py +++ b/cumulus/loaders/fhir/fhir_ndjson.py @@ -6,8 +6,8 @@ from typing import List from cumulus import common, errors, store +from cumulus.fhir_client import FatalError, FhirClient from cumulus.loaders import base -from cumulus.loaders.fhir.backend_service import BackendServiceServer, FatalError from cumulus.loaders.fhir.bulk_export import BulkExporter @@ -22,34 +22,18 @@ class FhirNdjsonLoader(base.Loader): def __init__( self, root: store.Root, - client_id: str = None, - jwks: str = None, - bearer_token: str = None, + client: FhirClient, since: str = None, until: str = None, ): """ :param root: location to load ndjson from - :param client_id: client ID for a SMART server - :param jwks: path to a JWKS file for a SMART server - :param bearer_token: path to a file with a bearer token for a FHIR server + :param client: client ready to talk to a FHIR server :param since: export start date for a FHIR server :param until: export end date for a FHIR server """ super().__init__(root) - - try: - try: - self.client_id = common.read_text(client_id).strip() if client_id else None - except FileNotFoundError: - self.client_id = client_id - - self.jwks = common.read_json(jwks) if jwks else None - self.bearer_token = common.read_text(bearer_token).strip() if bearer_token else None - except OSError as exc: - print(exc, file=sys.stderr) - raise SystemExit(errors.ARGS_INVALID) from exc - + self.client = client self.since = since self.until = until @@ -58,7 +42,7 @@ async def load_all(self, resources: List[str]) -> tempfile.TemporaryDirectory: if self.root.protocol in ["http", "https"]: return await self._load_from_bulk_export(resources) - if self.client_id or self.jwks or self.bearer_token or self.since or self.until: + if self.since or self.until: print("You provided FHIR bulk export parameters but did not provide a FHIR server", file=sys.stderr) raise SystemExit(errors.ARGS_CONFLICT) @@ -84,11 +68,8 @@ async def _load_from_bulk_export(self, resources: List[str]) -> tempfile.Tempora tmpdir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with try: - async with BackendServiceServer( - self.root.path, resources, client_id=self.client_id, jwks=self.jwks, bearer_token=self.bearer_token - ) as server: - bulk_exporter = BulkExporter(server, resources, tmpdir.name, self.since, self.until) - await bulk_exporter.export() + bulk_exporter = BulkExporter(self.client, resources, tmpdir.name, self.since, self.until) + await bulk_exporter.export() except FatalError as exc: print(str(exc), file=sys.stderr) raise SystemExit(errors.BULK_EXPORT_FAILED) from exc diff --git a/cumulus/tasks.py b/cumulus/tasks.py index 5fe7544e..876bcf6a 100644 --- a/cumulus/tasks.py +++ b/cumulus/tasks.py @@ -5,7 +5,7 @@ import os import re import sys -from typing import Iterable, Iterator, List, Set, Type, TypeVar, Union +from typing import AsyncIterable, AsyncIterator, Iterable, Iterator, List, Set, Type, TypeVar, Union import pandas @@ -15,7 +15,7 @@ AnyTask = TypeVar("AnyTask", bound="EtlTask") -def _batch_slice(iterable: Iterable[Union[List[T], T]], n: int) -> Iterator[T]: +async def _batch_slice(iterable: AsyncIterable[Union[List[T], T]], n: int) -> AsyncIterator[T]: """ Returns the first n elements of iterable, flattening lists, but including an entire list if we would end in middle. @@ -24,9 +24,10 @@ def _batch_slice(iterable: Iterable[Union[List[T], T]], n: int) -> Iterator[T]: Note that this will only flatten elements that are actual Python lists (isinstance list is True) """ count = 0 - for item in iterable: + async for item in iterable: if isinstance(item, list): - yield from item + for x in item: + yield x count += len(item) else: yield item @@ -36,7 +37,14 @@ def _batch_slice(iterable: Iterable[Union[List[T], T]], n: int) -> Iterator[T]: return -def _batch_iterate(iterable: Iterable[Union[List[T], T]], size: int) -> Iterator[Iterator[T]]: +async def _async_chain(first: T, rest: AsyncIterator[T]) -> AsyncIterator[T]: + """An asynchronous version of itertools.chain([first], rest)""" + yield first + async for x in rest: + yield x + + +async def _batch_iterate(iterable: AsyncIterable[Union[List[T], T]], size: int) -> AsyncIterator[AsyncIterator[T]]: """ Yields sub-iterators, each roughly {size} elements from iterable. @@ -59,14 +67,17 @@ def _batch_iterate(iterable: Iterable[Union[List[T], T]], size: int) -> Iterator if size < 1: raise ValueError("Must iterate by at least a batch of 1") - true_iterable = iter(iterable) # in case it's actually a list (we want to iterate only once through) + # aiter() and anext() were added in python 3.10 + # pylint: disable=unnecessary-dunder-call + + true_iterable = iterable.__aiter__() # get a real once-through iterable (we want to iterate only once) while True: iter_slice = _batch_slice(true_iterable, size) try: - peek = next(iter_slice) - except StopIteration: + peek = await iter_slice.__anext__() + except StopAsyncIteration: return # we're done! - yield itertools.chain([peek], iter_slice) + yield _async_chain(peek, iter_slice) class EtlTask: @@ -182,7 +193,7 @@ def __init__(self, task_config: config.JobConfig, scrubber: deid.Scrubber): self.task_config = task_config self.scrubber = scrubber - def run(self) -> config.JobSummary: + async def run(self) -> config.JobSummary: """ Executes a single task and returns the summary. @@ -198,9 +209,10 @@ def run(self) -> config.JobSummary: # At this point we have a giant iterable of de-identified FHIR objects, ready to be written out. # We want to batch them up, to allow resuming from interruptions more easily. - for index, batch in enumerate(_batch_iterate(entries, self.task_config.batch_size)): + index = 0 + async for batch in _batch_iterate(entries, self.task_config.batch_size): # Stuff de-identified FHIR json into one big pandas DataFrame - dataframe = pandas.DataFrame(batch) + dataframe = pandas.DataFrame([row async for row in batch]) # Checkpoint scrubber data before writing to the store, because if we get interrupted, it's safer to have an # updated codebook with no data than data with an inaccurate codebook. @@ -210,6 +222,7 @@ def run(self) -> config.JobSummary: formatter.write_records(dataframe, index) print(f" {summary.success:,} processed for {self.name}") + index += 1 # All data is written, now do any final cleanup the formatter wants formatter.finalize() @@ -238,7 +251,7 @@ def read_ndjson(self) -> Iterator[dict]: for line in f: yield json.loads(line) - def read_entries(self) -> Iterator[Union[List[dict], dict]]: + async def read_entries(self) -> AsyncIterator[Union[List[dict], dict]]: """ Reads input entries for the job. @@ -248,7 +261,8 @@ def read_entries(self) -> Iterator[Union[List[dict], dict]]: the elements of which will be guaranteed to all be in the same output batch. See comments for EtlTask.group_field for why you might do this. """ - return filter(self.scrubber.scrub_resource, self.read_ndjson()) + for x in filter(self.scrubber.scrub_resource, self.read_ndjson()): + yield x ########################################################################################## @@ -332,7 +346,7 @@ def is_ed_coding(cls, coding): """Returns true if this is a coding for an emergency department note""" return coding.get("code") in cls.ED_CODES.get(coding.get("system"), {}) - def read_entries(self) -> Iterator[Union[List[dict], dict]]: + async def read_entries(self) -> AsyncIterator[Union[List[dict], dict]]: """Passes physician notes through NLP and returns any symptoms found""" phi_root = store.Root(self.task_config.dir_phi, create=True) @@ -340,7 +354,7 @@ def read_entries(self) -> Iterator[Union[List[dict], dict]]: # Check that the note is one of our special allow-listed types (we do this here rather than on the output # side to save needing to run everything through NLP). # We check both type and category for safety -- we aren't sure yet how EHRs are using these fields. - codings = docref.get("category", {}).get("coding", []) + codings = list(itertools.chain.from_iterable([cat.get("coding", []) for cat in docref.get("category", [])])) codings += docref.get("type", {}).get("coding", []) is_er_note = any(self.is_ed_coding(x) for x in codings) if not is_er_note: @@ -352,4 +366,4 @@ def read_entries(self) -> Iterator[Union[List[dict], dict]]: # Yield the whole set of symptoms at once, to allow for more easily replacing previous a set of symptoms. # This way we don't need to worry about symptoms from the same note crossing batch boundaries. # The Format class will replace all existing symptoms from this note at once (because we set group_field). - yield ctakes.covid_symptoms_extract(phi_root, docref) + yield await ctakes.covid_symptoms_extract(self.task_config.client, phi_root, docref) diff --git a/docs/howtos/epic-tips.md b/docs/howtos/epic-tips.md index ef16b222..7a63fc15 100644 --- a/docs/howtos/epic-tips.md +++ b/docs/howtos/epic-tips.md @@ -2,6 +2,20 @@ # Epic Tips & Tricks +## Frequent Bulk Exporting + +You may encounter this error: +`Error processing Bulk Data Kickoff request: Request not allowed: The Client requested this Group too recently.`. + +If so, you will want to update the `FHIR_BULK_CLIENT_REQUEST_WINDOW_TBL` to a longer time. +The default is 24 hours. + +## Long IDs + +In rare cases, Epic's bulk FHIR export can generate IDs that are longer than the mandated 64-character limit. +Cumulus ETL itself will not mind this, but if you use another bulk export client, you may find that it complains. +If so, you have to reach out to Epic to cap it at 64 characters. + ## Batch Updates Epic has not yet (as of early 2023) implemented the `_since` or `_typeFilter` parameters for bulk exports. diff --git a/docs/howtos/run-cumulus-etl.md b/docs/howtos/run-cumulus-etl.md index ab4ec353..ab8444de 100644 --- a/docs/howtos/run-cumulus-etl.md +++ b/docs/howtos/run-cumulus-etl.md @@ -328,6 +328,7 @@ The [SMART Bulk Data Client](https://github.com/smart-on-fhir/bulk-data-client) options than Cumulus ETL's built-in exporter offers. If you use this tool, pass Cumulus ETL the folder that holds the downloaded data as the input path. +And you may need to specify `--fhir-url=` so that external document notes can be downloaded. ## EHR-Specific Advice diff --git a/tests/test_bulk_export.py b/tests/test_bulk_export.py index f35ae175..b487f9ac 100644 --- a/tests/test_bulk_export.py +++ b/tests/test_bulk_export.py @@ -2,7 +2,6 @@ import contextlib import io -import os import tempfile import time import unittest @@ -17,7 +16,7 @@ from jwcrypto import jwk, jwt from cumulus import common, errors, etl, loaders, store -from cumulus.loaders.fhir.backend_service import BackendServiceServer, FatalError +from cumulus.fhir_client import FatalError, FhirClient from cumulus.loaders.fhir.bulk_export import BulkExporter @@ -41,40 +40,69 @@ def make_response(status_code=200, json=None, text=None, reason=None, headers=No ) -class TestBulkLoader(unittest.IsolatedAsyncioTestCase): +class TestBulkEtl(unittest.IsolatedAsyncioTestCase): """ Test case for bulk export support in the etl pipeline and ndjson loader. - i.e. tests for fhir_ndjson.py + i.e. tests for etl.py & fhir_ndjson.py This does no actual bulk loading. """ def setUp(self): super().setUp() - self.root = store.Root("http://localhost:9999") - self.jwks_file = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with self.jwks_path = self.jwks_file.name self.jwks_file.write(b'{"fake":"jwks"}') self.jwks_file.flush() - # Mock out the backend service and bulk export code by default. We don't care about actually doing any + # Mock out the bulk export code by default. We don't care about actually doing any # bulk work in this test case, just confirming the flow. - - server_patcher = mock.patch("cumulus.loaders.fhir.fhir_ndjson.BackendServiceServer") - self.addCleanup(server_patcher.stop) - self.mock_server = server_patcher.start() - - exporter_patcher = mock.patch("cumulus.loaders.fhir.fhir_ndjson.BulkExporter") + exporter_patcher = mock.patch("cumulus.loaders.fhir.fhir_ndjson.BulkExporter", spec=BulkExporter) self.addCleanup(exporter_patcher.stop) - self.mock_exporter = exporter_patcher.start() + self.mock_exporter_class = exporter_patcher.start() + self.mock_exporter = mock.AsyncMock() + self.mock_exporter_class.return_value = self.mock_exporter + @mock.patch("cumulus.etl.fhir_client.FhirClient") @mock.patch("cumulus.etl.loaders.FhirNdjsonLoader") - async def test_etl_passes_args(self, mock_loader): + async def test_etl_passes_args(self, mock_loader, mock_client): """Verify that we are passed the client ID and JWKS from the command line""" mock_loader.side_effect = ValueError # just to stop the etl pipeline once we get this far + with tempfile.NamedTemporaryFile(buffering=0) as bt_file: + bt_file.write(b"bt") + + with self.assertRaises(ValueError): + await etl.main( + [ + "http://localhost:9999", + "/tmp/output", + "/tmp/phi", + "--skip-init-checks", + "--input-format=ndjson", + "--smart-client-id=x", + f"--smart-jwks={self.jwks_path}", + f"--bearer-token={bt_file.name}", + "--since=2018", + "--until=2020", + ] + ) + + self.assertEqual(1, mock_client.call_count) + self.assertEqual("x", mock_client.call_args[1]["client_id"]) + self.assertEqual({"fake": "jwks"}, mock_client.call_args[1]["jwks"]) + self.assertEqual("bt", mock_client.call_args[1]["bearer_token"]) + self.assertEqual(1, mock_loader.call_count) + self.assertEqual("2018", mock_loader.call_args[1]["since"]) + self.assertEqual("2020", mock_loader.call_args[1]["until"]) + + @mock.patch("cumulus.etl.fhir_client.FhirClient") + async def test_reads_client_id_from_file(self, mock_client): + """Verify that we try to read a client ID from a file.""" + mock_client.side_effect = ValueError # just to stop the etl pipeline once we get this far + + # First, confirm string is used directly if file doesn't exist with self.assertRaises(ValueError): await etl.main( [ @@ -82,92 +110,127 @@ async def test_etl_passes_args(self, mock_loader): "/tmp/output", "/tmp/phi", "--skip-init-checks", - "--input-format=ndjson", - "--smart-client-id=x", - "--smart-jwks=y", - "--bearer-token=bt", - "--since=2018", - "--until=2020", + "--smart-client-id=/direct-string", ] ) - - self.assertEqual(1, mock_loader.call_count) - self.assertEqual("x", mock_loader.call_args[1]["client_id"]) - self.assertEqual("y", mock_loader.call_args[1]["jwks"]) - self.assertEqual("bt", mock_loader.call_args[1]["bearer_token"]) - self.assertEqual("2018", mock_loader.call_args[1]["since"]) - self.assertEqual("2020", mock_loader.call_args[1]["until"]) - - def test_reads_client_id_from_file(self): - """Verify that we require both a client ID and a JWK Set.""" - # First, confirm string is used directly if file doesn't exist - loader = loaders.FhirNdjsonLoader(self.root, client_id="/direct-string") - self.assertEqual("/direct-string", loader.client_id) + self.assertEqual("/direct-string", mock_client.call_args[1]["client_id"]) # Now read from a file that exists - with tempfile.NamedTemporaryFile() as file: + with tempfile.NamedTemporaryFile(buffering=0) as file: file.write(b"\ninside-file\n") - file.flush() - loader = loaders.FhirNdjsonLoader(self.root, client_id=file.name) - self.assertEqual("inside-file", loader.client_id) + with self.assertRaises(ValueError): + await etl.main( + [ + "http://localhost:9999", + "/tmp/output", + "/tmp/phi", + "--skip-init-checks", + f"--smart-client-id={file.name}", + ] + ) + self.assertEqual("inside-file", mock_client.call_args[1]["client_id"]) - def test_reads_bearer_token(self): + @mock.patch("cumulus.etl.fhir_client.FhirClient") + async def test_reads_bearer_token(self, mock_client): """Verify that we read the bearer token file""" - with tempfile.NamedTemporaryFile() as file: + mock_client.side_effect = ValueError # just to stop the etl pipeline once we get this far + + with tempfile.NamedTemporaryFile(buffering=0) as file: file.write(b"\ninside-file\n") - file.flush() - loader = loaders.FhirNdjsonLoader(self.root, bearer_token=file.name) - self.assertEqual("inside-file", loader.bearer_token) + with self.assertRaises(ValueError): + await etl.main( + [ + "http://localhost:9999", + "/tmp/output", + "/tmp/phi", + "--skip-init-checks", + f"--bearer-token={file.name}", + ] + ) + self.assertEqual("inside-file", mock_client.call_args[1]["bearer_token"]) - async def test_export_flow(self): - """ - Verify that we make all the right calls into the bulk export helper classes. + @mock.patch("cumulus.etl.fhir_client.FhirClient") + async def test_fhir_url(self, mock_client): + """Verify that we handle the user provided --fhir-client correctly""" + mock_client.side_effect = ValueError # just to stop the etl pipeline once we get this far - This is a little lower-level than I would normally test, but the benefit of ensuring this flow here is that - the other test cases can focus on just the helper classes and trust that the flow works, without us needing to - do the full flow each time. - """ - mock_server_instance = mock.AsyncMock() - self.mock_server.return_value = mock_server_instance - mock_exporter_instance = mock.AsyncMock() - self.mock_exporter.return_value = mock_exporter_instance + # Confirm that we don't allow conflicting URLs + with self.assertRaises(SystemExit): + await etl.main( + [ + "http://localhost:9999", + "/tmp/output", + "/tmp/phi", + "--skip-init-checks", + "--fhir-url=https://example.com/hello", + ] + ) - loader = loaders.FhirNdjsonLoader(self.root, client_id="foo", jwks=self.jwks_path) - await loader.load_all(["Condition", "Encounter"]) + # But a subset --fhir-url is fine + with self.assertRaises(ValueError): + await etl.main( + [ + "https://example.com/hello/Group/1234", + "/tmp/output", + "/tmp/phi", + "--skip-init-checks", + "--fhir-url=https://example.com/hello", + ] + ) + self.assertEqual("https://example.com/hello/Group/1234", mock_client.call_args[0][0]) - expected_resources = [ - "Condition", - "Encounter", - ] + # Now do a normal use of --fhir-url + mock_client.side_effect = ValueError # just to stop the etl pipeline once we get this far + with self.assertRaises(ValueError): + await etl.main( + [ + "/tmp/input", + "/tmp/output", + "/tmp/phi", + "--skip-init-checks", + "--fhir-url=https://example.com/hello", + ] + ) + self.assertEqual("https://example.com/hello", mock_client.call_args[0][0]) - self.assertEqual( - [ - mock.call( - self.root.path, expected_resources, client_id="foo", jwks={"fake": "jwks"}, bearer_token=None - ), - ], - self.mock_server.call_args_list, - ) + @mock.patch("cumulus.etl.fhir_client.FhirClient") + async def test_export_flow(self, mock_client): + """ + Verify that we make the right calls down as far as the bulk export helper classes, with the right resources. + """ + self.mock_exporter.export.side_effect = ValueError # stop us when we get this far, but also confirm we call it - self.assertEqual(1, self.mock_exporter.call_count) - self.assertEqual(expected_resources, self.mock_exporter.call_args[0][1]) + with self.assertRaises(ValueError): + await etl.main( + [ + "http://localhost:9999", + "/tmp/output", + "/tmp/phi", + "--skip-init-checks", + "--task=condition,encounter", + ] + ) - self.assertEqual(1, mock_exporter_instance.export.call_count) + expected_resources = {"Condition", "Encounter"} + self.assertEqual(1, mock_client.call_count) + self.assertEqual(expected_resources, mock_client.call_args[0][1]) + self.assertEqual(1, self.mock_exporter_class.call_count) + self.assertEqual(expected_resources, set(self.mock_exporter_class.call_args[0][1])) async def test_fatal_errors_are_fatal(self): """Verify that when a FatalError is raised, we do really quit""" - self.mock_server.side_effect = FatalError + self.mock_exporter.export.side_effect = FatalError with self.assertRaises(SystemExit) as cm: - await loaders.FhirNdjsonLoader(self.root, client_id="foo", jwks=self.jwks_path).load_all(["Patient"]) + await loaders.FhirNdjsonLoader(store.Root("http://localhost:9999"), mock.AsyncMock()).load_all(["Patient"]) - self.assertEqual(1, self.mock_server.call_count) + self.assertEqual(1, self.mock_exporter.export.call_count) self.assertEqual(errors.BULK_EXPORT_FAILED, cm.exception.code) @ddt.ddt @freezegun.freeze_time("Sep 15th, 2021 1:23:45") -@mock.patch("cumulus.loaders.fhir.backend_service.uuid.uuid4", new=lambda: "1234") +@mock.patch("cumulus.fhir_client.uuid.uuid4", new=lambda: "1234") class TestBulkServer(unittest.IsolatedAsyncioTestCase): """ Test case for bulk export server oauth2 / request support. @@ -226,7 +289,7 @@ def setUp(self): # Set up mocks for fhirclient (we don't need to test its oauth code by mocking server responses there) self.mock_client = mock.MagicMock() # FHIRClient instance self.mock_server = self.mock_client.server # FHIRServer instance - client_patcher = mock.patch("cumulus.loaders.fhir.backend_service.FHIRClient") + client_patcher = mock.patch("cumulus.fhir_client.fhirclient.client.FHIRClient") self.addCleanup(client_patcher.stop) self.mock_client_class = client_patcher.start() # FHIRClient class self.mock_client_class.return_value = self.mock_client @@ -239,30 +302,41 @@ def mock_session(server, *args, **kwargs): async def test_required_arguments(self): """Verify that we require both a client ID and a JWK Set""" - # No SMART args at all - with self.assertRaises(SystemExit): - async with BackendServiceServer(self.server_url, []): - pass + # Deny any actual requests during this test + self.respx_mock.get(f"{self.server_url}/test").respond(status_code=401) + + # Simple helper to open and make a call on client. + async def use_client(request=False, code=None, url=self.server_url, **kwargs): + try: + async with FhirClient(url, [], **kwargs) as client: + if request: + await client.request("GET", "test") + except SystemExit as exc: + if code is None: + raise + self.assertEqual(code, exc.code) + + # No SMART args at all doesn't cause any problem, if we don't make calls + await use_client() + + # No SMART args at all will raise though if we do make a call + await use_client(code=errors.SMART_CREDENTIALS_MISSING, request=True) + + # No client ID + await use_client(code=errors.FHIR_URL_MISSING, request=True, url=None) # No JWKS - with self.assertRaises(SystemExit): - async with BackendServiceServer(self.server_url, [], client_id="foo"): - pass + await use_client(code=errors.SMART_CREDENTIALS_MISSING, client_id="foo") # No client ID - with self.assertRaises(SystemExit): - async with BackendServiceServer(self.server_url, [], jwks=self.jwks): - pass + await use_client(code=errors.SMART_CREDENTIALS_MISSING, jwks=self.jwks) # Works fine if both given - async with BackendServiceServer(self.server_url, [], client_id="foo", jwks=self.jwks): - pass + await use_client(client_id="foo", jwks=self.jwks) async def test_auth_initial_authorize(self): """Verify that we authorize correctly upon class initialization""" - async with BackendServiceServer( - self.server_url, ["Condition", "Patient"], client_id=self.client_id, jwks=self.jwks - ): + async with FhirClient(self.server_url, ["Condition", "Patient"], client_id=self.client_id, jwks=self.jwks): pass # Check initialization of FHIRClient @@ -292,7 +366,7 @@ async def test_auth_with_bearer_token(self): headers={"Authorization": "Bearer fob"}, ) - async with BackendServiceServer(self.server_url, ["Condition", "Patient"], bearer_token="fob") as server: + async with FhirClient(self.server_url, ["Condition", "Patient"], bearer_token="fob") as server: await server.request("GET", "foo") async def test_get_with_new_header(self): @@ -300,7 +374,7 @@ async def test_get_with_new_header(self): # This is mostly confirming that we call mocks correctly, but that's important since we're mocking out all # of fhirclient. Since we do that, we need to confirm we're driving it well. - async with BackendServiceServer(self.server_url, [], client_id=self.client_id, jwks=self.jwks) as server: + async with FhirClient(self.server_url, [], client_id=self.client_id, jwks=self.jwks) as server: with self.mock_session(server) as mock_session: # With new header and stream await server.request("GET", "foo", headers={"Test": "Value"}, stream=True) @@ -330,7 +404,7 @@ async def test_get_with_new_header(self): async def test_get_with_overriden_header(self): """Verify that we issue a GET correctly for the happy path""" - async with BackendServiceServer(self.server_url, [], client_id=self.client_id, jwks=self.jwks) as server: + async with FhirClient(self.server_url, [], client_id=self.client_id, jwks=self.jwks) as server: with self.mock_session(server) as mock_session: # With overriding a header and default stream (False) await server.request("GET", "bar", headers={"Accept": "text/plain"}) @@ -366,7 +440,7 @@ async def test_get_with_overriden_header(self): ) async def test_jwks_without_suitable_key(self, bad_jwks): with self.assertRaisesRegex(FatalError, "No private ES384 or RS384 key found"): - async with BackendServiceServer(self.server_url, [], client_id=self.client_id, jwks=bad_jwks): + async with FhirClient(self.server_url, [], client_id=self.client_id, jwks=bad_jwks): pass @ddt.data( @@ -394,7 +468,7 @@ async def test_bad_smart_config(self, bad_config_override): ) with self.assertRaisesRegex(FatalError, "does not support the client-confidential-asymmetric protocol"): - async with BackendServiceServer(self.server_url, [], client_id=self.client_id, jwks=self.jwks): + async with FhirClient(self.server_url, [], client_id=self.client_id, jwks=self.jwks): pass async def test_authorize_error_with_response(self): @@ -404,19 +478,19 @@ async def test_authorize_error_with_response(self): error.response.json.return_value = {"error_description": "Ouch!"} self.mock_client.authorize.side_effect = error with self.assertRaisesRegex(FatalError, "Could not authenticate with the FHIR server: Ouch!"): - async with BackendServiceServer(self.server_url, [], client_id=self.client_id, jwks=self.jwks): + async with FhirClient(self.server_url, [], client_id=self.client_id, jwks=self.jwks): pass async def test_authorize_error_without_response(self): """Verify that we translate authorize non-response errors into FatalErrors.""" self.mock_client.authorize.side_effect = Exception("no memory") with self.assertRaisesRegex(FatalError, "Could not authenticate with the FHIR server: no memory"): - async with BackendServiceServer(self.server_url, [], client_id=self.client_id, jwks=self.jwks): + async with FhirClient(self.server_url, [], client_id=self.client_id, jwks=self.jwks): pass async def test_get_error_401(self): """Verify that an expired token is refreshed.""" - async with BackendServiceServer(self.server_url, [], client_id=self.client_id, jwks=self.jwks) as server: + async with FhirClient(self.server_url, [], client_id=self.client_id, jwks=self.jwks) as server: # Check that we correctly tried to re-authenticate with self.mock_session(server) as mock_session: mock_session.send.side_effect = [make_response(status_code=401), make_response()] @@ -431,7 +505,7 @@ async def test_get_error_401(self): async def test_get_error_429(self): """Verify that 429 errors are passed through and not treated as exceptions.""" - async with BackendServiceServer(self.server_url, [], client_id=self.client_id, jwks=self.jwks) as server: + async with FhirClient(self.server_url, [], client_id=self.client_id, jwks=self.jwks) as server: # Confirm 429 passes with self.mock_session(server, status_code=429): response = await server.request("GET", "foo") @@ -450,7 +524,7 @@ async def test_get_error_429(self): ) async def test_get_error_other(self, response_args): """Verify that other http errors are FatalErrors.""" - async with BackendServiceServer(self.server_url, [], client_id=self.client_id, jwks=self.jwks) as server: + async with FhirClient(self.server_url, [], client_id=self.client_id, jwks=self.jwks) as server: with self.mock_session(server, status_code=500, **response_args): with self.assertRaisesRegex(FatalError, "testmsg"): await server.request("GET", "foo") @@ -672,6 +746,7 @@ async def test_delete_if_interrupted(self): ) +@mock.patch("cumulus.deid.codebook.secrets.token_hex", new=lambda x: "1234") # just to not waste entropy class TestBulkExportEndToEnd(unittest.IsolatedAsyncioTestCase): """ Test case for doing an entire bulk export loop, without mocking python code. @@ -803,16 +878,28 @@ def set_up_requests(self, respx_mock): status_code=202, ) + @mock.patch("cumulus.deid.codebook.secrets.token_hex", new=lambda x: "1234") @responses.mock.activate(assert_all_requests_are_fired=True) async def test_successful_bulk_export(self): """Verify a happy path bulk export, from toe to tip""" - loader = loaders.FhirNdjsonLoader(self.root, client_id=self.client_id, jwks=self.jwks_path) - - with respx.mock(assert_all_called=True) as respx_mock: - self.set_up_requests(respx_mock) - tmpdir = await loader.load_all(["Patient"]) + with tempfile.TemporaryDirectory() as tmpdir: + with respx.mock(assert_all_called=True) as respx_mock: + self.set_up_requests(respx_mock) + + await etl.main( + [ + self.root.path, + f"{tmpdir}/output", + f"{tmpdir}/phi", + "--skip-init-checks", + "--output-format=ndjson", + "--task=patient", + f"--smart-client-id={self.client_id}", + f"--smart-jwks={self.jwks_path}", + ] + ) - self.assertEqual( - {"id": "testPatient1", "resourceType": "Patient"}, - common.read_json(os.path.join(tmpdir.name, "Patient.000.ndjson")), - ) + self.assertEqual( + {"id": "4342abf315cf6f243e11f4d460303e36c6c3663a25c91cc6b1a8002476c850dd", "resourceType": "Patient"}, + common.read_json(f"{tmpdir}/output/patient/patient.000.ndjson"), + ) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index facc0da8..7037421e 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -4,23 +4,26 @@ import shutil import tempfile import unittest +from typing import AsyncIterator, List from unittest import mock import ddt +import respx -from cumulus import common, config, deid, errors, tasks +from cumulus import common, config, deid, errors, fhir_client, tasks from tests.ctakesmock import CtakesMixin from tests import i2b2_mock_data @ddt.ddt -class TestTasks(CtakesMixin, unittest.TestCase): +class TestTasks(CtakesMixin, unittest.IsolatedAsyncioTestCase): """Test case for task methods""" def setUp(self) -> None: super().setUp() + client = fhir_client.FhirClient("http://localhost/", []) script_dir = os.path.dirname(__file__) data_dir = os.path.join(script_dir, "data/simple") self.tmpdir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with @@ -30,7 +33,7 @@ def setUp(self) -> None: os.makedirs(self.phi_dir) self.job_config = config.JobConfig( - self.input_dir, self.input_dir, self.tmpdir.name, self.phi_dir, "ndjson", "ndjson", batch_size=5 + self.input_dir, self.input_dir, self.tmpdir.name, self.phi_dir, "ndjson", "ndjson", client, batch_size=5 ) self.format = mock.MagicMock() @@ -51,60 +54,80 @@ def make_json(self, filename, resource_id, **kwargs): os.path.join(self.input_dir, f"{filename}.ndjson"), {"resourceType": "Test", **kwargs, "id": resource_id} ) - def test_batch_iterate(self): + async def test_batch_iterate(self): """Check a bunch of edge cases for the _batch_iterate helper""" # pylint: disable=protected-access - self.assertEqual([], [list(x) for x in tasks._batch_iterate([], 2)]) + # Tiny little convenience method to be turn sync lists into async iterators. + async def async_iter(values: List) -> AsyncIterator: + for x in values: + yield x - self.assertEqual( + # Handles converting all the async code into synchronous lists for ease of testing + async def assert_batches_equal(expected: List, values: List, batch_size: int) -> None: + collected = [] + async for batch in tasks._batch_iterate(async_iter(values), batch_size): + batch_list = [] + async for item in batch: + batch_list.append(item) + collected.append(batch_list) + self.assertEqual(expected, collected) + + await assert_batches_equal([], [], 2) + + await assert_batches_equal( [ [1, 2.1, 2.2], [3, 4], ], - [list(x) for x in tasks._batch_iterate([1, [2.1, 2.2], 3, 4], 2)], + [1, [2.1, 2.2], 3, 4], + 2, ) - self.assertEqual( + await assert_batches_equal( [ [1.1, 1.2], [2.1, 2.2], [3, 4], ], - [list(x) for x in tasks._batch_iterate([[1.1, 1.2], [2.1, 2.2], 3, 4], 2)], + [[1.1, 1.2], [2.1, 2.2], 3, 4], + 2, ) - self.assertEqual( + await assert_batches_equal( [ [1, 2.1, 2.2], [3, 4], [5], ], - [list(x) for x in tasks._batch_iterate([1, [2.1, 2.2], 3, 4, 5], 2)], + [1, [2.1, 2.2], 3, 4, 5], + 2, ) - self.assertEqual( + await assert_batches_equal( [ [1, 2.1, 2.2], [3, 4], ], - [list(x) for x in tasks._batch_iterate([1, [2.1, 2.2], 3, 4], 3)], + [1, [2.1, 2.2], 3, 4], + 3, ) - self.assertEqual( + await assert_batches_equal( [ [1], [2.1, 2.2], [3], ], - [list(x) for x in tasks._batch_iterate([1, [2.1, 2.2], 3], 1)], + [1, [2.1, 2.2], 3], + 1, ) with self.assertRaises(ValueError): - list(tasks._batch_iterate([1, 2, 3], 0)) + await assert_batches_equal([], [1, 2, 3], 0) with self.assertRaises(ValueError): - list(tasks._batch_iterate([1, 2, 3], -1)) + await assert_batches_equal([], [1, 2, 3], -1) def test_read_ndjson(self): """Verify we recognize all expected ndjson filename formats""" @@ -122,19 +145,19 @@ def test_read_ndjson(self): resources = tasks.EncounterTask(self.job_config, self.scrubber).read_ndjson() self.assertEqual([], list(resources)) - def test_unknown_modifier_extensions_skipped_for_patients(self): + async def test_unknown_modifier_extensions_skipped_for_patients(self): """Verify we ignore unknown modifier extensions during a normal task (like patients)""" self.make_json("Patient.0", "0") self.make_json("Patient.1", "1", modifierExtension=[{"url": "unrecognized"}]) - tasks.PatientTask(self.job_config, self.scrubber).run() + await tasks.PatientTask(self.job_config, self.scrubber).run() # Confirm that only patient 0 got stored self.assertEqual(1, self.format.write_records.call_count) df = self.format.write_records.call_args[0][0] self.assertEqual([self.codebook.db.patient("0")], list(df.id)) - def test_unknown_modifier_extensions_skipped_for_nlp_symptoms(self): + async def test_unknown_modifier_extensions_skipped_for_nlp_symptoms(self): """Verify we ignore unknown modifier extensions during a custom read task (nlp symptoms)""" docref0 = i2b2_mock_data.documentreference() docref0["subject"]["reference"] = "Patient/1234" @@ -144,7 +167,7 @@ def test_unknown_modifier_extensions_skipped_for_nlp_symptoms(self): docref1["modifierExtension"] = [{"url": "unrecognized"}] self.make_json("DocumentReference.1", "1", **docref1) - tasks.CovidSymptomNlpResultsTask(self.job_config, self.scrubber).run() + await tasks.CovidSymptomNlpResultsTask(self.job_config, self.scrubber).run() # Confirm that only symptoms from docref 0 got stored self.assertEqual(1, self.format.write_records.call_count) @@ -165,25 +188,25 @@ def test_unknown_modifier_extensions_skipped_for_nlp_symptoms(self): ([{"system": "nope", "code": "nope"}, {"system": "http://loinc.org", "code": "57053-1"}], True), ) @ddt.unpack - def test_ed_note_filtering_for_nlp(self, codings, expected): + async def test_ed_note_filtering_for_nlp(self, codings, expected): """Verify we filter out any non-emergency-department note""" # Use one doc with category set, and one with type set. Either should work. docref0 = i2b2_mock_data.documentreference() - docref0["category"] = {"coding": codings} + docref0["category"] = [{"coding": codings}] del docref0["type"] self.make_json("DocumentReference.0", "0", **docref0) docref1 = i2b2_mock_data.documentreference() docref1["type"] = {"coding": codings} self.make_json("DocumentReference.1", "1", **docref1) - tasks.CovidSymptomNlpResultsTask(self.job_config, self.scrubber).run() + await tasks.CovidSymptomNlpResultsTask(self.job_config, self.scrubber).run() self.assertEqual(1 if expected else 0, self.format.write_records.call_count) if expected: df = self.format.write_records.call_args[0][0] self.assertEqual(4, len(df)) - def test_non_ed_visit_is_skipped_for_covid_symptoms(self): + async def test_non_ed_visit_is_skipped_for_covid_symptoms(self): """Verify we ignore non ED visits for the covid symptoms NLP""" docref0 = i2b2_mock_data.documentreference() docref0["type"]["coding"][0]["code"] = "NOTE:nope" # pylint: disable=unsubscriptable-object @@ -192,7 +215,7 @@ def test_non_ed_visit_is_skipped_for_covid_symptoms(self): docref1["type"]["coding"][0]["code"] = "NOTE:149798455" # pylint: disable=unsubscriptable-object self.make_json("DocumentReference.1", "present", **docref1) - tasks.CovidSymptomNlpResultsTask(self.job_config, self.scrubber).run() + await tasks.CovidSymptomNlpResultsTask(self.job_config, self.scrubber).run() # Confirm that only symptoms from docref 'present' got stored self.assertEqual(1, self.format.write_records.call_count) @@ -215,3 +238,34 @@ def test_filtered_but_named_task(self): with self.assertRaises(SystemExit) as cm: tasks.EtlTask.get_selected_tasks(names=["condition"], filter_tags=["gpu"]) self.assertEqual(errors.TASK_FILTERED_OUT, cm.exception.code) + + @ddt.data( + # list of (URL, contentType), expected text + ([("http://localhost/file-cough", "text/plain")], "cough"), # handles absolute URL + ([("file-cough", "text/html")], "cough"), # handles text/* + ([("file-cough", "application/xml")], "cough"), # handles app/xml + ([("file-cough", "text/html"), ("file-fever", "text/plain")], "fever"), # prefers text/plain to text/* + ([("file-cough", "application/xml"), ("file-fever", "text/blarg")], "fever"), # prefers text/* to app/xml + ([("file-cough", "nope/nope")], None), # ignores unsupported mimetypes + ) + @ddt.unpack + @respx.mock + async def test_note_urls_downloaded(self, attachments, expected_text): + """Verify that we download any attachments with URLs""" + # We return three words due to how our cTAKES mock works. It wants 3 words -- fever word is in middle. + respx.get("http://localhost/file-cough").respond(text="has cough bad") + respx.get("http://localhost/file-fever").respond(text="has fever bad") + + docref0 = i2b2_mock_data.documentreference() + docref0["content"] = [{"attachment": {"url": a[0], "contentType": a[1]}} for a in attachments] + self.make_json("DocumentReference.0", "doc0", **docref0) + + async with self.job_config.client: + await tasks.CovidSymptomNlpResultsTask(self.job_config, self.scrubber).run() + + if expected_text: + self.assertEqual(1, self.format.write_records.call_count) + df = self.format.write_records.call_args[0][0] + self.assertEqual(expected_text, df.iloc[0].match["text"]) + else: + self.assertEqual(0, self.format.write_records.call_count)