Skip to content

Commit

Permalink
add databricks oauth authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
donotpush committed Dec 11, 2024
1 parent 4e5a240 commit 6bb0a30
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 7 deletions.
12 changes: 12 additions & 0 deletions dlt/destinations/impl/databricks/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dlt.common.typing import TSecretStrValue
from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec
from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration
from dlt.common.configuration.exceptions import ConfigurationValueError


DATABRICKS_APPLICATION_ID = "dltHub_dlt"
Expand All @@ -15,6 +16,8 @@ class DatabricksCredentials(CredentialsConfiguration):
server_hostname: str = None
http_path: str = None
access_token: Optional[TSecretStrValue] = None
client_id: Optional[TSecretStrValue] = None
client_secret: Optional[TSecretStrValue] = None
http_headers: Optional[Dict[str, str]] = None
session_configuration: Optional[Dict[str, Any]] = None
"""Dict of session parameters that will be passed to `databricks.sql.connect`"""
Expand All @@ -27,9 +30,18 @@ class DatabricksCredentials(CredentialsConfiguration):
"server_hostname",
"http_path",
"catalog",
"client_id",
"client_secret",
"access_token",
]

def on_resolved(self) -> None:
if not ((self.client_id and self.client_secret) or self.access_token):
raise ConfigurationValueError(
"No valid authentication method detected. Provide either 'client_id' and"
" 'client_secret' for OAuth, or 'access_token' for token-based authentication."
)

def to_connector_params(self) -> Dict[str, Any]:
conn_params = dict(
catalog=self.catalog,
Expand Down
20 changes: 18 additions & 2 deletions dlt/destinations/impl/databricks/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
Tuple,
Union,
Dict,
cast,
Callable,
)


from databricks import sql as databricks_lib
from databricks.sdk.core import Config, oauth_service_principal
from databricks import sql as databricks_lib # type: ignore[attr-defined]
from databricks.sql.client import (
Connection as DatabricksSqlConnection,
Cursor as DatabricksSqlCursor,
Expand Down Expand Up @@ -73,8 +75,22 @@ def __init__(
self._conn: DatabricksSqlConnection = None
self.credentials = credentials

def _get_oauth_credentials(self) -> Optional[Callable[[], Dict[str, str]]]:
config = Config(
host=f"https://{self.credentials.server_hostname}",
client_id=self.credentials.client_id,
client_secret=self.credentials.client_secret,
)
return cast(Callable[[], Dict[str, str]], oauth_service_principal(config))

def open_connection(self) -> DatabricksSqlConnection:
conn_params = self.credentials.to_connector_params()

if self.credentials.client_id and self.credentials.client_secret:
conn_params["credentials_provider"] = self._get_oauth_credentials
else:
conn_params["access_token"] = self.credentials.access_token

self._conn = databricks_lib.connect(
**conn_params, schema=self.dataset_name, use_inline_params="silent"
)
Expand Down
25 changes: 24 additions & 1 deletion docs/website/docs/dlt-ecosystem/destinations/databricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,29 @@ If you already have your Databricks workspace set up, you can skip to the [Loade
Click your email in the top right corner and go to "User Settings". Go to "Developer" -> "Access Tokens".
Generate a new token and save it. You will use it in your `dlt` configuration.

## OAuth M2M (Machine-to-Machine) Authentication

You can authenticate to Databricks using a service principal via OAuth M2M. This method allows for secure, programmatic access to Databricks resources without requiring a user-managed personal access token.

### Create a Service Principal in Databricks
Follow the instructions in the Databricks documentation to create a service principal and retrieve the client_id and client_secret:

[Authenticate access to Databricks using OAuth M2M](https://docs.databricks.com/en/dev-tools/auth/oauth-m2m.html)

Once you have the service principal credentials, update your secrets.toml as shown bellow.

### Configuration

Add the following fields to your `.dlt/secrets.toml` file:
```toml
[destination.databricks.credentials]
server_hostname = "MY_DATABRICKS.azuredatabricks.net"
http_path = "/sql/1.0/warehouses/12345"
catalog = "my_catalog"
client_id = "XXX"
client_secret = "XXX"
```

## Loader setup guide

**1. Initialize a project with a pipeline that loads to Databricks by running**
Expand Down Expand Up @@ -118,7 +141,7 @@ Example:
[destination.databricks.credentials]
server_hostname = "MY_DATABRICKS.azuredatabricks.net"
http_path = "/sql/1.0/warehouses/12345"
access_token = "MY_ACCESS_TOKEN"
access_token = "MY_ACCESS_TOKEN" # Replace for client_id and client_secret when using OAuth
catalog = "my_catalog"
```

Expand Down
26 changes: 23 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ db-dtypes = { version = ">=1.2.0", optional = true }
# pyiceberg = { version = ">=0.7.1", optional = true, extras = ["sql-sqlite"] }
# we will rely on manual installation of `sqlalchemy>=2.0.18` instead
pyiceberg = { version = ">=0.8.1", python = ">=3.9", optional = true }
databricks-sdk = {version = ">=0.38.0", optional = true}

[tool.poetry.extras]
gcp = ["grpcio", "google-cloud-bigquery", "db-dtypes", "gcsfs"]
Expand All @@ -117,7 +118,7 @@ weaviate = ["weaviate-client"]
mssql = ["pyodbc"]
synapse = ["pyodbc", "adlfs", "pyarrow"]
qdrant = ["qdrant-client"]
databricks = ["databricks-sql-connector"]
databricks = ["databricks-sql-connector", "databricks-sdk"]
clickhouse = ["clickhouse-driver", "clickhouse-connect", "s3fs", "gcsfs", "adlfs", "pyarrow"]
dremio = ["pyarrow"]
lancedb = ["lancedb", "pyarrow", "tantivy"]
Expand Down
10 changes: 10 additions & 0 deletions tests/load/databricks/test_databricks_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
pytest.importorskip("databricks")

from dlt.common.exceptions import TerminalValueError
from dlt.common.configuration.exceptions import ConfigurationValueError
from dlt.destinations.impl.databricks.databricks import DatabricksLoadJob
from dlt.common.configuration import resolve_configuration

Expand Down Expand Up @@ -86,3 +87,12 @@ def test_databricks_abfss_converter() -> None:
abfss_url
== "abfss://dlt-ci-test-bucket@my_account.dfs.core.windows.net/path/to/file.parquet"
)


def test_databricks_auth_invalid() -> None:
with pytest.raises(ConfigurationValueError, match="No valid authentication method detected.*"):
os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = ""
os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = ""
os.environ["DESTINATION__DATABRICKS__CREDENTIALS__ACCESS_TOKEN"] = ""
bricks = databricks()
bricks.configuration(None, accept_partial=True)
50 changes: 50 additions & 0 deletions tests/load/pipeline/test_databricks_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os

from dlt.common.utils import uniq_id
from dlt.destinations import databricks
from tests.load.utils import (
GCS_BUCKET,
DestinationTestConfiguration,
Expand Down Expand Up @@ -145,3 +146,52 @@ def test_databricks_gcs_external_location(destination_config: DestinationTestCon
assert (
"credential_x" in pipeline.list_failed_jobs_in_package(info.loads_ids[0])[0].failed_message
)


@pytest.mark.parametrize(
"destination_config",
destinations_configs(default_sql_configs=True, subset=("databricks",)),
ids=lambda x: x.name,
)
def test_databricks_auth_oauth(destination_config: DestinationTestConfiguration) -> None:
os.environ["DESTINATION__DATABRICKS__CREDENTIALS__ACCESS_TOKEN"] = ""
bricks = databricks()
config = bricks.configuration(None, accept_partial=True)
assert config.credentials.client_id and config.credentials.client_secret

dataset_name = "test_databricks_oauth" + uniq_id()
pipeline = destination_config.setup_pipeline(
"test_databricks_oauth", dataset_name=dataset_name, destination=bricks
)

info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs)
assert info.has_failed_jobs is False

with pipeline.sql_client() as client:
rows = client.execute_sql(f"select * from {dataset_name}.digits")
assert len(rows) == 3


@pytest.mark.parametrize(
"destination_config",
destinations_configs(default_sql_configs=True, subset=("databricks",)),
ids=lambda x: x.name,
)
def test_databricks_auth_token(destination_config: DestinationTestConfiguration) -> None:
os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = ""
os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = ""
bricks = databricks()
config = bricks.configuration(None, accept_partial=True)
assert config.credentials.access_token

dataset_name = "test_databricks_token" + uniq_id()
pipeline = destination_config.setup_pipeline(
"test_databricks_token", dataset_name=dataset_name, destination=bricks
)

info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs)
assert info.has_failed_jobs is False

with pipeline.sql_client() as client:
rows = client.execute_sql(f"select * from {dataset_name}.digits")
assert len(rows) == 3

0 comments on commit 6bb0a30

Please sign in to comment.