From 5c016aea72efefa8df6fed2387f9ca2d7d477dd3 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Sun, 29 Sep 2024 18:35:59 -0400 Subject: [PATCH] Sqlalchemy scd2 --- dlt/destinations/impl/sqlalchemy/factory.py | 13 +- dlt/destinations/impl/sqlalchemy/load_jobs.py | 15 +- dlt/destinations/impl/sqlalchemy/merge_job.py | 155 +++++++++++++++--- .../impl/sqlalchemy/sqlalchemy_job_client.py | 6 +- .../dlt-ecosystem/destinations/sqlalchemy.md | 3 +- tests/load/pipeline/test_scd2.py | 47 ++++-- 6 files changed, 193 insertions(+), 46 deletions(-) diff --git a/dlt/destinations/impl/sqlalchemy/factory.py b/dlt/destinations/impl/sqlalchemy/factory.py index 9235c4f143..bf05c42f08 100644 --- a/dlt/destinations/impl/sqlalchemy/factory.py +++ b/dlt/destinations/impl/sqlalchemy/factory.py @@ -1,5 +1,6 @@ import typing as t +from dlt.common import pendulum from dlt.common.destination import Destination, DestinationCapabilitiesContext from dlt.common.destination.capabilities import DataTypeMapper from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE @@ -9,6 +10,7 @@ SqlalchemyCredentials, SqlalchemyClientConfiguration, ) +from dlt.common.data_writers.escape import format_datetime_literal SqlalchemyTypeMapper: t.Type[DataTypeMapper] @@ -24,6 +26,13 @@ from sqlalchemy.engine import Engine +def _format_mysql_datetime_literal( + v: pendulum.DateTime, precision: int = 6, no_tz: bool = False +) -> str: + # Format without timezone to prevent tz conversion in SELECT + return format_datetime_literal(v, precision, no_tz=True) + + class sqlalchemy(Destination[SqlalchemyClientConfiguration, "SqlalchemyJobClient"]): spec = SqlalchemyClientConfiguration @@ -50,7 +59,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.supports_multiple_statements = False caps.type_mapper = SqlalchemyTypeMapper caps.supported_replace_strategies = ["truncate-and-insert", "insert-from-staging"] - caps.supported_merge_strategies = ["delete-insert"] + caps.supported_merge_strategies = ["delete-insert", "scd2"] return caps @@ -68,6 +77,8 @@ def adjust_capabilities( caps.max_identifier_length = dialect.max_identifier_length caps.max_column_identifier_length = dialect.max_identifier_length caps.supports_native_boolean = dialect.supports_native_boolean + if dialect.name == "mysql": + caps.format_datetime_literal = _format_mysql_datetime_literal return caps diff --git a/dlt/destinations/impl/sqlalchemy/load_jobs.py b/dlt/destinations/impl/sqlalchemy/load_jobs.py index d3f0afaefe..3cfd6bd910 100644 --- a/dlt/destinations/impl/sqlalchemy/load_jobs.py +++ b/dlt/destinations/impl/sqlalchemy/load_jobs.py @@ -3,7 +3,6 @@ import sqlalchemy as sa -from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.destination.reference import ( RunnableLoadJob, HasFollowupJobs, @@ -11,12 +10,10 @@ ) from dlt.common.storages import FileStorage from dlt.common.json import json, PY_DATETIME_DECODERS -from dlt.destinations.sql_jobs import SqlFollowupJob, SqlJobParams, SqlMergeFollowupJob +from dlt.destinations.sql_jobs import SqlFollowupJob, SqlJobParams from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient -from dlt.destinations.impl.sqlalchemy.merge_job import ( - SqlalchemyMergeFollowupJob as SqlalchemyMergeFollowupJob, -) +from dlt.destinations.impl.sqlalchemy.merge_job import SqlalchemyMergeFollowupJob if TYPE_CHECKING: from dlt.destinations.impl.sqlalchemy.sqlalchemy_job_client import SqlalchemyJobClient @@ -138,3 +135,11 @@ def generate_sql( statements.append(stmt) return statements + + +__all__ = [ + "SqlalchemyJsonLInsertJob", + "SqlalchemyParquetInsertJob", + "SqlalchemyStagingCopyJob", + "SqlalchemyMergeFollowupJob", +] diff --git a/dlt/destinations/impl/sqlalchemy/merge_job.py b/dlt/destinations/impl/sqlalchemy/merge_job.py index f378920104..8fae053bcd 100644 --- a/dlt/destinations/impl/sqlalchemy/merge_job.py +++ b/dlt/destinations/impl/sqlalchemy/merge_job.py @@ -1,30 +1,34 @@ -from typing import Sequence, Tuple, Optional, List +from typing import Sequence, Tuple, Optional, List, Union +import operator import sqlalchemy as sa -from dlt.destinations.sql_jobs import SqlMergeFollowupJob, SqlJobParams -from dlt.common.destination.reference import PreparedTableSchema +from dlt.destinations.sql_jobs import SqlMergeFollowupJob +from dlt.common.destination.reference import PreparedTableSchema, DestinationCapabilitiesContext from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient from dlt.common.schema.utils import ( get_columns_names_with_prop, get_dedup_sort_tuple, get_first_column_name_with_prop, is_nested_table, + get_validity_column_names, + get_active_record_timestamp, ) +from dlt.common.time import ensure_pendulum_datetime +from dlt.common.storages.load_package import load_package as current_load_package class SqlalchemyMergeFollowupJob(SqlMergeFollowupJob): """Uses SQLAlchemy to generate merge SQL statements. Result is equivalent to the SQL generated by `SqlMergeFollowupJob` - except we use concrete tables instead of temporary tables. + except for delete-insert we use concrete tables instead of temporary tables. """ @classmethod - def generate_sql( + def gen_merge_sql( cls, table_chain: Sequence[PreparedTableSchema], sql_client: SqlalchemyClient, # type: ignore[override] - params: Optional[SqlJobParams] = None, ) -> List[str]: root_table = table_chain[0] @@ -123,8 +127,8 @@ def generate_sql( ) staging_row_key_col = staging_root_table_obj.c[row_key_col_name] - # Create the insert "temporary" table (but use a concrete table) + # Create the insert "temporary" table (but use a concrete table) insert_temp_table = sa.Table( "insert_" + root_table_obj.name, temp_metadata, @@ -149,8 +153,6 @@ def generate_sql( order_dir_func = sa.asc if condition_columns: inner_cols += condition_columns - # inner_cols = condition_columns if condition_columns else [staging_row_key_col] - # breakpoint() inner_select = sa.select( sa.func.row_number() @@ -159,13 +161,12 @@ def generate_sql( order_by=order_dir_func(order_by_col), ) .label("_dlt_dedup_rn"), - *inner_cols + *inner_cols, ).subquery() - select_for_temp_insert = sa.select( - # *[c for c in inner_select.c if c.name != "_dlt_dedup_rn"] - inner_select.c[row_key_col_name] - ).where(inner_select.c._dlt_dedup_rn == 1) + select_for_temp_insert = sa.select(inner_select.c[row_key_col_name]).where( + inner_select.c._dlt_dedup_rn == 1 + ) hard_delete_col_name, not_delete_cond = cls._get_hard_delete_col_and_cond( root_table, inner_select, @@ -193,7 +194,6 @@ def generate_sql( staging_table_obj = table_obj.to_metadata( sql_client.metadata, schema=sql_client.staging_dataset_name ) - # insert_cond = not_delete_cond if hard_delete_col_name is not None else sa.true() select_sql = staging_table_obj.select() if (primary_key_names and len(table_chain) > 1) or ( @@ -248,12 +248,6 @@ def generate_sql( if hard_delete_col_name is not None: select_sql = select_sql.where(not_delete_cond) - # hard_delete_col_name, not_delete_cond = cls._get_hard_delete_col_and_cond( - # root_table, inner_select, invert=True - # ) - # insert_cond = not_delete_cond if hard_delete_col_name is not None else sa.true() - # if insert_cond is not None: - # select_sql = select_sql.where(insert_cond) insert_statement = table_obj.insert().from_select( [col.name for col in table_obj.columns], select_sql @@ -326,3 +320,122 @@ def _generate_key_table_clauses( return sa.or_(*clauses) # type: ignore[no-any-return] else: return sa.true() # type: ignore[no-any-return] + + @classmethod + def _gen_concat_sqla( + cls, columns: Sequence[sa.Column] + ) -> Union[sa.sql.elements.BinaryExpression, sa.Column]: + # Use col1 + col2 + col3 ... to generate a dialect specific concat expression + result = columns[0] + if len(columns) == 1: + return result + # Cast because CONCAT is only generated for string columns + result = sa.cast(result, sa.String) + for col in columns[1:]: + result = operator.add(result, sa.cast(col, sa.String)) + return result + + @classmethod + def gen_scd2_sql( + cls, + table_chain: Sequence[PreparedTableSchema], + sql_client: SqlalchemyClient, # type: ignore[override] + ) -> List[str]: + sqla_statements = [] + root_table = table_chain[0] + root_table_obj = sql_client.get_existing_table(root_table["name"]) + staging_root_table_obj = root_table_obj.to_metadata( + sql_client.metadata, schema=sql_client.staging_dataset_name + ) + + from_, to = get_validity_column_names(root_table) + hash_ = get_first_column_name_with_prop(root_table, "x-row-version") + + caps = sql_client.capabilities + + format_datetime_literal = caps.format_datetime_literal + if format_datetime_literal is None: + format_datetime_literal = ( + DestinationCapabilitiesContext.generic_capabilities().format_datetime_literal + ) + + boundary_ts = ensure_pendulum_datetime( + root_table.get("x-boundary-timestamp", current_load_package()["state"]["created_at"]) # type: ignore[arg-type] + ) + + boundary_literal = format_datetime_literal(boundary_ts, caps.timestamp_precision) + + active_record_timestamp = get_active_record_timestamp(root_table) + + update_statement = ( + root_table_obj.update() + .values({to: sa.text(boundary_literal)}) + .where(root_table_obj.c[hash_].notin_(sa.select([staging_root_table_obj.c[hash_]]))) + ) + + if active_record_timestamp is None: + active_record_literal = None + root_is_active_clause = root_table_obj.c[to].is_(None) + else: + active_record_literal = format_datetime_literal( + active_record_timestamp, caps.timestamp_precision + ) + root_is_active_clause = root_table_obj.c[to] == sa.text(active_record_literal) + + update_statement = update_statement.where(root_is_active_clause) + + merge_keys = get_columns_names_with_prop(root_table, "merge_key") + if merge_keys: + root_merge_key_cols = [root_table_obj.c[key] for key in merge_keys] + staging_merge_key_cols = [staging_root_table_obj.c[key] for key in merge_keys] + + update_statement = update_statement.where( + cls._gen_concat_sqla(root_merge_key_cols).in_( + sa.select(cls._gen_concat_sqla(staging_merge_key_cols)) + ) + ) + + sqla_statements.append(update_statement) + + insert_statement = root_table_obj.insert().from_select( + [col.name for col in root_table_obj.columns], + sa.select( + sa.literal(boundary_literal.strip("'")).label(from_), + sa.literal( + active_record_literal.strip("'") if active_record_literal is not None else None + ).label(to), + *[c for c in staging_root_table_obj.columns if c.name not in [from_, to]], + ).where( + staging_root_table_obj.c[hash_].notin_( + sa.select([root_table_obj.c[hash_]]).where(root_is_active_clause) + ) + ), + ) + sqla_statements.append(insert_statement) + + nested_tables = table_chain[1:] + for table in nested_tables: + row_key_column = cls._get_root_key_col(table_chain, sql_client, table) + + table_obj = sql_client.get_existing_table(table["name"]) + staging_table_obj = table_obj.to_metadata( + sql_client.metadata, schema=sql_client.staging_dataset_name + ) + + insert_statement = table_obj.insert().from_select( + [col.name for col in table_obj.columns], + staging_table_obj.select().where( + staging_table_obj.c[row_key_column].notin_( + sa.select(table_obj.c[row_key_column]) + ) + ), + ) + sqla_statements.append(insert_statement) + + return [ + x + ";" if not x.endswith(";") else x + for x in ( + str(stmt.compile(sql_client.engine, compile_kwargs={"literal_binds": True})) + for stmt in sqla_statements + ) + ] diff --git a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py index 5819b3e4b5..c5a6442d8a 100644 --- a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py +++ b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py @@ -18,7 +18,11 @@ from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.schema import Schema, TTableSchema, TColumnSchema, TSchemaTables from dlt.common.schema.typing import TColumnType, TTableSchemaColumns -from dlt.common.schema.utils import pipeline_state_table, normalize_table_identifiers, is_complete_column +from dlt.common.schema.utils import ( + pipeline_state_table, + normalize_table_identifiers, + is_complete_column, +) from dlt.destinations.exceptions import DatabaseUndefinedRelation from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient from dlt.destinations.impl.sqlalchemy.configuration import SqlalchemyClientConfiguration diff --git a/docs/website/docs/dlt-ecosystem/destinations/sqlalchemy.md b/docs/website/docs/dlt-ecosystem/destinations/sqlalchemy.md index b9014e0564..9f33c02337 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/sqlalchemy.md +++ b/docs/website/docs/dlt-ecosystem/destinations/sqlalchemy.md @@ -135,8 +135,7 @@ The following write dispositions are supported: - `append` - `replace` with `truncate-and-insert` and `insert-from-staging` replace strategies. `staging-optimized` falls back to `insert-from-staging`. - -The `merge` disposition is not supported and falls back to `append`. +- `merge` with `delete-insert` and `scd2` merge strategies. ## Data loading diff --git a/tests/load/pipeline/test_scd2.py b/tests/load/pipeline/test_scd2.py index 3e08b792ed..2a5b9ed296 100644 --- a/tests/load/pipeline/test_scd2.py +++ b/tests/load/pipeline/test_scd2.py @@ -52,13 +52,22 @@ def strip_timezone(ts: TAnyDateTime) -> pendulum.DateTime: def get_table( - pipeline: dlt.Pipeline, table_name: str, sort_column: str = None, include_root_id: bool = True + pipeline: dlt.Pipeline, + table_name: str, + sort_column: str = None, + include_root_id: bool = True, + ts_columns: Optional[List[str]] = None, ) -> List[Dict[str, Any]]: """Returns destination table contents as list of dictionaries.""" + ts_columns = ts_columns or [] table = [ { - k: strip_timezone(v) if isinstance(v, datetime) else v + k: ( + strip_timezone(v) + if isinstance(v, datetime) or (k in ts_columns and v is not None) + else v + ) for k, v in r.items() if not k.startswith("_dlt") or k in DEFAULT_VALIDITY_COLUMN_NAMES @@ -128,7 +137,7 @@ def r(data): # assert load results ts_1 = get_load_package_created_at(p, info) assert_load_info(info) - assert get_table(p, "dim_test", "c2__nc1") == [ + assert get_table(p, "dim_test", "c2__nc1", ts_columns=[from_, to]) == [ { from_: ts_1, to: None, @@ -153,7 +162,7 @@ def r(data): info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_2 = get_load_package_created_at(p, info) assert_load_info(info) - assert get_table(p, "dim_test", "c2__nc1") == [ + assert get_table(p, "dim_test", "c2__nc1", ts_columns=[from_, to]) == [ { from_: ts_1, to: None, @@ -178,7 +187,7 @@ def r(data): info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_3 = get_load_package_created_at(p, info) assert_load_info(info) - assert get_table(p, "dim_test", "c2__nc1") == [ + assert get_table(p, "dim_test", "c2__nc1", ts_columns=[from_, to]) == [ {from_: ts_1, to: ts_3, "nk": 2, "c1": "bar", "c2__nc1": "bar"}, {from_: ts_1, to: ts_2, "nk": 1, "c1": "foo", "c2__nc1": "foo"}, { @@ -198,7 +207,7 @@ def r(data): info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_4 = get_load_package_created_at(p, info) assert_load_info(info) - assert get_table(p, "dim_test", "c2__nc1") == [ + assert get_table(p, "dim_test", "c2__nc1", ts_columns=[from_, to]) == [ {from_: ts_1, to: ts_3, "nk": 2, "c1": "bar", "c2__nc1": "bar"}, { from_: ts_4, @@ -242,7 +251,7 @@ def r(data): info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_1 = get_load_package_created_at(p, info) assert_load_info(info) - assert get_table(p, "dim_test", "c1") == [ + assert get_table(p, "dim_test", "c1", ts_columns=[FROM, TO]) == [ {FROM: ts_1, TO: None, "nk": 2, "c1": "bar"}, {FROM: ts_1, TO: None, "nk": 1, "c1": "foo"}, ] @@ -261,7 +270,7 @@ def r(data): info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_2 = get_load_package_created_at(p, info) assert_load_info(info) - assert get_table(p, "dim_test", "c1") == [ + assert get_table(p, "dim_test", "c1", ts_columns=[FROM, TO]) == [ {FROM: ts_1, TO: None, "nk": 2, "c1": "bar"}, {FROM: ts_1, TO: ts_2, "nk": 1, "c1": "foo"}, # updated {FROM: ts_2, TO: None, "nk": 1, "c1": "foo_updated"}, # new @@ -289,7 +298,7 @@ def r(data): ts_3 = get_load_package_created_at(p, info) assert_load_info(info) assert_records_as_set( - get_table(p, "dim_test"), + get_table(p, "dim_test", ts_columns=[FROM, TO]), [ {FROM: ts_1, TO: None, "nk": 2, "c1": "bar"}, {FROM: ts_1, TO: ts_2, "nk": 1, "c1": "foo"}, @@ -315,7 +324,7 @@ def r(data): ts_4 = get_load_package_created_at(p, info) assert_load_info(info) assert_records_as_set( - get_table(p, "dim_test"), + get_table(p, "dim_test", ts_columns=[FROM, TO]), [ {FROM: ts_1, TO: ts_4, "nk": 2, "c1": "bar"}, # updated {FROM: ts_1, TO: ts_2, "nk": 1, "c1": "foo"}, @@ -336,7 +345,7 @@ def r(data): ts_5 = get_load_package_created_at(p, info) assert_load_info(info) assert_records_as_set( - get_table(p, "dim_test"), + get_table(p, "dim_test", ts_columns=[FROM, TO]), [ {FROM: ts_1, TO: ts_4, "nk": 2, "c1": "bar"}, {FROM: ts_5, TO: None, "nk": 3, "c1": "baz"}, # new @@ -502,7 +511,7 @@ def r(data): {**{FROM: ts_3, TO: None}, **r1_no_child}, {**{FROM: ts_1, TO: None}, **r2_no_child}, ] - assert_records_as_set(get_table(p, "dim_test"), expected) + assert_records_as_set(get_table(p, "dim_test", ts_columns=[FROM, TO]), expected) # assert child records expected = [ @@ -739,7 +748,10 @@ def dim_test(data): assert load_table_counts(p, "dim_test")["dim_test"] == 3 ts3 = get_load_package_created_at(p, info) # natural key 1 should now have two records (one retired, one active) - actual = [{k: v for k, v in row.items() if k in ("nk", TO)} for row in get_table(p, "dim_test")] + actual = [ + {k: v for k, v in row.items() if k in ("nk", TO)} + for row in get_table(p, "dim_test", ts_columns=[FROM, TO]) + ] expected = [{"nk": 1, TO: ts3}, {"nk": 1, TO: None}, {"nk": 2, TO: None}] assert_records_as_set(actual, expected) # type: ignore[arg-type] @@ -753,7 +765,10 @@ def dim_test(data): assert load_table_counts(p, "dim_test")["dim_test"] == 4 ts4 = get_load_package_created_at(p, info) # natural key 1 should now have three records (two retired, one active) - actual = [{k: v for k, v in row.items() if k in ("nk", TO)} for row in get_table(p, "dim_test")] + actual = [ + {k: v for k, v in row.items() if k in ("nk", TO)} + for row in get_table(p, "dim_test", ts_columns=[FROM, TO]) + ] expected = [{"nk": 1, TO: ts3}, {"nk": 1, TO: ts4}, {"nk": 1, TO: None}, {"nk": 2, TO: None}] assert_records_as_set(actual, expected) # type: ignore[arg-type] @@ -805,7 +820,7 @@ def dim_test_compound(data): # "Doe" should now have two records (one retired, one active) actual = [ {k: v for k, v in row.items() if k in ("first_name", "last_name", TO)} - for row in get_table(p, "dim_test_compound") + for row in get_table(p, "dim_test_compound", ts_columns=[FROM, TO]) ] expected = [ {"first_name": first_name, "last_name": "Doe", TO: ts3}, @@ -869,7 +884,7 @@ def dim_test(data): ts2 = get_load_package_created_at(p, info) actual = [ {k: v for k, v in row.items() if k in ("date", "name", TO)} - for row in get_table(p, "dim_test") + for row in get_table(p, "dim_test", ts_columns=[TO]) ] expected = [ {"date": "2024-01-01", "name": "a", TO: None},