From 7bc2163ff001b9a3299827e1d3ddf0da021f36d6 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Thu, 18 Jan 2024 01:18:56 +0100 Subject: [PATCH] Synapse destination initial commit --- .../workflows/test_destination_synapse.yml | 22 ++- dlt/common/data_writers/escape.py | 10 +- dlt/common/data_writers/writers.py | 24 +++- dlt/common/destination/capabilities.py | 1 + dlt/destinations/__init__.py | 2 + dlt/destinations/impl/mssql/configuration.py | 31 +++-- dlt/destinations/impl/mssql/sql_client.py | 7 +- dlt/destinations/impl/synapse/README.md | 58 ++++++++ dlt/destinations/impl/synapse/__init__.py | 46 +++++++ .../impl/synapse/configuration.py | 38 +++++ dlt/destinations/impl/synapse/factory.py | 51 +++++++ dlt/destinations/impl/synapse/sql_client.py | 28 ++++ dlt/destinations/impl/synapse/synapse.py | 99 +++++++++++++ dlt/destinations/insert_job_client.py | 18 ++- dlt/helpers/dbt/profiles.yml | 18 ++- poetry.lock | 3 +- pyproject.toml | 1 + tests/load/mssql/test_mssql_credentials.py | 69 +++++++--- tests/load/mssql/test_mssql_table_builder.py | 3 +- tests/load/pipeline/test_dbt_helper.py | 7 +- .../load/pipeline/test_replace_disposition.py | 6 + tests/load/synapse/__init__.py | 3 + .../synapse/test_synapse_configuration.py | 46 +++++++ .../synapse/test_synapse_table_builder.py | 130 ++++++++++++++++++ tests/load/test_job_client.py | 3 +- tests/load/test_sql_client.py | 12 +- tests/load/utils.py | 7 +- tests/utils.py | 1 + 28 files changed, 672 insertions(+), 72 deletions(-) create mode 100644 dlt/destinations/impl/synapse/README.md create mode 100644 dlt/destinations/impl/synapse/__init__.py create mode 100644 dlt/destinations/impl/synapse/configuration.py create mode 100644 dlt/destinations/impl/synapse/factory.py create mode 100644 dlt/destinations/impl/synapse/sql_client.py create mode 100644 dlt/destinations/impl/synapse/synapse.py create mode 100644 tests/load/synapse/__init__.py create mode 100644 tests/load/synapse/test_synapse_configuration.py create mode 100644 tests/load/synapse/test_synapse_table_builder.py diff --git a/.github/workflows/test_destination_synapse.yml b/.github/workflows/test_destination_synapse.yml index 83800fa789..ecd890d32a 100644 --- a/.github/workflows/test_destination_synapse.yml +++ b/.github/workflows/test_destination_synapse.yml @@ -5,7 +5,6 @@ on: branches: - master - devel - workflow_dispatch: env: @@ -18,19 +17,14 @@ env: ALL_FILESYSTEM_DRIVERS: "[\"memory\"]" jobs: - - build: - runs-on: ubuntu-latest - - steps: - - name: Check source branch name - run: | - if [[ "${{ github.head_ref }}" != "synapse" ]]; then - exit 1 - fi + get_docs_changes: + uses: ./.github/workflows/get_docs_changes.yml + if: ${{ !github.event.pull_request.head.repo.fork }} run_loader: name: Tests Synapse loader + needs: get_docs_changes + if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' strategy: fail-fast: false matrix: @@ -69,17 +63,17 @@ jobs: key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}-gcp - name: Install dependencies - run: poetry install --no-interaction -E synapse -E s3 -E gs -E az --with sentry-sdk --with pipeline + run: poetry install --no-interaction -E synapse -E parquet --with sentry-sdk --with pipeline - name: create secrets.toml run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml - run: | - poetry run pytest tests/load --ignore tests/load/pipeline/test_dbt_helper.py + poetry run pytest tests/load if: runner.os != 'Windows' name: Run tests Linux/MAC - run: | - poetry run pytest tests/load --ignore tests/load/pipeline/test_dbt_helper.py + poetry run pytest tests/load if: runner.os == 'Windows' name: Run tests Windows shell: cmd diff --git a/dlt/common/data_writers/escape.py b/dlt/common/data_writers/escape.py index 5bf8f29ccb..b56a0d8f19 100644 --- a/dlt/common/data_writers/escape.py +++ b/dlt/common/data_writers/escape.py @@ -98,8 +98,14 @@ def escape_mssql_literal(v: Any) -> Any: json.dumps(v), prefix="N'", escape_dict=MS_SQL_ESCAPE_DICT, escape_re=MS_SQL_ESCAPE_RE ) if isinstance(v, bytes): - base_64_string = base64.b64encode(v).decode("ascii") - return f"""CAST('' AS XML).value('xs:base64Binary("{base_64_string}")', 'VARBINARY(MAX)')""" + # 8000 is the max value for n in VARBINARY(n) + # https://learn.microsoft.com/en-us/sql/t-sql/data-types/binary-and-varbinary-transact-sql + if len(v) <= 8000: + n = len(v) + else: + n = "MAX" + return f"CONVERT(VARBINARY({n}), '{v.hex()}', 2)" + if isinstance(v, bool): return str(int(v)) if v is None: diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 0f9ff09259..0f3640da1e 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -175,18 +175,29 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: # do not write INSERT INTO command, this must be added together with table name by the loader self._f.write("INSERT INTO {}(") self._f.write(",".join(map(self._caps.escape_identifier, headers))) - self._f.write(")\nVALUES\n") + if self._caps.insert_values_writer_type == "default": + self._f.write(")\nVALUES\n") + elif self._caps.insert_values_writer_type == "select_union": + self._f.write(")\n") def write_data(self, rows: Sequence[Any]) -> None: super().write_data(rows) - def write_row(row: StrAny) -> None: + def write_row(row: StrAny, last_row: bool = False) -> None: output = ["NULL"] * len(self._headers_lookup) for n, v in row.items(): output[self._headers_lookup[n]] = self._caps.escape_literal(v) - self._f.write("(") - self._f.write(",".join(output)) - self._f.write(")") + if self._caps.insert_values_writer_type == "default": + self._f.write("(") + self._f.write(",".join(output)) + self._f.write(")") + if not last_row: + self._f.write(",\n") + elif self._caps.insert_values_writer_type == "select_union": + self._f.write("SELECT ") + self._f.write(",".join(output)) + if not last_row: + self._f.write("\nUNION ALL\n") # if next chunk add separator if self._chunks_written > 0: @@ -195,10 +206,9 @@ def write_row(row: StrAny) -> None: # write rows for row in rows[:-1]: write_row(row) - self._f.write(",\n") # write last row without separator so we can write footer eventually - write_row(rows[-1]) + write_row(rows[-1], last_row=True) self._chunks_written += 1 def write_footer(self) -> None: diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index 2596b2bf99..08c7a31388 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -52,6 +52,7 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): schema_supports_numeric_precision: bool = True timestamp_precision: int = 6 max_rows_per_insert: Optional[int] = None + insert_values_writer_type: str = "default" # do not allow to create default value, destination caps must be always explicitly inserted into container can_create_default: ClassVar[bool] = False diff --git a/dlt/destinations/__init__.py b/dlt/destinations/__init__.py index 980c4ce7f2..775778cd4a 100644 --- a/dlt/destinations/__init__.py +++ b/dlt/destinations/__init__.py @@ -10,6 +10,7 @@ from dlt.destinations.impl.qdrant.factory import qdrant from dlt.destinations.impl.motherduck.factory import motherduck from dlt.destinations.impl.weaviate.factory import weaviate +from dlt.destinations.impl.synapse.factory import synapse __all__ = [ @@ -25,4 +26,5 @@ "qdrant", "motherduck", "weaviate", + "synapse", ] diff --git a/dlt/destinations/impl/mssql/configuration.py b/dlt/destinations/impl/mssql/configuration.py index f33aca4b82..f00998cfb2 100644 --- a/dlt/destinations/impl/mssql/configuration.py +++ b/dlt/destinations/impl/mssql/configuration.py @@ -1,4 +1,4 @@ -from typing import Final, ClassVar, Any, List, Optional, TYPE_CHECKING +from typing import Final, ClassVar, Any, List, Dict, Optional, TYPE_CHECKING from sqlalchemy.engine import URL from dlt.common.configuration import configspec @@ -10,9 +10,6 @@ from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration -SUPPORTED_DRIVERS = ["ODBC Driver 18 for SQL Server", "ODBC Driver 17 for SQL Server"] - - @configspec class MsSqlCredentials(ConnectionStringCredentials): drivername: Final[str] = "mssql" # type: ignore @@ -24,22 +21,27 @@ class MsSqlCredentials(ConnectionStringCredentials): __config_gen_annotations__: ClassVar[List[str]] = ["port", "connect_timeout"] + SUPPORTED_DRIVERS: ClassVar[List[str]] = [ + "ODBC Driver 18 for SQL Server", + "ODBC Driver 17 for SQL Server", + ] + def parse_native_representation(self, native_value: Any) -> None: # TODO: Support ODBC connection string or sqlalchemy URL super().parse_native_representation(native_value) if self.query is not None: self.query = {k.lower(): v for k, v in self.query.items()} # Make case-insensitive. - if "driver" in self.query and self.query.get("driver") not in SUPPORTED_DRIVERS: - raise SystemConfigurationException( - f"""The specified driver "{self.query.get('driver')}" is not supported.""" - f" Choose one of the supported drivers: {', '.join(SUPPORTED_DRIVERS)}." - ) self.driver = self.query.get("driver", self.driver) self.connect_timeout = int(self.query.get("connect_timeout", self.connect_timeout)) if not self.is_partial(): self.resolve() def on_resolved(self) -> None: + if self.driver not in self.SUPPORTED_DRIVERS: + raise SystemConfigurationException( + f"""The specified driver "{self.driver}" is not supported.""" + f" Choose one of the supported drivers: {', '.join(self.SUPPORTED_DRIVERS)}." + ) self.database = self.database.lower() def to_url(self) -> URL: @@ -55,20 +57,21 @@ def on_partial(self) -> None: def _get_driver(self) -> str: if self.driver: return self.driver + # Pick a default driver if available import pyodbc available_drivers = pyodbc.drivers() - for d in SUPPORTED_DRIVERS: + for d in self.SUPPORTED_DRIVERS: if d in available_drivers: return d docs_url = "https://learn.microsoft.com/en-us/sql/connect/odbc/download-odbc-driver-for-sql-server?view=sql-server-ver16" raise SystemConfigurationException( f"No supported ODBC driver found for MS SQL Server. See {docs_url} for information on" - f" how to install the '{SUPPORTED_DRIVERS[0]}' on your platform." + f" how to install the '{self.SUPPORTED_DRIVERS[0]}' on your platform." ) - def to_odbc_dsn(self) -> str: + def _get_odbc_dsn_dict(self) -> Dict[str, Any]: params = { "DRIVER": self.driver, "SERVER": f"{self.host},{self.port}", @@ -78,6 +81,10 @@ def to_odbc_dsn(self) -> str: } if self.query is not None: params.update({k.upper(): v for k, v in self.query.items()}) + return params + + def to_odbc_dsn(self) -> str: + params = self._get_odbc_dsn_dict() return ";".join([f"{k}={v}" for k, v in params.items()]) diff --git a/dlt/destinations/impl/mssql/sql_client.py b/dlt/destinations/impl/mssql/sql_client.py index 427518feeb..2ddd56350e 100644 --- a/dlt/destinations/impl/mssql/sql_client.py +++ b/dlt/destinations/impl/mssql/sql_client.py @@ -106,8 +106,8 @@ def drop_dataset(self) -> None: ) table_names = [row[0] for row in rows] self.drop_tables(*table_names) - - self.execute_sql("DROP SCHEMA IF EXISTS %s;" % self.fully_qualified_dataset_name()) + # Drop schema + self._drop_schema() def _drop_views(self, *tables: str) -> None: if not tables: @@ -117,6 +117,9 @@ def _drop_views(self, *tables: str) -> None: ] self.execute_fragments(statements) + def _drop_schema(self) -> None: + self.execute_sql("DROP SCHEMA IF EXISTS %s;" % self.fully_qualified_dataset_name()) + def execute_sql( self, sql: AnyStr, *args: Any, **kwargs: Any ) -> Optional[Sequence[Sequence[Any]]]: diff --git a/dlt/destinations/impl/synapse/README.md b/dlt/destinations/impl/synapse/README.md new file mode 100644 index 0000000000..b133faf67a --- /dev/null +++ b/dlt/destinations/impl/synapse/README.md @@ -0,0 +1,58 @@ +# Set up loader user +Execute the following SQL statements to set up the [loader](https://learn.microsoft.com/en-us/azure/synapse-analytics/sql/data-loading-best-practices#create-a-loading-user) user: +```sql +-- on master database + +CREATE LOGIN loader WITH PASSWORD = 'YOUR_LOADER_PASSWORD_HERE'; +``` + +```sql +-- on minipool database + +CREATE USER loader FOR LOGIN loader; + +-- DDL permissions +GRANT CREATE TABLE ON DATABASE :: minipool TO loader; +GRANT CREATE VIEW ON DATABASE :: minipool TO loader; + +-- DML permissions +GRANT SELECT ON DATABASE :: minipool TO loader; +GRANT INSERT ON DATABASE :: minipool TO loader; +GRANT ADMINISTER DATABASE BULK OPERATIONS TO loader; +``` + +```sql +-- https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-workload-isolation + +CREATE WORKLOAD GROUP DataLoads +WITH ( + MIN_PERCENTAGE_RESOURCE = 0 + ,CAP_PERCENTAGE_RESOURCE = 50 + ,REQUEST_MIN_RESOURCE_GRANT_PERCENT = 25 +); + +CREATE WORKLOAD CLASSIFIER [wgcELTLogin] +WITH ( + WORKLOAD_GROUP = 'DataLoads' + ,MEMBERNAME = 'loader' +); +``` + +# config.toml +```toml +[destination.synapse.credentials] +database = "minipool" +username = "loader" +host = "dlt-synapse-ci.sql.azuresynapse.net" +port = 1433 +driver = "ODBC Driver 18 for SQL Server" + +[destination.synapse] +create_indexes = false +``` + +# secrets.toml +```toml +[destination.synapse.credentials] +password = "YOUR_LOADER_PASSWORD_HERE" +``` \ No newline at end of file diff --git a/dlt/destinations/impl/synapse/__init__.py b/dlt/destinations/impl/synapse/__init__.py new file mode 100644 index 0000000000..175b011186 --- /dev/null +++ b/dlt/destinations/impl/synapse/__init__.py @@ -0,0 +1,46 @@ +from dlt.common.data_writers.escape import escape_postgres_identifier, escape_mssql_literal +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.wei import EVM_DECIMAL_PRECISION + + +def capabilities() -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + + caps.preferred_loader_file_format = "insert_values" + caps.supported_loader_file_formats = ["insert_values"] + caps.preferred_staging_file_format = None + caps.supported_staging_file_formats = [] + + caps.insert_values_writer_type = "select_union" # https://stackoverflow.com/a/77014299 + + caps.escape_identifier = escape_postgres_identifier + caps.escape_literal = escape_mssql_literal + + # Synapse has a max precision of 38 + # https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-azure-sql-data-warehouse?view=aps-pdw-2016-au7#DataTypes + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + + # https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-azure-sql-data-warehouse?view=aps-pdw-2016-au7#LimitationsRestrictions + caps.max_identifier_length = 128 + caps.max_column_identifier_length = 128 + + # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-service-capacity-limits#queries + caps.max_query_length = 65536 * 4096 + caps.is_max_query_length_in_bytes = True + + # nvarchar(max) can store 2 GB + # https://learn.microsoft.com/en-us/sql/t-sql/data-types/nchar-and-nvarchar-transact-sql?view=sql-server-ver16#nvarchar---n--max-- + caps.max_text_data_type_length = 2 * 1024 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = True + + # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-develop-transactions + caps.supports_transactions = True + caps.supports_ddl_transactions = False + + # datetimeoffset can store 7 digits for fractional seconds + # https://learn.microsoft.com/en-us/sql/t-sql/data-types/datetimeoffset-transact-sql?view=sql-server-ver16 + caps.timestamp_precision = 7 + + return caps diff --git a/dlt/destinations/impl/synapse/configuration.py b/dlt/destinations/impl/synapse/configuration.py new file mode 100644 index 0000000000..0596cc2c46 --- /dev/null +++ b/dlt/destinations/impl/synapse/configuration.py @@ -0,0 +1,38 @@ +from typing import Final, Any, List, Dict, Optional, ClassVar + +from dlt.common.configuration import configspec + +from dlt.destinations.impl.mssql.configuration import ( + MsSqlCredentials, + MsSqlClientConfiguration, +) +from dlt.destinations.impl.mssql.configuration import MsSqlCredentials + + +@configspec +class SynapseCredentials(MsSqlCredentials): + drivername: Final[str] = "synapse" # type: ignore + + # LongAsMax keyword got introduced in ODBC Driver 18 for SQL Server. + SUPPORTED_DRIVERS: ClassVar[List[str]] = ["ODBC Driver 18 for SQL Server"] + + def _get_odbc_dsn_dict(self) -> Dict[str, Any]: + params = super()._get_odbc_dsn_dict() + # Long types (text, ntext, image) are not supported on Synapse. + # Convert to max types using LongAsMax keyword. + # https://stackoverflow.com/a/57926224 + params["LONGASMAX"] = "yes" + return params + + +@configspec +class SynapseClientConfiguration(MsSqlClientConfiguration): + destination_type: Final[str] = "synapse" # type: ignore + credentials: SynapseCredentials + + # Determines if `primary_key` and `unique` column hints are applied. + # Set to False by default because the PRIMARY KEY and UNIQUE constraints + # are tricky in Synapse: they are NOT ENFORCED and can lead to innacurate + # results if the user does not ensure all column values are unique. + # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-table-constraints + create_indexes: bool = False diff --git a/dlt/destinations/impl/synapse/factory.py b/dlt/destinations/impl/synapse/factory.py new file mode 100644 index 0000000000..fa7facc0ca --- /dev/null +++ b/dlt/destinations/impl/synapse/factory.py @@ -0,0 +1,51 @@ +import typing as t + +from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.destinations.impl.synapse import capabilities + +from dlt.destinations.impl.synapse.configuration import ( + SynapseCredentials, + SynapseClientConfiguration, +) + +if t.TYPE_CHECKING: + from dlt.destinations.impl.synapse.synapse import SynapseClient + + +class synapse(Destination[SynapseClientConfiguration, "SynapseClient"]): + spec = SynapseClientConfiguration + + def capabilities(self) -> DestinationCapabilitiesContext: + return capabilities() + + @property + def client_class(self) -> t.Type["SynapseClient"]: + from dlt.destinations.impl.synapse.synapse import SynapseClient + + return SynapseClient + + def __init__( + self, + credentials: t.Union[SynapseCredentials, t.Dict[str, t.Any], str] = None, + create_indexes: bool = False, + destination_name: t.Optional[str] = None, + environment: t.Optional[str] = None, + **kwargs: t.Any, + ) -> None: + """Configure the Synapse destination to use in a pipeline. + + All arguments provided here supersede other configuration sources such as environment variables and dlt config files. + + Args: + credentials: Credentials to connect to the Synapse dedicated pool. Can be an instance of `SynapseCredentials` or + a connection string in the format `synapse://user:password@host:port/database` + create_indexes: Should unique indexes be created, defaults to False + **kwargs: Additional arguments passed to the destination config + """ + super().__init__( + credentials=credentials, + create_indexes=create_indexes, + destination_name=destination_name, + environment=environment, + **kwargs, + ) diff --git a/dlt/destinations/impl/synapse/sql_client.py b/dlt/destinations/impl/synapse/sql_client.py new file mode 100644 index 0000000000..089c58e57c --- /dev/null +++ b/dlt/destinations/impl/synapse/sql_client.py @@ -0,0 +1,28 @@ +from typing import ClassVar +from contextlib import suppress + +from dlt.common.destination import DestinationCapabilitiesContext + +from dlt.destinations.impl.mssql.sql_client import PyOdbcMsSqlClient +from dlt.destinations.impl.mssql.configuration import MsSqlCredentials +from dlt.destinations.impl.synapse import capabilities +from dlt.destinations.impl.synapse.configuration import SynapseCredentials + +from dlt.destinations.exceptions import DatabaseUndefinedRelation + + +class SynapseSqlClient(PyOdbcMsSqlClient): + capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() + + def drop_tables(self, *tables: str) -> None: + if not tables: + return + # Synapse does not support DROP TABLE IF EXISTS. + # Workaround: use DROP TABLE and suppress non-existence errors. + statements = [f"DROP TABLE {self.make_qualified_table_name(table)};" for table in tables] + with suppress(DatabaseUndefinedRelation): + self.execute_fragments(statements) + + def _drop_schema(self) -> None: + # Synapse does not support DROP SCHEMA IF EXISTS. + self.execute_sql("DROP SCHEMA %s;" % self.fully_qualified_dataset_name()) diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py new file mode 100644 index 0000000000..18d1fa81d4 --- /dev/null +++ b/dlt/destinations/impl/synapse/synapse.py @@ -0,0 +1,99 @@ +from typing import ClassVar, Sequence, List, Dict, Any, Optional +from copy import deepcopy + +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.destination.reference import SupportsStagingDestination, NewLoadJob + +from dlt.common.schema import TTableSchema, TColumnSchema, Schema, TColumnHint +from dlt.common.schema.typing import TTableSchemaColumns + +from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams +from dlt.destinations.sql_client import SqlClientBase +from dlt.destinations.insert_job_client import InsertValuesJobClient +from dlt.destinations.job_client_impl import SqlJobClientBase + +from dlt.destinations.impl.mssql.mssql import MsSqlTypeMapper, MsSqlClient, HINT_TO_MSSQL_ATTR + +from dlt.destinations.impl.synapse import capabilities +from dlt.destinations.impl.synapse.sql_client import SynapseSqlClient +from dlt.destinations.impl.synapse.configuration import SynapseClientConfiguration + + +HINT_TO_SYNAPSE_ATTR: Dict[TColumnHint, str] = { + "primary_key": "PRIMARY KEY NONCLUSTERED NOT ENFORCED", + "unique": "UNIQUE NOT ENFORCED", +} + + +class SynapseClient(MsSqlClient, SupportsStagingDestination): + capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() + + def __init__(self, schema: Schema, config: SynapseClientConfiguration) -> None: + sql_client = SynapseSqlClient(config.normalize_dataset_name(schema), config.credentials) + InsertValuesJobClient.__init__(self, schema, config, sql_client) + self.config: SynapseClientConfiguration = config + self.sql_client = sql_client + self.type_mapper = MsSqlTypeMapper(self.capabilities) + + self.active_hints = deepcopy(HINT_TO_SYNAPSE_ATTR) + if not self.config.create_indexes: + self.active_hints.pop("primary_key", None) + self.active_hints.pop("unique", None) + + def _get_table_update_sql( + self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool + ) -> List[str]: + _sql_result = SqlJobClientBase._get_table_update_sql( + self, table_name, new_columns, generate_alter + ) + if not generate_alter: + # Append WITH clause to create heap table instead of default + # columnstore table. Heap tables are a more robust choice, because + # columnstore tables do not support varchar(max), nvarchar(max), + # and varbinary(max). + # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-index + sql_result = [_sql_result[0] + "\n WITH ( HEAP );"] + else: + sql_result = _sql_result + return sql_result + + def _create_replace_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[NewLoadJob]: + if self.config.replace_strategy == "staging-optimized": + return [SynapseStagingCopyJob.from_table_chain(table_chain, self.sql_client)] + return super()._create_replace_followup_jobs(table_chain) + + +class SynapseStagingCopyJob(SqlStagingCopyJob): + @classmethod + def generate_sql( + cls, + table_chain: Sequence[TTableSchema], + sql_client: SqlClientBase[Any], + params: Optional[SqlJobParams] = None, + ) -> List[str]: + sql: List[str] = [] + for table in table_chain: + with sql_client.with_staging_dataset(staging=True): + staging_table_name = sql_client.make_qualified_table_name(table["name"]) + table_name = sql_client.make_qualified_table_name(table["name"]) + # drop destination table + sql.append(f"DROP TABLE {table_name};") + # moving staging table to destination schema + sql.append( + f"ALTER SCHEMA {sql_client.fully_qualified_dataset_name()} TRANSFER" + f" {staging_table_name};" + ) + # recreate staging table + # In some cases, when multiple instances of this CTAS query are + # executed concurrently, Synapse suspends the queries and hangs. + # This can be prevented by setting the env var LOAD__WORKERS = "1". + sql.append( + f"CREATE TABLE {staging_table_name}" + " WITH ( DISTRIBUTION = ROUND_ROBIN, HEAP )" # distribution must be explicitly specified with CTAS + f" AS SELECT * FROM {table_name}" + " WHERE 1 = 0;" # no data, table structure only + ) + + return sql diff --git a/dlt/destinations/insert_job_client.py b/dlt/destinations/insert_job_client.py index 678ba43bcc..776176078e 100644 --- a/dlt/destinations/insert_job_client.py +++ b/dlt/destinations/insert_job_client.py @@ -36,9 +36,10 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st # the procedure below will split the inserts into max_query_length // 2 packs with FileStorage.open_zipsafe_ro(file_path, "r", encoding="utf-8") as f: header = f.readline() - values_mark = f.readline() - # properly formatted file has a values marker at the beginning - assert values_mark == "VALUES\n" + if self._sql_client.capabilities.insert_values_writer_type == "default": + # properly formatted file has a values marker at the beginning + values_mark = f.readline() + assert values_mark == "VALUES\n" max_rows = self._sql_client.capabilities.max_rows_per_insert @@ -67,7 +68,9 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st # Chunk by max_rows - 1 for simplicity because one more row may be added for chunk in chunks(values_rows, max_rows - 1): processed += len(chunk) - insert_sql.extend([header.format(qualified_table_name), values_mark]) + insert_sql.append(header.format(qualified_table_name)) + if self._sql_client.capabilities.insert_values_writer_type == "default": + insert_sql.append(values_mark) if processed == len_rows: # On the last chunk we need to add the extra row read insert_sql.append("".join(chunk) + until_nl) @@ -76,7 +79,12 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st insert_sql.append("".join(chunk).strip()[:-1] + ";\n") else: # otherwise write all content in a single INSERT INTO - insert_sql.extend([header.format(qualified_table_name), values_mark, content]) + if self._sql_client.capabilities.insert_values_writer_type == "default": + insert_sql.extend( + [header.format(qualified_table_name), values_mark, content] + ) + elif self._sql_client.capabilities.insert_values_writer_type == "select_union": + insert_sql.extend([header.format(qualified_table_name), content]) if until_nl: insert_sql.append(until_nl) diff --git a/dlt/helpers/dbt/profiles.yml b/dlt/helpers/dbt/profiles.yml index 2414222cbd..7031f5de2c 100644 --- a/dlt/helpers/dbt/profiles.yml +++ b/dlt/helpers/dbt/profiles.yml @@ -141,4 +141,20 @@ athena: schema: "{{ var('destination_dataset_name', var('source_dataset_name')) }}" database: "{{ env_var('DLT__AWS_DATA_CATALOG') }}" # aws_profile_name: "{{ env_var('DLT__CREDENTIALS__PROFILE_NAME', '') }}" - work_group: "{{ env_var('DLT__ATHENA_WORK_GROUP', '') }}" \ No newline at end of file + work_group: "{{ env_var('DLT__ATHENA_WORK_GROUP', '') }}" + + +# commented out because dbt for Synapse isn't currently properly supported. +# Leave config here for potential future use. +# synapse: +# target: analytics +# outputs: +# analytics: +# type: synapse +# driver: "{{ env_var('DLT__CREDENTIALS__DRIVER') }}" +# server: "{{ env_var('DLT__CREDENTIALS__HOST') }}" +# port: "{{ env_var('DLT__CREDENTIALS__PORT') | as_number }}" +# database: "{{ env_var('DLT__CREDENTIALS__DATABASE') }}" +# schema: "{{ var('destination_dataset_name', var('source_dataset_name')) }}" +# user: "{{ env_var('DLT__CREDENTIALS__USERNAME') }}" +# password: "{{ env_var('DLT__CREDENTIALS__PASSWORD') }}" \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index c5da40c604..4d079fc44d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -8466,9 +8466,10 @@ qdrant = ["qdrant-client"] redshift = ["psycopg2-binary", "psycopg2cffi"] s3 = ["botocore", "s3fs"] snowflake = ["snowflake-connector-python"] +synapse = ["pyodbc"] weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "cf751b2e1e9c66efde0a11774b5204e3206a14fd04ba4c79b2d37e38db5367ad" +content-hash = "26c595a857f17a5cbdb348f165c267d8910412325be4e522d0e91224c7fec588" diff --git a/pyproject.toml b/pyproject.toml index 6436ec23a7..d9d5858674 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,6 +96,7 @@ cli = ["pipdeptree", "cron-descriptor"] athena = ["pyathena", "pyarrow", "s3fs", "botocore"] weaviate = ["weaviate-client"] mssql = ["pyodbc"] +synapse = ["pyodbc"] qdrant = ["qdrant-client"] [tool.poetry.scripts] diff --git a/tests/load/mssql/test_mssql_credentials.py b/tests/load/mssql/test_mssql_credentials.py index 0098d228f1..0e38791f22 100644 --- a/tests/load/mssql/test_mssql_credentials.py +++ b/tests/load/mssql/test_mssql_credentials.py @@ -1,18 +1,35 @@ import pyodbc import pytest -from dlt.common.configuration import resolve_configuration +from dlt.common.configuration import resolve_configuration, ConfigFieldMissingException from dlt.common.exceptions import SystemConfigurationException -from dlt.destinations.impl.mssql.configuration import MsSqlCredentials, SUPPORTED_DRIVERS +from dlt.destinations.impl.mssql.configuration import MsSqlCredentials -def test_parse_native_representation_unsupported_driver_specified() -> None: +def test_mssql_credentials_defaults() -> None: + creds = MsSqlCredentials() + assert creds.port == 1433 + assert creds.connect_timeout == 15 + assert MsSqlCredentials.__config_gen_annotations__ == ["port", "connect_timeout"] + # port should be optional + resolve_configuration(creds, explicit_value="mssql://loader:loader@localhost/dlt_data") + assert creds.port == 1433 + + +def test_parse_native_representation() -> None: # Case: unsupported driver specified. with pytest.raises(SystemConfigurationException): resolve_configuration( MsSqlCredentials( - "mssql://test_user:test_password@sql.example.com:12345/test_db?DRIVER=foo" + "mssql://test_user:test_pwd@sql.example.com/test_db?DRIVER=ODBC+Driver+13+for+SQL+Server" + ) + ) + # Case: password not specified. + with pytest.raises(ConfigFieldMissingException): + resolve_configuration( + MsSqlCredentials( + "mssql://test_user@sql.example.com/test_db?DRIVER=ODBC+Driver+18+for+SQL+Server" ) ) @@ -21,33 +38,49 @@ def test_to_odbc_dsn_supported_driver_specified() -> None: # Case: supported driver specified — ODBC Driver 18 for SQL Server. creds = resolve_configuration( MsSqlCredentials( - "mssql://test_user:test_password@sql.example.com:12345/test_db?DRIVER=ODBC+Driver+18+for+SQL+Server" + "mssql://test_user:test_pwd@sql.example.com/test_db?DRIVER=ODBC+Driver+18+for+SQL+Server" ) ) dsn = creds.to_odbc_dsn() result = {k: v for k, v in (param.split("=") for param in dsn.split(";"))} assert result == { "DRIVER": "ODBC Driver 18 for SQL Server", - "SERVER": "sql.example.com,12345", + "SERVER": "sql.example.com,1433", "DATABASE": "test_db", "UID": "test_user", - "PWD": "test_password", + "PWD": "test_pwd", } # Case: supported driver specified — ODBC Driver 17 for SQL Server. creds = resolve_configuration( MsSqlCredentials( - "mssql://test_user:test_password@sql.example.com:12345/test_db?DRIVER=ODBC+Driver+17+for+SQL+Server" + "mssql://test_user:test_pwd@sql.example.com/test_db?DRIVER=ODBC+Driver+17+for+SQL+Server" ) ) dsn = creds.to_odbc_dsn() result = {k: v for k, v in (param.split("=") for param in dsn.split(";"))} assert result == { "DRIVER": "ODBC Driver 17 for SQL Server", + "SERVER": "sql.example.com,1433", + "DATABASE": "test_db", + "UID": "test_user", + "PWD": "test_pwd", + } + + # Case: port and supported driver specified. + creds = resolve_configuration( + MsSqlCredentials( + "mssql://test_user:test_pwd@sql.example.com:12345/test_db?DRIVER=ODBC+Driver+18+for+SQL+Server" + ) + ) + dsn = creds.to_odbc_dsn() + result = {k: v for k, v in (param.split("=") for param in dsn.split(";"))} + assert result == { + "DRIVER": "ODBC Driver 18 for SQL Server", "SERVER": "sql.example.com,12345", "DATABASE": "test_db", "UID": "test_user", - "PWD": "test_password", + "PWD": "test_pwd", } @@ -55,7 +88,7 @@ def test_to_odbc_dsn_arbitrary_keys_specified() -> None: # Case: arbitrary query keys (and supported driver) specified. creds = resolve_configuration( MsSqlCredentials( - "mssql://test_user:test_password@sql.example.com:12345/test_db?FOO=a&BAR=b&DRIVER=ODBC+Driver+18+for+SQL+Server" + "mssql://test_user:test_pwd@sql.example.com:12345/test_db?FOO=a&BAR=b&DRIVER=ODBC+Driver+18+for+SQL+Server" ) ) dsn = creds.to_odbc_dsn() @@ -65,7 +98,7 @@ def test_to_odbc_dsn_arbitrary_keys_specified() -> None: "SERVER": "sql.example.com,12345", "DATABASE": "test_db", "UID": "test_user", - "PWD": "test_password", + "PWD": "test_pwd", "FOO": "a", "BAR": "b", } @@ -73,7 +106,7 @@ def test_to_odbc_dsn_arbitrary_keys_specified() -> None: # Case: arbitrary capitalization. creds = resolve_configuration( MsSqlCredentials( - "mssql://test_user:test_password@sql.example.com:12345/test_db?FOO=a&bar=b&Driver=ODBC+Driver+18+for+SQL+Server" + "mssql://test_user:test_pwd@sql.example.com:12345/test_db?FOO=a&bar=b&Driver=ODBC+Driver+18+for+SQL+Server" ) ) dsn = creds.to_odbc_dsn() @@ -83,30 +116,30 @@ def test_to_odbc_dsn_arbitrary_keys_specified() -> None: "SERVER": "sql.example.com,12345", "DATABASE": "test_db", "UID": "test_user", - "PWD": "test_password", + "PWD": "test_pwd", "FOO": "a", "BAR": "b", } -available_drivers = [d for d in pyodbc.drivers() if d in SUPPORTED_DRIVERS] +available_drivers = [d for d in pyodbc.drivers() if d in MsSqlCredentials.SUPPORTED_DRIVERS] @pytest.mark.skipif(not available_drivers, reason="no supported driver available") def test_to_odbc_dsn_driver_not_specified() -> None: # Case: driver not specified, but supported driver is available. creds = resolve_configuration( - MsSqlCredentials("mssql://test_user:test_password@sql.example.com:12345/test_db") + MsSqlCredentials("mssql://test_user:test_pwd@sql.example.com/test_db") ) dsn = creds.to_odbc_dsn() result = {k: v for k, v in (param.split("=") for param in dsn.split(";"))} assert result in [ { "DRIVER": d, - "SERVER": "sql.example.com,12345", + "SERVER": "sql.example.com,1433", "DATABASE": "test_db", "UID": "test_user", - "PWD": "test_password", + "PWD": "test_pwd", } - for d in SUPPORTED_DRIVERS + for d in MsSqlCredentials.SUPPORTED_DRIVERS ] diff --git a/tests/load/mssql/test_mssql_table_builder.py b/tests/load/mssql/test_mssql_table_builder.py index f7e0ce53ff..039ce99113 100644 --- a/tests/load/mssql/test_mssql_table_builder.py +++ b/tests/load/mssql/test_mssql_table_builder.py @@ -1,11 +1,10 @@ import pytest -from copy import deepcopy import sqlfluff from dlt.common.utils import uniq_id from dlt.common.schema import Schema -pytest.importorskip("dlt.destinations.mssql.mssql", reason="MSSQL ODBC driver not installed") +pytest.importorskip("dlt.destinations.impl.mssql.mssql", reason="MSSQL ODBC driver not installed") from dlt.destinations.impl.mssql.mssql import MsSqlClient from dlt.destinations.impl.mssql.configuration import MsSqlClientConfiguration, MsSqlCredentials diff --git a/tests/load/pipeline/test_dbt_helper.py b/tests/load/pipeline/test_dbt_helper.py index 11f59d5276..e919409311 100644 --- a/tests/load/pipeline/test_dbt_helper.py +++ b/tests/load/pipeline/test_dbt_helper.py @@ -37,6 +37,8 @@ def test_run_jaffle_package( pytest.skip( "dbt-athena requires database to be created and we don't do it in case of Jaffle" ) + if not destination_config.supports_dbt: + pytest.skip("dbt is not supported for this destination configuration") pipeline = destination_config.setup_pipeline("jaffle_jaffle", full_refresh=True) # get runner, pass the env from fixture dbt = dlt.dbt.package(pipeline, "https://github.com/dbt-labs/jaffle_shop.git", venv=dbt_venv) @@ -55,9 +57,10 @@ def test_run_jaffle_package( assert all(r.status == "pass" for r in tests) # get and display dataframe with customers - customers = select_data(pipeline, "SELECT * FROM customers") + qual_name = pipeline.sql_client().make_qualified_table_name + customers = select_data(pipeline, f"SELECT * FROM {qual_name('customers')}") assert len(customers) == 100 - orders = select_data(pipeline, "SELECT * FROM orders") + orders = select_data(pipeline, f"SELECT * FROM {qual_name('orders')}") assert len(orders) == 99 diff --git a/tests/load/pipeline/test_replace_disposition.py b/tests/load/pipeline/test_replace_disposition.py index c6db91efff..1dde56a6b1 100644 --- a/tests/load/pipeline/test_replace_disposition.py +++ b/tests/load/pipeline/test_replace_disposition.py @@ -264,6 +264,12 @@ def test_replace_table_clearing( # use staging tables for replace os.environ["DESTINATION__REPLACE_STRATEGY"] = replace_strategy + if destination_config.destination == "synapse" and replace_strategy == "staging-optimized": + # The "staging-optimized" replace strategy makes Synapse suspend the CTAS + # queries used to recreate the staging table, and hang, when the number + # of load workers is greater than 1. + os.environ["LOAD__WORKERS"] = "1" + pipeline = destination_config.setup_pipeline( "test_replace_table_clearing", dataset_name="test_replace_table_clearing", full_refresh=True ) diff --git a/tests/load/synapse/__init__.py b/tests/load/synapse/__init__.py new file mode 100644 index 0000000000..34119d38cb --- /dev/null +++ b/tests/load/synapse/__init__.py @@ -0,0 +1,3 @@ +from tests.utils import skip_if_not_active + +skip_if_not_active("synapse") diff --git a/tests/load/synapse/test_synapse_configuration.py b/tests/load/synapse/test_synapse_configuration.py new file mode 100644 index 0000000000..4055cbab38 --- /dev/null +++ b/tests/load/synapse/test_synapse_configuration.py @@ -0,0 +1,46 @@ +import pytest + +from dlt.common.configuration import resolve_configuration +from dlt.common.exceptions import SystemConfigurationException + +from dlt.destinations.impl.synapse.configuration import ( + SynapseClientConfiguration, + SynapseCredentials, +) + + +def test_synapse_configuration() -> None: + # By default, unique indexes should not be created. + assert SynapseClientConfiguration().create_indexes is False + + +def test_parse_native_representation() -> None: + # Case: unsupported driver specified. + with pytest.raises(SystemConfigurationException): + resolve_configuration( + SynapseCredentials( + "synapse://test_user:test_pwd@test.sql.azuresynapse.net/test_db?DRIVER=ODBC+Driver+17+for+SQL+Server" + ) + ) + + +def test_to_odbc_dsn_longasmax() -> None: + # Case: LONGASMAX not specified in query (this is the expected scenario). + creds = resolve_configuration( + SynapseCredentials( + "synapse://test_user:test_pwd@test.sql.azuresynapse.net/test_db?DRIVER=ODBC+Driver+18+for+SQL+Server" + ) + ) + dsn = creds.to_odbc_dsn() + result = {k: v for k, v in (param.split("=") for param in dsn.split(";"))} + assert result["LONGASMAX"] == "yes" + + # Case: LONGASMAX specified in query; specified value should be overridden. + creds = resolve_configuration( + SynapseCredentials( + "synapse://test_user:test_pwd@test.sql.azuresynapse.net/test_db?DRIVER=ODBC+Driver+18+for+SQL+Server&LONGASMAX=no" + ) + ) + dsn = creds.to_odbc_dsn() + result = {k: v for k, v in (param.split("=") for param in dsn.split(";"))} + assert result["LONGASMAX"] == "yes" diff --git a/tests/load/synapse/test_synapse_table_builder.py b/tests/load/synapse/test_synapse_table_builder.py new file mode 100644 index 0000000000..f58a7d5883 --- /dev/null +++ b/tests/load/synapse/test_synapse_table_builder.py @@ -0,0 +1,130 @@ +import os +import pytest +import sqlfluff +from copy import deepcopy +from sqlfluff.api.simple import APIParsingError + +from dlt.common.utils import uniq_id +from dlt.common.schema import Schema, TColumnHint + +from dlt.destinations.impl.synapse.synapse import SynapseClient +from dlt.destinations.impl.synapse.configuration import ( + SynapseClientConfiguration, + SynapseCredentials, +) + +from tests.load.utils import TABLE_UPDATE +from dlt.destinations.impl.synapse.synapse import HINT_TO_SYNAPSE_ATTR + + +@pytest.fixture +def schema() -> Schema: + return Schema("event") + + +@pytest.fixture +def client(schema: Schema) -> SynapseClient: + # return client without opening connection + client = SynapseClient( + schema, + SynapseClientConfiguration( + dataset_name="test_" + uniq_id(), credentials=SynapseCredentials() + ), + ) + assert client.config.create_indexes is False + return client + + +@pytest.fixture +def client_with_indexes_enabled(schema: Schema) -> SynapseClient: + # return client without opening connection + client = SynapseClient( + schema, + SynapseClientConfiguration( + dataset_name="test_" + uniq_id(), credentials=SynapseCredentials(), create_indexes=True + ), + ) + assert client.config.create_indexes is True + return client + + +def test_create_table(client: SynapseClient) -> None: + # non existing table + sql = client._get_table_update_sql("event_test_table", TABLE_UPDATE, False)[0] + sqlfluff.parse(sql, dialect="tsql") + assert "event_test_table" in sql + assert '"col1" bigint NOT NULL' in sql + assert '"col2" float NOT NULL' in sql + assert '"col3" bit NOT NULL' in sql + assert '"col4" datetimeoffset NOT NULL' in sql + assert '"col5" nvarchar(max) NOT NULL' in sql + assert '"col6" decimal(38,9) NOT NULL' in sql + assert '"col7" varbinary(max) NOT NULL' in sql + assert '"col8" decimal(38,0)' in sql + assert '"col9" nvarchar(max) NOT NULL' in sql + assert '"col10" date NOT NULL' in sql + assert '"col11" time NOT NULL' in sql + assert '"col1_precision" smallint NOT NULL' in sql + assert '"col4_precision" datetimeoffset(3) NOT NULL' in sql + assert '"col5_precision" nvarchar(25)' in sql + assert '"col6_precision" decimal(6,2) NOT NULL' in sql + assert '"col7_precision" varbinary(19)' in sql + assert '"col11_precision" time(3) NOT NULL' in sql + assert "WITH ( HEAP )" in sql + + +def test_alter_table(client: SynapseClient) -> None: + # existing table has no columns + sql = client._get_table_update_sql("event_test_table", TABLE_UPDATE, True)[0] + sqlfluff.parse(sql, dialect="tsql") + canonical_name = client.sql_client.make_qualified_table_name("event_test_table") + assert sql.count(f"ALTER TABLE {canonical_name}\nADD") == 1 + assert "event_test_table" in sql + assert '"col1" bigint NOT NULL' in sql + assert '"col2" float NOT NULL' in sql + assert '"col3" bit NOT NULL' in sql + assert '"col4" datetimeoffset NOT NULL' in sql + assert '"col5" nvarchar(max) NOT NULL' in sql + assert '"col6" decimal(38,9) NOT NULL' in sql + assert '"col7" varbinary(max) NOT NULL' in sql + assert '"col8" decimal(38,0)' in sql + assert '"col9" nvarchar(max) NOT NULL' in sql + assert '"col10" date NOT NULL' in sql + assert '"col11" time NOT NULL' in sql + assert '"col1_precision" smallint NOT NULL' in sql + assert '"col4_precision" datetimeoffset(3) NOT NULL' in sql + assert '"col5_precision" nvarchar(25)' in sql + assert '"col6_precision" decimal(6,2) NOT NULL' in sql + assert '"col7_precision" varbinary(19)' in sql + assert '"col11_precision" time(3) NOT NULL' in sql + assert "WITH ( HEAP )" not in sql + + +@pytest.mark.parametrize("hint", ["primary_key", "unique"]) +def test_create_table_with_column_hint( + client: SynapseClient, client_with_indexes_enabled: SynapseClient, hint: TColumnHint +) -> None: + attr = HINT_TO_SYNAPSE_ATTR[hint] + + # Case: table without hint. + sql = client._get_table_update_sql("event_test_table", TABLE_UPDATE, False)[0] + sqlfluff.parse(sql, dialect="tsql") + assert f" {attr} " not in sql + + # Case: table with hint, but client does not have indexes enabled. + mod_update = deepcopy(TABLE_UPDATE) + mod_update[0][hint] = True # type: ignore[typeddict-unknown-key] + sql = client._get_table_update_sql("event_test_table", mod_update, False)[0] + sqlfluff.parse(sql, dialect="tsql") + assert f" {attr} " not in sql + + # Case: table with hint, client has indexes enabled. + sql = client_with_indexes_enabled._get_table_update_sql("event_test_table", mod_update, False)[ + 0 + ] + # We expect an error because "PRIMARY KEY NONCLUSTERED NOT ENFORCED" and + # "UNIQUE NOT ENFORCED" are invalid in the generic "tsql" dialect. + # They are however valid in the Synapse variant of the dialect. + with pytest.raises(APIParsingError): + sqlfluff.parse(sql, dialect="tsql") + assert f'"col1" bigint {attr} NOT NULL' in sql diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 153504bf4a..b8d2e31e3f 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -387,7 +387,8 @@ def test_get_storage_table_with_all_types(client: SqlJobClientBase) -> None: "time", ): continue - if client.config.destination_type == "mssql" and c["data_type"] in ("complex"): + # mssql and synapse have no native data type for the complex type. + if client.config.destination_type in ("mssql", "synapse") and c["data_type"] in ("complex"): continue assert c["data_type"] == expected_c["data_type"] diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index 96f0db09bb..4bdf08e23c 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -38,7 +38,7 @@ def client(request) -> Iterator[SqlJobClientBase]: @pytest.mark.parametrize( "client", - destinations_configs(default_sql_configs=True, exclude=["mssql"]), + destinations_configs(default_sql_configs=True, exclude=["mssql", "synapse"]), indirect=True, ids=lambda x: x.name, ) @@ -263,9 +263,15 @@ def test_execute_df(client: SqlJobClientBase) -> None: client.update_stored_schema() table_name = prepare_temp_table(client) f_q_table_name = client.sql_client.make_qualified_table_name(table_name) - insert_query = ",".join([f"({idx})" for idx in range(0, total_records)]) - client.sql_client.execute_sql(f"INSERT INTO {f_q_table_name} VALUES {insert_query};") + if client.capabilities.insert_values_writer_type == "default": + insert_query = ",".join([f"({idx})" for idx in range(0, total_records)]) + sql_stmt = f"INSERT INTO {f_q_table_name} VALUES {insert_query};" + elif client.capabilities.insert_values_writer_type == "select_union": + insert_query = " UNION ALL ".join([f"SELECT {idx}" for idx in range(0, total_records)]) + sql_stmt = f"INSERT INTO {f_q_table_name} {insert_query};" + + client.sql_client.execute_sql(sql_stmt) with client.sql_client.execute_query( f"SELECT * FROM {f_q_table_name} ORDER BY col ASC" ) as curr: diff --git a/tests/load/utils.py b/tests/load/utils.py index 6811ca59a6..55445e0b95 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -163,7 +163,7 @@ def destinations_configs( destination_configs += [ DestinationTestConfiguration(destination=destination) for destination in SQL_DESTINATIONS - if destination != "athena" + if destination not in ("athena", "synapse") ] destination_configs += [ DestinationTestConfiguration(destination="duckdb", file_format="parquet") @@ -190,6 +190,10 @@ def destinations_configs( extra_info="iceberg", ) ] + # dbt for Synapse has some complications and I couldn't get it to pass all tests. + destination_configs += [ + DestinationTestConfiguration(destination="synapse", supports_dbt=False) + ] if default_vector_configs: # for now only weaviate @@ -465,7 +469,6 @@ def yield_client_with_storage( ) as client: client.initialize_storage() yield client - # print(dataset_name) client.sql_client.drop_dataset() if isinstance(client, WithStagingDataset): with client.with_staging_dataset(): diff --git a/tests/utils.py b/tests/utils.py index cf172f9733..211f87874d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -45,6 +45,7 @@ "motherduck", "mssql", "qdrant", + "synapse", } NON_SQL_DESTINATIONS = {"filesystem", "weaviate", "dummy", "motherduck", "qdrant"} SQL_DESTINATIONS = IMPLEMENTED_DESTINATIONS - NON_SQL_DESTINATIONS