Skip to content

Commit

Permalink
select fallback w_d for merge in loader
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Mar 26, 2024
1 parent 70c4135 commit ef96969
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 12 deletions.
6 changes: 3 additions & 3 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from dlt.common.schema import Schema, TTableSchema, TSchemaTables
from dlt.common.schema.exceptions import SchemaException
from dlt.common.schema.utils import (
get_write_disposition,
ensure_write_disposition,
get_table_format,
get_columns_names_with_prop,
has_column_with_prop,
Expand Down Expand Up @@ -305,6 +305,7 @@ def restore_file_load(self, file_path: str) -> LoadJob:
pass

def should_truncate_table_before_load(self, table: TTableSchema) -> bool:
table = self.prepare_load_table(table["name"])
return table["write_disposition"] == "replace"

def create_table_chain_completed_followup_jobs(
Expand Down Expand Up @@ -420,8 +421,7 @@ def prepare_load_table(
# make a copy of the schema so modifications do not affect the original document
table = deepcopy(self.schema.tables[table_name])
# add write disposition if not specified - in child tables
if "write_disposition" not in table:
table["write_disposition"] = get_write_disposition(self.schema.tables, table_name)
ensure_write_disposition(self.schema.tables, table)
if "table_format" not in table:
table["table_format"] = get_table_format(self.schema.tables, table_name)
return table
Expand Down
23 changes: 23 additions & 0 deletions dlt/common/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,12 @@ def get_inherited_table_hint(
)


def get_root_table(tables: TSchemaTables, table: TTableSchema) -> TTableSchema:
while table.get("parent"):
table = tables[table["parent"]]
return table


def get_write_disposition(tables: TSchemaTables, table_name: str) -> TWriteDisposition:
"""Returns table hint of a table if present. If not, looks up into parent table"""
return cast(
Expand All @@ -555,6 +561,23 @@ def get_write_disposition(tables: TSchemaTables, table_name: str) -> TWriteDispo
)


def ensure_write_disposition(tables: TSchemaTables, table: TTableSchema) -> None:
"""
Ensures the table has inherited the correct write disposition
Also falls back to append for tables that are declared as merge but
do not have merge keys
"""
root_table = get_root_table(tables, table)
w_d = root_table["write_disposition"]
if (
w_d == "merge"
and (not get_columns_names_with_prop(root_table, "primary_key"))
and (not get_columns_names_with_prop(root_table, "merge_key"))
):
w_d = "append"
table["write_disposition"] = w_d


def get_table_format(tables: TSchemaTables, table_name: str) -> TTableFormat:
return cast(
TTableFormat, get_inherited_table_hint(tables, table_name, "table_format", allow_none=True)
Expand Down
4 changes: 3 additions & 1 deletion dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def maybe_ddl_transaction(self) -> Iterator[None]:
yield

def should_truncate_table_before_load(self, table: TTableSchema) -> bool:
table = self.prepare_load_table(table["name"])
return (
table["write_disposition"] == "replace"
and self.config.replace_strategy == "truncate-and-insert"
Expand All @@ -240,7 +241,8 @@ def create_table_chain_completed_followup_jobs(
) -> List[NewLoadJob]:
"""Creates a list of followup jobs for merge write disposition and staging replace strategies"""
jobs = super().create_table_chain_completed_followup_jobs(table_chain)
write_disposition = table_chain[0]["write_disposition"]
root_table = self.prepare_load_table(table_chain[0]["name"])
write_disposition = root_table["write_disposition"]
if write_disposition == "append":
jobs.extend(self._create_append_followup_jobs(table_chain))
elif write_disposition == "merge":
Expand Down
10 changes: 6 additions & 4 deletions dlt/destinations/sql_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,12 @@ def gen_merge_sql(
)
)

has_merge_keys = len(primary_keys) == 0 and len(merge_keys)
if not has_merge_keys:
raise DatabaseTransientException("Could not find primary or merge keys, aborting merge job.")

if not primary_keys and not merge_keys:
# NOTE: this should never happen, the loader should select append for each tables that does not have
# the required keys
raise DatabaseTransientException(
"Could not find primary or merge keys, aborting merge job."
)

key_clauses = cls._gen_key_table_clauses(primary_keys, merge_keys)

Expand Down
9 changes: 5 additions & 4 deletions tests/load/pipeline/test_merge_disposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,8 @@ def test_merge_no_merge_keys(destination_config: DestinationTestConfiguration) -
info = p.run(github_data, loader_file_format=destination_config.file_format)
assert_load_info(info)
github_1_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()])
# only ten rows remains. merge falls back to replace when no keys are specified
assert github_1_counts["issues"] == 10 if destination_config.supports_merge else 100 - 45
# ten new rows are added. merge falls back to append when no keys are specified
assert github_1_counts["issues"] == (100 - 45) + 10


@pytest.mark.parametrize(
Expand All @@ -288,14 +288,15 @@ def test_merge_keys_non_existing_columns(destination_config: DestinationTestConf
if not destination_config.supports_merge:
return

# all the keys are invalid so the merge falls back to replace
# all the keys are invalid so the merge falls back to append
github_data = github()
github_data.load_issues.apply_hints(merge_key=("mA1", "Ma2"), primary_key=("123-x",))
github_data.load_issues.add_filter(take_first(1))
info = p.run(github_data, loader_file_format=destination_config.file_format)
assert_load_info(info)
github_2_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()])
assert github_2_counts["issues"] == 1
# we have one more than before
assert github_2_counts["issues"] == (100 - 45) + 1
with p._sql_job_client(p.default_schema) as job_c:
_, table_schema = job_c.get_storage_table("issues")
assert "url" in table_schema
Expand Down

0 comments on commit ef96969

Please sign in to comment.