From 32687c3fc6c20c4b45d012660dfb6c46346b6647 Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 30 Jan 2024 21:02:27 +0100 Subject: [PATCH 1/3] scd2 init --- dlt/common/normalizers/json/relational.py | 18 +++++++++--- dlt/common/schema/schema.py | 9 ++++++ dlt/common/schema/typing.py | 2 +- docs/examples/chess_production/chess.py | 12 ++++---- docs/examples/connector_x_arrow/load_arrow.py | 2 ++ docs/examples/google_sheets/google_sheets.py | 5 +++- docs/examples/incremental_loading/zendesk.py | 8 ++--- docs/examples/nested_data/nested_data.py | 2 ++ .../pdf_to_weaviate/pdf_to_weaviate.py | 5 +++- docs/examples/qdrant_zendesk/qdrant.py | 9 +++--- docs/examples/transformers/pokemon.py | 4 ++- tests/load/pipeline/test_scd2_disposition.py | 29 +++++++++++++++++++ 12 files changed, 82 insertions(+), 23 deletions(-) create mode 100644 tests/load/pipeline/test_scd2_disposition.py diff --git a/dlt/common/normalizers/json/relational.py b/dlt/common/normalizers/json/relational.py index e33bf2ab35..9fd6f96072 100644 --- a/dlt/common/normalizers/json/relational.py +++ b/dlt/common/normalizers/json/relational.py @@ -1,9 +1,9 @@ -from typing import Dict, List, Mapping, Optional, Sequence, Tuple, cast, TypedDict, Any -from dlt.common.data_types.typing import TDataType +from typing import Dict, Mapping, Optional, Sequence, Tuple, cast, TypedDict, Any + +from dlt.common.normalizers.utils import generate_dlt_id, DLT_ID_LENGTH_BYTES from dlt.common.normalizers.exceptions import InvalidJsonNormalizer from dlt.common.normalizers.typing import TJSONNormalizer -from dlt.common.normalizers.utils import generate_dlt_id, DLT_ID_LENGTH_BYTES - +from dlt.common import json from dlt.common.typing import DictStrAny, DictStrStr, TDataItem, StrAny from dlt.common.schema import Schema from dlt.common.schema.typing import TColumnSchema, TColumnName, TSimpleRegex @@ -21,6 +21,7 @@ class TDataItemRow(TypedDict, total=False): _dlt_id: str # unique id of current row + _dlt_hash: Optional[str] # hash of the row class TDataItemRowRoot(TDataItemRow, total=False): @@ -160,6 +161,11 @@ def _add_row_id( row["_dlt_id"] = row_id return row_id + def _add_row_hash(self, table: str, row: TDataItemRow) -> str: + row_hash = digest128(json.dumps(row, sort_keys=True)) + row["_dlt_hash"] = row_hash + return row_hash + def _get_propagated_values(self, table: str, row: TDataItemRow, _r_lvl: int) -> StrAny: extend: DictStrAny = {} @@ -232,6 +238,10 @@ def _normalize_row( row_id = flattened_row.get("_dlt_id", None) if not row_id: row_id = self._add_row_id(table, flattened_row, parent_row_id, pos, _r_lvl) + # add row hash (TODO: only add when needed, either via column hint or do it when scd2 wd is used) + row_hash = flattened_row.get("_dlt_hash", None) + if not row_hash: + row_hash = self._add_row_hash(table, flattened_row) # find fields to propagate to child tables in config extend.update(self._get_propagated_values(table, flattened_row, _r_lvl)) diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index e95699b91e..5852f9a90e 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -240,6 +240,15 @@ def coerce_row( updated_table_partial["columns"] = {} updated_table_partial["columns"][new_col_name] = new_col_def + # insert columns defs for scd2 (TODO: where to do this properly, maybe in a step after the normalization?) + for col in ["_dlt_valid_from", "_dlt_valid_until"]: + if col not in table["columns"].keys(): + updated_table_partial["columns"][col] = { + "name": col, + "data_type": "timestamp", + "nullable": True, + } + return new_row, updated_table_partial def apply_schema_contract( diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index 9a27cbe4bb..ef0967dd01 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -60,7 +60,7 @@ "merge_key", ] """Known hints of a column used to declare hint regexes.""" -TWriteDisposition = Literal["skip", "append", "replace", "merge"] +TWriteDisposition = Literal["skip", "append", "replace", "merge", "scd2"] TTableFormat = Literal["iceberg"] TTypeDetections = Literal[ "timestamp", "iso_timestamp", "iso_date", "large_integer", "hexbytes_to_text", "wei_to_double" diff --git a/docs/examples/chess_production/chess.py b/docs/examples/chess_production/chess.py index 2e85805781..f7c5849e57 100644 --- a/docs/examples/chess_production/chess.py +++ b/docs/examples/chess_production/chess.py @@ -6,6 +6,7 @@ from dlt.common.typing import StrAny, TDataItems from dlt.sources.helpers.requests import client + @dlt.source def chess( chess_url: str = dlt.config.value, @@ -59,6 +60,7 @@ def players_games(username: Any) -> Iterator[TDataItems]: MAX_PLAYERS = 5 + def load_data_with_retry(pipeline, data): try: for attempt in Retrying( @@ -68,9 +70,7 @@ def load_data_with_retry(pipeline, data): reraise=True, ): with attempt: - logger.info( - f"Running the pipeline, attempt={attempt.retry_state.attempt_number}" - ) + logger.info(f"Running the pipeline, attempt={attempt.retry_state.attempt_number}") load_info = pipeline.run(data) logger.info(str(load_info)) @@ -92,9 +92,7 @@ def load_data_with_retry(pipeline, data): # print the information on the first load package and all jobs inside logger.info(f"First load package info: {load_info.load_packages[0]}") # print the information on the first completed job in first load package - logger.info( - f"First completed job info: {load_info.load_packages[0].jobs['completed_jobs'][0]}" - ) + logger.info(f"First completed job info: {load_info.load_packages[0].jobs['completed_jobs'][0]}") # check for schema updates: schema_updates = [p.schema_update for p in load_info.load_packages] @@ -152,4 +150,4 @@ def load_data_with_retry(pipeline, data): ) # get data for a few famous players data = chess(chess_url="https://api.chess.com/pub/", max_players=MAX_PLAYERS) - load_data_with_retry(pipeline, data) \ No newline at end of file + load_data_with_retry(pipeline, data) diff --git a/docs/examples/connector_x_arrow/load_arrow.py b/docs/examples/connector_x_arrow/load_arrow.py index 24ba2acb0e..307e657514 100644 --- a/docs/examples/connector_x_arrow/load_arrow.py +++ b/docs/examples/connector_x_arrow/load_arrow.py @@ -3,6 +3,7 @@ import dlt from dlt.sources.credentials import ConnectionStringCredentials + def read_sql_x( conn_str: ConnectionStringCredentials = dlt.secrets.value, query: str = dlt.config.value, @@ -14,6 +15,7 @@ def read_sql_x( protocol="binary", ) + def genome_resource(): # create genome resource with merge on `upid` primary key genome = dlt.resource( diff --git a/docs/examples/google_sheets/google_sheets.py b/docs/examples/google_sheets/google_sheets.py index 8a93df9970..1ba330e4ca 100644 --- a/docs/examples/google_sheets/google_sheets.py +++ b/docs/examples/google_sheets/google_sheets.py @@ -9,6 +9,7 @@ ) from dlt.common.typing import DictStrAny, StrAny + def _initialize_sheets( credentials: Union[GcpOAuthCredentials, GcpServiceAccountCredentials] ) -> Any: @@ -16,6 +17,7 @@ def _initialize_sheets( service = build("sheets", "v4", credentials=credentials.to_native_credentials()) return service + @dlt.source def google_spreadsheet( spreadsheet_id: str, @@ -55,6 +57,7 @@ def get_sheet(sheet_name: str) -> Iterator[DictStrAny]: for name in sheet_names ] + if __name__ == "__main__": pipeline = dlt.pipeline(destination="duckdb") # see example.secrets.toml to where to put credentials @@ -67,4 +70,4 @@ def get_sheet(sheet_name: str) -> Iterator[DictStrAny]: sheet_names=range_names, ) ) - print(info) \ No newline at end of file + print(info) diff --git a/docs/examples/incremental_loading/zendesk.py b/docs/examples/incremental_loading/zendesk.py index 4b8597886a..6113f98793 100644 --- a/docs/examples/incremental_loading/zendesk.py +++ b/docs/examples/incremental_loading/zendesk.py @@ -6,12 +6,11 @@ from dlt.common.typing import TAnyDateTime from dlt.sources.helpers.requests import client + @dlt.source(max_table_nesting=2) def zendesk_support( credentials: Dict[str, str] = dlt.secrets.value, - start_date: Optional[TAnyDateTime] = pendulum.datetime( # noqa: B008 - year=2000, month=1, day=1 - ), + start_date: Optional[TAnyDateTime] = pendulum.datetime(year=2000, month=1, day=1), # noqa: B008 end_date: Optional[TAnyDateTime] = None, ): """ @@ -113,6 +112,7 @@ def get_pages( if not response_json["end_of_stream"]: get_url = response_json["next_page"] + if __name__ == "__main__": # create dlt pipeline pipeline = dlt.pipeline( @@ -120,4 +120,4 @@ def get_pages( ) load_info = pipeline.run(zendesk_support()) - print(load_info) \ No newline at end of file + print(load_info) diff --git a/docs/examples/nested_data/nested_data.py b/docs/examples/nested_data/nested_data.py index 3464448de6..7f85f0522e 100644 --- a/docs/examples/nested_data/nested_data.py +++ b/docs/examples/nested_data/nested_data.py @@ -13,6 +13,7 @@ CHUNK_SIZE = 10000 + # You can limit how deep dlt goes when generating child tables. # By default, the library will descend and generate child tables # for all nested lists, without a limit. @@ -81,6 +82,7 @@ def load_documents(self) -> Iterator[TDataItem]: while docs_slice := list(islice(cursor, CHUNK_SIZE)): yield map_nested_in_place(convert_mongo_objs, docs_slice) + def convert_mongo_objs(value: Any) -> Any: if isinstance(value, (ObjectId, Decimal128)): return str(value) diff --git a/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py b/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py index 8f7833e7d7..e7f57853ed 100644 --- a/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py +++ b/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py @@ -4,6 +4,7 @@ from dlt.destinations.impl.weaviate import weaviate_adapter from PyPDF2 import PdfReader + @dlt.resource(selected=False) def list_files(folder_path: str): folder_path = os.path.abspath(folder_path) @@ -15,6 +16,7 @@ def list_files(folder_path: str): "mtime": os.path.getmtime(file_path), } + @dlt.transformer(primary_key="page_id", write_disposition="merge") def pdf_to_text(file_item, separate_pages: bool = False): if not separate_pages: @@ -28,6 +30,7 @@ def pdf_to_text(file_item, separate_pages: bool = False): page_item["page_id"] = file_item["file_name"] + "_" + str(page_no) yield page_item + pipeline = dlt.pipeline(pipeline_name="pdf_to_text", destination="weaviate") # this constructs a simple pipeline that: (1) reads files from "invoices" folder (2) filters only those ending with ".pdf" @@ -51,4 +54,4 @@ def pdf_to_text(file_item, separate_pages: bool = False): client = weaviate.Client("http://localhost:8080") # get text of all the invoices in InvoiceText class we just created above -print(client.query.get("InvoiceText", ["text", "file_name", "mtime", "page_id"]).do()) \ No newline at end of file +print(client.query.get("InvoiceText", ["text", "file_name", "mtime", "page_id"]).do()) diff --git a/docs/examples/qdrant_zendesk/qdrant.py b/docs/examples/qdrant_zendesk/qdrant.py index 300d8dc6ad..bd0cbafc99 100644 --- a/docs/examples/qdrant_zendesk/qdrant.py +++ b/docs/examples/qdrant_zendesk/qdrant.py @@ -10,13 +10,12 @@ from dlt.common.configuration.inject import with_config + # function from: https://github.com/dlt-hub/verified-sources/tree/master/sources/zendesk @dlt.source(max_table_nesting=2) def zendesk_support( credentials: Dict[str, str] = dlt.secrets.value, - start_date: Optional[TAnyDateTime] = pendulum.datetime( # noqa: B008 - year=2000, month=1, day=1 - ), + start_date: Optional[TAnyDateTime] = pendulum.datetime(year=2000, month=1, day=1), # noqa: B008 end_date: Optional[TAnyDateTime] = None, ): """ @@ -80,6 +79,7 @@ def _parse_date_or_none(value: Optional[str]) -> Optional[pendulum.DateTime]: return None return ensure_pendulum_datetime(value) + # modify dates to return datetime objects instead def _fix_date(ticket): ticket["updated_at"] = _parse_date_or_none(ticket["updated_at"]) @@ -87,6 +87,7 @@ def _fix_date(ticket): ticket["due_at"] = _parse_date_or_none(ticket["due_at"]) return ticket + # function from: https://github.com/dlt-hub/verified-sources/tree/master/sources/zendesk def get_pages( url: str, @@ -127,6 +128,7 @@ def get_pages( if not response_json["end_of_stream"]: get_url = response_json["next_page"] + if __name__ == "__main__": # create a pipeline with an appropriate name pipeline = dlt.pipeline( @@ -146,7 +148,6 @@ def get_pages( print(load_info) - # running the Qdrant client to connect to your Qdrant database @with_config(sections=("destination", "qdrant", "credentials")) diff --git a/docs/examples/transformers/pokemon.py b/docs/examples/transformers/pokemon.py index c17beff6a8..97b9a98b11 100644 --- a/docs/examples/transformers/pokemon.py +++ b/docs/examples/transformers/pokemon.py @@ -1,6 +1,7 @@ import dlt from dlt.sources.helpers import requests + @dlt.source(max_table_nesting=2) def source(pokemon_api_url: str): """""" @@ -46,6 +47,7 @@ def species(pokemon_details): return (pokemon_list | pokemon, pokemon_list | pokemon | species) + if __name__ == "__main__": # build duck db pipeline pipeline = dlt.pipeline( @@ -54,4 +56,4 @@ def species(pokemon_details): # the pokemon_list resource does not need to be loaded load_info = pipeline.run(source("https://pokeapi.co/api/v2/pokemon")) - print(load_info) \ No newline at end of file + print(load_info) diff --git a/tests/load/pipeline/test_scd2_disposition.py b/tests/load/pipeline/test_scd2_disposition.py new file mode 100644 index 0000000000..ecf9f3578f --- /dev/null +++ b/tests/load/pipeline/test_scd2_disposition.py @@ -0,0 +1,29 @@ +import pytest, dlt + +from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["duckdb"]), + ids=lambda x: x.name, +) +def test_simple_scd2_load(destination_config: DestinationTestConfiguration) -> None: + @dlt.resource(name="items", write_disposition="scd2", primary_key="id") + def load_items(): + yield from [{ + "id": 1, + "name": "one", + }, + { + "id": 2, + "name": "two", + }, + { + "id": 3, + "name": "three", + }] + p = destination_config.setup_pipeline("test", full_refresh=True) + p.run(load_items()) + print(p.default_schema.to_pretty_yaml()) + assert False \ No newline at end of file From 8dca5b27388ade8271ae47c70aa5253758691924 Mon Sep 17 00:00:00 2001 From: Dave Date: Wed, 31 Jan 2024 14:26:25 +0100 Subject: [PATCH 2/3] finalize first scd2 experiment --- dlt/common/destination/reference.py | 3 + dlt/common/normalizers/json/relational.py | 18 ++-- dlt/destinations/job_client_impl.py | 9 +- dlt/destinations/sql_jobs.py | 81 +++++++++++++++-- tests/load/pipeline/test_scd2_disposition.py | 93 ++++++++++++++++---- 5 files changed, 172 insertions(+), 32 deletions(-) diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 1c28dffa8c..78c8e5fc87 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -48,6 +48,7 @@ TLoaderReplaceStrategy = Literal["truncate-and-insert", "insert-from-staging", "staging-optimized"] +TLoaderMergeStrategy = Literal["merge", "scd2"] TDestinationConfig = TypeVar("TDestinationConfig", bound="DestinationClientConfiguration") TDestinationClient = TypeVar("TDestinationClient", bound="JobClientBase") @@ -111,6 +112,8 @@ class DestinationClientDwhConfiguration(DestinationClientConfiguration): """name of default schema to be used to name effective dataset to load data to""" replace_strategy: TLoaderReplaceStrategy = "truncate-and-insert" """How to handle replace disposition for this destination, can be classic or staging""" + merge_strategy: TLoaderMergeStrategy = "merge" + """How to handle merging, can be classic merge on primary key or merge keys, or scd2""" def normalize_dataset_name(self, schema: Schema) -> str: """Builds full db dataset (schema) name out of configured dataset name and schema name: {dataset_name}_{schema.name}. The resulting name is normalized. diff --git a/dlt/common/normalizers/json/relational.py b/dlt/common/normalizers/json/relational.py index 9fd6f96072..d652afc610 100644 --- a/dlt/common/normalizers/json/relational.py +++ b/dlt/common/normalizers/json/relational.py @@ -149,6 +149,14 @@ def _extend_row(extend: DictStrAny, row: TDataItemRow) -> None: def _add_row_id( self, table: str, row: TDataItemRow, parent_row_id: str, pos: int, _r_lvl: int ) -> str: + # sometimes row id needs to be hash for now hardcode here + cleaned_row = {k: v for k, v in row.items() if not k.startswith("_dlt")} + row_hash = digest128(json.dumps(cleaned_row, sort_keys=True)) + row["_dlt_id"] = row_hash + if _r_lvl > 0: + DataItemNormalizer._link_row(cast(TDataItemRowChild, row), parent_row_id, pos) + return row_hash + # row_id is always random, no matter if primary_key is present or not row_id = generate_dlt_id() if _r_lvl > 0: @@ -161,11 +169,6 @@ def _add_row_id( row["_dlt_id"] = row_id return row_id - def _add_row_hash(self, table: str, row: TDataItemRow) -> str: - row_hash = digest128(json.dumps(row, sort_keys=True)) - row["_dlt_hash"] = row_hash - return row_hash - def _get_propagated_values(self, table: str, row: TDataItemRow, _r_lvl: int) -> StrAny: extend: DictStrAny = {} @@ -234,14 +237,11 @@ def _normalize_row( flattened_row, lists = self._flatten(table, dict_row, _r_lvl) # always extend row DataItemNormalizer._extend_row(extend, flattened_row) + # infer record hash or leave existing primary key if present row_id = flattened_row.get("_dlt_id", None) if not row_id: row_id = self._add_row_id(table, flattened_row, parent_row_id, pos, _r_lvl) - # add row hash (TODO: only add when needed, either via column hint or do it when scd2 wd is used) - row_hash = flattened_row.get("_dlt_hash", None) - if not row_hash: - row_hash = self._add_row_hash(table, flattened_row) # find fields to propagate to child tables in config extend.update(self._get_propagated_values(table, flattened_row, _r_lvl)) diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index ac68cfea8a..5816e6b96d 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -224,7 +224,14 @@ def _create_append_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> L return [] def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: - return [SqlMergeJob.from_table_chain(table_chain, self.sql_client)] + now = pendulum.now() + return [ + SqlMergeJob.from_table_chain( + table_chain, + self.sql_client, + {"merge_stragegy": self.config.merge_strategy, "validity_date": now}, + ) + ] def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index d97a098669..58be6dd6ae 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -1,4 +1,5 @@ from typing import Any, Callable, List, Sequence, Tuple, cast, TypedDict, Optional +from typing_extensions import NotRequired import yaml from dlt.common.runtime.logger import pretty_format_exception @@ -10,13 +11,17 @@ from dlt.destinations.exceptions import MergeDispositionException from dlt.destinations.job_impl import NewLoadJobImpl from dlt.destinations.sql_client import SqlClientBase +from dlt.common.destination.reference import TLoaderMergeStrategy +from datetime import datetime # noqa: I251 class SqlJobParams(TypedDict): - replace: Optional[bool] + replace: NotRequired[bool] + merge_stragegy: NotRequired[TLoaderMergeStrategy] + validity_date: NotRequired[datetime] -DEFAULTS: SqlJobParams = {"replace": False} +DEFAULTS: SqlJobParams = {"replace": False, "merge_stragegy": "merge", "validity_date": None} class SqlBaseJob(NewLoadJobImpl): @@ -35,7 +40,7 @@ def from_table_chain( The `table_chain` contains a list schemas of a tables with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list). """ - params = cast(SqlJobParams, {**DEFAULTS, **(params or {})}) # type: ignore + params = cast(SqlJobParams, {**DEFAULTS, **(params or {})}) top_table = table_chain[0] file_info = ParsedLoadJobFileName( top_table["name"], ParsedLoadJobFileName.new_file_id(), 0, "sql" @@ -120,11 +125,73 @@ def generate_sql( First we store the root_keys of root table elements to be deleted in the temp table. Then we use the temp table to delete records from root and all child tables in the destination dataset. At the end we copy the data from the staging dataset into destination dataset. """ - return cls.gen_merge_sql(table_chain, sql_client) + if params["merge_stragegy"] == "scd2": + return cls.gen_scd2_sql(table_chain, sql_client, params) + else: + return cls.gen_merge_sql(table_chain, sql_client) + + @classmethod + @classmethod + def gen_scd2_sql( + cls, + table_chain: Sequence[TTableSchema], + sql_client: SqlClientBase[Any], + params: Optional[SqlJobParams] = None, + ) -> List[str]: + sql: List[str] = [] + + validity_date = params["validity_date"] + hash_clause = cls._gen_key_table_clauses(["_dlt_id"], []) + + for table in table_chain: + table_name = sql_client.make_qualified_table_name(table["name"]) + with sql_client.with_staging_dataset(staging=True): + staging_table_name = sql_client.make_qualified_table_name(table["name"]) + + # we need to remember the original valid from dates in all tables, so copy those into the staging dataset + sql.append( + f"UPDATE {staging_table_name} SET _dlt_valid_from = (SELECT" + f" {table_name}._dlt_valid_from FROM {table_name} WHERE" + f" {staging_table_name}._dlt_id = {table_name}._dlt_id);" + ) + + # delete all rows that will be updated from all tables + key_table_clauses = cls.gen_key_table_clauses( + table_name, staging_table_name, hash_clause, for_delete=True + ) + for clause in key_table_clauses: + sql.append(f"DELETE {clause};") + + # now we only have colums in the main dataset left that are expired, so set valid until column + sql.append( + f"UPDATE {table_name} SET _dlt_valid_until = '{validity_date}' WHERE" + " _dlt_valid_until IS NULL;" + ) + + # copy all new rows from staging + columns = ", ".join( + map( + sql_client.capabilities.escape_identifier, + get_columns_names_with_prop(table, "name"), + ) + ) + sql.append( + f"INSERT INTO {table_name}({columns}) SELECT {columns} FROM {staging_table_name};" + ) + + # make sure all new columns have valid_from timestamp + sql.append( + f"UPDATE {table_name} SET _dlt_valid_from = '{validity_date}' WHERE" + " _dlt_valid_from IS NULL;" + ) + + return sql @classmethod def _gen_key_table_clauses( - cls, primary_keys: Sequence[str], merge_keys: Sequence[str] + cls, + primary_keys: Sequence[str], + merge_keys: Sequence[str], ) -> List[str]: """Generate sql clauses to select rows to delete via merge and primary key. Return select all clause if no keys defined.""" clauses: List[str] = [] @@ -205,7 +272,9 @@ def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: @classmethod def gen_merge_sql( - cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any] + cls, + table_chain: Sequence[TTableSchema], + sql_client: SqlClientBase[Any], ) -> List[str]: sql: List[str] = [] root_table = table_chain[0] diff --git a/tests/load/pipeline/test_scd2_disposition.py b/tests/load/pipeline/test_scd2_disposition.py index ecf9f3578f..afff644f31 100644 --- a/tests/load/pipeline/test_scd2_disposition.py +++ b/tests/load/pipeline/test_scd2_disposition.py @@ -1,4 +1,4 @@ -import pytest, dlt +import pytest, dlt, os from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration @@ -9,21 +9,82 @@ ids=lambda x: x.name, ) def test_simple_scd2_load(destination_config: DestinationTestConfiguration) -> None: - @dlt.resource(name="items", write_disposition="scd2", primary_key="id") + # use scd2 + os.environ["DESTINATION__MERGE_STRATEGY"] = "scd2" + + @dlt.resource(name="items", write_disposition="merge") def load_items(): - yield from [{ - "id": 1, - "name": "one", - }, - { - "id": 2, - "name": "two", - }, - { - "id": 3, - "name": "three", - }] + yield from [ + { + "id": 1, + "name": "one", + }, + { + "id": 2, + "name": "two", + "children": [ + { + "id_of": 2, + "name": "child2", + } + ], + }, + { + "id": 3, + "name": "three", + "children": [ + { + "id_of": 3, + "name": "child3", + } + ], + }, + ] + p = destination_config.setup_pipeline("test", full_refresh=True) p.run(load_items()) - print(p.default_schema.to_pretty_yaml()) - assert False \ No newline at end of file + + # new version of item 1 + # item 3 deleted + @dlt.resource(name="items", write_disposition="merge") + def load_items_2(): + yield from [ + { + "id": 1, + "name": "one_new", + }, + { + "id": 2, + "name": "two", + "children": [ + { + "id_of": 2, + "name": "child2_new", + } + ], + }, + ] + + p.run(load_items_2()) + + with p.sql_client() as c: + with c.execute_query( + "SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = 'items' ORDER" + " BY ORDINAL_POSITION" + ) as cur: + print(cur.fetchall()) + with c.execute_query("SELECT * FROM items") as cur: + for row in cur.fetchall(): + print(row) + + with p.sql_client() as c: + with c.execute_query( + "SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME =" + " 'items__children' ORDER BY ORDINAL_POSITION" + ) as cur: + print(cur.fetchall()) + with c.execute_query("SELECT * FROM items__children") as cur: + for row in cur.fetchall(): + print(row) + # print(p.default_schema.to_pretty_yaml()) + assert False From c2968c8866db56208b834c9ba01ce4d276266b58 Mon Sep 17 00:00:00 2001 From: Dave Date: Wed, 31 Jan 2024 15:38:31 +0100 Subject: [PATCH 3/3] some cleanup and proper asserts for first test --- dlt/common/destination/reference.py | 2 + dlt/common/schema/typing.py | 2 +- dlt/destinations/job_client_impl.py | 2 +- tests/load/pipeline/test_scd2_disposition.py | 95 +++++++++++++++----- 4 files changed, 77 insertions(+), 24 deletions(-) diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 78c8e5fc87..58c8a6fc76 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -114,6 +114,8 @@ class DestinationClientDwhConfiguration(DestinationClientConfiguration): """How to handle replace disposition for this destination, can be classic or staging""" merge_strategy: TLoaderMergeStrategy = "merge" """How to handle merging, can be classic merge on primary key or merge keys, or scd2""" + load_timestamp: str = None + """Configurable timestamp for load strategies that record validity dates""" def normalize_dataset_name(self, schema: Schema) -> str: """Builds full db dataset (schema) name out of configured dataset name and schema name: {dataset_name}_{schema.name}. The resulting name is normalized. diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index ef0967dd01..9a27cbe4bb 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -60,7 +60,7 @@ "merge_key", ] """Known hints of a column used to declare hint regexes.""" -TWriteDisposition = Literal["skip", "append", "replace", "merge", "scd2"] +TWriteDisposition = Literal["skip", "append", "replace", "merge"] TTableFormat = Literal["iceberg"] TTypeDetections = Literal[ "timestamp", "iso_timestamp", "iso_date", "large_integer", "hexbytes_to_text", "wei_to_double" diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 5816e6b96d..c2d5bb8eb2 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -224,7 +224,7 @@ def _create_append_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> L return [] def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: - now = pendulum.now() + now = self.config.load_timestamp or pendulum.now().to_iso8601_string() return [ SqlMergeJob.from_table_chain( table_chain, diff --git a/tests/load/pipeline/test_scd2_disposition.py b/tests/load/pipeline/test_scd2_disposition.py index afff644f31..38c1841527 100644 --- a/tests/load/pipeline/test_scd2_disposition.py +++ b/tests/load/pipeline/test_scd2_disposition.py @@ -1,6 +1,9 @@ -import pytest, dlt, os +import pytest, dlt, os, pendulum from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration +from tests.load.pipeline.utils import ( + load_tables_to_dicts, +) @pytest.mark.parametrize( @@ -10,7 +13,9 @@ ) def test_simple_scd2_load(destination_config: DestinationTestConfiguration) -> None: # use scd2 + first_load = pendulum.now() os.environ["DESTINATION__MERGE_STRATEGY"] = "scd2" + os.environ["DESTINATION__LOAD_TIMESTAMP"] = first_load.to_iso8601_string() @dlt.resource(name="items", write_disposition="merge") def load_items(): @@ -46,6 +51,7 @@ def load_items(): # new version of item 1 # item 3 deleted + # item 2 has a new child @dlt.resource(name="items", write_disposition="merge") def load_items_2(): yield from [ @@ -65,26 +71,71 @@ def load_items_2(): }, ] + second_load = pendulum.now() + os.environ["DESTINATION__LOAD_TIMESTAMP"] = second_load.to_iso8601_string() + p.run(load_items_2()) - with p.sql_client() as c: - with c.execute_query( - "SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = 'items' ORDER" - " BY ORDINAL_POSITION" - ) as cur: - print(cur.fetchall()) - with c.execute_query("SELECT * FROM items") as cur: - for row in cur.fetchall(): - print(row) - - with p.sql_client() as c: - with c.execute_query( - "SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME =" - " 'items__children' ORDER BY ORDINAL_POSITION" - ) as cur: - print(cur.fetchall()) - with c.execute_query("SELECT * FROM items__children") as cur: - for row in cur.fetchall(): - print(row) - # print(p.default_schema.to_pretty_yaml()) - assert False + tables = load_tables_to_dicts(p, "items", "items__children") + # we should have 4 items in total (3 from the first load and an update of item 1 from the second load) + assert len(tables["items"]) == 4 + # 2 should be active (1 and 2) + active_items = [item for item in tables["items"] if item["_dlt_valid_until"] is None] + inactive_items = [item for item in tables["items"] if item["_dlt_valid_until"] is not None] + active_items.sort(key=lambda i: i["id"]) + inactive_items.sort(key=lambda i: i["id"]) + assert len(active_items) == 2 + + # changed in the second load + assert active_items[0]["id"] == 1 + assert active_items[0]["name"] == "one_new" + assert active_items[0]["_dlt_valid_from"] == second_load + + # did not change in the second load + assert active_items[1]["id"] == 2 + assert active_items[1]["name"] == "two" + assert active_items[1]["_dlt_valid_from"] == first_load + + # was valid between first and second load + assert inactive_items[0]["id"] == 1 + assert inactive_items[0]["name"] == "one" + assert inactive_items[0]["_dlt_valid_from"] == first_load + assert inactive_items[0]["_dlt_valid_until"] == second_load + + # was valid between first and second load + assert inactive_items[1]["id"] == 3 + assert inactive_items[1]["name"] == "three" + assert inactive_items[1]["_dlt_valid_from"] == first_load + assert inactive_items[1]["_dlt_valid_until"] == second_load + + # child tables + assert len(tables["items__children"]) == 3 + active_child_items = [ + item for item in tables["items__children"] if item["_dlt_valid_until"] is None + ] + inactive_child_items = [ + item for item in tables["items__children"] if item["_dlt_valid_until"] is not None + ] + active_child_items.sort(key=lambda i: i["id_of"]) + inactive_child_items.sort(key=lambda i: i["id_of"]) + + assert len(active_child_items) == 1 + + # the one active child item should be linked to the right parent, was create during 2. load + assert active_child_items[0]["id_of"] == 2 + assert active_child_items[0]["name"] == "child2_new" + assert active_child_items[0]["_dlt_parent_id"] == active_items[1]["_dlt_id"] + assert active_child_items[0]["_dlt_valid_from"] == second_load + + # check inactive child items + assert inactive_child_items[0]["id_of"] == 2 + assert inactive_child_items[0]["name"] == "child2" + assert inactive_child_items[0]["_dlt_parent_id"] == active_items[1]["_dlt_id"] + assert inactive_child_items[0]["_dlt_valid_from"] == first_load + assert inactive_child_items[0]["_dlt_valid_until"] == second_load + + assert inactive_child_items[1]["id_of"] == 3 + assert inactive_child_items[1]["name"] == "child3" + assert inactive_child_items[1]["_dlt_parent_id"] == inactive_items[1]["_dlt_id"] + assert inactive_child_items[1]["_dlt_valid_from"] == first_load + assert inactive_child_items[1]["_dlt_valid_until"] == second_load