Skip to content

Commit

Permalink
Personal access token auth only
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Jan 30, 2024
1 parent 64b7e2e commit 3076700
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 74 deletions.
81 changes: 10 additions & 71 deletions dlt/destinations/impl/databricks/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,93 +6,35 @@
from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration


CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog"


@configspec
class DatabricksCredentials(CredentialsConfiguration):
catalog: Optional[str] = None
catalog: str = None
server_hostname: str = None
http_path: str = None
access_token: Optional[TSecretStrValue] = None
client_id: Optional[str] = None
client_secret: Optional[TSecretStrValue] = None
session_properties: Optional[Dict[str, Any]] = None
http_headers: Optional[Dict[str, str]] = None
session_configuration: Optional[Dict[str, Any]] = None
"""Dict of session parameters that will be passed to `databricks.sql.connect`"""
connection_parameters: Optional[Dict[str, Any]] = None
auth_type: Optional[str] = None

connect_retries: int = 1
connect_timeout: Optional[int] = None
retry_all: bool = False

_credentials_provider: Optional[Dict[str, Any]] = None
"""Additional keyword arguments that are passed to `databricks.sql.connect`"""
socket_timeout: Optional[int] = 180

__config_gen_annotations__: ClassVar[List[str]] = [
"server_hostname",
"http_path",
"catalog",
"schema",
"access_token",
]

def __post_init__(self) -> None:
session_properties = self.session_properties or {}
if CATALOG_KEY_IN_SESSION_PROPERTIES in session_properties:
if self.catalog is None:
self.catalog = session_properties[CATALOG_KEY_IN_SESSION_PROPERTIES]
del session_properties[CATALOG_KEY_IN_SESSION_PROPERTIES]
else:
raise ConfigurationValueError(
f"Got duplicate keys: (`{CATALOG_KEY_IN_SESSION_PROPERTIES}` "
'in session_properties) all map to "catalog"'
)
self.session_properties = session_properties

if self.catalog is not None:
catalog = self.catalog.strip()
if not catalog:
raise ConfigurationValueError(f"Invalid catalog name : `{self.catalog}`.")
self.catalog = catalog
else:
self.catalog = "hive_metastore"

connection_parameters = self.connection_parameters or {}
for key in (
"server_hostname",
"http_path",
"access_token",
"client_id",
"client_secret",
"session_configuration",
"catalog",
"_user_agent_entry",
):
if key in connection_parameters:
raise ConfigurationValueError(f"The connection parameter `{key}` is reserved.")
if "http_headers" in connection_parameters:
http_headers = connection_parameters["http_headers"]
if not isinstance(http_headers, dict) or any(
not isinstance(key, str) or not isinstance(value, str)
for key, value in http_headers.items()
):
raise ConfigurationValueError(
"The connection parameter `http_headers` should be dict of strings: "
f"{http_headers}."
)
if "_socket_timeout" not in connection_parameters:
connection_parameters["_socket_timeout"] = 180
self.connection_parameters = connection_parameters

def to_connector_params(self) -> Dict[str, Any]:
return dict(
catalog=self.catalog,
server_hostname=self.server_hostname,
http_path=self.http_path,
access_token=self.access_token,
client_id=self.client_id,
client_secret=self.client_secret,
session_properties=self.session_properties or {},
connection_parameters=self.connection_parameters or {},
auth_type=self.auth_type,
session_configuration=self.session_configuration or {},
_socket_timeout=self.socket_timeout,
**(self.connection_parameters or {}),
)


Expand All @@ -101,9 +43,6 @@ class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration
destination_type: Final[str] = "databricks" # type: ignore[misc]
credentials: DatabricksCredentials

stage_name: Optional[str] = None
"""Use an existing named stage instead of the default. Default uses the implicit table stage per table"""

def __str__(self) -> str:
"""Return displayable destination location"""
if self.staging_config:
Expand Down
3 changes: 0 additions & 3 deletions dlt/destinations/impl/databricks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,6 @@ def __init__(
)
from_clause = ""
credentials_clause = ""
files_clause = ""
# stage_file_path = ""

if bucket_path:
bucket_url = urlparse(bucket_path)
Expand Down Expand Up @@ -182,7 +180,6 @@ def __init__(

statement = f"""COPY INTO {qualified_table_name}
{from_clause}
{files_clause}
{credentials_clause}
FILEFORMAT = {source_format}
"""
Expand Down
Empty file.
32 changes: 32 additions & 0 deletions tests/load/databricks/test_databricks_configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
import os

pytest.importorskip("databricks")


from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration
from dlt.common.configuration import resolve_configuration
from tests.utils import preserve_environ


def test_databricks_credentials_to_connector_params():
os.environ["CREDENTIALS__SERVER_HOSTNAME"] = "my-databricks.example.com"
os.environ["CREDENTIALS__HTTP_PATH"] = "/sql/1.0/warehouses/asdfe"
os.environ["CREDENTIALS__ACCESS_TOKEN"] = "my-token"
os.environ["CREDENTIALS__CATALOG"] = "my-catalog"
# JSON encoded dict of extra args
os.environ["CREDENTIALS__CONNECTION_PARAMETERS"] = '{"extra_a": "a", "extra_b": "b"}'

config = resolve_configuration(DatabricksClientConfiguration(dataset_name="my-dataset"))

credentials = config.credentials

params = credentials.to_connector_params()

assert params["server_hostname"] == "my-databricks.example.com"
assert params["http_path"] == "/sql/1.0/warehouses/asdfe"
assert params["access_token"] == "my-token"
assert params["catalog"] == "my-catalog"
assert params["extra_a"] == "a"
assert params["extra_b"] == "b"
assert params["_socket_timeout"] == credentials.socket_timeout

0 comments on commit 3076700

Please sign in to comment.