diff --git a/dlt/common/configuration/specs/__init__.py b/dlt/common/configuration/specs/__init__.py index 9acf14bde3..f1d7d819ff 100644 --- a/dlt/common/configuration/specs/__init__.py +++ b/dlt/common/configuration/specs/__init__.py @@ -20,7 +20,13 @@ from .connection_string_credentials import ConnectionStringCredentials from .api_credentials import OAuth2Credentials from .aws_credentials import AwsCredentials, AwsCredentialsWithoutDefaults -from .azure_credentials import AzureCredentials, AzureCredentialsWithoutDefaults +from .azure_credentials import ( + AzureCredentials, + AzureCredentialsWithoutDefaults, + AzureServicePrincipalCredentials, + AzureServicePrincipalCredentialsWithoutDefaults, + AnyAzureCredentials, +) # backward compatibility for service account credentials @@ -51,6 +57,9 @@ "AwsCredentialsWithoutDefaults", "AzureCredentials", "AzureCredentialsWithoutDefaults", + "AzureServicePrincipalCredentials", + "AzureServicePrincipalCredentialsWithoutDefaults", + "AnyAzureCredentials", "GcpClientCredentials", "GcpClientCredentialsWithDefault", ] diff --git a/dlt/common/configuration/specs/azure_credentials.py b/dlt/common/configuration/specs/azure_credentials.py index 52d33ec0d3..96119206c3 100644 --- a/dlt/common/configuration/specs/azure_credentials.py +++ b/dlt/common/configuration/specs/azure_credentials.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, Union from dlt.common.pendulum import pendulum from dlt.common.typing import TSecretStrValue @@ -50,6 +50,22 @@ def on_partial(self) -> None: self.resolve() +@configspec +class AzureServicePrincipalCredentialsWithoutDefaults(CredentialsConfiguration): + azure_storage_account_name: str = None + azure_tenant_id: str = None + azure_client_id: str = None + azure_client_secret: TSecretStrValue = None + + def to_adlfs_credentials(self) -> Dict[str, Any]: + return dict( + account_name=self.azure_storage_account_name, + tenant_id=self.azure_tenant_id, + client_id=self.azure_client_id, + client_secret=self.azure_client_secret, + ) + + @configspec class AzureCredentials(AzureCredentialsWithoutDefaults, CredentialsWithDefault): def on_partial(self) -> None: @@ -67,3 +83,31 @@ def to_adlfs_credentials(self) -> Dict[str, Any]: if self.has_default_credentials(): base_kwargs["anon"] = False return base_kwargs + + +@configspec +class AzureServicePrincipalCredentials( + AzureServicePrincipalCredentialsWithoutDefaults, CredentialsWithDefault +): + def on_partial(self) -> None: + from azure.identity import DefaultAzureCredential + + self._set_default_credentials(DefaultAzureCredential()) + if self.azure_storage_account_name: + self.resolve() + + def to_adlfs_credentials(self) -> Dict[str, Any]: + base_kwargs = super().to_adlfs_credentials() + if self.has_default_credentials(): + base_kwargs["anon"] = False + return base_kwargs + + +AnyAzureCredentials = Union[ + # Credentials without defaults come first because union types are attempted in order + # and explicit config should supersede system defaults + AzureCredentialsWithoutDefaults, + AzureServicePrincipalCredentialsWithoutDefaults, + AzureCredentials, + AzureServicePrincipalCredentials, +] diff --git a/dlt/common/storages/configuration.py b/dlt/common/storages/configuration.py index a1838fab6e..6e100536af 100644 --- a/dlt/common/storages/configuration.py +++ b/dlt/common/storages/configuration.py @@ -10,8 +10,7 @@ GcpServiceAccountCredentials, AwsCredentials, GcpOAuthCredentials, - AzureCredentials, - AzureCredentialsWithoutDefaults, + AnyAzureCredentials, BaseConfiguration, ) from dlt.common.typing import DictStrAny @@ -49,7 +48,7 @@ class LoadStorageConfiguration(BaseConfiguration): FileSystemCredentials = Union[ - AwsCredentials, GcpServiceAccountCredentials, AzureCredentials, GcpOAuthCredentials + AwsCredentials, GcpServiceAccountCredentials, AnyAzureCredentials, GcpOAuthCredentials ] @@ -70,9 +69,9 @@ class FilesystemConfiguration(BaseConfiguration): "gcs": Union[GcpServiceAccountCredentials, GcpOAuthCredentials], "gdrive": Union[GcpServiceAccountCredentials, GcpOAuthCredentials], "s3": AwsCredentials, - "az": Union[AzureCredentialsWithoutDefaults, AzureCredentials], - "abfs": Union[AzureCredentialsWithoutDefaults, AzureCredentials], - "adl": Union[AzureCredentialsWithoutDefaults, AzureCredentials], + "az": AnyAzureCredentials, + "abfs": AnyAzureCredentials, + "adl": AnyAzureCredentials, } bucket_url: str = None diff --git a/tests/load/filesystem/test_azure_credentials.py b/tests/load/filesystem/test_azure_credentials.py index 467ba55a4f..194d6db33d 100644 --- a/tests/load/filesystem/test_azure_credentials.py +++ b/tests/load/filesystem/test_azure_credentials.py @@ -6,7 +6,12 @@ from dlt.common import pendulum from dlt.common.time import ensure_pendulum_datetime from dlt.common.configuration import resolve_configuration, ConfigFieldMissingException -from dlt.common.configuration.specs import AzureCredentials +from dlt.common.configuration.specs import ( + AzureCredentials, + AzureServicePrincipalCredentials, + AzureServicePrincipalCredentialsWithoutDefaults, + AzureCredentialsWithoutDefaults, +) from tests.load.utils import ALL_FILESYSTEM_DRIVERS from tests.common.configuration.utils import environment from tests.utils import preserve_environ, autouse_test_storage @@ -95,3 +100,53 @@ def test_azure_credentials_from_default(environment: Dict[str, str]) -> None: "sas_token": None, "anon": False, } + + +def test_azure_service_principal_credentials(environment: Dict[str, str]) -> None: + environment["CREDENTIALS__AZURE_STORAGE_ACCOUNT_NAME"] = "fake_account_name" + environment["CREDENTIALS__AZURE_CLIENT_ID"] = "fake_client_id" + environment["CREDENTIALS__AZURE_CLIENT_SECRET"] = "fake_client_secret" + environment["CREDENTIALS__AZURE_TENANT_ID"] = "fake_tenant_id" + + config = resolve_configuration(AzureServicePrincipalCredentials()) + + assert config.azure_client_id == environment["CREDENTIALS__AZURE_CLIENT_ID"] + assert config.azure_client_secret == environment["CREDENTIALS__AZURE_CLIENT_SECRET"] + assert config.azure_tenant_id == environment["CREDENTIALS__AZURE_TENANT_ID"] + + assert config.to_adlfs_credentials() == { + "account_name": environment["CREDENTIALS__AZURE_STORAGE_ACCOUNT_NAME"], + "client_id": environment["CREDENTIALS__AZURE_CLIENT_ID"], + "client_secret": environment["CREDENTIALS__AZURE_CLIENT_SECRET"], + "tenant_id": environment["CREDENTIALS__AZURE_TENANT_ID"], + } + + +from dlt.common.storages.configuration import FilesystemConfiguration + + +def test_azure_filesystem_configuration_service_principal(environment: Dict[str, str]) -> None: + """Filesystem config resolves correct credentials type""" + environment["CREDENTIALS__AZURE_STORAGE_ACCOUNT_NAME"] = "fake_account_name" + environment["CREDENTIALS__AZURE_CLIENT_ID"] = "fake_client_id" + environment["CREDENTIALS__AZURE_CLIENT_SECRET"] = "asdsadas" + environment["CREDENTIALS__AZURE_TENANT_ID"] = "fake_tenant_id" + + config = FilesystemConfiguration(bucket_url="az://my-bucket") + + resolved_config = resolve_configuration(config) + + assert isinstance(resolved_config.credentials, AzureServicePrincipalCredentialsWithoutDefaults) + + +def test_azure_filesystem_configuration_sas_token(environment: Dict[str, str]) -> None: + environment["CREDENTIALS__AZURE_STORAGE_ACCOUNT_NAME"] = "fake_account_name" + environment["CREDENTIALS__AZURE_STORAGE_SAS_TOKEN"] = ( + "sp=rwdlacx&se=2021-01-01T00:00:00Z&sv=2019-12-12&sr=c&sig=1234567890" + ) + + config = FilesystemConfiguration(bucket_url="az://my-bucket") + + resolved_config = resolve_configuration(config) + + assert isinstance(resolved_config.credentials, AzureCredentialsWithoutDefaults)