From f627a0fa153a3ef42449bab70f912e3f34a815b2 Mon Sep 17 00:00:00 2001 From: Dave Date: Thu, 5 Oct 2023 17:05:40 +0200 Subject: [PATCH] fix merge jobs for iceberg --- dlt/destinations/athena/athena.py | 5 ----- dlt/destinations/sql_jobs.py | 21 ++++++++++++++------- tests/load/utils.py | 2 -- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/dlt/destinations/athena/athena.py b/dlt/destinations/athena/athena.py index 730bcc6b9e..ed43f5b502 100644 --- a/dlt/destinations/athena/athena.py +++ b/dlt/destinations/athena/athena.py @@ -372,11 +372,6 @@ def create_table_chain_completed_followup_jobs(self, table_chain: Sequence[TTabl """Creates a list of followup jobs for merge write disposition and staging replace strategies""" jobs = super().create_table_chain_completed_followup_jobs(table_chain) - # append job if there is a merge TODO: add proper iceberg merge job - write_disposition = table_chain[0]["write_disposition"] - if write_disposition == "merge": - jobs.append(self._create_staging_copy_job(table_chain, False)) - # add some additional jobs write_disposition = table_chain[0]["write_disposition"] if write_disposition == "append": diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index b601cb4813..794db1b00a 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -143,7 +143,7 @@ def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: Returns: sql statement that inserts data from selects into temp table """ - return f"CREATE TEMP TABLE {temp_table_name} AS {select_sql};" + return f"CREATE TABLE {temp_table_name} AS {select_sql};" @classmethod def gen_merge_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: @@ -161,7 +161,8 @@ def gen_merge_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClien unique_column: str = None root_key_column: str = None - insert_temp_table_sql: str = None + insert_temp_table_name: str = None + delete_temp_table_name: str = None if len(table_chain) == 1: @@ -183,10 +184,10 @@ def gen_merge_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClien # get first unique column unique_column = sql_client.capabilities.escape_identifier(unique_columns[0]) # create temp table with unique identifier - create_delete_temp_table_sql, delete_temp_table_sql = cls.gen_delete_temp_table_sql(unique_column, key_table_clauses) + create_delete_temp_table_sql, delete_temp_table_name = cls.gen_delete_temp_table_sql(unique_column, key_table_clauses) sql.extend(create_delete_temp_table_sql) # delete top table - sql.append(f"DELETE FROM {root_table_name} WHERE {unique_column} IN (SELECT * FROM {delete_temp_table_sql});") + sql.append(f"DELETE FROM {root_table_name} WHERE {unique_column} IN (SELECT * FROM {delete_temp_table_name});") # delete other tables for table in table_chain[1:]: table_name = sql_client.make_qualified_table_name(table["name"]) @@ -199,10 +200,10 @@ def gen_merge_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClien f"There is no root foreign key (ie _dlt_root_id) in child table {table['name']} so it is not possible to refer to top level table {root_table['name']} unique column {unique_column}" ) root_key_column = sql_client.capabilities.escape_identifier(root_key_columns[0]) - sql.append(f"DELETE FROM {table_name} WHERE {root_key_column} IN (SELECT * FROM {delete_temp_table_sql});") + sql.append(f"DELETE FROM {table_name} WHERE {root_key_column} IN (SELECT * FROM {delete_temp_table_name});") # create temp table used to deduplicate, only when we have primary keys if primary_keys: - create_insert_temp_table_sql, insert_temp_table_sql = cls.gen_insert_temp_table_sql(staging_root_table_name, primary_keys, unique_column) + create_insert_temp_table_sql, insert_temp_table_name = cls.gen_insert_temp_table_sql(staging_root_table_name, primary_keys, unique_column) sql.extend(create_insert_temp_table_sql) # insert from staging to dataset, truncate staging table @@ -222,11 +223,17 @@ def gen_merge_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClien """ else: uniq_column = unique_column if table.get("parent") is None else root_key_column - insert_sql += f" WHERE {uniq_column} IN (SELECT * FROM {insert_temp_table_sql});" + insert_sql += f" WHERE {uniq_column} IN (SELECT * FROM {insert_temp_table_name});" if insert_sql.strip()[-1] != ";": insert_sql += ";" sql.append(insert_sql) # -- DELETE FROM {staging_table_name} WHERE 1=1; + # clean up + if insert_temp_table_name: + sql.append(f"DROP TABLE {insert_temp_table_name};") + if delete_temp_table_name: + sql.append(f"DROP TABLE {delete_temp_table_name};") + return sql diff --git a/tests/load/utils.py b/tests/load/utils.py index a0cab6ff73..471c8cc59f 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -145,8 +145,6 @@ def destinations_configs( for bucket in ALL_BUCKETS: destination_configs += [DestinationTestConfiguration(destination="filesystem", bucket_url=bucket, extra_info=bucket)] - # destination_configs = [DestinationTestConfiguration(destination="athena", staging="filesystem", file_format="parquet", bucket_url=AWS_BUCKET, iceberg_bucket_url=AWS_BUCKET + "/iceberg", supports_merge=False, extra_info="iceberg")] - # filter out non active destinations destination_configs = [conf for conf in destination_configs if conf.destination in ACTIVE_DESTINATIONS]