From 0924bc53b2831954633c1ee86fc82238988f1c69 Mon Sep 17 00:00:00 2001 From: Dave Date: Mon, 9 Oct 2023 18:01:59 +0200 Subject: [PATCH] new iceberg approach --- dlt/common/destination/reference.py | 15 +-- dlt/common/schema/exceptions.py | 5 + dlt/common/schema/typing.py | 2 + dlt/common/schema/utils.py | 46 ++++++-- dlt/common/storages/load_storage.py | 5 +- dlt/destinations/athena/athena.py | 25 ++--- dlt/destinations/bigquery/bigquery.py | 5 +- dlt/destinations/exceptions.py | 6 - dlt/destinations/filesystem/filesystem.py | 29 +++-- dlt/destinations/job_client_impl.py | 20 ++-- dlt/extract/decorators.py | 7 +- dlt/extract/schema.py | 5 +- dlt/load/load.py | 103 +++++++----------- .../athena_iceberg/test_athena_iceberg.py | 39 +++++-- tests/load/utils.py | 4 +- tests/utils.py | 1 + 16 files changed, 178 insertions(+), 139 deletions(-) diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index f4fbe4df76..5fea462159 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -10,6 +10,7 @@ from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.schema.typing import TWriteDisposition from dlt.common.schema.exceptions import InvalidDatasetName +from dlt.common.schema.utils import get_load_table from dlt.common.configuration import configspec from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration from dlt.common.configuration.accessors import config @@ -244,13 +245,8 @@ def restore_file_load(self, file_path: str) -> LoadJob: """Finds and restores already started loading job identified by `file_path` if destination supports it.""" pass - def get_truncate_destination_table_dispositions(self) -> List[TWriteDisposition]: - # in the base job, all replace strategies are treated the same, see filesystem for example - return ["replace"] - - def get_truncate_staging_destination_table_dispositions(self) -> List[TWriteDisposition]: - # some clients need to additionally be able to get the staging destination to truncate tables - return [] + def table_needs_truncating(self, table: TTableSchema) -> bool: + return table["write_disposition"] == "replace" def create_table_chain_completed_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: """Creates a list of followup jobs that should be executed after a table chain is completed""" @@ -313,9 +309,8 @@ class WithStagingDataset(ABC): """Adds capability to use staging dataset and request it from the loader""" @abstractmethod - def get_stage_dispositions(self) -> List[TWriteDisposition]: - """Returns a list of write dispositions that require staging dataset""" - return [] + def table_needs_staging(self, table: TTableSchema) -> bool: + return False @abstractmethod def with_staging_dataset(self)-> ContextManager["JobClientBase"]: diff --git a/dlt/common/schema/exceptions.py b/dlt/common/schema/exceptions.py index 2245a77b61..5f638a111d 100644 --- a/dlt/common/schema/exceptions.py +++ b/dlt/common/schema/exceptions.py @@ -69,3 +69,8 @@ def __init__(self, schema_name: str, init_engine: int, from_engine: int, to_engi self.from_engine = from_engine self.to_engine = to_engine super().__init__(f"No engine upgrade path in schema {schema_name} from {init_engine} to {to_engine}, stopped at {from_engine}") + +class UnknownTableException(SchemaException): + def __init__(self, table_name: str) -> None: + self.table_name = table_name + super().__init__(f"Trying to access unknown table {table_name}.") \ No newline at end of file diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index ae24691e2d..2cc057560c 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -24,6 +24,7 @@ TColumnHint = Literal["not_null", "partition", "cluster", "primary_key", "foreign_key", "sort", "unique", "root_key", "merge_key"] """Known hints of a column used to declare hint regexes.""" TWriteDisposition = Literal["skip", "append", "replace", "merge"] +TTableFormat = Literal["iceberg"] TTypeDetections = Literal["timestamp", "iso_timestamp", "large_integer", "hexbytes_to_text", "wei_to_double"] TTypeDetectionFunc = Callable[[Type[Any], Any], Optional[TDataType]] TColumnNames = Union[str, Sequence[str]] @@ -86,6 +87,7 @@ class TTableSchema(TypedDict, total=False): filters: Optional[TRowFilters] columns: TTableSchemaColumns resource: Optional[str] + table_format: Optional[TTableFormat] class TPartialTableSchema(TTableSchema): diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index d3c0b31cc0..93f0913550 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -15,10 +15,10 @@ from dlt.common.validation import TCustomValidator, validate_dict, validate_dict_ignoring_xkeys from dlt.common.schema import detections from dlt.common.schema.typing import (COLUMN_HINTS, SCHEMA_ENGINE_VERSION, LOADS_TABLE_NAME, SIMPLE_REGEX_PREFIX, VERSION_TABLE_NAME, TColumnName, TPartialTableSchema, TSchemaTables, TSchemaUpdate, - TSimpleRegex, TStoredSchema, TTableSchema, TTableSchemaColumns, TColumnSchemaBase, TColumnSchema, TColumnProp, + TSimpleRegex, TStoredSchema, TTableSchema, TTableSchemaColumns, TColumnSchemaBase, TColumnSchema, TColumnProp, TTableFormat, TColumnHint, TTypeDetectionFunc, TTypeDetections, TWriteDisposition) from dlt.common.schema.exceptions import (CannotCoerceColumnException, ParentTableNotFoundException, SchemaEngineNoUpgradePathException, SchemaException, - TablePropertiesConflictException, InvalidSchemaName) + TablePropertiesConflictException, InvalidSchemaName, UnknownTableException) from dlt.common.normalizers.utils import import_normalizers from dlt.common.schema.typing import TAnySchemaColumns @@ -493,18 +493,29 @@ def merge_schema_updates(schema_updates: Sequence[TSchemaUpdate]) -> TSchemaTabl return aggregated_update -def get_write_disposition(tables: TSchemaTables, table_name: str) -> TWriteDisposition: - """Returns write disposition of a table if present. If not, looks up into parent table""" +def get_inherited_table_hint(tables: TSchemaTables, table_name: str, table_hint_name: str, allow_none: bool = False) -> Any: table = tables[table_name] - w_d = table.get("write_disposition") - if w_d: - return w_d + hint = table.get(table_hint_name) + if hint: + return hint parent = table.get("parent") if parent: - return get_write_disposition(tables, parent) + return get_inherited_table_hint(tables, parent, table_hint_name, allow_none) + + if allow_none: + return None + + raise ValueError(f"No table hint '{table_hint_name} found in the chain of tables for '{table_name}'.") + + +def get_write_disposition(tables: TSchemaTables, table_name: str) -> TWriteDisposition: + """Returns table hint of a table if present. If not, looks up into parent table""" + return get_inherited_table_hint(tables, table_name, "write_disposition", allow_none=False) + - raise ValueError(f"No write disposition found in the chain of tables for '{table_name}'.") +def get_table_format(tables: TSchemaTables, table_name: str) -> TTableFormat: + return get_inherited_table_hint(tables, table_name, "table_format", allow_none=True) def table_schema_has_type(table: TTableSchema, _typ: TDataType) -> bool: @@ -525,6 +536,18 @@ def get_top_level_table(tables: TSchemaTables, table_name: str) -> TTableSchema: return get_top_level_table(tables, parent) return table +def get_load_table(tables: TSchemaTables, table_name: str) -> TTableSchema: + try: + # make a copy of the schema so modifications do not affect the original document + table = copy(tables[table_name]) + # add write disposition if not specified - in child tables + if "write_disposition" not in table: + table["write_disposition"] = get_write_disposition(tables, table_name) + if "table_format" not in table: + table["table_format"] = get_table_format(tables, table_name) + return table + except KeyError: + raise UnknownTableException(table_name) def get_child_tables(tables: TSchemaTables, table_name: str) -> List[TTableSchema]: """Get child tables for table name and return a list of tables ordered by ancestry so the child tables are always after their parents""" @@ -637,7 +660,8 @@ def new_table( write_disposition: TWriteDisposition = None, columns: Sequence[TColumnSchema] = None, validate_schema: bool = False, - resource: str = None + resource: str = None, + table_format: TTableFormat = None ) -> TTableSchema: table: TTableSchema = { @@ -652,6 +676,8 @@ def new_table( # set write disposition only for root tables table["write_disposition"] = write_disposition or DEFAULT_WRITE_DISPOSITION table["resource"] = resource or table_name + if table_format: + table["table_format"] = table_format if validate_schema: validate_dict_ignoring_xkeys( spec=TColumnSchema, diff --git a/dlt/common/storages/load_storage.py b/dlt/common/storages/load_storage.py index 95170ac46c..8e8a0ac5a8 100644 --- a/dlt/common/storages/load_storage.py +++ b/dlt/common/storages/load_storage.py @@ -237,8 +237,11 @@ def list_failed_jobs(self, load_id: str) -> Sequence[str]: return self.storage.list_folder_files(self._get_job_folder_path(load_id, LoadStorage.FAILED_JOBS_FOLDER)) def list_jobs_for_table(self, load_id: str, table_name: str) -> Sequence[LoadJobInfo]: + return [job for job in self.list_all_jobs(load_id) if job.job_file_info.table_name == table_name] + + def list_all_jobs(self, load_id: str) -> Sequence[LoadJobInfo]: info = self.get_load_package_info(load_id) - return [job for job in flatten_list_or_items(iter(info.jobs.values())) if job.job_file_info.table_name == table_name] # type: ignore + return [job for job in flatten_list_or_items(iter(info.jobs.values()))] # type: ignore def list_completed_failed_jobs(self, load_id: str) -> Sequence[str]: return self.storage.list_folder_files(self._get_job_folder_completed_path(load_id, LoadStorage.FAILED_JOBS_FOLDER)) diff --git a/dlt/destinations/athena/athena.py b/dlt/destinations/athena/athena.py index 792617badd..c9d6f3abb7 100644 --- a/dlt/destinations/athena/athena.py +++ b/dlt/destinations/athena/athena.py @@ -17,7 +17,7 @@ from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, Schema from dlt.common.schema.typing import TTableSchema, TColumnType, TWriteDisposition -from dlt.common.schema.utils import table_schema_has_type +from dlt.common.schema.utils import table_schema_has_type, get_table_format from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import LoadJob, FollowupJob from dlt.common.destination.reference import TLoadJobState, NewLoadJob @@ -325,12 +325,13 @@ def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSc # for the system tables we need to create empty iceberg tables to be able to run, DELETE and UPDATE queries # or if we are in iceberg mode, we create iceberg tables for all tables - is_iceberg = self.schema.tables[table_name].get("write_disposition", None) == "skip" + is_iceberg = (self.schema.tables[table_name].get("write_disposition", None) == "skip") or (self._is_iceberg_table(self.schema.tables[table_name]) and not self.in_staging_mode) columns = ", ".join([self._get_column_def_sql(c) for c in new_columns]) # this will fail if the table prefix is not properly defined table_prefix = self.table_prefix_layout.format(table_name=table_name) location = f"{bucket}/{dataset}/{table_prefix}" + # use qualified table names qualified_table_name = self.sql_client.make_qualified_ddl_table_name(table_name) if is_iceberg and not generate_alter: @@ -372,18 +373,14 @@ def _create_replace_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> return super()._create_replace_followup_jobs(table_chain) def _is_iceberg_table(self, table: TTableSchema) -> bool: - return False - - def get_stage_dispositions(self) -> List[TWriteDisposition]: - # in iceberg mode, we always use staging tables - # if self.iceberg_mode: - # return ["append", "replace", "merge"] - return super().get_stage_dispositions() - - def get_truncate_staging_destination_table_dispositions(self) -> List[TWriteDisposition]: - # if self.iceberg_mode: - # return ["append", "replace", "merge"] - return [] + table_format = get_table_format(self.schema.tables, table["name"]) + return table_format == "iceberg" + + def table_needs_staging(self, table: TTableSchema) -> bool: + # all iceberg tables need staging + if self._is_iceberg_table(table): + return True + return super().table_needs_staging(table) @staticmethod def is_dbapi_exception(ex: Exception) -> bool: diff --git a/dlt/destinations/bigquery/bigquery.py b/dlt/destinations/bigquery/bigquery.py index a5aa0cc703..eceb2ed57a 100644 --- a/dlt/destinations/bigquery/bigquery.py +++ b/dlt/destinations/bigquery/bigquery.py @@ -12,9 +12,10 @@ from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns from dlt.common.schema.typing import TTableSchema, TColumnType +from dlt.common.schema.exceptions import UnknownTableException from dlt.destinations.job_client_impl import SqlJobClientWithStaging -from dlt.destinations.exceptions import DestinationSchemaWillNotUpdate, DestinationTransientException, LoadJobNotExistsException, LoadJobTerminalException, LoadJobUnknownTableException +from dlt.destinations.exceptions import DestinationSchemaWillNotUpdate, DestinationTransientException, LoadJobNotExistsException, LoadJobTerminalException from dlt.destinations.bigquery import capabilities from dlt.destinations.bigquery.configuration import BigQueryClientConfiguration @@ -220,7 +221,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> reason = BigQuerySqlClient._get_reason_from_errors(gace) if reason == "notFound": # google.api_core.exceptions.NotFound: 404 - table not found - raise LoadJobUnknownTableException(table["name"], file_path) + raise UnknownTableException(table["name"]) elif reason == "duplicate": # google.api_core.exceptions.Conflict: 409 PUT - already exists return self.restore_file_load(file_path) diff --git a/dlt/destinations/exceptions.py b/dlt/destinations/exceptions.py index f0fe32f950..5c20f081f1 100644 --- a/dlt/destinations/exceptions.py +++ b/dlt/destinations/exceptions.py @@ -63,12 +63,6 @@ def __init__(self, file_path: str, message: str) -> None: super().__init__(f"Job with id/file name {file_path} encountered unrecoverable problem: {message}") -class LoadJobUnknownTableException(DestinationTerminalException): - def __init__(self, table_name: str, file_name: str) -> None: - self.table_name = table_name - super().__init__(f"Client does not know table {table_name} for load file {file_name}") - - class LoadJobInvalidStateTransitionException(DestinationTerminalException): def __init__(self, from_state: TLoadJobState, to_state: TLoadJobState) -> None: self.from_state = from_state diff --git a/dlt/destinations/filesystem/filesystem.py b/dlt/destinations/filesystem/filesystem.py index 3691c6417b..6ad5954496 100644 --- a/dlt/destinations/filesystem/filesystem.py +++ b/dlt/destinations/filesystem/filesystem.py @@ -1,14 +1,15 @@ import posixpath import os from types import TracebackType -from typing import ClassVar, List, Type, Iterable, Set +from typing import ClassVar, List, Type, Iterable, Set, Iterator from fsspec import AbstractFileSystem +from contextlib import contextmanager from dlt.common import logger from dlt.common.schema import Schema, TSchemaTables, TTableSchema from dlt.common.storages import FileStorage, LoadStorage, filesystem_from_config from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import NewLoadJob, TLoadJobState, LoadJob, JobClientBase, FollowupJob +from dlt.common.destination.reference import NewLoadJob, TLoadJobState, LoadJob, JobClientBase, FollowupJob, WithStagingDataset from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.filesystem import capabilities @@ -68,7 +69,7 @@ def create_followup_jobs(self, next_state: str) -> List[NewLoadJob]: return jobs -class FilesystemClient(JobClientBase): +class FilesystemClient(JobClientBase, WithStagingDataset): """filesystem client storing jobs in memory""" capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -82,16 +83,22 @@ def __init__(self, schema: Schema, config: FilesystemDestinationClientConfigurat # verify files layout. we need {table_name} and only allow {schema_name} before it, otherwise tables # cannot be replaced and we cannot initialize folders consistently self.table_prefix_layout = path_utils.get_table_prefix_layout(config.layout) - - @property - def dataset_path(self) -> str: - ds_path = posixpath.join(self.fs_path, self.config.normalize_dataset_name(self.schema)) - return ds_path + self.dataset_path = posixpath.join(self.fs_path, self.config.normalize_dataset_name(self.schema)) def drop_storage(self) -> None: if self.is_storage_initialized(): self.fs_client.rm(self.dataset_path, recursive=True) + @contextmanager + def with_staging_dataset(self) -> Iterator["FilesystemClient"]: + current_dataset_path = self.dataset_path + try: + self.dataset_path = posixpath.join(self.fs_path, self.config.normalize_dataset_name(self.schema)) + "_staging" + yield self + finally: + # restore previous dataset name + self.dataset_path = current_dataset_path + def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: # clean up existing files for tables selected for truncating if truncate_tables and self.fs_client.isdir(self.dataset_path): @@ -169,3 +176,9 @@ def __enter__(self) -> "FilesystemClient": def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> None: pass + + def table_needs_staging(self, table: TTableSchema) -> bool: + # not so nice, how to do it better, collect this info from the main destination as before? + if table["table_format"] == "iceberg": + return True + return super().table_needs_staging(table) diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 91f7d15bee..ec3afced94 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -140,10 +140,8 @@ def maybe_ddl_transaction(self) -> Iterator[None]: else: yield - def get_truncate_destination_table_dispositions(self) -> List[TWriteDisposition]: - if self.config.replace_strategy == "truncate-and-insert": - return ["replace"] - return [] + def table_needs_truncating(self, table: TTableSchema) -> bool: + return table["write_disposition"] == "replace" and self.config.replace_strategy == "truncate-and-insert" def _create_append_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: return [] @@ -442,10 +440,10 @@ def with_staging_dataset(self)-> Iterator["SqlJobClientBase"]: finally: self.in_staging_mode = False - def get_stage_dispositions(self) -> List[TWriteDisposition]: - """Returns a list of dispositions that require staging tables to be populated""" - dispositions: List[TWriteDisposition] = ["merge"] - # if we have anything but the truncate-and-insert replace strategy, we need staging tables - if self.config.replace_strategy in ["insert-from-staging", "staging-optimized"]: - dispositions.append("replace") - return dispositions + def table_needs_staging(self, table: TTableSchema) -> bool: + if table["write_disposition"] == "merge": + return True + elif table["write_disposition"] == "replace" and (self.config.replace_strategy in ["insert-from-staging", "staging-optimized"]): + return True + return False + diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index 84dfcb83f9..e122ad10cd 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -14,7 +14,7 @@ from dlt.common.pipeline import PipelineContext from dlt.common.source import _SOURCES, SourceInfo from dlt.common.schema.schema import Schema -from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns, TWriteDisposition, TAnySchemaColumns +from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns, TWriteDisposition, TAnySchemaColumns, TTableFormat from dlt.extract.utils import ensure_table_schema_columns_hint from dlt.common.storages.exceptions import SchemaNotFoundError from dlt.common.storages.schema_storage import SchemaStorage @@ -200,6 +200,7 @@ def resource( columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, + table_format: TTableHintTemplate[TTableFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None ) -> DltResource: @@ -215,6 +216,7 @@ def resource( columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, + table_format: TTableHintTemplate[TTableFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None ) -> Callable[[Callable[TResourceFunParams, Any]], DltResource]: @@ -230,6 +232,7 @@ def resource( columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, + table_format: TTableHintTemplate[TTableFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None ) -> DltResource: @@ -245,6 +248,7 @@ def resource( columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, + table_format: TTableHintTemplate[TTableFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, data_from: TUnboundDltResource = None, @@ -313,6 +317,7 @@ def make_resource(_name: str, _section: str, _data: Any, incremental: Incrementa columns=columns, primary_key=primary_key, merge_key=merge_key, + table_format=table_format ) return DltResource.from_data(_data, _name, _section, table_template, selected, cast(DltResource, data_from), incremental=incremental) diff --git a/dlt/extract/schema.py b/dlt/extract/schema.py index 80e9f6f32f..524bfabc0b 100644 --- a/dlt/extract/schema.py +++ b/dlt/extract/schema.py @@ -3,7 +3,7 @@ from typing import List, TypedDict, cast, Any from dlt.common.schema.utils import DEFAULT_WRITE_DISPOSITION, merge_columns, new_column, new_table -from dlt.common.schema.typing import TColumnNames, TColumnProp, TColumnSchema, TPartialTableSchema, TTableSchemaColumns, TWriteDisposition, TAnySchemaColumns +from dlt.common.schema.typing import TColumnNames, TColumnProp, TColumnSchema, TPartialTableSchema, TTableSchemaColumns, TWriteDisposition, TAnySchemaColumns, TTableFormat from dlt.common.typing import TDataItem from dlt.common.utils import update_dict_nested from dlt.common.validation import validate_dict_ignoring_xkeys @@ -211,6 +211,7 @@ def new_table_template( columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, + table_format: TTableHintTemplate[TTableFormat] = None ) -> TTableSchemaTemplate: if not table_name: raise TableNameMissing() @@ -224,7 +225,7 @@ def new_table_template( validator = None # create a table schema template where hints can be functions taking TDataItem new_template: TTableSchemaTemplate = new_table( - table_name, parent_table_name, write_disposition=write_disposition, columns=columns # type: ignore + table_name, parent_table_name, write_disposition=write_disposition, columns=columns, table_format=table_format # type: ignore ) if primary_key: new_template["primary_key"] = primary_key diff --git a/dlt/load/load.py b/dlt/load/load.py index 488a0ce4f2..ddce9bf8e9 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -2,7 +2,7 @@ from copy import copy from functools import reduce import datetime # noqa: 251 -from typing import Dict, List, Optional, Tuple, Set, Iterator +from typing import Dict, List, Optional, Tuple, Set, Iterator, Iterable, Callable from multiprocessing.pool import ThreadPool import os @@ -10,20 +10,19 @@ from dlt.common.configuration import with_config, known_sections from dlt.common.configuration.accessors import config from dlt.common.pipeline import LoadInfo, SupportsPipeline -from dlt.common.schema.utils import get_child_tables, get_top_level_table, get_write_disposition +from dlt.common.schema.utils import get_child_tables, get_top_level_table, get_load_table from dlt.common.storages.load_storage import LoadPackageInfo, ParsedLoadJobFileName, TJobState from dlt.common.typing import StrAny from dlt.common.runners import TRunMetrics, Runnable, workermethod from dlt.common.runtime.collector import Collector, NULL_COLLECTOR from dlt.common.runtime.logger import pretty_format_exception from dlt.common.exceptions import TerminalValueError, DestinationTerminalException, DestinationTransientException -from dlt.common.schema import Schema +from dlt.common.schema import Schema, TSchemaTables from dlt.common.schema.typing import TTableSchema, TWriteDisposition from dlt.common.storages import LoadStorage from dlt.common.destination.reference import DestinationClientDwhConfiguration, FollowupJob, JobClientBase, WithStagingDataset, DestinationReference, LoadJob, NewLoadJob, TLoadJobState, DestinationClientConfiguration from dlt.destinations.job_impl import EmptyLoadJob -from dlt.destinations.exceptions import LoadJobUnknownTableException from dlt.load.configuration import LoaderConfiguration from dlt.load.exceptions import LoadClientJobFailed, LoadClientJobRetry, LoadClientUnsupportedWriteDisposition, LoadClientUnsupportedFileFormats @@ -69,19 +68,6 @@ def create_storage(self, is_storage_owner: bool) -> LoadStorage: ) return load_storage - @staticmethod - def get_load_table(schema: Schema, file_name: str) -> TTableSchema: - table_name = LoadStorage.parse_job_file_name(file_name).table_name - try: - # make a copy of the schema so modifications do not affect the original document - table = copy(schema.get_table(table_name)) - # add write disposition if not specified - in child tables - if "write_disposition" not in table: - table["write_disposition"] = get_write_disposition(schema.tables, table_name) - return table - except KeyError: - raise LoadJobUnknownTableException(table_name, file_name) - def get_destination_client(self, schema: Schema) -> JobClientBase: return self.destination.client(schema, self.initial_client_config) @@ -94,7 +80,7 @@ def is_staging_destination_job(self, file_path: str) -> bool: @contextlib.contextmanager def maybe_with_staging_dataset(self, job_client: JobClientBase, table: TTableSchema) -> Iterator[None]: """Executes job client methods in context of staging dataset if `table` has `write_disposition` that requires it""" - if isinstance(job_client, WithStagingDataset) and table["write_disposition"] in job_client.get_stage_dispositions(): + if isinstance(job_client, WithStagingDataset) and job_client.table_needs_staging(table): with job_client.with_staging_dataset(): yield else: @@ -112,7 +98,7 @@ def w_spool_job(self: "Load", file_path: str, load_id: str, schema: Schema) -> O if job_info.file_format not in self.load_storage.supported_file_formats: raise LoadClientUnsupportedFileFormats(job_info.file_format, self.capabilities.supported_loader_file_formats, file_path) logger.info(f"Will load file {file_path} with table name {job_info.table_name}") - table = self.get_load_table(schema, file_path) + table = get_load_table(schema.tables, job_info.table_name) if table["write_disposition"] not in ["append", "replace", "merge"]: raise LoadClientUnsupportedWriteDisposition(job_info.table_name, table["write_disposition"], file_path) with self.maybe_with_staging_dataset(job_client, table): @@ -173,13 +159,8 @@ def retrieve_jobs(self, client: JobClientBase, load_id: str, staging_client: Job return len(jobs), jobs - def get_new_jobs_info(self, load_id: str, schema: Schema, dispositions: List[TWriteDisposition] = None) -> List[ParsedLoadJobFileName]: - jobs_info: List[ParsedLoadJobFileName] = [] - new_job_files = self.load_storage.list_new_jobs(load_id) - for job_file in new_job_files: - if dispositions is None or self.get_load_table(schema, job_file)["write_disposition"] in dispositions: - jobs_info.append(LoadStorage.parse_job_file_name(job_file)) - return jobs_info + def get_new_jobs_info(self, load_id: str) -> List[ParsedLoadJobFileName]: + return [LoadStorage.parse_job_file_name(job_file) for job_file in self.load_storage.list_new_jobs(load_id)] def get_completed_table_chain(self, load_id: str, schema: Schema, top_merged_table: TTableSchema, being_completed_job_id: str = None) -> List[TTableSchema]: """Gets a table chain starting from the `top_merged_table` containing only tables with completed/failed jobs. None is returned if there's any job that is not completed @@ -210,7 +191,7 @@ def create_followup_jobs(self, load_id: str, state: TLoadJobState, starting_job: starting_job_file_name = starting_job.file_name() if state == "completed" and not self.is_staging_destination_job(starting_job_file_name): client = self.destination.client(schema, self.initial_client_config) - top_job_table = get_top_level_table(schema.tables, self.get_load_table(schema, starting_job_file_name)["name"]) + top_job_table = get_top_level_table(schema.tables, starting_job.job_file_info().table_name) # if all tables of chain completed, create follow up jobs if table_chain := self.get_completed_table_chain(load_id, schema, top_job_table, starting_job.job_file_info().job_id()): if follow_up_jobs := client.create_table_chain_completed_followup_jobs(table_chain): @@ -278,56 +259,56 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) self.load_storage.complete_load_package(load_id, aborted) logger.info(f"All jobs completed, archiving package {load_id} with aborted set to {aborted}") - def get_table_chain_tables_for_write_disposition(self, load_id: str, schema: Schema, dispositions: List[TWriteDisposition]) -> Set[str]: + @staticmethod + def _get_table_chain_tables_with_filter(schema: Schema, filter: Callable, tables_with_jobs: Iterable[str]) -> Set[str]: """Get all jobs for tables with given write disposition and resolve the table chain""" result: Set[str] = set() - table_jobs = self.get_new_jobs_info(load_id, schema, dispositions) - for job in table_jobs: - top_job_table = get_top_level_table(schema.tables, self.get_load_table(schema, job.job_id())["name"]) - table_chain = get_child_tables(schema.tables, top_job_table["name"]) - for table in table_chain: - existing_jobs = self.load_storage.list_jobs_for_table(load_id, table["name"]) - # only add tables for tables that have jobs unless the disposition is replace - if not existing_jobs and top_job_table["write_disposition"] != "replace": - continue + for table_name in tables_with_jobs: + top_job_table = get_top_level_table(schema.tables, table_name) + if not filter(top_job_table): + continue + for table in get_child_tables(schema.tables, top_job_table["name"]): result.add(table["name"]) return result + @staticmethod + def _init_client_and_update_schema(job_client: JobClientBase, expected_update: TSchemaTables, update_tables: Iterable[str], truncate_tables: Iterable[str] = None, staging_info: bool = False) -> TSchemaTables: + staging_text = "for staging dataset" if staging_info else "" + logger.info(f"Client for {job_client.config.destination_name} will start initialize storage {staging_text}") + job_client.initialize_storage() + logger.info(f"Client for {job_client.config.destination_name} will update schema to package schema {staging_text}") + applied_update = job_client.update_stored_schema(only_tables=update_tables, expected_update=expected_update) + logger.info(f"Client for {job_client.config.destination_name} will truncate tables {staging_text}") + job_client.initialize_storage(truncate_tables=truncate_tables) + return applied_update + def load_single_package(self, load_id: str, schema: Schema) -> None: # initialize analytical storage ie. create dataset required by passed schema - job_client: JobClientBase with self.get_destination_client(schema) as job_client: - expected_update = self.load_storage.begin_schema_update(load_id) - if expected_update is not None: - # update the default dataset - logger.info(f"Client for {job_client.config.destination_name} will start initialize storage") - job_client.initialize_storage() - logger.info(f"Client for {job_client.config.destination_name} will update schema to package schema") - all_jobs = self.get_new_jobs_info(load_id, schema) - all_tables = set(job.table_name for job in all_jobs) + + if (expected_update := self.load_storage.begin_schema_update(load_id)) is not None: + + tables_with_jobs = set(job.table_name for job in self.get_new_jobs_info(load_id)) dlt_tables = set(t["name"] for t in schema.dlt_tables()) + + # update the default dataset + truncate_tables = self._get_table_chain_tables_with_filter(schema, job_client.table_needs_truncating, tables_with_jobs) + applied_update = self._init_client_and_update_schema(job_client, expected_update, tables_with_jobs | dlt_tables, truncate_tables) + # only update tables that are present in the load package - applied_update = job_client.update_stored_schema(only_tables=all_tables | dlt_tables, expected_update=expected_update) - truncate_tables = self.get_table_chain_tables_for_write_disposition(load_id, schema, job_client.get_truncate_destination_table_dispositions()) - job_client.initialize_storage(truncate_tables=truncate_tables) - # initialize staging storage if needed if self.staging_destination: with self.get_staging_destination_client(schema) as staging_client: - truncate_dispositions = staging_client.get_truncate_destination_table_dispositions() - truncate_dispositions.extend(job_client.get_truncate_staging_destination_table_dispositions()) - truncate_tables = self.get_table_chain_tables_for_write_disposition(load_id, schema, truncate_dispositions) - staging_client.initialize_storage(truncate_tables) + truncate_tables = self._get_table_chain_tables_with_filter(schema, job_client.table_needs_truncating, tables_with_jobs) + self._init_client_and_update_schema(staging_client, expected_update, tables_with_jobs | dlt_tables, truncate_tables) + # update the staging dataset if client supports this if isinstance(job_client, WithStagingDataset): - if staging_tables := self.get_table_chain_tables_for_write_disposition(load_id, schema, job_client.get_stage_dispositions()): + if staging_tables := self._get_table_chain_tables_with_filter(schema, job_client.table_needs_staging, tables_with_jobs): with job_client.with_staging_dataset(): - logger.info(f"Client for {job_client.config.destination_name} will start initialize STAGING storage") - job_client.initialize_storage() - logger.info(f"Client for {job_client.config.destination_name} will UPDATE STAGING SCHEMA to package schema") - job_client.update_stored_schema(only_tables=staging_tables | {schema.version_table_name}, expected_update=expected_update) - logger.info(f"Client for {job_client.config.destination_name} will TRUNCATE STAGING TABLES: {staging_tables}") - job_client.initialize_storage(truncate_tables=staging_tables) + self._init_client_and_update_schema(job_client, expected_update, staging_tables | {schema.version_table_name}, staging_tables, staging_info=True) + self.load_storage.commit_schema_update(load_id, applied_update) + # initialize staging destination and spool or retrieve unfinished jobs if self.staging_destination: with self.get_staging_destination_client(schema) as staging_client: diff --git a/tests/load/athena_iceberg/test_athena_iceberg.py b/tests/load/athena_iceberg/test_athena_iceberg.py index e1650549cc..72772b0e2d 100644 --- a/tests/load/athena_iceberg/test_athena_iceberg.py +++ b/tests/load/athena_iceberg/test_athena_iceberg.py @@ -14,14 +14,20 @@ from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration from tests.utils import skip_if_not_active +from dlt.destinations.exceptions import DatabaseTerminalException + skip_if_not_active("athena") def test_iceberg() -> None: + """ + We write two tables, one with the iceberg flag, one without. We expect the iceberg table and its subtables to accept update commands + and the other table to reject them. + """ os.environ['DESTINATION__FILESYSTEM__BUCKET_URL'] = "s3://dlt-ci-test-bucket" - pipeline = dlt.pipeline(pipeline_name="aaathena-iceberg", destination="athena", staging="filesystem", full_refresh=True) + pipeline = dlt.pipeline(pipeline_name="aaaaathena-iceberg", destination="athena", staging="filesystem", full_refresh=True) def items() -> Iterator[Any]: yield { @@ -40,22 +46,33 @@ def items() -> Iterator[Any]: def items_normal(): yield from items() - @dlt.resource(name="items_iceberg", write_disposition="append") + @dlt.resource(name="items_iceberg", write_disposition="append", table_format="iceberg") def items_iceberg(): yield from items() print(pipeline.run([items_normal, items_iceberg])) - return - # see if we have athena tables with items table_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema._schema_tables.values() ]) - assert table_counts["items"] == 1 - assert table_counts["items__sub_items"] == 2 + assert table_counts["items_normal"] == 1 + assert table_counts["items_normal__sub_items"] == 2 assert table_counts["_dlt_loads"] == 1 - pipeline.run(items) - table_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema._schema_tables.values() ]) - assert table_counts["items"] == 2 - assert table_counts["items__sub_items"] == 4 - assert table_counts["_dlt_loads"] == 2 \ No newline at end of file + assert table_counts["items_iceberg"] == 1 + assert table_counts["items_iceberg__sub_items"] == 2 + + with pipeline.sql_client() as client: + client.execute_sql("SELECT * FROM items_normal") + + # modifying regular athena table will fail + with pytest.raises(DatabaseTerminalException) as dbex: + client.execute_sql("UPDATE items_normal SET name='new name'") + assert "Modifying Hive table rows is only supported for transactional tables" in str(dbex) + with pytest.raises(DatabaseTerminalException) as dbex: + client.execute_sql("UPDATE items_normal__sub_items SET name='super new name'") + assert "Modifying Hive table rows is only supported for transactional tables" in str(dbex) + + # modifying iceberg table will succeed + client.execute_sql("UPDATE items_iceberg SET name='new name'") + client.execute_sql("UPDATE items_iceberg__sub_items SET name='super new name'") + diff --git a/tests/load/utils.py b/tests/load/utils.py index a615b696e3..9fd4f033b7 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -17,7 +17,7 @@ from dlt.common.data_writers import DataWriter from dlt.common.schema import TColumnSchema, TTableSchemaColumns, Schema from dlt.common.storages import SchemaStorage, FileStorage, SchemaStorageConfiguration -from dlt.common.schema.utils import new_table +from dlt.common.schema.utils import new_table, get_load_table from dlt.common.storages.load_storage import ParsedLoadJobFileName, LoadStorage from dlt.common.typing import StrAny from dlt.common.utils import uniq_id @@ -170,7 +170,7 @@ def load_table(name: str) -> Dict[str, TTableSchemaColumns]: def expect_load_file(client: JobClientBase, file_storage: FileStorage, query: str, table_name: str, status = "completed") -> LoadJob: file_name = ParsedLoadJobFileName(table_name, uniq_id(), 0, client.capabilities.preferred_loader_file_format).job_id() file_storage.save(file_name, query.encode("utf-8")) - table = Load.get_load_table(client.schema, file_name) + table = get_load_table(client.schema.tables, table_name) job = client.start_file_load(table, file_storage.make_full_path(file_name), uniq_id()) while job.state() == "running": sleep(0.5) diff --git a/tests/utils.py b/tests/utils.py index 2d675f514a..7321049c9d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -34,6 +34,7 @@ # filter out active destinations for current tests ACTIVE_DESTINATIONS = set(dlt.config.get("ACTIVE_DESTINATIONS", list) or IMPLEMENTED_DESTINATIONS) +# ACTIVE_DESTINATIONS = {"duckdb"} ACTIVE_SQL_DESTINATIONS = SQL_DESTINATIONS.intersection(ACTIVE_DESTINATIONS) ACTIVE_NON_SQL_DESTINATIONS = NON_SQL_DESTINATIONS.intersection(ACTIVE_DESTINATIONS)