Skip to content

Commit

Permalink
Sqlalchemy scd2
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Oct 1, 2024
1 parent cabff08 commit 5c016ae
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 46 deletions.
13 changes: 12 additions & 1 deletion dlt/destinations/impl/sqlalchemy/factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing as t

from dlt.common import pendulum
from dlt.common.destination import Destination, DestinationCapabilitiesContext
from dlt.common.destination.capabilities import DataTypeMapper
from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE
Expand All @@ -9,6 +10,7 @@
SqlalchemyCredentials,
SqlalchemyClientConfiguration,
)
from dlt.common.data_writers.escape import format_datetime_literal

SqlalchemyTypeMapper: t.Type[DataTypeMapper]

Expand All @@ -24,6 +26,13 @@
from sqlalchemy.engine import Engine


def _format_mysql_datetime_literal(
v: pendulum.DateTime, precision: int = 6, no_tz: bool = False
) -> str:
# Format without timezone to prevent tz conversion in SELECT
return format_datetime_literal(v, precision, no_tz=True)


class sqlalchemy(Destination[SqlalchemyClientConfiguration, "SqlalchemyJobClient"]):
spec = SqlalchemyClientConfiguration

Expand All @@ -50,7 +59,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext:
caps.supports_multiple_statements = False
caps.type_mapper = SqlalchemyTypeMapper
caps.supported_replace_strategies = ["truncate-and-insert", "insert-from-staging"]
caps.supported_merge_strategies = ["delete-insert"]
caps.supported_merge_strategies = ["delete-insert", "scd2"]

return caps

Expand All @@ -68,6 +77,8 @@ def adjust_capabilities(
caps.max_identifier_length = dialect.max_identifier_length
caps.max_column_identifier_length = dialect.max_identifier_length
caps.supports_native_boolean = dialect.supports_native_boolean
if dialect.name == "mysql":
caps.format_datetime_literal = _format_mysql_datetime_literal

return caps

Expand Down
15 changes: 10 additions & 5 deletions dlt/destinations/impl/sqlalchemy/load_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,17 @@

import sqlalchemy as sa

from dlt.common.schema.utils import get_columns_names_with_prop
from dlt.common.destination.reference import (
RunnableLoadJob,
HasFollowupJobs,
PreparedTableSchema,
)
from dlt.common.storages import FileStorage
from dlt.common.json import json, PY_DATETIME_DECODERS
from dlt.destinations.sql_jobs import SqlFollowupJob, SqlJobParams, SqlMergeFollowupJob
from dlt.destinations.sql_jobs import SqlFollowupJob, SqlJobParams

from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient
from dlt.destinations.impl.sqlalchemy.merge_job import (
SqlalchemyMergeFollowupJob as SqlalchemyMergeFollowupJob,
)
from dlt.destinations.impl.sqlalchemy.merge_job import SqlalchemyMergeFollowupJob

if TYPE_CHECKING:
from dlt.destinations.impl.sqlalchemy.sqlalchemy_job_client import SqlalchemyJobClient
Expand Down Expand Up @@ -138,3 +135,11 @@ def generate_sql(
statements.append(stmt)

return statements


__all__ = [
"SqlalchemyJsonLInsertJob",
"SqlalchemyParquetInsertJob",
"SqlalchemyStagingCopyJob",
"SqlalchemyMergeFollowupJob",
]
155 changes: 134 additions & 21 deletions dlt/destinations/impl/sqlalchemy/merge_job.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,34 @@
from typing import Sequence, Tuple, Optional, List
from typing import Sequence, Tuple, Optional, List, Union
import operator

import sqlalchemy as sa

from dlt.destinations.sql_jobs import SqlMergeFollowupJob, SqlJobParams
from dlt.common.destination.reference import PreparedTableSchema
from dlt.destinations.sql_jobs import SqlMergeFollowupJob
from dlt.common.destination.reference import PreparedTableSchema, DestinationCapabilitiesContext
from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient
from dlt.common.schema.utils import (
get_columns_names_with_prop,
get_dedup_sort_tuple,
get_first_column_name_with_prop,
is_nested_table,
get_validity_column_names,
get_active_record_timestamp,
)
from dlt.common.time import ensure_pendulum_datetime
from dlt.common.storages.load_package import load_package as current_load_package


class SqlalchemyMergeFollowupJob(SqlMergeFollowupJob):
"""Uses SQLAlchemy to generate merge SQL statements.
Result is equivalent to the SQL generated by `SqlMergeFollowupJob`
except we use concrete tables instead of temporary tables.
except for delete-insert we use concrete tables instead of temporary tables.
"""

@classmethod
def generate_sql(
def gen_merge_sql(
cls,
table_chain: Sequence[PreparedTableSchema],
sql_client: SqlalchemyClient, # type: ignore[override]
params: Optional[SqlJobParams] = None,
) -> List[str]:
root_table = table_chain[0]

Expand Down Expand Up @@ -123,8 +127,8 @@ def generate_sql(
)

staging_row_key_col = staging_root_table_obj.c[row_key_col_name]
# Create the insert "temporary" table (but use a concrete table)

# Create the insert "temporary" table (but use a concrete table)
insert_temp_table = sa.Table(
"insert_" + root_table_obj.name,
temp_metadata,
Expand All @@ -149,8 +153,6 @@ def generate_sql(
order_dir_func = sa.asc
if condition_columns:
inner_cols += condition_columns
# inner_cols = condition_columns if condition_columns else [staging_row_key_col]
# breakpoint()

inner_select = sa.select(
sa.func.row_number()
Expand All @@ -159,13 +161,12 @@ def generate_sql(
order_by=order_dir_func(order_by_col),
)
.label("_dlt_dedup_rn"),
*inner_cols
*inner_cols,
).subquery()

select_for_temp_insert = sa.select(
# *[c for c in inner_select.c if c.name != "_dlt_dedup_rn"]
inner_select.c[row_key_col_name]
).where(inner_select.c._dlt_dedup_rn == 1)
select_for_temp_insert = sa.select(inner_select.c[row_key_col_name]).where(
inner_select.c._dlt_dedup_rn == 1
)
hard_delete_col_name, not_delete_cond = cls._get_hard_delete_col_and_cond(
root_table,
inner_select,
Expand Down Expand Up @@ -193,7 +194,6 @@ def generate_sql(
staging_table_obj = table_obj.to_metadata(
sql_client.metadata, schema=sql_client.staging_dataset_name
)
# insert_cond = not_delete_cond if hard_delete_col_name is not None else sa.true()
select_sql = staging_table_obj.select()

if (primary_key_names and len(table_chain) > 1) or (
Expand Down Expand Up @@ -248,12 +248,6 @@ def generate_sql(

if hard_delete_col_name is not None:
select_sql = select_sql.where(not_delete_cond)
# hard_delete_col_name, not_delete_cond = cls._get_hard_delete_col_and_cond(
# root_table, inner_select, invert=True
# )
# insert_cond = not_delete_cond if hard_delete_col_name is not None else sa.true()
# if insert_cond is not None:
# select_sql = select_sql.where(insert_cond)

insert_statement = table_obj.insert().from_select(
[col.name for col in table_obj.columns], select_sql
Expand Down Expand Up @@ -326,3 +320,122 @@ def _generate_key_table_clauses(
return sa.or_(*clauses) # type: ignore[no-any-return]
else:
return sa.true() # type: ignore[no-any-return]

@classmethod
def _gen_concat_sqla(
cls, columns: Sequence[sa.Column]
) -> Union[sa.sql.elements.BinaryExpression, sa.Column]:
# Use col1 + col2 + col3 ... to generate a dialect specific concat expression
result = columns[0]
if len(columns) == 1:
return result
# Cast because CONCAT is only generated for string columns
result = sa.cast(result, sa.String)
for col in columns[1:]:
result = operator.add(result, sa.cast(col, sa.String))
return result

@classmethod
def gen_scd2_sql(
cls,
table_chain: Sequence[PreparedTableSchema],
sql_client: SqlalchemyClient, # type: ignore[override]
) -> List[str]:
sqla_statements = []
root_table = table_chain[0]
root_table_obj = sql_client.get_existing_table(root_table["name"])
staging_root_table_obj = root_table_obj.to_metadata(
sql_client.metadata, schema=sql_client.staging_dataset_name
)

from_, to = get_validity_column_names(root_table)
hash_ = get_first_column_name_with_prop(root_table, "x-row-version")

caps = sql_client.capabilities

format_datetime_literal = caps.format_datetime_literal
if format_datetime_literal is None:
format_datetime_literal = (
DestinationCapabilitiesContext.generic_capabilities().format_datetime_literal
)

boundary_ts = ensure_pendulum_datetime(
root_table.get("x-boundary-timestamp", current_load_package()["state"]["created_at"]) # type: ignore[arg-type]
)

boundary_literal = format_datetime_literal(boundary_ts, caps.timestamp_precision)

active_record_timestamp = get_active_record_timestamp(root_table)

update_statement = (
root_table_obj.update()
.values({to: sa.text(boundary_literal)})
.where(root_table_obj.c[hash_].notin_(sa.select([staging_root_table_obj.c[hash_]])))
)

if active_record_timestamp is None:
active_record_literal = None
root_is_active_clause = root_table_obj.c[to].is_(None)
else:
active_record_literal = format_datetime_literal(
active_record_timestamp, caps.timestamp_precision
)
root_is_active_clause = root_table_obj.c[to] == sa.text(active_record_literal)

update_statement = update_statement.where(root_is_active_clause)

merge_keys = get_columns_names_with_prop(root_table, "merge_key")
if merge_keys:
root_merge_key_cols = [root_table_obj.c[key] for key in merge_keys]
staging_merge_key_cols = [staging_root_table_obj.c[key] for key in merge_keys]

update_statement = update_statement.where(
cls._gen_concat_sqla(root_merge_key_cols).in_(
sa.select(cls._gen_concat_sqla(staging_merge_key_cols))
)
)

sqla_statements.append(update_statement)

insert_statement = root_table_obj.insert().from_select(
[col.name for col in root_table_obj.columns],
sa.select(
sa.literal(boundary_literal.strip("'")).label(from_),
sa.literal(
active_record_literal.strip("'") if active_record_literal is not None else None
).label(to),
*[c for c in staging_root_table_obj.columns if c.name not in [from_, to]],
).where(
staging_root_table_obj.c[hash_].notin_(
sa.select([root_table_obj.c[hash_]]).where(root_is_active_clause)
)
),
)
sqla_statements.append(insert_statement)

nested_tables = table_chain[1:]
for table in nested_tables:
row_key_column = cls._get_root_key_col(table_chain, sql_client, table)

table_obj = sql_client.get_existing_table(table["name"])
staging_table_obj = table_obj.to_metadata(
sql_client.metadata, schema=sql_client.staging_dataset_name
)

insert_statement = table_obj.insert().from_select(
[col.name for col in table_obj.columns],
staging_table_obj.select().where(
staging_table_obj.c[row_key_column].notin_(
sa.select(table_obj.c[row_key_column])
)
),
)
sqla_statements.append(insert_statement)

return [
x + ";" if not x.endswith(";") else x
for x in (
str(stmt.compile(sql_client.engine, compile_kwargs={"literal_binds": True}))
for stmt in sqla_statements
)
]
6 changes: 5 additions & 1 deletion dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from dlt.common.destination.capabilities import DestinationCapabilitiesContext
from dlt.common.schema import Schema, TTableSchema, TColumnSchema, TSchemaTables
from dlt.common.schema.typing import TColumnType, TTableSchemaColumns
from dlt.common.schema.utils import pipeline_state_table, normalize_table_identifiers, is_complete_column
from dlt.common.schema.utils import (
pipeline_state_table,
normalize_table_identifiers,
is_complete_column,
)
from dlt.destinations.exceptions import DatabaseUndefinedRelation
from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient
from dlt.destinations.impl.sqlalchemy.configuration import SqlalchemyClientConfiguration
Expand Down
3 changes: 1 addition & 2 deletions docs/website/docs/dlt-ecosystem/destinations/sqlalchemy.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,7 @@ The following write dispositions are supported:

- `append`
- `replace` with `truncate-and-insert` and `insert-from-staging` replace strategies. `staging-optimized` falls back to `insert-from-staging`.

The `merge` disposition is not supported and falls back to `append`.
- `merge` with `delete-insert` and `scd2` merge strategies.

## Data loading

Expand Down
Loading

0 comments on commit 5c016ae

Please sign in to comment.