Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Oct 12, 2023
1 parent baa5e44 commit 701f18a
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 50 deletions.
13 changes: 11 additions & 2 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def restore_file_load(self, file_path: str) -> LoadJob:
"""Finds and restores already started loading job identified by `file_path` if destination supports it."""
pass

def table_needs_truncating(self, table: TTableSchema) -> bool:
def should_truncate_table_before_load(self, table: TTableSchema) -> bool:
return table["write_disposition"] == "replace"

def create_table_chain_completed_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]:
Expand Down Expand Up @@ -324,14 +324,23 @@ class WithStagingDataset(ABC):
"""Adds capability to use staging dataset and request it from the loader"""

@abstractmethod
def table_needs_staging_dataset(self, table: TTableSchema) -> bool:
def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool:
return False

@abstractmethod
def with_staging_dataset(self)-> ContextManager["JobClientBase"]:
"""Executes job client methods on staging dataset"""
return self # type: ignore

class SupportsStagingDestination():
"""Adds capability to support a staging destination for the load"""

def should_load_data_to_staging_dataset_on_staging_destination(self, table: TTableSchema) -> bool:
return False

def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool:
# the default is to truncate the tables on the staging destination...
return True

TDestinationReferenceArg = Union["DestinationReference", ModuleType, None, str]

Expand Down
19 changes: 15 additions & 4 deletions dlt/destinations/athena/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from dlt.common.schema.utils import table_schema_has_type, get_table_format
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import LoadJob, FollowupJob
from dlt.common.destination.reference import TLoadJobState, NewLoadJob
from dlt.common.destination.reference import TLoadJobState, NewLoadJob, SupportsStagingDestination
from dlt.common.storages import FileStorage
from dlt.common.data_writers.escape import escape_bigquery_identifier
from dlt.destinations.sql_jobs import SqlStagingCopyJob
Expand Down Expand Up @@ -286,7 +286,7 @@ def has_dataset(self) -> bool:
return len(rows) > 0


class AthenaClient(SqlJobClientWithStaging):
class AthenaClient(SqlJobClientWithStaging, SupportsStagingDestination):

capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities()

Expand Down Expand Up @@ -376,11 +376,22 @@ def _is_iceberg_table(self, table: TTableSchema) -> bool:
table_format = get_table_format(self.schema.tables, table["name"])
return table_format == "iceberg" or self.config.force_iceberg

def table_needs_staging_dataset(self, table: TTableSchema) -> bool:
def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool:
# all iceberg tables need staging
if self._is_iceberg_table(table):
return True
return super().table_needs_staging_dataset(table)
return super().should_load_data_to_staging_dataset(table)

def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool:
# on athena we only truncate replace tables that are not iceberg
table = self.get_load_table(table["name"])
if table["write_disposition"] == "replace" and not self._is_iceberg_table(table):
return True
return False

def should_load_data_to_staging_dataset_on_staging_destination(self, table: TTableSchema) -> bool:
"""iceberg table data goes into staging on staging destination"""
return self._is_iceberg_table(table)

def get_load_table(self, table_name: str, staging: bool = False) -> TTableSchema:
table = super().get_load_table(table_name, staging)
Expand Down
4 changes: 2 additions & 2 deletions dlt/destinations/bigquery/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from dlt.common import json, logger
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import FollowupJob, NewLoadJob, TLoadJobState, LoadJob
from dlt.common.destination.reference import FollowupJob, NewLoadJob, TLoadJobState, LoadJob, SupportsStagingDestination
from dlt.common.data_types import TDataType
from dlt.common.storages.file_storage import FileStorage
from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns
Expand Down Expand Up @@ -151,7 +151,7 @@ def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClient
sql.append(f"CREATE TABLE {table_name} CLONE {staging_table_name};")
return sql

class BigQueryClient(SqlJobClientWithStaging):
class BigQueryClient(SqlJobClientWithStaging, SupportsStagingDestination):

capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities()

Expand Down
7 changes: 2 additions & 5 deletions dlt/destinations/filesystem/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,5 @@ def __enter__(self) -> "FilesystemClient":
def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> None:
pass

def table_needs_staging_dataset(self, table: TTableSchema) -> bool:
# not so nice, how to do it better, collect this info from the main destination as before?
if table.get("table_format") == "iceberg" or (self.config.force_iceberg is True):
return True
return super().table_needs_staging_dataset(table)
def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool:
return False
4 changes: 2 additions & 2 deletions dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def maybe_ddl_transaction(self) -> Iterator[None]:
else:
yield

def table_needs_truncating(self, table: TTableSchema) -> bool:
def should_truncate_table_before_load(self, table: TTableSchema) -> bool:
return table["write_disposition"] == "replace" and self.config.replace_strategy == "truncate-and-insert"

def _create_append_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]:
Expand Down Expand Up @@ -440,7 +440,7 @@ def with_staging_dataset(self)-> Iterator["SqlJobClientBase"]:
finally:
self.in_staging_mode = False

def table_needs_staging_dataset(self, table: TTableSchema) -> bool:
def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool:
if table["write_disposition"] == "merge":
return True
elif table["write_disposition"] == "replace" and (self.config.replace_strategy in ["insert-from-staging", "staging-optimized"]):
Expand Down
4 changes: 2 additions & 2 deletions dlt/destinations/redshift/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import ClassVar, Dict, List, Optional, Sequence, Any

from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import NewLoadJob, CredentialsConfiguration
from dlt.common.destination.reference import NewLoadJob, CredentialsConfiguration, SupportsStagingDestination
from dlt.common.data_types import TDataType
from dlt.common.schema import TColumnSchema, TColumnHint, Schema
from dlt.common.schema.typing import TTableSchema, TColumnType
Expand Down Expand Up @@ -187,7 +187,7 @@ def gen_key_table_clauses(cls, root_table_name: str, staging_root_table_name: st
return SqlMergeJob.gen_key_table_clauses(root_table_name, staging_root_table_name, key_clauses, for_delete)


class RedshiftClient(InsertValuesJobClient):
class RedshiftClient(InsertValuesJobClient, SupportsStagingDestination):

capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities()

Expand Down
8 changes: 4 additions & 4 deletions dlt/destinations/snowflake/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import ClassVar, Dict, Optional, Sequence, Tuple, List, Any
from typing import ClassVar, Optional, Sequence, Tuple, List, Any
from urllib.parse import urlparse, urlunparse

from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import FollowupJob, NewLoadJob, TLoadJobState, LoadJob, CredentialsConfiguration
from dlt.common.configuration.specs import AwsCredentialsWithoutDefaults, AzureCredentials, AzureCredentialsWithoutDefaults
from dlt.common.destination.reference import FollowupJob, NewLoadJob, TLoadJobState, LoadJob, CredentialsConfiguration, SupportsStagingDestination
from dlt.common.configuration.specs import AwsCredentialsWithoutDefaults, AzureCredentialsWithoutDefaults
from dlt.common.data_types import TDataType
from dlt.common.storages.file_storage import FileStorage
from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns
Expand Down Expand Up @@ -169,7 +169,7 @@ def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClient
return sql


class SnowflakeClient(SqlJobClientWithStaging):
class SnowflakeClient(SqlJobClientWithStaging, SupportsStagingDestination):
capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities()

def __init__(self, schema: Schema, config: SnowflakeClientConfiguration) -> None:
Expand Down
68 changes: 42 additions & 26 deletions dlt/load/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from dlt.common.schema import Schema, TSchemaTables
from dlt.common.schema.typing import TTableSchema, TWriteDisposition
from dlt.common.storages import LoadStorage
from dlt.common.destination.reference import DestinationClientDwhConfiguration, FollowupJob, JobClientBase, WithStagingDataset, DestinationReference, LoadJob, NewLoadJob, TLoadJobState, DestinationClientConfiguration
from dlt.common.destination.reference import DestinationClientDwhConfiguration, FollowupJob, JobClientBase, WithStagingDataset, DestinationReference, LoadJob, NewLoadJob, TLoadJobState, DestinationClientConfiguration, SupportsStagingDestination

from dlt.destinations.job_impl import EmptyLoadJob

Expand Down Expand Up @@ -78,9 +78,9 @@ def is_staging_destination_job(self, file_path: str) -> bool:
return self.staging_destination is not None and os.path.splitext(file_path)[1][1:] in self.staging_destination.capabilities().supported_loader_file_formats

@contextlib.contextmanager
def maybe_with_staging_dataset(self, job_client: JobClientBase, table: TTableSchema) -> Iterator[None]:
def maybe_with_staging_dataset(self, job_client: JobClientBase, use_staging: bool) -> Iterator[None]:
"""Executes job client methods in context of staging dataset if `table` has `write_disposition` that requires it"""
if isinstance(job_client, WithStagingDataset) and job_client.table_needs_staging_dataset(table):
if isinstance(job_client, WithStagingDataset) and use_staging:
with job_client.with_staging_dataset():
yield
else:
Expand All @@ -91,18 +91,27 @@ def maybe_with_staging_dataset(self, job_client: JobClientBase, table: TTableSch
def w_spool_job(self: "Load", file_path: str, load_id: str, schema: Schema) -> Optional[LoadJob]:
job: LoadJob = None
try:
is_staging_destination_job = self.is_staging_destination_job(file_path)
job_client = self.get_destination_client(schema)
staging_client = self.get_staging_destination_client(schema)

# if we have a staging destination and the file is not a reference, send to staging
job_client = self.get_staging_destination_client(schema) if self.is_staging_destination_job(file_path) else self.get_destination_client(schema)
with job_client as job_client:
with (staging_client if is_staging_destination_job else job_client) as client:
job_info = self.load_storage.parse_job_file_name(file_path)
if job_info.file_format not in self.load_storage.supported_file_formats:
raise LoadClientUnsupportedFileFormats(job_info.file_format, self.capabilities.supported_loader_file_formats, file_path)
logger.info(f"Will load file {file_path} with table name {job_info.table_name}")
table = job_client.get_load_table(job_info.table_name)
table = client.get_load_table(job_info.table_name)
if table["write_disposition"] not in ["append", "replace", "merge"]:
raise LoadClientUnsupportedWriteDisposition(job_info.table_name, table["write_disposition"], file_path)
with self.maybe_with_staging_dataset(job_client, table):
job = job_client.start_file_load(table, self.load_storage.storage.make_full_path(file_path), load_id)

if is_staging_destination_job:
use_staging_dataset = isinstance(job_client, SupportsStagingDestination) and job_client.should_load_data_to_staging_dataset_on_staging_destination(table)
else:
use_staging_dataset = isinstance(job_client, WithStagingDataset) and job_client.should_load_data_to_staging_dataset(table)

with self.maybe_with_staging_dataset(client, use_staging_dataset):
job = client.start_file_load(table, self.load_storage.storage.make_full_path(file_path), load_id)
except (DestinationTerminalException, TerminalValueError):
# if job irreversibly cannot be started, mark it as failed
logger.exception(f"Terminal problem when adding job {file_path}")
Expand Down Expand Up @@ -272,7 +281,7 @@ def _get_table_chain_tables_with_filter(schema: Schema, f: Callable[[TTableSchem
return result

@staticmethod
def _init_client_and_update_schema(job_client: JobClientBase, expected_update: TSchemaTables, update_tables: Iterable[str], truncate_tables: Iterable[str] = None, staging_info: bool = False) -> TSchemaTables:
def _init_dataset_and_update_schema(job_client: JobClientBase, expected_update: TSchemaTables, update_tables: Iterable[str], truncate_tables: Iterable[str] = None, staging_info: bool = False) -> TSchemaTables:
staging_text = "for staging dataset" if staging_info else ""
logger.info(f"Client for {job_client.config.destination_name} will start initialize storage {staging_text}")
job_client.initialize_storage()
Expand All @@ -282,31 +291,38 @@ def _init_client_and_update_schema(job_client: JobClientBase, expected_update: T
job_client.initialize_storage(truncate_tables=truncate_tables)
return applied_update


def _init_client(self, job_client: JobClientBase, schema: Schema, expected_update: TSchemaTables, load_id: str, truncate_filter: Callable[[TTableSchema], bool], truncate_staging_filter: Callable[[TTableSchema], bool]) -> TSchemaTables:

tables_with_jobs = set(job.table_name for job in self.get_new_jobs_info(load_id))
dlt_tables = set(t["name"] for t in schema.dlt_tables())

# update the default dataset
truncate_tables = self._get_table_chain_tables_with_filter(schema, truncate_filter, tables_with_jobs)
applied_update = self._init_dataset_and_update_schema(job_client, expected_update, tables_with_jobs | dlt_tables, truncate_tables)

# update the staging dataset if client supports this
if isinstance(job_client, WithStagingDataset):
if staging_tables := self._get_table_chain_tables_with_filter(schema, truncate_staging_filter, tables_with_jobs):
with job_client.with_staging_dataset():
self._init_dataset_and_update_schema(job_client, expected_update, staging_tables | {schema.version_table_name}, staging_tables, staging_info=True)

return applied_update


def load_single_package(self, load_id: str, schema: Schema) -> None:
# initialize analytical storage ie. create dataset required by passed schema
with self.get_destination_client(schema) as job_client:

if (expected_update := self.load_storage.begin_schema_update(load_id)) is not None:

tables_with_jobs = set(job.table_name for job in self.get_new_jobs_info(load_id))
dlt_tables = set(t["name"] for t in schema.dlt_tables())

# update the default dataset
truncate_tables = self._get_table_chain_tables_with_filter(schema, job_client.table_needs_truncating, tables_with_jobs)
applied_update = self._init_client_and_update_schema(job_client, expected_update, tables_with_jobs | dlt_tables, truncate_tables)

# update the staging dataset if client supports this
if isinstance(job_client, WithStagingDataset):
if staging_tables := self._get_table_chain_tables_with_filter(schema, job_client.table_needs_staging_dataset, tables_with_jobs):
with job_client.with_staging_dataset():
self._init_client_and_update_schema(job_client, expected_update, staging_tables | {schema.version_table_name}, staging_tables, staging_info=True)
# init job client
applied_update = self._init_client(job_client, schema, expected_update, load_id, job_client.should_truncate_table_before_load, job_client.should_load_data_to_staging_dataset if isinstance(job_client, WithStagingDataset) else None)

# only update tables that are present in the load package
if self.staging_destination and isinstance(job_client, WithStagingDataset):
# init staging client
if self.staging_destination and isinstance(job_client, SupportsStagingDestination):
with self.get_staging_destination_client(schema) as staging_client:
# truncate all the tables in staging that are requested by the job client (TODO: make this better...)
truncate_tables = self._get_table_chain_tables_with_filter(schema, job_client.table_needs_staging_dataset, tables_with_jobs)
self._init_client_and_update_schema(staging_client, expected_update, tables_with_jobs | dlt_tables, truncate_tables)
self._init_client(staging_client, schema, expected_update, load_id, job_client.should_truncate_table_before_load_on_staging_destination, job_client.should_load_data_to_staging_dataset_on_staging_destination)

self.load_storage.commit_schema_update(load_id, applied_update)

Expand Down
6 changes: 3 additions & 3 deletions tests/load/test_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,14 +428,14 @@ def test_load_with_all_types(client: SqlJobClientBase, write_disposition: TWrite
client.schema.bump_version()
client.update_stored_schema()

if client.table_needs_staging_dataset(client.schema.tables[table_name]): # type: ignore[attr-defined]
if client.should_load_data_to_staging_dataset(client.schema.tables[table_name]): # type: ignore[attr-defined]
with client.with_staging_dataset(): # type: ignore[attr-defined]
# create staging for merge dataset
client.initialize_storage()
client.update_stored_schema()

with client.sql_client.with_staging_dataset(
client.table_needs_staging_dataset(client.schema.tables[table_name]) # type: ignore[attr-defined]
client.should_load_data_to_staging_dataset(client.schema.tables[table_name]) # type: ignore[attr-defined]
):
canonical_name = client.sql_client.make_qualified_table_name(table_name)
# write row
Expand Down Expand Up @@ -493,7 +493,7 @@ def test_write_dispositions(client: SqlJobClientBase, write_disposition: TWriteD
with io.BytesIO() as f:
write_dataset(client, f, [table_row], TABLE_UPDATE_COLUMNS_SCHEMA)
query = f.getvalue().decode()
if client.table_needs_staging_dataset(client.schema.tables[table_name]): # type: ignore[attr-defined]
if client.should_load_data_to_staging_dataset(client.schema.tables[table_name]): # type: ignore[attr-defined]
# load to staging dataset on merge
with client.with_staging_dataset(): # type: ignore[attr-defined]
expect_load_file(client, file_storage, query, t)
Expand Down

0 comments on commit 701f18a

Please sign in to comment.