diff --git a/dlt/common/configuration/specs/azure_credentials.py b/dlt/common/configuration/specs/azure_credentials.py index 63f42056fa..cf6ec493de 100644 --- a/dlt/common/configuration/specs/azure_credentials.py +++ b/dlt/common/configuration/specs/azure_credentials.py @@ -9,21 +9,36 @@ configspec, ) from dlt import version +from dlt.common.utils import without_none _AZURE_STORAGE_EXTRA = f"{version.DLT_PKG_NAME}[az]" @configspec -class AzureCredentialsWithoutDefaults(CredentialsConfiguration): +class AzureCredentialsBase(CredentialsConfiguration): + azure_storage_account_name: str = None + azure_account_host: Optional[str] = None + """Alternative host when accessing blob storage endpoint ie. my_account.dfs.core.windows.net""" + + def to_adlfs_credentials(self) -> Dict[str, Any]: + pass + + def to_object_store_rs_credentials(self) -> Dict[str, str]: + # https://docs.rs/object_store/latest/object_store/azure + creds: Dict[str, Any] = without_none(self.to_adlfs_credentials()) # type: ignore[assignment] + # only string options accepted + creds.pop("anon", None) + return creds + + +@configspec +class AzureCredentialsWithoutDefaults(AzureCredentialsBase): """Credentials for Azure Blob Storage, compatible with adlfs""" - azure_storage_account_name: str = None azure_storage_account_key: Optional[TSecretStrValue] = None azure_storage_sas_token: TSecretStrValue = None azure_sas_token_permissions: str = "racwdl" """Permissions to use when generating a SAS token. Ignored when sas token is provided directly""" - azure_account_host: Optional[str] = None - """Alternative host when accessing blob storage endpoint ie. my_account.dfs.core.windows.net""" def to_adlfs_credentials(self) -> Dict[str, Any]: """Return a dict that can be passed as kwargs to adlfs""" @@ -34,15 +49,6 @@ def to_adlfs_credentials(self) -> Dict[str, Any]: account_host=self.azure_account_host, ) - def to_object_store_rs_credentials(self) -> Dict[str, str]: - # https://docs.rs/object_store/latest/object_store/azure - creds = self.to_adlfs_credentials() - if creds["sas_token"] is None: - creds.pop("sas_token") - if creds["account_key"] is None: - creds.pop("account_key") - return creds - def create_sas_token(self) -> None: try: from azure.storage.blob import generate_account_sas, ResourceTypes @@ -66,13 +72,10 @@ def on_partial(self) -> None: @configspec -class AzureServicePrincipalCredentialsWithoutDefaults(CredentialsConfiguration): - azure_storage_account_name: str = None +class AzureServicePrincipalCredentialsWithoutDefaults(AzureCredentialsBase): azure_tenant_id: str = None azure_client_id: str = None azure_client_secret: TSecretStrValue = None - azure_account_host: Optional[str] = None - """Alternative host when accessing blob storage endpoint ie. my_account.dfs.core.windows.net""" def to_adlfs_credentials(self) -> Dict[str, Any]: return dict( @@ -83,10 +86,6 @@ def to_adlfs_credentials(self) -> Dict[str, Any]: client_secret=self.azure_client_secret, ) - def to_object_store_rs_credentials(self) -> Dict[str, str]: - # https://docs.rs/object_store/latest/object_store/azure - return self.to_adlfs_credentials() - @configspec class AzureCredentials(AzureCredentialsWithoutDefaults, CredentialsWithDefault): diff --git a/dlt/common/storages/configuration.py b/dlt/common/storages/configuration.py index 4220716706..9fb176a42a 100644 --- a/dlt/common/storages/configuration.py +++ b/dlt/common/storages/configuration.py @@ -63,7 +63,7 @@ def ensure_canonical_az_url( ) -> str: """Converts any of the forms of azure blob storage into canonical form of {target_scheme}://@.{account_host}/ - `azure_storage_account_name` is optional only if not present in bucket_url, `account_host` assumes "dfs.core.windows.net" by default + `azure_storage_account_name` is optional only if not present in bucket_url, `account_host` assumes ".dfs.core.windows.net" by default """ parsed_bucket_url = urlparse(bucket_url) # Converts an az:/// to abfss://@.dfs.core.windows.net/ @@ -80,13 +80,16 @@ def ensure_canonical_az_url( ) account_host = account_host or f"{storage_account_name}.dfs.core.windows.net" + netloc = ( + f"{parsed_bucket_url.netloc}@{account_host}" if parsed_bucket_url.netloc else account_host + ) # as required by databricks _path = parsed_bucket_url.path return urlunparse( parsed_bucket_url._replace( scheme=target_scheme, - netloc=f"{parsed_bucket_url.netloc}@{account_host}", + netloc=netloc, path=_path, ) ) diff --git a/dlt/destinations/impl/clickhouse/clickhouse.py b/dlt/destinations/impl/clickhouse/clickhouse.py index 6e320dba8a..3a5f5c3e28 100644 --- a/dlt/destinations/impl/clickhouse/clickhouse.py +++ b/dlt/destinations/impl/clickhouse/clickhouse.py @@ -150,10 +150,11 @@ def run(self) -> None: # Authenticated access. account_name = self._staging_credentials.azure_storage_account_name - account_host = self._staging_credentials.azure_account_host - storage_account_url = ensure_canonical_az_url( - bucket_path, "https", account_name, account_host + account_host = ( + self._staging_credentials.azure_account_host + or f"{account_name}.blob.core.windows.net" ) + storage_account_url = ensure_canonical_az_url("", "https", account_name, account_host) account_key = self._staging_credentials.azure_storage_account_key # build table func diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index b66923002e..d83f4d7369 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -141,11 +141,9 @@ def gen_copy_sql( and staging_credentials and isinstance(staging_credentials, AzureCredentialsWithoutDefaults) ): - # Explicit azure credentials are needed to load from bucket without a named stage credentials_clause = f"CREDENTIALS=(AZURE_SAS_TOKEN='?{staging_credentials.azure_storage_sas_token}')" - file_url = ensure_canonical_az_url( + file_url = cls.ensure_snowflake_azure_url( file_url, - "azure", staging_credentials.azure_storage_account_name, staging_credentials.azure_account_host, ) @@ -201,6 +199,28 @@ def gen_copy_sql( {on_error_clause} """ + @staticmethod + def ensure_snowflake_azure_url( + file_url: str, account_name: str = None, account_host: str = None + ) -> str: + # Explicit azure credentials are needed to load from bucket without a named stage + if not account_host and account_name: + account_host = f"{account_name}.blob.core.windows.net" + # get canonical url first to convert it into snowflake form + canonical_url = ensure_canonical_az_url( + file_url, + "azure", + account_name, + account_host, + ) + parsed_file_url = urlparse(canonical_url) + return urlunparse( + parsed_file_url._replace( + path=f"/{parsed_file_url.username}{parsed_file_url.path}", + netloc=parsed_file_url.hostname, + ) + ) + class SnowflakeClient(SqlJobClientWithStagingDataset, SupportsStagingDestination): def __init__( diff --git a/dlt/sources/sql_database/helpers.py b/dlt/sources/sql_database/helpers.py index 0b6d0777de..072a254bcc 100644 --- a/dlt/sources/sql_database/helpers.py +++ b/dlt/sources/sql_database/helpers.py @@ -9,7 +9,6 @@ Literal, Optional, Iterator, - Protocol, Union, ) import operator @@ -46,23 +45,11 @@ ) TableBackend = Literal["sqlalchemy", "pyarrow", "pandas", "connectorx"] - SelectClause = Union[SelectAny, TextClause] - - -class TQueryAdapter(Protocol): - @staticmethod - def __call__( - query: SelectAny, - table: Table, - incremental: Optional[Incremental[Any]] = None, - engine: Optional[Engine] = None, - /, - ) -> SelectClause: - pass - - -# TQueryAdapter = Union[Callable[[SelectAny, Table], SelectAny], Callable[[SelectAny, Table, Incremental], SelectAny]] +TQueryAdapter = Union[ + Callable[[SelectAny, Table], SelectClause], + Callable[[SelectAny, Table, Incremental[Any], Engine], SelectClause], +] class TableLoader: @@ -153,12 +140,14 @@ def _make_query(self) -> SelectAny: def make_query(self) -> SelectClause: if self.query_adapter_callback: try: - return self.query_adapter_callback( + return self.query_adapter_callback( # type: ignore[call-arg] self._make_query(), self.table, self.incremental, self.engine ) except TypeError: try: - return self.query_adapter_callback(self._make_query(), self.table) + return self.query_adapter_callback( # type: ignore[call-arg] + self._make_query(), self.table + ) except TypeError: raise diff --git a/tests/load/filesystem/test_azure_credentials.py b/tests/load/filesystem/test_azure_credentials.py index 002e256cff..811eb41f75 100644 --- a/tests/load/filesystem/test_azure_credentials.py +++ b/tests/load/filesystem/test_azure_credentials.py @@ -126,6 +126,7 @@ def test_azure_credentials_from_default(environment: Dict[str, str]) -> None: "account_key": None, "sas_token": None, "anon": False, + "account_host": None, } @@ -143,6 +144,7 @@ def test_azure_service_principal_credentials(environment: Dict[str, str]) -> Non assert config.to_adlfs_credentials() == { "account_name": environment["CREDENTIALS__AZURE_STORAGE_ACCOUNT_NAME"], + "account_host": None, "client_id": environment["CREDENTIALS__AZURE_CLIENT_ID"], "client_secret": environment["CREDENTIALS__AZURE_CLIENT_SECRET"], "tenant_id": environment["CREDENTIALS__AZURE_TENANT_ID"], diff --git a/tests/load/snowflake/test_snowflake_configuration.py b/tests/load/snowflake/test_snowflake_configuration.py index 21973025c7..265a6a0935 100644 --- a/tests/load/snowflake/test_snowflake_configuration.py +++ b/tests/load/snowflake/test_snowflake_configuration.py @@ -4,6 +4,8 @@ from urllib3.util import parse_url from dlt.common.configuration.utils import add_config_to_env +from dlt.common.exceptions import TerminalValueError +from dlt.destinations.impl.snowflake.snowflake import SnowflakeLoadJob from tests.utils import TEST_DICT_CONFIG_PROVIDER pytest.importorskip("snowflake") @@ -271,3 +273,27 @@ def test_snowflake_configuration() -> None: explicit_value="snowflake://user1:pass@host1/db1?warehouse=warehouse1&role=role1", ) assert SnowflakeClientConfiguration(credentials=c).fingerprint() == digest128("host1") + + +def test_snowflake_azure_converter() -> None: + with pytest.raises(TerminalValueError): + SnowflakeLoadJob.ensure_snowflake_azure_url("az://dlt-ci-test-bucket") + + azure_url = SnowflakeLoadJob.ensure_snowflake_azure_url("az://dlt-ci-test-bucket", "my_account") + assert azure_url == "azure://my_account.blob.core.windows.net/dlt-ci-test-bucket" + + azure_url = SnowflakeLoadJob.ensure_snowflake_azure_url( + "az://dlt-ci-test-bucket/path/to/file.parquet", "my_account" + ) + assert ( + azure_url + == "azure://my_account.blob.core.windows.net/dlt-ci-test-bucket/path/to/file.parquet" + ) + + azure_url = SnowflakeLoadJob.ensure_snowflake_azure_url( + "abfss://dlt-ci-test-bucket@my_account.blob.core.windows.net/path/to/file.parquet" + ) + assert ( + azure_url + == "azure://my_account.blob.core.windows.net/dlt-ci-test-bucket/path/to/file.parquet" + ) diff --git a/tests/load/sources/sql_database/test_helpers.py b/tests/load/sources/sql_database/test_helpers.py index 4748f226a9..ca7312e2f6 100644 --- a/tests/load/sources/sql_database/test_helpers.py +++ b/tests/load/sources/sql_database/test_helpers.py @@ -1,10 +1,8 @@ import pytest - import dlt from dlt.common.typing import TDataItem - from dlt.common.exceptions import MissingDependencyException try: diff --git a/tests/load/sources/sql_database/test_sql_database_source.py b/tests/load/sources/sql_database/test_sql_database_source.py index 388b1c7d24..5a311eafa5 100644 --- a/tests/load/sources/sql_database/test_sql_database_source.py +++ b/tests/load/sources/sql_database/test_sql_database_source.py @@ -221,8 +221,8 @@ def add_new_columns(table: Table) -> None: last_query: str = None - def query_adapter_callback( - query: SelectAny, table: Table, incremental: Incremental[Any] = None, engine: Engine = None + def query_adapter( + query: SelectAny, table: Table, incremental: Optional[Incremental[Any]], engine: Engine ) -> TextClause: nonlocal last_query @@ -244,7 +244,7 @@ def query_adapter_callback( reflection_level="full", backend=backend, table_adapter_callback=add_new_columns, - query_adapter_callback=query_adapter_callback, + query_adapter_callback=query_adapter, incremental=dlt.sources.incremental("updated_at"), ) @@ -1102,7 +1102,10 @@ def test_sql_table_included_columns( def test_query_adapter_callback( sql_source_db: SQLAlchemySourceDB, backend: TableBackend, standalone_resource: bool ) -> None: - def query_adapter_callback(query, table): + from dlt.sources.sql_database.helpers import SelectAny + from dlt.common.libs.sql_alchemy import Table + + def query_adapter_callback(query: SelectAny, table: Table) -> SelectAny: if table.name == "chat_channel": # Only select active channels return query.where(table.c.active.is_(True)) @@ -1114,7 +1117,6 @@ def query_adapter_callback(query, table): schema=sql_source_db.schema, reflection_level="full", backend=backend, - query_adapter_callback=query_adapter_callback, ) if standalone_resource: @@ -1124,11 +1126,13 @@ def dummy_source(): yield sql_table( **common_kwargs, # type: ignore[arg-type] table="chat_channel", + query_adapter_callback=query_adapter_callback, ) yield sql_table( **common_kwargs, # type: ignore[arg-type] table="chat_message", + query_adapter_callback=query_adapter_callback, ) source = dummy_source() @@ -1136,6 +1140,7 @@ def dummy_source(): source = sql_database( **common_kwargs, # type: ignore[arg-type] table_names=["chat_message", "chat_channel"], + query_adapter_callback=query_adapter_callback, ) pipeline = make_pipeline("duckdb")