From 7bc2163ff001b9a3299827e1d3ddf0da021f36d6 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Thu, 18 Jan 2024 01:18:56 +0100 Subject: [PATCH 01/84] Synapse destination initial commit --- .../workflows/test_destination_synapse.yml | 22 ++- dlt/common/data_writers/escape.py | 10 +- dlt/common/data_writers/writers.py | 24 +++- dlt/common/destination/capabilities.py | 1 + dlt/destinations/__init__.py | 2 + dlt/destinations/impl/mssql/configuration.py | 31 +++-- dlt/destinations/impl/mssql/sql_client.py | 7 +- dlt/destinations/impl/synapse/README.md | 58 ++++++++ dlt/destinations/impl/synapse/__init__.py | 46 +++++++ .../impl/synapse/configuration.py | 38 +++++ dlt/destinations/impl/synapse/factory.py | 51 +++++++ dlt/destinations/impl/synapse/sql_client.py | 28 ++++ dlt/destinations/impl/synapse/synapse.py | 99 +++++++++++++ dlt/destinations/insert_job_client.py | 18 ++- dlt/helpers/dbt/profiles.yml | 18 ++- poetry.lock | 3 +- pyproject.toml | 1 + tests/load/mssql/test_mssql_credentials.py | 69 +++++++--- tests/load/mssql/test_mssql_table_builder.py | 3 +- tests/load/pipeline/test_dbt_helper.py | 7 +- .../load/pipeline/test_replace_disposition.py | 6 + tests/load/synapse/__init__.py | 3 + .../synapse/test_synapse_configuration.py | 46 +++++++ .../synapse/test_synapse_table_builder.py | 130 ++++++++++++++++++ tests/load/test_job_client.py | 3 +- tests/load/test_sql_client.py | 12 +- tests/load/utils.py | 7 +- tests/utils.py | 1 + 28 files changed, 672 insertions(+), 72 deletions(-) create mode 100644 dlt/destinations/impl/synapse/README.md create mode 100644 dlt/destinations/impl/synapse/__init__.py create mode 100644 dlt/destinations/impl/synapse/configuration.py create mode 100644 dlt/destinations/impl/synapse/factory.py create mode 100644 dlt/destinations/impl/synapse/sql_client.py create mode 100644 dlt/destinations/impl/synapse/synapse.py create mode 100644 tests/load/synapse/__init__.py create mode 100644 tests/load/synapse/test_synapse_configuration.py create mode 100644 tests/load/synapse/test_synapse_table_builder.py diff --git a/.github/workflows/test_destination_synapse.yml b/.github/workflows/test_destination_synapse.yml index 83800fa789..ecd890d32a 100644 --- a/.github/workflows/test_destination_synapse.yml +++ b/.github/workflows/test_destination_synapse.yml @@ -5,7 +5,6 @@ on: branches: - master - devel - workflow_dispatch: env: @@ -18,19 +17,14 @@ env: ALL_FILESYSTEM_DRIVERS: "[\"memory\"]" jobs: - - build: - runs-on: ubuntu-latest - - steps: - - name: Check source branch name - run: | - if [[ "${{ github.head_ref }}" != "synapse" ]]; then - exit 1 - fi + get_docs_changes: + uses: ./.github/workflows/get_docs_changes.yml + if: ${{ !github.event.pull_request.head.repo.fork }} run_loader: name: Tests Synapse loader + needs: get_docs_changes + if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' strategy: fail-fast: false matrix: @@ -69,17 +63,17 @@ jobs: key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}-gcp - name: Install dependencies - run: poetry install --no-interaction -E synapse -E s3 -E gs -E az --with sentry-sdk --with pipeline + run: poetry install --no-interaction -E synapse -E parquet --with sentry-sdk --with pipeline - name: create secrets.toml run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml - run: | - poetry run pytest tests/load --ignore tests/load/pipeline/test_dbt_helper.py + poetry run pytest tests/load if: runner.os != 'Windows' name: Run tests Linux/MAC - run: | - poetry run pytest tests/load --ignore tests/load/pipeline/test_dbt_helper.py + poetry run pytest tests/load if: runner.os == 'Windows' name: Run tests Windows shell: cmd diff --git a/dlt/common/data_writers/escape.py b/dlt/common/data_writers/escape.py index 5bf8f29ccb..b56a0d8f19 100644 --- a/dlt/common/data_writers/escape.py +++ b/dlt/common/data_writers/escape.py @@ -98,8 +98,14 @@ def escape_mssql_literal(v: Any) -> Any: json.dumps(v), prefix="N'", escape_dict=MS_SQL_ESCAPE_DICT, escape_re=MS_SQL_ESCAPE_RE ) if isinstance(v, bytes): - base_64_string = base64.b64encode(v).decode("ascii") - return f"""CAST('' AS XML).value('xs:base64Binary("{base_64_string}")', 'VARBINARY(MAX)')""" + # 8000 is the max value for n in VARBINARY(n) + # https://learn.microsoft.com/en-us/sql/t-sql/data-types/binary-and-varbinary-transact-sql + if len(v) <= 8000: + n = len(v) + else: + n = "MAX" + return f"CONVERT(VARBINARY({n}), '{v.hex()}', 2)" + if isinstance(v, bool): return str(int(v)) if v is None: diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 0f9ff09259..0f3640da1e 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -175,18 +175,29 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: # do not write INSERT INTO command, this must be added together with table name by the loader self._f.write("INSERT INTO {}(") self._f.write(",".join(map(self._caps.escape_identifier, headers))) - self._f.write(")\nVALUES\n") + if self._caps.insert_values_writer_type == "default": + self._f.write(")\nVALUES\n") + elif self._caps.insert_values_writer_type == "select_union": + self._f.write(")\n") def write_data(self, rows: Sequence[Any]) -> None: super().write_data(rows) - def write_row(row: StrAny) -> None: + def write_row(row: StrAny, last_row: bool = False) -> None: output = ["NULL"] * len(self._headers_lookup) for n, v in row.items(): output[self._headers_lookup[n]] = self._caps.escape_literal(v) - self._f.write("(") - self._f.write(",".join(output)) - self._f.write(")") + if self._caps.insert_values_writer_type == "default": + self._f.write("(") + self._f.write(",".join(output)) + self._f.write(")") + if not last_row: + self._f.write(",\n") + elif self._caps.insert_values_writer_type == "select_union": + self._f.write("SELECT ") + self._f.write(",".join(output)) + if not last_row: + self._f.write("\nUNION ALL\n") # if next chunk add separator if self._chunks_written > 0: @@ -195,10 +206,9 @@ def write_row(row: StrAny) -> None: # write rows for row in rows[:-1]: write_row(row) - self._f.write(",\n") # write last row without separator so we can write footer eventually - write_row(rows[-1]) + write_row(rows[-1], last_row=True) self._chunks_written += 1 def write_footer(self) -> None: diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index 2596b2bf99..08c7a31388 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -52,6 +52,7 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): schema_supports_numeric_precision: bool = True timestamp_precision: int = 6 max_rows_per_insert: Optional[int] = None + insert_values_writer_type: str = "default" # do not allow to create default value, destination caps must be always explicitly inserted into container can_create_default: ClassVar[bool] = False diff --git a/dlt/destinations/__init__.py b/dlt/destinations/__init__.py index 980c4ce7f2..775778cd4a 100644 --- a/dlt/destinations/__init__.py +++ b/dlt/destinations/__init__.py @@ -10,6 +10,7 @@ from dlt.destinations.impl.qdrant.factory import qdrant from dlt.destinations.impl.motherduck.factory import motherduck from dlt.destinations.impl.weaviate.factory import weaviate +from dlt.destinations.impl.synapse.factory import synapse __all__ = [ @@ -25,4 +26,5 @@ "qdrant", "motherduck", "weaviate", + "synapse", ] diff --git a/dlt/destinations/impl/mssql/configuration.py b/dlt/destinations/impl/mssql/configuration.py index f33aca4b82..f00998cfb2 100644 --- a/dlt/destinations/impl/mssql/configuration.py +++ b/dlt/destinations/impl/mssql/configuration.py @@ -1,4 +1,4 @@ -from typing import Final, ClassVar, Any, List, Optional, TYPE_CHECKING +from typing import Final, ClassVar, Any, List, Dict, Optional, TYPE_CHECKING from sqlalchemy.engine import URL from dlt.common.configuration import configspec @@ -10,9 +10,6 @@ from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration -SUPPORTED_DRIVERS = ["ODBC Driver 18 for SQL Server", "ODBC Driver 17 for SQL Server"] - - @configspec class MsSqlCredentials(ConnectionStringCredentials): drivername: Final[str] = "mssql" # type: ignore @@ -24,22 +21,27 @@ class MsSqlCredentials(ConnectionStringCredentials): __config_gen_annotations__: ClassVar[List[str]] = ["port", "connect_timeout"] + SUPPORTED_DRIVERS: ClassVar[List[str]] = [ + "ODBC Driver 18 for SQL Server", + "ODBC Driver 17 for SQL Server", + ] + def parse_native_representation(self, native_value: Any) -> None: # TODO: Support ODBC connection string or sqlalchemy URL super().parse_native_representation(native_value) if self.query is not None: self.query = {k.lower(): v for k, v in self.query.items()} # Make case-insensitive. - if "driver" in self.query and self.query.get("driver") not in SUPPORTED_DRIVERS: - raise SystemConfigurationException( - f"""The specified driver "{self.query.get('driver')}" is not supported.""" - f" Choose one of the supported drivers: {', '.join(SUPPORTED_DRIVERS)}." - ) self.driver = self.query.get("driver", self.driver) self.connect_timeout = int(self.query.get("connect_timeout", self.connect_timeout)) if not self.is_partial(): self.resolve() def on_resolved(self) -> None: + if self.driver not in self.SUPPORTED_DRIVERS: + raise SystemConfigurationException( + f"""The specified driver "{self.driver}" is not supported.""" + f" Choose one of the supported drivers: {', '.join(self.SUPPORTED_DRIVERS)}." + ) self.database = self.database.lower() def to_url(self) -> URL: @@ -55,20 +57,21 @@ def on_partial(self) -> None: def _get_driver(self) -> str: if self.driver: return self.driver + # Pick a default driver if available import pyodbc available_drivers = pyodbc.drivers() - for d in SUPPORTED_DRIVERS: + for d in self.SUPPORTED_DRIVERS: if d in available_drivers: return d docs_url = "https://learn.microsoft.com/en-us/sql/connect/odbc/download-odbc-driver-for-sql-server?view=sql-server-ver16" raise SystemConfigurationException( f"No supported ODBC driver found for MS SQL Server. See {docs_url} for information on" - f" how to install the '{SUPPORTED_DRIVERS[0]}' on your platform." + f" how to install the '{self.SUPPORTED_DRIVERS[0]}' on your platform." ) - def to_odbc_dsn(self) -> str: + def _get_odbc_dsn_dict(self) -> Dict[str, Any]: params = { "DRIVER": self.driver, "SERVER": f"{self.host},{self.port}", @@ -78,6 +81,10 @@ def to_odbc_dsn(self) -> str: } if self.query is not None: params.update({k.upper(): v for k, v in self.query.items()}) + return params + + def to_odbc_dsn(self) -> str: + params = self._get_odbc_dsn_dict() return ";".join([f"{k}={v}" for k, v in params.items()]) diff --git a/dlt/destinations/impl/mssql/sql_client.py b/dlt/destinations/impl/mssql/sql_client.py index 427518feeb..2ddd56350e 100644 --- a/dlt/destinations/impl/mssql/sql_client.py +++ b/dlt/destinations/impl/mssql/sql_client.py @@ -106,8 +106,8 @@ def drop_dataset(self) -> None: ) table_names = [row[0] for row in rows] self.drop_tables(*table_names) - - self.execute_sql("DROP SCHEMA IF EXISTS %s;" % self.fully_qualified_dataset_name()) + # Drop schema + self._drop_schema() def _drop_views(self, *tables: str) -> None: if not tables: @@ -117,6 +117,9 @@ def _drop_views(self, *tables: str) -> None: ] self.execute_fragments(statements) + def _drop_schema(self) -> None: + self.execute_sql("DROP SCHEMA IF EXISTS %s;" % self.fully_qualified_dataset_name()) + def execute_sql( self, sql: AnyStr, *args: Any, **kwargs: Any ) -> Optional[Sequence[Sequence[Any]]]: diff --git a/dlt/destinations/impl/synapse/README.md b/dlt/destinations/impl/synapse/README.md new file mode 100644 index 0000000000..b133faf67a --- /dev/null +++ b/dlt/destinations/impl/synapse/README.md @@ -0,0 +1,58 @@ +# Set up loader user +Execute the following SQL statements to set up the [loader](https://learn.microsoft.com/en-us/azure/synapse-analytics/sql/data-loading-best-practices#create-a-loading-user) user: +```sql +-- on master database + +CREATE LOGIN loader WITH PASSWORD = 'YOUR_LOADER_PASSWORD_HERE'; +``` + +```sql +-- on minipool database + +CREATE USER loader FOR LOGIN loader; + +-- DDL permissions +GRANT CREATE TABLE ON DATABASE :: minipool TO loader; +GRANT CREATE VIEW ON DATABASE :: minipool TO loader; + +-- DML permissions +GRANT SELECT ON DATABASE :: minipool TO loader; +GRANT INSERT ON DATABASE :: minipool TO loader; +GRANT ADMINISTER DATABASE BULK OPERATIONS TO loader; +``` + +```sql +-- https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-workload-isolation + +CREATE WORKLOAD GROUP DataLoads +WITH ( + MIN_PERCENTAGE_RESOURCE = 0 + ,CAP_PERCENTAGE_RESOURCE = 50 + ,REQUEST_MIN_RESOURCE_GRANT_PERCENT = 25 +); + +CREATE WORKLOAD CLASSIFIER [wgcELTLogin] +WITH ( + WORKLOAD_GROUP = 'DataLoads' + ,MEMBERNAME = 'loader' +); +``` + +# config.toml +```toml +[destination.synapse.credentials] +database = "minipool" +username = "loader" +host = "dlt-synapse-ci.sql.azuresynapse.net" +port = 1433 +driver = "ODBC Driver 18 for SQL Server" + +[destination.synapse] +create_indexes = false +``` + +# secrets.toml +```toml +[destination.synapse.credentials] +password = "YOUR_LOADER_PASSWORD_HERE" +``` \ No newline at end of file diff --git a/dlt/destinations/impl/synapse/__init__.py b/dlt/destinations/impl/synapse/__init__.py new file mode 100644 index 0000000000..175b011186 --- /dev/null +++ b/dlt/destinations/impl/synapse/__init__.py @@ -0,0 +1,46 @@ +from dlt.common.data_writers.escape import escape_postgres_identifier, escape_mssql_literal +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.wei import EVM_DECIMAL_PRECISION + + +def capabilities() -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + + caps.preferred_loader_file_format = "insert_values" + caps.supported_loader_file_formats = ["insert_values"] + caps.preferred_staging_file_format = None + caps.supported_staging_file_formats = [] + + caps.insert_values_writer_type = "select_union" # https://stackoverflow.com/a/77014299 + + caps.escape_identifier = escape_postgres_identifier + caps.escape_literal = escape_mssql_literal + + # Synapse has a max precision of 38 + # https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-azure-sql-data-warehouse?view=aps-pdw-2016-au7#DataTypes + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + + # https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-azure-sql-data-warehouse?view=aps-pdw-2016-au7#LimitationsRestrictions + caps.max_identifier_length = 128 + caps.max_column_identifier_length = 128 + + # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-service-capacity-limits#queries + caps.max_query_length = 65536 * 4096 + caps.is_max_query_length_in_bytes = True + + # nvarchar(max) can store 2 GB + # https://learn.microsoft.com/en-us/sql/t-sql/data-types/nchar-and-nvarchar-transact-sql?view=sql-server-ver16#nvarchar---n--max-- + caps.max_text_data_type_length = 2 * 1024 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = True + + # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-develop-transactions + caps.supports_transactions = True + caps.supports_ddl_transactions = False + + # datetimeoffset can store 7 digits for fractional seconds + # https://learn.microsoft.com/en-us/sql/t-sql/data-types/datetimeoffset-transact-sql?view=sql-server-ver16 + caps.timestamp_precision = 7 + + return caps diff --git a/dlt/destinations/impl/synapse/configuration.py b/dlt/destinations/impl/synapse/configuration.py new file mode 100644 index 0000000000..0596cc2c46 --- /dev/null +++ b/dlt/destinations/impl/synapse/configuration.py @@ -0,0 +1,38 @@ +from typing import Final, Any, List, Dict, Optional, ClassVar + +from dlt.common.configuration import configspec + +from dlt.destinations.impl.mssql.configuration import ( + MsSqlCredentials, + MsSqlClientConfiguration, +) +from dlt.destinations.impl.mssql.configuration import MsSqlCredentials + + +@configspec +class SynapseCredentials(MsSqlCredentials): + drivername: Final[str] = "synapse" # type: ignore + + # LongAsMax keyword got introduced in ODBC Driver 18 for SQL Server. + SUPPORTED_DRIVERS: ClassVar[List[str]] = ["ODBC Driver 18 for SQL Server"] + + def _get_odbc_dsn_dict(self) -> Dict[str, Any]: + params = super()._get_odbc_dsn_dict() + # Long types (text, ntext, image) are not supported on Synapse. + # Convert to max types using LongAsMax keyword. + # https://stackoverflow.com/a/57926224 + params["LONGASMAX"] = "yes" + return params + + +@configspec +class SynapseClientConfiguration(MsSqlClientConfiguration): + destination_type: Final[str] = "synapse" # type: ignore + credentials: SynapseCredentials + + # Determines if `primary_key` and `unique` column hints are applied. + # Set to False by default because the PRIMARY KEY and UNIQUE constraints + # are tricky in Synapse: they are NOT ENFORCED and can lead to innacurate + # results if the user does not ensure all column values are unique. + # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-table-constraints + create_indexes: bool = False diff --git a/dlt/destinations/impl/synapse/factory.py b/dlt/destinations/impl/synapse/factory.py new file mode 100644 index 0000000000..fa7facc0ca --- /dev/null +++ b/dlt/destinations/impl/synapse/factory.py @@ -0,0 +1,51 @@ +import typing as t + +from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.destinations.impl.synapse import capabilities + +from dlt.destinations.impl.synapse.configuration import ( + SynapseCredentials, + SynapseClientConfiguration, +) + +if t.TYPE_CHECKING: + from dlt.destinations.impl.synapse.synapse import SynapseClient + + +class synapse(Destination[SynapseClientConfiguration, "SynapseClient"]): + spec = SynapseClientConfiguration + + def capabilities(self) -> DestinationCapabilitiesContext: + return capabilities() + + @property + def client_class(self) -> t.Type["SynapseClient"]: + from dlt.destinations.impl.synapse.synapse import SynapseClient + + return SynapseClient + + def __init__( + self, + credentials: t.Union[SynapseCredentials, t.Dict[str, t.Any], str] = None, + create_indexes: bool = False, + destination_name: t.Optional[str] = None, + environment: t.Optional[str] = None, + **kwargs: t.Any, + ) -> None: + """Configure the Synapse destination to use in a pipeline. + + All arguments provided here supersede other configuration sources such as environment variables and dlt config files. + + Args: + credentials: Credentials to connect to the Synapse dedicated pool. Can be an instance of `SynapseCredentials` or + a connection string in the format `synapse://user:password@host:port/database` + create_indexes: Should unique indexes be created, defaults to False + **kwargs: Additional arguments passed to the destination config + """ + super().__init__( + credentials=credentials, + create_indexes=create_indexes, + destination_name=destination_name, + environment=environment, + **kwargs, + ) diff --git a/dlt/destinations/impl/synapse/sql_client.py b/dlt/destinations/impl/synapse/sql_client.py new file mode 100644 index 0000000000..089c58e57c --- /dev/null +++ b/dlt/destinations/impl/synapse/sql_client.py @@ -0,0 +1,28 @@ +from typing import ClassVar +from contextlib import suppress + +from dlt.common.destination import DestinationCapabilitiesContext + +from dlt.destinations.impl.mssql.sql_client import PyOdbcMsSqlClient +from dlt.destinations.impl.mssql.configuration import MsSqlCredentials +from dlt.destinations.impl.synapse import capabilities +from dlt.destinations.impl.synapse.configuration import SynapseCredentials + +from dlt.destinations.exceptions import DatabaseUndefinedRelation + + +class SynapseSqlClient(PyOdbcMsSqlClient): + capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() + + def drop_tables(self, *tables: str) -> None: + if not tables: + return + # Synapse does not support DROP TABLE IF EXISTS. + # Workaround: use DROP TABLE and suppress non-existence errors. + statements = [f"DROP TABLE {self.make_qualified_table_name(table)};" for table in tables] + with suppress(DatabaseUndefinedRelation): + self.execute_fragments(statements) + + def _drop_schema(self) -> None: + # Synapse does not support DROP SCHEMA IF EXISTS. + self.execute_sql("DROP SCHEMA %s;" % self.fully_qualified_dataset_name()) diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py new file mode 100644 index 0000000000..18d1fa81d4 --- /dev/null +++ b/dlt/destinations/impl/synapse/synapse.py @@ -0,0 +1,99 @@ +from typing import ClassVar, Sequence, List, Dict, Any, Optional +from copy import deepcopy + +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.destination.reference import SupportsStagingDestination, NewLoadJob + +from dlt.common.schema import TTableSchema, TColumnSchema, Schema, TColumnHint +from dlt.common.schema.typing import TTableSchemaColumns + +from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams +from dlt.destinations.sql_client import SqlClientBase +from dlt.destinations.insert_job_client import InsertValuesJobClient +from dlt.destinations.job_client_impl import SqlJobClientBase + +from dlt.destinations.impl.mssql.mssql import MsSqlTypeMapper, MsSqlClient, HINT_TO_MSSQL_ATTR + +from dlt.destinations.impl.synapse import capabilities +from dlt.destinations.impl.synapse.sql_client import SynapseSqlClient +from dlt.destinations.impl.synapse.configuration import SynapseClientConfiguration + + +HINT_TO_SYNAPSE_ATTR: Dict[TColumnHint, str] = { + "primary_key": "PRIMARY KEY NONCLUSTERED NOT ENFORCED", + "unique": "UNIQUE NOT ENFORCED", +} + + +class SynapseClient(MsSqlClient, SupportsStagingDestination): + capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() + + def __init__(self, schema: Schema, config: SynapseClientConfiguration) -> None: + sql_client = SynapseSqlClient(config.normalize_dataset_name(schema), config.credentials) + InsertValuesJobClient.__init__(self, schema, config, sql_client) + self.config: SynapseClientConfiguration = config + self.sql_client = sql_client + self.type_mapper = MsSqlTypeMapper(self.capabilities) + + self.active_hints = deepcopy(HINT_TO_SYNAPSE_ATTR) + if not self.config.create_indexes: + self.active_hints.pop("primary_key", None) + self.active_hints.pop("unique", None) + + def _get_table_update_sql( + self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool + ) -> List[str]: + _sql_result = SqlJobClientBase._get_table_update_sql( + self, table_name, new_columns, generate_alter + ) + if not generate_alter: + # Append WITH clause to create heap table instead of default + # columnstore table. Heap tables are a more robust choice, because + # columnstore tables do not support varchar(max), nvarchar(max), + # and varbinary(max). + # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-index + sql_result = [_sql_result[0] + "\n WITH ( HEAP );"] + else: + sql_result = _sql_result + return sql_result + + def _create_replace_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[NewLoadJob]: + if self.config.replace_strategy == "staging-optimized": + return [SynapseStagingCopyJob.from_table_chain(table_chain, self.sql_client)] + return super()._create_replace_followup_jobs(table_chain) + + +class SynapseStagingCopyJob(SqlStagingCopyJob): + @classmethod + def generate_sql( + cls, + table_chain: Sequence[TTableSchema], + sql_client: SqlClientBase[Any], + params: Optional[SqlJobParams] = None, + ) -> List[str]: + sql: List[str] = [] + for table in table_chain: + with sql_client.with_staging_dataset(staging=True): + staging_table_name = sql_client.make_qualified_table_name(table["name"]) + table_name = sql_client.make_qualified_table_name(table["name"]) + # drop destination table + sql.append(f"DROP TABLE {table_name};") + # moving staging table to destination schema + sql.append( + f"ALTER SCHEMA {sql_client.fully_qualified_dataset_name()} TRANSFER" + f" {staging_table_name};" + ) + # recreate staging table + # In some cases, when multiple instances of this CTAS query are + # executed concurrently, Synapse suspends the queries and hangs. + # This can be prevented by setting the env var LOAD__WORKERS = "1". + sql.append( + f"CREATE TABLE {staging_table_name}" + " WITH ( DISTRIBUTION = ROUND_ROBIN, HEAP )" # distribution must be explicitly specified with CTAS + f" AS SELECT * FROM {table_name}" + " WHERE 1 = 0;" # no data, table structure only + ) + + return sql diff --git a/dlt/destinations/insert_job_client.py b/dlt/destinations/insert_job_client.py index 678ba43bcc..776176078e 100644 --- a/dlt/destinations/insert_job_client.py +++ b/dlt/destinations/insert_job_client.py @@ -36,9 +36,10 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st # the procedure below will split the inserts into max_query_length // 2 packs with FileStorage.open_zipsafe_ro(file_path, "r", encoding="utf-8") as f: header = f.readline() - values_mark = f.readline() - # properly formatted file has a values marker at the beginning - assert values_mark == "VALUES\n" + if self._sql_client.capabilities.insert_values_writer_type == "default": + # properly formatted file has a values marker at the beginning + values_mark = f.readline() + assert values_mark == "VALUES\n" max_rows = self._sql_client.capabilities.max_rows_per_insert @@ -67,7 +68,9 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st # Chunk by max_rows - 1 for simplicity because one more row may be added for chunk in chunks(values_rows, max_rows - 1): processed += len(chunk) - insert_sql.extend([header.format(qualified_table_name), values_mark]) + insert_sql.append(header.format(qualified_table_name)) + if self._sql_client.capabilities.insert_values_writer_type == "default": + insert_sql.append(values_mark) if processed == len_rows: # On the last chunk we need to add the extra row read insert_sql.append("".join(chunk) + until_nl) @@ -76,7 +79,12 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st insert_sql.append("".join(chunk).strip()[:-1] + ";\n") else: # otherwise write all content in a single INSERT INTO - insert_sql.extend([header.format(qualified_table_name), values_mark, content]) + if self._sql_client.capabilities.insert_values_writer_type == "default": + insert_sql.extend( + [header.format(qualified_table_name), values_mark, content] + ) + elif self._sql_client.capabilities.insert_values_writer_type == "select_union": + insert_sql.extend([header.format(qualified_table_name), content]) if until_nl: insert_sql.append(until_nl) diff --git a/dlt/helpers/dbt/profiles.yml b/dlt/helpers/dbt/profiles.yml index 2414222cbd..7031f5de2c 100644 --- a/dlt/helpers/dbt/profiles.yml +++ b/dlt/helpers/dbt/profiles.yml @@ -141,4 +141,20 @@ athena: schema: "{{ var('destination_dataset_name', var('source_dataset_name')) }}" database: "{{ env_var('DLT__AWS_DATA_CATALOG') }}" # aws_profile_name: "{{ env_var('DLT__CREDENTIALS__PROFILE_NAME', '') }}" - work_group: "{{ env_var('DLT__ATHENA_WORK_GROUP', '') }}" \ No newline at end of file + work_group: "{{ env_var('DLT__ATHENA_WORK_GROUP', '') }}" + + +# commented out because dbt for Synapse isn't currently properly supported. +# Leave config here for potential future use. +# synapse: +# target: analytics +# outputs: +# analytics: +# type: synapse +# driver: "{{ env_var('DLT__CREDENTIALS__DRIVER') }}" +# server: "{{ env_var('DLT__CREDENTIALS__HOST') }}" +# port: "{{ env_var('DLT__CREDENTIALS__PORT') | as_number }}" +# database: "{{ env_var('DLT__CREDENTIALS__DATABASE') }}" +# schema: "{{ var('destination_dataset_name', var('source_dataset_name')) }}" +# user: "{{ env_var('DLT__CREDENTIALS__USERNAME') }}" +# password: "{{ env_var('DLT__CREDENTIALS__PASSWORD') }}" \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index c5da40c604..4d079fc44d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -8466,9 +8466,10 @@ qdrant = ["qdrant-client"] redshift = ["psycopg2-binary", "psycopg2cffi"] s3 = ["botocore", "s3fs"] snowflake = ["snowflake-connector-python"] +synapse = ["pyodbc"] weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "cf751b2e1e9c66efde0a11774b5204e3206a14fd04ba4c79b2d37e38db5367ad" +content-hash = "26c595a857f17a5cbdb348f165c267d8910412325be4e522d0e91224c7fec588" diff --git a/pyproject.toml b/pyproject.toml index 6436ec23a7..d9d5858674 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,6 +96,7 @@ cli = ["pipdeptree", "cron-descriptor"] athena = ["pyathena", "pyarrow", "s3fs", "botocore"] weaviate = ["weaviate-client"] mssql = ["pyodbc"] +synapse = ["pyodbc"] qdrant = ["qdrant-client"] [tool.poetry.scripts] diff --git a/tests/load/mssql/test_mssql_credentials.py b/tests/load/mssql/test_mssql_credentials.py index 0098d228f1..0e38791f22 100644 --- a/tests/load/mssql/test_mssql_credentials.py +++ b/tests/load/mssql/test_mssql_credentials.py @@ -1,18 +1,35 @@ import pyodbc import pytest -from dlt.common.configuration import resolve_configuration +from dlt.common.configuration import resolve_configuration, ConfigFieldMissingException from dlt.common.exceptions import SystemConfigurationException -from dlt.destinations.impl.mssql.configuration import MsSqlCredentials, SUPPORTED_DRIVERS +from dlt.destinations.impl.mssql.configuration import MsSqlCredentials -def test_parse_native_representation_unsupported_driver_specified() -> None: +def test_mssql_credentials_defaults() -> None: + creds = MsSqlCredentials() + assert creds.port == 1433 + assert creds.connect_timeout == 15 + assert MsSqlCredentials.__config_gen_annotations__ == ["port", "connect_timeout"] + # port should be optional + resolve_configuration(creds, explicit_value="mssql://loader:loader@localhost/dlt_data") + assert creds.port == 1433 + + +def test_parse_native_representation() -> None: # Case: unsupported driver specified. with pytest.raises(SystemConfigurationException): resolve_configuration( MsSqlCredentials( - "mssql://test_user:test_password@sql.example.com:12345/test_db?DRIVER=foo" + "mssql://test_user:test_pwd@sql.example.com/test_db?DRIVER=ODBC+Driver+13+for+SQL+Server" + ) + ) + # Case: password not specified. + with pytest.raises(ConfigFieldMissingException): + resolve_configuration( + MsSqlCredentials( + "mssql://test_user@sql.example.com/test_db?DRIVER=ODBC+Driver+18+for+SQL+Server" ) ) @@ -21,33 +38,49 @@ def test_to_odbc_dsn_supported_driver_specified() -> None: # Case: supported driver specified — ODBC Driver 18 for SQL Server. creds = resolve_configuration( MsSqlCredentials( - "mssql://test_user:test_password@sql.example.com:12345/test_db?DRIVER=ODBC+Driver+18+for+SQL+Server" + "mssql://test_user:test_pwd@sql.example.com/test_db?DRIVER=ODBC+Driver+18+for+SQL+Server" ) ) dsn = creds.to_odbc_dsn() result = {k: v for k, v in (param.split("=") for param in dsn.split(";"))} assert result == { "DRIVER": "ODBC Driver 18 for SQL Server", - "SERVER": "sql.example.com,12345", + "SERVER": "sql.example.com,1433", "DATABASE": "test_db", "UID": "test_user", - "PWD": "test_password", + "PWD": "test_pwd", } # Case: supported driver specified — ODBC Driver 17 for SQL Server. creds = resolve_configuration( MsSqlCredentials( - "mssql://test_user:test_password@sql.example.com:12345/test_db?DRIVER=ODBC+Driver+17+for+SQL+Server" + "mssql://test_user:test_pwd@sql.example.com/test_db?DRIVER=ODBC+Driver+17+for+SQL+Server" ) ) dsn = creds.to_odbc_dsn() result = {k: v for k, v in (param.split("=") for param in dsn.split(";"))} assert result == { "DRIVER": "ODBC Driver 17 for SQL Server", + "SERVER": "sql.example.com,1433", + "DATABASE": "test_db", + "UID": "test_user", + "PWD": "test_pwd", + } + + # Case: port and supported driver specified. + creds = resolve_configuration( + MsSqlCredentials( + "mssql://test_user:test_pwd@sql.example.com:12345/test_db?DRIVER=ODBC+Driver+18+for+SQL+Server" + ) + ) + dsn = creds.to_odbc_dsn() + result = {k: v for k, v in (param.split("=") for param in dsn.split(";"))} + assert result == { + "DRIVER": "ODBC Driver 18 for SQL Server", "SERVER": "sql.example.com,12345", "DATABASE": "test_db", "UID": "test_user", - "PWD": "test_password", + "PWD": "test_pwd", } @@ -55,7 +88,7 @@ def test_to_odbc_dsn_arbitrary_keys_specified() -> None: # Case: arbitrary query keys (and supported driver) specified. creds = resolve_configuration( MsSqlCredentials( - "mssql://test_user:test_password@sql.example.com:12345/test_db?FOO=a&BAR=b&DRIVER=ODBC+Driver+18+for+SQL+Server" + "mssql://test_user:test_pwd@sql.example.com:12345/test_db?FOO=a&BAR=b&DRIVER=ODBC+Driver+18+for+SQL+Server" ) ) dsn = creds.to_odbc_dsn() @@ -65,7 +98,7 @@ def test_to_odbc_dsn_arbitrary_keys_specified() -> None: "SERVER": "sql.example.com,12345", "DATABASE": "test_db", "UID": "test_user", - "PWD": "test_password", + "PWD": "test_pwd", "FOO": "a", "BAR": "b", } @@ -73,7 +106,7 @@ def test_to_odbc_dsn_arbitrary_keys_specified() -> None: # Case: arbitrary capitalization. creds = resolve_configuration( MsSqlCredentials( - "mssql://test_user:test_password@sql.example.com:12345/test_db?FOO=a&bar=b&Driver=ODBC+Driver+18+for+SQL+Server" + "mssql://test_user:test_pwd@sql.example.com:12345/test_db?FOO=a&bar=b&Driver=ODBC+Driver+18+for+SQL+Server" ) ) dsn = creds.to_odbc_dsn() @@ -83,30 +116,30 @@ def test_to_odbc_dsn_arbitrary_keys_specified() -> None: "SERVER": "sql.example.com,12345", "DATABASE": "test_db", "UID": "test_user", - "PWD": "test_password", + "PWD": "test_pwd", "FOO": "a", "BAR": "b", } -available_drivers = [d for d in pyodbc.drivers() if d in SUPPORTED_DRIVERS] +available_drivers = [d for d in pyodbc.drivers() if d in MsSqlCredentials.SUPPORTED_DRIVERS] @pytest.mark.skipif(not available_drivers, reason="no supported driver available") def test_to_odbc_dsn_driver_not_specified() -> None: # Case: driver not specified, but supported driver is available. creds = resolve_configuration( - MsSqlCredentials("mssql://test_user:test_password@sql.example.com:12345/test_db") + MsSqlCredentials("mssql://test_user:test_pwd@sql.example.com/test_db") ) dsn = creds.to_odbc_dsn() result = {k: v for k, v in (param.split("=") for param in dsn.split(";"))} assert result in [ { "DRIVER": d, - "SERVER": "sql.example.com,12345", + "SERVER": "sql.example.com,1433", "DATABASE": "test_db", "UID": "test_user", - "PWD": "test_password", + "PWD": "test_pwd", } - for d in SUPPORTED_DRIVERS + for d in MsSqlCredentials.SUPPORTED_DRIVERS ] diff --git a/tests/load/mssql/test_mssql_table_builder.py b/tests/load/mssql/test_mssql_table_builder.py index f7e0ce53ff..039ce99113 100644 --- a/tests/load/mssql/test_mssql_table_builder.py +++ b/tests/load/mssql/test_mssql_table_builder.py @@ -1,11 +1,10 @@ import pytest -from copy import deepcopy import sqlfluff from dlt.common.utils import uniq_id from dlt.common.schema import Schema -pytest.importorskip("dlt.destinations.mssql.mssql", reason="MSSQL ODBC driver not installed") +pytest.importorskip("dlt.destinations.impl.mssql.mssql", reason="MSSQL ODBC driver not installed") from dlt.destinations.impl.mssql.mssql import MsSqlClient from dlt.destinations.impl.mssql.configuration import MsSqlClientConfiguration, MsSqlCredentials diff --git a/tests/load/pipeline/test_dbt_helper.py b/tests/load/pipeline/test_dbt_helper.py index 11f59d5276..e919409311 100644 --- a/tests/load/pipeline/test_dbt_helper.py +++ b/tests/load/pipeline/test_dbt_helper.py @@ -37,6 +37,8 @@ def test_run_jaffle_package( pytest.skip( "dbt-athena requires database to be created and we don't do it in case of Jaffle" ) + if not destination_config.supports_dbt: + pytest.skip("dbt is not supported for this destination configuration") pipeline = destination_config.setup_pipeline("jaffle_jaffle", full_refresh=True) # get runner, pass the env from fixture dbt = dlt.dbt.package(pipeline, "https://github.com/dbt-labs/jaffle_shop.git", venv=dbt_venv) @@ -55,9 +57,10 @@ def test_run_jaffle_package( assert all(r.status == "pass" for r in tests) # get and display dataframe with customers - customers = select_data(pipeline, "SELECT * FROM customers") + qual_name = pipeline.sql_client().make_qualified_table_name + customers = select_data(pipeline, f"SELECT * FROM {qual_name('customers')}") assert len(customers) == 100 - orders = select_data(pipeline, "SELECT * FROM orders") + orders = select_data(pipeline, f"SELECT * FROM {qual_name('orders')}") assert len(orders) == 99 diff --git a/tests/load/pipeline/test_replace_disposition.py b/tests/load/pipeline/test_replace_disposition.py index c6db91efff..1dde56a6b1 100644 --- a/tests/load/pipeline/test_replace_disposition.py +++ b/tests/load/pipeline/test_replace_disposition.py @@ -264,6 +264,12 @@ def test_replace_table_clearing( # use staging tables for replace os.environ["DESTINATION__REPLACE_STRATEGY"] = replace_strategy + if destination_config.destination == "synapse" and replace_strategy == "staging-optimized": + # The "staging-optimized" replace strategy makes Synapse suspend the CTAS + # queries used to recreate the staging table, and hang, when the number + # of load workers is greater than 1. + os.environ["LOAD__WORKERS"] = "1" + pipeline = destination_config.setup_pipeline( "test_replace_table_clearing", dataset_name="test_replace_table_clearing", full_refresh=True ) diff --git a/tests/load/synapse/__init__.py b/tests/load/synapse/__init__.py new file mode 100644 index 0000000000..34119d38cb --- /dev/null +++ b/tests/load/synapse/__init__.py @@ -0,0 +1,3 @@ +from tests.utils import skip_if_not_active + +skip_if_not_active("synapse") diff --git a/tests/load/synapse/test_synapse_configuration.py b/tests/load/synapse/test_synapse_configuration.py new file mode 100644 index 0000000000..4055cbab38 --- /dev/null +++ b/tests/load/synapse/test_synapse_configuration.py @@ -0,0 +1,46 @@ +import pytest + +from dlt.common.configuration import resolve_configuration +from dlt.common.exceptions import SystemConfigurationException + +from dlt.destinations.impl.synapse.configuration import ( + SynapseClientConfiguration, + SynapseCredentials, +) + + +def test_synapse_configuration() -> None: + # By default, unique indexes should not be created. + assert SynapseClientConfiguration().create_indexes is False + + +def test_parse_native_representation() -> None: + # Case: unsupported driver specified. + with pytest.raises(SystemConfigurationException): + resolve_configuration( + SynapseCredentials( + "synapse://test_user:test_pwd@test.sql.azuresynapse.net/test_db?DRIVER=ODBC+Driver+17+for+SQL+Server" + ) + ) + + +def test_to_odbc_dsn_longasmax() -> None: + # Case: LONGASMAX not specified in query (this is the expected scenario). + creds = resolve_configuration( + SynapseCredentials( + "synapse://test_user:test_pwd@test.sql.azuresynapse.net/test_db?DRIVER=ODBC+Driver+18+for+SQL+Server" + ) + ) + dsn = creds.to_odbc_dsn() + result = {k: v for k, v in (param.split("=") for param in dsn.split(";"))} + assert result["LONGASMAX"] == "yes" + + # Case: LONGASMAX specified in query; specified value should be overridden. + creds = resolve_configuration( + SynapseCredentials( + "synapse://test_user:test_pwd@test.sql.azuresynapse.net/test_db?DRIVER=ODBC+Driver+18+for+SQL+Server&LONGASMAX=no" + ) + ) + dsn = creds.to_odbc_dsn() + result = {k: v for k, v in (param.split("=") for param in dsn.split(";"))} + assert result["LONGASMAX"] == "yes" diff --git a/tests/load/synapse/test_synapse_table_builder.py b/tests/load/synapse/test_synapse_table_builder.py new file mode 100644 index 0000000000..f58a7d5883 --- /dev/null +++ b/tests/load/synapse/test_synapse_table_builder.py @@ -0,0 +1,130 @@ +import os +import pytest +import sqlfluff +from copy import deepcopy +from sqlfluff.api.simple import APIParsingError + +from dlt.common.utils import uniq_id +from dlt.common.schema import Schema, TColumnHint + +from dlt.destinations.impl.synapse.synapse import SynapseClient +from dlt.destinations.impl.synapse.configuration import ( + SynapseClientConfiguration, + SynapseCredentials, +) + +from tests.load.utils import TABLE_UPDATE +from dlt.destinations.impl.synapse.synapse import HINT_TO_SYNAPSE_ATTR + + +@pytest.fixture +def schema() -> Schema: + return Schema("event") + + +@pytest.fixture +def client(schema: Schema) -> SynapseClient: + # return client without opening connection + client = SynapseClient( + schema, + SynapseClientConfiguration( + dataset_name="test_" + uniq_id(), credentials=SynapseCredentials() + ), + ) + assert client.config.create_indexes is False + return client + + +@pytest.fixture +def client_with_indexes_enabled(schema: Schema) -> SynapseClient: + # return client without opening connection + client = SynapseClient( + schema, + SynapseClientConfiguration( + dataset_name="test_" + uniq_id(), credentials=SynapseCredentials(), create_indexes=True + ), + ) + assert client.config.create_indexes is True + return client + + +def test_create_table(client: SynapseClient) -> None: + # non existing table + sql = client._get_table_update_sql("event_test_table", TABLE_UPDATE, False)[0] + sqlfluff.parse(sql, dialect="tsql") + assert "event_test_table" in sql + assert '"col1" bigint NOT NULL' in sql + assert '"col2" float NOT NULL' in sql + assert '"col3" bit NOT NULL' in sql + assert '"col4" datetimeoffset NOT NULL' in sql + assert '"col5" nvarchar(max) NOT NULL' in sql + assert '"col6" decimal(38,9) NOT NULL' in sql + assert '"col7" varbinary(max) NOT NULL' in sql + assert '"col8" decimal(38,0)' in sql + assert '"col9" nvarchar(max) NOT NULL' in sql + assert '"col10" date NOT NULL' in sql + assert '"col11" time NOT NULL' in sql + assert '"col1_precision" smallint NOT NULL' in sql + assert '"col4_precision" datetimeoffset(3) NOT NULL' in sql + assert '"col5_precision" nvarchar(25)' in sql + assert '"col6_precision" decimal(6,2) NOT NULL' in sql + assert '"col7_precision" varbinary(19)' in sql + assert '"col11_precision" time(3) NOT NULL' in sql + assert "WITH ( HEAP )" in sql + + +def test_alter_table(client: SynapseClient) -> None: + # existing table has no columns + sql = client._get_table_update_sql("event_test_table", TABLE_UPDATE, True)[0] + sqlfluff.parse(sql, dialect="tsql") + canonical_name = client.sql_client.make_qualified_table_name("event_test_table") + assert sql.count(f"ALTER TABLE {canonical_name}\nADD") == 1 + assert "event_test_table" in sql + assert '"col1" bigint NOT NULL' in sql + assert '"col2" float NOT NULL' in sql + assert '"col3" bit NOT NULL' in sql + assert '"col4" datetimeoffset NOT NULL' in sql + assert '"col5" nvarchar(max) NOT NULL' in sql + assert '"col6" decimal(38,9) NOT NULL' in sql + assert '"col7" varbinary(max) NOT NULL' in sql + assert '"col8" decimal(38,0)' in sql + assert '"col9" nvarchar(max) NOT NULL' in sql + assert '"col10" date NOT NULL' in sql + assert '"col11" time NOT NULL' in sql + assert '"col1_precision" smallint NOT NULL' in sql + assert '"col4_precision" datetimeoffset(3) NOT NULL' in sql + assert '"col5_precision" nvarchar(25)' in sql + assert '"col6_precision" decimal(6,2) NOT NULL' in sql + assert '"col7_precision" varbinary(19)' in sql + assert '"col11_precision" time(3) NOT NULL' in sql + assert "WITH ( HEAP )" not in sql + + +@pytest.mark.parametrize("hint", ["primary_key", "unique"]) +def test_create_table_with_column_hint( + client: SynapseClient, client_with_indexes_enabled: SynapseClient, hint: TColumnHint +) -> None: + attr = HINT_TO_SYNAPSE_ATTR[hint] + + # Case: table without hint. + sql = client._get_table_update_sql("event_test_table", TABLE_UPDATE, False)[0] + sqlfluff.parse(sql, dialect="tsql") + assert f" {attr} " not in sql + + # Case: table with hint, but client does not have indexes enabled. + mod_update = deepcopy(TABLE_UPDATE) + mod_update[0][hint] = True # type: ignore[typeddict-unknown-key] + sql = client._get_table_update_sql("event_test_table", mod_update, False)[0] + sqlfluff.parse(sql, dialect="tsql") + assert f" {attr} " not in sql + + # Case: table with hint, client has indexes enabled. + sql = client_with_indexes_enabled._get_table_update_sql("event_test_table", mod_update, False)[ + 0 + ] + # We expect an error because "PRIMARY KEY NONCLUSTERED NOT ENFORCED" and + # "UNIQUE NOT ENFORCED" are invalid in the generic "tsql" dialect. + # They are however valid in the Synapse variant of the dialect. + with pytest.raises(APIParsingError): + sqlfluff.parse(sql, dialect="tsql") + assert f'"col1" bigint {attr} NOT NULL' in sql diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 153504bf4a..b8d2e31e3f 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -387,7 +387,8 @@ def test_get_storage_table_with_all_types(client: SqlJobClientBase) -> None: "time", ): continue - if client.config.destination_type == "mssql" and c["data_type"] in ("complex"): + # mssql and synapse have no native data type for the complex type. + if client.config.destination_type in ("mssql", "synapse") and c["data_type"] in ("complex"): continue assert c["data_type"] == expected_c["data_type"] diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index 96f0db09bb..4bdf08e23c 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -38,7 +38,7 @@ def client(request) -> Iterator[SqlJobClientBase]: @pytest.mark.parametrize( "client", - destinations_configs(default_sql_configs=True, exclude=["mssql"]), + destinations_configs(default_sql_configs=True, exclude=["mssql", "synapse"]), indirect=True, ids=lambda x: x.name, ) @@ -263,9 +263,15 @@ def test_execute_df(client: SqlJobClientBase) -> None: client.update_stored_schema() table_name = prepare_temp_table(client) f_q_table_name = client.sql_client.make_qualified_table_name(table_name) - insert_query = ",".join([f"({idx})" for idx in range(0, total_records)]) - client.sql_client.execute_sql(f"INSERT INTO {f_q_table_name} VALUES {insert_query};") + if client.capabilities.insert_values_writer_type == "default": + insert_query = ",".join([f"({idx})" for idx in range(0, total_records)]) + sql_stmt = f"INSERT INTO {f_q_table_name} VALUES {insert_query};" + elif client.capabilities.insert_values_writer_type == "select_union": + insert_query = " UNION ALL ".join([f"SELECT {idx}" for idx in range(0, total_records)]) + sql_stmt = f"INSERT INTO {f_q_table_name} {insert_query};" + + client.sql_client.execute_sql(sql_stmt) with client.sql_client.execute_query( f"SELECT * FROM {f_q_table_name} ORDER BY col ASC" ) as curr: diff --git a/tests/load/utils.py b/tests/load/utils.py index 6811ca59a6..55445e0b95 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -163,7 +163,7 @@ def destinations_configs( destination_configs += [ DestinationTestConfiguration(destination=destination) for destination in SQL_DESTINATIONS - if destination != "athena" + if destination not in ("athena", "synapse") ] destination_configs += [ DestinationTestConfiguration(destination="duckdb", file_format="parquet") @@ -190,6 +190,10 @@ def destinations_configs( extra_info="iceberg", ) ] + # dbt for Synapse has some complications and I couldn't get it to pass all tests. + destination_configs += [ + DestinationTestConfiguration(destination="synapse", supports_dbt=False) + ] if default_vector_configs: # for now only weaviate @@ -465,7 +469,6 @@ def yield_client_with_storage( ) as client: client.initialize_storage() yield client - # print(dataset_name) client.sql_client.drop_dataset() if isinstance(client, WithStagingDataset): with client.with_staging_dataset(): diff --git a/tests/utils.py b/tests/utils.py index cf172f9733..211f87874d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -45,6 +45,7 @@ "motherduck", "mssql", "qdrant", + "synapse", } NON_SQL_DESTINATIONS = {"filesystem", "weaviate", "dummy", "motherduck", "qdrant"} SQL_DESTINATIONS = IMPLEMENTED_DESTINATIONS - NON_SQL_DESTINATIONS From 05b05305ee46108261789ed25442aec518b1cca6 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Thu, 18 Jan 2024 16:44:34 +0100 Subject: [PATCH 02/84] make var type consistent --- dlt/common/data_writers/escape.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlt/common/data_writers/escape.py b/dlt/common/data_writers/escape.py index b56a0d8f19..20932fec6c 100644 --- a/dlt/common/data_writers/escape.py +++ b/dlt/common/data_writers/escape.py @@ -101,7 +101,7 @@ def escape_mssql_literal(v: Any) -> Any: # 8000 is the max value for n in VARBINARY(n) # https://learn.microsoft.com/en-us/sql/t-sql/data-types/binary-and-varbinary-transact-sql if len(v) <= 8000: - n = len(v) + n = str(len(v)) else: n = "MAX" return f"CONVERT(VARBINARY({n}), '{v.hex()}', 2)" From dc7619ad6f778b55cefaa09a3d3ef194ae5bc07a Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Thu, 18 Jan 2024 17:12:32 +0100 Subject: [PATCH 03/84] simplify client init logic --- dlt/destinations/impl/synapse/synapse.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py index 18d1fa81d4..0ad959f7ab 100644 --- a/dlt/destinations/impl/synapse/synapse.py +++ b/dlt/destinations/impl/synapse/synapse.py @@ -12,7 +12,7 @@ from dlt.destinations.insert_job_client import InsertValuesJobClient from dlt.destinations.job_client_impl import SqlJobClientBase -from dlt.destinations.impl.mssql.mssql import MsSqlTypeMapper, MsSqlClient, HINT_TO_MSSQL_ATTR +from dlt.destinations.impl.mssql.mssql import MsSqlClient from dlt.destinations.impl.synapse import capabilities from dlt.destinations.impl.synapse.sql_client import SynapseSqlClient @@ -29,11 +29,11 @@ class SynapseClient(MsSqlClient, SupportsStagingDestination): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: SynapseClientConfiguration) -> None: - sql_client = SynapseSqlClient(config.normalize_dataset_name(schema), config.credentials) - InsertValuesJobClient.__init__(self, schema, config, sql_client) + super().__init__(schema, config) self.config: SynapseClientConfiguration = config - self.sql_client = sql_client - self.type_mapper = MsSqlTypeMapper(self.capabilities) + self.sql_client = SynapseSqlClient( + config.normalize_dataset_name(schema), config.credentials + ) self.active_hints = deepcopy(HINT_TO_SYNAPSE_ATTR) if not self.config.create_indexes: From 702dd28032fd6a1e36214d34131373afbbed03ba Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Sun, 21 Jan 2024 01:34:48 +0100 Subject: [PATCH 04/84] add support for table index type configuration --- dlt/common/data_writers/escape.py | 6 +- dlt/common/destination/reference.py | 4 +- dlt/common/schema/schema.py | 8 + dlt/common/schema/typing.py | 3 + dlt/common/schema/utils.py | 12 ++ dlt/destinations/impl/mssql/mssql.py | 2 + .../impl/synapse/configuration.py | 9 +- dlt/destinations/impl/synapse/factory.py | 7 + dlt/destinations/impl/synapse/synapse.py | 96 ++++++++++-- dlt/extract/decorators.py | 7 + dlt/extract/hints.py | 3 + tests/load/pipeline/test_table_indexing.py | 140 ++++++++++++++++++ .../synapse/test_synapse_table_builder.py | 13 +- 13 files changed, 292 insertions(+), 18 deletions(-) create mode 100644 tests/load/pipeline/test_table_indexing.py diff --git a/dlt/common/data_writers/escape.py b/dlt/common/data_writers/escape.py index 20932fec6c..1de584de2e 100644 --- a/dlt/common/data_writers/escape.py +++ b/dlt/common/data_writers/escape.py @@ -98,9 +98,9 @@ def escape_mssql_literal(v: Any) -> Any: json.dumps(v), prefix="N'", escape_dict=MS_SQL_ESCAPE_DICT, escape_re=MS_SQL_ESCAPE_RE ) if isinstance(v, bytes): - # 8000 is the max value for n in VARBINARY(n) - # https://learn.microsoft.com/en-us/sql/t-sql/data-types/binary-and-varbinary-transact-sql - if len(v) <= 8000: + from dlt.destinations.impl.mssql.mssql import VARBINARY_MAX_N + + if len(v) <= VARBINARY_MAX_N: n = str(len(v)) else: n = "MAX" diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 1c28dffa8c..59f13b30b9 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -34,7 +34,7 @@ from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.schema.typing import TWriteDisposition from dlt.common.schema.exceptions import InvalidDatasetName -from dlt.common.schema.utils import get_write_disposition, get_table_format +from dlt.common.schema.utils import get_write_disposition, get_table_format, get_table_index_type from dlt.common.configuration import configspec, with_config, resolve_configuration, known_sections from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration from dlt.common.configuration.accessors import config @@ -372,6 +372,8 @@ def get_load_table(self, table_name: str, prepare_for_staging: bool = False) -> table["write_disposition"] = get_write_disposition(self.schema.tables, table_name) if "table_format" not in table: table["table_format"] = get_table_format(self.schema.tables, table_name) + if "table_index_type" not in table: + table["table_index_type"] = get_table_index_type(self.schema.tables, table_name) return table except KeyError: raise UnknownTableException(table_name) diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index e95699b91e..ccfc038085 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -546,12 +546,20 @@ def data_tables(self, include_incomplete: bool = False) -> List[TTableSchema]: ) ] + def data_table_names(self) -> List[str]: + """Returns list of table table names. Excludes dlt table names.""" + return [t["name"] for t in self.data_tables()] + def dlt_tables(self) -> List[TTableSchema]: """Gets dlt tables""" return [ t for t in self._schema_tables.values() if t["name"].startswith(self._dlt_tables_prefix) ] + def dlt_table_names(self) -> List[str]: + """Returns list of dlt table names.""" + return [t["name"] for t in self.dlt_tables()] + def get_preferred_type(self, col_name: str) -> Optional[TDataType]: return next((m[1] for m in self._compiled_preferred_types if m[0].search(col_name)), None) diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index 9a27cbe4bb..351d666553 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -62,6 +62,8 @@ """Known hints of a column used to declare hint regexes.""" TWriteDisposition = Literal["skip", "append", "replace", "merge"] TTableFormat = Literal["iceberg"] +TTableIndexType = Literal["heap", "clustered_columnstore_index"] +"Table index type. Currently only used for Synapse destination." TTypeDetections = Literal[ "timestamp", "iso_timestamp", "iso_date", "large_integer", "hexbytes_to_text", "wei_to_double" ] @@ -165,6 +167,7 @@ class TTableSchema(TypedDict, total=False): columns: TTableSchemaColumns resource: Optional[str] table_format: Optional[TTableFormat] + table_index_type: Optional[TTableIndexType] class TPartialTableSchema(TTableSchema): diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index dc243f50dd..5ea244148e 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -32,6 +32,7 @@ TColumnSchema, TColumnProp, TTableFormat, + TTableIndexType, TColumnHint, TTypeDetectionFunc, TTypeDetections, @@ -618,6 +619,14 @@ def get_table_format(tables: TSchemaTables, table_name: str) -> TTableFormat: ) +def get_table_index_type(tables: TSchemaTables, table_name: str) -> TTableIndexType: + """Returns table index type of a table if present. If not, looks up into parent table.""" + return cast( + TTableIndexType, + get_inherited_table_hint(tables, table_name, "table_index_type", allow_none=True), + ) + + def table_schema_has_type(table: TTableSchema, _typ: TDataType) -> bool: """Checks if `table` schema contains column with type _typ""" return any(c.get("data_type") == _typ for c in table["columns"].values()) @@ -724,6 +733,7 @@ def new_table( resource: str = None, schema_contract: TSchemaContract = None, table_format: TTableFormat = None, + table_index_type: TTableIndexType = None, ) -> TTableSchema: table: TTableSchema = { "name": table_name, @@ -742,6 +752,8 @@ def new_table( table["schema_contract"] = schema_contract if table_format: table["table_format"] = table_format + if table_index_type is not None: + table["table_index_type"] = table_index_type if validate_schema: validate_dict_ignoring_xkeys( spec=TColumnSchema, diff --git a/dlt/destinations/impl/mssql/mssql.py b/dlt/destinations/impl/mssql/mssql.py index e97389f185..b6af345e36 100644 --- a/dlt/destinations/impl/mssql/mssql.py +++ b/dlt/destinations/impl/mssql/mssql.py @@ -20,6 +20,8 @@ HINT_TO_MSSQL_ATTR: Dict[TColumnHint, str] = {"unique": "UNIQUE"} +VARCHAR_MAX_N: int = 4000 +VARBINARY_MAX_N: int = 8000 class MsSqlTypeMapper(TypeMapper): diff --git a/dlt/destinations/impl/synapse/configuration.py b/dlt/destinations/impl/synapse/configuration.py index 0596cc2c46..966997b5a2 100644 --- a/dlt/destinations/impl/synapse/configuration.py +++ b/dlt/destinations/impl/synapse/configuration.py @@ -1,6 +1,7 @@ from typing import Final, Any, List, Dict, Optional, ClassVar from dlt.common.configuration import configspec +from dlt.common.schema.typing import TTableIndexType from dlt.destinations.impl.mssql.configuration import ( MsSqlCredentials, @@ -30,9 +31,15 @@ class SynapseClientConfiguration(MsSqlClientConfiguration): destination_type: Final[str] = "synapse" # type: ignore credentials: SynapseCredentials + # While Synapse uses CLUSTERED COLUMNSTORE INDEX tables by default, we use + # HEAP tables (no indexing) by default. HEAP is a more robust choice, because + # columnstore tables do not support varchar(max), nvarchar(max), and varbinary(max). + # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-index + default_table_index_type: Optional[TTableIndexType] = "heap" + # Determines if `primary_key` and `unique` column hints are applied. # Set to False by default because the PRIMARY KEY and UNIQUE constraints # are tricky in Synapse: they are NOT ENFORCED and can lead to innacurate # results if the user does not ensure all column values are unique. # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-table-constraints - create_indexes: bool = False + create_indexes: Optional[bool] = False diff --git a/dlt/destinations/impl/synapse/factory.py b/dlt/destinations/impl/synapse/factory.py index fa7facc0ca..6bdf2946b6 100644 --- a/dlt/destinations/impl/synapse/factory.py +++ b/dlt/destinations/impl/synapse/factory.py @@ -1,6 +1,7 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.schema.typing import TTableIndexType from dlt.destinations.impl.synapse import capabilities from dlt.destinations.impl.synapse.configuration import ( @@ -27,6 +28,7 @@ def client_class(self) -> t.Type["SynapseClient"]: def __init__( self, credentials: t.Union[SynapseCredentials, t.Dict[str, t.Any], str] = None, + default_table_index_type: t.Optional[TTableIndexType] = "heap", create_indexes: bool = False, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, @@ -39,11 +41,16 @@ def __init__( Args: credentials: Credentials to connect to the Synapse dedicated pool. Can be an instance of `SynapseCredentials` or a connection string in the format `synapse://user:password@host:port/database` + default_table_index_type: Table index type that is used if no + table index type is specified on the resource. This setting only + applies to data tables, dlt system tables are not affected + (they always have "heap" as table index type). create_indexes: Should unique indexes be created, defaults to False **kwargs: Additional arguments passed to the destination config """ super().__init__( credentials=credentials, + default_table_index_type=default_table_index_type, create_indexes=create_indexes, destination_name=destination_name, environment=environment, diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py index 0ad959f7ab..e01e851d83 100644 --- a/dlt/destinations/impl/synapse/synapse.py +++ b/dlt/destinations/impl/synapse/synapse.py @@ -1,18 +1,24 @@ -from typing import ClassVar, Sequence, List, Dict, Any, Optional +from typing import ClassVar, Sequence, List, Dict, Any, Optional, cast from copy import deepcopy +from textwrap import dedent from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import SupportsStagingDestination, NewLoadJob from dlt.common.schema import TTableSchema, TColumnSchema, Schema, TColumnHint -from dlt.common.schema.typing import TTableSchemaColumns +from dlt.common.schema.typing import TTableSchemaColumns, TTableIndexType from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.insert_job_client import InsertValuesJobClient from dlt.destinations.job_client_impl import SqlJobClientBase -from dlt.destinations.impl.mssql.mssql import MsSqlClient +from dlt.destinations.impl.mssql.mssql import ( + MsSqlTypeMapper, + MsSqlClient, + VARCHAR_MAX_N, + VARBINARY_MAX_N, +) from dlt.destinations.impl.synapse import capabilities from dlt.destinations.impl.synapse.sql_client import SynapseSqlClient @@ -23,9 +29,13 @@ "primary_key": "PRIMARY KEY NONCLUSTERED NOT ENFORCED", "unique": "UNIQUE NOT ENFORCED", } +TABLE_INDEX_TYPE_TO_SYNAPSE_ATTR: Dict[TTableIndexType, str] = { + "heap": "HEAP", + "clustered_columnstore_index": "CLUSTERED COLUMNSTORE INDEX", +} -class SynapseClient(MsSqlClient, SupportsStagingDestination): +class SynapseClient(MsSqlClient): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: SynapseClientConfiguration) -> None: @@ -43,20 +53,54 @@ def __init__(self, schema: Schema, config: SynapseClientConfiguration) -> None: def _get_table_update_sql( self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool ) -> List[str]: + table = self.get_load_table(table_name) + if table is None: + table_index_type = self.config.default_table_index_type + else: + table_index_type = table.get("table_index_type") + if table_index_type == "clustered_columnstore_index": + new_columns = self._get_columstore_valid_columns(new_columns) + _sql_result = SqlJobClientBase._get_table_update_sql( self, table_name, new_columns, generate_alter ) if not generate_alter: - # Append WITH clause to create heap table instead of default - # columnstore table. Heap tables are a more robust choice, because - # columnstore tables do not support varchar(max), nvarchar(max), - # and varbinary(max). - # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-index - sql_result = [_sql_result[0] + "\n WITH ( HEAP );"] + table_index_type_attr = TABLE_INDEX_TYPE_TO_SYNAPSE_ATTR[table_index_type] + sql_result = [_sql_result[0] + f"\n WITH ( {table_index_type_attr} );"] else: sql_result = _sql_result return sql_result + def _get_columstore_valid_columns( + self, columns: Sequence[TColumnSchema] + ) -> Sequence[TColumnSchema]: + return [self._get_columstore_valid_column(c) for c in columns] + + def _get_columstore_valid_column(self, c: TColumnSchema) -> TColumnSchema: + """ + Returns TColumnSchema that maps to a Synapse data type that can participate in a columnstore index. + + varchar(max), nvarchar(max), and varbinary(max) are replaced with + varchar(n), nvarchar(n), and varbinary(n), respectively, where + n equals the user-specified precision, or the maximum allowed + value if the user did not specify a precision. + """ + varchar_source_types = [ + sct + for sct, dbt in MsSqlTypeMapper.sct_to_unbound_dbt.items() + if dbt in ("varchar(max)", "nvarchar(max)") + ] + varbinary_source_types = [ + sct + for sct, dbt in MsSqlTypeMapper.sct_to_unbound_dbt.items() + if dbt == "varbinary(max)" + ] + if c["data_type"] in varchar_source_types and "precision" not in c: + return {**c, **{"precision": VARCHAR_MAX_N}} + elif c["data_type"] in varbinary_source_types and "precision" not in c: + return {**c, **{"precision": VARBINARY_MAX_N}} + return c + def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] ) -> List[NewLoadJob]: @@ -64,6 +108,38 @@ def _create_replace_followup_jobs( return [SynapseStagingCopyJob.from_table_chain(table_chain, self.sql_client)] return super()._create_replace_followup_jobs(table_chain) + def get_load_table(self, table_name: str, staging: bool = False) -> TTableSchema: + table = super().get_load_table(table_name, staging) + if table is None: + return None + if table_name in self.schema.dlt_table_names(): + # dlt tables should always be heap tables, regardless of the user + # configuration. Why? "For small lookup tables, less than 60 million rows, + # consider using HEAP or clustered index for faster query performance." + # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-index#heap-tables + table["table_index_type"] = "heap" + if table["table_index_type"] is None: + table["table_index_type"] = self.config.default_table_index_type + return table + + def get_storage_table_index_type(self, table_name: str) -> TTableIndexType: + """Returns table index type of table in storage destination.""" + with self.sql_client as sql_client: + schema_name = sql_client.fully_qualified_dataset_name(escape=False) + sql = dedent(f""" + SELECT + CASE i.type_desc + WHEN 'HEAP' THEN 'heap' + WHEN 'CLUSTERED COLUMNSTORE' THEN 'clustered_columnstore_index' + END AS table_index_type + FROM sys.indexes i + INNER JOIN sys.tables t ON t.object_id = i.object_id + INNER JOIN sys.schemas s ON s.schema_id = t.schema_id + WHERE s.name = '{schema_name}' AND t.name = '{table_name}' + """) + table_index_type = sql_client.execute_sql(sql)[0][0] + return cast(TTableIndexType, table_index_type) + class SynapseStagingCopyJob(SqlStagingCopyJob): @classmethod diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index cf7426e683..573d3d3ad0 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -36,6 +36,7 @@ TAnySchemaColumns, TSchemaContract, TTableFormat, + TTableIndexType, ) from dlt.extract.utils import ( ensure_table_schema_columns_hint, @@ -256,6 +257,7 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + table_index_type: TTableHintTemplate[TTableIndexType] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, ) -> DltResource: ... @@ -273,6 +275,7 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + table_index_type: TTableHintTemplate[TTableIndexType] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, ) -> Callable[[Callable[TResourceFunParams, Any]], DltResource]: ... @@ -290,6 +293,7 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + table_index_type: TTableHintTemplate[TTableIndexType] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, standalone: Literal[True] = True, @@ -308,6 +312,7 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + table_index_type: TTableHintTemplate[TTableIndexType] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, ) -> DltResource: ... @@ -324,6 +329,7 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + table_index_type: TTableHintTemplate[TTableIndexType] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, standalone: bool = False, @@ -403,6 +409,7 @@ def make_resource( merge_key=merge_key, schema_contract=schema_contract, table_format=table_format, + table_index_type=table_index_type, ) return DltResource.from_data( _data, diff --git a/dlt/extract/hints.py b/dlt/extract/hints.py index 437dbbc6bd..36354eb0da 100644 --- a/dlt/extract/hints.py +++ b/dlt/extract/hints.py @@ -12,6 +12,7 @@ TWriteDisposition, TAnySchemaColumns, TTableFormat, + TTableIndexType, TSchemaContract, ) from dlt.common.typing import TDataItem @@ -274,6 +275,7 @@ def new_table_template( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, + table_index_type: TTableHintTemplate[TTableIndexType] = None, ) -> TResourceHints: validator, schema_contract = create_item_validator(columns, schema_contract) clean_columns = columns @@ -289,6 +291,7 @@ def new_table_template( columns=clean_columns, # type: ignore schema_contract=schema_contract, # type: ignore table_format=table_format, # type: ignore + table_index_type=table_index_type, # type: ignore ) if not table_name: new_template.pop("name") diff --git a/tests/load/pipeline/test_table_indexing.py b/tests/load/pipeline/test_table_indexing.py new file mode 100644 index 0000000000..5f62cddfee --- /dev/null +++ b/tests/load/pipeline/test_table_indexing.py @@ -0,0 +1,140 @@ +import os +import pytest +from typing import Iterator, List, Any, Union +from textwrap import dedent + +import dlt +from dlt.common.schema import TColumnSchema +from dlt.common.schema.typing import TTableIndexType, TSchemaTables +from dlt.common.schema.utils import get_table_index_type + +from dlt.destinations.sql_client import SqlClientBase + +from tests.load.utils import TABLE_UPDATE, TABLE_ROW_ALL_DATA_TYPES +from tests.load.pipeline.utils import ( + destinations_configs, + DestinationTestConfiguration, +) + + +TABLE_INDEX_TYPE_COLUMN_SCHEMA_PARAM_GRID = [ + ("heap", None), + # For "clustered_columnstore_index" tables, different code paths exist + # when no column schema is specified versus when a column schema is + # specified, so we test both. + ("clustered_columnstore_index", None), + ("clustered_columnstore_index", TABLE_UPDATE), +] + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["synapse"]), + ids=lambda x: x.name, +) +@pytest.mark.parametrize( + "table_index_type,column_schema", TABLE_INDEX_TYPE_COLUMN_SCHEMA_PARAM_GRID +) +def test_default_table_index_type_configuration( + destination_config: DestinationTestConfiguration, + table_index_type: TTableIndexType, + column_schema: Union[List[TColumnSchema], None], +) -> None: + # Configure default_table_index_type. + os.environ["DESTINATION__SYNAPSE__DEFAULT_TABLE_INDEX_TYPE"] = table_index_type + + @dlt.resource( + name="items_without_table_index_type_specified", + write_disposition="append", + columns=column_schema, + ) + def items_without_table_index_type_specified() -> Iterator[Any]: + yield TABLE_ROW_ALL_DATA_TYPES + + pipeline = destination_config.setup_pipeline( + f"test_default_table_index_type_{table_index_type}", + full_refresh=True, + ) + job_client = pipeline.destination_client() + # Assert configuration value gets properly propagated to job client configuration. + assert job_client.config.default_table_index_type == table_index_type # type: ignore[attr-defined] + + # Run the pipeline and create the tables. + pipeline.run(items_without_table_index_type_specified) + + # For all tables, assert the applied index type equals the expected index type. + # Child tables, if any, inherit the index type of their parent. + tables = pipeline.default_schema.tables + for table_name in tables: + applied_table_index_type = job_client.get_storage_table_index_type(table_name) # type: ignore[attr-defined] + if table_name in pipeline.default_schema.data_table_names(): + # For data tables, the applied table index type should be the default value. + assert applied_table_index_type == job_client.config.default_table_index_type # type: ignore[attr-defined] + elif table_name in pipeline.default_schema.dlt_table_names(): + # For dlt tables, the applied table index type should always be "heap". + assert applied_table_index_type == "heap" + + # Test overriding the default_table_index_type from a resource configuration. + if job_client.config.default_table_index_type == "heap": # type: ignore[attr-defined] + + @dlt.resource( + name="items_with_table_index_type_specified", + write_disposition="append", + table_index_type="clustered_columnstore_index", + columns=column_schema, + ) + def items_with_table_index_type_specified() -> Iterator[Any]: + yield TABLE_ROW_ALL_DATA_TYPES + + pipeline.run(items_with_table_index_type_specified) + applied_table_index_type = job_client.get_storage_table_index_type( # type: ignore[attr-defined] + "items_with_table_index_type_specified" + ) + # While the default is "heap", the applied index type should be "clustered_columnstore_index" + # because it was provided as argument to the resource. + assert applied_table_index_type == "clustered_columnstore_index" + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["synapse"]), + ids=lambda x: x.name, +) +@pytest.mark.parametrize( + "table_index_type,column_schema", TABLE_INDEX_TYPE_COLUMN_SCHEMA_PARAM_GRID +) +def test_resource_table_index_type_configuration( + destination_config: DestinationTestConfiguration, + table_index_type: TTableIndexType, + column_schema: Union[List[TColumnSchema], None], +) -> None: + @dlt.resource( + name="items_with_table_index_type_specified", + write_disposition="append", + table_index_type=table_index_type, + columns=column_schema, + ) + def items_with_table_index_type_specified() -> Iterator[Any]: + yield TABLE_ROW_ALL_DATA_TYPES + + pipeline = destination_config.setup_pipeline( + f"test_table_index_type_{table_index_type}", + full_refresh=True, + ) + + # Run the pipeline and create the tables. + pipeline.run(items_with_table_index_type_specified) + + # For all tables, assert the applied index type equals the expected index type. + # Child tables, if any, inherit the index type of their parent. + job_client = pipeline.destination_client() + tables = pipeline.default_schema.tables + for table_name in tables: + applied_table_index_type = job_client.get_storage_table_index_type(table_name) # type: ignore[attr-defined] + if table_name in pipeline.default_schema.data_table_names(): + # For data tables, the applied table index type should be the type + # configured in the resource. + assert applied_table_index_type == table_index_type + elif table_name in pipeline.default_schema.dlt_table_names(): + # For dlt tables, the applied table index type should always be "heap". + assert applied_table_index_type == "heap" diff --git a/tests/load/synapse/test_synapse_table_builder.py b/tests/load/synapse/test_synapse_table_builder.py index f58a7d5883..4719a8d003 100644 --- a/tests/load/synapse/test_synapse_table_builder.py +++ b/tests/load/synapse/test_synapse_table_builder.py @@ -14,7 +14,10 @@ ) from tests.load.utils import TABLE_UPDATE -from dlt.destinations.impl.synapse.synapse import HINT_TO_SYNAPSE_ATTR +from dlt.destinations.impl.synapse.synapse import ( + HINT_TO_SYNAPSE_ATTR, + TABLE_INDEX_TYPE_TO_SYNAPSE_ATTR, +) @pytest.fixture @@ -70,7 +73,9 @@ def test_create_table(client: SynapseClient) -> None: assert '"col6_precision" decimal(6,2) NOT NULL' in sql assert '"col7_precision" varbinary(19)' in sql assert '"col11_precision" time(3) NOT NULL' in sql - assert "WITH ( HEAP )" in sql + table_index_type = client.config.default_table_index_type + table_index_type_attr = TABLE_INDEX_TYPE_TO_SYNAPSE_ATTR[table_index_type] + assert f"WITH ( {table_index_type_attr} )" in sql def test_alter_table(client: SynapseClient) -> None: @@ -97,7 +102,9 @@ def test_alter_table(client: SynapseClient) -> None: assert '"col6_precision" decimal(6,2) NOT NULL' in sql assert '"col7_precision" varbinary(19)' in sql assert '"col11_precision" time(3) NOT NULL' in sql - assert "WITH ( HEAP )" not in sql + table_index_type = client.config.default_table_index_type + table_index_type_attr = TABLE_INDEX_TYPE_TO_SYNAPSE_ATTR[table_index_type] + assert f"WITH ( {table_index_type_attr} )" not in sql @pytest.mark.parametrize("hint", ["primary_key", "unique"]) From db73162fef46c98c73ea00daba686d53211c6f81 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Tue, 23 Jan 2024 14:59:10 +0100 Subject: [PATCH 05/84] add load concurrency handling and warning --- .../impl/synapse/configuration.py | 61 ++++++++++++++++++- dlt/destinations/impl/synapse/factory.py | 10 +-- dlt/pipeline/pipeline.py | 9 ++- .../load/pipeline/test_replace_disposition.py | 10 ++- 4 files changed, 76 insertions(+), 14 deletions(-) diff --git a/dlt/destinations/impl/synapse/configuration.py b/dlt/destinations/impl/synapse/configuration.py index 966997b5a2..b5eec82e9e 100644 --- a/dlt/destinations/impl/synapse/configuration.py +++ b/dlt/destinations/impl/synapse/configuration.py @@ -1,7 +1,8 @@ from typing import Final, Any, List, Dict, Optional, ClassVar from dlt.common.configuration import configspec -from dlt.common.schema.typing import TTableIndexType +from dlt.common.schema.typing import TTableIndexType, TWriteDisposition +from dlt.common import logger from dlt.destinations.impl.mssql.configuration import ( MsSqlCredentials, @@ -36,10 +37,66 @@ class SynapseClientConfiguration(MsSqlClientConfiguration): # columnstore tables do not support varchar(max), nvarchar(max), and varbinary(max). # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-index default_table_index_type: Optional[TTableIndexType] = "heap" + """ + Table index type that is used if no table index type is specified on the resource. + This only affects data tables, dlt system tables ignore this setting and + are always created as "heap" tables. + """ - # Determines if `primary_key` and `unique` column hints are applied. # Set to False by default because the PRIMARY KEY and UNIQUE constraints # are tricky in Synapse: they are NOT ENFORCED and can lead to innacurate # results if the user does not ensure all column values are unique. # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-table-constraints create_indexes: Optional[bool] = False + """Whether `primary_key` and `unique` column hints are applied.""" + + # Concurrency is disabled by overriding the configured number of workers to 1 at runtime. + auto_disable_concurrency: Optional[bool] = True + """Whether concurrency is automatically disabled in cases where it might cause issues.""" + + __config_gen_annotations__: ClassVar[List[str]] = [ + "default_table_index_type", + "create_indexes", + "auto_disable_concurrency", + ] + + def get_load_workers(self, write_disposition: TWriteDisposition, workers: int) -> int: + if ( + write_disposition == "replace" + and self.replace_strategy == "staging-optimized" + and workers > 1 + ): + print("auto_disable_concurrency:", self.auto_disable_concurrency) + warning_msg_shared = ( + 'Data is being loaded into Synapse with write disposition "replace"' + ' and replace strategy "staging-optimized", while the number of' + f" load workers ({workers}) > 1. This configuration is problematic" + " in some cases, because Synapse does not always handle concurrency well" + " with the CTAS queries that are used behind the scenes to implement" + ' the "staging-optimized" strategy.' + ) + if self.auto_disable_concurrency: + logger.warning( + warning_msg_shared + + " The number of load workers will be automatically adjusted" + " and set to 1 to eliminate concurrency and prevent potential" + " issues. If you don't want this to happen, set the" + " DESTINATION__SYNAPSE__AUTO_DISABLE_CONCURRENCY environment" + ' variable to "false", or add the following to your config TOML:' + "\n\n[destination.synapse]\nauto_disable_concurrency = false\n" + ) + workers = 1 # adjust workers + else: + logger.warning( + warning_msg_shared + + " If you experience your pipeline gets stuck and doesn't finish," + " try reducing the number of load workers by exporting the LOAD__WORKERS" + " environment variable or by setting it in your config TOML:" + "\n\n[load]\nworkers = 1 # a value of 1 disables all concurrency," + " but perhaps a higher value also works\n\n" + "Alternatively, you can set the DESTINATION__SYNAPSE__AUTO_DISABLE_CONCURRENCY" + ' environment variable to "true", or add the following to your config TOML' + " to automatically disable concurrency where needed:" + "\n\n[destination.synapse]\nauto_disable_concurrency = true\n" + ) + return workers diff --git a/dlt/destinations/impl/synapse/factory.py b/dlt/destinations/impl/synapse/factory.py index 6bdf2946b6..f77d8c11c2 100644 --- a/dlt/destinations/impl/synapse/factory.py +++ b/dlt/destinations/impl/synapse/factory.py @@ -30,6 +30,7 @@ def __init__( credentials: t.Union[SynapseCredentials, t.Dict[str, t.Any], str] = None, default_table_index_type: t.Optional[TTableIndexType] = "heap", create_indexes: bool = False, + auto_disable_concurrency: t.Optional[bool] = True, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, **kwargs: t.Any, @@ -41,17 +42,16 @@ def __init__( Args: credentials: Credentials to connect to the Synapse dedicated pool. Can be an instance of `SynapseCredentials` or a connection string in the format `synapse://user:password@host:port/database` - default_table_index_type: Table index type that is used if no - table index type is specified on the resource. This setting only - applies to data tables, dlt system tables are not affected - (they always have "heap" as table index type). - create_indexes: Should unique indexes be created, defaults to False + default_table_index_type: Maps directly to the default_table_index_type attribute of the SynapseClientConfiguration object. + create_indexes: Maps directly to the create_indexes attribute of the SynapseClientConfiguration object. + auto_disable_concurrency: Maps directly to the auto_disable_concurrency attribute of the SynapseClientConfiguration object. **kwargs: Additional arguments passed to the destination config """ super().__init__( credentials=credentials, default_table_index_type=default_table_index_type, create_indexes=create_indexes, + auto_disable_concurrency=auto_disable_concurrency, destination_name=destination_name, environment=environment, **kwargs, diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 73c8f076d1..44a2cbdfdb 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -45,7 +45,7 @@ TAnySchemaColumns, TSchemaContract, ) -from dlt.common.schema.utils import normalize_schema_name +from dlt.common.schema.utils import normalize_schema_name, get_write_disposition from dlt.common.storages.exceptions import LoadPackageNotFound from dlt.common.typing import DictStrStr, TFun, TSecretValue, is_optional_type from dlt.common.runners import pool_runner as runner @@ -483,6 +483,13 @@ def load( # make sure that destination is set and client is importable and can be instantiated client, staging_client = self._get_destination_clients(self.default_schema) + # for synapse we might need to adjust the number of load workers + if self.destination.destination_name == "synapse": + write_disposition = get_write_disposition( + self.default_schema.tables, self.default_schema.data_table_names()[0] + ) + workers = client.config.get_load_workers(write_disposition, workers) # type: ignore[attr-defined] + # create default loader config and the loader load_config = LoaderConfiguration( workers=workers, diff --git a/tests/load/pipeline/test_replace_disposition.py b/tests/load/pipeline/test_replace_disposition.py index 1dde56a6b1..65d3646f2d 100644 --- a/tests/load/pipeline/test_replace_disposition.py +++ b/tests/load/pipeline/test_replace_disposition.py @@ -264,16 +264,14 @@ def test_replace_table_clearing( # use staging tables for replace os.environ["DESTINATION__REPLACE_STRATEGY"] = replace_strategy - if destination_config.destination == "synapse" and replace_strategy == "staging-optimized": - # The "staging-optimized" replace strategy makes Synapse suspend the CTAS - # queries used to recreate the staging table, and hang, when the number - # of load workers is greater than 1. - os.environ["LOAD__WORKERS"] = "1" - pipeline = destination_config.setup_pipeline( "test_replace_table_clearing", dataset_name="test_replace_table_clearing", full_refresh=True ) + if destination_config.destination == "synapse" and replace_strategy == "staging-optimized": + # this case requires load concurrency to be disabled (else the test gets stuck) + assert pipeline.destination_client().config.auto_disable_concurrency is True # type: ignore[attr-defined] + @dlt.resource(name="main_resource", write_disposition="replace", primary_key="id") def items_with_subitems(): data = { From 75be2ce54ccb486679ca1b177551c3097a2f3908 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Tue, 23 Jan 2024 20:26:06 +0100 Subject: [PATCH 06/84] rewrite naive code to prevent IndexError --- dlt/destinations/impl/synapse/configuration.py | 14 +++++++++----- dlt/pipeline/pipeline.py | 7 ++----- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/dlt/destinations/impl/synapse/configuration.py b/dlt/destinations/impl/synapse/configuration.py index b5eec82e9e..119c55ad7a 100644 --- a/dlt/destinations/impl/synapse/configuration.py +++ b/dlt/destinations/impl/synapse/configuration.py @@ -1,8 +1,9 @@ from typing import Final, Any, List, Dict, Optional, ClassVar -from dlt.common.configuration import configspec -from dlt.common.schema.typing import TTableIndexType, TWriteDisposition from dlt.common import logger +from dlt.common.configuration import configspec +from dlt.common.schema.typing import TTableIndexType, TSchemaTables +from dlt.common.schema.utils import get_write_disposition from dlt.destinations.impl.mssql.configuration import ( MsSqlCredentials, @@ -60,13 +61,16 @@ class SynapseClientConfiguration(MsSqlClientConfiguration): "auto_disable_concurrency", ] - def get_load_workers(self, write_disposition: TWriteDisposition, workers: int) -> int: + def get_load_workers(self, tables: TSchemaTables, workers: int) -> int: + """Returns the adjusted number of load workers to prevent concurrency issues.""" + + write_dispositions = [get_write_disposition(tables, table_name) for table_name in tables] + n_replace_dispositions = len([d for d in write_dispositions if d == "replace"]) if ( - write_disposition == "replace" + n_replace_dispositions > 1 and self.replace_strategy == "staging-optimized" and workers > 1 ): - print("auto_disable_concurrency:", self.auto_disable_concurrency) warning_msg_shared = ( 'Data is being loaded into Synapse with write disposition "replace"' ' and replace strategy "staging-optimized", while the number of' diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 44a2cbdfdb..3a0a8f3931 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -45,7 +45,7 @@ TAnySchemaColumns, TSchemaContract, ) -from dlt.common.schema.utils import normalize_schema_name, get_write_disposition +from dlt.common.schema.utils import normalize_schema_name from dlt.common.storages.exceptions import LoadPackageNotFound from dlt.common.typing import DictStrStr, TFun, TSecretValue, is_optional_type from dlt.common.runners import pool_runner as runner @@ -485,10 +485,7 @@ def load( # for synapse we might need to adjust the number of load workers if self.destination.destination_name == "synapse": - write_disposition = get_write_disposition( - self.default_schema.tables, self.default_schema.data_table_names()[0] - ) - workers = client.config.get_load_workers(write_disposition, workers) # type: ignore[attr-defined] + workers = client.config.get_load_workers(self.default_schema.tables, workers) # type: ignore[attr-defined] # create default loader config and the loader load_config = LoaderConfiguration( From 014543aa5adb7669adead1cbda39cb21268c9070 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Thu, 25 Jan 2024 19:56:21 +0100 Subject: [PATCH 07/84] add support for staged Parquet loading --- dlt/destinations/impl/synapse/__init__.py | 4 +- .../impl/synapse/configuration.py | 8 +- dlt/destinations/impl/synapse/factory.py | 5 +- dlt/destinations/impl/synapse/synapse.py | 115 +++++++++++++++++- poetry.lock | 4 +- pyproject.toml | 2 +- tests/load/pipeline/test_pipelines.py | 17 +-- tests/load/pipeline/test_stage_loading.py | 35 +++++- tests/load/utils.py | 17 +++ 9 files changed, 182 insertions(+), 25 deletions(-) diff --git a/dlt/destinations/impl/synapse/__init__.py b/dlt/destinations/impl/synapse/__init__.py index 175b011186..639d8a598f 100644 --- a/dlt/destinations/impl/synapse/__init__.py +++ b/dlt/destinations/impl/synapse/__init__.py @@ -9,8 +9,8 @@ def capabilities() -> DestinationCapabilitiesContext: caps.preferred_loader_file_format = "insert_values" caps.supported_loader_file_formats = ["insert_values"] - caps.preferred_staging_file_format = None - caps.supported_staging_file_formats = [] + caps.preferred_staging_file_format = "parquet" + caps.supported_staging_file_formats = ["parquet"] caps.insert_values_writer_type = "select_union" # https://stackoverflow.com/a/77014299 diff --git a/dlt/destinations/impl/synapse/configuration.py b/dlt/destinations/impl/synapse/configuration.py index 119c55ad7a..34b227a2ac 100644 --- a/dlt/destinations/impl/synapse/configuration.py +++ b/dlt/destinations/impl/synapse/configuration.py @@ -48,17 +48,21 @@ class SynapseClientConfiguration(MsSqlClientConfiguration): # are tricky in Synapse: they are NOT ENFORCED and can lead to innacurate # results if the user does not ensure all column values are unique. # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-table-constraints - create_indexes: Optional[bool] = False + create_indexes: bool = False """Whether `primary_key` and `unique` column hints are applied.""" # Concurrency is disabled by overriding the configured number of workers to 1 at runtime. - auto_disable_concurrency: Optional[bool] = True + auto_disable_concurrency: bool = True """Whether concurrency is automatically disabled in cases where it might cause issues.""" + staging_use_msi: bool = False + """Whether the managed identity of the Synapse workspace is used to authorize access to the staging Storage Account.""" + __config_gen_annotations__: ClassVar[List[str]] = [ "default_table_index_type", "create_indexes", "auto_disable_concurrency", + "staging_use_msi", ] def get_load_workers(self, tables: TSchemaTables, workers: int) -> int: diff --git a/dlt/destinations/impl/synapse/factory.py b/dlt/destinations/impl/synapse/factory.py index f77d8c11c2..3d951f3d4a 100644 --- a/dlt/destinations/impl/synapse/factory.py +++ b/dlt/destinations/impl/synapse/factory.py @@ -30,7 +30,8 @@ def __init__( credentials: t.Union[SynapseCredentials, t.Dict[str, t.Any], str] = None, default_table_index_type: t.Optional[TTableIndexType] = "heap", create_indexes: bool = False, - auto_disable_concurrency: t.Optional[bool] = True, + auto_disable_concurrency: bool = True, + staging_use_msi: bool = False, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, **kwargs: t.Any, @@ -45,6 +46,7 @@ def __init__( default_table_index_type: Maps directly to the default_table_index_type attribute of the SynapseClientConfiguration object. create_indexes: Maps directly to the create_indexes attribute of the SynapseClientConfiguration object. auto_disable_concurrency: Maps directly to the auto_disable_concurrency attribute of the SynapseClientConfiguration object. + auto_disable_concurrency: Maps directly to the staging_use_msi attribute of the SynapseClientConfiguration object. **kwargs: Additional arguments passed to the destination config """ super().__init__( @@ -52,6 +54,7 @@ def __init__( default_table_index_type=default_table_index_type, create_indexes=create_indexes, auto_disable_concurrency=auto_disable_concurrency, + staging_use_msi=staging_use_msi, destination_name=destination_name, environment=environment, **kwargs, diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py index e01e851d83..c29c0df3f5 100644 --- a/dlt/destinations/impl/synapse/synapse.py +++ b/dlt/destinations/impl/synapse/synapse.py @@ -1,17 +1,28 @@ +import os from typing import ClassVar, Sequence, List, Dict, Any, Optional, cast from copy import deepcopy from textwrap import dedent +from urllib.parse import urlparse, urlunparse from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import SupportsStagingDestination, NewLoadJob +from dlt.common.destination.reference import ( + SupportsStagingDestination, + NewLoadJob, + CredentialsConfiguration, +) from dlt.common.schema import TTableSchema, TColumnSchema, Schema, TColumnHint +from dlt.common.schema.utils import table_schema_has_type from dlt.common.schema.typing import TTableSchemaColumns, TTableIndexType +from dlt.common.configuration.specs import AzureCredentialsWithoutDefaults + +from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.insert_job_client import InsertValuesJobClient -from dlt.destinations.job_client_impl import SqlJobClientBase +from dlt.destinations.job_client_impl import SqlJobClientBase, LoadJob, CopyRemoteFileLoadJob +from dlt.destinations.exceptions import LoadJobTerminalException from dlt.destinations.impl.mssql.mssql import ( MsSqlTypeMapper, @@ -35,7 +46,7 @@ } -class SynapseClient(MsSqlClient): +class SynapseClient(MsSqlClient, SupportsStagingDestination): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: SynapseClientConfiguration) -> None: @@ -140,6 +151,21 @@ def get_storage_table_index_type(self, table_name: str) -> TTableIndexType: table_index_type = sql_client.execute_sql(sql)[0][0] return cast(TTableIndexType, table_index_type) + def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + job = super().start_file_load(table, file_path, load_id) + if not job: + assert NewReferenceJob.is_reference_job( + file_path + ), "Synapse must use staging to load files" + job = SynapseCopyFileLoadJob( + table, + file_path, + self.sql_client, + cast(AzureCredentialsWithoutDefaults, self.config.staging_config.credentials), + self.config.staging_use_msi, + ) + return job + class SynapseStagingCopyJob(SqlStagingCopyJob): @classmethod @@ -173,3 +199,86 @@ def generate_sql( ) return sql + + +class SynapseCopyFileLoadJob(CopyRemoteFileLoadJob): + def __init__( + self, + table: TTableSchema, + file_path: str, + sql_client: SqlClientBase[Any], + staging_credentials: Optional[AzureCredentialsWithoutDefaults] = None, + staging_use_msi: bool = False, + ) -> None: + self.staging_use_msi = staging_use_msi + super().__init__(table, file_path, sql_client, staging_credentials) + + def execute(self, table: TTableSchema, bucket_path: str) -> None: + # get format + ext = os.path.splitext(bucket_path)[1][1:] + if ext == "parquet": + if table_schema_has_type(table, "time"): + # Synapse interprets Parquet TIME columns as bigint, resulting in + # an incompatibility error. + raise LoadJobTerminalException( + self.file_name(), + "Synapse cannot load TIME columns from Parquet files. Switch to direct INSERT" + " file format or convert `datetime.time` objects in your data to `str` or" + " `datetime.datetime`", + ) + file_type = "PARQUET" + + # dlt-generated DDL statements will still create the table, but + # enabling AUTO_CREATE_TABLE prevents a MalformedInputException. + auto_create_table = "ON" + else: + raise ValueError(f"Unsupported file type {ext} for Synapse.") + + staging_credentials = self._staging_credentials + assert staging_credentials is not None + assert isinstance(staging_credentials, AzureCredentialsWithoutDefaults) + azure_storage_account_name = staging_credentials.azure_storage_account_name + https_path = self._get_https_path(bucket_path, azure_storage_account_name) + table_name = table["name"] + + if self.staging_use_msi: + credential = "IDENTITY = 'Managed Identity'" + else: + sas_token = staging_credentials.azure_storage_sas_token + credential = f"IDENTITY = 'Shared Access Signature', SECRET = '{sas_token}'" + + # Copy data from staging file into Synapse table. + with self._sql_client.begin_transaction(): + dataset_name = self._sql_client.dataset_name + sql = dedent(f""" + COPY INTO [{dataset_name}].[{table_name}] + FROM '{https_path}' + WITH ( + FILE_TYPE = '{file_type}', + CREDENTIAL = ({credential}), + AUTO_CREATE_TABLE = '{auto_create_table}' + ) + """) + self._sql_client.execute_sql(sql) + + def exception(self) -> str: + # this part of code should be never reached + raise NotImplementedError() + + def _get_https_path(self, bucket_path: str, storage_account_name: str) -> str: + """ + Converts a path in the form of az:/// to + https://.blob.core.windows.net// + as required by Synapse. + """ + bucket_url = urlparse(bucket_path) + # "blob" endpoint has better performance than "dfs" endoint + # https://learn.microsoft.com/en-us/sql/t-sql/statements/copy-into-transact-sql?view=azure-sqldw-latest#external-locations + endpoint = "blob" + _path = "/" + bucket_url.netloc + bucket_url.path + https_url = bucket_url._replace( + scheme="https", + netloc=f"{storage_account_name}.{endpoint}.core.windows.net", + path=_path, + ) + return urlunparse(https_url) diff --git a/poetry.lock b/poetry.lock index 4d079fc44d..400bcb61e2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -8466,10 +8466,10 @@ qdrant = ["qdrant-client"] redshift = ["psycopg2-binary", "psycopg2cffi"] s3 = ["botocore", "s3fs"] snowflake = ["snowflake-connector-python"] -synapse = ["pyodbc"] +synapse = ["adlfs", "pyodbc"] weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "26c595a857f17a5cbdb348f165c267d8910412325be4e522d0e91224c7fec588" +content-hash = "75a5f533e9456898ad0157b699d76d9c5a1abf8f4cd04ed7be2235ae3198e16c" diff --git a/pyproject.toml b/pyproject.toml index d9d5858674..f6ae77b593 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,7 +96,7 @@ cli = ["pipdeptree", "cron-descriptor"] athena = ["pyathena", "pyarrow", "s3fs", "botocore"] weaviate = ["weaviate-client"] mssql = ["pyodbc"] -synapse = ["pyodbc"] +synapse = ["pyodbc", "adlfs"] qdrant = ["qdrant-client"] [tool.poetry.scripts] diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index d170fd553b..304f1a0d2f 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -788,7 +788,7 @@ def other_data(): column_schemas["col11_precision"]["precision"] = 0 # drop TIME from databases not supporting it via parquet - if destination_config.destination in ["redshift", "athena"]: + if destination_config.destination in ["redshift", "athena", "synapse"]: data_types.pop("col11") data_types.pop("col11_null") data_types.pop("col11_precision") @@ -827,15 +827,16 @@ def some_source(): assert len(package_info.jobs["completed_jobs"]) == expected_completed_jobs with pipeline.sql_client() as sql_client: + qual_name = sql_client.make_qualified_table_name assert [ - row[0] for row in sql_client.execute_sql("SELECT * FROM other_data ORDER BY 1") + row[0] + for row in sql_client.execute_sql(f"SELECT * FROM {qual_name('other_data')} ORDER BY 1") ] == [1, 2, 3, 4, 5] - assert [row[0] for row in sql_client.execute_sql("SELECT * FROM some_data ORDER BY 1")] == [ - 1, - 2, - 3, - ] - db_rows = sql_client.execute_sql("SELECT * FROM data_types") + assert [ + row[0] + for row in sql_client.execute_sql(f"SELECT * FROM {qual_name('some_data')} ORDER BY 1") + ] == [1, 2, 3] + db_rows = sql_client.execute_sql(f"SELECT * FROM {qual_name('data_types')}") assert len(db_rows) == 10 db_row = list(db_rows[0]) # "snowflake" and "bigquery" do not parse JSON form parquet string so double parse diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py index de4a7f4c3b..ca27cf4b05 100644 --- a/tests/load/pipeline/test_stage_loading.py +++ b/tests/load/pipeline/test_stage_loading.py @@ -94,7 +94,13 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: # check item of first row in db with pipeline.sql_client() as sql_client: - rows = sql_client.execute_sql("SELECT url FROM issues WHERE id = 388089021 LIMIT 1") + if destination_config.destination in ["mssql", "synapse"]: + qual_name = sql_client.make_qualified_table_name + rows = sql_client.execute_sql( + f"SELECT TOP 1 url FROM {qual_name('issues')} WHERE id = 388089021" + ) + else: + rows = sql_client.execute_sql("SELECT url FROM issues WHERE id = 388089021 LIMIT 1") assert rows[0][0] == "https://api.github.com/repos/duckdb/duckdb/issues/71" if destination_config.supports_merge: @@ -109,10 +115,23 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: # check changes where merged in with pipeline.sql_client() as sql_client: - rows = sql_client.execute_sql("SELECT number FROM issues WHERE id = 1232152492 LIMIT 1") - assert rows[0][0] == 105 - rows = sql_client.execute_sql("SELECT number FROM issues WHERE id = 1142699354 LIMIT 1") - assert rows[0][0] == 300 + if destination_config.destination in ["mssql", "synapse"]: + qual_name = sql_client.make_qualified_table_name + rows_1 = sql_client.execute_sql( + f"SELECT TOP 1 number FROM {qual_name('issues')} WHERE id = 1232152492" + ) + rows_2 = sql_client.execute_sql( + f"SELECT TOP 1 number FROM {qual_name('issues')} WHERE id = 1142699354" + ) + else: + rows_1 = sql_client.execute_sql( + "SELECT number FROM issues WHERE id = 1232152492 LIMIT 1" + ) + rows_2 = sql_client.execute_sql( + "SELECT number FROM issues WHERE id = 1142699354 LIMIT 1" + ) + assert rows_1[0][0] == 105 + assert rows_2[0][0] == 300 # test append info = pipeline.run( @@ -161,6 +180,9 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non ) and destination_config.file_format in ("parquet", "jsonl"): # Redshift copy doesn't support TIME column exclude_types.append("time") + if destination_config.destination == "synapse" and destination_config.file_format == "parquet": + # TIME columns are not supported for staged parquet loads into Synapse + exclude_types.append("time") if destination_config.destination == "redshift" and destination_config.file_format in ( "parquet", "jsonl", @@ -199,7 +221,8 @@ def my_source(): assert_load_info(info) with pipeline.sql_client() as sql_client: - db_rows = sql_client.execute_sql("SELECT * FROM data_types") + qual_name = sql_client.make_qualified_table_name + db_rows = sql_client.execute_sql(f"SELECT * FROM {qual_name('data_types')}") assert len(db_rows) == 10 db_row = list(db_rows[0]) # parquet is not really good at inserting json, best we get are strings in JSON columns diff --git a/tests/load/utils.py b/tests/load/utils.py index 55445e0b95..207e32209f 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -95,6 +95,7 @@ class DestinationTestConfiguration: bucket_url: Optional[str] = None stage_name: Optional[str] = None staging_iam_role: Optional[str] = None + staging_use_msi: bool = False extra_info: Optional[str] = None supports_merge: bool = True # TODO: take it from client base class force_iceberg: bool = False @@ -118,6 +119,7 @@ def setup(self) -> None: os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = self.bucket_url or "" os.environ["DESTINATION__STAGE_NAME"] = self.stage_name or "" os.environ["DESTINATION__STAGING_IAM_ROLE"] = self.staging_iam_role or "" + os.environ["DESTINATION__STAGING_USE_MSI"] = str(self.staging_use_msi) or "" os.environ["DESTINATION__FORCE_ICEBERG"] = str(self.force_iceberg) or "" """For the filesystem destinations we disable compression to make analyzing the result easier""" @@ -254,6 +256,21 @@ def destinations_configs( bucket_url=AZ_BUCKET, extra_info="az-authorization", ), + DestinationTestConfiguration( + destination="synapse", + staging="filesystem", + file_format="parquet", + bucket_url=AZ_BUCKET, + extra_info="az-authorization", + ), + DestinationTestConfiguration( + destination="synapse", + staging="filesystem", + file_format="parquet", + bucket_url=AZ_BUCKET, + staging_use_msi=True, + extra_info="az-managed-identity", + ), ] if all_staging_configs: From 7868ca6bfd54ff691e8e84384a65c7b9c55a00f4 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Fri, 26 Jan 2024 21:36:42 +0100 Subject: [PATCH 08/84] made table index type logic Synapse specific through destination adapter --- dlt/common/destination/reference.py | 4 +- dlt/common/schema/typing.py | 3 -- dlt/common/schema/utils.py | 12 ----- dlt/destinations/adapters.py | 3 +- .../impl/qdrant/qdrant_adapter.py | 11 +--- dlt/destinations/impl/synapse/__init__.py | 2 + .../impl/synapse/configuration.py | 4 +- dlt/destinations/impl/synapse/factory.py | 4 +- dlt/destinations/impl/synapse/synapse.py | 23 ++++++--- .../impl/synapse/synapse_adapter.py | 50 +++++++++++++++++++ .../impl/weaviate/weaviate_adapter.py | 11 +--- dlt/destinations/utils.py | 16 ++++++ dlt/extract/decorators.py | 7 --- dlt/extract/hints.py | 3 -- .../test_table_indexing.py | 46 ++++++++--------- 15 files changed, 117 insertions(+), 82 deletions(-) create mode 100644 dlt/destinations/impl/synapse/synapse_adapter.py create mode 100644 dlt/destinations/utils.py rename tests/load/{pipeline => synapse}/test_table_indexing.py (81%) diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 59f13b30b9..1c28dffa8c 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -34,7 +34,7 @@ from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.schema.typing import TWriteDisposition from dlt.common.schema.exceptions import InvalidDatasetName -from dlt.common.schema.utils import get_write_disposition, get_table_format, get_table_index_type +from dlt.common.schema.utils import get_write_disposition, get_table_format from dlt.common.configuration import configspec, with_config, resolve_configuration, known_sections from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration from dlt.common.configuration.accessors import config @@ -372,8 +372,6 @@ def get_load_table(self, table_name: str, prepare_for_staging: bool = False) -> table["write_disposition"] = get_write_disposition(self.schema.tables, table_name) if "table_format" not in table: table["table_format"] = get_table_format(self.schema.tables, table_name) - if "table_index_type" not in table: - table["table_index_type"] = get_table_index_type(self.schema.tables, table_name) return table except KeyError: raise UnknownTableException(table_name) diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index 351d666553..9a27cbe4bb 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -62,8 +62,6 @@ """Known hints of a column used to declare hint regexes.""" TWriteDisposition = Literal["skip", "append", "replace", "merge"] TTableFormat = Literal["iceberg"] -TTableIndexType = Literal["heap", "clustered_columnstore_index"] -"Table index type. Currently only used for Synapse destination." TTypeDetections = Literal[ "timestamp", "iso_timestamp", "iso_date", "large_integer", "hexbytes_to_text", "wei_to_double" ] @@ -167,7 +165,6 @@ class TTableSchema(TypedDict, total=False): columns: TTableSchemaColumns resource: Optional[str] table_format: Optional[TTableFormat] - table_index_type: Optional[TTableIndexType] class TPartialTableSchema(TTableSchema): diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index 5ea244148e..dc243f50dd 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -32,7 +32,6 @@ TColumnSchema, TColumnProp, TTableFormat, - TTableIndexType, TColumnHint, TTypeDetectionFunc, TTypeDetections, @@ -619,14 +618,6 @@ def get_table_format(tables: TSchemaTables, table_name: str) -> TTableFormat: ) -def get_table_index_type(tables: TSchemaTables, table_name: str) -> TTableIndexType: - """Returns table index type of a table if present. If not, looks up into parent table.""" - return cast( - TTableIndexType, - get_inherited_table_hint(tables, table_name, "table_index_type", allow_none=True), - ) - - def table_schema_has_type(table: TTableSchema, _typ: TDataType) -> bool: """Checks if `table` schema contains column with type _typ""" return any(c.get("data_type") == _typ for c in table["columns"].values()) @@ -733,7 +724,6 @@ def new_table( resource: str = None, schema_contract: TSchemaContract = None, table_format: TTableFormat = None, - table_index_type: TTableIndexType = None, ) -> TTableSchema: table: TTableSchema = { "name": table_name, @@ -752,8 +742,6 @@ def new_table( table["schema_contract"] = schema_contract if table_format: table["table_format"] = table_format - if table_index_type is not None: - table["table_index_type"] = table_index_type if validate_schema: validate_dict_ignoring_xkeys( spec=TColumnSchema, diff --git a/dlt/destinations/adapters.py b/dlt/destinations/adapters.py index b8f12599dc..22c98d4f5a 100644 --- a/dlt/destinations/adapters.py +++ b/dlt/destinations/adapters.py @@ -2,5 +2,6 @@ from dlt.destinations.impl.weaviate import weaviate_adapter from dlt.destinations.impl.qdrant import qdrant_adapter +from dlt.destinations.impl.synapse import synapse_adapter -__all__ = ["weaviate_adapter", "qdrant_adapter"] +__all__ = ["weaviate_adapter", "qdrant_adapter", "synapse_adapter"] diff --git a/dlt/destinations/impl/qdrant/qdrant_adapter.py b/dlt/destinations/impl/qdrant/qdrant_adapter.py index 243cbd6c5b..215d87a920 100644 --- a/dlt/destinations/impl/qdrant/qdrant_adapter.py +++ b/dlt/destinations/impl/qdrant/qdrant_adapter.py @@ -2,6 +2,7 @@ from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns from dlt.extract import DltResource, resource as make_resource +from dlt.destinations.utils import ensure_resource VECTORIZE_HINT = "x-qdrant-embed" @@ -31,15 +32,7 @@ def qdrant_adapter( >>> qdrant_adapter(data, embed="description") [DltResource with hints applied] """ - # wrap `data` in a resource if not an instance already - resource: DltResource - if not isinstance(data, DltResource): - resource_name: str = None - if not hasattr(data, "__name__"): - resource_name = "content" - resource = make_resource(data, name=resource_name) - else: - resource = data + resource = ensure_resource(data) column_hints: TTableSchemaColumns = {} diff --git a/dlt/destinations/impl/synapse/__init__.py b/dlt/destinations/impl/synapse/__init__.py index 639d8a598f..53dbabc090 100644 --- a/dlt/destinations/impl/synapse/__init__.py +++ b/dlt/destinations/impl/synapse/__init__.py @@ -3,6 +3,8 @@ from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.common.wei import EVM_DECIMAL_PRECISION +from dlt.destinations.impl.synapse.synapse_adapter import synapse_adapter + def capabilities() -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() diff --git a/dlt/destinations/impl/synapse/configuration.py b/dlt/destinations/impl/synapse/configuration.py index 34b227a2ac..cc0e40114b 100644 --- a/dlt/destinations/impl/synapse/configuration.py +++ b/dlt/destinations/impl/synapse/configuration.py @@ -2,7 +2,7 @@ from dlt.common import logger from dlt.common.configuration import configspec -from dlt.common.schema.typing import TTableIndexType, TSchemaTables +from dlt.common.schema.typing import TSchemaTables from dlt.common.schema.utils import get_write_disposition from dlt.destinations.impl.mssql.configuration import ( @@ -11,6 +11,8 @@ ) from dlt.destinations.impl.mssql.configuration import MsSqlCredentials +from dlt.destinations.impl.synapse.synapse_adapter import TTableIndexType + @configspec class SynapseCredentials(MsSqlCredentials): diff --git a/dlt/destinations/impl/synapse/factory.py b/dlt/destinations/impl/synapse/factory.py index 3d951f3d4a..0ac58001ca 100644 --- a/dlt/destinations/impl/synapse/factory.py +++ b/dlt/destinations/impl/synapse/factory.py @@ -1,13 +1,13 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext -from dlt.common.schema.typing import TTableIndexType -from dlt.destinations.impl.synapse import capabilities +from dlt.destinations.impl.synapse import capabilities from dlt.destinations.impl.synapse.configuration import ( SynapseCredentials, SynapseClientConfiguration, ) +from dlt.destinations.impl.synapse.synapse_adapter import TTableIndexType if t.TYPE_CHECKING: from dlt.destinations.impl.synapse.synapse import SynapseClient diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py index c29c0df3f5..d34fef1ab4 100644 --- a/dlt/destinations/impl/synapse/synapse.py +++ b/dlt/destinations/impl/synapse/synapse.py @@ -12,8 +12,8 @@ ) from dlt.common.schema import TTableSchema, TColumnSchema, Schema, TColumnHint -from dlt.common.schema.utils import table_schema_has_type -from dlt.common.schema.typing import TTableSchemaColumns, TTableIndexType +from dlt.common.schema.utils import table_schema_has_type, get_inherited_table_hint +from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.configuration.specs import AzureCredentialsWithoutDefaults @@ -34,6 +34,10 @@ from dlt.destinations.impl.synapse import capabilities from dlt.destinations.impl.synapse.sql_client import SynapseSqlClient from dlt.destinations.impl.synapse.configuration import SynapseClientConfiguration +from dlt.destinations.impl.synapse.synapse_adapter import ( + TABLE_INDEX_TYPE_HINT, + TTableIndexType, +) HINT_TO_SYNAPSE_ATTR: Dict[TColumnHint, str] = { @@ -68,7 +72,7 @@ def _get_table_update_sql( if table is None: table_index_type = self.config.default_table_index_type else: - table_index_type = table.get("table_index_type") + table_index_type = cast(TTableIndexType, table.get(TABLE_INDEX_TYPE_HINT)) if table_index_type == "clustered_columnstore_index": new_columns = self._get_columstore_valid_columns(new_columns) @@ -128,9 +132,16 @@ def get_load_table(self, table_name: str, staging: bool = False) -> TTableSchema # configuration. Why? "For small lookup tables, less than 60 million rows, # consider using HEAP or clustered index for faster query performance." # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-index#heap-tables - table["table_index_type"] = "heap" - if table["table_index_type"] is None: - table["table_index_type"] = self.config.default_table_index_type + table[TABLE_INDEX_TYPE_HINT] = "heap" # type: ignore[typeddict-unknown-key] + elif table_name in self.schema.data_table_names(): + if TABLE_INDEX_TYPE_HINT not in table: + # If present in parent table, fetch hint from there. + table[TABLE_INDEX_TYPE_HINT] = get_inherited_table_hint( # type: ignore[typeddict-unknown-key] + self.schema.tables, table_name, TABLE_INDEX_TYPE_HINT, allow_none=True + ) + if table[TABLE_INDEX_TYPE_HINT] is None: # type: ignore[typeddict-item] + # Hint still not defined, fall back to default. + table[TABLE_INDEX_TYPE_HINT] = self.config.default_table_index_type # type: ignore[typeddict-unknown-key] return table def get_storage_table_index_type(self, table_name: str) -> TTableIndexType: diff --git a/dlt/destinations/impl/synapse/synapse_adapter.py b/dlt/destinations/impl/synapse/synapse_adapter.py new file mode 100644 index 0000000000..f135dd967a --- /dev/null +++ b/dlt/destinations/impl/synapse/synapse_adapter.py @@ -0,0 +1,50 @@ +from typing import Any, Literal, Set, get_args, Final + +from dlt.extract import DltResource, resource as make_resource +from dlt.extract.typing import TTableHintTemplate +from dlt.extract.hints import TResourceHints +from dlt.destinations.utils import ensure_resource + +TTableIndexType = Literal["heap", "clustered_columnstore_index"] +""" +Table [index type](https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-index) used when creating the Synapse table. +This regards indexes specified at the table level, not the column level. +""" +TABLE_INDEX_TYPES: Set[TTableIndexType] = set(get_args(TTableIndexType)) + +TABLE_INDEX_TYPE_HINT: Literal["x-table-index-type"] = "x-table-index-type" + + +def synapse_adapter(data: Any, table_index_type: TTableIndexType = None) -> DltResource: + """Prepares data for the Synapse destination by specifying which table index + type should be used. + + Args: + data (Any): The data to be transformed. It can be raw data or an instance + of DltResource. If raw data, the function wraps it into a DltResource + object. + table_index_type (TTableIndexType, optional): The table index type used when creating + the Synapse table. + + Returns: + DltResource: A resource with applied Synapse-specific hints. + + Raises: + ValueError: If input for `table_index_type` is invalid. + + Examples: + >>> data = [{"name": "Anush", "description": "Integrations Hacker"}] + >>> synapse_adapter(data, table_index_type="clustered_columnstore_index") + [DltResource with hints applied] + """ + resource = ensure_resource(data) + + if table_index_type is not None: + if table_index_type not in TABLE_INDEX_TYPES: + allowed_types = ", ".join(TABLE_INDEX_TYPES) + raise ValueError( + f"Table index type {table_index_type} is invalid. Allowed table index" + f" types are: {allowed_types}." + ) + resource._hints[TABLE_INDEX_TYPE_HINT] = table_index_type # type: ignore[typeddict-unknown-key] + return resource diff --git a/dlt/destinations/impl/weaviate/weaviate_adapter.py b/dlt/destinations/impl/weaviate/weaviate_adapter.py index 2d5161d9e9..a290ac65b4 100644 --- a/dlt/destinations/impl/weaviate/weaviate_adapter.py +++ b/dlt/destinations/impl/weaviate/weaviate_adapter.py @@ -2,6 +2,7 @@ from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns from dlt.extract import DltResource, resource as make_resource +from dlt.destinations.utils import ensure_resource TTokenizationTMethod = Literal["word", "lowercase", "whitespace", "field"] TOKENIZATION_METHODS: Set[TTokenizationTMethod] = set(get_args(TTokenizationTMethod)) @@ -53,15 +54,7 @@ def weaviate_adapter( >>> weaviate_adapter(data, vectorize="description", tokenization={"description": "word"}) [DltResource with hints applied] """ - # wrap `data` in a resource if not an instance already - resource: DltResource - if not isinstance(data, DltResource): - resource_name: str = None - if not hasattr(data, "__name__"): - resource_name = "content" - resource = make_resource(data, name=resource_name) - else: - resource = data + resource = ensure_resource(data) column_hints: TTableSchemaColumns = {} if vectorize: diff --git a/dlt/destinations/utils.py b/dlt/destinations/utils.py new file mode 100644 index 0000000000..d4b945a840 --- /dev/null +++ b/dlt/destinations/utils.py @@ -0,0 +1,16 @@ +from typing import Any + +from dlt.extract import DltResource, resource as make_resource + + +def ensure_resource(data: Any) -> DltResource: + """Wraps `data` in a DltResource if it's not a DltResource already.""" + resource: DltResource + if not isinstance(data, DltResource): + resource_name: str = None + if not hasattr(data, "__name__"): + resource_name = "content" + resource = make_resource(data, name=resource_name) + else: + resource = data + return resource diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index 573d3d3ad0..cf7426e683 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -36,7 +36,6 @@ TAnySchemaColumns, TSchemaContract, TTableFormat, - TTableIndexType, ) from dlt.extract.utils import ( ensure_table_schema_columns_hint, @@ -257,7 +256,6 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, - table_index_type: TTableHintTemplate[TTableIndexType] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, ) -> DltResource: ... @@ -275,7 +273,6 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, - table_index_type: TTableHintTemplate[TTableIndexType] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, ) -> Callable[[Callable[TResourceFunParams, Any]], DltResource]: ... @@ -293,7 +290,6 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, - table_index_type: TTableHintTemplate[TTableIndexType] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, standalone: Literal[True] = True, @@ -312,7 +308,6 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, - table_index_type: TTableHintTemplate[TTableIndexType] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, ) -> DltResource: ... @@ -329,7 +324,6 @@ def resource( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, - table_index_type: TTableHintTemplate[TTableIndexType] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, standalone: bool = False, @@ -409,7 +403,6 @@ def make_resource( merge_key=merge_key, schema_contract=schema_contract, table_format=table_format, - table_index_type=table_index_type, ) return DltResource.from_data( _data, diff --git a/dlt/extract/hints.py b/dlt/extract/hints.py index 36354eb0da..437dbbc6bd 100644 --- a/dlt/extract/hints.py +++ b/dlt/extract/hints.py @@ -12,7 +12,6 @@ TWriteDisposition, TAnySchemaColumns, TTableFormat, - TTableIndexType, TSchemaContract, ) from dlt.common.typing import TDataItem @@ -275,7 +274,6 @@ def new_table_template( merge_key: TTableHintTemplate[TColumnNames] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, table_format: TTableHintTemplate[TTableFormat] = None, - table_index_type: TTableHintTemplate[TTableIndexType] = None, ) -> TResourceHints: validator, schema_contract = create_item_validator(columns, schema_contract) clean_columns = columns @@ -291,7 +289,6 @@ def new_table_template( columns=clean_columns, # type: ignore schema_contract=schema_contract, # type: ignore table_format=table_format, # type: ignore - table_index_type=table_index_type, # type: ignore ) if not table_name: new_template.pop("name") diff --git a/tests/load/pipeline/test_table_indexing.py b/tests/load/synapse/test_table_indexing.py similarity index 81% rename from tests/load/pipeline/test_table_indexing.py rename to tests/load/synapse/test_table_indexing.py index 5f62cddfee..097bde09f9 100644 --- a/tests/load/pipeline/test_table_indexing.py +++ b/tests/load/synapse/test_table_indexing.py @@ -5,16 +5,13 @@ import dlt from dlt.common.schema import TColumnSchema -from dlt.common.schema.typing import TTableIndexType, TSchemaTables -from dlt.common.schema.utils import get_table_index_type from dlt.destinations.sql_client import SqlClientBase +from dlt.destinations.impl.synapse import synapse_adapter +from dlt.destinations.impl.synapse.synapse_adapter import TTableIndexType + from tests.load.utils import TABLE_UPDATE, TABLE_ROW_ALL_DATA_TYPES -from tests.load.pipeline.utils import ( - destinations_configs, - DestinationTestConfiguration, -) TABLE_INDEX_TYPE_COLUMN_SCHEMA_PARAM_GRID = [ @@ -27,16 +24,10 @@ ] -@pytest.mark.parametrize( - "destination_config", - destinations_configs(default_sql_configs=True, subset=["synapse"]), - ids=lambda x: x.name, -) @pytest.mark.parametrize( "table_index_type,column_schema", TABLE_INDEX_TYPE_COLUMN_SCHEMA_PARAM_GRID ) def test_default_table_index_type_configuration( - destination_config: DestinationTestConfiguration, table_index_type: TTableIndexType, column_schema: Union[List[TColumnSchema], None], ) -> None: @@ -51,10 +42,13 @@ def test_default_table_index_type_configuration( def items_without_table_index_type_specified() -> Iterator[Any]: yield TABLE_ROW_ALL_DATA_TYPES - pipeline = destination_config.setup_pipeline( - f"test_default_table_index_type_{table_index_type}", + pipeline = dlt.pipeline( + pipeline_name=f"test_default_table_index_type_{table_index_type}", + destination="synapse", + dataset_name=f"test_default_table_index_type_{table_index_type}", full_refresh=True, ) + job_client = pipeline.destination_client() # Assert configuration value gets properly propagated to job client configuration. assert job_client.config.default_table_index_type == table_index_type # type: ignore[attr-defined] @@ -80,13 +74,14 @@ def items_without_table_index_type_specified() -> Iterator[Any]: @dlt.resource( name="items_with_table_index_type_specified", write_disposition="append", - table_index_type="clustered_columnstore_index", columns=column_schema, ) def items_with_table_index_type_specified() -> Iterator[Any]: yield TABLE_ROW_ALL_DATA_TYPES - pipeline.run(items_with_table_index_type_specified) + pipeline.run( + synapse_adapter(items_with_table_index_type_specified, "clustered_columnstore_index") + ) applied_table_index_type = job_client.get_storage_table_index_type( # type: ignore[attr-defined] "items_with_table_index_type_specified" ) @@ -95,35 +90,34 @@ def items_with_table_index_type_specified() -> Iterator[Any]: assert applied_table_index_type == "clustered_columnstore_index" -@pytest.mark.parametrize( - "destination_config", - destinations_configs(default_sql_configs=True, subset=["synapse"]), - ids=lambda x: x.name, -) @pytest.mark.parametrize( "table_index_type,column_schema", TABLE_INDEX_TYPE_COLUMN_SCHEMA_PARAM_GRID ) def test_resource_table_index_type_configuration( - destination_config: DestinationTestConfiguration, table_index_type: TTableIndexType, column_schema: Union[List[TColumnSchema], None], ) -> None: @dlt.resource( name="items_with_table_index_type_specified", write_disposition="append", - table_index_type=table_index_type, columns=column_schema, ) def items_with_table_index_type_specified() -> Iterator[Any]: yield TABLE_ROW_ALL_DATA_TYPES - pipeline = destination_config.setup_pipeline( - f"test_table_index_type_{table_index_type}", + pipeline = dlt.pipeline( + pipeline_name=f"test_table_index_type_{table_index_type}", + destination="synapse", + dataset_name=f"test_table_index_type_{table_index_type}", full_refresh=True, ) + # An invalid value for `table_index_type` should raise a ValueError. + with pytest.raises(ValueError): + pipeline.run(synapse_adapter(items_with_table_index_type_specified, "foo")) # type: ignore[arg-type] + # Run the pipeline and create the tables. - pipeline.run(items_with_table_index_type_specified) + pipeline.run(synapse_adapter(items_with_table_index_type_specified, table_index_type)) # For all tables, assert the applied index type equals the expected index type. # Child tables, if any, inherit the index type of their parent. From b4cdd36e41af7e13849e255133a2654dde79ac7e Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Fri, 26 Jan 2024 22:06:11 +0100 Subject: [PATCH 09/84] moved test function into tests folder and renamed test file --- dlt/destinations/impl/synapse/synapse.py | 18 -------------- ...xing.py => test_synapse_table_indexing.py} | 10 ++++---- tests/load/synapse/utils.py | 24 +++++++++++++++++++ 3 files changed, 30 insertions(+), 22 deletions(-) rename tests/load/synapse/{test_table_indexing.py => test_synapse_table_indexing.py} (91%) create mode 100644 tests/load/synapse/utils.py diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py index d34fef1ab4..eb6eae3f20 100644 --- a/dlt/destinations/impl/synapse/synapse.py +++ b/dlt/destinations/impl/synapse/synapse.py @@ -144,24 +144,6 @@ def get_load_table(self, table_name: str, staging: bool = False) -> TTableSchema table[TABLE_INDEX_TYPE_HINT] = self.config.default_table_index_type # type: ignore[typeddict-unknown-key] return table - def get_storage_table_index_type(self, table_name: str) -> TTableIndexType: - """Returns table index type of table in storage destination.""" - with self.sql_client as sql_client: - schema_name = sql_client.fully_qualified_dataset_name(escape=False) - sql = dedent(f""" - SELECT - CASE i.type_desc - WHEN 'HEAP' THEN 'heap' - WHEN 'CLUSTERED COLUMNSTORE' THEN 'clustered_columnstore_index' - END AS table_index_type - FROM sys.indexes i - INNER JOIN sys.tables t ON t.object_id = i.object_id - INNER JOIN sys.schemas s ON s.schema_id = t.schema_id - WHERE s.name = '{schema_name}' AND t.name = '{table_name}' - """) - table_index_type = sql_client.execute_sql(sql)[0][0] - return cast(TTableIndexType, table_index_type) - def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: job = super().start_file_load(table, file_path, load_id) if not job: diff --git a/tests/load/synapse/test_table_indexing.py b/tests/load/synapse/test_synapse_table_indexing.py similarity index 91% rename from tests/load/synapse/test_table_indexing.py rename to tests/load/synapse/test_synapse_table_indexing.py index 097bde09f9..af4786af9f 100644 --- a/tests/load/synapse/test_table_indexing.py +++ b/tests/load/synapse/test_synapse_table_indexing.py @@ -12,6 +12,7 @@ from dlt.destinations.impl.synapse.synapse_adapter import TTableIndexType from tests.load.utils import TABLE_UPDATE, TABLE_ROW_ALL_DATA_TYPES +from tests.load.synapse.utils import get_storage_table_index_type TABLE_INDEX_TYPE_COLUMN_SCHEMA_PARAM_GRID = [ @@ -60,7 +61,7 @@ def items_without_table_index_type_specified() -> Iterator[Any]: # Child tables, if any, inherit the index type of their parent. tables = pipeline.default_schema.tables for table_name in tables: - applied_table_index_type = job_client.get_storage_table_index_type(table_name) # type: ignore[attr-defined] + applied_table_index_type = get_storage_table_index_type(job_client.sql_client, table_name) # type: ignore[attr-defined] if table_name in pipeline.default_schema.data_table_names(): # For data tables, the applied table index type should be the default value. assert applied_table_index_type == job_client.config.default_table_index_type # type: ignore[attr-defined] @@ -82,8 +83,9 @@ def items_with_table_index_type_specified() -> Iterator[Any]: pipeline.run( synapse_adapter(items_with_table_index_type_specified, "clustered_columnstore_index") ) - applied_table_index_type = job_client.get_storage_table_index_type( # type: ignore[attr-defined] - "items_with_table_index_type_specified" + applied_table_index_type = get_storage_table_index_type( + job_client.sql_client, # type: ignore[attr-defined] + "items_with_table_index_type_specified", ) # While the default is "heap", the applied index type should be "clustered_columnstore_index" # because it was provided as argument to the resource. @@ -124,7 +126,7 @@ def items_with_table_index_type_specified() -> Iterator[Any]: job_client = pipeline.destination_client() tables = pipeline.default_schema.tables for table_name in tables: - applied_table_index_type = job_client.get_storage_table_index_type(table_name) # type: ignore[attr-defined] + applied_table_index_type = get_storage_table_index_type(job_client.sql_client, table_name) # type: ignore[attr-defined] if table_name in pipeline.default_schema.data_table_names(): # For data tables, the applied table index type should be the type # configured in the resource. diff --git a/tests/load/synapse/utils.py b/tests/load/synapse/utils.py new file mode 100644 index 0000000000..cd53716878 --- /dev/null +++ b/tests/load/synapse/utils.py @@ -0,0 +1,24 @@ +from typing import cast +from textwrap import dedent + +from dlt.destinations.impl.synapse.sql_client import SynapseSqlClient +from dlt.destinations.impl.synapse.synapse_adapter import TTableIndexType + + +def get_storage_table_index_type(sql_client: SynapseSqlClient, table_name: str) -> TTableIndexType: + """Returns table index type of table in storage destination.""" + with sql_client: + schema_name = sql_client.fully_qualified_dataset_name(escape=False) + sql = dedent(f""" + SELECT + CASE i.type_desc + WHEN 'HEAP' THEN 'heap' + WHEN 'CLUSTERED COLUMNSTORE' THEN 'clustered_columnstore_index' + END AS table_index_type + FROM sys.indexes i + INNER JOIN sys.tables t ON t.object_id = i.object_id + INNER JOIN sys.schemas s ON s.schema_id = t.schema_id + WHERE s.name = '{schema_name}' AND t.name = '{table_name}' + """) + table_index_type = sql_client.execute_sql(sql)[0][0] + return cast(TTableIndexType, table_index_type) From 13e6a0ea09402c8582972ecc73c613e0030870aa Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Thu, 9 Nov 2023 12:29:14 -0600 Subject: [PATCH 10/84] Add escape_databricks_identifier function to escape.py --- dlt/common/data_writers/escape.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dlt/common/data_writers/escape.py b/dlt/common/data_writers/escape.py index 5bf8f29ccb..707dfb0e1f 100644 --- a/dlt/common/data_writers/escape.py +++ b/dlt/common/data_writers/escape.py @@ -124,3 +124,5 @@ def escape_snowflake_identifier(v: str) -> str: # Snowcase uppercase all identifiers unless quoted. Match this here so queries on information schema work without issue # See also https://docs.snowflake.com/en/sql-reference/identifiers-syntax#double-quoted-identifiers return escape_postgres_identifier(v.upper()) + +escape_databricks_identifier = escape_bigquery_identifier From 6c3e8afb834a702b9de022d5a4f5e005b76dbeaf Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Thu, 9 Nov 2023 12:38:52 -0600 Subject: [PATCH 11/84] Add Databricks destination client and configuration and rename classes and methods --- dlt/destinations/databricks/__init__.py | 47 ++++ dlt/destinations/databricks/configuration.py | 116 +++++++++ dlt/destinations/databricks/databricks.py | 236 +++++++++++++++++++ dlt/destinations/databricks/sql_client.py | 160 +++++++++++++ 4 files changed, 559 insertions(+) create mode 100644 dlt/destinations/databricks/__init__.py create mode 100644 dlt/destinations/databricks/configuration.py create mode 100644 dlt/destinations/databricks/databricks.py create mode 100644 dlt/destinations/databricks/sql_client.py diff --git a/dlt/destinations/databricks/__init__.py b/dlt/destinations/databricks/__init__.py new file mode 100644 index 0000000000..0ca1437d80 --- /dev/null +++ b/dlt/destinations/databricks/__init__.py @@ -0,0 +1,47 @@ +from typing import Type + +from dlt.common.schema.schema import Schema +from dlt.common.configuration import with_config, known_sections +from dlt.common.configuration.accessors import config +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration +from dlt.common.data_writers.escape import escape_databricks_identifier +from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE + +from dlt.destinations.databricks.configuration import DatabricksClientConfiguration + + +@with_config(spec=DatabricksClientConfiguration, sections=(known_sections.DESTINATION, "databricks",)) +def _configure(config: DatabricksClientConfiguration = config.value) -> DatabricksClientConfiguration: + return config + + +def capabilities() -> DestinationCapabilitiesContext: + caps = DestinationCapabilitiesContext() + caps.preferred_loader_file_format = "jsonl" + caps.supported_loader_file_formats = ["jsonl", "parquet"] + caps.preferred_staging_file_format = "jsonl" + caps.supported_staging_file_formats = ["jsonl", "parquet"] + caps.escape_identifier = escape_databricks_identifier + caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) + caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) + caps.max_identifier_length = 255 + caps.max_column_identifier_length = 255 + caps.max_query_length = 2 * 1024 * 1024 + caps.is_max_query_length_in_bytes = True + caps.max_text_data_type_length = 16 * 1024 * 1024 + caps.is_max_text_data_type_length_in_bytes = True + caps.supports_ddl_transactions = True + caps.alter_add_multi_column = True + return caps + + +def client(schema: Schema, initial_config: DestinationClientConfiguration = config.value) -> JobClientBase: + # import client when creating instance so capabilities and config specs can be accessed without dependencies installed + from dlt.destinations.databricks.databricks import DatabricksClient + + return DatabricksClient(schema, _configure(initial_config)) # type: ignore + + +def spec() -> Type[DestinationClientConfiguration]: + return DatabricksClientConfiguration diff --git a/dlt/destinations/databricks/configuration.py b/dlt/destinations/databricks/configuration.py new file mode 100644 index 0000000000..f24b3e2fcd --- /dev/null +++ b/dlt/destinations/databricks/configuration.py @@ -0,0 +1,116 @@ +import base64 +import binascii + +from typing import Final, Optional, Any, Dict, ClassVar, List + +from sqlalchemy.engine import URL + +from dlt import version +from dlt.common.exceptions import MissingDependencyException +from dlt.common.typing import TSecretStrValue +from dlt.common.configuration.specs import ConnectionStringCredentials +from dlt.common.configuration.exceptions import ConfigurationValueError +from dlt.common.configuration import configspec +from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration +from dlt.common.utils import digest128 + + +def _read_private_key(private_key: str, password: Optional[str] = None) -> bytes: + """Load an encrypted or unencrypted private key from string. + """ + try: + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.hazmat.primitives.asymmetric import dsa + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes + except ModuleNotFoundError as e: + raise MissingDependencyException("DatabricksCredentials with private key", dependencies=[f"{version.DLT_PKG_NAME}[databricks]"]) from e + + try: + # load key from base64-encoded DER key + pkey = serialization.load_der_private_key( + base64.b64decode(private_key), + password=password.encode() if password is not None else None, + backend=default_backend(), + ) + except Exception: + # loading base64-encoded DER key failed -> assume it's a plain-text PEM key + pkey = serialization.load_pem_private_key( + private_key.encode(encoding="ascii"), + password=password.encode() if password is not None else None, + backend=default_backend(), + ) + + return pkey.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption() + ) + + +@configspec +class DatabricksCredentials(ConnectionStringCredentials): + drivername: Final[str] = "databricks" # type: ignore[misc] + password: Optional[TSecretStrValue] = None + host: str = None + database: str = None + warehouse: Optional[str] = None + role: Optional[str] = None + private_key: Optional[TSecretStrValue] = None + private_key_passphrase: Optional[TSecretStrValue] = None + + __config_gen_annotations__: ClassVar[List[str]] = ["password", "warehouse", "role"] + + def parse_native_representation(self, native_value: Any) -> None: + super().parse_native_representation(native_value) + self.warehouse = self.query.get('warehouse') + self.role = self.query.get('role') + self.private_key = self.query.get('private_key') # type: ignore + self.private_key_passphrase = self.query.get('private_key_passphrase') # type: ignore + if not self.is_partial() and (self.password or self.private_key): + self.resolve() + + def on_resolved(self) -> None: + if not self.password and not self.private_key: + raise ConfigurationValueError("Please specify password or private_key. DatabricksCredentials supports password and private key authentication and one of those must be specified.") + + def to_url(self) -> URL: + query = dict(self.query or {}) + if self.warehouse and 'warehouse' not in query: + query['warehouse'] = self.warehouse + if self.role and 'role' not in query: + query['role'] = self.role + return URL.create(self.drivername, self.username, self.password, self.host, self.port, self.database, query) + + def to_connector_params(self) -> Dict[str, Any]: + private_key: Optional[bytes] = None + if self.private_key: + private_key = _read_private_key(self.private_key, self.private_key_passphrase) + return dict( + self.query or {}, + user=self.username, + password=self.password, + account=self.host, + database=self.database, + warehouse=self.warehouse, + role=self.role, + private_key=private_key, + ) + + +@configspec +class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration): + destination_name: Final[str] = "databricks" # type: ignore[misc] + credentials: DatabricksCredentials + + stage_name: Optional[str] = None + """Use an existing named stage instead of the default. Default uses the implicit table stage per table""" + keep_staged_files: bool = True + """Whether to keep or delete the staged files after COPY INTO succeeds""" + + def fingerprint(self) -> str: + """Returns a fingerprint of host part of a connection string""" + if self.credentials and self.credentials.host: + return digest128(self.credentials.host) + return "" diff --git a/dlt/destinations/databricks/databricks.py b/dlt/destinations/databricks/databricks.py new file mode 100644 index 0000000000..7af39082c6 --- /dev/null +++ b/dlt/destinations/databricks/databricks.py @@ -0,0 +1,236 @@ +from typing import ClassVar, Dict, Optional, Sequence, Tuple, List, Any +from urllib.parse import urlparse, urlunparse + +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.destination.reference import FollowupJob, NewLoadJob, TLoadJobState, LoadJob, CredentialsConfiguration +from dlt.common.configuration.specs import AwsCredentialsWithoutDefaults, AzureCredentials, AzureCredentialsWithoutDefaults +from dlt.common.data_types import TDataType +from dlt.common.storages.file_storage import FileStorage +from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns +from dlt.common.schema.typing import TTableSchema, TColumnType + + +from dlt.destinations.job_client_impl import SqlJobClientWithStaging +from dlt.destinations.job_impl import EmptyLoadJob +from dlt.destinations.exceptions import LoadJobTerminalException + +from dlt.destinations.databricks import capabilities +from dlt.destinations.databricks.configuration import DatabricksClientConfiguration +from dlt.destinations.databricks.sql_client import DatabricksSqlClient +from dlt.destinations.sql_jobs import SqlStagingCopyJob +from dlt.destinations.databricks.sql_client import DatabricksSqlClient +from dlt.destinations.job_impl import NewReferenceJob +from dlt.destinations.sql_client import SqlClientBase +from dlt.destinations.type_mapping import TypeMapper + + +class DatabricksTypeMapper(TypeMapper): + BIGINT_PRECISION = 19 + sct_to_unbound_dbt = { + "complex": "VARIANT", + "text": "VARCHAR", + "double": "FLOAT", + "bool": "BOOLEAN", + "date": "DATE", + "timestamp": "TIMESTAMP_TZ", + "bigint": f"NUMBER({BIGINT_PRECISION},0)", # Databricks has no integer types + "binary": "BINARY", + "time": "TIME", + } + + sct_to_dbt = { + "text": "VARCHAR(%i)", + "timestamp": "TIMESTAMP_TZ(%i)", + "decimal": "NUMBER(%i,%i)", + "time": "TIME(%i)", + "wei": "NUMBER(%i,%i)", + } + + dbt_to_sct = { + "VARCHAR": "text", + "FLOAT": "double", + "BOOLEAN": "bool", + "DATE": "date", + "TIMESTAMP_TZ": "timestamp", + "BINARY": "binary", + "VARIANT": "complex", + "TIME": "time" + } + + def from_db_type(self, db_type: str, precision: Optional[int] = None, scale: Optional[int] = None) -> TColumnType: + if db_type == "NUMBER": + if precision == self.BIGINT_PRECISION and scale == 0: + return dict(data_type='bigint') + elif (precision, scale) == self.capabilities.wei_precision: + return dict(data_type='wei') + return dict(data_type='decimal', precision=precision, scale=scale) + return super().from_db_type(db_type, precision, scale) + + +class DatabricksLoadJob(LoadJob, FollowupJob): + def __init__( + self, file_path: str, table_name: str, load_id: str, client: DatabricksSqlClient, + stage_name: Optional[str] = None, keep_staged_files: bool = True, staging_credentials: Optional[CredentialsConfiguration] = None + ) -> None: + file_name = FileStorage.get_file_name_from_file_path(file_path) + super().__init__(file_name) + + qualified_table_name = client.make_qualified_table_name(table_name) + + # extract and prepare some vars + bucket_path = NewReferenceJob.resolve_reference(file_path) if NewReferenceJob.is_reference_job(file_path) else "" + file_name = FileStorage.get_file_name_from_file_path(bucket_path) if bucket_path else file_name + from_clause = "" + credentials_clause = "" + files_clause = "" + stage_file_path = "" + + if bucket_path: + bucket_url = urlparse(bucket_path) + bucket_scheme = bucket_url.scheme + # referencing an external s3/azure stage does not require explicit AWS credentials + if bucket_scheme in ["s3", "az", "abfs"] and stage_name: + from_clause = f"FROM '@{stage_name}'" + files_clause = f"FILES = ('{bucket_url.path.lstrip('/')}')" + # referencing an staged files via a bucket URL requires explicit AWS credentials + elif bucket_scheme == "s3" and staging_credentials and isinstance(staging_credentials, AwsCredentialsWithoutDefaults): + credentials_clause = f"""CREDENTIALS=(AWS_KEY_ID='{staging_credentials.aws_access_key_id}' AWS_SECRET_KEY='{staging_credentials.aws_secret_access_key}')""" + from_clause = f"FROM '{bucket_path}'" + elif bucket_scheme in ["az", "abfs"] and staging_credentials and isinstance(staging_credentials, AzureCredentialsWithoutDefaults): + # Explicit azure credentials are needed to load from bucket without a named stage + credentials_clause = f"CREDENTIALS=(AZURE_SAS_TOKEN='?{staging_credentials.azure_storage_sas_token}')" + # Converts an az:/// to azure://.blob.core.windows.net// + # as required by databricks + _path = "/" + bucket_url.netloc + bucket_url.path + bucket_path = urlunparse( + bucket_url._replace( + scheme="azure", + netloc=f"{staging_credentials.azure_storage_account_name}.blob.core.windows.net", + path=_path + ) + ) + from_clause = f"FROM '{bucket_path}'" + else: + # ensure that gcs bucket path starts with gcs://, this is a requirement of databricks + bucket_path = bucket_path.replace("gs://", "gcs://") + if not stage_name: + # when loading from bucket stage must be given + raise LoadJobTerminalException(file_path, f"Cannot load from bucket path {bucket_path} without a stage name. See https://dlthub.com/docs/dlt-ecosystem/destinations/databricks for instructions on setting up the `stage_name`") + from_clause = f"FROM @{stage_name}/" + files_clause = f"FILES = ('{urlparse(bucket_path).path.lstrip('/')}')" + else: + # this means we have a local file + if not stage_name: + # Use implicit table stage by default: "SCHEMA_NAME"."%TABLE_NAME" + stage_name = client.make_qualified_table_name('%'+table_name) + stage_file_path = f'@{stage_name}/"{load_id}"/{file_name}' + from_clause = f"FROM {stage_file_path}" + + # decide on source format, stage_file_path will either be a local file or a bucket path + source_format = "( TYPE = 'JSON', BINARY_FORMAT = 'BASE64' )" + if file_name.endswith("parquet"): + source_format = "(TYPE = 'PARQUET', BINARY_AS_TEXT = FALSE)" + + with client.begin_transaction(): + # PUT and COPY in one tx if local file, otherwise only copy + if not bucket_path: + client.execute_sql(f'PUT file://{file_path} @{stage_name}/"{load_id}" OVERWRITE = TRUE, AUTO_COMPRESS = FALSE') + client.execute_sql( + f"""COPY INTO {qualified_table_name} + {from_clause} + {files_clause} + {credentials_clause} + FILE_FORMAT = {source_format} + MATCH_BY_COLUMN_NAME='CASE_INSENSITIVE' + """ + ) + if stage_file_path and not keep_staged_files: + client.execute_sql(f'REMOVE {stage_file_path}') + + + def state(self) -> TLoadJobState: + return "completed" + + def exception(self) -> str: + raise NotImplementedError() + +class DatabricksStagingCopyJob(SqlStagingCopyJob): + + @classmethod + def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + sql: List[str] = [] + for table in table_chain: + with sql_client.with_staging_dataset(staging=True): + staging_table_name = sql_client.make_qualified_table_name(table["name"]) + table_name = sql_client.make_qualified_table_name(table["name"]) + # drop destination table + sql.append(f"DROP TABLE IF EXISTS {table_name};") + # recreate destination table with data cloned from staging table + sql.append(f"CREATE TABLE {table_name} CLONE {staging_table_name};") + return sql + + +class DatabricksClient(SqlJobClientWithStaging): + capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() + + def __init__(self, schema: Schema, config: DatabricksClientConfiguration) -> None: + sql_client = DatabricksSqlClient( + config.normalize_dataset_name(schema), + config.credentials + ) + super().__init__(schema, config, sql_client) + self.config: DatabricksClientConfiguration = config + self.sql_client: DatabricksSqlClient = sql_client # type: ignore + self.type_mapper = DatabricksTypeMapper(self.capabilities) + + def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + job = super().start_file_load(table, file_path, load_id) + + if not job: + job = DatabricksLoadJob( + file_path, + table['name'], + load_id, + self.sql_client, + stage_name=self.config.stage_name, + keep_staged_files=self.config.keep_staged_files, + staging_credentials=self.config.staging_config.credentials if self.config.staging_config else None + ) + return job + + def restore_file_load(self, file_path: str) -> LoadJob: + return EmptyLoadJob.from_file_path(file_path, "completed") + + def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema]) -> List[str]: + # Override because databricks requires multiple columns in a single ADD COLUMN clause + return ["ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c) for c in new_columns)] + + def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: + return DatabricksStagingCopyJob.from_table_chain(table_chain, self.sql_client) + + def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool, separate_alters: bool = False) -> List[str]: + sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) + + cluster_list = [self.capabilities.escape_identifier(c['name']) for c in new_columns if c.get('cluster')] + + if cluster_list: + sql[0] = sql[0] + "\nCLUSTER BY (" + ",".join(cluster_list) + ")" + + return sql + + def _from_db_type(self, bq_t: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: + return self.type_mapper.from_db_type(bq_t, precision, scale) + + def _get_column_def_sql(self, c: TColumnSchema) -> str: + name = self.capabilities.escape_identifier(c["name"]) + return f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" + + def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: + table_name = table_name.upper() # All databricks tables are uppercased in information schema + exists, table = super().get_storage_table(table_name) + if not exists: + return exists, table + # Databricks converts all unquoted columns to UPPER CASE + # Convert back to lower case to enable comparison with dlt schema + table = {col_name.lower(): dict(col, name=col_name.lower()) for col_name, col in table.items()} # type: ignore + return exists, table diff --git a/dlt/destinations/databricks/sql_client.py b/dlt/destinations/databricks/sql_client.py new file mode 100644 index 0000000000..07c5d67d03 --- /dev/null +++ b/dlt/destinations/databricks/sql_client.py @@ -0,0 +1,160 @@ +from contextlib import contextmanager, suppress +from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence, List + +import databricks.connector as databricks_lib + +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.destinations.exceptions import DatabaseTerminalException, DatabaseTransientException, DatabaseUndefinedRelation +from dlt.destinations.sql_client import DBApiCursorImpl, SqlClientBase, raise_database_error, raise_open_connection_error +from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction, DataFrame +from dlt.destinations.databricks.configuration import DatabricksCredentials +from dlt.destinations.databricks import capabilities + +class DatabricksCursorImpl(DBApiCursorImpl): + native_cursor: databricks_lib.cursor.DatabricksCursor # type: ignore[assignment] + + def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: + if chunk_size is None: + return self.native_cursor.fetch_pandas_all(**kwargs) + return super().df(chunk_size=chunk_size, **kwargs) + + +class DatabricksSqlClient(SqlClientBase[databricks_lib.DatabricksConnection], DBTransaction): + + dbapi: ClassVar[DBApi] = databricks_lib + capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() + + def __init__(self, dataset_name: str, credentials: DatabricksCredentials) -> None: + super().__init__(credentials.database, dataset_name) + self._conn: databricks_lib.DatabricksConnection = None + self.credentials = credentials + + def open_connection(self) -> databricks_lib.DatabricksConnection: + conn_params = self.credentials.to_connector_params() + # set the timezone to UTC so when loading from file formats that do not have timezones + # we get dlt expected UTC + if "timezone" not in conn_params: + conn_params["timezone"] = "UTC" + self._conn = databricks_lib.connect( + schema=self.fully_qualified_dataset_name(), + **conn_params + ) + return self._conn + + @raise_open_connection_error + def close_connection(self) -> None: + if self._conn: + self._conn.close() + self._conn = None + + @contextmanager + def begin_transaction(self) -> Iterator[DBTransaction]: + try: + self._conn.autocommit(False) + yield self + self.commit_transaction() + except Exception: + self.rollback_transaction() + raise + + @raise_database_error + def commit_transaction(self) -> None: + self._conn.commit() + self._conn.autocommit(True) + + @raise_database_error + def rollback_transaction(self) -> None: + self._conn.rollback() + self._conn.autocommit(True) + + @property + def native_connection(self) -> "databricks_lib.DatabricksConnection": + return self._conn + + def drop_tables(self, *tables: str) -> None: + # Tables are drop with `IF EXISTS`, but databricks raises when the schema doesn't exist. + # Multi statement exec is safe and the error can be ignored since all tables are in the same schema. + with suppress(DatabaseUndefinedRelation): + super().drop_tables(*tables) + + def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequence[Sequence[Any]]]: + with self.execute_query(sql, *args, **kwargs) as curr: + if curr.description is None: + return None + else: + f = curr.fetchall() + return f + + @contextmanager + @raise_database_error + def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DBApiCursor]: + curr: DBApiCursor = None + db_args = args if args else kwargs if kwargs else None + with self._conn.cursor() as curr: # type: ignore[assignment] + try: + curr.execute(query, db_args, num_statements=0) + yield DatabricksCursorImpl(curr) # type: ignore[abstract] + except databricks_lib.Error as outer: + try: + self._reset_connection() + except databricks_lib.Error: + self.close_connection() + self.open_connection() + raise outer + + def fully_qualified_dataset_name(self, escape: bool = True) -> str: + # Always escape for uppercase + if escape: + return self.capabilities.escape_identifier(self.dataset_name) + return self.dataset_name.upper() + + def _reset_connection(self) -> None: + self._conn.rollback() + self._conn.autocommit(True) + + @classmethod + def _make_database_exception(cls, ex: Exception) -> Exception: + if isinstance(ex, databricks_lib.errors.ProgrammingError): + if ex.sqlstate == 'P0000' and ex.errno == 100132: + # Error in a multi statement execution. These don't show the original error codes + msg = str(ex) + if "NULL result in a non-nullable column" in msg: + return DatabaseTerminalException(ex) + elif "does not exist or not authorized" in msg: # E.g. schema not found + return DatabaseUndefinedRelation(ex) + else: + return DatabaseTransientException(ex) + if ex.sqlstate in {'42S02', '02000'}: + return DatabaseUndefinedRelation(ex) + elif ex.sqlstate == '22023': # Adding non-nullable no-default column + return DatabaseTerminalException(ex) + elif ex.sqlstate == '42000' and ex.errno == 904: # Invalid identifier + return DatabaseTerminalException(ex) + elif ex.sqlstate == "22000": + return DatabaseTerminalException(ex) + else: + return DatabaseTransientException(ex) + + elif isinstance(ex, databricks_lib.errors.IntegrityError): + raise DatabaseTerminalException(ex) + elif isinstance(ex, databricks_lib.errors.DatabaseError): + term = cls._maybe_make_terminal_exception_from_data_error(ex) + if term: + return term + else: + return DatabaseTransientException(ex) + elif isinstance(ex, TypeError): + # databricks raises TypeError on malformed query parameters + return DatabaseTransientException(databricks_lib.errors.ProgrammingError(str(ex))) + elif cls.is_dbapi_exception(ex): + return DatabaseTransientException(ex) + else: + return ex + + @staticmethod + def _maybe_make_terminal_exception_from_data_error(databricks_ex: databricks_lib.DatabaseError) -> Optional[Exception]: + return None + + @staticmethod + def is_dbapi_exception(ex: Exception) -> bool: + return isinstance(ex, databricks_lib.DatabaseError) From e941f4f82508cea0b9e7fdb08baf06c85705f19a Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Fri, 10 Nov 2023 09:01:31 -0600 Subject: [PATCH 12/84] Refactor Databricks SQL client to use new API --- dlt/destinations/databricks/sql_client.py | 27 +++++++++++++---------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/dlt/destinations/databricks/sql_client.py b/dlt/destinations/databricks/sql_client.py index 07c5d67d03..8a825af3ea 100644 --- a/dlt/destinations/databricks/sql_client.py +++ b/dlt/destinations/databricks/sql_client.py @@ -1,7 +1,11 @@ from contextlib import contextmanager, suppress from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence, List -import databricks.connector as databricks_lib +from databricks import sql as databricks_lib +from databricks.sql.client import ( + Connection as DatabricksSQLConnection, + Cursor as DatabricksSQLCursor, +) from dlt.common.destination import DestinationCapabilitiesContext from dlt.destinations.exceptions import DatabaseTerminalException, DatabaseTransientException, DatabaseUndefinedRelation @@ -11,32 +15,31 @@ from dlt.destinations.databricks import capabilities class DatabricksCursorImpl(DBApiCursorImpl): - native_cursor: databricks_lib.cursor.DatabricksCursor # type: ignore[assignment] + native_cursor: DatabricksSQLCursor # type: ignore[assignment] def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: if chunk_size is None: - return self.native_cursor.fetch_pandas_all(**kwargs) + return self.native_cursor.fetchall(**kwargs) return super().df(chunk_size=chunk_size, **kwargs) -class DatabricksSqlClient(SqlClientBase[databricks_lib.DatabricksConnection], DBTransaction): +class DatabricksSqlClient(SqlClientBase[DatabricksSQLConnection], DBTransaction): dbapi: ClassVar[DBApi] = databricks_lib capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, dataset_name: str, credentials: DatabricksCredentials) -> None: super().__init__(credentials.database, dataset_name) - self._conn: databricks_lib.DatabricksConnection = None + self._conn: databricks_lib.connect(credentials) = None self.credentials = credentials - def open_connection(self) -> databricks_lib.DatabricksConnection: + def open_connection(self) -> DatabricksSQLConnection: conn_params = self.credentials.to_connector_params() # set the timezone to UTC so when loading from file formats that do not have timezones # we get dlt expected UTC if "timezone" not in conn_params: conn_params["timezone"] = "UTC" self._conn = databricks_lib.connect( - schema=self.fully_qualified_dataset_name(), **conn_params ) return self._conn @@ -68,7 +71,7 @@ def rollback_transaction(self) -> None: self._conn.autocommit(True) @property - def native_connection(self) -> "databricks_lib.DatabricksConnection": + def native_connection(self) -> "DatabricksSQLConnection": return self._conn def drop_tables(self, *tables: str) -> None: @@ -114,7 +117,7 @@ def _reset_connection(self) -> None: @classmethod def _make_database_exception(cls, ex: Exception) -> Exception: - if isinstance(ex, databricks_lib.errors.ProgrammingError): + if isinstance(ex, databricks_lib.ProgrammingError): if ex.sqlstate == 'P0000' and ex.errno == 100132: # Error in a multi statement execution. These don't show the original error codes msg = str(ex) @@ -135,9 +138,9 @@ def _make_database_exception(cls, ex: Exception) -> Exception: else: return DatabaseTransientException(ex) - elif isinstance(ex, databricks_lib.errors.IntegrityError): + elif isinstance(ex, databricks_lib.IntegrityError): raise DatabaseTerminalException(ex) - elif isinstance(ex, databricks_lib.errors.DatabaseError): + elif isinstance(ex, databricks_lib.DatabaseError): term = cls._maybe_make_terminal_exception_from_data_error(ex) if term: return term @@ -145,7 +148,7 @@ def _make_database_exception(cls, ex: Exception) -> Exception: return DatabaseTransientException(ex) elif isinstance(ex, TypeError): # databricks raises TypeError on malformed query parameters - return DatabaseTransientException(databricks_lib.errors.ProgrammingError(str(ex))) + return DatabaseTransientException(databricks_lib.ProgrammingError(str(ex))) elif cls.is_dbapi_exception(ex): return DatabaseTransientException(ex) else: From 6bc1e4fd6f1c6d7fd5c8b3c83b1e1cfa53fe9c86 Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Fri, 10 Nov 2023 11:29:34 -0600 Subject: [PATCH 13/84] Refactor Databricks SQL client code --- dlt/destinations/databricks/sql_client.py | 84 ++++++----------------- 1 file changed, 21 insertions(+), 63 deletions(-) diff --git a/dlt/destinations/databricks/sql_client.py b/dlt/destinations/databricks/sql_client.py index 8a825af3ea..68033b5e76 100644 --- a/dlt/destinations/databricks/sql_client.py +++ b/dlt/destinations/databricks/sql_client.py @@ -6,7 +6,9 @@ Connection as DatabricksSQLConnection, Cursor as DatabricksSQLCursor, ) +from databricks.sql.exc import Error as DatabricksSQLError +from dlt.common import logger from dlt.common.destination import DestinationCapabilitiesContext from dlt.destinations.exceptions import DatabaseTerminalException, DatabaseTransientException, DatabaseUndefinedRelation from dlt.destinations.sql_client import DBApiCursorImpl, SqlClientBase, raise_database_error, raise_open_connection_error @@ -30,15 +32,11 @@ class DatabricksSqlClient(SqlClientBase[DatabricksSQLConnection], DBTransaction) def __init__(self, dataset_name: str, credentials: DatabricksCredentials) -> None: super().__init__(credentials.database, dataset_name) - self._conn: databricks_lib.connect(credentials) = None + self._conn: DatabricksSQLConnection = None self.credentials = credentials def open_connection(self) -> DatabricksSQLConnection: conn_params = self.credentials.to_connector_params() - # set the timezone to UTC so when loading from file formats that do not have timezones - # we get dlt expected UTC - if "timezone" not in conn_params: - conn_params["timezone"] = "UTC" self._conn = databricks_lib.connect( **conn_params ) @@ -46,29 +44,24 @@ def open_connection(self) -> DatabricksSQLConnection: @raise_open_connection_error def close_connection(self) -> None: - if self._conn: + try: self._conn.close() self._conn = None + except DatabricksSQLError as exc: + logger.warning("Exception while closing connection: {}".format(exc)) @contextmanager def begin_transaction(self) -> Iterator[DBTransaction]: - try: - self._conn.autocommit(False) - yield self - self.commit_transaction() - except Exception: - self.rollback_transaction() - raise + logger.warning("NotImplemented: Databricks does not support transactions. Each SQL statement is auto-committed separately.") + yield self @raise_database_error def commit_transaction(self) -> None: - self._conn.commit() - self._conn.autocommit(True) + pass @raise_database_error def rollback_transaction(self) -> None: - self._conn.rollback() - self._conn.autocommit(True) + logger.warning("NotImplemented: rollback") @property def native_connection(self) -> "DatabricksSQLConnection": @@ -106,58 +99,23 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB raise outer def fully_qualified_dataset_name(self, escape: bool = True) -> str: - # Always escape for uppercase if escape: return self.capabilities.escape_identifier(self.dataset_name) - return self.dataset_name.upper() + return self.dataset_name def _reset_connection(self) -> None: - self._conn.rollback() - self._conn.autocommit(True) - - @classmethod - def _make_database_exception(cls, ex: Exception) -> Exception: - if isinstance(ex, databricks_lib.ProgrammingError): - if ex.sqlstate == 'P0000' and ex.errno == 100132: - # Error in a multi statement execution. These don't show the original error codes - msg = str(ex) - if "NULL result in a non-nullable column" in msg: - return DatabaseTerminalException(ex) - elif "does not exist or not authorized" in msg: # E.g. schema not found - return DatabaseUndefinedRelation(ex) - else: - return DatabaseTransientException(ex) - if ex.sqlstate in {'42S02', '02000'}: - return DatabaseUndefinedRelation(ex) - elif ex.sqlstate == '22023': # Adding non-nullable no-default column - return DatabaseTerminalException(ex) - elif ex.sqlstate == '42000' and ex.errno == 904: # Invalid identifier - return DatabaseTerminalException(ex) - elif ex.sqlstate == "22000": - return DatabaseTerminalException(ex) - else: - return DatabaseTransientException(ex) + self.close_connection() + self.open_connection() - elif isinstance(ex, databricks_lib.IntegrityError): - raise DatabaseTerminalException(ex) + @staticmethod + def _make_database_exception(ex: Exception) -> Exception: + if isinstance(ex, databricks_lib.OperationalError): + if "TABLE_OR_VIEW_NOT_FOUND" in str(ex): + return DatabaseUndefinedRelation(ex) + return DatabaseTerminalException(ex) + elif isinstance(ex, (databricks_lib.ProgrammingError, databricks_lib.IntegrityError)): + return DatabaseTerminalException(ex) elif isinstance(ex, databricks_lib.DatabaseError): - term = cls._maybe_make_terminal_exception_from_data_error(ex) - if term: - return term - else: - return DatabaseTransientException(ex) - elif isinstance(ex, TypeError): - # databricks raises TypeError on malformed query parameters - return DatabaseTransientException(databricks_lib.ProgrammingError(str(ex))) - elif cls.is_dbapi_exception(ex): return DatabaseTransientException(ex) else: return ex - - @staticmethod - def _maybe_make_terminal_exception_from_data_error(databricks_ex: databricks_lib.DatabaseError) -> Optional[Exception]: - return None - - @staticmethod - def is_dbapi_exception(ex: Exception) -> bool: - return isinstance(ex, databricks_lib.DatabaseError) From 8cd57b41d7efdb149ef80e30435833a51cde7542 Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Fri, 10 Nov 2023 12:07:19 -0600 Subject: [PATCH 14/84] Refactor DatabricksCredentials configuration class --- dlt/destinations/databricks/configuration.py | 204 ++++++++++--------- 1 file changed, 108 insertions(+), 96 deletions(-) diff --git a/dlt/destinations/databricks/configuration.py b/dlt/destinations/databricks/configuration.py index f24b3e2fcd..23010aa477 100644 --- a/dlt/destinations/databricks/configuration.py +++ b/dlt/destinations/databricks/configuration.py @@ -1,102 +1,113 @@ -import base64 -import binascii +from typing import ClassVar, Final, Optional, Any, Dict, List -from typing import Final, Optional, Any, Dict, ClassVar, List - -from sqlalchemy.engine import URL - -from dlt import version -from dlt.common.exceptions import MissingDependencyException -from dlt.common.typing import TSecretStrValue -from dlt.common.configuration.specs import ConnectionStringCredentials from dlt.common.configuration.exceptions import ConfigurationValueError -from dlt.common.configuration import configspec +from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration -from dlt.common.utils import digest128 - - -def _read_private_key(private_key: str, password: Optional[str] = None) -> bytes: - """Load an encrypted or unencrypted private key from string. - """ - try: - from cryptography.hazmat.backends import default_backend - from cryptography.hazmat.primitives.asymmetric import rsa - from cryptography.hazmat.primitives.asymmetric import dsa - from cryptography.hazmat.primitives import serialization - from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes - except ModuleNotFoundError as e: - raise MissingDependencyException("DatabricksCredentials with private key", dependencies=[f"{version.DLT_PKG_NAME}[databricks]"]) from e - - try: - # load key from base64-encoded DER key - pkey = serialization.load_der_private_key( - base64.b64decode(private_key), - password=password.encode() if password is not None else None, - backend=default_backend(), - ) - except Exception: - # loading base64-encoded DER key failed -> assume it's a plain-text PEM key - pkey = serialization.load_pem_private_key( - private_key.encode(encoding="ascii"), - password=password.encode() if password is not None else None, - backend=default_backend(), - ) - - return pkey.private_bytes( - encoding=serialization.Encoding.DER, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption() - ) +CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog" + @configspec -class DatabricksCredentials(ConnectionStringCredentials): - drivername: Final[str] = "databricks" # type: ignore[misc] - password: Optional[TSecretStrValue] = None - host: str = None - database: str = None - warehouse: Optional[str] = None - role: Optional[str] = None - private_key: Optional[TSecretStrValue] = None - private_key_passphrase: Optional[TSecretStrValue] = None - - __config_gen_annotations__: ClassVar[List[str]] = ["password", "warehouse", "role"] - - def parse_native_representation(self, native_value: Any) -> None: - super().parse_native_representation(native_value) - self.warehouse = self.query.get('warehouse') - self.role = self.query.get('role') - self.private_key = self.query.get('private_key') # type: ignore - self.private_key_passphrase = self.query.get('private_key_passphrase') # type: ignore - if not self.is_partial() and (self.password or self.private_key): - self.resolve() - - def on_resolved(self) -> None: - if not self.password and not self.private_key: - raise ConfigurationValueError("Please specify password or private_key. DatabricksCredentials supports password and private key authentication and one of those must be specified.") - - def to_url(self) -> URL: - query = dict(self.query or {}) - if self.warehouse and 'warehouse' not in query: - query['warehouse'] = self.warehouse - if self.role and 'role' not in query: - query['role'] = self.role - return URL.create(self.drivername, self.username, self.password, self.host, self.port, self.database, query) +class DatabricksCredentials(CredentialsConfiguration): + database: Optional[str] = None # type: ignore[assignment] + schema: Optional[str] = None # type: ignore[assignment] + host: Optional[str] = None + http_path: Optional[str] = None + token: Optional[str] = None + client_id: Optional[str] = None + client_secret: Optional[str] = None + session_properties: Optional[Dict[str, Any]] = None + connection_parameters: Optional[Dict[str, Any]] = None + auth_type: Optional[str] = None + + connect_retries: int = 1 + connect_timeout: Optional[int] = None + retry_all: bool = False + + _credentials_provider: Optional[Dict[str, Any]] = None + + __config_gen_annotations__: ClassVar[List[str]] = ["server_hostname", "http_path", "catalog", "schema"] + + def __post_init__(self) -> None: + if "." in (self.schema or ""): + raise ConfigurationValueError( + f"The schema should not contain '.': {self.schema}\n" + "If you are trying to set a catalog, please use `catalog` instead.\n" + ) + + session_properties = self.session_properties or {} + if CATALOG_KEY_IN_SESSION_PROPERTIES in session_properties: + if self.database is None: + self.database = session_properties[CATALOG_KEY_IN_SESSION_PROPERTIES] + del session_properties[CATALOG_KEY_IN_SESSION_PROPERTIES] + else: + raise ConfigurationValueError( + f"Got duplicate keys: (`{CATALOG_KEY_IN_SESSION_PROPERTIES}` " + 'in session_properties) all map to "database"' + ) + self.session_properties = session_properties + + if self.database is not None: + database = self.database.strip() + if not database: + raise ConfigurationValueError( + f"Invalid catalog name : `{self.database}`." + ) + self.database = database + else: + self.database = "hive_metastore" + + connection_parameters = self.connection_parameters or {} + for key in ( + "server_hostname", + "http_path", + "access_token", + "client_id", + "client_secret", + "session_configuration", + "catalog", + "schema", + "_user_agent_entry", + ): + if key in connection_parameters: + raise ConfigurationValueError( + f"The connection parameter `{key}` is reserved." + ) + if "http_headers" in connection_parameters: + http_headers = connection_parameters["http_headers"] + if not isinstance(http_headers, dict) or any( + not isinstance(key, str) or not isinstance(value, str) + for key, value in http_headers.items() + ): + raise ConfigurationValueError( + "The connection parameter `http_headers` should be dict of strings: " + f"{http_headers}." + ) + if "_socket_timeout" not in connection_parameters: + connection_parameters["_socket_timeout"] = 180 + self.connection_parameters = connection_parameters + + def validate_creds(self) -> None: + for key in ["host", "http_path"]: + if not getattr(self, key): + raise ConfigurationValueError( + "The config '{}' is required to connect to Databricks".format(key) + ) + if not self.token and self.auth_type != "oauth": + raise ConfigurationValueError( + ("The config `auth_type: oauth` is required when not using access token") + ) + + if not self.client_id and self.client_secret: + raise ConfigurationValueError( + ( + "The config 'client_id' is required to connect " + "to Databricks when 'client_secret' is present" + ) + ) def to_connector_params(self) -> Dict[str, Any]: - private_key: Optional[bytes] = None - if self.private_key: - private_key = _read_private_key(self.private_key, self.private_key_passphrase) - return dict( - self.query or {}, - user=self.username, - password=self.password, - account=self.host, - database=self.database, - warehouse=self.warehouse, - role=self.role, - private_key=private_key, - ) + return self.connection_parameters or {} @configspec @@ -109,8 +120,9 @@ class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration keep_staged_files: bool = True """Whether to keep or delete the staged files after COPY INTO succeeds""" - def fingerprint(self) -> str: - """Returns a fingerprint of host part of a connection string""" - if self.credentials and self.credentials.host: - return digest128(self.credentials.host) - return "" + def __str__(self) -> str: + """Return displayable destination location""" + if self.staging_config: + return str(self.staging_config.credentials) + else: + return "[no staging set]" From d3dd5f5d0abf61526c35a5175e51e1ed36f200ba Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Sun, 12 Nov 2023 09:33:19 -0600 Subject: [PATCH 15/84] Implement commit_transaction method in DatabricksSqlClient --- dlt/destinations/databricks/sql_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dlt/destinations/databricks/sql_client.py b/dlt/destinations/databricks/sql_client.py index 68033b5e76..a2f69199e9 100644 --- a/dlt/destinations/databricks/sql_client.py +++ b/dlt/destinations/databricks/sql_client.py @@ -57,6 +57,7 @@ def begin_transaction(self) -> Iterator[DBTransaction]: @raise_database_error def commit_transaction(self) -> None: + logger.warning("NotImplemented: commit") pass @raise_database_error From a2a53fc3367d681dedafe46421409f18d80815a1 Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Sun, 12 Nov 2023 09:36:35 -0600 Subject: [PATCH 16/84] Fix DatabricksCredentials host and http_path types. --- dlt/destinations/databricks/configuration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dlt/destinations/databricks/configuration.py b/dlt/destinations/databricks/configuration.py index 23010aa477..18974ebf32 100644 --- a/dlt/destinations/databricks/configuration.py +++ b/dlt/destinations/databricks/configuration.py @@ -11,8 +11,8 @@ class DatabricksCredentials(CredentialsConfiguration): database: Optional[str] = None # type: ignore[assignment] schema: Optional[str] = None # type: ignore[assignment] - host: Optional[str] = None - http_path: Optional[str] = None + host: str = None + http_path: str = None token: Optional[str] = None client_id: Optional[str] = None client_secret: Optional[str] = None From b630d31e2cb17ffbf3773a6ef99c66747e093eb0 Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Sun, 12 Nov 2023 09:38:02 -0600 Subject: [PATCH 17/84] Refactor DatabricksTypeMapper and DatabricksLoadJob --- dlt/destinations/databricks/databricks.py | 79 +++++++++++++++-------- 1 file changed, 52 insertions(+), 27 deletions(-) diff --git a/dlt/destinations/databricks/databricks.py b/dlt/destinations/databricks/databricks.py index 7af39082c6..5d97c196d3 100644 --- a/dlt/destinations/databricks/databricks.py +++ b/dlt/destinations/databricks/databricks.py @@ -18,45 +18,73 @@ from dlt.destinations.databricks.configuration import DatabricksClientConfiguration from dlt.destinations.databricks.sql_client import DatabricksSqlClient from dlt.destinations.sql_jobs import SqlStagingCopyJob -from dlt.destinations.databricks.sql_client import DatabricksSqlClient from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.type_mapping import TypeMapper class DatabricksTypeMapper(TypeMapper): - BIGINT_PRECISION = 19 sct_to_unbound_dbt = { - "complex": "VARIANT", - "text": "VARCHAR", - "double": "FLOAT", + "complex": "ARRAY", # Databricks supports complex types like ARRAY + "text": "STRING", + "double": "DOUBLE", "bool": "BOOLEAN", "date": "DATE", - "timestamp": "TIMESTAMP_TZ", - "bigint": f"NUMBER({BIGINT_PRECISION},0)", # Databricks has no integer types + "timestamp": "TIMESTAMP", # TIMESTAMP for local timezone + "bigint": "BIGINT", "binary": "BINARY", - "time": "TIME", - } - - sct_to_dbt = { - "text": "VARCHAR(%i)", - "timestamp": "TIMESTAMP_TZ(%i)", - "decimal": "NUMBER(%i,%i)", - "time": "TIME(%i)", - "wei": "NUMBER(%i,%i)", + "decimal": "DECIMAL", # DECIMAL(p,s) format + "float": "FLOAT", + "int": "INT", + "smallint": "SMALLINT", + "tinyint": "TINYINT", + "void": "VOID", + "interval": "INTERVAL", + "map": "MAP", + "struct": "STRUCT", } dbt_to_sct = { - "VARCHAR": "text", - "FLOAT": "double", + "STRING": "text", + "DOUBLE": "double", "BOOLEAN": "bool", "DATE": "date", - "TIMESTAMP_TZ": "timestamp", + "TIMESTAMP": "timestamp", + "BIGINT": "bigint", "BINARY": "binary", - "VARIANT": "complex", - "TIME": "time" + "DECIMAL": "decimal", + "FLOAT": "float", + "INT": "int", + "SMALLINT": "smallint", + "TINYINT": "tinyint", + "VOID": "void", + "INTERVAL": "interval", + "MAP": "map", + "STRUCT": "struct", + "ARRAY": "complex" } + sct_to_dbt = { + "text": "STRING", + "double": "DOUBLE", + "bool": "BOOLEAN", + "date": "DATE", + "timestamp": "TIMESTAMP", + "bigint": "BIGINT", + "binary": "BINARY", + "decimal": "DECIMAL(%i,%i)", + "float": "FLOAT", + "int": "INT", + "smallint": "SMALLINT", + "tinyint": "TINYINT", + "void": "VOID", + "interval": "INTERVAL", + "map": "MAP", + "struct": "STRUCT", + "complex": "ARRAY" + } + + def from_db_type(self, db_type: str, precision: Optional[int] = None, scale: Optional[int] = None) -> TColumnType: if db_type == "NUMBER": if precision == self.BIGINT_PRECISION and scale == 0: @@ -89,14 +117,14 @@ def __init__( bucket_url = urlparse(bucket_path) bucket_scheme = bucket_url.scheme # referencing an external s3/azure stage does not require explicit AWS credentials - if bucket_scheme in ["s3", "az", "abfs"] and stage_name: + if bucket_scheme in ["s3", "az", "abfss"] and stage_name: from_clause = f"FROM '@{stage_name}'" - files_clause = f"FILES = ('{bucket_url.path.lstrip('/')}')" + files_clause = f"LOCATION ('{bucket_url.path.lstrip('/')}')" # referencing an staged files via a bucket URL requires explicit AWS credentials elif bucket_scheme == "s3" and staging_credentials and isinstance(staging_credentials, AwsCredentialsWithoutDefaults): credentials_clause = f"""CREDENTIALS=(AWS_KEY_ID='{staging_credentials.aws_access_key_id}' AWS_SECRET_KEY='{staging_credentials.aws_secret_access_key}')""" from_clause = f"FROM '{bucket_path}'" - elif bucket_scheme in ["az", "abfs"] and staging_credentials and isinstance(staging_credentials, AzureCredentialsWithoutDefaults): + elif bucket_scheme in ["az", "abfss"] and staging_credentials and isinstance(staging_credentials, AzureCredentialsWithoutDefaults): # Explicit azure credentials are needed to load from bucket without a named stage credentials_clause = f"CREDENTIALS=(AZURE_SAS_TOKEN='?{staging_credentials.azure_storage_sas_token}')" # Converts an az:/// to azure://.blob.core.windows.net// @@ -226,11 +254,8 @@ def _get_column_def_sql(self, c: TColumnSchema) -> str: return f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: - table_name = table_name.upper() # All databricks tables are uppercased in information schema exists, table = super().get_storage_table(table_name) if not exists: return exists, table - # Databricks converts all unquoted columns to UPPER CASE - # Convert back to lower case to enable comparison with dlt schema table = {col_name.lower(): dict(col, name=col_name.lower()) for col_name, col in table.items()} # type: ignore return exists, table From c0ce477b51be6768303104e8604bf3ca3e502e4a Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Sun, 12 Nov 2023 11:55:54 -0600 Subject: [PATCH 18/84] Add support for secret string values in Databricks credentials. --- dlt/destinations/databricks/configuration.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dlt/destinations/databricks/configuration.py b/dlt/destinations/databricks/configuration.py index 18974ebf32..15bb622ab7 100644 --- a/dlt/destinations/databricks/configuration.py +++ b/dlt/destinations/databricks/configuration.py @@ -1,5 +1,6 @@ from typing import ClassVar, Final, Optional, Any, Dict, List +from dlt.common.typing import TSecretStrValue from dlt.common.configuration.exceptions import ConfigurationValueError from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration @@ -13,9 +14,9 @@ class DatabricksCredentials(CredentialsConfiguration): schema: Optional[str] = None # type: ignore[assignment] host: str = None http_path: str = None - token: Optional[str] = None + token: Optional[TSecretStrValue] = None client_id: Optional[str] = None - client_secret: Optional[str] = None + client_secret: Optional[TSecretStrValue] = None session_properties: Optional[Dict[str, Any]] = None connection_parameters: Optional[Dict[str, Any]] = None auth_type: Optional[str] = None From b85b930ceb517e9040c419bde0bce0d7c3441ed7 Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Sun, 12 Nov 2023 17:54:23 -0600 Subject: [PATCH 19/84] Fix DatabricksSqlClient super call to use catalog instead of database. Add has_dataset method to check if dataset exists. --- dlt/destinations/databricks/sql_client.py | 32 ++++++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/dlt/destinations/databricks/sql_client.py b/dlt/destinations/databricks/sql_client.py index a2f69199e9..8c6ab4e32f 100644 --- a/dlt/destinations/databricks/sql_client.py +++ b/dlt/destinations/databricks/sql_client.py @@ -31,7 +31,7 @@ class DatabricksSqlClient(SqlClientBase[DatabricksSQLConnection], DBTransaction) capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, dataset_name: str, credentials: DatabricksCredentials) -> None: - super().__init__(credentials.database, dataset_name) + super().__init__(credentials.catalog, dataset_name) self._conn: DatabricksSQLConnection = None self.credentials = credentials @@ -68,6 +68,18 @@ def rollback_transaction(self) -> None: def native_connection(self) -> "DatabricksSQLConnection": return self._conn + def has_dataset(self) -> bool: + query = """ + SELECT 1 + FROM SYSTEM.INFORMATION_SCHEMA.SCHEMATA + WHERE """ + db_params = self.fully_qualified_dataset_name(escape=False).split(".", 2) + if len(db_params) == 2: + query += " catalog_name = %s AND " + query += "schema_name = %s" + rows = self.execute_sql(query, *db_params) + return len(rows) > 0 + def drop_tables(self, *tables: str) -> None: # Tables are drop with `IF EXISTS`, but databricks raises when the schema doesn't exist. # Multi statement exec is safe and the error can be ignored since all tables are in the same schema. @@ -89,7 +101,7 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB db_args = args if args else kwargs if kwargs else None with self._conn.cursor() as curr: # type: ignore[assignment] try: - curr.execute(query, db_args, num_statements=0) + curr.execute(query, db_args) yield DatabricksCursorImpl(curr) # type: ignore[abstract] except databricks_lib.Error as outer: try: @@ -101,8 +113,12 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB def fully_qualified_dataset_name(self, escape: bool = True) -> str: if escape: - return self.capabilities.escape_identifier(self.dataset_name) - return self.dataset_name + catalog = self.capabilities.escape_identifier(self.credentials.catalog) + dataset_name = self.capabilities.escape_identifier(self.dataset_name) + else: + catalog = self.credentials.catalog + dataset_name = self.dataset_name + return f"{catalog}.{dataset_name}" def _reset_connection(self) -> None: self.close_connection() @@ -120,3 +136,11 @@ def _make_database_exception(ex: Exception) -> Exception: return DatabaseTransientException(ex) else: return ex + + @staticmethod + def _maybe_make_terminal_exception_from_data_error(databricks_ex: databricks_lib.DatabaseError) -> Optional[Exception]: + return None + + @staticmethod + def is_dbapi_exception(ex: Exception) -> bool: + return isinstance(ex, databricks_lib.DatabaseError) From f6aac09e3c34fd0fdec8bd8537b85be3291b09f1 Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Sun, 12 Nov 2023 17:55:06 -0600 Subject: [PATCH 20/84] Update Databricks credentials configuration --- dlt/destinations/databricks/configuration.py | 37 +++++++++++++------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/dlt/destinations/databricks/configuration.py b/dlt/destinations/databricks/configuration.py index 15bb622ab7..7184c175c3 100644 --- a/dlt/destinations/databricks/configuration.py +++ b/dlt/destinations/databricks/configuration.py @@ -10,11 +10,11 @@ @configspec class DatabricksCredentials(CredentialsConfiguration): - database: Optional[str] = None # type: ignore[assignment] + catalog: Optional[str] = None # type: ignore[assignment] schema: Optional[str] = None # type: ignore[assignment] - host: str = None + server_hostname: str = None http_path: str = None - token: Optional[TSecretStrValue] = None + access_token: Optional[TSecretStrValue] = None client_id: Optional[str] = None client_secret: Optional[TSecretStrValue] = None session_properties: Optional[Dict[str, Any]] = None @@ -38,25 +38,25 @@ def __post_init__(self) -> None: session_properties = self.session_properties or {} if CATALOG_KEY_IN_SESSION_PROPERTIES in session_properties: - if self.database is None: - self.database = session_properties[CATALOG_KEY_IN_SESSION_PROPERTIES] + if self.catalog is None: + self.catalog = session_properties[CATALOG_KEY_IN_SESSION_PROPERTIES] del session_properties[CATALOG_KEY_IN_SESSION_PROPERTIES] else: raise ConfigurationValueError( f"Got duplicate keys: (`{CATALOG_KEY_IN_SESSION_PROPERTIES}` " - 'in session_properties) all map to "database"' + 'in session_properties) all map to "catalog"' ) self.session_properties = session_properties - if self.database is not None: - database = self.database.strip() - if not database: + if self.catalog is not None: + catalog = self.catalog.strip() + if not catalog: raise ConfigurationValueError( - f"Invalid catalog name : `{self.database}`." + f"Invalid catalog name : `{self.catalog}`." ) - self.database = database + self.catalog = catalog else: - self.database = "hive_metastore" + self.catalog = "hive_metastore" connection_parameters = self.connection_parameters or {} for key in ( @@ -108,7 +108,18 @@ def validate_creds(self) -> None: ) def to_connector_params(self) -> Dict[str, Any]: - return self.connection_parameters or {} + return dict( + catalog=self.catalog, + schema=self.schema, + server_hostname=self.server_hostname, + http_path=self.http_path, + access_token=self.access_token, + client_id=self.client_id, + client_secret=self.client_secret, + session_properties=self.session_properties or {}, + connection_parameters=self.connection_parameters or {}, + auth_type=self.auth_type, + ) @configspec From f2ff1814148e95c8bbd2944f5d6429c5d0d6b3a4 Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Sun, 12 Nov 2023 18:32:17 -0600 Subject: [PATCH 21/84] Update Databricks destination capabilities and SQL client. --- dlt/destinations/databricks/__init__.py | 2 +- dlt/destinations/databricks/sql_client.py | 19 +++++++++---------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/dlt/destinations/databricks/__init__.py b/dlt/destinations/databricks/__init__.py index 0ca1437d80..8d7fed338c 100644 --- a/dlt/destinations/databricks/__init__.py +++ b/dlt/destinations/databricks/__init__.py @@ -31,7 +31,7 @@ def capabilities() -> DestinationCapabilitiesContext: caps.is_max_query_length_in_bytes = True caps.max_text_data_type_length = 16 * 1024 * 1024 caps.is_max_text_data_type_length_in_bytes = True - caps.supports_ddl_transactions = True + caps.supports_ddl_transactions = False caps.alter_add_multi_column = True return caps diff --git a/dlt/destinations/databricks/sql_client.py b/dlt/destinations/databricks/sql_client.py index 8c6ab4e32f..cddf88d6ff 100644 --- a/dlt/destinations/databricks/sql_client.py +++ b/dlt/destinations/databricks/sql_client.py @@ -3,10 +3,10 @@ from databricks import sql as databricks_lib from databricks.sql.client import ( - Connection as DatabricksSQLConnection, - Cursor as DatabricksSQLCursor, + Connection as DatabricksSqlConnection, + Cursor as DatabricksSqlCursor, ) -from databricks.sql.exc import Error as DatabricksSQLError +from databricks.sql.exc import Error as DatabricksSqlError from dlt.common import logger from dlt.common.destination import DestinationCapabilitiesContext @@ -17,7 +17,7 @@ from dlt.destinations.databricks import capabilities class DatabricksCursorImpl(DBApiCursorImpl): - native_cursor: DatabricksSQLCursor # type: ignore[assignment] + native_cursor: DatabricksSqlCursor # type: ignore[assignment] def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: if chunk_size is None: @@ -25,17 +25,16 @@ def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: return super().df(chunk_size=chunk_size, **kwargs) -class DatabricksSqlClient(SqlClientBase[DatabricksSQLConnection], DBTransaction): - +class DatabricksSqlClient(SqlClientBase[DatabricksSqlConnection], DBTransaction): dbapi: ClassVar[DBApi] = databricks_lib capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, dataset_name: str, credentials: DatabricksCredentials) -> None: super().__init__(credentials.catalog, dataset_name) - self._conn: DatabricksSQLConnection = None + self._conn: DatabricksSqlConnection = None self.credentials = credentials - def open_connection(self) -> DatabricksSQLConnection: + def open_connection(self) -> DatabricksSqlConnection: conn_params = self.credentials.to_connector_params() self._conn = databricks_lib.connect( **conn_params @@ -47,7 +46,7 @@ def close_connection(self) -> None: try: self._conn.close() self._conn = None - except DatabricksSQLError as exc: + except DatabricksSqlError as exc: logger.warning("Exception while closing connection: {}".format(exc)) @contextmanager @@ -65,7 +64,7 @@ def rollback_transaction(self) -> None: logger.warning("NotImplemented: rollback") @property - def native_connection(self) -> "DatabricksSQLConnection": + def native_connection(self) -> "DatabricksSqlConnection": return self._conn def has_dataset(self) -> bool: From 1e9992d9406616fb7fa742c545fa18fdbd40ad5d Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Sun, 12 Nov 2023 18:32:41 -0600 Subject: [PATCH 22/84] Update file formats for Databricks staging --- dlt/destinations/databricks/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dlt/destinations/databricks/__init__.py b/dlt/destinations/databricks/__init__.py index 8d7fed338c..2e474a92bd 100644 --- a/dlt/destinations/databricks/__init__.py +++ b/dlt/destinations/databricks/__init__.py @@ -20,8 +20,8 @@ def capabilities() -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = "jsonl" caps.supported_loader_file_formats = ["jsonl", "parquet"] - caps.preferred_staging_file_format = "jsonl" - caps.supported_staging_file_formats = ["jsonl", "parquet"] + caps.preferred_staging_file_format = "parquet" + caps.supported_staging_file_formats = ["parquet", "jsonl"] caps.escape_identifier = escape_databricks_identifier caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) From 60e9b8b2a4070fb8d5871c7a15b4c71c57a9a266 Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Mon, 13 Nov 2023 08:39:33 -0600 Subject: [PATCH 23/84] Refactored DatabricksSqlClient.has_dataset() method to use correct query based on presence of catalog in db_params. --- dlt/destinations/databricks/sql_client.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/dlt/destinations/databricks/sql_client.py b/dlt/destinations/databricks/sql_client.py index cddf88d6ff..ba10ca6ce3 100644 --- a/dlt/destinations/databricks/sql_client.py +++ b/dlt/destinations/databricks/sql_client.py @@ -68,14 +68,18 @@ def native_connection(self) -> "DatabricksSqlConnection": return self._conn def has_dataset(self) -> bool: - query = """ - SELECT 1 - FROM SYSTEM.INFORMATION_SCHEMA.SCHEMATA - WHERE """ db_params = self.fully_qualified_dataset_name(escape=False).split(".", 2) + + # Determine the base query based on the presence of a catalog in db_params if len(db_params) == 2: - query += " catalog_name = %s AND " - query += "schema_name = %s" + # Use catalog from db_params + query = "SELECT 1 FROM %s.`INFORMATION_SCHEMA`.`SCHEMATA` WHERE `schema_name` = %s" % ( + self.capabilities.escape_identifier(db_params[0]), db_params[1]) + else: + # Use system catalog + query = "SELECT 1 FROM `SYSTEM`.`INFORMATION_SCHEMA`.`SCHEMATA` WHERE `catalog_name` = %s AND `schema_name` = %s" + + # Execute the query and check if any rows are returned rows = self.execute_sql(query, *db_params) return len(rows) > 0 From f82bc533281241813efb1bbcb0f86a4749060481 Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Mon, 13 Nov 2023 08:43:54 -0600 Subject: [PATCH 24/84] Refactor DatabricksLoadJob constructor to improve readability and add optional parameters. --- dlt/destinations/databricks/databricks.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/dlt/destinations/databricks/databricks.py b/dlt/destinations/databricks/databricks.py index 5d97c196d3..314d27e7a6 100644 --- a/dlt/destinations/databricks/databricks.py +++ b/dlt/destinations/databricks/databricks.py @@ -97,8 +97,13 @@ def from_db_type(self, db_type: str, precision: Optional[int] = None, scale: Opt class DatabricksLoadJob(LoadJob, FollowupJob): def __init__( - self, file_path: str, table_name: str, load_id: str, client: DatabricksSqlClient, - stage_name: Optional[str] = None, keep_staged_files: bool = True, staging_credentials: Optional[CredentialsConfiguration] = None + self, file_path: str, + table_name: str, + load_id: str, + client: DatabricksSqlClient, + stage_name: Optional[str] = None, + keep_staged_files: bool = True, + staging_credentials: Optional[CredentialsConfiguration] = None ) -> None: file_name = FileStorage.get_file_name_from_file_path(file_path) super().__init__(file_name) @@ -127,12 +132,12 @@ def __init__( elif bucket_scheme in ["az", "abfss"] and staging_credentials and isinstance(staging_credentials, AzureCredentialsWithoutDefaults): # Explicit azure credentials are needed to load from bucket without a named stage credentials_clause = f"CREDENTIALS=(AZURE_SAS_TOKEN='?{staging_credentials.azure_storage_sas_token}')" - # Converts an az:/// to azure://.blob.core.windows.net// + # Converts an az:/// to abfss://.blob.core.windows.net// # as required by databricks _path = "/" + bucket_url.netloc + bucket_url.path bucket_path = urlunparse( bucket_url._replace( - scheme="azure", + scheme="abfss", netloc=f"{staging_credentials.azure_storage_account_name}.blob.core.windows.net", path=_path ) From 640a04b02e554d6fa490c173d1a2e0e5e6b820dc Mon Sep 17 00:00:00 2001 From: Dave Date: Wed, 15 Nov 2023 11:13:14 +0100 Subject: [PATCH 25/84] a few small changes --- dlt/destinations/databricks/sql_client.py | 37 ++++++++++++----------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/dlt/destinations/databricks/sql_client.py b/dlt/destinations/databricks/sql_client.py index ba10ca6ce3..b9f3e930ef 100644 --- a/dlt/destinations/databricks/sql_client.py +++ b/dlt/destinations/databricks/sql_client.py @@ -67,23 +67,22 @@ def rollback_transaction(self) -> None: def native_connection(self) -> "DatabricksSqlConnection": return self._conn - def has_dataset(self) -> bool: - db_params = self.fully_qualified_dataset_name(escape=False).split(".", 2) - - # Determine the base query based on the presence of a catalog in db_params - if len(db_params) == 2: - # Use catalog from db_params - query = "SELECT 1 FROM %s.`INFORMATION_SCHEMA`.`SCHEMATA` WHERE `schema_name` = %s" % ( - self.capabilities.escape_identifier(db_params[0]), db_params[1]) - else: - # Use system catalog - query = "SELECT 1 FROM `SYSTEM`.`INFORMATION_SCHEMA`.`SCHEMATA` WHERE `catalog_name` = %s AND `schema_name` = %s" - - # Execute the query and check if any rows are returned - rows = self.execute_sql(query, *db_params) - return len(rows) > 0 - - def drop_tables(self, *tables: str) -> None: + # def has_dataset(self) -> bool: + # db_params = self.fully_qualified_dataset_name(escape=False).split(".", 2) + + # # Determine the base query based on the presence of a catalog in db_params + # if len(db_params) == 2: + # # Use catalog from db_params + # query = "SELECT 1 FROM %s.`INFORMATION_SCHEMA`.`SCHEMATA` WHERE `schema_name` = %s" + # else: + # # Use system catalog + # query = "SELECT 1 FROM `SYSTEM`.`INFORMATION_SCHEMA`.`SCHEMATA` WHERE `catalog_name` = %s AND `schema_name` = %s" + + # # Execute the query and check if any rows are returned + # rows = self.execute_sql(query, *db_params) + # return len(rows) > 0 + + def drop_tables(self, *tables: str) -> None: # Tables are drop with `IF EXISTS`, but databricks raises when the schema doesn't exist. # Multi statement exec is safe and the error can be ignored since all tables are in the same schema. with suppress(DatabaseUndefinedRelation): @@ -129,9 +128,11 @@ def _reset_connection(self) -> None: @staticmethod def _make_database_exception(ex: Exception) -> Exception: - if isinstance(ex, databricks_lib.OperationalError): + + if isinstance(ex, databricks_lib.ServerOperationError): if "TABLE_OR_VIEW_NOT_FOUND" in str(ex): return DatabaseUndefinedRelation(ex) + elif isinstance(ex, databricks_lib.OperationalError): return DatabaseTerminalException(ex) elif isinstance(ex, (databricks_lib.ProgrammingError, databricks_lib.IntegrityError)): return DatabaseTerminalException(ex) From 0d9824853c5492e067691c7876bc4a9f81a4ced6 Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Mon, 20 Nov 2023 08:29:40 -0600 Subject: [PATCH 26/84] Add and comment execute fragments method --- dlt/destinations/databricks/sql_client.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/dlt/destinations/databricks/sql_client.py b/dlt/destinations/databricks/sql_client.py index b9f3e930ef..4386a3e1c3 100644 --- a/dlt/destinations/databricks/sql_client.py +++ b/dlt/destinations/databricks/sql_client.py @@ -82,7 +82,7 @@ def native_connection(self) -> "DatabricksSqlConnection": # rows = self.execute_sql(query, *db_params) # return len(rows) > 0 - def drop_tables(self, *tables: str) -> None: + def drop_tables(self, *tables: str) -> None: # Tables are drop with `IF EXISTS`, but databricks raises when the schema doesn't exist. # Multi statement exec is safe and the error can be ignored since all tables are in the same schema. with suppress(DatabaseUndefinedRelation): @@ -96,6 +96,10 @@ def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequen f = curr.fetchall() return f + # def execute_fragments(self, fragments: Sequence[AnyStr], *args: Any, **kwargs: Any) -> Optional[Sequence[Sequence[Any]]]: + # """Executes several SQL fragments as efficiently as possible to prevent data copying. Default implementation just joins the strings and executes them together.""" + # return [self.execute_sql(fragment, *args, **kwargs) for fragment in fragments] # type: ignore + @contextmanager @raise_database_error def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DBApiCursor]: @@ -128,7 +132,7 @@ def _reset_connection(self) -> None: @staticmethod def _make_database_exception(ex: Exception) -> Exception: - + if isinstance(ex, databricks_lib.ServerOperationError): if "TABLE_OR_VIEW_NOT_FOUND" in str(ex): return DatabaseUndefinedRelation(ex) From 69404cca202daf7aaa9ff05070d3379d243663a0 Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Mon, 20 Nov 2023 08:30:24 -0600 Subject: [PATCH 27/84] Update staging file format preference --- dlt/destinations/databricks/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dlt/destinations/databricks/__init__.py b/dlt/destinations/databricks/__init__.py index 2e474a92bd..bb0767cff5 100644 --- a/dlt/destinations/databricks/__init__.py +++ b/dlt/destinations/databricks/__init__.py @@ -20,8 +20,8 @@ def capabilities() -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = "jsonl" caps.supported_loader_file_formats = ["jsonl", "parquet"] - caps.preferred_staging_file_format = "parquet" - caps.supported_staging_file_formats = ["parquet", "jsonl"] + caps.preferred_staging_file_format = "jsonl" + caps.supported_staging_file_formats = ["jsonl", "parquet"] caps.escape_identifier = escape_databricks_identifier caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) @@ -32,6 +32,8 @@ def capabilities() -> DestinationCapabilitiesContext: caps.max_text_data_type_length = 16 * 1024 * 1024 caps.is_max_text_data_type_length_in_bytes = True caps.supports_ddl_transactions = False + caps.supports_truncate_command = False + # caps.supports_transactions = False caps.alter_add_multi_column = True return caps From 635fd7e88b3a65bcd935e20511e4f5a27c101409 Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Mon, 20 Nov 2023 08:30:56 -0600 Subject: [PATCH 28/84] Refactor execute_fragments method --- dlt/destinations/databricks/sql_client.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dlt/destinations/databricks/sql_client.py b/dlt/destinations/databricks/sql_client.py index 4386a3e1c3..32a1d0acf1 100644 --- a/dlt/destinations/databricks/sql_client.py +++ b/dlt/destinations/databricks/sql_client.py @@ -97,7 +97,10 @@ def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequen return f # def execute_fragments(self, fragments: Sequence[AnyStr], *args: Any, **kwargs: Any) -> Optional[Sequence[Sequence[Any]]]: - # """Executes several SQL fragments as efficiently as possible to prevent data copying. Default implementation just joins the strings and executes them together.""" + # """ + # Executes several SQL fragments as efficiently as possible to prevent data copying. + # Default implementation just joins the strings and executes them together. + # """ # return [self.execute_sql(fragment, *args, **kwargs) for fragment in fragments] # type: ignore @contextmanager From 01604bbe5632dcbd7b9bc86121ab961742605877 Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Mon, 20 Nov 2023 08:32:47 -0600 Subject: [PATCH 29/84] Fix DatabricksLoadJob constructor arguments and update COPY INTO statement --- dlt/destinations/databricks/databricks.py | 47 +++++++++-------------- 1 file changed, 18 insertions(+), 29 deletions(-) diff --git a/dlt/destinations/databricks/databricks.py b/dlt/destinations/databricks/databricks.py index 314d27e7a6..7251692675 100644 --- a/dlt/destinations/databricks/databricks.py +++ b/dlt/destinations/databricks/databricks.py @@ -97,7 +97,8 @@ def from_db_type(self, db_type: str, precision: Optional[int] = None, scale: Opt class DatabricksLoadJob(LoadJob, FollowupJob): def __init__( - self, file_path: str, + self, + file_path: str, table_name: str, load_id: str, client: DatabricksSqlClient, @@ -117,40 +118,30 @@ def __init__( credentials_clause = "" files_clause = "" stage_file_path = "" + format_options = "" + copy_options = "" if bucket_path: bucket_url = urlparse(bucket_path) bucket_scheme = bucket_url.scheme - # referencing an external s3/azure stage does not require explicit AWS credentials - if bucket_scheme in ["s3", "az", "abfss"] and stage_name: - from_clause = f"FROM '@{stage_name}'" - files_clause = f"LOCATION ('{bucket_url.path.lstrip('/')}')" + # referencing an external s3/azure stage does not require explicit credentials + if bucket_scheme in ["s3", "abfss"] and stage_name: + from_clause = f"FROM ('{bucket_path}')" # referencing an staged files via a bucket URL requires explicit AWS credentials - elif bucket_scheme == "s3" and staging_credentials and isinstance(staging_credentials, AwsCredentialsWithoutDefaults): - credentials_clause = f"""CREDENTIALS=(AWS_KEY_ID='{staging_credentials.aws_access_key_id}' AWS_SECRET_KEY='{staging_credentials.aws_secret_access_key}')""" + if bucket_scheme == "s3" and staging_credentials and isinstance(staging_credentials, AwsCredentialsWithoutDefaults): + credentials_clause = f"""WITH(CREDENTIAL(AWS_KEY_ID='{staging_credentials.aws_access_key_id}', AWS_SECRET_KEY='{staging_credentials.aws_secret_access_key}'))""" from_clause = f"FROM '{bucket_path}'" - elif bucket_scheme in ["az", "abfss"] and staging_credentials and isinstance(staging_credentials, AzureCredentialsWithoutDefaults): + elif bucket_scheme == "abfss" and staging_credentials and isinstance(staging_credentials, AzureCredentialsWithoutDefaults): # Explicit azure credentials are needed to load from bucket without a named stage - credentials_clause = f"CREDENTIALS=(AZURE_SAS_TOKEN='?{staging_credentials.azure_storage_sas_token}')" - # Converts an az:/// to abfss://.blob.core.windows.net// - # as required by databricks - _path = "/" + bucket_url.netloc + bucket_url.path - bucket_path = urlunparse( - bucket_url._replace( - scheme="abfss", - netloc=f"{staging_credentials.azure_storage_account_name}.blob.core.windows.net", - path=_path - ) - ) + credentials_clause = f"""WITH(CREDENTIAL(AZURE_SAS_TOKEN='{staging_credentials.azure_storage_sas_token}'))""" from_clause = f"FROM '{bucket_path}'" else: - # ensure that gcs bucket path starts with gcs://, this is a requirement of databricks + # ensure that gcs bucket path starts with gs://, this is a requirement of databricks bucket_path = bucket_path.replace("gs://", "gcs://") if not stage_name: # when loading from bucket stage must be given raise LoadJobTerminalException(file_path, f"Cannot load from bucket path {bucket_path} without a stage name. See https://dlthub.com/docs/dlt-ecosystem/destinations/databricks for instructions on setting up the `stage_name`") - from_clause = f"FROM @{stage_name}/" - files_clause = f"FILES = ('{urlparse(bucket_path).path.lstrip('/')}')" + from_clause = f"FROM ('{bucket_path}')" else: # this means we have a local file if not stage_name: @@ -160,21 +151,19 @@ def __init__( from_clause = f"FROM {stage_file_path}" # decide on source format, stage_file_path will either be a local file or a bucket path - source_format = "( TYPE = 'JSON', BINARY_FORMAT = 'BASE64' )" + source_format = "JSON" if file_name.endswith("parquet"): - source_format = "(TYPE = 'PARQUET', BINARY_AS_TEXT = FALSE)" + source_format = "PARQUET" with client.begin_transaction(): - # PUT and COPY in one tx if local file, otherwise only copy - if not bucket_path: - client.execute_sql(f'PUT file://{file_path} @{stage_name}/"{load_id}" OVERWRITE = TRUE, AUTO_COMPRESS = FALSE') client.execute_sql( f"""COPY INTO {qualified_table_name} {from_clause} {files_clause} {credentials_clause} - FILE_FORMAT = {source_format} - MATCH_BY_COLUMN_NAME='CASE_INSENSITIVE' + FILEFORMAT = {source_format} + {format_options} + {copy_options} """ ) if stage_file_path and not keep_staged_files: From d9fbcdadb0293b85fc1d2ccc833452b2a70b293a Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Thu, 14 Dec 2023 09:23:08 -0600 Subject: [PATCH 30/84] Update Databricks destination capabilities --- dlt/destinations/databricks/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dlt/destinations/databricks/__init__.py b/dlt/destinations/databricks/__init__.py index bb0767cff5..c645f4cf69 100644 --- a/dlt/destinations/databricks/__init__.py +++ b/dlt/destinations/databricks/__init__.py @@ -18,10 +18,10 @@ def _configure(config: DatabricksClientConfiguration = config.value) -> Databric def capabilities() -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "jsonl" - caps.supported_loader_file_formats = ["jsonl", "parquet"] - caps.preferred_staging_file_format = "jsonl" - caps.supported_staging_file_formats = ["jsonl", "parquet"] + caps.preferred_loader_file_format = "parquet" + caps.supported_loader_file_formats = ["parquet", "jsonl"] + caps.preferred_staging_file_format = "parquet" + caps.supported_staging_file_formats = ["parquet", "jsonl"] caps.escape_identifier = escape_databricks_identifier caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) @@ -32,7 +32,7 @@ def capabilities() -> DestinationCapabilitiesContext: caps.max_text_data_type_length = 16 * 1024 * 1024 caps.is_max_text_data_type_length_in_bytes = True caps.supports_ddl_transactions = False - caps.supports_truncate_command = False + caps.supports_truncate_command = True # caps.supports_transactions = False caps.alter_add_multi_column = True return caps From 749aa11f1abb0b42db1377cce824c26ef758e5c1 Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Thu, 14 Dec 2023 09:24:50 -0600 Subject: [PATCH 31/84] Update Databricks destination code --- dlt/destinations/databricks/databricks.py | 46 ++++++++++++++++------- dlt/destinations/databricks/sql_client.py | 27 +++---------- 2 files changed, 38 insertions(+), 35 deletions(-) diff --git a/dlt/destinations/databricks/databricks.py b/dlt/destinations/databricks/databricks.py index 7251692675..dfc6f1914f 100644 --- a/dlt/destinations/databricks/databricks.py +++ b/dlt/destinations/databricks/databricks.py @@ -1,4 +1,4 @@ -from typing import ClassVar, Dict, Optional, Sequence, Tuple, List, Any +from typing import ClassVar, Dict, Optional, Sequence, Tuple, List, Any, Iterable, Type from urllib.parse import urlparse, urlunparse from dlt.common.destination import DestinationCapabilitiesContext @@ -7,7 +7,7 @@ from dlt.common.data_types import TDataType from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns -from dlt.common.schema.typing import TTableSchema, TColumnType +from dlt.common.schema.typing import TTableSchema, TColumnType, TSchemaTables from dlt.destinations.job_client_impl import SqlJobClientWithStaging @@ -119,21 +119,31 @@ def __init__( files_clause = "" stage_file_path = "" format_options = "" - copy_options = "" + copy_options = "COPY_OPTIONS ('mergeSchema'='true')" if bucket_path: bucket_url = urlparse(bucket_path) bucket_scheme = bucket_url.scheme # referencing an external s3/azure stage does not require explicit credentials - if bucket_scheme in ["s3", "abfss"] and stage_name: + if bucket_scheme in ["s3", "az"] and stage_name: from_clause = f"FROM ('{bucket_path}')" # referencing an staged files via a bucket URL requires explicit AWS credentials if bucket_scheme == "s3" and staging_credentials and isinstance(staging_credentials, AwsCredentialsWithoutDefaults): credentials_clause = f"""WITH(CREDENTIAL(AWS_KEY_ID='{staging_credentials.aws_access_key_id}', AWS_SECRET_KEY='{staging_credentials.aws_secret_access_key}'))""" from_clause = f"FROM '{bucket_path}'" - elif bucket_scheme == "abfss" and staging_credentials and isinstance(staging_credentials, AzureCredentialsWithoutDefaults): + elif bucket_scheme == "az" and staging_credentials and isinstance(staging_credentials, AzureCredentialsWithoutDefaults): # Explicit azure credentials are needed to load from bucket without a named stage credentials_clause = f"""WITH(CREDENTIAL(AZURE_SAS_TOKEN='{staging_credentials.azure_storage_sas_token}'))""" + # Converts an az:/// to abfss://@.dfs.core.windows.net/ + # as required by snowflake + _path = bucket_url.path + bucket_path = urlunparse( + bucket_url._replace( + scheme="abfss", + netloc=f"{bucket_url.netloc}@{staging_credentials.azure_storage_account_name}.dfs.core.windows.net", + path=_path + ) + ) from_clause = f"FROM '{bucket_path}'" else: # ensure that gcs bucket path starts with gs://, this is a requirement of databricks @@ -142,13 +152,13 @@ def __init__( # when loading from bucket stage must be given raise LoadJobTerminalException(file_path, f"Cannot load from bucket path {bucket_path} without a stage name. See https://dlthub.com/docs/dlt-ecosystem/destinations/databricks for instructions on setting up the `stage_name`") from_clause = f"FROM ('{bucket_path}')" - else: - # this means we have a local file - if not stage_name: - # Use implicit table stage by default: "SCHEMA_NAME"."%TABLE_NAME" - stage_name = client.make_qualified_table_name('%'+table_name) - stage_file_path = f'@{stage_name}/"{load_id}"/{file_name}' - from_clause = f"FROM {stage_file_path}" + # else: + # # this means we have a local file + # if not stage_name: + # # Use implicit table stage by default: "SCHEMA_NAME"."%TABLE_NAME" + # stage_name = client.make_qualified_table_name('%'+table_name) + # stage_file_path = f'@{stage_name}/"{load_id}"/{file_name}' + # from_clause = f"FROM {stage_file_path}" # decide on source format, stage_file_path will either be a local file or a bucket path source_format = "JSON" @@ -166,8 +176,8 @@ def __init__( {copy_options} """ ) - if stage_file_path and not keep_staged_files: - client.execute_sql(f'REMOVE {stage_file_path}') + # if stage_file_path and not keep_staged_files: + # client.execute_sql(f'REMOVE {stage_file_path}') def state(self) -> TLoadJobState: @@ -240,6 +250,14 @@ def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSc return sql + def _execute_schema_update_sql(self, only_tables: Iterable[str]) -> TSchemaTables: + sql_scripts, schema_update = self._build_schema_update_sql(only_tables) + # stay within max query size when doing DDL. some db backends use bytes not characters so decrease limit by half + # assuming that most of the characters in DDL encode into single bytes + self.sql_client.execute_fragments(sql_scripts) + self._update_schema_in_storage(self.schema) + return schema_update + def _from_db_type(self, bq_t: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: return self.type_mapper.from_db_type(bq_t, precision, scale) diff --git a/dlt/destinations/databricks/sql_client.py b/dlt/destinations/databricks/sql_client.py index 32a1d0acf1..876fda1a98 100644 --- a/dlt/destinations/databricks/sql_client.py +++ b/dlt/destinations/databricks/sql_client.py @@ -67,21 +67,6 @@ def rollback_transaction(self) -> None: def native_connection(self) -> "DatabricksSqlConnection": return self._conn - # def has_dataset(self) -> bool: - # db_params = self.fully_qualified_dataset_name(escape=False).split(".", 2) - - # # Determine the base query based on the presence of a catalog in db_params - # if len(db_params) == 2: - # # Use catalog from db_params - # query = "SELECT 1 FROM %s.`INFORMATION_SCHEMA`.`SCHEMATA` WHERE `schema_name` = %s" - # else: - # # Use system catalog - # query = "SELECT 1 FROM `SYSTEM`.`INFORMATION_SCHEMA`.`SCHEMATA` WHERE `catalog_name` = %s AND `schema_name` = %s" - - # # Execute the query and check if any rows are returned - # rows = self.execute_sql(query, *db_params) - # return len(rows) > 0 - def drop_tables(self, *tables: str) -> None: # Tables are drop with `IF EXISTS`, but databricks raises when the schema doesn't exist. # Multi statement exec is safe and the error can be ignored since all tables are in the same schema. @@ -96,12 +81,12 @@ def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequen f = curr.fetchall() return f - # def execute_fragments(self, fragments: Sequence[AnyStr], *args: Any, **kwargs: Any) -> Optional[Sequence[Sequence[Any]]]: - # """ - # Executes several SQL fragments as efficiently as possible to prevent data copying. - # Default implementation just joins the strings and executes them together. - # """ - # return [self.execute_sql(fragment, *args, **kwargs) for fragment in fragments] # type: ignore + def execute_fragments(self, fragments: Sequence[AnyStr], *args: Any, **kwargs: Any) -> Optional[Sequence[Sequence[Any]]]: + """ + Executes several SQL fragments as efficiently as possible to prevent data copying. + Default implementation just joins the strings and executes them together. + """ + return [self.execute_sql(fragment, *args, **kwargs) for fragment in fragments] # type: ignore @contextmanager @raise_database_error From e0fdf3f3a4599b44ac669286c991882e4bbc78f2 Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Sat, 6 Jan 2024 12:48:00 -0600 Subject: [PATCH 32/84] Fix SQL execution in SqlLoadJob --- dlt/destinations/job_client_impl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index ac68cfea8a..7fe0a4b6c7 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -84,7 +84,8 @@ def __init__(self, file_path: str, sql_client: SqlClientBase[Any]) -> None: # with sql_client.begin_transaction(): sql_client.execute_sql(sql) else: - sql_client.execute_sql(sql) + # sql_client.execute_sql(sql) + sql_client.execute_fragments(sql.split(";")) def state(self) -> TLoadJobState: # this job is always done From 8955cd9688efa5917591cba713dfec6ac9f87ccb Mon Sep 17 00:00:00 2001 From: Evan Phillips Date: Sat, 6 Jan 2024 12:48:08 -0600 Subject: [PATCH 33/84] Add SqlMergeJob to DatabricksLoadJob --- dlt/destinations/databricks/databricks.py | 47 +++++++++++++++-------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/dlt/destinations/databricks/databricks.py b/dlt/destinations/databricks/databricks.py index dfc6f1914f..5ba76dfee5 100644 --- a/dlt/destinations/databricks/databricks.py +++ b/dlt/destinations/databricks/databricks.py @@ -17,7 +17,7 @@ from dlt.destinations.databricks import capabilities from dlt.destinations.databricks.configuration import DatabricksClientConfiguration from dlt.destinations.databricks.sql_client import DatabricksSqlClient -from dlt.destinations.sql_jobs import SqlStagingCopyJob +from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.type_mapping import TypeMapper @@ -117,7 +117,7 @@ def __init__( from_clause = "" credentials_clause = "" files_clause = "" - stage_file_path = "" + # stage_file_path = "" format_options = "" copy_options = "COPY_OPTIONS ('mergeSchema'='true')" @@ -125,13 +125,13 @@ def __init__( bucket_url = urlparse(bucket_path) bucket_scheme = bucket_url.scheme # referencing an external s3/azure stage does not require explicit credentials - if bucket_scheme in ["s3", "az"] and stage_name: + if bucket_scheme in ["s3", "az", "abfs", "gc", "gcs"] and stage_name: from_clause = f"FROM ('{bucket_path}')" # referencing an staged files via a bucket URL requires explicit AWS credentials if bucket_scheme == "s3" and staging_credentials and isinstance(staging_credentials, AwsCredentialsWithoutDefaults): credentials_clause = f"""WITH(CREDENTIAL(AWS_KEY_ID='{staging_credentials.aws_access_key_id}', AWS_SECRET_KEY='{staging_credentials.aws_secret_access_key}'))""" from_clause = f"FROM '{bucket_path}'" - elif bucket_scheme == "az" and staging_credentials and isinstance(staging_credentials, AzureCredentialsWithoutDefaults): + elif bucket_scheme in ["az", "abfs"] and staging_credentials and isinstance(staging_credentials, AzureCredentialsWithoutDefaults): # Explicit azure credentials are needed to load from bucket without a named stage credentials_clause = f"""WITH(CREDENTIAL(AZURE_SAS_TOKEN='{staging_credentials.azure_storage_sas_token}'))""" # Converts an az:/// to abfss://@.dfs.core.windows.net/ @@ -152,6 +152,7 @@ def __init__( # when loading from bucket stage must be given raise LoadJobTerminalException(file_path, f"Cannot load from bucket path {bucket_path} without a stage name. See https://dlthub.com/docs/dlt-ecosystem/destinations/databricks for instructions on setting up the `stage_name`") from_clause = f"FROM ('{bucket_path}')" + # Databricks does not support loading from local files # else: # # this means we have a local file # if not stage_name: @@ -165,19 +166,19 @@ def __init__( if file_name.endswith("parquet"): source_format = "PARQUET" - with client.begin_transaction(): - client.execute_sql( - f"""COPY INTO {qualified_table_name} - {from_clause} - {files_clause} - {credentials_clause} - FILEFORMAT = {source_format} - {format_options} - {copy_options} - """ - ) - # if stage_file_path and not keep_staged_files: - # client.execute_sql(f'REMOVE {stage_file_path}') + client.execute_sql( + f"""COPY INTO {qualified_table_name} + {from_clause} + {files_clause} + {credentials_clause} + FILEFORMAT = {source_format} + {format_options} + {copy_options} + """ + ) + # Databricks does not support deleting staged files via sql + # if stage_file_path and not keep_staged_files: + # client.execute_sql(f'REMOVE {stage_file_path}') def state(self) -> TLoadJobState: @@ -202,6 +203,12 @@ def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClient return sql +class DatabricksMergeJob(SqlMergeJob): + @classmethod + def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: + return f"CREATE OR REPLACE TEMPORARY VIEW {temp_table_name} AS {select_sql};" + + class DatabricksClient(SqlJobClientWithStaging): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -233,6 +240,12 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") + def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: + return DatabricksMergeJob.from_table_chain(table_chain, self.sql_client) + + def _create_staging_copy_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: + return DatabricksStagingCopyJob.from_table_chain(table_chain, self.sql_client) + def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema]) -> List[str]: # Override because databricks requires multiple columns in a single ADD COLUMN clause return ["ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c) for c in new_columns)] From 38357a8d658177750ab430fb9154dba2314b2280 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Tue, 9 Jan 2024 13:41:46 -0500 Subject: [PATCH 34/84] Move databricks to new destination layout --- dlt/common/data_writers/escape.py | 1 + dlt/destinations/__init__.py | 2 + .../{ => impl}/databricks/__init__.py | 22 +++- .../{ => impl}/databricks/configuration.py | 24 ++-- .../{ => impl}/databricks/databricks.py | 114 ++++++++++++------ dlt/destinations/impl/databricks/factory.py | 54 +++++++++ .../{ => impl}/databricks/sql_client.py | 40 ++++-- 7 files changed, 190 insertions(+), 67 deletions(-) rename dlt/destinations/{ => impl}/databricks/__init__.py (80%) rename dlt/destinations/{ => impl}/databricks/configuration.py (88%) rename dlt/destinations/{ => impl}/databricks/databricks.py (77%) create mode 100644 dlt/destinations/impl/databricks/factory.py rename dlt/destinations/{ => impl}/databricks/sql_client.py (83%) diff --git a/dlt/common/data_writers/escape.py b/dlt/common/data_writers/escape.py index 707dfb0e1f..b454c476dc 100644 --- a/dlt/common/data_writers/escape.py +++ b/dlt/common/data_writers/escape.py @@ -125,4 +125,5 @@ def escape_snowflake_identifier(v: str) -> str: # See also https://docs.snowflake.com/en/sql-reference/identifiers-syntax#double-quoted-identifiers return escape_postgres_identifier(v.upper()) + escape_databricks_identifier = escape_bigquery_identifier diff --git a/dlt/destinations/__init__.py b/dlt/destinations/__init__.py index 980c4ce7f2..801d1d823a 100644 --- a/dlt/destinations/__init__.py +++ b/dlt/destinations/__init__.py @@ -10,6 +10,7 @@ from dlt.destinations.impl.qdrant.factory import qdrant from dlt.destinations.impl.motherduck.factory import motherduck from dlt.destinations.impl.weaviate.factory import weaviate +from dlt.destinations.impl.databricks.factory import databricks __all__ = [ @@ -25,4 +26,5 @@ "qdrant", "motherduck", "weaviate", + "databricks", ] diff --git a/dlt/destinations/databricks/__init__.py b/dlt/destinations/impl/databricks/__init__.py similarity index 80% rename from dlt/destinations/databricks/__init__.py rename to dlt/destinations/impl/databricks/__init__.py index c645f4cf69..09f4e736d3 100644 --- a/dlt/destinations/databricks/__init__.py +++ b/dlt/destinations/impl/databricks/__init__.py @@ -8,11 +8,19 @@ from dlt.common.data_writers.escape import escape_databricks_identifier from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE -from dlt.destinations.databricks.configuration import DatabricksClientConfiguration - - -@with_config(spec=DatabricksClientConfiguration, sections=(known_sections.DESTINATION, "databricks",)) -def _configure(config: DatabricksClientConfiguration = config.value) -> DatabricksClientConfiguration: +from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration + + +@with_config( + spec=DatabricksClientConfiguration, + sections=( + known_sections.DESTINATION, + "databricks", + ), +) +def _configure( + config: DatabricksClientConfiguration = config.value, +) -> DatabricksClientConfiguration: return config @@ -38,7 +46,9 @@ def capabilities() -> DestinationCapabilitiesContext: return caps -def client(schema: Schema, initial_config: DestinationClientConfiguration = config.value) -> JobClientBase: +def client( + schema: Schema, initial_config: DestinationClientConfiguration = config.value +) -> JobClientBase: # import client when creating instance so capabilities and config specs can be accessed without dependencies installed from dlt.destinations.databricks.databricks import DatabricksClient diff --git a/dlt/destinations/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py similarity index 88% rename from dlt/destinations/databricks/configuration.py rename to dlt/destinations/impl/databricks/configuration.py index 7184c175c3..e71783df66 100644 --- a/dlt/destinations/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -8,6 +8,7 @@ CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog" + @configspec class DatabricksCredentials(CredentialsConfiguration): catalog: Optional[str] = None # type: ignore[assignment] @@ -27,7 +28,12 @@ class DatabricksCredentials(CredentialsConfiguration): _credentials_provider: Optional[Dict[str, Any]] = None - __config_gen_annotations__: ClassVar[List[str]] = ["server_hostname", "http_path", "catalog", "schema"] + __config_gen_annotations__: ClassVar[List[str]] = [ + "server_hostname", + "http_path", + "catalog", + "schema", + ] def __post_init__(self) -> None: if "." in (self.schema or ""): @@ -51,9 +57,7 @@ def __post_init__(self) -> None: if self.catalog is not None: catalog = self.catalog.strip() if not catalog: - raise ConfigurationValueError( - f"Invalid catalog name : `{self.catalog}`." - ) + raise ConfigurationValueError(f"Invalid catalog name : `{self.catalog}`.") self.catalog = catalog else: self.catalog = "hive_metastore" @@ -71,9 +75,7 @@ def __post_init__(self) -> None: "_user_agent_entry", ): if key in connection_parameters: - raise ConfigurationValueError( - f"The connection parameter `{key}` is reserved." - ) + raise ConfigurationValueError(f"The connection parameter `{key}` is reserved.") if "http_headers" in connection_parameters: http_headers = connection_parameters["http_headers"] if not isinstance(http_headers, dict) or any( @@ -96,15 +98,13 @@ def validate_creds(self) -> None: ) if not self.token and self.auth_type != "oauth": raise ConfigurationValueError( - ("The config `auth_type: oauth` is required when not using access token") + "The config `auth_type: oauth` is required when not using access token" ) if not self.client_id and self.client_secret: raise ConfigurationValueError( - ( - "The config 'client_id' is required to connect " - "to Databricks when 'client_secret' is present" - ) + "The config 'client_id' is required to connect " + "to Databricks when 'client_secret' is present" ) def to_connector_params(self) -> Dict[str, Any]: diff --git a/dlt/destinations/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py similarity index 77% rename from dlt/destinations/databricks/databricks.py rename to dlt/destinations/impl/databricks/databricks.py index 5ba76dfee5..d0933add6e 100644 --- a/dlt/destinations/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -2,8 +2,18 @@ from urllib.parse import urlparse, urlunparse from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import FollowupJob, NewLoadJob, TLoadJobState, LoadJob, CredentialsConfiguration -from dlt.common.configuration.specs import AwsCredentialsWithoutDefaults, AzureCredentials, AzureCredentialsWithoutDefaults +from dlt.common.destination.reference import ( + FollowupJob, + NewLoadJob, + TLoadJobState, + LoadJob, + CredentialsConfiguration, +) +from dlt.common.configuration.specs import ( + AwsCredentialsWithoutDefaults, + AzureCredentials, + AzureCredentialsWithoutDefaults, +) from dlt.common.data_types import TDataType from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns @@ -14,9 +24,9 @@ from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.exceptions import LoadJobTerminalException -from dlt.destinations.databricks import capabilities -from dlt.destinations.databricks.configuration import DatabricksClientConfiguration -from dlt.destinations.databricks.sql_client import DatabricksSqlClient +from dlt.destinations.impl.databricks import capabilities +from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration +from dlt.destinations.impl.databricks.sql_client import DatabricksSqlClient from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase @@ -61,7 +71,7 @@ class DatabricksTypeMapper(TypeMapper): "INTERVAL": "interval", "MAP": "map", "STRUCT": "struct", - "ARRAY": "complex" + "ARRAY": "complex", } sct_to_dbt = { @@ -81,17 +91,18 @@ class DatabricksTypeMapper(TypeMapper): "interval": "INTERVAL", "map": "MAP", "struct": "STRUCT", - "complex": "ARRAY" + "complex": "ARRAY", } - - def from_db_type(self, db_type: str, precision: Optional[int] = None, scale: Optional[int] = None) -> TColumnType: + def from_db_type( + self, db_type: str, precision: Optional[int] = None, scale: Optional[int] = None + ) -> TColumnType: if db_type == "NUMBER": if precision == self.BIGINT_PRECISION and scale == 0: - return dict(data_type='bigint') + return dict(data_type="bigint") elif (precision, scale) == self.capabilities.wei_precision: - return dict(data_type='wei') - return dict(data_type='decimal', precision=precision, scale=scale) + return dict(data_type="wei") + return dict(data_type="decimal", precision=precision, scale=scale) return super().from_db_type(db_type, precision, scale) @@ -104,7 +115,7 @@ def __init__( client: DatabricksSqlClient, stage_name: Optional[str] = None, keep_staged_files: bool = True, - staging_credentials: Optional[CredentialsConfiguration] = None + staging_credentials: Optional[CredentialsConfiguration] = None, ) -> None: file_name = FileStorage.get_file_name_from_file_path(file_path) super().__init__(file_name) @@ -112,8 +123,14 @@ def __init__( qualified_table_name = client.make_qualified_table_name(table_name) # extract and prepare some vars - bucket_path = NewReferenceJob.resolve_reference(file_path) if NewReferenceJob.is_reference_job(file_path) else "" - file_name = FileStorage.get_file_name_from_file_path(bucket_path) if bucket_path else file_name + bucket_path = ( + NewReferenceJob.resolve_reference(file_path) + if NewReferenceJob.is_reference_job(file_path) + else "" + ) + file_name = ( + FileStorage.get_file_name_from_file_path(bucket_path) if bucket_path else file_name + ) from_clause = "" credentials_clause = "" files_clause = "" @@ -128,10 +145,18 @@ def __init__( if bucket_scheme in ["s3", "az", "abfs", "gc", "gcs"] and stage_name: from_clause = f"FROM ('{bucket_path}')" # referencing an staged files via a bucket URL requires explicit AWS credentials - if bucket_scheme == "s3" and staging_credentials and isinstance(staging_credentials, AwsCredentialsWithoutDefaults): + if ( + bucket_scheme == "s3" + and staging_credentials + and isinstance(staging_credentials, AwsCredentialsWithoutDefaults) + ): credentials_clause = f"""WITH(CREDENTIAL(AWS_KEY_ID='{staging_credentials.aws_access_key_id}', AWS_SECRET_KEY='{staging_credentials.aws_secret_access_key}'))""" from_clause = f"FROM '{bucket_path}'" - elif bucket_scheme in ["az", "abfs"] and staging_credentials and isinstance(staging_credentials, AzureCredentialsWithoutDefaults): + elif ( + bucket_scheme in ["az", "abfs"] + and staging_credentials + and isinstance(staging_credentials, AzureCredentialsWithoutDefaults) + ): # Explicit azure credentials are needed to load from bucket without a named stage credentials_clause = f"""WITH(CREDENTIAL(AZURE_SAS_TOKEN='{staging_credentials.azure_storage_sas_token}'))""" # Converts an az:/// to abfss://@.dfs.core.windows.net/ @@ -141,7 +166,7 @@ def __init__( bucket_url._replace( scheme="abfss", netloc=f"{bucket_url.netloc}@{staging_credentials.azure_storage_account_name}.dfs.core.windows.net", - path=_path + path=_path, ) ) from_clause = f"FROM '{bucket_path}'" @@ -150,7 +175,12 @@ def __init__( bucket_path = bucket_path.replace("gs://", "gcs://") if not stage_name: # when loading from bucket stage must be given - raise LoadJobTerminalException(file_path, f"Cannot load from bucket path {bucket_path} without a stage name. See https://dlthub.com/docs/dlt-ecosystem/destinations/databricks for instructions on setting up the `stage_name`") + raise LoadJobTerminalException( + file_path, + f"Cannot load from bucket path {bucket_path} without a stage name. See" + " https://dlthub.com/docs/dlt-ecosystem/destinations/databricks for" + " instructions on setting up the `stage_name`", + ) from_clause = f"FROM ('{bucket_path}')" # Databricks does not support loading from local files # else: @@ -166,31 +196,30 @@ def __init__( if file_name.endswith("parquet"): source_format = "PARQUET" - client.execute_sql( - f"""COPY INTO {qualified_table_name} + client.execute_sql(f"""COPY INTO {qualified_table_name} {from_clause} {files_clause} {credentials_clause} FILEFORMAT = {source_format} {format_options} {copy_options} - """ - ) + """) # Databricks does not support deleting staged files via sql # if stage_file_path and not keep_staged_files: # client.execute_sql(f'REMOVE {stage_file_path}') - def state(self) -> TLoadJobState: return "completed" def exception(self) -> str: raise NotImplementedError() -class DatabricksStagingCopyJob(SqlStagingCopyJob): +class DatabricksStagingCopyJob(SqlStagingCopyJob): @classmethod - def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def generate_sql( + cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any] + ) -> List[str]: sql: List[str] = [] for table in table_chain: with sql_client.with_staging_dataset(staging=True): @@ -213,13 +242,10 @@ class DatabricksClient(SqlJobClientWithStaging): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: DatabricksClientConfiguration) -> None: - sql_client = DatabricksSqlClient( - config.normalize_dataset_name(schema), - config.credentials - ) + sql_client = DatabricksSqlClient(config.normalize_dataset_name(schema), config.credentials) super().__init__(schema, config, sql_client) self.config: DatabricksClientConfiguration = config - self.sql_client: DatabricksSqlClient = sql_client # type: ignore + self.sql_client: DatabricksSqlClient = sql_client self.type_mapper = DatabricksTypeMapper(self.capabilities) def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: @@ -228,12 +254,14 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> if not job: job = DatabricksLoadJob( file_path, - table['name'], + table["name"], load_id, self.sql_client, stage_name=self.config.stage_name, keep_staged_files=self.config.keep_staged_files, - staging_credentials=self.config.staging_config.credentials if self.config.staging_config else None + staging_credentials=( + self.config.staging_config.credentials if self.config.staging_config else None + ), ) return job @@ -253,10 +281,18 @@ def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema]) -> List[str def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: return DatabricksStagingCopyJob.from_table_chain(table_chain, self.sql_client) - def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool, separate_alters: bool = False) -> List[str]: + def _get_table_update_sql( + self, + table_name: str, + new_columns: Sequence[TColumnSchema], + generate_alter: bool, + separate_alters: bool = False, + ) -> List[str]: sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) - cluster_list = [self.capabilities.escape_identifier(c['name']) for c in new_columns if c.get('cluster')] + cluster_list = [ + self.capabilities.escape_identifier(c["name"]) for c in new_columns if c.get("cluster") + ] if cluster_list: sql[0] = sql[0] + "\nCLUSTER BY (" + ",".join(cluster_list) + ")" @@ -271,12 +307,16 @@ def _execute_schema_update_sql(self, only_tables: Iterable[str]) -> TSchemaTable self._update_schema_in_storage(self.schema) return schema_update - def _from_db_type(self, bq_t: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: + def _from_db_type( + self, bq_t: str, precision: Optional[int], scale: Optional[int] + ) -> TColumnType: return self.type_mapper.from_db_type(bq_t, precision, scale) def _get_column_def_sql(self, c: TColumnSchema) -> str: name = self.capabilities.escape_identifier(c["name"]) - return f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" + return ( + f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" + ) def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: exists, table = super().get_storage_table(table_name) diff --git a/dlt/destinations/impl/databricks/factory.py b/dlt/destinations/impl/databricks/factory.py new file mode 100644 index 0000000000..a22e4514b7 --- /dev/null +++ b/dlt/destinations/impl/databricks/factory.py @@ -0,0 +1,54 @@ +import typing as t + +from dlt.common.destination import Destination, DestinationCapabilitiesContext + +from dlt.destinations.impl.databricks.configuration import ( + DatabricksCredentials, + DatabricksClientConfiguration, +) +from dlt.destinations.impl.databricks import capabilities + +if t.TYPE_CHECKING: + from dlt.destinations.impl.databricks.databricks import DatabricksClient + + +class databricks(Destination[DatabricksClientConfiguration, "DatabricksClient"]): + spec = DatabricksClientConfiguration + + def capabilities(self) -> DestinationCapabilitiesContext: + return capabilities() + + @property + def client_class(self) -> t.Type["DatabricksClient"]: + from dlt.destinations.impl.databricks.databricks import DatabricksClient + + return DatabricksClient + + def __init__( + self, + credentials: t.Union[DatabricksCredentials, t.Dict[str, t.Any], str] = None, + stage_name: t.Optional[str] = None, + keep_staged_files: bool = False, + destination_name: t.Optional[str] = None, + environment: t.Optional[str] = None, + **kwargs: t.Any, + ) -> None: + """Configure the Databricks destination to use in a pipeline. + + All arguments provided here supersede other configuration sources such as environment variables and dlt config files. + + Args: + credentials: Credentials to connect to the databricks database. Can be an instance of `DatabricksCredentials` or + a connection string in the format `databricks://user:password@host:port/database` + stage_name: Name of the stage to use for staging files. If not provided, the default stage will be used. + keep_staged_files: Should staged files be kept after loading. If False, staged files will be deleted after loading. + **kwargs: Additional arguments passed to the destination config + """ + super().__init__( + credentials=credentials, + stage_name=stage_name, + keep_staged_files=keep_staged_files, + destination_name=destination_name, + environment=environment, + **kwargs, + ) diff --git a/dlt/destinations/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py similarity index 83% rename from dlt/destinations/databricks/sql_client.py rename to dlt/destinations/impl/databricks/sql_client.py index 876fda1a98..6671cd4e72 100644 --- a/dlt/destinations/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -10,11 +10,21 @@ from dlt.common import logger from dlt.common.destination import DestinationCapabilitiesContext -from dlt.destinations.exceptions import DatabaseTerminalException, DatabaseTransientException, DatabaseUndefinedRelation -from dlt.destinations.sql_client import DBApiCursorImpl, SqlClientBase, raise_database_error, raise_open_connection_error +from dlt.destinations.exceptions import ( + DatabaseTerminalException, + DatabaseTransientException, + DatabaseUndefinedRelation, +) +from dlt.destinations.sql_client import ( + DBApiCursorImpl, + SqlClientBase, + raise_database_error, + raise_open_connection_error, +) from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction, DataFrame -from dlt.destinations.databricks.configuration import DatabricksCredentials -from dlt.destinations.databricks import capabilities +from dlt.destinations.impl.databricks.configuration import DatabricksCredentials +from dlt.destinations.impl.databricks import capabilities + class DatabricksCursorImpl(DBApiCursorImpl): native_cursor: DatabricksSqlCursor # type: ignore[assignment] @@ -36,9 +46,7 @@ def __init__(self, dataset_name: str, credentials: DatabricksCredentials) -> Non def open_connection(self) -> DatabricksSqlConnection: conn_params = self.credentials.to_connector_params() - self._conn = databricks_lib.connect( - **conn_params - ) + self._conn = databricks_lib.connect(**conn_params) return self._conn @raise_open_connection_error @@ -51,7 +59,10 @@ def close_connection(self) -> None: @contextmanager def begin_transaction(self) -> Iterator[DBTransaction]: - logger.warning("NotImplemented: Databricks does not support transactions. Each SQL statement is auto-committed separately.") + logger.warning( + "NotImplemented: Databricks does not support transactions. Each SQL statement is" + " auto-committed separately." + ) yield self @raise_database_error @@ -73,7 +84,9 @@ def drop_tables(self, *tables: str) -> None: with suppress(DatabaseUndefinedRelation): super().drop_tables(*tables) - def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequence[Sequence[Any]]]: + def execute_sql( + self, sql: AnyStr, *args: Any, **kwargs: Any + ) -> Optional[Sequence[Sequence[Any]]]: with self.execute_query(sql, *args, **kwargs) as curr: if curr.description is None: return None @@ -81,7 +94,9 @@ def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequen f = curr.fetchall() return f - def execute_fragments(self, fragments: Sequence[AnyStr], *args: Any, **kwargs: Any) -> Optional[Sequence[Sequence[Any]]]: + def execute_fragments( + self, fragments: Sequence[AnyStr], *args: Any, **kwargs: Any + ) -> Optional[Sequence[Sequence[Any]]]: """ Executes several SQL fragments as efficiently as possible to prevent data copying. Default implementation just joins the strings and executes them together. @@ -120,7 +135,6 @@ def _reset_connection(self) -> None: @staticmethod def _make_database_exception(ex: Exception) -> Exception: - if isinstance(ex, databricks_lib.ServerOperationError): if "TABLE_OR_VIEW_NOT_FOUND" in str(ex): return DatabaseUndefinedRelation(ex) @@ -134,7 +148,9 @@ def _make_database_exception(ex: Exception) -> Exception: return ex @staticmethod - def _maybe_make_terminal_exception_from_data_error(databricks_ex: databricks_lib.DatabaseError) -> Optional[Exception]: + def _maybe_make_terminal_exception_from_data_error( + databricks_ex: databricks_lib.DatabaseError, + ) -> Optional[Exception]: return None @staticmethod From 058489c3075bcd3d750de6a5faf8632a15412e4a Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Tue, 9 Jan 2024 13:51:19 -0500 Subject: [PATCH 35/84] Databricks dependency --- poetry.lock | 127 ++++++++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 2 + 2 files changed, 127 insertions(+), 2 deletions(-) diff --git a/poetry.lock b/poetry.lock index c5da40c604..777b660e1a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "about-time" @@ -1886,6 +1886,35 @@ nr-date = ">=2.0.0,<3.0.0" typeapi = ">=2.0.1,<3.0.0" typing-extensions = ">=3.10.0" +[[package]] +name = "databricks-sql-connector" +version = "3.0.1" +description = "Databricks SQL Connector for Python" +optional = true +python-versions = ">=3.8.0,<4.0.0" +files = [ + {file = "databricks_sql_connector-3.0.1-py3-none-any.whl", hash = "sha256:3824237732f4363f55e3a1b8dd90ac98b8008c66e869377c8a213581d13dcee2"}, + {file = "databricks_sql_connector-3.0.1.tar.gz", hash = "sha256:915648d5d43e41622d65446bf60c07b2a0d33f9e5ad03478712205703927fdb8"}, +] + +[package.dependencies] +lz4 = ">=4.0.2,<5.0.0" +numpy = [ + {version = ">=1.16.6", markers = "python_version >= \"3.8\" and python_version < \"3.11\""}, + {version = ">=1.23.4", markers = "python_version >= \"3.11\""}, +] +oauthlib = ">=3.1.0,<4.0.0" +openpyxl = ">=3.0.10,<4.0.0" +pandas = {version = ">=1.2.5,<3.0.0", markers = "python_version >= \"3.8\""} +pyarrow = ">=14.0.1,<15.0.0" +requests = ">=2.18.1,<3.0.0" +thrift = ">=0.16.0,<0.17.0" +urllib3 = ">=1.0" + +[package.extras] +alembic = ["alembic (>=1.0.11,<2.0.0)", "sqlalchemy (>=2.0.21)"] +sqlalchemy = ["sqlalchemy (>=2.0.21)"] + [[package]] name = "dbt-athena-community" version = "1.5.2" @@ -2298,6 +2327,17 @@ files = [ blessed = ">=1.17.7" prefixed = ">=0.3.2" +[[package]] +name = "et-xmlfile" +version = "1.1.0" +description = "An implementation of lxml.xmlfile for the standard library" +optional = true +python-versions = ">=3.6" +files = [ + {file = "et_xmlfile-1.1.0-py3-none-any.whl", hash = "sha256:a2ba85d1d6a74ef63837eed693bcb89c3f752169b0e3e7ae5b16ca5e1b3deada"}, + {file = "et_xmlfile-1.1.0.tar.gz", hash = "sha256:8eb9e2bc2f8c97e37a2dc85a09ecdcdec9d8a396530a6d5a33b30b9a92da0c5c"}, +] + [[package]] name = "exceptiongroup" version = "1.1.3" @@ -4284,6 +4324,56 @@ html5 = ["html5lib"] htmlsoup = ["BeautifulSoup4"] source = ["Cython (>=0.29.35)"] +[[package]] +name = "lz4" +version = "4.3.3" +description = "LZ4 Bindings for Python" +optional = true +python-versions = ">=3.8" +files = [ + {file = "lz4-4.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b891880c187e96339474af2a3b2bfb11a8e4732ff5034be919aa9029484cd201"}, + {file = "lz4-4.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:222a7e35137d7539c9c33bb53fcbb26510c5748779364014235afc62b0ec797f"}, + {file = "lz4-4.3.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f76176492ff082657ada0d0f10c794b6da5800249ef1692b35cf49b1e93e8ef7"}, + {file = "lz4-4.3.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1d18718f9d78182c6b60f568c9a9cec8a7204d7cb6fad4e511a2ef279e4cb05"}, + {file = "lz4-4.3.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6cdc60e21ec70266947a48839b437d46025076eb4b12c76bd47f8e5eb8a75dcc"}, + {file = "lz4-4.3.3-cp310-cp310-win32.whl", hash = "sha256:c81703b12475da73a5d66618856d04b1307e43428a7e59d98cfe5a5d608a74c6"}, + {file = "lz4-4.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:43cf03059c0f941b772c8aeb42a0813d68d7081c009542301637e5782f8a33e2"}, + {file = "lz4-4.3.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:30e8c20b8857adef7be045c65f47ab1e2c4fabba86a9fa9a997d7674a31ea6b6"}, + {file = "lz4-4.3.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2f7b1839f795315e480fb87d9bc60b186a98e3e5d17203c6e757611ef7dcef61"}, + {file = "lz4-4.3.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edfd858985c23523f4e5a7526ca6ee65ff930207a7ec8a8f57a01eae506aaee7"}, + {file = "lz4-4.3.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e9c410b11a31dbdc94c05ac3c480cb4b222460faf9231f12538d0074e56c563"}, + {file = "lz4-4.3.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d2507ee9c99dbddd191c86f0e0c8b724c76d26b0602db9ea23232304382e1f21"}, + {file = "lz4-4.3.3-cp311-cp311-win32.whl", hash = "sha256:f180904f33bdd1e92967923a43c22899e303906d19b2cf8bb547db6653ea6e7d"}, + {file = "lz4-4.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:b14d948e6dce389f9a7afc666d60dd1e35fa2138a8ec5306d30cd2e30d36b40c"}, + {file = "lz4-4.3.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e36cd7b9d4d920d3bfc2369840da506fa68258f7bb176b8743189793c055e43d"}, + {file = "lz4-4.3.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:31ea4be9d0059c00b2572d700bf2c1bc82f241f2c3282034a759c9a4d6ca4dc2"}, + {file = "lz4-4.3.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:33c9a6fd20767ccaf70649982f8f3eeb0884035c150c0b818ea660152cf3c809"}, + {file = "lz4-4.3.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bca8fccc15e3add173da91be8f34121578dc777711ffd98d399be35487c934bf"}, + {file = "lz4-4.3.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7d84b479ddf39fe3ea05387f10b779155fc0990125f4fb35d636114e1c63a2e"}, + {file = "lz4-4.3.3-cp312-cp312-win32.whl", hash = "sha256:337cb94488a1b060ef1685187d6ad4ba8bc61d26d631d7ba909ee984ea736be1"}, + {file = "lz4-4.3.3-cp312-cp312-win_amd64.whl", hash = "sha256:5d35533bf2cee56f38ced91f766cd0038b6abf46f438a80d50c52750088be93f"}, + {file = "lz4-4.3.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:363ab65bf31338eb364062a15f302fc0fab0a49426051429866d71c793c23394"}, + {file = "lz4-4.3.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0a136e44a16fc98b1abc404fbabf7f1fada2bdab6a7e970974fb81cf55b636d0"}, + {file = "lz4-4.3.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:abc197e4aca8b63f5ae200af03eb95fb4b5055a8f990079b5bdf042f568469dd"}, + {file = "lz4-4.3.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56f4fe9c6327adb97406f27a66420b22ce02d71a5c365c48d6b656b4aaeb7775"}, + {file = "lz4-4.3.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f0e822cd7644995d9ba248cb4b67859701748a93e2ab7fc9bc18c599a52e4604"}, + {file = "lz4-4.3.3-cp38-cp38-win32.whl", hash = "sha256:24b3206de56b7a537eda3a8123c644a2b7bf111f0af53bc14bed90ce5562d1aa"}, + {file = "lz4-4.3.3-cp38-cp38-win_amd64.whl", hash = "sha256:b47839b53956e2737229d70714f1d75f33e8ac26e52c267f0197b3189ca6de24"}, + {file = "lz4-4.3.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6756212507405f270b66b3ff7f564618de0606395c0fe10a7ae2ffcbbe0b1fba"}, + {file = "lz4-4.3.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ee9ff50557a942d187ec85462bb0960207e7ec5b19b3b48949263993771c6205"}, + {file = "lz4-4.3.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2b901c7784caac9a1ded4555258207d9e9697e746cc8532129f150ffe1f6ba0d"}, + {file = "lz4-4.3.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6d9ec061b9eca86e4dcc003d93334b95d53909afd5a32c6e4f222157b50c071"}, + {file = "lz4-4.3.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f4c7bf687303ca47d69f9f0133274958fd672efaa33fb5bcde467862d6c621f0"}, + {file = "lz4-4.3.3-cp39-cp39-win32.whl", hash = "sha256:054b4631a355606e99a42396f5db4d22046a3397ffc3269a348ec41eaebd69d2"}, + {file = "lz4-4.3.3-cp39-cp39-win_amd64.whl", hash = "sha256:eac9af361e0d98335a02ff12fb56caeb7ea1196cf1a49dbf6f17828a131da807"}, + {file = "lz4-4.3.3.tar.gz", hash = "sha256:01fe674ef2889dbb9899d8a67361e0c4a2c833af5aeb37dd505727cf5d2a131e"}, +] + +[package.extras] +docs = ["sphinx (>=1.6.0)", "sphinx-bootstrap-theme"] +flake8 = ["flake8"] +tests = ["psutil", "pytest (!=3.3.0)", "pytest-cov"] + [[package]] name = "makefun" version = "1.15.1" @@ -5140,6 +5230,20 @@ packaging = "*" protobuf = "*" sympy = "*" +[[package]] +name = "openpyxl" +version = "3.1.2" +description = "A Python library to read/write Excel 2010 xlsx/xlsm files" +optional = true +python-versions = ">=3.6" +files = [ + {file = "openpyxl-3.1.2-py2.py3-none-any.whl", hash = "sha256:f91456ead12ab3c6c2e9491cf33ba6d08357d802192379bb482f1033ade496f5"}, + {file = "openpyxl-3.1.2.tar.gz", hash = "sha256:a6f5977418eff3b2d5500d54d9db50c8277a368436f4e4f8ddb1be3422870184"}, +] + +[package.dependencies] +et-xmlfile = "*" + [[package]] name = "opentelemetry-api" version = "1.15.0" @@ -7713,6 +7817,24 @@ files = [ {file = "text_unidecode-1.3-py2.py3-none-any.whl", hash = "sha256:1311f10e8b895935241623731c2ba64f4c455287888b18189350b67134a822e8"}, ] +[[package]] +name = "thrift" +version = "0.16.0" +description = "Python bindings for the Apache Thrift RPC system" +optional = true +python-versions = "*" +files = [ + {file = "thrift-0.16.0.tar.gz", hash = "sha256:2b5b6488fcded21f9d312aa23c9ff6a0195d0f6ae26ddbd5ad9e3e25dfc14408"}, +] + +[package.dependencies] +six = ">=1.7.2" + +[package.extras] +all = ["tornado (>=4.0)", "twisted"] +tornado = ["tornado (>=4.0)"] +twisted = ["twisted"] + [[package]] name = "tokenizers" version = "0.13.3" @@ -8453,6 +8575,7 @@ athena = ["botocore", "pyarrow", "pyathena", "s3fs"] az = ["adlfs"] bigquery = ["gcsfs", "google-cloud-bigquery", "grpcio", "pyarrow"] cli = ["cron-descriptor", "pipdeptree"] +databricks = ["databricks-sql-connector"] dbt = ["dbt-athena-community", "dbt-bigquery", "dbt-core", "dbt-duckdb", "dbt-redshift", "dbt-snowflake"] duckdb = ["duckdb"] filesystem = ["botocore", "s3fs"] @@ -8471,4 +8594,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "cf751b2e1e9c66efde0a11774b5204e3206a14fd04ba4c79b2d37e38db5367ad" +content-hash = "2305ecdd34e888ae84eed985057df674263e856047a6d59d73c86f26a93e0783" diff --git a/pyproject.toml b/pyproject.toml index 6436ec23a7..69d4438089 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,7 @@ weaviate-client = {version = ">=3.22", optional = true} adlfs = {version = ">=2022.4.0", optional = true} pyodbc = {version = "^4.0.39", optional = true} qdrant-client = {version = "^1.6.4", optional = true, extras = ["fastembed"]} +databricks-sql-connector = {version = "^3.0.1", optional = true} [tool.poetry.extras] dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community"] @@ -97,6 +98,7 @@ athena = ["pyathena", "pyarrow", "s3fs", "botocore"] weaviate = ["weaviate-client"] mssql = ["pyodbc"] qdrant = ["qdrant-client"] +databricks = ["databricks-sql-connector"] [tool.poetry.scripts] dlt = "dlt.cli._dlt:_main" From ed659f8aabf9a6f6dc9eb3ff5cc21ce440419e34 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Tue, 9 Jan 2024 14:00:44 -0500 Subject: [PATCH 36/84] Add databricks destination_type --- dlt/destinations/impl/databricks/configuration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index e71783df66..f8d43f4ef1 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -124,7 +124,7 @@ def to_connector_params(self) -> Dict[str, Any]: @configspec class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration): - destination_name: Final[str] = "databricks" # type: ignore[misc] + destination_type: Final[str] = "databricks" # type: ignore[misc] credentials: DatabricksCredentials stage_name: Optional[str] = None From 36d1718e01bc2e86f4d2d795c5944c4aba6f8713 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 10 Jan 2024 12:14:18 -0500 Subject: [PATCH 37/84] Testing databricks with s3 staging --- .../configuration/specs/aws_credentials.py | 24 ++++ dlt/common/time.py | 37 ++++++ .../impl/databricks/databricks.py | 116 ++++++++---------- .../impl/databricks/sql_client.py | 43 ++++--- mypy.ini | 3 + tests/load/utils.py | 28 +++++ tests/utils.py | 1 + 7 files changed, 171 insertions(+), 81 deletions(-) diff --git a/dlt/common/configuration/specs/aws_credentials.py b/dlt/common/configuration/specs/aws_credentials.py index f6df1d8cce..ee7360e2cb 100644 --- a/dlt/common/configuration/specs/aws_credentials.py +++ b/dlt/common/configuration/specs/aws_credentials.py @@ -38,6 +38,13 @@ def to_native_representation(self) -> Dict[str, Optional[str]]: """Return a dict that can be passed as kwargs to boto3 session""" return dict(self) + def to_session_credentials(self) -> Dict[str, str]: + return dict( + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + aws_session_token=self.aws_session_token, + ) + @configspec class AwsCredentials(AwsCredentialsWithoutDefaults, CredentialsWithDefault): @@ -47,6 +54,23 @@ def on_partial(self) -> None: if self._from_session(session) and not self.is_partial(): self.resolve() + def to_session_credentials(self) -> Dict[str, str]: + """Return configured or new aws session token""" + if self.aws_session_token and self.aws_access_key_id and self.aws_secret_access_key: + return dict( + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + aws_session_token=self.aws_session_token, + ) + sess = self._to_botocore_session() + client = sess.create_client("sts") + token = client.get_session_token() + return dict( + aws_access_key_id=token["Credentials"]["AccessKeyId"], + aws_secret_access_key=token["Credentials"]["SecretAccessKey"], + aws_session_token=token["Credentials"]["SessionToken"], + ) + def _to_botocore_session(self) -> Any: try: import botocore.session diff --git a/dlt/common/time.py b/dlt/common/time.py index ed390c28bf..4f4dd05ef0 100644 --- a/dlt/common/time.py +++ b/dlt/common/time.py @@ -138,6 +138,43 @@ def ensure_pendulum_time(value: Union[str, datetime.time]) -> pendulum.Time: raise TypeError(f"Cannot coerce {value} to a pendulum.Time object.") +def to_py_datetime(value: datetime.datetime) -> datetime.datetime: + """Convert a pendulum.DateTime to a py datetime object. + + Args: + value: The value to convert. Can be a pendulum.DateTime or datetime. + + Returns: + A py datetime object + """ + if isinstance(value, pendulum.DateTime): + return datetime.datetime( + value.year, + value.month, + value.day, + value.hour, + value.minute, + value.second, + value.microsecond, + value.tzinfo, + ) + return value + + +def to_py_date(value: datetime.date) -> datetime.date: + """Convert a pendulum.Date to a py date object. + + Args: + value: The value to convert. Can be a pendulum.Date or date. + + Returns: + A py date object + """ + if isinstance(value, pendulum.Date): + return datetime.date(value.year, value.month, value.day) + return value + + def _datetime_from_ts_or_iso( value: Union[int, float, str] ) -> Union[pendulum.DateTime, pendulum.Date, pendulum.Time]: diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index d0933add6e..0236b96f4c 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -8,6 +8,7 @@ TLoadJobState, LoadJob, CredentialsConfiguration, + SupportsStagingDestination, ) from dlt.common.configuration.specs import ( AwsCredentialsWithoutDefaults, @@ -17,7 +18,7 @@ from dlt.common.data_types import TDataType from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns -from dlt.common.schema.typing import TTableSchema, TColumnType, TSchemaTables +from dlt.common.schema.typing import TTableSchema, TColumnType, TSchemaTables, TTableFormat from dlt.destinations.job_client_impl import SqlJobClientWithStaging @@ -35,7 +36,7 @@ class DatabricksTypeMapper(TypeMapper): sct_to_unbound_dbt = { - "complex": "ARRAY", # Databricks supports complex types like ARRAY + "complex": "STRING", # Databricks supports complex types like ARRAY "text": "STRING", "double": "DOUBLE", "bool": "BOOLEAN", @@ -44,14 +45,7 @@ class DatabricksTypeMapper(TypeMapper): "bigint": "BIGINT", "binary": "BINARY", "decimal": "DECIMAL", # DECIMAL(p,s) format - "float": "FLOAT", - "int": "INT", - "smallint": "SMALLINT", - "tinyint": "TINYINT", - "void": "VOID", - "interval": "INTERVAL", - "map": "MAP", - "struct": "STRUCT", + "time": "STRING", } dbt_to_sct = { @@ -61,48 +55,37 @@ class DatabricksTypeMapper(TypeMapper): "DATE": "date", "TIMESTAMP": "timestamp", "BIGINT": "bigint", + "INT": "bigint", + "SMALLINT": "bigint", + "TINYINT": "bigint", "BINARY": "binary", "DECIMAL": "decimal", - "FLOAT": "float", - "INT": "int", - "SMALLINT": "smallint", - "TINYINT": "tinyint", - "VOID": "void", - "INTERVAL": "interval", - "MAP": "map", - "STRUCT": "struct", - "ARRAY": "complex", } sct_to_dbt = { - "text": "STRING", - "double": "DOUBLE", - "bool": "BOOLEAN", - "date": "DATE", - "timestamp": "TIMESTAMP", - "bigint": "BIGINT", - "binary": "BINARY", "decimal": "DECIMAL(%i,%i)", - "float": "FLOAT", - "int": "INT", - "smallint": "SMALLINT", - "tinyint": "TINYINT", - "void": "VOID", - "interval": "INTERVAL", - "map": "MAP", - "struct": "STRUCT", - "complex": "ARRAY", + "wei": "DECIMAL(%i,%i)", } + def to_db_integer_type( + self, precision: Optional[int], table_format: TTableFormat = None + ) -> str: + if precision is None: + return "BIGINT" + if precision <= 8: + return "TINYINT" + if precision <= 16: + return "SMALLINT" + if precision <= 32: + return "INT" + return "BIGINT" + def from_db_type( self, db_type: str, precision: Optional[int] = None, scale: Optional[int] = None ) -> TColumnType: - if db_type == "NUMBER": - if precision == self.BIGINT_PRECISION and scale == 0: - return dict(data_type="bigint") - elif (precision, scale) == self.capabilities.wei_precision: - return dict(data_type="wei") - return dict(data_type="decimal", precision=precision, scale=scale) + if db_type == "DECIMAL": + if (precision, scale) == self.wei_precision(): + return dict(data_type="wei", precision=precision, scale=scale) return super().from_db_type(db_type, precision, scale) @@ -145,12 +128,18 @@ def __init__( if bucket_scheme in ["s3", "az", "abfs", "gc", "gcs"] and stage_name: from_clause = f"FROM ('{bucket_path}')" # referencing an staged files via a bucket URL requires explicit AWS credentials - if ( + elif ( bucket_scheme == "s3" and staging_credentials and isinstance(staging_credentials, AwsCredentialsWithoutDefaults) ): - credentials_clause = f"""WITH(CREDENTIAL(AWS_KEY_ID='{staging_credentials.aws_access_key_id}', AWS_SECRET_KEY='{staging_credentials.aws_secret_access_key}'))""" + s3_creds = staging_credentials.to_session_credentials() + credentials_clause = f"""WITH(CREDENTIAL( + AWS_ACCESS_KEY='{s3_creds["aws_access_key_id"]}', + AWS_SECRET_KEY='{s3_creds["aws_secret_access_key"]}', + AWS_SESSION_TOKEN='{s3_creds["aws_session_token"]}' + )) + """ from_clause = f"FROM '{bucket_path}'" elif ( bucket_scheme in ["az", "abfs"] @@ -171,39 +160,30 @@ def __init__( ) from_clause = f"FROM '{bucket_path}'" else: - # ensure that gcs bucket path starts with gs://, this is a requirement of databricks - bucket_path = bucket_path.replace("gs://", "gcs://") - if not stage_name: - # when loading from bucket stage must be given - raise LoadJobTerminalException( - file_path, - f"Cannot load from bucket path {bucket_path} without a stage name. See" - " https://dlthub.com/docs/dlt-ecosystem/destinations/databricks for" - " instructions on setting up the `stage_name`", - ) - from_clause = f"FROM ('{bucket_path}')" - # Databricks does not support loading from local files - # else: - # # this means we have a local file - # if not stage_name: - # # Use implicit table stage by default: "SCHEMA_NAME"."%TABLE_NAME" - # stage_name = client.make_qualified_table_name('%'+table_name) - # stage_file_path = f'@{stage_name}/"{load_id}"/{file_name}' - # from_clause = f"FROM {stage_file_path}" + raise LoadJobTerminalException( + file_path, + f"Databricks cannot load data from staging bucket {bucket_path}. Only s3 and azure buckets are supported", + ) + else: + raise LoadJobTerminalException( + file_path, + "Cannot load from local file. Databricks does not support loading from local files. Configure staging with an s3 or azure storage bucket.", + ) # decide on source format, stage_file_path will either be a local file or a bucket path source_format = "JSON" if file_name.endswith("parquet"): source_format = "PARQUET" - client.execute_sql(f"""COPY INTO {qualified_table_name} + statement = f"""COPY INTO {qualified_table_name} {from_clause} {files_clause} {credentials_clause} FILEFORMAT = {source_format} {format_options} {copy_options} - """) + """ + client.execute_sql(statement) # Databricks does not support deleting staged files via sql # if stage_file_path and not keep_staged_files: # client.execute_sql(f'REMOVE {stage_file_path}') @@ -238,7 +218,7 @@ def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: return f"CREATE OR REPLACE TEMPORARY VIEW {temp_table_name} AS {select_sql};" -class DatabricksClient(SqlJobClientWithStaging): +class DatabricksClient(SqlJobClientWithStaging, SupportsStagingDestination): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: DatabricksClientConfiguration) -> None: @@ -274,7 +254,9 @@ def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: def _create_staging_copy_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: return DatabricksStagingCopyJob.from_table_chain(table_chain, self.sql_client) - def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema]) -> List[str]: + def _make_add_column_sql( + self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None + ) -> List[str]: # Override because databricks requires multiple columns in a single ADD COLUMN clause return ["ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c) for c in new_columns)] @@ -312,7 +294,7 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_db_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: name = self.capabilities.escape_identifier(c["name"]) return ( f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index 6671cd4e72..4ba808884f 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -8,6 +8,7 @@ ) from databricks.sql.exc import Error as DatabricksSqlError +from dlt.common import pendulum from dlt.common import logger from dlt.common.destination import DestinationCapabilitiesContext from dlt.destinations.exceptions import ( @@ -24,10 +25,11 @@ from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction, DataFrame from dlt.destinations.impl.databricks.configuration import DatabricksCredentials from dlt.destinations.impl.databricks import capabilities +from dlt.common.time import to_py_date, to_py_datetime class DatabricksCursorImpl(DBApiCursorImpl): - native_cursor: DatabricksSqlCursor # type: ignore[assignment] + native_cursor: DatabricksSqlCursor def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: if chunk_size is None: @@ -51,11 +53,9 @@ def open_connection(self) -> DatabricksSqlConnection: @raise_open_connection_error def close_connection(self) -> None: - try: + if self._conn: self._conn.close() self._conn = None - except DatabricksSqlError as exc: - logger.warning("Exception while closing connection: {}".format(exc)) @contextmanager def begin_transaction(self) -> Iterator[DBTransaction]: @@ -101,14 +101,27 @@ def execute_fragments( Executes several SQL fragments as efficiently as possible to prevent data copying. Default implementation just joins the strings and executes them together. """ - return [self.execute_sql(fragment, *args, **kwargs) for fragment in fragments] # type: ignore + return [self.execute_sql(fragment, *args, **kwargs) for fragment in fragments] @contextmanager @raise_database_error def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DBApiCursor]: curr: DBApiCursor = None - db_args = args if args else kwargs if kwargs else None - with self._conn.cursor() as curr: # type: ignore[assignment] + if args: + keys = [f"arg{i}" for i in range(len(args))] + # Replace position arguments (%s) with named arguments (:arg0, :arg1, ...) + query = query % tuple(f":{key}" for key in keys) + db_args = {} + for key, db_arg in zip(keys, args): + # Databricks connector doesn't accept pendulum objects + if isinstance(db_arg, pendulum.DateTime): + db_arg = to_py_datetime(db_arg) + elif isinstance(db_arg, pendulum.Date): + db_arg = to_py_date(db_arg) + db_args[key] = db_arg + else: + db_args = None + with self._conn.cursor() as curr: try: curr.execute(query, db_args) yield DatabricksCursorImpl(curr) # type: ignore[abstract] @@ -138,14 +151,16 @@ def _make_database_exception(ex: Exception) -> Exception: if isinstance(ex, databricks_lib.ServerOperationError): if "TABLE_OR_VIEW_NOT_FOUND" in str(ex): return DatabaseUndefinedRelation(ex) - elif isinstance(ex, databricks_lib.OperationalError): - return DatabaseTerminalException(ex) - elif isinstance(ex, (databricks_lib.ProgrammingError, databricks_lib.IntegrityError)): return DatabaseTerminalException(ex) - elif isinstance(ex, databricks_lib.DatabaseError): - return DatabaseTransientException(ex) - else: - return ex + return ex + # elif isinstance(ex, databricks_lib.OperationalError): + # return DatabaseTerminalException(ex) + # elif isinstance(ex, (databricks_lib.ProgrammingError, databricks_lib.IntegrityError)): + # return DatabaseTerminalException(ex) + # elif isinstance(ex, databricks_lib.DatabaseError): + # return DatabaseTransientException(ex) + # else: + # return ex @staticmethod def _maybe_make_terminal_exception_from_data_error( diff --git a/mypy.ini b/mypy.ini index 8a02cf80bd..d4da898a0f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -104,4 +104,7 @@ ignore_missing_imports=true [mypy-s3fs.*] ignore_missing_imports=true [mypy-win_precise_time] +ignore_missing_imports=true + +[mypy-databricks.*] ignore_missing_imports=true \ No newline at end of file diff --git a/tests/load/utils.py b/tests/load/utils.py index 6811ca59a6..d772bf5da5 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -250,6 +250,20 @@ def destinations_configs( bucket_url=AZ_BUCKET, extra_info="az-authorization", ), + DestinationTestConfiguration( + destination="databricks", + staging="filesystem", + file_format="jsonl", + bucket_url=AWS_BUCKET, + extra_info="s3-authorization", + ), + DestinationTestConfiguration( + destination="databricks", + staging="filesystem", + file_format="jsonl", + bucket_url=AZ_BUCKET, + extra_info="az-authorization", + ), ] if all_staging_configs: @@ -282,6 +296,20 @@ def destinations_configs( bucket_url=GCS_BUCKET, extra_info="gcs-authorization", ), + DestinationTestConfiguration( + destination="databricks", + staging="filesystem", + file_format="parquet", + bucket_url=AWS_BUCKET, + extra_info="s3-authorization", + ), + DestinationTestConfiguration( + destination="databricks", + staging="filesystem", + file_format="parquet", + bucket_url=AZ_BUCKET, + extra_info="az-authorization", + ), ] # add local filesystem destinations if requested diff --git a/tests/utils.py b/tests/utils.py index cf172f9733..777dd4a27d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -45,6 +45,7 @@ "motherduck", "mssql", "qdrant", + "databricks", } NON_SQL_DESTINATIONS = {"filesystem", "weaviate", "dummy", "motherduck", "qdrant"} SQL_DESTINATIONS = IMPLEMENTED_DESTINATIONS - NON_SQL_DESTINATIONS From ff7d7c00f2c32c1294f46477ece24be48c768a05 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 10 Jan 2024 17:49:16 -0500 Subject: [PATCH 38/84] Type mapping fixes for databricks --- .../impl/databricks/databricks.py | 22 ++++++++++++++----- dlt/destinations/job_client_impl.py | 13 ++++++++--- tests/load/pipeline/test_pipelines.py | 2 +- tests/load/test_job_client.py | 4 +++- tests/load/utils.py | 2 +- 5 files changed, 31 insertions(+), 12 deletions(-) diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 0236b96f4c..3ed7bf0651 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -83,6 +83,16 @@ def to_db_integer_type( def from_db_type( self, db_type: str, precision: Optional[int] = None, scale: Optional[int] = None ) -> TColumnType: + # precision and scale arguments here are meaningless as they're not included separately in information schema + # We use full_data_type from databricks which is either in form "typename" or "typename(precision, scale)" + type_parts = db_type.split("(") + if len(type_parts) > 1: + db_type = type_parts[0] + scale_str = type_parts[1].strip(")") + precision, scale = [int(val) for val in scale_str.split(",")] + else: + scale = precision = None + db_type = db_type.upper() if db_type == "DECIMAL": if (precision, scale) == self.wei_precision(): return dict(data_type="wei", precision=precision, scale=scale) @@ -300,9 +310,9 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" ) - def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: - exists, table = super().get_storage_table(table_name) - if not exists: - return exists, table - table = {col_name.lower(): dict(col, name=col_name.lower()) for col_name, col in table.items()} # type: ignore - return exists, table + def _get_storage_table_query_columns(self) -> List[str]: + fields = super()._get_storage_table_query_columns() + fields[ + 1 + ] = "full_data_type" # Override because this is the only way to get data type with precision + return fields diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 7fe0a4b6c7..252504cd3f 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -296,6 +296,15 @@ def __exit__( ) -> None: self.sql_client.close_connection() + def _get_storage_table_query_columns(self) -> List[str]: + """Column names used when querying table from information schema. + Override for databases that use different namings. + """ + fields = ["column_name", "data_type", "is_nullable"] + if self.capabilities.schema_supports_numeric_precision: + fields += ["numeric_precision", "numeric_scale"] + return fields + def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: def _null_to_bool(v: str) -> bool: if v == "NO": @@ -304,9 +313,7 @@ def _null_to_bool(v: str) -> bool: return True raise ValueError(v) - fields = ["column_name", "data_type", "is_nullable"] - if self.capabilities.schema_supports_numeric_precision: - fields += ["numeric_precision", "numeric_scale"] + fields = self._get_storage_table_query_columns() db_params = self.sql_client.make_qualified_table_name(table_name, escape=False).split( ".", 3 ) diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index d170fd553b..8fb5b6c292 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -788,7 +788,7 @@ def other_data(): column_schemas["col11_precision"]["precision"] = 0 # drop TIME from databases not supporting it via parquet - if destination_config.destination in ["redshift", "athena"]: + if destination_config.destination in ["redshift", "athena", "databricks"]: data_types.pop("col11") data_types.pop("col11_null") data_types.pop("col11_precision") diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 153504bf4a..65a19a5323 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -387,7 +387,9 @@ def test_get_storage_table_with_all_types(client: SqlJobClientBase) -> None: "time", ): continue - if client.config.destination_type == "mssql" and c["data_type"] in ("complex"): + if client.config.destination_type in "mssql" and c["data_type"] in ("complex"): + continue + if client.config.destination_type == "databricks" and c["data_type"] in ("complex", "time"): continue assert c["data_type"] == expected_c["data_type"] diff --git a/tests/load/utils.py b/tests/load/utils.py index d772bf5da5..c0fe652722 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -163,7 +163,7 @@ def destinations_configs( destination_configs += [ DestinationTestConfiguration(destination=destination) for destination in SQL_DESTINATIONS - if destination != "athena" + if destination not in ("athena", "databricks") ] destination_configs += [ DestinationTestConfiguration(destination="duckdb", file_format="parquet") From 7dbd51e693a31743e0a8e4f02fff3a622d2cd67f Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 11 Jan 2024 13:15:56 -0500 Subject: [PATCH 39/84] Fix some databricks bugs --- dlt/common/destination/capabilities.py | 2 ++ dlt/destinations/impl/databricks/__init__.py | 1 + .../impl/databricks/configuration.py | 26 -------------- .../impl/databricks/databricks.py | 34 ++++++++++++------- .../impl/databricks/sql_client.py | 2 +- dlt/destinations/job_client_impl.py | 10 ++++-- dlt/destinations/sql_client.py | 6 +++- 7 files changed, 38 insertions(+), 43 deletions(-) diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index 2596b2bf99..10d09d52b3 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -52,6 +52,7 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): schema_supports_numeric_precision: bool = True timestamp_precision: int = 6 max_rows_per_insert: Optional[int] = None + supports_multiple_statements: bool = True # do not allow to create default value, destination caps must be always explicitly inserted into container can_create_default: ClassVar[bool] = False @@ -77,4 +78,5 @@ def generic_capabilities( caps.is_max_text_data_type_length_in_bytes = True caps.supports_ddl_transactions = True caps.supports_transactions = True + caps.supports_multiple_statements = True return caps diff --git a/dlt/destinations/impl/databricks/__init__.py b/dlt/destinations/impl/databricks/__init__.py index 09f4e736d3..b1cfca286d 100644 --- a/dlt/destinations/impl/databricks/__init__.py +++ b/dlt/destinations/impl/databricks/__init__.py @@ -43,6 +43,7 @@ def capabilities() -> DestinationCapabilitiesContext: caps.supports_truncate_command = True # caps.supports_transactions = False caps.alter_add_multi_column = True + caps.supports_multiple_statements = False return caps diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index f8d43f4ef1..edb547e1a2 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -12,7 +12,6 @@ @configspec class DatabricksCredentials(CredentialsConfiguration): catalog: Optional[str] = None # type: ignore[assignment] - schema: Optional[str] = None # type: ignore[assignment] server_hostname: str = None http_path: str = None access_token: Optional[TSecretStrValue] = None @@ -36,12 +35,6 @@ class DatabricksCredentials(CredentialsConfiguration): ] def __post_init__(self) -> None: - if "." in (self.schema or ""): - raise ConfigurationValueError( - f"The schema should not contain '.': {self.schema}\n" - "If you are trying to set a catalog, please use `catalog` instead.\n" - ) - session_properties = self.session_properties or {} if CATALOG_KEY_IN_SESSION_PROPERTIES in session_properties: if self.catalog is None: @@ -71,7 +64,6 @@ def __post_init__(self) -> None: "client_secret", "session_configuration", "catalog", - "schema", "_user_agent_entry", ): if key in connection_parameters: @@ -90,27 +82,9 @@ def __post_init__(self) -> None: connection_parameters["_socket_timeout"] = 180 self.connection_parameters = connection_parameters - def validate_creds(self) -> None: - for key in ["host", "http_path"]: - if not getattr(self, key): - raise ConfigurationValueError( - "The config '{}' is required to connect to Databricks".format(key) - ) - if not self.token and self.auth_type != "oauth": - raise ConfigurationValueError( - "The config `auth_type: oauth` is required when not using access token" - ) - - if not self.client_id and self.client_secret: - raise ConfigurationValueError( - "The config 'client_id' is required to connect " - "to Databricks when 'client_secret' is present" - ) - def to_connector_params(self) -> Dict[str, Any]: return dict( catalog=self.catalog, - schema=self.schema, server_hostname=self.server_hostname, http_path=self.http_path, access_token=self.access_token, diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 3ed7bf0651..6ad21aff5e 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -28,7 +28,7 @@ from dlt.destinations.impl.databricks import capabilities from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration from dlt.destinations.impl.databricks.sql_client import DatabricksSqlClient -from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob +from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob, SqlJobParams from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.type_mapping import TypeMapper @@ -208,24 +208,26 @@ def exception(self) -> str: class DatabricksStagingCopyJob(SqlStagingCopyJob): @classmethod def generate_sql( - cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any] + cls, + table_chain: Sequence[TTableSchema], + sql_client: SqlClientBase[Any], + params: Optional[SqlJobParams] = None, ) -> List[str]: sql: List[str] = [] for table in table_chain: with sql_client.with_staging_dataset(staging=True): staging_table_name = sql_client.make_qualified_table_name(table["name"]) table_name = sql_client.make_qualified_table_name(table["name"]) - # drop destination table sql.append(f"DROP TABLE IF EXISTS {table_name};") # recreate destination table with data cloned from staging table sql.append(f"CREATE TABLE {table_name} CLONE {staging_table_name};") return sql -class DatabricksMergeJob(SqlMergeJob): - @classmethod - def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: - return f"CREATE OR REPLACE TEMPORARY VIEW {temp_table_name} AS {select_sql};" +# class DatabricksMergeJob(SqlMergeJob): +# @classmethod +# def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: +# return f"CREATE OR REPLACE TEMPORARY VIEW {temp_table_name} AS {select_sql};" class DatabricksClient(SqlJobClientWithStaging, SupportsStagingDestination): @@ -258,11 +260,11 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") - def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: - return DatabricksMergeJob.from_table_chain(table_chain, self.sql_client) + # def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: + # return DatabricksMergeJob.from_table_chain(table_chain, self.sql_client) - def _create_staging_copy_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: - return DatabricksStagingCopyJob.from_table_chain(table_chain, self.sql_client) + # def _create_staging_copy_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: + # return DatabricksStagingCopyJob.from_table_chain(table_chain, self.sql_client) def _make_add_column_sql( self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None @@ -270,8 +272,14 @@ def _make_add_column_sql( # Override because databricks requires multiple columns in a single ADD COLUMN clause return ["ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c) for c in new_columns)] - def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: - return DatabricksStagingCopyJob.from_table_chain(table_chain, self.sql_client) + # def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: + # return DatabricksStagingCopyJob.from_table_chain(table_chain, self.sql_client) + def _create_replace_followup_jobs( + self, table_chain: Sequence[TTableSchema] + ) -> List[NewLoadJob]: + if self.config.replace_strategy == "staging-optimized": + return [DatabricksStagingCopyJob.from_table_chain(table_chain, self.sql_client)] + return super()._create_replace_followup_jobs(table_chain) def _get_table_update_sql( self, diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index 4ba808884f..c174762abb 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -48,7 +48,7 @@ def __init__(self, dataset_name: str, credentials: DatabricksCredentials) -> Non def open_connection(self) -> DatabricksSqlConnection: conn_params = self.credentials.to_connector_params() - self._conn = databricks_lib.connect(**conn_params) + self._conn = databricks_lib.connect(**conn_params, schema=self.dataset_name) return self._conn @raise_open_connection_error diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 252504cd3f..7b6e91baa2 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -76,8 +76,11 @@ def __init__(self, file_path: str, sql_client: SqlClientBase[Any]) -> None: with FileStorage.open_zipsafe_ro(file_path, "r", encoding="utf-8") as f: sql = f.read() + # Some clients (e.g. databricks) do not support multiple statements in one execute call + if not sql_client.capabilities.supports_multiple_statements: + sql_client.execute_fragments(self._split_fragments(sql)) # if we detect ddl transactions, only execute transaction if supported by client - if ( + elif ( not self._string_containts_ddl_queries(sql) or sql_client.capabilities.supports_ddl_transactions ): @@ -85,7 +88,7 @@ def __init__(self, file_path: str, sql_client: SqlClientBase[Any]) -> None: sql_client.execute_sql(sql) else: # sql_client.execute_sql(sql) - sql_client.execute_fragments(sql.split(";")) + sql_client.execute_fragments(self._split_fragments(sql)) def state(self) -> TLoadJobState: # this job is always done @@ -101,6 +104,9 @@ def _string_containts_ddl_queries(self, sql: str) -> bool: return True return False + def _split_fragments(self, sql: str) -> List[str]: + return [s for s in sql.split(";") if s.strip()] + @staticmethod def is_sql_job(file_path: str) -> bool: return os.path.splitext(file_path)[1][1:] == "sql" diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index 1e5f7031a5..e6df2c265b 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -20,7 +20,11 @@ from dlt.common.typing import TFun from dlt.common.destination import DestinationCapabilitiesContext -from dlt.destinations.exceptions import DestinationConnectionError, LoadClientNotConnected +from dlt.destinations.exceptions import ( + DestinationConnectionError, + LoadClientNotConnected, + DatabaseTerminalException, +) from dlt.destinations.typing import DBApi, TNativeConn, DBApiCursor, DataFrame, DBTransaction From ac916c8a7faaf15c22baac72448f73146b44cb50 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Fri, 12 Jan 2024 17:26:53 -0500 Subject: [PATCH 40/84] Databricks parquet only --- dlt/common/data_writers/writers.py | 2 +- dlt/destinations/impl/databricks/__init__.py | 4 ++-- .../impl/databricks/sql_client.py | 20 ++++++++++--------- tests/load/utils.py | 18 ++--------------- 4 files changed, 16 insertions(+), 28 deletions(-) diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 0f9ff09259..c2a3c68821 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -136,7 +136,7 @@ def data_format(cls) -> TFileFormatSpec: file_extension="jsonl", is_binary_format=True, supports_schema_changes=True, - supports_compression=True, + supports_compression=False, ) diff --git a/dlt/destinations/impl/databricks/__init__.py b/dlt/destinations/impl/databricks/__init__.py index b1cfca286d..247ab51e58 100644 --- a/dlt/destinations/impl/databricks/__init__.py +++ b/dlt/destinations/impl/databricks/__init__.py @@ -27,9 +27,9 @@ def _configure( def capabilities() -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = "parquet" - caps.supported_loader_file_formats = ["parquet", "jsonl"] + caps.supported_loader_file_formats = ["parquet"] caps.preferred_staging_file_format = "parquet" - caps.supported_staging_file_formats = ["parquet", "jsonl"] + caps.supported_staging_file_formats = ["parquet"] caps.escape_identifier = escape_databricks_identifier caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index c174762abb..b94ff763cf 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -78,6 +78,9 @@ def rollback_transaction(self) -> None: def native_connection(self) -> "DatabricksSqlConnection": return self._conn + def drop_dataset(self) -> None: + self.execute_sql("DROP SCHEMA IF EXISTS %s CASCADE;" % self.fully_qualified_dataset_name()) + def drop_tables(self, *tables: str) -> None: # Tables are drop with `IF EXISTS`, but databricks raises when the schema doesn't exist. # Multi statement exec is safe and the error can be ignored since all tables are in the same schema. @@ -152,15 +155,14 @@ def _make_database_exception(ex: Exception) -> Exception: if "TABLE_OR_VIEW_NOT_FOUND" in str(ex): return DatabaseUndefinedRelation(ex) return DatabaseTerminalException(ex) - return ex - # elif isinstance(ex, databricks_lib.OperationalError): - # return DatabaseTerminalException(ex) - # elif isinstance(ex, (databricks_lib.ProgrammingError, databricks_lib.IntegrityError)): - # return DatabaseTerminalException(ex) - # elif isinstance(ex, databricks_lib.DatabaseError): - # return DatabaseTransientException(ex) - # else: - # return ex + elif isinstance(ex, databricks_lib.OperationalError): + return DatabaseTerminalException(ex) + elif isinstance(ex, (databricks_lib.ProgrammingError, databricks_lib.IntegrityError)): + return DatabaseTerminalException(ex) + elif isinstance(ex, databricks_lib.DatabaseError): + return DatabaseTransientException(ex) + else: + return DatabaseTransientException(ex) @staticmethod def _maybe_make_terminal_exception_from_data_error( diff --git a/tests/load/utils.py b/tests/load/utils.py index c0fe652722..da8cd56328 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -253,14 +253,14 @@ def destinations_configs( DestinationTestConfiguration( destination="databricks", staging="filesystem", - file_format="jsonl", + file_format="parquet", bucket_url=AWS_BUCKET, extra_info="s3-authorization", ), DestinationTestConfiguration( destination="databricks", staging="filesystem", - file_format="jsonl", + file_format="parquet", bucket_url=AZ_BUCKET, extra_info="az-authorization", ), @@ -296,20 +296,6 @@ def destinations_configs( bucket_url=GCS_BUCKET, extra_info="gcs-authorization", ), - DestinationTestConfiguration( - destination="databricks", - staging="filesystem", - file_format="parquet", - bucket_url=AWS_BUCKET, - extra_info="s3-authorization", - ), - DestinationTestConfiguration( - destination="databricks", - staging="filesystem", - file_format="parquet", - bucket_url=AZ_BUCKET, - extra_info="az-authorization", - ), ] # add local filesystem destinations if requested From efd965f6f0862cc0954a44df56fda2ce4012f98a Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Mon, 15 Jan 2024 07:53:27 -0500 Subject: [PATCH 41/84] Init databricks ci --- .../workflows/test_destination_databricks.yml | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 .github/workflows/test_destination_databricks.yml diff --git a/.github/workflows/test_destination_databricks.yml b/.github/workflows/test_destination_databricks.yml new file mode 100644 index 0000000000..f301a1b9ed --- /dev/null +++ b/.github/workflows/test_destination_databricks.yml @@ -0,0 +1,88 @@ + +name: test databricks + +on: + pull_request: + branches: + - master + - devel + workflow_dispatch: + +env: + DLT_SECRETS_TOML: ${{ secrets.DLT_SECRETS_TOML }} + + RUNTIME__SENTRY_DSN: https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752 + RUNTIME__LOG_LEVEL: ERROR + + ACTIVE_DESTINATIONS: "[\"databricks\"]" + ALL_FILESYSTEM_DRIVERS: "[\"memory\"]" + +jobs: + get_docs_changes: + uses: ./.github/workflows/get_docs_changes.yml + if: ${{ !github.event.pull_request.head.repo.fork }} + + run_loader: + name: Tests Databricks loader + needs: get_docs_changes + if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' + strategy: + fail-fast: false + matrix: + os: ["ubuntu-latest"] + # os: ["ubuntu-latest", "macos-latest", "windows-latest"] + defaults: + run: + shell: bash + runs-on: ${{ matrix.os }} + + steps: + + - name: Check out + uses: actions/checkout@master + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: "3.10.x" + + - name: Install Poetry + uses: snok/install-poetry@v1.3.2 + with: + virtualenvs-create: true + virtualenvs-in-project: true + installer-parallel: true + + - name: Load cached venv + id: cached-poetry-dependencies + uses: actions/cache@v3 + with: + path: .venv + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}-gcp + + - name: Install dependencies + run: poetry install --no-interaction -E databricks -E s3 -E gs -E az -E parquet --with sentry-sdk --with pipeline + + - name: create secrets.toml + run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml + + - run: | + poetry run pytest tests/load + if: runner.os != 'Windows' + name: Run tests Linux/MAC + - run: | + poetry run pytest tests/load + if: runner.os == 'Windows' + name: Run tests Windows + shell: cmd + + matrix_job_required_check: + name: Databricks loader tests + needs: run_loader + runs-on: ubuntu-latest + if: always() + steps: + - name: Check matrix job results + if: contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') + run: | + echo "One or more matrix job tests failed or were cancelled. You may need to re-run them." && exit 1 From f843731c8c19b51f6e377d4d30a03c5e582b546b Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Mon, 15 Jan 2024 11:25:04 -0500 Subject: [PATCH 42/84] Lint, cleanup --- dlt/destinations/impl/databricks/__init__.py | 32 ------------------- .../impl/databricks/configuration.py | 2 +- 2 files changed, 1 insertion(+), 33 deletions(-) diff --git a/dlt/destinations/impl/databricks/__init__.py b/dlt/destinations/impl/databricks/__init__.py index 247ab51e58..bd1242b02a 100644 --- a/dlt/destinations/impl/databricks/__init__.py +++ b/dlt/destinations/impl/databricks/__init__.py @@ -1,29 +1,10 @@ -from typing import Type - -from dlt.common.schema.schema import Schema -from dlt.common.configuration import with_config, known_sections -from dlt.common.configuration.accessors import config from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import JobClientBase, DestinationClientConfiguration from dlt.common.data_writers.escape import escape_databricks_identifier from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration -@with_config( - spec=DatabricksClientConfiguration, - sections=( - known_sections.DESTINATION, - "databricks", - ), -) -def _configure( - config: DatabricksClientConfiguration = config.value, -) -> DatabricksClientConfiguration: - return config - - def capabilities() -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = "parquet" @@ -45,16 +26,3 @@ def capabilities() -> DestinationCapabilitiesContext: caps.alter_add_multi_column = True caps.supports_multiple_statements = False return caps - - -def client( - schema: Schema, initial_config: DestinationClientConfiguration = config.value -) -> JobClientBase: - # import client when creating instance so capabilities and config specs can be accessed without dependencies installed - from dlt.destinations.databricks.databricks import DatabricksClient - - return DatabricksClient(schema, _configure(initial_config)) # type: ignore - - -def spec() -> Type[DestinationClientConfiguration]: - return DatabricksClientConfiguration diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index edb547e1a2..c037b32cf3 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -11,7 +11,7 @@ @configspec class DatabricksCredentials(CredentialsConfiguration): - catalog: Optional[str] = None # type: ignore[assignment] + catalog: Optional[str] = None server_hostname: str = None http_path: str = None access_token: Optional[TSecretStrValue] = None From 52ef9398d843989f14b566f0ed6da58b151551e3 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Mon, 15 Jan 2024 16:24:47 -0500 Subject: [PATCH 43/84] Support databricks insert_values --- dlt/common/data_writers/escape.py | 15 +++++++++++++++ dlt/destinations/impl/athena/athena.py | 18 ++++++++++-------- dlt/destinations/impl/databricks/__init__.py | 7 ++++--- dlt/destinations/impl/databricks/databricks.py | 6 +++--- dlt/destinations/impl/databricks/sql_client.py | 16 +++++++++------- dlt/destinations/impl/mssql/sql_client.py | 2 +- dlt/destinations/job_client_impl.py | 4 ++-- dlt/destinations/sql_client.py | 9 +++++++-- dlt/pipeline/pipeline.py | 7 ++++--- tests/load/test_sql_client.py | 4 ++-- tests/load/utils.py | 2 +- 11 files changed, 58 insertions(+), 32 deletions(-) diff --git a/dlt/common/data_writers/escape.py b/dlt/common/data_writers/escape.py index b454c476dc..a6613110f8 100644 --- a/dlt/common/data_writers/escape.py +++ b/dlt/common/data_writers/escape.py @@ -127,3 +127,18 @@ def escape_snowflake_identifier(v: str) -> str: escape_databricks_identifier = escape_bigquery_identifier + + +def escape_databricks_literal(v: Any) -> Any: + if isinstance(v, str): + return _escape_extended(v, prefix="'") + if isinstance(v, (datetime, date, time)): + return f"'{v.isoformat()}'" + if isinstance(v, (list, dict)): + return _escape_extended(json.dumps(v), prefix="'") + if isinstance(v, bytes): + return "X'{v.hex()}'" + if v is None: + return "NULL" + + return str(v) diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index 4837f0dbdf..91525d771c 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -232,7 +232,7 @@ def drop_tables(self, *tables: str) -> None: statements = [ f"DROP TABLE IF EXISTS {self.make_qualified_ddl_table_name(table)};" for table in tables ] - self.execute_fragments(statements) + self.execute_many(statements) @contextmanager @raise_database_error @@ -351,9 +351,7 @@ def _from_db_type( return self.type_mapper.from_db_type(hive_t, precision, scale) def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: - return ( - f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}" - ) + return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}" def _get_table_update_sql( self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool @@ -378,15 +376,19 @@ def _get_table_update_sql( # use qualified table names qualified_table_name = self.sql_client.make_qualified_ddl_table_name(table_name) if is_iceberg and not generate_alter: - sql.append(f"""CREATE TABLE {qualified_table_name} + sql.append( + f"""CREATE TABLE {qualified_table_name} ({columns}) LOCATION '{location}' - TBLPROPERTIES ('table_type'='ICEBERG', 'format'='parquet');""") + TBLPROPERTIES ('table_type'='ICEBERG', 'format'='parquet');""" + ) elif not generate_alter: - sql.append(f"""CREATE EXTERNAL TABLE {qualified_table_name} + sql.append( + f"""CREATE EXTERNAL TABLE {qualified_table_name} ({columns}) STORED AS PARQUET - LOCATION '{location}';""") + LOCATION '{location}';""" + ) # alter table to add new columns at the end else: sql.append(f"""ALTER TABLE {qualified_table_name} ADD COLUMNS ({columns});""") diff --git a/dlt/destinations/impl/databricks/__init__.py b/dlt/destinations/impl/databricks/__init__.py index bd1242b02a..97836a8ce2 100644 --- a/dlt/destinations/impl/databricks/__init__.py +++ b/dlt/destinations/impl/databricks/__init__.py @@ -1,5 +1,5 @@ from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.data_writers.escape import escape_databricks_identifier +from dlt.common.data_writers.escape import escape_databricks_identifier, escape_databricks_literal from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration @@ -7,11 +7,12 @@ def capabilities() -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() - caps.preferred_loader_file_format = "parquet" - caps.supported_loader_file_formats = ["parquet"] + caps.preferred_loader_file_format = "insert_values" + caps.supported_loader_file_formats = ["insert_values"] caps.preferred_staging_file_format = "parquet" caps.supported_staging_file_formats = ["parquet"] caps.escape_identifier = escape_databricks_identifier + caps.escape_literal = escape_databricks_literal caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) caps.wei_precision = (DEFAULT_NUMERIC_PRECISION, 0) caps.max_identifier_length = 255 diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 6ad21aff5e..d8425f5b57 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -21,7 +21,7 @@ from dlt.common.schema.typing import TTableSchema, TColumnType, TSchemaTables, TTableFormat -from dlt.destinations.job_client_impl import SqlJobClientWithStaging +from dlt.destinations.insert_job_client import InsertValuesJobClient from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.exceptions import LoadJobTerminalException @@ -230,7 +230,7 @@ def generate_sql( # return f"CREATE OR REPLACE TEMPORARY VIEW {temp_table_name} AS {select_sql};" -class DatabricksClient(SqlJobClientWithStaging, SupportsStagingDestination): +class DatabricksClient(InsertValuesJobClient, SupportsStagingDestination): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: DatabricksClientConfiguration) -> None: @@ -303,7 +303,7 @@ def _execute_schema_update_sql(self, only_tables: Iterable[str]) -> TSchemaTable sql_scripts, schema_update = self._build_schema_update_sql(only_tables) # stay within max query size when doing DDL. some db backends use bytes not characters so decrease limit by half # assuming that most of the characters in DDL encode into single bytes - self.sql_client.execute_fragments(sql_scripts) + self.sql_client.execute_many(sql_scripts) self._update_schema_in_storage(self.schema) return schema_update diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index b94ff763cf..b2dbfcdd05 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -97,14 +97,16 @@ def execute_sql( f = curr.fetchall() return f - def execute_fragments( - self, fragments: Sequence[AnyStr], *args: Any, **kwargs: Any + def execute_many( + self, statements: Sequence[AnyStr], *args: Any, **kwargs: Any ) -> Optional[Sequence[Sequence[Any]]]: - """ - Executes several SQL fragments as efficiently as possible to prevent data copying. - Default implementation just joins the strings and executes them together. - """ - return [self.execute_sql(fragment, *args, **kwargs) for fragment in fragments] + """Databricks does not support multi-statement execution""" + ret = [] + for statement in statements: + result = self.execute_sql(statement, *args, **kwargs) + if result is not None: + ret.append(result) + return ret @contextmanager @raise_database_error diff --git a/dlt/destinations/impl/mssql/sql_client.py b/dlt/destinations/impl/mssql/sql_client.py index 427518feeb..53ed7cfd90 100644 --- a/dlt/destinations/impl/mssql/sql_client.py +++ b/dlt/destinations/impl/mssql/sql_client.py @@ -115,7 +115,7 @@ def _drop_views(self, *tables: str) -> None: statements = [ f"DROP VIEW IF EXISTS {self.make_qualified_table_name(table)};" for table in tables ] - self.execute_fragments(statements) + self.execute_many(statements) def execute_sql( self, sql: AnyStr, *args: Any, **kwargs: Any diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 7b6e91baa2..90f4749c39 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -78,7 +78,7 @@ def __init__(self, file_path: str, sql_client: SqlClientBase[Any]) -> None: # Some clients (e.g. databricks) do not support multiple statements in one execute call if not sql_client.capabilities.supports_multiple_statements: - sql_client.execute_fragments(self._split_fragments(sql)) + sql_client.execute_many(self._split_fragments(sql)) # if we detect ddl transactions, only execute transaction if supported by client elif ( not self._string_containts_ddl_queries(sql) @@ -88,7 +88,7 @@ def __init__(self, file_path: str, sql_client: SqlClientBase[Any]) -> None: sql_client.execute_sql(sql) else: # sql_client.execute_sql(sql) - sql_client.execute_fragments(self._split_fragments(sql)) + sql_client.execute_many(self._split_fragments(sql)) def state(self) -> TLoadJobState: # this job is always done diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index e6df2c265b..8803f88a3d 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -90,7 +90,7 @@ def drop_dataset(self) -> None: def truncate_tables(self, *tables: str) -> None: statements = [self._truncate_table_sql(self.make_qualified_table_name(t)) for t in tables] - self.execute_fragments(statements) + self.execute_many(statements) def drop_tables(self, *tables: str) -> None: if not tables: @@ -98,7 +98,7 @@ def drop_tables(self, *tables: str) -> None: statements = [ f"DROP TABLE IF EXISTS {self.make_qualified_table_name(table)};" for table in tables ] - self.execute_fragments(statements) + self.execute_many(statements) @abstractmethod def execute_sql( @@ -118,6 +118,11 @@ def execute_fragments( """Executes several SQL fragments as efficiently as possible to prevent data copying. Default implementation just joins the strings and executes them together.""" return self.execute_sql("".join(fragments), *args, **kwargs) # type: ignore + def execute_many( + self, statements: Sequence[AnyStr], *args: Any, **kwargs: Any + ) -> Optional[Sequence[Sequence[Any]]]: + return self.execute_sql("".join(statements), *args, **kwargs) # type: ignore + @abstractmethod def fully_qualified_dataset_name(self, escape: bool = True) -> str: pass diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 73c8f076d1..98fcd56b6d 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -445,6 +445,7 @@ def normalize( runner.run_pool(normalize_step.config, normalize_step) return self._get_step_info(normalize_step) except Exception as n_ex: + raise step_info = self._get_step_info(normalize_step) raise PipelineStepFailed( self, @@ -1163,9 +1164,9 @@ def _set_context(self, is_active: bool) -> None: # set destination context on activation if self.destination: # inject capabilities context - self._container[DestinationCapabilitiesContext] = ( - self._get_destination_capabilities() - ) + self._container[ + DestinationCapabilitiesContext + ] = self._get_destination_capabilities() else: # remove destination context on deactivation if DestinationCapabilitiesContext in self._container: diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index 96f0db09bb..8b005a08c8 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -567,7 +567,7 @@ def test_recover_on_explicit_tx(client: SqlJobClientBase) -> None: # syntax error within tx statements = ["BEGIN TRANSACTION;", f"INVERT INTO {version_table} VALUES(1);", "COMMIT;"] with pytest.raises(DatabaseTransientException): - client.sql_client.execute_fragments(statements) + client.sql_client.execute_many(statements) # assert derives_from_class_of_name(term_ex.value.dbapi_exception, "ProgrammingError") assert client.get_stored_schema() is not None client.complete_load("EFG") @@ -581,7 +581,7 @@ def test_recover_on_explicit_tx(client: SqlJobClientBase) -> None: ] # cannot insert NULL value with pytest.raises(DatabaseTerminalException): - client.sql_client.execute_fragments(statements) + client.sql_client.execute_many(statements) # assert derives_from_class_of_name(term_ex.value.dbapi_exception, "IntegrityError") # assert isinstance(term_ex.value.dbapi_exception, (psycopg2.InternalError, psycopg2.)) assert client.get_stored_schema() is not None diff --git a/tests/load/utils.py b/tests/load/utils.py index da8cd56328..3587bd9fa5 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -163,7 +163,7 @@ def destinations_configs( destination_configs += [ DestinationTestConfiguration(destination=destination) for destination in SQL_DESTINATIONS - if destination not in ("athena", "databricks") + if destination != "athena" ] destination_configs += [ DestinationTestConfiguration(destination="duckdb", file_format="parquet") From 4cd91a8cac7a71b3d24ae101d01ddb61dc5e775c Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Mon, 15 Jan 2024 21:13:53 -0500 Subject: [PATCH 44/84] Databricks merge disposition support --- .../impl/databricks/databricks.py | 27 ++++++++--- .../impl/databricks/sql_client.py | 26 +++-------- dlt/destinations/sql_jobs.py | 46 ++++++++++++++----- tests/load/pipeline/test_stage_loading.py | 1 + 4 files changed, 62 insertions(+), 38 deletions(-) diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index d8425f5b57..28bbd0dd2f 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -224,10 +224,21 @@ def generate_sql( return sql -# class DatabricksMergeJob(SqlMergeJob): -# @classmethod -# def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: -# return f"CREATE OR REPLACE TEMPORARY VIEW {temp_table_name} AS {select_sql};" +class DatabricksMergeJob(SqlMergeJob): + @classmethod + def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: + return f"CREATE TEMPORARY VIEW {temp_table_name} AS {select_sql};" + + @classmethod + def gen_delete_from_sql( + cls, table_name: str, column_name: str, temp_table_name: str, temp_table_column: str + ) -> str: + # Databricks does not support subqueries in DELETE FROM statements so we use a MERGE statement instead + return f"""MERGE INTO {table_name} + USING {temp_table_name} + ON {table_name}.{column_name} = {temp_table_name}.{temp_table_column} + WHEN MATCHED THEN DELETE; + """ class DatabricksClient(InsertValuesJobClient, SupportsStagingDestination): @@ -260,8 +271,11 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") - # def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: - # return DatabricksMergeJob.from_table_chain(table_chain, self.sql_client) + def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: + return DatabricksMergeJob.from_table_chain(table_chain, self.sql_client) + + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + return [DatabricksMergeJob.from_table_chain(table_chain, self.sql_client)] # def _create_staging_copy_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: # return DatabricksStagingCopyJob.from_table_chain(table_chain, self.sql_client) @@ -274,6 +288,7 @@ def _make_add_column_sql( # def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: # return DatabricksStagingCopyJob.from_table_chain(table_chain, self.sql_client) + def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] ) -> List[NewLoadJob]: diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index b2dbfcdd05..a19cd24811 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -59,20 +59,18 @@ def close_connection(self) -> None: @contextmanager def begin_transaction(self) -> Iterator[DBTransaction]: - logger.warning( - "NotImplemented: Databricks does not support transactions. Each SQL statement is" - " auto-committed separately." - ) + # Databricks does not support transactions yield self @raise_database_error def commit_transaction(self) -> None: - logger.warning("NotImplemented: commit") + # Databricks does not support transactions pass @raise_database_error def rollback_transaction(self) -> None: - logger.warning("NotImplemented: rollback") + # Databricks does not support transactions + pass @property def native_connection(self) -> "DatabricksSqlConnection": @@ -127,16 +125,8 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB else: db_args = None with self._conn.cursor() as curr: - try: - curr.execute(query, db_args) - yield DatabricksCursorImpl(curr) # type: ignore[abstract] - except databricks_lib.Error as outer: - try: - self._reset_connection() - except databricks_lib.Error: - self.close_connection() - self.open_connection() - raise outer + curr.execute(query, db_args) + yield DatabricksCursorImpl(curr) # type: ignore[abstract] def fully_qualified_dataset_name(self, escape: bool = True) -> str: if escape: @@ -147,10 +137,6 @@ def fully_qualified_dataset_name(self, escape: bool = True) -> str: dataset_name = self.dataset_name return f"{catalog}.{dataset_name}" - def _reset_connection(self) -> None: - self.close_connection() - self.open_connection() - @staticmethod def _make_database_exception(ex: Exception) -> Exception: if isinstance(ex, databricks_lib.ServerOperationError): diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index d97a098669..899947313d 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -186,6 +186,21 @@ def gen_insert_temp_table_sql( """ return [cls._to_temp_table(select_statement, temp_table_name)], temp_table_name + @classmethod + def gen_delete_from_sql( + cls, + table_name: str, + unique_column: str, + delete_temp_table_name: str, + temp_table_column: str, + ) -> str: + """Generate DELETE FROM statement deleting the records found in the deletes temp table.""" + return f"""DELETE FROM {table_name} + WHERE {unique_column} IN ( + SELECT * FROM {delete_temp_table_name} + ); + """ + @classmethod def _new_temp_table_name(cls, name_prefix: str) -> str: return f"{name_prefix}_{uniq_id()}" @@ -261,12 +276,9 @@ def gen_merge_sql( unique_column, key_table_clauses ) sql.extend(create_delete_temp_table_sql) - # delete top table - sql.append( - f"DELETE FROM {root_table_name} WHERE {unique_column} IN (SELECT * FROM" - f" {delete_temp_table_name});" - ) - # delete other tables + + # delete from child tables first. This is important for databricks which does not support temporary tables, + # but uses temporary views instead for table in table_chain[1:]: table_name = sql_client.make_qualified_table_name(table["name"]) root_key_columns = get_columns_names_with_prop(table, "root_key") @@ -281,15 +293,25 @@ def gen_merge_sql( ) root_key_column = sql_client.capabilities.escape_identifier(root_key_columns[0]) sql.append( - f"DELETE FROM {table_name} WHERE {root_key_column} IN (SELECT * FROM" - f" {delete_temp_table_name});" + cls.gen_delete_from_sql( + table_name, root_key_column, delete_temp_table_name, unique_column + ) + ) + + # delete from top table now that child tables have been prcessed + sql.append( + cls.gen_delete_from_sql( + root_table_name, unique_column, delete_temp_table_name, unique_column ) + ) + # create temp table used to deduplicate, only when we have primary keys if primary_keys: - create_insert_temp_table_sql, insert_temp_table_name = ( - cls.gen_insert_temp_table_sql( - staging_root_table_name, primary_keys, unique_column - ) + ( + create_insert_temp_table_sql, + insert_temp_table_name, + ) = cls.gen_insert_temp_table_sql( + staging_root_table_name, primary_keys, unique_column ) sql.extend(create_insert_temp_table_sql) diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py index de4a7f4c3b..bba589b444 100644 --- a/tests/load/pipeline/test_stage_loading.py +++ b/tests/load/pipeline/test_stage_loading.py @@ -158,6 +158,7 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non if destination_config.destination in ( "redshift", "athena", + "databricks", ) and destination_config.file_format in ("parquet", "jsonl"): # Redshift copy doesn't support TIME column exclude_types.append("time") From 444d1966f445b4bb14fd7d8bfe2f30d087a78ef7 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Tue, 16 Jan 2024 17:25:58 -0500 Subject: [PATCH 45/84] Fix string escaping --- dlt/common/data_writers/escape.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/dlt/common/data_writers/escape.py b/dlt/common/data_writers/escape.py index a6613110f8..f9b672fc65 100644 --- a/dlt/common/data_writers/escape.py +++ b/dlt/common/data_writers/escape.py @@ -129,13 +129,16 @@ def escape_snowflake_identifier(v: str) -> str: escape_databricks_identifier = escape_bigquery_identifier +DATABRICKS_ESCAPE_DICT = {"'": "\\'", "\\": "\\\\", "\n": "\\n", "\r": "\\r"} + + def escape_databricks_literal(v: Any) -> Any: if isinstance(v, str): - return _escape_extended(v, prefix="'") + return _escape_extended(v, prefix="'", escape_dict=DATABRICKS_ESCAPE_DICT) if isinstance(v, (datetime, date, time)): return f"'{v.isoformat()}'" if isinstance(v, (list, dict)): - return _escape_extended(json.dumps(v), prefix="'") + return _escape_extended(json.dumps(v), prefix="'", escape_dict=DATABRICKS_ESCAPE_DICT) if isinstance(v, bytes): return "X'{v.hex()}'" if v is None: From ed3d58b7ef653d71f54f3826accd83593f9355a1 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Tue, 16 Jan 2024 17:26:03 -0500 Subject: [PATCH 46/84] Remove keep-staged-files option --- dlt/destinations/impl/databricks/configuration.py | 2 -- dlt/destinations/impl/databricks/databricks.py | 5 ----- 2 files changed, 7 deletions(-) diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index c037b32cf3..eaf99038a5 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -103,8 +103,6 @@ class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration stage_name: Optional[str] = None """Use an existing named stage instead of the default. Default uses the implicit table stage per table""" - keep_staged_files: bool = True - """Whether to keep or delete the staged files after COPY INTO succeeds""" def __str__(self) -> str: """Return displayable destination location""" diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 28bbd0dd2f..32084de240 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -107,7 +107,6 @@ def __init__( load_id: str, client: DatabricksSqlClient, stage_name: Optional[str] = None, - keep_staged_files: bool = True, staging_credentials: Optional[CredentialsConfiguration] = None, ) -> None: file_name = FileStorage.get_file_name_from_file_path(file_path) @@ -194,9 +193,6 @@ def __init__( {copy_options} """ client.execute_sql(statement) - # Databricks does not support deleting staged files via sql - # if stage_file_path and not keep_staged_files: - # client.execute_sql(f'REMOVE {stage_file_path}') def state(self) -> TLoadJobState: return "completed" @@ -261,7 +257,6 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> load_id, self.sql_client, stage_name=self.config.stage_name, - keep_staged_files=self.config.keep_staged_files, staging_credentials=( self.config.staging_config.credentials if self.config.staging_config else None ), From 1fcdb5d60bfbec2fec7d1467682a93422e6588bd Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Tue, 16 Jan 2024 23:51:56 -0500 Subject: [PATCH 47/84] Exceptions fix, binary escape --- dlt/common/data_writers/escape.py | 2 +- .../impl/databricks/sql_client.py | 17 +++------ tests/load/test_job_client.py | 13 +++++-- tests/load/test_sql_client.py | 37 +++++++++++++------ 4 files changed, 41 insertions(+), 28 deletions(-) diff --git a/dlt/common/data_writers/escape.py b/dlt/common/data_writers/escape.py index f9b672fc65..f27b48a95f 100644 --- a/dlt/common/data_writers/escape.py +++ b/dlt/common/data_writers/escape.py @@ -140,7 +140,7 @@ def escape_databricks_literal(v: Any) -> Any: if isinstance(v, (list, dict)): return _escape_extended(json.dumps(v), prefix="'", escape_dict=DATABRICKS_ESCAPE_DICT) if isinstance(v, bytes): - return "X'{v.hex()}'" + return f"X'{v.hex()}'" if v is None: return "NULL" diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index a19cd24811..676d8be305 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -22,21 +22,12 @@ raise_database_error, raise_open_connection_error, ) -from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction, DataFrame +from dlt.destinations.typing import DBApi, DBApiCursor, DBTransaction from dlt.destinations.impl.databricks.configuration import DatabricksCredentials from dlt.destinations.impl.databricks import capabilities from dlt.common.time import to_py_date, to_py_datetime -class DatabricksCursorImpl(DBApiCursorImpl): - native_cursor: DatabricksSqlCursor - - def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: - if chunk_size is None: - return self.native_cursor.fetchall(**kwargs) - return super().df(chunk_size=chunk_size, **kwargs) - - class DatabricksSqlClient(SqlClientBase[DatabricksSqlConnection], DBTransaction): dbapi: ClassVar[DBApi] = databricks_lib capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -126,7 +117,7 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB db_args = None with self._conn.cursor() as curr: curr.execute(query, db_args) - yield DatabricksCursorImpl(curr) # type: ignore[abstract] + yield DBApiCursorImpl(curr) # type: ignore[abstract] def fully_qualified_dataset_name(self, escape: bool = True) -> str: if escape: @@ -142,6 +133,10 @@ def _make_database_exception(ex: Exception) -> Exception: if isinstance(ex, databricks_lib.ServerOperationError): if "TABLE_OR_VIEW_NOT_FOUND" in str(ex): return DatabaseUndefinedRelation(ex) + elif "SCHEMA_NOT_FOUND" in str(ex): + return DatabaseUndefinedRelation(ex) + elif "PARSE_SYNTAX_ERROR" in str(ex): + return DatabaseTransientException(ex) return DatabaseTerminalException(ex) elif isinstance(ex, databricks_lib.OperationalError): return DatabaseTerminalException(ex) diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 65a19a5323..2503e464c0 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -28,13 +28,13 @@ from dlt.destinations.job_client_impl import SqlJobClientBase from dlt.common.destination.reference import WithStagingDataset +from tests.cases import table_update_and_row, assert_all_data_types_row from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage from tests.common.utils import load_json_case from tests.load.utils import ( TABLE_UPDATE, TABLE_UPDATE_COLUMNS_SCHEMA, TABLE_ROW_ALL_DATA_TYPES, - assert_all_data_types_row, expect_load_file, load_table, yield_client_with_storage, @@ -504,9 +504,14 @@ def test_load_with_all_types( if not client.capabilities.preferred_loader_file_format: pytest.skip("preferred loader file format not set, destination will only work with staging") table_name = "event_test_table" + uniq_id() + column_schemas, data_types = table_update_and_row( + exclude_types=["time"] if client.config.destination_type == "databricks" else None, + ) # we should have identical content with all disposition types client.schema.update_table( - new_table(table_name, write_disposition=write_disposition, columns=TABLE_UPDATE) + new_table( + table_name, write_disposition=write_disposition, columns=list(column_schemas.values()) + ) ) client.schema.bump_version() client.update_stored_schema() @@ -523,12 +528,12 @@ def test_load_with_all_types( canonical_name = client.sql_client.make_qualified_table_name(table_name) # write row with io.BytesIO() as f: - write_dataset(client, f, [TABLE_ROW_ALL_DATA_TYPES], TABLE_UPDATE_COLUMNS_SCHEMA) + write_dataset(client, f, [data_types], column_schemas) query = f.getvalue().decode() expect_load_file(client, file_storage, query, table_name) db_row = list(client.sql_client.execute_sql(f"SELECT * FROM {canonical_name}")[0]) # content must equal - assert_all_data_types_row(db_row) + assert_all_data_types_row(db_row, schema=column_schemas) @pytest.mark.parametrize( diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index 8b005a08c8..ff0a84e07e 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -311,16 +311,17 @@ def test_database_exceptions(client: SqlJobClientBase) -> None: pass assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) with pytest.raises(DatabaseUndefinedRelation) as term_ex: - with client.sql_client.execute_query( - "DELETE FROM TABLE_XXX WHERE 1=1;DELETE FROM ticket_forms__ticket_field_ids WHERE 1=1;" - ): - pass + client.sql_client.execute_many( + [ + "DELETE FROM TABLE_XXX WHERE 1=1;", + "DELETE FROM ticket_forms__ticket_field_ids WHERE 1=1;", + ] + ) assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) with pytest.raises(DatabaseUndefinedRelation) as term_ex: - with client.sql_client.execute_query( - "DROP TABLE TABLE_XXX;DROP TABLE ticket_forms__ticket_field_ids;" - ): - pass + client.sql_client.execute_many( + ["DROP TABLE TABLE_XXX;", "DROP TABLE ticket_forms__ticket_field_ids;"] + ) # invalid syntax with pytest.raises(DatabaseTransientException) as term_ex: @@ -360,7 +361,10 @@ def test_database_exceptions(client: SqlJobClientBase) -> None: @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True, exclude=["databricks"]), + indirect=True, + ids=lambda x: x.name, ) def test_commit_transaction(client: SqlJobClientBase) -> None: table_name = prepare_temp_table(client) @@ -391,7 +395,10 @@ def test_commit_transaction(client: SqlJobClientBase) -> None: @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True, exclude=["databricks"]), + indirect=True, + ids=lambda x: x.name, ) def test_rollback_transaction(client: SqlJobClientBase) -> None: if client.capabilities.supports_transactions is False: @@ -449,7 +456,10 @@ def test_rollback_transaction(client: SqlJobClientBase) -> None: @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True, exclude=["databricks"]), + indirect=True, + ids=lambda x: x.name, ) def test_transaction_isolation(client: SqlJobClientBase) -> None: if client.capabilities.supports_transactions is False: @@ -546,7 +556,10 @@ def test_max_column_identifier_length(client: SqlJobClientBase) -> None: @pytest.mark.parametrize( - "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name + "client", + destinations_configs(default_sql_configs=True, exclude=["databricks"]), + indirect=True, + ids=lambda x: x.name, ) def test_recover_on_explicit_tx(client: SqlJobClientBase) -> None: if client.capabilities.supports_transactions is False: From e2c9aedcaef14bce5be9a3fe8c61e95bb8b8903d Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 17 Jan 2024 15:18:48 -0500 Subject: [PATCH 48/84] databricks dbt profile --- dlt/helpers/dbt/profiles.yml | 15 +- poetry.lock | 374 ++++++++++++++++++++++++----------- pyproject.toml | 5 +- 3 files changed, 281 insertions(+), 113 deletions(-) diff --git a/dlt/helpers/dbt/profiles.yml b/dlt/helpers/dbt/profiles.yml index 2414222cbd..a9a30106b9 100644 --- a/dlt/helpers/dbt/profiles.yml +++ b/dlt/helpers/dbt/profiles.yml @@ -141,4 +141,17 @@ athena: schema: "{{ var('destination_dataset_name', var('source_dataset_name')) }}" database: "{{ env_var('DLT__AWS_DATA_CATALOG') }}" # aws_profile_name: "{{ env_var('DLT__CREDENTIALS__PROFILE_NAME', '') }}" - work_group: "{{ env_var('DLT__ATHENA_WORK_GROUP', '') }}" \ No newline at end of file + work_group: "{{ env_var('DLT__ATHENA_WORK_GROUP', '') }}" + + +databricks: + target: analytics + outputs: + analytics: + type: databricks + catalog: "{{ env_var('DLT__CREDENTIALS__CATALOG') }}" + schema: "{{ var('destination_dataset_name', var('source_dataset_name')) }}" + host: "{{ env_var('DLT__CREDENTIALS__SERVER_HOSTNAME') }}" + http_path: "{{ env_var('DLT__CREDENTIALS__HTTP_PATH') }}" + token: "{{ env_var('DLT__CREDENTIALS__ACCESS_TOKEN') }}" + threads: 4 diff --git a/poetry.lock b/poetry.lock index 777b660e1a..e8cb1eb1fe 100644 --- a/poetry.lock +++ b/poetry.lock @@ -35,27 +35,26 @@ docs = ["furo", "myst-parser", "numpydoc", "sphinx"] [[package]] name = "agate" -version = "1.6.3" +version = "1.7.1" description = "A data analysis library that is optimized for humans instead of machines." optional = false python-versions = "*" files = [ - {file = "agate-1.6.3-py2.py3-none-any.whl", hash = "sha256:2d568fd68a8eb8b56c805a1299ba4bc30ca0434563be1bea309c9d1c1c8401f4"}, - {file = "agate-1.6.3.tar.gz", hash = "sha256:e0f2f813f7e12311a4cdccc97d6ba0a6781e9c1aa8eca0ab00d5931c0113a308"}, + {file = "agate-1.7.1-py2.py3-none-any.whl", hash = "sha256:23f9f412f74f97b72f82b1525ab235cc816bc8c8525d968a091576a0dbc54a5f"}, + {file = "agate-1.7.1.tar.gz", hash = "sha256:eadf46d980168b8922d5d396d6258eecd5e7dbef7e6f0c0b71e968545ea96389"}, ] [package.dependencies] Babel = ">=2.0" isodate = ">=0.5.4" leather = ">=0.3.2" -parsedatetime = ">=2.1,<2.5 || >2.5,<2.6 || >2.6" +parsedatetime = ">=2.1,<2.5 || >2.5" python-slugify = ">=1.2.1" pytimeparse = ">=1.1.5" -six = ">=1.9.0" [package.extras] docs = ["Sphinx (>=1.2.2)", "sphinx-rtd-theme (>=0.1.6)"] -test = ["PyICU (>=2.4.2)", "coverage (>=3.7.1)", "cssselect (>=0.9.1)", "lxml (>=3.6.0)", "mock (>=1.3.0)", "nose (>=1.1.2)", "pytz (>=2015.4)", "unittest2 (>=1.1.0)"] +test = ["PyICU (>=2.4.2)", "coverage (>=3.7.1)", "cssselect (>=0.9.1)", "lxml (>=3.6.0)", "pytest", "pytest-cov", "pytz (>=2015.4)"] [[package]] name = "aiobotocore" @@ -1886,97 +1885,119 @@ nr-date = ">=2.0.0,<3.0.0" typeapi = ">=2.0.1,<3.0.0" typing-extensions = ">=3.10.0" +[[package]] +name = "databricks-sdk" +version = "0.17.0" +description = "Databricks SDK for Python (Beta)" +optional = true +python-versions = ">=3.7" +files = [ + {file = "databricks-sdk-0.17.0.tar.gz", hash = "sha256:0a1baa6783aba9b034b9a017da8d0cf839ec61ae8318792b78bfb3db0374dd9c"}, + {file = "databricks_sdk-0.17.0-py3-none-any.whl", hash = "sha256:ad90e01c7b1a9d60a3de6a35606c79ac982e8972d3ad3ff89c251c24439c8bb9"}, +] + +[package.dependencies] +google-auth = ">=2.0,<3.0" +requests = ">=2.28.1,<3" + +[package.extras] +dev = ["autoflake", "ipython", "ipywidgets", "isort", "pycodestyle", "pyfakefs", "pytest", "pytest-cov", "pytest-mock", "pytest-xdist", "requests-mock", "wheel", "yapf"] +notebook = ["ipython (>=8,<9)", "ipywidgets (>=8,<9)"] + [[package]] name = "databricks-sql-connector" -version = "3.0.1" +version = "2.9.3" description = "Databricks SQL Connector for Python" optional = true -python-versions = ">=3.8.0,<4.0.0" +python-versions = ">=3.7.1,<4.0.0" files = [ - {file = "databricks_sql_connector-3.0.1-py3-none-any.whl", hash = "sha256:3824237732f4363f55e3a1b8dd90ac98b8008c66e869377c8a213581d13dcee2"}, - {file = "databricks_sql_connector-3.0.1.tar.gz", hash = "sha256:915648d5d43e41622d65446bf60c07b2a0d33f9e5ad03478712205703927fdb8"}, + {file = "databricks_sql_connector-2.9.3-py3-none-any.whl", hash = "sha256:e37b5aa8bea22e84a9920e87ad9ba6cafbe656008c180a790baa53b711dd9889"}, + {file = "databricks_sql_connector-2.9.3.tar.gz", hash = "sha256:09a1686de3470091e78640de276053d4e18f8c03ba3627ed45b368f78bf87db9"}, ] [package.dependencies] +alembic = ">=1.0.11,<2.0.0" lz4 = ">=4.0.2,<5.0.0" numpy = [ - {version = ">=1.16.6", markers = "python_version >= \"3.8\" and python_version < \"3.11\""}, + {version = ">=1.16.6", markers = "python_version >= \"3.7\" and python_version < \"3.11\""}, {version = ">=1.23.4", markers = "python_version >= \"3.11\""}, ] oauthlib = ">=3.1.0,<4.0.0" openpyxl = ">=3.0.10,<4.0.0" pandas = {version = ">=1.2.5,<3.0.0", markers = "python_version >= \"3.8\""} -pyarrow = ">=14.0.1,<15.0.0" +pyarrow = [ + {version = ">=6.0.0", markers = "python_version >= \"3.7\" and python_version < \"3.11\""}, + {version = ">=10.0.1", markers = "python_version >= \"3.11\""}, +] requests = ">=2.18.1,<3.0.0" +sqlalchemy = ">=1.3.24,<2.0.0" thrift = ">=0.16.0,<0.17.0" urllib3 = ">=1.0" -[package.extras] -alembic = ["alembic (>=1.0.11,<2.0.0)", "sqlalchemy (>=2.0.21)"] -sqlalchemy = ["sqlalchemy (>=2.0.21)"] - [[package]] name = "dbt-athena-community" -version = "1.5.2" +version = "1.7.1" description = "The athena adapter plugin for dbt (data build tool)" optional = true -python-versions = "*" +python-versions = ">=3.8" files = [ - {file = "dbt-athena-community-1.5.2.tar.gz", hash = "sha256:9acd333ddf33514769189a7a0b6219e13966d370098211cb1d022fa32e64671a"}, - {file = "dbt_athena_community-1.5.2-py3-none-any.whl", hash = "sha256:c9f0f8425500211a1c1deddce5aff5ed24fe08530f0ffad38e63de9c9b9f3ee6"}, + {file = "dbt-athena-community-1.7.1.tar.gz", hash = "sha256:02c7bc461628e2adbfaf9d3f51fbe9a5cb5e06ee2ea8329259758518ceafdc12"}, + {file = "dbt_athena_community-1.7.1-py3-none-any.whl", hash = "sha256:2a376fa128e2bd98cb774fcbf718ebe4fbc9cac7857aa037b9e36bec75448361"}, ] [package.dependencies] boto3 = ">=1.26,<2.0" boto3-stubs = {version = ">=1.26,<2.0", extras = ["athena", "glue", "lakeformation", "sts"]} -dbt-core = ">=1.5.0,<1.6.0" +dbt-core = ">=1.7.0,<1.8.0" +mmh3 = ">=4.0.1,<4.1.0" pyathena = ">=2.25,<4.0" pydantic = ">=1.10,<3.0" tenacity = ">=8.2,<9.0" [[package]] name = "dbt-bigquery" -version = "1.5.6" +version = "1.7.2" description = "The Bigquery adapter plugin for dbt" optional = true python-versions = ">=3.8" files = [ - {file = "dbt-bigquery-1.5.6.tar.gz", hash = "sha256:4655cf2ee0acda986b80e6c5d55cae57871bef22d868dfe29d8d4a5bca98a1ba"}, - {file = "dbt_bigquery-1.5.6-py3-none-any.whl", hash = "sha256:3f37544716880cbd17b32bc0c9728a0407b5615b2cd08e1bb904a7a83c46eb6c"}, + {file = "dbt-bigquery-1.7.2.tar.gz", hash = "sha256:27c7f492f65ab5d1d43432a4467a436fc3637e3cb72c5b4ab07ddf7573c43596"}, + {file = "dbt_bigquery-1.7.2-py3-none-any.whl", hash = "sha256:75015755363d9e8b8cebe190d59a5e08375032b37bcfec41ec8753e7dea29f6e"}, ] [package.dependencies] -agate = ">=1.6.3,<1.7.0" -dbt-core = ">=1.5.0,<1.6.0" +dbt-core = ">=1.7.0,<1.8.0" +google-api-core = ">=2.11.0" google-cloud-bigquery = ">=3.0,<4.0" google-cloud-dataproc = ">=5.0,<6.0" google-cloud-storage = ">=2.4,<3.0" [[package]] name = "dbt-core" -version = "1.5.6" +version = "1.7.4" description = "With dbt, data analysts and engineers can build analytics the way engineers build applications." optional = false -python-versions = ">=3.7.2" +python-versions = ">=3.8" files = [ - {file = "dbt-core-1.5.6.tar.gz", hash = "sha256:af3c03cd4a1fc92481362888014ca1ffed2ffef0b0e0d98463ad0f26c49ef458"}, - {file = "dbt_core-1.5.6-py3-none-any.whl", hash = "sha256:030d2179f9efbf8ccea079296d0c79278d963bb2475c0bcce9ca4bbb0d8c393c"}, + {file = "dbt-core-1.7.4.tar.gz", hash = "sha256:769b95949210cb0d1eafdb7be48b01e59984650403f86510fdee65bd0f70f76d"}, + {file = "dbt_core-1.7.4-py3-none-any.whl", hash = "sha256:50050ae44fe9bad63e1b639810ed3629822cdc7a2af0eff6e08461c94c4527c0"}, ] [package.dependencies] -agate = ">=1.6,<1.7.1" +agate = ">=1.7.0,<1.8.0" cffi = ">=1.9,<2.0.0" -click = "<9" -colorama = ">=0.3.9,<0.4.7" -dbt-extractor = ">=0.4.1,<0.5.0" -hologram = ">=0.0.14,<=0.0.16" +click = ">=8.0.2,<9" +colorama = ">=0.3.9,<0.5" +dbt-extractor = ">=0.5.0,<0.6.0" +dbt-semantic-interfaces = ">=0.4.2,<0.5.0" idna = ">=2.5,<4" isodate = ">=0.6,<0.7" -Jinja2 = "3.1.2" +Jinja2 = ">=3.1.2,<3.2.0" +jsonschema = ">=3.0" logbook = ">=1.5,<1.6" -mashumaro = {version = "3.6", extras = ["msgpack"]} -minimal-snowplow-tracker = "0.0.2" -networkx = {version = ">=2.3,<3", markers = "python_version >= \"3.8\""} +mashumaro = {version = ">=3.9,<4.0", extras = ["msgpack"]} +minimal-snowplow-tracker = ">=0.0.2,<0.1.0" +networkx = ">=2.3,<4" packaging = ">20.9" pathspec = ">=0.9,<0.12" protobuf = ">=4.0.0" @@ -1985,99 +2006,160 @@ pyyaml = ">=6.0" requests = "<3.0.0" sqlparse = ">=0.2.3,<0.5" typing-extensions = ">=3.7.4" -werkzeug = ">=1,<3" +urllib3 = ">=1.0,<2.0" + +[[package]] +name = "dbt-databricks" +version = "1.7.3" +description = "The Databricks adapter plugin for dbt" +optional = true +python-versions = ">=3.8" +files = [ + {file = "dbt-databricks-1.7.3.tar.gz", hash = "sha256:045e26240c825342259a59004c2e35e7773b0b6cbb255e6896bd46d3810f9607"}, + {file = "dbt_databricks-1.7.3-py3-none-any.whl", hash = "sha256:7c2b7bd7228a401d8262781749fc496c825fe6050e661e5ab3f1c66343e311cc"}, +] + +[package.dependencies] +databricks-sdk = ">=0.9.0" +databricks-sql-connector = ">=2.9.3,<3.0.0" +dbt-spark = "1.7.1" +keyring = ">=23.13.0" [[package]] name = "dbt-duckdb" -version = "1.5.2" +version = "1.7.1" description = "The duckdb adapter plugin for dbt (data build tool)" optional = false -python-versions = "*" +python-versions = ">=3.8" files = [ - {file = "dbt-duckdb-1.5.2.tar.gz", hash = "sha256:3407216c21bf78fd128dccfcff3ec4bf260fb145e633432015bc7d0f123e8e4b"}, - {file = "dbt_duckdb-1.5.2-py3-none-any.whl", hash = "sha256:5d18254807bbc3e61daf4f360208ad886adf44b8525e1998168290fbe73a5cbb"}, + {file = "dbt-duckdb-1.7.1.tar.gz", hash = "sha256:e59b3e58d7a461988d000892b75ce95245cdf899c847e3a430eb2e9e10e63bb9"}, + {file = "dbt_duckdb-1.7.1-py3-none-any.whl", hash = "sha256:bd75b1a72924b942794d0c3293a1159a01f21ab9d82c9f18b22c253dedad101a"}, ] [package.dependencies] -dbt-core = ">=1.5.0,<1.6.0" -duckdb = ">=0.5.0" +dbt-core = ">=1.7.0,<1.8.0" +duckdb = ">=0.7.0" [package.extras] glue = ["boto3", "mypy-boto3-glue"] [[package]] name = "dbt-extractor" -version = "0.4.1" +version = "0.5.1" description = "A tool to analyze and extract information from Jinja used in dbt projects." optional = false python-versions = ">=3.6.1" files = [ - {file = "dbt_extractor-0.4.1-cp36-abi3-macosx_10_7_x86_64.whl", hash = "sha256:4dc715bd740e418d8dc1dd418fea508e79208a24cf5ab110b0092a3cbe96bf71"}, - {file = "dbt_extractor-0.4.1-cp36-abi3-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:bc9e0050e3a2f4ea9fe58e8794bc808e6709a0c688ed710fc7c5b6ef3e5623ec"}, - {file = "dbt_extractor-0.4.1-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76872cdee659075d6ce2df92dc62e59a74ba571be62acab2e297ca478b49d766"}, - {file = "dbt_extractor-0.4.1-cp36-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:81435841610be1b07806d72cd89b1956c6e2a84c360b9ceb3f949c62a546d569"}, - {file = "dbt_extractor-0.4.1-cp36-abi3-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:7c291f9f483eae4f60dd5859097d7ba51d5cb6c4725f08973ebd18cdea89d758"}, - {file = "dbt_extractor-0.4.1-cp36-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:822b1e911db230e1b9701c99896578e711232001027b518c44c32f79a46fa3f9"}, - {file = "dbt_extractor-0.4.1-cp36-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:554d27741a54599c39e5c0b7dbcab77400d83f908caba284a3e960db812e5814"}, - {file = "dbt_extractor-0.4.1-cp36-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a805d51a25317f53cbff951c79b9cf75421cf48e4b3e1dfb3e9e8de6d824b76c"}, - {file = "dbt_extractor-0.4.1-cp36-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:cad90ddc708cb4182dc16fe2c87b1f088a1679877b93e641af068eb68a25d582"}, - {file = "dbt_extractor-0.4.1-cp36-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:34783d788b133f223844e280e37b3f5244f2fb60acc457aa75c2667e418d5442"}, - {file = "dbt_extractor-0.4.1-cp36-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:9da211869a1220ea55c5552c1567a3ea5233a6c52fa89ca87a22465481c37bc9"}, - {file = "dbt_extractor-0.4.1-cp36-abi3-musllinux_1_2_i686.whl", hash = "sha256:7d7c47774dc051b8c18690281a55e2e3d3320e823b17e04b06bc3ff81b1874ba"}, - {file = "dbt_extractor-0.4.1-cp36-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:037907a7c7ae0391045d81338ca77ddaef899a91d80f09958f09fe374594e19b"}, - {file = "dbt_extractor-0.4.1-cp36-abi3-win32.whl", hash = "sha256:3fe8d8e28a7bd3e0884896147269ca0202ca432d8733113386bdc84c824561bf"}, - {file = "dbt_extractor-0.4.1-cp36-abi3-win_amd64.whl", hash = "sha256:35265a0ae0a250623b0c2e3308b2738dc8212e40e0aa88407849e9ea090bb312"}, - {file = "dbt_extractor-0.4.1.tar.gz", hash = "sha256:75b1c665699ec0f1ffce1ba3d776f7dfce802156f22e70a7b9c8f0b4d7e80f42"}, + {file = "dbt_extractor-0.5.1-cp38-abi3-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:3b91e6106b967d908b34f83929d3f50ee2b498876a1be9c055fe060ed728c556"}, + {file = "dbt_extractor-0.5.1-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:3614ce9f83ae4cd0dc95f77730034a793a1c090a52dcf698ba1c94050afe3a8b"}, + {file = "dbt_extractor-0.5.1-cp38-abi3-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:ea4edf33035d0a060b1e01c42fb2d99316457d44c954d6ed4eed9f1948664d87"}, + {file = "dbt_extractor-0.5.1-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3b9bf50eb062b4344d9546fe42038996c6e7e7daa10724aa955d64717260e5d"}, + {file = "dbt_extractor-0.5.1-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c0ce901d4ebf0664977e4e1cbf596d4afc6c1339fcc7d2cf67ce3481566a626f"}, + {file = "dbt_extractor-0.5.1-cp38-abi3-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:cbe338b76e9ffaa18275456e041af56c21bb517f6fbda7a58308138703da0996"}, + {file = "dbt_extractor-0.5.1-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1b25fa7a276ab26aa2d70ff6e0cf4cfb1490d7831fb57ee1337c24d2b0333b84"}, + {file = "dbt_extractor-0.5.1-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c5651e458be910ff567c0da3ea2eb084fd01884cc88888ac2cf1e240dcddacc2"}, + {file = "dbt_extractor-0.5.1-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:62e4f040fd338b652683421ce48e903812e27fd6e7af58b1b70a4e1f9f2c79e3"}, + {file = "dbt_extractor-0.5.1-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:91e25ad78f1f4feadd27587ebbcc46ad909cfad843118908f30336d08d8400ca"}, + {file = "dbt_extractor-0.5.1-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:cdf9938b36cd098bcdd80f43dc03864da3f69f57d903a9160a32236540d4ddcd"}, + {file = "dbt_extractor-0.5.1-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:475e2c05b17eb4976eff6c8f7635be42bec33f15a74ceb87a40242c94a99cebf"}, + {file = "dbt_extractor-0.5.1-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:100453ba06e169cbdb118234ab3f06f6722a2e0e316089b81c88dea701212abc"}, + {file = "dbt_extractor-0.5.1-cp38-abi3-win32.whl", hash = "sha256:6916aae085fd5f2af069fd6947933e78b742c9e3d2165e1740c2e28ae543309a"}, + {file = "dbt_extractor-0.5.1-cp38-abi3-win_amd64.whl", hash = "sha256:eecc08f3743e802a8ede60c89f7b2bce872acc86120cbc0ae7df229bb8a95083"}, + {file = "dbt_extractor-0.5.1.tar.gz", hash = "sha256:cd5d95576a8dea4190240aaf9936a37fd74b4b7913ca69a3c368fc4472bb7e13"}, ] [[package]] name = "dbt-postgres" -version = "1.5.6" +version = "1.7.4" description = "The postgres adapter plugin for dbt (data build tool)" optional = true -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "dbt-postgres-1.5.6.tar.gz", hash = "sha256:b74e471dc661819a3d4bda2d11497935661ac2e25786c8a5b7314d8241b18582"}, - {file = "dbt_postgres-1.5.6-py3-none-any.whl", hash = "sha256:bc5711c9ab0ec4b57ab814b2c4e4c973554c8374b7da94b06814ac81c91f67ef"}, + {file = "dbt-postgres-1.7.4.tar.gz", hash = "sha256:16185b8de36d1a2052a2e4b85512306ab55085b1ea323a353d0dc3628473208d"}, + {file = "dbt_postgres-1.7.4-py3-none-any.whl", hash = "sha256:d414b070ca5e48925ea9ab12706bbb9e2294f7d4509c28e7af42268596334044"}, ] [package.dependencies] -dbt-core = "1.5.6" +agate = "*" +dbt-core = "1.7.4" psycopg2-binary = ">=2.8,<3.0" [[package]] name = "dbt-redshift" -version = "1.5.10" +version = "1.7.1" description = "The Redshift adapter plugin for dbt" optional = true python-versions = ">=3.8" files = [ - {file = "dbt-redshift-1.5.10.tar.gz", hash = "sha256:2b9ae1a7d05349e208b0937cd7cc920ea427341ef96096021b18e4070e927f5c"}, - {file = "dbt_redshift-1.5.10-py3-none-any.whl", hash = "sha256:b7689b043535b6b0d217c2abfe924db2336beaae71f3f36ab9aa1e920d2bb2e0"}, + {file = "dbt-redshift-1.7.1.tar.gz", hash = "sha256:6da69a83038d011570d131b85171842d0858a46bca3757419ae193b5724a2119"}, + {file = "dbt_redshift-1.7.1-py3-none-any.whl", hash = "sha256:2a48b9424934f5445e4285740ebe512afaa75882138121536ccc21d027ef62f2"}, ] [package.dependencies] agate = "*" -boto3 = ">=1.26.157,<1.27.0" -dbt-core = ">=1.5.0,<1.6.0" -dbt-postgres = ">=1.5.0,<1.6.0" -redshift-connector = "2.0.913" +dbt-core = ">=1.7.0,<1.8.0" +dbt-postgres = ">=1.7.0,<1.8.0" +redshift-connector = "2.0.915" + +[[package]] +name = "dbt-semantic-interfaces" +version = "0.4.3" +description = "The shared semantic layer definitions that dbt-core and MetricFlow use" +optional = false +python-versions = ">=3.8" +files = [ + {file = "dbt_semantic_interfaces-0.4.3-py3-none-any.whl", hash = "sha256:af6ab8509da81ae5f5f1d5631c9761cccaed8cd5311d4824a8d4168ecd0f2093"}, + {file = "dbt_semantic_interfaces-0.4.3.tar.gz", hash = "sha256:9a46d07ad022a4c48783565a776ebc6f1d19e0412e70c4759bc9d7bba461ea1c"}, +] + +[package.dependencies] +click = ">=7.0,<9.0" +importlib-metadata = ">=6.0,<7.0" +jinja2 = ">=3.0,<4.0" +jsonschema = ">=4.0,<5.0" +more-itertools = ">=8.0,<11.0" +pydantic = ">=1.10,<3" +python-dateutil = ">=2.0,<3.0" +pyyaml = ">=6.0,<7.0" +typing-extensions = ">=4.4,<5.0" [[package]] name = "dbt-snowflake" -version = "1.5.3" +version = "1.7.1" description = "The Snowflake adapter plugin for dbt" optional = true python-versions = ">=3.8" files = [ - {file = "dbt-snowflake-1.5.3.tar.gz", hash = "sha256:cf42772d2c2f1e29a2a64b039c66d80a8593f52a2dd711a144d43b4175802f9a"}, - {file = "dbt_snowflake-1.5.3-py3-none-any.whl", hash = "sha256:8aaa939d834798e5bb10a3ba4f52fc32a53e6e5568d6c0e8b3ac644f099972ff"}, + {file = "dbt-snowflake-1.7.1.tar.gz", hash = "sha256:842a9e87b9e2d999e3bc27aaa369398a4d02bb3f8bb7447aa6151204d4eb90f0"}, + {file = "dbt_snowflake-1.7.1-py3-none-any.whl", hash = "sha256:32ef8733f67dcf4eb594d1b80852ef0b67e920f25bb8a2953031a3868a8d2b3e"}, ] [package.dependencies] -dbt-core = ">=1.5.0,<1.6.0" +agate = "*" +dbt-core = ">=1.7.0,<1.8.0" snowflake-connector-python = {version = ">=3.0,<4.0", extras = ["secure-local-storage"]} +[[package]] +name = "dbt-spark" +version = "1.7.1" +description = "The Apache Spark adapter plugin for dbt" +optional = true +python-versions = ">=3.8" +files = [ + {file = "dbt-spark-1.7.1.tar.gz", hash = "sha256:a10e5d1bfdb2ca98e7ae2badd06150e2695d9d4fa18ae2354ed5bd093d77f947"}, + {file = "dbt_spark-1.7.1-py3-none-any.whl", hash = "sha256:99b5002edcdb82058a3b0ad33eb18b91a4bdde887d94855e8bd6f633d78837dc"}, +] + +[package.dependencies] +dbt-core = ">=1.7.0,<1.8.0" +sqlparams = ">=3.0.0" + +[package.extras] +all = ["PyHive[hive-pure-sasl] (>=0.7.0,<0.8.0)", "pyodbc (>=4.0.39,<4.1.0)", "pyspark (>=3.0.0,<4.0.0)", "thrift (>=0.11.0,<0.17.0)"] +odbc = ["pyodbc (>=4.0.39,<4.1.0)"] +pyhive = ["PyHive[hive-pure-sasl] (>=0.7.0,<0.8.0)", "thrift (>=0.11.0,<0.17.0)"] +session = ["pyspark (>=3.0.0,<4.0.0)"] + [[package]] name = "decopatch" version = "1.4.10" @@ -3679,21 +3761,6 @@ doc = ["sphinx (>=5.0.0)", "sphinx-rtd-theme (>=1.0.0)", "towncrier (>=21,<22)"] lint = ["black (>=22)", "flake8 (==6.0.0)", "flake8-bugbear (==23.3.23)", "isort (>=5.10.1)", "mypy (==0.971)", "pydocstyle (>=5.0.0)"] test = ["eth-utils (>=1.0.1,<3)", "hypothesis (>=3.44.24,<=6.31.6)", "pytest (>=7.0.0)", "pytest-xdist (>=2.4.0)"] -[[package]] -name = "hologram" -version = "0.0.16" -description = "JSON schema generation from dataclasses" -optional = false -python-versions = "*" -files = [ - {file = "hologram-0.0.16-py3-none-any.whl", hash = "sha256:4e56bd525336bb64a18916f871977a4125b64be8aaa750233583003333cda361"}, - {file = "hologram-0.0.16.tar.gz", hash = "sha256:1c2c921b4e575361623ea0e0d0aa5aee377b1a333cc6c6a879e213ed34583e55"}, -] - -[package.dependencies] -jsonschema = ">=3.0" -python-dateutil = ">=2.8,<2.9" - [[package]] name = "hpack" version = "4.0.0" @@ -3816,22 +3883,22 @@ files = [ [[package]] name = "importlib-metadata" -version = "4.13.0" +version = "6.11.0" description = "Read metadata from Python packages" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "importlib_metadata-4.13.0-py3-none-any.whl", hash = "sha256:8a8a81bcf996e74fee46f0d16bd3eaa382a7eb20fd82445c3ad11f4090334116"}, - {file = "importlib_metadata-4.13.0.tar.gz", hash = "sha256:dd0173e8f150d6815e098fd354f6414b0f079af4644ddfe90c71e2fc6174346d"}, + {file = "importlib_metadata-6.11.0-py3-none-any.whl", hash = "sha256:f0afba6205ad8f8947c7d338b5342d5db2afbfd82f9cbef7879a9539cc12eb9b"}, + {file = "importlib_metadata-6.11.0.tar.gz", hash = "sha256:1231cf92d825c9e03cfc4da076a16de6422c863558229ea0b22b675657463443"}, ] [package.dependencies] zipp = ">=0.5" [package.extras] -docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] perf = ["ipython"] -testing = ["flake8 (<5)", "flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] [[package]] name = "importlib-resources" @@ -4577,13 +4644,13 @@ tests = ["pytest", "pytest-lazy-fixture"] [[package]] name = "mashumaro" -version = "3.6" -description = "Fast serialization library on top of dataclasses" +version = "3.11" +description = "Fast and well tested serialization library" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "mashumaro-3.6-py3-none-any.whl", hash = "sha256:77403e3e2ecd0a7d0e22d472c08e33282460e48726eabe356c5163efbdf9c7ee"}, - {file = "mashumaro-3.6.tar.gz", hash = "sha256:ceb3de53029219bbbb0385ca600b59348dcd14e0c68523986c6d51889ad338f5"}, + {file = "mashumaro-3.11-py3-none-any.whl", hash = "sha256:8f858bdb33790db6d9f3087dce793a26d109aeae38bed3ca9c2d7f16f19db412"}, + {file = "mashumaro-3.11.tar.gz", hash = "sha256:b0b2443be4bdad29bb209d91fe4a2a918fbd7b63cccfeb457c7eeb567db02f5e"}, ] [package.dependencies] @@ -4651,11 +4718,87 @@ files = [ requests = ">=2.2.1,<3.0" six = ">=1.9.0,<2.0" +[[package]] +name = "mmh3" +version = "4.0.1" +description = "Python extension for MurmurHash (MurmurHash3), a set of fast and robust hash functions." +optional = true +python-versions = "*" +files = [ + {file = "mmh3-4.0.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b719ba87232749095011d567a36a25e40ed029fc61c47e74a12416d8bb60b311"}, + {file = "mmh3-4.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f0ad423711c5096cf4a346011f3b3ec763208e4f4cc4b10ed41cad2a03dbfaed"}, + {file = "mmh3-4.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:80918e3f8ab6b717af0a388c14ffac5a89c15d827ff008c1ef545b8b32724116"}, + {file = "mmh3-4.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8222cd5f147defa1355b4042d590c34cef9b2bb173a159fcb72cda204061a4ac"}, + {file = "mmh3-4.0.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3821bcd1961ef19247c78c5d01b5a759de82ab0c023e2ff1d5ceed74322fa018"}, + {file = "mmh3-4.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:59f7ed28c24249a54665f1ed3f6c7c1c56618473381080f79bcc0bd1d1db2e4a"}, + {file = "mmh3-4.0.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dacd8d07d4b9be8f0cb6e8fd9a08fc237c18578cf8d42370ee8af2f5a2bf1967"}, + {file = "mmh3-4.0.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5cd00883ef6bcf7831026ce42e773a4b2a4f3a7bf9003a4e781fecb1144b06c1"}, + {file = "mmh3-4.0.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:df73d1c7f0c50c0f8061cd349968fd9dcc6a9e7592d1c834fa898f9c98f8dd7e"}, + {file = "mmh3-4.0.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:f41eeae98f15af0a4ba2a92bce11d8505b612012af664a7634bbfdba7096f5fc"}, + {file = "mmh3-4.0.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ce9bb622e9f1162cafd033071b32ac495c5e8d5863fca2a5144c092a0f129a5b"}, + {file = "mmh3-4.0.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:dd92e0ff9edee6af960d9862a3e519d651e6344321fd280fb082654fc96ecc4d"}, + {file = "mmh3-4.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:1aefa8ac8c8fc8ad93365477baef2125dbfd7235880a9c47dca2c46a0af49ef7"}, + {file = "mmh3-4.0.1-cp310-cp310-win32.whl", hash = "sha256:a076ea30ec279a63f44f4c203e4547b5710d00581165fed12583d2017139468d"}, + {file = "mmh3-4.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:5aa1e87e448ee1ffa3737b72f2fe3f5960159ab75bbac2f49dca6fb9797132f6"}, + {file = "mmh3-4.0.1-cp310-cp310-win_arm64.whl", hash = "sha256:45155ff2f291c3a1503d1c93e539ab025a13fd8b3f2868650140702b8bd7bfc2"}, + {file = "mmh3-4.0.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:91f81d6dd4d0c3b4235b4a58a545493c946669c751a2e0f15084171dc2d81fee"}, + {file = "mmh3-4.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:bbfddaf55207798f5b29341e5b3a24dbff91711c51b1665eabc9d910255a78f0"}, + {file = "mmh3-4.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0deb8e19121c0896fdc709209aceda30a367cda47f4a884fcbe56223dbf9e867"}, + {file = "mmh3-4.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df468ac7b61ec7251d7499e27102899ca39d87686f659baf47f84323f8f4541f"}, + {file = "mmh3-4.0.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84936c113814c6ef3bc4bd3d54f538d7ba312d1d0c2441ac35fdd7d5221c60f6"}, + {file = "mmh3-4.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8b1df3cf5ce5786aa093f45462118d87ff485f0d69699cdc34f6289b1e833632"}, + {file = "mmh3-4.0.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:da281aa740aa9e7f9bebb879c1de0ea9366687ece5930f9f5027e7c87d018153"}, + {file = "mmh3-4.0.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ec380933a56eb9fea16d7fcd49f1b5a5c92d7d2b86f25e9a845b72758ee8c42"}, + {file = "mmh3-4.0.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2fa905fcec8a30e1c0ef522afae1d6170c4f08e6a88010a582f67c59209fb7c7"}, + {file = "mmh3-4.0.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:9b23a06315a65ef0b78da0be32409cfce0d6d83e51d70dcebd3302a61e4d34ce"}, + {file = "mmh3-4.0.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:36c27089b12026db14be594d750f7ea6d5d785713b40a971b063f033f5354a74"}, + {file = "mmh3-4.0.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:6338341ae6fa5eaa46f69ed9ac3e34e8eecad187b211a6e552e0d8128c568eb1"}, + {file = "mmh3-4.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1aece29e27d0c8fb489d00bb712fba18b4dd10e39c9aec2e216c779ae6400b8f"}, + {file = "mmh3-4.0.1-cp311-cp311-win32.whl", hash = "sha256:2733e2160c142eed359e25e5529915964a693f0d043165b53933f904a731c1b3"}, + {file = "mmh3-4.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:09f9f643e0b7f8d98473efdfcdb155105824a38a1ada374625b84c1208197a9b"}, + {file = "mmh3-4.0.1-cp311-cp311-win_arm64.whl", hash = "sha256:d93422f38bc9c4d808c5438a011b769935a87df92ce277e9e22b6ec0ae8ed2e2"}, + {file = "mmh3-4.0.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:41013c033dc446d3bfb573621b8b53223adcfcf07be1da0bcbe166d930276882"}, + {file = "mmh3-4.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:be46540eac024dd8d9b82899d35b2f23592d3d3850845aba6f10e6127d93246b"}, + {file = "mmh3-4.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0e64114b30c6c1e30f8201433b5fa6108a74a5d6f1a14af1b041360c0dd056aa"}, + {file = "mmh3-4.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:275637ecca755565e3b0505d3ecf8e1e0a51eb6a3cbe6e212ed40943f92f98cd"}, + {file = "mmh3-4.0.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:955178c8e8d3bc9ad18eab443af670cd13fe18a6b2dba16db2a2a0632be8a133"}, + {file = "mmh3-4.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:750afe0477e0c17904611045ad311ff10bc6c2ec5f5ddc5dd949a2b9bf71d5d5"}, + {file = "mmh3-4.0.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0b7c18c35e9d6a59d6c5f94a6576f800ff2b500e41cd152ecfc7bb4330f32ba2"}, + {file = "mmh3-4.0.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5b8635b1fc6b25d93458472c5d682a1a4b9e6c53e7f4ca75d2bf2a18fa9363ae"}, + {file = "mmh3-4.0.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:057b8de47adee8ad0f2e194ffa445b9845263c1c367ddb335e9ae19c011b25cc"}, + {file = "mmh3-4.0.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:78c0ee0197cfc912f57172aa16e784ad55b533e2e2e91b3a65188cc66fbb1b6e"}, + {file = "mmh3-4.0.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:d6acb15137467592691e41e6f897db1d2823ff3283111e316aa931ac0b5a5709"}, + {file = "mmh3-4.0.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:f91b2598e1f25e013da070ff641a29ebda76292d3a7bdd20ef1736e9baf0de67"}, + {file = "mmh3-4.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a78f6f2592395321e2f0dc6b618773398b2c9b15becb419364e0960df53e9f04"}, + {file = "mmh3-4.0.1-cp38-cp38-win32.whl", hash = "sha256:d8650982d0b70af24700bd32b15fab33bb3ef9be4af411100f4960a938b0dd0f"}, + {file = "mmh3-4.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:2489949c7261870a02eeaa2ec7b966881c1775df847c8ce6ea4de3e9d96b5f4f"}, + {file = "mmh3-4.0.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:dcd03a4bb0fa3db03648d26fb221768862f089b6aec5272f0df782a8b4fe5b5b"}, + {file = "mmh3-4.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3775fb0cc675977e5b506b12b8f23cd220be3d4c2d4db7df81f03c9f61baa4cc"}, + {file = "mmh3-4.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8f250f78328d41cdf73d3ad9809359636f4fb7a846d7a6586e1a0f0d2f5f2590"}, + {file = "mmh3-4.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4161009c9077d5ebf8b472dbf0f41b9139b3d380e0bbe71bf9b503efb2965584"}, + {file = "mmh3-4.0.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2cf986ebf530717fefeee8d0decbf3f359812caebba985e2c8885c0ce7c2ee4e"}, + {file = "mmh3-4.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b55741ed51e928b1eec94a119e003fa3bc0139f4f9802e19bea3af03f7dd55a"}, + {file = "mmh3-4.0.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8250375641b8c5ce5d56a00c6bb29f583516389b8bde0023181d5eba8aa4119"}, + {file = "mmh3-4.0.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:29373e802bc094ffd490e39047bac372ac893c0f411dac3223ef11775e34acd0"}, + {file = "mmh3-4.0.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:071ba41e56f5c385d13ee84b288ccaf46b70cd9e9a6d8cbcbe0964dee68c0019"}, + {file = "mmh3-4.0.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:909e0b88d2c6285481fa6895c2a0faf6384e1b0093f72791aa57d1e04f4adc65"}, + {file = "mmh3-4.0.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:51d356f4380f9d9c2a0612156c3d1e7359933991e84a19304440aa04fd723e68"}, + {file = "mmh3-4.0.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:c4b2549949efa63d8decb6572f7e75fad4f2375d52fafced674323239dd9812d"}, + {file = "mmh3-4.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9bcc7b32a89c4e5c6fdef97d82e8087ba26a20c25b4aaf0723abd0b302525934"}, + {file = "mmh3-4.0.1-cp39-cp39-win32.whl", hash = "sha256:8edee21ae4f4337fb970810ef5a263e5d2212b85daca0d39daf995e13380e908"}, + {file = "mmh3-4.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:8cbb6f90f08952fcc90dbf08f0310fdf4d61096c5cb7db8adf03e23f3b857ae5"}, + {file = "mmh3-4.0.1-cp39-cp39-win_arm64.whl", hash = "sha256:ce71856cbca9d7c74d084eeee1bc5b126ed197c1c9530a4fdb994d099b9bc4db"}, + {file = "mmh3-4.0.1.tar.gz", hash = "sha256:ad8be695dc4e44a79631748ba5562d803f0ac42d36a6b97a53aca84a70809385"}, +] + +[package.extras] +test = ["mypy (>=1.0)", "pytest (>=7.0.0)"] + [[package]] name = "more-itertools" version = "10.1.0" description = "More routines for operating on iterables, beyond itertools" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "more-itertools-10.1.0.tar.gz", hash = "sha256:626c369fa0eb37bac0291bce8259b332fd59ac792fa5497b59837309cd5b114a"}, @@ -6747,12 +6890,12 @@ fastembed = ["fastembed (==0.1.1)"] [[package]] name = "redshift-connector" -version = "2.0.913" +version = "2.0.915" description = "Redshift interface library" optional = true python-versions = ">=3.6" files = [ - {file = "redshift_connector-2.0.913-py3-none-any.whl", hash = "sha256:bd70395c5b7ec9fcae9565daff6bcb88c7d3ea6182dafba2bac6138f68d00582"}, + {file = "redshift_connector-2.0.915-py3-none-any.whl", hash = "sha256:d02e8d6fa01dd46504c879953f6abd7fa72980edd1e6a80202448fe35fb4c9e4"}, ] [package.dependencies] @@ -7709,6 +7852,17 @@ toml = {version = "*", markers = "python_version < \"3.11\""} tqdm = "*" typing-extensions = "*" +[[package]] +name = "sqlparams" +version = "6.0.1" +description = "Convert between various DB API 2.0 parameter styles." +optional = true +python-versions = ">=3.8" +files = [ + {file = "sqlparams-6.0.1-py3-none-any.whl", hash = "sha256:566651376315c832876be4a0f58ffa23a23fab257d77ee492bdf8d301e169d0d"}, + {file = "sqlparams-6.0.1.tar.gz", hash = "sha256:032b2f949d4afbcbfa24003f6fb407f2fc8468184e3d8ca3d59ba6b30d4935bf"}, +] + [[package]] name = "sqlparse" version = "0.4.4" @@ -8576,7 +8730,7 @@ az = ["adlfs"] bigquery = ["gcsfs", "google-cloud-bigquery", "grpcio", "pyarrow"] cli = ["cron-descriptor", "pipdeptree"] databricks = ["databricks-sql-connector"] -dbt = ["dbt-athena-community", "dbt-bigquery", "dbt-core", "dbt-duckdb", "dbt-redshift", "dbt-snowflake"] +dbt = ["dbt-athena-community", "dbt-bigquery", "dbt-core", "dbt-databricks", "dbt-duckdb", "dbt-redshift", "dbt-snowflake"] duckdb = ["duckdb"] filesystem = ["botocore", "s3fs"] gcp = ["gcsfs", "google-cloud-bigquery", "grpcio"] @@ -8594,4 +8748,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "2305ecdd34e888ae84eed985057df674263e856047a6d59d73c86f26a93e0783" +content-hash = "2423c6a16d547ee9ab26d59d9fad49fe35fb3e1b85b7c95d82ab33efabb184f6" diff --git a/pyproject.toml b/pyproject.toml index 69d4438089..a8d63f7e6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,10 +76,11 @@ weaviate-client = {version = ">=3.22", optional = true} adlfs = {version = ">=2022.4.0", optional = true} pyodbc = {version = "^4.0.39", optional = true} qdrant-client = {version = "^1.6.4", optional = true, extras = ["fastembed"]} -databricks-sql-connector = {version = "^3.0.1", optional = true} +databricks-sql-connector = {version = ">=2.9.3,<3.0.0", optional = true} +dbt-databricks = {version = "^1.7.3", optional = true} [tool.poetry.extras] -dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community"] +dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"] gcp = ["grpcio", "google-cloud-bigquery", "db-dtypes", "gcsfs"] # bigquery is alias on gcp extras bigquery = ["grpcio", "google-cloud-bigquery", "pyarrow", "db-dtypes", "gcsfs"] From d6584d9a3534d72b1e482c9e1713b653cb9407e0 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 17 Jan 2024 15:18:58 -0500 Subject: [PATCH 49/84] Handle databricks 2.9 paramstyle --- .../impl/databricks/sql_client.py | 34 ++++++++++++------- tests/load/pipeline/test_arrow_loading.py | 1 + 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index 676d8be305..7c2e549b76 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -1,5 +1,5 @@ from contextlib import contextmanager, suppress -from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence, List +from typing import Any, AnyStr, ClassVar, Iterator, Optional, Sequence, List, Union, Dict from databricks import sql as databricks_lib from databricks.sql.client import ( @@ -101,18 +101,26 @@ def execute_many( @raise_database_error def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DBApiCursor]: curr: DBApiCursor = None - if args: - keys = [f"arg{i}" for i in range(len(args))] - # Replace position arguments (%s) with named arguments (:arg0, :arg1, ...) - query = query % tuple(f":{key}" for key in keys) - db_args = {} - for key, db_arg in zip(keys, args): - # Databricks connector doesn't accept pendulum objects - if isinstance(db_arg, pendulum.DateTime): - db_arg = to_py_datetime(db_arg) - elif isinstance(db_arg, pendulum.Date): - db_arg = to_py_date(db_arg) - db_args[key] = db_arg + # TODO: databricks connector 3.0.0 will use :named paramstyle only + # if args: + # keys = [f"arg{i}" for i in range(len(args))] + # # Replace position arguments (%s) with named arguments (:arg0, :arg1, ...) + # # query = query % tuple(f":{key}" for key in keys) + # db_args = {} + # for key, db_arg in zip(keys, args): + # # Databricks connector doesn't accept pendulum objects + # if isinstance(db_arg, pendulum.DateTime): + # db_arg = to_py_datetime(db_arg) + # elif isinstance(db_arg, pendulum.Date): + # db_arg = to_py_date(db_arg) + # db_args[key] = db_arg + # else: + # db_args = None + db_args: Optional[Union[Dict[str, Any], Sequence[Any]]] + if kwargs: + db_args = kwargs + elif args: + db_args = args else: db_args = None with self._conn.cursor() as curr: diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index 4a3c209c32..7f159f57b7 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -35,6 +35,7 @@ def test_load_item( include_time = destination_config.destination not in ( "athena", "redshift", + "databricks", ) # athena/redshift can't load TIME columns from parquet item, records = arrow_table_all_data_types( item_type, include_json=False, include_time=include_time From e63de19060b0192fbf0ce3d250db522623cea125 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 17 Jan 2024 19:33:57 -0500 Subject: [PATCH 50/84] Databricks docs --- .../dlt-ecosystem/destinations/databricks.md | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 docs/website/docs/dlt-ecosystem/destinations/databricks.md diff --git a/docs/website/docs/dlt-ecosystem/destinations/databricks.md b/docs/website/docs/dlt-ecosystem/destinations/databricks.md new file mode 100644 index 0000000000..4fadea7fc0 --- /dev/null +++ b/docs/website/docs/dlt-ecosystem/destinations/databricks.md @@ -0,0 +1,102 @@ +--- + +title: Databricks +description: Databricks `dlt` destination +keywords: [Databricks, destination, data warehouse] + +--- + +# Databricks + +## Install dlt with Databricks +**To install the DLT library with Databricks dependencies:** +``` +pip install dlt[databricks] +``` + +## Setup Guide + +**1. Initialize a project with a pipeline that loads to Databricks by running** +``` +dlt init chess databricks +``` + +**2. Install the necessary dependencies for Databricks by running** +``` +pip install -r requirements.txt +``` +This will install dlt with **databricks** extra which contains Databricks Python dbapi client. + +**4. Enter your credentials into `.dlt/secrets.toml`.** + +This should have your connection parameters and your personal access token. + +It should now look like: + +```toml +[destination.databricks.credentials] +server_hostname = "MY_DATABRICKS.azuredatabricks.net" +http_path = "/sql/1.0/warehouses/12345" +access_token "MY_ACCESS_TOKEN" +catalog = "my_catalog" +``` + +## Write disposition +All write dispositions are supported + +## Data loading +Data is loaded using `INSERT VALUES` statements by default. + +Efficient loading from a staging filesystem is also supported by configuring an Amazon S3 or Azure Blob Storage bucket as a staging destination. When staging is enabled `dlt` will upload data in `parquet` files to the bucket and then use `COPY INTO` statements to ingest the data into Databricks. +For more information on staging, see the [staging support](#staging-support) section below. + +## Supported file formats +* [insert-values](../file-formats/insert-format.md) is used by default +* [parquet](../file-formats/parquet.md) supported when staging is enabled + +## Staging support + +Databricks supports both Amazon S3 and Azure Blob Storage as staging locations. `dlt` will upload files in `parquet` format to the staging location and will instruct Databricks to load data from there. + +### Databricks and Amazon S3 + +Please refer to the [S3 documentation](./filesystem.md#aws-s3) to learn how to set up your bucket with the bucket_url and credentials. For s3, the dlt Databricks loader will use the AWS credentials provided for s3 to access the s3 bucket if not specified otherwise (see config options below). You can specify your s3 bucket directly in your d + +lt configuration: + +To set up Databricks with s3 as a staging destination: + +```python +import dlt + +# Create a dlt pipeline that will load +# chess player data to the Databricks destination +# via staging on s3 +pipeline = dlt.pipeline( + pipeline_name='chess_pipeline', + destination='databricks', + staging=dlt.destinations.filesystem('s3://your-bucket-name'), # add this to activate the staging location + dataset_name='player_data', +) +``` + +### Databricks and Azure Blob Storage + +Refer to the [Azure Blob Storage filesystem documentation](./filesystem.md#azure-blob-storage) for setting up your container with the bucket_url and credentials. For Azure Blob Storage, Databricks can directly load data from the storage container specified in the configuration: + +```python +# Create a dlt pipeline that will load +# chess player data to the Databricks destination +# via staging on Azure Blob Storage +pipeline = dlt.pipeline( + pipeline_name='chess_pipeline', + destination='databricks', + staging=dlt.destinations.filesystem('az://your-container-name'), # add this to activate the staging location + dataset_name='player_data' +) +``` +### dbt support +This destination [integrates with dbt](../transformations/dbt/dbt.md) via [dbt-databricks](https://github.com/databricks/dbt-databricks) + +### Syncing of `dlt` state +This destination fully supports [dlt state sync](../../general-usage/state#syncing-state-with-destination). From dad726a6012ad267e570f3bca36a21949b138dae Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 17 Jan 2024 20:22:32 -0500 Subject: [PATCH 51/84] Remove debug raise --- dlt/pipeline/pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 98fcd56b6d..3fa8da6aee 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -445,7 +445,6 @@ def normalize( runner.run_pool(normalize_step.config, normalize_step) return self._get_step_info(normalize_step) except Exception as n_ex: - raise step_info = self._get_step_info(normalize_step) raise PipelineStepFailed( self, From 3fad7e40487c8d92b800feb581a93c37b01eb2f7 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Fri, 19 Jan 2024 16:13:24 -0500 Subject: [PATCH 52/84] Fix sql load job --- dlt/destinations/job_client_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 90f4749c39..da88f0d538 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -105,7 +105,7 @@ def _string_containts_ddl_queries(self, sql: str) -> bool: return False def _split_fragments(self, sql: str) -> List[str]: - return [s for s in sql.split(";") if s.strip()] + return [s + (";" if not s.endswith(";") else "") for s in sql.split(";") if s.strip()] @staticmethod def is_sql_job(file_path: str) -> bool: From b3cb533efa387e2c6a806904938a7e2f4b4f94cc Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Fri, 26 Jan 2024 09:43:10 -0500 Subject: [PATCH 53/84] Revert debug --- dlt/common/data_writers/writers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index c2a3c68821..0f9ff09259 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -136,7 +136,7 @@ def data_format(cls) -> TFileFormatSpec: file_extension="jsonl", is_binary_format=True, supports_schema_changes=True, - supports_compression=False, + supports_compression=True, ) From 07a5eec32230ca96d29f00f9d6e4e3af70c0fd36 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Fri, 26 Jan 2024 12:46:06 -0500 Subject: [PATCH 54/84] Typo fix --- tests/load/test_job_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 2503e464c0..774c494d28 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -387,7 +387,7 @@ def test_get_storage_table_with_all_types(client: SqlJobClientBase) -> None: "time", ): continue - if client.config.destination_type in "mssql" and c["data_type"] in ("complex"): + if client.config.destination_type == "mssql" and c["data_type"] in ("complex"): continue if client.config.destination_type == "databricks" and c["data_type"] in ("complex", "time"): continue From a615a0a737759a7aea045cc4f6bc7048d7bfe4fe Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Fri, 26 Jan 2024 12:47:27 -0500 Subject: [PATCH 55/84] General execute_many method in base sql client --- dlt/destinations/job_client_impl.py | 6 +----- dlt/destinations/sql_client.py | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index da88f0d538..e7dc4bcbe2 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -50,7 +50,6 @@ FollowupJob, CredentialsConfiguration, ) -from dlt.common.utils import concat_strings_with_limit from dlt.destinations.exceptions import ( DatabaseUndefinedRelation, DestinationSchemaTampered, @@ -397,10 +396,7 @@ def _execute_schema_update_sql(self, only_tables: Iterable[str]) -> TSchemaTable sql_scripts, schema_update = self._build_schema_update_sql(only_tables) # stay within max query size when doing DDL. some db backends use bytes not characters so decrease limit by half # assuming that most of the characters in DDL encode into single bytes - for sql_fragment in concat_strings_with_limit( - sql_scripts, "\n", self.capabilities.max_query_length // 2 - ): - self.sql_client.execute_sql(sql_fragment) + self.sql_client.execute_many(sql_scripts) self._update_schema_in_storage(self.schema) return schema_update diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index 8803f88a3d..695f1a0972 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -19,6 +19,7 @@ from dlt.common.typing import TFun from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.utils import concat_strings_with_limit from dlt.destinations.exceptions import ( DestinationConnectionError, @@ -119,9 +120,23 @@ def execute_fragments( return self.execute_sql("".join(fragments), *args, **kwargs) # type: ignore def execute_many( - self, statements: Sequence[AnyStr], *args: Any, **kwargs: Any + self, statements: Sequence[str], *args: Any, **kwargs: Any ) -> Optional[Sequence[Sequence[Any]]]: - return self.execute_sql("".join(statements), *args, **kwargs) # type: ignore + """Executes multiple SQL statements as efficiently as possible. When client supports multiple statements in a single query + they are executed together in as few database calls as possible. + """ + ret = [] + if self.capabilities.supports_multiple_statements: + for sql_fragment in concat_strings_with_limit( + list(statements), "\n", self.capabilities.max_query_length // 2 + ): + ret.append(self.execute_sql(sql_fragment, *args, **kwargs)) + else: + for statement in statements: + result = self.execute_sql(statement, *args, **kwargs) + if result is not None: + ret.append(result) + return ret @abstractmethod def fully_qualified_dataset_name(self, escape: bool = True) -> str: From 051e9d3331af5c1538970ace9c3ecc4637da4ef2 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Fri, 26 Jan 2024 12:47:39 -0500 Subject: [PATCH 56/84] Databricks client cleanup --- .../impl/databricks/databricks.py | 22 +------------------ .../impl/databricks/sql_client.py | 11 ---------- 2 files changed, 1 insertion(+), 32 deletions(-) diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 32084de240..4141ff33df 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -127,8 +127,6 @@ def __init__( credentials_clause = "" files_clause = "" # stage_file_path = "" - format_options = "" - copy_options = "COPY_OPTIONS ('mergeSchema'='true')" if bucket_path: bucket_url = urlparse(bucket_path) @@ -180,17 +178,13 @@ def __init__( ) # decide on source format, stage_file_path will either be a local file or a bucket path - source_format = "JSON" - if file_name.endswith("parquet"): - source_format = "PARQUET" + source_format = "PARQUET" # Only parquet is supported statement = f"""COPY INTO {qualified_table_name} {from_clause} {files_clause} {credentials_clause} FILEFORMAT = {source_format} - {format_options} - {copy_options} """ client.execute_sql(statement) @@ -272,18 +266,12 @@ def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: return [DatabricksMergeJob.from_table_chain(table_chain, self.sql_client)] - # def _create_staging_copy_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: - # return DatabricksStagingCopyJob.from_table_chain(table_chain, self.sql_client) - def _make_add_column_sql( self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None ) -> List[str]: # Override because databricks requires multiple columns in a single ADD COLUMN clause return ["ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c) for c in new_columns)] - # def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: - # return DatabricksStagingCopyJob.from_table_chain(table_chain, self.sql_client) - def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] ) -> List[NewLoadJob]: @@ -309,14 +297,6 @@ def _get_table_update_sql( return sql - def _execute_schema_update_sql(self, only_tables: Iterable[str]) -> TSchemaTables: - sql_scripts, schema_update = self._build_schema_update_sql(only_tables) - # stay within max query size when doing DDL. some db backends use bytes not characters so decrease limit by half - # assuming that most of the characters in DDL encode into single bytes - self.sql_client.execute_many(sql_scripts) - self._update_schema_in_storage(self.schema) - return schema_update - def _from_db_type( self, bq_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index 7c2e549b76..68ea863cc4 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -86,17 +86,6 @@ def execute_sql( f = curr.fetchall() return f - def execute_many( - self, statements: Sequence[AnyStr], *args: Any, **kwargs: Any - ) -> Optional[Sequence[Sequence[Any]]]: - """Databricks does not support multi-statement execution""" - ret = [] - for statement in statements: - result = self.execute_sql(statement, *args, **kwargs) - if result is not None: - ret.append(result) - return ret - @contextmanager @raise_database_error def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DBApiCursor]: From 64b7e2e583209146c0e13bbd532f22c66b0adea6 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Fri, 26 Jan 2024 16:32:40 -0500 Subject: [PATCH 57/84] Implement staging clone table in base class --- dlt/common/destination/capabilities.py | 2 + dlt/destinations/impl/bigquery/__init__.py | 1 + dlt/destinations/impl/bigquery/bigquery.py | 40 +++---------------- dlt/destinations/impl/databricks/__init__.py | 1 + .../impl/databricks/databricks.py | 28 +------------ dlt/destinations/impl/snowflake/__init__.py | 1 + dlt/destinations/impl/snowflake/snowflake.py | 34 +++------------- dlt/destinations/sql_jobs.py | 32 ++++++++++++++- 8 files changed, 46 insertions(+), 93 deletions(-) diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index 10d09d52b3..b891d4b31f 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -53,6 +53,8 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): timestamp_precision: int = 6 max_rows_per_insert: Optional[int] = None supports_multiple_statements: bool = True + supports_clone_table: bool = False + """Destination supports CREATE TABLE ... CLONE ... statements""" # do not allow to create default value, destination caps must be always explicitly inserted into container can_create_default: ClassVar[bool] = False diff --git a/dlt/destinations/impl/bigquery/__init__.py b/dlt/destinations/impl/bigquery/__init__.py index 1304bd72bb..6d1491817a 100644 --- a/dlt/destinations/impl/bigquery/__init__.py +++ b/dlt/destinations/impl/bigquery/__init__.py @@ -20,5 +20,6 @@ def capabilities() -> DestinationCapabilitiesContext: caps.max_text_data_type_length = 10 * 1024 * 1024 caps.is_max_text_data_type_length_in_bytes = True caps.supports_ddl_transactions = False + caps.supports_clone_table = True return caps diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index 254184b96d..1058b1d2c9 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -30,6 +30,7 @@ from dlt.destinations.impl.bigquery.configuration import BigQueryClientConfiguration from dlt.destinations.impl.bigquery.sql_client import BigQuerySqlClient, BQ_TERMINAL_REASONS from dlt.destinations.job_client_impl import SqlJobClientWithStaging +from dlt.destinations.sql_jobs import SqlMergeJob, SqlJobParams from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.sql_jobs import SqlMergeJob, SqlStagingCopyJob, SqlJobParams @@ -149,28 +150,6 @@ def gen_key_table_clauses( return sql -class BigqueryStagingCopyJob(SqlStagingCopyJob): - @classmethod - def generate_sql( - cls, - table_chain: Sequence[TTableSchema], - sql_client: SqlClientBase[Any], - params: Optional[SqlJobParams] = None, - ) -> List[str]: - sql: List[str] = [] - for table in table_chain: - with sql_client.with_staging_dataset(staging=True): - staging_table_name = sql_client.make_qualified_table_name(table["name"]) - table_name = sql_client.make_qualified_table_name(table["name"]) - sql.extend( - ( - f"DROP TABLE IF EXISTS {table_name};", - f"CREATE TABLE {table_name} CLONE {staging_table_name};", - ) - ) - return sql - - class BigQueryClient(SqlJobClientWithStaging, SupportsStagingDestination): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -190,13 +169,6 @@ def __init__(self, schema: Schema, config: BigQueryClientConfiguration) -> None: def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: return [BigQueryMergeJob.from_table_chain(table_chain, self.sql_client)] - def _create_replace_followup_jobs( - self, table_chain: Sequence[TTableSchema] - ) -> List[NewLoadJob]: - if self.config.replace_strategy == "staging-optimized": - return [BigqueryStagingCopyJob.from_table_chain(table_chain, self.sql_client)] - return super()._create_replace_followup_jobs(table_chain) - def restore_file_load(self, file_path: str) -> LoadJob: """Returns a completed SqlLoadJob or restored BigQueryLoadJob @@ -280,9 +252,9 @@ def _get_table_update_sql( elif (c := partition_list[0])["data_type"] == "date": sql[0] = f"{sql[0]}\nPARTITION BY {self.capabilities.escape_identifier(c['name'])}" elif (c := partition_list[0])["data_type"] == "timestamp": - sql[0] = ( - f"{sql[0]}\nPARTITION BY DATE({self.capabilities.escape_identifier(c['name'])})" - ) + sql[ + 0 + ] = f"{sql[0]}\nPARTITION BY DATE({self.capabilities.escape_identifier(c['name'])})" # Automatic partitioning of an INT64 type requires us to be prescriptive - we treat the column as a UNIX timestamp. # This is due to the bounds requirement of GENERATE_ARRAY function for partitioning. # The 10,000 partitions limit makes it infeasible to cover the entire `bigint` range. @@ -300,9 +272,7 @@ def _get_table_update_sql( def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: name = self.capabilities.escape_identifier(c["name"]) - return ( - f"{name} {self.type_mapper.to_db_type(c, table_format)} {self._gen_not_null(c.get('nullable', True))}" - ) + return f"{name} {self.type_mapper.to_db_type(c, table_format)} {self._gen_not_null(c.get('nullable', True))}" def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: schema_table: TTableSchemaColumns = {} diff --git a/dlt/destinations/impl/databricks/__init__.py b/dlt/destinations/impl/databricks/__init__.py index 97836a8ce2..b2e79279d6 100644 --- a/dlt/destinations/impl/databricks/__init__.py +++ b/dlt/destinations/impl/databricks/__init__.py @@ -26,4 +26,5 @@ def capabilities() -> DestinationCapabilitiesContext: # caps.supports_transactions = False caps.alter_add_multi_column = True caps.supports_multiple_statements = False + caps.supports_clone_table = True return caps diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 4141ff33df..384daf82b0 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -28,7 +28,7 @@ from dlt.destinations.impl.databricks import capabilities from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration from dlt.destinations.impl.databricks.sql_client import DatabricksSqlClient -from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob, SqlJobParams +from dlt.destinations.sql_jobs import SqlMergeJob, SqlJobParams from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.type_mapping import TypeMapper @@ -195,25 +195,6 @@ def exception(self) -> str: raise NotImplementedError() -class DatabricksStagingCopyJob(SqlStagingCopyJob): - @classmethod - def generate_sql( - cls, - table_chain: Sequence[TTableSchema], - sql_client: SqlClientBase[Any], - params: Optional[SqlJobParams] = None, - ) -> List[str]: - sql: List[str] = [] - for table in table_chain: - with sql_client.with_staging_dataset(staging=True): - staging_table_name = sql_client.make_qualified_table_name(table["name"]) - table_name = sql_client.make_qualified_table_name(table["name"]) - sql.append(f"DROP TABLE IF EXISTS {table_name};") - # recreate destination table with data cloned from staging table - sql.append(f"CREATE TABLE {table_name} CLONE {staging_table_name};") - return sql - - class DatabricksMergeJob(SqlMergeJob): @classmethod def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: @@ -272,13 +253,6 @@ def _make_add_column_sql( # Override because databricks requires multiple columns in a single ADD COLUMN clause return ["ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c) for c in new_columns)] - def _create_replace_followup_jobs( - self, table_chain: Sequence[TTableSchema] - ) -> List[NewLoadJob]: - if self.config.replace_strategy == "staging-optimized": - return [DatabricksStagingCopyJob.from_table_chain(table_chain, self.sql_client)] - return super()._create_replace_followup_jobs(table_chain) - def _get_table_update_sql( self, table_name: str, diff --git a/dlt/destinations/impl/snowflake/__init__.py b/dlt/destinations/impl/snowflake/__init__.py index d6bebd3fdd..dde4d5a382 100644 --- a/dlt/destinations/impl/snowflake/__init__.py +++ b/dlt/destinations/impl/snowflake/__init__.py @@ -21,4 +21,5 @@ def capabilities() -> DestinationCapabilitiesContext: caps.is_max_text_data_type_length_in_bytes = True caps.supports_ddl_transactions = True caps.alter_add_multi_column = True + caps.supports_clone_table = True return caps diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index 67df78c138..fb51ab9d36 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -27,7 +27,7 @@ from dlt.destinations.impl.snowflake import capabilities from dlt.destinations.impl.snowflake.configuration import SnowflakeClientConfiguration from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient -from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams +from dlt.destinations.sql_jobs import SqlJobParams from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase @@ -175,13 +175,15 @@ def __init__( f'PUT file://{file_path} @{stage_name}/"{load_id}" OVERWRITE = TRUE,' " AUTO_COMPRESS = FALSE" ) - client.execute_sql(f"""COPY INTO {qualified_table_name} + client.execute_sql( + f"""COPY INTO {qualified_table_name} {from_clause} {files_clause} {credentials_clause} FILE_FORMAT = {source_format} MATCH_BY_COLUMN_NAME='CASE_INSENSITIVE' - """) + """ + ) if stage_file_path and not keep_staged_files: client.execute_sql(f"REMOVE {stage_file_path}") @@ -192,25 +194,6 @@ def exception(self) -> str: raise NotImplementedError() -class SnowflakeStagingCopyJob(SqlStagingCopyJob): - @classmethod - def generate_sql( - cls, - table_chain: Sequence[TTableSchema], - sql_client: SqlClientBase[Any], - params: Optional[SqlJobParams] = None, - ) -> List[str]: - sql: List[str] = [] - for table in table_chain: - with sql_client.with_staging_dataset(staging=True): - staging_table_name = sql_client.make_qualified_table_name(table["name"]) - table_name = sql_client.make_qualified_table_name(table["name"]) - sql.append(f"DROP TABLE IF EXISTS {table_name};") - # recreate destination table with data cloned from staging table - sql.append(f"CREATE TABLE {table_name} CLONE {staging_table_name};") - return sql - - class SnowflakeClient(SqlJobClientWithStaging, SupportsStagingDestination): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -250,13 +233,6 @@ def _make_add_column_sql( + ",\n".join(self._get_column_def_sql(c, table_format) for c in new_columns) ] - def _create_replace_followup_jobs( - self, table_chain: Sequence[TTableSchema] - ) -> List[NewLoadJob]: - if self.config.replace_strategy == "staging-optimized": - return [SnowflakeStagingCopyJob.from_table_chain(table_chain, self.sql_client)] - return super()._create_replace_followup_jobs(table_chain) - def _get_table_update_sql( self, table_name: str, diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index 899947313d..d0911d0bea 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -74,11 +74,28 @@ class SqlStagingCopyJob(SqlBaseJob): failed_text: str = "Tried to generate a staging copy sql job for the following tables:" @classmethod - def generate_sql( + def _generate_clone_sql( cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], - params: Optional[SqlJobParams] = None, + ) -> List[str]: + """Drop and clone the table for supported destinations""" + sql: List[str] = [] + for table in table_chain: + with sql_client.with_staging_dataset(staging=True): + staging_table_name = sql_client.make_qualified_table_name(table["name"]) + table_name = sql_client.make_qualified_table_name(table["name"]) + sql.append(f"DROP TABLE IF EXISTS {table_name};") + # recreate destination table with data cloned from staging table + sql.append(f"CREATE TABLE {table_name} CLONE {staging_table_name};") + return sql + + @classmethod + def _generate_insert_sql( + cls, + table_chain: Sequence[TTableSchema], + sql_client: SqlClientBase[Any], + params: SqlJobParams = None, ) -> List[str]: sql: List[str] = [] for table in table_chain: @@ -98,6 +115,17 @@ def generate_sql( ) return sql + @classmethod + def generate_sql( + cls, + table_chain: Sequence[TTableSchema], + sql_client: SqlClientBase[Any], + params: SqlJobParams = None, + ) -> List[str]: + if params["replace"] and sql_client.capabilities.supports_clone_table: + return cls._generate_clone_sql(table_chain, sql_client) + return cls._generate_insert_sql(table_chain, sql_client, params) + class SqlMergeJob(SqlBaseJob): """Generates a list of sql statements that merge the data from staging dataset into destination dataset.""" From 97f66e28681e53fbfa8fac060f4f85d5cf05b82d Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Sat, 27 Jan 2024 10:25:19 +0100 Subject: [PATCH 58/84] ensure test data gets removed --- tests/load/synapse/test_synapse_table_indexing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/load/synapse/test_synapse_table_indexing.py b/tests/load/synapse/test_synapse_table_indexing.py index af4786af9f..e87b83fa3f 100644 --- a/tests/load/synapse/test_synapse_table_indexing.py +++ b/tests/load/synapse/test_synapse_table_indexing.py @@ -12,6 +12,9 @@ from dlt.destinations.impl.synapse.synapse_adapter import TTableIndexType from tests.load.utils import TABLE_UPDATE, TABLE_ROW_ALL_DATA_TYPES +from tests.load.pipeline.utils import ( + drop_pipeline, +) # this import ensures all test data gets removed from tests.load.synapse.utils import get_storage_table_index_type From 90685e7105c3b3f7c2a5981359fd6453cdafc721 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Sat, 27 Jan 2024 11:30:00 +0100 Subject: [PATCH 59/84] add pyarrow to synapse dependencies for parquet loading --- poetry.lock | 4 ++-- pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/poetry.lock b/poetry.lock index 400bcb61e2..6b5625e10a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -8466,10 +8466,10 @@ qdrant = ["qdrant-client"] redshift = ["psycopg2-binary", "psycopg2cffi"] s3 = ["botocore", "s3fs"] snowflake = ["snowflake-connector-python"] -synapse = ["adlfs", "pyodbc"] +synapse = ["adlfs", "pyarrow", "pyodbc"] weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "75a5f533e9456898ad0157b699d76d9c5a1abf8f4cd04ed7be2235ae3198e16c" +content-hash = "61fa24ff52200b5bf97906a376826f00350abc8f6810fb2fcea73abaf245437f" diff --git a/pyproject.toml b/pyproject.toml index f6ae77b593..fab301ad02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,7 +96,7 @@ cli = ["pipdeptree", "cron-descriptor"] athena = ["pyathena", "pyarrow", "s3fs", "botocore"] weaviate = ["weaviate-client"] mssql = ["pyodbc"] -synapse = ["pyodbc", "adlfs"] +synapse = ["pyodbc", "adlfs", "pyarrow"] qdrant = ["qdrant-client"] [tool.poetry.scripts] From 494e45b7b15ca041bdce15e66fcd52df59096b4d Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Sun, 28 Jan 2024 00:35:10 +0100 Subject: [PATCH 60/84] added user docs for synapse destination --- dlt/destinations/impl/synapse/README.md | 58 ----- .../docs/dlt-ecosystem/destinations/mssql.md | 14 +- .../dlt-ecosystem/destinations/synapse.md | 208 ++++++++++++++++++ 3 files changed, 214 insertions(+), 66 deletions(-) delete mode 100644 dlt/destinations/impl/synapse/README.md create mode 100644 docs/website/docs/dlt-ecosystem/destinations/synapse.md diff --git a/dlt/destinations/impl/synapse/README.md b/dlt/destinations/impl/synapse/README.md deleted file mode 100644 index b133faf67a..0000000000 --- a/dlt/destinations/impl/synapse/README.md +++ /dev/null @@ -1,58 +0,0 @@ -# Set up loader user -Execute the following SQL statements to set up the [loader](https://learn.microsoft.com/en-us/azure/synapse-analytics/sql/data-loading-best-practices#create-a-loading-user) user: -```sql --- on master database - -CREATE LOGIN loader WITH PASSWORD = 'YOUR_LOADER_PASSWORD_HERE'; -``` - -```sql --- on minipool database - -CREATE USER loader FOR LOGIN loader; - --- DDL permissions -GRANT CREATE TABLE ON DATABASE :: minipool TO loader; -GRANT CREATE VIEW ON DATABASE :: minipool TO loader; - --- DML permissions -GRANT SELECT ON DATABASE :: minipool TO loader; -GRANT INSERT ON DATABASE :: minipool TO loader; -GRANT ADMINISTER DATABASE BULK OPERATIONS TO loader; -``` - -```sql --- https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-workload-isolation - -CREATE WORKLOAD GROUP DataLoads -WITH ( - MIN_PERCENTAGE_RESOURCE = 0 - ,CAP_PERCENTAGE_RESOURCE = 50 - ,REQUEST_MIN_RESOURCE_GRANT_PERCENT = 25 -); - -CREATE WORKLOAD CLASSIFIER [wgcELTLogin] -WITH ( - WORKLOAD_GROUP = 'DataLoads' - ,MEMBERNAME = 'loader' -); -``` - -# config.toml -```toml -[destination.synapse.credentials] -database = "minipool" -username = "loader" -host = "dlt-synapse-ci.sql.azuresynapse.net" -port = 1433 -driver = "ODBC Driver 18 for SQL Server" - -[destination.synapse] -create_indexes = false -``` - -# secrets.toml -```toml -[destination.synapse.credentials] -password = "YOUR_LOADER_PASSWORD_HERE" -``` \ No newline at end of file diff --git a/docs/website/docs/dlt-ecosystem/destinations/mssql.md b/docs/website/docs/dlt-ecosystem/destinations/mssql.md index d64cf9b400..e98f8bf256 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/mssql.md +++ b/docs/website/docs/dlt-ecosystem/destinations/mssql.md @@ -16,16 +16,14 @@ pip install dlt[mssql] ### Prerequisites -Microsoft ODBC driver for SQL Server must be installed to use this destination. -This can't be included with `dlt`s python dependencies so you must installed it separately on your system. +_Microsoft ODBC Driver for SQL Server_ must be installed to use this destination. +This can't be included with `dlt`'s python dependencies, so you must install it separately on your system. You can find the official installation instructions [here](https://learn.microsoft.com/en-us/sql/connect/odbc/download-odbc-driver-for-sql-server?view=sql-server-ver16). -See instructions here to [install Microsoft ODBC Driver 18 for SQL Server on Windows, Mac and Linux](https://learn.microsoft.com/en-us/sql/connect/odbc/download-odbc-driver-for-sql-server?view=sql-server-ver16) +Supported driver versions: +* `ODBC Driver 18 for SQL Server` +* `ODBC Driver 17 for SQL Server` -Following ODBC drivers are supported: -* ODBC Driver 18 for SQL Server -* ODBC Driver 17 for SQL Server - -[You can configure driver name explicitly](#additional-destination-options) as well. +You can [configure driver name](#additional-destination-options) explicitly as well. ### Create a pipeline diff --git a/docs/website/docs/dlt-ecosystem/destinations/synapse.md b/docs/website/docs/dlt-ecosystem/destinations/synapse.md new file mode 100644 index 0000000000..4d66714ce3 --- /dev/null +++ b/docs/website/docs/dlt-ecosystem/destinations/synapse.md @@ -0,0 +1,208 @@ +--- +title: Azure Synapse +description: Azure Synapse `dlt` destination +keywords: [synapse, destination, data warehouse] +--- + +# Synapse + +## Install dlt with Synapse +**To install the DLT library with Synapse dependencies:** +``` +pip install dlt[synapse] +``` + +## Setup guide + +### Prerequisites + +* **Microsoft ODBC Driver for SQL Server** + + _Microsoft ODBC Driver for SQL Server_ must be installed to use this destination. + This can't be included with `dlt`'s python dependencies, so you must install it separately on your system. You can find the official installation instructions [here](https://learn.microsoft.com/en-us/sql/connect/odbc/download-odbc-driver-for-sql-server?view=sql-server-ver16). + + Supported driver versions: + * `ODBC Driver 18 for SQL Server` + + > 💡 Older driver versions don't properly work, because they don't support the `LongAsMax` keyword that got [introduced](https://learn.microsoft.com/en-us/sql/connect/odbc/windows/features-of-the-microsoft-odbc-driver-for-sql-server-on-windows?view=sql-server-ver15#microsoft-odbc-driver-180-for-sql-server-on-windows) in `ODBC Driver 18 for SQL Server`. Synapse does not support the legacy ["long data types"](https://learn.microsoft.com/en-us/sql/t-sql/data-types/ntext-text-and-image-transact-sql), and requires "max data types" instead. `dlt` uses the `LongAsMax` keyword to automatically do the conversion. +* **Azure Synapse Workspace and dedicated SQL pool** + + You need an Azure Synapse workspace with a dedicated SQL pool to load data into. If you don't have one yet, you can use this [quickstart](https://learn.microsoft.com/en-us/azure/synapse-analytics/quickstart-create-sql-pool-studio). + +### Steps + +**1. Initialize a project with a pipeline that loads to Synapse by running** +``` +dlt init chess synapse +``` + +**2. Install the necessary dependencies for Synapse by running** +``` +pip install -r requirements.txt +``` +This will install `dlt` with the **synapse** extra that contains all dependencies required for the Synapse destination. + +**3. Create a loader user** + +Execute the following SQL statements to set up the [loader](https://learn.microsoft.com/en-us/azure/synapse-analytics/sql/data-loading-best-practices#create-a-loading-user) user. Change the password and replace `yourpool` with the name of your dedicated SQL pool: +```sql +-- on master database, using a SQL admin account + +CREATE LOGIN loader WITH PASSWORD = 'your_loader_password'; +``` + +```sql +-- on yourpool database + +CREATE USER loader FOR LOGIN loader; + +-- DDL permissions +GRANT CREATE SCHEMA ON DATABASE :: yourpool TO loader; +GRANT CREATE TABLE ON DATABASE :: yourpool TO loader; +GRANT CREATE VIEW ON DATABASE :: yourpool TO loader; + +-- DML permissions +GRANT ADMINISTER DATABASE BULK OPERATIONS TO loader; -- only required when loading from staging Storage Account +``` + +Optionally, you can create a `WORKLOAD GROUP` and add the `loader` user as a member to manage [workload isolation](https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-workload-isolation). See the [instructions](https://learn.microsoft.com/en-us/azure/synapse-analytics/sql/data-loading-best-practices#create-a-loading-user) on setting up a loader user for an example of how to do this. + +**3. Enter your credentials into `.dlt/secrets.toml`.** + +Example, replace with your database connection info: +```toml +[destination.synapse.credentials] +database = "yourpool" +username = "loader" +password = "your_loader_password" +host = "your_synapse_workspace_name.sql.azuresynapse.net" +``` + +Equivalently, you can also pass a connection string as follows: + +```toml +# keep it at the top of your toml file! before any section starts +destination.synapse.credentials = "synapse://loader:your_loader_password@your_synapse_workspace_name.azuresynapse.net/yourpool" +``` + +To pass credentials directly you can use the `credentials` argument of `dlt.destinations.synapse(...)`: +```python +pipeline = dlt.pipeline( + pipeline_name='chess', + destination=dlt.destinations.synapse( + credentials='synapse://loader:your_loader_password@your_synapse_workspace_name.azuresynapse.net/yourpool' + ), + dataset_name='chess_data' +) +``` + +## Write disposition +All write dispositions are supported + +If you set the [`replace` strategy](../../general-usage/full-loading.md) to `staging-optimized`, the destination tables will be dropped and replaced by the staging tables with an `ALTER SCHEMA ... TRANSFER` command. Please note that this operation is **not** atomic—it involves multiple DDL commands and Synapse does not support DDL transactions. + +## Data loading +Data is loaded via `INSERT` statements by default. + +> 💡 Multi-row `INSERT INTO ... VALUES` statements are **not** possible in Synapse, because it doesn't support the [Table Value Constructor](https://learn.microsoft.com/en-us/sql/t-sql/queries/table-value-constructor-transact-sql). `dlt` uses `INSERT INTO ... SELECT ... UNION` statements as described [here](https://stackoverflow.com/a/73579830) to work around this limitation. + +## Supported file formats +* [insert-values](../file-formats/insert-format.md) is used by default +* [parquet](../file-formats/parquet.md) is used when [staging](#staging-support) is enabled + +## Data type limitations +* **Synapse cannot load `TIME` columns from `parquet` files**. `dlt` will fail such jobs permanently. Use the `insert_values` file format instead, or convert `datetime.time` objects to `str` or `datetime.datetime`, to load `TIME` columns. +* **Synapse does not have a complex/JSON/struct data type**. The `dlt` `complex` data type is mapped to the `nvarchar` type in Synapse. + +## Table index type +The [table index type](https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-index) of the created tables can be configured at the resource level with the `synapse_adapter`: + +```python +info = pipeline.run( + synapse_adapter( + data=your_resource, + table_index_type="clustered_columnstore_index", + ) +) +``` + +Possible values: +* `heap`: create [HEAP](https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-index#heap-tables) tables that do not have an index **(default)** +* `clustered_columnstore_index`: create [CLUSTERED COLUMNSTORE INDEX](https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-index#clustered-columnstore-indexes) tables + + +> ❗ Important: +>* **Set `default_table_index_type` to `"clustered_columnstore_index"` if you want to change the default** (see [additional destination options](#additional-destination-options)). +>* **CLUSTERED COLUMNSTORE INDEX tables do not support the `varchar(max)`, `nvarchar(max)`, and `varbinary(max)` data types.** If you don't specify the `precision` for columns that map to any of these types, `dlt` will use the maximum lengths `varchar(4000)`, `nvarchar(4000)`, and `varbinary(8000)`. +>* **While Synapse creates CLUSTERED COLUMNSTORE INDEXES by default, `dlt` creates HEAP tables by default.** HEAP is a more robust choice, because it supports all data types and doesn't require conversions. +>* **When using the `staging-optimized` [`replace` strategy](../../general-usage/full-loading.md), the staging tables are always created as HEAP tables**—any configuration of the table index types is ignored. The HEAP strategy makes sense + for staging tables for reasons explained [here](https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-index#heap-tables). +>* **`dlt` system tables are always created as HEAP tables, regardless of any configuration.** This is in line with Microsoft's recommendation that "for small lookup tables, less than 60 million rows, consider using HEAP or clustered index for faster query performance." +>* Child tables, if any, inherent the table index type of their parent table. + +## Supported column hints + +Synapse supports the following [column hints](https://dlthub.com/docs/general-usage/schema#tables-and-columns): + +* `primary_key` - creates a `PRIMARY KEY NONCLUSTERED NOT ENFORCED` constraint on the column +* `unique` - creates a `UNIQUE NOT ENFORCED` constraint on the column + +> ❗ These hints are **disabled by default**. This is because the `PRIMARY KEY` and `UNIQUE` [constraints](https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-table-constraints) are tricky in Synapse: they are **not enforced** and can lead to innacurate results if the user does not ensure all column values are unique. For the column hints to take effect, the `create_indexes` configuration needs to be set to `True`, see [additional destination options](#additional-destination-options). + +## Load concurrency issue +`dlt` uses threading to enable concurrent processing and [parallel loading](../../reference/performance.md#load). Concurrency does not work properly in all cases when using the `staging-optimized` [`replace` strategy](../../general-usage/full-loading.md), because Synapse suspends the CTAS queries that `dlt` uses behind the scenes and gets stuck. To prevent this from happening, `dlt` automatically sets the number of load workers to 1 to disable concurrency when replacing data using the `staging-optimized` strategy. Set `auto_disable_concurrency = "false"` if you don't want this to happen (see [additional destination options](#additional-destination-options)) + +## Staging support +Synapse supports Azure Blob Storage (both standard and [ADLS Gen2](https://learn.microsoft.com/en-us/azure/storage/blobs/data-lake-storage-introduction)) as a file staging destination. `dlt` first uploads Parquet files to the blob container, and then instructs Synapse to read the Parquet file and load its data into a Synapse table using the [COPY INTO](https://learn.microsoft.com/en-us/sql/t-sql/statements/copy-into-transact-sql) statement. + +Please refer to the [Azure Blob Storage filesystem documentation](./filesystem.md#azure-blob-storage) to learn how to configure credentials for the staging destination. By default, `dlt` will use these credentials for both the write into the blob container, and the read from it to load into Synapse. Managed Identity authentication can be enabled through the `staging_use_msi` option (see [additional destination options](#additional-destination-options)). + +To run Synapse with staging on Azure Blob Storage: + +```python +# Create a dlt pipeline that will load +# chess player data to the snowflake destination +# via staging on Azure Blob Storage +pipeline = dlt.pipeline( + pipeline_name='chess_pipeline', + destination='synapse', + staging='filesystem', # add this to activate the staging location + dataset_name='player_data' +) +``` + +## Additional destination options +The following settings can optionally be configured: +```toml +[destination.synapse] +default_table_index_type = "heap" +create_indexes = "false" +auto_disable_concurrency = "true" +staging_use_msi = "false" + +[destination.synapse.credentials] +port = "1433" +connect_timeout = 15 +``` + +`port` and `connect_timeout` can also be included in the connection string: + +```toml +# keep it at the top of your toml file! before any section starts +destination.synapse.credentials = "synapse://loader:your_loader_password@your_synapse_workspace_name.azuresynapse.net:1433/yourpool?connect_timeout=15" +``` + +Descriptions: +- `default_table_index_type` sets the [table index type](#table-index-type) that is used if no table index type is specified on the resource. +- `create_indexes` determines if `primary_key` and `unique` [column hints](#supported-column-hints) are applied. +- `auto_disable_concurrency` determines if concurrency is automatically disabled in cases where it might cause issues. +- `staging_use_msi` determines if the Managed Identity of the Synapse workspace is used to authorize access to the [staging](#staging-support) Storage Account. Ensure the Managed Identity has the [Storage Blob Data Reader](https://learn.microsoft.com/en-us/azure/role-based-access-control/built-in-roles#storage-blob-data-reader) role (or a higher-priviliged role) assigned on the blob container if you set this option to `"true"`. +- `port` used for the ODBC connection. +- `connect_timeout` sets the timeout for the `pyodbc` connection attempt, in seconds. + +### dbt support +Integration with [dbt](../transformations/dbt/dbt.md) is currently not supported. + +### Syncing of `dlt` state +This destination fully supports [dlt state sync](../../general-usage/state#syncing-state-with-destination). + From e8c6b1dcf08cfe03c469526c5f98c8d1159ad539 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Sun, 28 Jan 2024 13:22:44 +0100 Subject: [PATCH 61/84] refactor dbt test skipping to prevent unnecessary venv creation --- .github/workflows/test_destination_mssql.yml | 4 ++-- tests/load/pipeline/test_dbt_helper.py | 20 +++++++++----------- tests/load/utils.py | 11 ++++++++--- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/.github/workflows/test_destination_mssql.yml b/.github/workflows/test_destination_mssql.yml index b8ea1db2d4..d1da25c067 100644 --- a/.github/workflows/test_destination_mssql.yml +++ b/.github/workflows/test_destination_mssql.yml @@ -71,11 +71,11 @@ jobs: run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml - run: | - poetry run pytest tests/load --ignore tests/load/pipeline/test_dbt_helper.py + poetry run pytest tests/load if: runner.os != 'Windows' name: Run tests Linux/MAC - run: | - poetry run pytest tests/load --ignore tests/load/pipeline/test_dbt_helper.py + poetry run pytest tests/load if: runner.os == 'Windows' name: Run tests Windows shell: cmd diff --git a/tests/load/pipeline/test_dbt_helper.py b/tests/load/pipeline/test_dbt_helper.py index e919409311..91318d0f34 100644 --- a/tests/load/pipeline/test_dbt_helper.py +++ b/tests/load/pipeline/test_dbt_helper.py @@ -28,7 +28,9 @@ def dbt_venv() -> Iterator[Venv]: @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True, supports_dbt=True), + ids=lambda x: x.name, ) def test_run_jaffle_package( destination_config: DestinationTestConfiguration, dbt_venv: Venv @@ -37,8 +39,6 @@ def test_run_jaffle_package( pytest.skip( "dbt-athena requires database to be created and we don't do it in case of Jaffle" ) - if not destination_config.supports_dbt: - pytest.skip("dbt is not supported for this destination configuration") pipeline = destination_config.setup_pipeline("jaffle_jaffle", full_refresh=True) # get runner, pass the env from fixture dbt = dlt.dbt.package(pipeline, "https://github.com/dbt-labs/jaffle_shop.git", venv=dbt_venv) @@ -65,14 +65,13 @@ def test_run_jaffle_package( @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True, supports_dbt=True), + ids=lambda x: x.name, ) def test_run_chess_dbt(destination_config: DestinationTestConfiguration, dbt_venv: Venv) -> None: from docs.examples.chess.chess import chess - if not destination_config.supports_dbt: - pytest.skip("dbt is not supported for this destination configuration") - # provide chess url via environ os.environ["CHESS_URL"] = "https://api.chess.com/pub/" @@ -117,16 +116,15 @@ def test_run_chess_dbt(destination_config: DestinationTestConfiguration, dbt_ven @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True, supports_dbt=True), + ids=lambda x: x.name, ) def test_run_chess_dbt_to_other_dataset( destination_config: DestinationTestConfiguration, dbt_venv: Venv ) -> None: from docs.examples.chess.chess import chess - if not destination_config.supports_dbt: - pytest.skip("dbt is not supported for this destination configuration") - # provide chess url via environ os.environ["CHESS_URL"] = "https://api.chess.com/pub/" diff --git a/tests/load/utils.py b/tests/load/utils.py index 207e32209f..5fb706985d 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -152,6 +152,7 @@ def destinations_configs( subset: Sequence[str] = (), exclude: Sequence[str] = (), file_format: Optional[TLoaderFileFormat] = None, + supports_dbt: Optional[bool] = None, ) -> List[DestinationTestConfiguration]: # sanity check for item in subset: @@ -165,7 +166,7 @@ def destinations_configs( destination_configs += [ DestinationTestConfiguration(destination=destination) for destination in SQL_DESTINATIONS - if destination not in ("athena", "synapse") + if destination not in ("athena", "mssql", "synapse") ] destination_configs += [ DestinationTestConfiguration(destination="duckdb", file_format="parquet") @@ -192,9 +193,9 @@ def destinations_configs( extra_info="iceberg", ) ] - # dbt for Synapse has some complications and I couldn't get it to pass all tests. destination_configs += [ - DestinationTestConfiguration(destination="synapse", supports_dbt=False) + DestinationTestConfiguration(destination="mssql", supports_dbt=False), + DestinationTestConfiguration(destination="synapse", supports_dbt=False), ] if default_vector_configs: @@ -347,6 +348,10 @@ def destinations_configs( destination_configs = [ conf for conf in destination_configs if conf.file_format == file_format ] + if supports_dbt is not None: + destination_configs = [ + conf for conf in destination_configs if conf.supports_dbt == supports_dbt + ] # filter out excluded configs destination_configs = [ From e1e9bb38c48df79b8467ef68ae9f93a781c301b1 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Mon, 29 Jan 2024 00:27:38 +0100 Subject: [PATCH 62/84] replace CTAS with CREATE TABLE to eliminate concurrency issues --- .../impl/synapse/configuration.py | 49 ------------------- dlt/destinations/impl/synapse/factory.py | 5 +- dlt/destinations/impl/synapse/synapse.py | 20 ++++---- dlt/pipeline/pipeline.py | 4 -- .../load/pipeline/test_replace_disposition.py | 4 -- 5 files changed, 12 insertions(+), 70 deletions(-) diff --git a/dlt/destinations/impl/synapse/configuration.py b/dlt/destinations/impl/synapse/configuration.py index cc0e40114b..bb1ba632dc 100644 --- a/dlt/destinations/impl/synapse/configuration.py +++ b/dlt/destinations/impl/synapse/configuration.py @@ -53,60 +53,11 @@ class SynapseClientConfiguration(MsSqlClientConfiguration): create_indexes: bool = False """Whether `primary_key` and `unique` column hints are applied.""" - # Concurrency is disabled by overriding the configured number of workers to 1 at runtime. - auto_disable_concurrency: bool = True - """Whether concurrency is automatically disabled in cases where it might cause issues.""" - staging_use_msi: bool = False """Whether the managed identity of the Synapse workspace is used to authorize access to the staging Storage Account.""" __config_gen_annotations__: ClassVar[List[str]] = [ "default_table_index_type", "create_indexes", - "auto_disable_concurrency", "staging_use_msi", ] - - def get_load_workers(self, tables: TSchemaTables, workers: int) -> int: - """Returns the adjusted number of load workers to prevent concurrency issues.""" - - write_dispositions = [get_write_disposition(tables, table_name) for table_name in tables] - n_replace_dispositions = len([d for d in write_dispositions if d == "replace"]) - if ( - n_replace_dispositions > 1 - and self.replace_strategy == "staging-optimized" - and workers > 1 - ): - warning_msg_shared = ( - 'Data is being loaded into Synapse with write disposition "replace"' - ' and replace strategy "staging-optimized", while the number of' - f" load workers ({workers}) > 1. This configuration is problematic" - " in some cases, because Synapse does not always handle concurrency well" - " with the CTAS queries that are used behind the scenes to implement" - ' the "staging-optimized" strategy.' - ) - if self.auto_disable_concurrency: - logger.warning( - warning_msg_shared - + " The number of load workers will be automatically adjusted" - " and set to 1 to eliminate concurrency and prevent potential" - " issues. If you don't want this to happen, set the" - " DESTINATION__SYNAPSE__AUTO_DISABLE_CONCURRENCY environment" - ' variable to "false", or add the following to your config TOML:' - "\n\n[destination.synapse]\nauto_disable_concurrency = false\n" - ) - workers = 1 # adjust workers - else: - logger.warning( - warning_msg_shared - + " If you experience your pipeline gets stuck and doesn't finish," - " try reducing the number of load workers by exporting the LOAD__WORKERS" - " environment variable or by setting it in your config TOML:" - "\n\n[load]\nworkers = 1 # a value of 1 disables all concurrency," - " but perhaps a higher value also works\n\n" - "Alternatively, you can set the DESTINATION__SYNAPSE__AUTO_DISABLE_CONCURRENCY" - ' environment variable to "true", or add the following to your config TOML' - " to automatically disable concurrency where needed:" - "\n\n[destination.synapse]\nauto_disable_concurrency = true\n" - ) - return workers diff --git a/dlt/destinations/impl/synapse/factory.py b/dlt/destinations/impl/synapse/factory.py index 0ac58001ca..b7eddd6ef7 100644 --- a/dlt/destinations/impl/synapse/factory.py +++ b/dlt/destinations/impl/synapse/factory.py @@ -30,7 +30,6 @@ def __init__( credentials: t.Union[SynapseCredentials, t.Dict[str, t.Any], str] = None, default_table_index_type: t.Optional[TTableIndexType] = "heap", create_indexes: bool = False, - auto_disable_concurrency: bool = True, staging_use_msi: bool = False, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, @@ -45,15 +44,13 @@ def __init__( a connection string in the format `synapse://user:password@host:port/database` default_table_index_type: Maps directly to the default_table_index_type attribute of the SynapseClientConfiguration object. create_indexes: Maps directly to the create_indexes attribute of the SynapseClientConfiguration object. - auto_disable_concurrency: Maps directly to the auto_disable_concurrency attribute of the SynapseClientConfiguration object. - auto_disable_concurrency: Maps directly to the staging_use_msi attribute of the SynapseClientConfiguration object. + staging_use_msi: Maps directly to the staging_use_msi attribute of the SynapseClientConfiguration object. **kwargs: Additional arguments passed to the destination config """ super().__init__( credentials=credentials, default_table_index_type=default_table_index_type, create_indexes=create_indexes, - auto_disable_concurrency=auto_disable_concurrency, staging_use_msi=staging_use_msi, destination_name=destination_name, environment=environment, diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py index eb6eae3f20..268ffad933 100644 --- a/dlt/destinations/impl/synapse/synapse.py +++ b/dlt/destinations/impl/synapse/synapse.py @@ -4,6 +4,8 @@ from textwrap import dedent from urllib.parse import urlparse, urlunparse +from dlt import current + from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( SupportsStagingDestination, @@ -181,15 +183,15 @@ def generate_sql( f" {staging_table_name};" ) # recreate staging table - # In some cases, when multiple instances of this CTAS query are - # executed concurrently, Synapse suspends the queries and hangs. - # This can be prevented by setting the env var LOAD__WORKERS = "1". - sql.append( - f"CREATE TABLE {staging_table_name}" - " WITH ( DISTRIBUTION = ROUND_ROBIN, HEAP )" # distribution must be explicitly specified with CTAS - f" AS SELECT * FROM {table_name}" - " WHERE 1 = 0;" # no data, table structure only - ) + job_client = current.pipeline().destination_client() # type: ignore[operator] + with job_client.with_staging_dataset(): + # get table columns from schema + columns = [c for c in job_client.schema.get_table_columns(table["name"]).values()] + # generate CREATE TABLE statement + create_table_stmt = job_client._get_table_update_sql( + table["name"], columns, generate_alter=False + ) + sql.extend(create_table_stmt) return sql diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 3a0a8f3931..73c8f076d1 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -483,10 +483,6 @@ def load( # make sure that destination is set and client is importable and can be instantiated client, staging_client = self._get_destination_clients(self.default_schema) - # for synapse we might need to adjust the number of load workers - if self.destination.destination_name == "synapse": - workers = client.config.get_load_workers(self.default_schema.tables, workers) # type: ignore[attr-defined] - # create default loader config and the loader load_config = LoaderConfiguration( workers=workers, diff --git a/tests/load/pipeline/test_replace_disposition.py b/tests/load/pipeline/test_replace_disposition.py index 65d3646f2d..c6db91efff 100644 --- a/tests/load/pipeline/test_replace_disposition.py +++ b/tests/load/pipeline/test_replace_disposition.py @@ -268,10 +268,6 @@ def test_replace_table_clearing( "test_replace_table_clearing", dataset_name="test_replace_table_clearing", full_refresh=True ) - if destination_config.destination == "synapse" and replace_strategy == "staging-optimized": - # this case requires load concurrency to be disabled (else the test gets stuck) - assert pipeline.destination_client().config.auto_disable_concurrency is True # type: ignore[attr-defined] - @dlt.resource(name="main_resource", write_disposition="replace", primary_key="id") def items_with_subitems(): data = { From 99a0718c74dadf51e4ff95db6d02b83cb5d64797 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Mon, 29 Jan 2024 00:30:30 +0100 Subject: [PATCH 63/84] change test config type to reduce unnecessary tests --- tests/load/utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/load/utils.py b/tests/load/utils.py index 5fb706985d..ea4e2916cc 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -264,14 +264,6 @@ def destinations_configs( bucket_url=AZ_BUCKET, extra_info="az-authorization", ), - DestinationTestConfiguration( - destination="synapse", - staging="filesystem", - file_format="parquet", - bucket_url=AZ_BUCKET, - staging_use_msi=True, - extra_info="az-managed-identity", - ), ] if all_staging_configs: @@ -304,6 +296,14 @@ def destinations_configs( bucket_url=GCS_BUCKET, extra_info="gcs-authorization", ), + DestinationTestConfiguration( + destination="synapse", + staging="filesystem", + file_format="parquet", + bucket_url=AZ_BUCKET, + staging_use_msi=True, + extra_info="az-managed-identity", + ), ] # add local filesystem destinations if requested From 6d14d576a1c3c56e2d72646678cf2655eb929f07 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Mon, 29 Jan 2024 00:48:34 +0100 Subject: [PATCH 64/84] remove trailing whitespace --- tests/load/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/load/utils.py b/tests/load/utils.py index ea4e2916cc..805925ec6a 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -303,7 +303,7 @@ def destinations_configs( bucket_url=AZ_BUCKET, staging_use_msi=True, extra_info="az-managed-identity", - ), + ), ] # add local filesystem destinations if requested From b87dd1b744bc4c9fe3f2b6ac1cbea08c58296eb5 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Mon, 29 Jan 2024 16:08:22 +0100 Subject: [PATCH 65/84] refine staging table indexing --- dlt/destinations/impl/synapse/synapse.py | 32 +++++++++++++++---- .../dlt-ecosystem/destinations/synapse.md | 3 +- .../synapse/test_synapse_table_indexing.py | 11 ++++++- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py index 268ffad933..33e6194602 100644 --- a/dlt/destinations/impl/synapse/synapse.py +++ b/dlt/destinations/impl/synapse/synapse.py @@ -70,12 +70,22 @@ def __init__(self, schema: Schema, config: SynapseClientConfiguration) -> None: def _get_table_update_sql( self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool ) -> List[str]: - table = self.get_load_table(table_name) + table = self.get_load_table(table_name, staging=self.in_staging_mode) if table is None: table_index_type = self.config.default_table_index_type else: table_index_type = cast(TTableIndexType, table.get(TABLE_INDEX_TYPE_HINT)) - if table_index_type == "clustered_columnstore_index": + if self.in_staging_mode: + final_table = self.get_load_table(table_name, staging=False) + final_table_index_type = cast( + TTableIndexType, final_table.get(TABLE_INDEX_TYPE_HINT) + ) + else: + final_table_index_type = table_index_type + if final_table_index_type == "clustered_columnstore_index": + # Even if the staging table has index type "heap", we still adjust + # the column data types to prevent errors when writing into the + # final table that has index type "clustered_columnstore_index". new_columns = self._get_columstore_valid_columns(new_columns) _sql_result = SqlJobClientBase._get_table_update_sql( @@ -129,12 +139,20 @@ def get_load_table(self, table_name: str, staging: bool = False) -> TTableSchema table = super().get_load_table(table_name, staging) if table is None: return None - if table_name in self.schema.dlt_table_names(): - # dlt tables should always be heap tables, regardless of the user - # configuration. Why? "For small lookup tables, less than 60 million rows, - # consider using HEAP or clustered index for faster query performance." - # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-index#heap-tables + if staging and self.config.replace_strategy == "insert-from-staging": + # Staging tables should always be heap tables, because "when you are + # temporarily landing data in dedicated SQL pool, you may find that + # using a heap table makes the overall process faster." + # "staging-optimized" is not included, because in that strategy the + # staging table becomes the final table, so we should already create + # it with the desired index type. + table[TABLE_INDEX_TYPE_HINT] = "heap" # type: ignore[typeddict-unknown-key] + elif table_name in self.schema.dlt_table_names(): + # dlt tables should always be heap tables, because "for small lookup + # tables, less than 60 million rows, consider using HEAP or clustered + # index for faster query performance." table[TABLE_INDEX_TYPE_HINT] = "heap" # type: ignore[typeddict-unknown-key] + # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-index#heap-tables elif table_name in self.schema.data_table_names(): if TABLE_INDEX_TYPE_HINT not in table: # If present in parent table, fetch hint from there. diff --git a/docs/website/docs/dlt-ecosystem/destinations/synapse.md b/docs/website/docs/dlt-ecosystem/destinations/synapse.md index 4d66714ce3..dcfd92b9fb 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/synapse.md +++ b/docs/website/docs/dlt-ecosystem/destinations/synapse.md @@ -135,8 +135,9 @@ Possible values: >* **Set `default_table_index_type` to `"clustered_columnstore_index"` if you want to change the default** (see [additional destination options](#additional-destination-options)). >* **CLUSTERED COLUMNSTORE INDEX tables do not support the `varchar(max)`, `nvarchar(max)`, and `varbinary(max)` data types.** If you don't specify the `precision` for columns that map to any of these types, `dlt` will use the maximum lengths `varchar(4000)`, `nvarchar(4000)`, and `varbinary(8000)`. >* **While Synapse creates CLUSTERED COLUMNSTORE INDEXES by default, `dlt` creates HEAP tables by default.** HEAP is a more robust choice, because it supports all data types and doesn't require conversions. ->* **When using the `staging-optimized` [`replace` strategy](../../general-usage/full-loading.md), the staging tables are always created as HEAP tables**—any configuration of the table index types is ignored. The HEAP strategy makes sense +>* **When using the `insert-from-staging` [`replace` strategy](../../general-usage/full-loading.md), the staging tables are always created as HEAP tables**—any configuration of the table index types is ignored. The HEAP strategy makes sense for staging tables for reasons explained [here](https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-index#heap-tables). +>* **When using the `staging-optimized` [`replace` strategy](../../general-usage/full-loading.md), the staging tables are already created with the configured table index type**, because the staging table becomes the final table. >* **`dlt` system tables are always created as HEAP tables, regardless of any configuration.** This is in line with Microsoft's recommendation that "for small lookup tables, less than 60 million rows, consider using HEAP or clustered index for faster query performance." >* Child tables, if any, inherent the table index type of their parent table. diff --git a/tests/load/synapse/test_synapse_table_indexing.py b/tests/load/synapse/test_synapse_table_indexing.py index e87b83fa3f..df90933de4 100644 --- a/tests/load/synapse/test_synapse_table_indexing.py +++ b/tests/load/synapse/test_synapse_table_indexing.py @@ -98,13 +98,22 @@ def items_with_table_index_type_specified() -> Iterator[Any]: @pytest.mark.parametrize( "table_index_type,column_schema", TABLE_INDEX_TYPE_COLUMN_SCHEMA_PARAM_GRID ) +@pytest.mark.parametrize( + # Also test staging replace strategies, to make sure the final table index + # type is not affected by staging table index type adjustments. + "replace_strategy", + ["insert-from-staging", "staging-optimized"], +) def test_resource_table_index_type_configuration( table_index_type: TTableIndexType, column_schema: Union[List[TColumnSchema], None], + replace_strategy: str, ) -> None: + os.environ["DESTINATION__REPLACE_STRATEGY"] = replace_strategy + @dlt.resource( name="items_with_table_index_type_specified", - write_disposition="append", + write_disposition="replace", columns=column_schema, ) def items_with_table_index_type_specified() -> Iterator[Any]: From 1c817bddb475385205b3f332364199c76292c2c8 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Tue, 30 Jan 2024 15:09:35 +0100 Subject: [PATCH 66/84] use generic statement to prevent repeating info --- docs/website/docs/general-usage/full-loading.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/website/docs/general-usage/full-loading.md b/docs/website/docs/general-usage/full-loading.md index 92fdf064fd..4651d156f0 100644 --- a/docs/website/docs/general-usage/full-loading.md +++ b/docs/website/docs/general-usage/full-loading.md @@ -67,6 +67,4 @@ opportunities, you should use this strategy. The `staging-optimized` strategy be recreated with a [clone command](https://docs.snowflake.com/en/sql-reference/sql/create-clone) from the staging tables. This is a low cost and fast way to create a second independent table from the data of another. Learn more about [table cloning on snowflake](https://docs.snowflake.com/en/user-guide/object-clone). -For all other destinations the `staging-optimized` will fall back to the behavior of the `insert-from-staging` strategy. - - +For all other [destinations](../dlt-ecosystem/destinations/index.md), please look at their respective documentation pages to see if and how the `staging-optimized` strategy is implemented. If it is not implemented, `dlt` will fall back to the `insert-from-staging` strategy. From 30767004f474863143b9c07a8b70e53288df7e2e Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Tue, 30 Jan 2024 15:27:59 -0500 Subject: [PATCH 67/84] Personal access token auth only --- .../impl/databricks/configuration.py | 81 +++---------------- .../impl/databricks/databricks.py | 3 - tests/load/databricks/__init__.py | 0 .../test_databricks_configuration.py | 32 ++++++++ 4 files changed, 42 insertions(+), 74 deletions(-) create mode 100644 tests/load/databricks/__init__.py create mode 100644 tests/load/databricks/test_databricks_configuration.py diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index eaf99038a5..924047e30f 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -6,93 +6,35 @@ from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration -CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog" - - @configspec class DatabricksCredentials(CredentialsConfiguration): - catalog: Optional[str] = None + catalog: str = None server_hostname: str = None http_path: str = None access_token: Optional[TSecretStrValue] = None - client_id: Optional[str] = None - client_secret: Optional[TSecretStrValue] = None - session_properties: Optional[Dict[str, Any]] = None + http_headers: Optional[Dict[str, str]] = None + session_configuration: Optional[Dict[str, Any]] = None + """Dict of session parameters that will be passed to `databricks.sql.connect`""" connection_parameters: Optional[Dict[str, Any]] = None - auth_type: Optional[str] = None - - connect_retries: int = 1 - connect_timeout: Optional[int] = None - retry_all: bool = False - - _credentials_provider: Optional[Dict[str, Any]] = None + """Additional keyword arguments that are passed to `databricks.sql.connect`""" + socket_timeout: Optional[int] = 180 __config_gen_annotations__: ClassVar[List[str]] = [ "server_hostname", "http_path", "catalog", - "schema", + "access_token", ] - def __post_init__(self) -> None: - session_properties = self.session_properties or {} - if CATALOG_KEY_IN_SESSION_PROPERTIES in session_properties: - if self.catalog is None: - self.catalog = session_properties[CATALOG_KEY_IN_SESSION_PROPERTIES] - del session_properties[CATALOG_KEY_IN_SESSION_PROPERTIES] - else: - raise ConfigurationValueError( - f"Got duplicate keys: (`{CATALOG_KEY_IN_SESSION_PROPERTIES}` " - 'in session_properties) all map to "catalog"' - ) - self.session_properties = session_properties - - if self.catalog is not None: - catalog = self.catalog.strip() - if not catalog: - raise ConfigurationValueError(f"Invalid catalog name : `{self.catalog}`.") - self.catalog = catalog - else: - self.catalog = "hive_metastore" - - connection_parameters = self.connection_parameters or {} - for key in ( - "server_hostname", - "http_path", - "access_token", - "client_id", - "client_secret", - "session_configuration", - "catalog", - "_user_agent_entry", - ): - if key in connection_parameters: - raise ConfigurationValueError(f"The connection parameter `{key}` is reserved.") - if "http_headers" in connection_parameters: - http_headers = connection_parameters["http_headers"] - if not isinstance(http_headers, dict) or any( - not isinstance(key, str) or not isinstance(value, str) - for key, value in http_headers.items() - ): - raise ConfigurationValueError( - "The connection parameter `http_headers` should be dict of strings: " - f"{http_headers}." - ) - if "_socket_timeout" not in connection_parameters: - connection_parameters["_socket_timeout"] = 180 - self.connection_parameters = connection_parameters - def to_connector_params(self) -> Dict[str, Any]: return dict( catalog=self.catalog, server_hostname=self.server_hostname, http_path=self.http_path, access_token=self.access_token, - client_id=self.client_id, - client_secret=self.client_secret, - session_properties=self.session_properties or {}, - connection_parameters=self.connection_parameters or {}, - auth_type=self.auth_type, + session_configuration=self.session_configuration or {}, + _socket_timeout=self.socket_timeout, + **(self.connection_parameters or {}), ) @@ -101,9 +43,6 @@ class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration destination_type: Final[str] = "databricks" # type: ignore[misc] credentials: DatabricksCredentials - stage_name: Optional[str] = None - """Use an existing named stage instead of the default. Default uses the implicit table stage per table""" - def __str__(self) -> str: """Return displayable destination location""" if self.staging_config: diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 384daf82b0..6773714d59 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -125,8 +125,6 @@ def __init__( ) from_clause = "" credentials_clause = "" - files_clause = "" - # stage_file_path = "" if bucket_path: bucket_url = urlparse(bucket_path) @@ -182,7 +180,6 @@ def __init__( statement = f"""COPY INTO {qualified_table_name} {from_clause} - {files_clause} {credentials_clause} FILEFORMAT = {source_format} """ diff --git a/tests/load/databricks/__init__.py b/tests/load/databricks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/load/databricks/test_databricks_configuration.py b/tests/load/databricks/test_databricks_configuration.py new file mode 100644 index 0000000000..9127e39be4 --- /dev/null +++ b/tests/load/databricks/test_databricks_configuration.py @@ -0,0 +1,32 @@ +import pytest +import os + +pytest.importorskip("databricks") + + +from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration +from dlt.common.configuration import resolve_configuration +from tests.utils import preserve_environ + + +def test_databricks_credentials_to_connector_params(): + os.environ["CREDENTIALS__SERVER_HOSTNAME"] = "my-databricks.example.com" + os.environ["CREDENTIALS__HTTP_PATH"] = "/sql/1.0/warehouses/asdfe" + os.environ["CREDENTIALS__ACCESS_TOKEN"] = "my-token" + os.environ["CREDENTIALS__CATALOG"] = "my-catalog" + # JSON encoded dict of extra args + os.environ["CREDENTIALS__CONNECTION_PARAMETERS"] = '{"extra_a": "a", "extra_b": "b"}' + + config = resolve_configuration(DatabricksClientConfiguration(dataset_name="my-dataset")) + + credentials = config.credentials + + params = credentials.to_connector_params() + + assert params["server_hostname"] == "my-databricks.example.com" + assert params["http_path"] == "/sql/1.0/warehouses/asdfe" + assert params["access_token"] == "my-token" + assert params["catalog"] == "my-catalog" + assert params["extra_a"] == "a" + assert params["extra_b"] == "b" + assert params["_socket_timeout"] == credentials.socket_timeout From 0f329648fbd0279efc5f9a2aa7342c7b20ca2c0c Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Tue, 30 Jan 2024 15:50:20 -0500 Subject: [PATCH 68/84] stage_name is not relevant --- dlt/destinations/impl/databricks/databricks.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 6773714d59..d1b0666f57 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -106,7 +106,6 @@ def __init__( table_name: str, load_id: str, client: DatabricksSqlClient, - stage_name: Optional[str] = None, staging_credentials: Optional[CredentialsConfiguration] = None, ) -> None: file_name = FileStorage.get_file_name_from_file_path(file_path) @@ -129,11 +128,8 @@ def __init__( if bucket_path: bucket_url = urlparse(bucket_path) bucket_scheme = bucket_url.scheme - # referencing an external s3/azure stage does not require explicit credentials - if bucket_scheme in ["s3", "az", "abfs", "gc", "gcs"] and stage_name: - from_clause = f"FROM ('{bucket_path}')" # referencing an staged files via a bucket URL requires explicit AWS credentials - elif ( + if ( bucket_scheme == "s3" and staging_credentials and isinstance(staging_credentials, AwsCredentialsWithoutDefaults) @@ -228,7 +224,6 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> table["name"], load_id, self.sql_client, - stage_name=self.config.stage_name, staging_credentials=( self.config.staging_config.credentials if self.config.staging_config else None ), From fa10a05f753f175521aebcdc9202fef9ad33d008 Mon Sep 17 00:00:00 2001 From: adrianbr Date: Wed, 31 Jan 2024 16:43:06 +0100 Subject: [PATCH 69/84] fix typo (#924) --- docs/website/docs/examples/transformers/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/website/docs/examples/transformers/index.md b/docs/website/docs/examples/transformers/index.md index 7ed8fd29c3..860e830aae 100644 --- a/docs/website/docs/examples/transformers/index.md +++ b/docs/website/docs/examples/transformers/index.md @@ -9,7 +9,7 @@ import Header from '../_examples-header.md';
From 8b4ad00bdf3efe147cd524b99bdb407fdf2a6bcb Mon Sep 17 00:00:00 2001 From: Ilya Gurov Date: Thu, 1 Feb 2024 01:04:21 +0400 Subject: [PATCH 70/84] test(filesystem): drop folders ignore (#922) --- tests/common/storages/utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/common/storages/utils.py b/tests/common/storages/utils.py index 3146642536..bd30ad1d6f 100644 --- a/tests/common/storages/utils.py +++ b/tests/common/storages/utils.py @@ -48,9 +48,6 @@ def assert_sample_files( assert len(all_file_items) >= 10 for item in all_file_items: - # skip pseudo files that look like folders - if item["file_url"].endswith("/"): - continue # only accept file items we know assert item["file_name"] in minimally_expected_file_items From 5d212a9f99cb5e37f6671a692941ca8a361f5ff9 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 31 Jan 2024 20:39:05 -0500 Subject: [PATCH 71/84] databricks jsonl without compression --- dlt/destinations/impl/databricks/__init__.py | 4 ++-- .../impl/databricks/databricks.py | 15 ++++++++++++++- tests/load/utils.py | 19 ++++++++++++++++++- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/dlt/destinations/impl/databricks/__init__.py b/dlt/destinations/impl/databricks/__init__.py index b2e79279d6..f63d294818 100644 --- a/dlt/destinations/impl/databricks/__init__.py +++ b/dlt/destinations/impl/databricks/__init__.py @@ -9,8 +9,8 @@ def capabilities() -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = "insert_values" caps.supported_loader_file_formats = ["insert_values"] - caps.preferred_staging_file_format = "parquet" - caps.supported_staging_file_formats = ["parquet"] + caps.preferred_staging_file_format = "jsonl" + caps.supported_staging_file_formats = ["jsonl", "parquet"] caps.escape_identifier = escape_databricks_identifier caps.escape_literal = escape_databricks_literal caps.decimal_precision = (DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE) diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index d1b0666f57..b7807f49ee 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -32,6 +32,7 @@ from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.type_mapping import TypeMapper +from dlt import config class DatabricksTypeMapper(TypeMapper): @@ -124,6 +125,7 @@ def __init__( ) from_clause = "" credentials_clause = "" + format_options_clause = "" if bucket_path: bucket_url = urlparse(bucket_path) @@ -138,6 +140,7 @@ def __init__( credentials_clause = f"""WITH(CREDENTIAL( AWS_ACCESS_KEY='{s3_creds["aws_access_key_id"]}', AWS_SECRET_KEY='{s3_creds["aws_secret_access_key"]}', + AWS_SESSION_TOKEN='{s3_creds["aws_session_token"]}' )) """ @@ -172,12 +175,22 @@ def __init__( ) # decide on source format, stage_file_path will either be a local file or a bucket path - source_format = "PARQUET" # Only parquet is supported + if file_name.endswith(".parquet"): + source_format = "PARQUET" # Only parquet is supported + elif file_name.endswith(".jsonl"): + if not config.get("data_writer.disable_compression"): + raise LoadJobTerminalException( + file_path, + "Databricks loader does not support gzip compressed JSON files. Please disable compression in the data writer configuration: https://dlthub.com/docs/reference/performance#disabling-and-enabling-file-compression", + ) + source_format = "JSON" + format_options_clause = "FORMAT_OPTIONS('inferTimestamp'='true')" statement = f"""COPY INTO {qualified_table_name} {from_clause} {credentials_clause} FILEFORMAT = {source_format} + {format_options_clause} """ client.execute_sql(statement) diff --git a/tests/load/utils.py b/tests/load/utils.py index 3587bd9fa5..bc9bd9ddce 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -99,6 +99,7 @@ class DestinationTestConfiguration: supports_merge: bool = True # TODO: take it from client base class force_iceberg: bool = False supports_dbt: bool = True + disable_compression: bool = False @property def name(self) -> str: @@ -121,7 +122,7 @@ def setup(self) -> None: os.environ["DESTINATION__FORCE_ICEBERG"] = str(self.force_iceberg) or "" """For the filesystem destinations we disable compression to make analyzing the result easier""" - if self.destination == "filesystem": + if self.destination == "filesystem" or self.disable_compression: os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" def setup_pipeline( @@ -250,6 +251,22 @@ def destinations_configs( bucket_url=AZ_BUCKET, extra_info="az-authorization", ), + DestinationTestConfiguration( + destination="databricks", + staging="filesystem", + file_format="jsonl", + bucket_url=AWS_BUCKET, + extra_info="s3-authorization", + disable_compression=True, + ), + DestinationTestConfiguration( + destination="databricks", + staging="filesystem", + file_format="jsonl", + bucket_url=AZ_BUCKET, + extra_info="s3-authorization", + disable_compression=True, + ), DestinationTestConfiguration( destination="databricks", staging="filesystem", From 1e548e07d45064db2f780e3bde79d2173ec4bc13 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Wed, 31 Jan 2024 20:42:19 -0500 Subject: [PATCH 72/84] Update jsonl in docs --- docs/website/docs/dlt-ecosystem/destinations/databricks.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/website/docs/dlt-ecosystem/destinations/databricks.md b/docs/website/docs/dlt-ecosystem/destinations/databricks.md index 4fadea7fc0..18c41d3a8c 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/databricks.md +++ b/docs/website/docs/dlt-ecosystem/destinations/databricks.md @@ -52,6 +52,7 @@ For more information on staging, see the [staging support](#staging-support) sec ## Supported file formats * [insert-values](../file-formats/insert-format.md) is used by default +* [jsonl](../file-formats/jsonl.md) supported when staging is enabled. **Note**: Currently loading compressed jsonl files is not supported. `data_writer.disable_compression` should be set to `true` in dlt config * [parquet](../file-formats/parquet.md) supported when staging is enabled ## Staging support From bf453dd379c21159858386a6a82339cb730ccb3e Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Thu, 1 Feb 2024 12:10:26 +0300 Subject: [PATCH 73/84] Fix a typo in incremental-loading.md (#926) --- docs/website/docs/general-usage/incremental-loading.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/website/docs/general-usage/incremental-loading.md b/docs/website/docs/general-usage/incremental-loading.md index 0dc85332de..09b8ca7a96 100644 --- a/docs/website/docs/general-usage/incremental-loading.md +++ b/docs/website/docs/general-usage/incremental-loading.md @@ -290,7 +290,7 @@ def stripe(): yield data # create resources for several endpoints on a single decorator function - for endpoint in Endpoints: + for endpoint in endpoints: yield dlt.resource( get_resource, name=endpoint.value, From 5e5ec2b9f9708329ffc803eb929bb386cedc56b9 Mon Sep 17 00:00:00 2001 From: Sultan Iman Date: Thu, 1 Feb 2024 11:38:55 +0100 Subject: [PATCH 74/84] Adjust pipeline explanation --- docs/website/docs/general-usage/pipeline.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/website/docs/general-usage/pipeline.md b/docs/website/docs/general-usage/pipeline.md index 4850027f24..06e32d4e5a 100644 --- a/docs/website/docs/general-usage/pipeline.md +++ b/docs/website/docs/general-usage/pipeline.md @@ -7,9 +7,9 @@ keywords: [pipeline, source, full refresh] # Pipeline A [pipeline](glossary.md#pipeline) is a connection that moves the data from your Python code to a -[destination](glossary.md#destination). Typically, you pass the `dlt` [sources](source.md) or -[resources](resource.md) to the pipeline. You can also pass generators, lists and other iterables to -it. When the pipeline runs, the resources get executed and the data is loaded at destination. +[destination](glossary.md#destination). Pipeline accepts `dlt` [sources](source.md) or +[resources](resource.md) as well as generators, lists and any iterables. +Once pipeline runs, all resources get evaluated and the data is loaded at destination. Example: From 84fa7eded8e6dc73ac7312d764faacb7248f363c Mon Sep 17 00:00:00 2001 From: Sultan Iman Date: Thu, 1 Feb 2024 14:27:54 +0100 Subject: [PATCH 75/84] Add missing articles and more clarifications --- docs/website/docs/general-usage/pipeline.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/website/docs/general-usage/pipeline.md b/docs/website/docs/general-usage/pipeline.md index 06e32d4e5a..095e03e96d 100644 --- a/docs/website/docs/general-usage/pipeline.md +++ b/docs/website/docs/general-usage/pipeline.md @@ -7,9 +7,9 @@ keywords: [pipeline, source, full refresh] # Pipeline A [pipeline](glossary.md#pipeline) is a connection that moves the data from your Python code to a -[destination](glossary.md#destination). Pipeline accepts `dlt` [sources](source.md) or -[resources](resource.md) as well as generators, lists and any iterables. -Once pipeline runs, all resources get evaluated and the data is loaded at destination. +[destination](glossary.md#destination). The pipeline accepts `dlt` [sources](source.md) or +[resources](resource.md) as well as generators, async generators, lists and any iterables. +Once the pipeline runs, all resources get evaluated and the data is loaded at destination. Example: From e8c08e275994489bdcb9b9f503decb056af05fe2 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 1 Feb 2024 10:40:08 -0500 Subject: [PATCH 76/84] Check and ignore empty json files in load job --- .../impl/databricks/databricks.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index b7807f49ee..a414baa381 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -1,4 +1,4 @@ -from typing import ClassVar, Dict, Optional, Sequence, Tuple, List, Any, Iterable, Type +from typing import ClassVar, Dict, Optional, Sequence, Tuple, List, Any, Iterable, Type, cast from urllib.parse import urlparse, urlunparse from dlt.common.destination import DestinationCapabilitiesContext @@ -32,6 +32,7 @@ from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.type_mapping import TypeMapper +from dlt.common.storages import FilesystemConfiguration, fsspec_from_config from dlt import config @@ -107,15 +108,16 @@ def __init__( table_name: str, load_id: str, client: DatabricksSqlClient, - staging_credentials: Optional[CredentialsConfiguration] = None, + staging_config: FilesystemConfiguration, ) -> None: file_name = FileStorage.get_file_name_from_file_path(file_path) super().__init__(file_name) + staging_credentials = staging_config.credentials qualified_table_name = client.make_qualified_table_name(table_name) # extract and prepare some vars - bucket_path = ( + bucket_path = orig_bucket_path = ( NewReferenceJob.resolve_reference(file_path) if NewReferenceJob.is_reference_job(file_path) else "" @@ -131,10 +133,8 @@ def __init__( bucket_url = urlparse(bucket_path) bucket_scheme = bucket_url.scheme # referencing an staged files via a bucket URL requires explicit AWS credentials - if ( - bucket_scheme == "s3" - and staging_credentials - and isinstance(staging_credentials, AwsCredentialsWithoutDefaults) + if bucket_scheme == "s3" and isinstance( + staging_credentials, AwsCredentialsWithoutDefaults ): s3_creds = staging_credentials.to_session_credentials() credentials_clause = f"""WITH(CREDENTIAL( @@ -145,10 +145,8 @@ def __init__( )) """ from_clause = f"FROM '{bucket_path}'" - elif ( - bucket_scheme in ["az", "abfs"] - and staging_credentials - and isinstance(staging_credentials, AzureCredentialsWithoutDefaults) + elif bucket_scheme in ["az", "abfs"] and isinstance( + staging_credentials, AzureCredentialsWithoutDefaults ): # Explicit azure credentials are needed to load from bucket without a named stage credentials_clause = f"""WITH(CREDENTIAL(AZURE_SAS_TOKEN='{staging_credentials.azure_storage_sas_token}'))""" @@ -185,6 +183,11 @@ def __init__( ) source_format = "JSON" format_options_clause = "FORMAT_OPTIONS('inferTimestamp'='true')" + # Databricks fails when trying to load empty json files, so we have to check the file size + fs, _ = fsspec_from_config(staging_config) + file_size = fs.size(orig_bucket_path) + if file_size == 0: # Empty file, do nothing + return statement = f"""COPY INTO {qualified_table_name} {from_clause} @@ -237,9 +240,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> table["name"], load_id, self.sql_client, - staging_credentials=( - self.config.staging_config.credentials if self.config.staging_config else None - ), + staging_config=cast(FilesystemConfiguration, self.config.staging_config), ) return job From 2dd979eb8d6a79334c5e624c4d68a73e86dd5d5d Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Thu, 1 Feb 2024 20:13:09 +0100 Subject: [PATCH 77/84] remove outdated documentation --- docs/website/docs/dlt-ecosystem/destinations/synapse.md | 5 ----- 1 file changed, 5 deletions(-) diff --git a/docs/website/docs/dlt-ecosystem/destinations/synapse.md b/docs/website/docs/dlt-ecosystem/destinations/synapse.md index dcfd92b9fb..8c1a7b29bc 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/synapse.md +++ b/docs/website/docs/dlt-ecosystem/destinations/synapse.md @@ -150,9 +150,6 @@ Synapse supports the following [column hints](https://dlthub.com/docs/general-us > ❗ These hints are **disabled by default**. This is because the `PRIMARY KEY` and `UNIQUE` [constraints](https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-table-constraints) are tricky in Synapse: they are **not enforced** and can lead to innacurate results if the user does not ensure all column values are unique. For the column hints to take effect, the `create_indexes` configuration needs to be set to `True`, see [additional destination options](#additional-destination-options). -## Load concurrency issue -`dlt` uses threading to enable concurrent processing and [parallel loading](../../reference/performance.md#load). Concurrency does not work properly in all cases when using the `staging-optimized` [`replace` strategy](../../general-usage/full-loading.md), because Synapse suspends the CTAS queries that `dlt` uses behind the scenes and gets stuck. To prevent this from happening, `dlt` automatically sets the number of load workers to 1 to disable concurrency when replacing data using the `staging-optimized` strategy. Set `auto_disable_concurrency = "false"` if you don't want this to happen (see [additional destination options](#additional-destination-options)) - ## Staging support Synapse supports Azure Blob Storage (both standard and [ADLS Gen2](https://learn.microsoft.com/en-us/azure/storage/blobs/data-lake-storage-introduction)) as a file staging destination. `dlt` first uploads Parquet files to the blob container, and then instructs Synapse to read the Parquet file and load its data into a Synapse table using the [COPY INTO](https://learn.microsoft.com/en-us/sql/t-sql/statements/copy-into-transact-sql) statement. @@ -178,7 +175,6 @@ The following settings can optionally be configured: [destination.synapse] default_table_index_type = "heap" create_indexes = "false" -auto_disable_concurrency = "true" staging_use_msi = "false" [destination.synapse.credentials] @@ -196,7 +192,6 @@ destination.synapse.credentials = "synapse://loader:your_loader_password@your_sy Descriptions: - `default_table_index_type` sets the [table index type](#table-index-type) that is used if no table index type is specified on the resource. - `create_indexes` determines if `primary_key` and `unique` [column hints](#supported-column-hints) are applied. -- `auto_disable_concurrency` determines if concurrency is automatically disabled in cases where it might cause issues. - `staging_use_msi` determines if the Managed Identity of the Synapse workspace is used to authorize access to the [staging](#staging-support) Storage Account. Ensure the Managed Identity has the [Storage Blob Data Reader](https://learn.microsoft.com/en-us/azure/role-based-access-control/built-in-roles#storage-blob-data-reader) role (or a higher-priviliged role) assigned on the blob container if you set this option to `"true"`. - `port` used for the ODBC connection. - `connect_timeout` sets the timeout for the `pyodbc` connection attempt, in seconds. From da5cdac7f8cce1c59a36322c2a3bdc8735591f6e Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Thu, 1 Feb 2024 20:23:14 +0100 Subject: [PATCH 78/84] add synapse destination to sidebar --- docs/website/sidebars.js | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/website/sidebars.js b/docs/website/sidebars.js index 2c9b55e6da..f92f43564a 100644 --- a/docs/website/sidebars.js +++ b/docs/website/sidebars.js @@ -87,6 +87,7 @@ const sidebars = { 'dlt-ecosystem/destinations/bigquery', 'dlt-ecosystem/destinations/duckdb', 'dlt-ecosystem/destinations/mssql', + 'dlt-ecosystem/destinations/synapse', 'dlt-ecosystem/destinations/filesystem', 'dlt-ecosystem/destinations/postgres', 'dlt-ecosystem/destinations/redshift', From d7d9e35cf49691b2e877d9e4b905b9db3e77de67 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Thu, 1 Feb 2024 23:17:53 +0100 Subject: [PATCH 79/84] add support for additional table hints --- dlt/destinations/impl/synapse/synapse_adapter.py | 6 ++++-- dlt/extract/hints.py | 11 ++++++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/dlt/destinations/impl/synapse/synapse_adapter.py b/dlt/destinations/impl/synapse/synapse_adapter.py index f135dd967a..24932736f9 100644 --- a/dlt/destinations/impl/synapse/synapse_adapter.py +++ b/dlt/destinations/impl/synapse/synapse_adapter.py @@ -1,4 +1,4 @@ -from typing import Any, Literal, Set, get_args, Final +from typing import Any, Literal, Set, get_args, Final, Dict from dlt.extract import DltResource, resource as make_resource from dlt.extract.typing import TTableHintTemplate @@ -39,6 +39,7 @@ def synapse_adapter(data: Any, table_index_type: TTableIndexType = None) -> DltR """ resource = ensure_resource(data) + additional_table_hints: Dict[str, TTableHintTemplate[Any]] = {} if table_index_type is not None: if table_index_type not in TABLE_INDEX_TYPES: allowed_types = ", ".join(TABLE_INDEX_TYPES) @@ -46,5 +47,6 @@ def synapse_adapter(data: Any, table_index_type: TTableIndexType = None) -> DltR f"Table index type {table_index_type} is invalid. Allowed table index" f" types are: {allowed_types}." ) - resource._hints[TABLE_INDEX_TYPE_HINT] = table_index_type # type: ignore[typeddict-unknown-key] + additional_table_hints[TABLE_INDEX_TYPE_HINT] = table_index_type + resource.apply_hints(additional_table_hints=additional_table_hints) return resource diff --git a/dlt/extract/hints.py b/dlt/extract/hints.py index 437dbbc6bd..e483f035fc 100644 --- a/dlt/extract/hints.py +++ b/dlt/extract/hints.py @@ -1,5 +1,5 @@ from copy import copy, deepcopy -from typing import List, TypedDict, cast, Any +from typing import List, TypedDict, cast, Any, Optional, Dict from dlt.common.schema.utils import DEFAULT_WRITE_DISPOSITION, merge_columns, new_column, new_table from dlt.common.schema.typing import ( @@ -125,6 +125,7 @@ def apply_hints( merge_key: TTableHintTemplate[TColumnNames] = None, incremental: Incremental[Any] = None, schema_contract: TTableHintTemplate[TSchemaContract] = None, + additional_table_hints: Optional[Dict[str, TTableHintTemplate[Any]]] = None, ) -> None: """Creates or modifies existing table schema by setting provided hints. Accepts both static and dynamic hints based on data. @@ -208,6 +209,14 @@ def apply_hints( t["incremental"] = None else: t["incremental"] = incremental + if additional_table_hints is not None: + # loop through provided hints and add, overwrite, or remove them + for k, v in additional_table_hints.items(): + if v: + t[k] = v # type: ignore[literal-required] + else: + t.pop(k, None) # type: ignore[misc] + self.set_hints(t) def set_hints(self, hints_template: TResourceHints) -> None: From 2d50cc7a63584ce64aa257c817784fac940793ed Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 1 Feb 2024 19:12:18 -0500 Subject: [PATCH 80/84] Cleanup --- dlt/destinations/impl/databricks/databricks.py | 3 --- dlt/destinations/impl/databricks/factory.py | 6 ------ 2 files changed, 9 deletions(-) diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index a414baa381..1ea176ef1b 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -247,9 +247,6 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") - def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: - return DatabricksMergeJob.from_table_chain(table_chain, self.sql_client) - def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: return [DatabricksMergeJob.from_table_chain(table_chain, self.sql_client)] diff --git a/dlt/destinations/impl/databricks/factory.py b/dlt/destinations/impl/databricks/factory.py index a22e4514b7..7c6c95137d 100644 --- a/dlt/destinations/impl/databricks/factory.py +++ b/dlt/destinations/impl/databricks/factory.py @@ -27,8 +27,6 @@ def client_class(self) -> t.Type["DatabricksClient"]: def __init__( self, credentials: t.Union[DatabricksCredentials, t.Dict[str, t.Any], str] = None, - stage_name: t.Optional[str] = None, - keep_staged_files: bool = False, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, **kwargs: t.Any, @@ -40,14 +38,10 @@ def __init__( Args: credentials: Credentials to connect to the databricks database. Can be an instance of `DatabricksCredentials` or a connection string in the format `databricks://user:password@host:port/database` - stage_name: Name of the stage to use for staging files. If not provided, the default stage will be used. - keep_staged_files: Should staged files be kept after loading. If False, staged files will be deleted after loading. **kwargs: Additional arguments passed to the destination config """ super().__init__( credentials=credentials, - stage_name=stage_name, - keep_staged_files=keep_staged_files, destination_name=destination_name, environment=environment, **kwargs, From dc984df86b8c41a3797c171c038c97ea2fa8cbf7 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 1 Feb 2024 20:14:38 -0500 Subject: [PATCH 81/84] Update unsupported data types --- .../impl/databricks/databricks.py | 24 +++++++++++++++++++ .../dlt-ecosystem/destinations/databricks.md | 19 ++++++++++----- tests/load/pipeline/test_stage_loading.py | 3 +++ 3 files changed, 40 insertions(+), 6 deletions(-) diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 1ea176ef1b..b5a404302f 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -19,6 +19,7 @@ from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns from dlt.common.schema.typing import TTableSchema, TColumnType, TSchemaTables, TTableFormat +from dlt.common.schema.utils import table_schema_has_type from dlt.destinations.insert_job_client import InsertValuesJobClient @@ -104,6 +105,7 @@ def from_db_type( class DatabricksLoadJob(LoadJob, FollowupJob): def __init__( self, + table: TTableSchema, file_path: str, table_name: str, load_id: str, @@ -181,6 +183,27 @@ def __init__( file_path, "Databricks loader does not support gzip compressed JSON files. Please disable compression in the data writer configuration: https://dlthub.com/docs/reference/performance#disabling-and-enabling-file-compression", ) + if table_schema_has_type(table, "decimal"): + raise LoadJobTerminalException( + file_path, + "Databricks loader cannot load DECIMAL type columns from json files. Switch to parquet format to load decimals.", + ) + if table_schema_has_type(table, "binary"): + raise LoadJobTerminalException( + file_path, + "Databricks loader cannot load BINARY type columns from json files. Switch to parquet format to load byte values.", + ) + if table_schema_has_type(table, "complex"): + raise LoadJobTerminalException( + file_path, + "Databricks loader cannot load complex columns (lists and dicts) from json files. Switch to parquet format to load complex types.", + ) + if table_schema_has_type(table, "date"): + raise LoadJobTerminalException( + file_path, + "Databricks loader cannot load DATE type columns from json files. Switch to parquet format to load dates.", + ) + source_format = "JSON" format_options_clause = "FORMAT_OPTIONS('inferTimestamp'='true')" # Databricks fails when trying to load empty json files, so we have to check the file size @@ -236,6 +259,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> if not job: job = DatabricksLoadJob( + table, file_path, table["name"], load_id, diff --git a/docs/website/docs/dlt-ecosystem/destinations/databricks.md b/docs/website/docs/dlt-ecosystem/destinations/databricks.md index 18c41d3a8c..679988f918 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/databricks.md +++ b/docs/website/docs/dlt-ecosystem/destinations/databricks.md @@ -52,20 +52,25 @@ For more information on staging, see the [staging support](#staging-support) sec ## Supported file formats * [insert-values](../file-formats/insert-format.md) is used by default -* [jsonl](../file-formats/jsonl.md) supported when staging is enabled. **Note**: Currently loading compressed jsonl files is not supported. `data_writer.disable_compression` should be set to `true` in dlt config +* [jsonl](../file-formats/jsonl.md) supported when staging is enabled (see limitations below) * [parquet](../file-formats/parquet.md) supported when staging is enabled +The `jsonl` format has some limitations when used with Databricks: + +1. Compression must be disabled to load jsonl files in databricks. Set `data_writer.disable_compression` to `true` in dlt config when using this format. +2. The following data types are not supported when using `jsonl` format with `databricks`: `decimal`, `complex`, `date`, `binary`. Use `parquet` if your data contains these types. +3. `bigint` data type with precision is not supported with `jsonl` format + + ## Staging support Databricks supports both Amazon S3 and Azure Blob Storage as staging locations. `dlt` will upload files in `parquet` format to the staging location and will instruct Databricks to load data from there. ### Databricks and Amazon S3 -Please refer to the [S3 documentation](./filesystem.md#aws-s3) to learn how to set up your bucket with the bucket_url and credentials. For s3, the dlt Databricks loader will use the AWS credentials provided for s3 to access the s3 bucket if not specified otherwise (see config options below). You can specify your s3 bucket directly in your d +Please refer to the [S3 documentation](./filesystem.md#aws-s3) for details on connecting your s3 bucket with the bucket_url and credentials. -lt configuration: - -To set up Databricks with s3 as a staging destination: +Example to set up Databricks with s3 as a staging destination: ```python import dlt @@ -83,7 +88,9 @@ pipeline = dlt.pipeline( ### Databricks and Azure Blob Storage -Refer to the [Azure Blob Storage filesystem documentation](./filesystem.md#azure-blob-storage) for setting up your container with the bucket_url and credentials. For Azure Blob Storage, Databricks can directly load data from the storage container specified in the configuration: +Refer to the [Azure Blob Storage filesystem documentation](./filesystem.md#azure-blob-storage) for details on connecting your Azure Blob Storage container with the bucket_url and credentials. + +Example to set up Databricks with Azure as a staging destination: ```python # Create a dlt pipeline that will load diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py index bba589b444..0c3030ebaf 100644 --- a/tests/load/pipeline/test_stage_loading.py +++ b/tests/load/pipeline/test_stage_loading.py @@ -168,6 +168,9 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non ): # Redshift can't load fixed width binary columns from parquet exclude_columns.append("col7_precision") + if destination_config.destination == "databricks" and destination_config.file_format == "jsonl": + exclude_types.extend(["decimal", "binary", "wei", "complex", "date"]) + exclude_columns.append("col1_precision") column_schemas, data_types = table_update_and_row( exclude_types=exclude_types, exclude_columns=exclude_columns From bab216d8e09bc2012cceb419ed48ccfe9ac0d0ef Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Mon, 5 Feb 2024 14:45:39 +0100 Subject: [PATCH 82/84] correct content-hash after merge conflict resolution --- poetry.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/poetry.lock b/poetry.lock index 5ea4d19f2b..915152e0c2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "about-time" @@ -8749,4 +8749,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "61fa24ff52200b5bf97906a376826f00350abc8f6810fb2fcea73abaf245437f" +content-hash = "7b829a75b59316147385e16456395bebf2155e68cdeac3f9fa70523c3c33924a" From c3efe33d8c71469de77c54e6b4ec44758185da2e Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Mon, 5 Feb 2024 14:47:14 +0100 Subject: [PATCH 83/84] only remove hint if it is None, not if it is empty --- dlt/extract/hints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlt/extract/hints.py b/dlt/extract/hints.py index e483f035fc..c1a39041d8 100644 --- a/dlt/extract/hints.py +++ b/dlt/extract/hints.py @@ -212,7 +212,7 @@ def apply_hints( if additional_table_hints is not None: # loop through provided hints and add, overwrite, or remove them for k, v in additional_table_hints.items(): - if v: + if v is not None: t[k] = v # type: ignore[literal-required] else: t.pop(k, None) # type: ignore[misc] From cbd1b84a37fb8e9923fab6e9b2a812b1d5be9373 Mon Sep 17 00:00:00 2001 From: dat-a-man <98139823+dat-a-man@users.noreply.github.com> Date: Mon, 5 Feb 2024 21:36:28 +0530 Subject: [PATCH 84/84] Added documentation to remove columns (#921) * Added documentation to remove columns * Update documentation to remove columns * Update documentation to remove columns * Update documentation to remove columns * Update documentation to remove columns * Updated --- .../customising-pipelines/removing_columns.md | 91 +++++++++++++++++++ docs/website/sidebars.js | 1 + 2 files changed, 92 insertions(+) create mode 100644 docs/website/docs/general-usage/customising-pipelines/removing_columns.md diff --git a/docs/website/docs/general-usage/customising-pipelines/removing_columns.md b/docs/website/docs/general-usage/customising-pipelines/removing_columns.md new file mode 100644 index 0000000000..8493ffaec5 --- /dev/null +++ b/docs/website/docs/general-usage/customising-pipelines/removing_columns.md @@ -0,0 +1,91 @@ +--- +title: Removing columns +description: Removing columns by passing list of column names +keywords: [deleting, removing, columns, drop] +--- + +# Removing columns + +Removing columns before loading data into a database is a reliable method to eliminate sensitive or +unnecessary fields. For example, in the given scenario, a source is created with a "country_id" column, +which is then excluded from the database before loading. + +Let's create a sample pipeline demonstrating the process of removing a column. + +1. Create a source function that creates dummy data as follows: + + ```python + import dlt + + # This function creates a dummy data source. + @dlt.source + def dummy_source(): + @dlt.resource(write_disposition="replace") + def dummy_data(): + for i in range(3): + yield {"id": i, "name": f"Jane Washington {i}", "country_code": 40 + i} + + return dummy_data() + ``` + This function creates three columns `id`, `name` and `country_code`. + +1. Next, create a function to filter out columns from the data before loading it into a database as follows: + + ```python + from typing import Dict, List, Optional + + def remove_columns(doc: Dict, remove_columns: Optional[List[str]] = None) -> Dict: + if remove_columns is None: + remove_columns = [] + + # Iterating over the list of columns to be removed + for column_name in remove_columns: + # Removing the column if it exists in the document + if column_name in doc: + del doc[column_name] + + return doc + ``` + + `doc`: The document (dict) from which columns will be removed. + + `remove_columns`: List of column names to be removed, defaults to None. + +1. Next, declare the columns to be removed from the table, and then modify the source as follows: + + ```python + # Example columns to remove: + remove_columns_list = ["country_code"] + + # Create an instance of the source so you can edit it. + data_source = dummy_source() + + # Modify this source instance's resource + data_source = data_source.dummy_data.add_map( + lambda doc: remove_columns(doc, remove_columns_list) + ) + ``` +1. You can optionally inspect the result: + + ```python + for row in data_source: + print(row) + #{'id': 0, 'name': 'Jane Washington 0'} + #{'id': 1, 'name': 'Jane Washington 1'} + #{'id': 2, 'name': 'Jane Washington 2'} + ``` + +1. At last, create a pipeline: + + ```python + # Integrating with a DLT pipeline + pipeline = dlt.pipeline( + pipeline_name='example', + destination='bigquery', + dataset_name='filtered_data' + ) + # Run the pipeline with the transformed source + load_info = pipeline.run(data_source) + print(load_info) + ``` + diff --git a/docs/website/sidebars.js b/docs/website/sidebars.js index 2e23c5ca45..5e2726f4e6 100644 --- a/docs/website/sidebars.js +++ b/docs/website/sidebars.js @@ -220,6 +220,7 @@ const sidebars = { items: [ 'general-usage/customising-pipelines/renaming_columns', 'general-usage/customising-pipelines/pseudonymizing_columns', + 'general-usage/customising-pipelines/removing_columns', ] }, {