diff --git a/cumulus_etl/cli_utils.py b/cumulus_etl/cli_utils.py index b8a362ec..e7455269 100644 --- a/cumulus_etl/cli_utils.py +++ b/cumulus_etl/cli_utils.py @@ -1,7 +1,6 @@ """Helper methods for CLI parsing.""" import argparse -import os import socket import tempfile import time @@ -9,7 +8,7 @@ import rich.progress -from cumulus_etl import common, errors +from cumulus_etl import common, errors, store def add_auth(parser: argparse.ArgumentParser) -> None: @@ -60,22 +59,21 @@ def make_export_dir(export_to: str = None) -> common.Directory: # If we were to relax this requirement, we'd want to copy the exported files over to a local dir. errors.fatal(f"The target export folder '{export_to}' must be local. ", errors.BULK_EXPORT_FOLDER_NOT_LOCAL) - confirm_dir_is_empty(export_to) + confirm_dir_is_empty(store.Root(export_to, create=True)) return common.RealDirectory(export_to) -def confirm_dir_is_empty(path: str) -> None: - """Errors out if the dir exists with contents, but creates empty dir if not present yet""" +def confirm_dir_is_empty(root: store.Root) -> None: + """Errors out if the dir exists with contents""" try: - if os.listdir(path): + if root.ls(): errors.fatal( - f"The target folder '{path}' already has contents. Please provide an empty folder.", + f"The target folder '{root.path}' already has contents. Please provide an empty folder.", errors.FOLDER_NOT_EMPTY, ) except FileNotFoundError: - # Target folder doesn't exist, so let's make it - os.makedirs(path, mode=0o700) + pass def is_url_available(url: str, retry: bool = True) -> bool: diff --git a/cumulus_etl/completion/__init__.py b/cumulus_etl/completion/__init__.py new file mode 100644 index 00000000..6b8afbbf --- /dev/null +++ b/cumulus_etl/completion/__init__.py @@ -0,0 +1,21 @@ +""" +Helpers for implementing completion-tracking. + +Completion tracking allows downstream consumers to know when ETL runs are +"complete enough" for their purposes. + +For example, the `core` study may want to not expose Encounters whose +Conditions have not yet been loaded. These metadata tables allow that. + +Although these metadata tables aren't themselves tasks, they need a +lot of the same information that tasks need. This module provides that. +""" + +from .output import ( + COMPLETION_TABLE, + COMPLETION_ENCOUNTERS_TABLE, + completion_encounters_output_args, + completion_encounters_schema, + completion_format_args, + completion_schema, +) diff --git a/cumulus_etl/completion/output.py b/cumulus_etl/completion/output.py new file mode 100644 index 00000000..e2c18daf --- /dev/null +++ b/cumulus_etl/completion/output.py @@ -0,0 +1,62 @@ +"""Schemas and Format helpers for writing completion tables.""" + +import pyarrow + + +COMPLETION_TABLE = "etl__completion" +COMPLETION_ENCOUNTERS_TABLE = "etl__completion_encounters" + + +# FORMATTERS + + +def completion_format_args() -> dict: + """Returns kwargs to pass to the Format class initializer of your choice""" + return { + "dbname": COMPLETION_TABLE, + "uniqueness_fields": {"table_name", "group_name"}, + } + + +# OUTPUT TABLES + + +def completion_encounters_output_args() -> dict: + """Returns output table kwargs for the etl__completion_encounters table""" + return { + "name": COMPLETION_ENCOUNTERS_TABLE, + "uniqueness_fields": {"encounter_id", "group_name"}, + "update_existing": False, # we want to keep the first export time we make for a group + "resource_type": None, + "visible": False, + } + + +# SCHEMAS + + +def completion_schema() -> pyarrow.Schema: + """Returns a schema for the etl__completion table""" + return pyarrow.schema( + [ + pyarrow.field("table_name", pyarrow.string()), + pyarrow.field("group_name", pyarrow.string()), + # You might think this is an opportunity to use pyarrow.timestamp(), + # but because ndjson output formats (which can't natively represent a + # datetime) would then require conversion to and fro, it's easier to + # just mirror our FHIR tables and use strings for timestamps. + pyarrow.field("export_time", pyarrow.string()), + ] + ) + + +def completion_encounters_schema() -> pyarrow.Schema: + """Returns a schema for the etl__completion_encounters table""" + return pyarrow.schema( + [ + pyarrow.field("encounter_id", pyarrow.string()), + pyarrow.field("group_name", pyarrow.string()), + # See note above for why this isn't a pyarrow.timestamp() field. + pyarrow.field("export_time", pyarrow.string()), + ] + ) diff --git a/cumulus_etl/errors.py b/cumulus_etl/errors.py index 6807305b..7d70a24d 100644 --- a/cumulus_etl/errors.py +++ b/cumulus_etl/errors.py @@ -32,6 +32,7 @@ LABEL_STUDIO_MISSING = 31 FHIR_AUTH_FAILED = 32 SERVICE_MISSING = 33 # generic init-check service is missing +COMPLETION_ARG_MISSING = 34 class FhirConnectionError(Exception): diff --git a/cumulus_etl/etl/cli.py b/cumulus_etl/etl/cli.py index e98b8f02..fb3be80b 100644 --- a/cumulus_etl/etl/cli.py +++ b/cumulus_etl/etl/cli.py @@ -124,6 +124,12 @@ def define_etl_parser(parser: argparse.ArgumentParser) -> None: export.add_argument("--since", help="Start date for export from the FHIR server") export.add_argument("--until", help="End date for export from the FHIR server") + group = parser.add_argument_group("external export identification") + group.add_argument("--export-group", help=argparse.SUPPRESS) + group.add_argument("--export-timestamp", help=argparse.SUPPRESS) + # Temporary explicit opt-in flag during the development of the completion-tracking feature + group.add_argument("--write-completion", action="store_true", default=False, help=argparse.SUPPRESS) + cli_utils.add_nlp(parser) task = parser.add_argument_group("task selection") @@ -180,6 +186,30 @@ def print_config(args: argparse.Namespace, job_datetime: datetime.datetime, all_ rich.get_console().print(table) +def handle_completion_args(args: argparse.Namespace, loader: loaders.Loader) -> (str, datetime.datetime): + """Returns (group_name, datetime)""" + # Grab completion options from CLI or loader + export_group_name = args.export_group or loader.group_name + export_datetime = ( + datetime.datetime.fromisoformat(args.export_timestamp) if args.export_timestamp else loader.export_datetime + ) + + # Disable entirely if asked to + if not args.write_completion: + export_group_name = None + export_datetime = None + + # Error out if we have mismatched args + has_group_name = export_group_name is not None + has_datetime = bool(export_datetime) + if has_group_name and not has_datetime: + errors.fatal("Missing --export-datetime argument.", errors.COMPLETION_ARG_MISSING) + elif not has_group_name and has_datetime: + errors.fatal("Missing --export-group argument.", errors.COMPLETION_ARG_MISSING) + + return export_group_name, export_datetime + + async def etl_main(args: argparse.Namespace) -> None: # Set up some common variables store.set_user_fs_options(vars(args)) # record filesystem options like --s3-region before creating Roots @@ -200,7 +230,7 @@ async def etl_main(args: argparse.Namespace) -> None: common.print_header() # all "prep" comes in this next section, like connecting to server, bulk export, and de-id if args.errors_to: - cli_utils.confirm_dir_is_empty(args.errors_to) + cli_utils.confirm_dir_is_empty(store.Root(args.errors_to, create=True)) # Check that cTAKES is running and any other services or binaries we require if not args.skip_init_checks: @@ -226,6 +256,9 @@ async def etl_main(args: argparse.Namespace) -> None: # Pull down resources from any remote location (like s3), convert from i2b2, or do a bulk export loaded_dir = await config_loader.load_all(list(required_resources)) + # Establish the group name and datetime of the loaded dataset (from CLI args or Loader) + export_group_name, export_datetime = handle_completion_args(args, config_loader) + # If *any* of our tasks need bulk MS de-identification, run it if any(t.needs_bulk_deid for t in selected_tasks): loaded_dir = await deid.Scrubber.scrub_bulk_data(loaded_dir.name) @@ -248,6 +281,8 @@ async def etl_main(args: argparse.Namespace) -> None: ctakes_overrides=args.ctakes_overrides, dir_errors=args.errors_to, tasks=[t.name for t in selected_tasks], + export_group_name=export_group_name, + export_datetime=export_datetime, ) common.write_json(config.path_config(), config.as_json(), indent=4) diff --git a/cumulus_etl/etl/config.py b/cumulus_etl/etl/config.py index 1440408e..c468c54d 100644 --- a/cumulus_etl/etl/config.py +++ b/cumulus_etl/etl/config.py @@ -31,6 +31,8 @@ def __init__( ctakes_overrides: str = None, dir_errors: str = None, tasks: list[str] = None, + export_group_name: str = None, + export_datetime: datetime.datetime = None, ): self._dir_input_orig = dir_input_orig self.dir_input = dir_input_deid @@ -46,14 +48,16 @@ def __init__( self.batch_size = batch_size self.ctakes_overrides = ctakes_overrides self.tasks = tasks or [] + self.export_group_name = export_group_name + self.export_datetime = export_datetime # initialize format class self._output_root = store.Root(self._dir_output, create=True) self._format_class = formats.get_format_class(self._output_format) self._format_class.initialize_class(self._output_root) - def create_formatter(self, dbname: str, group_field: str = None, resource_type: str = None) -> formats.Format: - return self._format_class(self._output_root, dbname, group_field=group_field, resource_type=resource_type) + def create_formatter(self, dbname: str, **kwargs) -> formats.Format: + return self._format_class(self._output_root, dbname, **kwargs) def path_config(self) -> str: return os.path.join(self.dir_job_config(), "job_config.json") @@ -74,6 +78,8 @@ def as_json(self): "comment": self.comment, "batch_size": self.batch_size, "tasks": ",".join(self.tasks), + "export_group_name": self.export_group_name, + "export_timestamp": self.export_datetime and self.export_datetime.isoformat(), } diff --git a/cumulus_etl/etl/convert/cli.py b/cumulus_etl/etl/convert/cli.py index 17240d3b..3bde7191 100644 --- a/cumulus_etl/etl/convert/cli.py +++ b/cumulus_etl/etl/convert/cli.py @@ -7,15 +7,21 @@ import argparse import os import tempfile +from functools import partial +from typing import Callable +import pyarrow import rich.progress -from cumulus_etl import cli_utils, common, errors, formats, store +from cumulus_etl import cli_utils, common, completion, errors, formats, store from cumulus_etl.etl import tasks from cumulus_etl.etl.tasks import task_factory -def make_batch(task: type[tasks.EtlTask], formatter: formats.Format, index: int, path: str) -> formats.Batch: +def make_batch( + path: str, + schema_func: Callable[[list[dict]], pyarrow.Schema], +) -> formats.Batch: metadata_path = path.removesuffix(".ndjson") + ".meta" try: metadata = common.read_json(metadata_path) @@ -24,50 +30,88 @@ def make_batch(task: type[tasks.EtlTask], formatter: formats.Format, index: int, rows = list(common.read_ndjson(path)) groups = set(metadata.get("groups", [])) - return task.make_batch_from_rows(formatter, rows, groups=groups, index=index) + schema = schema_func(rows) + return formats.Batch(rows, groups=groups, schema=schema) -def convert_task_table( - task: type[tasks.EtlTask], - table: tasks.OutputTable, + +def convert_folder( input_root: store.Root, - output_root: store.Root, - formatter_class: type[formats.Format], + *, + table_name: str, + schema_func: Callable[[list[dict]], pyarrow.Schema], + formatter: formats.Format, progress: rich.progress.Progress, ) -> None: - """Converts a task's output folder (like output/observation/ or output/covid_symptom__nlp_results/)""" - # Does the task dir even exist? - task_input_dir = input_root.joinpath(table.get_name(task)) - if not input_root.exists(task_input_dir): + table_input_dir = input_root.joinpath(table_name) + if not input_root.exists(table_input_dir): # Don't error out in this case -- it's not the user's fault if the folder doesn't exist. # We're just checking all task folders. return # Grab all the files in the task dir - all_paths = store.Root(task_input_dir).ls() + all_paths = store.Root(table_input_dir).ls() ndjson_paths = sorted(filter(lambda x: x.endswith(".ndjson"), all_paths)) if not ndjson_paths: # Again, don't error out in this case -- if the ETL made an empty dir, it's not a user-visible error return - # Let's convert! Make the formatter and chew through the files + # Let's convert! Start chewing through the files + count = len(ndjson_paths) + 1 # add one for finalize step + progress_task = progress.add_task(table_name, total=count) + + for ndjson_path in ndjson_paths: + batch = make_batch(ndjson_path, schema_func) + formatter.write_records(batch) + progress.update(progress_task, advance=1) + + formatter.finalize() + progress.update(progress_task, advance=1) + + +def convert_task_table( + task: type[tasks.EtlTask], + table: tasks.OutputTable, + input_root: store.Root, + output_root: store.Root, + formatter_class: type[formats.Format], + progress: rich.progress.Progress, +) -> None: + """Converts a task's output folder (like output/observation/ or output/covid_symptom__nlp_results/)""" + + # Start with a formatter formatter = formatter_class( output_root, table.get_name(task), group_field=table.group_field, - resource_type=table.get_schema(task), + uniqueness_fields=table.uniqueness_fields, + update_existing=table.update_existing, ) - count = len(ndjson_paths) + 1 # add one for finalize step - progress_task = progress.add_task(table.get_name(task), total=count) + # And then run the conversion + convert_folder( + input_root, + table_name=table.get_name(task), + schema_func=partial(task.get_schema, table.get_resource_type(task)), + formatter=formatter, + progress=progress, + ) - for index, ndjson_path in enumerate(ndjson_paths): - batch = make_batch(task, formatter, index, ndjson_path) - formatter.write_records(batch) - progress.update(progress_task, advance=1) - formatter.finalize() - progress.update(progress_task, advance=1) +def convert_completion( + input_root: store.Root, + output_root: store.Root, + formatter_class: type[formats.Format], + progress: rich.progress.Progress, +) -> None: + """Converts the etl__completion metadata table""" + convert_folder( + input_root, + table_name=completion.COMPLETION_TABLE, + schema_func=lambda rows: completion.completion_schema(), + formatter=formatter_class(output_root, **completion.completion_format_args()), + progress=progress, + ) def copy_job_configs(input_root: store.Root, output_root: store.Root) -> None: @@ -90,6 +134,9 @@ def walk_tree(input_root: store.Root, output_root: store.Root, formatter_class: for table in task.outputs: convert_task_table(task, table, input_root, output_root, formatter_class, progress) + # And aftward, copy over the completion metadata tables + convert_completion(input_root, output_root, formatter_class, progress) + # Copy JobConfig files over too. # To consider: Marking the job_config.json file in these JobConfig directories as "converted" in some way. # They already will be detectable by having "output_format: ndjson", but maybe we could do more. diff --git a/cumulus_etl/etl/studies/covid_symptom/covid_tasks.py b/cumulus_etl/etl/studies/covid_symptom/covid_tasks.py index 167cd9b9..aade7bdf 100644 --- a/cumulus_etl/etl/studies/covid_symptom/covid_tasks.py +++ b/cumulus_etl/etl/studies/covid_symptom/covid_tasks.py @@ -7,7 +7,7 @@ import rich.progress from ctakesclient.transformer import TransformerModel -from cumulus_etl import formats, nlp, store +from cumulus_etl import nlp, store from cumulus_etl.etl import tasks from cumulus_etl.etl.studies.covid_symptom import covid_ctakes @@ -109,7 +109,7 @@ class BaseCovidSymptomNlpResultsTask(tasks.BaseNlpTask): # cNLP: smartonfhir/cnlp-transformers:negation-0.4.0 # ctakesclient: 3.0 - outputs = [tasks.OutputTable(schema=None, group_field="docref_id")] + outputs = [tasks.OutputTable(resource_type=None, group_field="docref_id")] async def prepare_task(self) -> bool: bsv_path = ctakesclient.filesystem.covid_symptoms_path() @@ -155,7 +155,7 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> task yield symptoms @classmethod - def get_schema(cls, formatter: formats.Format, rows: list[dict]) -> pyarrow.Schema: + def get_schema(cls, resource_type: str | None, rows: list[dict]) -> pyarrow.Schema: return pyarrow.schema( [ pyarrow.field("id", pyarrow.string()), diff --git a/cumulus_etl/etl/studies/hftest/hf_tasks.py b/cumulus_etl/etl/studies/hftest/hf_tasks.py index 1a3bab0b..2b67d49a 100644 --- a/cumulus_etl/etl/studies/hftest/hf_tasks.py +++ b/cumulus_etl/etl/studies/hftest/hf_tasks.py @@ -4,7 +4,7 @@ import pyarrow import rich.progress -from cumulus_etl import common, errors, formats, nlp +from cumulus_etl import common, errors, nlp from cumulus_etl.etl import tasks @@ -86,7 +86,7 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> task } @classmethod - def get_schema(cls, formatter: formats.Format, rows: list[dict]) -> pyarrow.Schema: + def get_schema(cls, resource_type: str | None, rows: list[dict]) -> pyarrow.Schema: return pyarrow.schema( [ pyarrow.field("id", pyarrow.string()), diff --git a/cumulus_etl/etl/tasks/base.py b/cumulus_etl/etl/tasks/base.py index eb425f53..1e4d720d 100644 --- a/cumulus_etl/etl/tasks/base.py +++ b/cumulus_etl/etl/tasks/base.py @@ -12,7 +12,7 @@ import rich.table import rich.text -from cumulus_etl import cli_utils, common, deid, formats, store +from cumulus_etl import cli_utils, common, completion, deid, formats, store from cumulus_etl.etl import config from cumulus_etl.etl.tasks import batching @@ -33,16 +33,16 @@ class OutputTable: def get_name(self, task): return self.name or task.name - # *** schema *** + # *** resource_type *** # This field determines the schema of the output table. # Put a FHIR resource name (like "Observation") here, to fill the output table with an appropriate schema. # - None disables using a schema # - "__same__" means to use the same resource name as our input # - Or use any FHIR resource name - schema: str | None = "__same__" + resource_type: str | None = "__same__" - def get_schema(self, task): - return task.resource if self.schema == "__same__" else self.schema + def get_resource_type(self, task): + return task.resource if self.resource_type == "__same__" else self.resource_type # *** group_field *** # Set group_field if your task generates a group of interdependent records (like NLP results from a document). @@ -64,6 +64,19 @@ def get_schema(self, task): # group at one time, and not split across batches. group_field: str | None = None + # *** uniqueness_fields *** + # The set of fields which together, determine a unique row. There should be no duplicates that + # share the same value for all these fields. Default is ["id"] + uniqueness_fields: set[str] | None = None + + # *** update_existing *** + # Whether to update existing rows or (if False) to ignore them and leave them in place. + update_existing: bool = True + + # *** visible *** + # Whether this table should be user-visible in the progress output. + visible: bool = True + class EtlTask: """ @@ -93,6 +106,9 @@ def __init__(self, task_config: config.JobConfig, scrubber: deid.Scrubber): self.scrubber = scrubber self.formatters: list[formats.Format | None] = [None] * len(self.outputs) # create format placeholders self.summaries: list[config.JobSummary] = [config.JobSummary(output.get_name(self)) for output in self.outputs] + self.completion_tracking_enabled = ( + self.task_config.export_group_name is not None and self.task_config.export_datetime + ) async def run(self) -> list[config.JobSummary]: """ @@ -126,6 +142,9 @@ async def run(self) -> list[config.JobSummary]: # changes. (The reason it's nice if the table & schema exist is so that downstream SQL can be dumber.) self._touch_remaining_tables() + # Mark this group & resource combo as complete + self._update_completion_table() + # All data is written, now do any final cleanup the formatters want for formatter in self.formatters: formatter.finalize() @@ -133,9 +152,9 @@ async def run(self) -> list[config.JobSummary]: return self.summaries @classmethod - def make_batch_from_rows(cls, formatter: formats.Format, rows: list[dict], groups: set[str] = None, index: int = 0): - schema = cls.get_schema(formatter, rows) - return formats.Batch(rows, groups=groups, schema=schema, index=index) + def make_batch_from_rows(cls, resource_type: str | None, rows: list[dict], groups: set[str] = None): + schema = cls.get_schema(resource_type, rows) + return formats.Batch(rows, groups=groups, schema=schema) ########################################################################################## # @@ -155,7 +174,9 @@ async def _write_tables_in_batches( """Writes all entries to each output tables in batches""" def update_status(): - status.plain = "\n".join(f"{x.success:,} written to {x.label}" for x in self.summaries) + status.plain = "\n".join( + f"{x.success:,} written to {x.label}" for i, x in enumerate(self.summaries) if self.outputs[i].visible + ) batch_index = 0 format_progress_task = None @@ -193,6 +214,32 @@ def _touch_remaining_tables(self): if formatter is None: # No data got written yet self._write_one_table_batch([], table_index, 0) # just write an empty dataframe (should be fast) + def _update_completion_table(self) -> None: + # TODO: what about empty sets - do we assume the export gave 0 results or skip it? + # Is there a difference we could notice? (like empty input file vs no file at all) + + if not self.completion_tracking_enabled: + return + + # Create completion rows + batch = formats.Batch( + rows=[ + { + "table_name": output.get_name(self), + "group_name": self.task_config.export_group_name, + "export_time": self.task_config.export_datetime.isoformat(), + } + for output in self.outputs + if not output.get_name(self).startswith("etl__") + ], + schema=completion.completion_schema(), + ) + + # Write it out + formatter = self.task_config.create_formatter(**completion.completion_format_args()) + formatter.write_records(batch) + formatter.finalize() + def _get_formatter(self, table_index: int) -> formats.Format: """ Lazily create output table formatters. @@ -206,14 +253,16 @@ def _get_formatter(self, table_index: int) -> formats.Format: self.formatters[table_index] = self.task_config.create_formatter( table_info.get_name(self), group_field=table_info.group_field, - resource_type=table_info.get_schema(self), + uniqueness_fields=table_info.uniqueness_fields, + update_existing=table_info.update_existing, ) return self.formatters[table_index] - def _uniquify_rows(self, rows: list[dict]) -> list[dict]: + @staticmethod + def _uniquify_rows(rows: list[dict], uniqueness_fields: set[str]) -> list[dict]: """ - Drop duplicates inside the batch to guarantee to the formatter that the "id" column is unique. + Drop duplicates inside the batch to guarantee to the formatter that each row is unique. This does not fix uniqueness across batches, but formatters that care about that can control for it. @@ -226,12 +275,14 @@ def _uniquify_rows(self, rows: list[dict]) -> list[dict]: - Other backends like ndjson can currently just live with duplicates across batches, that's fine. """ id_set = set() + uniqueness_fields = sorted(uniqueness_fields) if uniqueness_fields else ["id"] def is_unique(row): nonlocal id_set - if row["id"] in id_set: + row_id = tuple(row[field] for field in uniqueness_fields) + if row_id in id_set: return False - id_set.add(row["id"]) + id_set.add(row_id) return True return [row for row in rows if is_unique(row)] @@ -241,26 +292,27 @@ def _write_one_table_batch(self, rows: list[dict], table_index: int, batch_index # updated codebook with no data than data with an inaccurate codebook. self.scrubber.save() + output = self.outputs[table_index] formatter = self._get_formatter(table_index) - rows = self._uniquify_rows(rows) + rows = self._uniquify_rows(rows, formatter.uniqueness_fields) groups = self.pop_current_group_values(table_index) - batch = self.make_batch_from_rows(formatter, rows, groups=groups, index=batch_index) + batch = self.make_batch_from_rows(output.get_resource_type(self), rows, groups=groups) # Now we write that batch to the target folder, in the requested format (e.g. ndjson). success = formatter.write_records(batch) if not success: # We should write the "bad" batch to the error dir, for later review - self._write_errors(batch) + self._write_errors(batch, batch_index) return success - def _write_errors(self, batch: formats.Batch) -> None: + def _write_errors(self, batch: formats.Batch, batch_index: int) -> None: """Takes the dataframe and writes it to the error dir, if one was provided""" if not self.task_config.dir_errors: return error_root = store.Root(os.path.join(self.task_config.dir_errors, self.name), create=True) - error_path = error_root.joinpath(f"write-error.{batch.index:03}.ndjson") + error_path = error_root.joinpath(f"write-error.{batch_index:03}.ndjson") common.write_rows_to_ndjson(error_path, batch.rows) ########################################################################################## @@ -351,12 +403,12 @@ def pop_current_group_values(self, table_index: int) -> set[str]: return set() @classmethod - def get_schema(cls, formatter: formats.Format, rows: list[dict]) -> pyarrow.Schema | None: + def get_schema(cls, resource_type: str | None, rows: list[dict]) -> pyarrow.Schema | None: """ Creates a properly-schema'd Table from the provided batch. Can be overridden as needed for non-FHIR outputs. """ - if formatter.resource_type: - return cumulus_fhir_support.pyarrow_schema_from_rows(formatter.resource_type, rows) + if resource_type: + return cumulus_fhir_support.pyarrow_schema_from_rows(resource_type, rows) return None diff --git a/cumulus_etl/etl/tasks/basic_tasks.py b/cumulus_etl/etl/tasks/basic_tasks.py index cb9f08a9..064793f9 100644 --- a/cumulus_etl/etl/tasks/basic_tasks.py +++ b/cumulus_etl/etl/tasks/basic_tasks.py @@ -4,9 +4,10 @@ import logging import os +import pyarrow import rich.progress -from cumulus_etl import common, fhir, store +from cumulus_etl import common, completion, fhir, store from cumulus_etl.etl import tasks @@ -41,10 +42,43 @@ class DocumentReferenceTask(tasks.EtlTask): class EncounterTask(tasks.EtlTask): + """Processes Encounter FHIR resources""" + name = "encounter" resource = "Encounter" tags = {"cpu"} + # Encounters are a little more complicated than normal FHIR resources. + # We also write out a table tying Encounters to a group name, for completion tracking. + + outputs = [ + # Write completion data out first, so that if an encounter is being completion-tracked, + # there's never a gap where it doesn't have an entry. This will help downstream users + # know if an Encounter is tracked or not - by simply looking at this table. + tasks.OutputTable(**completion.completion_encounters_output_args()), + tasks.OutputTable(), + ] + + async def read_entries(self, *, progress: rich.progress.Progress = None) -> tasks.EntryIterator: + async for encounter in super().read_entries(progress=progress): + if self.completion_tracking_enabled: + completion_info = { + "encounter_id": encounter["id"], + "group_name": self.task_config.export_group_name, + "export_time": self.task_config.export_datetime.isoformat(), + } + else: + completion_info = None + + yield completion_info, encounter + + @classmethod + def get_schema(cls, resource_type: str | None, rows: list[dict]) -> pyarrow.Schema | None: + if resource_type: + return super().get_schema(resource_type, rows) + else: + return completion.completion_encounters_schema() + class ImmunizationTask(tasks.EtlTask): name = "immunization" @@ -66,8 +100,9 @@ class MedicationRequestTask(tasks.EtlTask): # and many EHRs don't let you simply bulk export them. outputs = [ + # Write medication out first, to avoid a moment where links are broken + tasks.OutputTable(name="medication", resource_type="Medication"), tasks.OutputTable(), - tasks.OutputTable(name="medication", schema="Medication"), ] def __init__(self, *args, **kwargs): @@ -132,7 +167,7 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> task continue medication = await self.fetch_medication(orig_resource) - yield resource, medication + yield medication, resource class ObservationTask(tasks.EtlTask): diff --git a/cumulus_etl/etl/tasks/nlp_task.py b/cumulus_etl/etl/tasks/nlp_task.py index f7bc65fe..b017367b 100644 --- a/cumulus_etl/etl/tasks/nlp_task.py +++ b/cumulus_etl/etl/tasks/nlp_task.py @@ -19,7 +19,7 @@ class BaseNlpTask(EtlTask): needs_bulk_deid = False # You may want to override these in your subclass - outputs = [OutputTable(schema=None)] # maybe a group_field? (remember to call self.seen_docrefs.add() if so) + outputs = [OutputTable(resource_type=None)] # maybe a group_field? (remember to call self.seen_docrefs.add() if so) tags = {"gpu"} # maybe a study identifier? # Task Version diff --git a/cumulus_etl/formats/base.py b/cumulus_etl/formats/base.py index 6d3af0ec..ee72ae9d 100644 --- a/cumulus_etl/formats/base.py +++ b/cumulus_etl/formats/base.py @@ -2,6 +2,7 @@ import abc import logging +from collections.abc import Collection from cumulus_etl import store from cumulus_etl.formats.batch import Batch @@ -22,7 +23,14 @@ def initialize_class(cls, root: store.Root) -> None: (e.g. some expensive setup that can be shared across per-table format instances, or eventually across threads) """ - def __init__(self, root: store.Root, dbname: str, group_field: str = None, resource_type: str = None): + def __init__( + self, + root: store.Root, + dbname: str, + group_field: str = None, + uniqueness_fields: Collection[str] = None, + update_existing: bool = True, + ): """ Initialize a new Format class :param root: the base location to write data to @@ -31,18 +39,21 @@ def __init__(self, root: store.Root, dbname: str, group_field: str = None, resou deleted -- for example "docref_id" will mean that any existing rows matching docref_id will be deleted before inserting any from this dataframe. Make sure that all records for a given group are in one single dataframe. See the comments for the EtlTask.group_field class attribute for more context. - :param resource_type: the name of the FHIR resource being stored, in case a fuller schema is needed + :param uniqueness_fields: a set of fields that together identify a unique row (defaults to {"id"}) + :param update_existing: whether to update existing rows or (if False) to ignore them and leave them in place """ self.root = root self.dbname = dbname self.group_field = group_field - self.resource_type = resource_type + self.uniqueness_fields = uniqueness_fields or {"id"} + self.update_existing = update_existing def write_records(self, batch: Batch) -> bool: """ Writes a single batch of data to the output root. - The batch must contain a unique (no duplicates) "id" column. + The batch must contain no duplicate rows + (i.e. rows with the same values in all the `uniqueness_fields` columns). :param batch: the batch of data :returns: whether the batch was successfully written diff --git a/cumulus_etl/formats/batch.py b/cumulus_etl/formats/batch.py index e8e60a03..96fada27 100644 --- a/cumulus_etl/formats/batch.py +++ b/cumulus_etl/formats/batch.py @@ -15,11 +15,10 @@ class Batch: - Written to the target location as one piece (e.g. one ndjson file or one Delta Lake update chunk) """ - def __init__(self, rows: list[dict], groups: set[str] = None, schema: pyarrow.Schema = None, index: int = 0): + def __init__(self, rows: list[dict], groups: set[str] = None, schema: pyarrow.Schema = None): self.rows = rows # `groups` is the set of the values of the format's `group_field` represented by `rows`. # We can't just get this from rows directly because there might be groups that now have zero entries. # And those won't be in rows, but should still be removed from the database. self.groups = groups self.schema = schema - self.index = index diff --git a/cumulus_etl/formats/batched_files.py b/cumulus_etl/formats/batched_files.py index 2a48fc75..9369a195 100644 --- a/cumulus_etl/formats/batched_files.py +++ b/cumulus_etl/formats/batched_files.py @@ -1,9 +1,10 @@ """An implementation of Format designed to write in batches of files""" import abc +import os import re -from cumulus_etl import errors, store +from cumulus_etl import cli_utils, store from cumulus_etl.formats.base import Format from cumulus_etl.formats.batch import Batch @@ -36,42 +37,49 @@ def write_format(self, batch: Batch, path: str) -> None: # ########################################################################################## + @classmethod + def initialize_class(cls, root: store.Root) -> None: + # The ndjson formatter has a few main use cases: + # - unit testing + # - manual testing + # - an initial ETL run, manual inspection, then converting that to deltalake + # + # In all those use cases, we don't really need to re-use the same directory. + # And re-using the target directory can cause problems: + # - accidentally overriding important data + # - how should we handle the 2nd ETL run writing less / different batched files? + # + # So we just confirm that the output folder is empty - let's avoid the whole thing. + # But we do it in class-init rather than object-init because other tasks will create + # files here during the ETL run. + cli_utils.confirm_dir_is_empty(root) + def __init__(self, *args, **kwargs) -> None: """Performs any preparation before any batches have been written.""" super().__init__(*args, **kwargs) - # Let's clear out any existing files before writing any new ones. - # Note: There is a real issue here where Athena will see invalid results until we've written all - # our files out. Use the deltalake format to get atomic updates. - parent_dir = self.root.joinpath(self.dbname) - self._confirm_no_unknown_files_exist(parent_dir) - try: - self.root.rm(parent_dir, recursive=True) - except FileNotFoundError: - pass + self.dbroot = store.Root(self.root.joinpath(self.dbname)) - def _confirm_no_unknown_files_exist(self, folder: str) -> None: - """ - Errors out if any unknown files exist in the target dir already. + # Grab the next available batch index to write. + # You might wonder why we do this, if we already checked that the output folder is empty + # during class initialization. + # But some output tables (like etl__completion) are written to in many small batches + # spread over the whole ETL run - so we need to support that workflow. + self._index = self._get_next_index() - This is designed to prevent accidents. - """ + def _get_next_index(self) -> int: try: - filenames = [path.split("/")[-1] for path in store.Root(folder).ls()] + basenames = [os.path.basename(path) for path in self.dbroot.ls()] except FileNotFoundError: - return # folder doesn't exist, we're good! - - allowed_pattern = re.compile(rf"{self.dbname}\.[0-9]+\.({self.suffix}|meta)") - if not all(map(allowed_pattern.fullmatch, filenames)): - errors.fatal( - f"There are unexpected files in the output folder '{folder}'.\n" - f"Please confirm you are using the right output format.\n" - f"If so, delete the output folder and try again.", - errors.FOLDER_NOT_EMPTY, - ) + return 0 + pattern = re.compile(rf"{self.dbname}\.([0-9]+)\.{self.suffix}") + matches = [pattern.match(basename) for basename in basenames] + numbers = [int(match.group(1)) for match in matches if match] + return max(numbers, default=-1) + 1 def _write_one_batch(self, batch: Batch) -> None: """Writes the whole dataframe to a single file""" - self.root.makedirs(self.root.joinpath(self.dbname)) - full_path = self.root.joinpath(f"{self.dbname}/{self.dbname}.{batch.index:03}.{self.suffix}") + self.root.makedirs(self.dbroot.path) + full_path = self.dbroot.joinpath(f"{self.dbname}.{self._index:03}.{self.suffix}") self.write_format(batch, full_path) + self._index += 1 diff --git a/cumulus_etl/formats/deltalake.py b/cumulus_etl/formats/deltalake.py index ec8d118d..c9267396 100644 --- a/cumulus_etl/formats/deltalake.py +++ b/cumulus_etl/formats/deltalake.py @@ -99,13 +99,19 @@ def update_delta_table(self, updates: pyspark.sql.DataFrame, groups: set[str]) - # Load table -- this will trigger an AnalysisException if the table doesn't exist yet table = delta.DeltaTable.forPath(self.spark, full_path) + # Determine merge condition + conditions = [f"table.{field} = updates.{field}" for field in self.uniqueness_fields] + condition = " AND ".join(conditions) + # Merge in new data merge = ( table.alias("table") - .merge(source=updates.alias("updates"), condition="table.id = updates.id") - .whenMatchedUpdateAll(condition=self._get_match_condition(updates.schema)) + .merge(source=updates.alias("updates"), condition=condition) .whenNotMatchedInsertAll() ) + if self.update_existing: + update_condition = self._get_update_condition(updates.schema) + merge = merge.whenMatchedUpdateAll(condition=update_condition) if self.group_field and groups: # Delete any entries for groups touched by this update that are no longer present in the group @@ -145,7 +151,7 @@ def _table_path(self, dbname: str) -> str: return self.root.joinpath(dbname).replace("s3://", "s3a://") # hadoop uses the s3a: scheme instead of s3: @staticmethod - def _get_match_condition(schema: pyspark.sql.types.StructType) -> str | None: + def _get_update_condition(schema: pyspark.sql.types.StructType) -> str | None: """ Determine what (if any) whenMatchedUpdateAll condition to use for the given update schema. diff --git a/cumulus_etl/loaders/base.py b/cumulus_etl/loaders/base.py index 858fd70e..c2f555a8 100644 --- a/cumulus_etl/loaders/base.py +++ b/cumulus_etl/loaders/base.py @@ -21,6 +21,10 @@ def __init__(self, root: store.Root): """ self.root = root + # Public properties (potentially set when loading) for reporting back to caller + self.group_name = None + self.export_datetime = None + @abc.abstractmethod async def load_all(self, resources: list[str]) -> common.Directory: """ diff --git a/cumulus_etl/loaders/fhir/bulk_export.py b/cumulus_etl/loaders/fhir/bulk_export.py index 14997358..64b502e5 100644 --- a/cumulus_etl/loaders/fhir/bulk_export.py +++ b/cumulus_etl/loaders/fhir/bulk_export.py @@ -1,6 +1,7 @@ """Support for FHIR bulk exports""" import asyncio +import datetime import json import os import urllib.parse @@ -54,6 +55,17 @@ def __init__( self._since = since self._until = until + # Public properties, to be read after the export: + self.export_datetime = None + + if "/Group/" in self._url: + latter_half = self._url.split("/Group/", 2)[-1] + self.group_name = latter_half.split("/")[0] + else: + # Global exports don't seem realistic, but the user does do them, + # we'll use the empty string as the default group name for that... + self.group_name = "" + async def export(self) -> None: """ Bulk export resources from a FHIR server into local ndjson files. @@ -100,6 +112,8 @@ async def export(self) -> None: # Finished! We're done waiting and can download all the files response_json = response.json() + self.export_datetime = datetime.datetime.fromisoformat(response_json["transactionTime"]) + # Were there any server-side errors during the export? # The spec acknowledges that "error" is perhaps misleading for an array that can contain info messages. error_texts, warning_texts = await self._gather_all_messages(response_json.get("error", [])) diff --git a/cumulus_etl/loaders/fhir/ndjson_loader.py b/cumulus_etl/loaders/fhir/ndjson_loader.py index ed4170d4..4f14f741 100644 --- a/cumulus_etl/loaders/fhir/ndjson_loader.py +++ b/cumulus_etl/loaders/fhir/ndjson_loader.py @@ -74,6 +74,11 @@ async def _load_from_bulk_export(self, resources: list[str]) -> common.Directory self.client, resources, self.root.path, target_dir.name, self.since, self.until ) await bulk_exporter.export() + + # Copy back these settings from the export + self.group_name = bulk_exporter.group_name + self.export_datetime = bulk_exporter.export_datetime + except errors.FatalError as exc: errors.fatal(str(exc), errors.BULK_EXPORT_FAILED) diff --git a/docs/setup/cumulus-aws-template.yaml b/docs/setup/cumulus-aws-template.yaml index 319f289c..4b433ca9 100644 --- a/docs/setup/cumulus-aws-template.yaml +++ b/docs/setup/cumulus-aws-template.yaml @@ -198,6 +198,8 @@ Resources: - !Sub "s3://${S3Bucket}/${EtlSubdir}/servicerequest" - !Sub "s3://${S3Bucket}/${EtlSubdir}/covid_symptom__nlp_results" - !Sub "s3://${S3Bucket}/${EtlSubdir}/covid_symptom__nlp_results_term_exists" + - !Sub "s3://${S3Bucket}/${EtlSubdir}/etl__completion" + - !Sub "s3://${S3Bucket}/${EtlSubdir}/etl__completion_encounters" CreateNativeDeltaTable: True WriteManifest: False diff --git a/tests/convert/test_convert_cli.py b/tests/convert/test_convert_cli.py index 4cc47fcc..b825958e 100644 --- a/tests/convert/test_convert_cli.py +++ b/tests/convert/test_convert_cli.py @@ -29,7 +29,14 @@ def prepare_original_dir(self) -> str: """Returns the job timestamp used, for easier inspection""" # Fill in original dir, including a non-default output folder shutil.copytree(f"{self.datadir}/simple/output", self.original_path) - shutil.copytree(f"{self.datadir}/covid/term-exists", self.original_path, dirs_exist_ok=True) + shutil.copytree( + f"{self.datadir}/covid/term-exists/covid_symptom__nlp_results_term_exists", + f"{self.original_path}/covid_symptom__nlp_results_term_exists", + ) + shutil.copyfile( + f"{self.datadir}/covid/term-exists/etl__completion/etl__completion.000.ndjson", + f"{self.original_path}/etl__completion/etl__completion.covid.ndjson", + ) os.makedirs(f"{self.original_path}/ignored") # just to confirm we only copy what we understand job_timestamp = "2023-02-28__19.53.08" @@ -87,6 +94,13 @@ async def test_happy_path(self): symptoms = utils.read_delta_lake(f"{self.target_path}/covid_symptom__nlp_results_term_exists") # and covid self.assertEqual(2, len(symptoms)) self.assertEqual("for", symptoms[0]["match"]["text"]) + completion = utils.read_delta_lake(f"{self.target_path}/etl__completion") # and completion + self.assertEqual(14, len(completion)) + self.assertEqual("allergyintolerance", completion[0]["table_name"]) + comp_enc = utils.read_delta_lake(f"{self.target_path}/etl__completion_encounters") + self.assertEqual(2, len(comp_enc)) + self.assertEqual("08f0ebd4-950c-ddd9-ce97-b5bdf073eed1", comp_enc[0]["encounter_id"]) + self.assertEqual("2020-10-13T12:00:20-05:00", comp_enc[0]["export_time"]) # Now make a second small, partial output folder to layer into the existing Delta Lake delta_timestamp = "2023-02-29__19.53.08" @@ -95,6 +109,18 @@ async def test_happy_path(self): with common.NdjsonWriter(f"{delta_path}/patient/new.ndjson") as writer: writer.write({"resourceType": "Patient", "id": "1de9ea66-70d3-da1f-c735-df5ef7697fb9", "birthDate": "1800"}) writer.write({"resourceType": "Patient", "id": "z-gen", "birthDate": "2005"}) + os.makedirs(f"{delta_path}/etl__completion_encounters") + with common.NdjsonWriter(f"{delta_path}/etl__completion_encounters/new.ndjson") as writer: + # Newer timestamp for the existing row + writer.write( + { + "encounter_id": "08f0ebd4-950c-ddd9-ce97-b5bdf073eed1", + "group_name": "test-group", + "export_time": "2021-10-13T17:00:20+00:00", + } + ) + # Totally new encounter + writer.write({"encounter_id": "NEW", "group_name": "NEW", "export_time": "2021-12-12T17:00:20+00:00"}) delta_config_dir = f"{delta_path}/JobConfig/{delta_timestamp}" os.makedirs(delta_config_dir) common.write_json(f"{delta_config_dir}/job_config.json", {"delta": "yup"}) @@ -113,6 +139,14 @@ async def test_happy_path(self): conditions = utils.read_delta_lake(f"{self.target_path}/condition") # and conditions shouldn't change at all self.assertEqual(2, len(conditions)) self.assertEqual("2010-03-02", conditions[0]["recordedDate"]) + comp_enc = utils.read_delta_lake( + f"{self.target_path}/etl__completion_encounters" + ) # and *some* enc mappings did + self.assertEqual(3, len(comp_enc)) + self.assertEqual("08f0ebd4-950c-ddd9-ce97-b5bdf073eed1", comp_enc[0]["encounter_id"]) + self.assertEqual("2020-10-13T12:00:20-05:00", comp_enc[0]["export_time"]) # confirm this *didn't* get updated + self.assertEqual("NEW", comp_enc[1]["encounter_id"]) + self.assertEqual("2021-12-12T17:00:20+00:00", comp_enc[1]["export_time"]) # but the new row did get inserted @mock.patch("cumulus_etl.formats.Format.write_records") async def test_batch_metadata(self, mock_write): diff --git a/tests/data/covid/output/etl__completion/etl__completion.000.ndjson b/tests/data/covid/output/etl__completion/etl__completion.000.ndjson new file mode 100644 index 00000000..1fa8285e --- /dev/null +++ b/tests/data/covid/output/etl__completion/etl__completion.000.ndjson @@ -0,0 +1 @@ +{"table_name": "covid_symptom__nlp_results", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/covid/term-exists/etl__completion/etl__completion.000.ndjson b/tests/data/covid/term-exists/etl__completion/etl__completion.000.ndjson new file mode 100644 index 00000000..7cf8331e --- /dev/null +++ b/tests/data/covid/term-exists/etl__completion/etl__completion.000.ndjson @@ -0,0 +1 @@ +{"table_name": "covid_symptom__nlp_results_term_exists", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/i2b2/output/etl__completion/etl__completion.000.ndjson b/tests/data/i2b2/output/etl__completion/etl__completion.000.ndjson new file mode 100644 index 00000000..7cfd70e5 --- /dev/null +++ b/tests/data/i2b2/output/etl__completion/etl__completion.000.ndjson @@ -0,0 +1 @@ +{"table_name": "encounter", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/i2b2/output/etl__completion/etl__completion.001.ndjson b/tests/data/i2b2/output/etl__completion/etl__completion.001.ndjson new file mode 100644 index 00000000..5d6ada9f --- /dev/null +++ b/tests/data/i2b2/output/etl__completion/etl__completion.001.ndjson @@ -0,0 +1 @@ +{"table_name": "patient", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/i2b2/output/etl__completion/etl__completion.002.ndjson b/tests/data/i2b2/output/etl__completion/etl__completion.002.ndjson new file mode 100644 index 00000000..0d222bb9 --- /dev/null +++ b/tests/data/i2b2/output/etl__completion/etl__completion.002.ndjson @@ -0,0 +1 @@ +{"table_name": "condition", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/i2b2/output/etl__completion/etl__completion.003.ndjson b/tests/data/i2b2/output/etl__completion/etl__completion.003.ndjson new file mode 100644 index 00000000..af658bc9 --- /dev/null +++ b/tests/data/i2b2/output/etl__completion/etl__completion.003.ndjson @@ -0,0 +1 @@ +{"table_name": "documentreference", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/i2b2/output/etl__completion/etl__completion.004.ndjson b/tests/data/i2b2/output/etl__completion/etl__completion.004.ndjson new file mode 100644 index 00000000..6b54b0a4 --- /dev/null +++ b/tests/data/i2b2/output/etl__completion/etl__completion.004.ndjson @@ -0,0 +1,2 @@ +{"table_name": "medication", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} +{"table_name": "medicationrequest", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/i2b2/output/etl__completion/etl__completion.005.ndjson b/tests/data/i2b2/output/etl__completion/etl__completion.005.ndjson new file mode 100644 index 00000000..f40f112a --- /dev/null +++ b/tests/data/i2b2/output/etl__completion/etl__completion.005.ndjson @@ -0,0 +1 @@ +{"table_name": "observation", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/i2b2/output/etl__completion_encounters/etl__completion_encounters.000.ndjson b/tests/data/i2b2/output/etl__completion_encounters/etl__completion_encounters.000.ndjson new file mode 100644 index 00000000..d1f17a75 --- /dev/null +++ b/tests/data/i2b2/output/etl__completion_encounters/etl__completion_encounters.000.ndjson @@ -0,0 +1,2 @@ +{"encounter_id": "82ebb5bb976239c0a7c3b37f50362b58b9c210b35753bb82fbb477f93c43b423", "export_time": "2020-10-13T12:00:20-05:00", "group_name": "test-group"} +{"encounter_id": "fb29ea2a68ca2e1e4bbe22bdeedf021d94ec89f7e3d38ecbe908a8f2b3d89687", "export_time": "2020-10-13T12:00:20-05:00", "group_name": "test-group"} diff --git a/tests/data/simple/batched-output/etl__completion/etl__completion.000.ndjson b/tests/data/simple/batched-output/etl__completion/etl__completion.000.ndjson new file mode 100644 index 00000000..7cfd70e5 --- /dev/null +++ b/tests/data/simple/batched-output/etl__completion/etl__completion.000.ndjson @@ -0,0 +1 @@ +{"table_name": "encounter", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/simple/batched-output/etl__completion/etl__completion.001.ndjson b/tests/data/simple/batched-output/etl__completion/etl__completion.001.ndjson new file mode 100644 index 00000000..5d6ada9f --- /dev/null +++ b/tests/data/simple/batched-output/etl__completion/etl__completion.001.ndjson @@ -0,0 +1 @@ +{"table_name": "patient", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/simple/batched-output/etl__completion/etl__completion.002.ndjson b/tests/data/simple/batched-output/etl__completion/etl__completion.002.ndjson new file mode 100644 index 00000000..0d222bb9 --- /dev/null +++ b/tests/data/simple/batched-output/etl__completion/etl__completion.002.ndjson @@ -0,0 +1 @@ +{"table_name": "condition", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/simple/batched-output/etl__completion/etl__completion.003.ndjson b/tests/data/simple/batched-output/etl__completion/etl__completion.003.ndjson new file mode 100644 index 00000000..af658bc9 --- /dev/null +++ b/tests/data/simple/batched-output/etl__completion/etl__completion.003.ndjson @@ -0,0 +1 @@ +{"table_name": "documentreference", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/simple/batched-output/etl__completion/etl__completion.004.ndjson b/tests/data/simple/batched-output/etl__completion/etl__completion.004.ndjson new file mode 100644 index 00000000..6b54b0a4 --- /dev/null +++ b/tests/data/simple/batched-output/etl__completion/etl__completion.004.ndjson @@ -0,0 +1,2 @@ +{"table_name": "medication", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} +{"table_name": "medicationrequest", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/simple/batched-output/etl__completion/etl__completion.005.ndjson b/tests/data/simple/batched-output/etl__completion/etl__completion.005.ndjson new file mode 100644 index 00000000..f40f112a --- /dev/null +++ b/tests/data/simple/batched-output/etl__completion/etl__completion.005.ndjson @@ -0,0 +1 @@ +{"table_name": "observation", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/simple/batched-output/etl__completion/etl__completion.006.ndjson b/tests/data/simple/batched-output/etl__completion/etl__completion.006.ndjson new file mode 100644 index 00000000..2de4de83 --- /dev/null +++ b/tests/data/simple/batched-output/etl__completion/etl__completion.006.ndjson @@ -0,0 +1 @@ +{"table_name": "procedure", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/simple/batched-output/etl__completion/etl__completion.007.ndjson b/tests/data/simple/batched-output/etl__completion/etl__completion.007.ndjson new file mode 100644 index 00000000..3c343dee --- /dev/null +++ b/tests/data/simple/batched-output/etl__completion/etl__completion.007.ndjson @@ -0,0 +1 @@ +{"table_name": "servicerequest", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/simple/batched-output/etl__completion_encounters/etl__completion_encounters.000.ndjson b/tests/data/simple/batched-output/etl__completion_encounters/etl__completion_encounters.000.ndjson new file mode 100644 index 00000000..4343588e --- /dev/null +++ b/tests/data/simple/batched-output/etl__completion_encounters/etl__completion_encounters.000.ndjson @@ -0,0 +1 @@ +{"encounter_id": "08f0ebd4-950c-ddd9-ce97-b5bdf073eed1", "export_time": "2020-10-13T12:00:20-05:00", "group_name": "test-group"} diff --git a/tests/data/simple/batched-output/etl__completion_encounters/etl__completion_encounters.001.ndjson b/tests/data/simple/batched-output/etl__completion_encounters/etl__completion_encounters.001.ndjson new file mode 100644 index 00000000..cd966a82 --- /dev/null +++ b/tests/data/simple/batched-output/etl__completion_encounters/etl__completion_encounters.001.ndjson @@ -0,0 +1 @@ +{"encounter_id": "af1e6186-3f9a-1fa9-3c73-cfa56c84a056", "export_time": "2020-10-13T12:00:20-05:00", "group_name": "test-group"} diff --git a/tests/data/simple/output/etl__completion/etl__completion.000.ndjson b/tests/data/simple/output/etl__completion/etl__completion.000.ndjson new file mode 100644 index 00000000..7cfd70e5 --- /dev/null +++ b/tests/data/simple/output/etl__completion/etl__completion.000.ndjson @@ -0,0 +1 @@ +{"table_name": "encounter", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/simple/output/etl__completion/etl__completion.001.ndjson b/tests/data/simple/output/etl__completion/etl__completion.001.ndjson new file mode 100644 index 00000000..5d6ada9f --- /dev/null +++ b/tests/data/simple/output/etl__completion/etl__completion.001.ndjson @@ -0,0 +1 @@ +{"table_name": "patient", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/simple/output/etl__completion/etl__completion.002.ndjson b/tests/data/simple/output/etl__completion/etl__completion.002.ndjson new file mode 100644 index 00000000..baaa7357 --- /dev/null +++ b/tests/data/simple/output/etl__completion/etl__completion.002.ndjson @@ -0,0 +1 @@ +{"table_name": "allergyintolerance", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/simple/output/etl__completion/etl__completion.003.ndjson b/tests/data/simple/output/etl__completion/etl__completion.003.ndjson new file mode 100644 index 00000000..0d222bb9 --- /dev/null +++ b/tests/data/simple/output/etl__completion/etl__completion.003.ndjson @@ -0,0 +1 @@ +{"table_name": "condition", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/simple/output/etl__completion/etl__completion.004.ndjson b/tests/data/simple/output/etl__completion/etl__completion.004.ndjson new file mode 100644 index 00000000..61934fef --- /dev/null +++ b/tests/data/simple/output/etl__completion/etl__completion.004.ndjson @@ -0,0 +1 @@ +{"table_name": "device", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/simple/output/etl__completion/etl__completion.005.ndjson b/tests/data/simple/output/etl__completion/etl__completion.005.ndjson new file mode 100644 index 00000000..ecaca931 --- /dev/null +++ b/tests/data/simple/output/etl__completion/etl__completion.005.ndjson @@ -0,0 +1 @@ +{"table_name": "diagnosticreport", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/simple/output/etl__completion/etl__completion.006.ndjson b/tests/data/simple/output/etl__completion/etl__completion.006.ndjson new file mode 100644 index 00000000..af658bc9 --- /dev/null +++ b/tests/data/simple/output/etl__completion/etl__completion.006.ndjson @@ -0,0 +1 @@ +{"table_name": "documentreference", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/simple/output/etl__completion/etl__completion.007.ndjson b/tests/data/simple/output/etl__completion/etl__completion.007.ndjson new file mode 100644 index 00000000..786ef05d --- /dev/null +++ b/tests/data/simple/output/etl__completion/etl__completion.007.ndjson @@ -0,0 +1 @@ +{"table_name": "immunization", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/simple/output/etl__completion/etl__completion.008.ndjson b/tests/data/simple/output/etl__completion/etl__completion.008.ndjson new file mode 100644 index 00000000..6b54b0a4 --- /dev/null +++ b/tests/data/simple/output/etl__completion/etl__completion.008.ndjson @@ -0,0 +1,2 @@ +{"table_name": "medication", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} +{"table_name": "medicationrequest", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/simple/output/etl__completion/etl__completion.009.ndjson b/tests/data/simple/output/etl__completion/etl__completion.009.ndjson new file mode 100644 index 00000000..f40f112a --- /dev/null +++ b/tests/data/simple/output/etl__completion/etl__completion.009.ndjson @@ -0,0 +1 @@ +{"table_name": "observation", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/simple/output/etl__completion/etl__completion.010.ndjson b/tests/data/simple/output/etl__completion/etl__completion.010.ndjson new file mode 100644 index 00000000..2de4de83 --- /dev/null +++ b/tests/data/simple/output/etl__completion/etl__completion.010.ndjson @@ -0,0 +1 @@ +{"table_name": "procedure", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/simple/output/etl__completion/etl__completion.011.ndjson b/tests/data/simple/output/etl__completion/etl__completion.011.ndjson new file mode 100644 index 00000000..3c343dee --- /dev/null +++ b/tests/data/simple/output/etl__completion/etl__completion.011.ndjson @@ -0,0 +1 @@ +{"table_name": "servicerequest", "group_name": "test-group", "export_time": "2020-10-13T12:00:20-05:00"} diff --git a/tests/data/simple/output/etl__completion_encounters/etl__completion_encounters.000.ndjson b/tests/data/simple/output/etl__completion_encounters/etl__completion_encounters.000.ndjson new file mode 100644 index 00000000..44531a1c --- /dev/null +++ b/tests/data/simple/output/etl__completion_encounters/etl__completion_encounters.000.ndjson @@ -0,0 +1,2 @@ +{"encounter_id": "08f0ebd4-950c-ddd9-ce97-b5bdf073eed1", "export_time": "2020-10-13T12:00:20-05:00", "group_name": "test-group"} +{"encounter_id": "af1e6186-3f9a-1fa9-3c73-cfa56c84a056", "export_time": "2020-10-13T12:00:20-05:00", "group_name": "test-group"} diff --git a/tests/etl/base.py b/tests/etl/base.py index 7062e499..7c471e06 100644 --- a/tests/etl/base.py +++ b/tests/etl/base.py @@ -1,5 +1,6 @@ """Base classes for ETL-oriented tests""" +import datetime import os import shutil import tempfile @@ -31,12 +32,12 @@ def setUp(self): self.root_path = os.path.join(self.datadir, self.DATA_ROOT) self.input_path = os.path.join(self.root_path, "input") - tmpdir = tempfile.mkdtemp() + self.tmpdir = tempfile.mkdtemp() # Comment out this next line when debugging, to persist directory - self.addCleanup(shutil.rmtree, tmpdir) + # self.addCleanup(shutil.rmtree, self.tmpdir) - self.output_path = os.path.join(tmpdir, "output") - self.phi_path = os.path.join(tmpdir, "phi") + self.output_path = os.path.join(self.tmpdir, "output") + self.phi_path = os.path.join(self.tmpdir, "phi") self.enforce_consistent_uuids() @@ -54,6 +55,9 @@ async def run_etl( errors_to=None, export_to: str = None, input_format: str = "ndjson", + export_group: str = "test-group", + export_timestamp: str = "2020-10-13T12:00:20-05:00", + write_completion: bool = True, ) -> None: args = [ input_path or self.input_path, @@ -63,6 +67,12 @@ async def run_etl( f"--input-format={input_format}", f"--ctakes-overrides={self.ctakes_overrides.name}", ] + if export_group is not None: + args.append(f"--export-group={export_group}") + if export_timestamp: + args.append(f"--export-timestamp={export_timestamp}") + if write_completion: + args.append("--write-completion") if output_format: args.append(f"--output-format={output_format}") if comment: @@ -103,6 +113,7 @@ def setUp(self) -> None: client = fhir.FhirClient("http://localhost/", []) self.tmpdir = self.make_tempdir() self.input_dir = os.path.join(self.tmpdir, "input") + self.output_dir = os.path.join(self.tmpdir, "output") self.phi_dir = os.path.join(self.tmpdir, "phi") self.errors_dir = os.path.join(self.tmpdir, "errors") os.makedirs(self.input_dir) @@ -111,17 +122,19 @@ def setUp(self) -> None: self.job_config = JobConfig( self.input_dir, self.input_dir, - self.tmpdir, + self.output_dir, self.phi_dir, "ndjson", "ndjson", client, batch_size=5, dir_errors=self.errors_dir, + export_group_name="test-group", + export_datetime=datetime.datetime(2012, 10, 10, 5, 30, 12, tzinfo=datetime.timezone.utc), ) - def make_formatter(dbname: str, group_field: str = None, resource_type: str = None): - formatter = mock.MagicMock(dbname=dbname, group_field=group_field, resource_type=resource_type) + def make_formatter(dbname: str, **kwargs): + formatter = mock.MagicMock(dbname=dbname, **kwargs) self.format_count += 1 if self.format_count == 1: self.format = self.format or formatter @@ -129,11 +142,15 @@ def make_formatter(dbname: str, group_field: str = None, resource_type: str = No elif self.format_count == 2: self.format2 = self.format2 or formatter return self.format2 + elif self.format_count == 3: + self.format3 = self.format3 or formatter + return self.format3 else: return formatter # stop keeping track self.format = None - self.format2 = None # for tasks that have multiple output streams + self.format2 = None # etl__completion (or second output table for a few tasks) + self.format3 = None # etl__completion for double-table tasks self.format_count = 0 self.create_formatter_mock = mock.MagicMock(side_effect=make_formatter) self.job_config.create_formatter = self.create_formatter_mock diff --git a/tests/etl/test_etl_cli.py b/tests/etl/test_etl_cli.py index 4d45f6d7..b9db6633 100644 --- a/tests/etl/test_etl_cli.py +++ b/tests/etl/test_etl_cli.py @@ -1,5 +1,6 @@ """Tests for etl/cli.py""" +import datetime import itertools import json import os @@ -105,7 +106,7 @@ def fake_load_all(internal_self, resources): await self.run_etl(tasks=["observation"]) # Confirm we only wrote the one resource - self.assertEqual({"observation", "JobConfig"}, set(os.listdir(self.output_path))) + self.assertEqual({"etl__completion", "observation", "JobConfig"}, set(os.listdir(self.output_path))) self.assertEqual(["observation.000.ndjson"], os.listdir(os.path.join(self.output_path, "observation"))) async def test_multiple_tasks(self): @@ -121,8 +122,8 @@ def fake_load_all(internal_self, resources): with mock.patch.object(loaders.FhirNdjsonLoader, "load_all", new=fake_load_all): await self.run_etl(tasks=["observation", "patient"]) - # Confirm we only wrote the one resource - self.assertEqual({"observation", "patient", "JobConfig"}, set(os.listdir(self.output_path))) + # Confirm we only wrote the two resources + self.assertEqual({"etl__completion", "observation", "patient", "JobConfig"}, set(os.listdir(self.output_path))) self.assertEqual(["observation.000.ndjson"], os.listdir(os.path.join(self.output_path, "observation"))) self.assertEqual(["patient.000.ndjson"], os.listdir(os.path.join(self.output_path, "patient"))) @@ -158,8 +159,8 @@ async def test_errors_to_must_be_empty(self): async def test_errors_to_passed_to_tasks(self): with self.assertRaises(SystemExit): with mock.patch("cumulus_etl.etl.cli.etl_job", side_effect=SystemExit) as mock_etl_job: - await self.run_etl(errors_to=f"{self.output_path}/errors") - self.assertEqual(mock_etl_job.call_args[0][0].dir_errors, f"{self.output_path}/errors") + await self.run_etl(errors_to=f"{self.tmpdir}/errors") + self.assertEqual(mock_etl_job.call_args[0][0].dir_errors, f"{self.tmpdir}/errors") @respx.mock async def test_bulk_no_auth(self): @@ -172,6 +173,59 @@ async def test_bulk_no_auth(self): await self.run_etl(input_path="https://localhost:12345/", tasks=["patient"]) self.assertEqual(errors.FHIR_AUTH_FAILED, cm.exception.code) + @ddt.data( + ( + {"export_group": "X", "export_timestamp": "2020-01-02", "write_completion": True}, + ("L", datetime.datetime(2010, 12, 12)), + ("X", datetime.datetime(2020, 1, 2)), + ), + ( + {"export_group": "X", "export_timestamp": "2020-01-02", "write_completion": False}, + ("L", datetime.datetime(2010, 12, 12)), + (None, None), + ), + ( + {"export_group": None, "export_timestamp": None, "write_completion": True}, + ("L", datetime.datetime(2010, 12, 12)), + ("L", datetime.datetime(2010, 12, 12)), + ), + ( + {"export_group": "X", "export_timestamp": None, "write_completion": True}, + (None, None), + None, # errors out + ), + ( + {"export_group": None, "export_timestamp": "2020-01-02", "write_completion": True}, + (None, None), + None, # errors out + ), + ) + @ddt.unpack + async def test_completion_args(self, etl_args, loader_vals, expected_vals): + """Verify that we parse completion args with the correct fallbacks and checks.""" + # Grab all observations before we mock anything + observations = loaders.FhirNdjsonLoader(store.Root(self.input_path)).load_all(["Observation"]) + + def fake_load_all(internal_self, resources): + del resources + internal_self.group_name = loader_vals[0] + internal_self.export_datetime = loader_vals[1] + return observations + + with ( + self.assertRaises(SystemExit) as cm, + mock.patch("cumulus_etl.etl.cli.etl_job", side_effect=SystemExit) as mock_etl_job, + mock.patch.object(loaders.FhirNdjsonLoader, "load_all", new=fake_load_all), + ): + await self.run_etl(tasks=["observation"], **etl_args) + + if expected_vals is None: + self.assertEqual(errors.COMPLETION_ARG_MISSING, cm.exception.code) + else: + config = mock_etl_job.call_args[0][0] + self.assertEqual(expected_vals[0], config.export_group_name) + self.assertEqual(expected_vals[1], config.export_datetime) + class TestEtlJobConfig(BaseEtlSimple): """Test case for the job config logging data""" @@ -185,11 +239,30 @@ def read_config_file(self, name: str) -> dict: with open(full_path, "r", encoding="utf8") as f: return json.load(f) - async def test_comment(self): - """Verify that a comment makes it from command line to the log file""" - await self.run_etl(comment="Run by foo on machine bar") + async def test_serialization(self): + """Verify that everything makes it from command line to the log file""" + await self.run_etl( + batch_size=100, + comment="Run by foo on machine bar", + tasks=["condition", "patient"], + ) config_file = self.read_config_file("job_config.json") - self.assertEqual(config_file["comment"], "Run by foo on machine bar") + self.assertEqual( + { + "dir_input": self.input_path, + "dir_output": self.output_path, + "dir_phi": self.phi_path, + "path": f"{self.job_config_path}/job_config.json", + "input_format": "ndjson", + "output_format": "ndjson", + "comment": "Run by foo on machine bar", + "batch_size": 100, + "tasks": "patient,condition", + "export_group_name": "test-group", + "export_timestamp": "2020-10-13T12:00:20-05:00", + }, + config_file, + ) class TestEtlJobContext(BaseEtlSimple): @@ -290,6 +363,9 @@ async def test_etl_job_s3(self): self.assertEqual( { "mockbucket/root/condition/condition.000.ndjson", + "mockbucket/root/etl__completion/etl__completion.000.ndjson", + "mockbucket/root/etl__completion/etl__completion.001.ndjson", + "mockbucket/root/etl__completion/etl__completion.002.ndjson", "mockbucket/root/medication/medication.000.ndjson", "mockbucket/root/medicationrequest/medicationrequest.000.ndjson", "mockbucket/root/patient/patient.000.ndjson", diff --git a/tests/etl/test_tasks.py b/tests/etl/test_tasks.py index 799b7fc9..ebffdde6 100644 --- a/tests/etl/test_tasks.py +++ b/tests/etl/test_tasks.py @@ -123,6 +123,159 @@ async def test_batch_is_given_schema(self): self.assertIn("id", schema.names) +@ddt.ddt +class TestTaskCompletion(TaskTestCase): + """Tests for etl__completion* handling""" + + async def test_encounter_completion(self): + """Verify that we write out completion data correctly""" + self.make_json("Encounter.1", "A") + self.make_json("Encounter.2", "B") + self.make_json("Encounter.3", "C") + self.job_config.batch_size = 4 # two encounters at a time (each encounter makes 2 rows) + + await basic_tasks.EncounterTask(self.job_config, self.scrubber).run() + + comp_enc_format = self.format # etl__completion_encounters + enc_format = self.format2 # encounter + comp_format = self.format3 # etl__completion + + self.assertEqual("etl__completion_encounters", comp_enc_format.dbname) + self.assertEqual({"encounter_id", "group_name"}, comp_enc_format.uniqueness_fields) + self.assertFalse(comp_enc_format.update_existing) + + self.assertEqual("etl__completion", comp_format.dbname) + self.assertEqual({"table_name", "group_name"}, comp_format.uniqueness_fields) + self.assertTrue(comp_format.update_existing) + + self.assertEqual(2, comp_enc_format.write_records.call_count) + self.assertEqual(2, enc_format.write_records.call_count) + self.assertEqual(1, comp_format.write_records.call_count) + + comp_enc_batches = [call[0][0] for call in comp_enc_format.write_records.call_args_list] + self.assertEqual( + [ + { + "encounter_id": self.codebook.db.encounter("A"), + "group_name": "test-group", + "export_time": "2012-10-10T05:30:12+00:00", + }, + { + "encounter_id": self.codebook.db.encounter("B"), + "group_name": "test-group", + "export_time": "2012-10-10T05:30:12+00:00", + }, + ], + comp_enc_batches[0].rows, + ) + self.assertEqual( + [ + { + "encounter_id": self.codebook.db.encounter("C"), + "group_name": "test-group", + "export_time": "2012-10-10T05:30:12+00:00", + } + ], + comp_enc_batches[1].rows, + ) + + comp_batch = comp_format.write_records.call_args[0][0] + self.assertEqual( + [ + { + "table_name": "encounter", + "group_name": "test-group", + "export_time": "2012-10-10T05:30:12+00:00", + } + ], + comp_batch.rows, + ) + + async def test_medication_completion(self): + """ + Verify that we write out Medication completion too. + + We just want to verify that we handle multi-output tasks. + """ + self.make_json("MedicationRequest.1", "A") + + await basic_tasks.MedicationRequestTask(self.job_config, self.scrubber).run() + + med_req_format = self.format # MedicationRequest + med_format = self.format2 # Medication (second because no content, touched in finalize) + comp_format = self.format3 # etl__completion + + self.assertEqual("medication", med_format.dbname) + self.assertEqual("medicationrequest", med_req_format.dbname) + self.assertEqual("etl__completion", comp_format.dbname) + + self.assertEqual(1, med_format.write_records.call_count) + self.assertEqual(1, med_req_format.write_records.call_count) + self.assertEqual(1, comp_format.write_records.call_count) + + comp_batch = comp_format.write_records.call_args[0][0] + self.assertEqual( + [ + { + "table_name": "medication", + "group_name": "test-group", + "export_time": "2012-10-10T05:30:12+00:00", + }, + { + "table_name": "medicationrequest", + "group_name": "test-group", + "export_time": "2012-10-10T05:30:12+00:00", + }, + ], + comp_batch.rows, + ) + + @ddt.data("export_datetime", "export_group_name") + async def test_completion_disabled(self, null_field): + """Verify that we don't write completion data if we don't have args for it""" + self.make_json("Encounter.1", "A") + setattr(self.job_config, null_field, None) + + await basic_tasks.EncounterTask(self.job_config, self.scrubber).run() + + # This order is unusual - normally `encounter` is second, + # but because there is no content for etl__completion_encounters, + # it is only touched during task finalization, so it goes second instead. + enc_format = self.format # encounter + comp_enc_format = self.format2 # etl__completion_encounters + comp_format = self.format3 # etl__completion + + self.assertEqual("encounter", enc_format.dbname) + self.assertEqual("etl__completion_encounters", comp_enc_format.dbname) + self.assertIsNone(comp_format) # tasks don't create this when completion is disabled + + self.assertEqual(1, comp_enc_format.write_records.call_count) + self.assertEqual(1, enc_format.write_records.call_count) + + self.assertEqual([], comp_enc_format.write_records.call_args[0][0].rows) + + async def test_allow_empty_group(self): + """Empty groups are (rarely) used to mark a server-wide global export""" + self.make_json("Device.1", "A") + self.job_config.export_group_name = "" + + await basic_tasks.DeviceTask(self.job_config, self.scrubber).run() + + comp_format = self.format2 # etl__completion + + self.assertEqual(1, comp_format.write_records.call_count) + self.assertEqual( + [ + { + "table_name": "device", + "group_name": "", + "export_time": "2012-10-10T05:30:12+00:00", + } + ], + comp_format.write_records.call_args[0][0].rows, + ) + + @ddt.ddt class TestMedicationRequestTask(TaskTestCase): """Test case for MedicationRequestTask, which has some extra logic than normal FHIR resources""" @@ -134,8 +287,11 @@ async def test_inline_codes(self): await basic_tasks.MedicationRequestTask(self.job_config, self.scrubber).run() - self.assertEqual(1, self.format.write_records.call_count) - batch = self.format.write_records.call_args[0][0] + med_req_format = self.format + med_format = self.format2 # second because it's empty + + self.assertEqual(1, med_req_format.write_records.call_count) + batch = med_req_format.write_records.call_args[0][0] self.assertEqual( { self.codebook.db.resource_hash("InlineCode"), @@ -145,8 +301,8 @@ async def test_inline_codes(self): ) # Confirm we wrote an empty dataframe to the medication table - self.assertEqual(1, self.format2.write_records.call_count) - batch = self.format2.write_records.call_args[0][0] + self.assertEqual(1, med_format.write_records.call_count) + batch = med_format.write_records.call_args[0][0] self.assertEqual([], batch.rows) async def test_contained_medications(self): @@ -155,14 +311,17 @@ async def test_contained_medications(self): await basic_tasks.MedicationRequestTask(self.job_config, self.scrubber).run() + med_req_format = self.format + med_format = self.format2 # second because it's empty + # Confirm we wrote the basic MedicationRequest - self.assertEqual(1, self.format.write_records.call_count) - batch = self.format.write_records.call_args[0][0] + self.assertEqual(1, med_req_format.write_records.call_count) + batch = med_req_format.write_records.call_args[0][0] self.assertEqual(f'#{self.codebook.db.resource_hash("123")}', batch.rows[0]["medicationReference"]["reference"]) # Confirm we wrote an empty dataframe to the medication table - self.assertEqual(1, self.format2.write_records.call_count) - batch = self.format2.write_records.call_args[0][0] + self.assertEqual(1, med_format.write_records.call_count) + batch = med_format.write_records.call_args[0][0] self.assertEqual(0, len(batch.rows)) @mock.patch("cumulus_etl.fhir.download_reference") @@ -173,19 +332,33 @@ async def test_external_medications(self, mock_download): await basic_tasks.MedicationRequestTask(self.job_config, self.scrubber).run() - # Confirm we made both formatters correctly - self.assertEqual(2, self.create_formatter_mock.call_count) + med_format = self.format + med_req_format = self.format2 + + # Confirm we made all table formatters correctly + self.assertEqual(3, self.create_formatter_mock.call_count) self.assertEqual( [ - mock.call("medicationrequest", group_field=None, resource_type="MedicationRequest"), - mock.call("medication", group_field=None, resource_type="Medication"), + mock.call( + "medication", + group_field=None, + uniqueness_fields=None, + update_existing=True, + ), + mock.call( + "medicationrequest", + group_field=None, + uniqueness_fields=None, + update_existing=True, + ), + mock.call(dbname="etl__completion", uniqueness_fields={"group_name", "table_name"}), ], self.create_formatter_mock.call_args_list, ) # Confirm we wrote the basic MedicationRequest - self.assertEqual(1, self.format.write_records.call_count) - batch = self.format.write_records.call_args[0][0] + self.assertEqual(1, med_req_format.write_records.call_count) + batch = med_req_format.write_records.call_args[0][0] self.assertEqual([self.codebook.db.resource_hash("A")], [row["id"] for row in batch.rows]) self.assertEqual( f'Medication/{self.codebook.db.resource_hash("123")}', @@ -193,8 +366,8 @@ async def test_external_medications(self, mock_download): ) # AND that we wrote the downloaded resource! - self.assertEqual(1, self.format2.write_records.call_count) - batch = self.format2.write_records.call_args[0][0] + self.assertEqual(1, med_format.write_records.call_count) + batch = med_format.write_records.call_args[0][0] self.assertEqual( [self.codebook.db.resource_hash("med1")], [row["id"] for row in batch.rows] ) # meds should be scrubbed too @@ -214,9 +387,11 @@ async def test_external_medication_scrubbed(self, mock_download): await basic_tasks.MedicationRequestTask(self.job_config, self.scrubber).run() + med_format = self.format + # Check result - self.assertEqual(1, self.format2.write_records.call_count) - batch = self.format2.write_records.call_args[0][0] + self.assertEqual(1, med_format.write_records.call_count) + batch = med_format.write_records.call_args[0][0] self.assertEqual( { "resourceType": "Medication", @@ -240,14 +415,17 @@ async def test_external_medications_with_error(self, mock_download): await basic_tasks.MedicationRequestTask(self.job_config, self.scrubber).run() + med_format = self.format + med_req_format = self.format2 + # Confirm we still wrote out all three request resources - self.assertEqual(1, self.format.write_records.call_count) - batch = self.format.write_records.call_args[0][0] + self.assertEqual(1, med_req_format.write_records.call_count) + batch = med_req_format.write_records.call_args[0][0] self.assertEqual(3, len(batch.rows)) # Confirm we still wrote out the medication for B - self.assertEqual(1, self.format2.write_records.call_count) - batch = self.format2.write_records.call_args[0][0] + self.assertEqual(1, med_format.write_records.call_count) + batch = med_format.write_records.call_args[0][0] self.assertEqual([self.codebook.db.resource_hash("medB")], [row["id"] for row in batch.rows]) # And we saved the error? @@ -272,6 +450,8 @@ async def test_external_medications_skips_duplicates(self, mock_download): await basic_tasks.MedicationRequestTask(self.job_config, self.scrubber).run() + med_format = self.format + # Confirm we only called the download method twice self.assertEqual( [ @@ -282,10 +462,10 @@ async def test_external_medications_skips_duplicates(self, mock_download): ) # Confirm we wrote just the downloaded resources, and didn't repeat the dup at all - self.assertEqual(2, self.format2.write_records.call_count) - batch = self.format2.write_records.call_args_list[0][0][0] + self.assertEqual(2, med_format.write_records.call_count) + batch = med_format.write_records.call_args_list[0][0][0] self.assertEqual([self.codebook.db.resource_hash("dup")], [row["id"] for row in batch.rows]) - batch = self.format2.write_records.call_args_list[1][0][0] + batch = med_format.write_records.call_args_list[1][0][0] self.assertEqual([self.codebook.db.resource_hash("new")], [row["id"] for row in batch.rows]) @mock.patch("cumulus_etl.fhir.download_reference") @@ -307,6 +487,8 @@ async def test_external_medications_skips_unknown_modifiers(self, mock_download) await basic_tasks.MedicationRequestTask(self.job_config, self.scrubber).run() - self.assertEqual(1, self.format2.write_records.call_count) - batch = self.format2.write_records.call_args[0][0] + med_format = self.format + + self.assertEqual(1, med_format.write_records.call_count) + batch = med_format.write_records.call_args[0][0] self.assertEqual([self.codebook.db.resource_hash("good")], [row["id"] for row in batch.rows]) # no "odd" diff --git a/tests/formats/test_deltalake.py b/tests/formats/test_deltalake.py index b9b9754e..39a11805 100644 --- a/tests/formats/test_deltalake.py +++ b/tests/formats/test_deltalake.py @@ -57,10 +57,9 @@ def get_spark_schema(self, df: pyspark.sql.DataFrame) -> str: def store( self, rows: list[dict], - batch_index: int = 10, schema: pyarrow.Schema = None, - group_field: str = None, groups: set[str] = None, + **kwargs, ) -> bool: """ Writes a single batch of data to the data lake. @@ -68,16 +67,15 @@ def store( :param rows: the data to insert :param batch_index: which batch number this is, defaulting to 10 to avoid triggering any first/last batch logic :param schema: the batch schema, in pyarrow format - :param group_field: a group field name, used to delete non-matching group rows :param groups: all group values for this batch (ignored if group_field is not set) """ - deltalake = DeltaLakeFormat(self.root, "patient", group_field=group_field) - batch = formats.Batch(rows, groups=groups, index=batch_index, schema=schema) + deltalake = DeltaLakeFormat(self.root, "patient", **kwargs) + batch = formats.Batch(rows, groups=groups, schema=schema) return deltalake.write_records(batch) def assert_lake_equal(self, rows: list[dict]) -> None: table_path = os.path.join(self.output_dir, "patient") - rows_by_id = sorted(rows, key=lambda x: x["id"]) + rows_by_id = sorted(rows, key=lambda x: x.get("id", sorted(x.items()))) self.assertListEqual(rows_by_id, utils.read_delta_lake(table_path)) def test_creates_if_empty(self): @@ -349,3 +347,29 @@ def test_group_field(self): d={"group": 'D"', "val": 3}, ) ) + + def test_custom_uniqueness(self): + """Verify that `uniqueness_fields` is properly handled.""" + ids = {"F1", "F2"} + self.store( + [ + {"F1": 1, "F2": 2, "msg": "original value"}, + {"F1": 1, "F2": 9, "msg": "same F1"}, + {"F1": 9, "F2": 2, "msg": "same F2"}, + ], + uniqueness_fields=ids, + ) + self.store([{"F1": 1, "F2": 2, "msg": "new"}], uniqueness_fields=ids) + self.assert_lake_equal( + [ + {"F1": 1, "F2": 2, "msg": "new"}, + {"F1": 1, "F2": 9, "msg": "same F1"}, + {"F1": 9, "F2": 2, "msg": "same F2"}, + ] + ) + + def test_update_existing(self): + """Verify that `update_existing` is properly handled.""" + self.store(self.df(a=1, b=2)) + self.store(self.df(a=999, c=3), update_existing=False) + self.assert_lake_equal(self.df(a=1, b=2, c=3)) diff --git a/tests/formats/test_ndjson.py b/tests/formats/test_ndjson.py index 78d99a98..2c5e412e 100644 --- a/tests/formats/test_ndjson.py +++ b/tests/formats/test_ndjson.py @@ -4,7 +4,7 @@ import ddt -from cumulus_etl import formats, store +from cumulus_etl import common, store from cumulus_etl.formats.ndjson import NdjsonFormat from tests import utils @@ -15,59 +15,44 @@ class TestNdjsonFormat(utils.AsyncTestCase): Test case for the ndjson format writer. i.e. tests for ndjson.py + + Note that a lot of the basics of that formatter gets tested in other unit tests. + This class is mostly just for the less typical edge cases. """ def setUp(self): super().setUp() self.output_tempdir = self.make_tempdir() self.root = store.Root(self.output_tempdir) - NdjsonFormat.initialize_class(self.root) - - @staticmethod - def df(**kwargs) -> list[dict]: - """ - Creates a dummy Table with ids & values equal to each kwarg provided. - """ - return [{"id": k, "value": v} for k, v in kwargs.items()] - - def store( - self, - rows: list[dict], - batch_index: int = 10, - ) -> bool: - """ - Writes a single batch of data to the output dir. - - :param rows: the data to insert - :param batch_index: which batch number this is, defaulting to 10 to avoid triggering any first/last batch logic - """ - ndjson = NdjsonFormat(self.root, "condition") - batch = formats.Batch(rows, index=batch_index) - return ndjson.write_records(batch) @ddt.data( (None, True), ([], True), - (["condition.1234.ndjson", "condition.22.ndjson"], True), - (["condition.000.meta"], True), - (["condition.ndjson"], False), - (["condition.000.parquet"], False), - (["patient.000.ndjson"], False), + (["condition/condition.000.ndjson", "condition/condition.111.ndjson"], False), + (["condition/readme.txt"], False), + (["my-novel.txt"], False), ) @ddt.unpack - def test_handles_existing_files(self, files: None | list[str], is_ok: bool): - """Verify that we bail out if any weird files already exist in the output""" - dbpath = self.root.joinpath("condition") - if files is not None: - os.makedirs(dbpath) + def test_disallows_existing_files(self, files: None | list[str], is_ok: bool): + """Verify that we bail out if any files already exist in the output""" + if files is None: + # This means we don't want any folder at all for the test + os.rmdir(self.root.path) + else: for file in files: - with open(f"{dbpath}/{file}", "w", encoding="utf8") as f: - f.write('{"id": "A"}') + pieces = file.split("/") + if len(pieces) > 1: + os.makedirs(self.root.joinpath(pieces[0]), exist_ok=True) + # write any old content in there, we just want to create the file. + common.write_text(self.root.joinpath(file), "Hello!") if is_ok: - self.store([{"id": "B"}], batch_index=0) - self.assertEqual(["condition.000.ndjson"], os.listdir(dbpath)) + NdjsonFormat.initialize_class(self.root) + # Test that we didn't adjust/remove/create any of the files on disk + if files is None: + self.assertFalse(os.path.exists(self.root.path)) + else: + self.assertEqual(files or [], os.listdir(self.root.path)) else: with self.assertRaises(SystemExit): - self.store([{"id": "B"}]) - self.assertEqual(files or [], os.listdir(dbpath)) + NdjsonFormat.initialize_class(self.root) diff --git a/tests/test_bulk_export.py b/tests/test_bulk_export.py index 900dadc2..56c574e4 100644 --- a/tests/test_bulk_export.py +++ b/tests/test_bulk_export.py @@ -42,11 +42,12 @@ async def test_happy_path(self): make_response(status_code=202, headers={"Content-Location": "https://example.com/poll"}), # kickoff make_response( json_payload={ + "transactionTime": "2015-02-07T13:28:17.239+02:00", "output": [ {"type": "Condition", "url": "https://example.com/con1"}, {"type": "Condition", "url": "https://example.com/con2"}, {"type": "Patient", "url": "https://example.com/pat1"}, - ] + ], } ), # status make_response(json_payload={"type": "Condition1"}, stream=True), # download @@ -55,7 +56,7 @@ async def test_happy_path(self): make_response(status_code=202), # delete request ] - await self.export() + exporter = await self.export() self.assertListEqual( [ @@ -79,6 +80,9 @@ async def test_happy_path(self): self.server.request.call_args_list, ) + self.assertEqual("", exporter.group_name) # global group name is empty string + self.assertEqual("2015-02-07T13:28:17.239000+02:00", exporter.export_datetime.isoformat()) + self.assertEqual({"type": "Condition1"}, common.read_json(f"{self.tmpdir}/Condition.000.ndjson")) self.assertEqual({"type": "Condition2"}, common.read_json(f"{self.tmpdir}/Condition.001.ndjson")) self.assertEqual({"type": "Patient1"}, common.read_json(f"{self.tmpdir}/Patient.000.ndjson")) @@ -108,6 +112,7 @@ async def test_export_error(self): make_response(status_code=202, headers={"Content-Location": "https://example.com/poll"}), # kickoff make_response( json_payload={ + "transactionTime": "2015-02-07T13:28:17.239+02:00", "error": [ {"type": "OperationOutcome", "url": "https://example.com/err1"}, {"type": "OperationOutcome", "url": "https://example.com/err2"}, @@ -156,6 +161,7 @@ async def test_export_warning(self): make_response(status_code=202, headers={"Content-Location": "https://example.com/poll"}), # kickoff make_response( json_payload={ + "transactionTime": "2015-02-07T13:28:17.239+02:00", "error": [ {"type": "OperationOutcome", "url": "https://example.com/warning1"}, ], @@ -246,6 +252,7 @@ def setUp(self) -> None: super().setUp() self.root = store.Root("http://localhost:9999/fhir") + self.input_url = self.root.joinpath("Group/MyGroup") self.client_id = "test-client-id" self.jwks_file = tempfile.NamedTemporaryFile() # pylint: disable=consider-using-with @@ -286,7 +293,7 @@ def set_up_requests(self, respx_mock): # /$export respx_mock.get( - f"{self.root.path}/$export", + f"{self.input_url}/$export", headers={ "Accept": "application/fhir+json", "Authorization": "Bearer 1234567890", @@ -309,6 +316,7 @@ def set_up_requests(self, respx_mock): }, ).respond( json={ + "transactionTime": "2015-02-07T13:28:17+02:00", "output": [{"type": "Patient", "url": f"{self.root.path}/download/patient1"}], }, ) @@ -346,7 +354,7 @@ async def test_successful_bulk_export(self): await cli.main( [ - self.root.path, + self.input_url, f"{tmpdir}/output", f"{tmpdir}/phi", "--skip-init-checks", @@ -354,6 +362,7 @@ async def test_successful_bulk_export(self): "--task=patient", f"--smart-client-id={self.client_id}", f"--smart-jwks={self.jwks_path}", + "--write-completion", ] ) @@ -361,3 +370,12 @@ async def test_successful_bulk_export(self): {"id": "4342abf315cf6f243e11f4d460303e36c6c3663a25c91cc6b1a8002476c850dd", "resourceType": "Patient"}, common.read_json(f"{tmpdir}/output/patient/patient.000.ndjson"), ) + + self.assertEqual( + { + "table_name": "patient", + "group_name": "MyGroup", + "export_time": "2015-02-07T13:28:17+02:00", + }, + common.read_json(f"{tmpdir}/output/etl__completion/etl__completion.000.ndjson"), + ) diff --git a/tests/utils.py b/tests/utils.py index 38d2e57e..f1babbf9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -202,12 +202,16 @@ def read_delta_lake(lake_path: str, *, version: int = None) -> list[dict]: if version is not None: reader = reader.option("versionAsOf", version) - table_spark = reader.format("delta").load(lake_path).sort("id") + table_spark = reader.format("delta").load(lake_path) # Convert the spark table to Python primitives. # Going to rdd or pandas and then to Python keeps inserting spark-specific constructs like Row(). # So instead, convert to a JSON string and then back to Python. - return [json.loads(row) for row in table_spark.toJSON().collect()] + rows = [json.loads(row) for row in table_spark.toJSON().collect()] + + # Try to sort by id, but if that doesn't exist (which happens for some completion tables), + # just use all dict values as a sort key. + return sorted(rows, key=lambda x: x.get("id", sorted(x.items()))) @contextlib.contextmanager