Skip to content

Commit

Permalink
refactor snowflake constraints support
Browse files Browse the repository at this point in the history
  • Loading branch information
donotpush committed Dec 16, 2024
1 parent 96003f4 commit ee71b17
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 32 deletions.
71 changes: 64 additions & 7 deletions dlt/destinations/impl/snowflake/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Sequence, List, Dict
from typing import Optional, Sequence, List, Dict, Set
from urllib.parse import urlparse, urlunparse

from dlt.common.data_writers.configuration import CsvFormatConfiguration
Expand Down Expand Up @@ -267,6 +267,61 @@ def _make_add_column_sql(
"ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c, table) for c in new_columns)
]

def _get_existing_constraints(self, table_name: str) -> Set[str]:
query = f"""
SELECT constraint_name
FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS
WHERE TABLE_NAME = '{table_name.upper()}'
"""

if self.sql_client.catalog_name:
query += f" AND CONSTRAINT_CATALOG = '{self.sql_client.catalog_name}'"

with self.sql_client.open_connection() as conn:
cursors = conn.execute_string(query)
existing_names = set()
for cursor in cursors:
for row in cursor:
existing_names.add(row[0])
return existing_names

def _get_constraints_statement(
self, table_name: str, columns: Sequence[TColumnSchema], existing_constraints: Set[str]
) -> List[str]:
statements = []
pk_constraint_name = f"PK_{table_name.upper()}"
uq_constraint_name = f"UQ_{table_name.upper()}"
qualified_name = self.sql_client.make_qualified_table_name(table_name)

pk_columns = [col["name"] for col in columns if col.get("primary_key")]
unique_columns = [col["name"] for col in columns if col.get("unique")]

# Drop existing PK/UQ constraints if found
if pk_constraint_name in existing_constraints:
statements.append(f"ALTER TABLE {qualified_name} DROP CONSTRAINT {pk_constraint_name}")
if uq_constraint_name in existing_constraints:
statements.append(f"ALTER TABLE {qualified_name} DROP CONSTRAINT {uq_constraint_name}")

# Add PK constraint if pk_columns exist
if pk_columns:
quoted_pk_cols = ", ".join(f'"{col}"' for col in pk_columns)
statements.append(
f"ALTER TABLE {qualified_name} "
f"ADD CONSTRAINT {pk_constraint_name} "
f"PRIMARY KEY ({quoted_pk_cols})"
)

# Add UNIQUE constraint if unique_columns exist
if unique_columns:
quoted_uq_cols = ", ".join(f'"{col}"' for col in unique_columns)
statements.append(
f"ALTER TABLE {qualified_name} "
f"ADD CONSTRAINT {uq_constraint_name} "
f"UNIQUE ({quoted_uq_cols})"
)

return statements

def _get_table_update_sql(
self,
table_name: str,
Expand All @@ -283,6 +338,13 @@ def _get_table_update_sql(
if cluster_list:
sql[0] = sql[0] + "\nCLUSTER BY (" + ",".join(cluster_list) + ")"

if self.active_hints:
existing_constraints = self._get_existing_constraints(table_name)
statements = self._get_constraints_statement(
table_name, new_columns, existing_constraints
)
sql.extend(statements)

return sql

def _from_db_type(
Expand All @@ -291,14 +353,9 @@ def _from_db_type(
return self.type_mapper.from_destination_type(bq_t, precision, scale)

def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str:
hints_str = " ".join(
self.active_hints.get(h, "")
for h in self.active_hints.keys()
if c.get(h, False) is True
)
column_name = self.sql_client.escape_column_name(c["name"])
return (
f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}"
f"{column_name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}"
)

def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool:
Expand Down
57 changes: 32 additions & 25 deletions tests/load/snowflake/test_snowflake_table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,16 @@ def test_create_table(snowflake_client: SnowflakeClient) -> None:

assert sql.strip().startswith("CREATE TABLE")
assert "EVENT_TEST_TABLE" in sql
assert '"COL1" NUMBER(19,0) NOT NULL' in sql
assert '"COL2" FLOAT NOT NULL' in sql
assert '"COL3" BOOLEAN NOT NULL' in sql
assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql
assert '"COL1" NUMBER(19,0) NOT NULL' in sql
assert '"COL2" FLOAT NOT NULL' in sql
assert '"COL3" BOOLEAN NOT NULL' in sql
assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql
assert '"COL5" VARCHAR' in sql
assert '"COL6" NUMBER(38,9) NOT NULL' in sql
assert '"COL6" NUMBER(38,9) NOT NULL' in sql
assert '"COL7" BINARY' in sql
assert '"COL8" NUMBER(38,0)' in sql
assert '"COL9" VARIANT NOT NULL' in sql
assert '"COL10" DATE NOT NULL' in sql
assert '"COL9" VARIANT NOT NULL' in sql
assert '"COL10" DATE NOT NULL' in sql


def test_create_table_with_hints(snowflake_client: SnowflakeClient) -> None:
Expand All @@ -90,16 +90,16 @@ def test_create_table_with_hints(snowflake_client: SnowflakeClient) -> None:

assert sql.strip().startswith("CREATE TABLE")
assert "EVENT_TEST_TABLE" in sql
assert '"COL1" NUMBER(19,0) NOT NULL' in sql
assert '"COL2" FLOAT NOT NULL' in sql
assert '"COL3" BOOLEAN NOT NULL' in sql
assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql
assert '"COL1" NUMBER(19,0) NOT NULL' in sql
assert '"COL2" FLOAT NOT NULL' in sql
assert '"COL3" BOOLEAN NOT NULL' in sql
assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql
assert '"COL5" VARCHAR' in sql
assert '"COL6" NUMBER(38,9) NOT NULL' in sql
assert '"COL6" NUMBER(38,9) NOT NULL' in sql
assert '"COL7" BINARY' in sql
assert '"COL8" NUMBER(38,0)' in sql
assert '"COL9" VARIANT NOT NULL' in sql
assert '"COL10" DATE NOT NULL' in sql
assert '"COL9" VARIANT NOT NULL' in sql
assert '"COL10" DATE NOT NULL' in sql

# same thing with indexes
snowflake_client = snowflake().client(
Expand All @@ -108,10 +108,17 @@ def test_create_table_with_hints(snowflake_client: SnowflakeClient) -> None:
dataset_name="test_" + uniq_id()
),
)
sql = snowflake_client._get_table_update_sql("event_test_table", mod_update, False)[0]
sqlfluff.parse(sql)
assert '"COL1" NUMBER(19,0) PRIMARY KEY NOT NULL' in sql
assert '"COL2" FLOAT UNIQUE NOT NULL' in sql
sql_statements = snowflake_client._get_table_update_sql("event_test_table", mod_update, False)

for stmt in sql_statements:
sqlfluff.parse(stmt)

assert any(
'ADD CONSTRAINT PK_EVENT_TEST_TABLE PRIMARY KEY ("col1")' in stmt for stmt in sql_statements
)
assert any(
'ADD CONSTRAINT UQ_EVENT_TEST_TABLE UNIQUE ("col2")' in stmt for stmt in sql_statements
)


def test_alter_table(snowflake_client: SnowflakeClient) -> None:
Expand All @@ -126,23 +133,23 @@ def test_alter_table(snowflake_client: SnowflakeClient) -> None:
assert sql.count("ALTER TABLE") == 1
assert sql.count("ADD COLUMN") == 1
assert '"EVENT_TEST_TABLE"' in sql
assert '"COL1" NUMBER(19,0) NOT NULL' in sql
assert '"COL2" FLOAT NOT NULL' in sql
assert '"COL3" BOOLEAN NOT NULL' in sql
assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql
assert '"COL1" NUMBER(19,0) NOT NULL' in sql
assert '"COL2" FLOAT NOT NULL' in sql
assert '"COL3" BOOLEAN NOT NULL' in sql
assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql
assert '"COL5" VARCHAR' in sql
assert '"COL6" NUMBER(38,9) NOT NULL' in sql
assert '"COL6" NUMBER(38,9) NOT NULL' in sql
assert '"COL7" BINARY' in sql
assert '"COL8" NUMBER(38,0)' in sql
assert '"COL9" VARIANT NOT NULL' in sql
assert '"COL9" VARIANT NOT NULL' in sql
assert '"COL10" DATE' in sql

mod_table = deepcopy(TABLE_UPDATE)
mod_table.pop(0)
sql = snowflake_client._get_table_update_sql("event_test_table", mod_table, True)[0]

assert '"COL1"' not in sql
assert '"COL2" FLOAT NOT NULL' in sql
assert '"COL2" FLOAT NOT NULL' in sql


def test_create_table_case_sensitive(cs_client: SnowflakeClient) -> None:
Expand Down

0 comments on commit ee71b17

Please sign in to comment.