From 88408d6385c2b4ac06f6245d026dfd394da23c24 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Mon, 15 Jan 2024 21:13:53 -0500 Subject: [PATCH] Databricks merge disposition support --- .../impl/databricks/databricks.py | 27 ++++++++--- .../impl/databricks/sql_client.py | 26 +++-------- dlt/destinations/sql_jobs.py | 46 ++++++++++++++----- tests/load/pipeline/test_stage_loading.py | 1 + 4 files changed, 62 insertions(+), 38 deletions(-) diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index d8425f5b57..28bbd0dd2f 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -224,10 +224,21 @@ def generate_sql( return sql -# class DatabricksMergeJob(SqlMergeJob): -# @classmethod -# def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: -# return f"CREATE OR REPLACE TEMPORARY VIEW {temp_table_name} AS {select_sql};" +class DatabricksMergeJob(SqlMergeJob): + @classmethod + def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: + return f"CREATE TEMPORARY VIEW {temp_table_name} AS {select_sql};" + + @classmethod + def gen_delete_from_sql( + cls, table_name: str, column_name: str, temp_table_name: str, temp_table_column: str + ) -> str: + # Databricks does not support subqueries in DELETE FROM statements so we use a MERGE statement instead + return f"""MERGE INTO {table_name} + USING {temp_table_name} + ON {table_name}.{column_name} = {temp_table_name}.{temp_table_column} + WHEN MATCHED THEN DELETE; + """ class DatabricksClient(InsertValuesJobClient, SupportsStagingDestination): @@ -260,8 +271,11 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") - # def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: - # return DatabricksMergeJob.from_table_chain(table_chain, self.sql_client) + def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: + return DatabricksMergeJob.from_table_chain(table_chain, self.sql_client) + + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + return [DatabricksMergeJob.from_table_chain(table_chain, self.sql_client)] # def _create_staging_copy_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: # return DatabricksStagingCopyJob.from_table_chain(table_chain, self.sql_client) @@ -274,6 +288,7 @@ def _make_add_column_sql( # def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: # return DatabricksStagingCopyJob.from_table_chain(table_chain, self.sql_client) + def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] ) -> List[NewLoadJob]: diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index b2dbfcdd05..a19cd24811 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -59,20 +59,18 @@ def close_connection(self) -> None: @contextmanager def begin_transaction(self) -> Iterator[DBTransaction]: - logger.warning( - "NotImplemented: Databricks does not support transactions. Each SQL statement is" - " auto-committed separately." - ) + # Databricks does not support transactions yield self @raise_database_error def commit_transaction(self) -> None: - logger.warning("NotImplemented: commit") + # Databricks does not support transactions pass @raise_database_error def rollback_transaction(self) -> None: - logger.warning("NotImplemented: rollback") + # Databricks does not support transactions + pass @property def native_connection(self) -> "DatabricksSqlConnection": @@ -127,16 +125,8 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB else: db_args = None with self._conn.cursor() as curr: - try: - curr.execute(query, db_args) - yield DatabricksCursorImpl(curr) # type: ignore[abstract] - except databricks_lib.Error as outer: - try: - self._reset_connection() - except databricks_lib.Error: - self.close_connection() - self.open_connection() - raise outer + curr.execute(query, db_args) + yield DatabricksCursorImpl(curr) # type: ignore[abstract] def fully_qualified_dataset_name(self, escape: bool = True) -> str: if escape: @@ -147,10 +137,6 @@ def fully_qualified_dataset_name(self, escape: bool = True) -> str: dataset_name = self.dataset_name return f"{catalog}.{dataset_name}" - def _reset_connection(self) -> None: - self.close_connection() - self.open_connection() - @staticmethod def _make_database_exception(ex: Exception) -> Exception: if isinstance(ex, databricks_lib.ServerOperationError): diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index d97a098669..899947313d 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -186,6 +186,21 @@ def gen_insert_temp_table_sql( """ return [cls._to_temp_table(select_statement, temp_table_name)], temp_table_name + @classmethod + def gen_delete_from_sql( + cls, + table_name: str, + unique_column: str, + delete_temp_table_name: str, + temp_table_column: str, + ) -> str: + """Generate DELETE FROM statement deleting the records found in the deletes temp table.""" + return f"""DELETE FROM {table_name} + WHERE {unique_column} IN ( + SELECT * FROM {delete_temp_table_name} + ); + """ + @classmethod def _new_temp_table_name(cls, name_prefix: str) -> str: return f"{name_prefix}_{uniq_id()}" @@ -261,12 +276,9 @@ def gen_merge_sql( unique_column, key_table_clauses ) sql.extend(create_delete_temp_table_sql) - # delete top table - sql.append( - f"DELETE FROM {root_table_name} WHERE {unique_column} IN (SELECT * FROM" - f" {delete_temp_table_name});" - ) - # delete other tables + + # delete from child tables first. This is important for databricks which does not support temporary tables, + # but uses temporary views instead for table in table_chain[1:]: table_name = sql_client.make_qualified_table_name(table["name"]) root_key_columns = get_columns_names_with_prop(table, "root_key") @@ -281,15 +293,25 @@ def gen_merge_sql( ) root_key_column = sql_client.capabilities.escape_identifier(root_key_columns[0]) sql.append( - f"DELETE FROM {table_name} WHERE {root_key_column} IN (SELECT * FROM" - f" {delete_temp_table_name});" + cls.gen_delete_from_sql( + table_name, root_key_column, delete_temp_table_name, unique_column + ) + ) + + # delete from top table now that child tables have been prcessed + sql.append( + cls.gen_delete_from_sql( + root_table_name, unique_column, delete_temp_table_name, unique_column ) + ) + # create temp table used to deduplicate, only when we have primary keys if primary_keys: - create_insert_temp_table_sql, insert_temp_table_name = ( - cls.gen_insert_temp_table_sql( - staging_root_table_name, primary_keys, unique_column - ) + ( + create_insert_temp_table_sql, + insert_temp_table_name, + ) = cls.gen_insert_temp_table_sql( + staging_root_table_name, primary_keys, unique_column ) sql.extend(create_insert_temp_table_sql) diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py index de4a7f4c3b..bba589b444 100644 --- a/tests/load/pipeline/test_stage_loading.py +++ b/tests/load/pipeline/test_stage_loading.py @@ -158,6 +158,7 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non if destination_config.destination in ( "redshift", "athena", + "databricks", ) and destination_config.file_format in ("parquet", "jsonl"): # Redshift copy doesn't support TIME column exclude_types.append("time")