diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index 42167b00..6387fff8 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -10,7 +10,7 @@ import re -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Union if TYPE_CHECKING: # Use this import purely for type annotations, a la https://mypy.readthedocs.io/en/latest/runtime_troubles.html#import-cycles @@ -84,7 +84,32 @@ def TimestampFromTicks(ticks): return Timestamp(*time.localtime(ticks)[:6]) -def connect(server_hostname, http_path, access_token=None, **kwargs) -> "Connection": +def singleton(class_): + instances = {} + + def getinstance(*args, **kwargs): + if class_ not in instances: + instances[class_] = class_(*args, **kwargs) + return instances[class_] + + return getinstance + + +@singleton +class DefaultNone(object): + """Used to represent a default value of None so that this code can distinguish between + the user passing None versus a default value of None being used. + """ + + pass + + +def connect( + server_hostname, + http_path, + access_token: Optional[Union[str, DefaultNone]] = DefaultNone, + **kwargs +) -> "Connection": from .client import Connection return Connection(server_hostname, http_path, access_token, **kwargs) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 4e0ab941..2b83465f 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -12,7 +12,7 @@ import decimal from uuid import UUID -from databricks.sql import __version__ +from databricks.sql import __version__, DefaultNone from databricks.sql import * from databricks.sql.exc import ( OperationalError, @@ -63,7 +63,7 @@ def __init__( self, server_hostname: str, http_path: str, - access_token: Optional[str] = None, + access_token: Optional[Union[str, DefaultNone]] = None, http_headers: Optional[List[Tuple[str, str]]] = None, session_configuration: Optional[Dict[str, Any]] = None, catalog: Optional[str] = None, @@ -204,7 +204,13 @@ def read(self) -> Optional[OAuthToken]: # use_cloud_fetch # Enable use of cloud fetch to extract large query results in parallel via cloud storage - if access_token: + if access_token is DefaultNone: + access_token = None + elif access_token is None: + logger.info( + "Connection access_token was passed a None value. U2M OAuth will be attempted" + ) + else: access_token_kv = {"access_token": access_token} kwargs = {**kwargs, **access_token_kv}