Skip to content

Commit

Permalink
fixes typing and azure url generation
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Nov 21, 2024
1 parent 19adf75 commit be7a602
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 55 deletions.
41 changes: 20 additions & 21 deletions dlt/common/configuration/specs/azure_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,36 @@
configspec,
)
from dlt import version
from dlt.common.utils import without_none

_AZURE_STORAGE_EXTRA = f"{version.DLT_PKG_NAME}[az]"


@configspec
class AzureCredentialsWithoutDefaults(CredentialsConfiguration):
class AzureCredentialsBase(CredentialsConfiguration):
azure_storage_account_name: str = None
azure_account_host: Optional[str] = None
"""Alternative host when accessing blob storage endpoint ie. my_account.dfs.core.windows.net"""

def to_adlfs_credentials(self) -> Dict[str, Any]:
pass

def to_object_store_rs_credentials(self) -> Dict[str, str]:
# https://docs.rs/object_store/latest/object_store/azure
creds: Dict[str, Any] = without_none(self.to_adlfs_credentials()) # type: ignore[assignment]
# only string options accepted
creds.pop("anon", None)
return creds


@configspec
class AzureCredentialsWithoutDefaults(AzureCredentialsBase):
"""Credentials for Azure Blob Storage, compatible with adlfs"""

azure_storage_account_name: str = None
azure_storage_account_key: Optional[TSecretStrValue] = None
azure_storage_sas_token: TSecretStrValue = None
azure_sas_token_permissions: str = "racwdl"
"""Permissions to use when generating a SAS token. Ignored when sas token is provided directly"""
azure_account_host: Optional[str] = None
"""Alternative host when accessing blob storage endpoint ie. my_account.dfs.core.windows.net"""

def to_adlfs_credentials(self) -> Dict[str, Any]:
"""Return a dict that can be passed as kwargs to adlfs"""
Expand All @@ -34,15 +49,6 @@ def to_adlfs_credentials(self) -> Dict[str, Any]:
account_host=self.azure_account_host,
)

def to_object_store_rs_credentials(self) -> Dict[str, str]:
# https://docs.rs/object_store/latest/object_store/azure
creds = self.to_adlfs_credentials()
if creds["sas_token"] is None:
creds.pop("sas_token")
if creds["account_key"] is None:
creds.pop("account_key")
return creds

def create_sas_token(self) -> None:
try:
from azure.storage.blob import generate_account_sas, ResourceTypes
Expand All @@ -66,13 +72,10 @@ def on_partial(self) -> None:


@configspec
class AzureServicePrincipalCredentialsWithoutDefaults(CredentialsConfiguration):
azure_storage_account_name: str = None
class AzureServicePrincipalCredentialsWithoutDefaults(AzureCredentialsBase):
azure_tenant_id: str = None
azure_client_id: str = None
azure_client_secret: TSecretStrValue = None
azure_account_host: Optional[str] = None
"""Alternative host when accessing blob storage endpoint ie. my_account.dfs.core.windows.net"""

def to_adlfs_credentials(self) -> Dict[str, Any]:
return dict(
Expand All @@ -83,10 +86,6 @@ def to_adlfs_credentials(self) -> Dict[str, Any]:
client_secret=self.azure_client_secret,
)

def to_object_store_rs_credentials(self) -> Dict[str, str]:
# https://docs.rs/object_store/latest/object_store/azure
return self.to_adlfs_credentials()


@configspec
class AzureCredentials(AzureCredentialsWithoutDefaults, CredentialsWithDefault):
Expand Down
7 changes: 5 additions & 2 deletions dlt/common/storages/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def ensure_canonical_az_url(
) -> str:
"""Converts any of the forms of azure blob storage into canonical form of {target_scheme}://<container_name>@<storage_account_name>.{account_host}/<path>
`azure_storage_account_name` is optional only if not present in bucket_url, `account_host` assumes "dfs.core.windows.net" by default
`azure_storage_account_name` is optional only if not present in bucket_url, `account_host` assumes "<azure_storage_account_name>.dfs.core.windows.net" by default
"""
parsed_bucket_url = urlparse(bucket_url)
# Converts an az://<container_name>/<path> to abfss://<container_name>@<storage_account_name>.dfs.core.windows.net/<path>
Expand All @@ -80,13 +80,16 @@ def ensure_canonical_az_url(
)

account_host = account_host or f"{storage_account_name}.dfs.core.windows.net"
netloc = (
f"{parsed_bucket_url.netloc}@{account_host}" if parsed_bucket_url.netloc else account_host
)

# as required by databricks
_path = parsed_bucket_url.path
return urlunparse(
parsed_bucket_url._replace(
scheme=target_scheme,
netloc=f"{parsed_bucket_url.netloc}@{account_host}",
netloc=netloc,
path=_path,
)
)
Expand Down
7 changes: 4 additions & 3 deletions dlt/destinations/impl/clickhouse/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,11 @@ def run(self) -> None:

# Authenticated access.
account_name = self._staging_credentials.azure_storage_account_name
account_host = self._staging_credentials.azure_account_host
storage_account_url = ensure_canonical_az_url(
bucket_path, "https", account_name, account_host
account_host = (
self._staging_credentials.azure_account_host
or f"{account_name}.blob.core.windows.net"
)
storage_account_url = ensure_canonical_az_url("", "https", account_name, account_host)
account_key = self._staging_credentials.azure_storage_account_key

# build table func
Expand Down
26 changes: 23 additions & 3 deletions dlt/destinations/impl/snowflake/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,9 @@ def gen_copy_sql(
and staging_credentials
and isinstance(staging_credentials, AzureCredentialsWithoutDefaults)
):
# Explicit azure credentials are needed to load from bucket without a named stage
credentials_clause = f"CREDENTIALS=(AZURE_SAS_TOKEN='?{staging_credentials.azure_storage_sas_token}')"
file_url = ensure_canonical_az_url(
file_url = cls.ensure_snowflake_azure_url(
file_url,
"azure",
staging_credentials.azure_storage_account_name,
staging_credentials.azure_account_host,
)
Expand Down Expand Up @@ -201,6 +199,28 @@ def gen_copy_sql(
{on_error_clause}
"""

@staticmethod
def ensure_snowflake_azure_url(
file_url: str, account_name: str = None, account_host: str = None
) -> str:
# Explicit azure credentials are needed to load from bucket without a named stage
if not account_host and account_name:
account_host = f"{account_name}.blob.core.windows.net"
# get canonical url first to convert it into snowflake form
canonical_url = ensure_canonical_az_url(
file_url,
"azure",
account_name,
account_host,
)
parsed_file_url = urlparse(canonical_url)
return urlunparse(
parsed_file_url._replace(
path=f"/{parsed_file_url.username}{parsed_file_url.path}",
netloc=parsed_file_url.hostname,
)
)


class SnowflakeClient(SqlJobClientWithStagingDataset, SupportsStagingDestination):
def __init__(
Expand Down
27 changes: 8 additions & 19 deletions dlt/sources/sql_database/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
Literal,
Optional,
Iterator,
Protocol,
Union,
)
import operator
Expand Down Expand Up @@ -46,23 +45,11 @@
)

TableBackend = Literal["sqlalchemy", "pyarrow", "pandas", "connectorx"]

SelectClause = Union[SelectAny, TextClause]


class TQueryAdapter(Protocol):
@staticmethod
def __call__(
query: SelectAny,
table: Table,
incremental: Optional[Incremental[Any]] = None,
engine: Optional[Engine] = None,
/,
) -> SelectClause:
pass


# TQueryAdapter = Union[Callable[[SelectAny, Table], SelectAny], Callable[[SelectAny, Table, Incremental], SelectAny]]
TQueryAdapter = Union[
Callable[[SelectAny, Table], SelectClause],
Callable[[SelectAny, Table, Incremental[Any], Engine], SelectClause],
]


class TableLoader:
Expand Down Expand Up @@ -153,12 +140,14 @@ def _make_query(self) -> SelectAny:
def make_query(self) -> SelectClause:
if self.query_adapter_callback:
try:
return self.query_adapter_callback(
return self.query_adapter_callback( # type: ignore[call-arg]
self._make_query(), self.table, self.incremental, self.engine
)
except TypeError:
try:
return self.query_adapter_callback(self._make_query(), self.table)
return self.query_adapter_callback( # type: ignore[call-arg]
self._make_query(), self.table
)
except TypeError:
raise

Expand Down
2 changes: 2 additions & 0 deletions tests/load/filesystem/test_azure_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def test_azure_credentials_from_default(environment: Dict[str, str]) -> None:
"account_key": None,
"sas_token": None,
"anon": False,
"account_host": None,
}


Expand All @@ -143,6 +144,7 @@ def test_azure_service_principal_credentials(environment: Dict[str, str]) -> Non

assert config.to_adlfs_credentials() == {
"account_name": environment["CREDENTIALS__AZURE_STORAGE_ACCOUNT_NAME"],
"account_host": None,
"client_id": environment["CREDENTIALS__AZURE_CLIENT_ID"],
"client_secret": environment["CREDENTIALS__AZURE_CLIENT_SECRET"],
"tenant_id": environment["CREDENTIALS__AZURE_TENANT_ID"],
Expand Down
26 changes: 26 additions & 0 deletions tests/load/snowflake/test_snowflake_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from urllib3.util import parse_url

from dlt.common.configuration.utils import add_config_to_env
from dlt.common.exceptions import TerminalValueError
from dlt.destinations.impl.snowflake.snowflake import SnowflakeLoadJob
from tests.utils import TEST_DICT_CONFIG_PROVIDER

pytest.importorskip("snowflake")
Expand Down Expand Up @@ -271,3 +273,27 @@ def test_snowflake_configuration() -> None:
explicit_value="snowflake://user1:pass@host1/db1?warehouse=warehouse1&role=role1",
)
assert SnowflakeClientConfiguration(credentials=c).fingerprint() == digest128("host1")


def test_snowflake_azure_converter() -> None:
with pytest.raises(TerminalValueError):
SnowflakeLoadJob.ensure_snowflake_azure_url("az://dlt-ci-test-bucket")

azure_url = SnowflakeLoadJob.ensure_snowflake_azure_url("az://dlt-ci-test-bucket", "my_account")
assert azure_url == "azure://my_account.blob.core.windows.net/dlt-ci-test-bucket"

azure_url = SnowflakeLoadJob.ensure_snowflake_azure_url(
"az://dlt-ci-test-bucket/path/to/file.parquet", "my_account"
)
assert (
azure_url
== "azure://my_account.blob.core.windows.net/dlt-ci-test-bucket/path/to/file.parquet"
)

azure_url = SnowflakeLoadJob.ensure_snowflake_azure_url(
"abfss://dlt-ci-test-bucket@my_account.blob.core.windows.net/path/to/file.parquet"
)
assert (
azure_url
== "azure://my_account.blob.core.windows.net/dlt-ci-test-bucket/path/to/file.parquet"
)
2 changes: 0 additions & 2 deletions tests/load/sources/sql_database/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import pytest


import dlt
from dlt.common.typing import TDataItem


from dlt.common.exceptions import MissingDependencyException

try:
Expand Down
15 changes: 10 additions & 5 deletions tests/load/sources/sql_database/test_sql_database_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ def add_new_columns(table: Table) -> None:

last_query: str = None

def query_adapter_callback(
query: SelectAny, table: Table, incremental: Incremental[Any] = None, engine: Engine = None
def query_adapter(
query: SelectAny, table: Table, incremental: Optional[Incremental[Any]], engine: Engine
) -> TextClause:
nonlocal last_query

Expand All @@ -244,7 +244,7 @@ def query_adapter_callback(
reflection_level="full",
backend=backend,
table_adapter_callback=add_new_columns,
query_adapter_callback=query_adapter_callback,
query_adapter_callback=query_adapter,
incremental=dlt.sources.incremental("updated_at"),
)

Expand Down Expand Up @@ -1102,7 +1102,10 @@ def test_sql_table_included_columns(
def test_query_adapter_callback(
sql_source_db: SQLAlchemySourceDB, backend: TableBackend, standalone_resource: bool
) -> None:
def query_adapter_callback(query, table):
from dlt.sources.sql_database.helpers import SelectAny
from dlt.common.libs.sql_alchemy import Table

def query_adapter_callback(query: SelectAny, table: Table) -> SelectAny:
if table.name == "chat_channel":
# Only select active channels
return query.where(table.c.active.is_(True))
Expand All @@ -1114,7 +1117,6 @@ def query_adapter_callback(query, table):
schema=sql_source_db.schema,
reflection_level="full",
backend=backend,
query_adapter_callback=query_adapter_callback,
)

if standalone_resource:
Expand All @@ -1124,18 +1126,21 @@ def dummy_source():
yield sql_table(
**common_kwargs, # type: ignore[arg-type]
table="chat_channel",
query_adapter_callback=query_adapter_callback,
)

yield sql_table(
**common_kwargs, # type: ignore[arg-type]
table="chat_message",
query_adapter_callback=query_adapter_callback,
)

source = dummy_source()
else:
source = sql_database(
**common_kwargs, # type: ignore[arg-type]
table_names=["chat_message", "chat_channel"],
query_adapter_callback=query_adapter_callback,
)

pipeline = make_pipeline("duckdb")
Expand Down

0 comments on commit be7a602

Please sign in to comment.