From ee71b1787814991d9e0d8314de581095557000e4 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Mon, 16 Dec 2024 16:51:57 +0100 Subject: [PATCH] refactor snowflake constraints support --- dlt/destinations/impl/snowflake/snowflake.py | 71 +++++++++++++++++-- .../snowflake/test_snowflake_table_builder.py | 57 ++++++++------- 2 files changed, 96 insertions(+), 32 deletions(-) diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index c6220fd65e..da6597ecd4 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, List, Dict +from typing import Optional, Sequence, List, Dict, Set from urllib.parse import urlparse, urlunparse from dlt.common.data_writers.configuration import CsvFormatConfiguration @@ -267,6 +267,61 @@ def _make_add_column_sql( "ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c, table) for c in new_columns) ] + def _get_existing_constraints(self, table_name: str) -> Set[str]: + query = f""" + SELECT constraint_name + FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS + WHERE TABLE_NAME = '{table_name.upper()}' + """ + + if self.sql_client.catalog_name: + query += f" AND CONSTRAINT_CATALOG = '{self.sql_client.catalog_name}'" + + with self.sql_client.open_connection() as conn: + cursors = conn.execute_string(query) + existing_names = set() + for cursor in cursors: + for row in cursor: + existing_names.add(row[0]) + return existing_names + + def _get_constraints_statement( + self, table_name: str, columns: Sequence[TColumnSchema], existing_constraints: Set[str] + ) -> List[str]: + statements = [] + pk_constraint_name = f"PK_{table_name.upper()}" + uq_constraint_name = f"UQ_{table_name.upper()}" + qualified_name = self.sql_client.make_qualified_table_name(table_name) + + pk_columns = [col["name"] for col in columns if col.get("primary_key")] + unique_columns = [col["name"] for col in columns if col.get("unique")] + + # Drop existing PK/UQ constraints if found + if pk_constraint_name in existing_constraints: + statements.append(f"ALTER TABLE {qualified_name} DROP CONSTRAINT {pk_constraint_name}") + if uq_constraint_name in existing_constraints: + statements.append(f"ALTER TABLE {qualified_name} DROP CONSTRAINT {uq_constraint_name}") + + # Add PK constraint if pk_columns exist + if pk_columns: + quoted_pk_cols = ", ".join(f'"{col}"' for col in pk_columns) + statements.append( + f"ALTER TABLE {qualified_name} " + f"ADD CONSTRAINT {pk_constraint_name} " + f"PRIMARY KEY ({quoted_pk_cols})" + ) + + # Add UNIQUE constraint if unique_columns exist + if unique_columns: + quoted_uq_cols = ", ".join(f'"{col}"' for col in unique_columns) + statements.append( + f"ALTER TABLE {qualified_name} " + f"ADD CONSTRAINT {uq_constraint_name} " + f"UNIQUE ({quoted_uq_cols})" + ) + + return statements + def _get_table_update_sql( self, table_name: str, @@ -283,6 +338,13 @@ def _get_table_update_sql( if cluster_list: sql[0] = sql[0] + "\nCLUSTER BY (" + ",".join(cluster_list) + ")" + if self.active_hints: + existing_constraints = self._get_existing_constraints(table_name) + statements = self._get_constraints_statement( + table_name, new_columns, existing_constraints + ) + sql.extend(statements) + return sql def _from_db_type( @@ -291,14 +353,9 @@ def _from_db_type( return self.type_mapper.from_destination_type(bq_t, precision, scale) def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - hints_str = " ".join( - self.active_hints.get(h, "") - for h in self.active_hints.keys() - if c.get(h, False) is True - ) column_name = self.sql_client.escape_column_name(c["name"]) return ( - f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" + f"{column_name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" ) def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool: diff --git a/tests/load/snowflake/test_snowflake_table_builder.py b/tests/load/snowflake/test_snowflake_table_builder.py index bf55fe9dc6..e2bd27d18a 100644 --- a/tests/load/snowflake/test_snowflake_table_builder.py +++ b/tests/load/snowflake/test_snowflake_table_builder.py @@ -66,16 +66,16 @@ def test_create_table(snowflake_client: SnowflakeClient) -> None: assert sql.strip().startswith("CREATE TABLE") assert "EVENT_TEST_TABLE" in sql - assert '"COL1" NUMBER(19,0) NOT NULL' in sql - assert '"COL2" FLOAT NOT NULL' in sql - assert '"COL3" BOOLEAN NOT NULL' in sql - assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql + assert '"COL1" NUMBER(19,0) NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL3" BOOLEAN NOT NULL' in sql + assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql assert '"COL5" VARCHAR' in sql - assert '"COL6" NUMBER(38,9) NOT NULL' in sql + assert '"COL6" NUMBER(38,9) NOT NULL' in sql assert '"COL7" BINARY' in sql assert '"COL8" NUMBER(38,0)' in sql - assert '"COL9" VARIANT NOT NULL' in sql - assert '"COL10" DATE NOT NULL' in sql + assert '"COL9" VARIANT NOT NULL' in sql + assert '"COL10" DATE NOT NULL' in sql def test_create_table_with_hints(snowflake_client: SnowflakeClient) -> None: @@ -90,16 +90,16 @@ def test_create_table_with_hints(snowflake_client: SnowflakeClient) -> None: assert sql.strip().startswith("CREATE TABLE") assert "EVENT_TEST_TABLE" in sql - assert '"COL1" NUMBER(19,0) NOT NULL' in sql - assert '"COL2" FLOAT NOT NULL' in sql - assert '"COL3" BOOLEAN NOT NULL' in sql - assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql + assert '"COL1" NUMBER(19,0) NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL3" BOOLEAN NOT NULL' in sql + assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql assert '"COL5" VARCHAR' in sql - assert '"COL6" NUMBER(38,9) NOT NULL' in sql + assert '"COL6" NUMBER(38,9) NOT NULL' in sql assert '"COL7" BINARY' in sql assert '"COL8" NUMBER(38,0)' in sql - assert '"COL9" VARIANT NOT NULL' in sql - assert '"COL10" DATE NOT NULL' in sql + assert '"COL9" VARIANT NOT NULL' in sql + assert '"COL10" DATE NOT NULL' in sql # same thing with indexes snowflake_client = snowflake().client( @@ -108,10 +108,17 @@ def test_create_table_with_hints(snowflake_client: SnowflakeClient) -> None: dataset_name="test_" + uniq_id() ), ) - sql = snowflake_client._get_table_update_sql("event_test_table", mod_update, False)[0] - sqlfluff.parse(sql) - assert '"COL1" NUMBER(19,0) PRIMARY KEY NOT NULL' in sql - assert '"COL2" FLOAT UNIQUE NOT NULL' in sql + sql_statements = snowflake_client._get_table_update_sql("event_test_table", mod_update, False) + + for stmt in sql_statements: + sqlfluff.parse(stmt) + + assert any( + 'ADD CONSTRAINT PK_EVENT_TEST_TABLE PRIMARY KEY ("col1")' in stmt for stmt in sql_statements + ) + assert any( + 'ADD CONSTRAINT UQ_EVENT_TEST_TABLE UNIQUE ("col2")' in stmt for stmt in sql_statements + ) def test_alter_table(snowflake_client: SnowflakeClient) -> None: @@ -126,15 +133,15 @@ def test_alter_table(snowflake_client: SnowflakeClient) -> None: assert sql.count("ALTER TABLE") == 1 assert sql.count("ADD COLUMN") == 1 assert '"EVENT_TEST_TABLE"' in sql - assert '"COL1" NUMBER(19,0) NOT NULL' in sql - assert '"COL2" FLOAT NOT NULL' in sql - assert '"COL3" BOOLEAN NOT NULL' in sql - assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql + assert '"COL1" NUMBER(19,0) NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL3" BOOLEAN NOT NULL' in sql + assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql assert '"COL5" VARCHAR' in sql - assert '"COL6" NUMBER(38,9) NOT NULL' in sql + assert '"COL6" NUMBER(38,9) NOT NULL' in sql assert '"COL7" BINARY' in sql assert '"COL8" NUMBER(38,0)' in sql - assert '"COL9" VARIANT NOT NULL' in sql + assert '"COL9" VARIANT NOT NULL' in sql assert '"COL10" DATE' in sql mod_table = deepcopy(TABLE_UPDATE) @@ -142,7 +149,7 @@ def test_alter_table(snowflake_client: SnowflakeClient) -> None: sql = snowflake_client._get_table_update_sql("event_test_table", mod_table, True)[0] assert '"COL1"' not in sql - assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql def test_create_table_case_sensitive(cs_client: SnowflakeClient) -> None: