Skip to content

Commit

Permalink
Merge pull request #240 from dlt-hub/rfix/adds-stage-dedup
Browse files Browse the repository at this point in the history
adds deduplication of staging dataset during merge
  • Loading branch information
rudolfix authored Apr 9, 2023
2 parents d1076bd + d0a4c69 commit 1e7a8e6
Show file tree
Hide file tree
Showing 33 changed files with 33,417 additions and 2,119 deletions.
16 changes: 11 additions & 5 deletions dlt/common/configuration/inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def with_config(
sections: Tuple[str, ...] = (),
sections_merge_style: ConfigSectionContext.TMergeFunc = ConfigSectionContext.prefer_incoming,
auto_pipeline_section: bool = False,
only_kw: bool = False
include_defaults: bool = True
) -> TFun:
...

Expand All @@ -43,7 +43,7 @@ def with_config(
sections: Tuple[str, ...] = (),
sections_merge_style: ConfigSectionContext.TMergeFunc = ConfigSectionContext.prefer_incoming,
auto_pipeline_section: bool = False,
only_kw: bool = False
include_defaults: bool = True
) -> Callable[[TFun], TFun]:
...

Expand All @@ -55,17 +55,20 @@ def with_config(
sections: Tuple[str, ...] = (),
sections_merge_style: ConfigSectionContext.TMergeFunc = ConfigSectionContext.prefer_incoming,
auto_pipeline_section: bool = False,
only_kw: bool = False
include_defaults: bool = True
) -> Callable[[TFun], TFun]:
"""Injects values into decorated function arguments following the specification in `spec` or by deriving one from function's signature.
The synthesized spec contains the arguments marked with `dlt.secrets.value` and `dlt.config.value` which are required to be injected at runtime.
Optionally (and by default) arguments with default values are included in spec as well.
Args:
func (Optional[AnyFun], optional): A function with arguments to be injected. Defaults to None.
spec (Type[BaseConfiguration], optional): A specification of injectable arguments. Defaults to None.
sections (Tuple[str, ...], optional): A set of config sections in which to look for arguments values. Defaults to ().
prefer_existing_sections: (bool, optional): When joining existing section context, the existing context will be preferred to the one in `sections`. Default: False
auto_pipeline_section (bool, optional): If True, a top level pipeline section will be added if `pipeline_name` argument is present . Defaults to False.
only_kw (bool, optional): If True and `spec` is not provided, one is synthesized from keyword only arguments ignoring any others. Defaults to False.
include_defaults (bool, optional): If True then arguments with default values will be included in synthesized spec. If False only the required arguments marked with `dlt.secrets.value` and `dlt.config.value` are included
Returns:
Callable[[TFun], TFun]: A decorated function
Expand All @@ -84,10 +87,13 @@ def decorator(f: TFun) -> TFun:
section_context = ConfigSectionContext(sections=sections, merge_style=sections_merge_style)

if spec is None:
SPEC = spec_from_signature(f, sig, only_kw)
SPEC = spec_from_signature(f, sig, include_defaults)
else:
SPEC = spec

if SPEC is None:
return f

for p in sig.parameters.values():
# for all positional parameters that do not have default value, set default
# if hasattr(SPEC, p.name) and p.default == Parameter.empty:
Expand Down
34 changes: 21 additions & 13 deletions dlt/common/configuration/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,22 @@ def inject_section(section_context: ConfigSectionContext, merge_existing: bool =

return container.injectable_context(section_context)

def _maybe_parse_native_value(config: TConfiguration, explicit_value: Any, embedded_sections: Tuple[str, ...]) -> Any:
# use initial value to resolve the whole configuration. if explicit value is a mapping it will be applied field by field later
if explicit_value and (not isinstance(explicit_value, C_Mapping) or isinstance(explicit_value, BaseConfiguration)):
# print(f"TRYING TO PARSE NATIVE from {explicit_value}")
try:
config.parse_native_representation(explicit_value)
# print("----ok")
except ValueError as v_err:
# provide generic exception
raise InvalidNativeValue(type(config), type(explicit_value), embedded_sections, v_err)
except NotImplementedError:

pass
# explicit value was consumed
explicit_value = None
return explicit_value

def _resolve_configuration(
config: TConfiguration,
Expand All @@ -69,18 +85,7 @@ def _resolve_configuration(
config.__exception__ = None
try:
try:
# use initial value to resolve the whole configuration. if explicit value is a mapping it will be applied field by field later
if explicit_value and not isinstance(explicit_value, C_Mapping):
try:
config.parse_native_representation(explicit_value)
except ValueError as v_err:
# provide generic exception
raise InvalidNativeValue(type(config), type(explicit_value), embedded_sections, v_err)
except NotImplementedError:
pass
# explicit value was consumed
explicit_value = None

explicit_value = _maybe_parse_native_value(config, explicit_value, embedded_sections)
# if native representation didn't fully resolve the config, we try to resolve field by field
if not config.is_resolved():
_resolve_config_fields(config, explicit_value, explicit_sections, embedded_sections, accept_partial)
Expand Down Expand Up @@ -142,7 +147,6 @@ def _resolve_config_fields(
# return first resolved config from an union
try:
current_value, traces = _resolve_config_field(key, alt_spec, default_value, explicit_value, config, config.__section__, explicit_sections, embedded_sections, accept_partial)
print(current_value)
break
except ConfigFieldMissingException as cfm_ex:
# add traces from unresolved union spec
Expand Down Expand Up @@ -199,6 +203,7 @@ def _resolve_config_field(
if isinstance(value, BaseConfiguration):
# if resolved value is instance of configuration (typically returned by context provider)
embedded_config = value
default_value = None
value = None
elif isinstance(default_value, BaseConfiguration):
# if default value was instance of configuration, use it
Expand All @@ -208,7 +213,10 @@ def _resolve_config_field(
embedded_config = inner_hint()

if embedded_config.is_resolved():
# print(f"{embedded_config} IS RESOLVED with VALUE {value}")
# injected context will be resolved
if value is not None:
_maybe_parse_native_value(embedded_config, value, embedded_sections + (key,))
value = embedded_config
else:
# only config with sections may look for initial values
Expand Down
2 changes: 1 addition & 1 deletion dlt/common/configuration/specs/base_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def wrap(cls: Type[TAnyClass]) -> Type[TAnyClass]:
# get all attributes without corresponding annotations
for att_name, att_value in cls.__dict__.items():
# skip callables, dunder names, class variables and some special names
if not callable(att_value) and not att_name.startswith(("__", "_abc_impl")) and not isinstance(att_value, (staticmethod, classmethod)):
if not callable(att_value) and not att_name.startswith(("__", "_abc_impl")) and not isinstance(att_value, (staticmethod, classmethod, property)):
if att_name not in cls.__annotations__:
raise ConfigFieldMissingTypeHintException(att_name, cls)
hint = cls.__annotations__[att_name]
Expand Down
2 changes: 1 addition & 1 deletion dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def _verify_schema(self) -> None:
* Removes and warns on (unbound) incomplete columns
"""

for table in self.schema.all_tables():
for table in self.schema.data_tables():
table_name = table["name"]
if len(table_name) > self.capabilities.max_identifier_length:
raise IdentifierTooLongException(self.config.destination_name, "table", table_name, self.capabilities.max_identifier_length)
Expand Down
19 changes: 11 additions & 8 deletions dlt/common/reflection/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _first_up(s: str) -> str:
return "".join(map(_first_up, _SLEEPING_CAT_SPLIT.findall(func_name))) + "Configuration"


def spec_from_signature(f: AnyFun, sig: Signature, kw_only: bool = False) -> Type[BaseConfiguration]:
def spec_from_signature(f: AnyFun, sig: Signature, include_defaults: bool = True) -> Type[BaseConfiguration]:
name = _get_spec_name_from_f(f)
module = inspect.getmodule(f)

Expand Down Expand Up @@ -59,16 +59,16 @@ def dlt_config_literal_to_type(arg_name: str) -> AnyType:
annotations: Dict[str, Any] = {}

for p in sig.parameters.values():
# skip *args and **kwargs, skip typical method params and if kw_only flag is set: accept KEYWORD ONLY args
if p.kind not in (Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL) and p.name not in ["self", "cls"] and \
(kw_only and p.kind == Parameter.KEYWORD_ONLY or not kw_only):
# skip *args and **kwargs, skip typical method params
if p.kind not in (Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL) and p.name not in ["self", "cls"]:
field_type = AnyType if p.annotation == Parameter.empty else p.annotation
# only valid hints and parameters with defaults are eligible
if is_valid_hint(field_type) and p.default != Parameter.empty:
# try to get type from default
if field_type is AnyType and p.default is not None:
field_type = type(p.default)
# make type optional if explicit None is provided as default
type_from_literal: AnyType = None
if p.default is None:
# check if the defaults were attributes of the form .config.value or .secrets.value
type_from_literal = dlt_config_literal_to_type(p.name)
Expand All @@ -88,11 +88,14 @@ def dlt_config_literal_to_type(arg_name: str) -> AnyType:
# keep type mandatory if config.value
# print(f"Param {p.name} is REQUIRED: config literal")
pass
if include_defaults or type_from_literal is not None:
# set annotations
annotations[p.name] = field_type
# set field with default value
fields[p.name] = p.default

# set annotations
annotations[p.name] = field_type
# set field with default value
fields[p.name] = p.default
if not fields:
return None

# new type goes to the module where sig was declared
fields["__module__"] = module.__name__
Expand Down
6 changes: 3 additions & 3 deletions dlt/common/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,9 @@ def get_table_columns(self, table_name: str, only_complete: bool = False) -> TTa
else:
return self._schema_tables[table_name]["columns"]

def all_tables(self, with_dlt_tables: bool = False) -> List[TTableSchema]:
"""Gets list of all tables, with or without dlt tables"""
return [t for t in self._schema_tables.values() if not t["name"].startswith("_dlt") or with_dlt_tables]
def data_tables(self) -> List[TTableSchema]:
"""Gets list of all tables, that hold the loaded data. Excludes dlt tables. Excludes incomplete tables (ie. without columns)"""
return [t for t in self._schema_tables.values() if not t["name"].startswith("_dlt") and len(t["columns"]) > 0]

def dlt_tables(self) -> List[TTableSchema]:
"""Gets dlt tables"""
Expand Down
67 changes: 49 additions & 18 deletions dlt/destinations/sql_merge_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ def from_table_chain(cls, table_chain: Sequence[TTableSchema], sql_client: SqlCl
return job

@classmethod
def _gen_key_table_clauses(cls, primary_keys: Sequence[str], merge_keys: Sequence[str], escape_identifier: Callable[[str], str])-> List[str]:
def _gen_key_table_clauses(cls, primary_keys: Sequence[str], merge_keys: Sequence[str])-> List[str]:
"""Generate sql clauses to select rows to delete via merge and primary key. Return select all clause if no keys defined."""
clauses: List[str] = []
if primary_keys or merge_keys:
if primary_keys:
clauses.append(" AND ".join(["%s.%s = %s.%s" % ("{d}", c, "{s}", c) for c in map(escape_identifier, primary_keys)]))
clauses.append(" AND ".join(["%s.%s = %s.%s" % ("{d}", c, "{s}", c) for c in primary_keys]))
if merge_keys:
clauses.append(" AND ".join(["%s.%s = %s.%s" % ("{d}", c, "{s}", c) for c in map(escape_identifier, merge_keys)]))
clauses.append(" AND ".join(["%s.%s = %s.%s" % ("{d}", c, "{s}", c) for c in merge_keys]))
return clauses or ["1=1"]

@classmethod
Expand All @@ -60,18 +60,30 @@ def gen_key_table_clauses(cls, root_table_name: str, staging_root_table_name: st
return [f"FROM {root_table_name} WHERE EXISTS (SELECT 1 FROM {staging_root_table_name} WHERE {' OR '.join([c.format(d=root_table_name,s=staging_root_table_name) for c in key_clauses])})"]

@classmethod
def gen_temp_table_sql(cls, unique_column: str, key_table_clauses: Sequence[str]) -> Tuple[List[str], str]:
"""Generate sql that creates the temp table and inserts `unique_column` from root table for all records to delete. May return several statements.
def gen_delete_temp_table_sql(cls, unique_column: str, key_table_clauses: Sequence[str]) -> Tuple[List[str], str]:
"""Generate sql that creates delete temp table and inserts `unique_column` from root table for all records to delete. May return several statements.
Returns temp table name for cases where special names are required like SQLServer.
"""
sql: List[str] = []
temp_table_name = f"test_{uniq_id()}"
temp_table_name = f"delete_{uniq_id()}"
sql.append(f"CREATE TEMP TABLE {temp_table_name} AS SELECT {unique_column} {key_table_clauses[0]};")
for clause in key_table_clauses[1:]:
sql.append(f"INSERT INTO {temp_table_name} SELECT {unique_column} {clause};")
return sql, temp_table_name

@classmethod
def gen_insert_temp_table_sql(cls, staging_root_table_name: str, primary_keys: Sequence[str], unique_column: str) -> Tuple[List[str], str]:
sql: List[str] = []
temp_table_name = f"insert_{uniq_id()}"
sql.append(f"""CREATE TEMP TABLE {temp_table_name} AS
WITH _dlt_dedup_numbered AS (
SELECT ROW_NUMBER() OVER (partition BY {", ".join(primary_keys)} ORDER BY (SELECT NULL)) AS _dlt_dedup_rn, {unique_column}
FROM {staging_root_table_name}
)
SELECT {unique_column} FROM _dlt_dedup_numbered WHERE _dlt_dedup_rn = 1;""")
return sql, temp_table_name

@classmethod
def gen_merge_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]:
sql: List[str] = []
Expand All @@ -82,11 +94,14 @@ def gen_merge_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClien
with sql_client.with_staging_dataset(staging=True):
staging_root_table_name = sql_client.make_qualified_table_name(root_table["name"])
# get merge and primary keys from top level
primary_keys = get_columns_names_with_prop(root_table, "primary_key")
merge_keys = get_columns_names_with_prop(root_table, "merge_key")
key_clauses = cls._gen_key_table_clauses(primary_keys, merge_keys, sql_client.capabilities.escape_identifier)
primary_keys = list(map(sql_client.capabilities.escape_identifier, get_columns_names_with_prop(root_table, "primary_key")))
merge_keys = list(map(sql_client.capabilities.escape_identifier, get_columns_names_with_prop(root_table, "merge_key")))
key_clauses = cls._gen_key_table_clauses(primary_keys, merge_keys)
key_table_clauses = cls.gen_key_table_clauses(root_table_name, staging_root_table_name, key_clauses)
# select_overlapped =
unique_column: str = None
root_key_column: str = None
insert_temp_table_sql: str = None

if len(table_chain) == 1:
# if no child tables, just delete data from top table
for clause in key_table_clauses:
Expand All @@ -104,10 +119,10 @@ def gen_merge_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClien
# get first unique column
unique_column = sql_client.capabilities.escape_identifier(unique_columns[0])
# create temp table with unique identifier
create_table_sql, temp_table_name = cls.gen_temp_table_sql(unique_column, key_table_clauses)
sql.extend(create_table_sql)
create_delete_temp_table_sql, delete_temp_table_sql = cls.gen_delete_temp_table_sql(unique_column, key_table_clauses)
sql.extend(create_delete_temp_table_sql)
# delete top table
sql.append(f"DELETE FROM {root_table_name} WHERE {unique_column} IN (SELECT * FROM {temp_table_name});")
sql.append(f"DELETE FROM {root_table_name} WHERE {unique_column} IN (SELECT * FROM {delete_temp_table_sql});")
# delete other tables
for table in table_chain[1:]:
table_name = sql_client.make_qualified_table_name(table["name"])
Expand All @@ -120,18 +135,34 @@ def gen_merge_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClien
f"There is no root foreign key (ie _dlt_root_id) in child table {table['name']} so it is not possible to refer to top level table {root_table['name']} unique column {unique_column}"
)
root_key_column = sql_client.capabilities.escape_identifier(root_key_columns[0])
sql.append(f"DELETE FROM {table_name} WHERE {root_key_column} IN (SELECT * FROM {temp_table_name});")
sql.append(f"DELETE FROM {table_name} WHERE {root_key_column} IN (SELECT * FROM {delete_temp_table_sql});")
# create temp table used to deduplicate, only when we have primary keys
if primary_keys:
create_insert_temp_table_sql, insert_temp_table_sql = cls.gen_insert_temp_table_sql(staging_root_table_name, primary_keys, unique_column)
sql.extend(create_insert_temp_table_sql)

# insert from staging to dataset, truncate staging table
for table in table_chain:
table_name = sql_client.make_qualified_table_name(table["name"])
with sql_client.with_staging_dataset(staging=True):
staging_table_name = sql_client.make_qualified_table_name(table["name"])
columns = ", ".join(map(sql_client.capabilities.escape_identifier, table["columns"].keys()))
sql.append(
f"""INSERT INTO {table_name}({columns})
SELECT {columns} FROM {staging_table_name};
""")
insert_sql = f"INSERT INTO {table_name}({columns}) SELECT {columns} FROM {staging_table_name}"
if len(primary_keys) > 0:
if len(table_chain) == 1:
insert_sql = f"""INSERT INTO {table_name}({columns})
WITH _dlt_dedup_numbered AS (
SELECT ROW_NUMBER() OVER (partition BY {", ".join(primary_keys)} ORDER BY (SELECT NULL)) AS _dlt_dedup_rn, {columns}
FROM {staging_table_name}
)
SELECT {columns} FROM _dlt_dedup_numbered WHERE _dlt_dedup_rn = 1;"""
else:
uniq_column = unique_column if table.get("parent") is None else root_key_column
insert_sql += f" WHERE {uniq_column} IN (SELECT * FROM {insert_temp_table_sql});"

if insert_sql[-1].strip() != ";":
insert_sql += ";"
sql.append(insert_sql)
# -- DELETE FROM {staging_table_name} WHERE 1=1;

return sql
Loading

0 comments on commit 1e7a8e6

Please sign in to comment.