diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index eb74c143b1..786cdc0b77 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -1,6 +1,7 @@ from typing import Optional, Sequence, List, Dict, Set from urllib.parse import urlparse, urlunparse +from dlt.common import logger from dlt.common.data_writers.configuration import CsvFormatConfiguration from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( @@ -15,13 +16,15 @@ AwsCredentialsWithoutDefaults, AzureCredentialsWithoutDefaults, ) +from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages.configuration import FilesystemConfiguration, ensure_canonical_az_url from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TColumnHint -from dlt.common.schema.typing import TColumnType +from dlt.common.schema.typing import TColumnType, TTableSchema from dlt.common.storages.fsspec_filesystem import AZURE_BLOB_STORAGE_PROTOCOLS, S3_PROTOCOLS from dlt.common.typing import TLoaderFileFormat +from dlt.common.utils import uniq_id from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset from dlt.destinations.exceptions import LoadJobTerminalException @@ -29,7 +32,7 @@ from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient from dlt.destinations.job_impl import ReferenceFollowupJobRequest -SUPPORTED_HINTS: Dict[TColumnHint, str] = {"unique": "UNIQUE", "primary_key": "PRIMARY KEY"} +SUPPORTED_HINTS: Dict[TColumnHint, str] = {"unique": "UNIQUE"} class SnowflakeLoadJob(RunnableLoadJob, HasFollowupJobs): @@ -267,60 +270,32 @@ def _make_add_column_sql( "ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c, table) for c in new_columns) ] - def _get_existing_constraints(self, table_name: str) -> Set[str]: - query = f""" - SELECT constraint_name - FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS - WHERE TABLE_NAME = '{table_name.upper()}' - """ - - if self.sql_client.catalog_name: - query += f" AND CONSTRAINT_CATALOG = '{self.sql_client.catalog_name}'" - - with self.sql_client.open_connection() as conn: - cursors = conn.execute_string(query) - existing_names = set() - for cursor in cursors: - for row in cursor: - existing_names.add(row[0]) - return existing_names - - def _get_constraints_statement( - self, table_name: str, columns: Sequence[TColumnSchema], existing_constraints: Set[str] - ) -> List[str]: - statements = [] - pk_constraint_name = f"PK_{table_name.upper()}" - uq_constraint_name = f"UQ_{table_name.upper()}" - qualified_name = self.sql_client.make_qualified_table_name(table_name) - - pk_columns = [col["name"] for col in columns if col.get("primary_key")] - unique_columns = [col["name"] for col in columns if col.get("unique")] - - # Drop existing PK/UQ constraints if found - if pk_constraint_name in existing_constraints: - statements.append(f"ALTER TABLE {qualified_name} DROP CONSTRAINT {pk_constraint_name}") - if uq_constraint_name in existing_constraints: - statements.append(f"ALTER TABLE {qualified_name} DROP CONSTRAINT {uq_constraint_name}") - - # Add PK constraint if pk_columns exist - if pk_columns: - quoted_pk_cols = ", ".join(f'"{col}"' for col in pk_columns) - statements.append( - f"ALTER TABLE {qualified_name} " - f"ADD CONSTRAINT {pk_constraint_name} " - f"PRIMARY KEY ({quoted_pk_cols})" - ) - - # Add UNIQUE constraint if unique_columns exist - if unique_columns: - quoted_uq_cols = ", ".join(f'"{col}"' for col in unique_columns) - statements.append( - f"ALTER TABLE {qualified_name} " - f"ADD CONSTRAINT {uq_constraint_name} " - f"UNIQUE ({quoted_uq_cols})" - ) - - return statements + def _get_constraints_sql( + self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool + ) -> str: + # "primary_key": "PRIMARY KEY" + if self.config.create_indexes: + partial: TTableSchema = { + "name": table_name, + "columns": {c["name"]: c for c in new_columns}, + } + # Add PK constraint if pk_columns exist + pk_columns = get_columns_names_with_prop(partial, "primary_key") + if pk_columns: + if generate_alter: + logger.warning( + f"PRIMARY KEY on {table_name} constraint cannot be added in ALTER TABLE and" + " is ignored" + ) + else: + pk_constraint_name = list( + self._norm_and_escape_columns(f"PK_{table_name}_{uniq_id(4)}") + )[0] + quoted_pk_cols = ", ".join( + self.sql_client.escape_column_name(col) for col in pk_columns + ) + return f",\nCONSTRAINT {pk_constraint_name} PRIMARY KEY ({quoted_pk_cols})" + return "" def _get_table_update_sql( self, @@ -338,13 +313,6 @@ def _get_table_update_sql( if cluster_list: sql[0] = sql[0] + "\nCLUSTER BY (" + ",".join(cluster_list) + ")" - if self.active_hints: - existing_constraints = self._get_existing_constraints(table_name) - statements = self._get_constraints_statement( - table_name, new_columns, existing_constraints - ) - sql.extend(statements) - return sql def _from_db_type( @@ -352,11 +320,5 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_destination_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: - name = self.sql_client.escape_column_name(c["name"]) - return ( - f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" - ) - def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool: return self.config.truncate_tables_on_staging_destination_before_load diff --git a/tests/load/snowflake/test_snowflake_client.py b/tests/load/snowflake/test_snowflake_client.py index aebf514b56..674e01ba31 100644 --- a/tests/load/snowflake/test_snowflake_client.py +++ b/tests/load/snowflake/test_snowflake_client.py @@ -1,14 +1,17 @@ +from copy import deepcopy import os from typing import Iterator from pytest_mock import MockerFixture import pytest -from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient +from dlt.common.schema.schema import Schema +from dlt.destinations.impl.snowflake.snowflake import SUPPORTED_HINTS, SnowflakeClient from dlt.destinations.job_client_impl import SqlJobClientBase from dlt.destinations.sql_client import TJobQueryTags -from tests.load.utils import yield_client_with_storage +from tests.cases import TABLE_UPDATE +from tests.load.utils import yield_client_with_storage, empty_schema # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -32,6 +35,39 @@ def client() -> Iterator[SqlJobClientBase]: yield from yield_client_with_storage("snowflake") +def test_create_table_with_hints(client: SnowflakeClient, empty_schema: Schema) -> None: + mod_update = deepcopy(TABLE_UPDATE[:11]) + # mock hints + client.config.create_indexes = True + client.active_hints = SUPPORTED_HINTS + client.schema = empty_schema + + mod_update[0]["primary_key"] = True + mod_update[5]["primary_key"] = True + + mod_update[0]["sort"] = True + mod_update[4]["parent_key"] = True + + # unique constraints are always single columns + mod_update[1]["unique"] = True + mod_update[7]["unique"] = True + + sql = ";".join(client._get_table_update_sql("event_test_table", mod_update, False)) + + print(sql) + client.sql_client.execute_sql(sql) + + # generate alter table + mod_update = deepcopy(TABLE_UPDATE[11:]) + mod_update[0]["primary_key"] = True + mod_update[1]["unique"] = True + + sql = ";".join(client._get_table_update_sql("event_test_table", mod_update, True)) + + print(sql) + client.sql_client.execute_sql(sql) + + def test_query_tag(client: SnowflakeClient, mocker: MockerFixture): assert client.config.query_tag == QUERY_TAG # make sure we generate proper query diff --git a/tests/load/snowflake/test_snowflake_table_builder.py b/tests/load/snowflake/test_snowflake_table_builder.py index e2bd27d18a..43d4395188 100644 --- a/tests/load/snowflake/test_snowflake_table_builder.py +++ b/tests/load/snowflake/test_snowflake_table_builder.py @@ -6,7 +6,7 @@ from dlt.common.utils import uniq_id from dlt.common.schema import Schema, utils from dlt.destinations import snowflake -from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient +from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient, SUPPORTED_HINTS from dlt.destinations.impl.snowflake.configuration import ( SnowflakeClientConfiguration, SnowflakeCredentials, @@ -66,59 +66,63 @@ def test_create_table(snowflake_client: SnowflakeClient) -> None: assert sql.strip().startswith("CREATE TABLE") assert "EVENT_TEST_TABLE" in sql - assert '"COL1" NUMBER(19,0) NOT NULL' in sql - assert '"COL2" FLOAT NOT NULL' in sql - assert '"COL3" BOOLEAN NOT NULL' in sql - assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql + assert '"COL1" NUMBER(19,0) NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL3" BOOLEAN NOT NULL' in sql + assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql assert '"COL5" VARCHAR' in sql - assert '"COL6" NUMBER(38,9) NOT NULL' in sql + assert '"COL6" NUMBER(38,9) NOT NULL' in sql assert '"COL7" BINARY' in sql assert '"COL8" NUMBER(38,0)' in sql - assert '"COL9" VARIANT NOT NULL' in sql - assert '"COL10" DATE NOT NULL' in sql + assert '"COL9" VARIANT NOT NULL' in sql + assert '"COL10" DATE NOT NULL' in sql def test_create_table_with_hints(snowflake_client: SnowflakeClient) -> None: - mod_update = deepcopy(TABLE_UPDATE) + mod_update = deepcopy(TABLE_UPDATE[:11]) + # mock hints + snowflake_client.config.create_indexes = True + snowflake_client.active_hints = SUPPORTED_HINTS mod_update[0]["primary_key"] = True + mod_update[5]["primary_key"] = True + mod_update[0]["sort"] = True + + # unique constraints are always single columns mod_update[1]["unique"] = True + mod_update[7]["unique"] = True + mod_update[4]["parent_key"] = True sql = ";".join(snowflake_client._get_table_update_sql("event_test_table", mod_update, False)) assert sql.strip().startswith("CREATE TABLE") assert "EVENT_TEST_TABLE" in sql - assert '"COL1" NUMBER(19,0) NOT NULL' in sql - assert '"COL2" FLOAT NOT NULL' in sql - assert '"COL3" BOOLEAN NOT NULL' in sql - assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql + assert '"COL1" NUMBER(19,0) NOT NULL' in sql + assert '"COL2" FLOAT UNIQUE NOT NULL' in sql + assert '"COL3" BOOLEAN NOT NULL' in sql + assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql assert '"COL5" VARCHAR' in sql - assert '"COL6" NUMBER(38,9) NOT NULL' in sql + assert '"COL6" NUMBER(38,9) NOT NULL' in sql assert '"COL7" BINARY' in sql - assert '"COL8" NUMBER(38,0)' in sql - assert '"COL9" VARIANT NOT NULL' in sql - assert '"COL10" DATE NOT NULL' in sql + assert '"COL8" NUMBER(38,0) UNIQUE' in sql + assert '"COL9" VARIANT NOT NULL' in sql + assert '"COL10" DATE NOT NULL' in sql - # same thing with indexes - snowflake_client = snowflake().client( - snowflake_client.schema, - SnowflakeClientConfiguration(create_indexes=True)._bind_dataset_name( - dataset_name="test_" + uniq_id() - ), - ) - sql_statements = snowflake_client._get_table_update_sql("event_test_table", mod_update, False) + # PRIMARY KEY constraint + assert 'CONSTRAINT "PK_EVENT_TEST_TABLE_' in sql + assert 'PRIMARY KEY ("COL1", "COL6")' in sql - for stmt in sql_statements: - sqlfluff.parse(stmt) + # generate alter + mod_update = deepcopy(TABLE_UPDATE[11:]) + mod_update[0]["primary_key"] = True + mod_update[1]["unique"] = True - assert any( - 'ADD CONSTRAINT PK_EVENT_TEST_TABLE PRIMARY KEY ("col1")' in stmt for stmt in sql_statements - ) - assert any( - 'ADD CONSTRAINT UQ_EVENT_TEST_TABLE UNIQUE ("col2")' in stmt for stmt in sql_statements - ) + sql = ";".join(snowflake_client._get_table_update_sql("event_test_table", mod_update, True)) + # PK constraint ignored for alter + assert "PRIMARY KEY" not in sql + assert '"COL2_NULL" FLOAT UNIQUE' in sql def test_alter_table(snowflake_client: SnowflakeClient) -> None: @@ -133,15 +137,15 @@ def test_alter_table(snowflake_client: SnowflakeClient) -> None: assert sql.count("ALTER TABLE") == 1 assert sql.count("ADD COLUMN") == 1 assert '"EVENT_TEST_TABLE"' in sql - assert '"COL1" NUMBER(19,0) NOT NULL' in sql - assert '"COL2" FLOAT NOT NULL' in sql - assert '"COL3" BOOLEAN NOT NULL' in sql - assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql + assert '"COL1" NUMBER(19,0) NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL3" BOOLEAN NOT NULL' in sql + assert '"COL4" TIMESTAMP_TZ NOT NULL' in sql assert '"COL5" VARCHAR' in sql - assert '"COL6" NUMBER(38,9) NOT NULL' in sql + assert '"COL6" NUMBER(38,9) NOT NULL' in sql assert '"COL7" BINARY' in sql assert '"COL8" NUMBER(38,0)' in sql - assert '"COL9" VARIANT NOT NULL' in sql + assert '"COL9" VARIANT NOT NULL' in sql assert '"COL10" DATE' in sql mod_table = deepcopy(TABLE_UPDATE) @@ -149,7 +153,7 @@ def test_alter_table(snowflake_client: SnowflakeClient) -> None: sql = snowflake_client._get_table_update_sql("event_test_table", mod_table, True)[0] assert '"COL1"' not in sql - assert '"COL2" FLOAT NOT NULL' in sql + assert '"COL2" FLOAT NOT NULL' in sql def test_create_table_case_sensitive(cs_client: SnowflakeClient) -> None: