Skip to content

Commit

Permalink
Databricks merge disposition support
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Jan 16, 2024
1 parent ac71c67 commit 88408d6
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 38 deletions.
27 changes: 21 additions & 6 deletions dlt/destinations/impl/databricks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,21 @@ def generate_sql(
return sql


# class DatabricksMergeJob(SqlMergeJob):
# @classmethod
# def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str:
# return f"CREATE OR REPLACE TEMPORARY VIEW {temp_table_name} AS {select_sql};"
class DatabricksMergeJob(SqlMergeJob):
@classmethod
def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str:
return f"CREATE TEMPORARY VIEW {temp_table_name} AS {select_sql};"

@classmethod
def gen_delete_from_sql(
cls, table_name: str, column_name: str, temp_table_name: str, temp_table_column: str
) -> str:
# Databricks does not support subqueries in DELETE FROM statements so we use a MERGE statement instead
return f"""MERGE INTO {table_name}
USING {temp_table_name}
ON {table_name}.{column_name} = {temp_table_name}.{temp_table_column}
WHEN MATCHED THEN DELETE;
"""


class DatabricksClient(InsertValuesJobClient, SupportsStagingDestination):
Expand Down Expand Up @@ -260,8 +271,11 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) ->
def restore_file_load(self, file_path: str) -> LoadJob:
return EmptyLoadJob.from_file_path(file_path, "completed")

# def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob:
# return DatabricksMergeJob.from_table_chain(table_chain, self.sql_client)
def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob:
return DatabricksMergeJob.from_table_chain(table_chain, self.sql_client)

def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]:
return [DatabricksMergeJob.from_table_chain(table_chain, self.sql_client)]

# def _create_staging_copy_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob:
# return DatabricksStagingCopyJob.from_table_chain(table_chain, self.sql_client)
Expand All @@ -274,6 +288,7 @@ def _make_add_column_sql(

# def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob:
# return DatabricksStagingCopyJob.from_table_chain(table_chain, self.sql_client)

def _create_replace_followup_jobs(
self, table_chain: Sequence[TTableSchema]
) -> List[NewLoadJob]:
Expand Down
26 changes: 6 additions & 20 deletions dlt/destinations/impl/databricks/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,18 @@ def close_connection(self) -> None:

@contextmanager
def begin_transaction(self) -> Iterator[DBTransaction]:
logger.warning(
"NotImplemented: Databricks does not support transactions. Each SQL statement is"
" auto-committed separately."
)
# Databricks does not support transactions
yield self

@raise_database_error
def commit_transaction(self) -> None:
logger.warning("NotImplemented: commit")
# Databricks does not support transactions
pass

@raise_database_error
def rollback_transaction(self) -> None:
logger.warning("NotImplemented: rollback")
# Databricks does not support transactions
pass

@property
def native_connection(self) -> "DatabricksSqlConnection":
Expand Down Expand Up @@ -127,16 +125,8 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB
else:
db_args = None
with self._conn.cursor() as curr:
try:
curr.execute(query, db_args)
yield DatabricksCursorImpl(curr) # type: ignore[abstract]
except databricks_lib.Error as outer:
try:
self._reset_connection()
except databricks_lib.Error:
self.close_connection()
self.open_connection()
raise outer
curr.execute(query, db_args)
yield DatabricksCursorImpl(curr) # type: ignore[abstract]

def fully_qualified_dataset_name(self, escape: bool = True) -> str:
if escape:
Expand All @@ -147,10 +137,6 @@ def fully_qualified_dataset_name(self, escape: bool = True) -> str:
dataset_name = self.dataset_name
return f"{catalog}.{dataset_name}"

def _reset_connection(self) -> None:
self.close_connection()
self.open_connection()

@staticmethod
def _make_database_exception(ex: Exception) -> Exception:
if isinstance(ex, databricks_lib.ServerOperationError):
Expand Down
46 changes: 34 additions & 12 deletions dlt/destinations/sql_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,21 @@ def gen_insert_temp_table_sql(
"""
return [cls._to_temp_table(select_statement, temp_table_name)], temp_table_name

@classmethod
def gen_delete_from_sql(
cls,
table_name: str,
unique_column: str,
delete_temp_table_name: str,
temp_table_column: str,
) -> str:
"""Generate DELETE FROM statement deleting the records found in the deletes temp table."""
return f"""DELETE FROM {table_name}
WHERE {unique_column} IN (
SELECT * FROM {delete_temp_table_name}
);
"""

@classmethod
def _new_temp_table_name(cls, name_prefix: str) -> str:
return f"{name_prefix}_{uniq_id()}"
Expand Down Expand Up @@ -261,12 +276,9 @@ def gen_merge_sql(
unique_column, key_table_clauses
)
sql.extend(create_delete_temp_table_sql)
# delete top table
sql.append(
f"DELETE FROM {root_table_name} WHERE {unique_column} IN (SELECT * FROM"
f" {delete_temp_table_name});"
)
# delete other tables

# delete from child tables first. This is important for databricks which does not support temporary tables,
# but uses temporary views instead
for table in table_chain[1:]:
table_name = sql_client.make_qualified_table_name(table["name"])
root_key_columns = get_columns_names_with_prop(table, "root_key")
Expand All @@ -281,15 +293,25 @@ def gen_merge_sql(
)
root_key_column = sql_client.capabilities.escape_identifier(root_key_columns[0])
sql.append(
f"DELETE FROM {table_name} WHERE {root_key_column} IN (SELECT * FROM"
f" {delete_temp_table_name});"
cls.gen_delete_from_sql(
table_name, root_key_column, delete_temp_table_name, unique_column
)
)

# delete from top table now that child tables have been prcessed
sql.append(
cls.gen_delete_from_sql(
root_table_name, unique_column, delete_temp_table_name, unique_column
)
)

# create temp table used to deduplicate, only when we have primary keys
if primary_keys:
create_insert_temp_table_sql, insert_temp_table_name = (
cls.gen_insert_temp_table_sql(
staging_root_table_name, primary_keys, unique_column
)
(
create_insert_temp_table_sql,
insert_temp_table_name,
) = cls.gen_insert_temp_table_sql(
staging_root_table_name, primary_keys, unique_column
)
sql.extend(create_insert_temp_table_sql)

Expand Down
1 change: 1 addition & 0 deletions tests/load/pipeline/test_stage_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non
if destination_config.destination in (
"redshift",
"athena",
"databricks",
) and destination_config.file_format in ("parquet", "jsonl"):
# Redshift copy doesn't support TIME column
exclude_types.append("time")
Expand Down

0 comments on commit 88408d6

Please sign in to comment.