Skip to content

Commit

Permalink
start moving to futures
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Jun 19, 2024
1 parent 78a5989 commit c265ecc
Showing 1 changed file with 24 additions and 13 deletions.
37 changes: 24 additions & 13 deletions dlt/load/load.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import contextlib
from functools import reduce
import datetime # noqa: 251
from typing import Dict, List, Optional, Tuple, Set, Iterator, Iterable, Sequence
from concurrent.futures import Executor
from typing import Dict, List, Optional, Tuple, Set, Iterator, Iterable, Sequence, Union
from concurrent.futures import Executor, Future
import os
from copy import deepcopy

Expand Down Expand Up @@ -61,6 +61,7 @@
filter_new_jobs,
)

TJobOrFuture = Union[LoadJob, Future[LoadJob]]

class Load(Runnable[Executor], WithStepInfo[LoadMetrics, LoadInfo]):
pool: Executor
Expand Down Expand Up @@ -137,6 +138,7 @@ def w_spool_job(
self: "Load", file_path: str, load_id: str, schema: Schema
) -> Optional[LoadJob]:
job: LoadJob = None
print(file_path)
try:
is_staging_destination_job = self.is_staging_destination_job(file_path)
job_client = self.get_destination_client(schema)
Expand Down Expand Up @@ -197,7 +199,8 @@ def w_spool_job(

def spool_new_jobs(
self, load_id: str, schema: Schema, running_jobs_count: int
) -> List[LoadJob]:
) -> List[Future[LoadJob]]:

# use thread based pool as jobs processing is mostly I/O and we do not want to pickle jobs
load_files = self.load_storage.list_new_jobs(load_id)
file_count = len(load_files)
Expand All @@ -208,12 +211,14 @@ def spool_new_jobs(
load_files = filter_new_jobs(load_files, self.capabilities, self.config, running_jobs_count)
file_count = len(load_files)
logger.info(f"Will load additional {file_count}, creating jobs")
param_chunk = [(id(self), file, load_id, schema) for file in load_files]
# exceptions should not be raised, None as job is a temporary failure
# other jobs should not be affected
jobs = self.pool.map(Load.w_spool_job, *zip(*param_chunk))
# remove None jobs and check the rest
return [job for job in jobs if job is not None]
jobs: List[Future[LoadJob]] = []

for file in load_files:
params = (id(self), file, load_id, schema)
jobs.append(self.pool.submit(Load.w_spool_job, *params))

# return futures
return jobs

def retrieve_jobs(
self, client: JobClientBase, load_id: str, staging_client: JobClientBase = None
Expand Down Expand Up @@ -279,16 +284,16 @@ def create_followup_jobs(
return jobs

def complete_jobs(
self, load_id: str, jobs: List[LoadJob], schema: Schema
) -> Tuple[List[LoadJob], Exception]:
self, load_id: str, jobs: List[TJobOrFuture], schema: Schema
) -> Tuple[List[TJobOrFuture], Exception]:
"""Run periodically in the main thread to collect job execution statuses.
After detecting change of status, it commits the job state by moving it to the right folder
May create one or more followup jobs that get scheduled as new jobs. New jobs are created
only in terminal states (completed / failed)
"""
# list of jobs still running
remaining_jobs: List[LoadJob] = []
remaining_jobs: List[TJobOrFuture] = []
# if an exception condition was met, return it to the main runner
pending_exception: Exception = None

Expand All @@ -313,6 +318,12 @@ def _schedule_followup_jobs(followup_jobs: Iterable[NewLoadJob]) -> None:
logger.info(f"Will complete {len(jobs)} for {load_id}")
for ii in range(len(jobs)):
job = jobs[ii]
if isinstance(job, Future):
if not job.done():
remaining_jobs.append(job)
continue
job = job.result()

logger.debug(f"Checking state for job {job.job_id()}")
state: TLoadJobState = job.state()
if state == "running":
Expand Down Expand Up @@ -455,7 +466,7 @@ def load_single_package(self, load_id: str, schema: Schema) -> None:
self.load_storage.commit_schema_update(load_id, applied_update)

# collect all unfinished jobs
running_jobs: List[LoadJob] = []
running_jobs: List[TJobOrFuture] = []
if self.staging_destination:
with self.get_staging_destination_client(schema) as staging_client:
running_jobs += self.retrieve_jobs(job_client, load_id, staging_client)
Expand Down

0 comments on commit c265ecc

Please sign in to comment.