diff --git a/.github/weaviate-compose.yml b/.github/weaviate-compose.yml index 8bbedb7b23..8d715c758f 100644 --- a/.github/weaviate-compose.yml +++ b/.github/weaviate-compose.yml @@ -11,8 +11,6 @@ services: image: semitechnologies/weaviate:1.21.1 ports: - 8080:8080 - volumes: - - weaviate_data restart: on-failure:0 environment: QUERY_DEFAULTS_LIMIT: 25 diff --git a/.github/workflows/test_destination_clickhouse.yml b/.github/workflows/test_destination_clickhouse.yml index 4aea5a8e90..5b6848f2fe 100644 --- a/.github/workflows/test_destination_clickhouse.yml +++ b/.github/workflows/test_destination_clickhouse.yml @@ -68,9 +68,9 @@ jobs: # OSS ClickHouse - run: | - docker-compose -f "tests/load/clickhouse/clickhouse-compose.yml" up -d + docker compose -f "tests/load/clickhouse/clickhouse-compose.yml" up -d echo "Waiting for ClickHouse to be healthy..." - timeout 30s bash -c 'until docker-compose -f "tests/load/clickhouse/clickhouse-compose.yml" ps | grep -q "healthy"; do sleep 1; done' + timeout 30s bash -c 'until docker compose -f "tests/load/clickhouse/clickhouse-compose.yml" ps | grep -q "healthy"; do sleep 1; done' echo "ClickHouse is up and running" name: Start ClickHouse OSS @@ -101,7 +101,7 @@ jobs: - name: Stop ClickHouse OSS if: always() - run: docker-compose -f "tests/load/clickhouse/clickhouse-compose.yml" down -v + run: docker compose -f "tests/load/clickhouse/clickhouse-compose.yml" down -v # ClickHouse Cloud - run: | diff --git a/.github/workflows/test_destination_dremio.yml b/.github/workflows/test_destination_dremio.yml index 1b47268b59..7ec6c4f697 100644 --- a/.github/workflows/test_destination_dremio.yml +++ b/.github/workflows/test_destination_dremio.yml @@ -43,7 +43,7 @@ jobs: uses: actions/checkout@master - name: Start dremio - run: docker-compose -f "tests/load/dremio/docker-compose.yml" up -d + run: docker compose -f "tests/load/dremio/docker-compose.yml" up -d - name: Setup Python uses: actions/setup-python@v4 @@ -87,4 +87,4 @@ jobs: - name: Stop dremio if: always() - run: docker-compose -f "tests/load/dremio/docker-compose.yml" down -v + run: docker compose -f "tests/load/dremio/docker-compose.yml" down -v diff --git a/.github/workflows/test_doc_snippets.yml b/.github/workflows/test_doc_snippets.yml index b140935d4c..6094f2c0ac 100644 --- a/.github/workflows/test_doc_snippets.yml +++ b/.github/workflows/test_doc_snippets.yml @@ -60,7 +60,7 @@ jobs: uses: actions/checkout@master - name: Start weaviate - run: docker-compose -f ".github/weaviate-compose.yml" up -d + run: docker compose -f ".github/weaviate-compose.yml" up -d - name: Setup Python uses: actions/setup-python@v4 diff --git a/.github/workflows/test_local_destinations.yml b/.github/workflows/test_local_destinations.yml index f1bf6016bc..78ea23ec1c 100644 --- a/.github/workflows/test_local_destinations.yml +++ b/.github/workflows/test_local_destinations.yml @@ -73,7 +73,7 @@ jobs: uses: actions/checkout@master - name: Start weaviate - run: docker-compose -f ".github/weaviate-compose.yml" up -d + run: docker compose -f ".github/weaviate-compose.yml" up -d - name: Setup Python uses: actions/setup-python@v4 @@ -109,4 +109,4 @@ jobs: - name: Stop weaviate if: always() - run: docker-compose -f ".github/weaviate-compose.yml" down -v + run: docker compose -f ".github/weaviate-compose.yml" down -v diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index ca9d6a2d94..ded7a28ad7 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -27,6 +27,7 @@ from dlt.common import logger from dlt.common.configuration.specs.base_configuration import extract_inner_hint from dlt.common.destination.utils import verify_schema_capabilities +from dlt.common.exceptions import TerminalValueError from dlt.common.normalizers.naming import NamingConvention from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.schema.utils import ( @@ -42,6 +43,8 @@ InvalidDestinationReference, UnknownDestinationModule, DestinationSchemaTampered, + DestinationTransientException, + DestinationTerminalException, ) from dlt.common.schema.exceptions import UnknownTableException from dlt.common.storages import FileStorage @@ -258,11 +261,45 @@ class DestinationClientDwhWithStagingConfiguration(DestinationClientDwhConfigura """configuration of the staging, if present, injected at runtime""" -TLoadJobState = Literal["running", "failed", "retry", "completed"] +TLoadJobState = Literal["ready", "running", "failed", "retry", "completed"] -class LoadJob: - """Represents a job that loads a single file +class LoadJob(ABC): + """ + A stateful load job, represents one job file + """ + + def __init__(self, file_path: str) -> None: + self._file_path = file_path + self._file_name = FileStorage.get_file_name_from_file_path(file_path) + # NOTE: we only accept a full filepath in the constructor + assert self._file_name != self._file_path + self._parsed_file_name = ParsedLoadJobFileName.parse(self._file_name) + + def job_id(self) -> str: + """The job id that is derived from the file name and does not changes during job lifecycle""" + return self._parsed_file_name.job_id() + + def file_name(self) -> str: + """A name of the job file""" + return self._file_name + + def job_file_info(self) -> ParsedLoadJobFileName: + return self._parsed_file_name + + @abstractmethod + def state(self) -> TLoadJobState: + """Returns current state. Should poll external resource if necessary.""" + pass + + @abstractmethod + def exception(self) -> str: + """The exception associated with failed or retry states""" + pass + + +class RunnableLoadJob(LoadJob, ABC): + """Represents a runnable job that loads a single file Each job starts in "running" state and ends in one of terminal states: "retry", "failed" or "completed". Each job is uniquely identified by a file name. The file is guaranteed to exist in "running" state. In terminal state, the file may not be present. @@ -273,39 +310,80 @@ class LoadJob: immediately transition job into "failed" or "retry" state respectively. """ - def __init__(self, file_name: str) -> None: + def __init__(self, file_path: str) -> None: """ File name is also a job id (or job id is deterministically derived) so it must be globally unique """ # ensure file name - assert file_name == FileStorage.get_file_name_from_file_path(file_name) - self._file_name = file_name - self._parsed_file_name = ParsedLoadJobFileName.parse(file_name) + super().__init__(file_path) + self._state: TLoadJobState = "ready" + self._exception: Exception = None - @abstractmethod - def state(self) -> TLoadJobState: - """Returns current state. Should poll external resource if necessary.""" - pass + # variables needed by most jobs, set by the loader in set_run_vars + self._schema: Schema = None + self._load_table: TTableSchema = None + self._load_id: str = None + self._job_client: "JobClientBase" = None - def file_name(self) -> str: - """A name of the job file""" - return self._file_name + def set_run_vars(self, load_id: str, schema: Schema, load_table: TTableSchema) -> None: + """ + called by the loader right before the job is run + """ + self._load_id = load_id + self._schema = schema + self._load_table = load_table - def job_id(self) -> str: - """The job id that is derived from the file name and does not changes during job lifecycle""" - return self._parsed_file_name.job_id() + @property + def load_table_name(self) -> str: + return self._load_table["name"] - def job_file_info(self) -> ParsedLoadJobFileName: - return self._parsed_file_name + def run_managed( + self, + job_client: "JobClientBase", + ) -> None: + """ + wrapper around the user implemented run method + """ + # only jobs that are not running or have not reached a final state + # may be started + assert self._state in ("ready", "retry") + self._job_client = job_client + + # filepath is now moved to running + try: + self._state = "running" + self._job_client.prepare_load_job_execution(self) + self.run() + self._state = "completed" + except (DestinationTerminalException, TerminalValueError) as e: + self._state = "failed" + self._exception = e + except (DestinationTransientException, Exception) as e: + self._state = "retry" + self._exception = e + finally: + # sanity check + assert self._state in ("completed", "retry", "failed") @abstractmethod + def run(self) -> None: + """ + run the actual job, this will be executed on a thread and should be implemented by the user + exception will be handled outside of this function + """ + raise NotImplementedError() + + def state(self) -> TLoadJobState: + """Returns current state. Should poll external resource if necessary.""" + return self._state + def exception(self) -> str: """The exception associated with failed or retry states""" - pass + return str(self._exception) -class NewLoadJob(LoadJob): - """Adds a trait that allows to save new job file""" +class FollowupJob: + """Base class for follow up jobs that should be created""" @abstractmethod def new_file_path(self) -> str: @@ -313,35 +391,14 @@ def new_file_path(self) -> str: pass -class FollowupJob: - """Adds a trait that allows to create a followup job""" +class HasFollowupJobs: + """Adds a trait that allows to create single or table chain followup jobs""" - def create_followup_jobs(self, final_state: TLoadJobState) -> List[NewLoadJob]: + def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJob]: """Return list of new jobs. `final_state` is state to which this job transits""" return [] -class DoNothingJob(LoadJob): - """The most lazy class of dlt""" - - def __init__(self, file_path: str) -> None: - super().__init__(FileStorage.get_file_name_from_file_path(file_path)) - - def state(self) -> TLoadJobState: - # this job is always done - return "completed" - - def exception(self) -> str: - # this part of code should be never reached - raise NotImplementedError() - - -class DoNothingFollowupJob(DoNothingJob, FollowupJob): - """The second most lazy class of dlt""" - - pass - - class JobClientBase(ABC): def __init__( self, @@ -394,13 +451,16 @@ def update_stored_schema( return expected_update @abstractmethod - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - """Creates and starts a load job for a particular `table` with content in `file_path`""" + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + """Creates a load job for a particular `table` with content in `file_path`""" pass - @abstractmethod - def restore_file_load(self, file_path: str) -> LoadJob: - """Finds and restores already started loading job identified by `file_path` if destination supports it.""" + def prepare_load_job_execution( # noqa: B027, optional override + self, job: RunnableLoadJob + ) -> None: + """Prepare the connected job client for the execution of a load job (used for query tags in sql clients)""" pass def should_truncate_table_before_load(self, table: TTableSchema) -> bool: @@ -410,7 +470,7 @@ def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, - ) -> List[NewLoadJob]: + ) -> List[FollowupJob]: """Creates a list of followup jobs that should be executed after a table chain is completed""" return [] diff --git a/dlt/common/runtime/signals.py b/dlt/common/runtime/signals.py index 8d1cb3803e..a8fa70936e 100644 --- a/dlt/common/runtime/signals.py +++ b/dlt/common/runtime/signals.py @@ -32,6 +32,11 @@ def raise_if_signalled() -> None: raise SignalReceivedException(_received_signal) +def signal_received() -> bool: + """check if a signal was received""" + return True if _received_signal else False + + def sleep(sleep_seconds: float) -> None: """A signal-aware version of sleep function. Will raise SignalReceivedException if signal was received during sleep period.""" # do not allow sleeping if signal was received diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index 4d84094427..b0ed93f734 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -723,19 +723,12 @@ def build_job_file_name( @staticmethod def is_package_partially_loaded(package_info: LoadPackageInfo) -> bool: - """Checks if package is partially loaded - has jobs that are not new.""" - if package_info.state == "normalized": - pending_jobs: Sequence[TJobState] = ["new_jobs"] - else: - pending_jobs = ["completed_jobs", "failed_jobs"] - return ( - sum( - len(package_info.jobs[job_state]) - for job_state in WORKING_FOLDERS - if job_state not in pending_jobs - ) - > 0 - ) + """Checks if package is partially loaded - has jobs that are completed and jobs that are not.""" + all_jobs_count = sum(len(package_info.jobs[job_state]) for job_state in WORKING_FOLDERS) + completed_jobs_count = len(package_info.jobs["completed_jobs"]) + if completed_jobs_count and all_jobs_count - completed_jobs_count > 0: + return True + return False @staticmethod def _job_elapsed_time_seconds(file_path: str, now_ts: float = None) -> float: diff --git a/dlt/common/typing.py b/dlt/common/typing.py index fdd27161f7..ee11a77965 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -106,7 +106,7 @@ VARIANT_FIELD_FORMAT = "v_%s" TFileOrPath = Union[str, PathLike, IO[Any]] TSortOrder = Literal["asc", "desc"] -TLoaderFileFormat = Literal["jsonl", "typed-jsonl", "insert_values", "parquet", "csv"] +TLoaderFileFormat = Literal["jsonl", "typed-jsonl", "insert_values", "parquet", "csv", "reference"] """known loader file formats""" diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index 4225d63fe7..371c1bae22 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -45,10 +45,10 @@ ) from dlt.common.schema.utils import table_schema_has_type from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import LoadJob, DoNothingFollowupJob, DoNothingJob -from dlt.common.destination.reference import NewLoadJob, SupportsStagingDestination +from dlt.common.destination.reference import LoadJob +from dlt.common.destination.reference import FollowupJob, SupportsStagingDestination from dlt.common.data_writers.escape import escape_hive_identifier -from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob +from dlt.destinations.sql_jobs import SqlStagingCopyFollowupJob, SqlMergeFollowupJob from dlt.destinations.typing import DBApi, DBTransaction from dlt.destinations.exceptions import ( @@ -65,6 +65,7 @@ ) from dlt.destinations.typing import DBApiCursor from dlt.destinations.job_client_impl import SqlJobClientWithStaging +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs, FinalizedLoadJob from dlt.destinations.impl.athena.configuration import AthenaClientConfiguration from dlt.destinations.type_mapping import TypeMapper from dlt.destinations import path_utils @@ -160,7 +161,7 @@ def __init__(self) -> None: DLTAthenaFormatter._INSTANCE = self -class AthenaMergeJob(SqlMergeJob): +class AthenaMergeJob(SqlMergeFollowupJob): @classmethod def _new_temp_table_name(cls, name_prefix: str, sql_client: SqlClientBase[Any]) -> str: # reproducible name so we know which table to drop @@ -468,7 +469,9 @@ def _get_table_update_sql( LOCATION '{location}';""") return sql - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: """Starts SqlLoadJob for files ending with .sql or returns None to let derived classes to handle their specific jobs""" if table_schema_has_type(table, "time"): raise LoadJobTerminalException( @@ -476,32 +479,38 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> "Athena cannot load TIME columns from parquet tables. Please convert" " `datetime.time` objects in your data to `str` or `datetime.datetime`.", ) - job = super().start_file_load(table, file_path, load_id) + job = super().create_load_job(table, file_path, load_id, restore) if not job: job = ( - DoNothingFollowupJob(file_path) + FinalizedLoadJobWithFollowupJobs(file_path) if self._is_iceberg_table(self.prepare_load_table(table["name"])) - else DoNothingJob(file_path) + else FinalizedLoadJob(file_path) ) return job - def _create_append_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_append_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[FollowupJob]: if self._is_iceberg_table(self.prepare_load_table(table_chain[0]["name"])): return [ - SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": False}) + SqlStagingCopyFollowupJob.from_table_chain( + table_chain, self.sql_client, {"replace": False} + ) ] return super()._create_append_followup_jobs(table_chain) def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] - ) -> List[NewLoadJob]: + ) -> List[FollowupJob]: if self._is_iceberg_table(self.prepare_load_table(table_chain[0]["name"])): return [ - SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": True}) + SqlStagingCopyFollowupJob.from_table_chain( + table_chain, self.sql_client, {"replace": True} + ) ] return super()._create_replace_followup_jobs(table_chain) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: return [AthenaMergeJob.from_table_chain(table_chain, self.sql_client)] def _is_iceberg_table(self, table: TTableSchema) -> bool: diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index 095974d186..ef4e31acd1 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -1,6 +1,7 @@ import functools import os from pathlib import Path +import time from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, cast import google.cloud.bigquery as bigquery # noqa: I250 @@ -10,14 +11,16 @@ from google.cloud.bigquery.retry import _RETRYABLE_REASONS from dlt.common import logger +from dlt.common.runtime.signals import sleep from dlt.common.json import json from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( + HasFollowupJobs, FollowupJob, - NewLoadJob, TLoadJobState, - LoadJob, + RunnableLoadJob, SupportsStagingDestination, + LoadJob, ) from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat @@ -33,7 +36,7 @@ DatabaseUndefinedRelation, DestinationSchemaWillNotUpdate, DestinationTerminalException, - LoadJobNotExistsException, + DatabaseTerminalException, LoadJobTerminalException, ) from dlt.destinations.impl.bigquery.bigquery_adapter import ( @@ -48,8 +51,8 @@ from dlt.destinations.impl.bigquery.configuration import BigQueryClientConfiguration from dlt.destinations.impl.bigquery.sql_client import BigQuerySqlClient, BQ_TERMINAL_REASONS from dlt.destinations.job_client_impl import SqlJobClientWithStaging -from dlt.destinations.job_impl import NewReferenceJob -from dlt.destinations.sql_jobs import SqlMergeJob +from dlt.destinations.job_impl import ReferenceFollowupJob +from dlt.destinations.sql_jobs import SqlMergeFollowupJob from dlt.destinations.type_mapping import TypeMapper from dlt.destinations.utils import parse_db_data_type_str_with_precision @@ -104,60 +107,95 @@ def from_db_type( return super().from_db_type(*parse_db_data_type_str_with_precision(db_type)) -class BigQueryLoadJob(LoadJob, FollowupJob): +class BigQueryLoadJob(RunnableLoadJob, HasFollowupJobs): def __init__( self, - file_name: str, - bq_load_job: bigquery.LoadJob, + file_path: str, http_timeout: float, retry_deadline: float, ) -> None: - self.bq_load_job = bq_load_job - self.default_retry = bigquery.DEFAULT_RETRY.with_deadline(retry_deadline) - self.http_timeout = http_timeout - super().__init__(file_name) - - def state(self) -> TLoadJobState: - if not self.bq_load_job.done(retry=self.default_retry, timeout=self.http_timeout): - return "running" - if self.bq_load_job.output_rows is not None and self.bq_load_job.error_result is None: - return "completed" - reason = self.bq_load_job.error_result.get("reason") - if reason in BQ_TERMINAL_REASONS: - # the job permanently failed for the reason above - return "failed" - elif reason in ["internalError"]: - logger.warning( - f"Got reason {reason} for job {self.file_name}, job considered still" - f" running. ({self.bq_load_job.error_result})" - ) - # the status of the job couldn't be obtained, job still running. - return "running" - else: - # retry on all other reasons, including `backendError` which requires retry when the job is done. - return "retry" - - def bigquery_job_id(self) -> str: - return BigQueryLoadJob.get_job_id_from_file_path(super().file_name()) + super().__init__(file_path) + self._default_retry = bigquery.DEFAULT_RETRY.with_deadline(retry_deadline) + self._http_timeout = http_timeout + self._job_client: "BigQueryClient" = None + self._bq_load_job: bigquery.LoadJob = None + # vars only used for testing + self._created_job = False + self._resumed_job = False + + def run(self) -> None: + # start the job (or retrieve in case it already exists) + try: + self._bq_load_job = self._job_client._create_load_job(self._load_table, self._file_path) + self._created_job = True + except api_core_exceptions.GoogleAPICallError as gace: + reason = BigQuerySqlClient._get_reason_from_errors(gace) + if reason == "notFound": + # google.api_core.exceptions.NotFound: 404 – table not found + raise DatabaseUndefinedRelation(gace) from gace + elif ( + reason == "duplicate" + ): # google.api_core.exceptions.Conflict: 409 PUT – already exists + self._bq_load_job = self._job_client._retrieve_load_job(self._file_path) + self._resumed_job = True + logger.info( + f"Found existing bigquery job for job {self._file_name}, will resume job." + ) + elif reason in BQ_TERMINAL_REASONS: + # google.api_core.exceptions.BadRequest - will not be processed ie bad job name + raise LoadJobTerminalException( + self._file_path, f"The server reason was: {reason}" + ) from gace + else: + raise DatabaseTransientException(gace) from gace + + # we loop on the job thread until we detect a status change + while True: + sleep(1) + # not done yet + if not self._bq_load_job.done(retry=self._default_retry, timeout=self._http_timeout): + continue + # done, break loop and go to completed state + if self._bq_load_job.output_rows is not None and self._bq_load_job.error_result is None: + break + reason = self._bq_load_job.error_result.get("reason") + if reason in BQ_TERMINAL_REASONS: + # the job permanently failed for the reason above + raise DatabaseTerminalException( + Exception( + f"Bigquery Load Job failed, reason reported from bigquery: '{reason}'" + ) + ) + elif reason in ["internalError"]: + logger.warning( + f"Got reason {reason} for job {self._file_name}, job considered still" + f" running. ({self._bq_load_job.error_result})" + ) + continue + else: + raise DatabaseTransientException( + Exception( + f"Bigquery Job needs to be retried, reason reported from bigquer '{reason}'" + ) + ) def exception(self) -> str: - exception: str = json.dumps( + return json.dumps( { - "error_result": self.bq_load_job.error_result, - "errors": self.bq_load_job.errors, - "job_start": self.bq_load_job.started, - "job_end": self.bq_load_job.ended, - "job_id": self.bq_load_job.job_id, + "error_result": self._bq_load_job.error_result, + "errors": self._bq_load_job.errors, + "job_start": self._bq_load_job.started, + "job_end": self._bq_load_job.ended, + "job_id": self._bq_load_job.job_id, } ) - return exception @staticmethod def get_job_id_from_file_path(file_path: str) -> str: return Path(file_path).name.replace(".", "_") -class BigQueryMergeJob(SqlMergeJob): +class BigQueryMergeJob(SqlMergeFollowupJob): @classmethod def gen_key_table_clauses( cls, @@ -195,97 +233,46 @@ def __init__( self.sql_client: BigQuerySqlClient = sql_client # type: ignore self.type_mapper = BigQueryTypeMapper(self.capabilities) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: return [BigQueryMergeJob.from_table_chain(table_chain, self.sql_client)] - def restore_file_load(self, file_path: str) -> LoadJob: - """Returns a completed SqlLoadJob or restored BigQueryLoadJob - - See base class for details on SqlLoadJob. - BigQueryLoadJob is restored with a job ID derived from `file_path`. - - Args: - file_path (str): a path to a job file. - - Returns: - LoadJob: completed SqlLoadJob or restored BigQueryLoadJob - """ - job = super().restore_file_load(file_path) - if not job: - try: - job = BigQueryLoadJob( - FileStorage.get_file_name_from_file_path(file_path), - self._retrieve_load_job(file_path), - self.config.http_timeout, - self.config.retry_deadline, - ) - except api_core_exceptions.GoogleAPICallError as gace: - reason = BigQuerySqlClient._get_reason_from_errors(gace) - if reason == "notFound": - raise LoadJobNotExistsException(file_path) from gace - elif reason in BQ_TERMINAL_REASONS: - raise LoadJobTerminalException( - file_path, f"The server reason was: {reason}" - ) from gace - else: - raise DatabaseTransientException(gace) from gace - return job - - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - job = super().start_file_load(table, file_path, load_id) + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + job = super().create_load_job(table, file_path, load_id) if not job: insert_api = table.get("x-insert-api", "default") - try: - if insert_api == "streaming": - if table["write_disposition"] != "append": - raise DestinationTerminalException( - "BigQuery streaming insert can only be used with `append`" - " write_disposition, while the given resource has" - f" `{table['write_disposition']}`." - ) - if file_path.endswith(".jsonl"): - job_cls = DestinationJsonlLoadJob - elif file_path.endswith(".parquet"): - job_cls = DestinationParquetLoadJob # type: ignore - else: - raise ValueError( - f"Unsupported file type for BigQuery streaming inserts: {file_path}" - ) - - job = job_cls( - table, - file_path, - self.config, # type: ignore - self.schema, - destination_state(), - functools.partial(_streaming_load, self.sql_client), - [], + if insert_api == "streaming": + if table["write_disposition"] != "append": + raise DestinationTerminalException( + "BigQuery streaming insert can only be used with `append`" + " write_disposition, while the given resource has" + f" `{table['write_disposition']}`." ) + if file_path.endswith(".jsonl"): + job_cls = DestinationJsonlLoadJob + elif file_path.endswith(".parquet"): + job_cls = DestinationParquetLoadJob # type: ignore else: - job = BigQueryLoadJob( - FileStorage.get_file_name_from_file_path(file_path), - self._create_load_job(table, file_path), - self.config.http_timeout, - self.config.retry_deadline, + raise ValueError( + f"Unsupported file type for BigQuery streaming inserts: {file_path}" ) - except api_core_exceptions.GoogleAPICallError as gace: - reason = BigQuerySqlClient._get_reason_from_errors(gace) - if reason == "notFound": - # google.api_core.exceptions.NotFound: 404 – table not found - raise DatabaseUndefinedRelation(gace) from gace - elif ( - reason == "duplicate" - ): # google.api_core.exceptions.Conflict: 409 PUT – already exists - return self.restore_file_load(file_path) - elif reason in BQ_TERMINAL_REASONS: - # google.api_core.exceptions.BadRequest - will not be processed ie bad job name - raise LoadJobTerminalException( - file_path, f"The server reason was: {reason}" - ) from gace - else: - raise DatabaseTransientException(gace) from gace + job = job_cls( + file_path, + self.config, # type: ignore + destination_state(), + _streaming_load, # type: ignore + [], + callable_requires_job_client_args=True, + ) + else: + job = BigQueryLoadJob( + file_path, + self.config.http_timeout, + self.config.retry_deadline, + ) return job def _get_table_update_sql( @@ -445,8 +432,8 @@ def _create_load_job(self, table: TTableSchema, file_path: str) -> bigquery.Load # determine whether we load from local or uri bucket_path = None ext: str = os.path.splitext(file_path)[1][1:] - if NewReferenceJob.is_reference_job(file_path): - bucket_path = NewReferenceJob.resolve_reference(file_path) + if ReferenceFollowupJob.is_reference_job(file_path): + bucket_path = ReferenceFollowupJob.resolve_reference(file_path) ext = os.path.splitext(bucket_path)[1][1:] # Select a correct source format @@ -515,7 +502,7 @@ def _should_autodetect_schema(self, table_name: str) -> bool: def _streaming_load( - sql_client: SqlClientBase[BigQueryClient], items: List[Dict[Any, Any]], table: Dict[str, Any] + items: List[Dict[Any, Any]], table: Dict[str, Any], job_client: BigQueryClient ) -> None: """ Upload the given items into BigQuery table, using streaming API. @@ -542,6 +529,8 @@ def _should_retry(exc: api_core_exceptions.GoogleAPICallError) -> bool: reason = exc.errors[0]["reason"] return reason in _RETRYABLE_REASONS + sql_client = job_client.sql_client + full_name = sql_client.make_qualified_table_name(table["name"], escape=False) bq_client = sql_client._client diff --git a/dlt/destinations/impl/clickhouse/clickhouse.py b/dlt/destinations/impl/clickhouse/clickhouse.py index 148fca3f1e..5bd34e0e0d 100644 --- a/dlt/destinations/impl/clickhouse/clickhouse.py +++ b/dlt/destinations/impl/clickhouse/clickhouse.py @@ -18,9 +18,10 @@ from dlt.common.destination.reference import ( SupportsStagingDestination, TLoadJobState, + HasFollowupJobs, + RunnableLoadJob, FollowupJob, LoadJob, - NewLoadJob, ) from dlt.common.schema import Schema, TColumnSchema from dlt.common.schema.typing import ( @@ -51,8 +52,8 @@ SqlJobClientBase, SqlJobClientWithStaging, ) -from dlt.destinations.job_impl import NewReferenceJob, EmptyLoadJob -from dlt.destinations.sql_jobs import SqlMergeJob +from dlt.destinations.job_impl import ReferenceFollowupJob, FinalizedLoadJobWithFollowupJobs +from dlt.destinations.sql_jobs import SqlMergeFollowupJob from dlt.destinations.type_mapping import TypeMapper @@ -123,22 +124,25 @@ def from_db_type( return super().from_db_type(db_type, precision, scale) -class ClickHouseLoadJob(LoadJob, FollowupJob): +class ClickHouseLoadJob(RunnableLoadJob, HasFollowupJobs): def __init__( self, file_path: str, - table_name: str, - client: ClickHouseSqlClient, staging_credentials: Optional[CredentialsConfiguration] = None, ) -> None: - file_name = FileStorage.get_file_name_from_file_path(file_path) - super().__init__(file_name) + super().__init__(file_path) + self._job_client: "ClickHouseClient" = None + self._staging_credentials = staging_credentials + + def run(self) -> None: + client = self._job_client.sql_client - qualified_table_name = client.make_qualified_table_name(table_name) + qualified_table_name = client.make_qualified_table_name(self.load_table_name) bucket_path = None + file_name = self._file_name - if NewReferenceJob.is_reference_job(file_path): - bucket_path = NewReferenceJob.resolve_reference(file_path) + if ReferenceFollowupJob.is_reference_job(self._file_path): + bucket_path = ReferenceFollowupJob.resolve_reference(self._file_path) file_name = FileStorage.get_file_name_from_file_path(bucket_path) bucket_url = urlparse(bucket_path) bucket_scheme = bucket_url.scheme @@ -152,7 +156,7 @@ def __init__( if not bucket_path: # Local filesystem. if ext == "jsonl": - compression = "gz" if FileStorage.is_gzipped(file_path) else "none" + compression = "gz" if FileStorage.is_gzipped(self._file_path) else "none" try: with clickhouse_connect.create_client( host=client.credentials.host, @@ -165,7 +169,7 @@ def __init__( insert_file( clickhouse_connect_client, qualified_table_name, - file_path, + self._file_path, fmt=clickhouse_format, settings={ "allow_experimental_lightweight_delete": 1, @@ -176,7 +180,7 @@ def __init__( ) except clickhouse_connect.driver.exceptions.Error as e: raise LoadJobTerminalException( - file_path, + self._file_path, f"ClickHouse connection failed due to {e}.", ) from e return @@ -188,9 +192,9 @@ def __init__( compression = "none" if config.get("data_writer.disable_compression") else "gz" if bucket_scheme in ("s3", "gs", "gcs"): - if not isinstance(staging_credentials, AwsCredentialsWithoutDefaults): + if not isinstance(self._staging_credentials, AwsCredentialsWithoutDefaults): raise LoadJobTerminalException( - file_path, + self._file_path, dedent( """ Google Cloud Storage buckets must be configured using the S3 compatible access pattern. @@ -201,10 +205,10 @@ def __init__( ) bucket_http_url = convert_storage_to_http_scheme( - bucket_url, endpoint=staging_credentials.endpoint_url + bucket_url, endpoint=self._staging_credentials.endpoint_url ) - access_key_id = staging_credentials.aws_access_key_id - secret_access_key = staging_credentials.aws_secret_access_key + access_key_id = self._staging_credentials.aws_access_key_id + secret_access_key = self._staging_credentials.aws_secret_access_key auth = "NOSIGN" if access_key_id and secret_access_key: auth = f"'{access_key_id}','{secret_access_key}'" @@ -214,24 +218,22 @@ def __init__( ) elif bucket_scheme in ("az", "abfs"): - if not isinstance(staging_credentials, AzureCredentialsWithoutDefaults): + if not isinstance(self._staging_credentials, AzureCredentialsWithoutDefaults): raise LoadJobTerminalException( - file_path, + self._file_path, "Unsigned Azure Blob Storage access from ClickHouse isn't supported as yet.", ) # Authenticated access. - account_name = staging_credentials.azure_storage_account_name - storage_account_url = ( - f"https://{staging_credentials.azure_storage_account_name}.blob.core.windows.net" - ) - account_key = staging_credentials.azure_storage_account_key + account_name = self._staging_credentials.azure_storage_account_name + storage_account_url = f"https://{self._staging_credentials.azure_storage_account_name}.blob.core.windows.net" + account_key = self._staging_credentials.azure_storage_account_key # build table func table_function = f"azureBlobStorage('{storage_account_url}','{bucket_url.netloc}','{bucket_url.path}','{account_name}','{account_key}','{clickhouse_format}','{compression}')" else: raise LoadJobTerminalException( - file_path, + self._file_path, f"ClickHouse loader does not support '{bucket_scheme}' filesystem.", ) @@ -239,14 +241,8 @@ def __init__( with client.begin_transaction(): client.execute_sql(statement) - def state(self) -> TLoadJobState: - return "completed" - def exception(self) -> str: - raise NotImplementedError() - - -class ClickHouseMergeJob(SqlMergeJob): +class ClickHouseMergeJob(SqlMergeFollowupJob): @classmethod def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: return f"CREATE TABLE {temp_table_name} ENGINE = Memory AS {select_sql};" @@ -292,7 +288,7 @@ def __init__( self.active_hints = deepcopy(HINT_TO_CLICKHOUSE_ATTR) self.type_mapper = ClickHouseTypeMapper(self.capabilities) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: return [ClickHouseMergeJob.from_table_chain(table_chain, self.sql_client)] def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: @@ -319,11 +315,11 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non .strip() ) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - return super().start_file_load(table, file_path, load_id) or ClickHouseLoadJob( + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + return super().create_load_job(table, file_path, load_id, restore) or ClickHouseLoadJob( file_path, - table["name"], - self.sql_client, staging_credentials=( self.config.staging_config.credentials if self.config.staging_config else None ), @@ -374,6 +370,3 @@ def _from_db_type( self, ch_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: return self.type_mapper.from_db_type(ch_t, precision, scale) - - def restore_file_load(self, file_path: str) -> LoadJob: - return EmptyLoadJob.from_file_path(file_path, "completed") diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index fbe7fa4c6b..0a203c21b6 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -4,12 +4,13 @@ from dlt import config from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( + HasFollowupJobs, FollowupJob, - NewLoadJob, TLoadJobState, - LoadJob, + RunnableLoadJob, CredentialsConfiguration, SupportsStagingDestination, + LoadJob, ) from dlt.common.configuration.specs import ( AwsCredentialsWithoutDefaults, @@ -25,12 +26,12 @@ from dlt.destinations.insert_job_client import InsertValuesJobClient -from dlt.destinations.job_impl import EmptyLoadJob +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs from dlt.destinations.exceptions import LoadJobTerminalException from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration from dlt.destinations.impl.databricks.sql_client import DatabricksSqlClient -from dlt.destinations.sql_jobs import SqlMergeJob -from dlt.destinations.job_impl import NewReferenceJob +from dlt.destinations.sql_jobs import SqlMergeFollowupJob +from dlt.destinations.job_impl import ReferenceFollowupJob from dlt.destinations.type_mapping import TypeMapper @@ -103,30 +104,31 @@ def from_db_type( return super().from_db_type(db_type, precision, scale) -class DatabricksLoadJob(LoadJob, FollowupJob): +class DatabricksLoadJob(RunnableLoadJob, HasFollowupJobs): def __init__( self, - table: TTableSchema, file_path: str, - table_name: str, - load_id: str, - client: DatabricksSqlClient, staging_config: FilesystemConfiguration, ) -> None: - file_name = FileStorage.get_file_name_from_file_path(file_path) - super().__init__(file_name) - staging_credentials = staging_config.credentials + super().__init__(file_path) + self._staging_config = staging_config + self._job_client: "DatabricksClient" = None - qualified_table_name = client.make_qualified_table_name(table_name) + def run(self) -> None: + self._sql_client = self._job_client.sql_client + qualified_table_name = self._sql_client.make_qualified_table_name(self.load_table_name) + staging_credentials = self._staging_config.credentials # extract and prepare some vars bucket_path = orig_bucket_path = ( - NewReferenceJob.resolve_reference(file_path) - if NewReferenceJob.is_reference_job(file_path) + ReferenceFollowupJob.resolve_reference(self._file_path) + if ReferenceFollowupJob.is_reference_job(self._file_path) else "" ) file_name = ( - FileStorage.get_file_name_from_file_path(bucket_path) if bucket_path else file_name + FileStorage.get_file_name_from_file_path(bucket_path) + if bucket_path + else self._file_name ) from_clause = "" credentials_clause = "" @@ -166,13 +168,13 @@ def __init__( from_clause = f"FROM '{bucket_path}'" else: raise LoadJobTerminalException( - file_path, + self._file_path, f"Databricks cannot load data from staging bucket {bucket_path}. Only s3 and" " azure buckets are supported", ) else: raise LoadJobTerminalException( - file_path, + self._file_path, "Cannot load from local file. Databricks does not support loading from local files." " Configure staging with an s3 or azure storage bucket.", ) @@ -183,32 +185,32 @@ def __init__( elif file_name.endswith(".jsonl"): if not config.get("data_writer.disable_compression"): raise LoadJobTerminalException( - file_path, + self._file_path, "Databricks loader does not support gzip compressed JSON files. Please disable" " compression in the data writer configuration:" " https://dlthub.com/docs/reference/performance#disabling-and-enabling-file-compression", ) - if table_schema_has_type(table, "decimal"): + if table_schema_has_type(self._load_table, "decimal"): raise LoadJobTerminalException( - file_path, + self._file_path, "Databricks loader cannot load DECIMAL type columns from json files. Switch to" " parquet format to load decimals.", ) - if table_schema_has_type(table, "binary"): + if table_schema_has_type(self._load_table, "binary"): raise LoadJobTerminalException( - file_path, + self._file_path, "Databricks loader cannot load BINARY type columns from json files. Switch to" " parquet format to load byte values.", ) - if table_schema_has_type(table, "complex"): + if table_schema_has_type(self._load_table, "complex"): raise LoadJobTerminalException( - file_path, + self._file_path, "Databricks loader cannot load complex columns (lists and dicts) from json" " files. Switch to parquet format to load complex types.", ) - if table_schema_has_type(table, "date"): + if table_schema_has_type(self._load_table, "date"): raise LoadJobTerminalException( - file_path, + self._file_path, "Databricks loader cannot load DATE type columns from json files. Switch to" " parquet format to load dates.", ) @@ -216,7 +218,7 @@ def __init__( source_format = "JSON" format_options_clause = "FORMAT_OPTIONS('inferTimestamp'='true')" # Databricks fails when trying to load empty json files, so we have to check the file size - fs, _ = fsspec_from_config(staging_config) + fs, _ = fsspec_from_config(self._staging_config) file_size = fs.size(orig_bucket_path) if file_size == 0: # Empty file, do nothing return @@ -227,16 +229,10 @@ def __init__( FILEFORMAT = {source_format} {format_options_clause} """ - client.execute_sql(statement) + self._sql_client.execute_sql(statement) - def state(self) -> TLoadJobState: - return "completed" - def exception(self) -> str: - raise NotImplementedError() - - -class DatabricksMergeJob(SqlMergeJob): +class DatabricksMergeJob(SqlMergeFollowupJob): @classmethod def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: return f"CREATE TEMPORARY VIEW {temp_table_name} AS {select_sql};" @@ -271,24 +267,19 @@ def __init__( self.sql_client: DatabricksSqlClient = sql_client # type: ignore[assignment] self.type_mapper = DatabricksTypeMapper(self.capabilities) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - job = super().start_file_load(table, file_path, load_id) + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + job = super().create_load_job(table, file_path, load_id, restore) if not job: job = DatabricksLoadJob( - table, file_path, - table["name"], - load_id, - self.sql_client, staging_config=cast(FilesystemConfiguration, self.config.staging_config), ) return job - def restore_file_load(self, file_path: str) -> LoadJob: - return EmptyLoadJob.from_file_path(file_path, "completed") - - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: return [DatabricksMergeJob.from_table_chain(table_chain, self.sql_client)] def _make_add_column_sql( diff --git a/dlt/destinations/impl/destination/destination.py b/dlt/destinations/impl/destination/destination.py index 976dfa4fb5..0c4da81471 100644 --- a/dlt/destinations/impl/destination/destination.py +++ b/dlt/destinations/impl/destination/destination.py @@ -2,6 +2,7 @@ from types import TracebackType from typing import ClassVar, Optional, Type, Iterable, cast, List +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs, FinalizedLoadJob from dlt.common.destination.reference import LoadJob from dlt.common.typing import AnyFun from dlt.common.storages.load_package import destination_state @@ -10,12 +11,10 @@ from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( - LoadJob, - DoNothingJob, JobClientBase, + LoadJob, ) -from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.impl.destination.configuration import CustomDestinationClientConfiguration from dlt.destinations.job_impl import ( DestinationJsonlLoadJob, @@ -56,44 +55,49 @@ def update_stored_schema( ) -> Optional[TSchemaTables]: return super().update_stored_schema(only_tables, expected_update) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: # skip internal tables and remove columns from schema if so configured - skipped_columns: List[str] = [] if self.config.skip_dlt_columns_and_tables: if table["name"].startswith(self.schema._dlt_tables_prefix): - return DoNothingJob(file_path) - table = deepcopy(table) - for column in list(table["columns"].keys()): + return FinalizedLoadJob(file_path) + + skipped_columns: List[str] = [] + if self.config.skip_dlt_columns_and_tables: + for column in list(self.schema.tables[table["name"]]["columns"].keys()): if column.startswith(self.schema._dlt_tables_prefix): - table["columns"].pop(column) skipped_columns.append(column) # save our state in destination name scope load_state = destination_state() if file_path.endswith("parquet"): return DestinationParquetLoadJob( - table, file_path, self.config, - self.schema, load_state, self.destination_callable, skipped_columns, ) if file_path.endswith("jsonl"): return DestinationJsonlLoadJob( - table, file_path, self.config, - self.schema, load_state, self.destination_callable, skipped_columns, ) return None - def restore_file_load(self, file_path: str) -> LoadJob: - return EmptyLoadJob.from_file_path(file_path, "completed") + def prepare_load_table( + self, table_name: str, prepare_for_staging: bool = False + ) -> TTableSchema: + table = super().prepare_load_table(table_name, prepare_for_staging) + if self.config.skip_dlt_columns_and_tables: + for column in list(table["columns"].keys()): + if column.startswith(self.schema._dlt_tables_prefix): + table["columns"].pop(column) + return table def complete_load(self, load_id: str) -> None: ... diff --git a/dlt/destinations/impl/dremio/dremio.py b/dlt/destinations/impl/dremio/dremio.py index bea18cdea5..3611665f6c 100644 --- a/dlt/destinations/impl/dremio/dremio.py +++ b/dlt/destinations/impl/dremio/dremio.py @@ -3,11 +3,12 @@ from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( - FollowupJob, + HasFollowupJobs, TLoadJobState, - LoadJob, + RunnableLoadJob, SupportsStagingDestination, - NewLoadJob, + FollowupJob, + LoadJob, ) from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat, TColumnSchemaBase @@ -17,9 +18,9 @@ from dlt.destinations.impl.dremio.configuration import DremioClientConfiguration from dlt.destinations.impl.dremio.sql_client import DremioSqlClient from dlt.destinations.job_client_impl import SqlJobClientWithStaging -from dlt.destinations.job_impl import EmptyLoadJob -from dlt.destinations.job_impl import NewReferenceJob -from dlt.destinations.sql_jobs import SqlMergeJob +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs +from dlt.destinations.job_impl import ReferenceFollowupJob +from dlt.destinations.sql_jobs import SqlMergeFollowupJob from dlt.destinations.type_mapping import TypeMapper from dlt.destinations.sql_client import SqlClientBase @@ -69,7 +70,7 @@ def from_db_type( return super().from_db_type(db_type, precision, scale) -class DremioMergeJob(SqlMergeJob): +class DremioMergeJob(SqlMergeFollowupJob): @classmethod def _new_temp_table_name(cls, name_prefix: str, sql_client: SqlClientBase[Any]) -> str: return sql_client.make_qualified_table_name(f"_temp_{name_prefix}_{uniq_id()}") @@ -83,23 +84,25 @@ def default_order_by(cls) -> str: return "NULL" -class DremioLoadJob(LoadJob, FollowupJob): +class DremioLoadJob(RunnableLoadJob, HasFollowupJobs): def __init__( self, file_path: str, - table_name: str, - client: DremioSqlClient, stage_name: Optional[str] = None, ) -> None: - file_name = FileStorage.get_file_name_from_file_path(file_path) - super().__init__(file_name) + super().__init__(file_path) + self._stage_name = stage_name + self._job_client: "DremioClient" = None - qualified_table_name = client.make_qualified_table_name(table_name) + def run(self) -> None: + self._sql_client = self._job_client.sql_client + + qualified_table_name = self._sql_client.make_qualified_table_name(self.load_table_name) # extract and prepare some vars bucket_path = ( - NewReferenceJob.resolve_reference(file_path) - if NewReferenceJob.is_reference_job(file_path) + ReferenceFollowupJob.resolve_reference(self._file_path) + if ReferenceFollowupJob.is_reference_job(self._file_path) else "" ) @@ -107,33 +110,29 @@ def __init__( raise RuntimeError("Could not resolve bucket path.") file_name = ( - FileStorage.get_file_name_from_file_path(bucket_path) if bucket_path else file_name + FileStorage.get_file_name_from_file_path(bucket_path) + if bucket_path + else self._file_name ) bucket_url = urlparse(bucket_path) bucket_scheme = bucket_url.scheme - if bucket_scheme == "s3" and stage_name: + if bucket_scheme == "s3" and self._stage_name: from_clause = ( - f"FROM '@{stage_name}/{bucket_url.hostname}/{bucket_url.path.lstrip('/')}'" + f"FROM '@{self._stage_name}/{bucket_url.hostname}/{bucket_url.path.lstrip('/')}'" ) else: raise LoadJobTerminalException( - file_path, "Only s3 staging currently supported in Dremio destination" + self._file_path, "Only s3 staging currently supported in Dremio destination" ) source_format = file_name.split(".")[-1] - client.execute_sql(f"""COPY INTO {qualified_table_name} + self._sql_client.execute_sql(f"""COPY INTO {qualified_table_name} {from_clause} FILE_FORMAT '{source_format}' """) - def state(self) -> TLoadJobState: - return "completed" - - def exception(self) -> str: - raise NotImplementedError() - class DremioClient(SqlJobClientWithStaging, SupportsStagingDestination): def __init__( @@ -153,21 +152,18 @@ def __init__( self.sql_client: DremioSqlClient = sql_client # type: ignore self.type_mapper = DremioTypeMapper(self.capabilities) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - job = super().start_file_load(table, file_path, load_id) + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + job = super().create_load_job(table, file_path, load_id, restore) if not job: job = DremioLoadJob( file_path=file_path, - table_name=table["name"], - client=self.sql_client, stage_name=self.config.staging_data_source, ) return job - def restore_file_load(self, file_path: str) -> LoadJob: - return EmptyLoadJob.from_file_path(file_path, "completed") - def _get_table_update_sql( self, table_name: str, @@ -205,7 +201,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" ) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: return [DremioMergeJob.from_table_chain(table_chain, self.sql_client)] def _make_add_column_sql( diff --git a/dlt/destinations/impl/duckdb/duck.py b/dlt/destinations/impl/duckdb/duck.py index 10d4fc13de..2926435edc 100644 --- a/dlt/destinations/impl/duckdb/duck.py +++ b/dlt/destinations/impl/duckdb/duck.py @@ -5,7 +5,7 @@ from dlt.common.data_types import TDataType from dlt.common.exceptions import TerminalValueError from dlt.common.schema import TColumnSchema, TColumnHint, Schema -from dlt.common.destination.reference import LoadJob, FollowupJob, TLoadJobState +from dlt.common.destination.reference import RunnableLoadJob, HasFollowupJobs, LoadJob from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat from dlt.common.storages.file_storage import FileStorage from dlt.common.utils import maybe_context @@ -113,12 +113,16 @@ def from_db_type( return super().from_db_type(db_type, precision, scale) -class DuckDbCopyJob(LoadJob, FollowupJob): - def __init__(self, table_name: str, file_path: str, sql_client: DuckDbSqlClient) -> None: - super().__init__(FileStorage.get_file_name_from_file_path(file_path)) +class DuckDbCopyJob(RunnableLoadJob, HasFollowupJobs): + def __init__(self, file_path: str) -> None: + super().__init__(file_path) + self._job_client: "DuckDbClient" = None - qualified_table_name = sql_client.make_qualified_table_name(table_name) - if file_path.endswith("parquet"): + def run(self) -> None: + self._sql_client = self._job_client.sql_client + + qualified_table_name = self._sql_client.make_qualified_table_name(self.load_table_name) + if self._file_path.endswith("parquet"): source_format = "PARQUET" options = "" # lock when creating a new lock @@ -127,27 +131,21 @@ def __init__(self, table_name: str, file_path: str, sql_client: DuckDbSqlClient) lock: threading.Lock = TABLES_LOCKS.setdefault( qualified_table_name, threading.Lock() ) - elif file_path.endswith("jsonl"): + elif self._file_path.endswith("jsonl"): # NOTE: loading JSON does not work in practice on duckdb: the missing keys fail the load instead of being interpreted as NULL source_format = "JSON" # newline delimited, compression auto - options = ", COMPRESSION GZIP" if FileStorage.is_gzipped(file_path) else "" + options = ", COMPRESSION GZIP" if FileStorage.is_gzipped(self._file_path) else "" lock = None else: - raise ValueError(file_path) + raise ValueError(self._file_path) with maybe_context(lock): - with sql_client.begin_transaction(): - sql_client.execute_sql( - f"COPY {qualified_table_name} FROM '{file_path}' ( FORMAT" + with self._sql_client.begin_transaction(): + self._sql_client.execute_sql( + f"COPY {qualified_table_name} FROM '{self._file_path}' ( FORMAT" f" {source_format} {options});" ) - def state(self) -> TLoadJobState: - return "completed" - - def exception(self) -> str: - raise NotImplementedError() - class DuckDbClient(InsertValuesJobClient): def __init__( @@ -168,10 +166,12 @@ def __init__( self.active_hints = HINT_TO_POSTGRES_ATTR if self.config.create_indexes else {} self.type_mapper = DuckDbTypeMapper(self.capabilities) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - job = super().start_file_load(table, file_path, load_id) + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + job = super().create_load_job(table, file_path, load_id, restore) if not job: - job = DuckDbCopyJob(table["name"], file_path, self.sql_client) + job = DuckDbCopyJob(file_path) return job def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: diff --git a/dlt/destinations/impl/dummy/configuration.py b/dlt/destinations/impl/dummy/configuration.py index a9fdb1f47d..7bc1d9e943 100644 --- a/dlt/destinations/impl/dummy/configuration.py +++ b/dlt/destinations/impl/dummy/configuration.py @@ -21,13 +21,29 @@ class DummyClientConfiguration(DestinationClientConfiguration): loader_file_format: TLoaderFileFormat = "jsonl" fail_schema_update: bool = False fail_prob: float = 0.0 + """probability of terminal fail""" retry_prob: float = 0.0 + """probability of job retry""" completed_prob: float = 0.0 + """probablibitly of successful job completion""" exception_prob: float = 0.0 - """probability of exception when checking job status""" + """probability of exception transient exception when running job""" timeout: float = 10.0 - fail_in_init: bool = True + """timeout time""" + fail_terminally_in_init: bool = False + """raise terminal exception in job init""" + fail_transiently_in_init: bool = False + """raise transient exception in job init""" + # new jobs workflows create_followup_jobs: bool = False - + """create followup job for individual jobs""" + fail_followup_job_creation: bool = False + """Raise generic exception during followupjob creation""" + fail_table_chain_followup_job_creation: bool = False + """Raise generic exception during tablechain followupjob creation""" + create_followup_table_chain_sql_jobs: bool = False + """create a table chain merge job which is guaranteed to fail""" + create_followup_table_chain_reference_jobs: bool = False + """create table chain jobs which succeed """ credentials: DummyClientCredentials = None diff --git a/dlt/destinations/impl/dummy/dummy.py b/dlt/destinations/impl/dummy/dummy.py index c41b7dca61..7d406c969f 100644 --- a/dlt/destinations/impl/dummy/dummy.py +++ b/dlt/destinations/impl/dummy/dummy.py @@ -12,7 +12,8 @@ Iterable, List, ) - +import os +import time from dlt.common.pendulum import pendulum from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.storages import FileStorage @@ -23,79 +24,88 @@ DestinationTransientException, ) from dlt.common.destination.reference import ( + HasFollowupJobs, FollowupJob, - NewLoadJob, SupportsStagingDestination, TLoadJobState, - LoadJob, + RunnableLoadJob, JobClientBase, WithStagingDataset, + LoadJob, ) +from dlt.destinations.sql_jobs import SqlMergeFollowupJob from dlt.destinations.exceptions import ( LoadJobNotExistsException, LoadJobInvalidStateTransitionException, ) from dlt.destinations.impl.dummy.configuration import DummyClientConfiguration -from dlt.destinations.job_impl import NewReferenceJob +from dlt.destinations.job_impl import ReferenceFollowupJob -class LoadDummyBaseJob(LoadJob): +class LoadDummyBaseJob(RunnableLoadJob): def __init__(self, file_name: str, config: DummyClientConfiguration) -> None: + super().__init__(file_name) self.config = copy(config) - self._status: TLoadJobState = "running" - self._exception: str = None self.start_time: float = pendulum.now().timestamp() - super().__init__(file_name) - if config.fail_in_init: - s = self.state() - if s == "failed": - raise DestinationTerminalException(self._exception) - if s == "retry": - raise DestinationTransientException(self._exception) - - def state(self) -> TLoadJobState: - # this should poll the server for a job status, here we simulate various outcomes - if self._status == "running": + + if self.config.fail_terminally_in_init: + raise DestinationTerminalException(self._exception) + if self.config.fail_transiently_in_init: + raise Exception(self._exception) + + def run(self) -> None: + while True: + # simulate generic exception (equals retry) c_r = random.random() if self.config.exception_prob >= c_r: - raise DestinationTransientException("Dummy job status raised exception") + # this will make the job go to a retry state with a generic exception + raise Exception("Dummy job status raised exception") + + # timeout condition (terminal) n = pendulum.now().timestamp() if n - self.start_time > self.config.timeout: - self._status = "failed" - self._exception = "failed due to timeout" - else: - c_r = random.random() - if self.config.completed_prob >= c_r: - self._status = "completed" - else: - c_r = random.random() - if self.config.retry_prob >= c_r: - self._status = "retry" - self._exception = "a random retry occured" - else: - c_r = random.random() - if self.config.fail_prob >= c_r: - self._status = "failed" - self._exception = "a random fail occured" - - return self._status - - def exception(self) -> str: - # this will typically call server for error messages - return self._exception - - def retry(self) -> None: - if self._status != "retry": - raise LoadJobInvalidStateTransitionException(self._status, "retry") - self._status = "retry" - - -class LoadDummyJob(LoadDummyBaseJob, FollowupJob): - def create_followup_jobs(self, final_state: TLoadJobState) -> List[NewLoadJob]: + # this will make the the job go to a failed state + raise DestinationTerminalException("failed due to timeout") + + # success + c_r = random.random() + if self.config.completed_prob >= c_r: + # this will make the run function exit and the job go to a completed state + break + + # retry prob + c_r = random.random() + if self.config.retry_prob >= c_r: + # this will make the job go to a retry state + raise DestinationTransientException("a random retry occured") + + # fail prob + c_r = random.random() + if self.config.fail_prob >= c_r: + # this will make the the job go to a failed state + raise DestinationTerminalException("a random fail occured") + + time.sleep(0.1) + + +class DummyFollowupJob(ReferenceFollowupJob): + def __init__( + self, original_file_name: str, remote_paths: List[str], config: DummyClientConfiguration + ) -> None: + self.config = config + if config.fail_followup_job_creation: + raise Exception("Failed to create followup job") + super().__init__(original_file_name=original_file_name, remote_paths=remote_paths) + + +class LoadDummyJob(LoadDummyBaseJob, HasFollowupJobs): + def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJob]: if self.config.create_followup_jobs and final_state == "completed": - new_job = NewReferenceJob( - file_name=self.file_name(), status="running", remote_path=self._file_name + new_job = DummyFollowupJob( + original_file_name=self.file_name(), + remote_paths=[self._file_name], + config=self.config, ) CREATED_FOLLOWUP_JOBS[new_job.job_id()] = new_job return [new_job] @@ -103,7 +113,9 @@ def create_followup_jobs(self, final_state: TLoadJobState) -> List[NewLoadJob]: JOBS: Dict[str, LoadDummyBaseJob] = {} -CREATED_FOLLOWUP_JOBS: Dict[str, NewLoadJob] = {} +CREATED_FOLLOWUP_JOBS: Dict[str, FollowupJob] = {} +CREATED_TABLE_CHAIN_FOLLOWUP_JOBS: Dict[str, FollowupJob] = {} +RETRIED_JOBS: Dict[str, LoadDummyBaseJob] = {} class DummyClient(JobClientBase, SupportsStagingDestination, WithStagingDataset): @@ -140,31 +152,41 @@ def update_stored_schema( ) return applied_update - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: job_id = FileStorage.get_file_name_from_file_path(file_path) - file_name = FileStorage.get_file_name_from_file_path(file_path) + if restore and job_id not in JOBS: + raise LoadJobNotExistsException(job_id) # return existing job if already there if job_id not in JOBS: - JOBS[job_id] = self._create_job(file_name) + JOBS[job_id] = self._create_job(file_path) else: job = JOBS[job_id] - if job.state == "retry": - job.retry() + # update config of existing job in case it was changed in tests + job.config = self.config + RETRIED_JOBS[job_id] = job return JOBS[job_id] - def restore_file_load(self, file_path: str) -> LoadJob: - job_id = FileStorage.get_file_name_from_file_path(file_path) - if job_id not in JOBS: - raise LoadJobNotExistsException(job_id) - return JOBS[job_id] - def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, - ) -> List[NewLoadJob]: + ) -> List[FollowupJob]: """Creates a list of followup jobs that should be executed after a table chain is completed""" + + # if sql job follow up is configure we schedule a merge job that will always fail + if self.config.fail_table_chain_followup_job_creation: + raise Exception("Failed to create table chain followup job") + if self.config.create_followup_table_chain_sql_jobs: + return [SqlMergeFollowupJob.from_table_chain(table_chain, self)] # type: ignore + if self.config.create_followup_table_chain_reference_jobs: + table_job_paths = [job.file_path for job in completed_table_chain_jobs] + file_name = FileStorage.get_file_name_from_file_path(table_job_paths[0]) + job = ReferenceFollowupJob(file_name, table_job_paths) + CREATED_TABLE_CHAIN_FOLLOWUP_JOBS[job.job_id()] = job + return [job] return [] def complete_load(self, load_id: str) -> None: @@ -190,7 +212,7 @@ def __exit__( pass def _create_job(self, job_id: str) -> LoadDummyBaseJob: - if NewReferenceJob.is_reference_job(job_id): + if ReferenceFollowupJob.is_reference_job(job_id): return LoadDummyBaseJob(job_id, config=self.config) else: return LoadDummyJob(job_id, config=self.config) diff --git a/dlt/destinations/impl/dummy/factory.py b/dlt/destinations/impl/dummy/factory.py index c2792fc432..8cf0408ec1 100644 --- a/dlt/destinations/impl/dummy/factory.py +++ b/dlt/destinations/impl/dummy/factory.py @@ -60,7 +60,9 @@ def adjust_capabilities( ) -> DestinationCapabilitiesContext: caps = super().adjust_capabilities(caps, config, naming) additional_formats: t.List[TLoaderFileFormat] = ( - ["reference"] if config.create_followup_jobs else [] # type:ignore[list-item] + ["reference"] + if (config.create_followup_jobs or config.create_followup_table_chain_reference_jobs) + else [] ) caps.preferred_loader_file_format = config.loader_file_format caps.supported_loader_file_formats = additional_formats + [config.loader_file_format] diff --git a/dlt/destinations/impl/filesystem/factory.py b/dlt/destinations/impl/filesystem/factory.py index 31b61c6cb1..94e46c770b 100644 --- a/dlt/destinations/impl/filesystem/factory.py +++ b/dlt/destinations/impl/filesystem/factory.py @@ -7,6 +7,7 @@ from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration from dlt.destinations.impl.filesystem.typing import TCurrentDateTime, TExtraPlaceholders +from dlt.common.normalizers.naming.naming import NamingConvention if t.TYPE_CHECKING: from dlt.destinations.impl.filesystem.filesystem import FilesystemClient @@ -28,7 +29,7 @@ class filesystem(Destination[FilesystemDestinationClientConfiguration, "Filesyst spec = FilesystemDestinationClientConfiguration def _raw_capabilities(self) -> DestinationCapabilitiesContext: - return DestinationCapabilitiesContext.generic_capabilities( + caps = DestinationCapabilitiesContext.generic_capabilities( preferred_loader_file_format="jsonl", loader_file_format_adapter=loader_file_format_adapter, supported_table_formats=["delta"], @@ -37,6 +38,10 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: # loader file format) supported_merge_strategies=["upsert"], ) + caps.supported_loader_file_formats = list(caps.supported_loader_file_formats) + [ + "reference" + ] + return caps @property def client_class(self) -> t.Type["FilesystemClient"]: diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index ef4702b17d..4f57d25389 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -17,28 +17,28 @@ from dlt.common.storages import FileStorage, fsspec_from_config from dlt.common.storages.load_package import ( LoadJobInfo, - ParsedLoadJobFileName, TPipelineStateDoc, load_package as current_load_package, ) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( - NewLoadJob, + FollowupJob, TLoadJobState, - LoadJob, + RunnableLoadJob, JobClientBase, - FollowupJob, + HasFollowupJobs, WithStagingDataset, WithStateSync, StorageSchemaInfo, StateInfo, - DoNothingJob, - DoNothingFollowupJob, + LoadJob, ) from dlt.common.destination.exceptions import DestinationUndefinedEntity -from dlt.destinations.job_impl import EmptyLoadJob, NewReferenceJob +from dlt.destinations.job_impl import ( + ReferenceFollowupJob, + FinalizedLoadJob, +) from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration -from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations import path_utils from dlt.destinations.fs_client import FSClientBase @@ -46,31 +46,27 @@ FILENAME_SEPARATOR = "__" -class LoadFilesystemJob(LoadJob): +class FilesystemLoadJob(RunnableLoadJob): def __init__( self, - client: "FilesystemClient", - local_path: str, - load_id: str, - table: TTableSchema, + file_path: str, ) -> None: - self.client = client - self.table = table - self.is_local_filesystem = client.config.protocol == "file" + super().__init__(file_path) + self._job_client: FilesystemClient = None + + def run(self) -> None: # pick local filesystem pathlib or posix for buckets + self.is_local_filesystem = self._job_client.config.protocol == "file" self.pathlib = os.path if self.is_local_filesystem else posixpath - file_name = FileStorage.get_file_name_from_file_path(local_path) - super().__init__(file_name) - self.destination_file_name = path_utils.create_path( - client.config.layout, - file_name, - client.schema.name, - load_id, - current_datetime=client.config.current_datetime, + self._job_client.config.layout, + self._file_name, + self._job_client.schema.name, + self._load_id, + current_datetime=self._job_client.config.current_datetime, load_package_timestamp=dlt.current.load_package()["state"]["created_at"], - extra_placeholders=client.config.extra_placeholders, + extra_placeholders=self._job_client.config.extra_placeholders, ) # We would like to avoid failing for local filesystem where # deeply nested directory will not exist before writing a file. @@ -79,48 +75,26 @@ def __init__( # remote_path = f"{client.config.protocol}://{posixpath.join(dataset_path, destination_file_name)}" remote_path = self.make_remote_path() if self.is_local_filesystem: - client.fs_client.makedirs(self.pathlib.dirname(remote_path), exist_ok=True) - client.fs_client.put_file(local_path, remote_path) + self._job_client.fs_client.makedirs(self.pathlib.dirname(remote_path), exist_ok=True) + self._job_client.fs_client.put_file(self._file_path, remote_path) def make_remote_path(self) -> str: """Returns path on the remote filesystem to which copy the file, without scheme. For local filesystem a native path is used""" # path.join does not normalize separators and available # normalization functions are very invasive and may string the trailing separator return self.pathlib.join( # type: ignore[no-any-return] - self.client.dataset_path, + self._job_client.dataset_path, path_utils.normalize_path_sep(self.pathlib, self.destination_file_name), ) - def state(self) -> TLoadJobState: - return "completed" - - def exception(self) -> str: - raise NotImplementedError() - - -class DeltaLoadFilesystemJob(NewReferenceJob): - def __init__( - self, - client: "FilesystemClient", - table: TTableSchema, - table_jobs: Sequence[LoadJobInfo], - ) -> None: - self.client = client - self.table = table - self.table_jobs = table_jobs - ref_file_name = ParsedLoadJobFileName( - table["name"], ParsedLoadJobFileName.new_file_id(), 0, "reference" - ).file_name() +class DeltaLoadFilesystemJob(FilesystemLoadJob): + def __init__(self, file_path: str) -> None: super().__init__( - file_name=ref_file_name, - status="running", - remote_path=self.client.make_remote_uri(self.make_remote_path()), + file_path=file_path, ) - self.write() - - def write(self) -> None: + def run(self) -> None: from dlt.common.libs.pyarrow import pyarrow as pa from dlt.common.libs.deltalake import ( DeltaTable, @@ -137,12 +111,14 @@ def write(self) -> None: ) # create Arrow dataset from Parquet files - file_paths = [job.file_path for job in self.table_jobs] + file_paths = ReferenceFollowupJob.resolve_references(self._file_path) arrow_ds = pa.dataset.dataset(file_paths) # create Delta table object - dt_path = self.client.make_remote_uri(self.make_remote_path()) - storage_options = _deltalake_storage_options(self.client.config) + dt_path = self._job_client.make_remote_uri( + self._job_client.get_table_dir(self.load_table_name) + ) + storage_options = _deltalake_storage_options(self._job_client.config) dt = try_get_deltatable(dt_path, storage_options=storage_options) # explicitly check if there is data @@ -159,15 +135,15 @@ def write(self) -> None: arrow_rbr = arrow_ds.scanner().to_reader() # RecordBatchReader - if self.table["write_disposition"] == "merge" and dt is not None: - assert self.table["x-merge-strategy"] in self.client.capabilities.supported_merge_strategies # type: ignore[typeddict-item] + if self._load_table["write_disposition"] == "merge" and dt is not None: + assert self._load_table["x-merge-strategy"] in self._job_client.capabilities.supported_merge_strategies # type: ignore[typeddict-item] - if self.table["x-merge-strategy"] == "upsert": # type: ignore[typeddict-item] - if "parent" in self.table: - unique_column = get_first_column_name_with_prop(self.table, "unique") + if self._load_table["x-merge-strategy"] == "upsert": # type: ignore[typeddict-item] + if "parent" in self._load_table: + unique_column = get_first_column_name_with_prop(self._load_table, "unique") predicate = f"target.{unique_column} = source.{unique_column}" else: - primary_keys = get_columns_names_with_prop(self.table, "primary_key") + primary_keys = get_columns_names_with_prop(self._load_table, "primary_key") predicate = " AND ".join([f"target.{c} = source.{c}" for c in primary_keys]) qry = ( @@ -187,26 +163,21 @@ def write(self) -> None: write_delta_table( table_or_uri=dt_path if dt is None else dt, data=arrow_rbr, - write_disposition=self.table["write_disposition"], + write_disposition=self._load_table["write_disposition"], storage_options=storage_options, ) - def make_remote_path(self) -> str: - # directory path, not file path - return self.client.get_table_dir(self.table["name"]) - - def state(self) -> TLoadJobState: - return "completed" - -class FollowupFilesystemJob(FollowupJob, LoadFilesystemJob): - def create_followup_jobs(self, final_state: TLoadJobState) -> List[NewLoadJob]: +class FilesystemLoadJobWithFollowup(HasFollowupJobs, FilesystemLoadJob): + def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJob]: jobs = super().create_followup_jobs(final_state) - if final_state == "completed": - ref_job = NewReferenceJob( - file_name=self.file_name(), - status="running", - remote_path=self.client.make_remote_uri(self.make_remote_path()), + if self._load_table.get("table_format") == "delta": + # delta table jobs only require table chain followup jobs + pass + elif final_state == "completed": + ref_job = ReferenceFollowupJob( + original_file_name=self.file_name(), + remote_paths=[self._job_client.make_remote_uri(self.make_remote_path())], ) jobs.append(ref_job) return jobs @@ -287,7 +258,7 @@ def drop_tables(self, *tables: str, delete_schema: bool = True) -> None: self._delete_file(filename) def truncate_tables(self, table_names: List[str]) -> None: - """Truncate a set of tables with given `table_names`""" + """Truncate a set of regular tables with given `table_names`""" table_dirs = set(self.get_table_dirs(table_names)) table_prefixes = [self.get_table_prefix(t) for t in table_names] for table_dir in table_dirs: @@ -383,22 +354,25 @@ def list_files_with_prefixes(self, table_dir: str, prefixes: List[str]) -> List[ def is_storage_initialized(self) -> bool: return self.fs_client.exists(self.pathlib.join(self.dataset_path, INIT_FILE_NAME)) # type: ignore[no-any-return] - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: # skip the state table, we create a jsonl file in the complete_load step # this does not apply to scenarios where we are using filesystem as staging # where we want to load the state the regular way if table["name"] == self.schema.state_table_name and not self.config.as_staging: - return DoNothingJob(file_path) + return FinalizedLoadJob(file_path) if table.get("table_format") == "delta": import dlt.common.libs.deltalake # assert dependencies are installed - return DoNothingFollowupJob(file_path) - - cls = FollowupFilesystemJob if self.config.as_staging else LoadFilesystemJob - return cls(self, file_path, load_id, table) + # a reference job for a delta table indicates a table chain followup job + if ReferenceFollowupJob.is_reference_job(file_path): + return DeltaLoadFilesystemJob(file_path) + # otherwise just continue + return FilesystemLoadJobWithFollowup(file_path) - def restore_file_load(self, file_path: str) -> LoadJob: - return EmptyLoadJob.from_file_path(file_path, "completed") + cls = FilesystemLoadJobWithFollowup if self.config.as_staging else FilesystemLoadJob + return cls(file_path) def make_remote_uri(self, remote_path: str) -> str: """Returns uri to the remote filesystem to which copy the file""" @@ -601,26 +575,18 @@ def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, - ) -> List[NewLoadJob]: - def get_table_jobs( - table_jobs: Sequence[LoadJobInfo], table_name: str - ) -> Sequence[LoadJobInfo]: - return [job for job in table_jobs if job.job_file_info.table_name == table_name] - + ) -> List[FollowupJob]: assert completed_table_chain_jobs is not None jobs = super().create_table_chain_completed_followup_jobs( table_chain, completed_table_chain_jobs ) - table_format = table_chain[0].get("table_format") - if table_format == "delta": - delta_jobs = [ - DeltaLoadFilesystemJob( - self, - table=self.prepare_load_table(table["name"]), - table_jobs=get_table_jobs(completed_table_chain_jobs, table["name"]), - ) - for table in table_chain - ] - jobs.extend(delta_jobs) - + if table_chain[0].get("table_format") == "delta": + for table in table_chain: + table_job_paths = [ + job.file_path + for job in completed_table_chain_jobs + if job.job_file_info.table_name == table["name"] + ] + file_name = FileStorage.get_file_name_from_file_path(table_job_paths[0]) + jobs.append(ReferenceFollowupJob(file_name, table_job_paths)) return jobs diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 8265e50fbf..78a37952b9 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -34,10 +34,11 @@ from dlt.common.destination.reference import ( JobClientBase, WithStateSync, - LoadJob, + RunnableLoadJob, StorageSchemaInfo, StateInfo, TLoadJobState, + LoadJob, ) from dlt.common.pendulum import timedelta from dlt.common.schema import Schema, TTableSchema, TSchemaTables @@ -69,7 +70,7 @@ generate_uuid, set_non_standard_providers_environment_variables, ) -from dlt.destinations.job_impl import EmptyLoadJob +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: @@ -686,17 +687,12 @@ def complete_load(self, load_id: str) -> None: write_disposition=write_disposition, ) - def restore_file_load(self, file_path: str) -> LoadJob: - return EmptyLoadJob.from_file_path(file_path, "completed") - - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - return LoadLanceDBJob( - self.schema, - table, - file_path, + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + return LanceDBLoadJob( + file_path=file_path, type_mapper=self.type_mapper, - db_client=self.db_client, - client_config=self.config, model_func=self.model_func, fq_table_name=self.make_qualified_table_name(table["name"]), ) @@ -705,66 +701,56 @@ def table_exists(self, table_name: str) -> bool: return table_name in self.db_client.table_names() -class LoadLanceDBJob(LoadJob): +class LanceDBLoadJob(RunnableLoadJob): arrow_schema: TArrowSchema def __init__( self, - schema: Schema, - table_schema: TTableSchema, - local_path: str, + file_path: str, type_mapper: LanceDBTypeMapper, - db_client: DBConnection, - client_config: LanceDBClientConfiguration, model_func: TextEmbeddingFunction, fq_table_name: str, ) -> None: - file_name = FileStorage.get_file_name_from_file_path(local_path) - super().__init__(file_name) - self.schema: Schema = schema - self.table_schema: TTableSchema = table_schema - self.db_client: DBConnection = db_client - self.type_mapper: TypeMapper = type_mapper - self.table_name: str = table_schema["name"] - self.fq_table_name: str = fq_table_name - self.unique_identifiers: Sequence[str] = list_merge_identifiers(table_schema) - self.embedding_fields: List[str] = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) - self.embedding_model_func: TextEmbeddingFunction = model_func - self.embedding_model_dimensions: int = client_config.embedding_model_dimensions - self.id_field_name: str = client_config.id_field_name - self.write_disposition: TWriteDisposition = cast( - TWriteDisposition, self.table_schema.get("write_disposition", "append") + super().__init__(file_path) + self._type_mapper: TypeMapper = type_mapper + self._fq_table_name: str = fq_table_name + self._model_func = model_func + self._job_client: "LanceDBClient" = None + + def run(self) -> None: + self._db_client: DBConnection = self._job_client.db_client + self._embedding_model_func: TextEmbeddingFunction = self._model_func + self._embedding_model_dimensions: int = self._job_client.config.embedding_model_dimensions + self._id_field_name: str = self._job_client.config.id_field_name + + unique_identifiers: Sequence[str] = list_merge_identifiers(self._load_table) + write_disposition: TWriteDisposition = cast( + TWriteDisposition, self._load_table.get("write_disposition", "append") ) - with FileStorage.open_zipsafe_ro(local_path) as f: + with FileStorage.open_zipsafe_ro(self._file_path) as f: records: List[DictStrAny] = [json.loads(line) for line in f] - if self.table_schema not in self.schema.dlt_tables(): + if self._load_table not in self._schema.dlt_tables(): for record in records: # Add reserved ID fields. uuid_id = ( - generate_uuid(record, self.unique_identifiers, self.fq_table_name) - if self.unique_identifiers + generate_uuid(record, unique_identifiers, self._fq_table_name) + if unique_identifiers else str(uuid.uuid4()) ) - record.update({self.id_field_name: uuid_id}) + record.update({self._id_field_name: uuid_id}) # LanceDB expects all fields in the target arrow table to be present in the data payload. # We add and set these missing fields, that are fields not present in the target schema, to NULL. - missing_fields = set(self.table_schema["columns"]) - set(record) + missing_fields = set(self._load_table["columns"]) - set(record) for field in missing_fields: record[field] = None upload_batch( records, - db_client=db_client, - table_name=self.fq_table_name, - write_disposition=self.write_disposition, - id_field_name=self.id_field_name, + db_client=self._db_client, + table_name=self._fq_table_name, + write_disposition=write_disposition, + id_field_name=self._id_field_name, ) - - def state(self) -> TLoadJobState: - return "completed" - - def exception(self) -> str: - raise NotImplementedError() diff --git a/dlt/destinations/impl/mssql/mssql.py b/dlt/destinations/impl/mssql/mssql.py index ec4a54d6f7..a67423a873 100644 --- a/dlt/destinations/impl/mssql/mssql.py +++ b/dlt/destinations/impl/mssql/mssql.py @@ -1,12 +1,12 @@ from typing import Dict, Optional, Sequence, List, Any from dlt.common.exceptions import TerminalValueError -from dlt.common.destination.reference import NewLoadJob +from dlt.common.destination.reference import FollowupJob from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.schema import TColumnSchema, TColumnHint, Schema from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat -from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob, SqlJobParams +from dlt.destinations.sql_jobs import SqlStagingCopyFollowupJob, SqlMergeFollowupJob, SqlJobParams from dlt.destinations.insert_job_client import InsertValuesJobClient @@ -85,7 +85,7 @@ def from_db_type( return super().from_db_type(db_type, precision, scale) -class MsSqlStagingCopyJob(SqlStagingCopyJob): +class MsSqlStagingCopyJob(SqlStagingCopyFollowupJob): @classmethod def generate_sql( cls, @@ -110,7 +110,7 @@ def generate_sql( return sql -class MsSqlMergeJob(SqlMergeJob): +class MsSqlMergeJob(SqlMergeFollowupJob): @classmethod def gen_key_table_clauses( cls, @@ -127,7 +127,7 @@ def gen_key_table_clauses( f" {staging_root_table_name} WHERE" f" {' OR '.join([c.format(d=root_table_name,s=staging_root_table_name) for c in key_clauses])})" ] - return SqlMergeJob.gen_key_table_clauses( + return SqlMergeFollowupJob.gen_key_table_clauses( root_table_name, staging_root_table_name, key_clauses, for_delete ) @@ -137,7 +137,7 @@ def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: @classmethod def _new_temp_table_name(cls, name_prefix: str, sql_client: SqlClientBase[Any]) -> str: - name = SqlMergeJob._new_temp_table_name(name_prefix, sql_client) + name = SqlMergeFollowupJob._new_temp_table_name(name_prefix, sql_client) return "#" + name @@ -160,7 +160,7 @@ def __init__( self.active_hints = HINT_TO_MSSQL_ATTR if self.config.create_indexes else {} self.type_mapper = MsSqlTypeMapper(self.capabilities) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: return [MsSqlMergeJob.from_table_chain(table_chain, self.sql_client)] def _make_add_column_sql( @@ -189,7 +189,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] - ) -> List[NewLoadJob]: + ) -> List[FollowupJob]: if self.config.replace_strategy == "staging-optimized": return [MsSqlStagingCopyJob.from_table_chain(table_chain, self.sql_client)] return super()._create_replace_followup_jobs(table_chain) diff --git a/dlt/destinations/impl/postgres/postgres.py b/dlt/destinations/impl/postgres/postgres.py index f47549fc4f..5ae5f27a6e 100644 --- a/dlt/destinations/impl/postgres/postgres.py +++ b/dlt/destinations/impl/postgres/postgres.py @@ -6,14 +6,20 @@ DestinationInvalidFileFormat, DestinationTerminalException, ) -from dlt.common.destination.reference import FollowupJob, LoadJob, NewLoadJob, TLoadJobState +from dlt.common.destination.reference import ( + HasFollowupJobs, + RunnableLoadJob, + FollowupJob, + LoadJob, + TLoadJobState, +) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.exceptions import TerminalValueError from dlt.common.schema import TColumnSchema, TColumnHint, Schema from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat from dlt.common.storages.file_storage import FileStorage -from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams +from dlt.destinations.sql_jobs import SqlStagingCopyFollowupJob, SqlJobParams from dlt.destinations.insert_job_client import InsertValuesJobClient from dlt.destinations.impl.postgres.sql_client import Psycopg2SqlClient from dlt.destinations.impl.postgres.configuration import PostgresClientConfiguration @@ -85,7 +91,7 @@ def from_db_type( return super().from_db_type(db_type, precision, scale) -class PostgresStagingCopyJob(SqlStagingCopyJob): +class PostgresStagingCopyJob(SqlStagingCopyFollowupJob): @classmethod def generate_sql( cls, @@ -110,21 +116,24 @@ def generate_sql( return sql -class PostgresCsvCopyJob(LoadJob, FollowupJob): - def __init__(self, table: TTableSchema, file_path: str, client: "PostgresClient") -> None: - super().__init__(FileStorage.get_file_name_from_file_path(file_path)) - config = client.config - sql_client = client.sql_client - csv_format = config.csv_format or CsvFormatConfiguration() - table_name = table["name"] +class PostgresCsvCopyJob(RunnableLoadJob, HasFollowupJobs): + def __init__(self, file_path: str) -> None: + super().__init__(file_path) + self._job_client: PostgresClient = None + + def run(self) -> None: + self._config = self._job_client.config + sql_client = self._job_client.sql_client + csv_format = self._config.csv_format or CsvFormatConfiguration() + table_name = self.load_table_name sep = csv_format.delimiter if csv_format.on_error_continue: logger.warning( - f"When processing {file_path} on table {table_name} Postgres csv reader does not" - " support on_error_continue" + f"When processing {self._file_path} on table {table_name} Postgres csv reader does" + " not support on_error_continue" ) - with FileStorage.open_zipsafe_ro(file_path, "rb") as f: + with FileStorage.open_zipsafe_ro(self._file_path, "rb") as f: if csv_format.include_header: # all headers in first line headers_row: str = f.readline().decode(csv_format.encoding).strip() @@ -132,12 +141,12 @@ def __init__(self, table: TTableSchema, file_path: str, client: "PostgresClient" else: # read first row to figure out the headers split_first_row: str = f.readline().decode(csv_format.encoding).strip().split(sep) - split_headers = list(client.schema.get_table_columns(table_name).keys()) + split_headers = list(self._job_client.schema.get_table_columns(table_name).keys()) if len(split_first_row) > len(split_headers): raise DestinationInvalidFileFormat( "postgres", "csv", - file_path, + self._file_path, f"First row {split_first_row} has more rows than columns {split_headers} in" f" table {table_name}", ) @@ -158,7 +167,7 @@ def __init__(self, table: TTableSchema, file_path: str, client: "PostgresClient" split_columns = [] # detect columns with NULL to use in FORCE NULL # detect headers that are not in columns - for col in client.schema.get_table_columns(table_name).values(): + for col in self._job_client.schema.get_table_columns(table_name).values(): norm_col = sql_client.escape_column_name(col["name"], escape=True) split_columns.append(norm_col) if norm_col in split_headers and col.get("nullable", True): @@ -168,7 +177,7 @@ def __init__(self, table: TTableSchema, file_path: str, client: "PostgresClient" raise DestinationInvalidFileFormat( "postgres", "csv", - file_path, + self._file_path, f"Following headers {split_unknown_headers} cannot be matched to columns" f" {split_columns} of table {table_name}.", ) @@ -196,12 +205,6 @@ def __init__(self, table: TTableSchema, file_path: str, client: "PostgresClient" with sql_client.native_connection.cursor() as cursor: cursor.copy_expert(copy_sql, f, size=8192) - def state(self) -> TLoadJobState: - return "completed" - - def exception(self) -> str: - raise NotImplementedError() - class PostgresClient(InsertValuesJobClient): def __init__( @@ -222,10 +225,12 @@ def __init__( self.active_hints = HINT_TO_POSTGRES_ATTR if self.config.create_indexes else {} self.type_mapper = PostgresTypeMapper(self.capabilities) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - job = super().start_file_load(table, file_path, load_id) + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + job = super().create_load_job(table, file_path, load_id, restore) if not job and file_path.endswith("csv"): - job = PostgresCsvCopyJob(table, file_path, self) + job = PostgresCsvCopyJob(file_path) return job def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: @@ -241,7 +246,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] - ) -> List[NewLoadJob]: + ) -> List[FollowupJob]: if self.config.replace_strategy == "staging-optimized": return [PostgresStagingCopyJob.from_table_chain(table_chain, self.sql_client)] return super()._create_replace_followup_jobs(table_chain) diff --git a/dlt/destinations/impl/qdrant/qdrant_job_client.py b/dlt/destinations/impl/qdrant/qdrant_job_client.py index 28d7388701..65019c6626 100644 --- a/dlt/destinations/impl/qdrant/qdrant_job_client.py +++ b/dlt/destinations/impl/qdrant/qdrant_job_client.py @@ -13,12 +13,19 @@ version_table, ) from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import TLoadJobState, LoadJob, JobClientBase, WithStateSync +from dlt.common.destination.reference import ( + TLoadJobState, + RunnableLoadJob, + JobClientBase, + WithStateSync, + LoadJob, +) from dlt.common.destination.exceptions import DestinationUndefinedEntity + from dlt.common.storages import FileStorage from dlt.common.time import precise_time -from dlt.destinations.job_impl import EmptyLoadJob +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs from dlt.destinations.job_client_impl import StorageSchemaInfo, StateInfo from dlt.destinations.utils import get_pipeline_state_query_columns @@ -30,49 +37,49 @@ from qdrant_client.http.exceptions import UnexpectedResponse -class LoadQdrantJob(LoadJob): +class QDrantLoadJob(RunnableLoadJob): def __init__( self, - table_schema: TTableSchema, - local_path: str, - db_client: QC, + file_path: str, client_config: QdrantClientConfiguration, collection_name: str, ) -> None: - file_name = FileStorage.get_file_name_from_file_path(local_path) - super().__init__(file_name) - self.db_client = db_client - self.collection_name = collection_name - self.embedding_fields = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) - self.unique_identifiers = self._list_unique_identifiers(table_schema) - self.config = client_config - - with FileStorage.open_zipsafe_ro(local_path) as f: + super().__init__(file_path) + self._collection_name = collection_name + self._config = client_config + self._job_client: "QdrantClient" = None + + def run(self) -> None: + embedding_fields = get_columns_names_with_prop(self._load_table, VECTORIZE_HINT) + unique_identifiers = self._list_unique_identifiers(self._load_table) + with FileStorage.open_zipsafe_ro(self._file_path) as f: ids: List[str] docs, payloads, ids = [], [], [] for line in f: data = json.loads(line) point_id = ( - self._generate_uuid(data, self.unique_identifiers, self.collection_name) - if self.unique_identifiers + self._generate_uuid(data, unique_identifiers, self._collection_name) + if unique_identifiers else str(uuid.uuid4()) ) payloads.append(data) ids.append(point_id) - if len(self.embedding_fields) > 0: - docs.append(self._get_embedding_doc(data)) + if len(embedding_fields) > 0: + docs.append(self._get_embedding_doc(data, embedding_fields)) - if len(self.embedding_fields) > 0: - embedding_model = db_client._get_or_init_model(db_client.embedding_model_name) + if len(embedding_fields) > 0: + embedding_model = self._job_client.db_client._get_or_init_model( + self._job_client.db_client.embedding_model_name + ) embeddings = list( embedding_model.embed( docs, - batch_size=self.config.embedding_batch_size, - parallel=self.config.embedding_parallelism, + batch_size=self._config.embedding_batch_size, + parallel=self._config.embedding_parallelism, ) ) - vector_name = db_client.get_vector_field_name() + vector_name = self._job_client.db_client.get_vector_field_name() embeddings = [{vector_name: embedding.tolist()} for embedding in embeddings] else: embeddings = [{}] * len(ids) @@ -80,7 +87,7 @@ def __init__( self._upload_data(vectors=embeddings, ids=ids, payloads=payloads) - def _get_embedding_doc(self, data: Dict[str, Any]) -> str: + def _get_embedding_doc(self, data: Dict[str, Any], embedding_fields: List[str]) -> str: """Returns a document to generate embeddings for. Args: @@ -89,7 +96,7 @@ def _get_embedding_doc(self, data: Dict[str, Any]) -> str: Returns: str: A concatenated string of all the fields intended for embedding. """ - doc = "\n".join(str(data[key]) for key in self.embedding_fields) + doc = "\n".join(str(data[key]) for key in embedding_fields) return doc def _list_unique_identifiers(self, table_schema: TTableSchema) -> Sequence[str]: @@ -117,14 +124,14 @@ def _upload_data( vectors (Iterable[Any]): Embeddings to be uploaded to the collection payloads (Iterable[Any]): Payloads to be uploaded to the collection """ - self.db_client.upload_collection( - self.collection_name, + self._job_client.db_client.upload_collection( + self._collection_name, ids=ids, payload=payloads, vectors=vectors, - parallel=self.config.upload_parallelism, - batch_size=self.config.upload_batch_size, - max_retries=self.config.upload_max_retries, + parallel=self._config.upload_parallelism, + batch_size=self._config.upload_batch_size, + max_retries=self._config.upload_max_retries, ) def _generate_uuid( @@ -143,12 +150,6 @@ def _generate_uuid( data_id = "_".join(str(data[key]) for key in unique_identifiers) return str(uuid.uuid5(uuid.NAMESPACE_DNS, collection_name + data_id)) - def state(self) -> TLoadJobState: - return "completed" - - def exception(self) -> str: - raise NotImplementedError() - class QdrantClient(JobClientBase, WithStateSync): """Qdrant Destination Handler""" @@ -438,18 +439,15 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI return None raise - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - return LoadQdrantJob( - table, + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + return QDrantLoadJob( file_path, - db_client=self.db_client, client_config=self.config, collection_name=self._make_qualified_collection_name(table["name"]), ) - def restore_file_load(self, file_path: str) -> LoadJob: - return EmptyLoadJob.from_file_path(file_path, "completed") - def complete_load(self, load_id: str) -> None: values = [load_id, self.schema.name, 0, str(pendulum.now()), self.schema.version_hash] assert len(values) == len(self.loads_collection_properties) diff --git a/dlt/destinations/impl/redshift/redshift.py b/dlt/destinations/impl/redshift/redshift.py index 8eacc76d11..81abd57803 100644 --- a/dlt/destinations/impl/redshift/redshift.py +++ b/dlt/destinations/impl/redshift/redshift.py @@ -14,9 +14,10 @@ from dlt.common.destination.reference import ( - NewLoadJob, + FollowupJob, CredentialsConfiguration, SupportsStagingDestination, + LoadJob, ) from dlt.common.data_types import TDataType from dlt.common.destination.capabilities import DestinationCapabilitiesContext @@ -27,12 +28,12 @@ from dlt.common.configuration.specs import AwsCredentialsWithoutDefaults from dlt.destinations.insert_job_client import InsertValuesJobClient -from dlt.destinations.sql_jobs import SqlMergeJob +from dlt.destinations.sql_jobs import SqlMergeFollowupJob from dlt.destinations.exceptions import DatabaseTerminalException, LoadJobTerminalException -from dlt.destinations.job_client_impl import CopyRemoteFileLoadJob, LoadJob +from dlt.destinations.job_client_impl import CopyRemoteFileLoadJob from dlt.destinations.impl.postgres.sql_client import Psycopg2SqlClient from dlt.destinations.impl.redshift.configuration import RedshiftClientConfiguration -from dlt.destinations.job_impl import NewReferenceJob +from dlt.destinations.job_impl import ReferenceFollowupJob from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.type_mapping import TypeMapper @@ -123,16 +124,16 @@ def _maybe_make_terminal_exception_from_data_error( class RedshiftCopyFileLoadJob(CopyRemoteFileLoadJob): def __init__( self, - table: TTableSchema, file_path: str, - sql_client: SqlClientBase[Any], staging_credentials: Optional[CredentialsConfiguration] = None, staging_iam_role: str = None, ) -> None: + super().__init__(file_path, staging_credentials) self._staging_iam_role = staging_iam_role - super().__init__(table, file_path, sql_client, staging_credentials) + self._job_client: "RedshiftClient" = None - def execute(self, table: TTableSchema, bucket_path: str) -> None: + def run(self) -> None: + self._sql_client = self._job_client.sql_client # we assume s3 credentials where provided for the staging credentials = "" if self._staging_iam_role: @@ -148,11 +149,11 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: ) # get format - ext = os.path.splitext(bucket_path)[1][1:] + ext = os.path.splitext(self._bucket_path)[1][1:] file_type = "" dateformat = "" compression = "" - if table_schema_has_type(table, "time"): + if table_schema_has_type(self._load_table, "time"): raise LoadJobTerminalException( self.file_name(), f"Redshift cannot load TIME columns from {ext} files. Switch to direct INSERT file" @@ -160,7 +161,7 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: " `datetime.datetime`", ) if ext == "jsonl": - if table_schema_has_type(table, "binary"): + if table_schema_has_type(self._load_table, "binary"): raise LoadJobTerminalException( self.file_name(), "Redshift cannot load VARBYTE columns from json files. Switch to parquet to" @@ -170,7 +171,7 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: dateformat = "dateformat 'auto' timeformat 'auto'" compression = "GZIP" elif ext == "parquet": - if table_schema_has_type_with_precision(table, "binary"): + if table_schema_has_type_with_precision(self._load_table, "binary"): raise LoadJobTerminalException( self.file_name(), f"Redshift cannot load fixed width VARBYTE columns from {ext} files. Switch to" @@ -179,7 +180,7 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: file_type = "PARQUET" # if table contains complex types then SUPER field will be used. # https://docs.aws.amazon.com/redshift/latest/dg/ingest-super.html - if table_schema_has_type(table, "complex"): + if table_schema_has_type(self._load_table, "complex"): file_type += " SERIALIZETOJSON" else: raise ValueError(f"Unsupported file type {ext} for Redshift.") @@ -187,19 +188,15 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: with self._sql_client.begin_transaction(): # TODO: if we ever support csv here remember to add column names to COPY self._sql_client.execute_sql(f""" - COPY {self._sql_client.make_qualified_table_name(table['name'])} - FROM '{bucket_path}' + COPY {self._sql_client.make_qualified_table_name(self.load_table_name)} + FROM '{self._bucket_path}' {file_type} {dateformat} {compression} {credentials} MAXERROR 0;""") - def exception(self) -> str: - # this part of code should be never reached - raise NotImplementedError() - -class RedshiftMergeJob(SqlMergeJob): +class RedshiftMergeJob(SqlMergeFollowupJob): @classmethod def gen_key_table_clauses( cls, @@ -218,7 +215,7 @@ def gen_key_table_clauses( f" {staging_root_table_name} WHERE" f" {' OR '.join([c.format(d=root_table_name,s=staging_root_table_name) for c in key_clauses])})" ] - return SqlMergeJob.gen_key_table_clauses( + return SqlMergeFollowupJob.gen_key_table_clauses( root_table_name, staging_root_table_name, key_clauses, for_delete ) @@ -241,7 +238,7 @@ def __init__( self.config: RedshiftClientConfiguration = config self.type_mapper = RedshiftTypeMapper(self.capabilities) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: return [RedshiftMergeJob.from_table_chain(table_chain, self.sql_client)] def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: @@ -255,17 +252,17 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: """Starts SqlLoadJob for files ending with .sql or returns None to let derived classes to handle their specific jobs""" - job = super().start_file_load(table, file_path, load_id) + job = super().create_load_job(table, file_path, load_id, restore) if not job: - assert NewReferenceJob.is_reference_job( + assert ReferenceFollowupJob.is_reference_job( file_path ), "Redshift must use staging to load files" job = RedshiftCopyFileLoadJob( - table, file_path, - self.sql_client, staging_credentials=self.config.staging_config.credentials, staging_iam_role=self.config.staging_iam_role, ) diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index bf175ba911..904b524791 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -4,10 +4,9 @@ from dlt.common.data_writers.configuration import CsvFormatConfiguration from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( - FollowupJob, - NewLoadJob, - TLoadJobState, + HasFollowupJobs, LoadJob, + RunnableLoadJob, CredentialsConfiguration, SupportsStagingDestination, ) @@ -24,13 +23,13 @@ from dlt.common.storages.load_package import ParsedLoadJobFileName from dlt.common.typing import TLoaderFileFormat from dlt.destinations.job_client_impl import SqlJobClientWithStaging -from dlt.destinations.job_impl import EmptyLoadJob +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs from dlt.destinations.exceptions import LoadJobTerminalException from dlt.destinations.impl.snowflake.configuration import SnowflakeClientConfiguration from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient -from dlt.destinations.job_impl import NewReferenceJob +from dlt.destinations.job_impl import ReferenceFollowupJob from dlt.destinations.type_mapping import TypeMapper @@ -79,63 +78,68 @@ def from_db_type( return super().from_db_type(db_type, precision, scale) -class SnowflakeLoadJob(LoadJob, FollowupJob): +class SnowflakeLoadJob(RunnableLoadJob, HasFollowupJobs): def __init__( self, file_path: str, - table_name: str, - load_id: str, - client: SnowflakeSqlClient, config: SnowflakeClientConfiguration, stage_name: Optional[str] = None, keep_staged_files: bool = True, staging_credentials: Optional[CredentialsConfiguration] = None, ) -> None: - file_name = FileStorage.get_file_name_from_file_path(file_path) - super().__init__(file_name) + super().__init__(file_path) + self._keep_staged_files = keep_staged_files + self._staging_credentials = staging_credentials + self._config = config + self._stage_name = stage_name + self._job_client: "SnowflakeClient" = None + + def run(self) -> None: + self._sql_client = self._job_client.sql_client + # resolve reference - is_local_file = not NewReferenceJob.is_reference_job(file_path) - file_url = file_path if is_local_file else NewReferenceJob.resolve_reference(file_path) + is_local_file = not ReferenceFollowupJob.is_reference_job(self._file_path) + file_url = ( + self._file_path + if is_local_file + else ReferenceFollowupJob.resolve_reference(self._file_path) + ) # take file name file_name = FileStorage.get_file_name_from_file_path(file_url) file_format = file_name.rsplit(".", 1)[-1] - qualified_table_name = client.make_qualified_table_name(table_name) + qualified_table_name = self._sql_client.make_qualified_table_name(self.load_table_name) # this means we have a local file stage_file_path: str = "" if is_local_file: - if not stage_name: + if not self._stage_name: # Use implicit table stage by default: "SCHEMA_NAME"."%TABLE_NAME" - stage_name = client.make_qualified_table_name("%" + table_name) - stage_file_path = f'@{stage_name}/"{load_id}"/{file_name}' + self._stage_name = self._sql_client.make_qualified_table_name( + "%" + self.load_table_name + ) + stage_file_path = f'@{self._stage_name}/"{self._load_id}"/{file_name}' copy_sql = self.gen_copy_sql( file_url, qualified_table_name, file_format, # type: ignore[arg-type] - client.capabilities.generates_case_sensitive_identifiers(), - stage_name, + self._sql_client.capabilities.generates_case_sensitive_identifiers(), + self._stage_name, stage_file_path, - staging_credentials, - config.csv_format, + self._staging_credentials, + self._config.csv_format, ) - with client.begin_transaction(): + with self._sql_client.begin_transaction(): # PUT and COPY in one tx if local file, otherwise only copy if is_local_file: - client.execute_sql( - f'PUT file://{file_path} @{stage_name}/"{load_id}" OVERWRITE = TRUE,' - " AUTO_COMPRESS = FALSE" + self._sql_client.execute_sql( + f'PUT file://{self._file_path} @{self._stage_name}/"{self._load_id}" OVERWRITE' + " = TRUE, AUTO_COMPRESS = FALSE" ) - client.execute_sql(copy_sql) - if stage_file_path and not keep_staged_files: - client.execute_sql(f"REMOVE {stage_file_path}") - - def state(self) -> TLoadJobState: - return "completed" - - def exception(self) -> str: - raise NotImplementedError() + self._sql_client.execute_sql(copy_sql) + if stage_file_path and not self._keep_staged_files: + self._sql_client.execute_sql(f"REMOVE {stage_file_path}") @classmethod def gen_copy_sql( @@ -267,15 +271,14 @@ def __init__( self.sql_client: SnowflakeSqlClient = sql_client # type: ignore self.type_mapper = SnowflakeTypeMapper(self.capabilities) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - job = super().start_file_load(table, file_path, load_id) + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + job = super().create_load_job(table, file_path, load_id, restore) if not job: job = SnowflakeLoadJob( file_path, - table["name"], - load_id, - self.sql_client, self.config, stage_name=self.config.stage_name, keep_staged_files=self.config.keep_staged_files, @@ -285,9 +288,6 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> ) return job - def restore_file_load(self, file_path: str) -> LoadJob: - return EmptyLoadJob.from_file_path(file_path, "completed") - def _make_add_column_sql( self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None ) -> List[str]: diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py index 408bfc2b53..d1b38f73bd 100644 --- a/dlt/destinations/impl/synapse/synapse.py +++ b/dlt/destinations/impl/synapse/synapse.py @@ -5,10 +5,7 @@ from urllib.parse import urlparse, urlunparse from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import ( - SupportsStagingDestination, - NewLoadJob, -) +from dlt.common.destination.reference import SupportsStagingDestination, FollowupJob, LoadJob from dlt.common.schema import TTableSchema, TColumnSchema, Schema, TColumnHint from dlt.common.schema.utils import ( @@ -22,9 +19,12 @@ AzureServicePrincipalCredentialsWithoutDefaults, ) -from dlt.destinations.job_impl import NewReferenceJob +from dlt.destinations.job_impl import ReferenceFollowupJob from dlt.destinations.sql_client import SqlClientBase -from dlt.destinations.job_client_impl import SqlJobClientBase, LoadJob, CopyRemoteFileLoadJob +from dlt.destinations.job_client_impl import ( + SqlJobClientBase, + CopyRemoteFileLoadJob, +) from dlt.destinations.exceptions import LoadJobTerminalException from dlt.destinations.impl.mssql.mssql import ( @@ -131,7 +131,7 @@ def _get_columstore_valid_column(self, c: TColumnSchema) -> TColumnSchema: def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] - ) -> List[NewLoadJob]: + ) -> List[FollowupJob]: return SqlJobClientBase._create_replace_followup_jobs(self, table_chain) def prepare_load_table(self, table_name: str, staging: bool = False) -> TTableSchema: @@ -158,16 +158,16 @@ def prepare_load_table(self, table_name: str, staging: bool = False) -> TTableSc table[TABLE_INDEX_TYPE_HINT] = self.config.default_table_index_type # type: ignore[typeddict-unknown-key] return table - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - job = super().start_file_load(table, file_path, load_id) + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + job = super().create_load_job(table, file_path, load_id, restore) if not job: - assert NewReferenceJob.is_reference_job( + assert ReferenceFollowupJob.is_reference_job( file_path ), "Synapse must use staging to load files" job = SynapseCopyFileLoadJob( - table, file_path, - self.sql_client, self.config.staging_config.credentials, # type: ignore[arg-type] self.config.staging_use_msi, ) @@ -177,22 +177,21 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> class SynapseCopyFileLoadJob(CopyRemoteFileLoadJob): def __init__( self, - table: TTableSchema, file_path: str, - sql_client: SqlClientBase[Any], staging_credentials: Optional[ Union[AzureCredentialsWithoutDefaults, AzureServicePrincipalCredentialsWithoutDefaults] ] = None, staging_use_msi: bool = False, ) -> None: self.staging_use_msi = staging_use_msi - super().__init__(table, file_path, sql_client, staging_credentials) + super().__init__(file_path, staging_credentials) - def execute(self, table: TTableSchema, bucket_path: str) -> None: + def run(self) -> None: + self._sql_client = self._job_client.sql_client # get format - ext = os.path.splitext(bucket_path)[1][1:] + ext = os.path.splitext(self._bucket_path)[1][1:] if ext == "parquet": - if table_schema_has_type(table, "time"): + if table_schema_has_type(self._load_table, "time"): # Synapse interprets Parquet TIME columns as bigint, resulting in # an incompatibility error. raise LoadJobTerminalException( @@ -216,8 +215,8 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: (AzureCredentialsWithoutDefaults, AzureServicePrincipalCredentialsWithoutDefaults), ) azure_storage_account_name = staging_credentials.azure_storage_account_name - https_path = self._get_https_path(bucket_path, azure_storage_account_name) - table_name = table["name"] + https_path = self._get_https_path(self._bucket_path, azure_storage_account_name) + table_name = self._load_table["name"] if self.staging_use_msi: credential = "IDENTITY = 'Managed Identity'" @@ -252,10 +251,6 @@ def execute(self, table: TTableSchema, bucket_path: str) -> None: """) self._sql_client.execute_sql(sql) - def exception(self) -> str: - # this part of code should be never reached - raise NotImplementedError() - def _get_https_path(self, bucket_path: str, storage_account_name: str) -> str: """ Converts a path in the form of az:/// to diff --git a/dlt/destinations/impl/weaviate/weaviate_client.py b/dlt/destinations/impl/weaviate/weaviate_client.py index dfbf83d7e5..b8bf3d62c6 100644 --- a/dlt/destinations/impl/weaviate/weaviate_client.py +++ b/dlt/destinations/impl/weaviate/weaviate_client.py @@ -38,11 +38,17 @@ version_table, ) from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import TLoadJobState, LoadJob, JobClientBase, WithStateSync +from dlt.common.destination.reference import ( + TLoadJobState, + RunnableLoadJob, + JobClientBase, + WithStateSync, + LoadJob, +) from dlt.common.storages import FileStorage from dlt.destinations.impl.weaviate.weaviate_adapter import VECTORIZE_HINT, TOKENIZATION_HINT -from dlt.destinations.job_impl import EmptyLoadJob +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs from dlt.destinations.job_client_impl import StorageSchemaInfo, StateInfo from dlt.destinations.impl.weaviate.configuration import WeaviateClientConfiguration from dlt.destinations.impl.weaviate.exceptions import PropertyNameConflict, WeaviateGrpcError @@ -143,34 +149,31 @@ def _wrap(*args: Any, **kwargs: Any) -> Any: return _wrap # type: ignore -class LoadWeaviateJob(LoadJob): +class LoadWeaviateJob(RunnableLoadJob): def __init__( self, - schema: Schema, - table_schema: TTableSchema, - local_path: str, - db_client: weaviate.Client, - client_config: WeaviateClientConfiguration, + file_path: str, class_name: str, ) -> None: - file_name = FileStorage.get_file_name_from_file_path(local_path) - super().__init__(file_name) - self.client_config = client_config - self.db_client = db_client - self.table_name = table_schema["name"] - self.class_name = class_name - self.unique_identifiers = self.list_unique_identifiers(table_schema) + super().__init__(file_path) + self._job_client: WeaviateClient = None + self._class_name = class_name + + def run(self) -> None: + self._db_client = self._job_client.db_client + self._client_config = self._job_client.config + self.unique_identifiers = self.list_unique_identifiers(self._load_table) self.complex_indices = [ i - for i, field in schema.get_table_columns(self.table_name).items() + for i, field in self._schema.get_table_columns(self.load_table_name).items() if field["data_type"] == "complex" ] self.date_indices = [ i - for i, field in schema.get_table_columns(self.table_name).items() + for i, field in self._schema.get_table_columns(self.load_table_name).items() if field["data_type"] == "date" ] - with FileStorage.open_zipsafe_ro(local_path) as f: + with FileStorage.open_zipsafe_ro(self._file_path) as f: self.load_batch(f) @wrap_weaviate_error @@ -188,15 +191,15 @@ def check_batch_result(results: List[StrAny]) -> None: if "error" in result["result"]["errors"]: raise WeaviateGrpcError(result["result"]["errors"]) - with self.db_client.batch( - batch_size=self.client_config.batch_size, - timeout_retries=self.client_config.batch_retries, - connection_error_retries=self.client_config.batch_retries, + with self._db_client.batch( + batch_size=self._client_config.batch_size, + timeout_retries=self._client_config.batch_retries, + connection_error_retries=self._client_config.batch_retries, weaviate_error_retries=weaviate.WeaviateErrorRetryConf( - self.client_config.batch_retries + self._client_config.batch_retries ), - consistency_level=weaviate.ConsistencyLevel[self.client_config.batch_consistency], - num_workers=self.client_config.batch_workers, + consistency_level=weaviate.ConsistencyLevel[self._client_config.batch_consistency], + num_workers=self._client_config.batch_workers, callback=check_batch_result, ) as batch: for line in f: @@ -209,11 +212,11 @@ def check_batch_result(results: List[StrAny]) -> None: if key in data: data[key] = ensure_pendulum_datetime(data[key]).isoformat() if self.unique_identifiers: - uuid = self.generate_uuid(data, self.unique_identifiers, self.class_name) + uuid = self.generate_uuid(data, self.unique_identifiers, self._class_name) else: uuid = None - batch.add_data_object(data, self.class_name, uuid=uuid) + batch.add_data_object(data, self._class_name, uuid=uuid) def list_unique_identifiers(self, table_schema: TTableSchema) -> Sequence[str]: if table_schema.get("write_disposition") == "merge": @@ -228,12 +231,6 @@ def generate_uuid( data_id = "_".join([str(data[key]) for key in unique_identifiers]) return generate_uuid5(data_id, class_name) # type: ignore - def state(self) -> TLoadJobState: - return "completed" - - def exception(self) -> str: - raise NotImplementedError() - class WeaviateClient(JobClientBase, WithStateSync): """Weaviate client implementation.""" @@ -677,19 +674,14 @@ def _make_property_schema( **extra_kv, } - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: return LoadWeaviateJob( - self.schema, - table, file_path, - db_client=self.db_client, - client_config=self.config, class_name=self.make_qualified_class_name(table["name"]), ) - def restore_file_load(self, file_path: str) -> LoadJob: - return EmptyLoadJob.from_file_path(file_path, "completed") - @wrap_weaviate_error def complete_load(self, load_id: str) -> None: # corresponds to order of the columns in loads_table() diff --git a/dlt/destinations/insert_job_client.py b/dlt/destinations/insert_job_client.py index 652d13f556..6ccc65705b 100644 --- a/dlt/destinations/insert_job_client.py +++ b/dlt/destinations/insert_job_client.py @@ -2,35 +2,31 @@ import abc from typing import Any, Iterator, List -from dlt.common.destination.reference import LoadJob, FollowupJob, TLoadJobState +from dlt.common.destination.reference import RunnableLoadJob, HasFollowupJobs, LoadJob from dlt.common.schema.typing import TTableSchema from dlt.common.storages import FileStorage from dlt.common.utils import chunks from dlt.destinations.sql_client import SqlClientBase -from dlt.destinations.job_impl import EmptyLoadJob -from dlt.destinations.job_client_impl import SqlJobClientWithStaging +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs +from dlt.destinations.job_client_impl import SqlJobClientWithStaging, SqlJobClientBase -class InsertValuesLoadJob(LoadJob, FollowupJob): - def __init__(self, table_name: str, file_path: str, sql_client: SqlClientBase[Any]) -> None: - super().__init__(FileStorage.get_file_name_from_file_path(file_path)) - self._sql_client = sql_client +class InsertValuesLoadJob(RunnableLoadJob, HasFollowupJobs): + def __init__(self, file_path: str) -> None: + super().__init__(file_path) + self._job_client: "SqlJobClientBase" = None + + def run(self) -> None: # insert file content immediately + self._sql_client = self._job_client.sql_client + with self._sql_client.begin_transaction(): for fragments in self._insert( - sql_client.make_qualified_table_name(table_name), file_path + self._sql_client.make_qualified_table_name(self.load_table_name), self._file_path ): self._sql_client.execute_fragments(fragments) - def state(self) -> TLoadJobState: - # this job is always done - return "completed" - - def exception(self) -> str: - # this part of code should be never reached - raise NotImplementedError() - def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[str]]: # WARNING: maximum redshift statement is 16MB https://docs.aws.amazon.com/redshift/latest/dg/c_redshift-sql.html # the procedure below will split the inserts into max_query_length // 2 packs @@ -101,27 +97,12 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st class InsertValuesJobClient(SqlJobClientWithStaging): - def restore_file_load(self, file_path: str) -> LoadJob: - """Returns a completed SqlLoadJob or InsertValuesJob - - Returns completed jobs as SqlLoadJob and InsertValuesJob executed atomically in start_file_load so any jobs that should be recreated are already completed. - Obviously the case of asking for jobs that were never created will not be handled. With correctly implemented loader that cannot happen. - - Args: - file_path (str): a path to a job file - - Returns: - LoadJob: Always a restored job completed - """ - job = super().restore_file_load(file_path) - if not job: - job = EmptyLoadJob.from_file_path(file_path, "completed") - return job - - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: - job = super().start_file_load(table, file_path, load_id) + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: + job = super().create_load_job(table, file_path, load_id, restore) if not job: # this is using sql_client internally and will raise a right exception if file_path.endswith("insert_values"): - job = InsertValuesLoadJob(table["name"], file_path, self.sql_client) + job = InsertValuesLoadJob(file_path) return job diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index dd0e783414..7fdd979c5d 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -42,18 +42,20 @@ WithStateSync, DestinationClientConfiguration, DestinationClientDwhConfiguration, - NewLoadJob, + FollowupJob, WithStagingDataset, - TLoadJobState, + RunnableLoadJob, LoadJob, JobClientBase, - FollowupJob, + HasFollowupJobs, CredentialsConfiguration, ) from dlt.destinations.exceptions import DatabaseUndefinedRelation -from dlt.destinations.job_impl import EmptyLoadJobWithoutFollowup, NewReferenceJob -from dlt.destinations.sql_jobs import SqlMergeJob, SqlStagingCopyJob +from dlt.destinations.job_impl import ( + ReferenceFollowupJob, +) +from dlt.destinations.sql_jobs import SqlMergeFollowupJob, SqlStagingCopyFollowupJob from dlt.destinations.typing import TNativeConn from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.utils import ( @@ -66,36 +68,32 @@ DDL_COMMANDS = ["ALTER", "CREATE", "DROP"] -class SqlLoadJob(LoadJob): +class SqlLoadJob(RunnableLoadJob): """A job executing sql statement, without followup trait""" - def __init__(self, file_path: str, sql_client: SqlClientBase[Any]) -> None: - super().__init__(FileStorage.get_file_name_from_file_path(file_path)) + def __init__(self, file_path: str) -> None: + super().__init__(file_path) + self._job_client: "SqlJobClientBase" = None + + def run(self) -> None: + self._sql_client = self._job_client.sql_client # execute immediately if client present - with FileStorage.open_zipsafe_ro(file_path, "r", encoding="utf-8") as f: + with FileStorage.open_zipsafe_ro(self._file_path, "r", encoding="utf-8") as f: sql = f.read() # Some clients (e.g. databricks) do not support multiple statements in one execute call - if not sql_client.capabilities.supports_multiple_statements: - sql_client.execute_many(self._split_fragments(sql)) + if not self._sql_client.capabilities.supports_multiple_statements: + self._sql_client.execute_many(self._split_fragments(sql)) # if we detect ddl transactions, only execute transaction if supported by client elif ( not self._string_contains_ddl_queries(sql) - or sql_client.capabilities.supports_ddl_transactions + or self._sql_client.capabilities.supports_ddl_transactions ): # with sql_client.begin_transaction(): - sql_client.execute_sql(sql) + self._sql_client.execute_sql(sql) else: # sql_client.execute_sql(sql) - sql_client.execute_many(self._split_fragments(sql)) - - def state(self) -> TLoadJobState: - # this job is always done - return "completed" - - def exception(self) -> str: - # this part of code should be never reached - raise NotImplementedError() + self._sql_client.execute_many(self._split_fragments(sql)) def _string_contains_ddl_queries(self, sql: str) -> bool: for cmd in DDL_COMMANDS: @@ -111,27 +109,16 @@ def is_sql_job(file_path: str) -> bool: return os.path.splitext(file_path)[1][1:] == "sql" -class CopyRemoteFileLoadJob(LoadJob, FollowupJob): +class CopyRemoteFileLoadJob(RunnableLoadJob, HasFollowupJobs): def __init__( self, - table: TTableSchema, file_path: str, - sql_client: SqlClientBase[Any], staging_credentials: Optional[CredentialsConfiguration] = None, ) -> None: - super().__init__(FileStorage.get_file_name_from_file_path(file_path)) - self._sql_client = sql_client + super().__init__(file_path) + self._job_client: "SqlJobClientBase" = None self._staging_credentials = staging_credentials - - self.execute(table, NewReferenceJob.resolve_reference(file_path)) - - def execute(self, table: TTableSchema, bucket_path: str) -> None: - # implement in child implementations - raise NotImplementedError() - - def state(self) -> TLoadJobState: - # this job is always done - return "completed" + self._bucket_path = ReferenceFollowupJob.resolve_reference(file_path) class SqlJobClientBase(JobClientBase, WithStateSync): @@ -227,19 +214,23 @@ def should_truncate_table_before_load(self, table: TTableSchema) -> bool: and self.config.replace_strategy == "truncate-and-insert" ) - def _create_append_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + def _create_append_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[FollowupJob]: return [] - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: - return [SqlMergeJob.from_table_chain(table_chain, self.sql_client)] + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[FollowupJob]: + return [SqlMergeFollowupJob.from_table_chain(table_chain, self.sql_client)] def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] - ) -> List[NewLoadJob]: - jobs: List[NewLoadJob] = [] + ) -> List[FollowupJob]: + jobs: List[FollowupJob] = [] if self.config.replace_strategy in ["insert-from-staging", "staging-optimized"]: jobs.append( - SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": True}) + SqlStagingCopyFollowupJob.from_table_chain( + table_chain, self.sql_client, {"replace": True} + ) ) return jobs @@ -247,7 +238,7 @@ def create_table_chain_completed_followup_jobs( self, table_chain: Sequence[TTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, - ) -> List[NewLoadJob]: + ) -> List[FollowupJob]: """Creates a list of followup jobs for merge write disposition and staging replace strategies""" jobs = super().create_table_chain_completed_followup_jobs( table_chain, completed_table_chain_jobs @@ -261,28 +252,13 @@ def create_table_chain_completed_followup_jobs( jobs.extend(self._create_replace_followup_jobs(table_chain)) return jobs - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + def create_load_job( + self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + ) -> LoadJob: """Starts SqlLoadJob for files ending with .sql or returns None to let derived classes to handle their specific jobs""" - self._set_query_tags_for_job(load_id, table) if SqlLoadJob.is_sql_job(file_path): - # execute sql load job - return SqlLoadJob(file_path, self.sql_client) - return None - - def restore_file_load(self, file_path: str) -> LoadJob: - """Returns a completed SqlLoadJob or None to let derived classes to handle their specific jobs - - Returns completed jobs as SqlLoadJob is executed atomically in start_file_load so any jobs that should be recreated are already completed. - Obviously the case of asking for jobs that were never created will not be handled. With correctly implemented loader that cannot happen. - - Args: - file_path (str): a path to a job file - - Returns: - LoadJob: A restored job or none - """ - if SqlLoadJob.is_sql_job(file_path): - return EmptyLoadJobWithoutFollowup.from_file_path(file_path, "completed") + # create sql load job + return SqlLoadJob(file_path) return None def complete_load(self, load_id: str) -> None: @@ -678,6 +654,9 @@ def _verify_schema(self) -> None: logger.error(str(exception)) raise exceptions[0] + def prepare_load_job_execution(self, job: RunnableLoadJob) -> None: + self._set_query_tags_for_job(load_id=job._load_id, table=job._load_table) + def _set_query_tags_for_job(self, load_id: str, table: TTableSchema) -> None: """Sets query tags in sql_client for a job in package `load_id`, starting for a particular `table`""" from dlt.common.pipeline import current_pipeline diff --git a/dlt/destinations/job_impl.py b/dlt/destinations/job_impl.py index 9a8f7277b7..41c939f482 100644 --- a/dlt/destinations/job_impl.py +++ b/dlt/destinations/job_impl.py @@ -1,14 +1,22 @@ from abc import ABC, abstractmethod import os import tempfile # noqa: 251 -from typing import Dict, Iterable, List, Optional +from typing import Dict, Iterable, List from dlt.common.json import json -from dlt.common.destination.reference import NewLoadJob, FollowupJob, TLoadJobState, LoadJob +from dlt.common.destination.reference import ( + HasFollowupJobs, + TLoadJobState, + RunnableLoadJob, + JobClientBase, + FollowupJob, + LoadJob, +) from dlt.common.storages.load_package import commit_load_package_state from dlt.common.schema import Schema, TTableSchema from dlt.common.storages import FileStorage from dlt.common.typing import TDataItems +from dlt.common.storages.load_storage import ParsedLoadJobFileName from dlt.destinations.impl.destination.configuration import ( CustomDestinationClientConfiguration, @@ -16,17 +24,26 @@ ) -class EmptyLoadJobWithoutFollowup(LoadJob): - def __init__(self, file_name: str, status: TLoadJobState, exception: str = None) -> None: +class FinalizedLoadJob(LoadJob): + """ + Special Load Job that should never get started and just indicates a job being in a final state. + May also be used to indicate that nothing needs to be done. + """ + + def __init__( + self, file_path: str, status: TLoadJobState = "completed", exception: str = None + ) -> None: self._status = status self._exception = exception - super().__init__(file_name) + self._file_path = file_path + assert self._status in ("completed", "failed", "retry") + super().__init__(file_path) @classmethod def from_file_path( - cls, file_path: str, status: TLoadJobState, message: str = None - ) -> "EmptyLoadJobWithoutFollowup": - return cls(FileStorage.get_file_name_from_file_path(file_path), status, exception=message) + cls, file_path: str, status: TLoadJobState = "completed", message: str = None + ) -> "FinalizedLoadJob": + return cls(file_path, status, exception=message) def state(self) -> TLoadJobState: return self._status @@ -35,101 +52,107 @@ def exception(self) -> str: return self._exception -class EmptyLoadJob(EmptyLoadJobWithoutFollowup, FollowupJob): +class FinalizedLoadJobWithFollowupJobs(FinalizedLoadJob, HasFollowupJobs): pass -class NewLoadJobImpl(EmptyLoadJobWithoutFollowup, NewLoadJob): +class FollowupJobImpl(FollowupJob): + """ + Class to create a new loadjob, not stateful and not runnable + """ + + def __init__(self, file_name: str) -> None: + self._file_path = os.path.join(tempfile.gettempdir(), file_name) + self._parsed_file_name = ParsedLoadJobFileName.parse(file_name) + # we only accept jobs that we can scheduleas new or mark as failed.. + def _save_text_file(self, data: str) -> None: - temp_file = os.path.join(tempfile.gettempdir(), self._file_name) - with open(temp_file, "w", encoding="utf-8") as f: + with open(self._file_path, "w", encoding="utf-8") as f: f.write(data) - self._new_file_path = temp_file def new_file_path(self) -> str: """Path to a newly created temporary job file""" - return self._new_file_path + return self._file_path + def job_id(self) -> str: + """The job id that is derived from the file name and does not changes during job lifecycle""" + return self._parsed_file_name.job_id() -class NewReferenceJob(NewLoadJobImpl): - def __init__( - self, file_name: str, status: TLoadJobState, exception: str = None, remote_path: str = None - ) -> None: - file_name = os.path.splitext(file_name)[0] + ".reference" - super().__init__(file_name, status, exception) - self._remote_path = remote_path - self._save_text_file(remote_path) + +class ReferenceFollowupJob(FollowupJobImpl): + def __init__(self, original_file_name: str, remote_paths: List[str]) -> None: + file_name = os.path.splitext(original_file_name)[0] + "." + "reference" + self._remote_paths = remote_paths + super().__init__(file_name) + self._save_text_file("\n".join(remote_paths)) @staticmethod def is_reference_job(file_path: str) -> bool: return os.path.splitext(file_path)[1][1:] == "reference" @staticmethod - def resolve_reference(file_path: str) -> str: + def resolve_references(file_path: str) -> List[str]: with open(file_path, "r+", encoding="utf-8") as f: # Reading from a file - return f.read() + return f.read().split("\n") + + @staticmethod + def resolve_reference(file_path: str) -> str: + refs = ReferenceFollowupJob.resolve_references(file_path) + assert len(refs) == 1 + return refs[0] -class DestinationLoadJob(LoadJob, ABC): +class DestinationLoadJob(RunnableLoadJob, ABC): def __init__( self, - table: TTableSchema, file_path: str, config: CustomDestinationClientConfiguration, - schema: Schema, destination_state: Dict[str, int], destination_callable: TDestinationCallable, skipped_columns: List[str], + callable_requires_job_client_args: bool = False, ) -> None: - super().__init__(FileStorage.get_file_name_from_file_path(file_path)) - self._file_path = file_path + super().__init__(file_path) self._config = config - self._table = table - self._schema = schema - # we create pre_resolved callable here self._callable = destination_callable - self._state: TLoadJobState = "running" self._storage_id = f"{self._parsed_file_name.table_name}.{self._parsed_file_name.file_id}" - self.skipped_columns = skipped_columns + self._skipped_columns = skipped_columns + self._destination_state = destination_state + self._callable_requires_job_client_args = callable_requires_job_client_args + + def run(self) -> None: + # update filepath, it will be in running jobs now try: if self._config.batch_size == 0: # on batch size zero we only call the callable with the filename self.call_callable_with_items(self._file_path) else: - current_index = destination_state.get(self._storage_id, 0) - for batch in self.run(current_index): + current_index = self._destination_state.get(self._storage_id, 0) + for batch in self.get_batches(current_index): self.call_callable_with_items(batch) current_index += len(batch) - destination_state[self._storage_id] = current_index - - self._state = "completed" - except Exception as e: - self._state = "retry" - raise e + self._destination_state[self._storage_id] = current_index finally: # save progress commit_load_package_state() - @abstractmethod - def run(self, start_index: int) -> Iterable[TDataItems]: - pass - def call_callable_with_items(self, items: TDataItems) -> None: if not items: return # call callable - self._callable(items, self._table) - - def state(self) -> TLoadJobState: - return self._state + if self._callable_requires_job_client_args: + self._callable(items, self._load_table, job_client=self._job_client) # type: ignore + else: + self._callable(items, self._load_table) - def exception(self) -> str: - raise NotImplementedError() + @abstractmethod + def get_batches(self, start_index: int) -> Iterable[TDataItems]: + pass class DestinationParquetLoadJob(DestinationLoadJob): - def run(self, start_index: int) -> Iterable[TDataItems]: + def get_batches(self, start_index: int) -> Iterable[TDataItems]: # stream items from dlt.common.libs.pyarrow import pyarrow @@ -140,7 +163,7 @@ def run(self, start_index: int) -> Iterable[TDataItems]: # on record batches we cannot drop columns, we need to # select the ones we want to keep - keep_columns = list(self._table["columns"].keys()) + keep_columns = list(self._load_table["columns"].keys()) start_batch = start_index / self._config.batch_size with pyarrow.parquet.ParquetFile(self._file_path) as reader: for record_batch in reader.iter_batches( @@ -153,7 +176,7 @@ def run(self, start_index: int) -> Iterable[TDataItems]: class DestinationJsonlLoadJob(DestinationLoadJob): - def run(self, start_index: int) -> Iterable[TDataItems]: + def get_batches(self, start_index: int) -> Iterable[TDataItems]: current_batch: TDataItems = [] # stream items @@ -168,7 +191,7 @@ def run(self, start_index: int) -> Iterable[TDataItems]: start_index -= 1 continue # skip internal columns - for column in self.skipped_columns: + for column in self._skipped_columns: item.pop(column, None) current_batch.append(item) if len(current_batch) == self._config.batch_size: diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index e67be049ab..cddae52bb7 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -21,8 +21,9 @@ from dlt.common.utils import uniq_id from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.destinations.exceptions import MergeDispositionException -from dlt.destinations.job_impl import NewLoadJobImpl +from dlt.destinations.job_impl import FollowupJobImpl from dlt.destinations.sql_client import SqlClientBase +from dlt.common.destination.exceptions import DestinationTransientException class SqlJobParams(TypedDict, total=False): @@ -33,10 +34,19 @@ class SqlJobParams(TypedDict, total=False): DEFAULTS: SqlJobParams = {"replace": False} -class SqlBaseJob(NewLoadJobImpl): - """Sql base job for jobs that rely on the whole tablechain""" +class SqlJobCreationException(DestinationTransientException): + def __init__(self, original_exception: Exception, table_chain: Sequence[TTableSchema]) -> None: + tables_str = yaml.dump( + table_chain, allow_unicode=True, default_flow_style=False, sort_keys=False + ) + super().__init__( + f"Could not create SQLFollowupJob with exception {str(original_exception)}. Table" + f" chain: {tables_str}" + ) - failed_text: str = "" + +class SqlFollowupJob(FollowupJobImpl): + """Sql base job for jobs that rely on the whole tablechain""" @classmethod def from_table_chain( @@ -44,7 +54,7 @@ def from_table_chain( table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None, - ) -> NewLoadJobImpl: + ) -> FollowupJobImpl: """Generates a list of sql statements, that will be executed by the sql client when the job is executed in the loader. The `table_chain` contains a list schemas of a tables with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list). @@ -54,6 +64,7 @@ def from_table_chain( file_info = ParsedLoadJobFileName( top_table["name"], ParsedLoadJobFileName.new_file_id(), 0, "sql" ) + try: # Remove line breaks from multiline statements and write one SQL statement per line in output file # to support clients that need to execute one statement at a time (i.e. snowflake) @@ -61,15 +72,12 @@ def from_table_chain( " ".join(stmt.splitlines()) for stmt in cls.generate_sql(table_chain, sql_client, params) ] - job = cls(file_info.file_name(), "running") + job = cls(file_info.file_name()) job._save_text_file("\n".join(sql)) - except Exception: - # return failed job - tables_str = yaml.dump( - table_chain, allow_unicode=True, default_flow_style=False, sort_keys=False - ) - job = cls(file_info.file_name(), "failed", pretty_format_exception()) - job._save_text_file("\n".join([cls.failed_text, tables_str])) + except Exception as e: + # raise exception with some context + raise SqlJobCreationException(e, table_chain) from e + return job @classmethod @@ -82,11 +90,9 @@ def generate_sql( pass -class SqlStagingCopyJob(SqlBaseJob): +class SqlStagingCopyFollowupJob(SqlFollowupJob): """Generates a list of sql statements that copy the data from staging dataset into destination dataset.""" - failed_text: str = "Tried to generate a staging copy sql job for the following tables:" - @classmethod def _generate_clone_sql( cls, @@ -141,14 +147,12 @@ def generate_sql( return cls._generate_insert_sql(table_chain, sql_client, params) -class SqlMergeJob(SqlBaseJob): +class SqlMergeFollowupJob(SqlFollowupJob): """ Generates a list of sql statements that merge the data from staging dataset into destination dataset. If no merge keys are discovered, falls back to append. """ - failed_text: str = "Tried to generate a merge sql job for the following tables:" - @classmethod def generate_sql( # type: ignore[return] cls, diff --git a/dlt/load/exceptions.py b/dlt/load/exceptions.py index e85dffd2e9..14d0eb1b23 100644 --- a/dlt/load/exceptions.py +++ b/dlt/load/exceptions.py @@ -5,7 +5,12 @@ ) -class LoadClientJobFailed(DestinationTerminalException): +class LoadClientJobException(Exception): + load_id: str + job_id: str + + +class LoadClientJobFailed(DestinationTerminalException, LoadClientJobException): def __init__(self, load_id: str, job_id: str, failed_message: str) -> None: self.load_id = load_id self.job_id = job_id @@ -16,15 +21,19 @@ def __init__(self, load_id: str, job_id: str, failed_message: str) -> None: ) -class LoadClientJobRetry(DestinationTransientException): - def __init__(self, load_id: str, job_id: str, retry_count: int, max_retry_count: int) -> None: +class LoadClientJobRetry(DestinationTransientException, LoadClientJobException): + def __init__( + self, load_id: str, job_id: str, retry_count: int, max_retry_count: int, retry_message: str + ) -> None: self.load_id = load_id self.job_id = job_id self.retry_count = retry_count self.max_retry_count = max_retry_count + self.retry_message = retry_message super().__init__( f"Job for {job_id} had {retry_count} retries which a multiple of {max_retry_count}." " Exiting retry loop. You can still rerun the load package to retry this job." + f" Last failure message was {retry_message}" ) @@ -50,3 +59,18 @@ def __init__(self, table_name: str, write_disposition: str, file_name: str) -> N f"Loader does not support {write_disposition} in table {table_name} when loading file" f" {file_name}" ) + + +class FollowupJobCreationFailedException(DestinationTransientException): + def __init__(self, job_id: str) -> None: + self.job_id = job_id + super().__init__(f"Failed to create followup job for job with id {job_id}") + + +class TableChainFollowupJobCreationFailedException(DestinationTransientException): + def __init__(self, root_table_name: str) -> None: + self.root_table_name = root_table_name + super().__init__( + "Failed creating table chain followup jobs for table chain with root table" + f" {root_table_name}." + ) diff --git a/dlt/load/load.py b/dlt/load/load.py index 2290d40a1e..34b7e2b5b7 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -18,18 +18,18 @@ from dlt.common.runners import TRunMetrics, Runnable, workermethod, NullExecutor from dlt.common.runtime.collector import Collector, NULL_COLLECTOR from dlt.common.logger import pretty_format_exception -from dlt.common.exceptions import TerminalValueError from dlt.common.configuration.container import Container from dlt.common.schema import Schema from dlt.common.storages import LoadStorage from dlt.common.destination.reference import ( DestinationClientDwhConfiguration, - FollowupJob, + HasFollowupJobs, JobClientBase, WithStagingDataset, Destination, + RunnableLoadJob, LoadJob, - NewLoadJob, + FollowupJob, TLoadJobState, DestinationClientConfiguration, SupportsStagingDestination, @@ -37,10 +37,10 @@ ) from dlt.common.destination.exceptions import ( DestinationTerminalException, - DestinationTransientException, ) +from dlt.common.runtime import signals -from dlt.destinations.job_impl import EmptyLoadJob +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs from dlt.load.configuration import LoaderConfiguration from dlt.load.exceptions import ( @@ -48,12 +48,16 @@ LoadClientJobRetry, LoadClientUnsupportedWriteDisposition, LoadClientUnsupportedFileFormats, + LoadClientJobException, + FollowupJobCreationFailedException, + TableChainFollowupJobCreationFailedException, ) from dlt.load.utils import ( _extend_tables_with_table_chain, get_completed_table_chain, init_client, filter_new_jobs, + get_available_worker_slots, ) @@ -80,6 +84,9 @@ def __init__( self.pool = NullExecutor() self.load_storage: LoadStorage = self.create_storage(is_storage_owner) self._loaded_packages: List[LoadPackageInfo] = [] + self._run_loop_sleep_duration: float = ( + 1.0 # amount of time to sleep between querying completed jobs + ) super().__init__() def create_storage(self, is_storage_owner: bool) -> LoadStorage: @@ -108,10 +115,13 @@ def get_staging_destination_client(self, schema: Schema) -> JobClientBase: return self.staging_destination.client(schema, self.initial_staging_client_config) def is_staging_destination_job(self, file_path: str) -> bool: + file_type = os.path.splitext(file_path)[1][1:] + # for now we know that reference jobs always go do the main destination + if file_type == "reference": + return False return ( self.staging_destination is not None - and os.path.splitext(file_path)[1][1:] - in self.staging_destination.capabilities().supported_loader_file_formats + and file_type in self.staging_destination.capabilities().supported_loader_file_formats ) @contextlib.contextmanager @@ -125,94 +135,150 @@ def maybe_with_staging_dataset( else: yield - @staticmethod - @workermethod - def w_spool_job( - self: "Load", file_path: str, load_id: str, schema: Schema - ) -> Optional[LoadJob]: + def submit_job( + self, file_path: str, load_id: str, schema: Schema, restore: bool = False + ) -> LoadJob: job: LoadJob = None + + is_staging_destination_job = self.is_staging_destination_job(file_path) + job_client = self.get_destination_client(schema) + + # if we have a staging destination and the file is not a reference, send to staging + active_job_client = ( + self.get_staging_destination_client(schema) + if is_staging_destination_job + else job_client + ) + try: - is_staging_destination_job = self.is_staging_destination_job(file_path) - job_client = self.get_destination_client(schema) - - # if we have a staging destination and the file is not a reference, send to staging - with ( - self.get_staging_destination_client(schema) - if is_staging_destination_job - else job_client - ) as client: - job_info = ParsedLoadJobFileName.parse(file_path) - if job_info.file_format not in self.load_storage.supported_job_file_formats: - raise LoadClientUnsupportedFileFormats( - job_info.file_format, - self.destination.capabilities().supported_loader_file_formats, - file_path, - ) - logger.info(f"Will load file {file_path} with table name {job_info.table_name}") - table = client.prepare_load_table(job_info.table_name) - if table["write_disposition"] not in ["append", "replace", "merge"]: - raise LoadClientUnsupportedWriteDisposition( - job_info.table_name, table["write_disposition"], file_path - ) + # check file format + job_info = ParsedLoadJobFileName.parse(file_path) + if job_info.file_format not in self.load_storage.supported_job_file_formats: + raise LoadClientUnsupportedFileFormats( + job_info.file_format, + self.destination.capabilities().supported_loader_file_formats, + file_path, + ) + logger.info(f"Will load file {file_path} with table name {job_info.table_name}") - if is_staging_destination_job: - use_staging_dataset = isinstance( - job_client, SupportsStagingDestination - ) and job_client.should_load_data_to_staging_dataset_on_staging_destination( - table - ) - else: - use_staging_dataset = isinstance( - job_client, WithStagingDataset - ) and job_client.should_load_data_to_staging_dataset(table) - - with self.maybe_with_staging_dataset(client, use_staging_dataset): - job = client.start_file_load( - table, - self.load_storage.normalized_packages.storage.make_full_path(file_path), - load_id, - ) - except (DestinationTerminalException, TerminalValueError): - # if job irreversibly cannot be started, mark it as failed - logger.exception(f"Terminal problem when adding job {file_path}") - job = EmptyLoadJob.from_file_path(file_path, "failed", pretty_format_exception()) - except (DestinationTransientException, Exception): - # return no job so file stays in new jobs (root) folder - logger.exception(f"Temporary problem when adding job {file_path}") - job = EmptyLoadJob.from_file_path(file_path, "retry", pretty_format_exception()) - if job is None: - raise DestinationTerminalException( - f"Destination could not create a job for file {file_path}. Typically the file" - " extension could not be associated with job type and that indicates an error in" - " the code." + # check write disposition + load_table = active_job_client.prepare_load_table(job_info.table_name) + if load_table["write_disposition"] not in ["append", "replace", "merge"]: + raise LoadClientUnsupportedWriteDisposition( + job_info.table_name, load_table["write_disposition"], file_path + ) + + job = active_job_client.create_load_job( + load_table, + self.load_storage.normalized_packages.storage.make_full_path(file_path), + load_id, + restore=restore, ) - self.load_storage.normalized_packages.start_job(load_id, job.file_name()) + + if job is None: + raise DestinationTerminalException( + f"Destination could not create a job for file {file_path}. Typically the file" + " extension could not be associated with job type and that indicates an error" + " in the code." + ) + except DestinationTerminalException: + job = FinalizedLoadJobWithFollowupJobs.from_file_path( + file_path, "failed", pretty_format_exception() + ) + except Exception: + job = FinalizedLoadJobWithFollowupJobs.from_file_path( + file_path, "retry", pretty_format_exception() + ) + + # move to started jobs in case this is not a restored job + if not restore: + job._file_path = self.load_storage.normalized_packages.start_job( + load_id, job.file_name() + ) + + # only start a thread if this job is runnable + if isinstance(job, RunnableLoadJob): + # determine which dataset to use + if is_staging_destination_job: + use_staging_dataset = isinstance( + job_client, SupportsStagingDestination + ) and job_client.should_load_data_to_staging_dataset_on_staging_destination( + load_table + ) + else: + use_staging_dataset = isinstance( + job_client, WithStagingDataset + ) and job_client.should_load_data_to_staging_dataset(load_table) + + # set job vars + job.set_run_vars(load_id=load_id, schema=schema, load_table=load_table) + + # submit to pool + self.pool.submit(Load.w_run_job, *(id(self), job, is_staging_destination_job, use_staging_dataset, schema)) # type: ignore + + # sanity check: otherwise a job in an actionable state is expected + else: + assert job.state() in ("completed", "failed", "retry") + return job - def spool_new_jobs(self, load_id: str, schema: Schema) -> Tuple[int, List[LoadJob]]: - # use thread based pool as jobs processing is mostly I/O and we do not want to pickle jobs + @staticmethod + @workermethod + def w_run_job( + self: "Load", + job: RunnableLoadJob, + use_staging_client: bool, + use_staging_dataset: bool, + schema: Schema, + ) -> None: + """ + Start a load job in a separate thread + """ + active_job_client = ( + self.get_staging_destination_client(schema) + if use_staging_client + else self.get_destination_client(schema) + ) + with active_job_client as client: + with self.maybe_with_staging_dataset(client, use_staging_dataset): + job.run_managed(active_job_client) + + def start_new_jobs( + self, load_id: str, schema: Schema, running_jobs: Sequence[LoadJob] + ) -> Sequence[LoadJob]: + """ + will retrieve jobs from the new_jobs folder and start as many as there are slots available + """ + caps = self.destination.capabilities( + self.destination.configuration(self.initial_client_config) + ) + + # early exit if no slots available + available_slots = get_available_worker_slots(self.config, caps, running_jobs) + if available_slots <= 0: + return [] + + # get a list of jobs eligible to be started load_files = filter_new_jobs( self.load_storage.list_new_jobs(load_id), - self.destination.capabilities( - self.destination.configuration(self.initial_client_config) - ), + caps, self.config, + running_jobs, + available_slots, ) - file_count = len(load_files) - if file_count == 0: - logger.info(f"No new jobs found in {load_id}") - return 0, [] - logger.info(f"Will load {file_count}, creating jobs") - param_chunk = [(id(self), file, load_id, schema) for file in load_files] - # exceptions should not be raised, None as job is a temporary failure - # other jobs should not be affected - jobs = self.pool.map(Load.w_spool_job, *zip(*param_chunk)) - # remove None jobs and check the rest - return file_count, [job for job in jobs if job is not None] - - def retrieve_jobs( - self, client: JobClientBase, load_id: str, staging_client: JobClientBase = None - ) -> Tuple[int, List[LoadJob]]: + + logger.info(f"Will load additional {len(load_files)}, creating jobs") + started_jobs: List[LoadJob] = [] + for file in load_files: + job = self.submit_job(file, load_id, schema) + started_jobs.append(job) + + return started_jobs + + def resume_started_jobs(self, load_id: str, schema: Schema) -> List[LoadJob]: + """ + will check jobs in the started folder and resume them + """ jobs: List[LoadJob] = [] # list all files that were started but not yet completed @@ -220,23 +286,13 @@ def retrieve_jobs( logger.info(f"Found {len(started_jobs)} that are already started and should be continued") if len(started_jobs) == 0: - return 0, jobs + return jobs for file_path in started_jobs: - try: - logger.info(f"Will retrieve {file_path}") - client = staging_client if self.is_staging_destination_job(file_path) else client - job = client.restore_file_load(file_path) - except DestinationTerminalException: - logger.exception(f"Job retrieval for {file_path} failed, job will be terminated") - job = EmptyLoadJob.from_file_path(file_path, "failed", pretty_format_exception()) - # proceed to appending job, do not reraise - except (DestinationTransientException, Exception): - # raise on all temporary exceptions, typically network / server problems - raise + job = self.submit_job(file_path, load_id, schema, restore=True) jobs.append(job) - return len(jobs), jobs + return jobs def get_new_jobs_info(self, load_id: str) -> List[ParsedLoadJobFileName]: return [ @@ -246,9 +302,14 @@ def get_new_jobs_info(self, load_id: str) -> List[ParsedLoadJobFileName]: def create_followup_jobs( self, load_id: str, state: TLoadJobState, starting_job: LoadJob, schema: Schema - ) -> List[NewLoadJob]: - jobs: List[NewLoadJob] = [] - if isinstance(starting_job, FollowupJob): + ) -> None: + """ + for jobs marked as having followup jobs, find them all and store them to the new jobs folder + where they will be picked up for execution + """ + + jobs: List[FollowupJob] = [] + if isinstance(starting_job, HasFollowupJobs): # check for merge jobs only for jobs executing on the destination, the staging destination jobs must be excluded # NOTE: we may move that logic to the interface starting_job_file_name = starting_job.file_name() @@ -257,7 +318,7 @@ def create_followup_jobs( top_job_table = get_top_level_table( schema.tables, starting_job.job_file_info().table_name ) - # if all tables of chain completed, create follow up jobs + # if all tables of chain completed, create follow up jobs all_jobs_states = self.load_storage.normalized_packages.list_all_jobs_with_states( load_id ) @@ -265,60 +326,71 @@ def create_followup_jobs( schema, all_jobs_states, top_job_table, starting_job.job_file_info().job_id() ): table_chain_names = [table["name"] for table in table_chain] - # create job infos that contain full path to job table_chain_jobs = [ - self.load_storage.normalized_packages.job_to_job_info(load_id, *job_state) + # we mark all jobs as completed, as by the time the followup job runs the starting job will be in this + # folder too + self.load_storage.normalized_packages.job_to_job_info( + load_id, "completed_jobs", job_state[1] + ) for job_state in all_jobs_states if job_state[1].table_name in table_chain_names # job being completed is still in started_jobs and job_state[0] in ("completed_jobs", "started_jobs") ] - if follow_up_jobs := client.create_table_chain_completed_followup_jobs( - table_chain, table_chain_jobs - ): - jobs = jobs + follow_up_jobs - jobs = jobs + starting_job.create_followup_jobs(state) - return jobs + try: + if follow_up_jobs := client.create_table_chain_completed_followup_jobs( + table_chain, table_chain_jobs + ): + jobs = jobs + follow_up_jobs + except Exception as e: + raise TableChainFollowupJobCreationFailedException( + root_table_name=table_chain[0]["name"] + ) from e - def complete_jobs(self, load_id: str, jobs: List[LoadJob], schema: Schema) -> List[LoadJob]: + try: + jobs = jobs + starting_job.create_followup_jobs(state) + except Exception as e: + raise FollowupJobCreationFailedException(job_id=starting_job.job_id()) from e + + # import all followup jobs to the new jobs folder + for followup_job in jobs: + # save all created jobs + self.load_storage.normalized_packages.import_job( + load_id, followup_job.new_file_path(), job_state="new_jobs" + ) + logger.info( + f"Job {starting_job.job_id()} CREATED a new FOLLOWUP JOB" + f" {followup_job.new_file_path()} placed in new_jobs" + ) + + def complete_jobs( + self, load_id: str, jobs: Sequence[LoadJob], schema: Schema + ) -> Tuple[List[LoadJob], List[LoadJob], Optional[LoadClientJobException]]: """Run periodically in the main thread to collect job execution statuses. After detecting change of status, it commits the job state by moving it to the right folder May create one or more followup jobs that get scheduled as new jobs. New jobs are created only in terminal states (completed / failed) """ + # list of jobs still running remaining_jobs: List[LoadJob] = [] - - def _schedule_followup_jobs(followup_jobs: Iterable[NewLoadJob]) -> None: - for followup_job in followup_jobs: - # running should be moved into "new jobs", other statuses into started - folder: TJobState = ( - "new_jobs" if followup_job.state() == "running" else "started_jobs" - ) - # save all created jobs - self.load_storage.normalized_packages.import_job( - load_id, followup_job.new_file_path(), job_state=folder - ) - logger.info( - f"Job {job.job_id()} CREATED a new FOLLOWUP JOB" - f" {followup_job.new_file_path()} placed in {folder}" - ) - # if followup job is not "running" place it in current queue to be finalized - if not followup_job.state() == "running": - remaining_jobs.append(followup_job) + # list of jobs in final state + finalized_jobs: List[LoadJob] = [] + # if an exception condition was met, return it to the main runner + pending_exception: Optional[LoadClientJobException] = None logger.info(f"Will complete {len(jobs)} for {load_id}") for ii in range(len(jobs)): job = jobs[ii] logger.debug(f"Checking state for job {job.job_id()}") state: TLoadJobState = job.state() - if state == "running": + if state in ("ready", "running"): # ask again logger.debug(f"job {job.job_id()} still running") remaining_jobs.append(job) elif state == "failed": # create followup jobs - _schedule_followup_jobs(self.create_followup_jobs(load_id, state, job, schema)) + self.create_followup_jobs(load_id, state, job, schema) # try to get exception message from job failed_message = job.exception() @@ -329,6 +401,14 @@ def _schedule_followup_jobs(followup_jobs: Iterable[NewLoadJob]) -> None: f"Job for {job.job_id()} failed terminally in load {load_id} with message" f" {failed_message}" ) + # schedule exception on job failure + if self.config.raise_on_failed_jobs: + pending_exception = LoadClientJobFailed( + load_id, + job.job_file_info().job_id(), + failed_message, + ) + finalized_jobs.append(job) elif state == "retry": # try to get exception message from job retry_message = job.exception() @@ -337,13 +417,27 @@ def _schedule_followup_jobs(followup_jobs: Iterable[NewLoadJob]) -> None: logger.warning( f"Job for {job.job_id()} retried in load {load_id} with message {retry_message}" ) + # possibly schedule exception on too many retries + if self.config.raise_on_max_retries: + r_c = job.job_file_info().retry_count + 1 + if r_c > 0 and r_c % self.config.raise_on_max_retries == 0: + pending_exception = LoadClientJobRetry( + load_id, + job.job_file_info().job_id(), + r_c, + self.config.raise_on_max_retries, + retry_message=retry_message, + ) elif state == "completed": # create followup jobs - _schedule_followup_jobs(self.create_followup_jobs(load_id, state, job, schema)) + self.create_followup_jobs(load_id, state, job, schema) # move to completed folder after followup jobs are created # in case of exception when creating followup job, the loader will retry operation and try to complete again self.load_storage.normalized_packages.complete_job(load_id, job.file_name()) logger.info(f"Job for {job.job_id()} completed in load {load_id}") + finalized_jobs.append(job) + else: + raise Exception("Incorrect job state") if state in ["failed", "completed"]: self.collector.update("Jobs") @@ -352,7 +446,7 @@ def _schedule_followup_jobs(followup_jobs: Iterable[NewLoadJob]) -> None: "Jobs", 1, message="WARNING: Some of the jobs failed!", label="Failed" ) - return remaining_jobs + return remaining_jobs, finalized_jobs, pending_exception def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) -> None: # do not commit load id for aborted packages @@ -377,6 +471,18 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) f"All jobs completed, archiving package {load_id} with aborted set to {aborted}" ) + def update_loadpackage_info(self, load_id: str) -> None: + # update counter we only care about the jobs that are scheduled to be loaded + package_jobs = self.load_storage.normalized_packages.get_load_package_jobs(load_id) + total_jobs = reduce(lambda p, c: p + len(c), package_jobs.values(), 0) + no_failed_jobs = len(package_jobs["failed_jobs"]) + no_completed_jobs = len(package_jobs["completed_jobs"]) + no_failed_jobs + self.collector.update("Jobs", no_completed_jobs, total_jobs) + if no_failed_jobs > 0: + self.collector.update( + "Jobs", no_failed_jobs, message="WARNING: Some of the jobs failed!", label="Failed" + ) + def load_single_package(self, load_id: str, schema: Schema) -> None: new_jobs = self.get_new_jobs_info(load_id) @@ -424,74 +530,58 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: drop_tables=dropped_tables, truncate_tables=truncated_tables, ) - self.load_storage.commit_schema_update(load_id, applied_update) - # initialize staging destination and spool or retrieve unfinished jobs - if self.staging_destination: - with self.get_staging_destination_client(schema) as staging_client: - jobs_count, jobs = self.retrieve_jobs(job_client, load_id, staging_client) - else: - jobs_count, jobs = self.retrieve_jobs(job_client, load_id) - - if not jobs: - # jobs count is a total number of jobs including those that could not be initialized - jobs_count, jobs = self.spool_new_jobs(load_id, schema) - # if there are no existing or new jobs we complete the package - if jobs_count == 0: - self.complete_package(load_id, schema, False) - return - # update counter we only care about the jobs that are scheduled to be loaded - package_jobs = self.load_storage.normalized_packages.get_load_package_jobs(load_id) - total_jobs = reduce(lambda p, c: p + len(c), package_jobs.values(), 0) - no_failed_jobs = len(package_jobs["failed_jobs"]) - no_completed_jobs = len(package_jobs["completed_jobs"]) + no_failed_jobs - self.collector.update("Jobs", no_completed_jobs, total_jobs) - if no_failed_jobs > 0: - self.collector.update( - "Jobs", no_failed_jobs, message="WARNING: Some of the jobs failed!", label="Failed" - ) + # collect all unfinished jobs + running_jobs: List[LoadJob] = self.resume_started_jobs(load_id, schema) + # loop until all jobs are processed + pending_exception: Optional[LoadClientJobException] = None while True: try: - remaining_jobs = self.complete_jobs(load_id, jobs, schema) - if len(remaining_jobs) == 0: - # get package status - package_jobs = self.load_storage.normalized_packages.get_load_package_jobs( - load_id - ) - # possibly raise on failed jobs - if self.config.raise_on_failed_jobs: - if package_jobs["failed_jobs"]: - failed_job = package_jobs["failed_jobs"][0] - raise LoadClientJobFailed( - load_id, - failed_job.job_id(), - self.load_storage.normalized_packages.get_job_failed_message( - load_id, failed_job - ), - ) - # possibly raise on too many retries - if self.config.raise_on_max_retries: - for new_job in package_jobs["new_jobs"]: - r_c = new_job.retry_count - if r_c > 0 and r_c % self.config.raise_on_max_retries == 0: - raise LoadClientJobRetry( - load_id, - new_job.job_id(), - r_c, - self.config.raise_on_max_retries, - ) + # we continously spool new jobs and complete finished ones + running_jobs, finalized_jobs, new_pending_exception = self.complete_jobs( + load_id, running_jobs, schema + ) + # update load package info if any jobs where finalized + if finalized_jobs: + self.update_loadpackage_info(load_id) + + pending_exception = pending_exception or new_pending_exception + + # do not spool new jobs if there was a signal or an exception was encountered + # we inform the users how many jobs remain when shutting down, but only if the count of running jobs + # has changed (as determined by finalized jobs) + if signals.signal_received(): + if finalized_jobs: + logger.info( + f"Signal received, draining running jobs. {len(running_jobs)} to go." + ) + elif pending_exception: + if finalized_jobs: + logger.info( + f"Exception for job {pending_exception.job_id} received, draining" + f" running jobs.{len(running_jobs)} to go." + ) + else: + running_jobs += self.start_new_jobs(load_id, schema, running_jobs) + + if len(running_jobs) == 0: + # if a pending exception was discovered during completion of jobs + # we can raise it now + if pending_exception: + raise pending_exception break - # process remaining jobs again - jobs = remaining_jobs # this will raise on signal - sleep(1) + sleep(self._run_loop_sleep_duration) except LoadClientJobFailed: # the package is completed and skipped self.complete_package(load_id, schema, True) raise + # no new jobs, load package done + self.complete_package(load_id, schema, False) + def run(self, pool: Optional[Executor]) -> TRunMetrics: # store pool self.pool = pool or NullExecutor() diff --git a/dlt/load/utils.py b/dlt/load/utils.py index 67a813f5f2..9750f89d4b 100644 --- a/dlt/load/utils.py +++ b/dlt/load/utils.py @@ -12,10 +12,7 @@ from dlt.common.storages.load_storage import ParsedLoadJobFileName from dlt.common.schema import Schema, TSchemaTables from dlt.common.schema.typing import TTableSchema -from dlt.common.destination.reference import ( - JobClientBase, - WithStagingDataset, -) +from dlt.common.destination.reference import JobClientBase, WithStagingDataset, LoadJob from dlt.load.configuration import LoaderConfiguration from dlt.common.destination import DestinationCapabilitiesContext @@ -230,10 +227,30 @@ def _extend_tables_with_table_chain( return result +def get_available_worker_slots( + config: LoaderConfiguration, + capabilities: DestinationCapabilitiesContext, + running_jobs: Sequence[LoadJob], +) -> int: + """ + Returns the number of available worker slots + """ + parallelism_strategy = config.parallelism_strategy or capabilities.loader_parallelism_strategy + + # find real max workers value + max_workers = 1 if parallelism_strategy == "sequential" else config.workers + if mp := capabilities.max_parallel_load_jobs: + max_workers = min(max_workers, mp) + + return max(0, max_workers - len(running_jobs)) + + def filter_new_jobs( file_names: Sequence[str], capabilities: DestinationCapabilitiesContext, config: LoaderConfiguration, + running_jobs: Sequence[LoadJob], + available_slots: int, ) -> Sequence[str]: """Filters the list of new jobs to adhere to max_workers and parallellism strategy""" """NOTE: in the current setup we only filter based on settings for the final destination""" @@ -246,24 +263,27 @@ def filter_new_jobs( # config can overwrite destination settings, if nothing is set, code below defaults to parallel parallelism_strategy = config.parallelism_strategy or capabilities.loader_parallelism_strategy - # find real max workers value - max_workers = 1 if parallelism_strategy == "sequential" else config.workers - if mp := capabilities.max_parallel_load_jobs: - max_workers = min(max_workers, mp) - # regular sequential works on all jobs eligible_jobs = file_names # we must ensure there only is one job per table if parallelism_strategy == "table-sequential": - eligible_jobs = sorted( - eligible_jobs, key=lambda j: ParsedLoadJobFileName.parse(j).table_name - ) - eligible_jobs = [ - next(table_jobs) - for _, table_jobs in groupby( - eligible_jobs, lambda j: ParsedLoadJobFileName.parse(j).table_name - ) - ] + # TODO later: this whole code block is a bit inefficient for long lists of jobs + # better would be to keep a list of loadjobinfos in the loader which we can iterate + + # find table names of all currently running jobs + running_tables = {j._parsed_file_name.table_name for j in running_jobs} + new_jobs: List[str] = [] + + for job in eligible_jobs: + if (table_name := ParsedLoadJobFileName.parse(job).table_name) not in running_tables: + running_tables.add(table_name) + new_jobs.append(job) + # exit loop if we have enough + if len(new_jobs) >= available_slots: + break + + return new_jobs - return eligible_jobs[:max_workers] + else: + return eligible_jobs[:available_slots] diff --git a/tests/cli/test_pipeline_command.py b/tests/cli/test_pipeline_command.py index 1f8e2ff4f3..82d74299f8 100644 --- a/tests/cli/test_pipeline_command.py +++ b/tests/cli/test_pipeline_command.py @@ -196,6 +196,7 @@ def test_pipeline_command_failed_jobs(repo_dir: str, project_files: FileStorage) def test_pipeline_command_drop_partial_loads(repo_dir: str, project_files: FileStorage) -> None: init_command.init_command("chess", "dummy", False, repo_dir) + os.environ["EXCEPTION_PROB"] = "1.0" try: pipeline = dlt.attach(pipeline_name="chess_pipeline") @@ -203,14 +204,22 @@ def test_pipeline_command_drop_partial_loads(repo_dir: str, project_files: FileS except Exception as e: print(e) - # now run the pipeline - os.environ["EXCEPTION_PROB"] = "1.0" - os.environ["FAIL_IN_INIT"] = "False" - os.environ["TIMEOUT"] = "1.0" venv = Venv.restore_current() with pytest.raises(CalledProcessError) as cpe: print(venv.run_script("chess_pipeline.py")) - assert "Dummy job status raised exception" in cpe.value.stdout + assert "PipelineStepFailed" in cpe.value.stdout + + # complete job manually to make a partial load + pipeline = dlt.attach(pipeline_name="chess_pipeline") + load_storage = pipeline._get_load_storage() + load_id = load_storage.normalized_packages.list_packages()[0] + job = load_storage.normalized_packages.list_new_jobs(load_id)[0] + load_storage.normalized_packages.start_job( + load_id, FileStorage.get_file_name_from_file_path(job) + ) + load_storage.normalized_packages.complete_job( + load_id, FileStorage.get_file_name_from_file_path(job) + ) with io.StringIO() as buf, contextlib.redirect_stdout(buf): pipeline_command.pipeline_command("info", "chess_pipeline", None, 1) diff --git a/tests/common/storages/test_load_package.py b/tests/common/storages/test_load_package.py index 45bc8d157e..8c4d5a439b 100644 --- a/tests/common/storages/test_load_package.py +++ b/tests/common/storages/test_load_package.py @@ -21,34 +21,69 @@ clear_destination_state, ) -from tests.common.storages.utils import start_loading_file, assert_package_info, load_storage +from tests.common.storages.utils import ( + start_loading_file, + assert_package_info, + load_storage, + start_loading_files, +) from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage def test_is_partially_loaded(load_storage: LoadStorage) -> None: - load_id, file_name = start_loading_file( - load_storage, [{"content": "a"}, {"content": "b"}], start_job=False + load_id, file_names = start_loading_files( + load_storage, [{"content": "a"}, {"content": "b"}], start_job=False, file_count=2 ) info = load_storage.get_load_package_info(load_id) # all jobs are new assert PackageStorage.is_package_partially_loaded(info) is False - # start job - load_storage.normalized_packages.start_job(load_id, file_name) + # start one job + load_storage.normalized_packages.start_job(load_id, file_names[0]) info = load_storage.get_load_package_info(load_id) - assert PackageStorage.is_package_partially_loaded(info) is True + assert PackageStorage.is_package_partially_loaded(info) is False # complete job - load_storage.normalized_packages.complete_job(load_id, file_name) + load_storage.normalized_packages.complete_job(load_id, file_names[0]) + info = load_storage.get_load_package_info(load_id) + assert PackageStorage.is_package_partially_loaded(info) is True + # start second job + load_storage.normalized_packages.start_job(load_id, file_names[1]) info = load_storage.get_load_package_info(load_id) assert PackageStorage.is_package_partially_loaded(info) is True + # finish second job, now not partial anymore + load_storage.normalized_packages.complete_job(load_id, file_names[1]) + info = load_storage.get_load_package_info(load_id) + assert PackageStorage.is_package_partially_loaded(info) is False + # must complete package load_storage.complete_load_package(load_id, False) info = load_storage.get_load_package_info(load_id) assert PackageStorage.is_package_partially_loaded(info) is False - # abort package + # abort package (will never be partially loaded) load_id, file_name = start_loading_file(load_storage, [{"content": "a"}, {"content": "b"}]) load_storage.complete_load_package(load_id, True) info = load_storage.get_load_package_info(load_id) + assert PackageStorage.is_package_partially_loaded(info) is False + + # abort partially loaded will stay partially loaded + load_id, file_names = start_loading_files( + load_storage, [{"content": "a"}, {"content": "b"}], start_job=False, file_count=2 + ) + load_storage.normalized_packages.start_job(load_id, file_names[0]) + load_storage.normalized_packages.complete_job(load_id, file_names[0]) + load_storage.complete_load_package(load_id, True) + info = load_storage.get_load_package_info(load_id) + assert PackageStorage.is_package_partially_loaded(info) is True + + # failed jobs will also result in partial loads, if one job is completed + load_id, file_names = start_loading_files( + load_storage, [{"content": "a"}, {"content": "b"}], start_job=False, file_count=2 + ) + load_storage.normalized_packages.start_job(load_id, file_names[0]) + load_storage.normalized_packages.complete_job(load_id, file_names[0]) + load_storage.normalized_packages.start_job(load_id, file_names[1]) + load_storage.normalized_packages.fail_job(load_id, file_names[1], "much broken, so bad") + info = load_storage.get_load_package_info(load_id) assert PackageStorage.is_package_partially_loaded(info) is True diff --git a/tests/common/storages/test_load_storage.py b/tests/common/storages/test_load_storage.py index 49deaff23e..bdcec4ceb2 100644 --- a/tests/common/storages/test_load_storage.py +++ b/tests/common/storages/test_load_storage.py @@ -8,7 +8,12 @@ from dlt.common.storages.file_storage import FileStorage from dlt.common.storages.load_package import create_load_id -from tests.common.storages.utils import start_loading_file, assert_package_info, load_storage +from tests.common.storages.utils import ( + start_loading_file, + assert_package_info, + load_storage, + start_loading_files, +) from tests.utils import write_version, autouse_test_storage diff --git a/tests/common/storages/utils.py b/tests/common/storages/utils.py index 1b5a68948b..baac3b7af5 100644 --- a/tests/common/storages/utils.py +++ b/tests/common/storages/utils.py @@ -157,25 +157,38 @@ def write_temp_job_file( return Path(file_name).name -def start_loading_file( - s: LoadStorage, content: Sequence[StrAny], start_job: bool = True -) -> Tuple[str, str]: +def start_loading_files( + s: LoadStorage, content: Sequence[StrAny], start_job: bool = True, file_count: int = 1 +) -> Tuple[str, List[str]]: load_id = uniq_id() s.new_packages.create_package(load_id) # write test file - item_storage = s.create_item_storage(DataWriter.writer_spec_from_file_format("jsonl", "object")) - file_name = write_temp_job_file( - item_storage, s.storage, load_id, "mock_table", None, uniq_id(), content - ) + file_names: List[str] = [] + for _ in range(0, file_count): + item_storage = s.create_item_storage( + DataWriter.writer_spec_from_file_format("jsonl", "object") + ) + file_name = write_temp_job_file( + item_storage, s.storage, load_id, "mock_table", None, uniq_id(), content + ) + file_names.append(file_name) # write schema and schema update s.new_packages.save_schema(load_id, Schema("mock")) s.new_packages.save_schema_updates(load_id, {}) s.commit_new_load_package(load_id) - assert_package_info(s, load_id, "normalized", "new_jobs") + assert_package_info(s, load_id, "normalized", "new_jobs", jobs_count=file_count) if start_job: - s.normalized_packages.start_job(load_id, file_name) - assert_package_info(s, load_id, "normalized", "started_jobs") - return load_id, file_name + for file_name in file_names: + s.normalized_packages.start_job(load_id, file_name) + assert_package_info(s, load_id, "normalized", "started_jobs") + return load_id, file_names + + +def start_loading_file( + s: LoadStorage, content: Sequence[StrAny], start_job: bool = True +) -> Tuple[str, str]: + load_id, file_names = start_loading_files(s, content, start_job) + return load_id, file_names[0] def assert_package_info( diff --git a/tests/load/bigquery/test_bigquery_client.py b/tests/load/bigquery/test_bigquery_client.py index a74ab11860..80bd008730 100644 --- a/tests/load/bigquery/test_bigquery_client.py +++ b/tests/load/bigquery/test_bigquery_client.py @@ -19,7 +19,7 @@ from dlt.common.schema.utils import new_table from dlt.common.storages import FileStorage from dlt.common.utils import digest128, uniq_id, custom_environ - +from dlt.common.destination.reference import RunnableLoadJob from dlt.destinations.impl.bigquery.bigquery import BigQueryClient, BigQueryClientConfiguration from dlt.destinations.exceptions import LoadJobNotExistsException, LoadJobTerminalException @@ -268,30 +268,7 @@ def test_bigquery_autodetect_configuration(client: BigQueryClient) -> None: assert client._should_autodetect_schema("event_slot__values") is True -def test_bigquery_job_errors(client: BigQueryClient, file_storage: FileStorage) -> None: - # non existing job - with pytest.raises(LoadJobNotExistsException): - client.restore_file_load(f"{uniq_id()}.") - - # bad name - with pytest.raises(LoadJobTerminalException): - client.restore_file_load("!!&*aaa") - - user_table_name = prepare_table(client) - - # start a job with non-existing file - with pytest.raises(FileNotFoundError): - client.start_file_load( - client.schema.get_table(user_table_name), - f"{uniq_id()}.", - uniq_id(), - ) - - # start a job with invalid name - dest_path = file_storage.save("!!aaaa", b"data") - with pytest.raises(LoadJobTerminalException): - client.start_file_load(client.schema.get_table(user_table_name), dest_path, uniq_id()) - +def test_bigquery_job_resuming(client: BigQueryClient, file_storage: FileStorage) -> None: user_table_name = prepare_table(client) load_json = { "_dlt_id": uniq_id(), @@ -300,14 +277,23 @@ def test_bigquery_job_errors(client: BigQueryClient, file_storage: FileStorage) "timestamp": str(pendulum.now()), } job = expect_load_file(client, file_storage, json.dumps(load_json), user_table_name) + assert job._created_job # type: ignore # start a job from the same file. it should be a fallback to retrieve a job silently - r_job = client.start_file_load( - client.schema.get_table(user_table_name), - file_storage.make_full_path(job.file_name()), - uniq_id(), + r_job = cast( + RunnableLoadJob, + client.create_load_job( + client.schema.get_table(user_table_name), + file_storage.make_full_path(job.file_name()), + uniq_id(), + ), ) + + # job will be automatically found and resumed + r_job.set_run_vars(uniq_id(), client.schema, client.schema.tables[user_table_name]) + r_job.run_managed(client) assert r_job.state() == "completed" + assert r_job._resumed_job # type: ignore @pytest.mark.parametrize("location", ["US", "EU"]) @@ -325,7 +311,7 @@ def test_bigquery_location(location: str, file_storage: FileStorage, client) -> job = expect_load_file(client, file_storage, json.dumps(load_json), user_table_name) # start a job from the same file. it should be a fallback to retrieve a job silently - client.start_file_load( + client.create_load_job( client.schema.get_table(user_table_name), file_storage.make_full_path(job.file_name()), uniq_id(), diff --git a/tests/load/cases/loading/event_loop_interrupted.1234.0.jsonl b/tests/load/cases/loading/event_loop_interrupted.1234.0.jsonl new file mode 100644 index 0000000000..8baec57d5c --- /dev/null +++ b/tests/load/cases/loading/event_loop_interrupted.1234.0.jsonl @@ -0,0 +1 @@ +small file that is never read \ No newline at end of file diff --git a/tests/load/cases/loading/event_user.1234.0.jsonl b/tests/load/cases/loading/event_user.1234.0.jsonl new file mode 100644 index 0000000000..8baec57d5c --- /dev/null +++ b/tests/load/cases/loading/event_user.1234.0.jsonl @@ -0,0 +1 @@ +small file that is never read \ No newline at end of file diff --git a/tests/load/filesystem/utils.py b/tests/load/filesystem/utils.py index ce15997ed6..bb4153da5c 100644 --- a/tests/load/filesystem/utils.py +++ b/tests/load/filesystem/utils.py @@ -14,11 +14,11 @@ from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_section_context import ConfigSectionContext -from dlt.common.destination.reference import LoadJob +from dlt.common.destination.reference import RunnableLoadJob from dlt.common.pendulum import timedelta, __utcnow from dlt.destinations import filesystem from dlt.destinations.impl.filesystem.filesystem import FilesystemClient -from dlt.destinations.job_impl import EmptyLoadJob +from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs from dlt.load import Load from tests.load.utils import prepare_load_package @@ -34,7 +34,7 @@ def setup_loader(dataset_name: str) -> Load: @contextmanager def perform_load( dataset_name: str, cases: Sequence[str], write_disposition: str = "append" -) -> Iterator[Tuple[FilesystemClient, List[LoadJob], str, str]]: +) -> Iterator[Tuple[FilesystemClient, List[RunnableLoadJob], str, str]]: load = setup_loader(dataset_name) load_id, schema = prepare_load_package(load.load_storage, cases, write_disposition) client: FilesystemClient = load.get_destination_client(schema) # type: ignore[assignment] @@ -54,13 +54,13 @@ def perform_load( try: jobs = [] for f in files: - job = Load.w_spool_job(load, f, load_id, schema) + job = load.submit_job(f, load_id, schema) # job execution failed - if isinstance(job, EmptyLoadJob): + if isinstance(job, FinalizedLoadJobWithFollowupJobs): raise RuntimeError(job.exception()) jobs.append(job) - yield client, jobs, root_path, load_id + yield client, jobs, root_path, load_id # type: ignore finally: try: client.drop_storage() diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index 7ad571f2aa..8da43799bf 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -19,6 +19,7 @@ from dlt.destinations.impl.filesystem.filesystem import FilesystemClient from dlt.destinations.impl.filesystem.typing import TExtraPlaceholders from dlt.pipeline.exceptions import PipelineStepFailed +from dlt.load.exceptions import LoadClientJobRetry from tests.cases import arrow_table_all_data_types, table_update_and_row, assert_all_data_types_row from tests.common.utils import load_json_case @@ -242,7 +243,11 @@ def foo(): with pytest.raises(PipelineStepFailed) as pip_ex: pipeline.run(foo()) - assert isinstance(pip_ex.value.__context__, DependencyVersionException) + assert isinstance(pip_ex.value.__context__, LoadClientJobRetry) + assert ( + "`pyarrow>=17.0.0` is needed for `delta` table format on `filesystem` destination" + in pip_ex.value.__context__.retry_message + ) @pytest.mark.essential diff --git a/tests/load/redshift/test_redshift_client.py b/tests/load/redshift/test_redshift_client.py index bb923df673..41287fcd2d 100644 --- a/tests/load/redshift/test_redshift_client.py +++ b/tests/load/redshift/test_redshift_client.py @@ -90,9 +90,10 @@ def test_text_too_long(client: RedshiftClient, file_storage: FileStorage) -> Non # print(len(max_len_str_b)) row_id = uniq_id() insert_values = f"('{row_id}', '{uniq_id()}', '{max_len_str}' , '{str(pendulum.now())}');" - with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) - assert type(exv.value.dbapi_exception) is psycopg2.errors.StringDataRightTruncation + job = expect_load_file( + client, file_storage, insert_sql + insert_values, user_table_name, "failed" + ) + assert type(job._exception.dbapi_exception) is psycopg2.errors.StringDataRightTruncation # type: ignore def test_wei_value(client: RedshiftClient, file_storage: FileStorage) -> None: @@ -107,9 +108,10 @@ def test_wei_value(client: RedshiftClient, file_storage: FileStorage) -> None: f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," f" '{str(pendulum.now())}', {10**38});" ) - with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) - assert type(exv.value.dbapi_exception) is psycopg2.errors.InternalError_ + job = expect_load_file( + client, file_storage, insert_sql + insert_values, user_table_name, "failed" + ) + assert type(job._exception.dbapi_exception) is psycopg2.errors.InternalError_ # type: ignore def test_schema_string_exceeds_max_text_length(client: RedshiftClient) -> None: diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index be917672f1..b55f4ceece 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -1,6 +1,6 @@ import os from concurrent.futures import ThreadPoolExecutor -from time import sleep +from time import sleep, time from unittest import mock import pytest from unittest.mock import patch @@ -10,7 +10,7 @@ from dlt.common.storages import FileStorage, PackageStorage, ParsedLoadJobFileName from dlt.common.storages.load_package import LoadJobInfo, TJobState from dlt.common.storages.load_storage import JobFileFormatUnsupported -from dlt.common.destination.reference import LoadJob, TDestination +from dlt.common.destination.reference import RunnableLoadJob, TDestination from dlt.common.schema.utils import ( fill_hints_from_parent_and_clone_table, get_child_tables, @@ -18,13 +18,17 @@ ) from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration -from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations import dummy, filesystem from dlt.destinations.impl.dummy import dummy as dummy_impl from dlt.destinations.impl.dummy.configuration import DummyClientConfiguration from dlt.load import Load -from dlt.load.exceptions import LoadClientJobFailed, LoadClientJobRetry +from dlt.load.exceptions import ( + LoadClientJobFailed, + LoadClientJobRetry, + TableChainFollowupJobCreationFailedException, + FollowupJobCreationFailedException, +) from dlt.load.utils import get_completed_table_chain, init_client, _extend_tables_with_table_chain from tests.utils import ( @@ -42,6 +46,8 @@ "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl", ] +SMALL_FILES = ["event_user.1234.0.jsonl", "event_loop_interrupted.1234.0.jsonl"] + REMOTE_FILESYSTEM = os.path.abspath(os.path.join(TEST_STORAGE_ROOT, "_remote_filesystem")) @@ -61,20 +67,21 @@ def test_spool_job_started() -> None: load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) files = load.load_storage.normalized_packages.list_new_jobs(load_id) assert len(files) == 2 - jobs: List[LoadJob] = [] + jobs: List[RunnableLoadJob] = [] for f in files: - job = Load.w_spool_job(load, f, load_id, schema) + job = load.submit_job(f, load_id, schema) + assert job.state() == "completed" assert type(job) is dummy_impl.LoadDummyJob - assert job.state() == "running" + # jobs runs, but is not moved yet (loader will do this) assert load.load_storage.normalized_packages.storage.has_file( load.load_storage.normalized_packages.get_job_file_path( load_id, PackageStorage.STARTED_JOBS_FOLDER, job.file_name() ) ) jobs.append(job) - # still running - remaining_jobs = load.complete_jobs(load_id, jobs, schema) - assert len(remaining_jobs) == 2 + remaining_jobs, finalized_jobs, _ = load.complete_jobs(load_id, jobs, schema) + assert len(remaining_jobs) == 0 + assert len(finalized_jobs) == 2 def test_unsupported_writer_type() -> None: @@ -87,6 +94,7 @@ def test_unsupported_writer_type() -> None: def test_unsupported_write_disposition() -> None: + # tests terminal error on retrieving job load = setup_loader() load_id, schema = prepare_load_package(load.load_storage, [NORMALIZED_FILES[0]]) # mock unsupported disposition @@ -96,13 +104,36 @@ def test_unsupported_write_disposition() -> None: with ThreadPoolExecutor() as pool: load.run(pool) # job with unsupported write disp. is failed - failed_job = load.load_storage.normalized_packages.list_failed_jobs(load_id)[0] - failed_message = load.load_storage.normalized_packages.get_job_failed_message( + failed_job = load.load_storage.loaded_packages.list_failed_jobs(load_id)[0] + failed_message = load.load_storage.loaded_packages.get_job_failed_message( load_id, ParsedLoadJobFileName.parse(failed_job) ) assert "LoadClientUnsupportedWriteDisposition" in failed_message +def test_big_loadpackages() -> None: + """ + This test guards against changes in the load that exponentially makes the loads slower + """ + + load = setup_loader() + # make the loop faster by basically not sleeping + load._run_loop_sleep_duration = 0.001 + load_id, schema = prepare_load_package(load.load_storage, SMALL_FILES, jobs_per_case=500) + start_time = time() + with ThreadPoolExecutor(max_workers=20) as pool: + load.run(pool) + duration = float(time() - start_time) + + # sanity check + assert duration > 3 + # we want 1000 empty processed jobs to need less than 15 seconds total (locally it runs in 5) + assert duration < 15 + + # we should have 1000 jobs processed + assert len(dummy_impl.JOBS) == 1000 + + def test_get_new_jobs_info() -> None: load = setup_loader() load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) @@ -158,10 +189,10 @@ def test_spool_job_failed() -> None: load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0)) load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) files = load.load_storage.normalized_packages.list_new_jobs(load_id) - jobs: List[LoadJob] = [] + jobs: List[RunnableLoadJob] = [] for f in files: - job = Load.w_spool_job(load, f, load_id, schema) - assert type(job) is EmptyLoadJob + job = load.submit_job(f, load_id, schema) + assert type(job) is dummy_impl.LoadDummyJob assert job.state() == "failed" assert load.load_storage.normalized_packages.storage.has_file( load.load_storage.normalized_packages.get_job_file_path( @@ -170,8 +201,9 @@ def test_spool_job_failed() -> None: ) jobs.append(job) # complete files - remaining_jobs = load.complete_jobs(load_id, jobs, schema) + remaining_jobs, finalized_jobs, _ = load.complete_jobs(load_id, jobs, schema) assert len(remaining_jobs) == 0 + assert len(finalized_jobs) == 2 for job in jobs: assert load.load_storage.normalized_packages.storage.has_file( load.load_storage.normalized_packages.get_job_file_path( @@ -196,11 +228,10 @@ def test_spool_job_failed() -> None: assert len(package_info.jobs["failed_jobs"]) == 2 -def test_spool_job_failed_exception_init() -> None: +def test_spool_job_failed_terminally_exception_init() -> None: # this config fails job on start os.environ["LOAD__RAISE_ON_FAILED_JOBS"] = "true" - os.environ["FAIL_IN_INIT"] = "true" - load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0, fail_in_init=True)) + load = setup_loader(client_config=DummyClientConfiguration(fail_terminally_in_init=True)) load_id, _ = prepare_load_package(load.load_storage, NORMALIZED_FILES) with patch.object(dummy_impl.DummyClient, "complete_load") as complete_load: with pytest.raises(LoadClientJobFailed) as py_ex: @@ -215,11 +246,30 @@ def test_spool_job_failed_exception_init() -> None: complete_load.assert_not_called() +def test_spool_job_failed_transiently_exception_init() -> None: + # this config fails job on start + os.environ["LOAD__RAISE_ON_FAILED_JOBS"] = "true" + load = setup_loader(client_config=DummyClientConfiguration(fail_transiently_in_init=True)) + load_id, _ = prepare_load_package(load.load_storage, NORMALIZED_FILES) + with patch.object(dummy_impl.DummyClient, "complete_load") as complete_load: + with pytest.raises(LoadClientJobRetry) as py_ex: + run_all(load) + assert py_ex.value.load_id == load_id + package_info = load.load_storage.get_load_package_info(load_id) + assert package_info.state == "normalized" + # both failed - we wait till the current loop is completed and then raise + assert len(package_info.jobs["failed_jobs"]) == 0 + assert len(package_info.jobs["started_jobs"]) == 0 + assert len(package_info.jobs["new_jobs"]) == 2 + + # load id was never committed + complete_load.assert_not_called() + + def test_spool_job_failed_exception_complete() -> None: # this config fails job on start os.environ["LOAD__RAISE_ON_FAILED_JOBS"] = "true" - os.environ["FAIL_IN_INIT"] = "false" - load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0, fail_in_init=False)) + load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0)) load_id, _ = prepare_load_package(load.load_storage, NORMALIZED_FILES) with pytest.raises(LoadClientJobFailed) as py_ex: run_all(load) @@ -237,7 +287,7 @@ def test_spool_job_retry_new() -> None: load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) files = load.load_storage.normalized_packages.list_new_jobs(load_id) for f in files: - job = Load.w_spool_job(load, f, load_id, schema) + job = load.submit_job(f, load_id, schema) assert job.state() == "retry" @@ -248,8 +298,7 @@ def test_spool_job_retry_spool_new() -> None: # call higher level function that returns jobs and counts with ThreadPoolExecutor() as pool: load.pool = pool - jobs_count, jobs = load.spool_new_jobs(load_id, schema) - assert jobs_count == 2 + jobs = load.start_new_jobs(load_id, schema, []) assert len(jobs) == 2 @@ -259,24 +308,26 @@ def test_spool_job_retry_started() -> None: # dummy_impl.CLIENT_CONFIG = DummyClientConfiguration load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) files = load.load_storage.normalized_packages.list_new_jobs(load_id) - jobs: List[LoadJob] = [] + jobs: List[RunnableLoadJob] = [] for f in files: - job = Load.w_spool_job(load, f, load_id, schema) + job = load.submit_job(f, load_id, schema) assert type(job) is dummy_impl.LoadDummyJob - assert job.state() == "running" + assert job.state() == "completed" + # mock job state to make it retry + job.config.retry_prob = 1.0 + job._state = "retry" assert load.load_storage.normalized_packages.storage.has_file( load.load_storage.normalized_packages.get_job_file_path( load_id, PackageStorage.STARTED_JOBS_FOLDER, job.file_name() ) ) - # mock job config to make it retry - job.config.retry_prob = 1.0 jobs.append(job) files = load.load_storage.normalized_packages.list_new_jobs(load_id) assert len(files) == 0 - # should retry, that moves jobs into new folder - remaining_jobs = load.complete_jobs(load_id, jobs, schema) + # should retry, that moves jobs into new folder, jobs are not counted as finalized + remaining_jobs, finalized_jobs, _ = load.complete_jobs(load_id, jobs, schema) assert len(remaining_jobs) == 0 + assert len(finalized_jobs) == 0 # clear retry flag dummy_impl.JOBS = {} files = load.load_storage.normalized_packages.list_new_jobs(load_id) @@ -285,9 +336,11 @@ def test_spool_job_retry_started() -> None: for fn in load.load_storage.normalized_packages.list_new_jobs(load_id): # we failed when already running the job so retry count will increase assert ParsedLoadJobFileName.parse(fn).retry_count == 1 + + # this time it will pass for f in files: - job = Load.w_spool_job(load, f, load_id, schema) - assert job.state() == "running" + job = load.submit_job(f, load_id, schema) + assert job.state() == "completed" def test_try_retrieve_job() -> None: @@ -301,22 +354,21 @@ def test_try_retrieve_job() -> None: ) # dummy client may retrieve jobs that it created itself, jobs in started folder are unknown # and returned as terminal - with load.destination.client(schema, load.initial_client_config) as c: - job_count, jobs = load.retrieve_jobs(c, load_id) - assert job_count == 2 - for j in jobs: - assert j.state() == "failed" + jobs = load.resume_started_jobs(load_id, schema) + assert len(jobs) == 2 + for j in jobs: + assert j.state() == "failed" # new load package load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) load.pool = ThreadPoolExecutor() - jobs_count, jobs = load.spool_new_jobs(load_id, schema) - assert jobs_count == 2 + jobs = load.start_new_jobs(load_id, schema, []) # type: ignore + assert len(jobs) == 2 # now jobs are known - with load.destination.client(schema, load.initial_client_config) as c: - job_count, jobs = load.retrieve_jobs(c, load_id) - assert job_count == 2 - for j in jobs: - assert j.state() == "running" + jobs = load.resume_started_jobs(load_id, schema) + assert len(jobs) == 2 + for j in jobs: + assert j.state() == "completed" + assert len(dummy_impl.RETRIED_JOBS) == 2 def test_completed_loop() -> None: @@ -328,7 +380,6 @@ def test_completed_loop() -> None: def test_completed_loop_followup_jobs() -> None: # TODO: until we fix how we create capabilities we must set env - os.environ["CREATE_FOLLOWUP_JOBS"] = "true" load = setup_loader( client_config=DummyClientConfiguration(completed_prob=1.0, create_followup_jobs=True) ) @@ -338,6 +389,95 @@ def test_completed_loop_followup_jobs() -> None: assert len(dummy_impl.JOBS) == len(dummy_impl.CREATED_FOLLOWUP_JOBS) * 2 +def test_failing_followup_jobs() -> None: + load = setup_loader( + client_config=DummyClientConfiguration( + completed_prob=1.0, create_followup_jobs=True, fail_followup_job_creation=True + ) + ) + with pytest.raises(FollowupJobCreationFailedException) as exc: + assert_complete_job(load) + # follow up job errors on main thread + assert "Failed to create followup job" in str(exc) + + # followup job fails, we have both jobs in started folder + load_id = list(dummy_impl.JOBS.values())[1]._load_id + started_files = load.load_storage.normalized_packages.list_started_jobs(load_id) + assert len(started_files) == 2 + assert len(dummy_impl.JOBS) == 2 + assert len(dummy_impl.RETRIED_JOBS) == 0 + assert len(dummy_impl.CREATED_FOLLOWUP_JOBS) == 0 + + # now we can retry the same load, it will restart the two jobs and successfully create the followup jobs + load.initial_client_config.fail_followup_job_creation = False # type: ignore + assert_complete_job(load, load_id=load_id) + assert len(dummy_impl.JOBS) == 2 * 2 + assert len(dummy_impl.JOBS) == len(dummy_impl.CREATED_FOLLOWUP_JOBS) * 2 + assert len(dummy_impl.RETRIED_JOBS) == 2 + + +def test_failing_table_chain_followup_jobs() -> None: + load = setup_loader( + client_config=DummyClientConfiguration( + completed_prob=1.0, + create_followup_table_chain_reference_jobs=True, + fail_table_chain_followup_job_creation=True, + ) + ) + with pytest.raises(TableChainFollowupJobCreationFailedException) as exc: + assert_complete_job(load) + # follow up job errors on main thread + assert "Failed creating table chain followup jobs for table chain with root table" in str(exc) + + # table chain followup job fails, we have both jobs in started folder + load_id = list(dummy_impl.JOBS.values())[1]._load_id + started_files = load.load_storage.normalized_packages.list_started_jobs(load_id) + assert len(started_files) == 2 + assert len(dummy_impl.JOBS) == 2 + assert len(dummy_impl.RETRIED_JOBS) == 0 + assert len(dummy_impl.CREATED_FOLLOWUP_JOBS) == 0 + + # now we can retry the same load, it will restart the two jobs and successfully create the table chain followup jobs + load.initial_client_config.fail_table_chain_followup_job_creation = False # type: ignore + assert_complete_job(load, load_id=load_id) + assert len(dummy_impl.JOBS) == 2 * 2 + assert len(dummy_impl.JOBS) == len(dummy_impl.CREATED_TABLE_CHAIN_FOLLOWUP_JOBS) * 2 + assert len(dummy_impl.RETRIED_JOBS) == 2 + + +def test_failing_sql_table_chain_job() -> None: + """ + Make sure we get a useful exception from a failing sql job + """ + load = setup_loader( + client_config=DummyClientConfiguration( + completed_prob=1.0, create_followup_table_chain_sql_jobs=True + ), + ) + with pytest.raises(Exception) as exc: + assert_complete_job(load) + + # sql jobs always fail because this is not an sql client, we just make sure the exception is there + assert "Failed creating table chain followup jobs for table chain with root table" in str(exc) + + +def test_successful_table_chain_jobs() -> None: + load = setup_loader( + client_config=DummyClientConfiguration( + completed_prob=1.0, create_followup_table_chain_reference_jobs=True + ), + ) + # we create 10 jobs per case (for two cases) + # and expect two table chain jobs at the end + assert_complete_job(load, jobs_per_case=10) + assert len(dummy_impl.CREATED_TABLE_CHAIN_FOLLOWUP_JOBS) == 2 + assert len(dummy_impl.JOBS) == 22 + + # check that we have 10 references per followup job + for _, job in dummy_impl.CREATED_TABLE_CHAIN_FOLLOWUP_JOBS.items(): + assert len(job._remote_paths) == 10 # type: ignore + + def test_failed_loop() -> None: # ask to delete completed load = setup_loader( @@ -345,21 +485,18 @@ def test_failed_loop() -> None: ) # actually not deleted because one of the jobs failed assert_complete_job(load, should_delete_completed=False) - # no jobs because fail on init - assert len(dummy_impl.JOBS) == 0 + # two failed jobs + assert len(dummy_impl.JOBS) == 2 + assert list(dummy_impl.JOBS.values())[0].state() == "failed" + assert list(dummy_impl.JOBS.values())[1].state() == "failed" assert len(dummy_impl.CREATED_FOLLOWUP_JOBS) == 0 def test_failed_loop_followup_jobs() -> None: - # TODO: until we fix how we create capabilities we must set env - os.environ["CREATE_FOLLOWUP_JOBS"] = "true" - os.environ["FAIL_IN_INIT"] = "false" # ask to delete completed load = setup_loader( delete_completed_jobs=True, - client_config=DummyClientConfiguration( - fail_prob=1.0, fail_in_init=False, create_followup_jobs=True - ), + client_config=DummyClientConfiguration(fail_prob=1.0, create_followup_jobs=True), ) # actually not deleted because one of the jobs failed assert_complete_job(load, should_delete_completed=False) @@ -381,36 +518,36 @@ def test_retry_on_new_loop() -> None: load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) with ThreadPoolExecutor() as pool: # 1st retry - load.run(pool) + with pytest.raises(LoadClientJobRetry): + load.run(pool) files = load.load_storage.normalized_packages.list_new_jobs(load_id) assert len(files) == 2 # 2nd retry - load.run(pool) + with pytest.raises(LoadClientJobRetry): + load.run(pool) files = load.load_storage.normalized_packages.list_new_jobs(load_id) assert len(files) == 2 - # jobs will be completed + # package will be completed load = setup_loader(client_config=DummyClientConfiguration(completed_prob=1.0)) load.run(pool) - files = load.load_storage.normalized_packages.list_new_jobs(load_id) - assert len(files) == 0 - # complete package - load.run(pool) assert not load.load_storage.normalized_packages.storage.has_folder( load.load_storage.get_normalized_package_path(load_id) ) + sleep(1) # parse the completed job names completed_path = load.load_storage.loaded_packages.get_package_path(load_id) for fn in load.load_storage.loaded_packages.storage.list_folder_files( os.path.join(completed_path, PackageStorage.COMPLETED_JOBS_FOLDER) ): - # we update a retry count in each case - assert ParsedLoadJobFileName.parse(fn).retry_count == 2 + # we update a retry count in each case (5 times for each loop run) + assert ParsedLoadJobFileName.parse(fn).retry_count == 10 def test_retry_exceptions() -> None: load = setup_loader(client_config=DummyClientConfiguration(retry_prob=1.0)) prepare_load_package(load.load_storage, NORMALIZED_FILES) + with ThreadPoolExecutor() as pool: # 1st retry with pytest.raises(LoadClientJobRetry) as py_ex: @@ -418,7 +555,6 @@ def test_retry_exceptions() -> None: load.run(pool) # configured to retry 5 times before exception assert py_ex.value.max_retry_count == py_ex.value.retry_count == 5 - # we can do it again with pytest.raises(LoadClientJobRetry) as py_ex: while True: @@ -730,8 +866,13 @@ def test_terminal_exceptions() -> None: raise AssertionError() -def assert_complete_job(load: Load, should_delete_completed: bool = False) -> None: - load_id, _ = prepare_load_package(load.load_storage, NORMALIZED_FILES) +def assert_complete_job( + load: Load, should_delete_completed: bool = False, load_id: str = None, jobs_per_case: int = 1 +) -> None: + if not load_id: + load_id, _ = prepare_load_package( + load.load_storage, NORMALIZED_FILES, jobs_per_case=jobs_per_case + ) # will complete all jobs timestamp = "2024-04-05T09:16:59.942779Z" mocked_timestamp = {"state": {"created_at": timestamp}} @@ -744,22 +885,7 @@ def assert_complete_job(load: Load, should_delete_completed: bool = False) -> No ) as complete_load: with ThreadPoolExecutor() as pool: load.run(pool) - # did process schema update - assert load.load_storage.storage.has_file( - os.path.join( - load.load_storage.get_normalized_package_path(load_id), - PackageStorage.APPLIED_SCHEMA_UPDATES_FILE_NAME, - ) - ) - # will finalize the whole package - load.run(pool) - # may have followup jobs or staging destination - if ( - load.initial_client_config.create_followup_jobs # type:ignore[attr-defined] - or load.staging_destination - ): - # run the followup jobs - load.run(pool) + # moved to loaded assert not load.load_storage.storage.has_folder( load.load_storage.get_normalized_package_path(load_id) @@ -767,6 +893,15 @@ def assert_complete_job(load: Load, should_delete_completed: bool = False) -> No completed_path = load.load_storage.loaded_packages.get_job_state_folder_path( load_id, "completed_jobs" ) + + # should have migrated the schema + assert load.load_storage.storage.has_file( + os.path.join( + load.load_storage.get_loaded_package_path(load_id), + PackageStorage.APPLIED_SCHEMA_UPDATES_FILE_NAME, + ) + ) + if should_delete_completed: # package was deleted assert not load.load_storage.loaded_packages.storage.has_folder(completed_path) @@ -794,14 +929,21 @@ def setup_loader( # reset jobs for a test dummy_impl.JOBS = {} dummy_impl.CREATED_FOLLOWUP_JOBS = {} - client_config = client_config or DummyClientConfiguration(loader_file_format="jsonl") + dummy_impl.RETRIED_JOBS = {} + dummy_impl.CREATED_TABLE_CHAIN_FOLLOWUP_JOBS = {} + + client_config = client_config or DummyClientConfiguration( + loader_file_format="jsonl", completed_prob=1 + ) destination: TDestination = dummy(**client_config) # type: ignore[assignment] # setup staging_system_config = None staging = None if filesystem_staging: # do not accept jsonl to not conflict with filesystem destination - client_config = client_config or DummyClientConfiguration(loader_file_format="reference") # type: ignore[arg-type] + client_config = client_config or DummyClientConfiguration( + loader_file_format="reference", completed_prob=1 + ) staging_system_config = FilesystemDestinationClientConfiguration()._bind_dataset_name( dataset_name="dummy" ) diff --git a/tests/load/test_insert_job_client.py b/tests/load/test_insert_job_client.py index 38155a8b09..c40e83e027 100644 --- a/tests/load/test_insert_job_client.py +++ b/tests/load/test_insert_job_client.py @@ -114,24 +114,27 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," f" '{str(pendulum.now())}', NULL);" ) - with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) - assert type(exv.value.dbapi_exception) is TUndefinedColumn + job = expect_load_file( + client, file_storage, insert_sql + insert_values, user_table_name, "failed" + ) + assert type(job._exception.dbapi_exception) is TUndefinedColumn # type: ignore # insert null value insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" insert_values = f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd', NULL);" - with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) - assert type(exv.value.dbapi_exception) is TNotNullViolation + job = expect_load_file( + client, file_storage, insert_sql + insert_values, user_table_name, "failed" + ) + assert type(job._exception.dbapi_exception) is TNotNullViolation # type: ignore # insert wrong type insert_sql = "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp)\nVALUES\n" insert_values = ( f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," f" {client.capabilities.escape_literal(True)});" ) - with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) - assert type(exv.value.dbapi_exception) is TDatatypeMismatch + job = expect_load_file( + client, file_storage, insert_sql + insert_values, user_table_name, "failed" + ) + assert type(job._exception.dbapi_exception) is TDatatypeMismatch # type: ignore # numeric overflow on bigint insert_sql = ( "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp, metadata__rasa_x_id)\nVALUES\n" @@ -141,9 +144,10 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," f" '{str(pendulum.now())}', {2**64//2});" ) - with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) - assert type(exv.value.dbapi_exception) == TNumericValueOutOfRange + job = expect_load_file( + client, file_storage, insert_sql + insert_values, user_table_name, "failed" + ) + assert type(job._exception) == DatabaseTerminalException # type: ignore # numeric overflow on NUMERIC insert_sql = ( "INSERT INTO {}(_dlt_id, _dlt_root_id, sender_id, timestamp," @@ -164,10 +168,13 @@ def test_loading_errors(client: InsertValuesJobClient, file_storage: FileStorage f"('{uniq_id()}', '{uniq_id()}', '90238094809sajlkjxoiewjhduuiuehd'," f" '{str(pendulum.now())}', {above_limit});" ) - with pytest.raises(DatabaseTerminalException) as exv: - expect_load_file(client, file_storage, insert_sql + insert_values, user_table_name) + job = expect_load_file( + client, file_storage, insert_sql + insert_values, user_table_name, "failed" + ) + assert type(job._exception) == DatabaseTerminalException # type: ignore + assert ( - type(exv.value.dbapi_exception) == psycopg2.errors.InternalError_ + type(job._exception.dbapi_exception) == psycopg2.errors.InternalError_ # type: ignore if dtype == "redshift" else TNumericValueOutOfRange ) diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 614eb17da1..fdc0140a56 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -707,7 +707,7 @@ def test_write_dispositions( @pytest.mark.parametrize( "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name ) -def test_retrieve_job(client: SqlJobClientBase, file_storage: FileStorage) -> None: +def test_get_resumed_job(client: SqlJobClientBase, file_storage: FileStorage) -> None: if not client.capabilities.preferred_loader_file_format: pytest.skip("preferred loader file format not set, destination will only work with staging") user_table_name = prepare_table(client) @@ -723,11 +723,13 @@ def test_retrieve_job(client: SqlJobClientBase, file_storage: FileStorage) -> No job = expect_load_file(client, file_storage, dataset, user_table_name) # now try to retrieve the job # TODO: we should re-create client instance as this call is intended to be run after some disruption ie. stopped loader process - r_job = client.restore_file_load(file_storage.make_full_path(job.file_name())) - assert r_job.state() == "completed" - # use just file name to restore - r_job = client.restore_file_load(job.file_name()) - assert r_job.state() == "completed" + r_job = client.create_load_job( + client.schema.get_table(user_table_name), + file_storage.make_full_path(job.file_name()), + uniq_id(), + restore=True, + ) + assert r_job.state() == "ready" @pytest.mark.parametrize( diff --git a/tests/load/test_jobs.py b/tests/load/test_jobs.py new file mode 100644 index 0000000000..69f5fb9ddc --- /dev/null +++ b/tests/load/test_jobs.py @@ -0,0 +1,75 @@ +import pytest + +from dlt.common.destination.reference import RunnableLoadJob +from dlt.common.destination.exceptions import DestinationTerminalException +from dlt.destinations.job_impl import FinalizedLoadJob + + +def test_instantiate_job() -> None: + file_name = "table.1234.0.jsonl" + file_path = "/path/" + file_name + + class SomeJob(RunnableLoadJob): + def run(self) -> None: + pass + + j = SomeJob(file_path) + assert j._file_name == file_name + assert j._file_path == file_path + + # providing only a filename is not allowed + with pytest.raises(AssertionError): + SomeJob(file_name) + + +def test_runnable_job_results() -> None: + file_path = "/table.1234.0.jsonl" + + class MockClient: + def prepare_load_job_execution(self, j: RunnableLoadJob): + pass + + class SuccessfulJob(RunnableLoadJob): + def run(self) -> None: + 5 + 5 + + j: RunnableLoadJob = SuccessfulJob(file_path) + assert j.state() == "ready" + j.run_managed(MockClient()) # type: ignore + assert j.state() == "completed" + + class RandomExceptionJob(RunnableLoadJob): + def run(self) -> None: + raise Exception("Oh no!") + + j = RandomExceptionJob(file_path) + assert j.state() == "ready" + j.run_managed(MockClient()) # type: ignore + assert j.state() == "retry" + assert j.exception() == "Oh no!" + + class TerminalJob(RunnableLoadJob): + def run(self) -> None: + raise DestinationTerminalException("Oh no!") + + j = TerminalJob(file_path) + assert j.state() == "ready" + j.run_managed(MockClient()) # type: ignore + assert j.state() == "failed" + assert j.exception() == "Oh no!" + + +def test_finalized_load_job() -> None: + file_name = "table.1234.0.jsonl" + file_path = "/path/" + file_name + j = FinalizedLoadJob(file_path) + assert j.state() == "completed" + assert not j.exception() + + j = FinalizedLoadJob(file_path, "failed", "oh no!") + assert j.state() == "failed" + assert j.exception() == "oh no!" + + # only actionable / terminal states are allowed + with pytest.raises(AssertionError): + FinalizedLoadJob(file_path, "ready") diff --git a/tests/load/test_parallelism_util.py b/tests/load/test_parallelism_util.py index b8f43d0743..3a7159563d 100644 --- a/tests/load/test_parallelism_util.py +++ b/tests/load/test_parallelism_util.py @@ -3,9 +3,9 @@ NOTE: there are tests in custom destination to check parallelism settings are applied """ -from typing import Tuple +from typing import Tuple, Any, cast -from dlt.load.utils import filter_new_jobs +from dlt.load.utils import filter_new_jobs, get_available_worker_slots from dlt.load.configuration import LoaderConfiguration from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.utils import uniq_id @@ -21,24 +21,35 @@ def get_caps_conf() -> Tuple[DestinationCapabilitiesContext, LoaderConfiguration return DestinationCapabilitiesContext(), LoaderConfiguration() -def test_max_workers() -> None: - job_names = [create_job_name("t1", i) for i in range(100)] +def test_get_available_worker_slots() -> None: caps, conf = get_caps_conf() - # default is 20 - assert len(filter_new_jobs(job_names, caps, conf)) == 20 + conf.workers = 20 + assert get_available_worker_slots(conf, caps, []) == 20 + + # change workers + conf.workers = 30 + assert get_available_worker_slots(conf, caps, []) == 30 + + # check with existing jobs + assert get_available_worker_slots(conf, caps, cast(Any, range(3))) == 27 + assert get_available_worker_slots(conf, caps, cast(Any, range(50))) == 0 + + # table-sequential will not change anything + caps.loader_parallelism_strategy = "table-sequential" + assert get_available_worker_slots(conf, caps, []) == 30 - # we can change it - conf.workers = 35 - assert len(filter_new_jobs(job_names, caps, conf)) == 35 + # caps with lower value will override + caps.max_parallel_load_jobs = 10 + assert get_available_worker_slots(conf, caps, []) == 10 - # destination may override this - caps.max_parallel_load_jobs = 15 - assert len(filter_new_jobs(job_names, caps, conf)) == 15 + # lower conf workers will override aing + conf.workers = 3 + assert get_available_worker_slots(conf, caps, []) == 3 - # lowest value will prevail - conf.workers = 5 - assert len(filter_new_jobs(job_names, caps, conf)) == 5 + # sequential strategy only allows one + caps.loader_parallelism_strategy = "sequential" + assert get_available_worker_slots(conf, caps, []) == 1 def test_table_sequential_parallelism_strategy() -> None: @@ -51,17 +62,16 @@ def test_table_sequential_parallelism_strategy() -> None: caps, conf = get_caps_conf() # default is 20 - assert len(filter_new_jobs(job_names, caps, conf)) == 20 + assert len(filter_new_jobs(job_names, caps, conf, [], 20)) == 20 # table sequential will give us 8, one for each table conf.parallelism_strategy = "table-sequential" - filtered = filter_new_jobs(job_names, caps, conf) + filtered = filter_new_jobs(job_names, caps, conf, [], 20) assert len(filtered) == 8 assert len({ParsedLoadJobFileName.parse(j).table_name for j in job_names}) == 8 - # max workers also are still applied - conf.workers = 3 - assert len(filter_new_jobs(job_names, caps, conf)) == 3 + # only free available slots are also applied + assert len(filter_new_jobs(job_names, caps, conf, [], 3)) == 3 def test_strategy_preference() -> None: @@ -72,22 +82,37 @@ def test_strategy_preference() -> None: caps, conf = get_caps_conf() # nothing set will default to parallel - assert len(filter_new_jobs(job_names, caps, conf)) == 20 + assert ( + len(filter_new_jobs(job_names, caps, conf, [], get_available_worker_slots(conf, caps, []))) + == 20 + ) caps.loader_parallelism_strategy = "table-sequential" - assert len(filter_new_jobs(job_names, caps, conf)) == 8 + assert ( + len(filter_new_jobs(job_names, caps, conf, [], get_available_worker_slots(conf, caps, []))) + == 8 + ) caps.loader_parallelism_strategy = "sequential" - assert len(filter_new_jobs(job_names, caps, conf)) == 1 + assert ( + len(filter_new_jobs(job_names, caps, conf, [], get_available_worker_slots(conf, caps, []))) + == 1 + ) # config may override (will go back to default 20) conf.parallelism_strategy = "parallel" - assert len(filter_new_jobs(job_names, caps, conf)) == 20 + assert ( + len(filter_new_jobs(job_names, caps, conf, [], get_available_worker_slots(conf, caps, []))) + == 20 + ) conf.parallelism_strategy = "table-sequential" - assert len(filter_new_jobs(job_names, caps, conf)) == 8 + assert ( + len(filter_new_jobs(job_names, caps, conf, [], get_available_worker_slots(conf, caps, []))) + == 8 + ) def test_no_input() -> None: caps, conf = get_caps_conf() - assert filter_new_jobs([], caps, conf) == [] + assert filter_new_jobs([], caps, conf, [], 50) == [] diff --git a/tests/load/utils.py b/tests/load/utils.py index 4b6c01c916..3f0726fe1b 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -17,6 +17,7 @@ from dlt.common.destination.reference import ( DestinationClientDwhConfiguration, JobClientBase, + RunnableLoadJob, LoadJob, DestinationClientStagingConfiguration, TDestinationReferenceArg, @@ -694,7 +695,12 @@ def expect_load_file( ).file_name() file_storage.save(file_name, query.encode("utf-8")) table = client.prepare_load_table(table_name) - job = client.start_file_load(table, file_storage.make_full_path(file_name), uniq_id()) + load_id = uniq_id() + job = client.create_load_job(table, file_storage.make_full_path(file_name), load_id) + + if isinstance(job, RunnableLoadJob): + job.set_run_vars(load_id=load_id, schema=client.schema, load_table=table) + job.run_managed(client) while job.state() == "running": sleep(0.5) assert job.file_name() == file_name @@ -842,18 +848,37 @@ def write_dataset( def prepare_load_package( - load_storage: LoadStorage, cases: Sequence[str], write_disposition: str = "append" + load_storage: LoadStorage, + cases: Sequence[str], + write_disposition: str = "append", + jobs_per_case: int = 1, ) -> Tuple[str, Schema]: + """ + Create a load package with explicitely provided files + job_per_case multiplies the amount of load jobs, for big packages use small files + """ load_id = uniq_id() load_storage.new_packages.create_package(load_id) for case in cases: path = f"./tests/load/cases/loading/{case}" - shutil.copy( - path, - load_storage.new_packages.storage.make_full_path( + for _ in range(jobs_per_case): + new_path = load_storage.new_packages.storage.make_full_path( load_storage.new_packages.get_job_state_folder_path(load_id, "new_jobs") - ), - ) + ) + shutil.copy( + path, + new_path, + ) + if jobs_per_case > 1: + parsed_name = ParsedLoadJobFileName.parse(case) + new_file_name = ParsedLoadJobFileName( + parsed_name.table_name, + ParsedLoadJobFileName.new_file_id(), + 0, + parsed_name.file_format, + ).file_name() + shutil.move(new_path + "/" + case, new_path + "/" + new_file_name) + schema_path = Path("./tests/load/cases/loading/schema.json") # load without migration data = json.loads(schema_path.read_text(encoding="utf8")) diff --git a/tests/load/weaviate/test_weaviate_client.py b/tests/load/weaviate/test_weaviate_client.py index dc2110d2f6..0a249db0fd 100644 --- a/tests/load/weaviate/test_weaviate_client.py +++ b/tests/load/weaviate/test_weaviate_client.py @@ -192,8 +192,8 @@ def test_load_case_sensitive_data(client: WeaviateClient, file_storage: FileStor write_dataset(client, f, [data_clash], table_create) query = f.getvalue().decode() class_name = client.schema.naming.normalize_table_identifier(class_name) - with pytest.raises(PropertyNameConflict): - expect_load_file(client, file_storage, query, class_name) + job = expect_load_file(client, file_storage, query, class_name, "failed") + assert type(job._exception) is PropertyNameConflict # type: ignore def test_load_case_sensitive_data_ci(ci_client: WeaviateClient, file_storage: FileStorage) -> None: diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 7c7dac8e71..0ab1f61d72 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -5,12 +5,14 @@ import logging import os import random +import shutil import threading from time import sleep from typing import Any, List, Tuple, cast from tenacity import retry_if_exception, Retrying, stop_after_attempt import pytest +from dlt.common.storages import FileStorage import dlt from dlt.common import json, pendulum @@ -1759,11 +1761,24 @@ def test_remove_pending_packages() -> None: assert pipeline.has_pending_data is False # partial load os.environ["EXCEPTION_PROB"] = "1.0" - os.environ["FAIL_IN_INIT"] = "False" os.environ["TIMEOUT"] = "1.0" - # should produce partial loads + # will make job go into retry state with pytest.raises(PipelineStepFailed): pipeline.run(airtable_emojis()) + # move job into completed folder manually to simulate partial package + load_storage = pipeline._get_load_storage() + load_id = load_storage.normalized_packages.list_packages()[0] + job = load_storage.normalized_packages.list_new_jobs(load_id)[0] + started_path = load_storage.normalized_packages.start_job( + load_id, FileStorage.get_file_name_from_file_path(job) + ) + completed_path = load_storage.normalized_packages.complete_job( + load_id, FileStorage.get_file_name_from_file_path(job) + ) + # to test partial loads we need two jobs one completed an one in another state + # to simulate this, we just duplicate the completed job into the started path + shutil.copyfile(completed_path, started_path) + # now "with partial loads" can be tested assert pipeline.has_pending_data pipeline.drop_pending_packages(with_partial_loads=False) assert pipeline.has_pending_data diff --git a/tests/pipeline/test_pipeline_trace.py b/tests/pipeline/test_pipeline_trace.py index bdb3e3eb22..7122b4a4c6 100644 --- a/tests/pipeline/test_pipeline_trace.py +++ b/tests/pipeline/test_pipeline_trace.py @@ -362,12 +362,12 @@ def test_trace_telemetry() -> None: with patch("dlt.common.runtime.sentry.before_send", _mock_sentry_before_send), patch( "dlt.common.runtime.anon_tracker.before_send", _mock_anon_tracker_before_send ): - # os.environ["FAIL_PROB"] = "1.0" # make it complete immediately start_test_telemetry() ANON_TRACKER_SENT_ITEMS.clear() SENTRY_SENT_ITEMS.clear() - # default dummy fails all files + # make dummy fail all files + os.environ["FAIL_PROB"] = "1.0" load_info = dlt.pipeline().run( [1, 2, 3], table_name="data", destination="dummy", dataset_name="data_data" ) @@ -397,6 +397,11 @@ def test_trace_telemetry() -> None: # dummy has empty fingerprint assert event["properties"]["destination_fingerprint"] == "" # we have two failed files (state and data) that should be logged by sentry + # TODO: make this work + print(SENTRY_SENT_ITEMS) + for item in SENTRY_SENT_ITEMS: + # print(item) + print(item["logentry"]["message"]) assert len(SENTRY_SENT_ITEMS) == 2 # trace with exception