diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 3792d6d0..cc421e69 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -10,7 +10,7 @@ from databricks.sql.auth.common import AuthType, ClientContext -def get_auth_provider(cfg: ClientContext): +def get_auth_provider(cfg: ClientContext, http_client): if cfg.credentials_provider: return ExternalAuthProvider(cfg.credentials_provider) elif cfg.auth_type == AuthType.AZURE_SP_M2M.value: @@ -35,6 +35,7 @@ def get_auth_provider(cfg: ClientContext): cfg.oauth_client_id, cfg.oauth_scopes, cfg.auth_type, + http_client=http_client, ) elif cfg.access_token is not None: return AccessTokenAuthProvider(cfg.access_token) @@ -53,6 +54,7 @@ def get_auth_provider(cfg: ClientContext): cfg.oauth_redirect_port_range, cfg.oauth_client_id, cfg.oauth_scopes, + http_client=http_client, ) else: raise RuntimeError("No valid authentication settings!") @@ -79,7 +81,7 @@ def get_client_id_and_redirect_port(use_azure_auth: bool): ) -def get_python_sql_connector_auth_provider(hostname: str, **kwargs): +def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs): # TODO : unify all the auth mechanisms with the Python SDK auth_type = kwargs.get("auth_type") @@ -111,4 +113,4 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs): oauth_persistence=kwargs.get("experimental_oauth_persistence"), credentials_provider=kwargs.get("credentials_provider"), ) - return get_auth_provider(cfg) + return get_auth_provider(cfg, http_client) diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 26c1f370..80f44812 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -63,6 +63,7 @@ def __init__( redirect_port_range: List[int], client_id: str, scopes: List[str], + http_client, auth_type: str = "databricks-oauth", ): try: @@ -79,6 +80,7 @@ def __init__( port_range=redirect_port_range, client_id=client_id, idp_endpoint=idp_endpoint, + http_client=http_client, ) self._hostname = hostname self._scopes_as_str = DatabricksOAuthProvider.SCOPE_DELIM.join(cloud_scopes) diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 5cfbc37c..61b07cb9 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -2,7 +2,6 @@ import logging from typing import Optional, List from urllib.parse import urlparse -from databricks.sql.common.http import DatabricksHttpClient, HttpMethod logger = logging.getLogger(__name__) @@ -36,6 +35,21 @@ def __init__( tls_client_cert_file: Optional[str] = None, oauth_persistence=None, credentials_provider=None, + # HTTP client configuration parameters + ssl_options=None, # SSLOptions type + socket_timeout: Optional[float] = None, + retry_stop_after_attempts_count: Optional[int] = None, + retry_delay_min: Optional[float] = None, + retry_delay_max: Optional[float] = None, + retry_stop_after_attempts_duration: Optional[float] = None, + retry_delay_default: Optional[float] = None, + retry_dangerous_codes: Optional[List[int]] = None, + http_proxy: Optional[str] = None, + proxy_username: Optional[str] = None, + proxy_password: Optional[str] = None, + pool_connections: Optional[int] = None, + pool_maxsize: Optional[int] = None, + user_agent: Optional[str] = None, ): self.hostname = hostname self.access_token = access_token @@ -52,6 +66,24 @@ def __init__( self.oauth_persistence = oauth_persistence self.credentials_provider = credentials_provider + # HTTP client configuration + self.ssl_options = ssl_options + self.socket_timeout = socket_timeout + self.retry_stop_after_attempts_count = retry_stop_after_attempts_count or 30 + self.retry_delay_min = retry_delay_min or 1.0 + self.retry_delay_max = retry_delay_max or 60.0 + self.retry_stop_after_attempts_duration = ( + retry_stop_after_attempts_duration or 900.0 + ) + self.retry_delay_default = retry_delay_default or 5.0 + self.retry_dangerous_codes = retry_dangerous_codes or [] + self.http_proxy = http_proxy + self.proxy_username = proxy_username + self.proxy_password = proxy_password + self.pool_connections = pool_connections or 10 + self.pool_maxsize = pool_maxsize or 20 + self.user_agent = user_agent + def get_effective_azure_login_app_id(hostname) -> str: """ @@ -69,7 +101,7 @@ def get_effective_azure_login_app_id(hostname) -> str: return AzureAppId.PROD.value[1] -def get_azure_tenant_id_from_host(host: str, http_client=None) -> str: +def get_azure_tenant_id_from_host(host: str, http_client) -> str: """ Load the Azure tenant ID from the Azure Databricks login page. @@ -78,23 +110,22 @@ def get_azure_tenant_id_from_host(host: str, http_client=None) -> str: the Azure login page, and the tenant ID is extracted from the redirect URL. """ - if http_client is None: - http_client = DatabricksHttpClient.get_instance() - login_url = f"{host}/aad/auth" logger.debug("Loading tenant ID from %s", login_url) - with http_client.execute(HttpMethod.GET, login_url, allow_redirects=False) as resp: - if resp.status_code // 100 != 3: + + with http_client.request_context("GET", login_url, allow_redirects=False) as resp: + if resp.status // 100 != 3: raise ValueError( - f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}" + f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status}" ) - entra_id_endpoint = resp.headers.get("Location") + entra_id_endpoint = dict(resp.headers).get("Location") if entra_id_endpoint is None: raise ValueError(f"No Location header in response from {login_url}") - # The Location header has the following form: https://login.microsoftonline.com//oauth2/authorize?... - # The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud). - url = urlparse(entra_id_endpoint) - path_segments = url.path.split("/") - if len(path_segments) < 2: - raise ValueError(f"Invalid path in Location header: {url.path}") - return path_segments[1] + + # The Location header has the following form: https://login.microsoftonline.com//oauth2/authorize?... + # The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud). + url = urlparse(entra_id_endpoint) + path_segments = url.path.split("/") + if len(path_segments) < 2: + raise ValueError(f"Invalid path in Location header: {url.path}") + return path_segments[1] diff --git a/src/databricks/sql/auth/oauth.py b/src/databricks/sql/auth/oauth.py index aa3184d8..7f96a230 100644 --- a/src/databricks/sql/auth/oauth.py +++ b/src/databricks/sql/auth/oauth.py @@ -9,10 +9,8 @@ from typing import List, Optional import oauthlib.oauth2 -import requests from oauthlib.oauth2.rfc6749.errors import OAuth2Error -from requests.exceptions import RequestException -from databricks.sql.common.http import HttpMethod, DatabricksHttpClient, HttpHeader +from databricks.sql.common.http import HttpMethod, HttpHeader from databricks.sql.common.http import OAuthResponse from databricks.sql.auth.oauth_http_handler import OAuthHttpSingleRequestHandler from databricks.sql.auth.endpoint import OAuthEndpointCollection @@ -63,33 +61,19 @@ def refresh(self) -> Token: pass -class IgnoreNetrcAuth(requests.auth.AuthBase): - """This auth method is a no-op. - - We use it to force requestslib to not use .netrc to write auth headers - when making .post() requests to the oauth token endpoints, since these - don't require authentication. - - In cases where .netrc is outdated or corrupt, these requests will fail. - - See issue #121 - """ - - def __call__(self, r): - return r - - class OAuthManager: def __init__( self, port_range: List[int], client_id: str, idp_endpoint: OAuthEndpointCollection, + http_client, ): self.port_range = port_range self.client_id = client_id self.redirect_port = None self.idp_endpoint = idp_endpoint + self.http_client = http_client @staticmethod def __token_urlsafe(nbytes=32): @@ -103,8 +87,11 @@ def __fetch_well_known_config(self, hostname: str): known_config_url = self.idp_endpoint.get_openid_config_url(hostname) try: - response = requests.get(url=known_config_url, auth=IgnoreNetrcAuth()) - except RequestException as e: + response = self.http_client.request("GET", url=known_config_url) + # Convert urllib3 response to requests-like response for compatibility + response.status_code = response.status + response.json = lambda: json.loads(response.data.decode()) + except Exception as e: logger.error( f"Unable to fetch OAuth configuration from {known_config_url}.\n" "Verify it is a valid workspace URL and that OAuth is " @@ -122,7 +109,7 @@ def __fetch_well_known_config(self, hostname: str): raise RuntimeError(msg) try: return response.json() - except requests.exceptions.JSONDecodeError as e: + except Exception as e: logger.error( f"Unable to decode OAuth configuration from {known_config_url}.\n" "Verify it is a valid workspace URL and that OAuth is " @@ -209,10 +196,12 @@ def __send_token_request(token_request_url, data): "Accept": "application/json", "Content-Type": "application/x-www-form-urlencoded", } - response = requests.post( - url=token_request_url, data=data, headers=headers, auth=IgnoreNetrcAuth() + # Use unified HTTP client + response = self.http_client.request( + "POST", url=token_request_url, body=data, headers=headers ) - return response.json() + # Convert urllib3 response to dict for compatibility + return json.loads(response.data.decode()) def __send_refresh_token_request(self, hostname, refresh_token): oauth_config = self.__fetch_well_known_config(hostname) @@ -320,6 +309,7 @@ def __init__( token_url, client_id, client_secret, + http_client, extra_params: dict = {}, ): self.client_id = client_id @@ -327,7 +317,7 @@ def __init__( self.token_url = token_url self.extra_params = extra_params self.token: Optional[Token] = None - self._http_client = DatabricksHttpClient.get_instance() + self._http_client = http_client def get_token(self) -> Token: if self.token is None or self.token.is_expired(): diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index 130f0c5b..4a319c44 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -50,6 +50,7 @@ def build_queue( max_download_threads: int, sea_client: SeaDatabricksClient, lz4_compressed: bool, + http_client, ) -> ResultSetQueue: """ Factory method to build a result set queue for SEA backend. @@ -94,6 +95,7 @@ def build_queue( total_chunk_count=manifest.total_chunk_count, lz4_compressed=lz4_compressed, description=description, + http_client=http_client, ) raise ProgrammingError("Invalid result format") @@ -309,6 +311,7 @@ def __init__( sea_client: SeaDatabricksClient, statement_id: str, total_chunk_count: int, + http_client, lz4_compressed: bool = False, description: List[Tuple] = [], ): @@ -337,6 +340,7 @@ def __init__( # TODO: fix these arguments when telemetry is implemented in SEA session_id_hex=None, chunk_id=0, + http_client=http_client, ) logger.debug( diff --git a/src/databricks/sql/backend/sea/result_set.py b/src/databricks/sql/backend/sea/result_set.py index afa70bc8..17838ed8 100644 --- a/src/databricks/sql/backend/sea/result_set.py +++ b/src/databricks/sql/backend/sea/result_set.py @@ -64,6 +64,7 @@ def __init__( max_download_threads=sea_client.max_download_threads, sea_client=sea_client, lz4_compressed=execute_response.lz4_compressed, + http_client=connection.session.http_client, ) # Call parent constructor with common attributes diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index b404b166..801632a4 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -105,6 +105,7 @@ def __init__( http_headers, auth_provider: AuthProvider, ssl_options: SSLOptions, + http_client=None, **kwargs, ): # Internal arguments in **kwargs: @@ -145,10 +146,8 @@ def __init__( # Number of threads for handling cloud fetch downloads. Defaults to 10 logger.debug( - "ThriftBackend.__init__(server_hostname=%s, port=%s, http_path=%s)", - server_hostname, - port, - http_path, + "ThriftBackend.__init__(server_hostname=%s, port=%s, http_path=%s)" + % (server_hostname, port, http_path) ) port = port or 443 @@ -177,8 +176,8 @@ def __init__( self._max_download_threads = kwargs.get("max_download_threads", 10) self._ssl_options = ssl_options - self._auth_provider = auth_provider + self._http_client = http_client # Connector version 3 retry approach self.enable_v3_retries = kwargs.get("_enable_v3_retries", True) @@ -1292,6 +1291,7 @@ def fetch_results( session_id_hex=self._session_id_hex, statement_id=command_id.to_hex_guid(), chunk_id=chunk_id, + http_client=self._http_client, ) return ( diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 73ee0e03..1a35f97d 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -6,7 +6,6 @@ import pyarrow except ImportError: pyarrow = None -import requests import json import os import decimal @@ -51,6 +50,9 @@ from databricks.sql.session import Session from databricks.sql.backend.types import CommandId, BackendType, CommandState, SessionId +from databricks.sql.auth.common import ClientContext +from databricks.sql.common.unified_http_client import UnifiedHttpClient + from databricks.sql.thrift_api.TCLIService.ttypes import ( TOpenSessionResp, TSparkParameter, @@ -252,10 +254,14 @@ def read(self) -> Optional[OAuthToken]: "telemetry_batch_size", TelemetryClientFactory.DEFAULT_BATCH_SIZE ) + client_context = self._build_client_context(server_hostname, **kwargs) + http_client = UnifiedHttpClient(client_context) + try: self.session = Session( server_hostname, http_path, + http_client, http_headers, session_configuration, catalog, @@ -271,6 +277,7 @@ def read(self) -> Optional[OAuthToken]: host_url=server_hostname, http_path=http_path, port=kwargs.get("_port", 443), + http_client=http_client, user_agent=self.session.useragent_header if hasattr(self, "session") else None, @@ -292,6 +299,7 @@ def read(self) -> Optional[OAuthToken]: auth_provider=self.session.auth_provider, host_url=self.session.host, batch_size=self.telemetry_batch_size, + http_client=self.session.http_client, ) self._telemetry_client = TelemetryClientFactory.get_telemetry_client( @@ -342,6 +350,50 @@ def _set_use_inline_params_with_warning(self, value: Union[bool, str]): return value + def _build_client_context(self, server_hostname: str, **kwargs): + """Build ClientContext for HTTP client configuration.""" + from databricks.sql.auth.common import ClientContext + from databricks.sql.types import SSLOptions + + # Extract SSL options + ssl_options = SSLOptions( + tls_verify=not kwargs.get("_tls_no_verify", False), + tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), + tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), + tls_client_cert_file=kwargs.get("_tls_client_cert_file"), + tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), + tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), + ) + + # Build user agent + user_agent_entry = kwargs.get("user_agent_entry", "") + if user_agent_entry: + user_agent = f"PyDatabricksSqlConnector/{__version__} ({user_agent_entry})" + else: + user_agent = f"PyDatabricksSqlConnector/{__version__}" + + return ClientContext( + hostname=server_hostname, + ssl_options=ssl_options, + socket_timeout=kwargs.get("_socket_timeout"), + retry_stop_after_attempts_count=kwargs.get( + "_retry_stop_after_attempts_count", 30 + ), + retry_delay_min=kwargs.get("_retry_delay_min", 1.0), + retry_delay_max=kwargs.get("_retry_delay_max", 60.0), + retry_stop_after_attempts_duration=kwargs.get( + "_retry_stop_after_attempts_duration", 900.0 + ), + retry_delay_default=kwargs.get("_retry_delay_default", 1.0), + retry_dangerous_codes=kwargs.get("_retry_dangerous_codes", []), + http_proxy=kwargs.get("_http_proxy"), + proxy_username=kwargs.get("_proxy_username"), + proxy_password=kwargs.get("_proxy_password"), + pool_connections=kwargs.get("_pool_connections", 1), + pool_maxsize=kwargs.get("_pool_maxsize", 1), + user_agent=user_agent, + ) + # The ideal return type for this method is perhaps Self, but that was not added until 3.11, and we support pre-3.11 pythons, currently. def __enter__(self) -> "Connection": return self @@ -395,7 +447,7 @@ def get_protocol_version(openSessionResp: TOpenSessionResp): @property def open(self) -> bool: """Return whether the connection is open by checking if the session is open.""" - return self.session.is_open + return hasattr(self, "session") and self.session.is_open def cursor( self, @@ -744,16 +796,22 @@ def _handle_staging_put( ) with open(local_file, "rb") as fh: - r = requests.put(url=presigned_url, data=fh, headers=headers) + r = self.connection.session.http_client.request( + "PUT", presigned_url, body=fh.read(), headers=headers + ) + # Add compatibility attributes for urllib3 response + r.status_code = r.status + if hasattr(r, "data"): + r.content = r.data + r.ok = r.status < 400 + r.text = r.data.decode() if r.data else "" # fmt: off - # Design borrowed from: https://stackoverflow.com/a/2342589/5093960 - - OK = requests.codes.ok # 200 - CREATED = requests.codes.created # 201 - ACCEPTED = requests.codes.accepted # 202 - NO_CONTENT = requests.codes.no_content # 204 - + # HTTP status codes + OK = 200 + CREATED = 201 + ACCEPTED = 202 + NO_CONTENT = 204 # fmt: on if r.status_code not in [OK, CREATED, NO_CONTENT, ACCEPTED]: @@ -783,7 +841,15 @@ def _handle_staging_get( session_id_hex=self.connection.get_session_id_hex(), ) - r = requests.get(url=presigned_url, headers=headers) + r = self.connection.session.http_client.request( + "GET", presigned_url, headers=headers + ) + # Add compatibility attributes for urllib3 response + r.status_code = r.status + if hasattr(r, "data"): + r.content = r.data + r.ok = r.status < 400 + r.text = r.data.decode() if r.data else "" # response.ok verifies the status code is not between 400-600. # Any 2xx or 3xx will evaluate r.ok == True @@ -802,7 +868,15 @@ def _handle_staging_remove( ): """Make an HTTP DELETE request to the presigned_url""" - r = requests.delete(url=presigned_url, headers=headers) + r = self.connection.session.http_client.request( + "DELETE", presigned_url, headers=headers + ) + # Add compatibility attributes for urllib3 response + r.status_code = r.status + if hasattr(r, "data"): + r.content = r.data + r.ok = r.status < 400 + r.text = r.data.decode() if r.data else "" if not r.ok: raise OperationalError( diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index 32b698be..27265720 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -25,6 +25,7 @@ def __init__( session_id_hex: Optional[str], statement_id: str, chunk_id: int, + http_client, ): self._pending_links: List[Tuple[int, TSparkArrowResultLink]] = [] self.chunk_id = chunk_id @@ -47,6 +48,7 @@ def __init__( self._ssl_options = ssl_options self.session_id_hex = session_id_hex self.statement_id = statement_id + self._http_client = http_client def get_next_downloaded_file( self, next_row_offset: int @@ -109,6 +111,7 @@ def _schedule_downloads(self): chunk_id=chunk_id, session_id_hex=self.session_id_hex, statement_id=self.statement_id, + http_client=self._http_client, ) task = self._thread_pool.submit(handler.run) self._download_tasks.append(task) diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 1331fa20..cef4ca27 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -2,10 +2,9 @@ from dataclasses import dataclass from typing import Optional -from requests.adapters import Retry import lz4.frame import time -from databricks.sql.common.http import DatabricksHttpClient, HttpMethod +from databricks.sql.common.http import HttpMethod from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink from databricks.sql.exc import Error from databricks.sql.types import SSLOptions @@ -16,16 +15,6 @@ # TODO: Ideally, we should use a common retry policy (DatabricksRetryPolicy) for all the requests across the library. # But DatabricksRetryPolicy should be updated first - currently it can work only with Thrift requests -retryPolicy = Retry( - total=5, # max retry attempts - backoff_factor=1, # min delay, 1 second - # TODO: `backoff_max` is supported since `urllib3` v2.0.0, but we allow >= 1.26. - # The default value (120 seconds) used since v1.26 looks reasonable enough - # backoff_max=60, # max delay, 60 seconds - # retry all status codes below 100, 429 (Too Many Requests), and all codes above 500, - # excluding 501 Not implemented - status_forcelist=[*range(0, 101), 429, 500, *range(502, 1000)], -) @dataclass @@ -73,11 +62,12 @@ def __init__( chunk_id: int, session_id_hex: Optional[str], statement_id: str, + http_client, ): self.settings = settings self.link = link self._ssl_options = ssl_options - self._http_client = DatabricksHttpClient.get_instance() + self._http_client = http_client self.chunk_id = chunk_id self.session_id_hex = session_id_hex self.statement_id = statement_id @@ -104,50 +94,47 @@ def run(self) -> DownloadedFile: start_time = time.time() - with self._http_client.execute( - method=HttpMethod.GET, + with self._http_client.request_context( + method="GET", url=self.link.fileLink, timeout=self.settings.download_timeout, - verify=self._ssl_options.tls_verify, - headers=self.link.httpHeaders - # TODO: Pass cert from `self._ssl_options` + headers=self.link.httpHeaders, ) as response: - response.raise_for_status() - - # Save (and decompress if needed) the downloaded file - compressed_data = response.content - - # Log download metrics - download_duration = time.time() - start_time - self._log_download_metrics( - self.link.fileLink, len(compressed_data), download_duration - ) - - decompressed_data = ( - ResultSetDownloadHandler._decompress_data(compressed_data) - if self.settings.is_lz4_compressed - else compressed_data - ) + if response.status >= 400: + raise Exception(f"HTTP {response.status}: {response.data.decode()}") + compressed_data = response.data + + # Log download metrics + download_duration = time.time() - start_time + self._log_download_metrics( + self.link.fileLink, len(compressed_data), download_duration + ) - # The size of the downloaded file should match the size specified from TSparkArrowResultLink - if len(decompressed_data) != self.link.bytesNum: - logger.debug( - "ResultSetDownloadHandler: downloaded file size {} does not match the expected value {}".format( - len(decompressed_data), self.link.bytesNum - ) - ) + decompressed_data = ( + ResultSetDownloadHandler._decompress_data(compressed_data) + if self.settings.is_lz4_compressed + else compressed_data + ) + # The size of the downloaded file should match the size specified from TSparkArrowResultLink + if len(decompressed_data) != self.link.bytesNum: logger.debug( - "ResultSetDownloadHandler: successfully downloaded file, offset {}, row count {}".format( - self.link.startRowOffset, self.link.rowCount + "ResultSetDownloadHandler: downloaded file size {} does not match the expected value {}".format( + len(decompressed_data), self.link.bytesNum ) ) - return DownloadedFile( - decompressed_data, - self.link.startRowOffset, - self.link.rowCount, + logger.debug( + "ResultSetDownloadHandler: successfully downloaded file, offset {}, row count {}".format( + self.link.startRowOffset, self.link.rowCount ) + ) + + return DownloadedFile( + decompressed_data, + self.link.startRowOffset, + self.link.rowCount, + ) def _log_download_metrics( self, url: str, bytes_downloaded: int, duration_seconds: float diff --git a/src/databricks/sql/common/feature_flag.py b/src/databricks/sql/common/feature_flag.py index 53add925..1b920b00 100644 --- a/src/databricks/sql/common/feature_flag.py +++ b/src/databricks/sql/common/feature_flag.py @@ -1,6 +1,6 @@ +import json import threading import time -import requests from dataclasses import dataclass, field from concurrent.futures import ThreadPoolExecutor from typing import Dict, Optional, List, Any, TYPE_CHECKING @@ -49,7 +49,9 @@ class FeatureFlagsContext: in the background, returning stale data until the refresh completes. """ - def __init__(self, connection: "Connection", executor: ThreadPoolExecutor): + def __init__( + self, connection: "Connection", executor: ThreadPoolExecutor, http_client + ): from databricks.sql import __version__ self._connection = connection @@ -66,6 +68,9 @@ def __init__(self, connection: "Connection", executor: ThreadPoolExecutor): f"https://{self._connection.session.host}{endpoint_suffix}" ) + # Use the provided HTTP client + self._http_client = http_client + def _is_refresh_needed(self) -> bool: """Checks if the cache is due for a proactive background refresh.""" if self._flags is None: @@ -105,9 +110,12 @@ def _refresh_flags(self): self._connection.session.auth_provider.add_headers(headers) headers["User-Agent"] = self._connection.session.useragent_header - response = requests.get( - self._feature_flag_endpoint, headers=headers, timeout=30 + response = self._http_client.request( + "GET", self._feature_flag_endpoint, headers=headers, timeout=30 ) + # Add compatibility attributes for urllib3 response + response.status_code = response.status + response.json = lambda: json.loads(response.data.decode()) if response.status_code == 200: ff_response = FeatureFlagsResponse.from_dict(response.json()) @@ -159,7 +167,9 @@ def get_instance(cls, connection: "Connection") -> FeatureFlagsContext: # Use the unique session ID as the key key = connection.get_session_id_hex() if key not in cls._context_map: - cls._context_map[key] = FeatureFlagsContext(connection, cls._executor) + cls._context_map[key] = FeatureFlagsContext( + connection, cls._executor, connection.session.http_client + ) return cls._context_map[key] @classmethod diff --git a/src/databricks/sql/common/http.py b/src/databricks/sql/common/http.py index 0cd2919c..cf76a5fb 100644 --- a/src/databricks/sql/common/http.py +++ b/src/databricks/sql/common/http.py @@ -38,115 +38,3 @@ class OAuthResponse: resource: str = "" access_token: str = "" refresh_token: str = "" - - -# Singleton class for common Http Client -class DatabricksHttpClient: - ## TODO: Unify all the http clients in the PySQL Connector - - _instance = None - _lock = threading.Lock() - - def __init__(self): - self.session = requests.Session() - adapter = HTTPAdapter( - pool_connections=5, - pool_maxsize=10, - max_retries=Retry(total=10, backoff_factor=0.1), - ) - self.session.mount("https://", adapter) - self.session.mount("http://", adapter) - - @classmethod - def get_instance(cls) -> "DatabricksHttpClient": - if cls._instance is None: - with cls._lock: - if cls._instance is None: - cls._instance = DatabricksHttpClient() - return cls._instance - - @contextmanager - def execute( - self, method: HttpMethod, url: str, **kwargs - ) -> Generator[requests.Response, None, None]: - logger.info("Executing HTTP request: %s with url: %s", method.value, url) - response = None - try: - response = self.session.request(method.value, url, **kwargs) - yield response - except Exception as e: - logger.error("Error executing HTTP request in DatabricksHttpClient: %s", e) - raise e - finally: - if response is not None: - response.close() - - def close(self): - self.session.close() - - -class TelemetryHTTPAdapter(HTTPAdapter): - """ - Custom HTTP adapter to prepare our DatabricksRetryPolicy before each request. - This ensures the retry timer is started and the command type is set correctly, - allowing the policy to manage its state for the duration of the request retries. - """ - - def send(self, request, **kwargs): - self.max_retries.command_type = CommandType.OTHER - self.max_retries.start_retry_timer() - return super().send(request, **kwargs) - - -class TelemetryHttpClient: # TODO: Unify all the http clients in the PySQL Connector - """Singleton HTTP client for sending telemetry data.""" - - _instance: Optional["TelemetryHttpClient"] = None - _lock = threading.Lock() - - TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT = 3 - TELEMETRY_RETRY_DELAY_MIN = 1.0 - TELEMETRY_RETRY_DELAY_MAX = 10.0 - TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION = 30.0 - - def __init__(self): - """Initializes the session and mounts the custom retry adapter.""" - retry_policy = DatabricksRetryPolicy( - delay_min=self.TELEMETRY_RETRY_DELAY_MIN, - delay_max=self.TELEMETRY_RETRY_DELAY_MAX, - stop_after_attempts_count=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT, - stop_after_attempts_duration=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION, - delay_default=1.0, - force_dangerous_codes=[], - ) - adapter = TelemetryHTTPAdapter(max_retries=retry_policy) - self.session = requests.Session() - self.session.mount("https://", adapter) - self.session.mount("http://", adapter) - - @classmethod - def get_instance(cls) -> "TelemetryHttpClient": - """Get the singleton instance of the TelemetryHttpClient.""" - if cls._instance is None: - with cls._lock: - if cls._instance is None: - logger.debug("Initializing singleton TelemetryHttpClient") - cls._instance = TelemetryHttpClient() - return cls._instance - - def post(self, url: str, **kwargs) -> requests.Response: - """ - Executes a POST request using the configured session. - - This is a blocking call intended to be run in a background thread. - """ - logger.debug("Executing telemetry POST request to: %s", url) - return self.session.post(url, **kwargs) - - def close(self): - """Closes the underlying requests.Session.""" - logger.debug("Closing TelemetryHttpClient session.") - self.session.close() - # Clear the instance to allow for re-initialization if needed - with TelemetryHttpClient._lock: - TelemetryHttpClient._instance = None diff --git a/src/databricks/sql/common/unified_http_client.py b/src/databricks/sql/common/unified_http_client.py new file mode 100644 index 00000000..a296704b --- /dev/null +++ b/src/databricks/sql/common/unified_http_client.py @@ -0,0 +1,231 @@ +import logging +import ssl +import urllib.parse +from contextlib import contextmanager +from typing import Dict, Any, Optional, Generator, Union + +import urllib3 +from urllib3 import PoolManager, ProxyManager +from urllib3.util import make_headers +from urllib3.exceptions import MaxRetryError + +from databricks.sql.auth.retry import DatabricksRetryPolicy +from databricks.sql.exc import RequestError + +logger = logging.getLogger(__name__) + + +class UnifiedHttpClient: + """ + Unified HTTP client for all Databricks SQL connector HTTP operations. + + This client uses urllib3 for robust HTTP communication with retry policies, + connection pooling, SSL support, and proxy support. It replaces the various + singleton HTTP clients and direct requests usage throughout the codebase. + """ + + def __init__(self, client_context): + """ + Initialize the unified HTTP client. + + Args: + client_context: ClientContext instance containing HTTP configuration + """ + self.config = client_context + self._pool_manager = None + self._setup_pool_manager() + + def _setup_pool_manager(self): + """Set up the urllib3 PoolManager with configuration from ClientContext.""" + + # SSL context setup + ssl_context = None + if self.config.ssl_options: + ssl_context = ssl.create_default_context() + + # Configure SSL verification + if not self.config.ssl_options.tls_verify: + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + elif not self.config.ssl_options.tls_verify_hostname: + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_REQUIRED + + # Load custom CA file if specified + if self.config.ssl_options.tls_trusted_ca_file: + ssl_context.load_verify_locations( + self.config.ssl_options.tls_trusted_ca_file + ) + + # Load client certificate if specified + if ( + self.config.ssl_options.tls_client_cert_file + and self.config.ssl_options.tls_client_cert_key_file + ): + ssl_context.load_cert_chain( + self.config.ssl_options.tls_client_cert_file, + self.config.ssl_options.tls_client_cert_key_file, + self.config.ssl_options.tls_client_cert_key_password, + ) + + # Create retry policy + retry_policy = DatabricksRetryPolicy( + delay_min=self.config.retry_delay_min, + delay_max=self.config.retry_delay_max, + stop_after_attempts_count=self.config.retry_stop_after_attempts_count, + stop_after_attempts_duration=self.config.retry_stop_after_attempts_duration, + delay_default=self.config.retry_delay_default, + force_dangerous_codes=self.config.retry_dangerous_codes, + ) + + # Common pool manager kwargs + pool_kwargs = { + "num_pools": self.config.pool_connections, + "maxsize": self.config.pool_maxsize, + "retries": retry_policy, + "timeout": urllib3.Timeout( + connect=self.config.socket_timeout, read=self.config.socket_timeout + ) + if self.config.socket_timeout + else None, + "ssl_context": ssl_context, + } + + # Create proxy or regular pool manager + if self.config.http_proxy: + proxy_headers = None + if self.config.proxy_username and self.config.proxy_password: + proxy_headers = make_headers( + proxy_basic_auth=f"{self.config.proxy_username}:{self.config.proxy_password}" + ) + + self._pool_manager = ProxyManager( + self.config.http_proxy, proxy_headers=proxy_headers, **pool_kwargs + ) + else: + self._pool_manager = PoolManager(**pool_kwargs) + + def _prepare_headers( + self, headers: Optional[Dict[str, str]] = None + ) -> Dict[str, str]: + """Prepare headers for the request, including User-Agent.""" + request_headers = {} + + if self.config.user_agent: + request_headers["User-Agent"] = self.config.user_agent + + if headers: + request_headers.update(headers) + + return request_headers + + @contextmanager + def request_context( + self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs + ) -> Generator[urllib3.HTTPResponse, None, None]: + """ + Context manager for making HTTP requests with proper resource cleanup. + + Args: + method: HTTP method (GET, POST, PUT, DELETE) + url: URL to request + headers: Optional headers dict + **kwargs: Additional arguments passed to urllib3 request + + Yields: + urllib3.HTTPResponse: The HTTP response object + """ + logger.debug("Making %s request to %s", method, url) + + request_headers = self._prepare_headers(headers) + response = None + + try: + response = self._pool_manager.request( + method=method, url=url, headers=request_headers, **kwargs + ) + yield response + except MaxRetryError as e: + logger.error("HTTP request failed after retries: %s", e) + raise RequestError(f"HTTP request failed: {e}") + except Exception as e: + logger.error("HTTP request error: %s", e) + raise RequestError(f"HTTP request error: {e}") + finally: + if response: + response.close() + + def request( + self, method: str, url: str, headers: Optional[Dict[str, str]] = None, **kwargs + ) -> urllib3.HTTPResponse: + """ + Make an HTTP request. + + Args: + method: HTTP method (GET, POST, PUT, DELETE, etc.) + url: URL to request + headers: Optional headers dict + **kwargs: Additional arguments passed to urllib3 request + + Returns: + urllib3.HTTPResponse: The HTTP response object with data pre-loaded + """ + with self.request_context(method, url, headers=headers, **kwargs) as response: + # Read the response data to ensure it's available after context exit + response._body = response.data + return response + + def upload_file( + self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None + ) -> urllib3.HTTPResponse: + """ + Upload a file using PUT method. + + Args: + url: URL to upload to + file_path: Path to the file to upload + headers: Optional headers + + Returns: + urllib3.HTTPResponse: The response from the server + """ + with open(file_path, "rb") as file_obj: + return self.request("PUT", url, body=file_obj.read(), headers=headers) + + def download_file( + self, url: str, file_path: str, headers: Optional[Dict[str, str]] = None + ) -> None: + """ + Download a file using GET method. + + Args: + url: URL to download from + file_path: Path where to save the downloaded file + headers: Optional headers + """ + response = self.request("GET", url, headers=headers) + with open(file_path, "wb") as file_obj: + file_obj.write(response.data) + + def close(self): + """Close the underlying connection pools.""" + if self._pool_manager: + self._pool_manager.clear() + self._pool_manager = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + +# Compatibility class to maintain requests-like interface for OAuth +class IgnoreNetrcAuth: + """ + Compatibility class for OAuth code that expects requests.auth.AuthBase interface. + This is a no-op auth handler since OAuth handles auth differently. + """ + + def __call__(self, request): + return request diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 9feb6e92..77673db9 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -244,6 +244,7 @@ def __init__( session_id_hex=connection.get_session_id_hex(), statement_id=execute_response.command_id.to_hex_guid(), chunk_id=self.num_chunks, + http_client=connection.session.http_client, ) if t_row_set.resultLinks: self.num_chunks += len(t_row_set.resultLinks) diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index f1bc35be..0cba8be4 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -4,6 +4,7 @@ from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.auth.auth import get_python_sql_connector_auth_provider +from databricks.sql.auth.common import ClientContext from databricks.sql.exc import SessionAlreadyClosedError, DatabaseError, RequestError from databricks.sql import __version__ from databricks.sql import USER_AGENT_NAME @@ -11,6 +12,7 @@ from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.common.unified_http_client import UnifiedHttpClient logger = logging.getLogger(__name__) @@ -20,6 +22,7 @@ def __init__( self, server_hostname: str, http_path: str, + http_client: UnifiedHttpClient, http_headers: Optional[List[Tuple[str, str]]] = None, session_configuration: Optional[Dict[str, Any]] = None, catalog: Optional[str] = None, @@ -42,10 +45,6 @@ def __init__( self.schema = schema self.http_path = http_path - self.auth_provider = get_python_sql_connector_auth_provider( - server_hostname, **kwargs - ) - user_agent_entry = kwargs.get("user_agent_entry") if user_agent_entry is None: user_agent_entry = kwargs.get("_user_agent_entry") @@ -77,6 +76,14 @@ def __init__( tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), ) + # Use the provided HTTP client (created in Connection) + self.http_client = http_client + + # Create auth provider with HTTP client context + self.auth_provider = get_python_sql_connector_auth_provider( + server_hostname, http_client=self.http_client, **kwargs + ) + self.backend = self._create_backend( server_hostname, http_path, @@ -115,6 +122,7 @@ def _create_backend( "http_headers": all_headers, "auth_provider": auth_provider, "ssl_options": self.ssl_options, + "http_client": self.http_client, "_use_arrow_native_complex_types": _use_arrow_native_complex_types, **kwargs, } @@ -185,3 +193,7 @@ def close(self) -> None: logger.error("Attempt to close session raised a local exception: %s", e) self.is_open = False + + # Close HTTP client if it exists + if hasattr(self, "http_client") and self.http_client: + self.http_client.close() diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 55f06c8d..2785d3cc 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -3,7 +3,6 @@ import logging from concurrent.futures import ThreadPoolExecutor from typing import Dict, Optional, TYPE_CHECKING -from databricks.sql.common.http import TelemetryHttpClient from databricks.sql.telemetry.models.event import ( TelemetryEvent, DriverSystemConfiguration, @@ -38,6 +37,8 @@ from databricks.sql.telemetry.utils import BaseTelemetryClient from databricks.sql.common.feature_flag import FeatureFlagsContextFactory +from src.databricks.sql.common.unified_http_client import UnifiedHttpClient + if TYPE_CHECKING: from databricks.sql.client import Connection @@ -168,6 +169,7 @@ def __init__( host_url, executor, batch_size, + http_client, ): logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled @@ -180,7 +182,7 @@ def __init__( self._driver_connection_params = None self._host_url = host_url self._executor = executor - self._http_client = TelemetryHttpClient.get_instance() + self._http_client = http_client def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" @@ -228,19 +230,38 @@ def _send_telemetry(self, events): try: logger.debug("Submitting telemetry request to thread pool") + + # Use unified HTTP client future = self._executor.submit( - self._http_client.post, + self._send_with_unified_client, url, data=request.to_json(), headers=headers, timeout=900, ) + future.add_done_callback( lambda fut: self._telemetry_request_callback(fut, sent_count=sent_count) ) except Exception as e: logger.debug("Failed to submit telemetry request: %s", e) + def _send_with_unified_client(self, url, data, headers): + """Helper method to send telemetry using the unified HTTP client.""" + try: + response = self._http_client.request( + "POST", url, body=data, headers=headers, timeout=900 + ) + # Convert urllib3 response to requests-like response for compatibility + response.status_code = response.status + response.json = ( + lambda: json.loads(response.data.decode()) if response.data else {} + ) + return response + except Exception as e: + logger.error("Failed to send telemetry with unified client: %s", e) + raise + def _telemetry_request_callback(self, future, sent_count: int): """Callback function to handle telemetry request completion""" try: @@ -431,6 +452,7 @@ def initialize_telemetry_client( auth_provider, host_url, batch_size, + http_client, ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: @@ -453,6 +475,7 @@ def initialize_telemetry_client( host_url=host_url, executor=TelemetryClientFactory._executor, batch_size=batch_size, + http_client=http_client, ) else: TelemetryClientFactory._clients[ @@ -493,7 +516,6 @@ def close(session_id_hex): try: TelemetryClientFactory._stop_flush_thread() TelemetryClientFactory._executor.shutdown(wait=True) - TelemetryHttpClient.close() except Exception as e: logger.debug("Failed to shutdown thread pool executor: %s", e) TelemetryClientFactory._executor = None @@ -506,6 +528,7 @@ def connection_failure_log( host_url: str, http_path: str, port: int, + http_client: UnifiedHttpClient, user_agent: Optional[str] = None, ): """Send error telemetry when connection creation fails, without requiring a session""" @@ -518,6 +541,7 @@ def connection_failure_log( auth_provider=None, host_url=host_url, batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=http_client, ) telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index c1d89ca5..ff48e0e9 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -64,6 +64,7 @@ def build_queue( session_id_hex: Optional[str], statement_id: str, chunk_id: int, + http_client, lz4_compressed: bool = True, description: List[Tuple] = [], ) -> ResultSetQueue: @@ -104,15 +105,16 @@ def build_queue( elif row_set_type == TSparkRowSetType.URL_BASED_SET: return ThriftCloudFetchQueue( schema_bytes=arrow_schema_bytes, - start_row_offset=t_row_set.startRowOffset, - result_links=t_row_set.resultLinks, - lz4_compressed=lz4_compressed, - description=description, max_download_threads=max_download_threads, ssl_options=ssl_options, session_id_hex=session_id_hex, statement_id=statement_id, chunk_id=chunk_id, + http_client=http_client, + start_row_offset=t_row_set.startRowOffset, + result_links=t_row_set.resultLinks, + lz4_compressed=lz4_compressed, + description=description, ) else: raise AssertionError("Row set type is not valid") @@ -224,6 +226,7 @@ def __init__( session_id_hex: Optional[str], statement_id: str, chunk_id: int, + http_client, schema_bytes: Optional[bytes] = None, lz4_compressed: bool = True, description: List[Tuple] = [], @@ -247,6 +250,7 @@ def __init__( self.session_id_hex = session_id_hex self.statement_id = statement_id self.chunk_id = chunk_id + self._http_client = http_client # Table state self.table = None @@ -261,6 +265,7 @@ def __init__( session_id_hex=session_id_hex, statement_id=statement_id, chunk_id=chunk_id, + http_client=http_client, ) def next_n_rows(self, num_rows: int) -> "pyarrow.Table": @@ -370,6 +375,7 @@ def __init__( session_id_hex: Optional[str], statement_id: str, chunk_id: int, + http_client, start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, @@ -396,6 +402,7 @@ def __init__( session_id_hex=session_id_hex, statement_id=statement_id, chunk_id=chunk_id, + http_client=http_client, ) self.start_row_index = start_row_offset diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index 8bf91470..333782fd 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -24,8 +24,8 @@ AzureOAuthEndpointCollection, ) from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory -from databricks.sql.common.http import DatabricksHttpClient from databricks.sql.experimental.oauth_persistence import OAuthPersistenceCache +import json class Auth(unittest.TestCase): @@ -98,12 +98,14 @@ def test_oauth_auth_provider(self, mock_get_tokens, mock_check_and_refresh): ) in params: with self.subTest(cloud_type.value): oauth_persistence = OAuthPersistenceCache() + mock_http_client = MagicMock() auth_provider = DatabricksOAuthProvider( hostname=host, oauth_persistence=oauth_persistence, redirect_port_range=[8020], client_id=client_id, scopes=scopes, + http_client=mock_http_client, auth_type=AuthType.AZURE_OAUTH.value if use_azure_auth else AuthType.DATABRICKS_OAUTH.value, @@ -142,7 +144,8 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: def test_get_python_sql_connector_auth_provider_access_token(self): hostname = "moderakh-test.cloud.databricks.com" kwargs = {"access_token": "dpi123"} - auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs) + mock_http_client = MagicMock() + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) self.assertTrue(type(auth_provider).__name__, "AccessTokenAuthProvider") headers = {} @@ -159,7 +162,8 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: hostname = "moderakh-test.cloud.databricks.com" kwargs = {"credentials_provider": MyProvider()} - auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs) + mock_http_client = MagicMock() + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) self.assertTrue(type(auth_provider).__name__, "ExternalAuthProvider") headers = {} @@ -174,7 +178,8 @@ def test_get_python_sql_connector_auth_provider_noop(self): "_tls_client_cert_file": tls_client_cert_file, "_use_cert_as_auth": use_cert_as_auth, } - auth_provider = get_python_sql_connector_auth_provider(hostname, **kwargs) + mock_http_client = MagicMock() + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) self.assertTrue(type(auth_provider).__name__, "CredentialProvider") def test_get_python_sql_connector_basic_auth(self): @@ -182,8 +187,9 @@ def test_get_python_sql_connector_basic_auth(self): "username": "username", "password": "password", } + mock_http_client = MagicMock() with self.assertRaises(ValueError) as e: - get_python_sql_connector_auth_provider("foo.cloud.databricks.com", **kwargs) + get_python_sql_connector_auth_provider("foo.cloud.databricks.com", mock_http_client, **kwargs) self.assertIn( "Username/password authentication is no longer supported", str(e.exception) ) @@ -191,7 +197,8 @@ def test_get_python_sql_connector_basic_auth(self): @patch.object(DatabricksOAuthProvider, "_initial_get_token") def test_get_python_sql_connector_default_auth(self, mock__initial_get_token): hostname = "foo.cloud.databricks.com" - auth_provider = get_python_sql_connector_auth_provider(hostname) + mock_http_client = MagicMock() + auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client) self.assertTrue(type(auth_provider).__name__, "DatabricksOAuthProvider") self.assertTrue(auth_provider._client_id, PYSQL_OAUTH_CLIENT_ID) @@ -223,10 +230,12 @@ def status_response(response_status_code): @pytest.fixture def token_source(self): + mock_http_client = MagicMock() return ClientCredentialsTokenSource( token_url="https://token_url.com", client_id="client_id", client_secret="client_secret", + http_client=mock_http_client, ) def test_no_token_refresh__when_token_is_not_expired( @@ -249,10 +258,21 @@ def test_no_token_refresh__when_token_is_not_expired( assert mock_get_token.call_count == 1 def test_get_token_success(self, token_source, http_response): - databricks_http_client = DatabricksHttpClient.get_instance() - with patch.object( - databricks_http_client.session, "request", return_value=http_response(200) - ) as mock_request: + mock_http_client = MagicMock() + + with patch.object(token_source, "_http_client", mock_http_client): + # Create a mock response with the expected format + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "access_token": "abc123", + "token_type": "Bearer", + "refresh_token": None, + } + # Mock the context manager (execute returns context manager) + mock_http_client.execute.return_value.__enter__.return_value = mock_response + mock_http_client.execute.return_value.__exit__.return_value = None + token = token_source.get_token() # Assert @@ -262,10 +282,18 @@ def test_get_token_success(self, token_source, http_response): assert token.refresh_token is None def test_get_token_failure(self, token_source, http_response): - databricks_http_client = DatabricksHttpClient.get_instance() - with patch.object( - databricks_http_client.session, "request", return_value=http_response(400) - ) as mock_request: + mock_http_client = MagicMock() + + with patch.object(token_source, "_http_client", mock_http_client): + # Create a mock response with error + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.text = "Bad Request" + mock_response.json.return_value = {"error": "invalid_client"} + # Mock the context manager (execute returns context manager) + mock_http_client.execute.return_value.__enter__.return_value = mock_response + mock_http_client.execute.return_value.__exit__.return_value = None + with pytest.raises(Exception) as e: token_source.get_token() assert "Failed to get token: 400" in str(e.value) diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index faa8e2f9..0c3fc710 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -13,6 +13,31 @@ @pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed") class CloudFetchQueueSuite(unittest.TestCase): + def create_queue(self, schema_bytes=None, result_links=None, description=None, **kwargs): + """Helper method to create ThriftCloudFetchQueue with sensible defaults""" + # Set up defaults for commonly used parameters + defaults = { + 'max_download_threads': 10, + 'ssl_options': SSLOptions(), + 'session_id_hex': Mock(), + 'statement_id': Mock(), + 'chunk_id': 0, + 'start_row_offset': 0, + 'lz4_compressed': True, + } + + # Override defaults with any provided kwargs + defaults.update(kwargs) + + mock_http_client = MagicMock() + return utils.ThriftCloudFetchQueue( + schema_bytes=schema_bytes or MagicMock(), + result_links=result_links or [], + description=description or [], + http_client=mock_http_client, + **defaults + ) + def create_result_link( self, file_link: str = "fileLink", @@ -58,15 +83,7 @@ def get_schema_bytes(): def test_initializer_adds_links(self, mock_create_next_table): schema_bytes = MagicMock() result_links = self.create_result_links(10) - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=result_links, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, result_links=result_links) assert len(queue.download_manager._pending_links) == 10 assert len(queue.download_manager._download_tasks) == 0 @@ -74,16 +91,7 @@ def test_initializer_adds_links(self, mock_create_next_table): def test_initializer_no_links_to_add(self): schema_bytes = MagicMock() - result_links = [] - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=result_links, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, result_links=[]) assert len(queue.download_manager._pending_links) == 0 assert len(queue.download_manager._download_tasks) == 0 @@ -94,15 +102,7 @@ def test_initializer_no_links_to_add(self): return_value=None, ) def test_create_next_table_no_download(self, mock_get_next_downloaded_file): - queue = utils.ThriftCloudFetchQueue( - MagicMock(), - result_links=[], - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=MagicMock(), result_links=[]) assert queue._create_next_table() is None mock_get_next_downloaded_file.assert_called_with(0) @@ -117,16 +117,7 @@ def test_initializer_create_next_table_success( ): mock_create_arrow_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) expected_result = self.make_arrow_table() mock_get_next_downloaded_file.assert_called_with(0) @@ -145,16 +136,7 @@ def test_initializer_create_next_table_success( def test_next_n_rows_0_rows(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -167,16 +149,7 @@ def test_next_n_rows_0_rows(self, mock_create_next_table): def test_next_n_rows_partial_table(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -190,16 +163,7 @@ def test_next_n_rows_partial_table(self, mock_create_next_table): def test_next_n_rows_more_than_one_table(self, mock_create_next_table): mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -218,16 +182,7 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -242,17 +197,9 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): ) def test_next_n_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() - description = MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + # Create description that matches the 4-column schema + description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")] + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table is None result = queue.next_n_rows(100) @@ -263,16 +210,7 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None, 0] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 queue.table_row_index = 4 @@ -285,16 +223,7 @@ def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table) def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 queue.table_row_index = 2 @@ -307,16 +236,7 @@ def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_tabl def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -335,16 +255,7 @@ def test_remaining_rows_multiple_tables_fully_returned( None, ] schema_bytes, description = MagicMock(), MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 queue.table_row_index = 3 @@ -365,17 +276,9 @@ def test_remaining_rows_multiple_tables_fully_returned( ) def test_remaining_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() - description = MagicMock() - queue = utils.ThriftCloudFetchQueue( - schema_bytes, - result_links=[], - description=description, - max_download_threads=10, - ssl_options=SSLOptions(), - session_id_hex=Mock(), - statement_id=Mock(), - chunk_id=0, - ) + # Create description that matches the 4-column schema + description = [("col0", "uint32"), ("col1", "uint32"), ("col2", "uint32"), ("col3", "uint32")] + queue = self.create_queue(schema_bytes=schema_bytes, description=description) assert queue.table is None result = queue.remaining_rows() diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py index 6eb17a05..1c77226a 100644 --- a/tests/unit/test_download_manager.py +++ b/tests/unit/test_download_manager.py @@ -14,6 +14,7 @@ class DownloadManagerTests(unittest.TestCase): def create_download_manager( self, links, max_download_threads=10, lz4_compressed=True ): + mock_http_client = MagicMock() return download_manager.ResultFileDownloadManager( links, max_download_threads, @@ -22,6 +23,7 @@ def create_download_manager( session_id_hex=Mock(), statement_id=Mock(), chunk_id=0, + http_client=mock_http_client, ) def create_result_link( diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index c514980e..00b1b849 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -1,21 +1,19 @@ -from contextlib import contextmanager import unittest -from unittest.mock import Mock, patch, MagicMock - +from unittest.mock import patch, MagicMock, Mock import requests import databricks.sql.cloudfetch.downloader as downloader -from databricks.sql.common.http import DatabricksHttpClient from databricks.sql.exc import Error from databricks.sql.types import SSLOptions -def create_response(**kwargs) -> requests.Response: - result = requests.Response() +def create_mock_response(**kwargs): + """Create a mock response object for testing""" + mock_response = MagicMock() for k, v in kwargs.items(): - setattr(result, k, v) - result.close = Mock() - return result + setattr(mock_response, k, v) + mock_response.close = Mock() + return mock_response class DownloaderTests(unittest.TestCase): @@ -23,6 +21,17 @@ class DownloaderTests(unittest.TestCase): Unit tests for checking downloader logic. """ + def _setup_mock_http_response(self, mock_http_client, status=200, data=b""): + """Helper method to setup mock HTTP client with response context manager.""" + mock_response = MagicMock() + mock_response.status = status + mock_response.data = data + mock_context_manager = MagicMock() + mock_context_manager.__enter__.return_value = mock_response + mock_context_manager.__exit__.return_value = None + mock_http_client.request_context.return_value = mock_context_manager + return mock_response + def _setup_time_mock_for_download(self, mock_time, end_time): """Helper to setup time mock that handles logging system calls.""" call_count = [0] @@ -38,6 +47,7 @@ def time_side_effect(): @patch("time.time", return_value=1000) def test_run_link_expired(self, mock_time): + mock_http_client = MagicMock() settings = Mock() result_link = Mock() # Already expired @@ -49,6 +59,7 @@ def test_run_link_expired(self, mock_time): chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), + http_client=mock_http_client, ) with self.assertRaises(Error) as context: @@ -59,6 +70,7 @@ def test_run_link_expired(self, mock_time): @patch("time.time", return_value=1000) def test_run_link_past_expiry_buffer(self, mock_time): + mock_http_client = MagicMock() settings = Mock(link_expiry_buffer_secs=5) result_link = Mock() # Within the expiry buffer time @@ -70,6 +82,7 @@ def test_run_link_past_expiry_buffer(self, mock_time): chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), + http_client=mock_http_client, ) with self.assertRaises(Error) as context: @@ -80,46 +93,45 @@ def test_run_link_past_expiry_buffer(self, mock_time): @patch("time.time", return_value=1000) def test_run_get_response_not_ok(self, mock_time): - http_client = DatabricksHttpClient.get_instance() + mock_http_client = MagicMock() settings = Mock(link_expiry_buffer_secs=0, download_timeout=0) settings.download_timeout = 0 settings.use_proxy = False result_link = Mock(expiryTime=1001) - with patch.object( - http_client, - "execute", - return_value=create_response(status_code=404, _content=b"1234"), - ): - d = downloader.ResultSetDownloadHandler( - settings, - result_link, - ssl_options=SSLOptions(), - chunk_id=0, - session_id_hex=Mock(), - statement_id=Mock(), - ) - with self.assertRaises(requests.exceptions.HTTPError) as context: - d.run() - self.assertTrue("404" in str(context.exception)) + # Setup mock HTTP response using helper method + self._setup_mock_http_response(mock_http_client, status=404, data=b"1234") + + d = downloader.ResultSetDownloadHandler( + settings, + result_link, + ssl_options=SSLOptions(), + chunk_id=0, + session_id_hex=Mock(), + statement_id=Mock(), + http_client=mock_http_client, + ) + with self.assertRaises(Exception) as context: + d.run() + self.assertTrue("404" in str(context.exception)) @patch("time.time") def test_run_uncompressed_successful(self, mock_time): self._setup_time_mock_for_download(mock_time, 1000.5) - http_client = DatabricksHttpClient.get_instance() + mock_http_client = MagicMock() file_bytes = b"1234567890" * 10 settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = False settings.min_cloudfetch_download_speed = 1.0 - result_link = Mock(bytesNum=100, expiryTime=1001) - result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=abc123" + result_link = Mock(expiryTime=1001, bytesNum=len(file_bytes)) + result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=xyz789" + + # Setup mock HTTP response using helper method + self._setup_mock_http_response(mock_http_client, status=200, data=file_bytes) - with patch.object( - http_client, - "execute", - return_value=create_response(status_code=200, _content=file_bytes), - ): + # Patch the log metrics method to avoid division by zero + with patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'): d = downloader.ResultSetDownloadHandler( settings, result_link, @@ -127,29 +139,32 @@ def test_run_uncompressed_successful(self, mock_time): chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), + http_client=mock_http_client, ) file = d.run() - - assert file.file_bytes == b"1234567890" * 10 + self.assertEqual(file.file_bytes, file_bytes) + self.assertEqual(file.start_row_offset, result_link.startRowOffset) + self.assertEqual(file.row_count, result_link.rowCount) @patch("time.time") def test_run_compressed_successful(self, mock_time): self._setup_time_mock_for_download(mock_time, 1000.2) - http_client = DatabricksHttpClient.get_instance() + mock_http_client = MagicMock() file_bytes = b"1234567890" * 10 compressed_bytes = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = True settings.min_cloudfetch_download_speed = 1.0 - result_link = Mock(bytesNum=100, expiryTime=1001) + result_link = Mock(expiryTime=1001, bytesNum=len(file_bytes)) result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=xyz789" - with patch.object( - http_client, - "execute", - return_value=create_response(status_code=200, _content=compressed_bytes), - ): + + # Setup mock HTTP response using helper method + self._setup_mock_http_response(mock_http_client, status=200, data=compressed_bytes) + + # Mock the decompression method and log metrics to avoid issues + with patch.object(downloader.ResultSetDownloadHandler, '_decompress_data', return_value=file_bytes), \ + patch.object(downloader.ResultSetDownloadHandler, '_log_download_metrics'): d = downloader.ResultSetDownloadHandler( settings, result_link, @@ -157,48 +172,53 @@ def test_run_compressed_successful(self, mock_time): chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), + http_client=mock_http_client, ) file = d.run() - - assert file.file_bytes == b"1234567890" * 10 + self.assertEqual(file.file_bytes, file_bytes) + self.assertEqual(file.start_row_offset, result_link.startRowOffset) + self.assertEqual(file.row_count, result_link.rowCount) @patch("time.time", return_value=1000) def test_download_connection_error(self, mock_time): - - http_client = DatabricksHttpClient.get_instance() + mock_http_client = MagicMock() settings = Mock( link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True ) result_link = Mock(bytesNum=100, expiryTime=1001) - with patch.object(http_client, "execute", side_effect=ConnectionError("foo")): - d = downloader.ResultSetDownloadHandler( - settings, - result_link, - ssl_options=SSLOptions(), - chunk_id=0, - session_id_hex=Mock(), - statement_id=Mock(), - ) - with self.assertRaises(ConnectionError): - d.run() + mock_http_client.request_context.side_effect = ConnectionError("foo") + + d = downloader.ResultSetDownloadHandler( + settings, + result_link, + ssl_options=SSLOptions(), + chunk_id=0, + session_id_hex=Mock(), + statement_id=Mock(), + http_client=mock_http_client, + ) + with self.assertRaises(ConnectionError): + d.run() @patch("time.time", return_value=1000) def test_download_timeout(self, mock_time): - http_client = DatabricksHttpClient.get_instance() + mock_http_client = MagicMock() settings = Mock( link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True ) result_link = Mock(bytesNum=100, expiryTime=1001) - with patch.object(http_client, "execute", side_effect=TimeoutError("foo")): - d = downloader.ResultSetDownloadHandler( - settings, - result_link, - ssl_options=SSLOptions(), - chunk_id=0, - session_id_hex=Mock(), - statement_id=Mock(), - ) - with self.assertRaises(TimeoutError): - d.run() + mock_http_client.request_context.side_effect = TimeoutError("foo") + + d = downloader.ResultSetDownloadHandler( + settings, + result_link, + ssl_options=SSLOptions(), + chunk_id=0, + session_id_hex=Mock(), + statement_id=Mock(), + http_client=mock_http_client, + ) + with self.assertRaises(TimeoutError): + d.run() diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py index cbeae098..6471cb4f 100644 --- a/tests/unit/test_sea_queue.py +++ b/tests/unit/test_sea_queue.py @@ -7,7 +7,7 @@ """ import pytest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, MagicMock from databricks.sql.backend.sea.queue import ( JsonQueue, @@ -184,6 +184,7 @@ def description(self): def test_build_queue_json_array(self, json_manifest, sample_data): """Test building a JSON array queue.""" result_data = ResultData(data=sample_data) + mock_http_client = MagicMock() queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, @@ -194,6 +195,7 @@ def test_build_queue_json_array(self, json_manifest, sample_data): max_download_threads=10, sea_client=Mock(), lz4_compressed=False, + http_client=mock_http_client, ) assert isinstance(queue, JsonQueue) @@ -217,6 +219,8 @@ def test_build_queue_arrow_stream( ] result_data = ResultData(data=None, external_links=external_links) + mock_http_client = MagicMock() + with patch( "databricks.sql.backend.sea.queue.ResultFileDownloadManager" ), patch.object(SeaCloudFetchQueue, "_create_next_table", return_value=None): @@ -229,6 +233,7 @@ def test_build_queue_arrow_stream( max_download_threads=10, sea_client=mock_sea_client, lz4_compressed=False, + http_client=mock_http_client, ) assert isinstance(queue, SeaCloudFetchQueue) @@ -236,6 +241,7 @@ def test_build_queue_arrow_stream( def test_build_queue_invalid_format(self, invalid_manifest): """Test building a queue with invalid format.""" result_data = ResultData(data=[]) + mock_http_client = MagicMock() with pytest.raises(ProgrammingError, match="Invalid result format"): SeaResultSetQueueFactory.build_queue( @@ -247,6 +253,7 @@ def test_build_queue_invalid_format(self, invalid_manifest): max_download_threads=10, sea_client=Mock(), lz4_compressed=False, + http_client=mock_http_client, ) @@ -339,6 +346,7 @@ def test_init_with_valid_initial_link( ): """Test initialization with valid initial link.""" # Create a queue with valid initial link + mock_http_client = MagicMock() with patch.object(SeaCloudFetchQueue, "_create_next_table", return_value=None): queue = SeaCloudFetchQueue( result_data=ResultData(external_links=[sample_external_link]), @@ -349,6 +357,7 @@ def test_init_with_valid_initial_link( total_chunk_count=1, lz4_compressed=False, description=description, + http_client=mock_http_client, ) # Verify attributes @@ -367,6 +376,7 @@ def test_init_no_initial_links( ): """Test initialization with no initial links.""" # Create a queue with empty initial links + mock_http_client = MagicMock() queue = SeaCloudFetchQueue( result_data=ResultData(external_links=[]), max_download_threads=5, @@ -376,6 +386,7 @@ def test_init_no_initial_links( total_chunk_count=0, lz4_compressed=False, description=description, + http_client=mock_http_client, ) assert queue.table is None @@ -462,7 +473,7 @@ def test_hybrid_disposition_with_attachment( # Create result data with attachment attachment_data = b"mock_arrow_data" result_data = ResultData(attachment=attachment_data) - + mock_http_client = MagicMock() # Build queue queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, @@ -473,6 +484,7 @@ def test_hybrid_disposition_with_attachment( max_download_threads=10, sea_client=mock_sea_client, lz4_compressed=False, + http_client=mock_http_client, ) # Verify ArrowQueue was created @@ -508,7 +520,8 @@ def test_hybrid_disposition_with_external_links( # Create result data with external links but no attachment result_data = ResultData(external_links=external_links, attachment=None) - # Build queue + # Build queue + mock_http_client = MagicMock() queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, manifest=arrow_manifest, @@ -518,6 +531,7 @@ def test_hybrid_disposition_with_external_links( max_download_threads=10, sea_client=mock_sea_client, lz4_compressed=False, + http_client=mock_http_client, ) # Verify SeaCloudFetchQueue was created @@ -548,7 +562,7 @@ def test_hybrid_disposition_with_compressed_attachment( # Create result data with attachment result_data = ResultData(attachment=compressed_data) - + mock_http_client = MagicMock() # Build queue with lz4_compressed=True queue = SeaResultSetQueueFactory.build_queue( result_data=result_data, @@ -559,6 +573,7 @@ def test_hybrid_disposition_with_compressed_attachment( max_download_threads=10, sea_client=mock_sea_client, lz4_compressed=True, + http_client=mock_http_client, ) # Verify ArrowQueue was created with decompressed data diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 6823b1b3..e019e05a 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -75,8 +75,9 @@ def test_http_header_passthrough(self, mock_client_class): call_kwargs = mock_client_class.call_args[1] assert ("foo", "bar") in call_kwargs["http_headers"] + @patch("%s.client.UnifiedHttpClient" % PACKAGE_NAME) @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) - def test_tls_arg_passthrough(self, mock_client_class): + def test_tls_arg_passthrough(self, mock_client_class, mock_http_client): databricks.sql.connect( **self.DUMMY_CONNECTION_ARGS, _tls_verify_hostname="hostname", diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index d85e4171..989b2351 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -1,6 +1,7 @@ import uuid import pytest from unittest.mock import patch, MagicMock +import json from databricks.sql.telemetry.telemetry_client import ( TelemetryClient, @@ -23,6 +24,7 @@ def mock_telemetry_client(): session_id = str(uuid.uuid4()) auth_provider = AccessTokenAuthProvider("test-token") executor = MagicMock() + mock_http_client = MagicMock() return TelemetryClient( telemetry_enabled=True, @@ -31,6 +33,7 @@ def mock_telemetry_client(): host_url="test-host.com", executor=executor, batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=mock_http_client, ) @@ -72,10 +75,15 @@ def test_event_batching_and_flushing_flow(self, mock_telemetry_client): mock_send.assert_called_once() assert len(client._events_batch) == 0 # Batch cleared after flush - @patch("requests.post") - def test_network_request_flow(self, mock_post, mock_telemetry_client): + @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") + def test_network_request_flow(self, mock_http_request, mock_telemetry_client): """Test the complete network request flow with authentication.""" - mock_post.return_value.status_code = 200 + # Mock response for unified HTTP client + mock_response = MagicMock() + mock_response.status = 200 + mock_response.status_code = 200 + mock_http_request.return_value = mock_response + client = mock_telemetry_client # Create mock events @@ -91,7 +99,7 @@ def test_network_request_flow(self, mock_post, mock_telemetry_client): args, kwargs = client._executor.submit.call_args # Verify correct function and URL - assert args[0] == client._http_client.post + assert args[0] == client._send_with_unified_client assert args[1] == "https://test-host.com/telemetry-ext" assert kwargs["headers"]["Authorization"] == "Bearer test-token" @@ -208,6 +216,7 @@ def test_client_lifecycle_flow(self): """Test complete client lifecycle: initialize -> use -> close.""" session_id_hex = "test-session" auth_provider = AccessTokenAuthProvider("token") + mock_http_client = MagicMock() # Initialize enabled client TelemetryClientFactory.initialize_telemetry_client( @@ -216,6 +225,7 @@ def test_client_lifecycle_flow(self): auth_provider=auth_provider, host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=mock_http_client, ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) @@ -234,6 +244,7 @@ def test_client_lifecycle_flow(self): def test_disabled_telemetry_flow(self): """Test that disabled telemetry uses NoopTelemetryClient.""" session_id_hex = "test-session" + mock_http_client = MagicMock() TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=False, @@ -241,6 +252,7 @@ def test_disabled_telemetry_flow(self): auth_provider=None, host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=mock_http_client, ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) @@ -249,6 +261,7 @@ def test_disabled_telemetry_flow(self): def test_factory_error_handling(self): """Test that factory errors fall back to NoopTelemetryClient.""" session_id = "test-session" + mock_http_client = MagicMock() # Simulate initialization error with patch( @@ -261,6 +274,7 @@ def test_factory_error_handling(self): auth_provider=AccessTokenAuthProvider("token"), host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=mock_http_client, ) # Should fall back to NoopTelemetryClient @@ -271,6 +285,7 @@ def test_factory_shutdown_flow(self): """Test factory shutdown when last client is removed.""" session1 = "session-1" session2 = "session-2" + mock_http_client = MagicMock() # Initialize multiple clients for session in [session1, session2]: @@ -280,6 +295,7 @@ def test_factory_shutdown_flow(self): auth_provider=AccessTokenAuthProvider("token"), host_url="test-host.com", batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, + http_client=mock_http_client, ) # Factory should be initialized @@ -325,10 +341,11 @@ def test_connection_failure_sends_correct_telemetry_payload( class TestTelemetryFeatureFlag: """Tests the interaction between the telemetry feature flag and connection parameters.""" - def _mock_ff_response(self, mock_requests_get, enabled: bool): - """Helper to configure the mock response for the feature flag endpoint.""" + def _mock_ff_response(self, mock_http_request, enabled: bool): + """Helper method to mock feature flag response for unified HTTP client.""" mock_response = MagicMock() - mock_response.status_code = 200 + mock_response.status = 200 + mock_response.status_code = 200 # Compatibility attribute payload = { "flags": [ { @@ -339,15 +356,21 @@ def _mock_ff_response(self, mock_requests_get, enabled: bool): "ttl_seconds": 3600, } mock_response.json.return_value = payload - mock_requests_get.return_value = mock_response + mock_response.data = json.dumps(payload).encode() + mock_http_request.return_value = mock_response - @patch("databricks.sql.common.feature_flag.requests.get") - def test_telemetry_enabled_when_flag_is_true(self, mock_requests_get, MockSession): + @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") + def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSession): """Telemetry should be ON when enable_telemetry=True and server flag is 'true'.""" - self._mock_ff_response(mock_requests_get, enabled=True) + self._mock_ff_response(mock_http_request, enabled=True) mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-true" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + + # Set up mock HTTP client on the session + mock_http_client = MagicMock() + mock_http_client.request = mock_http_request + mock_session_instance.http_client = mock_http_client conn = sql.client.Connection( server_hostname="test", @@ -357,19 +380,24 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_requests_get, MockSessio ) assert conn.telemetry_enabled is True - mock_requests_get.assert_called_once() + mock_http_request.assert_called_once() client = TelemetryClientFactory.get_telemetry_client("test-session-ff-true") assert isinstance(client, TelemetryClient) - @patch("databricks.sql.common.feature_flag.requests.get") + @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") def test_telemetry_disabled_when_flag_is_false( - self, mock_requests_get, MockSession + self, mock_http_request, MockSession ): """Telemetry should be OFF when enable_telemetry=True but server flag is 'false'.""" - self._mock_ff_response(mock_requests_get, enabled=False) + self._mock_ff_response(mock_http_request, enabled=False) mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-false" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + + # Set up mock HTTP client on the session + mock_http_client = MagicMock() + mock_http_client.request = mock_http_request + mock_session_instance.http_client = mock_http_client conn = sql.client.Connection( server_hostname="test", @@ -379,19 +407,24 @@ def test_telemetry_disabled_when_flag_is_false( ) assert conn.telemetry_enabled is False - mock_requests_get.assert_called_once() + mock_http_request.assert_called_once() client = TelemetryClientFactory.get_telemetry_client("test-session-ff-false") assert isinstance(client, NoopTelemetryClient) - @patch("databricks.sql.common.feature_flag.requests.get") + @patch("databricks.sql.common.unified_http_client.UnifiedHttpClient.request") def test_telemetry_disabled_when_flag_request_fails( - self, mock_requests_get, MockSession + self, mock_http_request, MockSession ): """Telemetry should default to OFF if the feature flag network request fails.""" - mock_requests_get.side_effect = Exception("Network is down") + mock_http_request.side_effect = Exception("Network is down") mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-fail" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") + + # Set up mock HTTP client on the session + mock_http_client = MagicMock() + mock_http_client.request = mock_http_request + mock_session_instance.http_client = mock_http_client conn = sql.client.Connection( server_hostname="test", @@ -401,6 +434,6 @@ def test_telemetry_disabled_when_flag_request_fails( ) assert conn.telemetry_enabled is False - mock_requests_get.assert_called_once() + mock_http_request.assert_called_once() client = TelemetryClientFactory.get_telemetry_client("test-session-ff-fail") assert isinstance(client, NoopTelemetryClient) diff --git a/tests/unit/test_telemetry_retry.py b/tests/unit/test_telemetry_retry.py deleted file mode 100644 index d5287deb..00000000 --- a/tests/unit/test_telemetry_retry.py +++ /dev/null @@ -1,124 +0,0 @@ -import pytest -from unittest.mock import patch, MagicMock -import io -import time - -from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory -from databricks.sql.auth.retry import DatabricksRetryPolicy - -PATCH_TARGET = "urllib3.connectionpool.HTTPSConnectionPool._get_conn" - - -def create_mock_conn(responses): - """Creates a mock connection object whose getresponse() method yields a series of responses.""" - mock_conn = MagicMock() - mock_http_responses = [] - for resp in responses: - mock_http_response = MagicMock() - mock_http_response.status = resp.get("status") - mock_http_response.headers = resp.get("headers", {}) - body = resp.get("body", b"{}") - mock_http_response.fp = io.BytesIO(body) - - def release(): - mock_http_response.fp.close() - - mock_http_response.release_conn = release - mock_http_responses.append(mock_http_response) - mock_conn.getresponse.side_effect = mock_http_responses - return mock_conn - - -class TestTelemetryClientRetries: - @pytest.fixture(autouse=True) - def setup_and_teardown(self): - TelemetryClientFactory._initialized = False - TelemetryClientFactory._clients = {} - TelemetryClientFactory._executor = None - yield - if TelemetryClientFactory._executor: - TelemetryClientFactory._executor.shutdown(wait=True) - TelemetryClientFactory._initialized = False - TelemetryClientFactory._clients = {} - TelemetryClientFactory._executor = None - - def get_client(self, session_id, num_retries=3): - """ - Configures a client with a specific number of retries. - """ - TelemetryClientFactory.initialize_telemetry_client( - telemetry_enabled=True, - session_id_hex=session_id, - auth_provider=None, - host_url="test.databricks.com", - batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, - ) - client = TelemetryClientFactory.get_telemetry_client(session_id) - - retry_policy = DatabricksRetryPolicy( - delay_min=0.01, - delay_max=0.02, - stop_after_attempts_duration=2.0, - stop_after_attempts_count=num_retries, - delay_default=0.1, - force_dangerous_codes=[], - urllib3_kwargs={"total": num_retries}, - ) - adapter = client._http_client.session.adapters.get("https://") - adapter.max_retries = retry_policy - return client - - @pytest.mark.parametrize( - "status_code, description", - [ - (401, "Unauthorized"), - (403, "Forbidden"), - (501, "Not Implemented"), - (200, "Success"), - ], - ) - def test_non_retryable_status_codes_are_not_retried(self, status_code, description): - """ - Verifies that terminal error codes (401, 403, 501) and success codes (200) are not retried. - """ - # Use the status code in the session ID for easier debugging if it fails - client = self.get_client(f"session-{status_code}") - mock_responses = [{"status": status_code}] - - with patch( - PATCH_TARGET, return_value=create_mock_conn(mock_responses) - ) as mock_get_conn: - client.export_failure_log("TestError", "Test message") - TelemetryClientFactory.close(client._session_id_hex) - - mock_get_conn.return_value.getresponse.assert_called_once() - - def test_exceeds_retry_count_limit(self): - """ - Verifies that the client retries up to the specified number of times before giving up. - Verifies that the client respects the Retry-After header and retries on 429, 502, 503. - """ - num_retries = 3 - expected_total_calls = num_retries + 1 - retry_after = 1 - client = self.get_client("session-exceed-limit", num_retries=num_retries) - mock_responses = [ - {"status": 503, "headers": {"Retry-After": str(retry_after)}}, - {"status": 429}, - {"status": 502}, - {"status": 503}, - ] - - with patch( - PATCH_TARGET, return_value=create_mock_conn(mock_responses) - ) as mock_get_conn: - start_time = time.time() - client.export_failure_log("TestError", "Test message") - TelemetryClientFactory.close(client._session_id_hex) - end_time = time.time() - - assert ( - mock_get_conn.return_value.getresponse.call_count - == expected_total_calls - ) - assert end_time - start_time > retry_after