From 78a5989a25dc707cdea565b3de50fdb77603a4b9 Mon Sep 17 00:00:00 2001 From: Dave Date: Wed, 19 Jun 2024 16:52:23 +0200 Subject: [PATCH] add support for starting load jobs as slots free up --- dlt/common/runtime/signals.py | 5 + dlt/load/load.py | 143 +++++++++++++++------------- dlt/load/utils.py | 8 +- tests/load/test_dummy_client.py | 69 ++++++-------- tests/load/test_parallelism_util.py | 44 ++++++--- 5 files changed, 152 insertions(+), 117 deletions(-) diff --git a/dlt/common/runtime/signals.py b/dlt/common/runtime/signals.py index 8d1cb3803e..a8fa70936e 100644 --- a/dlt/common/runtime/signals.py +++ b/dlt/common/runtime/signals.py @@ -32,6 +32,11 @@ def raise_if_signalled() -> None: raise SignalReceivedException(_received_signal) +def signal_received() -> bool: + """check if a signal was received""" + return True if _received_signal else False + + def sleep(sleep_seconds: float) -> None: """A signal-aware version of sleep function. Will raise SignalReceivedException if signal was received during sleep period.""" # do not allow sleeping if signal was received diff --git a/dlt/load/load.py b/dlt/load/load.py index abbeee5ddf..bf26ca8aab 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -43,6 +43,7 @@ DestinationTerminalException, DestinationTransientException, ) +from dlt.common.runtime import signals from dlt.destinations.job_impl import EmptyLoadJob @@ -194,26 +195,29 @@ def w_spool_job( self.load_storage.normalized_packages.start_job(load_id, job.file_name()) return job - def spool_new_jobs(self, load_id: str, schema: Schema) -> Tuple[int, List[LoadJob]]: + def spool_new_jobs( + self, load_id: str, schema: Schema, running_jobs_count: int + ) -> List[LoadJob]: # use thread based pool as jobs processing is mostly I/O and we do not want to pickle jobs - load_files = filter_new_jobs( - self.load_storage.list_new_jobs(load_id), self.capabilities, self.config - ) + load_files = self.load_storage.list_new_jobs(load_id) file_count = len(load_files) if file_count == 0: logger.info(f"No new jobs found in {load_id}") - return 0, [] - logger.info(f"Will load {file_count}, creating jobs") + return [] + + load_files = filter_new_jobs(load_files, self.capabilities, self.config, running_jobs_count) + file_count = len(load_files) + logger.info(f"Will load additional {file_count}, creating jobs") param_chunk = [(id(self), file, load_id, schema) for file in load_files] # exceptions should not be raised, None as job is a temporary failure # other jobs should not be affected jobs = self.pool.map(Load.w_spool_job, *zip(*param_chunk)) # remove None jobs and check the rest - return file_count, [job for job in jobs if job is not None] + return [job for job in jobs if job is not None] def retrieve_jobs( self, client: JobClientBase, load_id: str, staging_client: JobClientBase = None - ) -> Tuple[int, List[LoadJob]]: + ) -> List[LoadJob]: jobs: List[LoadJob] = [] # list all files that were started but not yet completed @@ -221,7 +225,7 @@ def retrieve_jobs( logger.info(f"Found {len(started_jobs)} that are already started and should be continued") if len(started_jobs) == 0: - return 0, jobs + return jobs for file_path in started_jobs: try: @@ -237,7 +241,7 @@ def retrieve_jobs( raise jobs.append(job) - return len(jobs), jobs + return jobs def get_new_jobs_info(self, load_id: str) -> List[ParsedLoadJobFileName]: return [ @@ -274,14 +278,19 @@ def create_followup_jobs( jobs = jobs + starting_job.create_followup_jobs(state) return jobs - def complete_jobs(self, load_id: str, jobs: List[LoadJob], schema: Schema) -> List[LoadJob]: + def complete_jobs( + self, load_id: str, jobs: List[LoadJob], schema: Schema + ) -> Tuple[List[LoadJob], Exception]: """Run periodically in the main thread to collect job execution statuses. After detecting change of status, it commits the job state by moving it to the right folder May create one or more followup jobs that get scheduled as new jobs. New jobs are created only in terminal states (completed / failed) """ + # list of jobs still running remaining_jobs: List[LoadJob] = [] + # if an exception condition was met, return it to the main runner + pending_exception: Exception = None def _schedule_followup_jobs(followup_jobs: Iterable[NewLoadJob]) -> None: for followup_job in followup_jobs: @@ -323,6 +332,13 @@ def _schedule_followup_jobs(followup_jobs: Iterable[NewLoadJob]) -> None: f"Job for {job.job_id()} failed terminally in load {load_id} with message" f" {failed_message}" ) + # schedule exception on job failure + if self.config.raise_on_failed_jobs: + pending_exception = LoadClientJobFailed( + load_id, + job.job_file_info().job_id(), + failed_message, + ) elif state == "retry": # try to get exception message from job retry_message = job.exception() @@ -331,6 +347,16 @@ def _schedule_followup_jobs(followup_jobs: Iterable[NewLoadJob]) -> None: logger.warning( f"Job for {job.job_id()} retried in load {load_id} with message {retry_message}" ) + # possibly schedule exception on too many retries + if self.config.raise_on_max_retries: + r_c = job.job_file_info().retry_count + 1 + if r_c > 0 and r_c % self.config.raise_on_max_retries == 0: + pending_exception = LoadClientJobRetry( + load_id, + job.job_file_info().job_id(), + r_c, + self.config.raise_on_max_retries, + ) elif state == "completed": # create followup jobs _schedule_followup_jobs(self.create_followup_jobs(load_id, state, job, schema)) @@ -346,7 +372,7 @@ def _schedule_followup_jobs(followup_jobs: Iterable[NewLoadJob]) -> None: "Jobs", 1, message="WARNING: Some of the jobs failed!", label="Failed" ) - return remaining_jobs + return remaining_jobs, pending_exception def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) -> None: # do not commit load id for aborted packages @@ -371,6 +397,18 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) f"All jobs completed, archiving package {load_id} with aborted set to {aborted}" ) + def update_loadpackage_info(self, load_id: str) -> None: + # update counter we only care about the jobs that are scheduled to be loaded + package_info = self.load_storage.normalized_packages.get_load_package_info(load_id) + total_jobs = reduce(lambda p, c: p + len(c), package_info.jobs.values(), 0) + no_failed_jobs = len(package_info.jobs["failed_jobs"]) + no_completed_jobs = len(package_info.jobs["completed_jobs"]) + no_failed_jobs + self.collector.update("Jobs", no_completed_jobs, total_jobs) + if no_failed_jobs > 0: + self.collector.update( + "Jobs", no_failed_jobs, message="WARNING: Some of the jobs failed!", label="Failed" + ) + def load_single_package(self, load_id: str, schema: Schema) -> None: new_jobs = self.get_new_jobs_info(load_id) @@ -414,72 +452,49 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: drop_tables=dropped_tables, truncate_tables=truncated_tables, ) - self.load_storage.commit_schema_update(load_id, applied_update) - # initialize staging destination and spool or retrieve unfinished jobs + # collect all unfinished jobs + running_jobs: List[LoadJob] = [] if self.staging_destination: with self.get_staging_destination_client(schema) as staging_client: - jobs_count, jobs = self.retrieve_jobs(job_client, load_id, staging_client) - else: - jobs_count, jobs = self.retrieve_jobs(job_client, load_id) - - if not jobs: - # jobs count is a total number of jobs including those that could not be initialized - jobs_count, jobs = self.spool_new_jobs(load_id, schema) - # if there are no existing or new jobs we complete the package - if jobs_count == 0: - self.complete_package(load_id, schema, False) - return - # update counter we only care about the jobs that are scheduled to be loaded - package_info = self.load_storage.normalized_packages.get_load_package_info(load_id) - total_jobs = reduce(lambda p, c: p + len(c), package_info.jobs.values(), 0) - no_failed_jobs = len(package_info.jobs["failed_jobs"]) - no_completed_jobs = len(package_info.jobs["completed_jobs"]) + no_failed_jobs - self.collector.update("Jobs", no_completed_jobs, total_jobs) - if no_failed_jobs > 0: - self.collector.update( - "Jobs", no_failed_jobs, message="WARNING: Some of the jobs failed!", label="Failed" - ) + running_jobs += self.retrieve_jobs(job_client, load_id, staging_client) + running_jobs += self.retrieve_jobs(job_client, load_id) + # loop until all jobs are processed while True: try: - remaining_jobs = self.complete_jobs(load_id, jobs, schema) - if len(remaining_jobs) == 0: - # get package status - package_info = self.load_storage.normalized_packages.get_load_package_info( - load_id - ) - # possibly raise on failed jobs - if self.config.raise_on_failed_jobs: - if package_info.jobs["failed_jobs"]: - failed_job = package_info.jobs["failed_jobs"][0] - raise LoadClientJobFailed( - load_id, - failed_job.job_file_info.job_id(), - failed_job.failed_message, - ) - # possibly raise on too many retries - if self.config.raise_on_max_retries: - for new_job in package_info.jobs["new_jobs"]: - r_c = new_job.job_file_info.retry_count - if r_c > 0 and r_c % self.config.raise_on_max_retries == 0: - raise LoadClientJobRetry( - load_id, - new_job.job_file_info.job_id(), - r_c, - self.config.raise_on_max_retries, - ) + # we continously spool new jobs and complete finished ones + running_jobs, pending_exception = self.complete_jobs(load_id, running_jobs, schema) + # do not spool new jobs if there was a signal + if not signals.signal_received() and not pending_exception: + running_jobs += self.spool_new_jobs(load_id, schema, len(running_jobs)) + self.update_loadpackage_info(load_id) + + if len(running_jobs) == 0: + # if a pending exception was discovered during completion of jobs + # we can raise it now + if pending_exception: + raise pending_exception break - # process remaining jobs again - jobs = remaining_jobs # this will raise on signal - sleep(1) + sleep(0.5) except LoadClientJobFailed: # the package is completed and skipped + self.update_loadpackage_info(load_id) self.complete_package(load_id, schema, True) raise + # always update load package info + self.update_loadpackage_info(load_id) + + # complete the package if no new or started jobs present after loop exit + if ( + len(self.load_storage.list_new_jobs(load_id)) == 0 + and len(self.load_storage.normalized_packages.list_started_jobs(load_id)) == 0 + ): + self.complete_package(load_id, schema, False) + def run(self, pool: Optional[Executor]) -> TRunMetrics: # store pool self.pool = pool or NullExecutor() diff --git a/dlt/load/utils.py b/dlt/load/utils.py index 4e5099855b..39ef5f7507 100644 --- a/dlt/load/utils.py +++ b/dlt/load/utils.py @@ -225,6 +225,7 @@ def filter_new_jobs( file_names: Sequence[str], capabilities: DestinationCapabilitiesContext, config: LoaderConfiguration, + running_jobs_count: int, ) -> Sequence[str]: """Filters the list of new jobs to adhere to max_workers and parallellism strategy""" """NOTE: in the current setup we only filter based on settings for the final destination""" @@ -242,6 +243,11 @@ def filter_new_jobs( if mp := capabilities.max_parallel_load_jobs: max_workers = min(max_workers, mp) + # if all slots are full, do not create new jobs + if running_jobs_count >= max_workers: + return [] + max_jobs = max_workers - running_jobs_count + # regular sequential works on all jobs eligible_jobs = file_names @@ -257,4 +263,4 @@ def filter_new_jobs( ) ] - return eligible_jobs[:max_workers] + return eligible_jobs[:max_jobs] diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index 30de51f069..63b3171df2 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -96,15 +96,15 @@ def test_unsupported_write_disposition() -> None: load.load_storage.normalized_packages.save_schema(load_id, schema) with ThreadPoolExecutor() as pool: load.run(pool) - # job with unsupported write disp. is failed + # job with unsupported write disp. is failed and job is completed already exception_file = [ f - for f in load.load_storage.normalized_packages.list_failed_jobs(load_id) + for f in load.load_storage.loaded_packages.list_failed_jobs(load_id) if f.endswith(".exception") ][0] assert ( "LoadClientUnsupportedWriteDisposition" - in load.load_storage.normalized_packages.storage.load(exception_file) + in load.load_storage.loaded_packages.storage.load(exception_file) ) @@ -175,7 +175,7 @@ def test_spool_job_failed() -> None: ) jobs.append(job) # complete files - remaining_jobs = load.complete_jobs(load_id, jobs, schema) + remaining_jobs, _ = load.complete_jobs(load_id, jobs, schema) assert len(remaining_jobs) == 0 for job in jobs: assert load.load_storage.normalized_packages.storage.has_file( @@ -253,8 +253,7 @@ def test_spool_job_retry_spool_new() -> None: # call higher level function that returns jobs and counts with ThreadPoolExecutor() as pool: load.pool = pool - jobs_count, jobs = load.spool_new_jobs(load_id, schema) - assert jobs_count == 2 + jobs = load.spool_new_jobs(load_id, schema, 0) assert len(jobs) == 2 @@ -280,7 +279,7 @@ def test_spool_job_retry_started() -> None: files = load.load_storage.normalized_packages.list_new_jobs(load_id) assert len(files) == 0 # should retry, that moves jobs into new folder - remaining_jobs = load.complete_jobs(load_id, jobs, schema) + remaining_jobs, _ = load.complete_jobs(load_id, jobs, schema) assert len(remaining_jobs) == 0 # clear retry flag dummy_impl.JOBS = {} @@ -307,19 +306,19 @@ def test_try_retrieve_job() -> None: # dummy client may retrieve jobs that it created itself, jobs in started folder are unknown # and returned as terminal with load.destination.client(schema, load.initial_client_config) as c: - job_count, jobs = load.retrieve_jobs(c, load_id) - assert job_count == 2 + jobs = load.retrieve_jobs(c, load_id) + assert len(jobs) == 2 for j in jobs: assert j.state() == "failed" # new load package load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) load.pool = ThreadPoolExecutor() - jobs_count, jobs = load.spool_new_jobs(load_id, schema) - assert jobs_count == 2 + jobs = load.spool_new_jobs(load_id, schema, 0) + assert len(jobs) == 2 # now jobs are known with load.destination.client(schema, load.initial_client_config) as c: - job_count, jobs = load.retrieve_jobs(c, load_id) - assert job_count == 2 + jobs = load.retrieve_jobs(c, load_id) + assert len(jobs) == 2 for j in jobs: assert j.state() == "running" @@ -386,21 +385,19 @@ def test_retry_on_new_loop() -> None: load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) with ThreadPoolExecutor() as pool: # 1st retry - load.run(pool) + with pytest.raises(LoadClientJobRetry): + load.run(pool) files = load.load_storage.normalized_packages.list_new_jobs(load_id) assert len(files) == 2 # 2nd retry - load.run(pool) + with pytest.raises(LoadClientJobRetry): + load.run(pool) files = load.load_storage.normalized_packages.list_new_jobs(load_id) assert len(files) == 2 - # jobs will be completed + # package will be completed load = setup_loader(client_config=DummyClientConfiguration(completed_prob=1.0)) load.run(pool) - files = load.load_storage.normalized_packages.list_new_jobs(load_id) - assert len(files) == 0 - # complete package - load.run(pool) assert not load.load_storage.normalized_packages.storage.has_folder( load.load_storage.get_normalized_package_path(load_id) ) @@ -409,13 +406,14 @@ def test_retry_on_new_loop() -> None: for fn in load.load_storage.loaded_packages.storage.list_folder_files( os.path.join(completed_path, PackageStorage.COMPLETED_JOBS_FOLDER) ): - # we update a retry count in each case - assert ParsedLoadJobFileName.parse(fn).retry_count == 2 + # we update a retry count in each case (5 times for each loop run) + assert ParsedLoadJobFileName.parse(fn).retry_count == 10 def test_retry_exceptions() -> None: load = setup_loader(client_config=DummyClientConfiguration(retry_prob=1.0)) prepare_load_package(load.load_storage, NORMALIZED_FILES) + with ThreadPoolExecutor() as pool: # 1st retry with pytest.raises(LoadClientJobRetry) as py_ex: @@ -423,7 +421,6 @@ def test_retry_exceptions() -> None: load.run(pool) # configured to retry 5 times before exception assert py_ex.value.max_retry_count == py_ex.value.retry_count == 5 - # we can do it again with pytest.raises(LoadClientJobRetry) as py_ex: while True: @@ -764,22 +761,7 @@ def assert_complete_job(load: Load, should_delete_completed: bool = False) -> No ) as complete_load: with ThreadPoolExecutor() as pool: load.run(pool) - # did process schema update - assert load.load_storage.storage.has_file( - os.path.join( - load.load_storage.get_normalized_package_path(load_id), - PackageStorage.APPLIED_SCHEMA_UPDATES_FILE_NAME, - ) - ) - # will finalize the whole package - load.run(pool) - # may have followup jobs or staging destination - if ( - load.initial_client_config.create_followup_jobs # type:ignore[attr-defined] - or load.staging_destination - ): - # run the followup jobs - load.run(pool) + # moved to loaded assert not load.load_storage.storage.has_folder( load.load_storage.get_normalized_package_path(load_id) @@ -787,6 +769,15 @@ def assert_complete_job(load: Load, should_delete_completed: bool = False) -> No completed_path = load.load_storage.loaded_packages.get_job_folder_path( load_id, "completed_jobs" ) + + # should have migrated the schema + assert load.load_storage.storage.has_file( + os.path.join( + load.load_storage.get_loaded_package_path(load_id), + PackageStorage.APPLIED_SCHEMA_UPDATES_FILE_NAME, + ) + ) + if should_delete_completed: # package was deleted assert not load.load_storage.loaded_packages.storage.has_folder(completed_path) diff --git a/tests/load/test_parallelism_util.py b/tests/load/test_parallelism_util.py index b8f43d0743..8968061544 100644 --- a/tests/load/test_parallelism_util.py +++ b/tests/load/test_parallelism_util.py @@ -26,19 +26,19 @@ def test_max_workers() -> None: caps, conf = get_caps_conf() # default is 20 - assert len(filter_new_jobs(job_names, caps, conf)) == 20 + assert len(filter_new_jobs(job_names, caps, conf, 0)) == 20 # we can change it conf.workers = 35 - assert len(filter_new_jobs(job_names, caps, conf)) == 35 + assert len(filter_new_jobs(job_names, caps, conf, 0)) == 35 # destination may override this caps.max_parallel_load_jobs = 15 - assert len(filter_new_jobs(job_names, caps, conf)) == 15 + assert len(filter_new_jobs(job_names, caps, conf, 0)) == 15 # lowest value will prevail conf.workers = 5 - assert len(filter_new_jobs(job_names, caps, conf)) == 5 + assert len(filter_new_jobs(job_names, caps, conf, 0)) == 5 def test_table_sequential_parallelism_strategy() -> None: @@ -51,17 +51,17 @@ def test_table_sequential_parallelism_strategy() -> None: caps, conf = get_caps_conf() # default is 20 - assert len(filter_new_jobs(job_names, caps, conf)) == 20 + assert len(filter_new_jobs(job_names, caps, conf, 0)) == 20 # table sequential will give us 8, one for each table conf.parallelism_strategy = "table-sequential" - filtered = filter_new_jobs(job_names, caps, conf) + filtered = filter_new_jobs(job_names, caps, conf, 0) assert len(filtered) == 8 assert len({ParsedLoadJobFileName.parse(j).table_name for j in job_names}) == 8 # max workers also are still applied conf.workers = 3 - assert len(filter_new_jobs(job_names, caps, conf)) == 3 + assert len(filter_new_jobs(job_names, caps, conf, 0)) == 3 def test_strategy_preference() -> None: @@ -72,22 +72,40 @@ def test_strategy_preference() -> None: caps, conf = get_caps_conf() # nothing set will default to parallel - assert len(filter_new_jobs(job_names, caps, conf)) == 20 + assert len(filter_new_jobs(job_names, caps, conf, 0)) == 20 caps.loader_parallelism_strategy = "table-sequential" - assert len(filter_new_jobs(job_names, caps, conf)) == 8 + assert len(filter_new_jobs(job_names, caps, conf, 0)) == 8 caps.loader_parallelism_strategy = "sequential" - assert len(filter_new_jobs(job_names, caps, conf)) == 1 + assert len(filter_new_jobs(job_names, caps, conf, 0)) == 1 # config may override (will go back to default 20) conf.parallelism_strategy = "parallel" - assert len(filter_new_jobs(job_names, caps, conf)) == 20 + assert len(filter_new_jobs(job_names, caps, conf, 0)) == 20 conf.parallelism_strategy = "table-sequential" - assert len(filter_new_jobs(job_names, caps, conf)) == 8 + assert len(filter_new_jobs(job_names, caps, conf, 0)) == 8 def test_no_input() -> None: caps, conf = get_caps_conf() - assert filter_new_jobs([], caps, conf) == [] + assert filter_new_jobs([], caps, conf, 0) == [] + + +def test_existing_jobs_count() -> None: + jobs = [f"job{i}" for i in range(50)] + caps, conf = get_caps_conf() + + # default is 20 jobs + assert len(filter_new_jobs(jobs, caps, conf, 0)) == 20 + + # if 5 are already running, just return 15 + assert len(filter_new_jobs(jobs, caps, conf, 5)) == 15 + + # ...etc + assert len(filter_new_jobs(jobs, caps, conf, 16)) == 4 + + assert len(filter_new_jobs(jobs, caps, conf, 300)) == 0 + assert len(filter_new_jobs(jobs, caps, conf, 20)) == 0 + assert len(filter_new_jobs(jobs, caps, conf, 19)) == 1