Skip to content

Commit

Permalink
refactor databricks auth test
Browse files Browse the repository at this point in the history
  • Loading branch information
donotpush committed Dec 11, 2024
1 parent 27fe53d commit a59ba3e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
10 changes: 10 additions & 0 deletions tests/load/databricks/test_databricks_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
pytest.importorskip("databricks")

from dlt.common.exceptions import TerminalValueError
from dlt.common.configuration.exceptions import ConfigurationValueError
from dlt.destinations.impl.databricks.databricks import DatabricksLoadJob
from dlt.common.configuration import resolve_configuration

Expand Down Expand Up @@ -86,3 +87,12 @@ def test_databricks_abfss_converter() -> None:
abfss_url
== "abfss://dlt-ci-test-bucket@my_account.dfs.core.windows.net/path/to/file.parquet"
)


def test_databricks_auth_invalid() -> None:
with pytest.raises(ConfigurationValueError, match="No valid authentication method detected.*"):
os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = ""
os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = ""
os.environ["DESTINATION__DATABRICKS__CREDENTIALS__ACCESS_TOKEN"] = ""
bricks = databricks()
bricks.configuration(None, accept_partial=True)
32 changes: 28 additions & 4 deletions tests/load/pipeline/test_databricks_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os

from dlt.common.utils import uniq_id
from dlt.destinations import databricks
from tests.load.utils import (
GCS_BUCKET,
DestinationTestConfiguration,
Expand Down Expand Up @@ -151,12 +152,10 @@ def test_databricks_gcs_external_location(destination_config: DestinationTestCon
destinations_configs(default_sql_configs=True, subset=("databricks",)),
ids=lambda x: x.name,
)
def test_databricks_oauth(destination_config: DestinationTestConfiguration) -> None:
from dlt.destinations import databricks

def test_databricks_auth_oauth(destination_config: DestinationTestConfiguration) -> None:
os.environ["DESTINATION__DATABRICKS__CREDENTIALS__ACCESS_TOKEN"] = ""
bricks = databricks()
config = bricks.configuration(None, accept_partial=True)

assert config.credentials.client_id and config.credentials.client_secret

dataset_name = "test_databricks_oauth" + uniq_id()
Expand All @@ -170,3 +169,28 @@ def test_databricks_oauth(destination_config: DestinationTestConfiguration) -> N
with pipeline.sql_client() as client:
rows = client.execute_sql(f"select * from {dataset_name}.digits")
assert len(rows) == 3


@pytest.mark.parametrize(
"destination_config",
destinations_configs(default_sql_configs=True, subset=("databricks",)),
ids=lambda x: x.name,
)
def test_databricks_auth_token(destination_config: DestinationTestConfiguration) -> None:
os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_ID"] = ""
os.environ["DESTINATION__DATABRICKS__CREDENTIALS__CLIENT_SECRET"] = ""
bricks = databricks()
config = bricks.configuration(None, accept_partial=True)
assert config.credentials.access_token

dataset_name = "test_databricks_token" + uniq_id()
pipeline = destination_config.setup_pipeline(
"test_databricks_token", dataset_name=dataset_name, destination=bricks
)

info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs)
assert info.has_failed_jobs is False

with pipeline.sql_client() as client:
rows = client.execute_sql(f"select * from {dataset_name}.digits")
assert len(rows) == 3

0 comments on commit a59ba3e

Please sign in to comment.