From 88f27229b5bbc10c5cb74caa20bf2d1509fc89c2 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink <47451109+jorritsandbrink@users.noreply.github.com> Date: Sat, 24 Feb 2024 10:08:15 +0100 Subject: [PATCH] Introduce `hard_delete` and `dedup_sort` columns hint for `merge` (#960) * black formatting * remove unused exception * add initial support for replicate write disposition * add hard_delete hint and sorted deduplication for merge * undo config change * undo unintentional changes * refactor hard_delete handling and introduce dedup_sort hint * update docstring * replace dialect-specific SQL * add parentheses to ensure proper clause evaluation order * add escape defaults and temp tables for non-primary key case * exclude destinations that don't support merge from test * correct typo * extend docstring * remove redundant copies for immutable strings * simplify boolean logic * add more test cases for hard_delete and dedup_sort hints * refactor table chain resolution * marks tables that seen data in normalizer, skips empty jobs if never seen data * ignores tables that didn't seen data when loading, tests edge cases * add sort order configuration option * bumps schema engine to v9, adds migrations * filters tables without data properly in load * converts seen-data to boolean, fixes tests * disables filesystem tests config due to merge present * add docs for hard_delete and dedup_sort column hints * fixes extending table chains in load * refactors load and adds unit tests with dummy --------- Co-authored-by: Jorrit Sandbrink Co-authored-by: Marcin Rudolf --- Makefile | 2 +- dlt/common/destination/reference.py | 67 ++- dlt/common/schema/migrations.py | 128 +++++ dlt/common/schema/schema.py | 6 +- dlt/common/schema/typing.py | 7 +- dlt/common/schema/utils.py | 164 ++---- dlt/common/storages/load_package.py | 15 +- dlt/common/utils.py | 8 + dlt/destinations/impl/athena/athena.py | 26 +- dlt/destinations/impl/bigquery/bigquery.py | 68 ++- dlt/destinations/impl/dummy/__init__.py | 9 +- dlt/destinations/impl/dummy/configuration.py | 3 + dlt/destinations/impl/dummy/dummy.py | 52 +- .../impl/filesystem/filesystem.py | 6 +- dlt/destinations/impl/synapse/synapse.py | 35 +- dlt/destinations/job_client_impl.py | 6 +- dlt/destinations/sql_jobs.py | 193 +++++-- dlt/extract/exceptions.py | 7 - dlt/load/load.py | 186 ++----- dlt/load/utils.py | 186 +++++++ dlt/normalize/items_normalizers.py | 22 +- dlt/normalize/normalize.py | 31 +- .../docs/general-usage/incremental-loading.md | 83 +++ .../cases/schemas/eth/ethereum_schema_v9.yml | 476 ++++++++++++++++++ tests/common/schema/test_schema.py | 23 +- tests/common/schema/test_versioning.py | 23 +- tests/common/storages/test_schema_storage.py | 16 +- tests/common/utils.py | 2 +- tests/extract/test_decorators.py | 4 +- .../athena_iceberg/test_athena_iceberg.py | 2 +- .../bigquery/test_bigquery_table_builder.py | 11 +- .../load/duckdb/test_duckdb_table_builder.py | 11 +- tests/load/mssql/test_mssql_table_builder.py | 11 +- tests/load/pipeline/test_merge_disposition.py | 402 ++++++++++++++- tests/load/pipeline/test_pipelines.py | 122 ++++- .../load/pipeline/test_replace_disposition.py | 8 +- tests/load/pipeline/test_restore_state.py | 4 +- tests/load/pipeline/utils.py | 2 + .../postgres/test_postgres_table_builder.py | 12 +- .../redshift/test_redshift_table_builder.py | 11 +- .../snowflake/test_snowflake_table_builder.py | 12 +- .../synapse/test_synapse_table_builder.py | 15 +- tests/load/test_dummy_client.py | 392 ++++++++++++++- tests/load/utils.py | 16 +- tests/pipeline/test_dlt_versions.py | 6 +- 45 files changed, 2345 insertions(+), 546 deletions(-) create mode 100644 dlt/common/schema/migrations.py create mode 100644 dlt/load/utils.py create mode 100644 tests/common/cases/schemas/eth/ethereum_schema_v9.yml diff --git a/Makefile b/Makefile index bd425b0e42..8da28717c0 100644 --- a/Makefile +++ b/Makefile @@ -74,7 +74,7 @@ test-load-local: DESTINATION__POSTGRES__CREDENTIALS=postgresql://loader:loader@localhost:5432/dlt_data DESTINATION__DUCKDB__CREDENTIALS=duckdb:///_storage/test_quack.duckdb poetry run pytest tests -k '(postgres or duckdb)' test-common: - poetry run pytest tests/common tests/normalize tests/extract tests/pipeline tests/reflection tests/sources tests/cli/common + poetry run pytest tests/common tests/normalize tests/extract tests/pipeline tests/reflection tests/sources tests/cli/common tests/load/test_dummy_client.py tests/libs tests/destinations reset-test-storage: -rm -r _storage diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 1c28dffa8c..5e698347e5 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -20,7 +20,6 @@ Generic, Final, ) -from contextlib import contextmanager import datetime # noqa: 251 from copy import deepcopy import inspect @@ -32,10 +31,15 @@ UnknownDestinationModule, ) from dlt.common.schema import Schema, TTableSchema, TSchemaTables -from dlt.common.schema.typing import TWriteDisposition -from dlt.common.schema.exceptions import InvalidDatasetName -from dlt.common.schema.utils import get_write_disposition, get_table_format -from dlt.common.configuration import configspec, with_config, resolve_configuration, known_sections +from dlt.common.schema.exceptions import SchemaException +from dlt.common.schema.utils import ( + get_write_disposition, + get_table_format, + get_columns_names_with_prop, + has_column_with_prop, + get_first_column_name_with_prop, +) +from dlt.common.configuration import configspec, resolve_configuration, known_sections from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration from dlt.common.configuration.accessors import config from dlt.common.destination.capabilities import DestinationCapabilitiesContext @@ -43,7 +47,6 @@ from dlt.common.schema.exceptions import UnknownTableException from dlt.common.storages import FileStorage from dlt.common.storages.load_storage import ParsedLoadJobFileName -from dlt.common.utils import get_module_name from dlt.common.configuration.specs import GcpCredentials, AwsCredentialsWithoutDefaults @@ -252,7 +255,8 @@ def new_file_path(self) -> str: class FollowupJob: """Adds a trait that allows to create a followup job""" - def create_followup_jobs(self, next_state: str) -> List[NewLoadJob]: + def create_followup_jobs(self, final_state: TLoadJobState) -> List[NewLoadJob]: + """Return list of new jobs. `final_state` is state to which this job transits""" return [] @@ -345,6 +349,49 @@ def _verify_schema(self) -> None: table_name, self.capabilities.max_identifier_length, ) + if has_column_with_prop(table, "hard_delete"): + if len(get_columns_names_with_prop(table, "hard_delete")) > 1: + raise SchemaException( + f'Found multiple "hard_delete" column hints for table "{table_name}" in' + f' schema "{self.schema.name}" while only one is allowed:' + f' {", ".join(get_columns_names_with_prop(table, "hard_delete"))}.' + ) + if table.get("write_disposition") in ("replace", "append"): + logger.warning( + f"""The "hard_delete" column hint for column "{get_first_column_name_with_prop(table, 'hard_delete')}" """ + f'in table "{table_name}" with write disposition' + f' "{table.get("write_disposition")}"' + f' in schema "{self.schema.name}" will be ignored.' + ' The "hard_delete" column hint is only applied when using' + ' the "merge" write disposition.' + ) + if has_column_with_prop(table, "dedup_sort"): + if len(get_columns_names_with_prop(table, "dedup_sort")) > 1: + raise SchemaException( + f'Found multiple "dedup_sort" column hints for table "{table_name}" in' + f' schema "{self.schema.name}" while only one is allowed:' + f' {", ".join(get_columns_names_with_prop(table, "dedup_sort"))}.' + ) + if table.get("write_disposition") in ("replace", "append"): + logger.warning( + f"""The "dedup_sort" column hint for column "{get_first_column_name_with_prop(table, 'dedup_sort')}" """ + f'in table "{table_name}" with write disposition' + f' "{table.get("write_disposition")}"' + f' in schema "{self.schema.name}" will be ignored.' + ' The "dedup_sort" column hint is only applied when using' + ' the "merge" write disposition.' + ) + if table.get("write_disposition") == "merge" and not has_column_with_prop( + table, "primary_key" + ): + logger.warning( + f"""The "dedup_sort" column hint for column "{get_first_column_name_with_prop(table, 'dedup_sort')}" """ + f'in table "{table_name}" with write disposition' + f' "{table.get("write_disposition")}"' + f' in schema "{self.schema.name}" will be ignored.' + ' The "dedup_sort" column hint is only applied when a' + " primary key has been specified." + ) for column_name, column in dict(table["columns"]).items(): if len(column_name) > self.capabilities.max_column_identifier_length: raise IdentifierTooLongException( @@ -361,9 +408,9 @@ def _verify_schema(self) -> None: " column manually in code ie. as a merge key?" ) - def get_load_table(self, table_name: str, prepare_for_staging: bool = False) -> TTableSchema: - if table_name not in self.schema.tables: - return None + def prepare_load_table( + self, table_name: str, prepare_for_staging: bool = False + ) -> TTableSchema: try: # make a copy of the schema so modifications do not affect the original document table = deepcopy(self.schema.tables[table_name]) diff --git a/dlt/common/schema/migrations.py b/dlt/common/schema/migrations.py new file mode 100644 index 0000000000..9b206d61a6 --- /dev/null +++ b/dlt/common/schema/migrations.py @@ -0,0 +1,128 @@ +from typing import Dict, List, cast + +from dlt.common.data_types import TDataType +from dlt.common.normalizers import explicit_normalizers +from dlt.common.typing import DictStrAny +from dlt.common.schema.typing import ( + LOADS_TABLE_NAME, + VERSION_TABLE_NAME, + TSimpleRegex, + TStoredSchema, + TTableSchemaColumns, + TColumnHint, +) +from dlt.common.schema.exceptions import SchemaEngineNoUpgradePathException + +from dlt.common.normalizers.utils import import_normalizers +from dlt.common.schema.utils import new_table, version_table, load_table + + +def migrate_schema(schema_dict: DictStrAny, from_engine: int, to_engine: int) -> TStoredSchema: + if from_engine == to_engine: + return cast(TStoredSchema, schema_dict) + + if from_engine == 1 and to_engine > 1: + schema_dict["includes"] = [] + schema_dict["excludes"] = [] + from_engine = 2 + if from_engine == 2 and to_engine > 2: + # current version of the schema + current = cast(TStoredSchema, schema_dict) + # add default normalizers and root hash propagation + current["normalizers"], _, _ = import_normalizers(explicit_normalizers()) + current["normalizers"]["json"]["config"] = { + "propagation": {"root": {"_dlt_id": "_dlt_root_id"}} + } + # move settings, convert strings to simple regexes + d_h: Dict[TColumnHint, List[TSimpleRegex]] = schema_dict.pop("hints", {}) + for h_k, h_l in d_h.items(): + d_h[h_k] = list(map(lambda r: TSimpleRegex("re:" + r), h_l)) + p_t: Dict[TSimpleRegex, TDataType] = schema_dict.pop("preferred_types", {}) + p_t = {TSimpleRegex("re:" + k): v for k, v in p_t.items()} + + current["settings"] = { + "default_hints": d_h, + "preferred_types": p_t, + } + # repackage tables + old_tables: Dict[str, TTableSchemaColumns] = schema_dict.pop("tables") + current["tables"] = {} + for name, columns in old_tables.items(): + # find last path separator + parent = name + # go back in a loop to find existing parent + while True: + idx = parent.rfind("__") + if idx > 0: + parent = parent[:idx] + if parent not in old_tables: + continue + else: + parent = None + break + nt = new_table(name, parent) + nt["columns"] = columns + current["tables"][name] = nt + # assign exclude and include to tables + + def migrate_filters(group: str, filters: List[str]) -> None: + # existing filter were always defined at the root table. find this table and move filters + for f in filters: + # skip initial ^ + root = f[1 : f.find("__")] + path = f[f.find("__") + 2 :] + t = current["tables"].get(root) + if t is None: + # must add new table to hold filters + t = new_table(root) + current["tables"][root] = t + t.setdefault("filters", {}).setdefault(group, []).append("re:^" + path) # type: ignore + + excludes = schema_dict.pop("excludes", []) + migrate_filters("excludes", excludes) + includes = schema_dict.pop("includes", []) + migrate_filters("includes", includes) + + # upgraded + from_engine = 3 + if from_engine == 3 and to_engine > 3: + # set empty version hash to pass validation, in engine 4 this hash is mandatory + schema_dict.setdefault("version_hash", "") + from_engine = 4 + if from_engine == 4 and to_engine > 4: + # replace schema versions table + schema_dict["tables"][VERSION_TABLE_NAME] = version_table() + schema_dict["tables"][LOADS_TABLE_NAME] = load_table() + from_engine = 5 + if from_engine == 5 and to_engine > 5: + # replace loads table + schema_dict["tables"][LOADS_TABLE_NAME] = load_table() + from_engine = 6 + if from_engine == 6 and to_engine > 6: + # migrate from sealed properties to schema evolution settings + schema_dict["settings"].pop("schema_sealed", None) + schema_dict["settings"]["schema_contract"] = {} + for table in schema_dict["tables"].values(): + table.pop("table_sealed", None) + if not table.get("parent"): + table["schema_contract"] = {} + from_engine = 7 + if from_engine == 7 and to_engine > 7: + schema_dict["previous_hashes"] = [] + from_engine = 8 + if from_engine == 8 and to_engine > 8: + # add "seen-data" to all tables with _dlt_id, this will handle packages + # that are being loaded + for table in schema_dict["tables"].values(): + if "_dlt_id" in table["columns"]: + x_normalizer = table.setdefault("x-normalizer", {}) + x_normalizer["seen-data"] = True + from_engine = 9 + + schema_dict["engine_version"] = from_engine + if from_engine != to_engine: + raise SchemaEngineNoUpgradePathException( + schema_dict["name"], schema_dict["engine_version"], from_engine, to_engine + ) + + return cast(TStoredSchema, schema_dict) diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index ccfc038085..b73e45d489 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -2,6 +2,7 @@ from copy import copy, deepcopy from typing import ClassVar, Dict, List, Mapping, Optional, Sequence, Tuple, Any, cast, Literal from dlt.common import json +from dlt.common.schema.migrations import migrate_schema from dlt.common.utils import extend_list_deduplicated from dlt.common.typing import ( @@ -103,7 +104,7 @@ def __init__(self, name: str, normalizers: TNormalizersConfig = None) -> None: @classmethod def from_dict(cls, d: DictStrAny, bump_version: bool = True) -> "Schema": # upgrade engine if needed - stored_schema = utils.migrate_schema(d, d["engine_version"], cls.ENGINE_VERSION) + stored_schema = migrate_schema(d, d["engine_version"], cls.ENGINE_VERSION) # verify schema utils.validate_stored_schema(stored_schema) # add defaults @@ -390,6 +391,7 @@ def resolve_contract_settings_for_table( return Schema.expand_schema_contract_settings(settings) def update_table(self, partial_table: TPartialTableSchema) -> TPartialTableSchema: + """Adds or merges `partial_table` into the schema. Identifiers are not normalized""" table_name = partial_table["name"] parent_table_name = partial_table.get("parent") # check if parent table present @@ -414,7 +416,7 @@ def update_table(self, partial_table: TPartialTableSchema) -> TPartialTableSchem return partial_table def update_schema(self, schema: "Schema") -> None: - """Updates this schema from an incoming schema""" + """Updates this schema from an incoming schema. Normalizes identifiers after updating normalizers.""" # update all tables for table in schema.tables.values(): self.update_table(table) diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index e1ff17115d..fcabeb409a 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -26,7 +26,7 @@ # current version of schema engine -SCHEMA_ENGINE_VERSION = 8 +SCHEMA_ENGINE_VERSION = 9 # dlt tables VERSION_TABLE_NAME = "_dlt_version" @@ -46,6 +46,7 @@ "unique", "merge_key", "root_key", + "dedup_sort", ] """Known properties and hints of the column""" # TODO: merge TColumnHint with TColumnProp @@ -59,6 +60,7 @@ "unique", "root_key", "merge_key", + "dedup_sort", ] """Known hints of a column used to declare hint regexes.""" TWriteDisposition = Literal["skip", "append", "replace", "merge"] @@ -69,6 +71,7 @@ TTypeDetectionFunc = Callable[[Type[Any], Any], Optional[TDataType]] TColumnNames = Union[str, Sequence[str]] """A string representing a column name or a list of""" +TSortOrder = Literal["asc", "desc"] COLUMN_PROPS: Set[TColumnProp] = set(get_args(TColumnProp)) COLUMN_HINTS: Set[TColumnHint] = set( @@ -112,6 +115,8 @@ class TColumnSchema(TColumnSchemaBase, total=False): root_key: Optional[bool] merge_key: Optional[bool] variant: Optional[bool] + hard_delete: Optional[bool] + dedup_sort: Optional[TSortOrder] TTableSchemaColumns = Dict[str, TColumnSchema] diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index de1e113bc3..835fe4279e 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -8,7 +8,6 @@ from dlt.common import json from dlt.common.data_types import TDataType from dlt.common.exceptions import DictValidationException -from dlt.common.normalizers import explicit_normalizers from dlt.common.normalizers.naming import NamingConvention from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCase from dlt.common.typing import DictStrAny, REPattern @@ -27,7 +26,6 @@ TSimpleRegex, TStoredSchema, TTableSchema, - TTableSchemaColumns, TColumnSchemaBase, TColumnSchema, TColumnProp, @@ -37,20 +35,15 @@ TTypeDetections, TWriteDisposition, TSchemaContract, + TSortOrder, ) from dlt.common.schema.exceptions import ( CannotCoerceColumnException, ParentTableNotFoundException, - SchemaEngineNoUpgradePathException, - SchemaException, TablePropertiesConflictException, InvalidSchemaName, - UnknownTableException, ) -from dlt.common.normalizers.utils import import_normalizers -from dlt.common.schema.typing import TAnySchemaColumns - RE_NON_ALPHANUMERIC_UNDERSCORE = re.compile(r"[^a-zA-Z\d_]") DEFAULT_WRITE_DISPOSITION: TWriteDisposition = "append" @@ -318,109 +311,6 @@ def validate_stored_schema(stored_schema: TStoredSchema) -> None: raise ParentTableNotFoundException(table_name, parent_table_name) -def migrate_schema(schema_dict: DictStrAny, from_engine: int, to_engine: int) -> TStoredSchema: - if from_engine == to_engine: - return cast(TStoredSchema, schema_dict) - - if from_engine == 1 and to_engine > 1: - schema_dict["includes"] = [] - schema_dict["excludes"] = [] - from_engine = 2 - if from_engine == 2 and to_engine > 2: - # current version of the schema - current = cast(TStoredSchema, schema_dict) - # add default normalizers and root hash propagation - current["normalizers"], _, _ = import_normalizers(explicit_normalizers()) - current["normalizers"]["json"]["config"] = { - "propagation": {"root": {"_dlt_id": "_dlt_root_id"}} - } - # move settings, convert strings to simple regexes - d_h: Dict[TColumnHint, List[TSimpleRegex]] = schema_dict.pop("hints", {}) - for h_k, h_l in d_h.items(): - d_h[h_k] = list(map(lambda r: TSimpleRegex("re:" + r), h_l)) - p_t: Dict[TSimpleRegex, TDataType] = schema_dict.pop("preferred_types", {}) - p_t = {TSimpleRegex("re:" + k): v for k, v in p_t.items()} - - current["settings"] = { - "default_hints": d_h, - "preferred_types": p_t, - } - # repackage tables - old_tables: Dict[str, TTableSchemaColumns] = schema_dict.pop("tables") - current["tables"] = {} - for name, columns in old_tables.items(): - # find last path separator - parent = name - # go back in a loop to find existing parent - while True: - idx = parent.rfind("__") - if idx > 0: - parent = parent[:idx] - if parent not in old_tables: - continue - else: - parent = None - break - nt = new_table(name, parent) - nt["columns"] = columns - current["tables"][name] = nt - # assign exclude and include to tables - - def migrate_filters(group: str, filters: List[str]) -> None: - # existing filter were always defined at the root table. find this table and move filters - for f in filters: - # skip initial ^ - root = f[1 : f.find("__")] - path = f[f.find("__") + 2 :] - t = current["tables"].get(root) - if t is None: - # must add new table to hold filters - t = new_table(root) - current["tables"][root] = t - t.setdefault("filters", {}).setdefault(group, []).append("re:^" + path) # type: ignore - - excludes = schema_dict.pop("excludes", []) - migrate_filters("excludes", excludes) - includes = schema_dict.pop("includes", []) - migrate_filters("includes", includes) - - # upgraded - from_engine = 3 - if from_engine == 3 and to_engine > 3: - # set empty version hash to pass validation, in engine 4 this hash is mandatory - schema_dict.setdefault("version_hash", "") - from_engine = 4 - if from_engine == 4 and to_engine > 4: - # replace schema versions table - schema_dict["tables"][VERSION_TABLE_NAME] = version_table() - schema_dict["tables"][LOADS_TABLE_NAME] = load_table() - from_engine = 5 - if from_engine == 5 and to_engine > 5: - # replace loads table - schema_dict["tables"][LOADS_TABLE_NAME] = load_table() - from_engine = 6 - if from_engine == 6 and to_engine > 6: - # migrate from sealed properties to schema evolution settings - schema_dict["settings"].pop("schema_sealed", None) - schema_dict["settings"]["schema_contract"] = {} - for table in schema_dict["tables"].values(): - table.pop("table_sealed", None) - if not table.get("parent"): - table["schema_contract"] = {} - from_engine = 7 - if from_engine == 7 and to_engine > 7: - schema_dict["previous_hashes"] = [] - from_engine = 8 - - schema_dict["engine_version"] = from_engine - if from_engine != to_engine: - raise SchemaEngineNoUpgradePathException( - schema_dict["name"], schema_dict["engine_version"], from_engine, to_engine - ) - - return cast(TStoredSchema, schema_dict) - - def autodetect_sc_type(detection_fs: Sequence[TTypeDetections], t: Type[Any], v: Any) -> TDataType: if detection_fs: for detection_fn in detection_fs: @@ -555,6 +445,11 @@ def merge_tables(table: TTableSchema, partial_table: TPartialTableSchema) -> TPa return diff_table +def has_table_seen_data(table: TTableSchema) -> bool: + """Checks if normalizer has seen data coming to the table.""" + return "x-normalizer" in table and table["x-normalizer"].get("seen-data", None) is True # type: ignore[typeddict-item] + + def hint_to_column_prop(h: TColumnHint) -> TColumnProp: if h == "not_null": return "nullable" @@ -573,6 +468,39 @@ def get_columns_names_with_prop( ] +def get_first_column_name_with_prop( + table: TTableSchema, column_prop: Union[TColumnProp, str], include_incomplete: bool = False +) -> Optional[str]: + """Returns name of first column in `table` schema with property `column_prop` or None if no such column exists.""" + column_names = get_columns_names_with_prop(table, column_prop, include_incomplete) + if len(column_names) > 0: + return column_names[0] + return None + + +def has_column_with_prop( + table: TTableSchema, column_prop: Union[TColumnProp, str], include_incomplete: bool = False +) -> bool: + """Checks if `table` schema contains column with property `column_prop`.""" + return len(get_columns_names_with_prop(table, column_prop, include_incomplete)) > 0 + + +def get_dedup_sort_tuple( + table: TTableSchema, include_incomplete: bool = False +) -> Optional[Tuple[str, TSortOrder]]: + """Returns tuple with dedup sort information. + + First element is the sort column name, second element is the sort order. + + Returns None if "dedup_sort" hint was not provided. + """ + dedup_sort_col = get_first_column_name_with_prop(table, "dedup_sort", include_incomplete) + if dedup_sort_col is None: + return None + dedup_sort_order = table["columns"][dedup_sort_col]["dedup_sort"] + return (dedup_sort_col, dedup_sort_order) + + def merge_schema_updates(schema_updates: Sequence[TSchemaUpdate]) -> TSchemaTables: aggregated_update: TSchemaTables = {} for schema_update in schema_updates: @@ -618,6 +546,20 @@ def get_table_format(tables: TSchemaTables, table_name: str) -> TTableFormat: ) +def fill_hints_from_parent_and_clone_table( + tables: TSchemaTables, table: TTableSchema +) -> TTableSchema: + """Takes write disposition and table format from parent tables if not present""" + # make a copy of the schema so modifications do not affect the original document + table = deepcopy(table) + # add write disposition if not specified - in child tables + if "write_disposition" not in table: + table["write_disposition"] = get_write_disposition(tables, table["name"]) + if "table_format" not in table: + table["table_format"] = get_table_format(tables, table["name"]) + return table + + def table_schema_has_type(table: TTableSchema, _typ: TDataType) -> bool: """Checks if `table` schema contains column with type _typ""" return any(c.get("data_type") == _typ for c in table["columns"].values()) diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index 2860364cd0..01f3923455 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -8,6 +8,7 @@ from typing import ( ClassVar, Dict, + Iterable, List, NamedTuple, Literal, @@ -245,9 +246,7 @@ def list_failed_jobs(self, load_id: str) -> Sequence[str]: ) def list_jobs_for_table(self, load_id: str, table_name: str) -> Sequence[LoadJobInfo]: - return [ - job for job in self.list_all_jobs(load_id) if job.job_file_info.table_name == table_name - ] + return self.filter_jobs_for_table(self.list_all_jobs(load_id), table_name) def list_all_jobs(self, load_id: str) -> Sequence[LoadJobInfo]: info = self.get_load_package_info(load_id) @@ -448,6 +447,10 @@ def _read_job_file_info(self, state: TJobState, file: str, now: DateTime = None) failed_message, ) + # + # Utils + # + def _move_job( self, load_id: str, @@ -503,3 +506,9 @@ def is_package_partially_loaded(package_info: LoadPackageInfo) -> bool: @staticmethod def _job_elapsed_time_seconds(file_path: str, now_ts: float = None) -> float: return (now_ts or pendulum.now().timestamp()) - os.path.getmtime(file_path) + + @staticmethod + def filter_jobs_for_table( + all_jobs: Iterable[LoadJobInfo], table_name: str + ) -> Sequence[LoadJobInfo]: + return [job for job in all_jobs if job.job_file_info.table_name == table_name] diff --git a/dlt/common/utils.py b/dlt/common/utils.py index 72fee608a8..49a425780b 100644 --- a/dlt/common/utils.py +++ b/dlt/common/utils.py @@ -575,3 +575,11 @@ def get_exception_trace_chain( elif exc.__context__: return get_exception_trace_chain(exc.__context__, traces, seen) return traces + + +def order_deduped(lst: List[Any]) -> List[Any]: + """Returns deduplicated list preserving order of input elements. + + Only works for lists with hashable elements. + """ + return list(dict.fromkeys(lst)) diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index 96e7818d57..c2dae7a350 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -365,7 +365,7 @@ def _get_table_update_sql( # for the system tables we need to create empty iceberg tables to be able to run, DELETE and UPDATE queries # or if we are in iceberg mode, we create iceberg tables for all tables - table = self.get_load_table(table_name, self.in_staging_mode) + table = self.prepare_load_table(table_name, self.in_staging_mode) is_iceberg = self._is_iceberg_table(table) or table.get("write_disposition", None) == "skip" columns = ", ".join( [self._get_column_def_sql(c, table.get("table_format")) for c in new_columns] @@ -405,13 +405,13 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> if not job: job = ( DoNothingFollowupJob(file_path) - if self._is_iceberg_table(self.get_load_table(table["name"])) + if self._is_iceberg_table(self.prepare_load_table(table["name"])) else DoNothingJob(file_path) ) return job def _create_append_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: - if self._is_iceberg_table(self.get_load_table(table_chain[0]["name"])): + if self._is_iceberg_table(self.prepare_load_table(table_chain[0]["name"])): return [ SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": False}) ] @@ -420,7 +420,7 @@ def _create_append_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> L def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] ) -> List[NewLoadJob]: - if self._is_iceberg_table(self.get_load_table(table_chain[0]["name"])): + if self._is_iceberg_table(self.prepare_load_table(table_chain[0]["name"])): return [ SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": True}) ] @@ -436,15 +436,15 @@ def _is_iceberg_table(self, table: TTableSchema) -> bool: def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: # all iceberg tables need staging - if self._is_iceberg_table(self.get_load_table(table["name"])): + if self._is_iceberg_table(self.prepare_load_table(table["name"])): return True return super().should_load_data_to_staging_dataset(table) def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: # on athena we only truncate replace tables that are not iceberg - table = self.get_load_table(table["name"]) + table = self.prepare_load_table(table["name"]) if table["write_disposition"] == "replace" and not self._is_iceberg_table( - self.get_load_table(table["name"]) + self.prepare_load_table(table["name"]) ): return True return False @@ -453,13 +453,17 @@ def should_load_data_to_staging_dataset_on_staging_destination( self, table: TTableSchema ) -> bool: """iceberg table data goes into staging on staging destination""" - return self._is_iceberg_table(self.get_load_table(table["name"])) + if self._is_iceberg_table(self.prepare_load_table(table["name"])): + return True + return super().should_load_data_to_staging_dataset_on_staging_destination(table) - def get_load_table(self, table_name: str, staging: bool = False) -> TTableSchema: - table = super().get_load_table(table_name, staging) + def prepare_load_table( + self, table_name: str, prepare_for_staging: bool = False + ) -> TTableSchema: + table = super().prepare_load_table(table_name, prepare_for_staging) if self.config.force_iceberg: table["table_format"] = "iceberg" - if staging and table.get("table_format", None) == "iceberg": + if prepare_for_staging and table.get("table_format", None) == "iceberg": table.pop("table_format") return table diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index 9fc94b6ee3..d4261a1636 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -248,7 +248,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> def _get_table_update_sql( self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool ) -> List[str]: - table: Optional[TTableSchema] = self.get_load_table(table_name) + table: Optional[TTableSchema] = self.prepare_load_table(table_name) sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) canonical_name = self.sql_client.make_qualified_table_name(table_name) @@ -263,8 +263,9 @@ def _get_table_update_sql( elif (c := partition_list[0])["data_type"] == "date": sql[0] += f"\nPARTITION BY {self.capabilities.escape_identifier(c['name'])}" elif (c := partition_list[0])["data_type"] == "timestamp": - sql[0] += f"\nPARTITION BY DATE({self.capabilities.escape_identifier(c['name'])})" - + sql[0] = ( + f"{sql[0]}\nPARTITION BY DATE({self.capabilities.escape_identifier(c['name'])})" + ) # Automatic partitioning of an INT64 type requires us to be prescriptive - we treat the column as a UNIX timestamp. # This is due to the bounds requirement of GENERATE_ARRAY function for partitioning. # The 10,000 partitions limit makes it infeasible to cover the entire `bigint` range. @@ -284,46 +285,39 @@ def _get_table_update_sql( sql[0] += "\nCLUSTER BY " + ", ".join(cluster_list) # Table options. - if table: - table_options: DictStrAny = { - "description": ( - f"'{table.get(TABLE_DESCRIPTION_HINT)}'" - if table.get(TABLE_DESCRIPTION_HINT) - else None - ), - "expiration_timestamp": ( - f"TIMESTAMP '{table.get(TABLE_EXPIRATION_HINT)}'" - if table.get(TABLE_EXPIRATION_HINT) - else None - ), - } - if not any(table_options.values()): - return sql - - if generate_alter: - raise NotImplementedError("Update table options not yet implemented.") - else: - sql[0] += ( - "\nOPTIONS (" - + ", ".join( - [ - f"{key}={value}" - for key, value in table_options.items() - if value is not None - ] - ) - + ")" + table_options: DictStrAny = { + "description": ( + f"'{table.get(TABLE_DESCRIPTION_HINT)}'" + if table.get(TABLE_DESCRIPTION_HINT) + else None + ), + "expiration_timestamp": ( + f"TIMESTAMP '{table.get(TABLE_EXPIRATION_HINT)}'" + if table.get(TABLE_EXPIRATION_HINT) + else None + ), + } + if not any(table_options.values()): + return sql + + if generate_alter: + raise NotImplementedError("Update table options not yet implemented.") + else: + sql[0] += ( + "\nOPTIONS (" + + ", ".join( + [f"{key}={value}" for key, value in table_options.items() if value is not None] ) + + ")" + ) return sql - def get_load_table( + def prepare_load_table( self, table_name: str, prepare_for_staging: bool = False ) -> Optional[TTableSchema]: - table = super().get_load_table(table_name, prepare_for_staging) - if table is None: - return None - elif table_name in self.schema.data_table_names(): + table = super().prepare_load_table(table_name, prepare_for_staging) + if table_name in self.schema.data_table_names(): if TABLE_DESCRIPTION_HINT not in table: table[TABLE_DESCRIPTION_HINT] = ( # type: ignore[name-defined, typeddict-unknown-key, unused-ignore] get_inherited_table_hint( diff --git a/dlt/destinations/impl/dummy/__init__.py b/dlt/destinations/impl/dummy/__init__.py index a3152b8d77..37b2e77c8a 100644 --- a/dlt/destinations/impl/dummy/__init__.py +++ b/dlt/destinations/impl/dummy/__init__.py @@ -1,6 +1,8 @@ +from typing import List from dlt.common.configuration import with_config, known_sections from dlt.common.configuration.accessors import config from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.destination.capabilities import TLoaderFileFormat from dlt.destinations.impl.dummy.configuration import DummyClientConfiguration @@ -18,11 +20,14 @@ def _configure(config: DummyClientConfiguration = config.value) -> DummyClientCo def capabilities() -> DestinationCapabilitiesContext: config = _configure() + additional_formats: List[TLoaderFileFormat] = ( + ["reference"] if config.create_followup_jobs else [] + ) caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = config.loader_file_format - caps.supported_loader_file_formats = [config.loader_file_format] + caps.supported_loader_file_formats = additional_formats + [config.loader_file_format] caps.preferred_staging_file_format = None - caps.supported_staging_file_formats = [config.loader_file_format] + caps.supported_staging_file_formats = additional_formats + [config.loader_file_format] caps.max_identifier_length = 127 caps.max_column_identifier_length = 127 caps.max_query_length = 8 * 1024 * 1024 diff --git a/dlt/destinations/impl/dummy/configuration.py b/dlt/destinations/impl/dummy/configuration.py index 82dc797126..cce0dfa8ed 100644 --- a/dlt/destinations/impl/dummy/configuration.py +++ b/dlt/destinations/impl/dummy/configuration.py @@ -26,6 +26,8 @@ class DummyClientConfiguration(DestinationClientConfiguration): """probability of exception when checking job status""" timeout: float = 10.0 fail_in_init: bool = True + # new jobs workflows + create_followup_jobs: bool = False credentials: DummyClientCredentials = None @@ -43,6 +45,7 @@ def __init__( exception_prob: float = None, timeout: float = None, fail_in_init: bool = None, + create_followup_jobs: bool = None, destination_name: str = None, environment: str = None, ) -> None: ... diff --git a/dlt/destinations/impl/dummy/dummy.py b/dlt/destinations/impl/dummy/dummy.py index 367db11e82..c46e329819 100644 --- a/dlt/destinations/impl/dummy/dummy.py +++ b/dlt/destinations/impl/dummy/dummy.py @@ -1,7 +1,18 @@ +from contextlib import contextmanager import random from copy import copy from types import TracebackType -from typing import ClassVar, Dict, Optional, Sequence, Type, Iterable, List +from typing import ( + ClassVar, + ContextManager, + Dict, + Iterator, + Optional, + Sequence, + Type, + Iterable, + List, +) from dlt.common import pendulum from dlt.common.schema import Schema, TTableSchema, TSchemaTables @@ -15,6 +26,7 @@ TLoadJobState, LoadJob, JobClientBase, + WithStagingDataset, ) from dlt.destinations.exceptions import ( @@ -26,9 +38,10 @@ from dlt.destinations.impl.dummy import capabilities from dlt.destinations.impl.dummy.configuration import DummyClientConfiguration +from dlt.destinations.job_impl import NewReferenceJob -class LoadDummyJob(LoadJob, FollowupJob): +class LoadDummyBaseJob(LoadJob): def __init__(self, file_name: str, config: DummyClientConfiguration) -> None: self.config = copy(config) self._status: TLoadJobState = "running" @@ -79,16 +92,29 @@ def retry(self) -> None: self._status = "retry" -JOBS: Dict[str, LoadDummyJob] = {} +class LoadDummyJob(LoadDummyBaseJob, FollowupJob): + def create_followup_jobs(self, final_state: TLoadJobState) -> List[NewLoadJob]: + if self.config.create_followup_jobs and final_state == "completed": + new_job = NewReferenceJob( + file_name=self.file_name(), status="running", remote_path=self._file_name + ) + CREATED_FOLLOWUP_JOBS[new_job.job_id()] = new_job + return [new_job] + return [] + + +JOBS: Dict[str, LoadDummyBaseJob] = {} +CREATED_FOLLOWUP_JOBS: Dict[str, NewLoadJob] = {} -class DummyClient(JobClientBase, SupportsStagingDestination): +class DummyClient(JobClientBase, SupportsStagingDestination, WithStagingDataset): """dummy client storing jobs in memory""" capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: DummyClientConfiguration) -> None: super().__init__(schema, config) + self.in_staging_context = False self.config: DummyClientConfiguration = config def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: @@ -138,6 +164,17 @@ def create_table_chain_completed_followup_jobs( def complete_load(self, load_id: str) -> None: pass + def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: + return super().should_load_data_to_staging_dataset(table) + + @contextmanager + def with_staging_dataset(self) -> Iterator[JobClientBase]: + try: + self.in_staging_context = True + yield self + finally: + self.in_staging_context = False + def __enter__(self) -> "DummyClient": return self @@ -146,5 +183,8 @@ def __exit__( ) -> None: pass - def _create_job(self, job_id: str) -> LoadDummyJob: - return LoadDummyJob(job_id, config=self.config) + def _create_job(self, job_id: str) -> LoadDummyBaseJob: + if NewReferenceJob.is_reference_job(job_id): + return LoadDummyBaseJob(job_id, config=self.config) + else: + return LoadDummyJob(job_id, config=self.config) diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index 5885f8a1ec..33a597f915 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -77,9 +77,9 @@ def exception(self) -> str: class FollowupFilesystemJob(FollowupJob, LoadFilesystemJob): - def create_followup_jobs(self, next_state: str) -> List[NewLoadJob]: - jobs = super().create_followup_jobs(next_state) - if next_state == "completed": + def create_followup_jobs(self, final_state: TLoadJobState) -> List[NewLoadJob]: + jobs = super().create_followup_jobs(final_state) + if final_state == "completed": ref_job = NewReferenceJob( file_name=self.file_name(), status="running", remote_path=self.make_remote_path() ) diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py index 33e6194602..457e128ba0 100644 --- a/dlt/destinations/impl/synapse/synapse.py +++ b/dlt/destinations/impl/synapse/synapse.py @@ -70,23 +70,18 @@ def __init__(self, schema: Schema, config: SynapseClientConfiguration) -> None: def _get_table_update_sql( self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool ) -> List[str]: - table = self.get_load_table(table_name, staging=self.in_staging_mode) - if table is None: - table_index_type = self.config.default_table_index_type + table = self.prepare_load_table(table_name, staging=self.in_staging_mode) + table_index_type = cast(TTableIndexType, table.get(TABLE_INDEX_TYPE_HINT)) + if self.in_staging_mode: + final_table = self.prepare_load_table(table_name, staging=False) + final_table_index_type = cast(TTableIndexType, final_table.get(TABLE_INDEX_TYPE_HINT)) else: - table_index_type = cast(TTableIndexType, table.get(TABLE_INDEX_TYPE_HINT)) - if self.in_staging_mode: - final_table = self.get_load_table(table_name, staging=False) - final_table_index_type = cast( - TTableIndexType, final_table.get(TABLE_INDEX_TYPE_HINT) - ) - else: - final_table_index_type = table_index_type - if final_table_index_type == "clustered_columnstore_index": - # Even if the staging table has index type "heap", we still adjust - # the column data types to prevent errors when writing into the - # final table that has index type "clustered_columnstore_index". - new_columns = self._get_columstore_valid_columns(new_columns) + final_table_index_type = table_index_type + if final_table_index_type == "clustered_columnstore_index": + # Even if the staging table has index type "heap", we still adjust + # the column data types to prevent errors when writing into the + # final table that has index type "clustered_columnstore_index". + new_columns = self._get_columstore_valid_columns(new_columns) _sql_result = SqlJobClientBase._get_table_update_sql( self, table_name, new_columns, generate_alter @@ -135,10 +130,8 @@ def _create_replace_followup_jobs( return [SynapseStagingCopyJob.from_table_chain(table_chain, self.sql_client)] return super()._create_replace_followup_jobs(table_chain) - def get_load_table(self, table_name: str, staging: bool = False) -> TTableSchema: - table = super().get_load_table(table_name, staging) - if table is None: - return None + def prepare_load_table(self, table_name: str, staging: bool = False) -> TTableSchema: + table = super().prepare_load_table(table_name, staging) if staging and self.config.replace_strategy == "insert-from-staging": # Staging tables should always be heap tables, because "when you are # temporarily landing data in dedicated SQL pool, you may find that @@ -153,7 +146,7 @@ def get_load_table(self, table_name: str, staging: bool = False) -> TTableSchema # index for faster query performance." table[TABLE_INDEX_TYPE_HINT] = "heap" # type: ignore[typeddict-unknown-key] # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-index#heap-tables - elif table_name in self.schema.data_table_names(): + else: if TABLE_INDEX_TYPE_HINT not in table: # If present in parent table, fetch hint from there. table[TABLE_INDEX_TYPE_HINT] = get_inherited_table_hint( # type: ignore[typeddict-unknown-key] diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 4333509b2c..7896fa2cc4 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -427,7 +427,7 @@ def _build_schema_update_sql( sql += ";" sql_updates.append(sql) # create a schema update for particular table - partial_table = copy(self.get_load_table(table_name)) + partial_table = copy(self.prepare_load_table(table_name)) # keep only new columns partial_table["columns"] = {c["name"]: c for c in new_columns} schema_update[table_name] = partial_table @@ -445,8 +445,8 @@ def _get_table_update_sql( ) -> List[str]: # build sql canonical_name = self.sql_client.make_qualified_table_name(table_name) - table = self.get_load_table(table_name) - table_format = table.get("table_format") if table else None + table = self.prepare_load_table(table_name) + table_format = table.get("table_format") sql_result: List[str] = [] if not generate_alter: # build CREATE diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index d0911d0bea..215bcf9fe5 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -3,10 +3,15 @@ import yaml from dlt.common.runtime.logger import pretty_format_exception -from dlt.common.schema.typing import TTableSchema -from dlt.common.schema.utils import get_columns_names_with_prop +from dlt.common.schema.typing import TTableSchema, TSortOrder +from dlt.common.schema.utils import ( + get_columns_names_with_prop, + get_first_column_name_with_prop, + get_dedup_sort_tuple, +) from dlt.common.storages.load_storage import ParsedLoadJobFileName from dlt.common.utils import uniq_id +from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.destinations.exceptions import MergeDispositionException from dlt.destinations.job_impl import NewLoadJobImpl from dlt.destinations.sql_client import SqlClientBase @@ -147,6 +152,9 @@ def generate_sql( First we store the root_keys of root table elements to be deleted in the temp table. Then we use the temp table to delete records from root and all child tables in the destination dataset. At the end we copy the data from the staging dataset into destination dataset. + + If a hard_delete column is specified, records flagged as deleted will be excluded from the copy into the destination dataset. + If a dedup_sort column is specified in conjunction with a primary key, records will be sorted before deduplication, so the "latest" record remains. """ return cls.gen_merge_sql(table_chain, sql_client) @@ -200,19 +208,94 @@ def gen_delete_temp_table_sql( sql.append(f"INSERT INTO {temp_table_name} SELECT {unique_column} {clause};") return sql, temp_table_name + @classmethod + def gen_select_from_dedup_sql( + cls, + table_name: str, + primary_keys: Sequence[str], + columns: Sequence[str], + dedup_sort: Tuple[str, TSortOrder] = None, + condition: str = None, + condition_columns: Sequence[str] = None, + ) -> str: + """Returns SELECT FROM SQL statement. + + The FROM clause in the SQL statement represents a deduplicated version + of the `table_name` table. + + Expects column names provided in arguments to be escaped identifiers. + + Args: + table_name: Name of the table that is selected from. + primary_keys: A sequence of column names representing the primary + key of the table. Is used to deduplicate the table. + columns: Sequence of column names that will be selected from + the table. + sort_column: Name of a column to sort the records by within a + primary key. Values in the column are sorted in descending order, + so the record with the highest value in `sort_column` remains + after deduplication. No sorting is done if a None value is provided, + leading to arbitrary deduplication. + condition: String used as a WHERE clause in the SQL statement to + filter records. The name of any column that is used in the + condition but is not part of `columns` must be provided in the + `condition_columns` argument. No filtering is done (aside from the + deduplication) if a None value is provided. + condition_columns: Sequence of names of columns used in the `condition` + argument. These column names will be selected in the inner subquery + to make them accessible to the outer WHERE clause. This argument + should only be used in combination with the `condition` argument. + + Returns: + A string representing a SELECT FROM SQL statement where the FROM + clause represents a deduplicated version of the `table_name` table. + + The returned value is used in two ways: + 1) To select the values for an INSERT INTO statement. + 2) To select the values for a temporary table used for inserts. + """ + order_by = "(SELECT NULL)" + if dedup_sort is not None: + order_by = f"{dedup_sort[0]} {dedup_sort[1].upper()}" + if condition is None: + condition = "1 = 1" + col_str = ", ".join(columns) + inner_col_str = col_str + if condition_columns is not None: + inner_col_str += ", " + ", ".join(condition_columns) + return f""" + SELECT {col_str} + FROM ( + SELECT ROW_NUMBER() OVER (partition BY {", ".join(primary_keys)} ORDER BY {order_by}) AS _dlt_dedup_rn, {inner_col_str} + FROM {table_name} + ) AS _dlt_dedup_numbered WHERE _dlt_dedup_rn = 1 AND ({condition}) + """ + @classmethod def gen_insert_temp_table_sql( - cls, staging_root_table_name: str, primary_keys: Sequence[str], unique_column: str + cls, + staging_root_table_name: str, + primary_keys: Sequence[str], + unique_column: str, + dedup_sort: Tuple[str, TSortOrder] = None, + condition: str = None, + condition_columns: Sequence[str] = None, ) -> Tuple[List[str], str]: temp_table_name = cls._new_temp_table_name("insert") - select_statement = f""" - SELECT {unique_column} - FROM ( - SELECT ROW_NUMBER() OVER (partition BY {", ".join(primary_keys)} ORDER BY (SELECT NULL)) AS _dlt_dedup_rn, {unique_column} - FROM {staging_root_table_name} - ) AS _dlt_dedup_numbered WHERE _dlt_dedup_rn = 1 - """ - return [cls._to_temp_table(select_statement, temp_table_name)], temp_table_name + if len(primary_keys) > 0: + # deduplicate + select_sql = cls.gen_select_from_dedup_sql( + staging_root_table_name, + primary_keys, + [unique_column], + dedup_sort, + condition, + condition_columns, + ) + else: + # don't deduplicate + select_sql = f"SELECT {unique_column} FROM {staging_root_table_name} WHERE {condition}" + return [cls._to_temp_table(select_sql, temp_table_name)], temp_table_name @classmethod def gen_delete_from_sql( @@ -253,6 +336,13 @@ def gen_merge_sql( sql: List[str] = [] root_table = table_chain[0] + escape_id = sql_client.capabilities.escape_identifier + escape_lit = sql_client.capabilities.escape_literal + if escape_id is None: + escape_id = DestinationCapabilitiesContext.generic_capabilities().escape_identifier + if escape_lit is None: + escape_lit = DestinationCapabilitiesContext.generic_capabilities().escape_literal + # get top level table full identifiers root_table_name = sql_client.make_qualified_table_name(root_table["name"]) with sql_client.with_staging_dataset(staging=True): @@ -260,13 +350,13 @@ def gen_merge_sql( # get merge and primary keys from top level primary_keys = list( map( - sql_client.capabilities.escape_identifier, + escape_id, get_columns_names_with_prop(root_table, "primary_key"), ) ) merge_keys = list( map( - sql_client.capabilities.escape_identifier, + escape_id, get_columns_names_with_prop(root_table, "merge_key"), ) ) @@ -274,7 +364,6 @@ def gen_merge_sql( unique_column: str = None root_key_column: str = None - insert_temp_table_name: str = None if len(table_chain) == 1: key_table_clauses = cls.gen_key_table_clauses( @@ -298,7 +387,7 @@ def gen_merge_sql( " it is not possible to link child tables to it.", ) # get first unique column - unique_column = sql_client.capabilities.escape_identifier(unique_columns[0]) + unique_column = escape_id(unique_columns[0]) # create temp table with unique identifier create_delete_temp_table_sql, delete_temp_table_name = cls.gen_delete_temp_table_sql( unique_column, key_table_clauses @@ -319,7 +408,7 @@ def gen_merge_sql( f" {table['name']} so it is not possible to refer to top level table" f" {root_table['name']} unique column {unique_column}", ) - root_key_column = sql_client.capabilities.escape_identifier(root_key_columns[0]) + root_key_column = escape_id(root_key_columns[0]) sql.append( cls.gen_delete_from_sql( table_name, root_key_column, delete_temp_table_name, unique_column @@ -333,47 +422,59 @@ def gen_merge_sql( ) ) - # create temp table used to deduplicate, only when we have primary keys - if primary_keys: + # get name of column with hard_delete hint, if specified + not_deleted_cond: str = None + hard_delete_col = get_first_column_name_with_prop(root_table, "hard_delete") + if hard_delete_col is not None: + # any value indicates a delete for non-boolean columns + not_deleted_cond = f"{escape_id(hard_delete_col)} IS NULL" + if root_table["columns"][hard_delete_col]["data_type"] == "bool": + # only True values indicate a delete for boolean columns + not_deleted_cond += f" OR {escape_id(hard_delete_col)} = {escape_lit(False)}" + + # get dedup sort information + dedup_sort = get_dedup_sort_tuple(root_table) + + insert_temp_table_name: str = None + if len(table_chain) > 1: + if len(primary_keys) > 0 or hard_delete_col is not None: + condition_columns = [hard_delete_col] if not_deleted_cond is not None else None ( create_insert_temp_table_sql, insert_temp_table_name, ) = cls.gen_insert_temp_table_sql( - staging_root_table_name, primary_keys, unique_column + staging_root_table_name, + primary_keys, + unique_column, + dedup_sort, + not_deleted_cond, + condition_columns, ) sql.extend(create_insert_temp_table_sql) - # insert from staging to dataset, truncate staging table + # insert from staging to dataset for table in table_chain: table_name = sql_client.make_qualified_table_name(table["name"]) with sql_client.with_staging_dataset(staging=True): staging_table_name = sql_client.make_qualified_table_name(table["name"]) - columns = ", ".join( - map( - sql_client.capabilities.escape_identifier, - get_columns_names_with_prop(table, "name"), - ) - ) - insert_sql = ( - f"INSERT INTO {table_name}({columns}) SELECT {columns} FROM {staging_table_name}" - ) - if len(primary_keys) > 0: - if len(table_chain) == 1: - insert_sql = f"""INSERT INTO {table_name}({columns}) - SELECT {columns} FROM ( - SELECT ROW_NUMBER() OVER (partition BY {", ".join(primary_keys)} ORDER BY (SELECT NULL)) AS _dlt_dedup_rn, {columns} - FROM {staging_table_name} - ) AS _dlt_dedup_numbered WHERE _dlt_dedup_rn = 1; - """ - else: - uniq_column = unique_column if table.get("parent") is None else root_key_column - insert_sql += ( - f" WHERE {uniq_column} IN (SELECT * FROM {insert_temp_table_name});" - ) - if insert_sql.strip()[-1] != ";": - insert_sql += ";" - sql.append(insert_sql) - # -- DELETE FROM {staging_table_name} WHERE 1=1; + insert_cond = not_deleted_cond if hard_delete_col is not None else "1 = 1" + if (len(primary_keys) > 0 and len(table_chain) > 1) or ( + len(primary_keys) == 0 + and table.get("parent") is not None # child table + and hard_delete_col is not None + ): + uniq_column = unique_column if table.get("parent") is None else root_key_column + insert_cond = f"{uniq_column} IN (SELECT * FROM {insert_temp_table_name})" + + columns = list(map(escape_id, get_columns_names_with_prop(table, "name"))) + col_str = ", ".join(columns) + select_sql = f"SELECT {col_str} FROM {staging_table_name} WHERE {insert_cond}" + if len(primary_keys) > 0 and len(table_chain) == 1: + # without child tables we deduplicate inside the query instead of using a temp table + select_sql = cls.gen_select_from_dedup_sql( + staging_table_name, primary_keys, columns, dedup_sort, insert_cond + ) + sql.append(f"INSERT INTO {table_name}({col_str}) {select_sql};") return sql diff --git a/dlt/extract/exceptions.py b/dlt/extract/exceptions.py index de785865c5..d24b6f5250 100644 --- a/dlt/extract/exceptions.py +++ b/dlt/extract/exceptions.py @@ -300,13 +300,6 @@ def __init__(self, resource_name: str, msg: str) -> None: super().__init__(resource_name, f"This resource is not a transformer: {msg}") -class TableNameMissing(DltSourceException): - def __init__(self) -> None: - super().__init__( - """Table name is missing in table template. Please provide a string or a function that takes a data item as an argument""" - ) - - class InconsistentTableTemplate(DltSourceException): def __init__(self, reason: str) -> None: msg = f"A set of table hints provided to the resource is inconsistent: {reason}" diff --git a/dlt/load/load.py b/dlt/load/load.py index b0b52d61d6..050e7bce67 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -1,7 +1,7 @@ import contextlib from functools import reduce import datetime # noqa: 251 -from typing import Dict, List, Optional, Tuple, Set, Iterator, Iterable, Callable +from typing import Dict, List, Optional, Tuple, Set, Iterator, Iterable from concurrent.futures import Executor import os @@ -9,7 +9,7 @@ from dlt.common.configuration import with_config, known_sections from dlt.common.configuration.accessors import config from dlt.common.pipeline import LoadInfo, LoadMetrics, SupportsPipeline, WithStepInfo -from dlt.common.schema.utils import get_child_tables, get_top_level_table +from dlt.common.schema.utils import get_top_level_table from dlt.common.storages.load_storage import LoadPackageInfo, ParsedLoadJobFileName, TJobState from dlt.common.runners import TRunMetrics, Runnable, workermethod, NullExecutor from dlt.common.runtime.collector import Collector, NULL_COLLECTOR @@ -20,7 +20,6 @@ DestinationTransientException, ) from dlt.common.schema import Schema, TSchemaTables -from dlt.common.schema.typing import TTableSchema, TWriteDisposition from dlt.common.storages import LoadStorage from dlt.common.destination.reference import ( DestinationClientDwhConfiguration, @@ -45,6 +44,7 @@ LoadClientUnsupportedWriteDisposition, LoadClientUnsupportedFileFormats, ) +from dlt.load.utils import get_completed_table_chain, init_client class Load(Runnable[Executor], WithStepInfo[LoadMetrics, LoadInfo]): @@ -138,7 +138,7 @@ def w_spool_job( file_path, ) logger.info(f"Will load file {file_path} with table name {job_info.table_name}") - table = client.get_load_table(job_info.table_name) + table = client.prepare_load_table(job_info.table_name) if table["write_disposition"] not in ["append", "replace", "merge"]: raise LoadClientUnsupportedWriteDisposition( job_info.table_name, table["write_disposition"], file_path @@ -221,39 +221,6 @@ def get_new_jobs_info(self, load_id: str) -> List[ParsedLoadJobFileName]: for job_file in self.load_storage.list_new_jobs(load_id) ] - def get_completed_table_chain( - self, - load_id: str, - schema: Schema, - top_merged_table: TTableSchema, - being_completed_job_id: str = None, - ) -> List[TTableSchema]: - """Gets a table chain starting from the `top_merged_table` containing only tables with completed/failed jobs. None is returned if there's any job that is not completed - - Optionally `being_completed_job_id` can be passed that is considered to be completed before job itself moves in storage - """ - # returns ordered list of tables from parent to child leaf tables - table_chain: List[TTableSchema] = [] - # make sure all the jobs for the table chain is completed - for table in get_child_tables(schema.tables, top_merged_table["name"]): - table_jobs = self.load_storage.normalized_packages.list_jobs_for_table( - load_id, table["name"] - ) - # all jobs must be completed in order for merge to be created - if any( - job.state not in ("failed_jobs", "completed_jobs") - and job.job_file_info.job_id() != being_completed_job_id - for job in table_jobs - ): - return None - # if there are no jobs for the table, skip it, unless the write disposition is replace, as we need to create and clear the child tables - if not table_jobs and top_merged_table["write_disposition"] != "replace": - continue - table_chain.append(table) - # there must be at least table - assert len(table_chain) > 0 - return table_chain - def create_followup_jobs( self, load_id: str, state: TLoadJobState, starting_job: LoadJob, schema: Schema ) -> List[NewLoadJob]: @@ -268,8 +235,9 @@ def create_followup_jobs( schema.tables, starting_job.job_file_info().table_name ) # if all tables of chain completed, create follow up jobs - if table_chain := self.get_completed_table_chain( - load_id, schema, top_job_table, starting_job.job_file_info().job_id() + all_jobs = self.load_storage.normalized_packages.list_all_jobs(load_id) + if table_chain := get_completed_table_chain( + schema, all_jobs, top_job_table, starting_job.job_file_info().job_id() ): if follow_up_jobs := client.create_table_chain_completed_followup_jobs( table_chain @@ -279,7 +247,32 @@ def create_followup_jobs( return jobs def complete_jobs(self, load_id: str, jobs: List[LoadJob], schema: Schema) -> List[LoadJob]: + """Run periodically in the main thread to collect job execution statuses. + + After detecting change of status, it commits the job state by moving it to the right folder + May create one or more followup jobs that get scheduled as new jobs. New jobs are created + only in terminal states (completed / failed) + """ remaining_jobs: List[LoadJob] = [] + + def _schedule_followup_jobs(followup_jobs: Iterable[NewLoadJob]) -> None: + for followup_job in followup_jobs: + # running should be moved into "new jobs", other statuses into started + folder: TJobState = ( + "new_jobs" if followup_job.state() == "running" else "started_jobs" + ) + # save all created jobs + self.load_storage.normalized_packages.import_job( + load_id, followup_job.new_file_path(), job_state=folder + ) + logger.info( + f"Job {job.job_id()} CREATED a new FOLLOWUP JOB" + f" {followup_job.new_file_path()} placed in {folder}" + ) + # if followup job is not "running" place it in current queue to be finalized + if not followup_job.state() == "running": + remaining_jobs.append(followup_job) + logger.info(f"Will complete {len(jobs)} for {load_id}") for ii in range(len(jobs)): job = jobs[ii] @@ -290,6 +283,9 @@ def complete_jobs(self, load_id: str, jobs: List[LoadJob], schema: Schema) -> Li logger.debug(f"job {job.job_id()} still running") remaining_jobs.append(job) elif state == "failed": + # create followup jobs + _schedule_followup_jobs(self.create_followup_jobs(load_id, state, job, schema)) + # try to get exception message from job failed_message = job.exception() self.load_storage.normalized_packages.fail_job( @@ -309,23 +305,7 @@ def complete_jobs(self, load_id: str, jobs: List[LoadJob], schema: Schema) -> Li ) elif state == "completed": # create followup jobs - followup_jobs = self.create_followup_jobs(load_id, state, job, schema) - for followup_job in followup_jobs: - # running should be moved into "new jobs", other statuses into started - folder: TJobState = ( - "new_jobs" if followup_job.state() == "running" else "started_jobs" - ) - # save all created jobs - self.load_storage.normalized_packages.import_job( - load_id, followup_job.new_file_path(), job_state=folder - ) - logger.info( - f"Job {job.job_id()} CREATED a new FOLLOWUP JOB" - f" {followup_job.new_file_path()} placed in {folder}" - ) - # if followup job is not "running" place it in current queue to be finalized - if not followup_job.state() == "running": - remaining_jobs.append(followup_job) + _schedule_followup_jobs(self.create_followup_jobs(load_id, state, job, schema)) # move to completed folder after followup jobs are created # in case of exception when creating followup job, the loader will retry operation and try to complete again self.load_storage.normalized_packages.complete_job(load_id, job.file_name()) @@ -355,101 +335,17 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) f"All jobs completed, archiving package {load_id} with aborted set to {aborted}" ) - @staticmethod - def _get_table_chain_tables_with_filter( - schema: Schema, f: Callable[[TTableSchema], bool], tables_with_jobs: Iterable[str] - ) -> Set[str]: - """Get all jobs for tables with given write disposition and resolve the table chain""" - result: Set[str] = set() - for table_name in tables_with_jobs: - top_job_table = get_top_level_table(schema.tables, table_name) - if not f(top_job_table): - continue - for table in get_child_tables(schema.tables, top_job_table["name"]): - # only add tables for tables that have jobs unless the disposition is replace - # TODO: this is a (formerly used) hack to make test_merge_on_keys_in_schema, - # we should change that test - if ( - not table["name"] in tables_with_jobs - and top_job_table["write_disposition"] != "replace" - ): - continue - result.add(table["name"]) - return result - - @staticmethod - def _init_dataset_and_update_schema( - job_client: JobClientBase, - expected_update: TSchemaTables, - update_tables: Iterable[str], - truncate_tables: Iterable[str] = None, - staging_info: bool = False, - ) -> TSchemaTables: - staging_text = "for staging dataset" if staging_info else "" - logger.info( - f"Client for {job_client.config.destination_type} will start initialize storage" - f" {staging_text}" - ) - job_client.initialize_storage() - logger.info( - f"Client for {job_client.config.destination_type} will update schema to package schema" - f" {staging_text}" - ) - applied_update = job_client.update_stored_schema( - only_tables=update_tables, expected_update=expected_update - ) - logger.info( - f"Client for {job_client.config.destination_type} will truncate tables {staging_text}" - ) - job_client.initialize_storage(truncate_tables=truncate_tables) - return applied_update - - def _init_client( - self, - job_client: JobClientBase, - schema: Schema, - expected_update: TSchemaTables, - load_id: str, - truncate_filter: Callable[[TTableSchema], bool], - truncate_staging_filter: Callable[[TTableSchema], bool], - ) -> TSchemaTables: - tables_with_jobs = set(job.table_name for job in self.get_new_jobs_info(load_id)) - dlt_tables = set(t["name"] for t in schema.dlt_tables()) - - # update the default dataset - truncate_tables = self._get_table_chain_tables_with_filter( - schema, truncate_filter, tables_with_jobs - ) - applied_update = self._init_dataset_and_update_schema( - job_client, expected_update, tables_with_jobs | dlt_tables, truncate_tables - ) - - # update the staging dataset if client supports this - if isinstance(job_client, WithStagingDataset): - if staging_tables := self._get_table_chain_tables_with_filter( - schema, truncate_staging_filter, tables_with_jobs - ): - with job_client.with_staging_dataset(): - self._init_dataset_and_update_schema( - job_client, - expected_update, - staging_tables | {schema.version_table_name}, - staging_tables, - staging_info=True, - ) - - return applied_update - def load_single_package(self, load_id: str, schema: Schema) -> None: + new_jobs = self.get_new_jobs_info(load_id) # initialize analytical storage ie. create dataset required by passed schema with self.get_destination_client(schema) as job_client: if (expected_update := self.load_storage.begin_schema_update(load_id)) is not None: # init job client - applied_update = self._init_client( + applied_update = init_client( job_client, schema, + new_jobs, expected_update, - load_id, job_client.should_truncate_table_before_load, ( job_client.should_load_data_to_staging_dataset @@ -465,11 +361,11 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: " implement SupportsStagingDestination" ) with self.get_staging_destination_client(schema) as staging_client: - self._init_client( + init_client( staging_client, schema, + new_jobs, expected_update, - load_id, job_client.should_truncate_table_before_load_on_staging_destination, job_client.should_load_data_to_staging_dataset_on_staging_destination, ) diff --git a/dlt/load/utils.py b/dlt/load/utils.py new file mode 100644 index 0000000000..067ae33613 --- /dev/null +++ b/dlt/load/utils.py @@ -0,0 +1,186 @@ +from typing import List, Set, Iterable, Callable + +from dlt.common import logger +from dlt.common.storages.load_package import LoadJobInfo, PackageStorage +from dlt.common.schema.utils import ( + fill_hints_from_parent_and_clone_table, + get_child_tables, + get_top_level_table, + has_table_seen_data, +) +from dlt.common.storages.load_storage import ParsedLoadJobFileName +from dlt.common.schema import Schema, TSchemaTables +from dlt.common.schema.typing import TTableSchema +from dlt.common.destination.reference import ( + JobClientBase, + WithStagingDataset, +) + + +def get_completed_table_chain( + schema: Schema, + all_jobs: Iterable[LoadJobInfo], + top_merged_table: TTableSchema, + being_completed_job_id: str = None, +) -> List[TTableSchema]: + """Gets a table chain starting from the `top_merged_table` containing only tables with completed/failed jobs. None is returned if there's any job that is not completed + For append and merge write disposition, tables without jobs will be included, providing they have seen data (and were created in the destination) + Optionally `being_completed_job_id` can be passed that is considered to be completed before job itself moves in storage + """ + # returns ordered list of tables from parent to child leaf tables + table_chain: List[TTableSchema] = [] + # allow for jobless tables for those write disposition + skip_jobless_table = top_merged_table["write_disposition"] not in ("replace", "merge") + + # make sure all the jobs for the table chain is completed + for table in map( + lambda t: fill_hints_from_parent_and_clone_table(schema.tables, t), + get_child_tables(schema.tables, top_merged_table["name"]), + ): + table_jobs = PackageStorage.filter_jobs_for_table(all_jobs, table["name"]) + # skip tables that never seen data + if not has_table_seen_data(table): + assert len(table_jobs) == 0, f"Tables that never seen data cannot have jobs {table}" + continue + # skip jobless tables + if len(table_jobs) == 0 and skip_jobless_table: + continue + else: + # all jobs must be completed in order for merge to be created + if any( + job.state not in ("failed_jobs", "completed_jobs") + and job.job_file_info.job_id() != being_completed_job_id + for job in table_jobs + ): + return None + table_chain.append(table) + # there must be at least table + assert len(table_chain) > 0 + return table_chain + + +def init_client( + job_client: JobClientBase, + schema: Schema, + new_jobs: Iterable[ParsedLoadJobFileName], + expected_update: TSchemaTables, + truncate_filter: Callable[[TTableSchema], bool], + load_staging_filter: Callable[[TTableSchema], bool], +) -> TSchemaTables: + """Initializes destination storage including staging dataset if supported + + Will initialize and migrate schema in destination dataset and staging dataset. + + Args: + job_client (JobClientBase): Instance of destination client + schema (Schema): The schema as in load package + new_jobs (Iterable[LoadJobInfo]): List of new jobs + expected_update (TSchemaTables): Schema update as in load package. Always present even if empty + truncate_filter (Callable[[TTableSchema], bool]): A filter that tells which table in destination dataset should be truncated + load_staging_filter (Callable[[TTableSchema], bool]): A filter which tell which table in the staging dataset may be loaded into + + Returns: + TSchemaTables: Actual migrations done at destination + """ + # get dlt/internal tables + dlt_tables = set(schema.dlt_table_names()) + # tables without data (TODO: normalizer removes such jobs, write tests and remove the line below) + tables_no_data = set( + table["name"] for table in schema.data_tables() if not has_table_seen_data(table) + ) + # get all tables that actually have load jobs with data + tables_with_jobs = set(job.table_name for job in new_jobs) - tables_no_data + + # get tables to truncate by extending tables with jobs with all their child tables + truncate_tables = set( + _extend_tables_with_table_chain(schema, tables_with_jobs, tables_with_jobs, truncate_filter) + ) + + applied_update = _init_dataset_and_update_schema( + job_client, expected_update, tables_with_jobs | dlt_tables, truncate_tables + ) + + # update the staging dataset if client supports this + if isinstance(job_client, WithStagingDataset): + # get staging tables (all data tables that are eligible) + staging_tables = set( + _extend_tables_with_table_chain( + schema, tables_with_jobs, tables_with_jobs, load_staging_filter + ) + ) + + if staging_tables: + with job_client.with_staging_dataset(): + _init_dataset_and_update_schema( + job_client, + expected_update, + staging_tables | {schema.version_table_name}, # keep only schema version + staging_tables, # all eligible tables must be also truncated + staging_info=True, + ) + + return applied_update + + +def _init_dataset_and_update_schema( + job_client: JobClientBase, + expected_update: TSchemaTables, + update_tables: Iterable[str], + truncate_tables: Iterable[str] = None, + staging_info: bool = False, +) -> TSchemaTables: + staging_text = "for staging dataset" if staging_info else "" + logger.info( + f"Client for {job_client.config.destination_type} will start initialize storage" + f" {staging_text}" + ) + job_client.initialize_storage() + logger.info( + f"Client for {job_client.config.destination_type} will update schema to package schema" + f" {staging_text}" + ) + applied_update = job_client.update_stored_schema( + only_tables=update_tables, expected_update=expected_update + ) + logger.info( + f"Client for {job_client.config.destination_type} will truncate tables {staging_text}" + ) + job_client.initialize_storage(truncate_tables=truncate_tables) + return applied_update + + +def _extend_tables_with_table_chain( + schema: Schema, + tables: Iterable[str], + tables_with_jobs: Iterable[str], + include_table_filter: Callable[[TTableSchema], bool] = lambda t: True, +) -> Iterable[str]: + """Extend 'tables` with all their children and filter out tables that do not have jobs (in `tables_with_jobs`), + haven't seen data or are not included by `include_table_filter`. + Note that for top tables with replace and merge, the filter for tables that do not have jobs + + Returns an unordered set of table names and their child tables + """ + result: Set[str] = set() + for table_name in tables: + top_job_table = get_top_level_table(schema.tables, table_name) + # for replace and merge write dispositions we should include tables + # without jobs in the table chain, because child tables may need + # processing due to changes in the root table + skip_jobless_table = top_job_table["write_disposition"] not in ("replace", "merge") + for table in map( + lambda t: fill_hints_from_parent_and_clone_table(schema.tables, t), + get_child_tables(schema.tables, top_job_table["name"]), + ): + chain_table_name = table["name"] + table_has_job = chain_table_name in tables_with_jobs + # table that never seen data are skipped as they will not be created + # also filter out tables + # NOTE: this will ie. eliminate all non iceberg tables on ATHENA destination from staging (only iceberg needs that) + if not has_table_seen_data(table) or not include_table_filter(table): + continue + # if there's no job for the table and we are in append then skip + if not table_has_job and skip_jobless_table: + continue + result.add(chain_table_name) + return result diff --git a/dlt/normalize/items_normalizers.py b/dlt/normalize/items_normalizers.py index 2167250036..56d38a5a64 100644 --- a/dlt/normalize/items_normalizers.py +++ b/dlt/normalize/items_normalizers.py @@ -6,6 +6,7 @@ from dlt.common.json import custom_pua_decode, may_have_pua from dlt.common.runtime import signals from dlt.common.schema.typing import TSchemaEvolutionMode, TTableSchemaColumns, TSchemaContractDict +from dlt.common.schema.utils import has_table_seen_data from dlt.common.storages import ( NormalizeStorage, LoadStorage, @@ -196,15 +197,18 @@ def __call__( schema_updates.append(partial_update) logger.debug(f"Processed {line_no} lines from file {extracted_items_file}") if line is None and root_table_name in self.schema.tables: - self.load_storage.write_empty_items_file( - self.load_id, - self.schema.name, - root_table_name, - self.schema.get_table_columns(root_table_name), - ) - logger.debug( - f"No lines in file {extracted_items_file}, written empty load job file" - ) + # write only if table seen data before + root_table = self.schema.tables[root_table_name] + if has_table_seen_data(root_table): + self.load_storage.write_empty_items_file( + self.load_id, + self.schema.name, + root_table_name, + self.schema.get_table_columns(root_table_name), + ) + logger.debug( + f"No lines in file {extracted_items_file}, written empty load job file" + ) return schema_updates diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 0a3c6784c7..d360a1c7c4 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -289,9 +289,26 @@ def spool_files( ) -> None: # process files in parallel or in single thread, depending on map_f schema_updates, writer_metrics = map_f(schema, load_id, files) - # remove normalizer specific info - for table in schema.tables.values(): - table.pop("x-normalizer", None) # type: ignore[typeddict-item] + # compute metrics + job_metrics = {ParsedLoadJobFileName.parse(m.file_path): m for m in writer_metrics} + table_metrics: Dict[str, DataWriterMetrics] = { + table_name: sum(map(lambda pair: pair[1], metrics), EMPTY_DATA_WRITER_METRICS) + for table_name, metrics in itertools.groupby( + job_metrics.items(), lambda pair: pair[0].table_name + ) + } + # update normalizer specific info + for table_name in table_metrics: + table = schema.tables[table_name] + x_normalizer = table.setdefault("x-normalizer", {}) # type: ignore[typeddict-item] + # drop evolve once for all tables that seen data + x_normalizer.pop("evolve-columns-once", None) + # mark that table have seen data only if there was data + if table_metrics[table_name].items_count > 0 and "seen-data" not in x_normalizer: + logger.info( + f"Table {table_name} has seen data for a first time with load id {load_id}" + ) + x_normalizer["seen-data"] = True logger.info( f"Saving schema {schema.name} with version {schema.stored_version}:{schema.version}" ) @@ -312,19 +329,13 @@ def spool_files( self.normalize_storage.extracted_packages.delete_package(load_id) # log and update metrics logger.info(f"Extracted package {load_id} processed") - job_metrics = {ParsedLoadJobFileName.parse(m.file_path): m for m in writer_metrics} self._step_info_complete_load_id( load_id, { "started_at": None, "finished_at": None, "job_metrics": {job.job_id(): metrics for job, metrics in job_metrics.items()}, - "table_metrics": { - table_name: sum(map(lambda pair: pair[1], metrics), EMPTY_DATA_WRITER_METRICS) - for table_name, metrics in itertools.groupby( - job_metrics.items(), lambda pair: pair[0].table_name - ) - }, + "table_metrics": table_metrics, }, ) diff --git a/docs/website/docs/general-usage/incremental-loading.md b/docs/website/docs/general-usage/incremental-loading.md index 09b8ca7a96..7e4021214e 100644 --- a/docs/website/docs/general-usage/incremental-loading.md +++ b/docs/website/docs/general-usage/incremental-loading.md @@ -77,6 +77,17 @@ You can use compound primary keys: ... ``` +By default, `primary_key` deduplication is arbitrary. You can pass the `dedup_sort` column hint with a value of `desc` or `asc` to influence which record remains after deduplication. Using `desc`, the records sharing the same `primary_key` are sorted in descending order before deduplication, making sure the record with the highest value for the column with the `dedup_sort` hint remains. `asc` has the opposite behavior. + +```python +@dlt.resource( + primary_key="id", + write_disposition="merge", + columns={"created_at": {"dedup_sort": "desc"}} # select "latest" record +) +... +``` + Example below merges on a column `batch_day` that holds the day for which given record is valid. Merge keys also can be compound: @@ -113,6 +124,78 @@ def github_repo_events(last_created_at = dlt.sources.incremental("created_at", " yield from _get_rest_pages("events") ``` +### Delete records +The `hard_delete` column hint can be used to delete records from the destination dataset. The behavior of the delete mechanism depends on the data type of the column marked with the hint: +1) `bool` type: only `True` leads to a delete—`None` and `False` values are disregarded +2) other types: each `not None` value leads to a delete + +Each record in the destination table with the same `primary_key` or `merge_key` as a record in the source dataset that's marked as a delete will be deleted. + +Deletes are propagated to any child table that might exist. For each record that gets deleted in the root table, all corresponding records in the child table(s) will also be deleted. Records in parent and child tables are linked through the `root key` that is explained in the next section. + +#### Example: with primary key and boolean delete column +```python +@dlt.resource( + primary_key="id", + write_disposition="merge", + columns={"deleted_flag": {"hard_delete": True}} +) +def resource(): + # this will insert a record (assuming a record with id = 1 does not yet exist) + yield {"id": 1, "val": "foo", "deleted_flag": False} + + # this will update the record + yield {"id": 1, "val": "bar", "deleted_flag": None} + + # this will delete the record + yield {"id": 1, "val": "foo", "deleted_flag": True} + + # similarly, this would have also deleted the record + # only the key and the column marked with the "hard_delete" hint suffice to delete records + yield {"id": 1, "deleted_flag": True} +... +``` + +#### Example: with merge key and non-boolean delete column +```python +@dlt.resource( + merge_key="id", + write_disposition="merge", + columns={"deleted_at_ts": {"hard_delete": True}}} +def resource(): + # this will insert two records + yield [ + {"id": 1, "val": "foo", "deleted_at_ts": None}, + {"id": 1, "val": "bar", "deleted_at_ts": None} + ] + + # this will delete two records + yield {"id": 1, "val": "foo", "deleted_at_ts": "2024-02-22T12:34:56Z"} +... +``` + +#### Example: with primary key and "dedup_sort" hint +```python +@dlt.resource( + primary_key="id", + write_disposition="merge", + columns={"deleted_flag": {"hard_delete": True}, "lsn": {"dedup_sort": "desc"}} +def resource(): + # this will insert one record (the one with lsn = 3) + yield [ + {"id": 1, "val": "foo", "lsn": 1, "deleted_flag": None}, + {"id": 1, "val": "baz", "lsn": 3, "deleted_flag": None}, + {"id": 1, "val": "bar", "lsn": 2, "deleted_flag": True} + ] + + # this will insert nothing, because the "latest" record is a delete + yield [ + {"id": 2, "val": "foo", "lsn": 1, "deleted_flag": False}, + {"id": 2, "lsn": 2, "deleted_flag": True} + ] +... +``` + ### Forcing root key propagation Merge write disposition requires that the `_dlt_id` of top level table is propagated to child diff --git a/tests/common/cases/schemas/eth/ethereum_schema_v9.yml b/tests/common/cases/schemas/eth/ethereum_schema_v9.yml new file mode 100644 index 0000000000..c56ff85a9f --- /dev/null +++ b/tests/common/cases/schemas/eth/ethereum_schema_v9.yml @@ -0,0 +1,476 @@ +version: 17 +version_hash: PgEHvn5+BHV1jNzNYpx9aDpq6Pq1PSSetufj/h0hKg4= +engine_version: 9 +name: ethereum +tables: + _dlt_loads: + columns: + load_id: + nullable: false + data_type: text + name: load_id + schema_name: + nullable: true + data_type: text + name: schema_name + status: + nullable: false + data_type: bigint + name: status + inserted_at: + nullable: false + data_type: timestamp + name: inserted_at + schema_version_hash: + nullable: true + data_type: text + name: schema_version_hash + write_disposition: skip + description: Created by DLT. Tracks completed loads + schema_contract: {} + name: _dlt_loads + resource: _dlt_loads + _dlt_version: + columns: + version: + nullable: false + data_type: bigint + name: version + engine_version: + nullable: false + data_type: bigint + name: engine_version + inserted_at: + nullable: false + data_type: timestamp + name: inserted_at + schema_name: + nullable: false + data_type: text + name: schema_name + version_hash: + nullable: false + data_type: text + name: version_hash + schema: + nullable: false + data_type: text + name: schema + write_disposition: skip + description: Created by DLT. Tracks schema updates + schema_contract: {} + name: _dlt_version + resource: _dlt_version + blocks: + description: Ethereum blocks + x-annotation: this will be preserved on save + write_disposition: append + filters: + includes: [] + excludes: [] + columns: + _dlt_load_id: + nullable: false + description: load id coming from the extractor + data_type: text + name: _dlt_load_id + _dlt_id: + nullable: false + unique: true + data_type: text + name: _dlt_id + number: + nullable: false + primary_key: true + data_type: bigint + name: number + parent_hash: + nullable: true + data_type: text + name: parent_hash + hash: + nullable: false + cluster: true + unique: true + data_type: text + name: hash + base_fee_per_gas: + nullable: false + data_type: wei + name: base_fee_per_gas + difficulty: + nullable: false + data_type: wei + name: difficulty + extra_data: + nullable: true + data_type: text + name: extra_data + gas_limit: + nullable: false + data_type: bigint + name: gas_limit + gas_used: + nullable: false + data_type: bigint + name: gas_used + logs_bloom: + nullable: true + data_type: binary + name: logs_bloom + miner: + nullable: true + data_type: text + name: miner + mix_hash: + nullable: true + data_type: text + name: mix_hash + nonce: + nullable: true + data_type: text + name: nonce + receipts_root: + nullable: true + data_type: text + name: receipts_root + sha3_uncles: + nullable: true + data_type: text + name: sha3_uncles + size: + nullable: true + data_type: bigint + name: size + state_root: + nullable: false + data_type: text + name: state_root + timestamp: + nullable: false + unique: true + sort: true + data_type: timestamp + name: timestamp + total_difficulty: + nullable: true + data_type: wei + name: total_difficulty + transactions_root: + nullable: false + data_type: text + name: transactions_root + schema_contract: {} + name: blocks + resource: blocks + x-normalizer: + seen-data: true + blocks__transactions: + parent: blocks + columns: + _dlt_id: + nullable: false + unique: true + data_type: text + name: _dlt_id + block_number: + nullable: false + primary_key: true + foreign_key: true + data_type: bigint + name: block_number + transaction_index: + nullable: false + primary_key: true + data_type: bigint + name: transaction_index + hash: + nullable: false + unique: true + data_type: text + name: hash + block_hash: + nullable: false + cluster: true + data_type: text + name: block_hash + block_timestamp: + nullable: false + sort: true + data_type: timestamp + name: block_timestamp + chain_id: + nullable: true + data_type: text + name: chain_id + from: + nullable: true + data_type: text + name: from + gas: + nullable: true + data_type: bigint + name: gas + gas_price: + nullable: true + data_type: bigint + name: gas_price + input: + nullable: true + data_type: text + name: input + max_fee_per_gas: + nullable: true + data_type: wei + name: max_fee_per_gas + max_priority_fee_per_gas: + nullable: true + data_type: wei + name: max_priority_fee_per_gas + nonce: + nullable: true + data_type: bigint + name: nonce + r: + nullable: true + data_type: text + name: r + s: + nullable: true + data_type: text + name: s + status: + nullable: true + data_type: bigint + name: status + to: + nullable: true + data_type: text + name: to + type: + nullable: true + data_type: text + name: type + v: + nullable: true + data_type: bigint + name: v + value: + nullable: false + data_type: wei + name: value + eth_value: + nullable: true + data_type: decimal + name: eth_value + name: blocks__transactions + x-normalizer: + seen-data: true + blocks__transactions__logs: + parent: blocks__transactions + columns: + _dlt_id: + nullable: false + unique: true + data_type: text + name: _dlt_id + address: + nullable: false + data_type: text + name: address + block_timestamp: + nullable: false + sort: true + data_type: timestamp + name: block_timestamp + block_hash: + nullable: false + cluster: true + data_type: text + name: block_hash + block_number: + nullable: false + primary_key: true + foreign_key: true + data_type: bigint + name: block_number + transaction_index: + nullable: false + primary_key: true + foreign_key: true + data_type: bigint + name: transaction_index + log_index: + nullable: false + primary_key: true + data_type: bigint + name: log_index + data: + nullable: true + data_type: text + name: data + removed: + nullable: true + data_type: bool + name: removed + transaction_hash: + nullable: false + data_type: text + name: transaction_hash + name: blocks__transactions__logs + x-normalizer: + seen-data: true + blocks__transactions__logs__topics: + parent: blocks__transactions__logs + columns: + _dlt_parent_id: + nullable: false + foreign_key: true + data_type: text + name: _dlt_parent_id + _dlt_list_idx: + nullable: false + data_type: bigint + name: _dlt_list_idx + _dlt_id: + nullable: false + unique: true + data_type: text + name: _dlt_id + _dlt_root_id: + nullable: false + root_key: true + data_type: text + name: _dlt_root_id + value: + nullable: true + data_type: text + name: value + name: blocks__transactions__logs__topics + x-normalizer: + seen-data: true + blocks__transactions__access_list: + parent: blocks__transactions + columns: + _dlt_parent_id: + nullable: false + foreign_key: true + data_type: text + name: _dlt_parent_id + _dlt_list_idx: + nullable: false + data_type: bigint + name: _dlt_list_idx + _dlt_id: + nullable: false + unique: true + data_type: text + name: _dlt_id + _dlt_root_id: + nullable: false + root_key: true + data_type: text + name: _dlt_root_id + address: + nullable: true + data_type: text + name: address + name: blocks__transactions__access_list + x-normalizer: + seen-data: true + blocks__transactions__access_list__storage_keys: + parent: blocks__transactions__access_list + columns: + _dlt_parent_id: + nullable: false + foreign_key: true + data_type: text + name: _dlt_parent_id + _dlt_list_idx: + nullable: false + data_type: bigint + name: _dlt_list_idx + _dlt_id: + nullable: false + unique: true + data_type: text + name: _dlt_id + _dlt_root_id: + nullable: false + root_key: true + data_type: text + name: _dlt_root_id + value: + nullable: true + data_type: text + name: value + name: blocks__transactions__access_list__storage_keys + x-normalizer: + seen-data: true + blocks__uncles: + parent: blocks + columns: + _dlt_parent_id: + nullable: false + foreign_key: true + data_type: text + name: _dlt_parent_id + _dlt_list_idx: + nullable: false + data_type: bigint + name: _dlt_list_idx + _dlt_id: + nullable: false + unique: true + data_type: text + name: _dlt_id + _dlt_root_id: + nullable: false + root_key: true + data_type: text + name: _dlt_root_id + value: + nullable: true + data_type: text + name: value + name: blocks__uncles + x-normalizer: + seen-data: true +settings: + default_hints: + foreign_key: + - _dlt_parent_id + not_null: + - re:^_dlt_id$ + - _dlt_root_id + - _dlt_parent_id + - _dlt_list_idx + unique: + - _dlt_id + cluster: + - block_hash + partition: + - block_timestamp + root_key: + - _dlt_root_id + preferred_types: + timestamp: timestamp + block_timestamp: timestamp + schema_contract: {} +normalizers: + names: dlt.common.normalizers.names.snake_case + json: + module: dlt.common.normalizers.json.relational + config: + generate_dlt_id: true + propagation: + root: + _dlt_id: _dlt_root_id + tables: + blocks: + timestamp: block_timestamp + hash: block_hash +previous_hashes: +- C5An8WClbavalXDdNSqXbdI7Swqh/mTWMcwWKCF//EE= +- yjMtV4Zv0IJlfR5DPMwuXxGg8BRhy7E79L26XAHWEGE= + diff --git a/tests/common/schema/test_schema.py b/tests/common/schema/test_schema.py index 54892eeae5..ba817b946f 100644 --- a/tests/common/schema/test_schema.py +++ b/tests/common/schema/test_schema.py @@ -6,6 +6,7 @@ from dlt.common import pendulum from dlt.common.configuration import resolve_configuration from dlt.common.configuration.container import Container +from dlt.common.schema.migrations import migrate_schema from dlt.common.storages import SchemaStorageConfiguration from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.exceptions import DictValidationException @@ -308,7 +309,7 @@ def test_upgrade_engine_v1_schema() -> None: # ensure engine v1 assert schema_dict["engine_version"] == 1 # schema_dict will be updated to new engine version - utils.migrate_schema(schema_dict, from_engine=1, to_engine=2) + migrate_schema(schema_dict, from_engine=1, to_engine=2) assert schema_dict["engine_version"] == 2 # we have 27 tables assert len(schema_dict["tables"]) == 27 @@ -316,40 +317,46 @@ def test_upgrade_engine_v1_schema() -> None: # upgrade schema eng 2 -> 4 schema_dict = load_json_case("schemas/ev2/event.schema") assert schema_dict["engine_version"] == 2 - upgraded = utils.migrate_schema(schema_dict, from_engine=2, to_engine=4) + upgraded = migrate_schema(schema_dict, from_engine=2, to_engine=4) assert upgraded["engine_version"] == 4 # upgrade 1 -> 4 schema_dict = load_json_case("schemas/ev1/event.schema") assert schema_dict["engine_version"] == 1 - upgraded = utils.migrate_schema(schema_dict, from_engine=1, to_engine=4) + upgraded = migrate_schema(schema_dict, from_engine=1, to_engine=4) assert upgraded["engine_version"] == 4 # upgrade 1 -> 6 schema_dict = load_json_case("schemas/ev1/event.schema") assert schema_dict["engine_version"] == 1 - upgraded = utils.migrate_schema(schema_dict, from_engine=1, to_engine=6) + upgraded = migrate_schema(schema_dict, from_engine=1, to_engine=6) assert upgraded["engine_version"] == 6 # upgrade 1 -> 7 schema_dict = load_json_case("schemas/ev1/event.schema") assert schema_dict["engine_version"] == 1 - upgraded = utils.migrate_schema(schema_dict, from_engine=1, to_engine=7) + upgraded = migrate_schema(schema_dict, from_engine=1, to_engine=7) assert upgraded["engine_version"] == 7 # upgrade 1 -> 8 schema_dict = load_json_case("schemas/ev1/event.schema") assert schema_dict["engine_version"] == 1 - upgraded = utils.migrate_schema(schema_dict, from_engine=1, to_engine=8) + upgraded = migrate_schema(schema_dict, from_engine=1, to_engine=8) assert upgraded["engine_version"] == 8 + # upgrade 1 -> 9 + schema_dict = load_json_case("schemas/ev1/event.schema") + assert schema_dict["engine_version"] == 1 + upgraded = migrate_schema(schema_dict, from_engine=1, to_engine=9) + assert upgraded["engine_version"] == 9 + def test_unknown_engine_upgrade() -> None: schema_dict: TStoredSchema = load_json_case("schemas/ev1/event.schema") # there's no path to migrate 3 -> 2 schema_dict["engine_version"] = 3 with pytest.raises(SchemaEngineNoUpgradePathException): - utils.migrate_schema(schema_dict, 3, 2) # type: ignore[arg-type] + migrate_schema(schema_dict, 3, 2) # type: ignore[arg-type] def test_preserve_column_order(schema: Schema, schema_storage: SchemaStorage) -> None: @@ -693,7 +700,7 @@ def assert_new_schema_values(schema: Schema) -> None: assert schema.stored_version == 1 assert schema.stored_version_hash is not None assert schema.version_hash is not None - assert schema.ENGINE_VERSION == 8 + assert schema.ENGINE_VERSION == 9 assert schema._stored_previous_hashes == [] assert len(schema.settings["default_hints"]) > 0 # check settings diff --git a/tests/common/schema/test_versioning.py b/tests/common/schema/test_versioning.py index 5b794f51ee..dde05001e8 100644 --- a/tests/common/schema/test_versioning.py +++ b/tests/common/schema/test_versioning.py @@ -84,10 +84,10 @@ def test_infer_column_bumps_version() -> None: def test_preserve_version_on_load() -> None: - eth_v8: TStoredSchema = load_yml_case("schemas/eth/ethereum_schema_v8") - version = eth_v8["version"] - version_hash = eth_v8["version_hash"] - schema = Schema.from_dict(eth_v8) # type: ignore[arg-type] + eth_v9: TStoredSchema = load_yml_case("schemas/eth/ethereum_schema_v9") + version = eth_v9["version"] + version_hash = eth_v9["version_hash"] + schema = Schema.from_dict(eth_v9) # type: ignore[arg-type] # version should not be bumped assert version_hash == schema._stored_version_hash assert version_hash == schema.version_hash @@ -126,13 +126,18 @@ def test_version_preserve_on_reload(remove_defaults: bool) -> None: def test_create_ancestry() -> None: - eth_v8: TStoredSchema = load_yml_case("schemas/eth/ethereum_schema_v8") - schema = Schema.from_dict(eth_v8) # type: ignore[arg-type] - assert schema._stored_previous_hashes == ["yjMtV4Zv0IJlfR5DPMwuXxGg8BRhy7E79L26XAHWEGE="] + eth_v9: TStoredSchema = load_yml_case("schemas/eth/ethereum_schema_v9") + schema = Schema.from_dict(eth_v9) # type: ignore[arg-type] + + expected_previous_hashes = [ + "C5An8WClbavalXDdNSqXbdI7Swqh/mTWMcwWKCF//EE=", + "yjMtV4Zv0IJlfR5DPMwuXxGg8BRhy7E79L26XAHWEGE=", + ] + hash_count = len(expected_previous_hashes) + assert schema._stored_previous_hashes == expected_previous_hashes version = schema._stored_version # modify save and load schema 15 times and check ancestry - expected_previous_hashes = ["yjMtV4Zv0IJlfR5DPMwuXxGg8BRhy7E79L26XAHWEGE="] for i in range(1, 15): # keep expected previous_hashes expected_previous_hashes.insert(0, schema._stored_version_hash) @@ -148,6 +153,6 @@ def test_create_ancestry() -> None: assert schema._stored_version == version + i # we never have more than 10 previous_hashes - assert len(schema._stored_previous_hashes) == i + 1 if i + 1 <= 10 else 10 + assert len(schema._stored_previous_hashes) == i + hash_count if i + hash_count <= 10 else 10 assert len(schema._stored_previous_hashes) == 10 diff --git a/tests/common/storages/test_schema_storage.py b/tests/common/storages/test_schema_storage.py index c72fa75927..0e04554649 100644 --- a/tests/common/storages/test_schema_storage.py +++ b/tests/common/storages/test_schema_storage.py @@ -4,9 +4,9 @@ import yaml from dlt.common import json +from dlt.common.normalizers import explicit_normalizers from dlt.common.schema.schema import Schema from dlt.common.schema.typing import TStoredSchema -from dlt.common.schema.utils import explicit_normalizers from dlt.common.storages.exceptions import ( InStorageSchemaModified, SchemaNotFoundError, @@ -24,7 +24,7 @@ load_yml_case, yml_case_path, COMMON_TEST_CASES_PATH, - IMPORTED_VERSION_HASH_ETH_V8, + IMPORTED_VERSION_HASH_ETH_V9, ) @@ -227,10 +227,10 @@ def test_save_store_schema_over_import(ie_storage: SchemaStorage) -> None: ie_storage.save_schema(schema) assert schema.version_hash == schema_hash # we linked schema to import schema - assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V8 + assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9 # load schema and make sure our new schema is here schema = ie_storage.load_schema("ethereum") - assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V8 + assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9 assert schema._stored_version_hash == schema_hash assert schema.version_hash == schema_hash assert schema.previous_hashes == [] @@ -247,7 +247,7 @@ def test_save_store_schema_over_import_sync(synced_storage: SchemaStorage) -> No schema = Schema("ethereum") schema_hash = schema.version_hash synced_storage.save_schema(schema) - assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V8 + assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9 # import schema is overwritten fs = FileStorage(synced_storage.config.import_schema_path) exported_name = synced_storage._file_name_in_store("ethereum", "yaml") @@ -327,13 +327,13 @@ def prepare_import_folder(storage: SchemaStorage) -> None: def assert_schema_imported(synced_storage: SchemaStorage, storage: SchemaStorage) -> Schema: prepare_import_folder(synced_storage) - eth_V8: TStoredSchema = load_yml_case("schemas/eth/ethereum_schema_v8") + eth_V9: TStoredSchema = load_yml_case("schemas/eth/ethereum_schema_v9") schema = synced_storage.load_schema("ethereum") # is linked to imported schema - schema._imported_version_hash = eth_V8["version_hash"] + schema._imported_version_hash = eth_V9["version_hash"] # also was saved in storage assert synced_storage.has_schema("ethereum") # and has link to imported schema s well (load without import) schema = storage.load_schema("ethereum") - assert schema._imported_version_hash == eth_V8["version_hash"] + assert schema._imported_version_hash == eth_V9["version_hash"] return schema diff --git a/tests/common/utils.py b/tests/common/utils.py index 0235d18bbe..a234937e56 100644 --- a/tests/common/utils.py +++ b/tests/common/utils.py @@ -16,7 +16,7 @@ COMMON_TEST_CASES_PATH = "./tests/common/cases/" # for import schema tests, change when upgrading the schema version -IMPORTED_VERSION_HASH_ETH_V8 = "C5An8WClbavalXDdNSqXbdI7Swqh/mTWMcwWKCF//EE=" +IMPORTED_VERSION_HASH_ETH_V9 = "PgEHvn5+BHV1jNzNYpx9aDpq6Pq1PSSetufj/h0hKg4=" # test sentry DSN TEST_SENTRY_DSN = ( "https://797678dd0af64b96937435326c7d30c1@o1061158.ingest.sentry.io/4504306172821504" diff --git a/tests/extract/test_decorators.py b/tests/extract/test_decorators.py index 3c15bf37f5..a76cbd0cfd 100644 --- a/tests/extract/test_decorators.py +++ b/tests/extract/test_decorators.py @@ -39,7 +39,7 @@ ) from dlt.extract.typing import TableNameMeta -from tests.common.utils import IMPORTED_VERSION_HASH_ETH_V8 +from tests.common.utils import IMPORTED_VERSION_HASH_ETH_V9 def test_none_returning_source() -> None: @@ -84,7 +84,7 @@ def test_load_schema_for_callable() -> None: schema = s.schema assert schema.name == "ethereum" == s.name # the schema in the associated file has this hash - assert schema.stored_version_hash == IMPORTED_VERSION_HASH_ETH_V8 + assert schema.stored_version_hash == IMPORTED_VERSION_HASH_ETH_V9 def test_unbound_parametrized_transformer() -> None: diff --git a/tests/load/athena_iceberg/test_athena_iceberg.py b/tests/load/athena_iceberg/test_athena_iceberg.py index 0b18f22639..6804b98427 100644 --- a/tests/load/athena_iceberg/test_athena_iceberg.py +++ b/tests/load/athena_iceberg/test_athena_iceberg.py @@ -27,7 +27,7 @@ def test_iceberg() -> None: os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = "s3://dlt-ci-test-bucket" pipeline = dlt.pipeline( - pipeline_name="aaaaathena-iceberg", + pipeline_name="athena-iceberg", destination="athena", staging="filesystem", full_refresh=True, diff --git a/tests/load/bigquery/test_bigquery_table_builder.py b/tests/load/bigquery/test_bigquery_table_builder.py index 63b54726c5..a223de9b26 100644 --- a/tests/load/bigquery/test_bigquery_table_builder.py +++ b/tests/load/bigquery/test_bigquery_table_builder.py @@ -32,12 +32,7 @@ DestinationTestConfiguration, drop_active_pipeline_data, ) -from tests.load.utils import TABLE_UPDATE, sequence_generator - - -@pytest.fixture -def schema() -> Schema: - return Schema("event") +from tests.load.utils import TABLE_UPDATE, sequence_generator, empty_schema def test_configuration() -> None: @@ -56,13 +51,13 @@ def test_configuration() -> None: @pytest.fixture -def gcp_client(schema: Schema) -> BigQueryClient: +def gcp_client(empty_schema: Schema) -> BigQueryClient: # return a client without opening connection creds = GcpServiceAccountCredentialsWithoutDefaults() creds.project_id = "test_project_id" # noinspection PydanticTypeChecker return BigQueryClient( - schema, + empty_schema, BigQueryClientConfiguration( dataset_name=f"test_{uniq_id()}", credentials=creds # type: ignore[arg-type] ), diff --git a/tests/load/duckdb/test_duckdb_table_builder.py b/tests/load/duckdb/test_duckdb_table_builder.py index 0e6f799047..9b12e04f77 100644 --- a/tests/load/duckdb/test_duckdb_table_builder.py +++ b/tests/load/duckdb/test_duckdb_table_builder.py @@ -8,18 +8,13 @@ from dlt.destinations.impl.duckdb.duck import DuckDbClient from dlt.destinations.impl.duckdb.configuration import DuckDbClientConfiguration -from tests.load.utils import TABLE_UPDATE +from tests.load.utils import TABLE_UPDATE, empty_schema @pytest.fixture -def schema() -> Schema: - return Schema("event") - - -@pytest.fixture -def client(schema: Schema) -> DuckDbClient: +def client(empty_schema: Schema) -> DuckDbClient: # return client without opening connection - return DuckDbClient(schema, DuckDbClientConfiguration(dataset_name="test_" + uniq_id())) + return DuckDbClient(empty_schema, DuckDbClientConfiguration(dataset_name="test_" + uniq_id())) def test_create_table(client: DuckDbClient) -> None: diff --git a/tests/load/mssql/test_mssql_table_builder.py b/tests/load/mssql/test_mssql_table_builder.py index 039ce99113..75f46e8905 100644 --- a/tests/load/mssql/test_mssql_table_builder.py +++ b/tests/load/mssql/test_mssql_table_builder.py @@ -9,19 +9,14 @@ from dlt.destinations.impl.mssql.mssql import MsSqlClient from dlt.destinations.impl.mssql.configuration import MsSqlClientConfiguration, MsSqlCredentials -from tests.load.utils import TABLE_UPDATE +from tests.load.utils import TABLE_UPDATE, empty_schema @pytest.fixture -def schema() -> Schema: - return Schema("event") - - -@pytest.fixture -def client(schema: Schema) -> MsSqlClient: +def client(empty_schema: Schema) -> MsSqlClient: # return client without opening connection return MsSqlClient( - schema, + empty_schema, MsSqlClientConfiguration(dataset_name="test_" + uniq_id(), credentials=MsSqlCredentials()), ) diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index 0714ac333d..78d405ae65 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -1,6 +1,5 @@ from copy import copy import pytest -import itertools import random from typing import List import pytest @@ -11,10 +10,12 @@ from dlt.common import json, pendulum from dlt.common.configuration.container import Container from dlt.common.pipeline import StateInjectableContext -from dlt.common.typing import AnyFun, StrAny +from dlt.common.schema.utils import has_table_seen_data +from dlt.common.typing import StrAny from dlt.common.utils import digest128 from dlt.extract import DltResource from dlt.sources.helpers.transform import skip_first, take_first +from dlt.pipeline.exceptions import PipelineStepFailed from tests.pipeline.utils import assert_load_info from tests.load.pipeline.utils import load_table_counts, select_data @@ -34,6 +35,11 @@ def test_merge_on_keys_in_schema(destination_config: DestinationTestConfiguratio with open("tests/common/cases/schemas/eth/ethereum_schema_v5.yml", "r", encoding="utf-8") as f: schema = dlt.Schema.from_dict(yaml.safe_load(f)) + # make block uncles unseen to trigger filtering loader in loader for child tables + if has_table_seen_data(schema.tables["blocks__uncles"]): + del schema.tables["blocks__uncles"]["x-normalizer"] # type: ignore[typeddict-item] + assert not has_table_seen_data(schema.tables["blocks__uncles"]) + with open( "tests/normalize/cases/ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2.json", "r", @@ -488,3 +494,395 @@ def duplicates_no_child(): assert_load_info(info) counts = load_table_counts(p, "duplicates_no_child") assert counts["duplicates_no_child"] == 2 + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, supports_merge=True), + ids=lambda x: x.name, +) +def test_complex_column_missing(destination_config: DestinationTestConfiguration) -> None: + table_name = "test_complex_column_missing" + + @dlt.resource(name=table_name, write_disposition="merge", primary_key="id") + def r(data): + yield data + + p = destination_config.setup_pipeline("abstract", full_refresh=True) + + data = [{"id": 1, "simple": "foo", "complex": [1, 2, 3]}] + info = p.run(r(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + assert load_table_counts(p, table_name + "__complex")[table_name + "__complex"] == 3 + + # complex column is missing, previously inserted records should be deleted from child table + data = [{"id": 1, "simple": "bar"}] + info = p.run(r(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + assert load_table_counts(p, table_name + "__complex")[table_name + "__complex"] == 0 + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, supports_merge=True), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("key_type", ["primary_key", "merge_key"]) +def test_hard_delete_hint(destination_config: DestinationTestConfiguration, key_type: str) -> None: + table_name = "test_hard_delete_hint" + + @dlt.resource( + name=table_name, + write_disposition="merge", + columns={"deleted": {"hard_delete": True}}, + ) + def data_resource(data): + yield data + + if key_type == "primary_key": + data_resource.apply_hints(primary_key="id", merge_key="") + elif key_type == "merge_key": + data_resource.apply_hints(primary_key="", merge_key="id") + + p = destination_config.setup_pipeline(f"abstract_{key_type}", full_refresh=True) + + # insert two records + data = [ + {"id": 1, "val": "foo", "deleted": False}, + {"id": 2, "val": "bar", "deleted": False}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 2 + + # delete one record + data = [ + {"id": 1, "deleted": True}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + + # update one record (None for hard_delete column is treated as "not True") + data = [ + {"id": 2, "val": "baz", "deleted": None}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + + # compare observed records with expected records + qual_name = p.sql_client().make_qualified_table_name(table_name) + observed = [ + {"id": row[0], "val": row[1], "deleted": row[2]} + for row in select_data(p, f"SELECT id, val, deleted FROM {qual_name}") + ] + expected = [{"id": 2, "val": "baz", "deleted": None}] + assert sorted(observed, key=lambda d: d["id"]) == expected + + # insert two records with same key + data = [ + {"id": 3, "val": "foo", "deleted": False}, + {"id": 3, "val": "bar", "deleted": False}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + counts = load_table_counts(p, table_name)[table_name] + if key_type == "primary_key": + assert counts == 2 + elif key_type == "merge_key": + assert counts == 3 + + # delete one key, resulting in one (primary key) or two (merge key) deleted records + data = [ + {"id": 3, "val": "foo", "deleted": True}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + counts = load_table_counts(p, table_name)[table_name] + assert load_table_counts(p, table_name)[table_name] == 1 + + table_name = "test_hard_delete_hint_complex" + data_resource.apply_hints(table_name=table_name) + + # insert two records with childs and grandchilds + data = [ + { + "id": 1, + "child_1": ["foo", "bar"], + "child_2": [ + {"grandchild_1": ["foo", "bar"], "grandchild_2": True}, + {"grandchild_1": ["bar", "baz"], "grandchild_2": False}, + ], + "deleted": False, + }, + { + "id": 2, + "child_1": ["baz"], + "child_2": [{"grandchild_1": ["baz"], "grandchild_2": True}], + "deleted": False, + }, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 2 + assert load_table_counts(p, table_name + "__child_1")[table_name + "__child_1"] == 3 + assert load_table_counts(p, table_name + "__child_2")[table_name + "__child_2"] == 3 + assert ( + load_table_counts(p, table_name + "__child_2__grandchild_1")[ + table_name + "__child_2__grandchild_1" + ] + == 5 + ) + + # delete first record + data = [ + {"id": 1, "deleted": True}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + assert load_table_counts(p, table_name + "__child_1")[table_name + "__child_1"] == 1 + assert ( + load_table_counts(p, table_name + "__child_2__grandchild_1")[ + table_name + "__child_2__grandchild_1" + ] + == 1 + ) + + # delete second record + data = [ + {"id": 2, "deleted": True}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 0 + assert load_table_counts(p, table_name + "__child_1")[table_name + "__child_1"] == 0 + assert ( + load_table_counts(p, table_name + "__child_2__grandchild_1")[ + table_name + "__child_2__grandchild_1" + ] + == 0 + ) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, supports_merge=True), + ids=lambda x: x.name, +) +def test_hard_delete_hint_config(destination_config: DestinationTestConfiguration) -> None: + table_name = "test_hard_delete_hint_non_bool" + + @dlt.resource( + name=table_name, + write_disposition="merge", + primary_key="id", + columns={ + "deleted_timestamp": {"data_type": "timestamp", "nullable": True, "hard_delete": True} + }, + ) + def data_resource(data): + yield data + + p = destination_config.setup_pipeline("abstract", full_refresh=True) + + # insert two records + data = [ + {"id": 1, "val": "foo", "deleted_timestamp": None}, + {"id": 2, "val": "bar", "deleted_timestamp": None}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 2 + + # delete one record + data = [ + {"id": 1, "deleted_timestamp": "2024-02-15T17:16:53Z"}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + + # compare observed records with expected records + qual_name = p.sql_client().make_qualified_table_name(table_name) + observed = [ + {"id": row[0], "val": row[1], "deleted_timestamp": row[2]} + for row in select_data(p, f"SELECT id, val, deleted_timestamp FROM {qual_name}") + ] + expected = [{"id": 2, "val": "bar", "deleted_timestamp": None}] + assert sorted(observed, key=lambda d: d["id"]) == expected + + # test if exception is raised when more than one "hard_delete" column hints are provided + @dlt.resource( + name="test_hard_delete_hint_too_many_hints", + write_disposition="merge", + columns={"deleted_1": {"hard_delete": True}, "deleted_2": {"hard_delete": True}}, + ) + def r(): + yield {"id": 1, "val": "foo", "deleted_1": True, "deleted_2": False} + + with pytest.raises(PipelineStepFailed): + info = p.run(r()) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, supports_merge=True), + ids=lambda x: x.name, +) +def test_dedup_sort_hint(destination_config: DestinationTestConfiguration) -> None: + table_name = "test_dedup_sort_hint" + + @dlt.resource( + name=table_name, + write_disposition="merge", + primary_key="id", # sort hints only have effect when a primary key is provided + columns={"sequence": {"dedup_sort": "desc"}}, + ) + def data_resource(data): + yield data + + p = destination_config.setup_pipeline("abstract", full_refresh=True) + + # three records with same primary key + data = [ + {"id": 1, "val": "foo", "sequence": 1}, + {"id": 1, "val": "baz", "sequence": 3}, + {"id": 1, "val": "bar", "sequence": 2}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + + # compare observed records with expected records + # record with highest value in sort column is inserted (because "desc") + qual_name = p.sql_client().make_qualified_table_name(table_name) + observed = [ + {"id": row[0], "val": row[1], "sequence": row[2]} + for row in select_data(p, f"SELECT id, val, sequence FROM {qual_name}") + ] + expected = [{"id": 1, "val": "baz", "sequence": 3}] + assert sorted(observed, key=lambda d: d["id"]) == expected + + # now test "asc" sorting + data_resource.apply_hints(columns={"sequence": {"dedup_sort": "asc"}}) + + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + + # compare observed records with expected records + # record with highest lowest in sort column is inserted (because "asc") + qual_name = p.sql_client().make_qualified_table_name(table_name) + observed = [ + {"id": row[0], "val": row[1], "sequence": row[2]} + for row in select_data(p, f"SELECT id, val, sequence FROM {qual_name}") + ] + expected = [{"id": 1, "val": "foo", "sequence": 1}] + assert sorted(observed, key=lambda d: d["id"]) == expected + + table_name = "test_dedup_sort_hint_complex" + data_resource.apply_hints( + table_name=table_name, + columns={"sequence": {"dedup_sort": "desc"}}, + ) + + # three records with same primary key + # only record with highest value in sort column is inserted + data = [ + {"id": 1, "val": [1, 2, 3], "sequence": 1}, + {"id": 1, "val": [7, 8, 9], "sequence": 3}, + {"id": 1, "val": [4, 5, 6], "sequence": 2}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + assert load_table_counts(p, table_name + "__val")[table_name + "__val"] == 3 + + # compare observed records with expected records, now for child table + qual_name = p.sql_client().make_qualified_table_name(table_name + "__val") + observed = [row[0] for row in select_data(p, f"SELECT value FROM {qual_name}")] + assert sorted(observed) == [7, 8, 9] # type: ignore[type-var] + + table_name = "test_dedup_sort_hint_with_hard_delete" + data_resource.apply_hints( + table_name=table_name, + columns={"sequence": {"dedup_sort": "desc"}, "deleted": {"hard_delete": True}}, + ) + + # three records with same primary key + # record with highest value in sort column is a delete, so no record will be inserted + data = [ + {"id": 1, "val": "foo", "sequence": 1, "deleted": False}, + {"id": 1, "val": "baz", "sequence": 3, "deleted": True}, + {"id": 1, "val": "bar", "sequence": 2, "deleted": False}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 0 + + # three records with same primary key + # record with highest value in sort column is not a delete, so it will be inserted + data = [ + {"id": 1, "val": "foo", "sequence": 1, "deleted": False}, + {"id": 1, "val": "bar", "sequence": 2, "deleted": True}, + {"id": 1, "val": "baz", "sequence": 3, "deleted": False}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + + # compare observed records with expected records + qual_name = p.sql_client().make_qualified_table_name(table_name) + observed = [ + {"id": row[0], "val": row[1], "sequence": row[2]} + for row in select_data(p, f"SELECT id, val, sequence FROM {qual_name}") + ] + expected = [{"id": 1, "val": "baz", "sequence": 3}] + assert sorted(observed, key=lambda d: d["id"]) == expected + + # additional tests with two records, run only on duckdb to limit test load + if destination_config.destination == "duckdb": + # two records with same primary key + # record with highest value in sort column is a delete + # existing record is deleted and no record will be inserted + data = [ + {"id": 1, "val": "foo", "sequence": 1}, + {"id": 1, "val": "bar", "sequence": 2, "deleted": True}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 0 + + # two records with same primary key + # record with highest value in sort column is not a delete, so it will be inserted + data = [ + {"id": 1, "val": "foo", "sequence": 2}, + {"id": 1, "val": "bar", "sequence": 1, "deleted": True}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + + # test if exception is raised for invalid column schema's + @dlt.resource( + name="test_dedup_sort_hint_too_many_hints", + write_disposition="merge", + columns={"dedup_sort_1": {"dedup_sort": "this_is_invalid"}}, # type: ignore[call-overload] + ) + def r(): + yield {"id": 1, "val": "foo", "dedup_sort_1": 1, "dedup_sort_2": 5} + + # invalid value for "dedup_sort" hint + with pytest.raises(PipelineStepFailed): + info = p.run(r()) + + # more than one "dedup_sort" column hints are provided + r.apply_hints( + columns={"dedup_sort_1": {"dedup_sort": "desc"}, "dedup_sort_2": {"dedup_sort": "desc"}} + ) + with pytest.raises(PipelineStepFailed): + info = p.run(r()) diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index af61dc5d3d..a483cbee1a 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -13,6 +13,7 @@ from dlt.common.schema.typing import VERSION_TABLE_NAME from dlt.common.typing import TDataItem from dlt.common.utils import uniq_id +from dlt.destinations.exceptions import DatabaseUndefinedRelation from dlt.extract.exceptions import ResourceNameMissing from dlt.extract import DltSource from dlt.pipeline.exceptions import ( @@ -23,8 +24,8 @@ from dlt.common.schema.exceptions import CannotCoerceColumnException from dlt.common.exceptions import DestinationHasFailedJobs -from tests.utils import TEST_STORAGE_ROOT, preserve_environ -from tests.pipeline.utils import assert_load_info +from tests.utils import TEST_STORAGE_ROOT, data_to_item_format, preserve_environ +from tests.pipeline.utils import assert_data_table_counts, assert_load_info from tests.load.utils import ( TABLE_ROW_ALL_DATA_TYPES, TABLE_UPDATE_COLUMNS_SCHEMA, @@ -37,6 +38,7 @@ assert_table, load_table_counts, select_data, + REPLACE_STRATEGIES, ) from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration @@ -849,6 +851,122 @@ def some_source(): ) +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_staging_configs=True, default_sql_configs=True), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("replace_strategy", REPLACE_STRATEGIES) +def test_pipeline_upfront_tables_two_loads( + destination_config: DestinationTestConfiguration, replace_strategy: str +) -> None: + if not destination_config.supports_merge and replace_strategy != "truncate-and-insert": + pytest.skip( + f"Destination {destination_config.name} does not support merge and thus" + f" {replace_strategy}" + ) + + # use staging tables for replace + os.environ["DESTINATION__REPLACE_STRATEGY"] = replace_strategy + + pipeline = destination_config.setup_pipeline( + "test_pipeline_upfront_tables_two_loads", + dataset_name="test_pipeline_upfront_tables_two_loads", + full_refresh=True, + ) + + @dlt.source + def two_tables(): + @dlt.resource( + columns=[{"name": "id", "data_type": "bigint", "nullable": True}], + write_disposition="merge", + ) + def table_1(): + yield {"id": 1} + + @dlt.resource( + columns=[{"name": "id", "data_type": "bigint", "nullable": True}], + write_disposition="merge", + ) + def table_2(): + yield data_to_item_format("arrow", [{"id": 2}]) + + @dlt.resource( + columns=[{"name": "id", "data_type": "bigint", "nullable": True}], + write_disposition="replace", + ) + def table_3(make_data=False): + if not make_data: + return + yield {"id": 3} + + return table_1, table_2, table_3 + + # discover schema + schema = two_tables().discover_schema() + # print(schema.to_pretty_yaml()) + + # now we use this schema but load just one resource + source = two_tables() + # push state, table 3 not created + load_info_1 = pipeline.run(source.table_3, schema=schema) + assert_load_info(load_info_1) + with pytest.raises(DatabaseUndefinedRelation): + load_table_counts(pipeline, "table_3") + assert "x-normalizer" not in pipeline.default_schema.tables["table_3"] + assert ( + pipeline.default_schema.tables["_dlt_pipeline_state"]["x-normalizer"]["seen-data"] # type: ignore[typeddict-item] + is True + ) + + # load with one empty job, table 3 not created + load_info = pipeline.run(source.table_3) + assert_load_info(load_info) + with pytest.raises(DatabaseUndefinedRelation): + load_table_counts(pipeline, "table_3") + # print(pipeline.default_schema.to_pretty_yaml()) + + load_info_2 = pipeline.run([source.table_1, source.table_3]) + assert_load_info(load_info_2) + # 1 record in table 1 + assert pipeline.last_trace.last_normalize_info.row_counts["table_1"] == 1 + assert "table_3" not in pipeline.last_trace.last_normalize_info.row_counts + assert "table_2" not in pipeline.last_trace.last_normalize_info.row_counts + # only table_1 got created + assert load_table_counts(pipeline, "table_1") == {"table_1": 1} + with pytest.raises(DatabaseUndefinedRelation): + load_table_counts(pipeline, "table_2") + with pytest.raises(DatabaseUndefinedRelation): + load_table_counts(pipeline, "table_3") + + # v4 = pipeline.default_schema.to_pretty_yaml() + # print(v4) + + # now load the second one. for arrow format the schema will not update because + # in that case normalizer does not add dlt specific fields, changes are not detected + # and schema is not updated because the hash didn't change + # also we make the replace resource to load its 1 record + load_info_3 = pipeline.run([source.table_3(make_data=True), source.table_2]) + assert_load_info(load_info_3) + assert_data_table_counts(pipeline, {"table_1": 1, "table_2": 1, "table_3": 1}) + # v5 = pipeline.default_schema.to_pretty_yaml() + # print(v5) + + # check if seen data is market correctly + assert ( + pipeline.default_schema.tables["table_3"]["x-normalizer"]["seen-data"] # type: ignore[typeddict-item] + is True + ) + assert ( + pipeline.default_schema.tables["table_2"]["x-normalizer"]["seen-data"] # type: ignore[typeddict-item] + is True + ) + assert ( + pipeline.default_schema.tables["table_1"]["x-normalizer"]["seen-data"] # type: ignore[typeddict-item] + is True + ) + + def simple_nested_pipeline( destination_config: DestinationTestConfiguration, dataset_name: str, full_refresh: bool ) -> Tuple[dlt.Pipeline, Callable[[], DltSource]]: diff --git a/tests/load/pipeline/test_replace_disposition.py b/tests/load/pipeline/test_replace_disposition.py index c6db91efff..a69d4440dc 100644 --- a/tests/load/pipeline/test_replace_disposition.py +++ b/tests/load/pipeline/test_replace_disposition.py @@ -9,9 +9,11 @@ load_table_counts, load_tables_to_dicts, ) -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration - -REPLACE_STRATEGIES = ["truncate-and-insert", "insert-from-staging", "staging-optimized"] +from tests.load.pipeline.utils import ( + destinations_configs, + DestinationTestConfiguration, + REPLACE_STRATEGIES, +) @pytest.mark.parametrize( diff --git a/tests/load/pipeline/test_restore_state.py b/tests/load/pipeline/test_restore_state.py index 381068f1e1..73c651688d 100644 --- a/tests/load/pipeline/test_restore_state.py +++ b/tests/load/pipeline/test_restore_state.py @@ -18,7 +18,7 @@ from tests.utils import TEST_STORAGE_ROOT from tests.cases import JSON_TYPED_DICT, JSON_TYPED_DICT_DECODED -from tests.common.utils import IMPORTED_VERSION_HASH_ETH_V8, yml_case_path as common_yml_case_path +from tests.common.utils import IMPORTED_VERSION_HASH_ETH_V9, yml_case_path as common_yml_case_path from tests.common.configuration.utils import environment from tests.load.pipeline.utils import assert_query_data, drop_active_pipeline_data from tests.load.utils import ( @@ -469,7 +469,7 @@ def test_restore_schemas_while_import_schemas_exist( assert normalized_annotations in schema.tables # check if attached to import schema - assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V8 + assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9 # extract some data with restored pipeline p.run(["C", "D", "E"], table_name="blacklist") assert normalized_labels in schema.tables diff --git a/tests/load/pipeline/utils.py b/tests/load/pipeline/utils.py index 17360e76fd..54c6231dcc 100644 --- a/tests/load/pipeline/utils.py +++ b/tests/load/pipeline/utils.py @@ -21,6 +21,8 @@ if TYPE_CHECKING: from dlt.destinations.impl.filesystem.filesystem import FilesystemClient +REPLACE_STRATEGIES = ["truncate-and-insert", "insert-from-staging", "staging-optimized"] + @pytest.fixture(autouse=True) def drop_pipeline(request) -> Iterator[None]: diff --git a/tests/load/postgres/test_postgres_table_builder.py b/tests/load/postgres/test_postgres_table_builder.py index 68e6702b75..fde9d82cf7 100644 --- a/tests/load/postgres/test_postgres_table_builder.py +++ b/tests/load/postgres/test_postgres_table_builder.py @@ -2,6 +2,7 @@ from copy import deepcopy import sqlfluff +from dlt.common.schema.utils import new_table from dlt.common.utils import uniq_id from dlt.common.schema import Schema @@ -11,19 +12,14 @@ PostgresCredentials, ) -from tests.load.utils import TABLE_UPDATE +from tests.load.utils import TABLE_UPDATE, empty_schema @pytest.fixture -def schema() -> Schema: - return Schema("event") - - -@pytest.fixture -def client(schema: Schema) -> PostgresClient: +def client(empty_schema: Schema) -> PostgresClient: # return client without opening connection return PostgresClient( - schema, + empty_schema, PostgresClientConfiguration( dataset_name="test_" + uniq_id(), credentials=PostgresCredentials() ), diff --git a/tests/load/redshift/test_redshift_table_builder.py b/tests/load/redshift/test_redshift_table_builder.py index e280dd94e0..c6981e5553 100644 --- a/tests/load/redshift/test_redshift_table_builder.py +++ b/tests/load/redshift/test_redshift_table_builder.py @@ -12,19 +12,14 @@ RedshiftCredentials, ) -from tests.load.utils import TABLE_UPDATE +from tests.load.utils import TABLE_UPDATE, empty_schema @pytest.fixture -def schema() -> Schema: - return Schema("event") - - -@pytest.fixture -def client(schema: Schema) -> RedshiftClient: +def client(empty_schema: Schema) -> RedshiftClient: # return client without opening connection return RedshiftClient( - schema, + empty_schema, RedshiftClientConfiguration( dataset_name="test_" + uniq_id(), credentials=RedshiftCredentials() ), diff --git a/tests/load/snowflake/test_snowflake_table_builder.py b/tests/load/snowflake/test_snowflake_table_builder.py index e6eaf26c89..1e80a61f1c 100644 --- a/tests/load/snowflake/test_snowflake_table_builder.py +++ b/tests/load/snowflake/test_snowflake_table_builder.py @@ -12,20 +12,16 @@ ) from dlt.destinations.exceptions import DestinationSchemaWillNotUpdate -from tests.load.utils import TABLE_UPDATE +from tests.load.utils import TABLE_UPDATE, empty_schema @pytest.fixture -def schema() -> Schema: - return Schema("event") - - -@pytest.fixture -def snowflake_client(schema: Schema) -> SnowflakeClient: +def snowflake_client(empty_schema: Schema) -> SnowflakeClient: # return client without opening connection creds = SnowflakeCredentials() return SnowflakeClient( - schema, SnowflakeClientConfiguration(dataset_name="test_" + uniq_id(), credentials=creds) + empty_schema, + SnowflakeClientConfiguration(dataset_name="test_" + uniq_id(), credentials=creds), ) diff --git a/tests/load/synapse/test_synapse_table_builder.py b/tests/load/synapse/test_synapse_table_builder.py index 4719a8d003..871ceecf96 100644 --- a/tests/load/synapse/test_synapse_table_builder.py +++ b/tests/load/synapse/test_synapse_table_builder.py @@ -13,7 +13,7 @@ SynapseCredentials, ) -from tests.load.utils import TABLE_UPDATE +from tests.load.utils import TABLE_UPDATE, empty_schema from dlt.destinations.impl.synapse.synapse import ( HINT_TO_SYNAPSE_ATTR, TABLE_INDEX_TYPE_TO_SYNAPSE_ATTR, @@ -21,15 +21,10 @@ @pytest.fixture -def schema() -> Schema: - return Schema("event") - - -@pytest.fixture -def client(schema: Schema) -> SynapseClient: +def client(empty_schema: Schema) -> SynapseClient: # return client without opening connection client = SynapseClient( - schema, + empty_schema, SynapseClientConfiguration( dataset_name="test_" + uniq_id(), credentials=SynapseCredentials() ), @@ -39,10 +34,10 @@ def client(schema: Schema) -> SynapseClient: @pytest.fixture -def client_with_indexes_enabled(schema: Schema) -> SynapseClient: +def client_with_indexes_enabled(empty_schema: Schema) -> SynapseClient: # return client without opening connection client = SynapseClient( - schema, + empty_schema, SynapseClientConfiguration( dataset_name="test_" + uniq_id(), credentials=SynapseCredentials(), create_indexes=True ), diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index 7436023f03..d7884abcf0 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -6,18 +6,26 @@ from typing import List from dlt.common.exceptions import TerminalException, TerminalValueError +from dlt.common.schema.typing import TWriteDisposition from dlt.common.storages import FileStorage, LoadStorage, PackageStorage, ParsedLoadJobFileName +from dlt.common.storages.load_package import LoadJobInfo from dlt.common.storages.load_storage import JobWithUnsupportedWriterException from dlt.common.destination.reference import LoadJob, TDestination +from dlt.common.schema.utils import ( + fill_hints_from_parent_and_clone_table, + get_child_tables, + get_top_level_table, +) -from dlt.load import Load +from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration from dlt.destinations.job_impl import EmptyLoadJob - -from dlt.destinations import dummy +from dlt.destinations import dummy, filesystem from dlt.destinations.impl.dummy import dummy as dummy_impl from dlt.destinations.impl.dummy.configuration import DummyClientConfiguration + +from dlt.load import Load from dlt.load.exceptions import LoadClientJobFailed, LoadClientJobRetry -from dlt.common.schema.utils import get_top_level_table +from dlt.load.utils import get_completed_table_chain, init_client, _extend_tables_with_table_chain from tests.utils import ( clean_test_storage, @@ -26,7 +34,7 @@ preserve_environ, ) from tests.load.utils import prepare_load_package -from tests.utils import skip_if_not_active +from tests.utils import skip_if_not_active, TEST_STORAGE_ROOT skip_if_not_active("dummy") @@ -35,6 +43,8 @@ "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl", ] +REMOTE_FILESYSTEM = os.path.abspath(os.path.join(TEST_STORAGE_ROOT, "_remote_filesystem")) + @pytest.fixture(autouse=True) def storage() -> FileStorage: @@ -110,14 +120,19 @@ def test_get_completed_table_chain_single_job_per_table() -> None: load = setup_loader() load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) + # update tables so we have all possible hints + for table_name, table in schema.tables.items(): + schema.tables[table_name] = fill_hints_from_parent_and_clone_table(schema.tables, table) + top_job_table = get_top_level_table(schema.tables, "event_user") - assert load.get_completed_table_chain(load_id, schema, top_job_table) is None + all_jobs = load.load_storage.normalized_packages.list_all_jobs(load_id) + assert get_completed_table_chain(schema, all_jobs, top_job_table) is None # fake being completed assert ( len( - load.get_completed_table_chain( - load_id, + get_completed_table_chain( schema, + all_jobs, top_job_table, "event_user.839c6e6b514e427687586ccc65bf133f.jsonl", ) @@ -129,15 +144,17 @@ def test_get_completed_table_chain_single_job_per_table() -> None: load.load_storage.normalized_packages.start_job( load_id, "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl" ) - assert load.get_completed_table_chain(load_id, schema, loop_top_job_table) is None + all_jobs = load.load_storage.normalized_packages.list_all_jobs(load_id) + assert get_completed_table_chain(schema, all_jobs, loop_top_job_table) is None load.load_storage.normalized_packages.complete_job( load_id, "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl" ) - assert load.get_completed_table_chain(load_id, schema, loop_top_job_table) == [ + all_jobs = load.load_storage.normalized_packages.list_all_jobs(load_id) + assert get_completed_table_chain(schema, all_jobs, loop_top_job_table) == [ schema.get_table("event_loop_interrupted") ] - assert load.get_completed_table_chain( - load_id, schema, loop_top_job_table, "event_user.839c6e6b514e427687586ccc65bf133f.0.jsonl" + assert get_completed_table_chain( + schema, all_jobs, loop_top_job_table, "event_user.839c6e6b514e427687586ccc65bf133f.0.jsonl" ) == [schema.get_table("event_loop_interrupted")] @@ -188,7 +205,7 @@ def test_spool_job_failed_exception_init() -> None: # this config fails job on start os.environ["LOAD__RAISE_ON_FAILED_JOBS"] = "true" os.environ["FAIL_IN_INIT"] = "true" - load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0)) + load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0, fail_in_init=True)) load_id, _ = prepare_load_package(load.load_storage, NORMALIZED_FILES) with patch.object(dummy_impl.DummyClient, "complete_load") as complete_load: with pytest.raises(LoadClientJobFailed) as py_ex: @@ -207,7 +224,7 @@ def test_spool_job_failed_exception_complete() -> None: # this config fails job on start os.environ["LOAD__RAISE_ON_FAILED_JOBS"] = "true" os.environ["FAIL_IN_INIT"] = "false" - load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0)) + load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0, fail_in_init=False)) load_id, _ = prepare_load_package(load.load_storage, NORMALIZED_FILES) with pytest.raises(LoadClientJobFailed) as py_ex: run_all(load) @@ -310,6 +327,20 @@ def test_try_retrieve_job() -> None: def test_completed_loop() -> None: load = setup_loader(client_config=DummyClientConfiguration(completed_prob=1.0)) assert_complete_job(load) + assert len(dummy_impl.JOBS) == 2 + assert len(dummy_impl.CREATED_FOLLOWUP_JOBS) == 0 + + +def test_completed_loop_followup_jobs() -> None: + # TODO: until we fix how we create capabilities we must set env + os.environ["CREATE_FOLLOWUP_JOBS"] = "true" + load = setup_loader( + client_config=DummyClientConfiguration(completed_prob=1.0, create_followup_jobs=True) + ) + assert_complete_job(load) + # for each JOB there's REFERENCE JOB + assert len(dummy_impl.JOBS) == 2 * 2 + assert len(dummy_impl.JOBS) == len(dummy_impl.CREATED_FOLLOWUP_JOBS) * 2 def test_failed_loop() -> None: @@ -319,6 +350,27 @@ def test_failed_loop() -> None: ) # actually not deleted because one of the jobs failed assert_complete_job(load, should_delete_completed=False) + # no jobs because fail on init + assert len(dummy_impl.JOBS) == 0 + assert len(dummy_impl.CREATED_FOLLOWUP_JOBS) == 0 + + +def test_failed_loop_followup_jobs() -> None: + # TODO: until we fix how we create capabilities we must set env + os.environ["CREATE_FOLLOWUP_JOBS"] = "true" + os.environ["FAIL_IN_INIT"] = "false" + # ask to delete completed + load = setup_loader( + delete_completed_jobs=True, + client_config=DummyClientConfiguration( + fail_prob=1.0, fail_in_init=False, create_followup_jobs=True + ), + ) + # actually not deleted because one of the jobs failed + assert_complete_job(load, should_delete_completed=False) + # followup jobs were not started + assert len(dummy_impl.JOBS) == 2 + assert len(dummy_impl.CREATED_FOLLOWUP_JOBS) == 0 def test_completed_loop_with_delete_completed() -> None: @@ -409,6 +461,286 @@ def test_wrong_writer_type() -> None: assert exv.value.load_id == load_id +def test_extend_table_chain() -> None: + load = setup_loader() + _, schema = prepare_load_package( + load.load_storage, ["event_user.b1d32c6660b242aaabbf3fc27245b7e6.0.insert_values"] + ) + # only event user table (no other jobs) + tables = _extend_tables_with_table_chain(schema, ["event_user"], ["event_user"]) + assert tables == {"event_user"} + # add child jobs + tables = _extend_tables_with_table_chain( + schema, ["event_user"], ["event_user", "event_user__parse_data__entities"] + ) + assert tables == {"event_user", "event_user__parse_data__entities"} + user_chain = {name for name in schema.data_table_names() if name.startswith("event_user__")} | { + "event_user" + } + # change event user to merge/replace to get full table chain + for w_d in ["merge", "replace"]: + schema.tables["event_user"]["write_disposition"] = w_d # type:ignore[typeddict-item] + tables = _extend_tables_with_table_chain(schema, ["event_user"], ["event_user"]) + assert tables == user_chain + # no jobs for bot + assert _extend_tables_with_table_chain(schema, ["event_bot"], ["event_user"]) == set() + # skip unseen tables + del schema.tables["event_user__parse_data__entities"][ # type:ignore[typeddict-item] + "x-normalizer" + ] + entities_chain = { + name + for name in schema.data_table_names() + if name.startswith("event_user__parse_data__entities") + } + tables = _extend_tables_with_table_chain(schema, ["event_user"], ["event_user"]) + assert tables == user_chain - {"event_user__parse_data__entities"} + # exclude the whole chain + tables = _extend_tables_with_table_chain( + schema, ["event_user"], ["event_user"], lambda table: table["name"] not in entities_chain + ) + assert tables == user_chain - entities_chain + # ask for tables that are not top + tables = _extend_tables_with_table_chain(schema, ["event_user__parse_data__entities"], []) + # user chain but without entities (not seen data) + assert tables == user_chain - {"event_user__parse_data__entities"} + # go to append and ask only for entities chain + schema.tables["event_user"]["write_disposition"] = "append" + tables = _extend_tables_with_table_chain( + schema, ["event_user__parse_data__entities"], entities_chain + ) + # without entities (not seen data) + assert tables == entities_chain - {"event_user__parse_data__entities"} + + # add multiple chains + bot_jobs = {"event_bot", "event_bot__data__buttons"} + tables = _extend_tables_with_table_chain( + schema, ["event_user__parse_data__entities", "event_bot"], entities_chain | bot_jobs + ) + assert tables == (entities_chain | bot_jobs) - {"event_user__parse_data__entities"} + + +def test_get_completed_table_chain_cases() -> None: + load = setup_loader() + _, schema = prepare_load_package( + load.load_storage, ["event_user.b1d32c6660b242aaabbf3fc27245b7e6.0.insert_values"] + ) + + # update tables so we have all possible hints + for table_name, table in schema.tables.items(): + schema.tables[table_name] = fill_hints_from_parent_and_clone_table(schema.tables, table) + + # child completed, parent not + event_user = schema.get_table("event_user") + event_user_entities = schema.get_table("event_user__parse_data__entities") + event_user_job = LoadJobInfo( + "started_jobs", + "path", + 0, + None, + 0, + ParsedLoadJobFileName("event_user", "event_user_id", 0, "jsonl"), + None, + ) + event_user_entities_job = LoadJobInfo( + "completed_jobs", + "path", + 0, + None, + 0, + ParsedLoadJobFileName( + "event_user__parse_data__entities", "event_user__parse_data__entities_id", 0, "jsonl" + ), + None, + ) + chain = get_completed_table_chain(schema, [event_user_job, event_user_entities_job], event_user) + assert chain is None + + # parent just got completed + chain = get_completed_table_chain( + schema, + [event_user_job, event_user_entities_job], + event_user, + event_user_job.job_file_info.job_id(), + ) + # full chain + assert chain == [event_user, event_user_entities] + + # parent failed, child completed + chain = get_completed_table_chain( + schema, [event_user_job._replace(state="failed_jobs"), event_user_entities_job], event_user + ) + assert chain == [event_user, event_user_entities] + + # both failed + chain = get_completed_table_chain( + schema, + [ + event_user_job._replace(state="failed_jobs"), + event_user_entities_job._replace(state="failed_jobs"), + ], + event_user, + ) + assert chain == [event_user, event_user_entities] + + # merge and replace do not require whole chain to be in jobs + user_chain = get_child_tables(schema.tables, "event_user") + for w_d in ["merge", "replace"]: + event_user["write_disposition"] = w_d # type:ignore[typeddict-item] + + chain = get_completed_table_chain( + schema, [event_user_job], event_user, event_user_job.job_file_info.job_id() + ) + assert chain == user_chain + + # but if child is present and incomplete... + chain = get_completed_table_chain( + schema, + [event_user_job, event_user_entities_job._replace(state="new_jobs")], + event_user, + event_user_job.job_file_info.job_id(), + ) + # noting is returned + assert chain is None + + # skip unseen + deep_child = schema.tables[ + "event_user__parse_data__response_selector__default__response__response_templates" + ] + del deep_child["x-normalizer"] # type:ignore[typeddict-item] + chain = get_completed_table_chain( + schema, [event_user_job], event_user, event_user_job.job_file_info.job_id() + ) + user_chain.remove(deep_child) + assert chain == user_chain + + +def test_init_client_truncate_tables() -> None: + load = setup_loader() + _, schema = prepare_load_package( + load.load_storage, ["event_user.b1d32c6660b242aaabbf3fc27245b7e6.0.insert_values"] + ) + + nothing_ = lambda _: False + all_ = lambda _: True + + event_user = ParsedLoadJobFileName("event_user", "event_user_id", 0, "jsonl") + event_bot = ParsedLoadJobFileName("event_bot", "event_bot_id", 0, "jsonl") + + with patch.object(dummy_impl.DummyClient, "initialize_storage") as initialize_storage: + with patch.object(dummy_impl.DummyClient, "update_stored_schema") as update_stored_schema: + with load.get_destination_client(schema) as client: + init_client(client, schema, [], {}, nothing_, nothing_) + # we do not allow for any staging dataset tables + assert update_stored_schema.call_count == 1 + assert update_stored_schema.call_args[1]["only_tables"] == { + "_dlt_loads", + "_dlt_version", + } + assert initialize_storage.call_count == 2 + # initialize storage is called twice, we deselected all tables to truncate + assert initialize_storage.call_args_list[0].args == () + assert initialize_storage.call_args_list[1].kwargs["truncate_tables"] == set() + + initialize_storage.reset_mock() + update_stored_schema.reset_mock() + + # now we want all tables to be truncated but not on staging + with load.get_destination_client(schema) as client: + init_client(client, schema, [event_user], {}, all_, nothing_) + assert update_stored_schema.call_count == 1 + assert "event_user" in update_stored_schema.call_args[1]["only_tables"] + assert initialize_storage.call_count == 2 + assert initialize_storage.call_args_list[0].args == () + assert initialize_storage.call_args_list[1].kwargs["truncate_tables"] == {"event_user"} + + # now we push all to stage + initialize_storage.reset_mock() + update_stored_schema.reset_mock() + + with load.get_destination_client(schema) as client: + init_client(client, schema, [event_user, event_bot], {}, nothing_, all_) + assert update_stored_schema.call_count == 2 + # first call main dataset + assert {"event_user", "event_bot"} <= set( + update_stored_schema.call_args_list[0].kwargs["only_tables"] + ) + # second one staging dataset + assert {"event_user", "event_bot"} <= set( + update_stored_schema.call_args_list[1].kwargs["only_tables"] + ) + assert initialize_storage.call_count == 4 + assert initialize_storage.call_args_list[0].args == () + assert initialize_storage.call_args_list[1].kwargs["truncate_tables"] == set() + assert initialize_storage.call_args_list[2].args == () + # all tables that will be used on staging must be truncated + assert initialize_storage.call_args_list[3].kwargs["truncate_tables"] == { + "event_user", + "event_bot", + } + + replace_ = lambda table: table["write_disposition"] == "replace" + merge_ = lambda table: table["write_disposition"] == "merge" + + # set event_bot chain to merge + bot_chain = get_child_tables(schema.tables, "event_bot") + for w_d in ["merge", "replace"]: + initialize_storage.reset_mock() + update_stored_schema.reset_mock() + for bot in bot_chain: + bot["write_disposition"] = w_d # type:ignore[typeddict-item] + # merge goes to staging, replace goes to truncate + with load.get_destination_client(schema) as client: + init_client(client, schema, [event_user, event_bot], {}, replace_, merge_) + + if w_d == "merge": + # we use staging dataset + assert update_stored_schema.call_count == 2 + # 4 tables to update in main dataset + assert len(update_stored_schema.call_args_list[0].kwargs["only_tables"]) == 4 + assert ( + "event_user" in update_stored_schema.call_args_list[0].kwargs["only_tables"] + ) + # full bot table chain + dlt version but no user + assert len( + update_stored_schema.call_args_list[1].kwargs["only_tables"] + ) == 1 + len(bot_chain) + assert ( + "event_user" + not in update_stored_schema.call_args_list[1].kwargs["only_tables"] + ) + + assert initialize_storage.call_count == 4 + assert initialize_storage.call_args_list[1].kwargs["truncate_tables"] == set() + assert initialize_storage.call_args_list[3].kwargs[ + "truncate_tables" + ] == update_stored_schema.call_args_list[1].kwargs["only_tables"] - { + "_dlt_version" + } + + if w_d == "replace": + assert update_stored_schema.call_count == 1 + assert initialize_storage.call_count == 2 + # we truncate the whole bot chain but not user (which is append) + assert len( + initialize_storage.call_args_list[1].kwargs["truncate_tables"] + ) == len(bot_chain) + # migrate only tables for which we have jobs + assert len(update_stored_schema.call_args_list[0].kwargs["only_tables"]) == 4 + # print(initialize_storage.call_args_list) + # print(update_stored_schema.call_args_list) + + +def test_dummy_staging_filesystem() -> None: + load = setup_loader( + client_config=DummyClientConfiguration(completed_prob=1.0), filesystem_staging=True + ) + assert_complete_job(load) + # two reference jobs + assert len(dummy_impl.JOBS) == 2 + assert len(dummy_impl.CREATED_FOLLOWUP_JOBS) == 0 + + def test_terminal_exceptions() -> None: try: raise TerminalValueError("a") @@ -433,6 +765,13 @@ def assert_complete_job(load: Load, should_delete_completed: bool = False) -> No ) # will finalize the whole package load.run(pool) + # may have followup jobs or staging destination + if ( + load.initial_client_config.create_followup_jobs # type:ignore[attr-defined] + or load.staging_destination + ): + # run the followup jobs + load.run(pool) # moved to loaded assert not load.load_storage.storage.has_folder( load.load_storage.get_normalized_package_path(load_id) @@ -460,15 +799,32 @@ def run_all(load: Load) -> None: def setup_loader( - delete_completed_jobs: bool = False, client_config: DummyClientConfiguration = None + delete_completed_jobs: bool = False, + client_config: DummyClientConfiguration = None, + filesystem_staging: bool = False, ) -> Load: # reset jobs for a test dummy_impl.JOBS = {} - destination: TDestination = dummy() # type: ignore[assignment] + dummy_impl.CREATED_FOLLOWUP_JOBS = {} client_config = client_config or DummyClientConfiguration(loader_file_format="jsonl") + destination: TDestination = dummy(**client_config) # type: ignore[assignment] + # setup + staging_system_config = None + staging = None + if filesystem_staging: + # do not accept jsonl to not conflict with filesystem destination + client_config = client_config or DummyClientConfiguration(loader_file_format="reference") + staging_system_config = FilesystemDestinationClientConfiguration(dataset_name="dummy") + staging_system_config.as_staging = True + os.makedirs(REMOTE_FILESYSTEM) + staging = filesystem(bucket_url=REMOTE_FILESYSTEM) # patch destination to provide client_config # destination.client = lambda schema: dummy_impl.DummyClient(schema, client_config) - # setup loader with TEST_DICT_CONFIG_PROVIDER().values({"delete_completed_jobs": delete_completed_jobs}): - return Load(destination, initial_client_config=client_config) + return Load( + destination, + initial_client_config=client_config, + staging_destination=staging, # type: ignore[arg-type] + initial_staging_client_config=staging_system_config, + ) diff --git a/tests/load/utils.py b/tests/load/utils.py index a571e4b640..50dca88248 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -1,3 +1,4 @@ +import pytest import contextlib import codecs import os @@ -159,6 +160,7 @@ def destinations_configs( subset: Sequence[str] = (), exclude: Sequence[str] = (), file_format: Optional[TLoaderFileFormat] = None, + supports_merge: Optional[bool] = None, supports_dbt: Optional[bool] = None, ) -> List[DestinationTestConfiguration]: # sanity check @@ -378,6 +380,10 @@ def destinations_configs( destination_configs = [ conf for conf in destination_configs if conf.file_format == file_format ] + if supports_merge is not None: + destination_configs = [ + conf for conf in destination_configs if conf.supports_merge == supports_merge + ] if supports_dbt is not None: destination_configs = [ conf for conf in destination_configs if conf.supports_dbt == supports_dbt @@ -391,6 +397,14 @@ def destinations_configs( return destination_configs +@pytest.fixture +def empty_schema() -> Schema: + schema = Schema("event") + table = new_table("event_test_table") + schema.update_table(table) + return schema + + def get_normalized_dataset_name(client: JobClientBase) -> str: if isinstance(client.config, DestinationClientDwhConfiguration): return client.config.normalize_dataset_name(client.schema) @@ -420,7 +434,7 @@ def expect_load_file( client.capabilities.preferred_loader_file_format, ).file_name() file_storage.save(file_name, query.encode("utf-8")) - table = client.get_load_table(table_name) + table = client.prepare_load_table(table_name) job = client.start_file_load(table, file_storage.make_full_path(file_name), uniq_id()) while job.state() == "running": sleep(0.5) diff --git a/tests/pipeline/test_dlt_versions.py b/tests/pipeline/test_dlt_versions.py index 5cf1857dfa..cec7562d60 100644 --- a/tests/pipeline/test_dlt_versions.py +++ b/tests/pipeline/test_dlt_versions.py @@ -99,7 +99,7 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: f".dlt/pipelines/{GITHUB_PIPELINE_NAME}/schemas/github.schema.json" ) ) - assert github_schema["engine_version"] == 8 + assert github_schema["engine_version"] == 9 assert "schema_version_hash" in github_schema["tables"][LOADS_TABLE_NAME]["columns"] # load state state_dict = json.loads( @@ -149,7 +149,7 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: pipeline.sync_destination() # print(pipeline.working_dir) # we have updated schema - assert pipeline.default_schema.ENGINE_VERSION == 8 + assert pipeline.default_schema.ENGINE_VERSION == 9 # make sure that schema hash retrieved from the destination is exactly the same as the schema hash that was in storage before the schema was wiped assert pipeline.default_schema.stored_version_hash == github_schema["version_hash"] @@ -204,7 +204,7 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: ) pipeline = pipeline.drop() pipeline.sync_destination() - assert pipeline.default_schema.ENGINE_VERSION == 8 + assert pipeline.default_schema.ENGINE_VERSION == 9 # schema version does not match `dlt.attach` does not update to the right schema by itself assert pipeline.default_schema.stored_version_hash != github_schema["version_hash"] # state has hash