diff --git a/cumulus_etl/errors.py b/cumulus_etl/errors.py index 075fc05..6e2e9b3 100644 --- a/cumulus_etl/errors.py +++ b/cumulus_etl/errors.py @@ -34,6 +34,7 @@ SERVICE_MISSING = 33 # generic init-check service is missing COMPLETION_ARG_MISSING = 34 TASK_HELP = 35 +MISSING_REQUESTED_RESOURCES = 36 class FatalError(Exception): diff --git a/cumulus_etl/etl/cli.py b/cumulus_etl/etl/cli.py index 3569d78..a20fb5c 100644 --- a/cumulus_etl/etl/cli.py +++ b/cumulus_etl/etl/cli.py @@ -2,10 +2,10 @@ import argparse import datetime +import logging import os import shutil import sys -from collections.abc import Iterable import rich import rich.table @@ -16,6 +16,9 @@ from cumulus_etl.etl.config import JobConfig, JobSummary from cumulus_etl.etl.tasks import task_factory +TaskList = list[type[tasks.EtlTask]] + + ############################################################################### # # Main Pipeline (run all tasks) @@ -24,7 +27,7 @@ async def etl_job( - config: JobConfig, selected_tasks: list[type[tasks.EtlTask]], use_philter: bool = False + config: JobConfig, selected_tasks: TaskList, use_philter: bool = False ) -> list[JobSummary]: """ :param config: job config @@ -68,7 +71,7 @@ def check_mstool() -> None: raise SystemExit(errors.MSTOOL_MISSING) -async def check_requirements(selected_tasks: Iterable[type[tasks.EtlTask]]) -> None: +async def check_requirements(selected_tasks: TaskList) -> None: """ Verifies that all external services and programs are ready @@ -118,6 +121,11 @@ def define_etl_parser(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--errors-to", metavar="DIR", help="where to put resources that could not be processed" ) + parser.add_argument( + "--allow-missing-resources", + action="store_true", + help="run tasks even if their resources are not present", + ) cli_utils.add_aws(parser) cli_utils.add_auth(parser) @@ -143,7 +151,7 @@ def define_etl_parser(parser: argparse.ArgumentParser) -> None: def print_config( - args: argparse.Namespace, job_datetime: datetime.datetime, all_tasks: Iterable[tasks.EtlTask] + args: argparse.Namespace, job_datetime: datetime.datetime, all_tasks: TaskList ) -> None: """ Prints the ETL configuration to the console. @@ -214,6 +222,49 @@ def handle_completion_args( return export_group_name, export_datetime +async def check_available_resources( + loader: loaders.Loader, + *, + requested_resources: set[str], + args: argparse.Namespace, + is_default_tasks: bool, +) -> set[str]: + # Here we try to reconcile which resources the user requested and which resources are actually + # available in the input root. + # - If the user didn't specify a specific task, we'll scope down the requested resources to + # what is actually present in the input. + # - If they did, we'll complain if their required resources are not available. + # + # Reconciling is helpful for performance reasons (don't need to finalize untouched tables), + # UX reasons (can tell user if they made a CLI mistake), and completion tracking (don't + # mark a resource as complete if we didn't even export it) + if args.allow_missing_resources: + return requested_resources + + detected = await loader.detect_resources() + if detected is None: + return requested_resources # likely we haven't run bulk export yet + + if missing_resources := requested_resources - detected: + for resource in sorted(missing_resources): + # Log the same message we would print if in common.py if we ran tasks anyway + logging.warning("No %s files found in %s", resource, loader.root.path) + + if is_default_tasks: + requested_resources -= missing_resources # scope down to detected resources + if not requested_resources: + errors.fatal( + "No supported resources found.", + errors.MISSING_REQUESTED_RESOURCES, + ) + else: + msg = "Required resources not found.\n" + msg += "Add --allow-missing-resources to run related tasks anyway with no input." + errors.fatal(msg, errors.MISSING_REQUESTED_RESOURCES) + + return requested_resources + + async def etl_main(args: argparse.Namespace) -> None: # Set up some common variables @@ -227,6 +278,7 @@ async def etl_main(args: argparse.Namespace) -> None: job_datetime = common.datetime_now() # grab timestamp before we do anything selected_tasks = task_factory.get_selected_tasks(args.task, args.task_filter) + is_default_tasks = not args.task and not args.task_filter # Print configuration print_config(args, job_datetime, selected_tasks) @@ -261,8 +313,17 @@ async def etl_main(args: argparse.Namespace) -> None: resume=args.resume, ) + required_resources = await check_available_resources( + config_loader, + args=args, + is_default_tasks=is_default_tasks, + requested_resources=required_resources, + ) + # Drop any tasks that we didn't find resources for + selected_tasks = [t for t in selected_tasks if t.resource in required_resources] + # Pull down resources from any remote location (like s3), convert from i2b2, or do a bulk export - loader_results = await config_loader.load_all(list(required_resources)) + loader_results = await config_loader.load_resources(required_resources) # Establish the group name and datetime of the loaded dataset (from CLI args or Loader) export_group_name, export_datetime = handle_completion_args(args, loader_results) diff --git a/cumulus_etl/etl/tasks/base.py b/cumulus_etl/etl/tasks/base.py index 9589f07..c6343d0 100644 --- a/cumulus_etl/etl/tasks/base.py +++ b/cumulus_etl/etl/tasks/base.py @@ -260,9 +260,6 @@ def _delete_requested_ids(self): self.formatters[index].delete_records(deleted_ids) def _update_completion_table(self) -> None: - # TODO: what about empty sets - do we assume the export gave 0 results or skip it? - # Is there a difference we could notice? (like empty input file vs no file at all) - if not self.completion_tracking_enabled: return diff --git a/cumulus_etl/loaders/base.py b/cumulus_etl/loaders/base.py index 2a262a6..2767d0e 100644 --- a/cumulus_etl/loaders/base.py +++ b/cumulus_etl/loaders/base.py @@ -46,7 +46,15 @@ def __init__(self, root: store.Root): self.root = root @abc.abstractmethod - async def load_all(self, resources: list[str]) -> LoaderResults: + async def detect_resources(self) -> set[str] | None: + """ + Inspect which resources are available for use. + + :returns: the types of resources detected (or None if that can't be determined yet) + """ + + @abc.abstractmethod + async def load_resources(self, resources: set[str]) -> LoaderResults: """ Loads the listed remote resources and places them into a local folder as FHIR ndjson diff --git a/cumulus_etl/loaders/fhir/bulk_export.py b/cumulus_etl/loaders/fhir/bulk_export.py index 2816f16..4f4cd89 100644 --- a/cumulus_etl/loaders/fhir/bulk_export.py +++ b/cumulus_etl/loaders/fhir/bulk_export.py @@ -30,7 +30,7 @@ class BulkExporter: def __init__( self, client: fhir.FhirClient, - resources: list[str], + resources: set[str], url: str, destination: str, *, @@ -81,7 +81,7 @@ def format_kickoff_url( self, url: str, *, - resources: list[str], + resources: set[str], since: str | None, until: str | None, prefer_url_resources: bool, diff --git a/cumulus_etl/loaders/fhir/ndjson_loader.py b/cumulus_etl/loaders/fhir/ndjson_loader.py index 8e4a309..9107841 100644 --- a/cumulus_etl/loaders/fhir/ndjson_loader.py +++ b/cumulus_etl/loaders/fhir/ndjson_loader.py @@ -2,6 +2,8 @@ import tempfile +import cumulus_fhir_support + from cumulus_etl import cli_utils, common, errors, fhir, store from cumulus_etl.loaders import base from cumulus_etl.loaders.fhir.bulk_export import BulkExporter @@ -37,7 +39,18 @@ def __init__( self.until = until self.resume = resume - async def load_all(self, resources: list[str]) -> base.LoaderResults: + async def detect_resources(self) -> set[str] | None: + if self.root.protocol in {"http", "https"}: + # We haven't done the export yet, so there are no files to inspect yet. + # Returning None means "dunno" (i.e. "just accept whatever you eventually get"). + return None + + found_files = cumulus_fhir_support.list_multiline_json_in_dir( + self.root.path, fsspec_fs=self.root.fs + ) + return {resource for resource in found_files.values() if resource} + + async def load_resources(self, resources: set[str]) -> base.LoaderResults: # Are we doing a bulk FHIR export from a server? if self.root.protocol in ["http", "https"]: bulk_dir = await self.load_from_bulk_export(resources) @@ -61,14 +74,14 @@ async def load_all(self, resources: list[str]) -> base.LoaderResults: # TemporaryDirectory gets discarded), but that seems reasonable. print("Copying ndjson input files…") tmpdir = tempfile.TemporaryDirectory() - filenames = common.ls_resources(input_root, set(resources), warn_if_empty=True) + filenames = common.ls_resources(input_root, resources, warn_if_empty=True) for filename in filenames: input_root.get(filename, f"{tmpdir.name}/") return self.read_loader_results(input_root, tmpdir) async def load_from_bulk_export( - self, resources: list[str], prefer_url_resources: bool = False + self, resources: set[str], prefer_url_resources: bool = False ) -> common.Directory: """ Performs a bulk export and drops the results in an export dir. diff --git a/cumulus_etl/loaders/i2b2/loader.py b/cumulus_etl/loaders/i2b2/loader.py index 42f7240..343dc30 100644 --- a/cumulus_etl/loaders/i2b2/loader.py +++ b/cumulus_etl/loaders/i2b2/loader.py @@ -34,7 +34,29 @@ def __init__(self, root: store.Root, export_to: str | None = None): super().__init__(root) self.export_to = export_to - async def load_all(self, resources: list[str]) -> base.LoaderResults: + async def detect_resources(self) -> set[str] | None: + if self.root.protocol in {"tcp"}: + # We haven't done the export yet, so there are no files to inspect yet. + # Returning None means "dunno" (i.e. "just accept whatever you eventually get"). + return None + + filenames = { + "observation_fact_diagnosis.csv": "Condition", + "observation_fact_lab_views.csv": "Observation", + "observation_fact_medications.csv": "MedicationRequest", + "observation_fact_notes.csv": "DocumentReference", + "observation_fact_vitals.csv": "Observation", + "patient_dimension.csv": "Patient", + "visit_dimension.csv": "Encounter", + } + + return { + resource + for path, resource in filenames.items() + if self.root.exists(self.root.joinpath(path)) + } + + async def load_resources(self, resources: set[str]) -> base.LoaderResults: if self.root.protocol in ["tcp"]: directory = self._load_all_from_oracle(resources) else: @@ -43,7 +65,7 @@ async def load_all(self, resources: list[str]) -> base.LoaderResults: def _load_all_with_extractors( self, - resources: list[str], + resources: set[str], conditions: I2b2ExtractorCallable, lab_views: I2b2ExtractorCallable, medicationrequests: I2b2ExtractorCallable, @@ -139,7 +161,7 @@ def _loop( # ################################################################################################################### - def _load_all_from_csv(self, resources: list[str]) -> common.Directory: + def _load_all_from_csv(self, resources: set[str]) -> common.Directory: path = self.root.path return self._load_all_with_extractors( resources, @@ -177,7 +199,7 @@ def _load_all_from_csv(self, resources: list[str]) -> common.Directory: # ################################################################################################################### - def _load_all_from_oracle(self, resources: list[str]) -> common.Directory: + def _load_all_from_oracle(self, resources: set[str]) -> common.Directory: path = self.root.path return self._load_all_with_extractors( resources, diff --git a/cumulus_etl/upload_notes/downloader.py b/cumulus_etl/upload_notes/downloader.py index 96d2a14..47af041 100644 --- a/cumulus_etl/upload_notes/downloader.py +++ b/cumulus_etl/upload_notes/downloader.py @@ -27,7 +27,7 @@ async def download_docrefs_from_fhir_server( else: # else we'll download the entire target path as a bulk export (presumably the user has scoped a Group) ndjson_loader = loaders.FhirNdjsonLoader(root_input, client, export_to=export_to) - return await ndjson_loader.load_all(["DocumentReference"]) + return await ndjson_loader.load_resources({"DocumentReference"}) async def _download_docrefs_from_fake_ids( diff --git a/tests/etl/test_etl_cli.py b/tests/etl/test_etl_cli.py index 74678e5..2dd1f9d 100644 --- a/tests/etl/test_etl_cli.py +++ b/tests/etl/test_etl_cli.py @@ -103,17 +103,17 @@ async def test_failed_task(self): async def test_single_task(self): # Grab all observations before we mock anything - observations = loaders.FhirNdjsonLoader(store.Root(self.input_path)).load_all( - ["Observation"] + observations = loaders.FhirNdjsonLoader(store.Root(self.input_path)).load_resources( + {"Observation"} ) - def fake_load_all(internal_self, resources): + def fake_load_resources(internal_self, resources): del internal_self # Confirm we only tried to load one resource - self.assertEqual(["Observation"], resources) + self.assertEqual({"Observation"}, resources) return observations - with mock.patch.object(loaders.FhirNdjsonLoader, "load_all", new=fake_load_all): + with mock.patch.object(loaders.FhirNdjsonLoader, "load_resources", new=fake_load_resources): await self.run_etl(tasks=["observation"]) # Confirm we only wrote the one resource @@ -126,17 +126,17 @@ def fake_load_all(internal_self, resources): async def test_multiple_tasks(self): # Grab all observations before we mock anything - loaded = loaders.FhirNdjsonLoader(store.Root(self.input_path)).load_all( - ["Observation", "Patient"] + loaded = loaders.FhirNdjsonLoader(store.Root(self.input_path)).load_resources( + {"Observation", "Patient"} ) - def fake_load_all(internal_self, resources): + def fake_load_resources(internal_self, resources): del internal_self # Confirm we only tried to load two resources - self.assertEqual({"Observation", "Patient"}, set(resources)) + self.assertEqual({"Observation", "Patient"}, resources) return loaded - with mock.patch.object(loaders.FhirNdjsonLoader, "load_all", new=fake_load_all): + with mock.patch.object(loaders.FhirNdjsonLoader, "load_resources", new=fake_load_resources): await self.run_etl(tasks=["observation", "patient"]) # Confirm we only wrote the two resources @@ -267,8 +267,8 @@ async def test_task_init_checks(self, mock_check): async def test_completion_args(self, etl_args, loader_vals, expected_vals): """Verify that we parse completion args with the correct fallbacks and checks.""" # Grab all observations before we mock anything - observations = await loaders.FhirNdjsonLoader(store.Root(self.input_path)).load_all( - ["Observation"] + observations = await loaders.FhirNdjsonLoader(store.Root(self.input_path)).load_resources( + {"Observation"} ) observations.group_name = loader_vals[0] observations.export_datetime = loader_vals[1] @@ -276,7 +276,9 @@ async def test_completion_args(self, etl_args, loader_vals, expected_vals): with ( self.assertRaises(SystemExit) as cm, mock.patch("cumulus_etl.etl.cli.etl_job", side_effect=SystemExit) as mock_etl_job, - mock.patch.object(loaders.FhirNdjsonLoader, "load_all", return_value=observations), + mock.patch.object( + loaders.FhirNdjsonLoader, "load_resources", return_value=observations + ), ): await self.run_etl(tasks=["observation"], **etl_args) @@ -297,7 +299,7 @@ async def test_deleted_ids_passed_down(self): with ( self.assertRaises(SystemExit), mock.patch("cumulus_etl.etl.cli.etl_job", side_effect=SystemExit) as mock_etl_job, - mock.patch.object(loaders.FhirNdjsonLoader, "load_all", return_value=results), + mock.patch.object(loaders.FhirNdjsonLoader, "load_resources", return_value=results), ): await self.run_etl(tasks=["observation"]) @@ -305,6 +307,28 @@ async def test_deleted_ids_passed_down(self): config = mock_etl_job.call_args[0][0] self.assertEqual({"Observation": {"obs1"}}, config.deleted_ids) + @ddt.data(["patient"], None) + async def test_missing_resources(self, tasks): + with tempfile.TemporaryDirectory() as tmpdir: + with self.assertRaises(SystemExit) as cm: + await self.run_etl(tasks=tasks, input_path=tmpdir) + self.assertEqual(errors.MISSING_REQUESTED_RESOURCES, cm.exception.code) + + async def test_allow_missing_resources(self): + with tempfile.TemporaryDirectory() as tmpdir: + await self.run_etl("--allow-missing-resources", tasks=["patient"], input_path=tmpdir) + + self.assertEqual("", common.read_text(f"{self.output_path}/patient/patient.000.ndjson")) + + async def test_missing_resources_skips_tasks(self): + with tempfile.TemporaryDirectory() as tmpdir: + common.write_json(f"{tmpdir}/p.ndjson", {"id": "A", "resourceType": "Patient"}) + await self.run_etl(input_path=tmpdir) + + self.assertEqual( + {"etl__completion", "patient", "JobConfig"}, set(os.listdir(self.output_path)) + ) + class TestEtlJobConfig(BaseEtlSimple): """Test case for the job config logging data""" diff --git a/tests/loaders/i2b2/test_i2b2_loader.py b/tests/loaders/i2b2/test_i2b2_loader.py index 4e1c31f..269d1b8 100644 --- a/tests/loaders/i2b2/test_i2b2_loader.py +++ b/tests/loaders/i2b2/test_i2b2_loader.py @@ -22,7 +22,7 @@ async def test_missing_files(self): vitals = f"{self.datadir}/i2b2/input/observation_fact_vitals.csv" shutil.copy(vitals, tmpdir) - results = await i2b2_loader.load_all(["Observation", "Patient"]) + results = await i2b2_loader.load_resources({"Observation", "Patient"}) self.assertEqual(["Observation.1.ndjson"], os.listdir(results.path)) @@ -37,7 +37,24 @@ async def test_duplicate_ids(self): "PATIENT_NUM,BIRTH_DATE\n" "123,1982-10-16\n" "123,1983-11-17\n" "456,2000-01-13\n", ) - results = await i2b2_loader.load_all(["Patient"]) + results = await i2b2_loader.load_resources({"Patient"}) rows = common.read_resource_ndjson(store.Root(results.path), "Patient") values = [(r["id"], r["birthDate"]) for r in rows] self.assertEqual(values, [("123", "1982-10-16"), ("456", "2000-01-13")]) + + async def test_detect_resources(self): + """Verify we can inspect a folder and find all resources.""" + with tempfile.TemporaryDirectory() as tmpdir: + common.write_text(f"{tmpdir}/visit_dimension.csv", "") + common.write_text(f"{tmpdir}/unrelated.csv", "") + common.write_text(f"{tmpdir}/observation_fact_lab_views.csv", "") + + i2b2_loader = loader.I2b2Loader(store.Root(tmpdir)) + resources = await i2b2_loader.detect_resources() + + self.assertEqual(resources, {"Encounter", "Observation"}) + + async def test_detect_resources_tcp(self): + """Verify we skip trying to detect resources before exporting from oracle.""" + i2b2_loader = loader.I2b2Loader(store.Root("tcp://localhost")) + self.assertIsNone(await i2b2_loader.detect_resources()) diff --git a/tests/loaders/i2b2/test_i2b2_oracle_extract.py b/tests/loaders/i2b2/test_i2b2_oracle_extract.py index 8c2071a..e533597 100644 --- a/tests/loaders/i2b2/test_i2b2_oracle_extract.py +++ b/tests/loaders/i2b2/test_i2b2_oracle_extract.py @@ -93,7 +93,7 @@ async def test_loader(self, mock_extract): root = store.Root("tcp://localhost/foo") oracle_loader = loader.I2b2Loader(root) - results = await oracle_loader.load_all(["Condition", "Encounter", "Patient"]) + results = await oracle_loader.load_resources({"Condition", "Encounter", "Patient"}) # Check results self.assertEqual( diff --git a/tests/loaders/ndjson/test_ndjson_loader.py b/tests/loaders/ndjson/test_ndjson_loader.py index 916c7e5..a75195d 100644 --- a/tests/loaders/ndjson/test_ndjson_loader.py +++ b/tests/loaders/ndjson/test_ndjson_loader.py @@ -64,7 +64,7 @@ async def test_local_happy_path(self): writer.write(patient) loader = loaders.FhirNdjsonLoader(store.Root(tmpdir)) - results = await loader.load_all(["Patient"]) + results = await loader.load_resources(["Patient"]) self.assertEqual(["Patient.ndjson"], os.listdir(results.path)) self.assertEqual(patient, common.read_json(f"{results.path}/Patient.ndjson")) @@ -82,7 +82,7 @@ async def test_log_parsing_is_non_fatal(self): self._write_log_file(f"{tmpdir}/log.2.ndjson", "G2", "2002-02-02") loader = loaders.FhirNdjsonLoader(store.Root(tmpdir)) - results = await loader.load_all([]) + results = await loader.load_resources([]) # We used neither log and didn't error out. self.assertIsNone(results.group_name) @@ -280,7 +280,7 @@ async def test_fatal_errors_are_fatal(self): with self.assertRaises(SystemExit) as cm: await loaders.FhirNdjsonLoader( store.Root("http://localhost:9999"), mock.AsyncMock() - ).load_all(["Patient"]) + ).load_resources({"Patient"}) self.assertEqual(1, self.mock_exporter.export.call_count) self.assertEqual(errors.BULK_EXPORT_FAILED, cm.exception.code) @@ -301,7 +301,7 @@ async def fake_export() -> None: loader = loaders.FhirNdjsonLoader( store.Root("http://localhost:9999"), mock.AsyncMock(), export_to=target ) - results = await loader.load_all(["Patient"]) + results = await loader.load_resources({"Patient"}) # Confirm export folder still has the data (and log) we created above in the mock self.assertTrue(os.path.isdir(target)) @@ -327,7 +327,7 @@ async def fake_export() -> None: self.mock_exporter.export.side_effect = fake_export loader = loaders.FhirNdjsonLoader(store.Root("http://localhost:9999"), mock.AsyncMock()) - results = await loader.load_all(["Patient"]) + results = await loader.load_resources({"Patient"}) # Confirm the returned dir has only the data (we don't want to confuse MS tool with logs) self.assertEqual({"Patient.ndjson"}, set(os.listdir(results.path))) @@ -341,7 +341,7 @@ async def test_export_to_folder_has_contents(self): store.Root("http://localhost:9999"), mock.AsyncMock(), export_to=tmpdir ) with self.assertRaises(SystemExit) as cm: - await loader.load_all([]) + await loader.load_resources(set()) self.assertEqual(cm.exception.code, errors.FOLDER_NOT_EMPTY) async def test_export_to_folder_not_local(self): @@ -350,7 +350,7 @@ async def test_export_to_folder_not_local(self): store.Root("http://localhost:9999"), mock.AsyncMock(), export_to="http://foo" ) with self.assertRaises(SystemExit) as cm: - await loader.load_all([]) + await loader.load_resources(set()) self.assertEqual(cm.exception.code, errors.BULK_EXPORT_FOLDER_NOT_LOCAL) async def test_reads_deleted_ids(self): @@ -393,6 +393,18 @@ async def test_reads_deleted_ids(self): }, ) loader = loaders.FhirNdjsonLoader(store.Root(tmpdir)) - results = await loader.load_all(["Patient"]) + results = await loader.load_resources({"Patient"}) self.assertEqual(results.deleted_ids, {"Patient": {"pat1"}, "Condition": {"con1", "con2"}}) + + async def test_detect_resources(self): + """Verify we can inspect a folder and find all resources.""" + with tempfile.TemporaryDirectory() as tmpdir: + common.write_json(f"{tmpdir}/p.ndjson", {"id": "A", "resourceType": "Patient"}) + common.write_json(f"{tmpdir}/unrelated.ndjson", {"num_cats": 5}) + common.write_json(f"{tmpdir}/c.ndjson", {"id": "A", "resourceType": "Condition"}) + + loader = loaders.FhirNdjsonLoader(store.Root(tmpdir)) + resources = await loader.detect_resources() + + self.assertEqual(resources, {"Condition", "Patient"}) diff --git a/tests/upload_notes/test_upload_cli.py b/tests/upload_notes/test_upload_cli.py index 628e77d..1d1df63 100644 --- a/tests/upload_notes/test_upload_cli.py +++ b/tests/upload_notes/test_upload_cli.py @@ -259,12 +259,12 @@ async def test_gather_real_docrefs_from_server(self, respx_mock): @mock.patch("cumulus_etl.upload_notes.downloader.loaders.FhirNdjsonLoader") async def test_gather_all_docrefs_from_server(self, mock_loader): # Mock out the bulk export loading, as that's well tested elsewhere - async def load_all(*args): + async def load_resources(*args): del args return common.RealDirectory(self.input_path) - load_all_mock = mock_loader.return_value.load_all - load_all_mock.side_effect = load_all + load_resources_mock = mock_loader.return_value.load_resources + load_resources_mock.side_effect = load_resources # Do the actual upload-notes push await self.run_upload_notes(input_path="https://localhost") @@ -273,7 +273,7 @@ async def load_all(*args): self.assertEqual(1, mock_loader.call_count) self.assertEqual("https://localhost", mock_loader.call_args[0][0].path) self.assertEqual(self.export_path, mock_loader.call_args[1]["export_to"]) - self.assertEqual([mock.call(["DocumentReference"])], load_all_mock.call_args_list) + self.assertEqual([mock.call({"DocumentReference"})], load_resources_mock.call_args_list) # Make sure we do read the result and push the docrefs out self.assertEqual({"43", "44"}, self.get_pushed_ids())