From 82c363482ae98c7b812115b7553dae7e5880cac5 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Mon, 12 Feb 2024 02:32:51 +0100 Subject: [PATCH 01/28] black formatting --- dlt/destinations/impl/athena/athena.py | 16 +++++------ dlt/destinations/impl/bigquery/bigquery.py | 10 ++++--- .../impl/databricks/databricks.py | 28 ++++++++++++------- dlt/destinations/impl/snowflake/snowflake.py | 6 ++-- 4 files changed, 33 insertions(+), 27 deletions(-) diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index 91525d771c..96e7818d57 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -351,7 +351,9 @@ def _from_db_type( return self.type_mapper.from_db_type(hive_t, precision, scale) def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: - return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}" + return ( + f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}" + ) def _get_table_update_sql( self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool @@ -376,19 +378,15 @@ def _get_table_update_sql( # use qualified table names qualified_table_name = self.sql_client.make_qualified_ddl_table_name(table_name) if is_iceberg and not generate_alter: - sql.append( - f"""CREATE TABLE {qualified_table_name} + sql.append(f"""CREATE TABLE {qualified_table_name} ({columns}) LOCATION '{location}' - TBLPROPERTIES ('table_type'='ICEBERG', 'format'='parquet');""" - ) + TBLPROPERTIES ('table_type'='ICEBERG', 'format'='parquet');""") elif not generate_alter: - sql.append( - f"""CREATE EXTERNAL TABLE {qualified_table_name} + sql.append(f"""CREATE EXTERNAL TABLE {qualified_table_name} ({columns}) STORED AS PARQUET - LOCATION '{location}';""" - ) + LOCATION '{location}';""") # alter table to add new columns at the end else: sql.append(f"""ALTER TABLE {qualified_table_name} ADD COLUMNS ({columns});""") diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index 1058b1d2c9..16b5d82c61 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -252,9 +252,9 @@ def _get_table_update_sql( elif (c := partition_list[0])["data_type"] == "date": sql[0] = f"{sql[0]}\nPARTITION BY {self.capabilities.escape_identifier(c['name'])}" elif (c := partition_list[0])["data_type"] == "timestamp": - sql[ - 0 - ] = f"{sql[0]}\nPARTITION BY DATE({self.capabilities.escape_identifier(c['name'])})" + sql[0] = ( + f"{sql[0]}\nPARTITION BY DATE({self.capabilities.escape_identifier(c['name'])})" + ) # Automatic partitioning of an INT64 type requires us to be prescriptive - we treat the column as a UNIX timestamp. # This is due to the bounds requirement of GENERATE_ARRAY function for partitioning. # The 10,000 partitions limit makes it infeasible to cover the entire `bigint` range. @@ -272,7 +272,9 @@ def _get_table_update_sql( def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: name = self.capabilities.escape_identifier(c["name"]) - return f"{name} {self.type_mapper.to_db_type(c, table_format)} {self._gen_not_null(c.get('nullable', True))}" + return ( + f"{name} {self.type_mapper.to_db_type(c, table_format)} {self._gen_not_null(c.get('nullable', True))}" + ) def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: schema_table: TTableSchemaColumns = {} diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index b5a404302f..07e827cd28 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -166,12 +166,14 @@ def __init__( else: raise LoadJobTerminalException( file_path, - f"Databricks cannot load data from staging bucket {bucket_path}. Only s3 and azure buckets are supported", + f"Databricks cannot load data from staging bucket {bucket_path}. Only s3 and" + " azure buckets are supported", ) else: raise LoadJobTerminalException( file_path, - "Cannot load from local file. Databricks does not support loading from local files. Configure staging with an s3 or azure storage bucket.", + "Cannot load from local file. Databricks does not support loading from local files." + " Configure staging with an s3 or azure storage bucket.", ) # decide on source format, stage_file_path will either be a local file or a bucket path @@ -181,27 +183,33 @@ def __init__( if not config.get("data_writer.disable_compression"): raise LoadJobTerminalException( file_path, - "Databricks loader does not support gzip compressed JSON files. Please disable compression in the data writer configuration: https://dlthub.com/docs/reference/performance#disabling-and-enabling-file-compression", + "Databricks loader does not support gzip compressed JSON files. Please disable" + " compression in the data writer configuration:" + " https://dlthub.com/docs/reference/performance#disabling-and-enabling-file-compression", ) if table_schema_has_type(table, "decimal"): raise LoadJobTerminalException( file_path, - "Databricks loader cannot load DECIMAL type columns from json files. Switch to parquet format to load decimals.", + "Databricks loader cannot load DECIMAL type columns from json files. Switch to" + " parquet format to load decimals.", ) if table_schema_has_type(table, "binary"): raise LoadJobTerminalException( file_path, - "Databricks loader cannot load BINARY type columns from json files. Switch to parquet format to load byte values.", + "Databricks loader cannot load BINARY type columns from json files. Switch to" + " parquet format to load byte values.", ) if table_schema_has_type(table, "complex"): raise LoadJobTerminalException( file_path, - "Databricks loader cannot load complex columns (lists and dicts) from json files. Switch to parquet format to load complex types.", + "Databricks loader cannot load complex columns (lists and dicts) from json" + " files. Switch to parquet format to load complex types.", ) if table_schema_has_type(table, "date"): raise LoadJobTerminalException( file_path, - "Databricks loader cannot load DATE type columns from json files. Switch to parquet format to load dates.", + "Databricks loader cannot load DATE type columns from json files. Switch to" + " parquet format to load dates.", ) source_format = "JSON" @@ -311,7 +319,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non def _get_storage_table_query_columns(self) -> List[str]: fields = super()._get_storage_table_query_columns() - fields[ - 1 - ] = "full_data_type" # Override because this is the only way to get data type with precision + fields[1] = ( # Override because this is the only way to get data type with precision + "full_data_type" + ) return fields diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index fb51ab9d36..7fafbf83b7 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -175,15 +175,13 @@ def __init__( f'PUT file://{file_path} @{stage_name}/"{load_id}" OVERWRITE = TRUE,' " AUTO_COMPRESS = FALSE" ) - client.execute_sql( - f"""COPY INTO {qualified_table_name} + client.execute_sql(f"""COPY INTO {qualified_table_name} {from_clause} {files_clause} {credentials_clause} FILE_FORMAT = {source_format} MATCH_BY_COLUMN_NAME='CASE_INSENSITIVE' - """ - ) + """) if stage_file_path and not keep_staged_files: client.execute_sql(f"REMOVE {stage_file_path}") From 97c55127f3c936ee33d7ffc38b8c50a5d634e833 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Mon, 12 Feb 2024 02:33:19 +0100 Subject: [PATCH 02/28] remove unused exception --- dlt/extract/exceptions.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/dlt/extract/exceptions.py b/dlt/extract/exceptions.py index de785865c5..d24b6f5250 100644 --- a/dlt/extract/exceptions.py +++ b/dlt/extract/exceptions.py @@ -300,13 +300,6 @@ def __init__(self, resource_name: str, msg: str) -> None: super().__init__(resource_name, f"This resource is not a transformer: {msg}") -class TableNameMissing(DltSourceException): - def __init__(self) -> None: - super().__init__( - """Table name is missing in table template. Please provide a string or a function that takes a data item as an argument""" - ) - - class InconsistentTableTemplate(DltSourceException): def __init__(self, reason: str) -> None: msg = f"A set of table hints provided to the resource is inconsistent: {reason}" From 400d84bb91e85a8ad3efe65e97992954be0785e9 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Mon, 12 Feb 2024 02:34:11 +0100 Subject: [PATCH 03/28] add initial support for replicate write disposition --- dlt/common/schema/typing.py | 35 ++- dlt/common/schema/utils.py | 17 ++ dlt/destinations/job_client_impl.py | 18 +- dlt/destinations/sql_jobs.py | 118 +++++-- dlt/extract/decorators.py | 11 +- dlt/extract/hints.py | 3 + dlt/load/load.py | 2 +- dlt/pipeline/__init__.py | 2 +- dlt/pipeline/pipeline.py | 8 +- .../pipeline/test_replicate_disposition.py | 287 ++++++++++++++++++ tests/load/weaviate/test_weaviate_client.py | 2 +- 11 files changed, 464 insertions(+), 39 deletions(-) create mode 100644 tests/load/pipeline/test_replicate_disposition.py diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index e1ff17115d..82e6621f60 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -61,7 +61,7 @@ "merge_key", ] """Known hints of a column used to declare hint regexes.""" -TWriteDisposition = Literal["skip", "append", "replace", "merge"] +TWriteDisposition = Literal["skip", "append", "replace", "merge", "replicate"] TTableFormat = Literal["iceberg"] TTypeDetections = Literal[ "timestamp", "iso_timestamp", "iso_date", "large_integer", "hexbytes_to_text", "wei_to_double" @@ -150,6 +150,38 @@ class NormalizerInfo(TypedDict, total=True): new_table: bool +class TCdcOperationMapperStr(TypedDict, total=True): + """ + Dictionary that informs dlt which string literals are used + in the change data to identify inserts, updates, and deletes. + """ + + insert: str + update: str + delete: str + + +class TCdcOperationMapperInt(TypedDict, total=True): + """ + Dictionary that informs dlt which integer literals are used + in the change data to identify inserts, updates, and deletes. + """ + + insert: int + update: int + delete: int + + +class TCdcConfig(TypedDict, total=True): + """Dictionary that informs dlt how change data is organized.""" + + operation_column: str + """Name of the column containing the operation type ("insert", "update", or "delete") for the change record.""" + operation_mapper: Union[TCdcOperationMapperStr, TCdcOperationMapperInt] + sequence_column: str + """Name of the column containing a sequence identifier that can be used to order the change records.""" + + # TypedDict that defines properties of a table @@ -166,6 +198,7 @@ class TTableSchema(TypedDict, total=False): columns: TTableSchemaColumns resource: Optional[str] table_format: Optional[TTableFormat] + cdc_config: Optional[TCdcConfig] class TPartialTableSchema(TTableSchema): diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index dc243f50dd..7f494d333f 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -37,6 +37,7 @@ TTypeDetections, TWriteDisposition, TSchemaContract, + TCdcConfig, ) from dlt.common.schema.exceptions import ( CannotCoerceColumnException, @@ -317,6 +318,19 @@ def validate_stored_schema(stored_schema: TStoredSchema) -> None: if parent_table_name not in stored_schema["tables"]: raise ParentTableNotFoundException(table_name, parent_table_name) + # check for "replicate" tables that miss a primary key or "cdc_config" + if table.get("write_disposition") == "replicate": + if len(get_columns_names_with_prop(table, "primary_key", True)) == 0: + raise SchemaException( + f'Primary key missing for table "{table_name}" with "replicate" write' + " disposition." + ) + if "cdc_config" not in table: + raise SchemaException( + f'"cdc_config" missing for table "{table_name}" with "replicate" write' + " disposition." + ) + def migrate_schema(schema_dict: DictStrAny, from_engine: int, to_engine: int) -> TStoredSchema: if from_engine == to_engine: @@ -724,6 +738,7 @@ def new_table( resource: str = None, schema_contract: TSchemaContract = None, table_format: TTableFormat = None, + cdc_config: TCdcConfig = None, ) -> TTableSchema: table: TTableSchema = { "name": table_name, @@ -742,6 +757,8 @@ def new_table( table["schema_contract"] = schema_contract if table_format: table["table_format"] = table_format + if cdc_config is not None: + table["cdc_config"] = cdc_config if validate_schema: validate_dict_ignoring_xkeys( spec=TColumnSchema, diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index e7dc4bcbe2..6df149ae7c 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -250,7 +250,7 @@ def create_table_chain_completed_followup_jobs( write_disposition = table_chain[0]["write_disposition"] if write_disposition == "append": jobs.extend(self._create_append_followup_jobs(table_chain)) - elif write_disposition == "merge": + elif write_disposition in ("merge", "replicate"): jobs.extend(self._create_merge_followup_jobs(table_chain)) elif write_disposition == "replace": jobs.extend(self._create_replace_followup_jobs(table_chain)) @@ -581,10 +581,24 @@ def with_staging_dataset(self) -> Iterator["SqlJobClientBase"]: self.in_staging_mode = False def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: - if table["write_disposition"] == "merge": + if table["write_disposition"] in ("merge", "replicate"): return True elif table["write_disposition"] == "replace" and ( self.config.replace_strategy in ["insert-from-staging", "staging-optimized"] ): return True return False + + def _create_table_update( + self, table_name: str, storage_columns: TTableSchemaColumns + ) -> Sequence[TColumnSchema]: + updates = super()._create_table_update(table_name, storage_columns) + table = self.schema.get_table(table_name) + if "write_disposition" in table and table["write_disposition"] == "replicate": + # operation and sequence columns should only be present in staging table + # not in final table + if not self.in_staging_mode: + op_col = table["cdc_config"]["operation_column"] + seq_col = table["cdc_config"]["sequence_column"] + updates = [d for d in updates if d["name"] not in (op_col, seq_col)] + return updates diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index d0911d0bea..6461b10565 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -6,7 +6,7 @@ from dlt.common.schema.typing import TTableSchema from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages.load_storage import ParsedLoadJobFileName -from dlt.common.utils import uniq_id +from dlt.common.utils import uniq_id, identity from dlt.destinations.exceptions import MergeDispositionException from dlt.destinations.job_impl import NewLoadJobImpl from dlt.destinations.sql_client import SqlClientBase @@ -202,16 +202,19 @@ def gen_delete_temp_table_sql( @classmethod def gen_insert_temp_table_sql( - cls, staging_root_table_name: str, primary_keys: Sequence[str], unique_column: str + cls, + staging_root_table_name: str, + primary_keys: Sequence[str], + unique_column: str, + condition: str = "1 = 1", ) -> Tuple[List[str], str]: temp_table_name = cls._new_temp_table_name("insert") - select_statement = f""" - SELECT {unique_column} - FROM ( - SELECT ROW_NUMBER() OVER (partition BY {", ".join(primary_keys)} ORDER BY (SELECT NULL)) AS _dlt_dedup_rn, {unique_column} - FROM {staging_root_table_name} - ) AS _dlt_dedup_numbered WHERE _dlt_dedup_rn = 1 - """ + select_statement = cls.gen_select_from_deduplicated_sql( + table_name=staging_root_table_name, + select_columns=[unique_column], + key_columns=primary_keys, + condition=condition, + ) return [cls._to_temp_table(select_statement, temp_table_name)], temp_table_name @classmethod @@ -229,6 +232,42 @@ def gen_delete_from_sql( ); """ + @classmethod + def gen_select_from_deduplicated_sql( + cls, + table_name: str, + key_columns: Sequence[str], + escape_identifier: Callable[[str], str] = identity, + columns: Sequence[str] = None, + order_column: str = "(SELECT NULL)", + condition: str = "1 = 1", + select_columns: Sequence[str] = None, + exclude_columns: Sequence[str] = None, + ) -> str: + """Generate SELECT FROM statement that deduplicates records based one or multiple deduplication keys.""" + columns_str = "*" + select_columns_str = "*" + if columns is not None: + columns_str = ", ".join(map(escape_identifier, columns)) + if select_columns is None: + if exclude_columns is None: + exclude_columns = [] + select_columns = [c for c in columns if c not in exclude_columns] + if select_columns is not None: + select_columns_str = ", ".join(map(escape_identifier, select_columns)) + key_columns_str = ", ".join(key_columns) + return f""" + SELECT {select_columns_str} + FROM ( + SELECT + ROW_NUMBER() OVER (partition BY {key_columns_str} ORDER BY {order_column} DESC) AS _dlt_dedup_rn, + {columns_str} + FROM {table_name} + ) AS _dlt_dedup_numbered + WHERE _dlt_dedup_rn = 1 + AND {condition} + """ + @classmethod def _new_temp_table_name(cls, name_prefix: str) -> str: return f"{name_prefix}_{uniq_id()}" @@ -253,6 +292,22 @@ def gen_merge_sql( sql: List[str] = [] root_table = table_chain[0] + escape_identifier = sql_client.capabilities.escape_identifier + escape_literal = sql_client.capabilities.escape_literal + + insert_condition = "1 = 1" + write_disposition = root_table["write_disposition"] + if write_disposition == "replicate": + # define variables specific to "replicate" write disposition + cdc_config = root_table["cdc_config"] + op_col = cdc_config["operation_column"] + seq_col = cdc_config["sequence_column"] + insert_literal = escape_literal(cdc_config["operation_mapper"]["insert"]) + update_literal = escape_literal(cdc_config["operation_mapper"]["update"]) + insert_condition = ( + f"{escape_identifier(op_col)} IN ({insert_literal}, {update_literal})" + ) + # get top level table full identifiers root_table_name = sql_client.make_qualified_table_name(root_table["name"]) with sql_client.with_staging_dataset(staging=True): @@ -260,13 +315,13 @@ def gen_merge_sql( # get merge and primary keys from top level primary_keys = list( map( - sql_client.capabilities.escape_identifier, + escape_identifier, get_columns_names_with_prop(root_table, "primary_key"), ) ) merge_keys = list( map( - sql_client.capabilities.escape_identifier, + escape_identifier, get_columns_names_with_prop(root_table, "merge_key"), ) ) @@ -298,7 +353,7 @@ def gen_merge_sql( " it is not possible to link child tables to it.", ) # get first unique column - unique_column = sql_client.capabilities.escape_identifier(unique_columns[0]) + unique_column = escape_identifier(unique_columns[0]) # create temp table with unique identifier create_delete_temp_table_sql, delete_temp_table_name = cls.gen_delete_temp_table_sql( unique_column, key_table_clauses @@ -339,7 +394,7 @@ def gen_merge_sql( create_insert_temp_table_sql, insert_temp_table_name, ) = cls.gen_insert_temp_table_sql( - staging_root_table_name, primary_keys, unique_column + staging_root_table_name, primary_keys, unique_column, insert_condition ) sql.extend(create_insert_temp_table_sql) @@ -348,23 +403,33 @@ def gen_merge_sql( table_name = sql_client.make_qualified_table_name(table["name"]) with sql_client.with_staging_dataset(staging=True): staging_table_name = sql_client.make_qualified_table_name(table["name"]) - columns = ", ".join( - map( - sql_client.capabilities.escape_identifier, - get_columns_names_with_prop(table, "name"), - ) - ) + columns = get_columns_names_with_prop(table, "name") + if write_disposition == "replicate": + columns = [c for c in columns if c not in (op_col, seq_col)] + column_str = ", ".join(map(escape_identifier, columns)) insert_sql = ( - f"INSERT INTO {table_name}({columns}) SELECT {columns} FROM {staging_table_name}" + f"INSERT INTO {table_name}({column_str}) SELECT {column_str} FROM" + f" {staging_table_name}" ) if len(primary_keys) > 0: if len(table_chain) == 1: - insert_sql = f"""INSERT INTO {table_name}({columns}) - SELECT {columns} FROM ( - SELECT ROW_NUMBER() OVER (partition BY {", ".join(primary_keys)} ORDER BY (SELECT NULL)) AS _dlt_dedup_rn, {columns} - FROM {staging_table_name} - ) AS _dlt_dedup_numbered WHERE _dlt_dedup_rn = 1; - """ + select_sql = cls.gen_select_from_deduplicated_sql( + table_name=staging_table_name, + columns=get_columns_names_with_prop(table, "name"), + key_columns=primary_keys, + escape_identifier=escape_identifier, + ) + if write_disposition == "replicate": + select_sql = cls.gen_select_from_deduplicated_sql( + table_name=staging_table_name, + columns=get_columns_names_with_prop(table, "name"), + key_columns=primary_keys, + escape_identifier=escape_identifier, + order_column=seq_col, + condition=insert_condition, + exclude_columns=[op_col, seq_col], + ) + insert_sql = f"""INSERT INTO {table_name}({column_str}) {select_sql};""" else: uniq_column = unique_column if table.get("parent") is None else root_key_column insert_sql += ( @@ -374,6 +439,5 @@ def gen_merge_sql( if insert_sql.strip()[-1] != ";": insert_sql += ";" sql.append(insert_sql) - # -- DELETE FROM {staging_table_name} WHERE 1=1; return sql diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index d86fd04ef4..46e6f71e8d 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -36,6 +36,7 @@ TAnySchemaColumns, TSchemaContract, TTableFormat, + TCdcConfig, ) from dlt.extract.hints import make_hints from dlt.extract.utils import ( @@ -271,6 +272,7 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + cdc_config: TTableHintTemplate[TCdcConfig] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, ) -> DltResource: ... @@ -288,6 +290,7 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + cdc_config: TTableHintTemplate[TCdcConfig] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, ) -> Callable[[Callable[TResourceFunParams, Any]], DltResource]: ... @@ -305,6 +308,7 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + cdc_config: TTableHintTemplate[TCdcConfig] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, standalone: Literal[True] = True, @@ -323,6 +327,7 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + cdc_config: TTableHintTemplate[TCdcConfig] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, ) -> DltResource: ... @@ -339,6 +344,7 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + cdc_config: TTableHintTemplate[TCdcConfig] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, standalone: bool = False, @@ -374,7 +380,7 @@ def resource( table_name (TTableHintTemplate[str], optional): An table name, if different from `name`. This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. - write_disposition (Literal["skip", "append", "replace", "merge"], optional): Controls how to write data to a table. `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". + write_disposition (Literal["skip", "append", "replace", "merge", "replicate"], optional): Controls how to write data to a table. `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. columns (Sequence[TAnySchemaColumns], optional): A list, dict or pydantic model of column schemas. @@ -418,6 +424,7 @@ def make_resource( merge_key=merge_key, schema_contract=schema_contract, table_format=table_format, + cdc_config=cdc_config, ) return DltResource.from_data( _data, @@ -643,7 +650,7 @@ def transformer( table_name (TTableHintTemplate[str], optional): An table name, if different from `name`. This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. - write_disposition (Literal["skip", "append", "replace", "merge"], optional): Controls how to write data to a table. `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". + write_disposition (Literal["skip", "append", "replace", "merge", "replicate"], optional): Controls how to write data to a table. `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. columns (Sequence[TAnySchemaColumns], optional): A list, dict or pydantic model of column schemas. Typed dictionary describing column names, data types, write disposition and performance hints that gives you full control over the created table schema. diff --git a/dlt/extract/hints.py b/dlt/extract/hints.py index ec4bd56021..f8d2b75cd1 100644 --- a/dlt/extract/hints.py +++ b/dlt/extract/hints.py @@ -12,6 +12,7 @@ TAnySchemaColumns, TTableFormat, TSchemaContract, + TCdcConfig, ) from dlt.common.typing import TDataItem from dlt.common.utils import update_dict_nested @@ -61,6 +62,7 @@ def make_hints( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + cdc_config: TTableHintTemplate[TCdcConfig] = None, ) -> TResourceHints: """A convenience function to create resource hints. Accepts both static and dynamic hints based on data. @@ -80,6 +82,7 @@ def make_hints( columns=clean_columns, # type: ignore schema_contract=schema_contract, # type: ignore table_format=table_format, # type: ignore + cdc_config=cdc_config, # type: ignore ) if not table_name: new_template.pop("name") diff --git a/dlt/load/load.py b/dlt/load/load.py index b0b52d61d6..3df79cf653 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -139,7 +139,7 @@ def w_spool_job( ) logger.info(f"Will load file {file_path} with table name {job_info.table_name}") table = client.get_load_table(job_info.table_name) - if table["write_disposition"] not in ["append", "replace", "merge"]: + if table["write_disposition"] not in ["append", "replace", "merge", "replicate"]: raise LoadClientUnsupportedWriteDisposition( job_info.table_name, table["write_disposition"], file_path ) diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py index 4101e58320..6b7201ecf0 100644 --- a/dlt/pipeline/__init__.py +++ b/dlt/pipeline/__init__.py @@ -238,7 +238,7 @@ def run( * `@dlt.resource`: resource contains the full table schema and that includes the table name. `table_name` will override this property. Use with care! * `@dlt.source`: source contains several resources each with a table schema. `table_name` will override all table names within the source and load the data into single table. - write_disposition (Literal["skip", "append", "replace", "merge"], optional): Controls how to write data to a table. `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". + write_disposition (Literal["skip", "append", "replace", "merge", "replicate"], optional): Controls how to write data to a table. `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". Please note that in case of `dlt.resource` the table schema value will be overwritten and in case of `dlt.source`, the values in all resources will be overwritten. columns (Sequence[TColumnSchema], optional): A list of column schemas. Typed dictionary describing column names, data types, write disposition and performance hints that gives you full control over the created table schema. diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 3fa8da6aee..7db2cc56f5 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -565,7 +565,7 @@ def run( * `@dlt.resource`: resource contains the full table schema and that includes the table name. `table_name` will override this property. Use with care! * `@dlt.source`: source contains several resources each with a table schema. `table_name` will override all table names within the source and load the data into single table. - write_disposition (Literal["skip", "append", "replace", "merge"], optional): Controls how to write data to a table. `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". + write_disposition (Literal["skip", "append", "replace", "merge", "replicate"], optional): Controls how to write data to a table. `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". Please note that in case of `dlt.resource` the table schema value will be overwritten and in case of `dlt.source`, the values in all resources will be overwritten. columns (Sequence[TColumnSchema], optional): A list of column schemas. Typed dictionary describing column names, data types, write disposition and performance hints that gives you full control over the created table schema. @@ -1163,9 +1163,9 @@ def _set_context(self, is_active: bool) -> None: # set destination context on activation if self.destination: # inject capabilities context - self._container[ - DestinationCapabilitiesContext - ] = self._get_destination_capabilities() + self._container[DestinationCapabilitiesContext] = ( + self._get_destination_capabilities() + ) else: # remove destination context on deactivation if DestinationCapabilitiesContext in self._container: diff --git a/tests/load/pipeline/test_replicate_disposition.py b/tests/load/pipeline/test_replicate_disposition.py new file mode 100644 index 0000000000..dbae4c1d77 --- /dev/null +++ b/tests/load/pipeline/test_replicate_disposition.py @@ -0,0 +1,287 @@ +import pytest + +import dlt + +from dlt.common.typing import TDataItems +from dlt.common.schema.typing import TCdcConfig +from dlt.common.schema.exceptions import SchemaException +from dlt.extract import DltResource + +from tests.pipeline.utils import assert_load_info +from tests.load.pipeline.utils import load_table_counts, select_data +from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration + + +CDC_CONFIGS = [ + { + "operation_column": "operation", + "operation_mapper": {"insert": "I", "update": "U", "delete": "D"}, + "sequence_column": "lsn", + }, + { + "operation_column": "op", + "operation_mapper": {"insert": 1, "update": 2, "delete": 3}, + "sequence_column": "commit_id", + }, +] + + +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) +@pytest.mark.parametrize("cdc_config", CDC_CONFIGS) +def test_replicate_core_functionality( + destination_config: DestinationTestConfiguration, + cdc_config: TCdcConfig, +) -> None: + op_col = cdc_config["operation_column"] + seq_col = cdc_config["sequence_column"] + op_map = cdc_config["operation_mapper"] + i = op_map["insert"] + u = op_map["update"] + d = op_map["delete"] + + # define batches of CDC data + batches_simple = [ + [ + {"id": 1, "val": "foo", op_col: i, seq_col: 1}, + {"id": 2, "val": "bar", op_col: i, seq_col: 2}, + ], + [ + {"id": 1, "val": "foo_new", op_col: i, seq_col: 3}, + {"id": 3, "val": "baz", op_col: i, seq_col: 4}, + ], + [ + {"id": 2, "val": "bar_new", op_col: u, seq_col: 5}, + ], + [ + {"id": 4, "val": "foo", op_col: u, seq_col: 6}, + ], + [ + {"id": 2, op_col: d, seq_col: 7}, + {"id": 2, "val": "bar_new_new", op_col: i, seq_col: 8}, + ], + [ + {"id": 5, "val": "foo", op_col: i, seq_col: 9}, + {"id": 5, "val": "foo_new", op_col: u, seq_col: 10}, + ], + [ + {"id": 6, "val": "foo", op_col: i, seq_col: 11}, + {"id": 6, op_col: d, seq_col: 12}, + ], + [ + {"id": 1, op_col: d, seq_col: 13}, + ], + ] + + table_name = "test_replicate_core_functionality" + + @dlt.resource( + table_name=table_name, + write_disposition="replicate", + primary_key="id", + cdc_config=cdc_config, + ) + def data_resource(batches: TDataItems, batch: int): + yield batches[batch] + + p = destination_config.setup_pipeline("pl_test_replicate_core_functionality", full_refresh=True) + + # insert keys in a new empty table + info = p.run(data_resource(batches_simple, batch=0)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 2 + + # insert a key that already exists (unexpected scenario) + info = p.run(data_resource(batches_simple, batch=1)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 3 + + # update a key that already exists + info = p.run(data_resource(batches_simple, batch=2)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 3 + + # update a key that doesn't exist yet (unexpected scenario) + info = p.run(data_resource(batches_simple, batch=3)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 4 + + # delete an existing key, then insert it again + info = p.run(data_resource(batches_simple, batch=4)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 4 + + # insert a new key, then update it + info = p.run(data_resource(batches_simple, batch=5)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 5 + + # insert a new key, then delete it + info = p.run(data_resource(batches_simple, batch=6)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 5 + + # delete an existing key + info = p.run(data_resource(batches_simple, batch=7)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 4 + + # compare observed records with expected records + qual_name = p.sql_client().make_qualified_table_name(table_name) + observed = [ + {"id": row[0], "val": row[1]} for row in select_data(p, f"SELECT id, val FROM {qual_name}") + ] + expected = [ + {"id": 2, "val": "bar_new_new"}, + {"id": 3, "val": "baz"}, + {"id": 4, "val": "foo"}, + {"id": 5, "val": "foo_new"}, + ] + assert sorted(observed, key=lambda d: d["id"]) == expected + + +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) +@pytest.mark.parametrize("cdc_config", [CDC_CONFIGS[0]]) +def test_replicate_complex_data( + destination_config: DestinationTestConfiguration, + cdc_config: TCdcConfig, +) -> None: + op_col = cdc_config["operation_column"] + seq_col = cdc_config["sequence_column"] + op_map = cdc_config["operation_mapper"] + i = op_map["insert"] + u = op_map["update"] + d = op_map["delete"] + + # define batches of CDC data + batches_complex = [ + [ + {"id": 1, "val": ["foo", "bar"], op_col: i, seq_col: 1}, + {"id": 2, "val": ["baz"], op_col: i, seq_col: 2}, + ], + [ + {"id": 1, "val": ["foo", "bar", "baz"], op_col: u, seq_col: 3}, + {"id": 3, "val": ["foo"], op_col: i, seq_col: 4}, + ], + [ + {"id": 1, op_col: d, seq_col: 5}, + {"id": 4, "val": ["foo", "bar"], op_col: i, seq_col: 6}, + ], + ] + + table_name = "test_replicate_complex_data" + + @dlt.resource( + table_name=table_name, + write_disposition="replicate", + primary_key="id", + cdc_config=cdc_config, + ) + def data_resource(batches: TDataItems, batch: int): + yield batches[batch] + + # nesting disabled -> no child table + @dlt.source(max_table_nesting=0) + def data_nesting_disabled(batches: TDataItems, batch: int) -> DltResource: + return data_resource(batches, batch) + + p = destination_config.setup_pipeline("pl_test_replicate_complex_data", full_refresh=True) + info = p.run(data_nesting_disabled(batches_complex, batch=0)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 2 + + info = p.run(data_nesting_disabled(batches_complex, batch=1)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 3 + + info = p.run(data_nesting_disabled(batches_complex, batch=2)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 3 + + # nesting enabled -> child table + @dlt.source(max_table_nesting=1, root_key=True) + def data_nesting_enabled(batches: TDataItems, batch: int) -> DltResource: + return data_resource(batches, batch) + + p = destination_config.setup_pipeline("pl_test_replicate_complex_data", full_refresh=True) + + info = p.run(data_nesting_enabled(batches_complex, batch=0)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 2 + assert load_table_counts(p, table_name + "__val")[table_name + "__val"] == 3 + + info = p.run(data_nesting_enabled(batches_complex, batch=1)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 3 + assert load_table_counts(p, table_name + "__val")[table_name + "__val"] == 5 + + info = p.run(data_nesting_enabled(batches_complex, batch=2)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 3 + assert load_table_counts(p, table_name + "__val")[table_name + "__val"] == 4 + + # compare observed records with expected records + qual_name = p.sql_client().make_qualified_table_name(table_name) + observed = [row[0] for row in select_data(p, f"SELECT id FROM {qual_name}")] + expected = [2, 3, 4] + assert sorted(observed) == expected + qual_name = p.sql_client().make_qualified_table_name(table_name + "__val") + observed = [row[0] for row in select_data(p, f"SELECT value FROM {qual_name}")] + expected = ["bar", "baz", "foo", "foo"] # type: ignore[list-item] + assert sorted(observed) == expected + + +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) +@pytest.mark.parametrize("cdc_config", [CDC_CONFIGS[0]]) +def test_replicate_missing_config( + destination_config: DestinationTestConfiguration, + cdc_config: TCdcConfig, +) -> None: + op_col = cdc_config["operation_column"] + seq_col = cdc_config["sequence_column"] + op_map = cdc_config["operation_mapper"] + i = op_map["insert"] + u = op_map["update"] + + # define batches of CDC data + batches_simple = [ + [ + {"id": 1, "val": "foo", op_col: i, seq_col: 1}, + {"id": 2, "val": "bar", op_col: i, seq_col: 2}, + ], + [ + {"id": 1, "val": "foo_new", op_col: u, seq_col: 3}, + {"id": 3, "val": "baz", op_col: i, seq_col: 4}, + ], + ] + + @dlt.resource( + table_name="test_replicate_no_pk", + write_disposition="replicate", + cdc_config=cdc_config, + ) + def data_resource_no_pk(batches: TDataItems, batch: int): + yield batches[batch] + + # a SchemaException should be raised when using "replicate" and no primary key is specified + p = destination_config.setup_pipeline("pl_test_replicate_no_pk", full_refresh=True) + with pytest.raises(SchemaException): + p.run(data_resource_no_pk(batches_simple, batch=0)) + + @dlt.resource( + table_name="test_replicate_no_cdc_config", + write_disposition="replicate", + primary_key="id", + ) + def data_resource_no_cdc_config(batches: TDataItems, batch: int): + yield batches[batch] + + # a SchemaException should be raised when using "replicate" and no "cdc_config" is specified + p = destination_config.setup_pipeline("pl_test_replicate_no_cdc_config", full_refresh=True) + with pytest.raises(SchemaException): + p.run(data_resource_no_pk(batches_simple, batch=0)) diff --git a/tests/load/weaviate/test_weaviate_client.py b/tests/load/weaviate/test_weaviate_client.py index 48153f7706..d6e337a2dc 100644 --- a/tests/load/weaviate/test_weaviate_client.py +++ b/tests/load/weaviate/test_weaviate_client.py @@ -67,7 +67,7 @@ def file_storage() -> FileStorage: return FileStorage(TEST_STORAGE_ROOT, file_type="b", makedirs=True) -@pytest.mark.parametrize("write_disposition", ["append", "replace", "merge"]) +@pytest.mark.parametrize("write_disposition", ["append", "replace", "merge", "replicate"]) def test_all_data_types( client: WeaviateClient, write_disposition: TWriteDisposition, file_storage: FileStorage ) -> None: From 24f362e96b4a12593a2b4e6063dae69cfe91dc33 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Wed, 14 Feb 2024 02:39:03 +0100 Subject: [PATCH 04/28] add hard_delete hint and sorted deduplication for merge --- dlt/common/schema/typing.py | 36 +-- dlt/common/schema/utils.py | 24 +- dlt/destinations/job_client_impl.py | 15 +- dlt/destinations/sql_jobs.py | 159 +++++----- dlt/extract/decorators.py | 11 +- dlt/extract/hints.py | 3 - dlt/load/load.py | 13 +- dlt/pipeline/__init__.py | 2 +- dlt/pipeline/pipeline.py | 2 +- tests/.dlt/config.toml | 23 ++ tests/load/pipeline/test_merge_disposition.py | 210 ++++++++++++- .../pipeline/test_replicate_disposition.py | 287 ------------------ tests/load/weaviate/test_weaviate_client.py | 2 +- 13 files changed, 329 insertions(+), 458 deletions(-) delete mode 100644 tests/load/pipeline/test_replicate_disposition.py diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index 82e6621f60..65eb0afff1 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -61,7 +61,7 @@ "merge_key", ] """Known hints of a column used to declare hint regexes.""" -TWriteDisposition = Literal["skip", "append", "replace", "merge", "replicate"] +TWriteDisposition = Literal["skip", "append", "replace", "merge"] TTableFormat = Literal["iceberg"] TTypeDetections = Literal[ "timestamp", "iso_timestamp", "iso_date", "large_integer", "hexbytes_to_text", "wei_to_double" @@ -112,6 +112,7 @@ class TColumnSchema(TColumnSchemaBase, total=False): root_key: Optional[bool] merge_key: Optional[bool] variant: Optional[bool] + hard_delete: Optional[bool] TTableSchemaColumns = Dict[str, TColumnSchema] @@ -150,38 +151,6 @@ class NormalizerInfo(TypedDict, total=True): new_table: bool -class TCdcOperationMapperStr(TypedDict, total=True): - """ - Dictionary that informs dlt which string literals are used - in the change data to identify inserts, updates, and deletes. - """ - - insert: str - update: str - delete: str - - -class TCdcOperationMapperInt(TypedDict, total=True): - """ - Dictionary that informs dlt which integer literals are used - in the change data to identify inserts, updates, and deletes. - """ - - insert: int - update: int - delete: int - - -class TCdcConfig(TypedDict, total=True): - """Dictionary that informs dlt how change data is organized.""" - - operation_column: str - """Name of the column containing the operation type ("insert", "update", or "delete") for the change record.""" - operation_mapper: Union[TCdcOperationMapperStr, TCdcOperationMapperInt] - sequence_column: str - """Name of the column containing a sequence identifier that can be used to order the change records.""" - - # TypedDict that defines properties of a table @@ -198,7 +167,6 @@ class TTableSchema(TypedDict, total=False): columns: TTableSchemaColumns resource: Optional[str] table_format: Optional[TTableFormat] - cdc_config: Optional[TCdcConfig] class TPartialTableSchema(TTableSchema): diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index 7f494d333f..b3323a3673 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -37,7 +37,6 @@ TTypeDetections, TWriteDisposition, TSchemaContract, - TCdcConfig, ) from dlt.common.schema.exceptions import ( CannotCoerceColumnException, @@ -318,19 +317,6 @@ def validate_stored_schema(stored_schema: TStoredSchema) -> None: if parent_table_name not in stored_schema["tables"]: raise ParentTableNotFoundException(table_name, parent_table_name) - # check for "replicate" tables that miss a primary key or "cdc_config" - if table.get("write_disposition") == "replicate": - if len(get_columns_names_with_prop(table, "primary_key", True)) == 0: - raise SchemaException( - f'Primary key missing for table "{table_name}" with "replicate" write' - " disposition." - ) - if "cdc_config" not in table: - raise SchemaException( - f'"cdc_config" missing for table "{table_name}" with "replicate" write' - " disposition." - ) - def migrate_schema(schema_dict: DictStrAny, from_engine: int, to_engine: int) -> TStoredSchema: if from_engine == to_engine: @@ -587,6 +573,13 @@ def get_columns_names_with_prop( ] +def has_column_with_prop( + table: TTableSchema, column_prop: Union[TColumnProp, str], include_incomplete: bool = False +) -> bool: + """Checks if `table` schema contains column with property `column_prop`.""" + return len(get_columns_names_with_prop(table, column_prop, include_incomplete)) > 0 + + def merge_schema_updates(schema_updates: Sequence[TSchemaUpdate]) -> TSchemaTables: aggregated_update: TSchemaTables = {} for schema_update in schema_updates: @@ -738,7 +731,6 @@ def new_table( resource: str = None, schema_contract: TSchemaContract = None, table_format: TTableFormat = None, - cdc_config: TCdcConfig = None, ) -> TTableSchema: table: TTableSchema = { "name": table_name, @@ -757,8 +749,6 @@ def new_table( table["schema_contract"] = schema_contract if table_format: table["table_format"] = table_format - if cdc_config is not None: - table["cdc_config"] = cdc_config if validate_schema: validate_dict_ignoring_xkeys( spec=TColumnSchema, diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 6df149ae7c..649ad8c6e2 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -35,6 +35,7 @@ ) from dlt.common.storages import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns, TSchemaTables +from dlt.common.schema.utils import get_columns_names_with_prop, has_column_with_prop from dlt.common.destination.reference import ( StateInfo, StorageSchemaInfo, @@ -250,7 +251,7 @@ def create_table_chain_completed_followup_jobs( write_disposition = table_chain[0]["write_disposition"] if write_disposition == "append": jobs.extend(self._create_append_followup_jobs(table_chain)) - elif write_disposition in ("merge", "replicate"): + elif write_disposition == "merge": jobs.extend(self._create_merge_followup_jobs(table_chain)) elif write_disposition == "replace": jobs.extend(self._create_replace_followup_jobs(table_chain)) @@ -581,7 +582,7 @@ def with_staging_dataset(self) -> Iterator["SqlJobClientBase"]: self.in_staging_mode = False def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: - if table["write_disposition"] in ("merge", "replicate"): + if table["write_disposition"] == "merge": return True elif table["write_disposition"] == "replace" and ( self.config.replace_strategy in ["insert-from-staging", "staging-optimized"] @@ -594,11 +595,9 @@ def _create_table_update( ) -> Sequence[TColumnSchema]: updates = super()._create_table_update(table_name, storage_columns) table = self.schema.get_table(table_name) - if "write_disposition" in table and table["write_disposition"] == "replicate": - # operation and sequence columns should only be present in staging table - # not in final table + if has_column_with_prop(table, "hard_delete"): + # hard_delete column should only be present in staging table, not in final table if not self.in_staging_mode: - op_col = table["cdc_config"]["operation_column"] - seq_col = table["cdc_config"]["sequence_column"] - updates = [d for d in updates if d["name"] not in (op_col, seq_col)] + hard_delete_column = get_columns_names_with_prop(table, "hard_delete")[0] + updates = [d for d in updates if d["name"] != hard_delete_column] return updates diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index 6461b10565..821dac3183 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -4,9 +4,9 @@ from dlt.common.runtime.logger import pretty_format_exception from dlt.common.schema.typing import TTableSchema -from dlt.common.schema.utils import get_columns_names_with_prop +from dlt.common.schema.utils import get_columns_names_with_prop, has_column_with_prop from dlt.common.storages.load_storage import ParsedLoadJobFileName -from dlt.common.utils import uniq_id, identity +from dlt.common.utils import uniq_id from dlt.destinations.exceptions import MergeDispositionException from dlt.destinations.job_impl import NewLoadJobImpl from dlt.destinations.sql_client import SqlClientBase @@ -147,6 +147,8 @@ def generate_sql( First we store the root_keys of root table elements to be deleted in the temp table. Then we use the temp table to delete records from root and all child tables in the destination dataset. At the end we copy the data from the staging dataset into destination dataset. + + If sort and/or hard_delete column hints are provided, records are deleted from the staging dataset before its data is copied to the destination dataset. """ return cls.gen_merge_sql(table_chain, sql_client) @@ -202,19 +204,16 @@ def gen_delete_temp_table_sql( @classmethod def gen_insert_temp_table_sql( - cls, - staging_root_table_name: str, - primary_keys: Sequence[str], - unique_column: str, - condition: str = "1 = 1", + cls, staging_root_table_name: str, primary_keys: Sequence[str], unique_column: str ) -> Tuple[List[str], str]: temp_table_name = cls._new_temp_table_name("insert") - select_statement = cls.gen_select_from_deduplicated_sql( - table_name=staging_root_table_name, - select_columns=[unique_column], - key_columns=primary_keys, - condition=condition, - ) + select_statement = f""" + SELECT {unique_column} + FROM ( + SELECT ROW_NUMBER() OVER (partition BY {", ".join(primary_keys)} ORDER BY (SELECT NULL)) AS _dlt_dedup_rn, {unique_column} + FROM {staging_root_table_name} + ) AS _dlt_dedup_numbered WHERE _dlt_dedup_rn = 1 + """ return [cls._to_temp_table(select_statement, temp_table_name)], temp_table_name @classmethod @@ -232,42 +231,6 @@ def gen_delete_from_sql( ); """ - @classmethod - def gen_select_from_deduplicated_sql( - cls, - table_name: str, - key_columns: Sequence[str], - escape_identifier: Callable[[str], str] = identity, - columns: Sequence[str] = None, - order_column: str = "(SELECT NULL)", - condition: str = "1 = 1", - select_columns: Sequence[str] = None, - exclude_columns: Sequence[str] = None, - ) -> str: - """Generate SELECT FROM statement that deduplicates records based one or multiple deduplication keys.""" - columns_str = "*" - select_columns_str = "*" - if columns is not None: - columns_str = ", ".join(map(escape_identifier, columns)) - if select_columns is None: - if exclude_columns is None: - exclude_columns = [] - select_columns = [c for c in columns if c not in exclude_columns] - if select_columns is not None: - select_columns_str = ", ".join(map(escape_identifier, select_columns)) - key_columns_str = ", ".join(key_columns) - return f""" - SELECT {select_columns_str} - FROM ( - SELECT - ROW_NUMBER() OVER (partition BY {key_columns_str} ORDER BY {order_column} DESC) AS _dlt_dedup_rn, - {columns_str} - FROM {table_name} - ) AS _dlt_dedup_numbered - WHERE _dlt_dedup_rn = 1 - AND {condition} - """ - @classmethod def _new_temp_table_name(cls, name_prefix: str) -> str: return f"{name_prefix}_{uniq_id()}" @@ -291,23 +254,9 @@ def gen_merge_sql( ) -> List[str]: sql: List[str] = [] root_table = table_chain[0] - escape_identifier = sql_client.capabilities.escape_identifier escape_literal = sql_client.capabilities.escape_literal - insert_condition = "1 = 1" - write_disposition = root_table["write_disposition"] - if write_disposition == "replicate": - # define variables specific to "replicate" write disposition - cdc_config = root_table["cdc_config"] - op_col = cdc_config["operation_column"] - seq_col = cdc_config["sequence_column"] - insert_literal = escape_literal(cdc_config["operation_mapper"]["insert"]) - update_literal = escape_literal(cdc_config["operation_mapper"]["update"]) - insert_condition = ( - f"{escape_identifier(op_col)} IN ({insert_literal}, {update_literal})" - ) - # get top level table full identifiers root_table_name = sql_client.make_qualified_table_name(root_table["name"]) with sql_client.with_staging_dataset(staging=True): @@ -374,7 +323,7 @@ def gen_merge_sql( f" {table['name']} so it is not possible to refer to top level table" f" {root_table['name']} unique column {unique_column}", ) - root_key_column = sql_client.capabilities.escape_identifier(root_key_columns[0]) + root_key_column = escape_identifier(root_key_columns[0]) sql.append( cls.gen_delete_from_sql( table_name, root_key_column, delete_temp_table_name, unique_column @@ -388,48 +337,80 @@ def gen_merge_sql( ) ) + # remove "non-latest" records from staging table (deduplicate) if a sort column is provided + if len(primary_keys) > 0: + if has_column_with_prop(root_table, "sort"): + sort_column = escape_identifier(get_columns_names_with_prop(root_table, "sort")[0]) + sql.append(f""" + DELETE FROM {staging_root_table_name} + WHERE {sort_column} IN ( + SELECT {sort_column} FROM ( + SELECT {sort_column}, ROW_NUMBER() OVER (partition BY {", ".join(primary_keys)} ORDER BY {sort_column} DESC) AS _rn + FROM {staging_root_table_name} + ) AS a + WHERE a._rn > 1 + ); + """) + + # remove deleted records from staging tables if a hard_delete column is provided + if has_column_with_prop(root_table, "hard_delete"): + hard_delete_column = escape_identifier( + get_columns_names_with_prop(root_table, "hard_delete")[0] + ) + # first delete from root staging table + sql.append(f""" + DELETE FROM {staging_root_table_name} + WHERE {hard_delete_column} IS NOT DISTINCT FROM {escape_literal(True)}; + """) + # then delete from child staging tables + for table in table_chain[1:]: + with sql_client.with_staging_dataset(staging=True): + staging_table_name = sql_client.make_qualified_table_name(table["name"]) + sql.append(f""" + DELETE FROM {staging_table_name} + WHERE NOT EXISTS ( + SELECT 1 FROM {staging_root_table_name} AS p + WHERE {staging_table_name}.{root_key_column} = p.{unique_column} + ); + """) + + if len(table_chain) > 1: # create temp table used to deduplicate, only when we have primary keys if primary_keys: ( create_insert_temp_table_sql, insert_temp_table_name, ) = cls.gen_insert_temp_table_sql( - staging_root_table_name, primary_keys, unique_column, insert_condition + staging_root_table_name, primary_keys, unique_column ) sql.extend(create_insert_temp_table_sql) - # insert from staging to dataset, truncate staging table + # insert from staging to dataset for table in table_chain: table_name = sql_client.make_qualified_table_name(table["name"]) with sql_client.with_staging_dataset(staging=True): staging_table_name = sql_client.make_qualified_table_name(table["name"]) - columns = get_columns_names_with_prop(table, "name") - if write_disposition == "replicate": - columns = [c for c in columns if c not in (op_col, seq_col)] - column_str = ", ".join(map(escape_identifier, columns)) + columns = ", ".join( + map( + escape_identifier, + [ + c + for c in get_columns_names_with_prop(table, "name") + if c not in get_columns_names_with_prop(table, "hard_delete") + ], + ) + ) insert_sql = ( - f"INSERT INTO {table_name}({column_str}) SELECT {column_str} FROM" - f" {staging_table_name}" + f"INSERT INTO {table_name}({columns}) SELECT {columns} FROM {staging_table_name}" ) if len(primary_keys) > 0: if len(table_chain) == 1: - select_sql = cls.gen_select_from_deduplicated_sql( - table_name=staging_table_name, - columns=get_columns_names_with_prop(table, "name"), - key_columns=primary_keys, - escape_identifier=escape_identifier, - ) - if write_disposition == "replicate": - select_sql = cls.gen_select_from_deduplicated_sql( - table_name=staging_table_name, - columns=get_columns_names_with_prop(table, "name"), - key_columns=primary_keys, - escape_identifier=escape_identifier, - order_column=seq_col, - condition=insert_condition, - exclude_columns=[op_col, seq_col], - ) - insert_sql = f"""INSERT INTO {table_name}({column_str}) {select_sql};""" + insert_sql = f"""INSERT INTO {table_name}({columns}) + SELECT {columns} FROM ( + SELECT ROW_NUMBER() OVER (partition BY {", ".join(primary_keys)} ORDER BY (SELECT NULL)) AS _dlt_dedup_rn, {columns} + FROM {staging_table_name} + ) AS _dlt_dedup_numbered WHERE _dlt_dedup_rn = 1; + """ else: uniq_column = unique_column if table.get("parent") is None else root_key_column insert_sql += ( diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index 46e6f71e8d..d86fd04ef4 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -36,7 +36,6 @@ TAnySchemaColumns, TSchemaContract, TTableFormat, - TCdcConfig, ) from dlt.extract.hints import make_hints from dlt.extract.utils import ( @@ -272,7 +271,6 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, - cdc_config: TTableHintTemplate[TCdcConfig] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, ) -> DltResource: ... @@ -290,7 +288,6 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, - cdc_config: TTableHintTemplate[TCdcConfig] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, ) -> Callable[[Callable[TResourceFunParams, Any]], DltResource]: ... @@ -308,7 +305,6 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, - cdc_config: TTableHintTemplate[TCdcConfig] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, standalone: Literal[True] = True, @@ -327,7 +323,6 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, - cdc_config: TTableHintTemplate[TCdcConfig] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, ) -> DltResource: ... @@ -344,7 +339,6 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, - cdc_config: TTableHintTemplate[TCdcConfig] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, standalone: bool = False, @@ -380,7 +374,7 @@ def resource( table_name (TTableHintTemplate[str], optional): An table name, if different from `name`. This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. - write_disposition (Literal["skip", "append", "replace", "merge", "replicate"], optional): Controls how to write data to a table. `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". + write_disposition (Literal["skip", "append", "replace", "merge"], optional): Controls how to write data to a table. `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. columns (Sequence[TAnySchemaColumns], optional): A list, dict or pydantic model of column schemas. @@ -424,7 +418,6 @@ def make_resource( merge_key=merge_key, schema_contract=schema_contract, table_format=table_format, - cdc_config=cdc_config, ) return DltResource.from_data( _data, @@ -650,7 +643,7 @@ def transformer( table_name (TTableHintTemplate[str], optional): An table name, if different from `name`. This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. - write_disposition (Literal["skip", "append", "replace", "merge", "replicate"], optional): Controls how to write data to a table. `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". + write_disposition (Literal["skip", "append", "replace", "merge"], optional): Controls how to write data to a table. `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. columns (Sequence[TAnySchemaColumns], optional): A list, dict or pydantic model of column schemas. Typed dictionary describing column names, data types, write disposition and performance hints that gives you full control over the created table schema. diff --git a/dlt/extract/hints.py b/dlt/extract/hints.py index f8d2b75cd1..ec4bd56021 100644 --- a/dlt/extract/hints.py +++ b/dlt/extract/hints.py @@ -12,7 +12,6 @@ TAnySchemaColumns, TTableFormat, TSchemaContract, - TCdcConfig, ) from dlt.common.typing import TDataItem from dlt.common.utils import update_dict_nested @@ -62,7 +61,6 @@ def make_hints( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, - cdc_config: TTableHintTemplate[TCdcConfig] = None, ) -> TResourceHints: """A convenience function to create resource hints. Accepts both static and dynamic hints based on data. @@ -82,7 +80,6 @@ def make_hints( columns=clean_columns, # type: ignore schema_contract=schema_contract, # type: ignore table_format=table_format, # type: ignore - cdc_config=cdc_config, # type: ignore ) if not table_name: new_template.pop("name") diff --git a/dlt/load/load.py b/dlt/load/load.py index 3df79cf653..0ecdd1f13b 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -21,6 +21,7 @@ ) from dlt.common.schema import Schema, TSchemaTables from dlt.common.schema.typing import TTableSchema, TWriteDisposition +from dlt.common.schema.utils import has_column_with_prop from dlt.common.storages import LoadStorage from dlt.common.destination.reference import ( DestinationClientDwhConfiguration, @@ -139,7 +140,7 @@ def w_spool_job( ) logger.info(f"Will load file {file_path} with table name {job_info.table_name}") table = client.get_load_table(job_info.table_name) - if table["write_disposition"] not in ["append", "replace", "merge", "replicate"]: + if table["write_disposition"] not in ["append", "replace", "merge"]: raise LoadClientUnsupportedWriteDisposition( job_info.table_name, table["write_disposition"], file_path ) @@ -246,8 +247,14 @@ def get_completed_table_chain( for job in table_jobs ): return None - # if there are no jobs for the table, skip it, unless the write disposition is replace, as we need to create and clear the child tables - if not table_jobs and top_merged_table["write_disposition"] != "replace": + # if there are no jobs for the table, skip it, unless child tables need to be replaced + needs_replacement = False + if top_merged_table["write_disposition"] == "replace" or ( + top_merged_table["write_disposition"] == "merge" + and has_column_with_prop(top_merged_table, "hard_delete") + ): + needs_replacement = True + if not table_jobs and not needs_replacement: continue table_chain.append(table) # there must be at least table diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py index 6b7201ecf0..4101e58320 100644 --- a/dlt/pipeline/__init__.py +++ b/dlt/pipeline/__init__.py @@ -238,7 +238,7 @@ def run( * `@dlt.resource`: resource contains the full table schema and that includes the table name. `table_name` will override this property. Use with care! * `@dlt.source`: source contains several resources each with a table schema. `table_name` will override all table names within the source and load the data into single table. - write_disposition (Literal["skip", "append", "replace", "merge", "replicate"], optional): Controls how to write data to a table. `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". + write_disposition (Literal["skip", "append", "replace", "merge"], optional): Controls how to write data to a table. `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". Please note that in case of `dlt.resource` the table schema value will be overwritten and in case of `dlt.source`, the values in all resources will be overwritten. columns (Sequence[TColumnSchema], optional): A list of column schemas. Typed dictionary describing column names, data types, write disposition and performance hints that gives you full control over the created table schema. diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 7db2cc56f5..73c8f076d1 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -565,7 +565,7 @@ def run( * `@dlt.resource`: resource contains the full table schema and that includes the table name. `table_name` will override this property. Use with care! * `@dlt.source`: source contains several resources each with a table schema. `table_name` will override all table names within the source and load the data into single table. - write_disposition (Literal["skip", "append", "replace", "merge", "replicate"], optional): Controls how to write data to a table. `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". + write_disposition (Literal["skip", "append", "replace", "merge"], optional): Controls how to write data to a table. `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". Please note that in case of `dlt.resource` the table schema value will be overwritten and in case of `dlt.source`, the values in all resources will be overwritten. columns (Sequence[TColumnSchema], optional): A list of column schemas. Typed dictionary describing column names, data types, write disposition and performance hints that gives you full control over the created table schema. diff --git a/tests/.dlt/config.toml b/tests/.dlt/config.toml index 1eac35d306..137f67eab8 100644 --- a/tests/.dlt/config.toml +++ b/tests/.dlt/config.toml @@ -1,3 +1,26 @@ +ACTIVE_DESTINATIONS = ["postgres"] +#ACTIVE_DESTINATIONS = ["mssql", "duckdb", "postgres"] + +[load] +raise_on_max_retries = 1 + +[destination.mssql] +#merge_strategy = "cdc" + +[destination.mssql.credentials] +database = "dlt_data" +username = "loader" +host = "localhost" +port = 1433 +driver = "ODBC Driver 17 for SQL Server" + +[destination.postgres.credentials] +database = "dlt_data" +username = "loader" +host = "localhost" +port = 5432 +connect_timeout = 15 + [runtime] sentry_dsn="https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752" diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index 0714ac333d..7e238d7871 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -448,17 +448,18 @@ def duplicates(): counts = load_table_counts(p, "duplicates", "duplicates__child") assert counts["duplicates"] == 1 if destination_config.supports_merge else 2 assert counts["duplicates__child"] == 3 if destination_config.supports_merge else 6 - qual_name = p.sql_client().make_qualified_table_name("duplicates") - select_data(p, f"SELECT * FROM {qual_name}")[0] - @dlt.resource(write_disposition="merge", primary_key=("id", "subkey")) + @dlt.resource(write_disposition="merge", primary_key="id") def duplicates_no_child(): - yield [{"id": 1, "subkey": "AX", "name": "row1"}, {"id": 1, "subkey": "AX", "name": "row2"}] + yield [ + {"id": 1, "name": "row1"}, + {"id": 1, "name": "row2"}, + ] info = p.run(duplicates_no_child()) assert_load_info(info) counts = load_table_counts(p, "duplicates_no_child") - assert counts["duplicates_no_child"] == 1 if destination_config.supports_merge else 2 + assert counts["duplicates_no_child"] == 1 @pytest.mark.parametrize( @@ -488,3 +489,202 @@ def duplicates_no_child(): assert_load_info(info) counts = load_table_counts(p, "duplicates_no_child") assert counts["duplicates_no_child"] == 2 + + +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) +@pytest.mark.parametrize("key_type", ["primary_key", "merge_key"]) +def test_hard_delete_hint(destination_config: DestinationTestConfiguration, key_type: str) -> None: + table_name = f"test_hard_delete_hint_{key_type}" + + @dlt.resource( + name=table_name, + write_disposition="merge", + columns={"deleted": {"hard_delete": True}}, + ) + def data_resource(data): + yield data + + if key_type == "primary_key": + data_resource.apply_hints(primary_key="id", merge_key="") + elif key_type == "merge_key": + data_resource.apply_hints(primary_key="", merge_key="id") + + p = destination_config.setup_pipeline(f"test_hard_delete_hint_{key_type}", full_refresh=True) + + # insert two records + data = [ + {"id": 1, "val": "foo", "deleted": False}, + {"id": 2, "val": "bar", "deleted": False}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 2 + + # delete one record + data = [ + {"id": 1, "deleted": True}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + + # update one record (None for hard_delete column is treated as "not True") + data = [ + {"id": 2, "val": "baz", "deleted": None}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + + # compare observed records with expected records + qual_name = p.sql_client().make_qualified_table_name(table_name) + observed = [ + {"id": row[0], "val": row[1]} for row in select_data(p, f"SELECT id, val FROM {qual_name}") + ] + expected = [{"id": 2, "val": "baz"}] + assert sorted(observed, key=lambda d: d["id"]) == expected + + # insert and delete the same record, record will be inserted if no sort column is provided + data = [ + {"id": 3, "val": "foo", "deleted": None}, + {"id": 3, "deleted": True}, + # {"id": 4, "deleted": True}, + # {"id": 4, "val": "foo", "deleted": None}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 2 + + with p.destination_client(p.default_schema_name) as client: + _, table = client.get_storage_table(table_name) # type: ignore[attr-defined] + column_names = table.keys() + # ensure hard_delete column is not propagated to final table + assert "deleted" not in column_names + + p = destination_config.setup_pipeline( + f"test_hard_delete_hint_{key_type}_complex", full_refresh=True + ) + + # insert two records + data = [ + {"id": 1, "val": ["foo", "bar"], "deleted": False}, + {"id": 2, "val": ["baz"], "deleted": False}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 2 + assert load_table_counts(p, table_name + "__val")[table_name + "__val"] == 3 + + # delete one record, providing only the key + data = [ + {"id": 1, "deleted": True}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + assert load_table_counts(p, table_name + "__val")[table_name + "__val"] == 1 + + # delete one record, providing the full record + data = [ + {"id": 2, "val": ["foo", "bar"], "deleted": True}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 0 + assert load_table_counts(p, table_name + "__val")[table_name + "__val"] == 0 + + +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) +def test_sort_hint(destination_config: DestinationTestConfiguration) -> None: + table_name = "test_sort_hint" + + @dlt.resource( + name=table_name, + write_disposition="merge", + primary_key="id", # sort hints only have effect when a primary key is provided + columns={"sequence": {"sort": True}}, + ) + def data_resource(data): + yield data + + p = destination_config.setup_pipeline("test_sort_hint", full_refresh=True) + + # three records with same primary key + # only record with highest value in sort column is inserted + data = [ + {"id": 1, "val": "foo", "sequence": 1}, + {"id": 1, "val": "baz", "sequence": 3}, + {"id": 1, "val": "bar", "sequence": 2}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + + # compare observed records with expected records + qual_name = p.sql_client().make_qualified_table_name(table_name) + observed = [ + {"id": row[0], "val": row[1], "sequence": row[2]} + for row in select_data(p, f"SELECT id, val, sequence FROM {qual_name}") + ] + expected = [{"id": 1, "val": "baz", "sequence": 3}] + assert sorted(observed, key=lambda d: d["id"]) == expected + + p = destination_config.setup_pipeline("test_sort_hint_complex", full_refresh=True) + + # three records with same primary key + # only record with highest value in sort column is inserted + data = [ + {"id": 1, "val": [1, 2, 3], "sequence": 1}, + {"id": 1, "val": [7, 8, 9], "sequence": 3}, + {"id": 1, "val": [4, 5, 6], "sequence": 2}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + assert load_table_counts(p, table_name + "__val")[table_name + "__val"] == 3 + + # compare observed records with expected records, now for child table + qual_name = p.sql_client().make_qualified_table_name(table_name + "__val") + observed = [row[0] for row in select_data(p, f"SELECT value FROM {qual_name}")] + assert sorted(observed) == [7, 8, 9] # type: ignore[type-var] + + data_resource.apply_hints( + columns={"sequence": {"sort": True}, "deleted": {"hard_delete": True}} + ) + + p = destination_config.setup_pipeline("test_sort_hint_with_hard_delete", full_refresh=True) + + # three records with same primary key + # record with highest value in sort column is a delete, so no record will be inserted + data = [ + {"id": 1, "val": "foo", "sequence": 1, "deleted": False}, + {"id": 1, "val": "baz", "sequence": 3, "deleted": True}, + {"id": 1, "val": "bar", "sequence": 2, "deleted": False}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 0 + + # three records with same primary key + # record with highest value in sort column is not a delete, so it will be inserted + data = [ + {"id": 1, "val": "foo", "sequence": 1, "deleted": False}, + {"id": 1, "val": "bar", "sequence": 2, "deleted": True}, + {"id": 1, "val": "baz", "sequence": 3, "deleted": False}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + + # compare observed records with expected records + qual_name = p.sql_client().make_qualified_table_name(table_name) + observed = [ + {"id": row[0], "val": row[1], "sequence": row[2]} + for row in select_data(p, f"SELECT id, val, sequence FROM {qual_name}") + ] + expected = [{"id": 1, "val": "baz", "sequence": 3}] + assert sorted(observed, key=lambda d: d["id"]) == expected diff --git a/tests/load/pipeline/test_replicate_disposition.py b/tests/load/pipeline/test_replicate_disposition.py deleted file mode 100644 index dbae4c1d77..0000000000 --- a/tests/load/pipeline/test_replicate_disposition.py +++ /dev/null @@ -1,287 +0,0 @@ -import pytest - -import dlt - -from dlt.common.typing import TDataItems -from dlt.common.schema.typing import TCdcConfig -from dlt.common.schema.exceptions import SchemaException -from dlt.extract import DltResource - -from tests.pipeline.utils import assert_load_info -from tests.load.pipeline.utils import load_table_counts, select_data -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration - - -CDC_CONFIGS = [ - { - "operation_column": "operation", - "operation_mapper": {"insert": "I", "update": "U", "delete": "D"}, - "sequence_column": "lsn", - }, - { - "operation_column": "op", - "operation_mapper": {"insert": 1, "update": 2, "delete": 3}, - "sequence_column": "commit_id", - }, -] - - -@pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name -) -@pytest.mark.parametrize("cdc_config", CDC_CONFIGS) -def test_replicate_core_functionality( - destination_config: DestinationTestConfiguration, - cdc_config: TCdcConfig, -) -> None: - op_col = cdc_config["operation_column"] - seq_col = cdc_config["sequence_column"] - op_map = cdc_config["operation_mapper"] - i = op_map["insert"] - u = op_map["update"] - d = op_map["delete"] - - # define batches of CDC data - batches_simple = [ - [ - {"id": 1, "val": "foo", op_col: i, seq_col: 1}, - {"id": 2, "val": "bar", op_col: i, seq_col: 2}, - ], - [ - {"id": 1, "val": "foo_new", op_col: i, seq_col: 3}, - {"id": 3, "val": "baz", op_col: i, seq_col: 4}, - ], - [ - {"id": 2, "val": "bar_new", op_col: u, seq_col: 5}, - ], - [ - {"id": 4, "val": "foo", op_col: u, seq_col: 6}, - ], - [ - {"id": 2, op_col: d, seq_col: 7}, - {"id": 2, "val": "bar_new_new", op_col: i, seq_col: 8}, - ], - [ - {"id": 5, "val": "foo", op_col: i, seq_col: 9}, - {"id": 5, "val": "foo_new", op_col: u, seq_col: 10}, - ], - [ - {"id": 6, "val": "foo", op_col: i, seq_col: 11}, - {"id": 6, op_col: d, seq_col: 12}, - ], - [ - {"id": 1, op_col: d, seq_col: 13}, - ], - ] - - table_name = "test_replicate_core_functionality" - - @dlt.resource( - table_name=table_name, - write_disposition="replicate", - primary_key="id", - cdc_config=cdc_config, - ) - def data_resource(batches: TDataItems, batch: int): - yield batches[batch] - - p = destination_config.setup_pipeline("pl_test_replicate_core_functionality", full_refresh=True) - - # insert keys in a new empty table - info = p.run(data_resource(batches_simple, batch=0)) - assert_load_info(info) - assert load_table_counts(p, table_name)[table_name] == 2 - - # insert a key that already exists (unexpected scenario) - info = p.run(data_resource(batches_simple, batch=1)) - assert_load_info(info) - assert load_table_counts(p, table_name)[table_name] == 3 - - # update a key that already exists - info = p.run(data_resource(batches_simple, batch=2)) - assert_load_info(info) - assert load_table_counts(p, table_name)[table_name] == 3 - - # update a key that doesn't exist yet (unexpected scenario) - info = p.run(data_resource(batches_simple, batch=3)) - assert_load_info(info) - assert load_table_counts(p, table_name)[table_name] == 4 - - # delete an existing key, then insert it again - info = p.run(data_resource(batches_simple, batch=4)) - assert_load_info(info) - assert load_table_counts(p, table_name)[table_name] == 4 - - # insert a new key, then update it - info = p.run(data_resource(batches_simple, batch=5)) - assert_load_info(info) - assert load_table_counts(p, table_name)[table_name] == 5 - - # insert a new key, then delete it - info = p.run(data_resource(batches_simple, batch=6)) - assert_load_info(info) - assert load_table_counts(p, table_name)[table_name] == 5 - - # delete an existing key - info = p.run(data_resource(batches_simple, batch=7)) - assert_load_info(info) - assert load_table_counts(p, table_name)[table_name] == 4 - - # compare observed records with expected records - qual_name = p.sql_client().make_qualified_table_name(table_name) - observed = [ - {"id": row[0], "val": row[1]} for row in select_data(p, f"SELECT id, val FROM {qual_name}") - ] - expected = [ - {"id": 2, "val": "bar_new_new"}, - {"id": 3, "val": "baz"}, - {"id": 4, "val": "foo"}, - {"id": 5, "val": "foo_new"}, - ] - assert sorted(observed, key=lambda d: d["id"]) == expected - - -@pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name -) -@pytest.mark.parametrize("cdc_config", [CDC_CONFIGS[0]]) -def test_replicate_complex_data( - destination_config: DestinationTestConfiguration, - cdc_config: TCdcConfig, -) -> None: - op_col = cdc_config["operation_column"] - seq_col = cdc_config["sequence_column"] - op_map = cdc_config["operation_mapper"] - i = op_map["insert"] - u = op_map["update"] - d = op_map["delete"] - - # define batches of CDC data - batches_complex = [ - [ - {"id": 1, "val": ["foo", "bar"], op_col: i, seq_col: 1}, - {"id": 2, "val": ["baz"], op_col: i, seq_col: 2}, - ], - [ - {"id": 1, "val": ["foo", "bar", "baz"], op_col: u, seq_col: 3}, - {"id": 3, "val": ["foo"], op_col: i, seq_col: 4}, - ], - [ - {"id": 1, op_col: d, seq_col: 5}, - {"id": 4, "val": ["foo", "bar"], op_col: i, seq_col: 6}, - ], - ] - - table_name = "test_replicate_complex_data" - - @dlt.resource( - table_name=table_name, - write_disposition="replicate", - primary_key="id", - cdc_config=cdc_config, - ) - def data_resource(batches: TDataItems, batch: int): - yield batches[batch] - - # nesting disabled -> no child table - @dlt.source(max_table_nesting=0) - def data_nesting_disabled(batches: TDataItems, batch: int) -> DltResource: - return data_resource(batches, batch) - - p = destination_config.setup_pipeline("pl_test_replicate_complex_data", full_refresh=True) - info = p.run(data_nesting_disabled(batches_complex, batch=0)) - assert_load_info(info) - assert load_table_counts(p, table_name)[table_name] == 2 - - info = p.run(data_nesting_disabled(batches_complex, batch=1)) - assert_load_info(info) - assert load_table_counts(p, table_name)[table_name] == 3 - - info = p.run(data_nesting_disabled(batches_complex, batch=2)) - assert_load_info(info) - assert load_table_counts(p, table_name)[table_name] == 3 - - # nesting enabled -> child table - @dlt.source(max_table_nesting=1, root_key=True) - def data_nesting_enabled(batches: TDataItems, batch: int) -> DltResource: - return data_resource(batches, batch) - - p = destination_config.setup_pipeline("pl_test_replicate_complex_data", full_refresh=True) - - info = p.run(data_nesting_enabled(batches_complex, batch=0)) - assert_load_info(info) - assert load_table_counts(p, table_name)[table_name] == 2 - assert load_table_counts(p, table_name + "__val")[table_name + "__val"] == 3 - - info = p.run(data_nesting_enabled(batches_complex, batch=1)) - assert_load_info(info) - assert load_table_counts(p, table_name)[table_name] == 3 - assert load_table_counts(p, table_name + "__val")[table_name + "__val"] == 5 - - info = p.run(data_nesting_enabled(batches_complex, batch=2)) - assert_load_info(info) - assert load_table_counts(p, table_name)[table_name] == 3 - assert load_table_counts(p, table_name + "__val")[table_name + "__val"] == 4 - - # compare observed records with expected records - qual_name = p.sql_client().make_qualified_table_name(table_name) - observed = [row[0] for row in select_data(p, f"SELECT id FROM {qual_name}")] - expected = [2, 3, 4] - assert sorted(observed) == expected - qual_name = p.sql_client().make_qualified_table_name(table_name + "__val") - observed = [row[0] for row in select_data(p, f"SELECT value FROM {qual_name}")] - expected = ["bar", "baz", "foo", "foo"] # type: ignore[list-item] - assert sorted(observed) == expected - - -@pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name -) -@pytest.mark.parametrize("cdc_config", [CDC_CONFIGS[0]]) -def test_replicate_missing_config( - destination_config: DestinationTestConfiguration, - cdc_config: TCdcConfig, -) -> None: - op_col = cdc_config["operation_column"] - seq_col = cdc_config["sequence_column"] - op_map = cdc_config["operation_mapper"] - i = op_map["insert"] - u = op_map["update"] - - # define batches of CDC data - batches_simple = [ - [ - {"id": 1, "val": "foo", op_col: i, seq_col: 1}, - {"id": 2, "val": "bar", op_col: i, seq_col: 2}, - ], - [ - {"id": 1, "val": "foo_new", op_col: u, seq_col: 3}, - {"id": 3, "val": "baz", op_col: i, seq_col: 4}, - ], - ] - - @dlt.resource( - table_name="test_replicate_no_pk", - write_disposition="replicate", - cdc_config=cdc_config, - ) - def data_resource_no_pk(batches: TDataItems, batch: int): - yield batches[batch] - - # a SchemaException should be raised when using "replicate" and no primary key is specified - p = destination_config.setup_pipeline("pl_test_replicate_no_pk", full_refresh=True) - with pytest.raises(SchemaException): - p.run(data_resource_no_pk(batches_simple, batch=0)) - - @dlt.resource( - table_name="test_replicate_no_cdc_config", - write_disposition="replicate", - primary_key="id", - ) - def data_resource_no_cdc_config(batches: TDataItems, batch: int): - yield batches[batch] - - # a SchemaException should be raised when using "replicate" and no "cdc_config" is specified - p = destination_config.setup_pipeline("pl_test_replicate_no_cdc_config", full_refresh=True) - with pytest.raises(SchemaException): - p.run(data_resource_no_pk(batches_simple, batch=0)) diff --git a/tests/load/weaviate/test_weaviate_client.py b/tests/load/weaviate/test_weaviate_client.py index d6e337a2dc..48153f7706 100644 --- a/tests/load/weaviate/test_weaviate_client.py +++ b/tests/load/weaviate/test_weaviate_client.py @@ -67,7 +67,7 @@ def file_storage() -> FileStorage: return FileStorage(TEST_STORAGE_ROOT, file_type="b", makedirs=True) -@pytest.mark.parametrize("write_disposition", ["append", "replace", "merge", "replicate"]) +@pytest.mark.parametrize("write_disposition", ["append", "replace", "merge"]) def test_all_data_types( client: WeaviateClient, write_disposition: TWriteDisposition, file_storage: FileStorage ) -> None: From f3a487843605bca859d38cd8186757abe719e937 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Wed, 14 Feb 2024 02:44:38 +0100 Subject: [PATCH 05/28] undo config change --- tests/.dlt/config.toml | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/tests/.dlt/config.toml b/tests/.dlt/config.toml index 137f67eab8..1eac35d306 100644 --- a/tests/.dlt/config.toml +++ b/tests/.dlt/config.toml @@ -1,26 +1,3 @@ -ACTIVE_DESTINATIONS = ["postgres"] -#ACTIVE_DESTINATIONS = ["mssql", "duckdb", "postgres"] - -[load] -raise_on_max_retries = 1 - -[destination.mssql] -#merge_strategy = "cdc" - -[destination.mssql.credentials] -database = "dlt_data" -username = "loader" -host = "localhost" -port = 1433 -driver = "ODBC Driver 17 for SQL Server" - -[destination.postgres.credentials] -database = "dlt_data" -username = "loader" -host = "localhost" -port = 5432 -connect_timeout = 15 - [runtime] sentry_dsn="https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752" From deb816fa04920b2c517585f5d665dd4ab973f138 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Wed, 14 Feb 2024 02:53:24 +0100 Subject: [PATCH 06/28] undo unintentional changes --- tests/load/pipeline/test_merge_disposition.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index 7e238d7871..4a7ad0f522 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -448,18 +448,17 @@ def duplicates(): counts = load_table_counts(p, "duplicates", "duplicates__child") assert counts["duplicates"] == 1 if destination_config.supports_merge else 2 assert counts["duplicates__child"] == 3 if destination_config.supports_merge else 6 + qual_name = p.sql_client().make_qualified_table_name("duplicates") + select_data(p, f"SELECT * FROM {qual_name}")[0] - @dlt.resource(write_disposition="merge", primary_key="id") + @dlt.resource(write_disposition="merge", primary_key=("id", "subkey")) def duplicates_no_child(): - yield [ - {"id": 1, "name": "row1"}, - {"id": 1, "name": "row2"}, - ] + yield [{"id": 1, "subkey": "AX", "name": "row1"}, {"id": 1, "subkey": "AX", "name": "row2"}] info = p.run(duplicates_no_child()) assert_load_info(info) counts = load_table_counts(p, "duplicates_no_child") - assert counts["duplicates_no_child"] == 1 + assert counts["duplicates_no_child"] == 1 if destination_config.supports_merge else 2 @pytest.mark.parametrize( From 4a38d56a156e6ad58d8d651051ef80eb06f91d16 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Fri, 16 Feb 2024 00:15:05 +0100 Subject: [PATCH 07/28] refactor hard_delete handling and introduce dedup_sort hint --- dlt/common/destination/reference.py | 58 ++++- dlt/common/schema/typing.py | 1 + dlt/common/schema/utils.py | 10 + dlt/destinations/job_client_impl.py | 13 -- dlt/destinations/sql_jobs.py | 202 ++++++++++-------- tests/load/pipeline/test_merge_disposition.py | 164 ++++++++++---- 6 files changed, 299 insertions(+), 149 deletions(-) diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 1c28dffa8c..a2bcea0d56 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -20,7 +20,6 @@ Generic, Final, ) -from contextlib import contextmanager import datetime # noqa: 251 from copy import deepcopy import inspect @@ -32,10 +31,15 @@ UnknownDestinationModule, ) from dlt.common.schema import Schema, TTableSchema, TSchemaTables -from dlt.common.schema.typing import TWriteDisposition -from dlt.common.schema.exceptions import InvalidDatasetName -from dlt.common.schema.utils import get_write_disposition, get_table_format -from dlt.common.configuration import configspec, with_config, resolve_configuration, known_sections +from dlt.common.schema.exceptions import SchemaException +from dlt.common.schema.utils import ( + get_write_disposition, + get_table_format, + get_columns_names_with_prop, + has_column_with_prop, + get_first_column_name_with_prop, +) +from dlt.common.configuration import configspec, resolve_configuration, known_sections from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration from dlt.common.configuration.accessors import config from dlt.common.destination.capabilities import DestinationCapabilitiesContext @@ -43,7 +47,6 @@ from dlt.common.schema.exceptions import UnknownTableException from dlt.common.storages import FileStorage from dlt.common.storages.load_storage import ParsedLoadJobFileName -from dlt.common.utils import get_module_name from dlt.common.configuration.specs import GcpCredentials, AwsCredentialsWithoutDefaults @@ -345,6 +348,49 @@ def _verify_schema(self) -> None: table_name, self.capabilities.max_identifier_length, ) + if has_column_with_prop(table, "hard_delete"): + if len(get_columns_names_with_prop(table, "hard_delete")) > 1: + raise SchemaException( + f'Found multiple "hard_delete" column hints for table "{table_name}" in' + f' schema "{self.schema.name}" while only one is allowed:' + f' {", ".join(get_columns_names_with_prop(table, "hard_delete"))}.' + ) + if table.get("write_disposition") in ("replace", "append"): + logger.warning( + f"""The "hard_delete" column hint for column "{get_first_column_name_with_prop(table, 'hard_delete')}" """ + f'in table "{table_name}" with write disposition' + f' "{table.get("write_disposition")}"' + f' in schema "{self.schema.name}" will be ignored.' + ' The "hard_delete" column hint is only applied when using' + ' the "merge" write disposition.' + ) + if has_column_with_prop(table, "dedup_sort"): + if len(get_columns_names_with_prop(table, "dedup_sort")) > 1: + raise SchemaException( + f'Found multiple "dedup_sort" column hints for table "{table_name}" in' + f' schema "{self.schema.name}" while only one is allowed:' + f' {", ".join(get_columns_names_with_prop(table, "dedup_sort"))}.' + ) + if table.get("write_disposition") in ("replace", "append"): + logger.warning( + f"""The "dedup_sort" column hint for column "{get_first_column_name_with_prop(table, 'dedup_sort')}" """ + f'in table "{table_name}" with write disposition' + f' "{table.get("write_disposition")}"' + f' in schema "{self.schema.name}" will be ignored.' + ' The "dedup_sort" column hint is only applied when using' + ' the "merge" write disposition.' + ) + if table.get("write_disposition") == "merge" and not has_column_with_prop( + table, "primary_key" + ): + logger.warning( + f"""The "dedup_sort" column hint for column "{get_first_column_name_with_prop(table, 'dedup_sort')}" """ + f'in table "{table_name}" with write disposition' + f' "{table.get("write_disposition")}"' + f' in schema "{self.schema.name}" will be ignored.' + ' The "dedup_sort" column hint is only applied when a' + " primary key has been specified." + ) for column_name, column in dict(table["columns"]).items(): if len(column_name) > self.capabilities.max_column_identifier_length: raise IdentifierTooLongException( diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index 65eb0afff1..da6c46526e 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -113,6 +113,7 @@ class TColumnSchema(TColumnSchemaBase, total=False): merge_key: Optional[bool] variant: Optional[bool] hard_delete: Optional[bool] + dedup_sort: Optional[bool] TTableSchemaColumns = Dict[str, TColumnSchema] diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index b3323a3673..7ce82d7c88 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -573,6 +573,16 @@ def get_columns_names_with_prop( ] +def get_first_column_name_with_prop( + table: TTableSchema, column_prop: Union[TColumnProp, str], include_incomplete: bool = False +) -> Optional[str]: + """Returns name of first column in `table` schema with property `column_prop` or None if no such column exists.""" + column_names = get_columns_names_with_prop(table, column_prop, include_incomplete) + if len(column_names) > 0: + return column_names[0] + return None + + def has_column_with_prop( table: TTableSchema, column_prop: Union[TColumnProp, str], include_incomplete: bool = False ) -> bool: diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 649ad8c6e2..e7dc4bcbe2 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -35,7 +35,6 @@ ) from dlt.common.storages import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns, TSchemaTables -from dlt.common.schema.utils import get_columns_names_with_prop, has_column_with_prop from dlt.common.destination.reference import ( StateInfo, StorageSchemaInfo, @@ -589,15 +588,3 @@ def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: ): return True return False - - def _create_table_update( - self, table_name: str, storage_columns: TTableSchemaColumns - ) -> Sequence[TColumnSchema]: - updates = super()._create_table_update(table_name, storage_columns) - table = self.schema.get_table(table_name) - if has_column_with_prop(table, "hard_delete"): - # hard_delete column should only be present in staging table, not in final table - if not self.in_staging_mode: - hard_delete_column = get_columns_names_with_prop(table, "hard_delete")[0] - updates = [d for d in updates if d["name"] != hard_delete_column] - return updates diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index 821dac3183..2e62af955c 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -1,10 +1,11 @@ from typing import Any, Callable, List, Sequence, Tuple, cast, TypedDict, Optional +from copy import copy import yaml from dlt.common.runtime.logger import pretty_format_exception from dlt.common.schema.typing import TTableSchema -from dlt.common.schema.utils import get_columns_names_with_prop, has_column_with_prop +from dlt.common.schema.utils import get_columns_names_with_prop, get_first_column_name_with_prop from dlt.common.storages.load_storage import ParsedLoadJobFileName from dlt.common.utils import uniq_id from dlt.destinations.exceptions import MergeDispositionException @@ -202,19 +203,58 @@ def gen_delete_temp_table_sql( sql.append(f"INSERT INTO {temp_table_name} SELECT {unique_column} {clause};") return sql, temp_table_name + @classmethod + def gen_select_from_dedup_sql( + cls, + table_name: str, + primary_keys: Sequence[str], + columns: Sequence[str], + sort_column: str = None, + condition: str = None, + condition_columns: Sequence[str] = None, + ) -> str: + """ + Returns SELECT FROM statement where the FROM clause represents a deduplicated version of the `table_name` table. + Expects column names provided in arguments to be escaped identifiers. + """ + if sort_column is None: + order_by = "(SELECT NULL)" + else: + order_by = f"{sort_column} DESC" + if condition is None: + condition = "1 = 1" + col_str = ", ".join(columns) + inner_col_str = copy(col_str) + if condition_columns is not None: + inner_col_str += ", " + ", ".join(condition_columns) + return f""" + SELECT {col_str} + FROM ( + SELECT ROW_NUMBER() OVER (partition BY {", ".join(primary_keys)} ORDER BY {order_by}) AS _dlt_dedup_rn, {inner_col_str} + FROM {table_name} + ) AS _dlt_dedup_numbered WHERE _dlt_dedup_rn = 1 AND {condition} + """ + @classmethod def gen_insert_temp_table_sql( - cls, staging_root_table_name: str, primary_keys: Sequence[str], unique_column: str + cls, + staging_root_table_name: str, + primary_keys: Sequence[str], + unique_column: str, + sort_column: str = None, + condition: str = None, + condition_columns: Sequence[str] = None, ) -> Tuple[List[str], str]: temp_table_name = cls._new_temp_table_name("insert") - select_statement = f""" - SELECT {unique_column} - FROM ( - SELECT ROW_NUMBER() OVER (partition BY {", ".join(primary_keys)} ORDER BY (SELECT NULL)) AS _dlt_dedup_rn, {unique_column} - FROM {staging_root_table_name} - ) AS _dlt_dedup_numbered WHERE _dlt_dedup_rn = 1 - """ - return [cls._to_temp_table(select_statement, temp_table_name)], temp_table_name + select_sql = cls.gen_select_from_dedup_sql( + staging_root_table_name, + primary_keys, + [unique_column], + sort_column, + condition, + condition_columns, + ) + return [cls._to_temp_table(select_sql, temp_table_name)], temp_table_name @classmethod def gen_delete_from_sql( @@ -254,8 +294,9 @@ def gen_merge_sql( ) -> List[str]: sql: List[str] = [] root_table = table_chain[0] - escape_identifier = sql_client.capabilities.escape_identifier - escape_literal = sql_client.capabilities.escape_literal + + escape_id = sql_client.capabilities.escape_identifier + escape_lit = sql_client.capabilities.escape_literal # get top level table full identifiers root_table_name = sql_client.make_qualified_table_name(root_table["name"]) @@ -264,13 +305,13 @@ def gen_merge_sql( # get merge and primary keys from top level primary_keys = list( map( - escape_identifier, + escape_id, get_columns_names_with_prop(root_table, "primary_key"), ) ) merge_keys = list( map( - escape_identifier, + escape_id, get_columns_names_with_prop(root_table, "merge_key"), ) ) @@ -278,7 +319,6 @@ def gen_merge_sql( unique_column: str = None root_key_column: str = None - insert_temp_table_name: str = None if len(table_chain) == 1: key_table_clauses = cls.gen_key_table_clauses( @@ -302,7 +342,7 @@ def gen_merge_sql( " it is not possible to link child tables to it.", ) # get first unique column - unique_column = escape_identifier(unique_columns[0]) + unique_column = escape_id(unique_columns[0]) # create temp table with unique identifier create_delete_temp_table_sql, delete_temp_table_name = cls.gen_delete_temp_table_sql( unique_column, key_table_clauses @@ -323,7 +363,7 @@ def gen_merge_sql( f" {table['name']} so it is not possible to refer to top level table" f" {root_table['name']} unique column {unique_column}", ) - root_key_column = escape_identifier(root_key_columns[0]) + root_key_column = escape_id(root_key_columns[0]) sql.append( cls.gen_delete_from_sql( table_name, root_key_column, delete_temp_table_name, unique_column @@ -337,88 +377,68 @@ def gen_merge_sql( ) ) - # remove "non-latest" records from staging table (deduplicate) if a sort column is provided - if len(primary_keys) > 0: - if has_column_with_prop(root_table, "sort"): - sort_column = escape_identifier(get_columns_names_with_prop(root_table, "sort")[0]) - sql.append(f""" - DELETE FROM {staging_root_table_name} - WHERE {sort_column} IN ( - SELECT {sort_column} FROM ( - SELECT {sort_column}, ROW_NUMBER() OVER (partition BY {", ".join(primary_keys)} ORDER BY {sort_column} DESC) AS _rn - FROM {staging_root_table_name} - ) AS a - WHERE a._rn > 1 - ); - """) - - # remove deleted records from staging tables if a hard_delete column is provided - if has_column_with_prop(root_table, "hard_delete"): - hard_delete_column = escape_identifier( - get_columns_names_with_prop(root_table, "hard_delete")[0] - ) - # first delete from root staging table - sql.append(f""" - DELETE FROM {staging_root_table_name} - WHERE {hard_delete_column} IS NOT DISTINCT FROM {escape_literal(True)}; - """) - # then delete from child staging tables - for table in table_chain[1:]: - with sql_client.with_staging_dataset(staging=True): - staging_table_name = sql_client.make_qualified_table_name(table["name"]) - sql.append(f""" - DELETE FROM {staging_table_name} - WHERE NOT EXISTS ( - SELECT 1 FROM {staging_root_table_name} AS p - WHERE {staging_table_name}.{root_key_column} = p.{unique_column} - ); - """) - - if len(table_chain) > 1: - # create temp table used to deduplicate, only when we have primary keys - if primary_keys: - ( - create_insert_temp_table_sql, - insert_temp_table_name, - ) = cls.gen_insert_temp_table_sql( - staging_root_table_name, primary_keys, unique_column + # get name of column with hard_delete hint, if specified + not_deleted_cond: str = None + hard_delete_col = get_first_column_name_with_prop(root_table, "hard_delete") + if hard_delete_col is not None: + if root_table["columns"][hard_delete_col]["data_type"] == "bool": + # only True values indicate a delete for boolean column + not_deleted_cond = ( + f"{escape_id(hard_delete_col)} IS DISTINCT FROM {escape_lit(True)}" ) - sql.extend(create_insert_temp_table_sql) + else: + # any value indicates a delete for non-boolean columns + not_deleted_cond = f"{escape_id(hard_delete_col)} IS NULL" + + # get name of column with dedup_sort hint, if specified + dedup_sort_col = get_first_column_name_with_prop(root_table, "dedup_sort") + + # create temp table used to deduplicate, only when we have primary keys and child tables + insert_temp_table_name: str = None + if len(primary_keys) > 0 and len(table_chain) > 1: + condition_colummns = [hard_delete_col] if not_deleted_cond is not None else None + ( + create_insert_temp_table_sql, + insert_temp_table_name, + ) = cls.gen_insert_temp_table_sql( + staging_root_table_name, + primary_keys, + unique_column, + dedup_sort_col, + not_deleted_cond, + condition_colummns, + ) + sql.extend(create_insert_temp_table_sql) # insert from staging to dataset for table in table_chain: table_name = sql_client.make_qualified_table_name(table["name"]) with sql_client.with_staging_dataset(staging=True): staging_table_name = sql_client.make_qualified_table_name(table["name"]) - columns = ", ".join( - map( - escape_identifier, - [ - c - for c in get_columns_names_with_prop(table, "name") - if c not in get_columns_names_with_prop(table, "hard_delete") - ], - ) - ) - insert_sql = ( - f"INSERT INTO {table_name}({columns}) SELECT {columns} FROM {staging_table_name}" - ) - if len(primary_keys) > 0: - if len(table_chain) == 1: - insert_sql = f"""INSERT INTO {table_name}({columns}) - SELECT {columns} FROM ( - SELECT ROW_NUMBER() OVER (partition BY {", ".join(primary_keys)} ORDER BY (SELECT NULL)) AS _dlt_dedup_rn, {columns} - FROM {staging_table_name} - ) AS _dlt_dedup_numbered WHERE _dlt_dedup_rn = 1; - """ - else: - uniq_column = unique_column if table.get("parent") is None else root_key_column - insert_sql += ( - f" WHERE {uniq_column} IN (SELECT * FROM {insert_temp_table_name});" + columns = list(map(escape_id, get_columns_names_with_prop(table, "name"))) + col_str = ", ".join(columns) + insert_cond = "1 = 1" + if hard_delete_col is not None: + # exlude deleted records from INSERT statement + insert_cond = copy(not_deleted_cond) + if table.get("parent") is not None: + insert_cond = ( + f"{root_key_column} IN (SELECT {root_key_column}" + f" FROM {staging_root_table_name} WHERE {not_deleted_cond});" ) - if insert_sql.strip()[-1] != ";": - insert_sql += ";" - sql.append(insert_sql) + if len(primary_keys) > 0 and len(table_chain) > 1: + # deduplicate using temp table + uniq_column = unique_column if table.get("parent") is None else root_key_column + insert_cond = f"{uniq_column} IN (SELECT * FROM {insert_temp_table_name})" + + # select_sql = f"SELECT {columns} FROM {staging_table_name} WHERE {insert_cond}" + select_sql = f"SELECT {col_str} FROM {staging_table_name} WHERE {insert_cond}" + if len(primary_keys) > 0 and len(table_chain) == 1: + # without child tables we deduplicate inside the query instead of using a temp table + select_sql = cls.gen_select_from_dedup_sql( + staging_table_name, primary_keys, columns, dedup_sort_col, insert_cond + ) + sql.append(f"INSERT INTO {table_name}({col_str}) {select_sql};") return sql diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index 4a7ad0f522..ee6585763e 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -1,6 +1,5 @@ from copy import copy import pytest -import itertools import random from typing import List import pytest @@ -11,10 +10,11 @@ from dlt.common import json, pendulum from dlt.common.configuration.container import Container from dlt.common.pipeline import StateInjectableContext -from dlt.common.typing import AnyFun, StrAny +from dlt.common.typing import StrAny from dlt.common.utils import digest128 from dlt.extract import DltResource from dlt.sources.helpers.transform import skip_first, take_first +from dlt.pipeline.exceptions import PipelineStepFailed from tests.pipeline.utils import assert_load_info from tests.load.pipeline.utils import load_table_counts, select_data @@ -495,7 +495,7 @@ def duplicates_no_child(): ) @pytest.mark.parametrize("key_type", ["primary_key", "merge_key"]) def test_hard_delete_hint(destination_config: DestinationTestConfiguration, key_type: str) -> None: - table_name = f"test_hard_delete_hint_{key_type}" + table_name = "test_hard_delete_hint" @dlt.resource( name=table_name, @@ -510,7 +510,7 @@ def data_resource(data): elif key_type == "merge_key": data_resource.apply_hints(primary_key="", merge_key="id") - p = destination_config.setup_pipeline(f"test_hard_delete_hint_{key_type}", full_refresh=True) + p = destination_config.setup_pipeline(f"abstract_{key_type}", full_refresh=True) # insert two records data = [ @@ -540,77 +540,150 @@ def data_resource(data): # compare observed records with expected records qual_name = p.sql_client().make_qualified_table_name(table_name) observed = [ - {"id": row[0], "val": row[1]} for row in select_data(p, f"SELECT id, val FROM {qual_name}") + {"id": row[0], "val": row[1], "deleted": row[2]} + for row in select_data(p, f"SELECT id, val, deleted FROM {qual_name}") ] - expected = [{"id": 2, "val": "baz"}] + expected = [{"id": 2, "val": "baz", "deleted": None}] assert sorted(observed, key=lambda d: d["id"]) == expected - # insert and delete the same record, record will be inserted if no sort column is provided + table_name = "test_hard_delete_hint_complex" + data_resource.apply_hints(table_name=table_name) + + # insert two records with childs and grandchilds data = [ - {"id": 3, "val": "foo", "deleted": None}, - {"id": 3, "deleted": True}, - # {"id": 4, "deleted": True}, - # {"id": 4, "val": "foo", "deleted": None}, + { + "id": 1, + "child_1": ["foo", "bar"], + "child_2": [ + {"grandchild_1": ["foo", "bar"], "grandchild_2": True}, + {"grandchild_1": ["bar", "baz"], "grandchild_2": False}, + ], + "deleted": False, + }, + { + "id": 2, + "child_1": ["baz"], + "child_2": [{"grandchild_1": ["baz"], "grandchild_2": True}], + "deleted": False, + }, ] info = p.run(data_resource(data)) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == 2 + assert load_table_counts(p, table_name + "__child_1")[table_name + "__child_1"] == 3 + assert load_table_counts(p, table_name + "__child_2")[table_name + "__child_2"] == 3 + assert ( + load_table_counts(p, table_name + "__child_2__grandchild_1")[ + table_name + "__child_2__grandchild_1" + ] + == 5 + ) - with p.destination_client(p.default_schema_name) as client: - _, table = client.get_storage_table(table_name) # type: ignore[attr-defined] - column_names = table.keys() - # ensure hard_delete column is not propagated to final table - assert "deleted" not in column_names + # delete first record + data = [ + {"id": 1, "deleted": True}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + assert load_table_counts(p, table_name + "__child_1")[table_name + "__child_1"] == 1 + assert ( + load_table_counts(p, table_name + "__child_2__grandchild_1")[ + table_name + "__child_2__grandchild_1" + ] + == 1 + ) - p = destination_config.setup_pipeline( - f"test_hard_delete_hint_{key_type}_complex", full_refresh=True + # delete second record + data = [ + {"id": 2, "deleted": True}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 0 + assert load_table_counts(p, table_name + "__child_1")[table_name + "__child_1"] == 0 + assert ( + load_table_counts(p, table_name + "__child_2__grandchild_1")[ + table_name + "__child_2__grandchild_1" + ] + == 0 + ) + + +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) +def test_hard_delete_hint_config(destination_config: DestinationTestConfiguration) -> None: + table_name = "test_hard_delete_hint_non_bool" + + @dlt.resource( + name=table_name, + write_disposition="merge", + primary_key="id", + columns={ + "deleted_timestamp": {"data_type": "timestamp", "nullable": True, "hard_delete": True} + }, ) + def data_resource(data): + yield data + + p = destination_config.setup_pipeline("abstract", full_refresh=True) # insert two records data = [ - {"id": 1, "val": ["foo", "bar"], "deleted": False}, - {"id": 2, "val": ["baz"], "deleted": False}, + {"id": 1, "val": "foo", "deleted_timestamp": None}, + {"id": 2, "val": "bar", "deleted_timestamp": None}, ] info = p.run(data_resource(data)) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == 2 - assert load_table_counts(p, table_name + "__val")[table_name + "__val"] == 3 - # delete one record, providing only the key + # delete one record data = [ - {"id": 1, "deleted": True}, + {"id": 1, "deleted_timestamp": "2024-02-15T17:16:53Z"}, ] info = p.run(data_resource(data)) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == 1 - assert load_table_counts(p, table_name + "__val")[table_name + "__val"] == 1 - # delete one record, providing the full record - data = [ - {"id": 2, "val": ["foo", "bar"], "deleted": True}, + # compare observed records with expected records + qual_name = p.sql_client().make_qualified_table_name(table_name) + observed = [ + {"id": row[0], "val": row[1], "deleted_timestamp": row[2]} + for row in select_data(p, f"SELECT id, val, deleted_timestamp FROM {qual_name}") ] - info = p.run(data_resource(data)) - assert_load_info(info) - assert load_table_counts(p, table_name)[table_name] == 0 - assert load_table_counts(p, table_name + "__val")[table_name + "__val"] == 0 + expected = [{"id": 2, "val": "bar", "deleted_timestamp": None}] + assert sorted(observed, key=lambda d: d["id"]) == expected + + # test if exception is raised when more than one "hard_delete" column hints are provided + @dlt.resource( + name="test_hard_delete_hint_too_many_hints", + write_disposition="merge", + columns={"deleted_1": {"hard_delete": True}, "deleted_2": {"hard_delete": True}}, + ) + def r(): + yield {"id": 1, "val": "foo", "deleted_1": True, "deleted_2": False} + + with pytest.raises(PipelineStepFailed): + info = p.run(r()) @pytest.mark.parametrize( "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name ) -def test_sort_hint(destination_config: DestinationTestConfiguration) -> None: - table_name = "test_sort_hint" +def test_dedup_sort_hint(destination_config: DestinationTestConfiguration) -> None: + table_name = "test_dedup_sort_hint" @dlt.resource( name=table_name, write_disposition="merge", primary_key="id", # sort hints only have effect when a primary key is provided - columns={"sequence": {"sort": True}}, + columns={"sequence": {"dedup_sort": True}}, ) def data_resource(data): yield data - p = destination_config.setup_pipeline("test_sort_hint", full_refresh=True) + p = destination_config.setup_pipeline("abstract", full_refresh=True) # three records with same primary key # only record with highest value in sort column is inserted @@ -632,7 +705,8 @@ def data_resource(data): expected = [{"id": 1, "val": "baz", "sequence": 3}] assert sorted(observed, key=lambda d: d["id"]) == expected - p = destination_config.setup_pipeline("test_sort_hint_complex", full_refresh=True) + table_name = "test_dedup_sort_hint_complex" + data_resource.apply_hints(table_name=table_name) # three records with same primary key # only record with highest value in sort column is inserted @@ -651,12 +725,12 @@ def data_resource(data): observed = [row[0] for row in select_data(p, f"SELECT value FROM {qual_name}")] assert sorted(observed) == [7, 8, 9] # type: ignore[type-var] + table_name = "test_dedup_sort_hint_with_hard_delete" data_resource.apply_hints( - columns={"sequence": {"sort": True}, "deleted": {"hard_delete": True}} + table_name=table_name, + columns={"sequence": {"dedup_sort": True}, "deleted": {"hard_delete": True}}, ) - p = destination_config.setup_pipeline("test_sort_hint_with_hard_delete", full_refresh=True) - # three records with same primary key # record with highest value in sort column is a delete, so no record will be inserted data = [ @@ -687,3 +761,15 @@ def data_resource(data): ] expected = [{"id": 1, "val": "baz", "sequence": 3}] assert sorted(observed, key=lambda d: d["id"]) == expected + + # test if exception is raised when more than one "dedup_sort" column hints are provided + @dlt.resource( + name="test_dedup_sort_hint_too_many_hints", + write_disposition="merge", + columns={"dedup_sort_1": {"dedup_sort": True}, "dedup_sort_2": {"dedup_sort": True}}, + ) + def r(): + yield {"id": 1, "val": "foo", "dedup_sort_1": 1, "dedup_sort_2": 5} + + with pytest.raises(PipelineStepFailed): + info = p.run(r()) From 0d1c97735153e64515c1ef308176edce6c03003c Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Fri, 16 Feb 2024 00:20:36 +0100 Subject: [PATCH 08/28] update docstring --- dlt/destinations/sql_jobs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index 2e62af955c..8d8e692db2 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -149,7 +149,8 @@ def generate_sql( First we store the root_keys of root table elements to be deleted in the temp table. Then we use the temp table to delete records from root and all child tables in the destination dataset. At the end we copy the data from the staging dataset into destination dataset. - If sort and/or hard_delete column hints are provided, records are deleted from the staging dataset before its data is copied to the destination dataset. + If a hard_delete column is specified, records flagged as deleted will be excluded from the copy into the destination dataset. + If a dedup_sort column is specified in conjunction with a primary key, records will be sorted before deduplication, so the "latest" record remains. """ return cls.gen_merge_sql(table_chain, sql_client) From 474d8bc9e632f99784af78dd98dfa6e21f004da7 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Fri, 16 Feb 2024 15:29:20 +0100 Subject: [PATCH 09/28] replace dialect-specific SQL --- dlt/destinations/sql_jobs.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index 8d8e692db2..1b51104cba 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -382,14 +382,11 @@ def gen_merge_sql( not_deleted_cond: str = None hard_delete_col = get_first_column_name_with_prop(root_table, "hard_delete") if hard_delete_col is not None: + # any value indicates a delete for non-boolean columns + not_deleted_cond = f"{escape_id(hard_delete_col)} IS NULL" if root_table["columns"][hard_delete_col]["data_type"] == "bool": - # only True values indicate a delete for boolean column - not_deleted_cond = ( - f"{escape_id(hard_delete_col)} IS DISTINCT FROM {escape_lit(True)}" - ) - else: - # any value indicates a delete for non-boolean columns - not_deleted_cond = f"{escape_id(hard_delete_col)} IS NULL" + # only True values indicate a delete for boolean columns + not_deleted_cond += f" OR {escape_id(hard_delete_col)} = {escape_lit(False)}" # get name of column with dedup_sort hint, if specified dedup_sort_col = get_first_column_name_with_prop(root_table, "dedup_sort") From 568ef2659d3e40d39e78dc0300b2f1e532ef5db1 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Fri, 16 Feb 2024 17:25:10 +0100 Subject: [PATCH 10/28] add parentheses to ensure proper clause evaluation order --- dlt/destinations/sql_jobs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index 1b51104cba..b7a4f6766d 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -233,7 +233,7 @@ def gen_select_from_dedup_sql( FROM ( SELECT ROW_NUMBER() OVER (partition BY {", ".join(primary_keys)} ORDER BY {order_by}) AS _dlt_dedup_rn, {inner_col_str} FROM {table_name} - ) AS _dlt_dedup_numbered WHERE _dlt_dedup_rn = 1 AND {condition} + ) AS _dlt_dedup_numbered WHERE _dlt_dedup_rn = 1 AND ({condition}) """ @classmethod From 81ea42691aa5100a489a7b1a13cbb5d2c6e8a0a9 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Fri, 16 Feb 2024 23:42:44 +0100 Subject: [PATCH 11/28] add escape defaults and temp tables for non-primary key case --- dlt/destinations/sql_jobs.py | 78 +++++++++++++++++++----------------- 1 file changed, 41 insertions(+), 37 deletions(-) diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index b7a4f6766d..f1e4c12f71 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -8,6 +8,7 @@ from dlt.common.schema.utils import get_columns_names_with_prop, get_first_column_name_with_prop from dlt.common.storages.load_storage import ParsedLoadJobFileName from dlt.common.utils import uniq_id +from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.destinations.exceptions import MergeDispositionException from dlt.destinations.job_impl import NewLoadJobImpl from dlt.destinations.sql_client import SqlClientBase @@ -247,14 +248,19 @@ def gen_insert_temp_table_sql( condition_columns: Sequence[str] = None, ) -> Tuple[List[str], str]: temp_table_name = cls._new_temp_table_name("insert") - select_sql = cls.gen_select_from_dedup_sql( - staging_root_table_name, - primary_keys, - [unique_column], - sort_column, - condition, - condition_columns, - ) + if len(primary_keys) > 0: + # deduplicate + select_sql = cls.gen_select_from_dedup_sql( + staging_root_table_name, + primary_keys, + [unique_column], + sort_column, + condition, + condition_columns, + ) + else: + # don't deduplicate + select_sql = f"SELECT {unique_column} FROM {staging_root_table_name} WHERE {condition}" return [cls._to_temp_table(select_sql, temp_table_name)], temp_table_name @classmethod @@ -298,6 +304,10 @@ def gen_merge_sql( escape_id = sql_client.capabilities.escape_identifier escape_lit = sql_client.capabilities.escape_literal + if escape_id is None: + escape_id = DestinationCapabilitiesContext.generic_capabilities().escape_identifier + if escape_lit is None: + escape_lit = DestinationCapabilitiesContext.generic_capabilities().escape_literal # get top level table full identifiers root_table_name = sql_client.make_qualified_table_name(root_table["name"]) @@ -391,46 +401,40 @@ def gen_merge_sql( # get name of column with dedup_sort hint, if specified dedup_sort_col = get_first_column_name_with_prop(root_table, "dedup_sort") - # create temp table used to deduplicate, only when we have primary keys and child tables insert_temp_table_name: str = None - if len(primary_keys) > 0 and len(table_chain) > 1: - condition_colummns = [hard_delete_col] if not_deleted_cond is not None else None - ( - create_insert_temp_table_sql, - insert_temp_table_name, - ) = cls.gen_insert_temp_table_sql( - staging_root_table_name, - primary_keys, - unique_column, - dedup_sort_col, - not_deleted_cond, - condition_colummns, - ) - sql.extend(create_insert_temp_table_sql) + if len(table_chain) > 1: + if len(primary_keys) > 0 or (len(primary_keys) == 0 and hard_delete_col is not None): + condition_colummns = [hard_delete_col] if not_deleted_cond is not None else None + ( + create_insert_temp_table_sql, + insert_temp_table_name, + ) = cls.gen_insert_temp_table_sql( + staging_root_table_name, + primary_keys, + unique_column, + dedup_sort_col, + not_deleted_cond, + condition_colummns, + ) + sql.extend(create_insert_temp_table_sql) # insert from staging to dataset for table in table_chain: table_name = sql_client.make_qualified_table_name(table["name"]) with sql_client.with_staging_dataset(staging=True): staging_table_name = sql_client.make_qualified_table_name(table["name"]) - columns = list(map(escape_id, get_columns_names_with_prop(table, "name"))) - col_str = ", ".join(columns) - insert_cond = "1 = 1" - if hard_delete_col is not None: - # exlude deleted records from INSERT statement - insert_cond = copy(not_deleted_cond) - if table.get("parent") is not None: - insert_cond = ( - f"{root_key_column} IN (SELECT {root_key_column}" - f" FROM {staging_root_table_name} WHERE {not_deleted_cond});" - ) - if len(primary_keys) > 0 and len(table_chain) > 1: - # deduplicate using temp table + insert_cond = copy(not_deleted_cond) if hard_delete_col is not None else "1 = 1" + if (len(primary_keys) > 0 and len(table_chain) > 1) or ( + len(primary_keys) == 0 + and table.get("parent") is not None # child table + and hard_delete_col is not None + ): uniq_column = unique_column if table.get("parent") is None else root_key_column insert_cond = f"{uniq_column} IN (SELECT * FROM {insert_temp_table_name})" - # select_sql = f"SELECT {columns} FROM {staging_table_name} WHERE {insert_cond}" + columns = list(map(escape_id, get_columns_names_with_prop(table, "name"))) + col_str = ", ".join(columns) select_sql = f"SELECT {col_str} FROM {staging_table_name} WHERE {insert_cond}" if len(primary_keys) > 0 and len(table_chain) == 1: # without child tables we deduplicate inside the query instead of using a temp table From a04a238e721994d0e01b3d4f943febc22e693345 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Sat, 17 Feb 2024 11:35:05 +0100 Subject: [PATCH 12/28] exclude destinations that don't support merge from test --- tests/load/pipeline/test_merge_disposition.py | 12 +++++++++--- tests/load/utils.py | 5 +++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index ee6585763e..888ebe3d97 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -491,7 +491,9 @@ def duplicates_no_child(): @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True, supports_merge=True), + ids=lambda x: x.name, ) @pytest.mark.parametrize("key_type", ["primary_key", "merge_key"]) def test_hard_delete_hint(destination_config: DestinationTestConfiguration, key_type: str) -> None: @@ -611,7 +613,9 @@ def data_resource(data): @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True, supports_merge=True), + ids=lambda x: x.name, ) def test_hard_delete_hint_config(destination_config: DestinationTestConfiguration) -> None: table_name = "test_hard_delete_hint_non_bool" @@ -669,7 +673,9 @@ def r(): @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True, supports_merge=True), + ids=lambda x: x.name, ) def test_dedup_sort_hint(destination_config: DestinationTestConfiguration) -> None: table_name = "test_dedup_sort_hint" diff --git a/tests/load/utils.py b/tests/load/utils.py index 31a187e13e..b20fdd8145 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -153,6 +153,7 @@ def destinations_configs( subset: Sequence[str] = (), exclude: Sequence[str] = (), file_format: Optional[TLoaderFileFormat] = None, + supports_merge: Optional[bool] = None, supports_dbt: Optional[bool] = None, ) -> List[DestinationTestConfiguration]: # sanity check @@ -372,6 +373,10 @@ def destinations_configs( destination_configs = [ conf for conf in destination_configs if conf.file_format == file_format ] + if supports_merge is not None: + destination_configs = [ + conf for conf in destination_configs if conf.supports_merge == supports_merge + ] if supports_dbt is not None: destination_configs = [ conf for conf in destination_configs if conf.supports_dbt == supports_dbt From 8ac0f9cc0efdb523531e47b1d7a1cc32e99938dd Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Tue, 20 Feb 2024 10:19:56 +0100 Subject: [PATCH 13/28] correct typo --- dlt/destinations/sql_jobs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index f1e4c12f71..13d9f7c784 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -404,7 +404,7 @@ def gen_merge_sql( insert_temp_table_name: str = None if len(table_chain) > 1: if len(primary_keys) > 0 or (len(primary_keys) == 0 and hard_delete_col is not None): - condition_colummns = [hard_delete_col] if not_deleted_cond is not None else None + condition_columns = [hard_delete_col] if not_deleted_cond is not None else None ( create_insert_temp_table_sql, insert_temp_table_name, @@ -414,7 +414,7 @@ def gen_merge_sql( unique_column, dedup_sort_col, not_deleted_cond, - condition_colummns, + condition_columns, ) sql.extend(create_insert_temp_table_sql) From ec115e9726d01bdebcf94592660175ba0d472b5b Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Tue, 20 Feb 2024 12:20:10 +0100 Subject: [PATCH 14/28] extend docstring --- dlt/destinations/sql_jobs.py | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index 13d9f7c784..fd02c1a7d7 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -215,9 +215,41 @@ def gen_select_from_dedup_sql( condition: str = None, condition_columns: Sequence[str] = None, ) -> str: - """ - Returns SELECT FROM statement where the FROM clause represents a deduplicated version of the `table_name` table. + """Returns SELECT FROM SQL statement. + + The FROM clause in the SQL statement represents a deduplicated version + of the `table_name` table. + Expects column names provided in arguments to be escaped identifiers. + + Args: + table_name: Name of the table that is selected from. + primary_keys: A sequence of column names representing the primary + key of the table. Is used to deduplicate the table. + columns: Sequence of column names that will be selected from + the table. + sort_column: Name of a column to sort the records by within a + primary key. Values in the column are sorted in descending order, + so the record with the highest value in `sort_column` remains + after deduplication. No sorting is done if a None value is provided, + leading to arbitrary deduplication. + condition: String used as a WHERE clause in the SQL statement to + filter records. The names of all columns that are used in the + condition must be provided in the `condition_columns` argument. + No filtering is done (aside from the deduplication) if a None value + is provided. + condition_columns: Sequence of names of columns used in the `condition` + argument. These column names will be selected in the inner subquery + to make them accessible to the outer WHERE clause. This argument + should only be used in combination with the `condition` argument. + + Returns: + A string representing a SELECT FROM SQL statement where the FROM + clause represents a deduplicated version of the `table_name` table. + + The returned value is used in two ways: + 1) To select the values for an INSERT INTO statement. + 2) To select the values for a temporary table used for inserts. """ if sort_column is None: order_by = "(SELECT NULL)" From a1afeb858131825a0ba610cb225868fc3525b884 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Tue, 20 Feb 2024 12:35:07 +0100 Subject: [PATCH 15/28] remove redundant copies for immutable strings --- dlt/destinations/sql_jobs.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index fd02c1a7d7..f86b82085a 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -1,5 +1,4 @@ from typing import Any, Callable, List, Sequence, Tuple, cast, TypedDict, Optional -from copy import copy import yaml from dlt.common.runtime.logger import pretty_format_exception @@ -234,10 +233,10 @@ def gen_select_from_dedup_sql( after deduplication. No sorting is done if a None value is provided, leading to arbitrary deduplication. condition: String used as a WHERE clause in the SQL statement to - filter records. The names of all columns that are used in the - condition must be provided in the `condition_columns` argument. - No filtering is done (aside from the deduplication) if a None value - is provided. + filter records. The name of any column that is used in the + condition but is not part of `columns` must be provided in the + `condition_columns` argument. No filtering is done (aside from the + deduplication) if a None value is provided. condition_columns: Sequence of names of columns used in the `condition` argument. These column names will be selected in the inner subquery to make them accessible to the outer WHERE clause. This argument @@ -258,7 +257,7 @@ def gen_select_from_dedup_sql( if condition is None: condition = "1 = 1" col_str = ", ".join(columns) - inner_col_str = copy(col_str) + inner_col_str = col_str if condition_columns is not None: inner_col_str += ", " + ", ".join(condition_columns) return f""" @@ -456,7 +455,7 @@ def gen_merge_sql( with sql_client.with_staging_dataset(staging=True): staging_table_name = sql_client.make_qualified_table_name(table["name"]) - insert_cond = copy(not_deleted_cond) if hard_delete_col is not None else "1 = 1" + insert_cond = not_deleted_cond if hard_delete_col is not None else "1 = 1" if (len(primary_keys) > 0 and len(table_chain) > 1) or ( len(primary_keys) == 0 and table.get("parent") is not None # child table From f07205dd9a3c15abe952d5d2fb474a22ca274677 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Tue, 20 Feb 2024 12:43:35 +0100 Subject: [PATCH 16/28] simplify boolean logic --- dlt/destinations/sql_jobs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index f86b82085a..4458b56175 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -434,7 +434,7 @@ def gen_merge_sql( insert_temp_table_name: str = None if len(table_chain) > 1: - if len(primary_keys) > 0 or (len(primary_keys) == 0 and hard_delete_col is not None): + if len(primary_keys) > 0 or hard_delete_col is not None: condition_columns = [hard_delete_col] if not_deleted_cond is not None else None ( create_insert_temp_table_sql, From a64580d3a6a988f2395ce52b35cb94eed7f5c4af Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Tue, 20 Feb 2024 13:46:52 +0100 Subject: [PATCH 17/28] add more test cases for hard_delete and dedup_sort hints --- tests/load/pipeline/test_merge_disposition.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index 888ebe3d97..03fccfa8ee 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -548,6 +548,28 @@ def data_resource(data): expected = [{"id": 2, "val": "baz", "deleted": None}] assert sorted(observed, key=lambda d: d["id"]) == expected + # insert two records with same key + data = [ + {"id": 3, "val": "foo", "deleted": False}, + {"id": 3, "val": "bar", "deleted": False}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + counts = load_table_counts(p, table_name)[table_name] + if key_type == "primary_key": + assert counts == 2 + elif key_type == "merge_key": + assert counts == 3 + + # delete one key, resulting in one (primary key) or two (merge key) deleted records + data = [ + {"id": 3, "val": "foo", "deleted": True}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + counts = load_table_counts(p, table_name)[table_name] + assert load_table_counts(p, table_name)[table_name] == 1 + table_name = "test_hard_delete_hint_complex" data_resource.apply_hints(table_name=table_name) @@ -768,6 +790,29 @@ def data_resource(data): expected = [{"id": 1, "val": "baz", "sequence": 3}] assert sorted(observed, key=lambda d: d["id"]) == expected + # additional tests with two records, run only on duckdb to limit test load + if destination_config.destination == "duckdb": + # two records with same primary key + # record with highest value in sort column is a delete + # existing record is deleted and no record will be inserted + data = [ + {"id": 1, "val": "foo", "sequence": 1}, + {"id": 1, "val": "bar", "sequence": 2, "deleted": True}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 0 + + # two records with same primary key + # record with highest value in sort column is not a delete, so it will be inserted + data = [ + {"id": 1, "val": "foo", "sequence": 2}, + {"id": 1, "val": "bar", "sequence": 1, "deleted": True}, + ] + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + # test if exception is raised when more than one "dedup_sort" column hints are provided @dlt.resource( name="test_dedup_sort_hint_too_many_hints", From 3308549602f4053ca2f855d90ac7048fb794fc3e Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Wed, 21 Feb 2024 02:36:03 +0100 Subject: [PATCH 18/28] refactor table chain resolution --- dlt/common/utils.py | 8 +++ dlt/load/load.py | 67 ++++++++++--------- tests/load/pipeline/test_merge_disposition.py | 28 ++++++++ 3 files changed, 70 insertions(+), 33 deletions(-) diff --git a/dlt/common/utils.py b/dlt/common/utils.py index 72fee608a8..49a425780b 100644 --- a/dlt/common/utils.py +++ b/dlt/common/utils.py @@ -575,3 +575,11 @@ def get_exception_trace_chain( elif exc.__context__: return get_exception_trace_chain(exc.__context__, traces, seen) return traces + + +def order_deduped(lst: List[Any]) -> List[Any]: + """Returns deduplicated list preserving order of input elements. + + Only works for lists with hashable elements. + """ + return list(dict.fromkeys(lst)) diff --git a/dlt/load/load.py b/dlt/load/load.py index 0ecdd1f13b..a7764d164b 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -6,6 +6,7 @@ import os from dlt.common import sleep, logger +from dlt.common.utils import order_deduped from dlt.common.configuration import with_config, known_sections from dlt.common.configuration.accessors import config from dlt.common.pipeline import LoadInfo, LoadMetrics, SupportsPipeline, WithStepInfo @@ -233,12 +234,13 @@ def get_completed_table_chain( Optionally `being_completed_job_id` can be passed that is considered to be completed before job itself moves in storage """ - # returns ordered list of tables from parent to child leaf tables table_chain: List[TTableSchema] = [] - # make sure all the jobs for the table chain is completed - for table in get_child_tables(schema.tables, top_merged_table["name"]): + # returns ordered list of tables from parent to child leaf tables + for table_name in self._get_table_chain_tables_with_filter( + schema, [top_merged_table["name"]] + ): table_jobs = self.load_storage.normalized_packages.list_jobs_for_table( - load_id, table["name"] + load_id, table_name ) # all jobs must be completed in order for merge to be created if any( @@ -247,17 +249,8 @@ def get_completed_table_chain( for job in table_jobs ): return None - # if there are no jobs for the table, skip it, unless child tables need to be replaced - needs_replacement = False - if top_merged_table["write_disposition"] == "replace" or ( - top_merged_table["write_disposition"] == "merge" - and has_column_with_prop(top_merged_table, "hard_delete") - ): - needs_replacement = True - if not table_jobs and not needs_replacement: - continue - table_chain.append(table) - # there must be at least table + table_chain.append(schema.tables[table_name]) + # there must be at least one table assert len(table_chain) > 0 return table_chain @@ -362,27 +355,35 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) f"All jobs completed, archiving package {load_id} with aborted set to {aborted}" ) - @staticmethod def _get_table_chain_tables_with_filter( - schema: Schema, f: Callable[[TTableSchema], bool], tables_with_jobs: Iterable[str] - ) -> Set[str]: - """Get all jobs for tables with given write disposition and resolve the table chain""" - result: Set[str] = set() + self, + schema: Schema, + tables_with_jobs: Iterable[str], + f: Callable[[TTableSchema], bool] = lambda t: True, + ) -> List[str]: + """Get all jobs for tables with given write disposition and resolve the table chain. + + Returns a list of table names ordered by ancestry so the child tables are always after their parents. + """ + result: List[str] = [] for table_name in tables_with_jobs: top_job_table = get_top_level_table(schema.tables, table_name) if not f(top_job_table): continue + # for replace and merge write dispositions we should include tables + # without jobs in the table chain, because child tables may need + # processing due to changes in the root table + skip_jobless_table = top_job_table["write_disposition"] not in ("replace", "merge") for table in get_child_tables(schema.tables, top_job_table["name"]): - # only add tables for tables that have jobs unless the disposition is replace - # TODO: this is a (formerly used) hack to make test_merge_on_keys_in_schema, - # we should change that test - if ( - not table["name"] in tables_with_jobs - and top_job_table["write_disposition"] != "replace" - ): - continue - result.add(table["name"]) - return result + with self.get_destination_client(schema) as job_client: + table_has_job = table["name"] in tables_with_jobs + table_exists = True + if hasattr(job_client, "get_storage_table"): + table_exists = job_client.get_storage_table(table["name"])[0] + if (not table_has_job and skip_jobless_table) or not table_exists: + continue + result.append(table["name"]) + return order_deduped(result) @staticmethod def _init_dataset_and_update_schema( @@ -425,7 +426,7 @@ def _init_client( # update the default dataset truncate_tables = self._get_table_chain_tables_with_filter( - schema, truncate_filter, tables_with_jobs + schema, tables_with_jobs, truncate_filter ) applied_update = self._init_dataset_and_update_schema( job_client, expected_update, tables_with_jobs | dlt_tables, truncate_tables @@ -434,13 +435,13 @@ def _init_client( # update the staging dataset if client supports this if isinstance(job_client, WithStagingDataset): if staging_tables := self._get_table_chain_tables_with_filter( - schema, truncate_staging_filter, tables_with_jobs + schema, tables_with_jobs, truncate_staging_filter ): with job_client.with_staging_dataset(): self._init_dataset_and_update_schema( job_client, expected_update, - staging_tables | {schema.version_table_name}, + order_deduped(staging_tables + [schema.version_table_name]), staging_tables, staging_info=True, ) diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index 03fccfa8ee..e6466bb4b5 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -490,6 +490,34 @@ def duplicates_no_child(): assert counts["duplicates_no_child"] == 2 +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, supports_merge=True), + ids=lambda x: x.name, +) +def test_complex_column_missing(destination_config: DestinationTestConfiguration) -> None: + table_name = "test_complex_column_missing" + + @dlt.resource(name=table_name, write_disposition="merge", primary_key="id") + def r(data): + yield data + + p = destination_config.setup_pipeline("abstract", full_refresh=True) + + data = [{"id": 1, "simple": "foo", "complex": [1, 2, 3]}] + info = p.run(r(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + assert load_table_counts(p, table_name + "__complex")[table_name + "__complex"] == 3 + + # complex column is missing, previously inserted records should be deleted from child table + data = [{"id": 1, "simple": "bar"}] + info = p.run(r(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + assert load_table_counts(p, table_name + "__complex")[table_name + "__complex"] == 0 + + @pytest.mark.parametrize( "destination_config", destinations_configs(default_sql_configs=True, supports_merge=True), From 189c2fb7f58465b7c5c9694dc98b5559e0d4ed3f Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Thu, 22 Feb 2024 01:42:37 +0100 Subject: [PATCH 19/28] marks tables that seen data in normalizer, skips empty jobs if never seen data --- dlt/normalize/items_normalizers.py | 24 +++++++++++++++--------- dlt/normalize/normalize.py | 28 ++++++++++++++++++---------- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/dlt/normalize/items_normalizers.py b/dlt/normalize/items_normalizers.py index 2167250036..077bfca990 100644 --- a/dlt/normalize/items_normalizers.py +++ b/dlt/normalize/items_normalizers.py @@ -196,15 +196,21 @@ def __call__( schema_updates.append(partial_update) logger.debug(f"Processed {line_no} lines from file {extracted_items_file}") if line is None and root_table_name in self.schema.tables: - self.load_storage.write_empty_items_file( - self.load_id, - self.schema.name, - root_table_name, - self.schema.get_table_columns(root_table_name), - ) - logger.debug( - f"No lines in file {extracted_items_file}, written empty load job file" - ) + # write only if table seen data before + root_table = self.schema.tables[root_table_name] + if ( + "x-normalizer" in root_table + and root_table["x-normalizer"].get("first-seen", None) is not None # type: ignore[typeddict-item] + ): + self.load_storage.write_empty_items_file( + self.load_id, + self.schema.name, + root_table_name, + self.schema.get_table_columns(root_table_name), + ) + logger.debug( + f"No lines in file {extracted_items_file}, written empty load job file" + ) return schema_updates diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 0a3c6784c7..52f3bd5a74 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -289,9 +289,23 @@ def spool_files( ) -> None: # process files in parallel or in single thread, depending on map_f schema_updates, writer_metrics = map_f(schema, load_id, files) - # remove normalizer specific info - for table in schema.tables.values(): - table.pop("x-normalizer", None) # type: ignore[typeddict-item] + # compute metrics + job_metrics = {ParsedLoadJobFileName.parse(m.file_path): m for m in writer_metrics} + table_metrics: Dict[str, DataWriterMetrics] = { + table_name: sum(map(lambda pair: pair[1], metrics), EMPTY_DATA_WRITER_METRICS) + for table_name, metrics in itertools.groupby( + job_metrics.items(), lambda pair: pair[0].table_name + ) + } + # update normalizer specific info + for table_name in table_metrics: + table = schema.tables[table_name] + x_normalizer = table.setdefault("x-normalizer", {}) # type: ignore[typeddict-item] + # drop evolve once for all tables that seen data + x_normalizer.pop("evolve-columns-once", None) + # mark that table have seen data only if there was data + if table_metrics[table_name].items_count > 0: + x_normalizer["first-seen"] = load_id logger.info( f"Saving schema {schema.name} with version {schema.stored_version}:{schema.version}" ) @@ -312,19 +326,13 @@ def spool_files( self.normalize_storage.extracted_packages.delete_package(load_id) # log and update metrics logger.info(f"Extracted package {load_id} processed") - job_metrics = {ParsedLoadJobFileName.parse(m.file_path): m for m in writer_metrics} self._step_info_complete_load_id( load_id, { "started_at": None, "finished_at": None, "job_metrics": {job.job_id(): metrics for job, metrics in job_metrics.items()}, - "table_metrics": { - table_name: sum(map(lambda pair: pair[1], metrics), EMPTY_DATA_WRITER_METRICS) - for table_name, metrics in itertools.groupby( - job_metrics.items(), lambda pair: pair[0].table_name - ) - }, + "table_metrics": table_metrics, }, ) From a649b0ed8c7d0dd53ce87b500ad5cb5bec708b79 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Thu, 22 Feb 2024 01:43:43 +0100 Subject: [PATCH 20/28] ignores tables that didn't seen data when loading, tests edge cases --- dlt/common/schema/schema.py | 3 +- dlt/destinations/impl/athena/athena.py | 4 +- dlt/load/load.py | 80 ++++++++++++---- tests/load/pipeline/test_pipelines.py | 123 ++++++++++++++++++++++++- 4 files changed, 187 insertions(+), 23 deletions(-) diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index ccfc038085..676fdaa996 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -390,6 +390,7 @@ def resolve_contract_settings_for_table( return Schema.expand_schema_contract_settings(settings) def update_table(self, partial_table: TPartialTableSchema) -> TPartialTableSchema: + """Adds or merges `partial_table` into the schema. Identifiers are not normalized""" table_name = partial_table["name"] parent_table_name = partial_table.get("parent") # check if parent table present @@ -414,7 +415,7 @@ def update_table(self, partial_table: TPartialTableSchema) -> TPartialTableSchem return partial_table def update_schema(self, schema: "Schema") -> None: - """Updates this schema from an incoming schema""" + """Updates this schema from an incoming schema. Normalizes identifiers after updating normalizers.""" # update all tables for table in schema.tables.values(): self.update_table(table) diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index 96e7818d57..cb7579f027 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -453,7 +453,9 @@ def should_load_data_to_staging_dataset_on_staging_destination( self, table: TTableSchema ) -> bool: """iceberg table data goes into staging on staging destination""" - return self._is_iceberg_table(self.get_load_table(table["name"])) + if self._is_iceberg_table(self.get_load_table(table["name"])): + return True + return super().should_load_data_to_staging_dataset_on_staging_destination(table) def get_load_table(self, table_name: str, staging: bool = False) -> TTableSchema: table = super().get_load_table(table_name, staging) diff --git a/dlt/load/load.py b/dlt/load/load.py index a7764d164b..d7efdad60b 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -374,14 +374,14 @@ def _get_table_chain_tables_with_filter( # without jobs in the table chain, because child tables may need # processing due to changes in the root table skip_jobless_table = top_job_table["write_disposition"] not in ("replace", "merge") - for table in get_child_tables(schema.tables, top_job_table["name"]): - with self.get_destination_client(schema) as job_client: - table_has_job = table["name"] in tables_with_jobs - table_exists = True - if hasattr(job_client, "get_storage_table"): - table_exists = job_client.get_storage_table(table["name"])[0] - if (not table_has_job and skip_jobless_table) or not table_exists: - continue + # use only complete tables to infer table chains + data_tables = { + t["name"]: t for t in schema.data_tables(include_incomplete=False) + [top_job_table] + } + for table in get_child_tables(data_tables, top_job_table["name"]): + table_has_job = table["name"] in tables_with_jobs + if not table_has_job and skip_jobless_table: + continue result.append(table["name"]) return order_deduped(result) @@ -419,30 +419,72 @@ def _init_client( expected_update: TSchemaTables, load_id: str, truncate_filter: Callable[[TTableSchema], bool], - truncate_staging_filter: Callable[[TTableSchema], bool], + load_staging_filter: Callable[[TTableSchema], bool], ) -> TSchemaTables: - tables_with_jobs = set(job.table_name for job in self.get_new_jobs_info(load_id)) - dlt_tables = set(t["name"] for t in schema.dlt_tables()) + """Initializes destination storage including staging dataset if supported - # update the default dataset - truncate_tables = self._get_table_chain_tables_with_filter( - schema, tables_with_jobs, truncate_filter + Will initialize and migrate schema in destination dataset and staging dataset. + + Args: + job_client (JobClientBase): Instance of destination client + schema (Schema): The schema as in load package + expected_update (TSchemaTables): Schema update as in load package. Always present even if empty + load_id (str): Package load id + truncate_filter (Callable[[TTableSchema], bool]): A filter that tells which table in destination dataset should be truncated + load_staging_filter (Callable[[TTableSchema], bool]): A filter which tell which table in the staging dataset may be loaded into + + Returns: + TSchemaTables: Actual migrations done at destination + """ + # get dlt/internal tables + dlt_tables = set(schema.dlt_table_names()) + # tables without data + tables_no_data = set( + table["name"] + for table in schema.data_tables() + if table.get("x-normalizer", {}).get("first-seen", None) is None # type: ignore[attr-defined] + ) + # get all tables that actually have load jobs with data + tables_with_jobs = ( + set(job.table_name for job in self.get_new_jobs_info(load_id)) - tables_no_data ) + + # get tables to truncate by extending tables with jobs with all their child tables + truncate_tables = set( + self._get_table_chain_tables_with_filter(schema, tables_with_jobs, truncate_filter) + ) + # must be a subset + assert (tables_with_jobs | dlt_tables).issuperset(truncate_tables) + applied_update = self._init_dataset_and_update_schema( job_client, expected_update, tables_with_jobs | dlt_tables, truncate_tables ) # update the staging dataset if client supports this if isinstance(job_client, WithStagingDataset): - if staging_tables := self._get_table_chain_tables_with_filter( - schema, tables_with_jobs, truncate_staging_filter - ): + # get staging tables (all data tables that are eligible) + staging_tables = set( + self._get_table_chain_tables_with_filter( + schema, tables_with_jobs, load_staging_filter + ) + ) + # truncate all tables + staging_truncate_tables = set( + self._get_table_chain_tables_with_filter( + schema, tables_with_jobs, load_staging_filter + ) + ) + # must be a subset + assert staging_tables.issuperset(staging_truncate_tables) + assert tables_with_jobs.issuperset(staging_tables) + + if staging_tables: with job_client.with_staging_dataset(): self._init_dataset_and_update_schema( job_client, expected_update, - order_deduped(staging_tables + [schema.version_table_name]), - staging_tables, + staging_tables | {schema.version_table_name}, # keep only schema version + staging_truncate_tables, staging_info=True, ) diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index af61dc5d3d..60eee18ed2 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -13,6 +13,7 @@ from dlt.common.schema.typing import VERSION_TABLE_NAME from dlt.common.typing import TDataItem from dlt.common.utils import uniq_id +from dlt.destinations.exceptions import DatabaseUndefinedRelation from dlt.extract.exceptions import ResourceNameMissing from dlt.extract import DltSource from dlt.pipeline.exceptions import ( @@ -22,9 +23,10 @@ ) from dlt.common.schema.exceptions import CannotCoerceColumnException from dlt.common.exceptions import DestinationHasFailedJobs +from tests.load.pipeline.test_replace_disposition import REPLACE_STRATEGIES -from tests.utils import TEST_STORAGE_ROOT, preserve_environ -from tests.pipeline.utils import assert_load_info +from tests.utils import TEST_STORAGE_ROOT, data_to_item_format, preserve_environ +from tests.pipeline.utils import assert_data_table_counts, assert_load_info from tests.load.utils import ( TABLE_ROW_ALL_DATA_TYPES, TABLE_UPDATE_COLUMNS_SCHEMA, @@ -849,6 +851,123 @@ def some_source(): ) +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + local_filesystem_configs=True, default_staging_configs=True, default_sql_configs=True + ), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("replace_strategy", REPLACE_STRATEGIES) +def test_pipeline_upfront_tables_two_loads( + destination_config: DestinationTestConfiguration, replace_strategy: str +) -> None: + if not destination_config.supports_merge and replace_strategy != "truncate-and-insert": + pytest.skip( + f"Destination {destination_config.name} does not support merge and thus" + f" {replace_strategy}" + ) + + # use staging tables for replace + os.environ["DESTINATION__REPLACE_STRATEGY"] = replace_strategy + + pipeline = destination_config.setup_pipeline( + "test_pipeline_upfront_tables_two_loads", + dataset_name="test_pipeline_upfront_tables_two_loads", + full_refresh=True, + ) + + @dlt.source + def two_tables(): + @dlt.resource( + columns=[{"name": "id", "data_type": "bigint", "nullable": True}], + write_disposition="merge", + ) + def table_1(): + yield {"id": 1} + + @dlt.resource( + columns=[{"name": "id", "data_type": "bigint", "nullable": True}], + write_disposition="merge", + ) + def table_2(): + yield data_to_item_format("arrow", [{"id": 2}]) + + @dlt.resource( + columns=[{"name": "id", "data_type": "bigint", "nullable": True}], + write_disposition="replace", + ) + def table_3(make_data=False): + if not make_data: + return + yield {"id": 3} + + return table_1, table_2, table_3 + + # discover schema + schema = two_tables().discover_schema() + # print(schema.to_pretty_yaml()) + + # now we use this schema but load just one resource + source = two_tables() + # push state, table 3 not created + load_info_1 = pipeline.run(source.table_3, schema=schema) + assert_load_info(load_info_1) + with pytest.raises(DatabaseUndefinedRelation): + load_table_counts(pipeline, "table_3") + assert "x-normalizer" not in pipeline.default_schema.tables["table_3"] + assert ( + pipeline.default_schema.tables["_dlt_pipeline_state"]["x-normalizer"]["first-seen"] # type: ignore[typeddict-item] + == load_info_1.loads_ids[0] + ) + + # load with one empty job, table 3 not created + load_info = pipeline.run(source.table_3) + assert_load_info(load_info) + with pytest.raises(DatabaseUndefinedRelation): + load_table_counts(pipeline, "table_3") + # print(pipeline.default_schema.to_pretty_yaml()) + + load_info_2 = pipeline.run([source.table_1, source.table_3]) + assert_load_info(load_info_2) + # 1 record in table 1 + assert pipeline.last_trace.last_normalize_info.row_counts["table_1"] == 1 + assert "table_3" not in pipeline.last_trace.last_normalize_info.row_counts + assert "table_2" not in pipeline.last_trace.last_normalize_info.row_counts + # only table_1 got created + assert load_table_counts(pipeline, "table_1") == {"table_1": 1} + with pytest.raises(DatabaseUndefinedRelation): + load_table_counts(pipeline, "table_2") + with pytest.raises(DatabaseUndefinedRelation): + load_table_counts(pipeline, "table_3") + + # v4 = pipeline.default_schema.to_pretty_yaml() + # print(v4) + + # now load the second one. for arrow format the schema will not update because + # in that case normalizer does not add dlt specific fields, changes are not detected + # and schema is not updated because the hash didn't change + # also we make the replace resource to load its 1 record + load_info_3 = pipeline.run([source.table_3(make_data=True), source.table_2]) + assert_data_table_counts(pipeline, {"table_1": 1, "table_2": 1, "table_3": 1}) + # v5 = pipeline.default_schema.to_pretty_yaml() + # print(v5) + + # check if seen data is market correctly + assert ( + pipeline.default_schema.tables["table_3"]["x-normalizer"]["first-seen"] # type: ignore[typeddict-item] + == load_info_3.loads_ids[0] + ) + assert ( + pipeline.default_schema.tables["table_2"]["x-normalizer"]["first-seen"] # type: ignore[typeddict-item] + == load_info_3.loads_ids[0] + ) + assert ( + pipeline.default_schema.tables["table_1"]["x-normalizer"]["first-seen"] # type: ignore[typeddict-item] + == load_info_2.loads_ids[0] + ) + + def simple_nested_pipeline( destination_config: DestinationTestConfiguration, dataset_name: str, full_refresh: bool ) -> Tuple[dlt.Pipeline, Callable[[], DltSource]]: From 4b3c59b48c92617ba0e358488f0d829d8434b130 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Thu, 22 Feb 2024 19:10:34 +0100 Subject: [PATCH 21/28] add sort order configuration option --- dlt/common/schema/typing.py | 5 ++- dlt/common/schema/utils.py | 17 ++++++++ dlt/destinations/sql_jobs.py | 29 ++++++++------ tests/load/pipeline/test_merge_disposition.py | 40 ++++++++++++++++--- 4 files changed, 71 insertions(+), 20 deletions(-) diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index da6c46526e..e57c1ede47 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -46,6 +46,7 @@ "unique", "merge_key", "root_key", + "dedup_sort", ] """Known properties and hints of the column""" # TODO: merge TColumnHint with TColumnProp @@ -59,6 +60,7 @@ "unique", "root_key", "merge_key", + "dedup_sort", ] """Known hints of a column used to declare hint regexes.""" TWriteDisposition = Literal["skip", "append", "replace", "merge"] @@ -69,6 +71,7 @@ TTypeDetectionFunc = Callable[[Type[Any], Any], Optional[TDataType]] TColumnNames = Union[str, Sequence[str]] """A string representing a column name or a list of""" +TSortOrder = Literal["asc", "desc"] COLUMN_PROPS: Set[TColumnProp] = set(get_args(TColumnProp)) COLUMN_HINTS: Set[TColumnHint] = set( @@ -113,7 +116,7 @@ class TColumnSchema(TColumnSchemaBase, total=False): merge_key: Optional[bool] variant: Optional[bool] hard_delete: Optional[bool] - dedup_sort: Optional[bool] + dedup_sort: Optional[TSortOrder] TTableSchemaColumns = Dict[str, TColumnSchema] diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index 299785f3c0..f633773383 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -37,6 +37,7 @@ TTypeDetections, TWriteDisposition, TSchemaContract, + TSortOrder, ) from dlt.common.schema.exceptions import ( CannotCoerceColumnException, @@ -590,6 +591,22 @@ def has_column_with_prop( return len(get_columns_names_with_prop(table, column_prop, include_incomplete)) > 0 +def get_dedup_sort_tuple( + table: TTableSchema, include_incomplete: bool = False +) -> Optional[Tuple[str, TSortOrder]]: + """Returns tuple with dedup sort information. + + First element is the sort column name, second element is the sort order. + + Returns None if "dedup_sort" hint was not provided. + """ + dedup_sort_col = get_first_column_name_with_prop(table, "dedup_sort", include_incomplete) + if dedup_sort_col is None: + return None + dedup_sort_order = table["columns"][dedup_sort_col]["dedup_sort"] + return (dedup_sort_col, dedup_sort_order) + + def merge_schema_updates(schema_updates: Sequence[TSchemaUpdate]) -> TSchemaTables: aggregated_update: TSchemaTables = {} for schema_update in schema_updates: diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index 4458b56175..215bcf9fe5 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -3,8 +3,12 @@ import yaml from dlt.common.runtime.logger import pretty_format_exception -from dlt.common.schema.typing import TTableSchema -from dlt.common.schema.utils import get_columns_names_with_prop, get_first_column_name_with_prop +from dlt.common.schema.typing import TTableSchema, TSortOrder +from dlt.common.schema.utils import ( + get_columns_names_with_prop, + get_first_column_name_with_prop, + get_dedup_sort_tuple, +) from dlt.common.storages.load_storage import ParsedLoadJobFileName from dlt.common.utils import uniq_id from dlt.common.destination.capabilities import DestinationCapabilitiesContext @@ -210,7 +214,7 @@ def gen_select_from_dedup_sql( table_name: str, primary_keys: Sequence[str], columns: Sequence[str], - sort_column: str = None, + dedup_sort: Tuple[str, TSortOrder] = None, condition: str = None, condition_columns: Sequence[str] = None, ) -> str: @@ -250,10 +254,9 @@ def gen_select_from_dedup_sql( 1) To select the values for an INSERT INTO statement. 2) To select the values for a temporary table used for inserts. """ - if sort_column is None: - order_by = "(SELECT NULL)" - else: - order_by = f"{sort_column} DESC" + order_by = "(SELECT NULL)" + if dedup_sort is not None: + order_by = f"{dedup_sort[0]} {dedup_sort[1].upper()}" if condition is None: condition = "1 = 1" col_str = ", ".join(columns) @@ -274,7 +277,7 @@ def gen_insert_temp_table_sql( staging_root_table_name: str, primary_keys: Sequence[str], unique_column: str, - sort_column: str = None, + dedup_sort: Tuple[str, TSortOrder] = None, condition: str = None, condition_columns: Sequence[str] = None, ) -> Tuple[List[str], str]: @@ -285,7 +288,7 @@ def gen_insert_temp_table_sql( staging_root_table_name, primary_keys, [unique_column], - sort_column, + dedup_sort, condition, condition_columns, ) @@ -429,8 +432,8 @@ def gen_merge_sql( # only True values indicate a delete for boolean columns not_deleted_cond += f" OR {escape_id(hard_delete_col)} = {escape_lit(False)}" - # get name of column with dedup_sort hint, if specified - dedup_sort_col = get_first_column_name_with_prop(root_table, "dedup_sort") + # get dedup sort information + dedup_sort = get_dedup_sort_tuple(root_table) insert_temp_table_name: str = None if len(table_chain) > 1: @@ -443,7 +446,7 @@ def gen_merge_sql( staging_root_table_name, primary_keys, unique_column, - dedup_sort_col, + dedup_sort, not_deleted_cond, condition_columns, ) @@ -470,7 +473,7 @@ def gen_merge_sql( if len(primary_keys) > 0 and len(table_chain) == 1: # without child tables we deduplicate inside the query instead of using a temp table select_sql = cls.gen_select_from_dedup_sql( - staging_table_name, primary_keys, columns, dedup_sort_col, insert_cond + staging_table_name, primary_keys, columns, dedup_sort, insert_cond ) sql.append(f"INSERT INTO {table_name}({col_str}) {select_sql};") diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index e6466bb4b5..a209e727ae 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -734,7 +734,7 @@ def test_dedup_sort_hint(destination_config: DestinationTestConfiguration) -> No name=table_name, write_disposition="merge", primary_key="id", # sort hints only have effect when a primary key is provided - columns={"sequence": {"dedup_sort": True}}, + columns={"sequence": {"dedup_sort": "desc"}}, ) def data_resource(data): yield data @@ -742,7 +742,6 @@ def data_resource(data): p = destination_config.setup_pipeline("abstract", full_refresh=True) # three records with same primary key - # only record with highest value in sort column is inserted data = [ {"id": 1, "val": "foo", "sequence": 1}, {"id": 1, "val": "baz", "sequence": 3}, @@ -753,6 +752,7 @@ def data_resource(data): assert load_table_counts(p, table_name)[table_name] == 1 # compare observed records with expected records + # record with highest value in sort column is inserted (because "desc") qual_name = p.sql_client().make_qualified_table_name(table_name) observed = [ {"id": row[0], "val": row[1], "sequence": row[2]} @@ -761,8 +761,28 @@ def data_resource(data): expected = [{"id": 1, "val": "baz", "sequence": 3}] assert sorted(observed, key=lambda d: d["id"]) == expected + # now test "asc" sorting + data_resource.apply_hints(columns={"sequence": {"dedup_sort": "asc"}}) + + info = p.run(data_resource(data)) + assert_load_info(info) + assert load_table_counts(p, table_name)[table_name] == 1 + + # compare observed records with expected records + # record with highest lowest in sort column is inserted (because "asc") + qual_name = p.sql_client().make_qualified_table_name(table_name) + observed = [ + {"id": row[0], "val": row[1], "sequence": row[2]} + for row in select_data(p, f"SELECT id, val, sequence FROM {qual_name}") + ] + expected = [{"id": 1, "val": "foo", "sequence": 1}] + assert sorted(observed, key=lambda d: d["id"]) == expected + table_name = "test_dedup_sort_hint_complex" - data_resource.apply_hints(table_name=table_name) + data_resource.apply_hints( + table_name=table_name, + columns={"sequence": {"dedup_sort": "desc"}}, + ) # three records with same primary key # only record with highest value in sort column is inserted @@ -784,7 +804,7 @@ def data_resource(data): table_name = "test_dedup_sort_hint_with_hard_delete" data_resource.apply_hints( table_name=table_name, - columns={"sequence": {"dedup_sort": True}, "deleted": {"hard_delete": True}}, + columns={"sequence": {"dedup_sort": "desc"}, "deleted": {"hard_delete": True}}, ) # three records with same primary key @@ -841,14 +861,22 @@ def data_resource(data): assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == 1 - # test if exception is raised when more than one "dedup_sort" column hints are provided + # test if exception is raised for invalid column schema's @dlt.resource( name="test_dedup_sort_hint_too_many_hints", write_disposition="merge", - columns={"dedup_sort_1": {"dedup_sort": True}, "dedup_sort_2": {"dedup_sort": True}}, + columns={"dedup_sort_1": {"dedup_sort": "this_is_invalid"}}, # type: ignore[call-overload] ) def r(): yield {"id": 1, "val": "foo", "dedup_sort_1": 1, "dedup_sort_2": 5} + # invalid value for "dedup_sort" hint + with pytest.raises(PipelineStepFailed): + info = p.run(r()) + + # more than one "dedup_sort" column hints are provided + r.apply_hints( + columns={"dedup_sort_1": {"dedup_sort": "desc"}, "dedup_sort_2": {"dedup_sort": "desc"}} + ) with pytest.raises(PipelineStepFailed): info = p.run(r()) From c984c4e6e92cfe58c4f0ae7055698c44de8359a5 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Thu, 22 Feb 2024 20:32:04 +0100 Subject: [PATCH 22/28] bumps schema engine to v9, adds migrations --- dlt/common/schema/migrations.py | 128 +++++ dlt/common/schema/schema.py | 3 +- dlt/common/schema/typing.py | 2 +- dlt/common/schema/utils.py | 116 +---- .../cases/schemas/eth/ethereum_schema_v9.yml | 476 ++++++++++++++++++ tests/common/schema/test_schema.py | 23 +- tests/common/schema/test_versioning.py | 23 +- tests/common/storages/test_schema_storage.py | 16 +- tests/extract/test_decorators.py | 4 +- 9 files changed, 651 insertions(+), 140 deletions(-) create mode 100644 dlt/common/schema/migrations.py create mode 100644 tests/common/cases/schemas/eth/ethereum_schema_v9.yml diff --git a/dlt/common/schema/migrations.py b/dlt/common/schema/migrations.py new file mode 100644 index 0000000000..9b206d61a6 --- /dev/null +++ b/dlt/common/schema/migrations.py @@ -0,0 +1,128 @@ +from typing import Dict, List, cast + +from dlt.common.data_types import TDataType +from dlt.common.normalizers import explicit_normalizers +from dlt.common.typing import DictStrAny +from dlt.common.schema.typing import ( + LOADS_TABLE_NAME, + VERSION_TABLE_NAME, + TSimpleRegex, + TStoredSchema, + TTableSchemaColumns, + TColumnHint, +) +from dlt.common.schema.exceptions import SchemaEngineNoUpgradePathException + +from dlt.common.normalizers.utils import import_normalizers +from dlt.common.schema.utils import new_table, version_table, load_table + + +def migrate_schema(schema_dict: DictStrAny, from_engine: int, to_engine: int) -> TStoredSchema: + if from_engine == to_engine: + return cast(TStoredSchema, schema_dict) + + if from_engine == 1 and to_engine > 1: + schema_dict["includes"] = [] + schema_dict["excludes"] = [] + from_engine = 2 + if from_engine == 2 and to_engine > 2: + # current version of the schema + current = cast(TStoredSchema, schema_dict) + # add default normalizers and root hash propagation + current["normalizers"], _, _ = import_normalizers(explicit_normalizers()) + current["normalizers"]["json"]["config"] = { + "propagation": {"root": {"_dlt_id": "_dlt_root_id"}} + } + # move settings, convert strings to simple regexes + d_h: Dict[TColumnHint, List[TSimpleRegex]] = schema_dict.pop("hints", {}) + for h_k, h_l in d_h.items(): + d_h[h_k] = list(map(lambda r: TSimpleRegex("re:" + r), h_l)) + p_t: Dict[TSimpleRegex, TDataType] = schema_dict.pop("preferred_types", {}) + p_t = {TSimpleRegex("re:" + k): v for k, v in p_t.items()} + + current["settings"] = { + "default_hints": d_h, + "preferred_types": p_t, + } + # repackage tables + old_tables: Dict[str, TTableSchemaColumns] = schema_dict.pop("tables") + current["tables"] = {} + for name, columns in old_tables.items(): + # find last path separator + parent = name + # go back in a loop to find existing parent + while True: + idx = parent.rfind("__") + if idx > 0: + parent = parent[:idx] + if parent not in old_tables: + continue + else: + parent = None + break + nt = new_table(name, parent) + nt["columns"] = columns + current["tables"][name] = nt + # assign exclude and include to tables + + def migrate_filters(group: str, filters: List[str]) -> None: + # existing filter were always defined at the root table. find this table and move filters + for f in filters: + # skip initial ^ + root = f[1 : f.find("__")] + path = f[f.find("__") + 2 :] + t = current["tables"].get(root) + if t is None: + # must add new table to hold filters + t = new_table(root) + current["tables"][root] = t + t.setdefault("filters", {}).setdefault(group, []).append("re:^" + path) # type: ignore + + excludes = schema_dict.pop("excludes", []) + migrate_filters("excludes", excludes) + includes = schema_dict.pop("includes", []) + migrate_filters("includes", includes) + + # upgraded + from_engine = 3 + if from_engine == 3 and to_engine > 3: + # set empty version hash to pass validation, in engine 4 this hash is mandatory + schema_dict.setdefault("version_hash", "") + from_engine = 4 + if from_engine == 4 and to_engine > 4: + # replace schema versions table + schema_dict["tables"][VERSION_TABLE_NAME] = version_table() + schema_dict["tables"][LOADS_TABLE_NAME] = load_table() + from_engine = 5 + if from_engine == 5 and to_engine > 5: + # replace loads table + schema_dict["tables"][LOADS_TABLE_NAME] = load_table() + from_engine = 6 + if from_engine == 6 and to_engine > 6: + # migrate from sealed properties to schema evolution settings + schema_dict["settings"].pop("schema_sealed", None) + schema_dict["settings"]["schema_contract"] = {} + for table in schema_dict["tables"].values(): + table.pop("table_sealed", None) + if not table.get("parent"): + table["schema_contract"] = {} + from_engine = 7 + if from_engine == 7 and to_engine > 7: + schema_dict["previous_hashes"] = [] + from_engine = 8 + if from_engine == 8 and to_engine > 8: + # add "seen-data" to all tables with _dlt_id, this will handle packages + # that are being loaded + for table in schema_dict["tables"].values(): + if "_dlt_id" in table["columns"]: + x_normalizer = table.setdefault("x-normalizer", {}) + x_normalizer["seen-data"] = True + from_engine = 9 + + schema_dict["engine_version"] = from_engine + if from_engine != to_engine: + raise SchemaEngineNoUpgradePathException( + schema_dict["name"], schema_dict["engine_version"], from_engine, to_engine + ) + + return cast(TStoredSchema, schema_dict) diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index 676fdaa996..b73e45d489 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -2,6 +2,7 @@ from copy import copy, deepcopy from typing import ClassVar, Dict, List, Mapping, Optional, Sequence, Tuple, Any, cast, Literal from dlt.common import json +from dlt.common.schema.migrations import migrate_schema from dlt.common.utils import extend_list_deduplicated from dlt.common.typing import ( @@ -103,7 +104,7 @@ def __init__(self, name: str, normalizers: TNormalizersConfig = None) -> None: @classmethod def from_dict(cls, d: DictStrAny, bump_version: bool = True) -> "Schema": # upgrade engine if needed - stored_schema = utils.migrate_schema(d, d["engine_version"], cls.ENGINE_VERSION) + stored_schema = migrate_schema(d, d["engine_version"], cls.ENGINE_VERSION) # verify schema utils.validate_stored_schema(stored_schema) # add defaults diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index da6c46526e..35e3f75d4f 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -26,7 +26,7 @@ # current version of schema engine -SCHEMA_ENGINE_VERSION = 8 +SCHEMA_ENGINE_VERSION = 9 # dlt tables VERSION_TABLE_NAME = "_dlt_version" diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index 299785f3c0..428b96e3b7 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -8,7 +8,6 @@ from dlt.common import json from dlt.common.data_types import TDataType from dlt.common.exceptions import DictValidationException -from dlt.common.normalizers import explicit_normalizers from dlt.common.normalizers.naming import NamingConvention from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCase from dlt.common.typing import DictStrAny, REPattern @@ -27,7 +26,6 @@ TSimpleRegex, TStoredSchema, TTableSchema, - TTableSchemaColumns, TColumnSchemaBase, TColumnSchema, TColumnProp, @@ -41,16 +39,10 @@ from dlt.common.schema.exceptions import ( CannotCoerceColumnException, ParentTableNotFoundException, - SchemaEngineNoUpgradePathException, - SchemaException, TablePropertiesConflictException, InvalidSchemaName, - UnknownTableException, ) -from dlt.common.normalizers.utils import import_normalizers -from dlt.common.schema.typing import TAnySchemaColumns - RE_NON_ALPHANUMERIC_UNDERSCORE = re.compile(r"[^a-zA-Z\d_]") DEFAULT_WRITE_DISPOSITION: TWriteDisposition = "append" @@ -318,109 +310,6 @@ def validate_stored_schema(stored_schema: TStoredSchema) -> None: raise ParentTableNotFoundException(table_name, parent_table_name) -def migrate_schema(schema_dict: DictStrAny, from_engine: int, to_engine: int) -> TStoredSchema: - if from_engine == to_engine: - return cast(TStoredSchema, schema_dict) - - if from_engine == 1 and to_engine > 1: - schema_dict["includes"] = [] - schema_dict["excludes"] = [] - from_engine = 2 - if from_engine == 2 and to_engine > 2: - # current version of the schema - current = cast(TStoredSchema, schema_dict) - # add default normalizers and root hash propagation - current["normalizers"], _, _ = import_normalizers(explicit_normalizers()) - current["normalizers"]["json"]["config"] = { - "propagation": {"root": {"_dlt_id": "_dlt_root_id"}} - } - # move settings, convert strings to simple regexes - d_h: Dict[TColumnHint, List[TSimpleRegex]] = schema_dict.pop("hints", {}) - for h_k, h_l in d_h.items(): - d_h[h_k] = list(map(lambda r: TSimpleRegex("re:" + r), h_l)) - p_t: Dict[TSimpleRegex, TDataType] = schema_dict.pop("preferred_types", {}) - p_t = {TSimpleRegex("re:" + k): v for k, v in p_t.items()} - - current["settings"] = { - "default_hints": d_h, - "preferred_types": p_t, - } - # repackage tables - old_tables: Dict[str, TTableSchemaColumns] = schema_dict.pop("tables") - current["tables"] = {} - for name, columns in old_tables.items(): - # find last path separator - parent = name - # go back in a loop to find existing parent - while True: - idx = parent.rfind("__") - if idx > 0: - parent = parent[:idx] - if parent not in old_tables: - continue - else: - parent = None - break - nt = new_table(name, parent) - nt["columns"] = columns - current["tables"][name] = nt - # assign exclude and include to tables - - def migrate_filters(group: str, filters: List[str]) -> None: - # existing filter were always defined at the root table. find this table and move filters - for f in filters: - # skip initial ^ - root = f[1 : f.find("__")] - path = f[f.find("__") + 2 :] - t = current["tables"].get(root) - if t is None: - # must add new table to hold filters - t = new_table(root) - current["tables"][root] = t - t.setdefault("filters", {}).setdefault(group, []).append("re:^" + path) # type: ignore - - excludes = schema_dict.pop("excludes", []) - migrate_filters("excludes", excludes) - includes = schema_dict.pop("includes", []) - migrate_filters("includes", includes) - - # upgraded - from_engine = 3 - if from_engine == 3 and to_engine > 3: - # set empty version hash to pass validation, in engine 4 this hash is mandatory - schema_dict.setdefault("version_hash", "") - from_engine = 4 - if from_engine == 4 and to_engine > 4: - # replace schema versions table - schema_dict["tables"][VERSION_TABLE_NAME] = version_table() - schema_dict["tables"][LOADS_TABLE_NAME] = load_table() - from_engine = 5 - if from_engine == 5 and to_engine > 5: - # replace loads table - schema_dict["tables"][LOADS_TABLE_NAME] = load_table() - from_engine = 6 - if from_engine == 6 and to_engine > 6: - # migrate from sealed properties to schema evolution settings - schema_dict["settings"].pop("schema_sealed", None) - schema_dict["settings"]["schema_contract"] = {} - for table in schema_dict["tables"].values(): - table.pop("table_sealed", None) - if not table.get("parent"): - table["schema_contract"] = {} - from_engine = 7 - if from_engine == 7 and to_engine > 7: - schema_dict["previous_hashes"] = [] - from_engine = 8 - - schema_dict["engine_version"] = from_engine - if from_engine != to_engine: - raise SchemaEngineNoUpgradePathException( - schema_dict["name"], schema_dict["engine_version"], from_engine, to_engine - ) - - return cast(TStoredSchema, schema_dict) - - def autodetect_sc_type(detection_fs: Sequence[TTypeDetections], t: Type[Any], v: Any) -> TDataType: if detection_fs: for detection_fn in detection_fs: @@ -555,6 +444,11 @@ def merge_tables(table: TTableSchema, partial_table: TPartialTableSchema) -> TPa return diff_table +def has_table_seen_data(table: TTableSchema) -> bool: + """Checks if normalizer has seen data coming to the table.""" + return "x-normalizer" in table and table["x-normalizer"].get("seen-data", None) is True # type: ignore[typeddict-item] + + def hint_to_column_prop(h: TColumnHint) -> TColumnProp: if h == "not_null": return "nullable" diff --git a/tests/common/cases/schemas/eth/ethereum_schema_v9.yml b/tests/common/cases/schemas/eth/ethereum_schema_v9.yml new file mode 100644 index 0000000000..c56ff85a9f --- /dev/null +++ b/tests/common/cases/schemas/eth/ethereum_schema_v9.yml @@ -0,0 +1,476 @@ +version: 17 +version_hash: PgEHvn5+BHV1jNzNYpx9aDpq6Pq1PSSetufj/h0hKg4= +engine_version: 9 +name: ethereum +tables: + _dlt_loads: + columns: + load_id: + nullable: false + data_type: text + name: load_id + schema_name: + nullable: true + data_type: text + name: schema_name + status: + nullable: false + data_type: bigint + name: status + inserted_at: + nullable: false + data_type: timestamp + name: inserted_at + schema_version_hash: + nullable: true + data_type: text + name: schema_version_hash + write_disposition: skip + description: Created by DLT. Tracks completed loads + schema_contract: {} + name: _dlt_loads + resource: _dlt_loads + _dlt_version: + columns: + version: + nullable: false + data_type: bigint + name: version + engine_version: + nullable: false + data_type: bigint + name: engine_version + inserted_at: + nullable: false + data_type: timestamp + name: inserted_at + schema_name: + nullable: false + data_type: text + name: schema_name + version_hash: + nullable: false + data_type: text + name: version_hash + schema: + nullable: false + data_type: text + name: schema + write_disposition: skip + description: Created by DLT. Tracks schema updates + schema_contract: {} + name: _dlt_version + resource: _dlt_version + blocks: + description: Ethereum blocks + x-annotation: this will be preserved on save + write_disposition: append + filters: + includes: [] + excludes: [] + columns: + _dlt_load_id: + nullable: false + description: load id coming from the extractor + data_type: text + name: _dlt_load_id + _dlt_id: + nullable: false + unique: true + data_type: text + name: _dlt_id + number: + nullable: false + primary_key: true + data_type: bigint + name: number + parent_hash: + nullable: true + data_type: text + name: parent_hash + hash: + nullable: false + cluster: true + unique: true + data_type: text + name: hash + base_fee_per_gas: + nullable: false + data_type: wei + name: base_fee_per_gas + difficulty: + nullable: false + data_type: wei + name: difficulty + extra_data: + nullable: true + data_type: text + name: extra_data + gas_limit: + nullable: false + data_type: bigint + name: gas_limit + gas_used: + nullable: false + data_type: bigint + name: gas_used + logs_bloom: + nullable: true + data_type: binary + name: logs_bloom + miner: + nullable: true + data_type: text + name: miner + mix_hash: + nullable: true + data_type: text + name: mix_hash + nonce: + nullable: true + data_type: text + name: nonce + receipts_root: + nullable: true + data_type: text + name: receipts_root + sha3_uncles: + nullable: true + data_type: text + name: sha3_uncles + size: + nullable: true + data_type: bigint + name: size + state_root: + nullable: false + data_type: text + name: state_root + timestamp: + nullable: false + unique: true + sort: true + data_type: timestamp + name: timestamp + total_difficulty: + nullable: true + data_type: wei + name: total_difficulty + transactions_root: + nullable: false + data_type: text + name: transactions_root + schema_contract: {} + name: blocks + resource: blocks + x-normalizer: + seen-data: true + blocks__transactions: + parent: blocks + columns: + _dlt_id: + nullable: false + unique: true + data_type: text + name: _dlt_id + block_number: + nullable: false + primary_key: true + foreign_key: true + data_type: bigint + name: block_number + transaction_index: + nullable: false + primary_key: true + data_type: bigint + name: transaction_index + hash: + nullable: false + unique: true + data_type: text + name: hash + block_hash: + nullable: false + cluster: true + data_type: text + name: block_hash + block_timestamp: + nullable: false + sort: true + data_type: timestamp + name: block_timestamp + chain_id: + nullable: true + data_type: text + name: chain_id + from: + nullable: true + data_type: text + name: from + gas: + nullable: true + data_type: bigint + name: gas + gas_price: + nullable: true + data_type: bigint + name: gas_price + input: + nullable: true + data_type: text + name: input + max_fee_per_gas: + nullable: true + data_type: wei + name: max_fee_per_gas + max_priority_fee_per_gas: + nullable: true + data_type: wei + name: max_priority_fee_per_gas + nonce: + nullable: true + data_type: bigint + name: nonce + r: + nullable: true + data_type: text + name: r + s: + nullable: true + data_type: text + name: s + status: + nullable: true + data_type: bigint + name: status + to: + nullable: true + data_type: text + name: to + type: + nullable: true + data_type: text + name: type + v: + nullable: true + data_type: bigint + name: v + value: + nullable: false + data_type: wei + name: value + eth_value: + nullable: true + data_type: decimal + name: eth_value + name: blocks__transactions + x-normalizer: + seen-data: true + blocks__transactions__logs: + parent: blocks__transactions + columns: + _dlt_id: + nullable: false + unique: true + data_type: text + name: _dlt_id + address: + nullable: false + data_type: text + name: address + block_timestamp: + nullable: false + sort: true + data_type: timestamp + name: block_timestamp + block_hash: + nullable: false + cluster: true + data_type: text + name: block_hash + block_number: + nullable: false + primary_key: true + foreign_key: true + data_type: bigint + name: block_number + transaction_index: + nullable: false + primary_key: true + foreign_key: true + data_type: bigint + name: transaction_index + log_index: + nullable: false + primary_key: true + data_type: bigint + name: log_index + data: + nullable: true + data_type: text + name: data + removed: + nullable: true + data_type: bool + name: removed + transaction_hash: + nullable: false + data_type: text + name: transaction_hash + name: blocks__transactions__logs + x-normalizer: + seen-data: true + blocks__transactions__logs__topics: + parent: blocks__transactions__logs + columns: + _dlt_parent_id: + nullable: false + foreign_key: true + data_type: text + name: _dlt_parent_id + _dlt_list_idx: + nullable: false + data_type: bigint + name: _dlt_list_idx + _dlt_id: + nullable: false + unique: true + data_type: text + name: _dlt_id + _dlt_root_id: + nullable: false + root_key: true + data_type: text + name: _dlt_root_id + value: + nullable: true + data_type: text + name: value + name: blocks__transactions__logs__topics + x-normalizer: + seen-data: true + blocks__transactions__access_list: + parent: blocks__transactions + columns: + _dlt_parent_id: + nullable: false + foreign_key: true + data_type: text + name: _dlt_parent_id + _dlt_list_idx: + nullable: false + data_type: bigint + name: _dlt_list_idx + _dlt_id: + nullable: false + unique: true + data_type: text + name: _dlt_id + _dlt_root_id: + nullable: false + root_key: true + data_type: text + name: _dlt_root_id + address: + nullable: true + data_type: text + name: address + name: blocks__transactions__access_list + x-normalizer: + seen-data: true + blocks__transactions__access_list__storage_keys: + parent: blocks__transactions__access_list + columns: + _dlt_parent_id: + nullable: false + foreign_key: true + data_type: text + name: _dlt_parent_id + _dlt_list_idx: + nullable: false + data_type: bigint + name: _dlt_list_idx + _dlt_id: + nullable: false + unique: true + data_type: text + name: _dlt_id + _dlt_root_id: + nullable: false + root_key: true + data_type: text + name: _dlt_root_id + value: + nullable: true + data_type: text + name: value + name: blocks__transactions__access_list__storage_keys + x-normalizer: + seen-data: true + blocks__uncles: + parent: blocks + columns: + _dlt_parent_id: + nullable: false + foreign_key: true + data_type: text + name: _dlt_parent_id + _dlt_list_idx: + nullable: false + data_type: bigint + name: _dlt_list_idx + _dlt_id: + nullable: false + unique: true + data_type: text + name: _dlt_id + _dlt_root_id: + nullable: false + root_key: true + data_type: text + name: _dlt_root_id + value: + nullable: true + data_type: text + name: value + name: blocks__uncles + x-normalizer: + seen-data: true +settings: + default_hints: + foreign_key: + - _dlt_parent_id + not_null: + - re:^_dlt_id$ + - _dlt_root_id + - _dlt_parent_id + - _dlt_list_idx + unique: + - _dlt_id + cluster: + - block_hash + partition: + - block_timestamp + root_key: + - _dlt_root_id + preferred_types: + timestamp: timestamp + block_timestamp: timestamp + schema_contract: {} +normalizers: + names: dlt.common.normalizers.names.snake_case + json: + module: dlt.common.normalizers.json.relational + config: + generate_dlt_id: true + propagation: + root: + _dlt_id: _dlt_root_id + tables: + blocks: + timestamp: block_timestamp + hash: block_hash +previous_hashes: +- C5An8WClbavalXDdNSqXbdI7Swqh/mTWMcwWKCF//EE= +- yjMtV4Zv0IJlfR5DPMwuXxGg8BRhy7E79L26XAHWEGE= + diff --git a/tests/common/schema/test_schema.py b/tests/common/schema/test_schema.py index 54892eeae5..ba817b946f 100644 --- a/tests/common/schema/test_schema.py +++ b/tests/common/schema/test_schema.py @@ -6,6 +6,7 @@ from dlt.common import pendulum from dlt.common.configuration import resolve_configuration from dlt.common.configuration.container import Container +from dlt.common.schema.migrations import migrate_schema from dlt.common.storages import SchemaStorageConfiguration from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.exceptions import DictValidationException @@ -308,7 +309,7 @@ def test_upgrade_engine_v1_schema() -> None: # ensure engine v1 assert schema_dict["engine_version"] == 1 # schema_dict will be updated to new engine version - utils.migrate_schema(schema_dict, from_engine=1, to_engine=2) + migrate_schema(schema_dict, from_engine=1, to_engine=2) assert schema_dict["engine_version"] == 2 # we have 27 tables assert len(schema_dict["tables"]) == 27 @@ -316,40 +317,46 @@ def test_upgrade_engine_v1_schema() -> None: # upgrade schema eng 2 -> 4 schema_dict = load_json_case("schemas/ev2/event.schema") assert schema_dict["engine_version"] == 2 - upgraded = utils.migrate_schema(schema_dict, from_engine=2, to_engine=4) + upgraded = migrate_schema(schema_dict, from_engine=2, to_engine=4) assert upgraded["engine_version"] == 4 # upgrade 1 -> 4 schema_dict = load_json_case("schemas/ev1/event.schema") assert schema_dict["engine_version"] == 1 - upgraded = utils.migrate_schema(schema_dict, from_engine=1, to_engine=4) + upgraded = migrate_schema(schema_dict, from_engine=1, to_engine=4) assert upgraded["engine_version"] == 4 # upgrade 1 -> 6 schema_dict = load_json_case("schemas/ev1/event.schema") assert schema_dict["engine_version"] == 1 - upgraded = utils.migrate_schema(schema_dict, from_engine=1, to_engine=6) + upgraded = migrate_schema(schema_dict, from_engine=1, to_engine=6) assert upgraded["engine_version"] == 6 # upgrade 1 -> 7 schema_dict = load_json_case("schemas/ev1/event.schema") assert schema_dict["engine_version"] == 1 - upgraded = utils.migrate_schema(schema_dict, from_engine=1, to_engine=7) + upgraded = migrate_schema(schema_dict, from_engine=1, to_engine=7) assert upgraded["engine_version"] == 7 # upgrade 1 -> 8 schema_dict = load_json_case("schemas/ev1/event.schema") assert schema_dict["engine_version"] == 1 - upgraded = utils.migrate_schema(schema_dict, from_engine=1, to_engine=8) + upgraded = migrate_schema(schema_dict, from_engine=1, to_engine=8) assert upgraded["engine_version"] == 8 + # upgrade 1 -> 9 + schema_dict = load_json_case("schemas/ev1/event.schema") + assert schema_dict["engine_version"] == 1 + upgraded = migrate_schema(schema_dict, from_engine=1, to_engine=9) + assert upgraded["engine_version"] == 9 + def test_unknown_engine_upgrade() -> None: schema_dict: TStoredSchema = load_json_case("schemas/ev1/event.schema") # there's no path to migrate 3 -> 2 schema_dict["engine_version"] = 3 with pytest.raises(SchemaEngineNoUpgradePathException): - utils.migrate_schema(schema_dict, 3, 2) # type: ignore[arg-type] + migrate_schema(schema_dict, 3, 2) # type: ignore[arg-type] def test_preserve_column_order(schema: Schema, schema_storage: SchemaStorage) -> None: @@ -693,7 +700,7 @@ def assert_new_schema_values(schema: Schema) -> None: assert schema.stored_version == 1 assert schema.stored_version_hash is not None assert schema.version_hash is not None - assert schema.ENGINE_VERSION == 8 + assert schema.ENGINE_VERSION == 9 assert schema._stored_previous_hashes == [] assert len(schema.settings["default_hints"]) > 0 # check settings diff --git a/tests/common/schema/test_versioning.py b/tests/common/schema/test_versioning.py index 5b794f51ee..dde05001e8 100644 --- a/tests/common/schema/test_versioning.py +++ b/tests/common/schema/test_versioning.py @@ -84,10 +84,10 @@ def test_infer_column_bumps_version() -> None: def test_preserve_version_on_load() -> None: - eth_v8: TStoredSchema = load_yml_case("schemas/eth/ethereum_schema_v8") - version = eth_v8["version"] - version_hash = eth_v8["version_hash"] - schema = Schema.from_dict(eth_v8) # type: ignore[arg-type] + eth_v9: TStoredSchema = load_yml_case("schemas/eth/ethereum_schema_v9") + version = eth_v9["version"] + version_hash = eth_v9["version_hash"] + schema = Schema.from_dict(eth_v9) # type: ignore[arg-type] # version should not be bumped assert version_hash == schema._stored_version_hash assert version_hash == schema.version_hash @@ -126,13 +126,18 @@ def test_version_preserve_on_reload(remove_defaults: bool) -> None: def test_create_ancestry() -> None: - eth_v8: TStoredSchema = load_yml_case("schemas/eth/ethereum_schema_v8") - schema = Schema.from_dict(eth_v8) # type: ignore[arg-type] - assert schema._stored_previous_hashes == ["yjMtV4Zv0IJlfR5DPMwuXxGg8BRhy7E79L26XAHWEGE="] + eth_v9: TStoredSchema = load_yml_case("schemas/eth/ethereum_schema_v9") + schema = Schema.from_dict(eth_v9) # type: ignore[arg-type] + + expected_previous_hashes = [ + "C5An8WClbavalXDdNSqXbdI7Swqh/mTWMcwWKCF//EE=", + "yjMtV4Zv0IJlfR5DPMwuXxGg8BRhy7E79L26XAHWEGE=", + ] + hash_count = len(expected_previous_hashes) + assert schema._stored_previous_hashes == expected_previous_hashes version = schema._stored_version # modify save and load schema 15 times and check ancestry - expected_previous_hashes = ["yjMtV4Zv0IJlfR5DPMwuXxGg8BRhy7E79L26XAHWEGE="] for i in range(1, 15): # keep expected previous_hashes expected_previous_hashes.insert(0, schema._stored_version_hash) @@ -148,6 +153,6 @@ def test_create_ancestry() -> None: assert schema._stored_version == version + i # we never have more than 10 previous_hashes - assert len(schema._stored_previous_hashes) == i + 1 if i + 1 <= 10 else 10 + assert len(schema._stored_previous_hashes) == i + hash_count if i + hash_count <= 10 else 10 assert len(schema._stored_previous_hashes) == 10 diff --git a/tests/common/storages/test_schema_storage.py b/tests/common/storages/test_schema_storage.py index c72fa75927..0e04554649 100644 --- a/tests/common/storages/test_schema_storage.py +++ b/tests/common/storages/test_schema_storage.py @@ -4,9 +4,9 @@ import yaml from dlt.common import json +from dlt.common.normalizers import explicit_normalizers from dlt.common.schema.schema import Schema from dlt.common.schema.typing import TStoredSchema -from dlt.common.schema.utils import explicit_normalizers from dlt.common.storages.exceptions import ( InStorageSchemaModified, SchemaNotFoundError, @@ -24,7 +24,7 @@ load_yml_case, yml_case_path, COMMON_TEST_CASES_PATH, - IMPORTED_VERSION_HASH_ETH_V8, + IMPORTED_VERSION_HASH_ETH_V9, ) @@ -227,10 +227,10 @@ def test_save_store_schema_over_import(ie_storage: SchemaStorage) -> None: ie_storage.save_schema(schema) assert schema.version_hash == schema_hash # we linked schema to import schema - assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V8 + assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9 # load schema and make sure our new schema is here schema = ie_storage.load_schema("ethereum") - assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V8 + assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9 assert schema._stored_version_hash == schema_hash assert schema.version_hash == schema_hash assert schema.previous_hashes == [] @@ -247,7 +247,7 @@ def test_save_store_schema_over_import_sync(synced_storage: SchemaStorage) -> No schema = Schema("ethereum") schema_hash = schema.version_hash synced_storage.save_schema(schema) - assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V8 + assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9 # import schema is overwritten fs = FileStorage(synced_storage.config.import_schema_path) exported_name = synced_storage._file_name_in_store("ethereum", "yaml") @@ -327,13 +327,13 @@ def prepare_import_folder(storage: SchemaStorage) -> None: def assert_schema_imported(synced_storage: SchemaStorage, storage: SchemaStorage) -> Schema: prepare_import_folder(synced_storage) - eth_V8: TStoredSchema = load_yml_case("schemas/eth/ethereum_schema_v8") + eth_V9: TStoredSchema = load_yml_case("schemas/eth/ethereum_schema_v9") schema = synced_storage.load_schema("ethereum") # is linked to imported schema - schema._imported_version_hash = eth_V8["version_hash"] + schema._imported_version_hash = eth_V9["version_hash"] # also was saved in storage assert synced_storage.has_schema("ethereum") # and has link to imported schema s well (load without import) schema = storage.load_schema("ethereum") - assert schema._imported_version_hash == eth_V8["version_hash"] + assert schema._imported_version_hash == eth_V9["version_hash"] return schema diff --git a/tests/extract/test_decorators.py b/tests/extract/test_decorators.py index 3c15bf37f5..a76cbd0cfd 100644 --- a/tests/extract/test_decorators.py +++ b/tests/extract/test_decorators.py @@ -39,7 +39,7 @@ ) from dlt.extract.typing import TableNameMeta -from tests.common.utils import IMPORTED_VERSION_HASH_ETH_V8 +from tests.common.utils import IMPORTED_VERSION_HASH_ETH_V9 def test_none_returning_source() -> None: @@ -84,7 +84,7 @@ def test_load_schema_for_callable() -> None: schema = s.schema assert schema.name == "ethereum" == s.name # the schema in the associated file has this hash - assert schema.stored_version_hash == IMPORTED_VERSION_HASH_ETH_V8 + assert schema.stored_version_hash == IMPORTED_VERSION_HASH_ETH_V9 def test_unbound_parametrized_transformer() -> None: From 935748a7aec2f359ce91ce7c4054dc1351bd923a Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Thu, 22 Feb 2024 20:32:30 +0100 Subject: [PATCH 23/28] filters tables without data properly in load --- dlt/load/load.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/dlt/load/load.py b/dlt/load/load.py index d7efdad60b..f223cd0409 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -10,7 +10,7 @@ from dlt.common.configuration import with_config, known_sections from dlt.common.configuration.accessors import config from dlt.common.pipeline import LoadInfo, LoadMetrics, SupportsPipeline, WithStepInfo -from dlt.common.schema.utils import get_child_tables, get_top_level_table +from dlt.common.schema.utils import get_child_tables, get_top_level_table, has_table_seen_data from dlt.common.storages.load_storage import LoadPackageInfo, ParsedLoadJobFileName, TJobState from dlt.common.runners import TRunMetrics, Runnable, workermethod, NullExecutor from dlt.common.runtime.collector import Collector, NULL_COLLECTOR @@ -359,7 +359,7 @@ def _get_table_chain_tables_with_filter( self, schema: Schema, tables_with_jobs: Iterable[str], - f: Callable[[TTableSchema], bool] = lambda t: True, + exclude_tables: Callable[[TTableSchema], bool] = lambda t: True, ) -> List[str]: """Get all jobs for tables with given write disposition and resolve the table chain. @@ -368,7 +368,7 @@ def _get_table_chain_tables_with_filter( result: List[str] = [] for table_name in tables_with_jobs: top_job_table = get_top_level_table(schema.tables, table_name) - if not f(top_job_table): + if not exclude_tables(top_job_table): continue # for replace and merge write dispositions we should include tables # without jobs in the table chain, because child tables may need @@ -379,6 +379,10 @@ def _get_table_chain_tables_with_filter( t["name"]: t for t in schema.data_tables(include_incomplete=False) + [top_job_table] } for table in get_child_tables(data_tables, top_job_table["name"]): + # table that never seen data are skipped as they will not be created + if not has_table_seen_data(table): + continue + # if there's no job for the table and we are in append then skip table_has_job = table["name"] in tables_with_jobs if not table_has_job and skip_jobless_table: continue @@ -440,9 +444,7 @@ def _init_client( dlt_tables = set(schema.dlt_table_names()) # tables without data tables_no_data = set( - table["name"] - for table in schema.data_tables() - if table.get("x-normalizer", {}).get("first-seen", None) is None # type: ignore[attr-defined] + table["name"] for table in schema.data_tables() if not has_table_seen_data(table) ) # get all tables that actually have load jobs with data tables_with_jobs = ( @@ -453,8 +455,6 @@ def _init_client( truncate_tables = set( self._get_table_chain_tables_with_filter(schema, tables_with_jobs, truncate_filter) ) - # must be a subset - assert (tables_with_jobs | dlt_tables).issuperset(truncate_tables) applied_update = self._init_dataset_and_update_schema( job_client, expected_update, tables_with_jobs | dlt_tables, truncate_tables @@ -468,15 +468,6 @@ def _init_client( schema, tables_with_jobs, load_staging_filter ) ) - # truncate all tables - staging_truncate_tables = set( - self._get_table_chain_tables_with_filter( - schema, tables_with_jobs, load_staging_filter - ) - ) - # must be a subset - assert staging_tables.issuperset(staging_truncate_tables) - assert tables_with_jobs.issuperset(staging_tables) if staging_tables: with job_client.with_staging_dataset(): @@ -484,7 +475,7 @@ def _init_client( job_client, expected_update, staging_tables | {schema.version_table_name}, # keep only schema version - staging_truncate_tables, + staging_tables, # all eligible tables must be also truncated staging_info=True, ) From d1255566b831b31e586b83cd90f20055bf7b0bd0 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Thu, 22 Feb 2024 20:33:13 +0100 Subject: [PATCH 24/28] converts seen-data to boolean, fixes tests --- dlt/normalize/items_normalizers.py | 6 ++---- dlt/normalize/normalize.py | 7 +++++-- tests/common/utils.py | 2 +- tests/load/pipeline/test_merge_disposition.py | 6 ++++++ tests/load/pipeline/test_pipelines.py | 19 ++++++++++--------- .../load/pipeline/test_replace_disposition.py | 8 +++++--- tests/load/pipeline/test_restore_state.py | 4 ++-- tests/load/pipeline/utils.py | 2 ++ tests/pipeline/test_dlt_versions.py | 6 +++--- 9 files changed, 36 insertions(+), 24 deletions(-) diff --git a/dlt/normalize/items_normalizers.py b/dlt/normalize/items_normalizers.py index 077bfca990..56d38a5a64 100644 --- a/dlt/normalize/items_normalizers.py +++ b/dlt/normalize/items_normalizers.py @@ -6,6 +6,7 @@ from dlt.common.json import custom_pua_decode, may_have_pua from dlt.common.runtime import signals from dlt.common.schema.typing import TSchemaEvolutionMode, TTableSchemaColumns, TSchemaContractDict +from dlt.common.schema.utils import has_table_seen_data from dlt.common.storages import ( NormalizeStorage, LoadStorage, @@ -198,10 +199,7 @@ def __call__( if line is None and root_table_name in self.schema.tables: # write only if table seen data before root_table = self.schema.tables[root_table_name] - if ( - "x-normalizer" in root_table - and root_table["x-normalizer"].get("first-seen", None) is not None # type: ignore[typeddict-item] - ): + if has_table_seen_data(root_table): self.load_storage.write_empty_items_file( self.load_id, self.schema.name, diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 52f3bd5a74..d360a1c7c4 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -304,8 +304,11 @@ def spool_files( # drop evolve once for all tables that seen data x_normalizer.pop("evolve-columns-once", None) # mark that table have seen data only if there was data - if table_metrics[table_name].items_count > 0: - x_normalizer["first-seen"] = load_id + if table_metrics[table_name].items_count > 0 and "seen-data" not in x_normalizer: + logger.info( + f"Table {table_name} has seen data for a first time with load id {load_id}" + ) + x_normalizer["seen-data"] = True logger.info( f"Saving schema {schema.name} with version {schema.stored_version}:{schema.version}" ) diff --git a/tests/common/utils.py b/tests/common/utils.py index 0235d18bbe..a234937e56 100644 --- a/tests/common/utils.py +++ b/tests/common/utils.py @@ -16,7 +16,7 @@ COMMON_TEST_CASES_PATH = "./tests/common/cases/" # for import schema tests, change when upgrading the schema version -IMPORTED_VERSION_HASH_ETH_V8 = "C5An8WClbavalXDdNSqXbdI7Swqh/mTWMcwWKCF//EE=" +IMPORTED_VERSION_HASH_ETH_V9 = "PgEHvn5+BHV1jNzNYpx9aDpq6Pq1PSSetufj/h0hKg4=" # test sentry DSN TEST_SENTRY_DSN = ( "https://797678dd0af64b96937435326c7d30c1@o1061158.ingest.sentry.io/4504306172821504" diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index e6466bb4b5..8431f489e6 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -10,6 +10,7 @@ from dlt.common import json, pendulum from dlt.common.configuration.container import Container from dlt.common.pipeline import StateInjectableContext +from dlt.common.schema.utils import has_table_seen_data from dlt.common.typing import StrAny from dlt.common.utils import digest128 from dlt.extract import DltResource @@ -34,6 +35,11 @@ def test_merge_on_keys_in_schema(destination_config: DestinationTestConfiguratio with open("tests/common/cases/schemas/eth/ethereum_schema_v5.yml", "r", encoding="utf-8") as f: schema = dlt.Schema.from_dict(yaml.safe_load(f)) + # make block uncles unseen to trigger filtering loader in loader for child tables + if has_table_seen_data(schema.tables["blocks__uncles"]): + del schema.tables["blocks__uncles"]["x-normalizer"] # type: ignore[typeddict-item] + assert not has_table_seen_data(schema.tables["blocks__uncles"]) + with open( "tests/normalize/cases/ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2.json", "r", diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 60eee18ed2..cef1a6936d 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -23,7 +23,6 @@ ) from dlt.common.schema.exceptions import CannotCoerceColumnException from dlt.common.exceptions import DestinationHasFailedJobs -from tests.load.pipeline.test_replace_disposition import REPLACE_STRATEGIES from tests.utils import TEST_STORAGE_ROOT, data_to_item_format, preserve_environ from tests.pipeline.utils import assert_data_table_counts, assert_load_info @@ -39,6 +38,7 @@ assert_table, load_table_counts, select_data, + REPLACE_STRATEGIES, ) from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration @@ -917,8 +917,8 @@ def table_3(make_data=False): load_table_counts(pipeline, "table_3") assert "x-normalizer" not in pipeline.default_schema.tables["table_3"] assert ( - pipeline.default_schema.tables["_dlt_pipeline_state"]["x-normalizer"]["first-seen"] # type: ignore[typeddict-item] - == load_info_1.loads_ids[0] + pipeline.default_schema.tables["_dlt_pipeline_state"]["x-normalizer"]["seen-data"] # type: ignore[typeddict-item] + is True ) # load with one empty job, table 3 not created @@ -949,22 +949,23 @@ def table_3(make_data=False): # and schema is not updated because the hash didn't change # also we make the replace resource to load its 1 record load_info_3 = pipeline.run([source.table_3(make_data=True), source.table_2]) + assert_load_info(load_info_3) assert_data_table_counts(pipeline, {"table_1": 1, "table_2": 1, "table_3": 1}) # v5 = pipeline.default_schema.to_pretty_yaml() # print(v5) # check if seen data is market correctly assert ( - pipeline.default_schema.tables["table_3"]["x-normalizer"]["first-seen"] # type: ignore[typeddict-item] - == load_info_3.loads_ids[0] + pipeline.default_schema.tables["table_3"]["x-normalizer"]["seen-data"] # type: ignore[typeddict-item] + is True ) assert ( - pipeline.default_schema.tables["table_2"]["x-normalizer"]["first-seen"] # type: ignore[typeddict-item] - == load_info_3.loads_ids[0] + pipeline.default_schema.tables["table_2"]["x-normalizer"]["seen-data"] # type: ignore[typeddict-item] + is True ) assert ( - pipeline.default_schema.tables["table_1"]["x-normalizer"]["first-seen"] # type: ignore[typeddict-item] - == load_info_2.loads_ids[0] + pipeline.default_schema.tables["table_1"]["x-normalizer"]["seen-data"] # type: ignore[typeddict-item] + is True ) diff --git a/tests/load/pipeline/test_replace_disposition.py b/tests/load/pipeline/test_replace_disposition.py index c6db91efff..a69d4440dc 100644 --- a/tests/load/pipeline/test_replace_disposition.py +++ b/tests/load/pipeline/test_replace_disposition.py @@ -9,9 +9,11 @@ load_table_counts, load_tables_to_dicts, ) -from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration - -REPLACE_STRATEGIES = ["truncate-and-insert", "insert-from-staging", "staging-optimized"] +from tests.load.pipeline.utils import ( + destinations_configs, + DestinationTestConfiguration, + REPLACE_STRATEGIES, +) @pytest.mark.parametrize( diff --git a/tests/load/pipeline/test_restore_state.py b/tests/load/pipeline/test_restore_state.py index 381068f1e1..73c651688d 100644 --- a/tests/load/pipeline/test_restore_state.py +++ b/tests/load/pipeline/test_restore_state.py @@ -18,7 +18,7 @@ from tests.utils import TEST_STORAGE_ROOT from tests.cases import JSON_TYPED_DICT, JSON_TYPED_DICT_DECODED -from tests.common.utils import IMPORTED_VERSION_HASH_ETH_V8, yml_case_path as common_yml_case_path +from tests.common.utils import IMPORTED_VERSION_HASH_ETH_V9, yml_case_path as common_yml_case_path from tests.common.configuration.utils import environment from tests.load.pipeline.utils import assert_query_data, drop_active_pipeline_data from tests.load.utils import ( @@ -469,7 +469,7 @@ def test_restore_schemas_while_import_schemas_exist( assert normalized_annotations in schema.tables # check if attached to import schema - assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V8 + assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9 # extract some data with restored pipeline p.run(["C", "D", "E"], table_name="blacklist") assert normalized_labels in schema.tables diff --git a/tests/load/pipeline/utils.py b/tests/load/pipeline/utils.py index 17360e76fd..54c6231dcc 100644 --- a/tests/load/pipeline/utils.py +++ b/tests/load/pipeline/utils.py @@ -21,6 +21,8 @@ if TYPE_CHECKING: from dlt.destinations.impl.filesystem.filesystem import FilesystemClient +REPLACE_STRATEGIES = ["truncate-and-insert", "insert-from-staging", "staging-optimized"] + @pytest.fixture(autouse=True) def drop_pipeline(request) -> Iterator[None]: diff --git a/tests/pipeline/test_dlt_versions.py b/tests/pipeline/test_dlt_versions.py index 5cf1857dfa..cec7562d60 100644 --- a/tests/pipeline/test_dlt_versions.py +++ b/tests/pipeline/test_dlt_versions.py @@ -99,7 +99,7 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: f".dlt/pipelines/{GITHUB_PIPELINE_NAME}/schemas/github.schema.json" ) ) - assert github_schema["engine_version"] == 8 + assert github_schema["engine_version"] == 9 assert "schema_version_hash" in github_schema["tables"][LOADS_TABLE_NAME]["columns"] # load state state_dict = json.loads( @@ -149,7 +149,7 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: pipeline.sync_destination() # print(pipeline.working_dir) # we have updated schema - assert pipeline.default_schema.ENGINE_VERSION == 8 + assert pipeline.default_schema.ENGINE_VERSION == 9 # make sure that schema hash retrieved from the destination is exactly the same as the schema hash that was in storage before the schema was wiped assert pipeline.default_schema.stored_version_hash == github_schema["version_hash"] @@ -204,7 +204,7 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: ) pipeline = pipeline.drop() pipeline.sync_destination() - assert pipeline.default_schema.ENGINE_VERSION == 8 + assert pipeline.default_schema.ENGINE_VERSION == 9 # schema version does not match `dlt.attach` does not update to the right schema by itself assert pipeline.default_schema.stored_version_hash != github_schema["version_hash"] # state has hash From af0b34438f325869b8df43b3540c9f857acaadb9 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Thu, 22 Feb 2024 23:31:49 +0100 Subject: [PATCH 25/28] disables filesystem tests config due to merge present --- tests/load/pipeline/test_pipelines.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index cef1a6936d..a483cbee1a 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -853,9 +853,7 @@ def some_source(): @pytest.mark.parametrize( "destination_config", - destinations_configs( - local_filesystem_configs=True, default_staging_configs=True, default_sql_configs=True - ), + destinations_configs(default_staging_configs=True, default_sql_configs=True), ids=lambda x: x.name, ) @pytest.mark.parametrize("replace_strategy", REPLACE_STRATEGIES) From 262018b7122c48387d87eb402bad1f86ce74876e Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Fri, 23 Feb 2024 00:46:09 +0100 Subject: [PATCH 26/28] add docs for hard_delete and dedup_sort column hints --- .../docs/general-usage/incremental-loading.md | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/docs/website/docs/general-usage/incremental-loading.md b/docs/website/docs/general-usage/incremental-loading.md index 09b8ca7a96..7e4021214e 100644 --- a/docs/website/docs/general-usage/incremental-loading.md +++ b/docs/website/docs/general-usage/incremental-loading.md @@ -77,6 +77,17 @@ You can use compound primary keys: ... ``` +By default, `primary_key` deduplication is arbitrary. You can pass the `dedup_sort` column hint with a value of `desc` or `asc` to influence which record remains after deduplication. Using `desc`, the records sharing the same `primary_key` are sorted in descending order before deduplication, making sure the record with the highest value for the column with the `dedup_sort` hint remains. `asc` has the opposite behavior. + +```python +@dlt.resource( + primary_key="id", + write_disposition="merge", + columns={"created_at": {"dedup_sort": "desc"}} # select "latest" record +) +... +``` + Example below merges on a column `batch_day` that holds the day for which given record is valid. Merge keys also can be compound: @@ -113,6 +124,78 @@ def github_repo_events(last_created_at = dlt.sources.incremental("created_at", " yield from _get_rest_pages("events") ``` +### Delete records +The `hard_delete` column hint can be used to delete records from the destination dataset. The behavior of the delete mechanism depends on the data type of the column marked with the hint: +1) `bool` type: only `True` leads to a delete—`None` and `False` values are disregarded +2) other types: each `not None` value leads to a delete + +Each record in the destination table with the same `primary_key` or `merge_key` as a record in the source dataset that's marked as a delete will be deleted. + +Deletes are propagated to any child table that might exist. For each record that gets deleted in the root table, all corresponding records in the child table(s) will also be deleted. Records in parent and child tables are linked through the `root key` that is explained in the next section. + +#### Example: with primary key and boolean delete column +```python +@dlt.resource( + primary_key="id", + write_disposition="merge", + columns={"deleted_flag": {"hard_delete": True}} +) +def resource(): + # this will insert a record (assuming a record with id = 1 does not yet exist) + yield {"id": 1, "val": "foo", "deleted_flag": False} + + # this will update the record + yield {"id": 1, "val": "bar", "deleted_flag": None} + + # this will delete the record + yield {"id": 1, "val": "foo", "deleted_flag": True} + + # similarly, this would have also deleted the record + # only the key and the column marked with the "hard_delete" hint suffice to delete records + yield {"id": 1, "deleted_flag": True} +... +``` + +#### Example: with merge key and non-boolean delete column +```python +@dlt.resource( + merge_key="id", + write_disposition="merge", + columns={"deleted_at_ts": {"hard_delete": True}}} +def resource(): + # this will insert two records + yield [ + {"id": 1, "val": "foo", "deleted_at_ts": None}, + {"id": 1, "val": "bar", "deleted_at_ts": None} + ] + + # this will delete two records + yield {"id": 1, "val": "foo", "deleted_at_ts": "2024-02-22T12:34:56Z"} +... +``` + +#### Example: with primary key and "dedup_sort" hint +```python +@dlt.resource( + primary_key="id", + write_disposition="merge", + columns={"deleted_flag": {"hard_delete": True}, "lsn": {"dedup_sort": "desc"}} +def resource(): + # this will insert one record (the one with lsn = 3) + yield [ + {"id": 1, "val": "foo", "lsn": 1, "deleted_flag": None}, + {"id": 1, "val": "baz", "lsn": 3, "deleted_flag": None}, + {"id": 1, "val": "bar", "lsn": 2, "deleted_flag": True} + ] + + # this will insert nothing, because the "latest" record is a delete + yield [ + {"id": 2, "val": "foo", "lsn": 1, "deleted_flag": False}, + {"id": 2, "lsn": 2, "deleted_flag": True} + ] +... +``` + ### Forcing root key propagation Merge write disposition requires that the `_dlt_id` of top level table is propagated to child From 44a9ff2782ea7ea7f497350c11be387070ed9cb6 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Fri, 23 Feb 2024 10:40:08 +0100 Subject: [PATCH 27/28] fixes extending table chains in load --- dlt/load/load.py | 87 +++++++++++-------- .../athena_iceberg/test_athena_iceberg.py | 2 +- 2 files changed, 52 insertions(+), 37 deletions(-) diff --git a/dlt/load/load.py b/dlt/load/load.py index f223cd0409..998bf6a580 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -231,26 +231,36 @@ def get_completed_table_chain( being_completed_job_id: str = None, ) -> List[TTableSchema]: """Gets a table chain starting from the `top_merged_table` containing only tables with completed/failed jobs. None is returned if there's any job that is not completed - + For append and merge write disposition, tables without jobs will be included, providing they have seen data (and were created in the destination) Optionally `being_completed_job_id` can be passed that is considered to be completed before job itself moves in storage """ - table_chain: List[TTableSchema] = [] # returns ordered list of tables from parent to child leaf tables - for table_name in self._get_table_chain_tables_with_filter( - schema, [top_merged_table["name"]] - ): + table_chain: List[TTableSchema] = [] + # allow for jobless tables for those write disposition + skip_jobless_table = top_merged_table["write_disposition"] not in ("replace", "merge") + + # make sure all the jobs for the table chain is completed + for table in get_child_tables(schema.tables, top_merged_table["name"]): table_jobs = self.load_storage.normalized_packages.list_jobs_for_table( - load_id, table_name + load_id, table["name"] ) - # all jobs must be completed in order for merge to be created - if any( - job.state not in ("failed_jobs", "completed_jobs") - and job.job_file_info.job_id() != being_completed_job_id - for job in table_jobs - ): - return None - table_chain.append(schema.tables[table_name]) - # there must be at least one table + # skip tables that never seen data + if not has_table_seen_data(table): + assert len(table_jobs) == 0, f"Tables that never seen data cannot have jobs {table}" + continue + # skip jobless tables + if len(table_jobs) == 0 and skip_jobless_table: + continue + else: + # all jobs must be completed in order for merge to be created + if any( + job.state not in ("failed_jobs", "completed_jobs") + and job.job_file_info.job_id() != being_completed_job_id + for job in table_jobs + ): + return None + table_chain.append(table) + # there must be at least table assert len(table_chain) > 0 return table_chain @@ -355,39 +365,42 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) f"All jobs completed, archiving package {load_id} with aborted set to {aborted}" ) - def _get_table_chain_tables_with_filter( - self, + @staticmethod + def _extend_tables_with_table_chain( schema: Schema, + tables: Iterable[str], tables_with_jobs: Iterable[str], - exclude_tables: Callable[[TTableSchema], bool] = lambda t: True, - ) -> List[str]: - """Get all jobs for tables with given write disposition and resolve the table chain. + include_table_filter: Callable[[TTableSchema], bool] = lambda t: True, + ) -> Iterable[str]: + """Extend 'tables` with all their children and filter out tables that do not have jobs (in `tables_with_jobs`), + haven't seen data or are not included by `include_table_filter`. + Note that for top tables with replace and merge, the filter for tables that do not have jobs - Returns a list of table names ordered by ancestry so the child tables are always after their parents. + Returns an unordered set of table names and their child tables """ - result: List[str] = [] - for table_name in tables_with_jobs: + result: Set[str] = set() + for table_name in tables: top_job_table = get_top_level_table(schema.tables, table_name) - if not exclude_tables(top_job_table): + # NOTE: this will ie. eliminate all non iceberg tables on ATHENA destination from staging (only iceberg needs that) + if not include_table_filter(top_job_table): continue # for replace and merge write dispositions we should include tables # without jobs in the table chain, because child tables may need # processing due to changes in the root table skip_jobless_table = top_job_table["write_disposition"] not in ("replace", "merge") - # use only complete tables to infer table chains - data_tables = { - t["name"]: t for t in schema.data_tables(include_incomplete=False) + [top_job_table] - } - for table in get_child_tables(data_tables, top_job_table["name"]): + for table in get_child_tables(schema.tables, top_job_table["name"]): + table_has_job = table["name"] in tables_with_jobs # table that never seen data are skipped as they will not be created if not has_table_seen_data(table): + assert ( + not table_has_job + ), f"Tables that never seen data cannot have jobs {table}" continue # if there's no job for the table and we are in append then skip - table_has_job = table["name"] in tables_with_jobs if not table_has_job and skip_jobless_table: continue - result.append(table["name"]) - return order_deduped(result) + result.add(table["name"]) + return result @staticmethod def _init_dataset_and_update_schema( @@ -442,7 +455,7 @@ def _init_client( """ # get dlt/internal tables dlt_tables = set(schema.dlt_table_names()) - # tables without data + # tables without data (TODO: normalizer removes such jobs, write tests and remove the line below) tables_no_data = set( table["name"] for table in schema.data_tables() if not has_table_seen_data(table) ) @@ -453,7 +466,9 @@ def _init_client( # get tables to truncate by extending tables with jobs with all their child tables truncate_tables = set( - self._get_table_chain_tables_with_filter(schema, tables_with_jobs, truncate_filter) + self._extend_tables_with_table_chain( + schema, tables_with_jobs, tables_with_jobs, truncate_filter + ) ) applied_update = self._init_dataset_and_update_schema( @@ -464,8 +479,8 @@ def _init_client( if isinstance(job_client, WithStagingDataset): # get staging tables (all data tables that are eligible) staging_tables = set( - self._get_table_chain_tables_with_filter( - schema, tables_with_jobs, load_staging_filter + self._extend_tables_with_table_chain( + schema, tables_with_jobs, tables_with_jobs, load_staging_filter ) ) diff --git a/tests/load/athena_iceberg/test_athena_iceberg.py b/tests/load/athena_iceberg/test_athena_iceberg.py index 0b18f22639..6804b98427 100644 --- a/tests/load/athena_iceberg/test_athena_iceberg.py +++ b/tests/load/athena_iceberg/test_athena_iceberg.py @@ -27,7 +27,7 @@ def test_iceberg() -> None: os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = "s3://dlt-ci-test-bucket" pipeline = dlt.pipeline( - pipeline_name="aaaaathena-iceberg", + pipeline_name="athena-iceberg", destination="athena", staging="filesystem", full_refresh=True, From 9921b8993aa855840fc35f4102e52ec82e296465 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 24 Feb 2024 01:14:02 +0100 Subject: [PATCH 28/28] refactors load and adds unit tests with dummy --- Makefile | 2 +- dlt/common/destination/reference.py | 9 +- dlt/common/schema/utils.py | 14 + dlt/common/storages/load_package.py | 15 +- dlt/destinations/impl/athena/athena.py | 24 +- dlt/destinations/impl/bigquery/bigquery.py | 63 ++- dlt/destinations/impl/dummy/__init__.py | 9 +- dlt/destinations/impl/dummy/configuration.py | 3 + dlt/destinations/impl/dummy/dummy.py | 52 ++- .../impl/filesystem/filesystem.py | 6 +- dlt/destinations/impl/synapse/synapse.py | 35 +- dlt/destinations/job_client_impl.py | 6 +- dlt/load/load.py | 242 ++--------- dlt/load/utils.py | 186 +++++++++ .../bigquery/test_bigquery_table_builder.py | 11 +- .../load/duckdb/test_duckdb_table_builder.py | 11 +- tests/load/mssql/test_mssql_table_builder.py | 11 +- .../postgres/test_postgres_table_builder.py | 12 +- .../redshift/test_redshift_table_builder.py | 11 +- .../snowflake/test_snowflake_table_builder.py | 12 +- .../synapse/test_synapse_table_builder.py | 15 +- tests/load/test_dummy_client.py | 392 +++++++++++++++++- tests/load/utils.py | 11 +- 23 files changed, 785 insertions(+), 367 deletions(-) create mode 100644 dlt/load/utils.py diff --git a/Makefile b/Makefile index bd425b0e42..8da28717c0 100644 --- a/Makefile +++ b/Makefile @@ -74,7 +74,7 @@ test-load-local: DESTINATION__POSTGRES__CREDENTIALS=postgresql://loader:loader@localhost:5432/dlt_data DESTINATION__DUCKDB__CREDENTIALS=duckdb:///_storage/test_quack.duckdb poetry run pytest tests -k '(postgres or duckdb)' test-common: - poetry run pytest tests/common tests/normalize tests/extract tests/pipeline tests/reflection tests/sources tests/cli/common + poetry run pytest tests/common tests/normalize tests/extract tests/pipeline tests/reflection tests/sources tests/cli/common tests/load/test_dummy_client.py tests/libs tests/destinations reset-test-storage: -rm -r _storage diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index a2bcea0d56..5e698347e5 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -255,7 +255,8 @@ def new_file_path(self) -> str: class FollowupJob: """Adds a trait that allows to create a followup job""" - def create_followup_jobs(self, next_state: str) -> List[NewLoadJob]: + def create_followup_jobs(self, final_state: TLoadJobState) -> List[NewLoadJob]: + """Return list of new jobs. `final_state` is state to which this job transits""" return [] @@ -407,9 +408,9 @@ def _verify_schema(self) -> None: " column manually in code ie. as a merge key?" ) - def get_load_table(self, table_name: str, prepare_for_staging: bool = False) -> TTableSchema: - if table_name not in self.schema.tables: - return None + def prepare_load_table( + self, table_name: str, prepare_for_staging: bool = False + ) -> TTableSchema: try: # make a copy of the schema so modifications do not affect the original document table = deepcopy(self.schema.tables[table_name]) diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index 685bdeda67..835fe4279e 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -546,6 +546,20 @@ def get_table_format(tables: TSchemaTables, table_name: str) -> TTableFormat: ) +def fill_hints_from_parent_and_clone_table( + tables: TSchemaTables, table: TTableSchema +) -> TTableSchema: + """Takes write disposition and table format from parent tables if not present""" + # make a copy of the schema so modifications do not affect the original document + table = deepcopy(table) + # add write disposition if not specified - in child tables + if "write_disposition" not in table: + table["write_disposition"] = get_write_disposition(tables, table["name"]) + if "table_format" not in table: + table["table_format"] = get_table_format(tables, table["name"]) + return table + + def table_schema_has_type(table: TTableSchema, _typ: TDataType) -> bool: """Checks if `table` schema contains column with type _typ""" return any(c.get("data_type") == _typ for c in table["columns"].values()) diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index 2860364cd0..01f3923455 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -8,6 +8,7 @@ from typing import ( ClassVar, Dict, + Iterable, List, NamedTuple, Literal, @@ -245,9 +246,7 @@ def list_failed_jobs(self, load_id: str) -> Sequence[str]: ) def list_jobs_for_table(self, load_id: str, table_name: str) -> Sequence[LoadJobInfo]: - return [ - job for job in self.list_all_jobs(load_id) if job.job_file_info.table_name == table_name - ] + return self.filter_jobs_for_table(self.list_all_jobs(load_id), table_name) def list_all_jobs(self, load_id: str) -> Sequence[LoadJobInfo]: info = self.get_load_package_info(load_id) @@ -448,6 +447,10 @@ def _read_job_file_info(self, state: TJobState, file: str, now: DateTime = None) failed_message, ) + # + # Utils + # + def _move_job( self, load_id: str, @@ -503,3 +506,9 @@ def is_package_partially_loaded(package_info: LoadPackageInfo) -> bool: @staticmethod def _job_elapsed_time_seconds(file_path: str, now_ts: float = None) -> float: return (now_ts or pendulum.now().timestamp()) - os.path.getmtime(file_path) + + @staticmethod + def filter_jobs_for_table( + all_jobs: Iterable[LoadJobInfo], table_name: str + ) -> Sequence[LoadJobInfo]: + return [job for job in all_jobs if job.job_file_info.table_name == table_name] diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index cb7579f027..c2dae7a350 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -365,7 +365,7 @@ def _get_table_update_sql( # for the system tables we need to create empty iceberg tables to be able to run, DELETE and UPDATE queries # or if we are in iceberg mode, we create iceberg tables for all tables - table = self.get_load_table(table_name, self.in_staging_mode) + table = self.prepare_load_table(table_name, self.in_staging_mode) is_iceberg = self._is_iceberg_table(table) or table.get("write_disposition", None) == "skip" columns = ", ".join( [self._get_column_def_sql(c, table.get("table_format")) for c in new_columns] @@ -405,13 +405,13 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> if not job: job = ( DoNothingFollowupJob(file_path) - if self._is_iceberg_table(self.get_load_table(table["name"])) + if self._is_iceberg_table(self.prepare_load_table(table["name"])) else DoNothingJob(file_path) ) return job def _create_append_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: - if self._is_iceberg_table(self.get_load_table(table_chain[0]["name"])): + if self._is_iceberg_table(self.prepare_load_table(table_chain[0]["name"])): return [ SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": False}) ] @@ -420,7 +420,7 @@ def _create_append_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> L def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] ) -> List[NewLoadJob]: - if self._is_iceberg_table(self.get_load_table(table_chain[0]["name"])): + if self._is_iceberg_table(self.prepare_load_table(table_chain[0]["name"])): return [ SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": True}) ] @@ -436,15 +436,15 @@ def _is_iceberg_table(self, table: TTableSchema) -> bool: def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: # all iceberg tables need staging - if self._is_iceberg_table(self.get_load_table(table["name"])): + if self._is_iceberg_table(self.prepare_load_table(table["name"])): return True return super().should_load_data_to_staging_dataset(table) def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: # on athena we only truncate replace tables that are not iceberg - table = self.get_load_table(table["name"]) + table = self.prepare_load_table(table["name"]) if table["write_disposition"] == "replace" and not self._is_iceberg_table( - self.get_load_table(table["name"]) + self.prepare_load_table(table["name"]) ): return True return False @@ -453,15 +453,17 @@ def should_load_data_to_staging_dataset_on_staging_destination( self, table: TTableSchema ) -> bool: """iceberg table data goes into staging on staging destination""" - if self._is_iceberg_table(self.get_load_table(table["name"])): + if self._is_iceberg_table(self.prepare_load_table(table["name"])): return True return super().should_load_data_to_staging_dataset_on_staging_destination(table) - def get_load_table(self, table_name: str, staging: bool = False) -> TTableSchema: - table = super().get_load_table(table_name, staging) + def prepare_load_table( + self, table_name: str, prepare_for_staging: bool = False + ) -> TTableSchema: + table = super().prepare_load_table(table_name, prepare_for_staging) if self.config.force_iceberg: table["table_format"] = "iceberg" - if staging and table.get("table_format", None) == "iceberg": + if prepare_for_staging and table.get("table_format", None) == "iceberg": table.pop("table_format") return table diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index d4756d8978..d4261a1636 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -248,7 +248,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> def _get_table_update_sql( self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool ) -> List[str]: - table: Optional[TTableSchema] = self.get_load_table(table_name) + table: Optional[TTableSchema] = self.prepare_load_table(table_name) sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) canonical_name = self.sql_client.make_qualified_table_name(table_name) @@ -285,46 +285,39 @@ def _get_table_update_sql( sql[0] += "\nCLUSTER BY " + ", ".join(cluster_list) # Table options. - if table: - table_options: DictStrAny = { - "description": ( - f"'{table.get(TABLE_DESCRIPTION_HINT)}'" - if table.get(TABLE_DESCRIPTION_HINT) - else None - ), - "expiration_timestamp": ( - f"TIMESTAMP '{table.get(TABLE_EXPIRATION_HINT)}'" - if table.get(TABLE_EXPIRATION_HINT) - else None - ), - } - if not any(table_options.values()): - return sql - - if generate_alter: - raise NotImplementedError("Update table options not yet implemented.") - else: - sql[0] += ( - "\nOPTIONS (" - + ", ".join( - [ - f"{key}={value}" - for key, value in table_options.items() - if value is not None - ] - ) - + ")" + table_options: DictStrAny = { + "description": ( + f"'{table.get(TABLE_DESCRIPTION_HINT)}'" + if table.get(TABLE_DESCRIPTION_HINT) + else None + ), + "expiration_timestamp": ( + f"TIMESTAMP '{table.get(TABLE_EXPIRATION_HINT)}'" + if table.get(TABLE_EXPIRATION_HINT) + else None + ), + } + if not any(table_options.values()): + return sql + + if generate_alter: + raise NotImplementedError("Update table options not yet implemented.") + else: + sql[0] += ( + "\nOPTIONS (" + + ", ".join( + [f"{key}={value}" for key, value in table_options.items() if value is not None] ) + + ")" + ) return sql - def get_load_table( + def prepare_load_table( self, table_name: str, prepare_for_staging: bool = False ) -> Optional[TTableSchema]: - table = super().get_load_table(table_name, prepare_for_staging) - if table is None: - return None - elif table_name in self.schema.data_table_names(): + table = super().prepare_load_table(table_name, prepare_for_staging) + if table_name in self.schema.data_table_names(): if TABLE_DESCRIPTION_HINT not in table: table[TABLE_DESCRIPTION_HINT] = ( # type: ignore[name-defined, typeddict-unknown-key, unused-ignore] get_inherited_table_hint( diff --git a/dlt/destinations/impl/dummy/__init__.py b/dlt/destinations/impl/dummy/__init__.py index a3152b8d77..37b2e77c8a 100644 --- a/dlt/destinations/impl/dummy/__init__.py +++ b/dlt/destinations/impl/dummy/__init__.py @@ -1,6 +1,8 @@ +from typing import List from dlt.common.configuration import with_config, known_sections from dlt.common.configuration.accessors import config from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.destination.capabilities import TLoaderFileFormat from dlt.destinations.impl.dummy.configuration import DummyClientConfiguration @@ -18,11 +20,14 @@ def _configure(config: DummyClientConfiguration = config.value) -> DummyClientCo def capabilities() -> DestinationCapabilitiesContext: config = _configure() + additional_formats: List[TLoaderFileFormat] = ( + ["reference"] if config.create_followup_jobs else [] + ) caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = config.loader_file_format - caps.supported_loader_file_formats = [config.loader_file_format] + caps.supported_loader_file_formats = additional_formats + [config.loader_file_format] caps.preferred_staging_file_format = None - caps.supported_staging_file_formats = [config.loader_file_format] + caps.supported_staging_file_formats = additional_formats + [config.loader_file_format] caps.max_identifier_length = 127 caps.max_column_identifier_length = 127 caps.max_query_length = 8 * 1024 * 1024 diff --git a/dlt/destinations/impl/dummy/configuration.py b/dlt/destinations/impl/dummy/configuration.py index 82dc797126..cce0dfa8ed 100644 --- a/dlt/destinations/impl/dummy/configuration.py +++ b/dlt/destinations/impl/dummy/configuration.py @@ -26,6 +26,8 @@ class DummyClientConfiguration(DestinationClientConfiguration): """probability of exception when checking job status""" timeout: float = 10.0 fail_in_init: bool = True + # new jobs workflows + create_followup_jobs: bool = False credentials: DummyClientCredentials = None @@ -43,6 +45,7 @@ def __init__( exception_prob: float = None, timeout: float = None, fail_in_init: bool = None, + create_followup_jobs: bool = None, destination_name: str = None, environment: str = None, ) -> None: ... diff --git a/dlt/destinations/impl/dummy/dummy.py b/dlt/destinations/impl/dummy/dummy.py index 367db11e82..c46e329819 100644 --- a/dlt/destinations/impl/dummy/dummy.py +++ b/dlt/destinations/impl/dummy/dummy.py @@ -1,7 +1,18 @@ +from contextlib import contextmanager import random from copy import copy from types import TracebackType -from typing import ClassVar, Dict, Optional, Sequence, Type, Iterable, List +from typing import ( + ClassVar, + ContextManager, + Dict, + Iterator, + Optional, + Sequence, + Type, + Iterable, + List, +) from dlt.common import pendulum from dlt.common.schema import Schema, TTableSchema, TSchemaTables @@ -15,6 +26,7 @@ TLoadJobState, LoadJob, JobClientBase, + WithStagingDataset, ) from dlt.destinations.exceptions import ( @@ -26,9 +38,10 @@ from dlt.destinations.impl.dummy import capabilities from dlt.destinations.impl.dummy.configuration import DummyClientConfiguration +from dlt.destinations.job_impl import NewReferenceJob -class LoadDummyJob(LoadJob, FollowupJob): +class LoadDummyBaseJob(LoadJob): def __init__(self, file_name: str, config: DummyClientConfiguration) -> None: self.config = copy(config) self._status: TLoadJobState = "running" @@ -79,16 +92,29 @@ def retry(self) -> None: self._status = "retry" -JOBS: Dict[str, LoadDummyJob] = {} +class LoadDummyJob(LoadDummyBaseJob, FollowupJob): + def create_followup_jobs(self, final_state: TLoadJobState) -> List[NewLoadJob]: + if self.config.create_followup_jobs and final_state == "completed": + new_job = NewReferenceJob( + file_name=self.file_name(), status="running", remote_path=self._file_name + ) + CREATED_FOLLOWUP_JOBS[new_job.job_id()] = new_job + return [new_job] + return [] + + +JOBS: Dict[str, LoadDummyBaseJob] = {} +CREATED_FOLLOWUP_JOBS: Dict[str, NewLoadJob] = {} -class DummyClient(JobClientBase, SupportsStagingDestination): +class DummyClient(JobClientBase, SupportsStagingDestination, WithStagingDataset): """dummy client storing jobs in memory""" capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: DummyClientConfiguration) -> None: super().__init__(schema, config) + self.in_staging_context = False self.config: DummyClientConfiguration = config def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: @@ -138,6 +164,17 @@ def create_table_chain_completed_followup_jobs( def complete_load(self, load_id: str) -> None: pass + def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: + return super().should_load_data_to_staging_dataset(table) + + @contextmanager + def with_staging_dataset(self) -> Iterator[JobClientBase]: + try: + self.in_staging_context = True + yield self + finally: + self.in_staging_context = False + def __enter__(self) -> "DummyClient": return self @@ -146,5 +183,8 @@ def __exit__( ) -> None: pass - def _create_job(self, job_id: str) -> LoadDummyJob: - return LoadDummyJob(job_id, config=self.config) + def _create_job(self, job_id: str) -> LoadDummyBaseJob: + if NewReferenceJob.is_reference_job(job_id): + return LoadDummyBaseJob(job_id, config=self.config) + else: + return LoadDummyJob(job_id, config=self.config) diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index 5885f8a1ec..33a597f915 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -77,9 +77,9 @@ def exception(self) -> str: class FollowupFilesystemJob(FollowupJob, LoadFilesystemJob): - def create_followup_jobs(self, next_state: str) -> List[NewLoadJob]: - jobs = super().create_followup_jobs(next_state) - if next_state == "completed": + def create_followup_jobs(self, final_state: TLoadJobState) -> List[NewLoadJob]: + jobs = super().create_followup_jobs(final_state) + if final_state == "completed": ref_job = NewReferenceJob( file_name=self.file_name(), status="running", remote_path=self.make_remote_path() ) diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py index 33e6194602..457e128ba0 100644 --- a/dlt/destinations/impl/synapse/synapse.py +++ b/dlt/destinations/impl/synapse/synapse.py @@ -70,23 +70,18 @@ def __init__(self, schema: Schema, config: SynapseClientConfiguration) -> None: def _get_table_update_sql( self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool ) -> List[str]: - table = self.get_load_table(table_name, staging=self.in_staging_mode) - if table is None: - table_index_type = self.config.default_table_index_type + table = self.prepare_load_table(table_name, staging=self.in_staging_mode) + table_index_type = cast(TTableIndexType, table.get(TABLE_INDEX_TYPE_HINT)) + if self.in_staging_mode: + final_table = self.prepare_load_table(table_name, staging=False) + final_table_index_type = cast(TTableIndexType, final_table.get(TABLE_INDEX_TYPE_HINT)) else: - table_index_type = cast(TTableIndexType, table.get(TABLE_INDEX_TYPE_HINT)) - if self.in_staging_mode: - final_table = self.get_load_table(table_name, staging=False) - final_table_index_type = cast( - TTableIndexType, final_table.get(TABLE_INDEX_TYPE_HINT) - ) - else: - final_table_index_type = table_index_type - if final_table_index_type == "clustered_columnstore_index": - # Even if the staging table has index type "heap", we still adjust - # the column data types to prevent errors when writing into the - # final table that has index type "clustered_columnstore_index". - new_columns = self._get_columstore_valid_columns(new_columns) + final_table_index_type = table_index_type + if final_table_index_type == "clustered_columnstore_index": + # Even if the staging table has index type "heap", we still adjust + # the column data types to prevent errors when writing into the + # final table that has index type "clustered_columnstore_index". + new_columns = self._get_columstore_valid_columns(new_columns) _sql_result = SqlJobClientBase._get_table_update_sql( self, table_name, new_columns, generate_alter @@ -135,10 +130,8 @@ def _create_replace_followup_jobs( return [SynapseStagingCopyJob.from_table_chain(table_chain, self.sql_client)] return super()._create_replace_followup_jobs(table_chain) - def get_load_table(self, table_name: str, staging: bool = False) -> TTableSchema: - table = super().get_load_table(table_name, staging) - if table is None: - return None + def prepare_load_table(self, table_name: str, staging: bool = False) -> TTableSchema: + table = super().prepare_load_table(table_name, staging) if staging and self.config.replace_strategy == "insert-from-staging": # Staging tables should always be heap tables, because "when you are # temporarily landing data in dedicated SQL pool, you may find that @@ -153,7 +146,7 @@ def get_load_table(self, table_name: str, staging: bool = False) -> TTableSchema # index for faster query performance." table[TABLE_INDEX_TYPE_HINT] = "heap" # type: ignore[typeddict-unknown-key] # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-index#heap-tables - elif table_name in self.schema.data_table_names(): + else: if TABLE_INDEX_TYPE_HINT not in table: # If present in parent table, fetch hint from there. table[TABLE_INDEX_TYPE_HINT] = get_inherited_table_hint( # type: ignore[typeddict-unknown-key] diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 4333509b2c..7896fa2cc4 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -427,7 +427,7 @@ def _build_schema_update_sql( sql += ";" sql_updates.append(sql) # create a schema update for particular table - partial_table = copy(self.get_load_table(table_name)) + partial_table = copy(self.prepare_load_table(table_name)) # keep only new columns partial_table["columns"] = {c["name"]: c for c in new_columns} schema_update[table_name] = partial_table @@ -445,8 +445,8 @@ def _get_table_update_sql( ) -> List[str]: # build sql canonical_name = self.sql_client.make_qualified_table_name(table_name) - table = self.get_load_table(table_name) - table_format = table.get("table_format") if table else None + table = self.prepare_load_table(table_name) + table_format = table.get("table_format") sql_result: List[str] = [] if not generate_alter: # build CREATE diff --git a/dlt/load/load.py b/dlt/load/load.py index 998bf6a580..050e7bce67 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -1,16 +1,15 @@ import contextlib from functools import reduce import datetime # noqa: 251 -from typing import Dict, List, Optional, Tuple, Set, Iterator, Iterable, Callable +from typing import Dict, List, Optional, Tuple, Set, Iterator, Iterable from concurrent.futures import Executor import os from dlt.common import sleep, logger -from dlt.common.utils import order_deduped from dlt.common.configuration import with_config, known_sections from dlt.common.configuration.accessors import config from dlt.common.pipeline import LoadInfo, LoadMetrics, SupportsPipeline, WithStepInfo -from dlt.common.schema.utils import get_child_tables, get_top_level_table, has_table_seen_data +from dlt.common.schema.utils import get_top_level_table from dlt.common.storages.load_storage import LoadPackageInfo, ParsedLoadJobFileName, TJobState from dlt.common.runners import TRunMetrics, Runnable, workermethod, NullExecutor from dlt.common.runtime.collector import Collector, NULL_COLLECTOR @@ -21,8 +20,6 @@ DestinationTransientException, ) from dlt.common.schema import Schema, TSchemaTables -from dlt.common.schema.typing import TTableSchema, TWriteDisposition -from dlt.common.schema.utils import has_column_with_prop from dlt.common.storages import LoadStorage from dlt.common.destination.reference import ( DestinationClientDwhConfiguration, @@ -47,6 +44,7 @@ LoadClientUnsupportedWriteDisposition, LoadClientUnsupportedFileFormats, ) +from dlt.load.utils import get_completed_table_chain, init_client class Load(Runnable[Executor], WithStepInfo[LoadMetrics, LoadInfo]): @@ -140,7 +138,7 @@ def w_spool_job( file_path, ) logger.info(f"Will load file {file_path} with table name {job_info.table_name}") - table = client.get_load_table(job_info.table_name) + table = client.prepare_load_table(job_info.table_name) if table["write_disposition"] not in ["append", "replace", "merge"]: raise LoadClientUnsupportedWriteDisposition( job_info.table_name, table["write_disposition"], file_path @@ -223,47 +221,6 @@ def get_new_jobs_info(self, load_id: str) -> List[ParsedLoadJobFileName]: for job_file in self.load_storage.list_new_jobs(load_id) ] - def get_completed_table_chain( - self, - load_id: str, - schema: Schema, - top_merged_table: TTableSchema, - being_completed_job_id: str = None, - ) -> List[TTableSchema]: - """Gets a table chain starting from the `top_merged_table` containing only tables with completed/failed jobs. None is returned if there's any job that is not completed - For append and merge write disposition, tables without jobs will be included, providing they have seen data (and were created in the destination) - Optionally `being_completed_job_id` can be passed that is considered to be completed before job itself moves in storage - """ - # returns ordered list of tables from parent to child leaf tables - table_chain: List[TTableSchema] = [] - # allow for jobless tables for those write disposition - skip_jobless_table = top_merged_table["write_disposition"] not in ("replace", "merge") - - # make sure all the jobs for the table chain is completed - for table in get_child_tables(schema.tables, top_merged_table["name"]): - table_jobs = self.load_storage.normalized_packages.list_jobs_for_table( - load_id, table["name"] - ) - # skip tables that never seen data - if not has_table_seen_data(table): - assert len(table_jobs) == 0, f"Tables that never seen data cannot have jobs {table}" - continue - # skip jobless tables - if len(table_jobs) == 0 and skip_jobless_table: - continue - else: - # all jobs must be completed in order for merge to be created - if any( - job.state not in ("failed_jobs", "completed_jobs") - and job.job_file_info.job_id() != being_completed_job_id - for job in table_jobs - ): - return None - table_chain.append(table) - # there must be at least table - assert len(table_chain) > 0 - return table_chain - def create_followup_jobs( self, load_id: str, state: TLoadJobState, starting_job: LoadJob, schema: Schema ) -> List[NewLoadJob]: @@ -278,8 +235,9 @@ def create_followup_jobs( schema.tables, starting_job.job_file_info().table_name ) # if all tables of chain completed, create follow up jobs - if table_chain := self.get_completed_table_chain( - load_id, schema, top_job_table, starting_job.job_file_info().job_id() + all_jobs = self.load_storage.normalized_packages.list_all_jobs(load_id) + if table_chain := get_completed_table_chain( + schema, all_jobs, top_job_table, starting_job.job_file_info().job_id() ): if follow_up_jobs := client.create_table_chain_completed_followup_jobs( table_chain @@ -289,7 +247,32 @@ def create_followup_jobs( return jobs def complete_jobs(self, load_id: str, jobs: List[LoadJob], schema: Schema) -> List[LoadJob]: + """Run periodically in the main thread to collect job execution statuses. + + After detecting change of status, it commits the job state by moving it to the right folder + May create one or more followup jobs that get scheduled as new jobs. New jobs are created + only in terminal states (completed / failed) + """ remaining_jobs: List[LoadJob] = [] + + def _schedule_followup_jobs(followup_jobs: Iterable[NewLoadJob]) -> None: + for followup_job in followup_jobs: + # running should be moved into "new jobs", other statuses into started + folder: TJobState = ( + "new_jobs" if followup_job.state() == "running" else "started_jobs" + ) + # save all created jobs + self.load_storage.normalized_packages.import_job( + load_id, followup_job.new_file_path(), job_state=folder + ) + logger.info( + f"Job {job.job_id()} CREATED a new FOLLOWUP JOB" + f" {followup_job.new_file_path()} placed in {folder}" + ) + # if followup job is not "running" place it in current queue to be finalized + if not followup_job.state() == "running": + remaining_jobs.append(followup_job) + logger.info(f"Will complete {len(jobs)} for {load_id}") for ii in range(len(jobs)): job = jobs[ii] @@ -300,6 +283,9 @@ def complete_jobs(self, load_id: str, jobs: List[LoadJob], schema: Schema) -> Li logger.debug(f"job {job.job_id()} still running") remaining_jobs.append(job) elif state == "failed": + # create followup jobs + _schedule_followup_jobs(self.create_followup_jobs(load_id, state, job, schema)) + # try to get exception message from job failed_message = job.exception() self.load_storage.normalized_packages.fail_job( @@ -319,23 +305,7 @@ def complete_jobs(self, load_id: str, jobs: List[LoadJob], schema: Schema) -> Li ) elif state == "completed": # create followup jobs - followup_jobs = self.create_followup_jobs(load_id, state, job, schema) - for followup_job in followup_jobs: - # running should be moved into "new jobs", other statuses into started - folder: TJobState = ( - "new_jobs" if followup_job.state() == "running" else "started_jobs" - ) - # save all created jobs - self.load_storage.normalized_packages.import_job( - load_id, followup_job.new_file_path(), job_state=folder - ) - logger.info( - f"Job {job.job_id()} CREATED a new FOLLOWUP JOB" - f" {followup_job.new_file_path()} placed in {folder}" - ) - # if followup job is not "running" place it in current queue to be finalized - if not followup_job.state() == "running": - remaining_jobs.append(followup_job) + _schedule_followup_jobs(self.create_followup_jobs(load_id, state, job, schema)) # move to completed folder after followup jobs are created # in case of exception when creating followup job, the loader will retry operation and try to complete again self.load_storage.normalized_packages.complete_job(load_id, job.file_name()) @@ -365,147 +335,17 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) f"All jobs completed, archiving package {load_id} with aborted set to {aborted}" ) - @staticmethod - def _extend_tables_with_table_chain( - schema: Schema, - tables: Iterable[str], - tables_with_jobs: Iterable[str], - include_table_filter: Callable[[TTableSchema], bool] = lambda t: True, - ) -> Iterable[str]: - """Extend 'tables` with all their children and filter out tables that do not have jobs (in `tables_with_jobs`), - haven't seen data or are not included by `include_table_filter`. - Note that for top tables with replace and merge, the filter for tables that do not have jobs - - Returns an unordered set of table names and their child tables - """ - result: Set[str] = set() - for table_name in tables: - top_job_table = get_top_level_table(schema.tables, table_name) - # NOTE: this will ie. eliminate all non iceberg tables on ATHENA destination from staging (only iceberg needs that) - if not include_table_filter(top_job_table): - continue - # for replace and merge write dispositions we should include tables - # without jobs in the table chain, because child tables may need - # processing due to changes in the root table - skip_jobless_table = top_job_table["write_disposition"] not in ("replace", "merge") - for table in get_child_tables(schema.tables, top_job_table["name"]): - table_has_job = table["name"] in tables_with_jobs - # table that never seen data are skipped as they will not be created - if not has_table_seen_data(table): - assert ( - not table_has_job - ), f"Tables that never seen data cannot have jobs {table}" - continue - # if there's no job for the table and we are in append then skip - if not table_has_job and skip_jobless_table: - continue - result.add(table["name"]) - return result - - @staticmethod - def _init_dataset_and_update_schema( - job_client: JobClientBase, - expected_update: TSchemaTables, - update_tables: Iterable[str], - truncate_tables: Iterable[str] = None, - staging_info: bool = False, - ) -> TSchemaTables: - staging_text = "for staging dataset" if staging_info else "" - logger.info( - f"Client for {job_client.config.destination_type} will start initialize storage" - f" {staging_text}" - ) - job_client.initialize_storage() - logger.info( - f"Client for {job_client.config.destination_type} will update schema to package schema" - f" {staging_text}" - ) - applied_update = job_client.update_stored_schema( - only_tables=update_tables, expected_update=expected_update - ) - logger.info( - f"Client for {job_client.config.destination_type} will truncate tables {staging_text}" - ) - job_client.initialize_storage(truncate_tables=truncate_tables) - return applied_update - - def _init_client( - self, - job_client: JobClientBase, - schema: Schema, - expected_update: TSchemaTables, - load_id: str, - truncate_filter: Callable[[TTableSchema], bool], - load_staging_filter: Callable[[TTableSchema], bool], - ) -> TSchemaTables: - """Initializes destination storage including staging dataset if supported - - Will initialize and migrate schema in destination dataset and staging dataset. - - Args: - job_client (JobClientBase): Instance of destination client - schema (Schema): The schema as in load package - expected_update (TSchemaTables): Schema update as in load package. Always present even if empty - load_id (str): Package load id - truncate_filter (Callable[[TTableSchema], bool]): A filter that tells which table in destination dataset should be truncated - load_staging_filter (Callable[[TTableSchema], bool]): A filter which tell which table in the staging dataset may be loaded into - - Returns: - TSchemaTables: Actual migrations done at destination - """ - # get dlt/internal tables - dlt_tables = set(schema.dlt_table_names()) - # tables without data (TODO: normalizer removes such jobs, write tests and remove the line below) - tables_no_data = set( - table["name"] for table in schema.data_tables() if not has_table_seen_data(table) - ) - # get all tables that actually have load jobs with data - tables_with_jobs = ( - set(job.table_name for job in self.get_new_jobs_info(load_id)) - tables_no_data - ) - - # get tables to truncate by extending tables with jobs with all their child tables - truncate_tables = set( - self._extend_tables_with_table_chain( - schema, tables_with_jobs, tables_with_jobs, truncate_filter - ) - ) - - applied_update = self._init_dataset_and_update_schema( - job_client, expected_update, tables_with_jobs | dlt_tables, truncate_tables - ) - - # update the staging dataset if client supports this - if isinstance(job_client, WithStagingDataset): - # get staging tables (all data tables that are eligible) - staging_tables = set( - self._extend_tables_with_table_chain( - schema, tables_with_jobs, tables_with_jobs, load_staging_filter - ) - ) - - if staging_tables: - with job_client.with_staging_dataset(): - self._init_dataset_and_update_schema( - job_client, - expected_update, - staging_tables | {schema.version_table_name}, # keep only schema version - staging_tables, # all eligible tables must be also truncated - staging_info=True, - ) - - return applied_update - def load_single_package(self, load_id: str, schema: Schema) -> None: + new_jobs = self.get_new_jobs_info(load_id) # initialize analytical storage ie. create dataset required by passed schema with self.get_destination_client(schema) as job_client: if (expected_update := self.load_storage.begin_schema_update(load_id)) is not None: # init job client - applied_update = self._init_client( + applied_update = init_client( job_client, schema, + new_jobs, expected_update, - load_id, job_client.should_truncate_table_before_load, ( job_client.should_load_data_to_staging_dataset @@ -521,11 +361,11 @@ def load_single_package(self, load_id: str, schema: Schema) -> None: " implement SupportsStagingDestination" ) with self.get_staging_destination_client(schema) as staging_client: - self._init_client( + init_client( staging_client, schema, + new_jobs, expected_update, - load_id, job_client.should_truncate_table_before_load_on_staging_destination, job_client.should_load_data_to_staging_dataset_on_staging_destination, ) diff --git a/dlt/load/utils.py b/dlt/load/utils.py new file mode 100644 index 0000000000..067ae33613 --- /dev/null +++ b/dlt/load/utils.py @@ -0,0 +1,186 @@ +from typing import List, Set, Iterable, Callable + +from dlt.common import logger +from dlt.common.storages.load_package import LoadJobInfo, PackageStorage +from dlt.common.schema.utils import ( + fill_hints_from_parent_and_clone_table, + get_child_tables, + get_top_level_table, + has_table_seen_data, +) +from dlt.common.storages.load_storage import ParsedLoadJobFileName +from dlt.common.schema import Schema, TSchemaTables +from dlt.common.schema.typing import TTableSchema +from dlt.common.destination.reference import ( + JobClientBase, + WithStagingDataset, +) + + +def get_completed_table_chain( + schema: Schema, + all_jobs: Iterable[LoadJobInfo], + top_merged_table: TTableSchema, + being_completed_job_id: str = None, +) -> List[TTableSchema]: + """Gets a table chain starting from the `top_merged_table` containing only tables with completed/failed jobs. None is returned if there's any job that is not completed + For append and merge write disposition, tables without jobs will be included, providing they have seen data (and were created in the destination) + Optionally `being_completed_job_id` can be passed that is considered to be completed before job itself moves in storage + """ + # returns ordered list of tables from parent to child leaf tables + table_chain: List[TTableSchema] = [] + # allow for jobless tables for those write disposition + skip_jobless_table = top_merged_table["write_disposition"] not in ("replace", "merge") + + # make sure all the jobs for the table chain is completed + for table in map( + lambda t: fill_hints_from_parent_and_clone_table(schema.tables, t), + get_child_tables(schema.tables, top_merged_table["name"]), + ): + table_jobs = PackageStorage.filter_jobs_for_table(all_jobs, table["name"]) + # skip tables that never seen data + if not has_table_seen_data(table): + assert len(table_jobs) == 0, f"Tables that never seen data cannot have jobs {table}" + continue + # skip jobless tables + if len(table_jobs) == 0 and skip_jobless_table: + continue + else: + # all jobs must be completed in order for merge to be created + if any( + job.state not in ("failed_jobs", "completed_jobs") + and job.job_file_info.job_id() != being_completed_job_id + for job in table_jobs + ): + return None + table_chain.append(table) + # there must be at least table + assert len(table_chain) > 0 + return table_chain + + +def init_client( + job_client: JobClientBase, + schema: Schema, + new_jobs: Iterable[ParsedLoadJobFileName], + expected_update: TSchemaTables, + truncate_filter: Callable[[TTableSchema], bool], + load_staging_filter: Callable[[TTableSchema], bool], +) -> TSchemaTables: + """Initializes destination storage including staging dataset if supported + + Will initialize and migrate schema in destination dataset and staging dataset. + + Args: + job_client (JobClientBase): Instance of destination client + schema (Schema): The schema as in load package + new_jobs (Iterable[LoadJobInfo]): List of new jobs + expected_update (TSchemaTables): Schema update as in load package. Always present even if empty + truncate_filter (Callable[[TTableSchema], bool]): A filter that tells which table in destination dataset should be truncated + load_staging_filter (Callable[[TTableSchema], bool]): A filter which tell which table in the staging dataset may be loaded into + + Returns: + TSchemaTables: Actual migrations done at destination + """ + # get dlt/internal tables + dlt_tables = set(schema.dlt_table_names()) + # tables without data (TODO: normalizer removes such jobs, write tests and remove the line below) + tables_no_data = set( + table["name"] for table in schema.data_tables() if not has_table_seen_data(table) + ) + # get all tables that actually have load jobs with data + tables_with_jobs = set(job.table_name for job in new_jobs) - tables_no_data + + # get tables to truncate by extending tables with jobs with all their child tables + truncate_tables = set( + _extend_tables_with_table_chain(schema, tables_with_jobs, tables_with_jobs, truncate_filter) + ) + + applied_update = _init_dataset_and_update_schema( + job_client, expected_update, tables_with_jobs | dlt_tables, truncate_tables + ) + + # update the staging dataset if client supports this + if isinstance(job_client, WithStagingDataset): + # get staging tables (all data tables that are eligible) + staging_tables = set( + _extend_tables_with_table_chain( + schema, tables_with_jobs, tables_with_jobs, load_staging_filter + ) + ) + + if staging_tables: + with job_client.with_staging_dataset(): + _init_dataset_and_update_schema( + job_client, + expected_update, + staging_tables | {schema.version_table_name}, # keep only schema version + staging_tables, # all eligible tables must be also truncated + staging_info=True, + ) + + return applied_update + + +def _init_dataset_and_update_schema( + job_client: JobClientBase, + expected_update: TSchemaTables, + update_tables: Iterable[str], + truncate_tables: Iterable[str] = None, + staging_info: bool = False, +) -> TSchemaTables: + staging_text = "for staging dataset" if staging_info else "" + logger.info( + f"Client for {job_client.config.destination_type} will start initialize storage" + f" {staging_text}" + ) + job_client.initialize_storage() + logger.info( + f"Client for {job_client.config.destination_type} will update schema to package schema" + f" {staging_text}" + ) + applied_update = job_client.update_stored_schema( + only_tables=update_tables, expected_update=expected_update + ) + logger.info( + f"Client for {job_client.config.destination_type} will truncate tables {staging_text}" + ) + job_client.initialize_storage(truncate_tables=truncate_tables) + return applied_update + + +def _extend_tables_with_table_chain( + schema: Schema, + tables: Iterable[str], + tables_with_jobs: Iterable[str], + include_table_filter: Callable[[TTableSchema], bool] = lambda t: True, +) -> Iterable[str]: + """Extend 'tables` with all their children and filter out tables that do not have jobs (in `tables_with_jobs`), + haven't seen data or are not included by `include_table_filter`. + Note that for top tables with replace and merge, the filter for tables that do not have jobs + + Returns an unordered set of table names and their child tables + """ + result: Set[str] = set() + for table_name in tables: + top_job_table = get_top_level_table(schema.tables, table_name) + # for replace and merge write dispositions we should include tables + # without jobs in the table chain, because child tables may need + # processing due to changes in the root table + skip_jobless_table = top_job_table["write_disposition"] not in ("replace", "merge") + for table in map( + lambda t: fill_hints_from_parent_and_clone_table(schema.tables, t), + get_child_tables(schema.tables, top_job_table["name"]), + ): + chain_table_name = table["name"] + table_has_job = chain_table_name in tables_with_jobs + # table that never seen data are skipped as they will not be created + # also filter out tables + # NOTE: this will ie. eliminate all non iceberg tables on ATHENA destination from staging (only iceberg needs that) + if not has_table_seen_data(table) or not include_table_filter(table): + continue + # if there's no job for the table and we are in append then skip + if not table_has_job and skip_jobless_table: + continue + result.add(chain_table_name) + return result diff --git a/tests/load/bigquery/test_bigquery_table_builder.py b/tests/load/bigquery/test_bigquery_table_builder.py index 63b54726c5..a223de9b26 100644 --- a/tests/load/bigquery/test_bigquery_table_builder.py +++ b/tests/load/bigquery/test_bigquery_table_builder.py @@ -32,12 +32,7 @@ DestinationTestConfiguration, drop_active_pipeline_data, ) -from tests.load.utils import TABLE_UPDATE, sequence_generator - - -@pytest.fixture -def schema() -> Schema: - return Schema("event") +from tests.load.utils import TABLE_UPDATE, sequence_generator, empty_schema def test_configuration() -> None: @@ -56,13 +51,13 @@ def test_configuration() -> None: @pytest.fixture -def gcp_client(schema: Schema) -> BigQueryClient: +def gcp_client(empty_schema: Schema) -> BigQueryClient: # return a client without opening connection creds = GcpServiceAccountCredentialsWithoutDefaults() creds.project_id = "test_project_id" # noinspection PydanticTypeChecker return BigQueryClient( - schema, + empty_schema, BigQueryClientConfiguration( dataset_name=f"test_{uniq_id()}", credentials=creds # type: ignore[arg-type] ), diff --git a/tests/load/duckdb/test_duckdb_table_builder.py b/tests/load/duckdb/test_duckdb_table_builder.py index 0e6f799047..9b12e04f77 100644 --- a/tests/load/duckdb/test_duckdb_table_builder.py +++ b/tests/load/duckdb/test_duckdb_table_builder.py @@ -8,18 +8,13 @@ from dlt.destinations.impl.duckdb.duck import DuckDbClient from dlt.destinations.impl.duckdb.configuration import DuckDbClientConfiguration -from tests.load.utils import TABLE_UPDATE +from tests.load.utils import TABLE_UPDATE, empty_schema @pytest.fixture -def schema() -> Schema: - return Schema("event") - - -@pytest.fixture -def client(schema: Schema) -> DuckDbClient: +def client(empty_schema: Schema) -> DuckDbClient: # return client without opening connection - return DuckDbClient(schema, DuckDbClientConfiguration(dataset_name="test_" + uniq_id())) + return DuckDbClient(empty_schema, DuckDbClientConfiguration(dataset_name="test_" + uniq_id())) def test_create_table(client: DuckDbClient) -> None: diff --git a/tests/load/mssql/test_mssql_table_builder.py b/tests/load/mssql/test_mssql_table_builder.py index 039ce99113..75f46e8905 100644 --- a/tests/load/mssql/test_mssql_table_builder.py +++ b/tests/load/mssql/test_mssql_table_builder.py @@ -9,19 +9,14 @@ from dlt.destinations.impl.mssql.mssql import MsSqlClient from dlt.destinations.impl.mssql.configuration import MsSqlClientConfiguration, MsSqlCredentials -from tests.load.utils import TABLE_UPDATE +from tests.load.utils import TABLE_UPDATE, empty_schema @pytest.fixture -def schema() -> Schema: - return Schema("event") - - -@pytest.fixture -def client(schema: Schema) -> MsSqlClient: +def client(empty_schema: Schema) -> MsSqlClient: # return client without opening connection return MsSqlClient( - schema, + empty_schema, MsSqlClientConfiguration(dataset_name="test_" + uniq_id(), credentials=MsSqlCredentials()), ) diff --git a/tests/load/postgres/test_postgres_table_builder.py b/tests/load/postgres/test_postgres_table_builder.py index 68e6702b75..fde9d82cf7 100644 --- a/tests/load/postgres/test_postgres_table_builder.py +++ b/tests/load/postgres/test_postgres_table_builder.py @@ -2,6 +2,7 @@ from copy import deepcopy import sqlfluff +from dlt.common.schema.utils import new_table from dlt.common.utils import uniq_id from dlt.common.schema import Schema @@ -11,19 +12,14 @@ PostgresCredentials, ) -from tests.load.utils import TABLE_UPDATE +from tests.load.utils import TABLE_UPDATE, empty_schema @pytest.fixture -def schema() -> Schema: - return Schema("event") - - -@pytest.fixture -def client(schema: Schema) -> PostgresClient: +def client(empty_schema: Schema) -> PostgresClient: # return client without opening connection return PostgresClient( - schema, + empty_schema, PostgresClientConfiguration( dataset_name="test_" + uniq_id(), credentials=PostgresCredentials() ), diff --git a/tests/load/redshift/test_redshift_table_builder.py b/tests/load/redshift/test_redshift_table_builder.py index 4ad66b6f6b..5307be3e73 100644 --- a/tests/load/redshift/test_redshift_table_builder.py +++ b/tests/load/redshift/test_redshift_table_builder.py @@ -12,19 +12,14 @@ RedshiftCredentials, ) -from tests.load.utils import TABLE_UPDATE +from tests.load.utils import TABLE_UPDATE, empty_schema @pytest.fixture -def schema() -> Schema: - return Schema("event") - - -@pytest.fixture -def client(schema: Schema) -> RedshiftClient: +def client(empty_schema: Schema) -> RedshiftClient: # return client without opening connection return RedshiftClient( - schema, + empty_schema, RedshiftClientConfiguration( dataset_name="test_" + uniq_id(), credentials=RedshiftCredentials() ), diff --git a/tests/load/snowflake/test_snowflake_table_builder.py b/tests/load/snowflake/test_snowflake_table_builder.py index e6eaf26c89..1e80a61f1c 100644 --- a/tests/load/snowflake/test_snowflake_table_builder.py +++ b/tests/load/snowflake/test_snowflake_table_builder.py @@ -12,20 +12,16 @@ ) from dlt.destinations.exceptions import DestinationSchemaWillNotUpdate -from tests.load.utils import TABLE_UPDATE +from tests.load.utils import TABLE_UPDATE, empty_schema @pytest.fixture -def schema() -> Schema: - return Schema("event") - - -@pytest.fixture -def snowflake_client(schema: Schema) -> SnowflakeClient: +def snowflake_client(empty_schema: Schema) -> SnowflakeClient: # return client without opening connection creds = SnowflakeCredentials() return SnowflakeClient( - schema, SnowflakeClientConfiguration(dataset_name="test_" + uniq_id(), credentials=creds) + empty_schema, + SnowflakeClientConfiguration(dataset_name="test_" + uniq_id(), credentials=creds), ) diff --git a/tests/load/synapse/test_synapse_table_builder.py b/tests/load/synapse/test_synapse_table_builder.py index 4719a8d003..871ceecf96 100644 --- a/tests/load/synapse/test_synapse_table_builder.py +++ b/tests/load/synapse/test_synapse_table_builder.py @@ -13,7 +13,7 @@ SynapseCredentials, ) -from tests.load.utils import TABLE_UPDATE +from tests.load.utils import TABLE_UPDATE, empty_schema from dlt.destinations.impl.synapse.synapse import ( HINT_TO_SYNAPSE_ATTR, TABLE_INDEX_TYPE_TO_SYNAPSE_ATTR, @@ -21,15 +21,10 @@ @pytest.fixture -def schema() -> Schema: - return Schema("event") - - -@pytest.fixture -def client(schema: Schema) -> SynapseClient: +def client(empty_schema: Schema) -> SynapseClient: # return client without opening connection client = SynapseClient( - schema, + empty_schema, SynapseClientConfiguration( dataset_name="test_" + uniq_id(), credentials=SynapseCredentials() ), @@ -39,10 +34,10 @@ def client(schema: Schema) -> SynapseClient: @pytest.fixture -def client_with_indexes_enabled(schema: Schema) -> SynapseClient: +def client_with_indexes_enabled(empty_schema: Schema) -> SynapseClient: # return client without opening connection client = SynapseClient( - schema, + empty_schema, SynapseClientConfiguration( dataset_name="test_" + uniq_id(), credentials=SynapseCredentials(), create_indexes=True ), diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index 7436023f03..d7884abcf0 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -6,18 +6,26 @@ from typing import List from dlt.common.exceptions import TerminalException, TerminalValueError +from dlt.common.schema.typing import TWriteDisposition from dlt.common.storages import FileStorage, LoadStorage, PackageStorage, ParsedLoadJobFileName +from dlt.common.storages.load_package import LoadJobInfo from dlt.common.storages.load_storage import JobWithUnsupportedWriterException from dlt.common.destination.reference import LoadJob, TDestination +from dlt.common.schema.utils import ( + fill_hints_from_parent_and_clone_table, + get_child_tables, + get_top_level_table, +) -from dlt.load import Load +from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration from dlt.destinations.job_impl import EmptyLoadJob - -from dlt.destinations import dummy +from dlt.destinations import dummy, filesystem from dlt.destinations.impl.dummy import dummy as dummy_impl from dlt.destinations.impl.dummy.configuration import DummyClientConfiguration + +from dlt.load import Load from dlt.load.exceptions import LoadClientJobFailed, LoadClientJobRetry -from dlt.common.schema.utils import get_top_level_table +from dlt.load.utils import get_completed_table_chain, init_client, _extend_tables_with_table_chain from tests.utils import ( clean_test_storage, @@ -26,7 +34,7 @@ preserve_environ, ) from tests.load.utils import prepare_load_package -from tests.utils import skip_if_not_active +from tests.utils import skip_if_not_active, TEST_STORAGE_ROOT skip_if_not_active("dummy") @@ -35,6 +43,8 @@ "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl", ] +REMOTE_FILESYSTEM = os.path.abspath(os.path.join(TEST_STORAGE_ROOT, "_remote_filesystem")) + @pytest.fixture(autouse=True) def storage() -> FileStorage: @@ -110,14 +120,19 @@ def test_get_completed_table_chain_single_job_per_table() -> None: load = setup_loader() load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) + # update tables so we have all possible hints + for table_name, table in schema.tables.items(): + schema.tables[table_name] = fill_hints_from_parent_and_clone_table(schema.tables, table) + top_job_table = get_top_level_table(schema.tables, "event_user") - assert load.get_completed_table_chain(load_id, schema, top_job_table) is None + all_jobs = load.load_storage.normalized_packages.list_all_jobs(load_id) + assert get_completed_table_chain(schema, all_jobs, top_job_table) is None # fake being completed assert ( len( - load.get_completed_table_chain( - load_id, + get_completed_table_chain( schema, + all_jobs, top_job_table, "event_user.839c6e6b514e427687586ccc65bf133f.jsonl", ) @@ -129,15 +144,17 @@ def test_get_completed_table_chain_single_job_per_table() -> None: load.load_storage.normalized_packages.start_job( load_id, "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl" ) - assert load.get_completed_table_chain(load_id, schema, loop_top_job_table) is None + all_jobs = load.load_storage.normalized_packages.list_all_jobs(load_id) + assert get_completed_table_chain(schema, all_jobs, loop_top_job_table) is None load.load_storage.normalized_packages.complete_job( load_id, "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl" ) - assert load.get_completed_table_chain(load_id, schema, loop_top_job_table) == [ + all_jobs = load.load_storage.normalized_packages.list_all_jobs(load_id) + assert get_completed_table_chain(schema, all_jobs, loop_top_job_table) == [ schema.get_table("event_loop_interrupted") ] - assert load.get_completed_table_chain( - load_id, schema, loop_top_job_table, "event_user.839c6e6b514e427687586ccc65bf133f.0.jsonl" + assert get_completed_table_chain( + schema, all_jobs, loop_top_job_table, "event_user.839c6e6b514e427687586ccc65bf133f.0.jsonl" ) == [schema.get_table("event_loop_interrupted")] @@ -188,7 +205,7 @@ def test_spool_job_failed_exception_init() -> None: # this config fails job on start os.environ["LOAD__RAISE_ON_FAILED_JOBS"] = "true" os.environ["FAIL_IN_INIT"] = "true" - load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0)) + load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0, fail_in_init=True)) load_id, _ = prepare_load_package(load.load_storage, NORMALIZED_FILES) with patch.object(dummy_impl.DummyClient, "complete_load") as complete_load: with pytest.raises(LoadClientJobFailed) as py_ex: @@ -207,7 +224,7 @@ def test_spool_job_failed_exception_complete() -> None: # this config fails job on start os.environ["LOAD__RAISE_ON_FAILED_JOBS"] = "true" os.environ["FAIL_IN_INIT"] = "false" - load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0)) + load = setup_loader(client_config=DummyClientConfiguration(fail_prob=1.0, fail_in_init=False)) load_id, _ = prepare_load_package(load.load_storage, NORMALIZED_FILES) with pytest.raises(LoadClientJobFailed) as py_ex: run_all(load) @@ -310,6 +327,20 @@ def test_try_retrieve_job() -> None: def test_completed_loop() -> None: load = setup_loader(client_config=DummyClientConfiguration(completed_prob=1.0)) assert_complete_job(load) + assert len(dummy_impl.JOBS) == 2 + assert len(dummy_impl.CREATED_FOLLOWUP_JOBS) == 0 + + +def test_completed_loop_followup_jobs() -> None: + # TODO: until we fix how we create capabilities we must set env + os.environ["CREATE_FOLLOWUP_JOBS"] = "true" + load = setup_loader( + client_config=DummyClientConfiguration(completed_prob=1.0, create_followup_jobs=True) + ) + assert_complete_job(load) + # for each JOB there's REFERENCE JOB + assert len(dummy_impl.JOBS) == 2 * 2 + assert len(dummy_impl.JOBS) == len(dummy_impl.CREATED_FOLLOWUP_JOBS) * 2 def test_failed_loop() -> None: @@ -319,6 +350,27 @@ def test_failed_loop() -> None: ) # actually not deleted because one of the jobs failed assert_complete_job(load, should_delete_completed=False) + # no jobs because fail on init + assert len(dummy_impl.JOBS) == 0 + assert len(dummy_impl.CREATED_FOLLOWUP_JOBS) == 0 + + +def test_failed_loop_followup_jobs() -> None: + # TODO: until we fix how we create capabilities we must set env + os.environ["CREATE_FOLLOWUP_JOBS"] = "true" + os.environ["FAIL_IN_INIT"] = "false" + # ask to delete completed + load = setup_loader( + delete_completed_jobs=True, + client_config=DummyClientConfiguration( + fail_prob=1.0, fail_in_init=False, create_followup_jobs=True + ), + ) + # actually not deleted because one of the jobs failed + assert_complete_job(load, should_delete_completed=False) + # followup jobs were not started + assert len(dummy_impl.JOBS) == 2 + assert len(dummy_impl.CREATED_FOLLOWUP_JOBS) == 0 def test_completed_loop_with_delete_completed() -> None: @@ -409,6 +461,286 @@ def test_wrong_writer_type() -> None: assert exv.value.load_id == load_id +def test_extend_table_chain() -> None: + load = setup_loader() + _, schema = prepare_load_package( + load.load_storage, ["event_user.b1d32c6660b242aaabbf3fc27245b7e6.0.insert_values"] + ) + # only event user table (no other jobs) + tables = _extend_tables_with_table_chain(schema, ["event_user"], ["event_user"]) + assert tables == {"event_user"} + # add child jobs + tables = _extend_tables_with_table_chain( + schema, ["event_user"], ["event_user", "event_user__parse_data__entities"] + ) + assert tables == {"event_user", "event_user__parse_data__entities"} + user_chain = {name for name in schema.data_table_names() if name.startswith("event_user__")} | { + "event_user" + } + # change event user to merge/replace to get full table chain + for w_d in ["merge", "replace"]: + schema.tables["event_user"]["write_disposition"] = w_d # type:ignore[typeddict-item] + tables = _extend_tables_with_table_chain(schema, ["event_user"], ["event_user"]) + assert tables == user_chain + # no jobs for bot + assert _extend_tables_with_table_chain(schema, ["event_bot"], ["event_user"]) == set() + # skip unseen tables + del schema.tables["event_user__parse_data__entities"][ # type:ignore[typeddict-item] + "x-normalizer" + ] + entities_chain = { + name + for name in schema.data_table_names() + if name.startswith("event_user__parse_data__entities") + } + tables = _extend_tables_with_table_chain(schema, ["event_user"], ["event_user"]) + assert tables == user_chain - {"event_user__parse_data__entities"} + # exclude the whole chain + tables = _extend_tables_with_table_chain( + schema, ["event_user"], ["event_user"], lambda table: table["name"] not in entities_chain + ) + assert tables == user_chain - entities_chain + # ask for tables that are not top + tables = _extend_tables_with_table_chain(schema, ["event_user__parse_data__entities"], []) + # user chain but without entities (not seen data) + assert tables == user_chain - {"event_user__parse_data__entities"} + # go to append and ask only for entities chain + schema.tables["event_user"]["write_disposition"] = "append" + tables = _extend_tables_with_table_chain( + schema, ["event_user__parse_data__entities"], entities_chain + ) + # without entities (not seen data) + assert tables == entities_chain - {"event_user__parse_data__entities"} + + # add multiple chains + bot_jobs = {"event_bot", "event_bot__data__buttons"} + tables = _extend_tables_with_table_chain( + schema, ["event_user__parse_data__entities", "event_bot"], entities_chain | bot_jobs + ) + assert tables == (entities_chain | bot_jobs) - {"event_user__parse_data__entities"} + + +def test_get_completed_table_chain_cases() -> None: + load = setup_loader() + _, schema = prepare_load_package( + load.load_storage, ["event_user.b1d32c6660b242aaabbf3fc27245b7e6.0.insert_values"] + ) + + # update tables so we have all possible hints + for table_name, table in schema.tables.items(): + schema.tables[table_name] = fill_hints_from_parent_and_clone_table(schema.tables, table) + + # child completed, parent not + event_user = schema.get_table("event_user") + event_user_entities = schema.get_table("event_user__parse_data__entities") + event_user_job = LoadJobInfo( + "started_jobs", + "path", + 0, + None, + 0, + ParsedLoadJobFileName("event_user", "event_user_id", 0, "jsonl"), + None, + ) + event_user_entities_job = LoadJobInfo( + "completed_jobs", + "path", + 0, + None, + 0, + ParsedLoadJobFileName( + "event_user__parse_data__entities", "event_user__parse_data__entities_id", 0, "jsonl" + ), + None, + ) + chain = get_completed_table_chain(schema, [event_user_job, event_user_entities_job], event_user) + assert chain is None + + # parent just got completed + chain = get_completed_table_chain( + schema, + [event_user_job, event_user_entities_job], + event_user, + event_user_job.job_file_info.job_id(), + ) + # full chain + assert chain == [event_user, event_user_entities] + + # parent failed, child completed + chain = get_completed_table_chain( + schema, [event_user_job._replace(state="failed_jobs"), event_user_entities_job], event_user + ) + assert chain == [event_user, event_user_entities] + + # both failed + chain = get_completed_table_chain( + schema, + [ + event_user_job._replace(state="failed_jobs"), + event_user_entities_job._replace(state="failed_jobs"), + ], + event_user, + ) + assert chain == [event_user, event_user_entities] + + # merge and replace do not require whole chain to be in jobs + user_chain = get_child_tables(schema.tables, "event_user") + for w_d in ["merge", "replace"]: + event_user["write_disposition"] = w_d # type:ignore[typeddict-item] + + chain = get_completed_table_chain( + schema, [event_user_job], event_user, event_user_job.job_file_info.job_id() + ) + assert chain == user_chain + + # but if child is present and incomplete... + chain = get_completed_table_chain( + schema, + [event_user_job, event_user_entities_job._replace(state="new_jobs")], + event_user, + event_user_job.job_file_info.job_id(), + ) + # noting is returned + assert chain is None + + # skip unseen + deep_child = schema.tables[ + "event_user__parse_data__response_selector__default__response__response_templates" + ] + del deep_child["x-normalizer"] # type:ignore[typeddict-item] + chain = get_completed_table_chain( + schema, [event_user_job], event_user, event_user_job.job_file_info.job_id() + ) + user_chain.remove(deep_child) + assert chain == user_chain + + +def test_init_client_truncate_tables() -> None: + load = setup_loader() + _, schema = prepare_load_package( + load.load_storage, ["event_user.b1d32c6660b242aaabbf3fc27245b7e6.0.insert_values"] + ) + + nothing_ = lambda _: False + all_ = lambda _: True + + event_user = ParsedLoadJobFileName("event_user", "event_user_id", 0, "jsonl") + event_bot = ParsedLoadJobFileName("event_bot", "event_bot_id", 0, "jsonl") + + with patch.object(dummy_impl.DummyClient, "initialize_storage") as initialize_storage: + with patch.object(dummy_impl.DummyClient, "update_stored_schema") as update_stored_schema: + with load.get_destination_client(schema) as client: + init_client(client, schema, [], {}, nothing_, nothing_) + # we do not allow for any staging dataset tables + assert update_stored_schema.call_count == 1 + assert update_stored_schema.call_args[1]["only_tables"] == { + "_dlt_loads", + "_dlt_version", + } + assert initialize_storage.call_count == 2 + # initialize storage is called twice, we deselected all tables to truncate + assert initialize_storage.call_args_list[0].args == () + assert initialize_storage.call_args_list[1].kwargs["truncate_tables"] == set() + + initialize_storage.reset_mock() + update_stored_schema.reset_mock() + + # now we want all tables to be truncated but not on staging + with load.get_destination_client(schema) as client: + init_client(client, schema, [event_user], {}, all_, nothing_) + assert update_stored_schema.call_count == 1 + assert "event_user" in update_stored_schema.call_args[1]["only_tables"] + assert initialize_storage.call_count == 2 + assert initialize_storage.call_args_list[0].args == () + assert initialize_storage.call_args_list[1].kwargs["truncate_tables"] == {"event_user"} + + # now we push all to stage + initialize_storage.reset_mock() + update_stored_schema.reset_mock() + + with load.get_destination_client(schema) as client: + init_client(client, schema, [event_user, event_bot], {}, nothing_, all_) + assert update_stored_schema.call_count == 2 + # first call main dataset + assert {"event_user", "event_bot"} <= set( + update_stored_schema.call_args_list[0].kwargs["only_tables"] + ) + # second one staging dataset + assert {"event_user", "event_bot"} <= set( + update_stored_schema.call_args_list[1].kwargs["only_tables"] + ) + assert initialize_storage.call_count == 4 + assert initialize_storage.call_args_list[0].args == () + assert initialize_storage.call_args_list[1].kwargs["truncate_tables"] == set() + assert initialize_storage.call_args_list[2].args == () + # all tables that will be used on staging must be truncated + assert initialize_storage.call_args_list[3].kwargs["truncate_tables"] == { + "event_user", + "event_bot", + } + + replace_ = lambda table: table["write_disposition"] == "replace" + merge_ = lambda table: table["write_disposition"] == "merge" + + # set event_bot chain to merge + bot_chain = get_child_tables(schema.tables, "event_bot") + for w_d in ["merge", "replace"]: + initialize_storage.reset_mock() + update_stored_schema.reset_mock() + for bot in bot_chain: + bot["write_disposition"] = w_d # type:ignore[typeddict-item] + # merge goes to staging, replace goes to truncate + with load.get_destination_client(schema) as client: + init_client(client, schema, [event_user, event_bot], {}, replace_, merge_) + + if w_d == "merge": + # we use staging dataset + assert update_stored_schema.call_count == 2 + # 4 tables to update in main dataset + assert len(update_stored_schema.call_args_list[0].kwargs["only_tables"]) == 4 + assert ( + "event_user" in update_stored_schema.call_args_list[0].kwargs["only_tables"] + ) + # full bot table chain + dlt version but no user + assert len( + update_stored_schema.call_args_list[1].kwargs["only_tables"] + ) == 1 + len(bot_chain) + assert ( + "event_user" + not in update_stored_schema.call_args_list[1].kwargs["only_tables"] + ) + + assert initialize_storage.call_count == 4 + assert initialize_storage.call_args_list[1].kwargs["truncate_tables"] == set() + assert initialize_storage.call_args_list[3].kwargs[ + "truncate_tables" + ] == update_stored_schema.call_args_list[1].kwargs["only_tables"] - { + "_dlt_version" + } + + if w_d == "replace": + assert update_stored_schema.call_count == 1 + assert initialize_storage.call_count == 2 + # we truncate the whole bot chain but not user (which is append) + assert len( + initialize_storage.call_args_list[1].kwargs["truncate_tables"] + ) == len(bot_chain) + # migrate only tables for which we have jobs + assert len(update_stored_schema.call_args_list[0].kwargs["only_tables"]) == 4 + # print(initialize_storage.call_args_list) + # print(update_stored_schema.call_args_list) + + +def test_dummy_staging_filesystem() -> None: + load = setup_loader( + client_config=DummyClientConfiguration(completed_prob=1.0), filesystem_staging=True + ) + assert_complete_job(load) + # two reference jobs + assert len(dummy_impl.JOBS) == 2 + assert len(dummy_impl.CREATED_FOLLOWUP_JOBS) == 0 + + def test_terminal_exceptions() -> None: try: raise TerminalValueError("a") @@ -433,6 +765,13 @@ def assert_complete_job(load: Load, should_delete_completed: bool = False) -> No ) # will finalize the whole package load.run(pool) + # may have followup jobs or staging destination + if ( + load.initial_client_config.create_followup_jobs # type:ignore[attr-defined] + or load.staging_destination + ): + # run the followup jobs + load.run(pool) # moved to loaded assert not load.load_storage.storage.has_folder( load.load_storage.get_normalized_package_path(load_id) @@ -460,15 +799,32 @@ def run_all(load: Load) -> None: def setup_loader( - delete_completed_jobs: bool = False, client_config: DummyClientConfiguration = None + delete_completed_jobs: bool = False, + client_config: DummyClientConfiguration = None, + filesystem_staging: bool = False, ) -> Load: # reset jobs for a test dummy_impl.JOBS = {} - destination: TDestination = dummy() # type: ignore[assignment] + dummy_impl.CREATED_FOLLOWUP_JOBS = {} client_config = client_config or DummyClientConfiguration(loader_file_format="jsonl") + destination: TDestination = dummy(**client_config) # type: ignore[assignment] + # setup + staging_system_config = None + staging = None + if filesystem_staging: + # do not accept jsonl to not conflict with filesystem destination + client_config = client_config or DummyClientConfiguration(loader_file_format="reference") + staging_system_config = FilesystemDestinationClientConfiguration(dataset_name="dummy") + staging_system_config.as_staging = True + os.makedirs(REMOTE_FILESYSTEM) + staging = filesystem(bucket_url=REMOTE_FILESYSTEM) # patch destination to provide client_config # destination.client = lambda schema: dummy_impl.DummyClient(schema, client_config) - # setup loader with TEST_DICT_CONFIG_PROVIDER().values({"delete_completed_jobs": delete_completed_jobs}): - return Load(destination, initial_client_config=client_config) + return Load( + destination, + initial_client_config=client_config, + staging_destination=staging, # type: ignore[arg-type] + initial_staging_client_config=staging_system_config, + ) diff --git a/tests/load/utils.py b/tests/load/utils.py index 80e4af6fc6..50dca88248 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -1,3 +1,4 @@ +import pytest import contextlib import codecs import os @@ -396,6 +397,14 @@ def destinations_configs( return destination_configs +@pytest.fixture +def empty_schema() -> Schema: + schema = Schema("event") + table = new_table("event_test_table") + schema.update_table(table) + return schema + + def get_normalized_dataset_name(client: JobClientBase) -> str: if isinstance(client.config, DestinationClientDwhConfiguration): return client.config.normalize_dataset_name(client.schema) @@ -425,7 +434,7 @@ def expect_load_file( client.capabilities.preferred_loader_file_format, ).file_name() file_storage.save(file_name, query.encode("utf-8")) - table = client.get_load_table(table_name) + table = client.prepare_load_table(table_name) job = client.start_file_load(table, file_storage.make_full_path(file_name), uniq_id()) while job.state() == "running": sleep(0.5)