From 1f8cf7333164b8b9ddfe3caf8625c539c28c845d Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Thu, 22 Aug 2024 19:11:28 +0300 Subject: [PATCH 01/13] [PECO-1857] Use SSL options with HTTPS connection pool (#425) * [PECO-1857] Use SSL options with HTTPS connection pool Signed-off-by: Levko Kravets * Some cleanup Signed-off-by: Levko Kravets * Resolve circular dependencies Signed-off-by: Levko Kravets * Update existing tests Signed-off-by: Levko Kravets * Fix MyPy issues Signed-off-by: Levko Kravets * Fix `_tls_no_verify` handling Signed-off-by: Levko Kravets * Add tests Signed-off-by: Levko Kravets --------- Signed-off-by: Levko Kravets --- src/databricks/sql/auth/thrift_http_client.py | 41 ++-- src/databricks/sql/client.py | 18 +- .../sql/cloudfetch/download_manager.py | 9 +- src/databricks/sql/cloudfetch/downloader.py | 12 +- src/databricks/sql/thrift_backend.py | 43 +--- src/databricks/sql/types.py | 48 +++++ src/databricks/sql/utils.py | 18 +- tests/unit/test_cloud_fetch_queue.py | 30 +-- tests/unit/test_download_manager.py | 5 +- tests/unit/test_downloader.py | 16 +- tests/unit/test_thrift_backend.py | 186 ++++++++++++------ 11 files changed, 267 insertions(+), 159 deletions(-) diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index f7c22a1e..6273ab28 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -1,13 +1,11 @@ import base64 import logging import urllib.parse -from typing import Dict, Union +from typing import Dict, Union, Optional import six import thrift -logger = logging.getLogger(__name__) - import ssl import warnings from http.client import HTTPResponse @@ -16,6 +14,9 @@ from urllib3 import HTTPConnectionPool, HTTPSConnectionPool, ProxyManager from urllib3.util import make_headers from databricks.sql.auth.retry import CommandType, DatabricksRetryPolicy +from databricks.sql.types import SSLOptions + +logger = logging.getLogger(__name__) class THttpClient(thrift.transport.THttpClient.THttpClient): @@ -25,13 +26,12 @@ def __init__( uri_or_host, port=None, path=None, - cafile=None, - cert_file=None, - key_file=None, - ssl_context=None, + ssl_options: Optional[SSLOptions] = None, max_connections: int = 1, retry_policy: Union[DatabricksRetryPolicy, int] = 0, ): + self._ssl_options = ssl_options + if port is not None: warnings.warn( "Please use the THttpClient('http{s}://host:port/path') constructor", @@ -48,13 +48,11 @@ def __init__( self.scheme = parsed.scheme assert self.scheme in ("http", "https") if self.scheme == "https": - self.certfile = cert_file - self.keyfile = key_file - self.context = ( - ssl.create_default_context(cafile=cafile) - if (cafile and not ssl_context) - else ssl_context - ) + if self._ssl_options is not None: + # TODO: Not sure if those options are used anywhere - need to double-check + self.certfile = self._ssl_options.tls_client_cert_file + self.keyfile = self._ssl_options.tls_client_cert_key_file + self.context = self._ssl_options.create_ssl_context() self.port = parsed.port self.host = parsed.hostname self.path = parsed.path @@ -109,12 +107,23 @@ def startRetryTimer(self): def open(self): # self.__pool replaces the self.__http used by the original THttpClient + _pool_kwargs = {"maxsize": self.max_connections} + if self.scheme == "http": pool_class = HTTPConnectionPool elif self.scheme == "https": pool_class = HTTPSConnectionPool - - _pool_kwargs = {"maxsize": self.max_connections} + _pool_kwargs.update( + { + "cert_reqs": ssl.CERT_REQUIRED + if self._ssl_options.tls_verify + else ssl.CERT_NONE, + "ca_certs": self._ssl_options.tls_trusted_ca_file, + "cert_file": self._ssl_options.tls_client_cert_file, + "key_file": self._ssl_options.tls_client_cert_key_file, + "key_password": self._ssl_options.tls_client_cert_key_password, + } + ) if self.using_proxy(): proxy_manager = ProxyManager( diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index c0bf534d..addc340e 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -35,7 +35,7 @@ ) -from databricks.sql.types import Row +from databricks.sql.types import Row, SSLOptions from databricks.sql.auth.auth import get_python_sql_connector_auth_provider from databricks.sql.experimental.oauth_persistence import OAuthPersistence @@ -178,8 +178,9 @@ def read(self) -> Optional[OAuthToken]: # _tls_trusted_ca_file # Set to the path of the file containing trusted CA certificates for server certificate # verification. If not provide, uses system truststore. - # _tls_client_cert_file, _tls_client_cert_key_file + # _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password # Set client SSL certificate. + # See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain # _retry_stop_after_attempts_count # The maximum number of attempts during a request retry sequence (defaults to 24) # _socket_timeout @@ -220,12 +221,25 @@ def read(self) -> Optional[OAuthToken]: base_headers = [("User-Agent", useragent_header)] + self._ssl_options = SSLOptions( + # Double negation is generally a bad thing, but we have to keep backward compatibility + tls_verify=not kwargs.get( + "_tls_no_verify", False + ), # by default - verify cert and host + tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), + tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), + tls_client_cert_file=kwargs.get("_tls_client_cert_file"), + tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), + tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), + ) + self.thrift_backend = ThriftBackend( self.host, self.port, http_path, (http_headers or []) + base_headers, auth_provider, + ssl_options=self._ssl_options, _use_arrow_native_complex_types=_use_arrow_native_complex_types, **kwargs, ) diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index e30adcd6..7e96cd32 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -1,6 +1,5 @@ import logging -from ssl import SSLContext from concurrent.futures import ThreadPoolExecutor, Future from typing import List, Union @@ -9,6 +8,8 @@ DownloadableResultSettings, DownloadedFile, ) +from databricks.sql.types import SSLOptions + from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink logger = logging.getLogger(__name__) @@ -20,7 +21,7 @@ def __init__( links: List[TSparkArrowResultLink], max_download_threads: int, lz4_compressed: bool, - ssl_context: SSLContext, + ssl_options: SSLOptions, ): self._pending_links: List[TSparkArrowResultLink] = [] for link in links: @@ -38,7 +39,7 @@ def __init__( self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads) self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed) - self._ssl_context = ssl_context + self._ssl_options = ssl_options def get_next_downloaded_file( self, next_row_offset: int @@ -95,7 +96,7 @@ def _schedule_downloads(self): handler = ResultSetDownloadHandler( settings=self._downloadable_result_settings, link=link, - ssl_context=self._ssl_context, + ssl_options=self._ssl_options, ) task = self._thread_pool.submit(handler.run) self._download_tasks.append(task) diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 00ffecd0..03c70054 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -3,13 +3,12 @@ import requests from requests.adapters import HTTPAdapter, Retry -from ssl import SSLContext, CERT_NONE import lz4.frame import time from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink - from databricks.sql.exc import Error +from databricks.sql.types import SSLOptions logger = logging.getLogger(__name__) @@ -66,11 +65,11 @@ def __init__( self, settings: DownloadableResultSettings, link: TSparkArrowResultLink, - ssl_context: SSLContext, + ssl_options: SSLOptions, ): self.settings = settings self.link = link - self._ssl_context = ssl_context + self._ssl_options = ssl_options def run(self) -> DownloadedFile: """ @@ -95,14 +94,13 @@ def run(self) -> DownloadedFile: session.mount("http://", HTTPAdapter(max_retries=retryPolicy)) session.mount("https://", HTTPAdapter(max_retries=retryPolicy)) - ssl_verify = self._ssl_context.verify_mode != CERT_NONE - try: # Get the file via HTTP request response = session.get( self.link.fileLink, timeout=self.settings.download_timeout, - verify=ssl_verify, + verify=self._ssl_options.tls_verify, + # TODO: Pass cert from `self._ssl_options` ) response.raise_for_status() diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 56412fce..e89bff26 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -5,7 +5,6 @@ import time import uuid import threading -from ssl import CERT_NONE, CERT_REQUIRED, create_default_context from typing import List, Union import pyarrow @@ -36,6 +35,7 @@ convert_decimals_in_arrow_table, convert_column_based_set_to_arrow_table, ) +from databricks.sql.types import SSLOptions logger = logging.getLogger(__name__) @@ -85,6 +85,7 @@ def __init__( http_path: str, http_headers, auth_provider: AuthProvider, + ssl_options: SSLOptions, staging_allowed_local_path: Union[None, str, List[str]] = None, **kwargs, ): @@ -93,16 +94,6 @@ def __init__( # Tag to add to User-Agent header. For use by partners. # _username, _password # Username and password Basic authentication (no official support) - # _tls_no_verify - # Set to True (Boolean) to completely disable SSL verification. - # _tls_verify_hostname - # Set to False (Boolean) to disable SSL hostname verification, but check certificate. - # _tls_trusted_ca_file - # Set to the path of the file containing trusted CA certificates for server certificate - # verification. If not provide, uses system truststore. - # _tls_client_cert_file, _tls_client_cert_key_file, _tls_client_cert_key_password - # Set client SSL certificate. - # See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.load_cert_chain # _connection_uri # Overrides server_hostname and http_path. # RETRY/ATTEMPT POLICY @@ -162,29 +153,7 @@ def __init__( # Cloud fetch self.max_download_threads = kwargs.get("max_download_threads", 10) - # Configure tls context - ssl_context = create_default_context(cafile=kwargs.get("_tls_trusted_ca_file")) - if kwargs.get("_tls_no_verify") is True: - ssl_context.check_hostname = False - ssl_context.verify_mode = CERT_NONE - elif kwargs.get("_tls_verify_hostname") is False: - ssl_context.check_hostname = False - ssl_context.verify_mode = CERT_REQUIRED - else: - ssl_context.check_hostname = True - ssl_context.verify_mode = CERT_REQUIRED - - tls_client_cert_file = kwargs.get("_tls_client_cert_file") - tls_client_cert_key_file = kwargs.get("_tls_client_cert_key_file") - tls_client_cert_key_password = kwargs.get("_tls_client_cert_key_password") - if tls_client_cert_file: - ssl_context.load_cert_chain( - certfile=tls_client_cert_file, - keyfile=tls_client_cert_key_file, - password=tls_client_cert_key_password, - ) - - self._ssl_context = ssl_context + self._ssl_options = ssl_options self._auth_provider = auth_provider @@ -225,7 +194,7 @@ def __init__( self._transport = databricks.sql.auth.thrift_http_client.THttpClient( auth_provider=self._auth_provider, uri_or_host=uri, - ssl_context=self._ssl_context, + ssl_options=self._ssl_options, **additional_transport_args, # type: ignore ) @@ -776,7 +745,7 @@ def _results_message_to_execute_response(self, resp, operation_state): max_download_threads=self.max_download_threads, lz4_compressed=lz4_compressed, description=description, - ssl_context=self._ssl_context, + ssl_options=self._ssl_options, ) else: arrow_queue_opt = None @@ -1008,7 +977,7 @@ def fetch_results( max_download_threads=self.max_download_threads, lz4_compressed=lz4_compressed, description=description, - ssl_context=self._ssl_context, + ssl_options=self._ssl_options, ) return queue, resp.hasMoreRows diff --git a/src/databricks/sql/types.py b/src/databricks/sql/types.py index aa11b954..fef22cd9 100644 --- a/src/databricks/sql/types.py +++ b/src/databricks/sql/types.py @@ -19,6 +19,54 @@ from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar import datetime import decimal +from ssl import SSLContext, CERT_NONE, CERT_REQUIRED, create_default_context + + +class SSLOptions: + tls_verify: bool + tls_verify_hostname: bool + tls_trusted_ca_file: Optional[str] + tls_client_cert_file: Optional[str] + tls_client_cert_key_file: Optional[str] + tls_client_cert_key_password: Optional[str] + + def __init__( + self, + tls_verify: bool = True, + tls_verify_hostname: bool = True, + tls_trusted_ca_file: Optional[str] = None, + tls_client_cert_file: Optional[str] = None, + tls_client_cert_key_file: Optional[str] = None, + tls_client_cert_key_password: Optional[str] = None, + ): + self.tls_verify = tls_verify + self.tls_verify_hostname = tls_verify_hostname + self.tls_trusted_ca_file = tls_trusted_ca_file + self.tls_client_cert_file = tls_client_cert_file + self.tls_client_cert_key_file = tls_client_cert_key_file + self.tls_client_cert_key_password = tls_client_cert_key_password + + def create_ssl_context(self) -> SSLContext: + ssl_context = create_default_context(cafile=self.tls_trusted_ca_file) + + if self.tls_verify is False: + ssl_context.check_hostname = False + ssl_context.verify_mode = CERT_NONE + elif self.tls_verify_hostname is False: + ssl_context.check_hostname = False + ssl_context.verify_mode = CERT_REQUIRED + else: + ssl_context.check_hostname = True + ssl_context.verify_mode = CERT_REQUIRED + + if self.tls_client_cert_file: + ssl_context.load_cert_chain( + certfile=self.tls_client_cert_file, + keyfile=self.tls_client_cert_key_file, + password=self.tls_client_cert_key_password, + ) + + return ssl_context class Row(tuple): diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index c22688bb..2807bd2b 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -9,7 +9,6 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union import re -from ssl import SSLContext import lz4.frame import pyarrow @@ -21,13 +20,14 @@ TSparkArrowResultLink, TSparkRowSetType, ) +from databricks.sql.types import SSLOptions from databricks.sql.parameters.native import ParameterStructure, TDbsqlParameter -BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128] - import logging +BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128] + logger = logging.getLogger(__name__) @@ -48,7 +48,7 @@ def build_queue( t_row_set: TRowSet, arrow_schema_bytes: bytes, max_download_threads: int, - ssl_context: SSLContext, + ssl_options: SSLOptions, lz4_compressed: bool = True, description: Optional[List[List[Any]]] = None, ) -> ResultSetQueue: @@ -62,7 +62,7 @@ def build_queue( lz4_compressed (bool): Whether result data has been lz4 compressed. description (List[List[Any]]): Hive table schema description. max_download_threads (int): Maximum number of downloader thread pool threads. - ssl_context (SSLContext): SSLContext object for CloudFetchQueue + ssl_options (SSLOptions): SSLOptions object for CloudFetchQueue Returns: ResultSetQueue @@ -91,7 +91,7 @@ def build_queue( lz4_compressed=lz4_compressed, description=description, max_download_threads=max_download_threads, - ssl_context=ssl_context, + ssl_options=ssl_options, ) else: raise AssertionError("Row set type is not valid") @@ -137,7 +137,7 @@ def __init__( self, schema_bytes, max_download_threads: int, - ssl_context: SSLContext, + ssl_options: SSLOptions, start_row_offset: int = 0, result_links: Optional[List[TSparkArrowResultLink]] = None, lz4_compressed: bool = True, @@ -160,7 +160,7 @@ def __init__( self.result_links = result_links self.lz4_compressed = lz4_compressed self.description = description - self._ssl_context = ssl_context + self._ssl_options = ssl_options logger.debug( "Initialize CloudFetch loader, row set start offset: {}, file list:".format( @@ -178,7 +178,7 @@ def __init__( links=result_links or [], max_download_threads=self.max_download_threads, lz4_compressed=self.lz4_compressed, - ssl_context=self._ssl_context, + ssl_options=self._ssl_options, ) self.table = self._create_next_table() diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index cd14c676..acd0c392 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -1,10 +1,10 @@ import pyarrow import unittest from unittest.mock import MagicMock, patch -from ssl import create_default_context from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink import databricks.sql.utils as utils +from databricks.sql.types import SSLOptions class CloudFetchQueueSuite(unittest.TestCase): @@ -51,7 +51,7 @@ def test_initializer_adds_links(self, mock_create_next_table): schema_bytes, result_links=result_links, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert len(queue.download_manager._pending_links) == 10 @@ -65,7 +65,7 @@ def test_initializer_no_links_to_add(self): schema_bytes, result_links=result_links, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert len(queue.download_manager._pending_links) == 0 @@ -78,7 +78,7 @@ def test_create_next_table_no_download(self, mock_get_next_downloaded_file): MagicMock(), result_links=[], max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue._create_next_table() is None @@ -95,7 +95,7 @@ def test_initializer_create_next_table_success(self, mock_get_next_downloaded_fi result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) expected_result = self.make_arrow_table() @@ -120,7 +120,7 @@ def test_next_n_rows_0_rows(self, mock_create_next_table): result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -140,7 +140,7 @@ def test_next_n_rows_partial_table(self, mock_create_next_table): result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -160,7 +160,7 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -180,7 +180,7 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -199,7 +199,7 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue.table is None @@ -216,7 +216,7 @@ def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table) result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -235,7 +235,7 @@ def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_tabl result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -254,7 +254,7 @@ def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -273,7 +273,7 @@ def test_remaining_rows_multiple_tables_fully_returned(self, mock_create_next_ta result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 @@ -293,7 +293,7 @@ def test_remaining_rows_empty_table(self, mock_create_next_table): result_links=[], description=description, max_download_threads=10, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) assert queue.table is None diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py index c084d8e4..a11bc8d4 100644 --- a/tests/unit/test_download_manager.py +++ b/tests/unit/test_download_manager.py @@ -1,9 +1,8 @@ import unittest from unittest.mock import patch, MagicMock -from ssl import create_default_context - import databricks.sql.cloudfetch.download_manager as download_manager +from databricks.sql.types import SSLOptions from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink @@ -17,7 +16,7 @@ def create_download_manager(self, links, max_download_threads=10, lz4_compressed links, max_download_threads, lz4_compressed, - ssl_context=create_default_context(), + ssl_options=SSLOptions(), ) def create_result_link( diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index b6e473b5..7075ef6c 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -2,10 +2,10 @@ from unittest.mock import Mock, patch, MagicMock import requests -from ssl import create_default_context import databricks.sql.cloudfetch.downloader as downloader from databricks.sql.exc import Error +from databricks.sql.types import SSLOptions def create_response(**kwargs) -> requests.Response: @@ -26,7 +26,7 @@ def test_run_link_expired(self, mock_time): result_link = Mock() # Already expired result_link.expiryTime = 999 - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) with self.assertRaises(Error) as context: d.run() @@ -40,7 +40,7 @@ def test_run_link_past_expiry_buffer(self, mock_time): result_link = Mock() # Within the expiry buffer time result_link.expiryTime = 1004 - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) with self.assertRaises(Error) as context: d.run() @@ -58,7 +58,7 @@ def test_run_get_response_not_ok(self, mock_time, mock_session): settings.use_proxy = False result_link = Mock(expiryTime=1001) - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) with self.assertRaises(requests.exceptions.HTTPError) as context: d.run() self.assertTrue('404' in str(context.exception)) @@ -73,7 +73,7 @@ def test_run_uncompressed_successful(self, mock_time, mock_session): settings.is_lz4_compressed = False result_link = Mock(bytesNum=100, expiryTime=1001) - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) file = d.run() assert file.file_bytes == b"1234567890" * 10 @@ -89,7 +89,7 @@ def test_run_compressed_successful(self, mock_time, mock_session): settings.is_lz4_compressed = True result_link = Mock(bytesNum=100, expiryTime=1001) - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) file = d.run() assert file.file_bytes == b"1234567890" * 10 @@ -102,7 +102,7 @@ def test_download_connection_error(self, mock_time, mock_session): mock_session.return_value.get.return_value.content = \ b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) with self.assertRaises(ConnectionError): d.run() @@ -114,6 +114,6 @@ def test_download_timeout(self, mock_time, mock_session): mock_session.return_value.get.return_value.content = \ b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_context=create_default_context()) + d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) with self.assertRaises(TimeoutError): d.run() diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 4bcf84d2..0333766c 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -4,11 +4,13 @@ import unittest from unittest.mock import patch, MagicMock, Mock from ssl import CERT_NONE, CERT_REQUIRED +from urllib3 import HTTPSConnectionPool import pyarrow import databricks.sql from databricks.sql import utils +from databricks.sql.types import SSLOptions from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql import * from databricks.sql.auth.authenticators import AuthProvider @@ -67,7 +69,7 @@ def test_make_request_checks_thrift_status_code(self): mock_method = Mock() mock_method.__name__ = "method name" mock_method.return_value = mock_response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) with self.assertRaises(DatabaseError): thrift_backend.make_request(mock_method, Mock()) @@ -77,7 +79,7 @@ def _make_type_desc(self, type): ) def _make_fake_thrift_backend(self): - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) thrift_backend._hive_schema_to_arrow_schema = Mock() thrift_backend._hive_schema_to_description = Mock() thrift_backend._create_arrow_table = MagicMock() @@ -155,7 +157,7 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_headers_are_set(self, t_http_client_class): - ThriftBackend("foo", 123, "bar", [("header", "value")], auth_provider=AuthProvider()) + ThriftBackend("foo", 123, "bar", [("header", "value")], auth_provider=AuthProvider(), ssl_options=SSLOptions()) t_http_client_class.return_value.setCustomHeaders.assert_called_with({"header": "value"}) def test_proxy_headers_are_set(self): @@ -175,75 +177,140 @@ def test_proxy_headers_are_set(self): assert isinstance(result.get('proxy-authorization'), type(str())) @patch("databricks.sql.auth.thrift_http_client.THttpClient") - @patch("databricks.sql.thrift_backend.create_default_context") + @patch("databricks.sql.types.create_default_context") def test_tls_cert_args_are_propagated(self, mock_create_default_context, t_http_client_class): mock_cert_key_file = Mock() mock_cert_key_password = Mock() mock_trusted_ca_file = Mock() mock_cert_file = Mock() + mock_ssl_options = SSLOptions( + tls_client_cert_file=mock_cert_file, + tls_client_cert_key_file=mock_cert_key_file, + tls_client_cert_key_password=mock_cert_key_password, + tls_trusted_ca_file=mock_trusted_ca_file, + ) + mock_ssl_context = mock_ssl_options.create_ssl_context() + mock_create_default_context.assert_called_once_with(cafile=mock_trusted_ca_file) + ThriftBackend( "foo", 123, "bar", [], auth_provider=AuthProvider(), - _tls_client_cert_file=mock_cert_file, - _tls_client_cert_key_file=mock_cert_key_file, - _tls_client_cert_key_password=mock_cert_key_password, - _tls_trusted_ca_file=mock_trusted_ca_file, + ssl_options=mock_ssl_options, ) - mock_create_default_context.assert_called_once_with(cafile=mock_trusted_ca_file) - mock_ssl_context = mock_create_default_context.return_value mock_ssl_context.load_cert_chain.assert_called_once_with( certfile=mock_cert_file, keyfile=mock_cert_key_file, password=mock_cert_key_password ) self.assertTrue(mock_ssl_context.check_hostname) self.assertEqual(mock_ssl_context.verify_mode, CERT_REQUIRED) - self.assertEqual(t_http_client_class.call_args[1]["ssl_context"], mock_ssl_context) + self.assertEqual(t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options) + + @patch("databricks.sql.types.create_default_context") + def test_tls_cert_args_are_used_by_http_client(self, mock_create_default_context): + from databricks.sql.auth.thrift_http_client import THttpClient + + mock_cert_key_file = Mock() + mock_cert_key_password = Mock() + mock_trusted_ca_file = Mock() + mock_cert_file = Mock() + + mock_ssl_options = SSLOptions( + tls_verify=True, + tls_client_cert_file=mock_cert_file, + tls_client_cert_key_file=mock_cert_key_file, + tls_client_cert_key_password=mock_cert_key_password, + tls_trusted_ca_file=mock_trusted_ca_file, + ) + + http_client = THttpClient( + auth_provider=None, + uri_or_host="https://example.com", + ssl_options=mock_ssl_options, + ) + + self.assertEqual(http_client.scheme, 'https') + self.assertEqual(http_client.certfile, mock_ssl_options.tls_client_cert_file) + self.assertEqual(http_client.keyfile, mock_ssl_options.tls_client_cert_key_file) + self.assertIsNotNone(http_client.certfile) + mock_create_default_context.assert_called() + + http_client.open() + + conn_pool = http_client._THttpClient__pool + self.assertIsInstance(conn_pool, HTTPSConnectionPool) + self.assertEqual(conn_pool.cert_reqs, CERT_REQUIRED) + self.assertEqual(conn_pool.ca_certs, mock_ssl_options.tls_trusted_ca_file) + self.assertEqual(conn_pool.cert_file, mock_ssl_options.tls_client_cert_file) + self.assertEqual(conn_pool.key_file, mock_ssl_options.tls_client_cert_key_file) + self.assertEqual(conn_pool.key_password, mock_ssl_options.tls_client_cert_key_password) + + def test_tls_no_verify_is_respected_by_http_client(self): + from databricks.sql.auth.thrift_http_client import THttpClient + + http_client = THttpClient( + auth_provider=None, + uri_or_host="https://example.com", + ssl_options=SSLOptions(tls_verify=False), + ) + self.assertEqual(http_client.scheme, 'https') + + http_client.open() + + conn_pool = http_client._THttpClient__pool + self.assertIsInstance(conn_pool, HTTPSConnectionPool) + self.assertEqual(conn_pool.cert_reqs, CERT_NONE) @patch("databricks.sql.auth.thrift_http_client.THttpClient") - @patch("databricks.sql.thrift_backend.create_default_context") + @patch("databricks.sql.types.create_default_context") def test_tls_no_verify_is_respected(self, mock_create_default_context, t_http_client_class): - ThriftBackend("foo", 123, "bar", [], auth_provider=AuthProvider(), _tls_no_verify=True) + mock_ssl_options = SSLOptions(tls_verify=False) + mock_ssl_context = mock_ssl_options.create_ssl_context() + mock_create_default_context.assert_called() + + ThriftBackend("foo", 123, "bar", [], auth_provider=AuthProvider(), ssl_options=mock_ssl_options) - mock_ssl_context = mock_create_default_context.return_value self.assertFalse(mock_ssl_context.check_hostname) self.assertEqual(mock_ssl_context.verify_mode, CERT_NONE) - self.assertEqual(t_http_client_class.call_args[1]["ssl_context"], mock_ssl_context) + self.assertEqual(t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options) @patch("databricks.sql.auth.thrift_http_client.THttpClient") - @patch("databricks.sql.thrift_backend.create_default_context") + @patch("databricks.sql.types.create_default_context") def test_tls_verify_hostname_is_respected( self, mock_create_default_context, t_http_client_class ): + mock_ssl_options = SSLOptions(tls_verify_hostname=False) + mock_ssl_context = mock_ssl_options.create_ssl_context() + mock_create_default_context.assert_called() + ThriftBackend( - "foo", 123, "bar", [], auth_provider=AuthProvider(), _tls_verify_hostname=False + "foo", 123, "bar", [], auth_provider=AuthProvider(), ssl_options=mock_ssl_options ) - mock_ssl_context = mock_create_default_context.return_value self.assertFalse(mock_ssl_context.check_hostname) self.assertEqual(mock_ssl_context.verify_mode, CERT_REQUIRED) - self.assertEqual(t_http_client_class.call_args[1]["ssl_context"], mock_ssl_context) + self.assertEqual(t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options) @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_port_and_host_are_respected(self, t_http_client_class): - ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider()) + ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) self.assertEqual( t_http_client_class.call_args[1]["uri_or_host"], "https://hostname:123/path_value" ) @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_host_with_https_does_not_duplicate(self, t_http_client_class): - ThriftBackend("https://hostname", 123, "path_value", [], auth_provider=AuthProvider()) + ThriftBackend("https://hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) self.assertEqual( t_http_client_class.call_args[1]["uri_or_host"], "https://hostname:123/path_value" ) @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_host_with_trailing_backslash_does_not_duplicate(self, t_http_client_class): - ThriftBackend("https://hostname/", 123, "path_value", [], auth_provider=AuthProvider()) + ThriftBackend("https://hostname/", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) self.assertEqual( t_http_client_class.call_args[1]["uri_or_host"], "https://hostname:123/path_value" ) @@ -251,17 +318,17 @@ def test_host_with_trailing_backslash_does_not_duplicate(self, t_http_client_cla @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_socket_timeout_is_propagated(self, t_http_client_class): ThriftBackend( - "hostname", 123, "path_value", [], auth_provider=AuthProvider(), _socket_timeout=129 + "hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), _socket_timeout=129 ) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 129 * 1000) ThriftBackend( - "hostname", 123, "path_value", [], auth_provider=AuthProvider(), _socket_timeout=0 + "hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), _socket_timeout=0 ) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 0) - ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider()) + ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 900 * 1000) ThriftBackend( - "hostname", 123, "path_value", [], auth_provider=AuthProvider(), _socket_timeout=None + "hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), _socket_timeout=None ) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], None) @@ -350,7 +417,7 @@ def test_hive_schema_to_description_preserves_scale_and_precision(self): def test_make_request_checks_status_code(self): error_codes = [ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS] - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) for code in error_codes: mock_error_response = Mock() @@ -388,7 +455,7 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): ), ) thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider() + "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions() ) with self.assertRaises(DatabaseError) as cm: @@ -417,7 +484,7 @@ def test_handle_execute_response_sets_compression_in_direct_results(self, build_ closeOperation=None, ), ) - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) execute_response = thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) @@ -449,7 +516,7 @@ def test_handle_execute_response_checks_operation_state_in_polls(self, tcli_serv tcli_service_instance.GetOperationStatus.return_value = op_state_resp thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider() + "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions() ) with self.assertRaises(DatabaseError) as cm: @@ -477,7 +544,7 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): tcli_service_instance.GetOperationStatus.return_value = t_get_operation_status_resp tcli_service_instance.ExecuteStatement.return_value = t_execute_resp - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) with self.assertRaises(DatabaseError) as cm: thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) @@ -510,7 +577,7 @@ def test_direct_results_uses_display_message_if_available(self, tcli_service_cla tcli_service_instance.ExecuteStatement.return_value = t_execute_resp - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) with self.assertRaises(DatabaseError) as cm: thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) @@ -562,7 +629,7 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): for error_resp in [resp_1, resp_2, resp_3, resp_4]: with self.subTest(error_resp=error_resp): thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider() + "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions() ) with self.assertRaises(DatabaseError) as cm: @@ -603,7 +670,7 @@ def test_handle_execute_response_can_handle_without_direct_results(self, tcli_se op_state_3, ] thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider() + "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions() ) results_message_response = thrift_backend._handle_execute_response( execute_resp, Mock() @@ -634,7 +701,7 @@ def test_handle_execute_response_can_handle_with_direct_results(self): ) thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider() + "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions() ) thrift_backend._results_message_to_execute_response = Mock() @@ -818,7 +885,7 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): .to_pybytes() ) - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) arrow_queue, has_more_results = thrift_backend.fetch_results( op_handle=Mock(), max_rows=1, @@ -836,7 +903,7 @@ def test_execute_statement_calls_client_and_handle_execute_response(self, tcli_s tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.ExecuteStatement.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -854,7 +921,7 @@ def test_get_catalogs_calls_client_and_handle_execute_response(self, tcli_servic tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetCatalogs.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -871,7 +938,7 @@ def test_get_schemas_calls_client_and_handle_execute_response(self, tcli_service tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetSchemas.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -897,7 +964,7 @@ def test_get_tables_calls_client_and_handle_execute_response(self, tcli_service_ tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetTables.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -927,7 +994,7 @@ def test_get_columns_calls_client_and_handle_execute_response(self, tcli_service tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetColumns.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -957,14 +1024,14 @@ def test_open_session_user_provided_session_id_optional(self, tcli_service_class tcli_service_instance = tcli_service_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) thrift_backend.open_session({}, None, None) self.assertEqual(len(tcli_service_instance.OpenSession.call_args_list), 1) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_op_handle_respected_in_close_command(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) thrift_backend.close_command(self.operation_handle) self.assertEqual( tcli_service_instance.CloseOperation.call_args[0][0].operationHandle, @@ -974,7 +1041,7 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_session_handle_respected_in_close_session(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) thrift_backend.close_session(self.session_handle) self.assertEqual( tcli_service_instance.CloseSession.call_args[0][0].sessionHandle, self.session_handle @@ -1012,7 +1079,7 @@ def test_non_arrow_non_column_based_set_triggers_exception(self, tcli_service_cl def test_create_arrow_table_raises_error_for_unsupported_type(self): t_row_set = ttypes.TRowSet() - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) with self.assertRaises(OperationalError): thrift_backend._create_arrow_table(t_row_set, Mock(), None, Mock()) @@ -1021,7 +1088,7 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): def test_create_arrow_table_calls_correct_conversion_method( self, convert_col_mock, convert_arrow_mock ): - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) convert_arrow_mock.return_value = (MagicMock(), Mock()) convert_col_mock.return_value = (MagicMock(), Mock()) @@ -1043,7 +1110,7 @@ def test_create_arrow_table_calls_correct_conversion_method( @patch("lz4.frame.decompress") @patch("pyarrow.ipc.open_stream") def test_convert_arrow_based_set_to_arrow_table(self, open_stream_mock, lz4_decompress_mock): - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) lz4_decompress_mock.return_value = bytearray("Testing", "utf-8") @@ -1221,6 +1288,7 @@ def test_make_request_will_retry_GetOperationStatus( "path", [], auth_provider=AuthProvider(), + ssl_options=SSLOptions(), _retry_stop_after_attempts_count=EXPECTED_RETRIES, _retry_delay_default=1, ) @@ -1287,6 +1355,7 @@ def test_make_request_will_retry_GetOperationStatus_for_http_error( "path", [], auth_provider=AuthProvider(), + ssl_options=SSLOptions(), _retry_stop_after_attempts_count=EXPECTED_RETRIES, _retry_delay_default=1, ) @@ -1308,7 +1377,7 @@ def test_make_request_wont_retry_if_error_code_not_429_or_503(self, t_transport_ mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) with self.assertRaises(OperationalError) as cm: thrift_backend.make_request(mock_method, Mock()) @@ -1333,6 +1402,7 @@ def test_make_request_will_retry_stop_after_attempts_count_if_retryable( "path", [], auth_provider=AuthProvider(), + ssl_options=SSLOptions(), _retry_stop_after_attempts_count=14, _retry_delay_max=0, _retry_delay_min=0, @@ -1353,7 +1423,7 @@ def test_make_request_will_read_error_message_headers_if_set(self, t_transport_c mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) error_headers = [ [("x-thriftserver-error-message", "thrift server error message")], @@ -1454,7 +1524,7 @@ def test_retry_args_passthrough(self, mock_http_client): "_retry_stop_after_attempts_duration": 100, } backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), **retry_delay_args + "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), **retry_delay_args ) for arg, val in retry_delay_args.items(): self.assertEqual(getattr(backend, arg), val) @@ -1470,7 +1540,7 @@ def test_retry_args_bounding(self, mock_http_client): k: v[i][0] for (k, v) in retry_delay_test_args_and_expected_values.items() } backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), **retry_delay_args + "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), **retry_delay_args ) retry_delay_expected_vals = { k: v[i][1] for (k, v) in retry_delay_test_args_and_expected_values.items() @@ -1490,7 +1560,7 @@ def test_configuration_passthrough(self, tcli_client_class): "42": "42", } - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) backend.open_session(mock_config, None, None) open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] @@ -1501,7 +1571,7 @@ def test_cant_set_timestamp_as_string_to_true(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp mock_config = {"spark.thriftserver.arrowBasedRowSet.timestampAsString": True} - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) with self.assertRaises(databricks.sql.Error) as cm: backend.open_session(mock_config, None, None) @@ -1520,7 +1590,7 @@ def _construct_open_session_with_namespace(self, can_use_multiple_cats, cat, sch def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) initial_cat_schem_args = [("cat", None), (None, "schem"), ("cat", "schem")] for cat, schem in initial_cat_schem_args: @@ -1540,7 +1610,7 @@ def test_can_use_multiple_catalogs_is_set_in_open_session_req(self, tcli_client_ tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) backend.open_session({}, None, None) open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] @@ -1550,7 +1620,7 @@ def test_can_use_multiple_catalogs_is_set_in_open_session_req(self, tcli_client_ def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) # If the initial catalog is set, but server returns canUseMultipleCatalogs=False, we # expect failure. If the initial catalog isn't set, then canUseMultipleCatalogs=False # is fine @@ -1588,7 +1658,7 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): initialNamespace=ttypes.TNamespace(catalogName="cat", schemaName="schem"), ) - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) + backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) with self.assertRaises(InvalidServerResponseError) as cm: backend.open_session({}, "cat", "schem") @@ -1616,7 +1686,7 @@ def test_execute_command_sets_complex_type_fields_correctly( complex_arg_types["_use_arrow_native_decimals"] = decimals thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), **complex_arg_types + "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), **complex_arg_types ) thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) t_execute_statement_req = tcli_service_instance.ExecuteStatement.call_args[0][0] From d31063ca918167412153a368c13a99055bf89c02 Mon Sep 17 00:00:00 2001 From: Levko Kravets Date: Tue, 27 Aug 2024 11:20:32 +0300 Subject: [PATCH 02/13] Prepare release v3.4.0 (#430) Prepare release 3.4.0 Signed-off-by: Levko Kravets --- CHANGELOG.md | 6 ++++++ pyproject.toml | 2 +- src/databricks/sql/__init__.py | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4fca2046..14db0f8b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Release History +# 3.4.0 (2024-08-27) + +- Unpin pandas to support v2.2.2 (databricks/databricks-sql-python#416 by @kfollesdal) +- Make OAuth as the default authenticator if no authentication setting is provided (databricks/databricks-sql-python#419 by @jackyhu-db) +- Fix (regression): use SSL options with HTTPS connection pool (databricks/databricks-sql-python#425 by @kravets-levko) + # 3.3.0 (2024-07-18) - Don't retry requests that fail with HTTP code 401 (databricks/databricks-sql-python#408 by @Hodnebo) diff --git a/pyproject.toml b/pyproject.toml index 44d25ef9..69d0c274 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "3.3.0" +version = "3.4.0" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index 767cbf05..c0fdc2f1 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -68,7 +68,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "3.3.0" +__version__ = "3.4.0" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy From a151df2b925aa4b7800919b3eefd57452af3f608 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Thu, 3 Oct 2024 18:21:15 +0530 Subject: [PATCH 03/13] [PECO-1926] Create a non pyarrow flow to handle small results for the column set (#440) * Implemented the columnar flow for non arrow users * Minor fixes * Introduced the Column Table structure * Added test for the new column table * Minor fix * Removed unnecessory fikes --- src/databricks/sql/client.py | 96 +++++++++++++++-- src/databricks/sql/thrift_backend.py | 25 +++-- src/databricks/sql/utils.py | 152 ++++++++++++++++++++++++--- tests/unit/test_column_queue.py | 22 ++++ 4 files changed, 263 insertions(+), 32 deletions(-) create mode 100644 tests/unit/test_column_queue.py diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index addc340e..4df67a08 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1,7 +1,10 @@ from typing import Dict, Tuple, List, Optional, Any, Union, Sequence import pandas -import pyarrow +try: + import pyarrow +except ImportError: + pyarrow = None import requests import json import os @@ -22,6 +25,8 @@ ParamEscaper, inject_parameters, transform_paramstyle, + ColumnTable, + ColumnQueue ) from databricks.sql.parameters.native import ( DbsqlParameterBase, @@ -991,14 +996,14 @@ def fetchmany(self, size: int) -> List[Row]: else: raise Error("There is no active result set") - def fetchall_arrow(self) -> pyarrow.Table: + def fetchall_arrow(self) -> "pyarrow.Table": self._check_not_closed() if self.active_result_set: return self.active_result_set.fetchall_arrow() else: raise Error("There is no active result set") - def fetchmany_arrow(self, size) -> pyarrow.Table: + def fetchmany_arrow(self, size) -> "pyarrow.Table": self._check_not_closed() if self.active_result_set: return self.active_result_set.fetchmany_arrow(size) @@ -1143,6 +1148,18 @@ def _fill_results_buffer(self): self.results = results self.has_more_rows = has_more_rows + def _convert_columnar_table(self, table): + column_names = [c[0] for c in self.description] + ResultRow = Row(*column_names) + result = [] + for row_index in range(table.num_rows): + curr_row = [] + for col_index in range(table.num_columns): + curr_row.append(table.get_item(col_index, row_index)) + result.append(ResultRow(*curr_row)) + + return result + def _convert_arrow_table(self, table): column_names = [c[0] for c in self.description] ResultRow = Row(*column_names) @@ -1185,7 +1202,7 @@ def _convert_arrow_table(self, table): def rownumber(self): return self._next_row_index - def fetchmany_arrow(self, size: int) -> pyarrow.Table: + def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """ Fetch the next set of rows of a query result, returning a PyArrow table. @@ -1210,7 +1227,46 @@ def fetchmany_arrow(self, size: int) -> pyarrow.Table: return results - def fetchall_arrow(self) -> pyarrow.Table: + def merge_columnar(self, result1, result2): + """ + Function to merge / combining the columnar results into a single result + :param result1: + :param result2: + :return: + """ + + if result1.column_names != result2.column_names: + raise ValueError("The columns in the results don't match") + + merged_result = [result1.column_table[i] + result2.column_table[i] for i in range(result1.num_columns)] + return ColumnTable(merged_result, result1.column_names) + + def fetchmany_columnar(self, size: int): + """ + Fetch the next set of rows of a query result, returning a Columnar Table. + An empty sequence is returned when no more rows are available. + """ + if size < 0: + raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + + results = self.results.next_n_rows(size) + n_remaining_rows = size - results.num_rows + self._next_row_index += results.num_rows + + while ( + n_remaining_rows > 0 + and not self.has_been_closed_server_side + and self.has_more_rows + ): + self._fill_results_buffer() + partial_results = self.results.next_n_rows(n_remaining_rows) + results = self.merge_columnar(results, partial_results) + n_remaining_rows -= partial_results.num_rows + self._next_row_index += partial_results.num_rows + + return results + + def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" results = self.results.remaining_rows() self._next_row_index += results.num_rows @@ -1223,12 +1279,30 @@ def fetchall_arrow(self) -> pyarrow.Table: return results + def fetchall_columnar(self): + """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" + results = self.results.remaining_rows() + self._next_row_index += results.num_rows + + while not self.has_been_closed_server_side and self.has_more_rows: + self._fill_results_buffer() + partial_results = self.results.remaining_rows() + results = self.merge_columnar(results, partial_results) + self._next_row_index += partial_results.num_rows + + return results + def fetchone(self) -> Optional[Row]: """ Fetch the next row of a query result set, returning a single sequence, or None when no more data is available. """ - res = self._convert_arrow_table(self.fetchmany_arrow(1)) + + if isinstance(self.results, ColumnQueue): + res = self._convert_columnar_table(self.fetchmany_columnar(1)) + else: + res = self._convert_arrow_table(self.fetchmany_arrow(1)) + if len(res) > 0: return res[0] else: @@ -1238,7 +1312,10 @@ def fetchall(self) -> List[Row]: """ Fetch all (remaining) rows of a query result, returning them as a list of rows. """ - return self._convert_arrow_table(self.fetchall_arrow()) + if isinstance(self.results, ColumnQueue): + return self._convert_columnar_table(self.fetchall_columnar()) + else: + return self._convert_arrow_table(self.fetchall_arrow()) def fetchmany(self, size: int) -> List[Row]: """ @@ -1246,7 +1323,10 @@ def fetchmany(self, size: int) -> List[Row]: An empty sequence is returned when no more rows are available. """ - return self._convert_arrow_table(self.fetchmany_arrow(size)) + if isinstance(self.results, ColumnQueue): + return self._convert_columnar_table(self.fetchmany_columnar(size)) + else: + return self._convert_arrow_table(self.fetchmany_arrow(size)) def close(self) -> None: """ diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index e89bff26..7f6ada9d 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -7,7 +7,10 @@ import threading from typing import List, Union -import pyarrow +try: + import pyarrow +except ImportError: + pyarrow = None import thrift.transport.THttpClient import thrift.protocol.TBinaryProtocol import thrift.transport.TSocket @@ -621,6 +624,7 @@ def _get_metadata_resp(self, op_handle): @staticmethod def _hive_schema_to_arrow_schema(t_table_schema): + def map_type(t_type_entry): if t_type_entry.primitiveEntry: return { @@ -726,12 +730,17 @@ def _results_message_to_execute_response(self, resp, operation_state): description = self._hive_schema_to_description( t_result_set_metadata_resp.schema ) - schema_bytes = ( - t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) - .serialize() - .to_pybytes() - ) + + if pyarrow: + schema_bytes = ( + t_result_set_metadata_resp.arrowSchema + or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) + .serialize() + .to_pybytes() + ) + else: + schema_bytes = None + lz4_compressed = t_result_set_metadata_resp.lz4Compressed is_staging_operation = t_result_set_metadata_resp.isStagingOperation if direct_results and direct_results.resultSet: @@ -827,7 +836,7 @@ def execute_command( getDirectResults=ttypes.TSparkGetDirectResults( maxRows=max_rows, maxBytes=max_bytes ), - canReadArrowResult=True, + canReadArrowResult=True if pyarrow else False, canDecompressLZ4Result=lz4_compression, canDownloadResult=use_cloud_fetch, confOverlay={ diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 2807bd2b..97df6d4d 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import pytz import datetime import decimal from abc import ABC, abstractmethod @@ -11,7 +12,10 @@ import re import lz4.frame -import pyarrow +try: + import pyarrow +except ImportError: + pyarrow = None from databricks.sql import OperationalError, exc from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager @@ -33,11 +37,11 @@ class ResultSetQueue(ABC): @abstractmethod - def next_n_rows(self, num_rows: int) -> pyarrow.Table: + def next_n_rows(self, num_rows: int): pass @abstractmethod - def remaining_rows(self) -> pyarrow.Table: + def remaining_rows(self): pass @@ -76,13 +80,15 @@ def build_queue( ) return ArrowQueue(converted_arrow_table, n_valid_rows) elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET: - arrow_table, n_valid_rows = convert_column_based_set_to_arrow_table( + column_table, column_names = convert_column_based_set_to_column_table( t_row_set.columns, description ) - converted_arrow_table = convert_decimals_in_arrow_table( - arrow_table, description + + converted_column_table = convert_to_assigned_datatypes_in_column_table( + column_table, description ) - return ArrowQueue(converted_arrow_table, n_valid_rows) + + return ColumnQueue(ColumnTable(converted_column_table, column_names)) elif row_set_type == TSparkRowSetType.URL_BASED_SET: return CloudFetchQueue( schema_bytes=arrow_schema_bytes, @@ -96,11 +102,55 @@ def build_queue( else: raise AssertionError("Row set type is not valid") +class ColumnTable: + def __init__(self, column_table, column_names): + self.column_table = column_table + self.column_names = column_names + + @property + def num_rows(self): + if len(self.column_table) == 0: + return 0 + else: + return len(self.column_table[0]) + + @property + def num_columns(self): + return len(self.column_names) + + def get_item(self, col_index, row_index): + return self.column_table[col_index][row_index] + + def slice(self, curr_index, length): + sliced_column_table = [column[curr_index : curr_index + length] for column in self.column_table] + return ColumnTable(sliced_column_table, self.column_names) + + def __eq__(self, other): + return self.column_table == other.column_table and self.column_names == other.column_names + +class ColumnQueue(ResultSetQueue): + def __init__(self, column_table: ColumnTable): + self.column_table = column_table + self.cur_row_index = 0 + self.n_valid_rows = column_table.num_rows + + def next_n_rows(self, num_rows): + length = min(num_rows, self.n_valid_rows - self.cur_row_index) + + slice = self.column_table.slice(self.cur_row_index, length) + self.cur_row_index += slice.num_rows + return slice + + def remaining_rows(self): + slice = self.column_table.slice(self.cur_row_index, self.n_valid_rows - self.cur_row_index) + self.cur_row_index += slice.num_rows + return slice + class ArrowQueue(ResultSetQueue): def __init__( self, - arrow_table: pyarrow.Table, + arrow_table: "pyarrow.Table", n_valid_rows: int, start_row_index: int = 0, ): @@ -115,7 +165,7 @@ def __init__( self.arrow_table = arrow_table self.n_valid_rows = n_valid_rows - def next_n_rows(self, num_rows: int) -> pyarrow.Table: + def next_n_rows(self, num_rows: int) -> "pyarrow.Table": """Get upto the next n rows of the Arrow dataframe""" length = min(num_rows, self.n_valid_rows - self.cur_row_index) # Note that the table.slice API is not the same as Python's slice @@ -124,7 +174,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table: self.cur_row_index += slice.num_rows return slice - def remaining_rows(self) -> pyarrow.Table: + def remaining_rows(self) -> "pyarrow.Table": slice = self.arrow_table.slice( self.cur_row_index, self.n_valid_rows - self.cur_row_index ) @@ -184,7 +234,7 @@ def __init__( self.table = self._create_next_table() self.table_row_index = 0 - def next_n_rows(self, num_rows: int) -> pyarrow.Table: + def next_n_rows(self, num_rows: int) -> "pyarrow.Table": """ Get up to the next n rows of the cloud fetch Arrow dataframes. @@ -216,7 +266,7 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table: logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows)) return results - def remaining_rows(self) -> pyarrow.Table: + def remaining_rows(self) -> "pyarrow.Table": """ Get all remaining rows of the cloud fetch Arrow dataframes. @@ -237,7 +287,7 @@ def remaining_rows(self) -> pyarrow.Table: self.table_row_index = 0 return results - def _create_next_table(self) -> Union[pyarrow.Table, None]: + def _create_next_table(self) -> Union["pyarrow.Table", None]: logger.debug( "CloudFetchQueue: Trying to get downloaded file for row {}".format( self.start_row_index @@ -276,7 +326,7 @@ def _create_next_table(self) -> Union[pyarrow.Table, None]: return arrow_table - def _create_empty_table(self) -> pyarrow.Table: + def _create_empty_table(self) -> "pyarrow.Table": # Create a 0-row table with just the schema bytes return create_arrow_table_from_arrow_file(self.schema_bytes, self.description) @@ -515,7 +565,7 @@ def transform_paramstyle( return output -def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> pyarrow.Table: +def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> "pyarrow.Table": arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes) return convert_decimals_in_arrow_table(arrow_table, description) @@ -542,7 +592,7 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema return arrow_table, n_rows -def convert_decimals_in_arrow_table(table, description) -> pyarrow.Table: +def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table": for i, col in enumerate(table.itercolumns()): if description[i][1] == "decimal": decimal_col = col.to_pandas().apply( @@ -560,6 +610,33 @@ def convert_decimals_in_arrow_table(table, description) -> pyarrow.Table: return table +def convert_to_assigned_datatypes_in_column_table(column_table, description): + + converted_column_table = [] + for i, col in enumerate(column_table): + if description[i][1] == "decimal": + converted_column_table.append(tuple(v if v is None else Decimal(v) for v in col)) + elif description[i][1] == "date": + converted_column_table.append(tuple( + v if v is None else datetime.date.fromisoformat(v) for v in col + )) + elif description[i][1] == "timestamp": + converted_column_table.append(tuple( + ( + v + if v is None + else datetime.datetime.strptime(v, "%Y-%m-%d %H:%M:%S.%f").replace( + tzinfo=pytz.UTC + ) + ) + for v in col + )) + else: + converted_column_table.append(col) + + return converted_column_table + + def convert_column_based_set_to_arrow_table(columns, description): arrow_table = pyarrow.Table.from_arrays( [_convert_column_to_arrow_array(c) for c in columns], @@ -571,6 +648,13 @@ def convert_column_based_set_to_arrow_table(columns, description): return arrow_table, arrow_table.num_rows +def convert_column_based_set_to_column_table(columns, description): + column_names = [c[0] for c in description] + column_table = [_convert_column_to_list(c) for c in columns] + + return column_table, column_names + + def _convert_column_to_arrow_array(t_col): """ Return a pyarrow array from the values in a TColumn instance. @@ -595,6 +679,26 @@ def _convert_column_to_arrow_array(t_col): raise OperationalError("Empty TColumn instance {}".format(t_col)) +def _convert_column_to_list(t_col): + SUPPORTED_FIELD_TYPES = ( + "boolVal", + "byteVal", + "i16Val", + "i32Val", + "i64Val", + "doubleVal", + "stringVal", + "binaryVal", + ) + + for field in SUPPORTED_FIELD_TYPES: + wrapper = getattr(t_col, field) + if wrapper: + return _create_python_tuple(wrapper) + + raise OperationalError("Empty TColumn instance {}".format(t_col)) + + def _create_arrow_array(t_col_value_wrapper, arrow_type): result = t_col_value_wrapper.values nulls = t_col_value_wrapper.nulls # bitfield describing which values are null @@ -609,3 +713,19 @@ def _create_arrow_array(t_col_value_wrapper, arrow_type): result[i] = None return pyarrow.array(result, type=arrow_type) + + +def _create_python_tuple(t_col_value_wrapper): + result = t_col_value_wrapper.values + nulls = t_col_value_wrapper.nulls # bitfield describing which values are null + assert isinstance(nulls, bytes) + + # The number of bits in nulls can be both larger or smaller than the number of + # elements in result, so take the minimum of both to iterate over. + length = min(len(result), len(nulls) * 8) + + for i in range(length): + if nulls[i >> 3] & BIT_MASKS[i & 0x7]: + result[i] = None + + return tuple(result) \ No newline at end of file diff --git a/tests/unit/test_column_queue.py b/tests/unit/test_column_queue.py new file mode 100644 index 00000000..130b589b --- /dev/null +++ b/tests/unit/test_column_queue.py @@ -0,0 +1,22 @@ +from databricks.sql.utils import ColumnQueue, ColumnTable + + +class TestColumnQueueSuite: + @staticmethod + def make_column_table(table): + n_cols = len(table) if table else 0 + return ColumnTable(table, [f"col_{i}" for i in range(n_cols)]) + + def test_fetchmany_respects_n_rows(self): + column_table = self.make_column_table([[0, 3, 6, 9], [1, 4, 7, 10], [2, 5, 8, 11]]) + column_queue = ColumnQueue(column_table) + + assert column_queue.next_n_rows(2) == column_table.slice(0, 2) + assert column_queue.next_n_rows(2) == column_table.slice(2, 2) + + def test_fetch_remaining_rows_respects_n_rows(self): + column_table = self.make_column_table([[0, 3, 6, 9], [1, 4, 7, 10], [2, 5, 8, 11]]) + column_queue = ColumnQueue(column_table) + + assert column_queue.next_n_rows(2) == column_table.slice(0, 2) + assert column_queue.remaining_rows() == column_table.slice(2, 2) From 08f14a00aa9b1627a31a94738eafaf82b3ae6bfd Mon Sep 17 00:00:00 2001 From: Shivam Raj <171748731+shivam2680@users.noreply.github.com> Date: Thu, 3 Oct 2024 19:33:50 +0530 Subject: [PATCH 04/13] [PECO-1961] On non-retryable error, ensure PySQL includes useful information in error (#447) * added error info on non-retryable error --- src/databricks/sql/utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 97df6d4d..ffeaeaf0 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -31,6 +31,7 @@ import logging BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128] +DEFAULT_ERROR_CONTEXT = "Unknown error" logger = logging.getLogger(__name__) @@ -407,7 +408,12 @@ def user_friendly_error_message(self, no_retry_reason, attempt, elapsed): user_friendly_error_message = "{}: {}".format( user_friendly_error_message, self.error_message ) - return user_friendly_error_message + try: + error_context = str(self.error) + except: + error_context = DEFAULT_ERROR_CONTEXT + + return user_friendly_error_message + ". " + error_context # Taken from PyHive From 97c815eb03249070bcdec40e8d4c4fa86a53e278 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Thu, 3 Oct 2024 22:46:16 +0530 Subject: [PATCH 05/13] Reformatted all the files using black (#448) Reformatted the files using black --- examples/custom_cred_provider.py | 22 +- examples/insert_data.py | 26 +- examples/interactive_oauth.py | 8 +- examples/m2m_oauth.py | 12 +- examples/persistent_oauth.py | 47 +- examples/query_cancel.py | 69 +- examples/query_execute.py | 18 +- examples/set_user_agent.py | 20 +- examples/v3_retries_query_execute.py | 24 +- src/databricks/sql/client.py | 16 +- src/databricks/sql/thrift_backend.py | 9 +- src/databricks/sql/utils.py | 50 +- tests/e2e/common/core_tests.py | 31 +- tests/e2e/common/decimal_tests.py | 20 +- tests/e2e/common/large_queries_mixin.py | 12 +- tests/e2e/common/predicates.py | 34 +- tests/e2e/common/retry_test_mixins.py | 41 +- tests/e2e/common/staging_ingestion_tests.py | 89 ++- tests/e2e/common/timestamp_tests.py | 17 +- tests/e2e/common/uc_volume_tests.py | 73 +- tests/e2e/test_complex_types.py | 4 +- tests/e2e/test_driver.py | 93 ++- tests/e2e/test_parameterized_queries.py | 20 +- tests/unit/test_arrow_queue.py | 16 +- tests/unit/test_auth.py | 12 +- tests/unit/test_client.py | 178 +++-- tests/unit/test_cloud_fetch_queue.py | 65 +- tests/unit/test_column_queue.py | 8 +- tests/unit/test_download_manager.py | 30 +- tests/unit/test_downloader.py | 83 +- tests/unit/test_fetches.py | 91 ++- tests/unit/test_fetches_bench.py | 20 +- tests/unit/test_oauth_persistence.py | 11 +- tests/unit/test_param_escaper.py | 26 +- tests/unit/test_retry.py | 5 +- tests/unit/test_thrift_backend.py | 801 ++++++++++++++++---- 36 files changed, 1521 insertions(+), 580 deletions(-) diff --git a/examples/custom_cred_provider.py b/examples/custom_cred_provider.py index 4c43280f..67945f23 100644 --- a/examples/custom_cred_provider.py +++ b/examples/custom_cred_provider.py @@ -4,23 +4,27 @@ from databricks.sdk.oauth import OAuthClient import os -oauth_client = OAuthClient(host=os.getenv("DATABRICKS_SERVER_HOSTNAME"), - client_id=os.getenv("DATABRICKS_CLIENT_ID"), - client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"), - redirect_url=os.getenv("APP_REDIRECT_URL"), - scopes=['all-apis', 'offline_access']) +oauth_client = OAuthClient( + host=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + client_id=os.getenv("DATABRICKS_CLIENT_ID"), + client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"), + redirect_url=os.getenv("APP_REDIRECT_URL"), + scopes=["all-apis", "offline_access"], +) consent = oauth_client.initiate_consent() creds = consent.launch_external_browser() -with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), - http_path = os.getenv("DATABRICKS_HTTP_PATH"), - credentials_provider=creds) as connection: +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + credentials_provider=creds, +) as connection: for x in range(1, 5): cursor = connection.cursor() - cursor.execute('SELECT 1+1') + cursor.execute("SELECT 1+1") result = cursor.fetchall() for row in result: print(row) diff --git a/examples/insert_data.py b/examples/insert_data.py index b304a0e9..053ed158 100644 --- a/examples/insert_data.py +++ b/examples/insert_data.py @@ -1,21 +1,23 @@ from databricks import sql import os -with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), - http_path = os.getenv("DATABRICKS_HTTP_PATH"), - access_token = os.getenv("DATABRICKS_TOKEN")) as connection: +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), +) as connection: - with connection.cursor() as cursor: - cursor.execute("CREATE TABLE IF NOT EXISTS squares (x int, x_squared int)") + with connection.cursor() as cursor: + cursor.execute("CREATE TABLE IF NOT EXISTS squares (x int, x_squared int)") - squares = [(i, i * i) for i in range(100)] - values = ",".join([f"({x}, {y})" for (x, y) in squares]) + squares = [(i, i * i) for i in range(100)] + values = ",".join([f"({x}, {y})" for (x, y) in squares]) - cursor.execute(f"INSERT INTO squares VALUES {values}") + cursor.execute(f"INSERT INTO squares VALUES {values}") - cursor.execute("SELECT * FROM squares LIMIT 10") + cursor.execute("SELECT * FROM squares LIMIT 10") - result = cursor.fetchall() + result = cursor.fetchall() - for row in result: - print(row) + for row in result: + print(row) diff --git a/examples/interactive_oauth.py b/examples/interactive_oauth.py index dad5cac6..8dbc8c47 100644 --- a/examples/interactive_oauth.py +++ b/examples/interactive_oauth.py @@ -13,12 +13,14 @@ token across script executions. """ -with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), - http_path = os.getenv("DATABRICKS_HTTP_PATH")) as connection: +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), +) as connection: for x in range(1, 100): cursor = connection.cursor() - cursor.execute('SELECT 1+1') + cursor.execute("SELECT 1+1") result = cursor.fetchall() for row in result: print(row) diff --git a/examples/m2m_oauth.py b/examples/m2m_oauth.py index eba2095c..1c8c7278 100644 --- a/examples/m2m_oauth.py +++ b/examples/m2m_oauth.py @@ -22,17 +22,19 @@ def credential_provider(): # Service Principal UUID client_id=os.getenv("DATABRICKS_CLIENT_ID"), # Service Principal Secret - client_secret=os.getenv("DATABRICKS_CLIENT_SECRET")) + client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"), + ) return oauth_service_principal(config) with sql.connect( - server_hostname=server_hostname, - http_path=os.getenv("DATABRICKS_HTTP_PATH"), - credentials_provider=credential_provider) as connection: + server_hostname=server_hostname, + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + credentials_provider=credential_provider, +) as connection: for x in range(1, 100): cursor = connection.cursor() - cursor.execute('SELECT 1+1') + cursor.execute("SELECT 1+1") result = cursor.fetchall() for row in result: print(row) diff --git a/examples/persistent_oauth.py b/examples/persistent_oauth.py index 0f2ba077..1a2eded2 100644 --- a/examples/persistent_oauth.py +++ b/examples/persistent_oauth.py @@ -17,37 +17,44 @@ from typing import Optional from databricks import sql -from databricks.sql.experimental.oauth_persistence import OAuthPersistence, OAuthToken, DevOnlyFilePersistence +from databricks.sql.experimental.oauth_persistence import ( + OAuthPersistence, + OAuthToken, + DevOnlyFilePersistence, +) class SampleOAuthPersistence(OAuthPersistence): - def persist(self, hostname: str, oauth_token: OAuthToken): - """To be implemented by the end user to persist in the preferred storage medium. + def persist(self, hostname: str, oauth_token: OAuthToken): + """To be implemented by the end user to persist in the preferred storage medium. - OAuthToken has two properties: - 1. OAuthToken.access_token - 2. OAuthToken.refresh_token + OAuthToken has two properties: + 1. OAuthToken.access_token + 2. OAuthToken.refresh_token - Both should be persisted. - """ - pass + Both should be persisted. + """ + pass - def read(self, hostname: str) -> Optional[OAuthToken]: - """To be implemented by the end user to fetch token from the preferred storage + def read(self, hostname: str) -> Optional[OAuthToken]: + """To be implemented by the end user to fetch token from the preferred storage - Fetch the access_token and refresh_token for the given hostname. - Return OAuthToken(access_token, refresh_token) - """ - pass + Fetch the access_token and refresh_token for the given hostname. + Return OAuthToken(access_token, refresh_token) + """ + pass -with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), - http_path = os.getenv("DATABRICKS_HTTP_PATH"), - auth_type="databricks-oauth", - experimental_oauth_persistence=DevOnlyFilePersistence("./sample.json")) as connection: + +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + auth_type="databricks-oauth", + experimental_oauth_persistence=DevOnlyFilePersistence("./sample.json"), +) as connection: for x in range(1, 100): cursor = connection.cursor() - cursor.execute('SELECT 1+1') + cursor.execute("SELECT 1+1") result = cursor.fetchall() for row in result: print(row) diff --git a/examples/query_cancel.py b/examples/query_cancel.py index 4e0b74a5..b67fc085 100644 --- a/examples/query_cancel.py +++ b/examples/query_cancel.py @@ -5,47 +5,52 @@ The current operation of a cursor may be cancelled by calling its `.cancel()` method as shown in the example below. """ -with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), - http_path = os.getenv("DATABRICKS_HTTP_PATH"), - access_token = os.getenv("DATABRICKS_TOKEN")) as connection: +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), +) as connection: - with connection.cursor() as cursor: - def execute_really_long_query(): - try: - cursor.execute("SELECT SUM(A.id - B.id) " + - "FROM range(1000000000) A CROSS JOIN range(100000000) B " + - "GROUP BY (A.id - B.id)") - except sql.exc.RequestError: - print("It looks like this query was cancelled.") + with connection.cursor() as cursor: - exec_thread = threading.Thread(target=execute_really_long_query) + def execute_really_long_query(): + try: + cursor.execute( + "SELECT SUM(A.id - B.id) " + + "FROM range(1000000000) A CROSS JOIN range(100000000) B " + + "GROUP BY (A.id - B.id)" + ) + except sql.exc.RequestError: + print("It looks like this query was cancelled.") - print("\n Beginning to execute long query") - exec_thread.start() + exec_thread = threading.Thread(target=execute_really_long_query) - # Make sure the query has started before cancelling - print("\n Waiting 15 seconds before canceling", end="", flush=True) + print("\n Beginning to execute long query") + exec_thread.start() - seconds_waited = 0 - while seconds_waited < 15: - seconds_waited += 1 - print(".", end="", flush=True) - time.sleep(1) + # Make sure the query has started before cancelling + print("\n Waiting 15 seconds before canceling", end="", flush=True) - print("\n Cancelling the cursor's operation. This can take a few seconds.") - cursor.cancel() + seconds_waited = 0 + while seconds_waited < 15: + seconds_waited += 1 + print(".", end="", flush=True) + time.sleep(1) - print("\n Now checking the cursor status:") - exec_thread.join(5) + print("\n Cancelling the cursor's operation. This can take a few seconds.") + cursor.cancel() - assert not exec_thread.is_alive() - print("\n The previous command was successfully canceled") + print("\n Now checking the cursor status:") + exec_thread.join(5) - print("\n Now reusing the cursor to run a separate query.") + assert not exec_thread.is_alive() + print("\n The previous command was successfully canceled") - # We can still execute a new command on the cursor - cursor.execute("SELECT * FROM range(3)") + print("\n Now reusing the cursor to run a separate query.") - print("\n Execution was successful. Results appear below:") + # We can still execute a new command on the cursor + cursor.execute("SELECT * FROM range(3)") - print(cursor.fetchall()) + print("\n Execution was successful. Results appear below:") + + print(cursor.fetchall()) diff --git a/examples/query_execute.py b/examples/query_execute.py index a851ab50..38d2f17a 100644 --- a/examples/query_execute.py +++ b/examples/query_execute.py @@ -1,13 +1,15 @@ from databricks import sql import os -with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), - http_path = os.getenv("DATABRICKS_HTTP_PATH"), - access_token = os.getenv("DATABRICKS_TOKEN")) as connection: +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), +) as connection: - with connection.cursor() as cursor: - cursor.execute("SELECT * FROM default.diamonds LIMIT 2") - result = cursor.fetchall() + with connection.cursor() as cursor: + cursor.execute("SELECT * FROM default.diamonds LIMIT 2") + result = cursor.fetchall() - for row in result: - print(row) + for row in result: + print(row) diff --git a/examples/set_user_agent.py b/examples/set_user_agent.py index 449692cf..93eb2e0b 100644 --- a/examples/set_user_agent.py +++ b/examples/set_user_agent.py @@ -1,14 +1,16 @@ from databricks import sql import os -with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), - http_path = os.getenv("DATABRICKS_HTTP_PATH"), - access_token = os.getenv("DATABRICKS_TOKEN"), - _user_agent_entry="ExamplePartnerTag") as connection: +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), + _user_agent_entry="ExamplePartnerTag", +) as connection: - with connection.cursor() as cursor: - cursor.execute("SELECT * FROM default.diamonds LIMIT 2") - result = cursor.fetchall() + with connection.cursor() as cursor: + cursor.execute("SELECT * FROM default.diamonds LIMIT 2") + result = cursor.fetchall() - for row in result: - print(row) + for row in result: + print(row) diff --git a/examples/v3_retries_query_execute.py b/examples/v3_retries_query_execute.py index 4b6772fe..aaab47d1 100644 --- a/examples/v3_retries_query_execute.py +++ b/examples/v3_retries_query_execute.py @@ -28,16 +28,18 @@ # # For complete information about configuring retries, see the docstring for databricks.sql.thrift_backend.ThriftBackend -with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"), - http_path = os.getenv("DATABRICKS_HTTP_PATH"), - access_token = os.getenv("DATABRICKS_TOKEN"), - _enable_v3_retries = True, - _retry_dangerous_codes=[502,400], - _retry_max_redirects=2) as connection: +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), + _enable_v3_retries=True, + _retry_dangerous_codes=[502, 400], + _retry_max_redirects=2, +) as connection: - with connection.cursor() as cursor: - cursor.execute("SELECT * FROM default.diamonds LIMIT 2") - result = cursor.fetchall() + with connection.cursor() as cursor: + cursor.execute("SELECT * FROM default.diamonds LIMIT 2") + result = cursor.fetchall() - for row in result: - print(row) + for row in result: + print(row) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 4df67a08..4e0ab941 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1,6 +1,7 @@ from typing import Dict, Tuple, List, Optional, Any, Union, Sequence import pandas + try: import pyarrow except ImportError: @@ -26,7 +27,7 @@ inject_parameters, transform_paramstyle, ColumnTable, - ColumnQueue + ColumnQueue, ) from databricks.sql.parameters.native import ( DbsqlParameterBase, @@ -1155,7 +1156,7 @@ def _convert_columnar_table(self, table): for row_index in range(table.num_rows): curr_row = [] for col_index in range(table.num_columns): - curr_row.append(table.get_item(col_index, row_index)) + curr_row.append(table.get_item(col_index, row_index)) result.append(ResultRow(*curr_row)) return result @@ -1238,7 +1239,10 @@ def merge_columnar(self, result1, result2): if result1.column_names != result2.column_names: raise ValueError("The columns in the results don't match") - merged_result = [result1.column_table[i] + result2.column_table[i] for i in range(result1.num_columns)] + merged_result = [ + result1.column_table[i] + result2.column_table[i] + for i in range(result1.num_columns) + ] return ColumnTable(merged_result, result1.column_names) def fetchmany_columnar(self, size: int): @@ -1254,9 +1258,9 @@ def fetchmany_columnar(self, size: int): self._next_row_index += results.num_rows while ( - n_remaining_rows > 0 - and not self.has_been_closed_server_side - and self.has_more_rows + n_remaining_rows > 0 + and not self.has_been_closed_server_side + and self.has_more_rows ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 7f6ada9d..cf5cd906 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -624,7 +624,6 @@ def _get_metadata_resp(self, op_handle): @staticmethod def _hive_schema_to_arrow_schema(t_table_schema): - def map_type(t_type_entry): if t_type_entry.primitiveEntry: return { @@ -733,10 +732,10 @@ def _results_message_to_execute_response(self, resp, operation_state): if pyarrow: schema_bytes = ( - t_result_set_metadata_resp.arrowSchema - or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) - .serialize() - .to_pybytes() + t_result_set_metadata_resp.arrowSchema + or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) + .serialize() + .to_pybytes() ) else: schema_bytes = None diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index ffeaeaf0..cd655c4e 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -12,6 +12,7 @@ import re import lz4.frame + try: import pyarrow except ImportError: @@ -103,6 +104,7 @@ def build_queue( else: raise AssertionError("Row set type is not valid") + class ColumnTable: def __init__(self, column_table, column_names): self.column_table = column_table @@ -123,11 +125,17 @@ def get_item(self, col_index, row_index): return self.column_table[col_index][row_index] def slice(self, curr_index, length): - sliced_column_table = [column[curr_index : curr_index + length] for column in self.column_table] + sliced_column_table = [ + column[curr_index : curr_index + length] for column in self.column_table + ] return ColumnTable(sliced_column_table, self.column_names) def __eq__(self, other): - return self.column_table == other.column_table and self.column_names == other.column_names + return ( + self.column_table == other.column_table + and self.column_names == other.column_names + ) + class ColumnQueue(ResultSetQueue): def __init__(self, column_table: ColumnTable): @@ -143,7 +151,9 @@ def next_n_rows(self, num_rows): return slice def remaining_rows(self): - slice = self.column_table.slice(self.cur_row_index, self.n_valid_rows - self.cur_row_index) + slice = self.column_table.slice( + self.cur_row_index, self.n_valid_rows - self.cur_row_index + ) self.cur_row_index += slice.num_rows return slice @@ -571,7 +581,9 @@ def transform_paramstyle( return output -def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> "pyarrow.Table": +def create_arrow_table_from_arrow_file( + file_bytes: bytes, description +) -> "pyarrow.Table": arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes) return convert_decimals_in_arrow_table(arrow_table, description) @@ -621,22 +633,26 @@ def convert_to_assigned_datatypes_in_column_table(column_table, description): converted_column_table = [] for i, col in enumerate(column_table): if description[i][1] == "decimal": - converted_column_table.append(tuple(v if v is None else Decimal(v) for v in col)) + converted_column_table.append( + tuple(v if v is None else Decimal(v) for v in col) + ) elif description[i][1] == "date": - converted_column_table.append(tuple( - v if v is None else datetime.date.fromisoformat(v) for v in col - )) + converted_column_table.append( + tuple(v if v is None else datetime.date.fromisoformat(v) for v in col) + ) elif description[i][1] == "timestamp": - converted_column_table.append(tuple( - ( - v - if v is None - else datetime.datetime.strptime(v, "%Y-%m-%d %H:%M:%S.%f").replace( - tzinfo=pytz.UTC + converted_column_table.append( + tuple( + ( + v + if v is None + else datetime.datetime.strptime( + v, "%Y-%m-%d %H:%M:%S.%f" + ).replace(tzinfo=pytz.UTC) ) + for v in col ) - for v in col - )) + ) else: converted_column_table.append(col) @@ -734,4 +750,4 @@ def _create_python_tuple(t_col_value_wrapper): if nulls[i >> 3] & BIT_MASKS[i & 0x7]: result[i] = None - return tuple(result) \ No newline at end of file + return tuple(result) diff --git a/tests/e2e/common/core_tests.py b/tests/e2e/common/core_tests.py index e89289ef..3f0fdc05 100644 --- a/tests/e2e/common/core_tests.py +++ b/tests/e2e/common/core_tests.py @@ -4,15 +4,18 @@ TypeFailure = namedtuple( "TypeFailure", - "query,columnType,resultType,resultValue," "actualValue,actualType,description,conf", + "query,columnType,resultType,resultValue," + "actualValue,actualType,description,conf", ) ResultFailure = namedtuple( "ResultFailure", - "query,columnType,resultType,resultValue," "actualValue,actualType,description,conf", + "query,columnType,resultType,resultValue," + "actualValue,actualType,description,conf", ) ExecFailure = namedtuple( "ExecFailure", - "query,columnType,resultType,resultValue," "actualValue,actualType,description,conf,error", + "query,columnType,resultType,resultValue," + "actualValue,actualType,description,conf,error", ) @@ -58,7 +61,9 @@ def run_tests_on_queries(self, default_conf): for query, columnType, rowValueType, answer in self.range_queries: with self.cursor(default_conf) as cursor: failures.extend( - self.run_query(cursor, query, columnType, rowValueType, answer, default_conf) + self.run_query( + cursor, query, columnType, rowValueType, answer, default_conf + ) ) failures.extend( self.run_range_query( @@ -69,7 +74,9 @@ def run_tests_on_queries(self, default_conf): for query, columnType, rowValueType, answer in self.queries: with self.cursor(default_conf) as cursor: failures.extend( - self.run_query(cursor, query, columnType, rowValueType, answer, default_conf) + self.run_query( + cursor, query, columnType, rowValueType, answer, default_conf + ) ) if failures: @@ -84,7 +91,9 @@ def run_query(self, cursor, query, columnType, rowValueType, answer, conf): try: cursor.execute(full_query) (result,) = cursor.fetchone() - if not all(cursor.description[0][1] == type for type in expected_column_types): + if not all( + cursor.description[0][1] == type for type in expected_column_types + ): return [ TypeFailure( full_query, @@ -150,7 +159,10 @@ def run_range_query(self, cursor, query, columnType, rowValueType, expected, con if len(rows) <= 0: break for index, (result, id) in enumerate(rows): - if not all(cursor.description[0][1] == type for type in expected_column_types): + if not all( + cursor.description[0][1] == type + for type in expected_column_types + ): return [ TypeFailure( full_query, @@ -163,7 +175,10 @@ def run_range_query(self, cursor, query, columnType, rowValueType, expected, con conf, ) ] - if self.validate_row_value_type and type(result) is not rowValueType: + if ( + self.validate_row_value_type + and type(result) is not rowValueType + ): return [ TypeFailure( full_query, diff --git a/tests/e2e/common/decimal_tests.py b/tests/e2e/common/decimal_tests.py index 5005cdf1..0029f30c 100644 --- a/tests/e2e/common/decimal_tests.py +++ b/tests/e2e/common/decimal_tests.py @@ -7,8 +7,16 @@ class DecimalTestsMixin: decimal_and_expected_results = [ ("100.001 AS DECIMAL(6, 3)", Decimal("100.001"), pyarrow.decimal128(6, 3)), - ("1000000.0000 AS DECIMAL(11, 4)", Decimal("1000000.0000"), pyarrow.decimal128(11, 4)), - ("-10.2343 AS DECIMAL(10, 6)", Decimal("-10.234300"), pyarrow.decimal128(10, 6)), + ( + "1000000.0000 AS DECIMAL(11, 4)", + Decimal("1000000.0000"), + pyarrow.decimal128(11, 4), + ), + ( + "-10.2343 AS DECIMAL(10, 6)", + Decimal("-10.234300"), + pyarrow.decimal128(10, 6), + ), # TODO(SC-90767): Re-enable this test after we have a way of passing `ansi_mode` = False # ("-13872347.2343 AS DECIMAL(10, 10)", None, pyarrow.decimal128(10, 10)), ("NULL AS DECIMAL(1, 1)", None, pyarrow.decimal128(1, 1)), @@ -30,7 +38,9 @@ class DecimalTestsMixin: ), ] - @pytest.mark.parametrize("decimal, expected_value, expected_type", decimal_and_expected_results) + @pytest.mark.parametrize( + "decimal, expected_value, expected_type", decimal_and_expected_results + ) def test_decimals(self, decimal, expected_value, expected_type): with self.cursor({}) as cursor: query = "SELECT CAST ({})".format(decimal) @@ -44,7 +54,9 @@ def test_decimals(self, decimal, expected_value, expected_type): ) def test_multi_decimals(self, decimals, expected_values, expected_type): with self.cursor({}) as cursor: - union_str = " UNION ".join(["(SELECT CAST ({}))".format(dec) for dec in decimals]) + union_str = " UNION ".join( + ["(SELECT CAST ({}))".format(dec) for dec in decimals] + ) query = "SELECT * FROM ({}) ORDER BY 1 NULLS LAST".format(union_str) cursor.execute(query) diff --git a/tests/e2e/common/large_queries_mixin.py b/tests/e2e/common/large_queries_mixin.py index 9ebc3f01..41ef029b 100644 --- a/tests/e2e/common/large_queries_mixin.py +++ b/tests/e2e/common/large_queries_mixin.py @@ -36,7 +36,9 @@ def fetch_rows(self, cursor, row_count, fetchmany_size): num_fetches = max(math.ceil(n / 10000), 1) latency_ms = int((time.time() - start_time) * 1000 / num_fetches), 1 print( - "Fetched {} rows with an avg latency of {} per fetch, ".format(n, latency_ms) + "Fetched {} rows with an avg latency of {} per fetch, ".format( + n, latency_ms + ) + "assuming 10K fetch size." ) @@ -55,10 +57,14 @@ def test_query_with_large_wide_result_set(self): cursor.connection.lz4_compression = lz4_compression uuids = ", ".join(["uuid() uuid{}".format(i) for i in range(cols)]) cursor.execute( - "SELECT id, {uuids} FROM RANGE({rows})".format(uuids=uuids, rows=rows) + "SELECT id, {uuids} FROM RANGE({rows})".format( + uuids=uuids, rows=rows + ) ) assert lz4_compression == cursor.active_result_set.lz4_compressed - for row_id, row in enumerate(self.fetch_rows(cursor, rows, fetchmany_size)): + for row_id, row in enumerate( + self.fetch_rows(cursor, rows, fetchmany_size) + ): assert row[0] == row_id # Verify no rows are dropped in the middle. assert len(row[1]) == 36 diff --git a/tests/e2e/common/predicates.py b/tests/e2e/common/predicates.py index 88b14961..61de69fd 100644 --- a/tests/e2e/common/predicates.py +++ b/tests/e2e/common/predicates.py @@ -10,7 +10,8 @@ def pysql_supports_arrow(): """Import databricks.sql and test whether Cursor has fetchall_arrow.""" from databricks.sql.client import Cursor - return hasattr(Cursor, 'fetchall_arrow') + + return hasattr(Cursor, "fetchall_arrow") def pysql_has_version(compare, version): @@ -25,6 +26,7 @@ def test_some_pyhive_v1_stuff(): ... """ from databricks import sql + return compare_module_version(sql, compare, version) @@ -38,7 +40,7 @@ def is_endpoint_test(cli_args=None): def compare_dbr_versions(cli_args, compare, major_version, minor_version): if MAJOR_DBR_V_KEY in cli_args and MINOR_DBR_V_KEY in cli_args: if cli_args[MINOR_DBR_V_KEY] == "x": - actual_minor_v = float('inf') + actual_minor_v = float("inf") else: actual_minor_v = int(cli_args[MINOR_DBR_V_KEY]) dbr_version = (int(cli_args[MAJOR_DBR_V_KEY]), actual_minor_v) @@ -47,8 +49,10 @@ def compare_dbr_versions(cli_args, compare, major_version, minor_version): if not is_endpoint_test(): raise ValueError( - "DBR version not provided for non-endpoint test. Please pass the {} and {} params". - format(MAJOR_DBR_V_KEY, MINOR_DBR_V_KEY)) + "DBR version not provided for non-endpoint test. Please pass the {} and {} params".format( + MAJOR_DBR_V_KEY, MINOR_DBR_V_KEY + ) + ) def is_thrift_v5_plus(cli_args): @@ -56,18 +60,18 @@ def is_thrift_v5_plus(cli_args): _compare_fns = { - '<': '__lt__', - '<=': '__le__', - '>': '__gt__', - '>=': '__ge__', - '==': '__eq__', - '!=': '__ne__', + "<": "__lt__", + "<=": "__le__", + ">": "__gt__", + ">=": "__ge__", + "==": "__eq__", + "!=": "__ne__", } def compare_versions(compare, v1_tuple, v2_tuple): compare_fn_name = _compare_fns.get(compare) - assert compare_fn_name, 'Received invalid compare string: ' + compare + assert compare_fn_name, "Received invalid compare string: " + compare return getattr(v1_tuple, compare_fn_name)(v2_tuple) @@ -87,13 +91,15 @@ def test_some_pyhive_v1_stuff(): NOTE: This comparison leverages packaging.version.parse, and compares _release_ versions, thus ignoring pre/post release tags (eg -rc1, -dev, etc). """ - assert module, 'Received invalid module: ' + module - assert getattr(module, '__version__'), 'Received module with no version: ' + module + assert module, "Received invalid module: " + module + assert getattr(module, "__version__"), "Received module with no version: " + module def validate_version(version): v = parse_version(str(version)) # assert that we get a PEP-440 Version back -- LegacyVersion doesn't have major/minor. - assert hasattr(v, 'major'), 'Module has incompatible "Legacy" version: ' + version + assert hasattr(v, "major"), ( + 'Module has incompatible "Legacy" version: ' + version + ) return (v.major, v.minor, v.micro) mod_version = validate_version(module.__version__) diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index 106a8fb5..7dd5f745 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -59,7 +59,9 @@ def _test_retry_disabled_with_message(self, error_msg_substring, exception_type) @contextmanager -def mocked_server_response(status: int = 200, headers: dict = {}, redirect_location: Optional[str] = None): +def mocked_server_response( + status: int = 200, headers: dict = {}, redirect_location: Optional[str] = None +): """Context manager for patching urllib3 responses""" # When mocking mocking a BaseHTTPResponse for urllib3 the mock must include @@ -97,7 +99,9 @@ def mock_sequential_server_responses(responses: List[dict]): # Each resp should have these members: for resp in responses: - _mock = MagicMock(headers=resp["headers"], msg=resp["headers"], status=resp["status"]) + _mock = MagicMock( + headers=resp["headers"], msg=resp["headers"], status=resp["status"] + ) _mock.get_redirect_location.return_value = ( False if resp["redirect_location"] is None else resp["redirect_location"] ) @@ -176,7 +180,9 @@ def test_retry_exponential_backoff(self): retry_policy["_retry_delay_min"] = 1 time_start = time.time() - with mocked_server_response(status=429, headers={"Retry-After": "3"}) as mock_obj: + with mocked_server_response( + status=429, headers={"Retry-After": "3"} + ) as mock_obj: with pytest.raises(RequestError) as cm: with self.connection(extra_params=retry_policy) as conn: pass @@ -256,7 +262,9 @@ def test_retry_dangerous_codes(self): assert isinstance(cm.value.args[1], UnsafeToRetryError) # Prove that these codes are retried if forced by the user - with self.connection(extra_params={**self._retry_policy, **additional_settings}) as conn: + with self.connection( + extra_params={**self._retry_policy, **additional_settings} + ) as conn: with conn.cursor() as cursor: for dangerous_code in DANGEROUS_CODES: with mocked_server_response(status=dangerous_code): @@ -326,7 +334,9 @@ def test_retry_abort_close_operation_on_404(self, caplog): curs.execute("SELECT 1") with mock_sequential_server_responses(responses): curs.close() - assert "Operation was canceled by a prior request" in caplog.text + assert ( + "Operation was canceled by a prior request" in caplog.text + ) def test_retry_max_redirects_raises_too_many_redirects_exception(self): """GIVEN the connector is configured with a custom max_redirects @@ -337,7 +347,9 @@ def test_retry_max_redirects_raises_too_many_redirects_exception(self): max_redirects, expected_call_count = 1, 2 # Code 302 is a redirect - with mocked_server_response(status=302, redirect_location="/foo.bar") as mock_obj: + with mocked_server_response( + status=302, redirect_location="/foo.bar" + ) as mock_obj: with pytest.raises(MaxRetryError) as cm: with self.connection( extra_params={ @@ -359,7 +371,9 @@ def test_retry_max_redirects_unset_doesnt_redirect_forever(self): _stop_after_attempts_count is enforced. """ # Code 302 is a redirect - with mocked_server_response(status=302, redirect_location="/foo.bar/") as mock_obj: + with mocked_server_response( + status=302, redirect_location="/foo.bar/" + ) as mock_obj: with pytest.raises(MaxRetryError) as cm: with self.connection( extra_params={ @@ -385,7 +399,9 @@ def test_retry_max_redirects_is_bounded_by_stop_after_attempts_count(self): with pytest.raises(RequestError) as cm: with mock_sequential_server_responses(responses): - with self.connection(extra_params={**self._retry_policy, **additional_settings}): + with self.connection( + extra_params={**self._retry_policy, **additional_settings} + ): pass # The error should be the result of the 500, not because of too many requests. @@ -405,9 +421,12 @@ def test_retry_max_redirects_exceeds_max_attempts_count_warns_user(self, caplog) assert "it will have no affect!" in caplog.text def test_retry_legacy_behavior_warns_user(self, caplog): - with self.connection(extra_params={**self._retry_policy, "_enable_v3_retries": False}): - assert "Legacy retry behavior is enabled for this connection." in caplog.text - + with self.connection( + extra_params={**self._retry_policy, "_enable_v3_retries": False} + ): + assert ( + "Legacy retry behavior is enabled for this connection." in caplog.text + ) def test_403_not_retried(self): """GIVEN the server returns a code 403 diff --git a/tests/e2e/common/staging_ingestion_tests.py b/tests/e2e/common/staging_ingestion_tests.py index d8d0429f..008055e3 100644 --- a/tests/e2e/common/staging_ingestion_tests.py +++ b/tests/e2e/common/staging_ingestion_tests.py @@ -41,7 +41,9 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): with open(fh, "wb") as fp: fp.write(original_text) - with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": temp_path} + ) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" @@ -51,7 +53,9 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): new_fh, new_temp_path = tempfile.mkstemp() - with self.connection(extra_params={"staging_allowed_local_path": new_temp_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": new_temp_path} + ) as conn: cursor = conn.cursor() query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' TO '{new_temp_path}'" cursor.execute(query) @@ -71,17 +75,19 @@ def test_staging_ingestion_life_cycle(self, ingestion_user): # GET after REMOVE should fail - with pytest.raises(Error, match="Staging operation over HTTP was unsuccessful: 404"): + with pytest.raises( + Error, match="Staging operation over HTTP was unsuccessful: 404" + ): cursor = conn.cursor() - query = ( - f"GET 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' TO '{new_temp_path}'" - ) + query = f"GET 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' TO '{new_temp_path}'" cursor.execute(query) os.remove(temp_path) os.remove(new_temp_path) - def test_staging_ingestion_put_fails_without_staging_allowed_local_path(self, ingestion_user): + def test_staging_ingestion_put_fails_without_staging_allowed_local_path( + self, ingestion_user + ): """PUT operations are not supported unless the connection was built with a parameter called staging_allowed_local_path """ @@ -93,7 +99,9 @@ def test_staging_ingestion_put_fails_without_staging_allowed_local_path(self, in with open(fh, "wb") as fp: fp.write(original_text) - with pytest.raises(Error, match="You must provide at least one staging_allowed_local_path"): + with pytest.raises( + Error, match="You must provide at least one staging_allowed_local_path" + ): with self.connection() as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" @@ -119,12 +127,16 @@ def test_staging_ingestion_put_fails_if_localFile_not_in_staging_allowed_local_p Error, match="Local file operations are restricted to paths within the configured staging_allowed_local_path", ): - with self.connection(extra_params={"staging_allowed_local_path": base_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": base_path} + ) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" cursor.execute(query) - def test_staging_ingestion_put_fails_if_file_exists_and_overwrite_not_set(self, ingestion_user): + def test_staging_ingestion_put_fails_if_file_exists_and_overwrite_not_set( + self, ingestion_user + ): """PUT a file into the staging location twice. First command should succeed. Second should fail.""" fh, temp_path = tempfile.mkstemp() @@ -135,16 +147,22 @@ def test_staging_ingestion_put_fails_if_file_exists_and_overwrite_not_set(self, fp.write(original_text) def perform_put(): - with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": temp_path} + ) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/12/15/file1.csv'" cursor.execute(query) def perform_remove(): try: - remove_query = f"REMOVE 'stage://tmp/{ingestion_user}/tmp/12/15/file1.csv'" + remove_query = ( + f"REMOVE 'stage://tmp/{ingestion_user}/tmp/12/15/file1.csv'" + ) - with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": "/"} + ) as conn: cursor = conn.cursor() cursor.execute(remove_query) except Exception: @@ -178,7 +196,9 @@ def test_staging_ingestion_fails_to_modify_another_staging_user(self): fp.write(original_text) def perform_put(): - with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": temp_path} + ) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO 'stage://tmp/{some_other_user}/tmp/12/15/file1.csv' OVERWRITE" cursor.execute(query) @@ -186,12 +206,16 @@ def perform_put(): def perform_remove(): remove_query = f"REMOVE 'stage://tmp/{some_other_user}/tmp/12/15/file1.csv'" - with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": "/"} + ) as conn: cursor = conn.cursor() cursor.execute(remove_query) def perform_get(): - with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": temp_path} + ) as conn: cursor = conn.cursor() query = f"GET 'stage://tmp/{some_other_user}/tmp/11/15/file1.csv' TO '{temp_path}'" cursor.execute(query) @@ -232,7 +256,9 @@ def test_staging_ingestion_put_fails_if_absolute_localFile_not_in_staging_allowe query = f"PUT '{target_file}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" cursor.execute(query) - def test_staging_ingestion_empty_local_path_fails_to_parse_at_server(self, ingestion_user): + def test_staging_ingestion_empty_local_path_fails_to_parse_at_server( + self, ingestion_user + ): staging_allowed_local_path = "/var/www/html" target_file = "" @@ -244,7 +270,9 @@ def test_staging_ingestion_empty_local_path_fails_to_parse_at_server(self, inges query = f"PUT '{target_file}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/file1.csv' OVERWRITE" cursor.execute(query) - def test_staging_ingestion_invalid_staging_path_fails_at_server(self, ingestion_user): + def test_staging_ingestion_invalid_staging_path_fails_at_server( + self, ingestion_user + ): staging_allowed_local_path = "/var/www/html" target_file = "index.html" @@ -278,12 +306,29 @@ def generate_file_and_path_and_queries(): original_text = "hello world!".encode("utf-8") fp.write(original_text) put_query = f"PUT '{temp_path}' INTO 'stage://tmp/{ingestion_user}/tmp/11/15/{id(temp_path)}.csv' OVERWRITE" - remove_query = f"REMOVE 'stage://tmp/{ingestion_user}/tmp/11/15/{id(temp_path)}.csv'" + remove_query = ( + f"REMOVE 'stage://tmp/{ingestion_user}/tmp/11/15/{id(temp_path)}.csv'" + ) return fh, temp_path, put_query, remove_query - fh1, temp_path1, put_query1, remove_query1 = generate_file_and_path_and_queries() - fh2, temp_path2, put_query2, remove_query2 = generate_file_and_path_and_queries() - fh3, temp_path3, put_query3, remove_query3 = generate_file_and_path_and_queries() + ( + fh1, + temp_path1, + put_query1, + remove_query1, + ) = generate_file_and_path_and_queries() + ( + fh2, + temp_path2, + put_query2, + remove_query2, + ) = generate_file_and_path_and_queries() + ( + fh3, + temp_path3, + put_query3, + remove_query3, + ) = generate_file_and_path_and_queries() with self.connection( extra_params={"staging_allowed_local_path": [temp_path1, temp_path2]} diff --git a/tests/e2e/common/timestamp_tests.py b/tests/e2e/common/timestamp_tests.py index f25aed7e..70ded7d0 100644 --- a/tests/e2e/common/timestamp_tests.py +++ b/tests/e2e/common/timestamp_tests.py @@ -15,7 +15,10 @@ class TimestampTestsMixin: ] timestamp_and_expected_results = [ - ("2021-09-30 11:27:35.123+04:00", datetime.datetime(2021, 9, 30, 7, 27, 35, 123000)), + ( + "2021-09-30 11:27:35.123+04:00", + datetime.datetime(2021, 9, 30, 7, 27, 35, 123000), + ), ("2021-09-30 11:27:35+04:00", datetime.datetime(2021, 9, 30, 7, 27, 35)), ("2021-09-30 11:27:35.123", datetime.datetime(2021, 9, 30, 11, 27, 35, 123000)), ("2021-09-30 11:27:35", datetime.datetime(2021, 9, 30, 11, 27, 35)), @@ -45,18 +48,24 @@ def assertTimestampsEqual(self, result, expected): def multi_query(self, n_rows=10): row_sql = "SELECT " + ", ".join( - ["TIMESTAMP('{}')".format(ts) for (ts, _) in self.timestamp_and_expected_results] + [ + "TIMESTAMP('{}')".format(ts) + for (ts, _) in self.timestamp_and_expected_results + ] ) query = " UNION ALL ".join([row_sql for _ in range(n_rows)]) expected_matrix = [ - [dt for (_, dt) in self.timestamp_and_expected_results] for _ in range(n_rows) + [dt for (_, dt) in self.timestamp_and_expected_results] + for _ in range(n_rows) ] return query, expected_matrix def test_timestamps(self): with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: for timestamp, expected in self.timestamp_and_expected_results: - cursor.execute("SELECT TIMESTAMP('{timestamp}')".format(timestamp=timestamp)) + cursor.execute( + "SELECT TIMESTAMP('{timestamp}')".format(timestamp=timestamp) + ) result = cursor.fetchone()[0] self.assertTimestampsEqual(result, expected) diff --git a/tests/e2e/common/uc_volume_tests.py b/tests/e2e/common/uc_volume_tests.py index 21e43036..72e2f502 100644 --- a/tests/e2e/common/uc_volume_tests.py +++ b/tests/e2e/common/uc_volume_tests.py @@ -40,19 +40,21 @@ def test_uc_volume_life_cycle(self, catalog, schema): with open(fh, "wb") as fp: fp.write(original_text) - with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": temp_path} + ) as conn: cursor = conn.cursor() - query = ( - f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" - ) + query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" cursor.execute(query) # GET should succeed new_fh, new_temp_path = tempfile.mkstemp() - with self.connection(extra_params={"staging_allowed_local_path": new_temp_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": new_temp_path} + ) as conn: cursor = conn.cursor() query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'" cursor.execute(query) @@ -72,7 +74,9 @@ def test_uc_volume_life_cycle(self, catalog, schema): # GET after REMOVE should fail - with pytest.raises(Error, match="Staging operation over HTTP was unsuccessful: 404"): + with pytest.raises( + Error, match="Staging operation over HTTP was unsuccessful: 404" + ): cursor = conn.cursor() query = f"GET '/Volumes/{catalog}/{schema}/e2etests/file1.csv' TO '{new_temp_path}'" cursor.execute(query) @@ -80,7 +84,9 @@ def test_uc_volume_life_cycle(self, catalog, schema): os.remove(temp_path) os.remove(new_temp_path) - def test_uc_volume_put_fails_without_staging_allowed_local_path(self, catalog, schema): + def test_uc_volume_put_fails_without_staging_allowed_local_path( + self, catalog, schema + ): """PUT operations are not supported unless the connection was built with a parameter called staging_allowed_local_path """ @@ -92,7 +98,9 @@ def test_uc_volume_put_fails_without_staging_allowed_local_path(self, catalog, s with open(fh, "wb") as fp: fp.write(original_text) - with pytest.raises(Error, match="You must provide at least one staging_allowed_local_path"): + with pytest.raises( + Error, match="You must provide at least one staging_allowed_local_path" + ): with self.connection() as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" @@ -118,12 +126,16 @@ def test_uc_volume_put_fails_if_localFile_not_in_staging_allowed_local_path( Error, match="Local file operations are restricted to paths within the configured staging_allowed_local_path", ): - with self.connection(extra_params={"staging_allowed_local_path": base_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": base_path} + ) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" cursor.execute(query) - def test_uc_volume_put_fails_if_file_exists_and_overwrite_not_set(self, catalog, schema): + def test_uc_volume_put_fails_if_file_exists_and_overwrite_not_set( + self, catalog, schema + ): """PUT a file into the staging location twice. First command should succeed. Second should fail.""" fh, temp_path = tempfile.mkstemp() @@ -134,16 +146,22 @@ def test_uc_volume_put_fails_if_file_exists_and_overwrite_not_set(self, catalog, fp.write(original_text) def perform_put(): - with self.connection(extra_params={"staging_allowed_local_path": temp_path}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": temp_path} + ) as conn: cursor = conn.cursor() query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/file1.csv'" cursor.execute(query) def perform_remove(): try: - remove_query = f"REMOVE '/Volumes/{catalog}/{schema}/e2etests/file1.csv'" + remove_query = ( + f"REMOVE '/Volumes/{catalog}/{schema}/e2etests/file1.csv'" + ) - with self.connection(extra_params={"staging_allowed_local_path": "/"}) as conn: + with self.connection( + extra_params={"staging_allowed_local_path": "/"} + ) as conn: cursor = conn.cursor() cursor.execute(remove_query) except Exception: @@ -212,7 +230,9 @@ def test_uc_volume_invalid_volume_path_fails_at_server(self, catalog, schema): query = f"PUT '{target_file}' INTO '/Volumes/RANDOMSTRINGOFCHARACTERS/{catalog}/{schema}/e2etests/file1.csv' OVERWRITE" cursor.execute(query) - def test_uc_volume_supports_multiple_staging_allowed_local_path_values(self, catalog, schema): + def test_uc_volume_supports_multiple_staging_allowed_local_path_values( + self, catalog, schema + ): """staging_allowed_local_path may be either a path-like object or a list of path-like objects. This test confirms that two configured base paths: @@ -232,12 +252,29 @@ def generate_file_and_path_and_queries(): original_text = "hello world!".encode("utf-8") fp.write(original_text) put_query = f"PUT '{temp_path}' INTO '/Volumes/{catalog}/{schema}/e2etests/{id(temp_path)}.csv' OVERWRITE" - remove_query = f"REMOVE '/Volumes/{catalog}/{schema}/e2etests/{id(temp_path)}.csv'" + remove_query = ( + f"REMOVE '/Volumes/{catalog}/{schema}/e2etests/{id(temp_path)}.csv'" + ) return fh, temp_path, put_query, remove_query - fh1, temp_path1, put_query1, remove_query1 = generate_file_and_path_and_queries() - fh2, temp_path2, put_query2, remove_query2 = generate_file_and_path_and_queries() - fh3, temp_path3, put_query3, remove_query3 = generate_file_and_path_and_queries() + ( + fh1, + temp_path1, + put_query1, + remove_query1, + ) = generate_file_and_path_and_queries() + ( + fh2, + temp_path2, + put_query2, + remove_query2, + ) = generate_file_and_path_and_queries() + ( + fh3, + temp_path3, + put_query3, + remove_query3, + ) = generate_file_and_path_and_queries() with self.connection( extra_params={"staging_allowed_local_path": [temp_path1, temp_path2]} diff --git a/tests/e2e/test_complex_types.py b/tests/e2e/test_complex_types.py index 0a7f514a..446a6b50 100644 --- a/tests/e2e/test_complex_types.py +++ b/tests/e2e/test_complex_types.py @@ -53,7 +53,9 @@ def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture): @pytest.mark.parametrize("field", [("array_col"), ("map_col"), ("struct_col")]) def test_read_complex_types_as_string(self, field, table_fixture): """Confirms the return type of a complex type that is returned as a string""" - with self.cursor(extra_params={"_use_arrow_native_complex_types": False}) as cursor: + with self.cursor( + extra_params={"_use_arrow_native_complex_types": False} + ) as cursor: result = cursor.execute( "SELECT * FROM pysql_test_complex_types_table LIMIT 1" ).fetchone() diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index c23e4f79..cfd1e969 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -40,7 +40,10 @@ from tests.e2e.common.large_queries_mixin import LargeQueriesMixin from tests.e2e.common.timestamp_tests import TimestampTestsMixin from tests.e2e.common.decimal_tests import DecimalTestsMixin -from tests.e2e.common.retry_test_mixins import Client429ResponseMixin, Client503ResponseMixin +from tests.e2e.common.retry_test_mixins import ( + Client429ResponseMixin, + Client503ResponseMixin, +) from tests.e2e.common.staging_ingestion_tests import PySQLStagingIngestionTestSuiteMixin from tests.e2e.common.retry_test_mixins import PySQLRetryTestsMixin @@ -57,7 +60,9 @@ # manually decorate DecimalTestsMixin to need arrow support for name in loader.getTestCaseNames(DecimalTestsMixin, "test_"): fn = getattr(DecimalTestsMixin, name) - decorated = skipUnless(pysql_supports_arrow(), "Decimal tests need arrow support")(fn) + decorated = skipUnless(pysql_supports_arrow(), "Decimal tests need arrow support")( + fn + ) setattr(DecimalTestsMixin, name, decorated) @@ -68,7 +73,9 @@ class PySQLPytestTestCase: error_type = Error conf_to_disable_rate_limit_retries = {"_retry_stop_after_attempts_count": 1} - conf_to_disable_temporarily_unavailable_retries = {"_retry_stop_after_attempts_count": 1} + conf_to_disable_temporarily_unavailable_retries = { + "_retry_stop_after_attempts_count": 1 + } arraysize = 1000 buffer_size_bytes = 104857600 @@ -105,7 +112,9 @@ def connection(self, extra_params=()): @contextmanager def cursor(self, extra_params=()): with self.connection(extra_params) as conn: - cursor = conn.cursor(arraysize=self.arraysize, buffer_size_bytes=self.buffer_size_bytes) + cursor = conn.cursor( + arraysize=self.arraysize, buffer_size_bytes=self.buffer_size_bytes + ) try: yield cursor finally: @@ -144,7 +153,9 @@ def test_cloud_fetch(self): limits, threads, [True, False] ): with self.subTest( - num_limit=num_limit, num_threads=num_threads, lz4_compression=lz4_compression + num_limit=num_limit, + num_threads=num_threads, + lz4_compression=lz4_compression, ): cf_result, noop_result = None, None query = base_query + "LIMIT " + str(num_limit) @@ -289,7 +300,15 @@ def test_get_tables(self): ("TYPE_CAT", "string", None, None, None, None, None), ("TYPE_SCHEM", "string", None, None, None, None, None), ("TYPE_NAME", "string", None, None, None, None, None), - ("SELF_REFERENCING_COL_NAME", "string", None, None, None, None, None), + ( + "SELF_REFERENCING_COL_NAME", + "string", + None, + None, + None, + None, + None, + ), ("REF_GENERATION", "string", None, None, None, None, None), ] assert tables_desc == expected @@ -390,15 +409,21 @@ def test_escape_single_quotes(self): table_name = "table_{uuid}".format(uuid=str(uuid4()).replace("-", "_")) # Test escape syntax directly cursor.execute( - "CREATE TABLE IF NOT EXISTS {} AS (SELECT 'you\\'re' AS col_1)".format(table_name) + "CREATE TABLE IF NOT EXISTS {} AS (SELECT 'you\\'re' AS col_1)".format( + table_name + ) + ) + cursor.execute( + "SELECT * FROM {} WHERE col_1 LIKE 'you\\'re'".format(table_name) ) - cursor.execute("SELECT * FROM {} WHERE col_1 LIKE 'you\\'re'".format(table_name)) rows = cursor.fetchall() assert rows[0]["col_1"] == "you're" # Test escape syntax in parameter cursor.execute( - "SELECT * FROM {} WHERE {}.col_1 LIKE %(var)s".format(table_name, table_name), + "SELECT * FROM {} WHERE {}.col_1 LIKE %(var)s".format( + table_name, table_name + ), parameters={"var": "you're"}, ) rows = cursor.fetchall() @@ -427,7 +452,9 @@ def test_get_catalogs(self): cursor.catalogs() cursor.fetchall() catalogs_desc = cursor.description - assert catalogs_desc == [("TABLE_CAT", "string", None, None, None, None, None)] + assert catalogs_desc == [ + ("TABLE_CAT", "string", None, None, None, None, None) + ] @skipUnless(pysql_supports_arrow(), "arrow test need arrow support") def test_get_arrow(self): @@ -564,7 +591,8 @@ def test_temp_view_fetch(self): @skipIf(pysql_has_version("<", "2"), "requires pysql v2") @skipIf( - True, "Unclear the purpose of this test since urllib3 does not complain when timeout == 0" + True, + "Unclear the purpose of this test since urllib3 does not complain when timeout == 0", ) def test_socket_timeout(self): # We expect to see a BlockingIO error when the socket is opened @@ -587,7 +615,9 @@ def test_socket_timeout_user_defined(self): def test_ssp_passthrough(self): for enable_ansi in (True, False): - with self.cursor({"session_configuration": {"ansi_mode": enable_ansi}}) as cursor: + with self.cursor( + {"session_configuration": {"ansi_mode": enable_ansi}} + ) as cursor: cursor.execute("SET ansi_mode") assert list(cursor.fetchone()) == ["ansi_mode", str(enable_ansi)] @@ -595,7 +625,9 @@ def test_ssp_passthrough(self): def test_timestamps_arrow(self): with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: for timestamp, expected in self.timestamp_and_expected_results: - cursor.execute("SELECT TIMESTAMP('{timestamp}')".format(timestamp=timestamp)) + cursor.execute( + "SELECT TIMESTAMP('{timestamp}')".format(timestamp=timestamp) + ) arrow_table = cursor.fetchmany_arrow(1) if self.should_add_timezone(): ts_type = pyarrow.timestamp("us", tz="Etc/UTC") @@ -606,7 +638,9 @@ def test_timestamps_arrow(self): # To work consistently across different local timezones, we specify the timezone # of the expected result to # be UTC (what it should be by default on the server) - aware_timestamp = expected and expected.replace(tzinfo=datetime.timezone.utc) + aware_timestamp = expected and expected.replace( + tzinfo=datetime.timezone.utc + ) assert result_value == ( aware_timestamp and aware_timestamp.timestamp() * 1000000 ), "timestamp {} did not match {}".format(timestamp, expected) @@ -616,14 +650,16 @@ def test_multi_timestamps_arrow(self): with self.cursor({"session_configuration": {"ansi_mode": False}}) as cursor: query, expected = self.multi_query() expected = [ - [self.maybe_add_timezone_to_timestamp(ts) for ts in row] for row in expected + [self.maybe_add_timezone_to_timestamp(ts) for ts in row] + for row in expected ] cursor.execute(query) table = cursor.fetchall_arrow() # Transpose columnar result to list of rows list_of_cols = [c.to_pylist() for c in table] result = [ - [col[row_index] for col in list_of_cols] for row_index in range(table.num_rows) + [col[row_index] for col in list_of_cols] + for row_index in range(table.num_rows) ] assert result == expected @@ -640,7 +676,9 @@ def test_timezone_with_timestamp(self): cursor.execute("select CAST('2022-03-02 12:54:56' as TIMESTAMP)") arrow_result_table = cursor.fetchmany_arrow(1) - arrow_result_value = arrow_result_table.column(0).combine_chunks()[0].value + arrow_result_value = ( + arrow_result_table.column(0).combine_chunks()[0].value + ) ts_type = pyarrow.timestamp("us", tz="Europe/Amsterdam") assert arrow_result_table.field(0).type == ts_type @@ -700,7 +738,9 @@ def test_close_connection_closes_cursors(self): with self.connection() as conn: cursor = conn.cursor() - cursor.execute("SELECT id, id `id2`, id `id3` FROM RANGE(1000000) order by RANDOM()") + cursor.execute( + "SELECT id, id `id2`, id `id3` FROM RANGE(1000000) order by RANDOM()" + ) ars = cursor.active_result_set # We must manually run this check because thrift_backend always forces `has_been_closed_server_side` to True @@ -709,14 +749,21 @@ def test_close_connection_closes_cursors(self): status_request = ttypes.TGetOperationStatusReq( operationHandle=ars.command_id, getProgressUpdate=False ) - op_status_at_server = ars.thrift_backend._client.GetOperationStatus(status_request) - assert op_status_at_server.operationState != ttypes.TOperationState.CLOSED_STATE + op_status_at_server = ars.thrift_backend._client.GetOperationStatus( + status_request + ) + assert ( + op_status_at_server.operationState + != ttypes.TOperationState.CLOSED_STATE + ) conn.close() # When connection closes, any cursor operations should no longer exist at the server with pytest.raises(SessionAlreadyClosedError) as cm: - op_status_at_server = ars.thrift_backend._client.GetOperationStatus(status_request) + op_status_at_server = ars.thrift_backend._client.GetOperationStatus( + status_request + ) def test_closing_a_closed_connection_doesnt_fail(self, caplog): caplog.set_level(logging.DEBUG) @@ -737,7 +784,9 @@ class HTTP429Suite(Client429ResponseMixin, PySQLPytestTestCase): class HTTP503Suite(Client503ResponseMixin, PySQLPytestTestCase): # 503Response suite gets custom error here vs PyODBC def test_retry_disabled(self): - self._test_retry_disabled_with_message("TEMPORARILY_UNAVAILABLE", OperationalError) + self._test_retry_disabled_with_message( + "TEMPORARILY_UNAVAILABLE", OperationalError + ) class TestPySQLUnityCatalogSuite(PySQLPytestTestCase): diff --git a/tests/e2e/test_parameterized_queries.py b/tests/e2e/test_parameterized_queries.py index 47dfc38c..d346ad5c 100644 --- a/tests/e2e/test_parameterized_queries.py +++ b/tests/e2e/test_parameterized_queries.py @@ -167,12 +167,8 @@ def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle): This is a no-op but is included to make the test-code easier to read. """ target_column = self._get_inline_table_column(params.get("p")) - INSERT_QUERY = ( - f"INSERT INTO pysql_e2e_inline_param_test_table (`{target_column}`) VALUES (%(p)s)" - ) - SELECT_QUERY = ( - f"SELECT {target_column} `col` FROM pysql_e2e_inline_param_test_table LIMIT 1" - ) + INSERT_QUERY = f"INSERT INTO pysql_e2e_inline_param_test_table (`{target_column}`) VALUES (%(p)s)" + SELECT_QUERY = f"SELECT {target_column} `col` FROM pysql_e2e_inline_param_test_table LIMIT 1" DELETE_QUERY = "DELETE FROM pysql_e2e_inline_param_test_table" with self.connection(extra_params={"use_inline_params": True}) as conn: @@ -308,11 +304,15 @@ def test_use_inline_off_by_default_with_warning(self, use_inline_params, caplog) If a user explicitly sets use_inline_params, don't warn them about it. """ - extra_args = {"use_inline_params": use_inline_params} if use_inline_params else {} + extra_args = ( + {"use_inline_params": use_inline_params} if use_inline_params else {} + ) with self.connection(extra_params=extra_args) as conn: with conn.cursor() as cursor: - with self.patch_server_supports_native_params(supports_native_params=True): + with self.patch_server_supports_native_params( + supports_native_params=True + ): cursor.execute("SELECT %(p)s", parameters={"p": 1}) if use_inline_params is True: assert ( @@ -402,7 +402,9 @@ def test_inline_ordinals_can_break_sql(self): query = "SELECT 'samsonite', %s WHERE 'samsonite' LIKE '%sonite'" params = ["luggage"] with self.cursor(extra_params={"use_inline_params": True}) as cursor: - with pytest.raises(TypeError, match="not enough arguments for format string"): + with pytest.raises( + TypeError, match="not enough arguments for format string" + ): cursor.execute(query, parameters=params) def test_inline_named_dont_break_sql(self): diff --git a/tests/unit/test_arrow_queue.py b/tests/unit/test_arrow_queue.py index 6834cc9c..b3dff45f 100644 --- a/tests/unit/test_arrow_queue.py +++ b/tests/unit/test_arrow_queue.py @@ -14,13 +14,21 @@ def make_arrow_table(batch): return pa.Table.from_pydict(dict(zip(schema.names, cols)), schema=schema) def test_fetchmany_respects_n_rows(self): - arrow_table = self.make_arrow_table([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]) + arrow_table = self.make_arrow_table( + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]] + ) aq = ArrowQueue(arrow_table, 3) - self.assertEqual(aq.next_n_rows(2), self.make_arrow_table([[0, 1, 2], [3, 4, 5]])) + self.assertEqual( + aq.next_n_rows(2), self.make_arrow_table([[0, 1, 2], [3, 4, 5]]) + ) self.assertEqual(aq.next_n_rows(2), self.make_arrow_table([[6, 7, 8]])) def test_fetch_remaining_rows_respects_n_rows(self): - arrow_table = self.make_arrow_table([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]) + arrow_table = self.make_arrow_table( + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]] + ) aq = ArrowQueue(arrow_table, 3) self.assertEqual(aq.next_n_rows(1), self.make_arrow_table([[0, 1, 2]])) - self.assertEqual(aq.remaining_rows(), self.make_arrow_table([[3, 4, 5], [6, 7, 8]])) + self.assertEqual( + aq.remaining_rows(), self.make_arrow_table([[3, 4, 5], [6, 7, 8]]) + ) diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index d6541525..d5b06bbf 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -9,7 +9,10 @@ ExternalAuthProvider, AuthType, ) -from databricks.sql.auth.auth import get_python_sql_connector_auth_provider, PYSQL_OAUTH_CLIENT_ID +from databricks.sql.auth.auth import ( + get_python_sql_connector_auth_provider, + PYSQL_OAUTH_CLIENT_ID, +) from databricks.sql.auth.oauth import OAuthManager from databricks.sql.auth.authenticators import DatabricksOAuthProvider from databricks.sql.auth.endpoint import ( @@ -177,12 +180,13 @@ def test_get_python_sql_connector_basic_auth(self): } with self.assertRaises(ValueError) as e: get_python_sql_connector_auth_provider("foo.cloud.databricks.com", **kwargs) - self.assertIn("Username/password authentication is no longer supported", str(e.exception)) + self.assertIn( + "Username/password authentication is no longer supported", str(e.exception) + ) @patch.object(DatabricksOAuthProvider, "_initial_get_token") def test_get_python_sql_connector_default_auth(self, mock__initial_get_token): hostname = "foo.cloud.databricks.com" auth_provider = get_python_sql_connector_auth_provider(hostname) self.assertTrue(type(auth_provider).__name__, "DatabricksOAuthProvider") - self.assertTrue(auth_provider._client_id,PYSQL_OAUTH_CLIENT_ID) - + self.assertTrue(auth_provider._client_id, PYSQL_OAUTH_CLIENT_ID) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index c86a9f7f..0ff660d5 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -13,7 +13,7 @@ TExecuteStatementResp, TOperationHandle, THandleIdentifier, - TOperationType + TOperationType, ) from databricks.sql.thrift_backend import ThriftBackend @@ -26,8 +26,8 @@ from tests.unit.test_thrift_backend import ThriftBackendTestSuite from tests.unit.test_arrow_queue import ArrowQueueSuite -class ThriftBackendMockFactory: +class ThriftBackendMockFactory: @classmethod def new(cls): ThriftBackendMock = Mock(spec=ThriftBackend) @@ -68,10 +68,6 @@ def apply_property_to_mock(self, mock_obj, **kwargs): setattr(type(mock_obj), key, prop) - - - - class ClientTestSuite(unittest.TestCase): """ Unit tests for isolated client behaviour. @@ -89,7 +85,7 @@ def test_close_uses_the_correct_session_id(self, mock_client_class): instance = mock_client_class.return_value mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b'\x22' + mock_open_session_resp.sessionHandle.sessionId = b"\x22" instance.open_session.return_value = mock_open_session_resp connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) @@ -97,7 +93,7 @@ def test_close_uses_the_correct_session_id(self, mock_client_class): # Check the close session request has an id of x22 close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b'\x22') + self.assertEqual(close_session_id, b"\x22") @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_auth_args(self, mock_client_class): @@ -155,13 +151,19 @@ def test_useragent_header(self, mock_client_class): databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) http_headers = mock_client_class.call_args[0][3] - user_agent_header = ("User-Agent", "{}/{}".format(databricks.sql.USER_AGENT_NAME, - databricks.sql.__version__)) + user_agent_header = ( + "User-Agent", + "{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__), + ) self.assertIn(user_agent_header, http_headers) databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, _user_agent_entry="foobar") - user_agent_header_with_entry = ("User-Agent", "{}/{} ({})".format( - databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar")) + user_agent_header_with_entry = ( + "User-Agent", + "{}/{} ({})".format( + databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar" + ), + ) http_headers = mock_client_class.call_args[0][3] self.assertIn(user_agent_header_with_entry, http_headers) @@ -177,7 +179,9 @@ def test_closing_connection_closes_commands(self, mock_result_set_class): cursor.execute("SELECT 1;") connection.close() - self.assertTrue(mock_result_set_class.return_value.has_been_closed_server_side) + self.assertTrue( + mock_result_set_class.return_value.has_been_closed_server_side + ) mock_result_set_class.return_value.close.assert_called_once_with() @patch("%s.client.ThriftBackend" % PACKAGE_NAME) @@ -192,7 +196,9 @@ def test_cant_open_cursor_on_closed_connection(self, mock_client_class): @patch("%s.client.ThriftBackend" % PACKAGE_NAME) @patch("%s.client.Cursor" % PACKAGE_NAME) - def test_arraysize_buffer_size_passthrough(self, mock_cursor_class, mock_client_class): + def test_arraysize_buffer_size_passthrough( + self, mock_cursor_class, mock_client_class + ): connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) connection.cursor(arraysize=999, buffer_size_bytes=1234) kwargs = mock_cursor_class.call_args[1] @@ -204,7 +210,10 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(self): mock_connection = Mock() mock_backend = Mock() result_set = client.ResultSet( - connection=mock_connection, thrift_backend=mock_backend, execute_response=Mock()) + connection=mock_connection, + thrift_backend=mock_backend, + execute_response=Mock(), + ) mock_connection.open = False result_set.close() @@ -218,20 +227,27 @@ def test_closing_result_set_hard_closes_commands(self): mock_connection = Mock() mock_thrift_backend = Mock() mock_connection.open = True - result_set = client.ResultSet(mock_connection, mock_results_response, mock_thrift_backend) + result_set = client.ResultSet( + mock_connection, mock_results_response, mock_thrift_backend + ) result_set.close() mock_thrift_backend.close_command.assert_called_once_with( - mock_results_response.command_handle) + mock_results_response.command_handle + ) @patch("%s.client.ResultSet" % PACKAGE_NAME) - def test_executing_multiple_commands_uses_the_most_recent_command(self, mock_result_set_class): + def test_executing_multiple_commands_uses_the_most_recent_command( + self, mock_result_set_class + ): mock_result_sets = [Mock(), Mock()] mock_result_set_class.side_effect = mock_result_sets - cursor = client.Cursor(connection=Mock(), thrift_backend=ThriftBackendMockFactory.new()) + cursor = client.Cursor( + connection=Mock(), thrift_backend=ThriftBackendMockFactory.new() + ) cursor.execute("SELECT 1;") cursor.execute("SELECT 1;") @@ -272,7 +288,7 @@ def test_context_manager_closes_connection(self, mock_client_class): instance = mock_client_class.return_value mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b'\x22' + mock_open_session_resp.sessionHandle.sessionId = b"\x22" instance.open_session.return_value = mock_open_session_resp with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: @@ -280,7 +296,7 @@ def test_context_manager_closes_connection(self, mock_client_class): # Check the close session request has an id of x22 close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b'\x22') + self.assertEqual(close_session_id, b"\x22") def dict_product(self, dicts): """ @@ -299,7 +315,9 @@ def test_get_schemas_parameters_passed_to_thrift_backend(self, mock_thrift_backe req_args_combinations = self.dict_product( dict( catalog_name=["NOT_SET", None, "catalog_pattern"], - schema_name=["NOT_SET", None, "schema_pattern"])) + schema_name=["NOT_SET", None, "schema_pattern"], + ) + ) for req_args in req_args_combinations: req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"} @@ -320,7 +338,9 @@ def test_get_tables_parameters_passed_to_thrift_backend(self, mock_thrift_backen catalog_name=["NOT_SET", None, "catalog_pattern"], schema_name=["NOT_SET", None, "schema_pattern"], table_name=["NOT_SET", None, "table_pattern"], - table_types=["NOT_SET", [], ["type1", "type2"]])) + table_types=["NOT_SET", [], ["type1", "type2"]], + ) + ) for req_args in req_args_combinations: req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"} @@ -341,7 +361,9 @@ def test_get_columns_parameters_passed_to_thrift_backend(self, mock_thrift_backe catalog_name=["NOT_SET", None, "catalog_pattern"], schema_name=["NOT_SET", None, "schema_pattern"], table_name=["NOT_SET", None, "table_pattern"], - column_name=["NOT_SET", None, "column_pattern"])) + column_name=["NOT_SET", None, "column_pattern"], + ) + ) for req_args in req_args_combinations: req_args = {k: v for k, v in req_args.items() if v != "NOT_SET"} @@ -365,7 +387,8 @@ def test_cancel_command_calls_the_backend(self): @patch("databricks.sql.client.logger") def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command( - self, logger_instance): + self, logger_instance + ): mock_thrift_backend = Mock() cursor = client.Cursor(Mock(), mock_thrift_backend) cursor.cancel() @@ -375,9 +398,13 @@ def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command( @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_max_number_of_retries_passthrough(self, mock_client_class): - databricks.sql.connect(_retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS) + databricks.sql.connect( + _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS + ) - self.assertEqual(mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54) + self.assertEqual( + mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54 + ) @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_socket_timeout_passthrough(self, mock_client_class): @@ -386,27 +413,38 @@ def test_socket_timeout_passthrough(self, mock_client_class): def test_version_is_canonical(self): version = databricks.sql.__version__ - canonical_version_re = r'^([1-9][0-9]*!)?(0|[1-9][0-9]*)(\.(0|[1-9][0-9]*))*((a|b|rc)' \ - r'(0|[1-9][0-9]*))?(\.post(0|[1-9][0-9]*))?(\.dev(0|[1-9][0-9]*))?$' + canonical_version_re = ( + r"^([1-9][0-9]*!)?(0|[1-9][0-9]*)(\.(0|[1-9][0-9]*))*((a|b|rc)" + r"(0|[1-9][0-9]*))?(\.post(0|[1-9][0-9]*))?(\.dev(0|[1-9][0-9]*))?$" + ) self.assertIsNotNone(re.match(canonical_version_re, version)) @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_configuration_passthrough(self, mock_client_class): mock_session_config = Mock() databricks.sql.connect( - session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS) + session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS + ) - self.assertEqual(mock_client_class.return_value.open_session.call_args[0][0], - mock_session_config) + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][0], + mock_session_config, + ) @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_initial_namespace_passthrough(self, mock_client_class): mock_cat = Mock() mock_schem = Mock() - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem) - self.assertEqual(mock_client_class.return_value.open_session.call_args[0][1], mock_cat) - self.assertEqual(mock_client_class.return_value.open_session.call_args[0][2], mock_schem) + databricks.sql.connect( + **self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem + ) + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][1], mock_cat + ) + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][2], mock_schem + ) def test_execute_parameter_passthrough(self): mock_thrift_backend = ThriftBackendMockFactory.new() @@ -436,7 +474,8 @@ def test_execute_parameter_passthrough(self): @patch("%s.client.ThriftBackend" % PACKAGE_NAME) @patch("%s.client.ResultSet" % PACKAGE_NAME) def test_executemany_parameter_passhthrough_and_uses_last_result_set( - self, mock_result_set_class, mock_thrift_backend): + self, mock_result_set_class, mock_thrift_backend + ): # Create a new mock result set each time the class is instantiated mock_result_set_instances = [Mock(), Mock(), Mock()] mock_result_set_class.side_effect = mock_result_set_instances @@ -449,17 +488,22 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set( cursor.executemany("SELECT %(x)s", seq_of_parameters=params) self.assertEqual( - len(mock_thrift_backend.execute_command.call_args_list), len(expected_queries), - "Expected execute_command to be called the same number of times as params were passed") + len(mock_thrift_backend.execute_command.call_args_list), + len(expected_queries), + "Expected execute_command to be called the same number of times as params were passed", + ) - for expected_query, call_args in zip(expected_queries, - mock_thrift_backend.execute_command.call_args_list): + for expected_query, call_args in zip( + expected_queries, mock_thrift_backend.execute_command.call_args_list + ): self.assertEqual(call_args[1]["operation"], expected_query) self.assertEqual( - cursor.active_result_set, mock_result_set_instances[2], + cursor.active_result_set, + mock_result_set_instances[2], "Expected the active result set to be the result set corresponding to the" - "last operation") + "last operation", + ) @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_commit_a_noop(self, mock_thrift_backend_class): @@ -495,7 +539,7 @@ def make_fake_row_slice(n_rows): mock_thrift_backend.fetch_results.return_value = (mock_aq, True) cursor = client.Cursor(Mock(), mock_thrift_backend) - cursor.execute('foo') + cursor.execute("foo") self.assertEqual(cursor.rownumber, 0) cursor.fetchmany_arrow(10) @@ -516,12 +560,14 @@ def test_disable_pandas_respected(self, mock_thrift_backend_class): mock_aq = Mock() mock_aq.remaining_rows.return_value = mock_table mock_thrift_backend.execute_command.return_value.arrow_queue = mock_aq - mock_thrift_backend.execute_command.return_value.has_been_closed_server_side = True + mock_thrift_backend.execute_command.return_value.has_been_closed_server_side = ( + True + ) mock_con = Mock() mock_con.disable_pandas = True cursor = client.Cursor(mock_con, mock_thrift_backend) - cursor.execute('foo') + cursor.execute("foo") cursor.fetchall() mock_table.itercolumns.assert_called_once_with() @@ -548,18 +594,21 @@ def test_column_name_api(self): self.assertEqual(row[1], expected[1]) self.assertEqual(row[2], expected[2]) - self.assertEqual(row.asDict(), { - "first_col": expected[0], - "second_col": expected[1], - "third_col": expected[2] - }) + self.assertEqual( + row.asDict(), + { + "first_col": expected[0], + "second_col": expected[1], + "third_col": expected[2], + }, + ) @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_finalizer_closes_abandoned_connection(self, mock_client_class): instance = mock_client_class.return_value mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b'\x22' + mock_open_session_resp.sessionHandle.sessionId = b"\x22" instance.open_session.return_value = mock_open_session_resp databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) @@ -569,14 +618,14 @@ def test_finalizer_closes_abandoned_connection(self, mock_client_class): # Check the close session request has an id of x22 close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b'\x22') + self.assertEqual(close_session_id, b"\x22") @patch("%s.client.ThriftBackend" % PACKAGE_NAME) def test_cursor_keeps_connection_alive(self, mock_client_class): instance = mock_client_class.return_value mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b'\x22' + mock_open_session_resp.sessionHandle.sessionId = b"\x22" instance.open_session.return_value = mock_open_session_resp connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) @@ -591,11 +640,14 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): @patch("%s.utils.ExecuteResponse" % PACKAGE_NAME, autospec=True) @patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME) @patch("%s.client.ThriftBackend" % PACKAGE_NAME) - def test_staging_operation_response_is_handled(self, mock_client_class, mock_handle_staging_operation, mock_execute_response): + def test_staging_operation_response_is_handled( + self, mock_client_class, mock_handle_staging_operation, mock_execute_response + ): # If server sets ExecuteResponse.is_staging_operation True then _handle_staging_operation should be called - - ThriftBackendMockFactory.apply_property_to_mock(mock_execute_response, is_staging_operation=True) + ThriftBackendMockFactory.apply_property_to_mock( + mock_execute_response, is_staging_operation=True + ) mock_client_class.execute_command.return_value = mock_execute_response mock_client_class.return_value = mock_client_class @@ -608,7 +660,7 @@ def test_staging_operation_response_is_handled(self, mock_client_class, mock_han @patch("%s.client.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) def test_access_current_query_id(self): - operation_id = 'EE6A8778-21FC-438B-92D8-96AC51EE3821' + operation_id = "EE6A8778-21FC-438B-92D8-96AC51EE3821" connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) cursor = connection.cursor() @@ -617,17 +669,23 @@ def test_access_current_query_id(self): cursor.active_op_handle = TOperationHandle( operationId=THandleIdentifier(guid=UUID(operation_id).bytes, secret=0x00), - operationType=TOperationType.EXECUTE_STATEMENT) + operationType=TOperationType.EXECUTE_STATEMENT, + ) self.assertEqual(cursor.query_id.upper(), operation_id.upper()) cursor.close() self.assertIsNone(cursor.query_id) -if __name__ == '__main__': +if __name__ == "__main__": suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__]) loader = unittest.TestLoader() - test_classes = [ClientTestSuite, FetchTests, ThriftBackendTestSuite, ArrowQueueSuite] + test_classes = [ + ClientTestSuite, + FetchTests, + ThriftBackendTestSuite, + ArrowQueueSuite, + ] suites_list = [] for test_class in test_classes: suite = loader.loadTestsFromTestCase(test_class) diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index acd0c392..01d8a79b 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -6,22 +6,26 @@ import databricks.sql.utils as utils from databricks.sql.types import SSLOptions -class CloudFetchQueueSuite(unittest.TestCase): +class CloudFetchQueueSuite(unittest.TestCase): def create_result_link( - self, - file_link: str = "fileLink", - start_row_offset: int = 0, - row_count: int = 8000, - bytes_num: int = 20971520 + self, + file_link: str = "fileLink", + start_row_offset: int = 0, + row_count: int = 8000, + bytes_num: int = 20971520, ): - return TSparkArrowResultLink(file_link, None, start_row_offset, row_count, bytes_num) + return TSparkArrowResultLink( + file_link, None, start_row_offset, row_count, bytes_num + ) def create_result_links(self, num_files: int, start_row_offset: int = 0): result_links = [] for i in range(num_files): file_link = "fileLink_" + str(i) - result_link = self.create_result_link(file_link=file_link, start_row_offset=start_row_offset) + result_link = self.create_result_link( + file_link=file_link, start_row_offset=start_row_offset + ) result_links.append(result_link) start_row_offset += result_link.rowCount return result_links @@ -42,8 +46,10 @@ def get_schema_bytes(): writer.close() return sink.getvalue().to_pybytes() - - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=[None, None]) + @patch( + "databricks.sql.utils.CloudFetchQueue._create_next_table", + return_value=[None, None], + ) def test_initializer_adds_links(self, mock_create_next_table): schema_bytes = MagicMock() result_links = self.create_result_links(10) @@ -72,7 +78,10 @@ def test_initializer_no_links_to_add(self): assert len(queue.download_manager._download_tasks) == 0 assert queue.table is None - @patch("databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", return_value=None) + @patch( + "databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", + return_value=None, + ) def test_create_next_table_no_download(self, mock_get_next_downloaded_file): queue = utils.CloudFetchQueue( MagicMock(), @@ -85,9 +94,13 @@ def test_create_next_table_no_download(self, mock_get_next_downloaded_file): mock_get_next_downloaded_file.assert_called_with(0) @patch("databricks.sql.utils.create_arrow_table_from_arrow_file") - @patch("databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", - return_value=MagicMock(file_bytes=b"1234567890", row_count=4)) - def test_initializer_create_next_table_success(self, mock_get_next_downloaded_file, mock_create_arrow_table): + @patch( + "databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", + return_value=MagicMock(file_bytes=b"1234567890", row_count=4), + ) + def test_initializer_create_next_table_success( + self, mock_get_next_downloaded_file, mock_create_arrow_table + ): mock_create_arrow_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue( @@ -169,7 +182,12 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): result = queue.next_n_rows(7) assert result.num_rows == 7 assert queue.table_row_index == 3 - assert result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[:7] + assert ( + result + == pyarrow.concat_tables( + [self.make_arrow_table(), self.make_arrow_table()] + )[:7] + ) @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): @@ -265,8 +283,14 @@ def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): assert result == self.make_arrow_table() @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") - def test_remaining_rows_multiple_tables_fully_returned(self, mock_create_next_table): - mock_create_next_table.side_effect = [self.make_arrow_table(), self.make_arrow_table(), None] + def test_remaining_rows_multiple_tables_fully_returned( + self, mock_create_next_table + ): + mock_create_next_table.side_effect = [ + self.make_arrow_table(), + self.make_arrow_table(), + None, + ] schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue( schema_bytes, @@ -282,7 +306,12 @@ def test_remaining_rows_multiple_tables_fully_returned(self, mock_create_next_ta result = queue.remaining_rows() assert mock_create_next_table.call_count == 3 assert result.num_rows == 5 - assert result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[3:] + assert ( + result + == pyarrow.concat_tables( + [self.make_arrow_table(), self.make_arrow_table()] + )[3:] + ) @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=None) def test_remaining_rows_empty_table(self, mock_create_next_table): diff --git a/tests/unit/test_column_queue.py b/tests/unit/test_column_queue.py index 130b589b..234af88e 100644 --- a/tests/unit/test_column_queue.py +++ b/tests/unit/test_column_queue.py @@ -8,14 +8,18 @@ def make_column_table(table): return ColumnTable(table, [f"col_{i}" for i in range(n_cols)]) def test_fetchmany_respects_n_rows(self): - column_table = self.make_column_table([[0, 3, 6, 9], [1, 4, 7, 10], [2, 5, 8, 11]]) + column_table = self.make_column_table( + [[0, 3, 6, 9], [1, 4, 7, 10], [2, 5, 8, 11]] + ) column_queue = ColumnQueue(column_table) assert column_queue.next_n_rows(2) == column_table.slice(0, 2) assert column_queue.next_n_rows(2) == column_table.slice(2, 2) def test_fetch_remaining_rows_respects_n_rows(self): - column_table = self.make_column_table([[0, 3, 6, 9], [1, 4, 7, 10], [2, 5, 8, 11]]) + column_table = self.make_column_table( + [[0, 3, 6, 9], [1, 4, 7, 10], [2, 5, 8, 11]] + ) column_queue = ColumnQueue(column_table) assert column_queue.next_n_rows(2) == column_table.slice(0, 2) diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py index a11bc8d4..64edbdeb 100644 --- a/tests/unit/test_download_manager.py +++ b/tests/unit/test_download_manager.py @@ -11,7 +11,9 @@ class DownloadManagerTests(unittest.TestCase): Unit tests for checking download manager logic. """ - def create_download_manager(self, links, max_download_threads=10, lz4_compressed=True): + def create_download_manager( + self, links, max_download_threads=10, lz4_compressed=True + ): return download_manager.ResultFileDownloadManager( links, max_download_threads, @@ -20,19 +22,23 @@ def create_download_manager(self, links, max_download_threads=10, lz4_compressed ) def create_result_link( - self, - file_link: str = "fileLink", - start_row_offset: int = 0, - row_count: int = 8000, - bytes_num: int = 20971520 + self, + file_link: str = "fileLink", + start_row_offset: int = 0, + row_count: int = 8000, + bytes_num: int = 20971520, ): - return TSparkArrowResultLink(file_link, None, start_row_offset, row_count, bytes_num) + return TSparkArrowResultLink( + file_link, None, start_row_offset, row_count, bytes_num + ) def create_result_links(self, num_files: int, start_row_offset: int = 0): result_links = [] for i in range(num_files): file_link = "fileLink_" + str(i) - result_link = self.create_result_link(file_link=file_link, start_row_offset=start_row_offset) + result_link = self.create_result_link( + file_link=file_link, start_row_offset=start_row_offset + ) result_links.append(result_link) start_row_offset += result_link.rowCount return result_links @@ -41,7 +47,9 @@ def test_add_file_links_zero_row_count(self): links = [self.create_result_link(row_count=0, bytes_num=0)] manager = self.create_download_manager(links) - assert len(manager._pending_links) == 0 # the only link supplied contains no data, so should be skipped + assert ( + len(manager._pending_links) == 0 + ) # the only link supplied contains no data, so should be skipped assert len(manager._download_tasks) == 0 def test_add_file_links_success(self): @@ -55,7 +63,9 @@ def test_add_file_links_success(self): def test_schedule_downloads(self, mock_submit): max_download_threads = 4 links = self.create_result_links(num_files=10) - manager = self.create_download_manager(links, max_download_threads=max_download_threads) + manager = self.create_download_manager( + links, max_download_threads=max_download_threads + ) manager._schedule_downloads() assert mock_submit.call_count == max_download_threads diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index 7075ef6c..2a3b715b 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -20,36 +20,40 @@ class DownloaderTests(unittest.TestCase): Unit tests for checking downloader logic. """ - @patch('time.time', return_value=1000) + @patch("time.time", return_value=1000) def test_run_link_expired(self, mock_time): settings = Mock() result_link = Mock() # Already expired result_link.expiryTime = 999 - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) with self.assertRaises(Error) as context: d.run() - self.assertTrue('link has expired' in context.exception.message) + self.assertTrue("link has expired" in context.exception.message) mock_time.assert_called_once() - @patch('time.time', return_value=1000) + @patch("time.time", return_value=1000) def test_run_link_past_expiry_buffer(self, mock_time): settings = Mock(link_expiry_buffer_secs=5) result_link = Mock() # Within the expiry buffer time result_link.expiryTime = 1004 - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) with self.assertRaises(Error) as context: d.run() - self.assertTrue('link has expired' in context.exception.message) + self.assertTrue("link has expired" in context.exception.message) mock_time.assert_called_once() - @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=None))) - @patch('time.time', return_value=1000) + @patch("requests.Session", return_value=MagicMock(get=MagicMock(return_value=None))) + @patch("time.time", return_value=1000) def test_run_get_response_not_ok(self, mock_time, mock_session): mock_session.return_value.get.return_value = create_response(status_code=404) @@ -58,62 +62,81 @@ def test_run_get_response_not_ok(self, mock_time, mock_session): settings.use_proxy = False result_link = Mock(expiryTime=1001) - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) with self.assertRaises(requests.exceptions.HTTPError) as context: d.run() - self.assertTrue('404' in str(context.exception)) + self.assertTrue("404" in str(context.exception)) - @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=None))) - @patch('time.time', return_value=1000) + @patch("requests.Session", return_value=MagicMock(get=MagicMock(return_value=None))) + @patch("time.time", return_value=1000) def test_run_uncompressed_successful(self, mock_time, mock_session): file_bytes = b"1234567890" * 10 - mock_session.return_value.get.return_value = create_response(status_code=200, _content=file_bytes) + mock_session.return_value.get.return_value = create_response( + status_code=200, _content=file_bytes + ) settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = False result_link = Mock(bytesNum=100, expiryTime=1001) - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) file = d.run() assert file.file_bytes == b"1234567890" * 10 - @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True)))) - @patch('time.time', return_value=1000) + @patch( + "requests.Session", + return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True))), + ) + @patch("time.time", return_value=1000) def test_run_compressed_successful(self, mock_time, mock_session): file_bytes = b"1234567890" * 10 compressed_bytes = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - mock_session.return_value.get.return_value = create_response(status_code=200, _content=compressed_bytes) + mock_session.return_value.get.return_value = create_response( + status_code=200, _content=compressed_bytes + ) settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = True result_link = Mock(bytesNum=100, expiryTime=1001) - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) file = d.run() assert file.file_bytes == b"1234567890" * 10 - @patch('requests.Session.get', side_effect=ConnectionError('foo')) - @patch('time.time', return_value=1000) + @patch("requests.Session.get", side_effect=ConnectionError("foo")) + @patch("time.time", return_value=1000) def test_download_connection_error(self, mock_time, mock_session): - settings = Mock(link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True) + settings = Mock( + link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True + ) result_link = Mock(bytesNum=100, expiryTime=1001) - mock_session.return_value.get.return_value.content = \ - b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' + mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) with self.assertRaises(ConnectionError): d.run() - @patch('requests.Session.get', side_effect=TimeoutError('foo')) - @patch('time.time', return_value=1000) + @patch("requests.Session.get", side_effect=TimeoutError("foo")) + @patch("time.time", return_value=1000) def test_download_timeout(self, mock_time, mock_session): - settings = Mock(link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True) + settings = Mock( + link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True + ) result_link = Mock(bytesNum=100, expiryTime=1001) - mock_session.return_value.get.return_value.content = \ - b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' + mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - d = downloader.ResultSetDownloadHandler(settings, result_link, ssl_options=SSLOptions()) + d = downloader.ResultSetDownloadHandler( + settings, result_link, ssl_options=SSLOptions() + ) with self.assertRaises(TimeoutError): d.run() diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 7d5686f8..e9a58acd 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -17,7 +17,9 @@ def make_arrow_table(batch): n_cols = len(batch[0]) if batch else 0 schema = pa.schema({"col%s" % i: pa.uint32() for i in range(n_cols)}) cols = [[batch[row][col] for row in range(len(batch))] for col in range(n_cols)] - return schema, pa.Table.from_pydict(dict(zip(schema.names, cols)), schema=schema) + return schema, pa.Table.from_pydict( + dict(zip(schema.names, cols)), schema=schema + ) @staticmethod def make_arrow_queue(batch): @@ -42,18 +44,29 @@ def make_dummy_result_set_from_initial_results(initial_results): command_handle=None, arrow_queue=arrow_queue, arrow_schema_bytes=schema.serialize().to_pybytes(), - is_staging_operation=False)) + is_staging_operation=False, + ), + ) num_cols = len(initial_results[0]) if initial_results else 0 - rs.description = [(f'col{col_id}', 'integer', None, None, None, None, None) - for col_id in range(num_cols)] + rs.description = [ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ] return rs @staticmethod def make_dummy_result_set_from_batch_list(batch_list): batch_index = 0 - def fetch_results(op_handle, max_rows, max_bytes, expected_row_start_offset, lz4_compressed, - arrow_schema_bytes, description): + def fetch_results( + op_handle, + max_rows, + max_bytes, + expected_row_start_offset, + lz4_compressed, + arrow_schema_bytes, + description, + ): nonlocal batch_index results = FetchTests.make_arrow_queue(batch_list[batch_index]) batch_index += 1 @@ -71,13 +84,17 @@ def fetch_results(op_handle, max_rows, max_bytes, expected_row_start_offset, lz4 status=None, has_been_closed_server_side=False, has_more_rows=True, - description=[(f'col{col_id}', 'integer', None, None, None, None, None) - for col_id in range(num_cols)], + description=[ + (f"col{col_id}", "integer", None, None, None, None, None) + for col_id in range(num_cols) + ], lz4_compressed=Mock(), command_handle=None, arrow_queue=None, arrow_schema_bytes=None, - is_staging_operation=False)) + is_staging_operation=False, + ), + ) return rs def assertEqualRowValues(self, actual, expected): @@ -87,30 +104,44 @@ def assertEqualRowValues(self, actual, expected): def test_fetchmany_with_initial_results(self): # Fetch all in one go - initial_results_1 = [[1], [2], [3]] # This is a list of rows, each row with 1 col - dummy_result_set = self.make_dummy_result_set_from_initial_results(initial_results_1) + initial_results_1 = [ + [1], + [2], + [3], + ] # This is a list of rows, each row with 1 col + dummy_result_set = self.make_dummy_result_set_from_initial_results( + initial_results_1 + ) self.assertEqualRowValues(dummy_result_set.fetchmany(3), [[1], [2], [3]]) # Fetch in small amounts initial_results_2 = [[1], [2], [3], [4]] - dummy_result_set = self.make_dummy_result_set_from_initial_results(initial_results_2) + dummy_result_set = self.make_dummy_result_set_from_initial_results( + initial_results_2 + ) self.assertEqualRowValues(dummy_result_set.fetchmany(1), [[1]]) self.assertEqualRowValues(dummy_result_set.fetchmany(2), [[2], [3]]) self.assertEqualRowValues(dummy_result_set.fetchmany(1), [[4]]) # Fetch too many initial_results_3 = [[2], [3]] - dummy_result_set = self.make_dummy_result_set_from_initial_results(initial_results_3) + dummy_result_set = self.make_dummy_result_set_from_initial_results( + initial_results_3 + ) self.assertEqualRowValues(dummy_result_set.fetchmany(5), [[2], [3]]) # Empty results initial_results_4 = [[]] - dummy_result_set = self.make_dummy_result_set_from_initial_results(initial_results_4) + dummy_result_set = self.make_dummy_result_set_from_initial_results( + initial_results_4 + ) self.assertEqualRowValues(dummy_result_set.fetchmany(0), []) def test_fetch_many_without_initial_results(self): # Fetch all in one go; single batch - batch_list_1 = [[[1], [2], [3]]] # This is a list of one batch of rows, each row with 1 col + batch_list_1 = [ + [[1], [2], [3]] + ] # This is a list of one batch of rows, each row with 1 col dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_1) self.assertEqualRowValues(dummy_result_set.fetchmany(3), [[1], [2], [3]]) @@ -140,7 +171,9 @@ def test_fetch_many_without_initial_results(self): # Fetch too many; multiple batches batch_list_6 = [[[1]], [[2], [3], [4]], [[5], [6]]] dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_6) - self.assertEqualRowValues(dummy_result_set.fetchmany(100), [[1], [2], [3], [4], [5], [6]]) + self.assertEqualRowValues( + dummy_result_set.fetchmany(100), [[1], [2], [3], [4], [5], [6]] + ) # Fetch 0; 1 empty batch batch_list_7 = [[]] @@ -154,19 +187,25 @@ def test_fetch_many_without_initial_results(self): def test_fetchall_with_initial_results(self): initial_results_1 = [[1], [2], [3]] - dummy_result_set = self.make_dummy_result_set_from_initial_results(initial_results_1) + dummy_result_set = self.make_dummy_result_set_from_initial_results( + initial_results_1 + ) self.assertEqualRowValues(dummy_result_set.fetchall(), [[1], [2], [3]]) def test_fetchall_without_initial_results(self): # Fetch all, single batch - batch_list_1 = [[[1], [2], [3]]] # This is a list of one batch of rows, each row with 1 col + batch_list_1 = [ + [[1], [2], [3]] + ] # This is a list of one batch of rows, each row with 1 col dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_1) self.assertEqualRowValues(dummy_result_set.fetchall(), [[1], [2], [3]]) # Fetch all, multiple batches batch_list_2 = [[[1], [2]], [[3]], [[4], [5], [6]]] dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_2) - self.assertEqualRowValues(dummy_result_set.fetchall(), [[1], [2], [3], [4], [5], [6]]) + self.assertEqualRowValues( + dummy_result_set.fetchall(), [[1], [2], [3], [4], [5], [6]] + ) batch_list_3 = [[]] dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_3) @@ -174,12 +213,16 @@ def test_fetchall_without_initial_results(self): def test_fetchmany_fetchall_with_initial_results(self): initial_results_1 = [[1], [2], [3]] - dummy_result_set = self.make_dummy_result_set_from_initial_results(initial_results_1) + dummy_result_set = self.make_dummy_result_set_from_initial_results( + initial_results_1 + ) self.assertEqualRowValues(dummy_result_set.fetchmany(2), [[1], [2]]) self.assertEqualRowValues(dummy_result_set.fetchall(), [[3]]) def test_fetchmany_fetchall_without_initial_results(self): - batch_list_1 = [[[1], [2], [3]]] # This is a list of one batch of rows, each row with 1 col + batch_list_1 = [ + [[1], [2], [3]] + ] # This is a list of one batch of rows, each row with 1 col dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_1) self.assertEqualRowValues(dummy_result_set.fetchmany(2), [[1], [2]]) self.assertEqualRowValues(dummy_result_set.fetchall(), [[3]]) @@ -191,7 +234,9 @@ def test_fetchmany_fetchall_without_initial_results(self): def test_fetchone_with_initial_results(self): initial_results_1 = [[1], [2], [3]] - dummy_result_set = self.make_dummy_result_set_from_initial_results(initial_results_1) + dummy_result_set = self.make_dummy_result_set_from_initial_results( + initial_results_1 + ) self.assertSequenceEqual(dummy_result_set.fetchone(), [1]) self.assertSequenceEqual(dummy_result_set.fetchone(), [2]) self.assertSequenceEqual(dummy_result_set.fetchone(), [3]) @@ -210,5 +255,5 @@ def test_fetchone_without_initial_results(self): self.assertEqual(dummy_result_set.fetchone(), None) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_fetches_bench.py b/tests/unit/test_fetches_bench.py index e322b44a..9382c3b3 100644 --- a/tests/unit/test_fetches_bench.py +++ b/tests/unit/test_fetches_bench.py @@ -35,12 +35,18 @@ def make_dummy_result_set_from_initial_results(arrow_table): description=Mock(), command_handle=None, arrow_queue=arrow_queue, - arrow_schema=arrow_table.schema)) - rs.description = [(f'col{col_id}', 'string', None, None, None, None, None) - for col_id in range(arrow_table.num_columns)] + arrow_schema=arrow_table.schema, + ), + ) + rs.description = [ + (f"col{col_id}", "string", None, None, None, None, None) + for col_id in range(arrow_table.num_columns) + ] return rs - @pytest.mark.skip(reason="Test has not been updated for latest connector API (June 2022)") + @pytest.mark.skip( + reason="Test has not been updated for latest connector API (June 2022)" + ) def test_benchmark_fetchall(self): print("preparing dummy arrow table") arrow_table = FetchBenchmarkTests.make_arrow_table(10, 25000) @@ -50,7 +56,9 @@ def test_benchmark_fetchall(self): start_time = time.time() count = 0 while time.time() < start_time + benchmark_seconds: - dummy_result_set = self.make_dummy_result_set_from_initial_results(arrow_table) + dummy_result_set = self.make_dummy_result_set_from_initial_results( + arrow_table + ) res = dummy_result_set.fetchall() for _ in res: pass @@ -59,5 +67,5 @@ def test_benchmark_fetchall(self): print(f"Executed query {count} times, in {time.time() - start_time} seconds") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_oauth_persistence.py b/tests/unit/test_oauth_persistence.py index 28b3cab3..a8ceb14e 100644 --- a/tests/unit/test_oauth_persistence.py +++ b/tests/unit/test_oauth_persistence.py @@ -1,16 +1,17 @@ - import unittest -from databricks.sql.experimental.oauth_persistence import DevOnlyFilePersistence, OAuthToken +from databricks.sql.experimental.oauth_persistence import ( + DevOnlyFilePersistence, + OAuthToken, +) import tempfile import os class OAuthPersistenceTests(unittest.TestCase): - def test_DevOnlyFilePersistence_read_my_write(self): with tempfile.TemporaryDirectory() as tempdir: - test_json_file_path = os.path.join(tempdir, 'test.json') + test_json_file_path = os.path.join(tempdir, "test.json") persistence_manager = DevOnlyFilePersistence(test_json_file_path) access_token = "abc#$%%^&^*&*()()_=-/" refresh_token = "#$%%^^&**()+)_gter243]xyz" @@ -23,7 +24,7 @@ def test_DevOnlyFilePersistence_read_my_write(self): def test_DevOnlyFilePersistence_file_does_not_exist(self): with tempfile.TemporaryDirectory() as tempdir: - test_json_file_path = os.path.join(tempdir, 'test.json') + test_json_file_path = os.path.join(tempdir, "test.json") persistence_manager = DevOnlyFilePersistence(test_json_file_path) new_token = persistence_manager.read("https://randomserver") diff --git a/tests/unit/test_param_escaper.py b/tests/unit/test_param_escaper.py index 472a0843..925fcea5 100644 --- a/tests/unit/test_param_escaper.py +++ b/tests/unit/test_param_escaper.py @@ -3,7 +3,12 @@ from typing import Any, Dict from databricks.sql.parameters.native import dbsql_parameter_from_primitive -from databricks.sql.utils import ParamEscaper, inject_parameters, transform_paramstyle, ParameterStructure +from databricks.sql.utils import ( + ParamEscaper, + inject_parameters, + transform_paramstyle, + ParameterStructure, +) pe = ParamEscaper() @@ -200,26 +205,31 @@ class TestInlineToNativeTransformer(object): "query with like wildcard", 'select * from table where field like "%"', {}, - 'select * from table where field like "%"' + 'select * from table where field like "%"', ), ( "query with named param and like wildcard", 'select :param from table where field like "%"', {"param": None}, - 'select :param from table where field like "%"' + 'select :param from table where field like "%"', ), ( "query with doubled wildcards", - 'select 1 where '' like "%%"', + "select 1 where " ' like "%%"', {"param": None}, - 'select 1 where '' like "%%"', - ) + "select 1 where " ' like "%%"', + ), ), ) def test_transformer( self, label: str, query: str, params: Dict[str, Any], expected: str ): - _params = [dbsql_parameter_from_primitive(value=value, name=name) for name, value in params.items()] - output = transform_paramstyle(query, _params, param_structure=ParameterStructure.NAMED) + _params = [ + dbsql_parameter_from_primitive(value=value, name=name) + for name, value in params.items() + ] + output = transform_paramstyle( + query, _params, param_structure=ParameterStructure.NAMED + ) assert output == expected diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py index 798bac2e..2108af4f 100644 --- a/tests/unit/test_retry.py +++ b/tests/unit/test_retry.py @@ -8,7 +8,6 @@ class TestRetry: - @pytest.fixture() def retry_policy(self) -> DatabricksRetryPolicy: return DatabricksRetryPolicy( @@ -41,7 +40,9 @@ def test_sleep__retry_after_is_binding(self, t_mock, retry_policy, error_history t_mock.assert_called_with(3) @patch("time.sleep") - def test_sleep__retry_after_present_but_not_binding(self, t_mock, retry_policy, error_history): + def test_sleep__retry_after_present_but_not_binding( + self, t_mock, retry_policy, error_history + ): retry_policy._retry_start_time = time.time() retry_policy.history = [error_history, error_history] retry_policy.sleep(HTTPResponse(status=503, headers={"Retry-After": "1"})) diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 0333766c..293467af 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -69,7 +69,14 @@ def test_make_request_checks_thrift_status_code(self): mock_method = Mock() mock_method.__name__ = "method name" mock_method.return_value = mock_response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(DatabaseError): thrift_backend.make_request(mock_method, Mock()) @@ -79,7 +86,14 @@ def _make_type_desc(self, type): ) def _make_fake_thrift_backend(self): - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) thrift_backend._hive_schema_to_arrow_schema = Mock() thrift_backend._hive_schema_to_description = Mock() thrift_backend._create_arrow_table = MagicMock() @@ -89,13 +103,16 @@ def _make_fake_thrift_backend(self): def test_hive_schema_to_arrow_schema_preserves_column_names(self): columns = [ ttypes.TColumnDesc( - columnName="column 1", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE) + columnName="column 1", + typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE), ), ttypes.TColumnDesc( - columnName="column 2", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE) + columnName="column 2", + typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE), ), ttypes.TColumnDesc( - columnName="column 2", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE) + columnName="column 2", + typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE), ), ttypes.TColumnDesc( columnName="", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE) @@ -136,7 +153,9 @@ def test_bad_protocol_versions_are_rejected(self, tcli_service_client_cass): thrift_backend = self._make_fake_thrift_backend() thrift_backend.open_session({}, None, None) - self.assertIn("expected server to use a protocol version", str(cm.exception)) + self.assertIn( + "expected server to use a protocol version", str(cm.exception) + ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): @@ -157,8 +176,17 @@ def test_okay_protocol_versions_succeed(self, tcli_service_client_cass): @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_headers_are_set(self, t_http_client_class): - ThriftBackend("foo", 123, "bar", [("header", "value")], auth_provider=AuthProvider(), ssl_options=SSLOptions()) - t_http_client_class.return_value.setCustomHeaders.assert_called_with({"header": "value"}) + ThriftBackend( + "foo", + 123, + "bar", + [("header", "value")], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + t_http_client_class.return_value.setCustomHeaders.assert_called_with( + {"header": "value"} + ) def test_proxy_headers_are_set(self): @@ -174,11 +202,13 @@ def test_proxy_headers_are_set(self): assert False assert isinstance(result, type(dict())) - assert isinstance(result.get('proxy-authorization'), type(str())) + assert isinstance(result.get("proxy-authorization"), type(str())) @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch("databricks.sql.types.create_default_context") - def test_tls_cert_args_are_propagated(self, mock_create_default_context, t_http_client_class): + def test_tls_cert_args_are_propagated( + self, mock_create_default_context, t_http_client_class + ): mock_cert_key_file = Mock() mock_cert_key_password = Mock() mock_trusted_ca_file = Mock() @@ -203,11 +233,15 @@ def test_tls_cert_args_are_propagated(self, mock_create_default_context, t_http_ ) mock_ssl_context.load_cert_chain.assert_called_once_with( - certfile=mock_cert_file, keyfile=mock_cert_key_file, password=mock_cert_key_password + certfile=mock_cert_file, + keyfile=mock_cert_key_file, + password=mock_cert_key_password, ) self.assertTrue(mock_ssl_context.check_hostname) self.assertEqual(mock_ssl_context.verify_mode, CERT_REQUIRED) - self.assertEqual(t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options) + self.assertEqual( + t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options + ) @patch("databricks.sql.types.create_default_context") def test_tls_cert_args_are_used_by_http_client(self, mock_create_default_context): @@ -232,7 +266,7 @@ def test_tls_cert_args_are_used_by_http_client(self, mock_create_default_context ssl_options=mock_ssl_options, ) - self.assertEqual(http_client.scheme, 'https') + self.assertEqual(http_client.scheme, "https") self.assertEqual(http_client.certfile, mock_ssl_options.tls_client_cert_file) self.assertEqual(http_client.keyfile, mock_ssl_options.tls_client_cert_key_file) self.assertIsNotNone(http_client.certfile) @@ -246,7 +280,9 @@ def test_tls_cert_args_are_used_by_http_client(self, mock_create_default_context self.assertEqual(conn_pool.ca_certs, mock_ssl_options.tls_trusted_ca_file) self.assertEqual(conn_pool.cert_file, mock_ssl_options.tls_client_cert_file) self.assertEqual(conn_pool.key_file, mock_ssl_options.tls_client_cert_key_file) - self.assertEqual(conn_pool.key_password, mock_ssl_options.tls_client_cert_key_password) + self.assertEqual( + conn_pool.key_password, mock_ssl_options.tls_client_cert_key_password + ) def test_tls_no_verify_is_respected_by_http_client(self): from databricks.sql.auth.thrift_http_client import THttpClient @@ -256,7 +292,7 @@ def test_tls_no_verify_is_respected_by_http_client(self): uri_or_host="https://example.com", ssl_options=SSLOptions(tls_verify=False), ) - self.assertEqual(http_client.scheme, 'https') + self.assertEqual(http_client.scheme, "https") http_client.open() @@ -266,16 +302,27 @@ def test_tls_no_verify_is_respected_by_http_client(self): @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch("databricks.sql.types.create_default_context") - def test_tls_no_verify_is_respected(self, mock_create_default_context, t_http_client_class): + def test_tls_no_verify_is_respected( + self, mock_create_default_context, t_http_client_class + ): mock_ssl_options = SSLOptions(tls_verify=False) mock_ssl_context = mock_ssl_options.create_ssl_context() mock_create_default_context.assert_called() - ThriftBackend("foo", 123, "bar", [], auth_provider=AuthProvider(), ssl_options=mock_ssl_options) + ThriftBackend( + "foo", + 123, + "bar", + [], + auth_provider=AuthProvider(), + ssl_options=mock_ssl_options, + ) self.assertFalse(mock_ssl_context.check_hostname) self.assertEqual(mock_ssl_context.verify_mode, CERT_NONE) - self.assertEqual(t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options) + self.assertEqual( + t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options + ) @patch("databricks.sql.auth.thrift_http_client.THttpClient") @patch("databricks.sql.types.create_default_context") @@ -287,61 +334,126 @@ def test_tls_verify_hostname_is_respected( mock_create_default_context.assert_called() ThriftBackend( - "foo", 123, "bar", [], auth_provider=AuthProvider(), ssl_options=mock_ssl_options + "foo", + 123, + "bar", + [], + auth_provider=AuthProvider(), + ssl_options=mock_ssl_options, ) self.assertFalse(mock_ssl_context.check_hostname) self.assertEqual(mock_ssl_context.verify_mode, CERT_REQUIRED) - self.assertEqual(t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options) + self.assertEqual( + t_http_client_class.call_args[1]["ssl_options"], mock_ssl_options + ) @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_port_and_host_are_respected(self, t_http_client_class): - ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + ThriftBackend( + "hostname", + 123, + "path_value", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) self.assertEqual( - t_http_client_class.call_args[1]["uri_or_host"], "https://hostname:123/path_value" + t_http_client_class.call_args[1]["uri_or_host"], + "https://hostname:123/path_value", ) @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_host_with_https_does_not_duplicate(self, t_http_client_class): - ThriftBackend("https://hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + ThriftBackend( + "https://hostname", + 123, + "path_value", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) self.assertEqual( - t_http_client_class.call_args[1]["uri_or_host"], "https://hostname:123/path_value" + t_http_client_class.call_args[1]["uri_or_host"], + "https://hostname:123/path_value", ) @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_host_with_trailing_backslash_does_not_duplicate(self, t_http_client_class): - ThriftBackend("https://hostname/", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + ThriftBackend( + "https://hostname/", + 123, + "path_value", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) self.assertEqual( - t_http_client_class.call_args[1]["uri_or_host"], "https://hostname:123/path_value" + t_http_client_class.call_args[1]["uri_or_host"], + "https://hostname:123/path_value", ) @patch("databricks.sql.auth.thrift_http_client.THttpClient") def test_socket_timeout_is_propagated(self, t_http_client_class): ThriftBackend( - "hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), _socket_timeout=129 + "hostname", + 123, + "path_value", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + _socket_timeout=129, + ) + self.assertEqual( + t_http_client_class.return_value.setTimeout.call_args[0][0], 129 * 1000 ) - self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 129 * 1000) ThriftBackend( - "hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), _socket_timeout=0 + "hostname", + 123, + "path_value", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + _socket_timeout=0, ) self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 0) - ThriftBackend("hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) - self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], 900 * 1000) ThriftBackend( - "hostname", 123, "path_value", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), _socket_timeout=None + "hostname", + 123, + "path_value", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) + self.assertEqual( + t_http_client_class.return_value.setTimeout.call_args[0][0], 900 * 1000 + ) + ThriftBackend( + "hostname", + 123, + "path_value", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + _socket_timeout=None, + ) + self.assertEqual( + t_http_client_class.return_value.setTimeout.call_args[0][0], None ) - self.assertEqual(t_http_client_class.return_value.setTimeout.call_args[0][0], None) def test_non_primitive_types_raise_error(self): columns = [ ttypes.TColumnDesc( - columnName="column 1", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE) + columnName="column 1", + typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE), ), ttypes.TColumnDesc( columnName="column 2", typeDesc=ttypes.TTypeDesc( types=[ - ttypes.TTypeEntry(userDefinedTypeEntry=ttypes.TUserDefinedTypeEntry("foo")) + ttypes.TTypeEntry( + userDefinedTypeEntry=ttypes.TUserDefinedTypeEntry("foo") + ) ] ), ), @@ -358,13 +470,16 @@ def test_hive_schema_to_description_preserves_column_names_and_types(self): # canary test columns = [ ttypes.TColumnDesc( - columnName="column 1", typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE) + columnName="column 1", + typeDesc=self._make_type_desc(ttypes.TTypeId.INT_TYPE), ), ttypes.TColumnDesc( - columnName="column 2", typeDesc=self._make_type_desc(ttypes.TTypeId.BOOLEAN_TYPE) + columnName="column 2", + typeDesc=self._make_type_desc(ttypes.TTypeId.BOOLEAN_TYPE), ), ttypes.TColumnDesc( - columnName="column 2", typeDesc=self._make_type_desc(ttypes.TTypeId.MAP_TYPE) + columnName="column 2", + typeDesc=self._make_type_desc(ttypes.TTypeId.MAP_TYPE), ), ttypes.TColumnDesc( columnName="", typeDesc=self._make_type_desc(ttypes.TTypeId.STRUCT_TYPE) @@ -395,8 +510,12 @@ def test_hive_schema_to_description_preserves_scale_and_precision(self): type=ttypes.TTypeId.DECIMAL_TYPE, typeQualifiers=ttypes.TTypeQualifiers( qualifiers={ - "precision": ttypes.TTypeQualifierValue(i32Value=10), - "scale": ttypes.TTypeQualifierValue(i32Value=100), + "precision": ttypes.TTypeQualifierValue( + i32Value=10 + ), + "scale": ttypes.TTypeQualifierValue( + i32Value=100 + ), } ), ) @@ -416,8 +535,18 @@ def test_hive_schema_to_description_preserves_scale_and_precision(self): ) def test_make_request_checks_status_code(self): - error_codes = [ttypes.TStatusCode.ERROR_STATUS, ttypes.TStatusCode.INVALID_HANDLE_STATUS] - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + error_codes = [ + ttypes.TStatusCode.ERROR_STATUS, + ttypes.TStatusCode.INVALID_HANDLE_STATUS, + ] + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) for code in error_codes: mock_error_response = Mock() @@ -455,15 +584,24 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): ), ) thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions() + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), ) with self.assertRaises(DatabaseError) as cm: thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertIn("some information about the error", str(cm.exception)) - @patch("databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) - def test_handle_execute_response_sets_compression_in_direct_results(self, build_queue): + @patch( + "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + ) + def test_handle_execute_response_sets_compression_in_direct_results( + self, build_queue + ): for resp_type in self.execute_response_types: lz4Compressed = Mock() resultSet = MagicMock() @@ -484,13 +622,24 @@ def test_handle_execute_response_sets_compression_in_direct_results(self, build_ closeOperation=None, ), ) - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) - execute_response = thrift_backend._handle_execute_response(t_execute_resp, Mock()) + execute_response = thrift_backend._handle_execute_response( + t_execute_resp, Mock() + ) self.assertEqual(execute_response.lz4_compressed, lz4Compressed) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_handle_execute_response_checks_operation_state_in_polls(self, tcli_service_class): + def test_handle_execute_response_checks_operation_state_in_polls( + self, tcli_service_class + ): tcli_service_instance = tcli_service_class.return_value error_resp = ttypes.TGetOperationStatusResp( @@ -506,7 +655,9 @@ def test_handle_execute_response_checks_operation_state_in_polls(self, tcli_serv for op_state_resp, exec_resp_type in itertools.product( [error_resp, closed_resp], self.execute_response_types ): - with self.subTest(op_state_resp=op_state_resp, exec_resp_type=exec_resp_type): + with self.subTest( + op_state_resp=op_state_resp, exec_resp_type=exec_resp_type + ): tcli_service_instance = tcli_service_class.return_value t_execute_resp = exec_resp_type( status=self.okay_status, @@ -516,7 +667,12 @@ def test_handle_execute_response_checks_operation_state_in_polls(self, tcli_serv tcli_service_instance.GetOperationStatus.return_value = op_state_resp thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions() + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), ) with self.assertRaises(DatabaseError) as cm: @@ -539,12 +695,23 @@ def test_get_status_uses_display_message_if_available(self, tcli_service_class): ) t_execute_resp = ttypes.TExecuteStatementResp( - status=self.okay_status, directResults=None, operationHandle=self.operation_handle + status=self.okay_status, + directResults=None, + operationHandle=self.operation_handle, + ) + tcli_service_instance.GetOperationStatus.return_value = ( + t_get_operation_status_resp ) - tcli_service_instance.GetOperationStatus.return_value = t_get_operation_status_resp tcli_service_instance.ExecuteStatement.return_value = t_execute_resp - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(DatabaseError) as cm: thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) @@ -577,7 +744,14 @@ def test_direct_results_uses_display_message_if_available(self, tcli_service_cla tcli_service_instance.ExecuteStatement.return_value = t_execute_resp - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(DatabaseError) as cm: thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) @@ -589,7 +763,9 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): resp_1 = resp_type( status=self.okay_status, directResults=ttypes.TSparkDirectResults( - operationStatus=ttypes.TGetOperationStatusResp(status=self.bad_status), + operationStatus=ttypes.TGetOperationStatusResp( + status=self.bad_status + ), resultSetMetadata=None, resultSet=None, closeOperation=None, @@ -600,7 +776,9 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): status=self.okay_status, directResults=ttypes.TSparkDirectResults( operationStatus=None, - resultSetMetadata=ttypes.TGetResultSetMetadataResp(status=self.bad_status), + resultSetMetadata=ttypes.TGetResultSetMetadataResp( + status=self.bad_status + ), resultSet=None, closeOperation=None, ), @@ -629,7 +807,12 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): for error_resp in [resp_1, resp_2, resp_3, resp_4]: with self.subTest(error_resp=error_resp): thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions() + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), ) with self.assertRaises(DatabaseError) as cm: @@ -637,7 +820,9 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self): self.assertIn("this is a bad error", str(cm.exception)) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_handle_execute_response_can_handle_without_direct_results(self, tcli_service_class): + def test_handle_execute_response_can_handle_without_direct_results( + self, tcli_service_class + ): tcli_service_instance = tcli_service_class.return_value for resp_type in self.execute_response_types: @@ -660,23 +845,32 @@ def test_handle_execute_response_can_handle_without_direct_results(self, tcli_se ) op_state_3 = ttypes.TGetOperationStatusResp( - status=self.okay_status, operationState=ttypes.TOperationState.FINISHED_STATE + status=self.okay_status, + operationState=ttypes.TOperationState.FINISHED_STATE, ) - tcli_service_instance.GetResultSetMetadata.return_value = self.metadata_resp + tcli_service_instance.GetResultSetMetadata.return_value = ( + self.metadata_resp + ) tcli_service_instance.GetOperationStatus.side_effect = [ op_state_1, op_state_2, op_state_3, ] thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions() + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), ) results_message_response = thrift_backend._handle_execute_response( execute_resp, Mock() ) self.assertEqual( - results_message_response.status, ttypes.TOperationState.FINISHED_STATE + results_message_response.status, + ttypes.TOperationState.FINISHED_STATE, ) def test_handle_execute_response_can_handle_with_direct_results(self): @@ -701,7 +895,12 @@ def test_handle_execute_response_can_handle_with_direct_results(self): ) thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions() + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), ) thrift_backend._results_message_to_execute_response = Mock() @@ -731,9 +930,13 @@ def test_use_arrow_schema_if_available(self, tcli_service_class): operationHandle=self.operation_handle, ) - tcli_service_instance.GetResultSetMetadata.return_value = t_get_result_set_metadata_resp + tcli_service_instance.GetResultSetMetadata.return_value = ( + t_get_result_set_metadata_resp + ) thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response(t_execute_resp, Mock()) + execute_response = thrift_backend._handle_execute_response( + t_execute_resp, Mock() + ) self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock) @@ -760,10 +963,13 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertEqual( - hive_schema_mock, thrift_backend._hive_schema_to_arrow_schema.call_args[0][0] + hive_schema_mock, + thrift_backend._hive_schema_to_arrow_schema.call_args[0][0], ) - @patch("databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) + @patch( + "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_direct_results( self, tcli_service_class, build_queue @@ -794,14 +1000,20 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( operationHandle=self.operation_handle, ) - tcli_service_instance.GetResultSetMetadata.return_value = self.metadata_resp + tcli_service_instance.GetResultSetMetadata.return_value = ( + self.metadata_resp + ) thrift_backend = self._make_fake_thrift_backend() - execute_response = thrift_backend._handle_execute_response(execute_resp, Mock()) + execute_response = thrift_backend._handle_execute_response( + execute_resp, Mock() + ) self.assertEqual(has_more_rows, execute_response.has_more_rows) - @patch("databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) + @patch( + "databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock() + ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_handle_execute_response_reads_has_more_rows_in_result_response( self, tcli_service_class, build_queue @@ -836,8 +1048,12 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( ) tcli_service_instance.FetchResults.return_value = fetch_results_resp - tcli_service_instance.GetOperationStatus.return_value = operation_status_resp - tcli_service_instance.GetResultSetMetadata.return_value = self.metadata_resp + tcli_service_instance.GetOperationStatus.return_value = ( + operation_status_resp + ) + tcli_service_instance.GetResultSetMetadata.return_value = ( + self.metadata_resp + ) thrift_backend = self._make_fake_thrift_backend() thrift_backend._handle_execute_response(execute_resp, Mock()) @@ -864,7 +1080,8 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): startRowOffset=0, rows=[], arrowBatches=[ - ttypes.TSparkArrowBatch(batch=bytearray(), rowCount=15) for _ in range(10) + ttypes.TSparkArrowBatch(batch=bytearray(), rowCount=15) + for _ in range(10) ], ), resultSetMetadata=ttypes.TGetResultSetMetadataResp( @@ -885,7 +1102,14 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): .to_pybytes() ) - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) arrow_queue, has_more_results = thrift_backend.fetch_results( op_handle=Mock(), max_rows=1, @@ -899,11 +1123,20 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): self.assertEqual(arrow_queue.n_valid_rows, 15 * 10) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_execute_statement_calls_client_and_handle_execute_response(self, tcli_service_class): + def test_execute_statement_calls_client_and_handle_execute_response( + self, tcli_service_class + ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.ExecuteStatement.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -914,14 +1147,25 @@ def test_execute_statement_calls_client_and_handle_execute_response(self, tcli_s self.assertEqual(req.getDirectResults, get_direct_results) self.assertEqual(req.statement, "foo") # Check response handling - thrift_backend._handle_execute_response.assert_called_with(response, cursor_mock) + thrift_backend._handle_execute_response.assert_called_with( + response, cursor_mock + ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_get_catalogs_calls_client_and_handle_execute_response(self, tcli_service_class): + def test_get_catalogs_calls_client_and_handle_execute_response( + self, tcli_service_class + ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetCatalogs.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -931,14 +1175,25 @@ def test_get_catalogs_calls_client_and_handle_execute_response(self, tcli_servic get_direct_results = ttypes.TSparkGetDirectResults(maxRows=100, maxBytes=200) self.assertEqual(req.getDirectResults, get_direct_results) # Check response handling - thrift_backend._handle_execute_response.assert_called_with(response, cursor_mock) + thrift_backend._handle_execute_response.assert_called_with( + response, cursor_mock + ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_get_schemas_calls_client_and_handle_execute_response(self, tcli_service_class): + def test_get_schemas_calls_client_and_handle_execute_response( + self, tcli_service_class + ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetSchemas.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -957,14 +1212,25 @@ def test_get_schemas_calls_client_and_handle_execute_response(self, tcli_service self.assertEqual(req.catalogName, "catalog_pattern") self.assertEqual(req.schemaName, "schema_pattern") # Check response handling - thrift_backend._handle_execute_response.assert_called_with(response, cursor_mock) + thrift_backend._handle_execute_response.assert_called_with( + response, cursor_mock + ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_get_tables_calls_client_and_handle_execute_response(self, tcli_service_class): + def test_get_tables_calls_client_and_handle_execute_response( + self, tcli_service_class + ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetTables.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -987,14 +1253,25 @@ def test_get_tables_calls_client_and_handle_execute_response(self, tcli_service_ self.assertEqual(req.tableName, "table_pattern") self.assertEqual(req.tableTypes, ["type1", "type2"]) # Check response handling - thrift_backend._handle_execute_response.assert_called_with(response, cursor_mock) + thrift_backend._handle_execute_response.assert_called_with( + response, cursor_mock + ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_get_columns_calls_client_and_handle_execute_response(self, tcli_service_class): + def test_get_columns_calls_client_and_handle_execute_response( + self, tcli_service_class + ): tcli_service_instance = tcli_service_class.return_value response = Mock() tcli_service_instance.GetColumns.return_value = response - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) thrift_backend._handle_execute_response = Mock() cursor_mock = Mock() @@ -1017,21 +1294,37 @@ def test_get_columns_calls_client_and_handle_execute_response(self, tcli_service self.assertEqual(req.tableName, "table_pattern") self.assertEqual(req.columnName, "column_pattern") # Check response handling - thrift_backend._handle_execute_response.assert_called_with(response, cursor_mock) + thrift_backend._handle_execute_response.assert_called_with( + response, cursor_mock + ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_open_session_user_provided_session_id_optional(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) thrift_backend.open_session({}, None, None) self.assertEqual(len(tcli_service_instance.OpenSession.call_args_list), 1) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_op_handle_respected_in_close_command(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) thrift_backend.close_command(self.operation_handle) self.assertEqual( tcli_service_instance.CloseOperation.call_args[0][0].operationHandle, @@ -1041,20 +1334,32 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) def test_session_handle_respected_in_close_session(self, tcli_service_class): tcli_service_instance = tcli_service_class.return_value - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) thrift_backend.close_session(self.session_handle) self.assertEqual( - tcli_service_instance.CloseSession.call_args[0][0].sessionHandle, self.session_handle + tcli_service_instance.CloseSession.call_args[0][0].sessionHandle, + self.session_handle, ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_non_arrow_non_column_based_set_triggers_exception(self, tcli_service_class): + def test_non_arrow_non_column_based_set_triggers_exception( + self, tcli_service_class + ): tcli_service_instance = tcli_service_class.return_value results_mock = Mock() results_mock.startRowOffset = 0 execute_statement_resp = ttypes.TExecuteStatementResp( - status=self.okay_status, directResults=None, operationHandle=self.operation_handle + status=self.okay_status, + directResults=None, + operationHandle=self.operation_handle, ) metadata_resp = ttypes.TGetResultSetMetadataResp( @@ -1075,11 +1380,20 @@ def test_non_arrow_non_column_based_set_triggers_exception(self, tcli_service_cl with self.assertRaises(OperationalError) as cm: thrift_backend.execute_command("foo", Mock(), 100, 100, Mock(), Mock()) - self.assertIn("Expected results to be in Arrow or column based format", str(cm.exception)) + self.assertIn( + "Expected results to be in Arrow or column based format", str(cm.exception) + ) def test_create_arrow_table_raises_error_for_unsupported_type(self): t_row_set = ttypes.TRowSet() - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(OperationalError): thrift_backend._create_arrow_table(t_row_set, Mock(), None, Mock()) @@ -1088,7 +1402,14 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): def test_create_arrow_table_calls_correct_conversion_method( self, convert_col_mock, convert_arrow_mock ): - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) convert_arrow_mock.return_value = (MagicMock(), Mock()) convert_col_mock.return_value = (MagicMock(), Mock()) @@ -1099,18 +1420,31 @@ def test_create_arrow_table_calls_correct_conversion_method( description = Mock() t_col_set = ttypes.TRowSet(columns=cols) - thrift_backend._create_arrow_table(t_col_set, lz4_compressed, schema, description) + thrift_backend._create_arrow_table( + t_col_set, lz4_compressed, schema, description + ) convert_arrow_mock.assert_not_called() convert_col_mock.assert_called_once_with(cols, description) t_arrow_set = ttypes.TRowSet(arrowBatches=arrow_batches) thrift_backend._create_arrow_table(t_arrow_set, lz4_compressed, schema, Mock()) - convert_arrow_mock.assert_called_once_with(arrow_batches, lz4_compressed, schema) + convert_arrow_mock.assert_called_once_with( + arrow_batches, lz4_compressed, schema + ) @patch("lz4.frame.decompress") @patch("pyarrow.ipc.open_stream") - def test_convert_arrow_based_set_to_arrow_table(self, open_stream_mock, lz4_decompress_mock): - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + def test_convert_arrow_based_set_to_arrow_table( + self, open_stream_mock, lz4_decompress_mock + ): + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) lz4_decompress_mock.return_value = bytearray("Testing", "utf-8") @@ -1142,15 +1476,23 @@ def test_convert_column_based_set_to_arrow_table_without_nulls(self): t_cols = [ ttypes.TColumn(i32Val=ttypes.TI32Column(values=[1, 2, 3], nulls=bytes(1))), ttypes.TColumn( - stringVal=ttypes.TStringColumn(values=["s1", "s2", "s3"], nulls=bytes(1)) + stringVal=ttypes.TStringColumn( + values=["s1", "s2", "s3"], nulls=bytes(1) + ) ), - ttypes.TColumn(doubleVal=ttypes.TDoubleColumn(values=[1.15, 2.2, 3.3], nulls=bytes(1))), ttypes.TColumn( - binaryVal=ttypes.TBinaryColumn(values=[b"\x11", b"\x22", b"\x33"], nulls=bytes(1)) + doubleVal=ttypes.TDoubleColumn(values=[1.15, 2.2, 3.3], nulls=bytes(1)) + ), + ttypes.TColumn( + binaryVal=ttypes.TBinaryColumn( + values=[b"\x11", b"\x22", b"\x33"], nulls=bytes(1) + ) ), ] - arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table(t_cols, description) + arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table( + t_cols, description + ) self.assertEqual(n_rows, 3) # Check schema, column names and types @@ -1175,19 +1517,29 @@ def test_convert_column_based_set_to_arrow_table_with_nulls(self): description = [(name,) for name in field_names] t_cols = [ - ttypes.TColumn(i32Val=ttypes.TI32Column(values=[1, 2, 3], nulls=bytes([1]))), ttypes.TColumn( - stringVal=ttypes.TStringColumn(values=["s1", "s2", "s3"], nulls=bytes([2])) + i32Val=ttypes.TI32Column(values=[1, 2, 3], nulls=bytes([1])) ), ttypes.TColumn( - doubleVal=ttypes.TDoubleColumn(values=[1.15, 2.2, 3.3], nulls=bytes([4])) + stringVal=ttypes.TStringColumn( + values=["s1", "s2", "s3"], nulls=bytes([2]) + ) ), ttypes.TColumn( - binaryVal=ttypes.TBinaryColumn(values=[b"\x11", b"\x22", b"\x33"], nulls=bytes([3])) + doubleVal=ttypes.TDoubleColumn( + values=[1.15, 2.2, 3.3], nulls=bytes([4]) + ) + ), + ttypes.TColumn( + binaryVal=ttypes.TBinaryColumn( + values=[b"\x11", b"\x22", b"\x33"], nulls=bytes([3]) + ) ), ] - arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table(t_cols, description) + arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table( + t_cols, description + ) self.assertEqual(n_rows, 3) # Check data @@ -1203,15 +1555,23 @@ def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self): t_cols = [ ttypes.TColumn(i32Val=ttypes.TI32Column(values=[1, 2, 3], nulls=bytes(1))), ttypes.TColumn( - stringVal=ttypes.TStringColumn(values=["s1", "s2", "s3"], nulls=bytes(1)) + stringVal=ttypes.TStringColumn( + values=["s1", "s2", "s3"], nulls=bytes(1) + ) ), - ttypes.TColumn(doubleVal=ttypes.TDoubleColumn(values=[1.15, 2.2, 3.3], nulls=bytes(1))), ttypes.TColumn( - binaryVal=ttypes.TBinaryColumn(values=[b"\x11", b"\x22", b"\x33"], nulls=bytes(1)) + doubleVal=ttypes.TDoubleColumn(values=[1.15, 2.2, 3.3], nulls=bytes(1)) + ), + ttypes.TColumn( + binaryVal=ttypes.TBinaryColumn( + values=[b"\x11", b"\x22", b"\x33"], nulls=bytes(1) + ) ), ] - arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table(t_cols, description) + arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table( + t_cols, description + ) self.assertEqual(n_rows, 3) # Check schema, column names and types @@ -1257,8 +1617,12 @@ def test_handle_execute_response_sets_active_op_handle(self): self.assertEqual(mock_resp.operationHandle, mock_cursor.active_op_handle) @patch("databricks.sql.auth.thrift_http_client.THttpClient") - @patch("databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus") - @patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) + @patch( + "databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus" + ) + @patch( + "databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory + ) def test_make_request_will_retry_GetOperationStatus( self, mock_retry_policy, mock_GetOperationStatus, t_transport_class ): @@ -1270,7 +1634,9 @@ def test_make_request_will_retry_GetOperationStatus( this_gos_name = "GetOperationStatus" mock_GetOperationStatus.__name__ = this_gos_name - mock_GetOperationStatus.side_effect = OSError(errno.ETIMEDOUT, "Connection timed out") + mock_GetOperationStatus.side_effect = OSError( + errno.ETIMEDOUT, "Connection timed out" + ) protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(t_transport_class) client = Client(protocol) @@ -1299,12 +1665,18 @@ def test_make_request_will_retry_GetOperationStatus( self.assertEqual( NoRetryReason.OUT_OF_ATTEMPTS.value, cm.exception.context["no-retry-reason"] ) - self.assertEqual(f"{EXPECTED_RETRIES}/{EXPECTED_RETRIES}", cm.exception.context["attempt"]) + self.assertEqual( + f"{EXPECTED_RETRIES}/{EXPECTED_RETRIES}", cm.exception.context["attempt"] + ) # Unusual OSError code - mock_GetOperationStatus.side_effect = OSError(errno.EEXIST, "File does not exist") + mock_GetOperationStatus.side_effect = OSError( + errno.EEXIST, "File does not exist" + ) - with self.assertLogs("databricks.sql.thrift_backend", level=logging.WARNING) as cm: + with self.assertLogs( + "databricks.sql.thrift_backend", level=logging.WARNING + ) as cm: with self.assertRaises(RequestError): thrift_backend.make_request(client.GetOperationStatus, req) @@ -1320,8 +1692,12 @@ def test_make_request_will_retry_GetOperationStatus( cm.output[0], ) - @patch("databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus") - @patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) + @patch( + "databricks.sql.thrift_api.TCLIService.TCLIService.Client.GetOperationStatus" + ) + @patch( + "databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory + ) def test_make_request_will_retry_GetOperationStatus_for_http_error( self, mock_retry_policy, mock_gos ): @@ -1366,10 +1742,14 @@ def test_make_request_will_retry_GetOperationStatus_for_http_error( self.assertEqual( NoRetryReason.OUT_OF_ATTEMPTS.value, cm.exception.context["no-retry-reason"] ) - self.assertEqual(f"{EXPECTED_RETRIES}/{EXPECTED_RETRIES}", cm.exception.context["attempt"]) + self.assertEqual( + f"{EXPECTED_RETRIES}/{EXPECTED_RETRIES}", cm.exception.context["attempt"] + ) @patch("thrift.transport.THttpClient.THttpClient") - def test_make_request_wont_retry_if_error_code_not_429_or_503(self, t_transport_class): + def test_make_request_wont_retry_if_error_code_not_429_or_503( + self, t_transport_class + ): t_transport_instance = t_transport_class.return_value t_transport_instance.code = 430 t_transport_instance.headers = {"Retry-After": "1"} @@ -1377,7 +1757,14 @@ def test_make_request_wont_retry_if_error_code_not_429_or_503(self, t_transport_ mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(OperationalError) as cm: thrift_backend.make_request(mock_method, Mock()) @@ -1385,7 +1772,9 @@ def test_make_request_wont_retry_if_error_code_not_429_or_503(self, t_transport_ self.assertIn("This method fails", str(cm.exception.message_with_context())) @patch("databricks.sql.auth.thrift_http_client.THttpClient") - @patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) + @patch( + "databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory + ) def test_make_request_will_retry_stop_after_attempts_count_if_retryable( self, mock_retry_policy, t_transport_class ): @@ -1417,13 +1806,22 @@ def test_make_request_will_retry_stop_after_attempts_count_if_retryable( self.assertEqual(mock_method.call_count, 14) @patch("databricks.sql.auth.thrift_http_client.THttpClient") - def test_make_request_will_read_error_message_headers_if_set(self, t_transport_class): + def test_make_request_will_read_error_message_headers_if_set( + self, t_transport_class + ): t_transport_instance = t_transport_class.return_value mock_method = Mock() mock_method.__name__ = "method name" mock_method.side_effect = Exception("This method fails") - thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + thrift_backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) error_headers = [ [("x-thriftserver-error-message", "thrift server error message")], @@ -1457,12 +1855,18 @@ def make_table_and_desc( ): int_col = [int_constant for _ in range(height)] decimal_col = [decimal_constant for _ in range(height)] - data = OrderedDict({"col{}".format(i): int_col for i in range(width - n_decimal_cols)}) - decimals = OrderedDict({"col_dec{}".format(i): decimal_col for i in range(n_decimal_cols)}) + data = OrderedDict( + {"col{}".format(i): int_col for i in range(width - n_decimal_cols)} + ) + decimals = OrderedDict( + {"col_dec{}".format(i): decimal_col for i in range(n_decimal_cols)} + ) data.update(decimals) int_desc = [("", "int")] * (width - n_decimal_cols) - decimal_desc = [("", "decimal", None, None, precision, scale, None)] * n_decimal_cols + decimal_desc = [ + ("", "decimal", None, None, precision, scale, None) + ] * n_decimal_cols description = int_desc + decimal_desc table = pyarrow.Table.from_pydict(data) @@ -1495,25 +1899,36 @@ def test_arrow_decimal_conversion(self): if height > 0: if i < width - n_decimal_cols: self.assertEqual( - decimal_converted_table.field(i).type, pyarrow.int64() + decimal_converted_table.field(i).type, + pyarrow.int64(), ) else: self.assertEqual( decimal_converted_table.field(i).type, - pyarrow.decimal128(precision=precision, scale=scale), + pyarrow.decimal128( + precision=precision, scale=scale + ), ) int_col = [int_constant for _ in range(height)] decimal_col = [Decimal(decimal_constant) for _ in range(height)] expected_result = OrderedDict( - {"col{}".format(i): int_col for i in range(width - n_decimal_cols)} + { + "col{}".format(i): int_col + for i in range(width - n_decimal_cols) + } ) decimals = OrderedDict( - {"col_dec{}".format(i): decimal_col for i in range(n_decimal_cols)} + { + "col_dec{}".format(i): decimal_col + for i in range(n_decimal_cols) + } ) expected_result.update(decimals) - self.assertEqual(decimal_converted_table.to_pydict(), expected_result) + self.assertEqual( + decimal_converted_table.to_pydict(), expected_result + ) @patch("thrift.transport.THttpClient.THttpClient") def test_retry_args_passthrough(self, mock_http_client): @@ -1524,7 +1939,13 @@ def test_retry_args_passthrough(self, mock_http_client): "_retry_stop_after_attempts_duration": 100, } backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), **retry_delay_args + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + **retry_delay_args, ) for arg, val in retry_delay_args.items(): self.assertEqual(getattr(backend, arg), val) @@ -1533,17 +1954,28 @@ def test_retry_args_passthrough(self, mock_http_client): def test_retry_args_bounding(self, mock_http_client): retry_delay_test_args_and_expected_values = {} for k, (_, _, min, max) in databricks.sql.thrift_backend._retry_policy.items(): - retry_delay_test_args_and_expected_values[k] = ((min - 1, min), (max + 1, max)) + retry_delay_test_args_and_expected_values[k] = ( + (min - 1, min), + (max + 1, max), + ) for i in range(2): retry_delay_args = { - k: v[i][0] for (k, v) in retry_delay_test_args_and_expected_values.items() + k: v[i][0] + for (k, v) in retry_delay_test_args_and_expected_values.items() } backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), **retry_delay_args + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + **retry_delay_args, ) retry_delay_expected_vals = { - k: v[i][1] for (k, v) in retry_delay_test_args_and_expected_values.items() + k: v[i][1] + for (k, v) in retry_delay_test_args_and_expected_values.items() } for arg, val in retry_delay_expected_vals.items(): self.assertEqual(getattr(backend, arg), val) @@ -1560,7 +1992,14 @@ def test_configuration_passthrough(self, tcli_client_class): "42": "42", } - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) backend.open_session(mock_config, None, None) open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] @@ -1571,7 +2010,14 @@ def test_cant_set_timestamp_as_string_to_true(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp mock_config = {"spark.thriftserver.arrowBasedRowSet.timestampAsString": True} - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(databricks.sql.Error) as cm: backend.open_session(mock_config, None, None) @@ -1590,7 +2036,14 @@ def _construct_open_session_with_namespace(self, can_use_multiple_cats, cat, sch def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): tcli_service_instance = tcli_client_class.return_value - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) initial_cat_schem_args = [("cat", None), (None, "schem"), ("cat", "schem")] for cat, schem in initial_cat_schem_args: @@ -1601,26 +2054,46 @@ def test_initial_namespace_passthrough_to_open_session(self, tcli_client_class): backend.open_session({}, cat, schem) - open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] + open_session_req = tcli_client_class.return_value.OpenSession.call_args[ + 0 + ][0] self.assertEqual(open_session_req.initialNamespace.catalogName, cat) self.assertEqual(open_session_req.initialNamespace.schemaName, schem) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_can_use_multiple_catalogs_is_set_in_open_session_req(self, tcli_client_class): + def test_can_use_multiple_catalogs_is_set_in_open_session_req( + self, tcli_client_class + ): tcli_service_instance = tcli_client_class.return_value tcli_service_instance.OpenSession.return_value = self.open_session_resp - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) backend.open_session({}, None, None) open_session_req = tcli_client_class.return_value.OpenSession.call_args[0][0] self.assertTrue(open_session_req.canUseMultipleCatalogs) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) - def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog(self, tcli_client_class): + def test_can_use_multiple_catalogs_is_false_fails_with_initial_catalog( + self, tcli_client_class + ): tcli_service_instance = tcli_client_class.return_value - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) # If the initial catalog is set, but server returns canUseMultipleCatalogs=False, we # expect failure. If the initial catalog isn't set, then canUseMultipleCatalogs=False # is fine @@ -1658,13 +2131,21 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class): initialNamespace=ttypes.TNamespace(catalogName="cat", schemaName="schem"), ) - backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions()) + backend = ThriftBackend( + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + ) with self.assertRaises(InvalidServerResponseError) as cm: backend.open_session({}, "cat", "schem") self.assertIn( - "Setting initial namespace not supported by the DBR version", str(cm.exception) + "Setting initial namespace not supported by the DBR version", + str(cm.exception), ) @patch("databricks.sql.thrift_backend.TCLIService.Client", autospec=True) @@ -1686,10 +2167,18 @@ def test_execute_command_sets_complex_type_fields_correctly( complex_arg_types["_use_arrow_native_decimals"] = decimals thrift_backend = ThriftBackend( - "foobar", 443, "path", [], auth_provider=AuthProvider(), ssl_options=SSLOptions(), **complex_arg_types + "foobar", + 443, + "path", + [], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + **complex_arg_types, ) thrift_backend.execute_command(Mock(), Mock(), 100, 100, Mock(), Mock()) - t_execute_statement_req = tcli_service_instance.ExecuteStatement.call_args[0][0] + t_execute_statement_req = tcli_service_instance.ExecuteStatement.call_args[ + 0 + ][0] # If the value is unset, the native type should default to True self.assertEqual( t_execute_statement_req.useArrowNativeTypes.timestampAsArrow, @@ -1703,7 +2192,9 @@ def test_execute_command_sets_complex_type_fields_correctly( t_execute_statement_req.useArrowNativeTypes.complexTypesAsArrow, complex_arg_types.get("_use_arrow_native_complex_types", True), ) - self.assertFalse(t_execute_statement_req.useArrowNativeTypes.intervalTypesAsArrow) + self.assertFalse( + t_execute_statement_req.useArrowNativeTypes.intervalTypesAsArrow + ) if __name__ == "__main__": From 55105febf4c771cc9286bba81443d3c6eb6b09b4 Mon Sep 17 00:00:00 2001 From: Jacky Hu Date: Fri, 18 Oct 2024 10:35:58 -0700 Subject: [PATCH 06/13] Prepare release v3.5.0 (#457) Prepare release 3.5.0 Signed-off-by: Jacky Hu --- CHANGELOG.md | 5 +++++ pyproject.toml | 2 +- src/databricks/sql/__init__.py | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 14db0f8b..93391193 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Release History +# 3.5.0 (2024-10-18) + +- Create a non pyarrow flow to handle small results for the column set (databricks/databricks-sql-python#440 by @jprakash-db) +- Fix: On non-retryable error, ensure PySQL includes useful information in error (databricks/databricks-sql-python#447 by @shivam2680) + # 3.4.0 (2024-08-27) - Unpin pandas to support v2.2.2 (databricks/databricks-sql-python#416 by @kfollesdal) diff --git a/pyproject.toml b/pyproject.toml index 69d0c274..79b7b926 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "3.4.0" +version = "3.5.0" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index c0fdc2f1..8bbcc5f6 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -68,7 +68,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "3.4.0" +__version__ = "3.5.0" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy From d3cb62c4c4c0a7dc955a02504d36e506e48cf0e4 Mon Sep 17 00:00:00 2001 From: Jacky Hu Date: Fri, 25 Oct 2024 09:41:40 -0700 Subject: [PATCH 07/13] [PECO-2051] Add custom auth headers into cloud fetch request (#460) Signed-off-by: Jacky Hu --- src/databricks/sql/cloudfetch/downloader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 03c70054..228e07d6 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -100,6 +100,7 @@ def run(self) -> DownloadedFile: self.link.fileLink, timeout=self.settings.download_timeout, verify=self._ssl_options.tls_verify, + headers=self.link.httpHeaders # TODO: Pass cert from `self._ssl_options` ) response.raise_for_status() From ecdddba46a49c5122a5f8d5694b6b148466165d5 Mon Sep 17 00:00:00 2001 From: Jacky Hu Date: Fri, 25 Oct 2024 11:47:49 -0700 Subject: [PATCH 08/13] Prepare release 3.6.0 (#461) Signed-off-by: Jacky Hu --- CHANGELOG.md | 4 ++++ pyproject.toml | 2 +- src/databricks/sql/__init__.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 93391193..e1a70f96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Release History +# 3.6.0 (2024-10-25) + +- Support encryption headers in the cloud fetch request (https://github.com/databricks/databricks-sql-python/pull/460 by @jackyhu-db) + # 3.5.0 (2024-10-18) - Create a non pyarrow flow to handle small results for the column set (databricks/databricks-sql-python#440 by @jprakash-db) diff --git a/pyproject.toml b/pyproject.toml index 79b7b926..1747d21b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "databricks-sql-connector" -version = "3.5.0" +version = "3.6.0" description = "Databricks SQL Connector for Python" authors = ["Databricks "] license = "Apache-2.0" diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index 8bbcc5f6..42167b00 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -68,7 +68,7 @@ def __repr__(self): DATE = DBAPITypeObject("date") ROWID = DBAPITypeObject() -__version__ = "3.5.0" +__version__ = "3.6.0" USER_AGENT_NAME = "PyDatabricksSqlConnector" # These two functions are pyhive legacy From 43fa964c4bc679d73eaa21219f01705dfb021cd9 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Wed, 20 Nov 2024 13:44:55 +0530 Subject: [PATCH 09/13] [ PECO - 1768 ] PySQL: adjust HTTP retry logic to align with Go and Nodejs drivers (#467) * Added the exponential backoff code * Added the exponential backoff algorithm and refractored the code * Added jitter and added unit tests * Reformatted * Fixed the test_retry_exponential_backoff integration test --- src/databricks/sql/auth/retry.py | 26 +++++++----- src/databricks/sql/thrift_backend.py | 4 +- tests/e2e/common/retry_test_mixins.py | 8 ++-- tests/unit/test_retry.py | 57 ++++++++++++++++++--------- 4 files changed, 61 insertions(+), 34 deletions(-) diff --git a/src/databricks/sql/auth/retry.py b/src/databricks/sql/auth/retry.py index 0c6547cb..ec321bed 100755 --- a/src/databricks/sql/auth/retry.py +++ b/src/databricks/sql/auth/retry.py @@ -1,4 +1,5 @@ import logging +import random import time import typing from enum import Enum @@ -285,25 +286,30 @@ def sleep_for_retry(self, response: BaseHTTPResponse) -> bool: """ retry_after = self.get_retry_after(response) if retry_after: - backoff = self.get_backoff_time() - proposed_wait = max(backoff, retry_after) - self.check_proposed_wait(proposed_wait) - time.sleep(proposed_wait) - return True + proposed_wait = retry_after + else: + proposed_wait = self.get_backoff_time() - return False + proposed_wait = min(proposed_wait, self.delay_max) + self.check_proposed_wait(proposed_wait) + time.sleep(proposed_wait) + return True def get_backoff_time(self) -> float: - """Calls urllib3's built-in get_backoff_time. + """ + This method implements the exponential backoff algorithm to calculate the delay between retries. Never returns a value larger than self.delay_max A MaxRetryDurationError will be raised if the calculated backoff would exceed self.max_attempts_duration - Note: within urllib3, a backoff is only calculated in cases where a Retry-After header is not present - in the previous unsuccessful request and `self.respect_retry_after_header` is True (which is always true) + :return: """ - proposed_backoff = super().get_backoff_time() + current_attempt = self.stop_after_attempts_count - self.total + proposed_backoff = (2**current_attempt) * self.delay_min + if self.backoff_jitter != 0.0: + proposed_backoff += random.random() * self.backoff_jitter + proposed_backoff = min(proposed_backoff, self.delay_max) self.check_proposed_wait(proposed_backoff) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index cf5cd906..29be5482 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -64,8 +64,8 @@ # - 900s attempts-duration lines up w ODBC/JDBC drivers (for cluster startup > 10 mins) _retry_policy = { # (type, default, min, max) "_retry_delay_min": (float, 1, 0.1, 60), - "_retry_delay_max": (float, 60, 5, 3600), - "_retry_stop_after_attempts_count": (int, 30, 1, 60), + "_retry_delay_max": (float, 30, 5, 3600), + "_retry_stop_after_attempts_count": (int, 5, 1, 60), "_retry_stop_after_attempts_duration": (float, 900, 1, 86400), "_retry_delay_default": (float, 5, 1, 60), } diff --git a/tests/e2e/common/retry_test_mixins.py b/tests/e2e/common/retry_test_mixins.py index 7dd5f745..942955ca 100755 --- a/tests/e2e/common/retry_test_mixins.py +++ b/tests/e2e/common/retry_test_mixins.py @@ -174,7 +174,7 @@ def test_retry_max_count_not_exceeded(self): def test_retry_exponential_backoff(self): """GIVEN the retry policy is configured for reasonable exponential backoff WHEN the server sends nothing but 429 responses with retry-afters - THEN the connector will use those retry-afters as a floor + THEN the connector will use those retry-afters values as delay """ retry_policy = self._retry_policy.copy() retry_policy["_retry_delay_min"] = 1 @@ -191,10 +191,10 @@ def test_retry_exponential_backoff(self): assert isinstance(cm.value.args[1], MaxRetryDurationError) # With setting delay_min to 1, the expected retry delays should be: - # 3, 3, 4 - # The first 2 retries are allowed, the 3rd retry puts the total duration over the limit + # 3, 3, 3, 3 + # The first 3 retries are allowed, the 4th retry puts the total duration over the limit # of 10 seconds - assert mock_obj.return_value.getresponse.call_count == 3 + assert mock_obj.return_value.getresponse.call_count == 4 assert duration > 6 # Should be less than 7, but this is a safe margin for CI/CD slowness diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py index 2108af4f..b7648ffb 100644 --- a/tests/unit/test_retry.py +++ b/tests/unit/test_retry.py @@ -1,11 +1,9 @@ -from os import error import time -from unittest.mock import Mock, patch +from unittest.mock import patch, call import pytest -from requests import Request from urllib3 import HTTPResponse -from databricks.sql.auth.retry import DatabricksRetryPolicy, RequestHistory - +from databricks.sql.auth.retry import DatabricksRetryPolicy, RequestHistory, CommandType +from urllib3.exceptions import MaxRetryError class TestRetry: @pytest.fixture() @@ -25,32 +23,55 @@ def error_history(self) -> RequestHistory: method="POST", url=None, error=None, status=503, redirect_location=None ) + def calculate_backoff_time(self, attempt, delay_min, delay_max): + exponential_backoff_time = (2**attempt) * delay_min + return min(exponential_backoff_time, delay_max) + @patch("time.sleep") def test_sleep__no_retry_after(self, t_mock, retry_policy, error_history): retry_policy._retry_start_time = time.time() retry_policy.history = [error_history, error_history] retry_policy.sleep(HTTPResponse(status=503)) - t_mock.assert_called_with(2) + + expected_backoff_time = self.calculate_backoff_time(0, retry_policy.delay_min, retry_policy.delay_max) + t_mock.assert_called_with(expected_backoff_time) @patch("time.sleep") - def test_sleep__retry_after_is_binding(self, t_mock, retry_policy, error_history): + def test_sleep__no_retry_after_header__multiple_retries(self, t_mock, retry_policy): + num_attempts = retry_policy.stop_after_attempts_count + retry_policy._retry_start_time = time.time() - retry_policy.history = [error_history, error_history] - retry_policy.sleep(HTTPResponse(status=503, headers={"Retry-After": "3"})) - t_mock.assert_called_with(3) + retry_policy.command_type = CommandType.OTHER + + for attempt in range(num_attempts): + retry_policy.sleep(HTTPResponse(status=503)) + # Internally urllib3 calls the increment function generating a new instance for every retry + retry_policy = retry_policy.increment() + + expected_backoff_times = [] + for attempt in range(num_attempts): + expected_backoff_times.append(self.calculate_backoff_time(attempt, retry_policy.delay_min, retry_policy.delay_max)) + + # Asserts if the sleep value was called in the expected order + t_mock.assert_has_calls([call(expected_time) for expected_time in expected_backoff_times]) @patch("time.sleep") - def test_sleep__retry_after_present_but_not_binding( - self, t_mock, retry_policy, error_history - ): + def test_excessive_retry_attempts_error(self, t_mock, retry_policy): + # Attempting more than stop_after_attempt_count + num_attempts = retry_policy.stop_after_attempts_count + 1 + retry_policy._retry_start_time = time.time() - retry_policy.history = [error_history, error_history] - retry_policy.sleep(HTTPResponse(status=503, headers={"Retry-After": "1"})) - t_mock.assert_called_with(2) + retry_policy.command_type = CommandType.OTHER + + with pytest.raises(MaxRetryError): + for attempt in range(num_attempts): + retry_policy.sleep(HTTPResponse(status=503)) + # Internally urllib3 calls the increment function generating a new instance for every retry + retry_policy = retry_policy.increment() @patch("time.sleep") - def test_sleep__retry_after_surpassed(self, t_mock, retry_policy, error_history): + def test_sleep__retry_after_present(self, t_mock, retry_policy, error_history): retry_policy._retry_start_time = time.time() retry_policy.history = [error_history, error_history, error_history] retry_policy.sleep(HTTPResponse(status=503, headers={"Retry-After": "3"})) - t_mock.assert_called_with(4) + t_mock.assert_called_with(3) From 328aeb528f714cb22055e923d27c6d8cf1a9e0fd Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Tue, 26 Nov 2024 21:31:32 +0530 Subject: [PATCH 10/13] [ PECO-2065 ] Create the async execution flow for the PySQL Connector (#463) * Built the basic flow for the async pipeline - testing is remaining * Implemented the flow for the get_execution_result, but the problem of invalid operation handle still persists * Missed adding some files in previous commit * Working prototype of execute_async, get_query_state and get_execution_result * Added integration tests for execute_async * add docs for functions * Refractored the async code * Fixed java doc * Reformatted --- src/databricks/sql/client.py | 105 +++++++++++++++++++++++++++ src/databricks/sql/thrift_backend.py | 76 ++++++++++++++++++- tests/e2e/test_driver.py | 23 ++++++ 3 files changed, 203 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 4e0ab941..8ea81e12 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1,3 +1,4 @@ +import time from typing import Dict, Tuple, List, Optional, Any, Union, Sequence import pandas @@ -47,6 +48,7 @@ from databricks.sql.thrift_api.TCLIService.ttypes import ( TSparkParameter, + TOperationState, ) @@ -430,6 +432,8 @@ def __init__( self.escaper = ParamEscaper() self.lastrowid = None + self.ASYNC_DEFAULT_POLLING_INTERVAL = 2 + # The ideal return type for this method is perhaps Self, but that was not added until 3.11, and we support pre-3.11 pythons, currently. def __enter__(self) -> "Cursor": return self @@ -796,6 +800,7 @@ def execute( cursor=self, use_cloud_fetch=self.connection.use_cloud_fetch, parameters=prepared_params, + async_op=False, ) self.active_result_set = ResultSet( self.connection, @@ -812,6 +817,106 @@ def execute( return self + def execute_async( + self, + operation: str, + parameters: Optional[TParameterCollection] = None, + ) -> "Cursor": + """ + + Execute a query and do not wait for it to complete and just move ahead + + :param operation: + :param parameters: + :return: + """ + param_approach = self._determine_parameter_approach(parameters) + if param_approach == ParameterApproach.NONE: + prepared_params = NO_NATIVE_PARAMS + prepared_operation = operation + + elif param_approach == ParameterApproach.INLINE: + prepared_operation, prepared_params = self._prepare_inline_parameters( + operation, parameters + ) + elif param_approach == ParameterApproach.NATIVE: + normalized_parameters = self._normalize_tparametercollection(parameters) + param_structure = self._determine_parameter_structure(normalized_parameters) + transformed_operation = transform_paramstyle( + operation, normalized_parameters, param_structure + ) + prepared_operation, prepared_params = self._prepare_native_parameters( + transformed_operation, normalized_parameters, param_structure + ) + + self._check_not_closed() + self._close_and_clear_active_result_set() + self.thrift_backend.execute_command( + operation=prepared_operation, + session_handle=self.connection._session_handle, + max_rows=self.arraysize, + max_bytes=self.buffer_size_bytes, + lz4_compression=self.connection.lz4_compression, + cursor=self, + use_cloud_fetch=self.connection.use_cloud_fetch, + parameters=prepared_params, + async_op=True, + ) + + return self + + def get_query_state(self) -> "TOperationState": + """ + Get the state of the async executing query or basically poll the status of the query + + :return: + """ + self._check_not_closed() + return self.thrift_backend.get_query_state(self.active_op_handle) + + def get_async_execution_result(self): + """ + + Checks for the status of the async executing query and fetches the result if the query is finished + Otherwise it will keep polling the status of the query till there is a Not pending state + :return: + """ + self._check_not_closed() + + def is_executing(operation_state) -> "bool": + return not operation_state or operation_state in [ + ttypes.TOperationState.RUNNING_STATE, + ttypes.TOperationState.PENDING_STATE, + ] + + while is_executing(self.get_query_state()): + # Poll after some default time + time.sleep(self.ASYNC_DEFAULT_POLLING_INTERVAL) + + operation_state = self.get_query_state() + if operation_state == ttypes.TOperationState.FINISHED_STATE: + execute_response = self.thrift_backend.get_execution_result( + self.active_op_handle, self + ) + self.active_result_set = ResultSet( + self.connection, + execute_response, + self.thrift_backend, + self.buffer_size_bytes, + self.arraysize, + ) + + if execute_response.is_staging_operation: + self._handle_staging_operation( + staging_allowed_local_path=self.thrift_backend.staging_allowed_local_path + ) + + return self + else: + raise Error( + f"get_execution_result failed with Operation status {operation_state}" + ) + def executemany(self, operation, seq_of_parameters): """ Execute the operation once for every set of passed in parameters. diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 29be5482..8c212c55 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -7,6 +7,8 @@ import threading from typing import List, Union +from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState + try: import pyarrow except ImportError: @@ -769,6 +771,63 @@ def _results_message_to_execute_response(self, resp, operation_state): arrow_schema_bytes=schema_bytes, ) + def get_execution_result(self, op_handle, cursor): + + assert op_handle is not None + + req = ttypes.TFetchResultsReq( + operationHandle=ttypes.TOperationHandle( + op_handle.operationId, + op_handle.operationType, + False, + op_handle.modifiedRowCount, + ), + maxRows=cursor.arraysize, + maxBytes=cursor.buffer_size_bytes, + orientation=ttypes.TFetchOrientation.FETCH_NEXT, + includeResultSetMetadata=True, + ) + + resp = self.make_request(self._client.FetchResults, req) + + t_result_set_metadata_resp = resp.resultSetMetadata + + lz4_compressed = t_result_set_metadata_resp.lz4Compressed + is_staging_operation = t_result_set_metadata_resp.isStagingOperation + has_more_rows = resp.hasMoreRows + description = self._hive_schema_to_description( + t_result_set_metadata_resp.schema + ) + + schema_bytes = ( + t_result_set_metadata_resp.arrowSchema + or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema) + .serialize() + .to_pybytes() + ) + + queue = ResultSetQueueFactory.build_queue( + row_set_type=resp.resultSetMetadata.resultFormat, + t_row_set=resp.results, + arrow_schema_bytes=schema_bytes, + max_download_threads=self.max_download_threads, + lz4_compressed=lz4_compressed, + description=description, + ssl_options=self._ssl_options, + ) + + return ExecuteResponse( + arrow_queue=queue, + status=resp.status, + has_been_closed_server_side=False, + has_more_rows=has_more_rows, + lz4_compressed=lz4_compressed, + is_staging_operation=is_staging_operation, + command_handle=op_handle, + description=description, + arrow_schema_bytes=schema_bytes, + ) + def _wait_until_command_done(self, op_handle, initial_operation_status_resp): if initial_operation_status_resp: self._check_command_not_in_error_or_closed_state( @@ -787,6 +846,12 @@ def _wait_until_command_done(self, op_handle, initial_operation_status_resp): self._check_command_not_in_error_or_closed_state(op_handle, poll_resp) return operation_state + def get_query_state(self, op_handle) -> "TOperationState": + poll_resp = self._poll_for_status(op_handle) + operation_state = poll_resp.operationState + self._check_command_not_in_error_or_closed_state(op_handle, poll_resp) + return operation_state + @staticmethod def _check_direct_results_for_error(t_spark_direct_results): if t_spark_direct_results: @@ -817,6 +882,7 @@ def execute_command( cursor, use_cloud_fetch=True, parameters=[], + async_op=False, ): assert session_handle is not None @@ -846,7 +912,11 @@ def execute_command( parameters=parameters, ) resp = self.make_request(self._client.ExecuteStatement, req) - return self._handle_execute_response(resp, cursor) + + if async_op: + self._handle_execute_response_async(resp, cursor) + else: + return self._handle_execute_response(resp, cursor) def get_catalogs(self, session_handle, max_rows, max_bytes, cursor): assert session_handle is not None @@ -945,6 +1015,10 @@ def _handle_execute_response(self, resp, cursor): return self._results_message_to_execute_response(resp, final_operation_state) + def _handle_execute_response_async(self, resp, cursor): + cursor.active_op_handle = resp.operationHandle + self._check_direct_results_for_error(resp.directResults) + def fetch_results( self, op_handle, diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index cfd1e969..2f0881cd 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -36,6 +36,7 @@ compare_dbr_versions, is_thrift_v5_plus, ) +from databricks.sql.thrift_api.TCLIService import ttypes from tests.e2e.common.core_tests import CoreTestMixin, SmokeTestMixin from tests.e2e.common.large_queries_mixin import LargeQueriesMixin from tests.e2e.common.timestamp_tests import TimestampTestsMixin @@ -78,6 +79,7 @@ class PySQLPytestTestCase: } arraysize = 1000 buffer_size_bytes = 104857600 + POLLING_INTERVAL = 2 @pytest.fixture(autouse=True) def get_details(self, connection_details): @@ -175,6 +177,27 @@ def test_cloud_fetch(self): for i in range(len(cf_result)): assert cf_result[i] == noop_result[i] + def test_execute_async(self): + def isExecuting(operation_state): + return not operation_state or operation_state in [ + ttypes.TOperationState.RUNNING_STATE, + ttypes.TOperationState.PENDING_STATE, + ] + + long_running_query = "SELECT COUNT(*) FROM RANGE(10000 * 16) x JOIN RANGE(10000) y ON FROM_UNIXTIME(x.id * y.id, 'yyyy-MM-dd') LIKE '%not%a%date%'" + with self.cursor() as cursor: + cursor.execute_async(long_running_query) + + ## Polling after every POLLING_INTERVAL seconds + while isExecuting(cursor.get_query_state()): + time.sleep(self.POLLING_INTERVAL) + log.info("Polling the status in test_execute_async") + + cursor.get_async_execution_result() + result = cursor.fetchall() + + assert result[0].asDict() == {"count(1)": 0} + # Exclude Retry tests because they require specific setups, and LargeQueries too slow for core # tests From 980af886677edcb3fbfd1fd482a187e780d318e9 Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Tue, 26 Nov 2024 22:53:06 +0530 Subject: [PATCH 11/13] Fix for check_types github action failing (#472) Fixed the chekc_types failing --- src/databricks/sql/auth/retry.py | 2 +- tests/unit/test_retry.py | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/auth/retry.py b/src/databricks/sql/auth/retry.py index ec321bed..0243d0aa 100755 --- a/src/databricks/sql/auth/retry.py +++ b/src/databricks/sql/auth/retry.py @@ -305,7 +305,7 @@ def get_backoff_time(self) -> float: :return: """ - current_attempt = self.stop_after_attempts_count - self.total + current_attempt = self.stop_after_attempts_count - int(self.total or 0) proposed_backoff = (2**current_attempt) * self.delay_min if self.backoff_jitter != 0.0: proposed_backoff += random.random() * self.backoff_jitter diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py index b7648ffb..1e18e1f4 100644 --- a/tests/unit/test_retry.py +++ b/tests/unit/test_retry.py @@ -5,6 +5,7 @@ from databricks.sql.auth.retry import DatabricksRetryPolicy, RequestHistory, CommandType from urllib3.exceptions import MaxRetryError + class TestRetry: @pytest.fixture() def retry_policy(self) -> DatabricksRetryPolicy: @@ -33,7 +34,9 @@ def test_sleep__no_retry_after(self, t_mock, retry_policy, error_history): retry_policy.history = [error_history, error_history] retry_policy.sleep(HTTPResponse(status=503)) - expected_backoff_time = self.calculate_backoff_time(0, retry_policy.delay_min, retry_policy.delay_max) + expected_backoff_time = self.calculate_backoff_time( + 0, retry_policy.delay_min, retry_policy.delay_max + ) t_mock.assert_called_with(expected_backoff_time) @patch("time.sleep") @@ -50,10 +53,16 @@ def test_sleep__no_retry_after_header__multiple_retries(self, t_mock, retry_poli expected_backoff_times = [] for attempt in range(num_attempts): - expected_backoff_times.append(self.calculate_backoff_time(attempt, retry_policy.delay_min, retry_policy.delay_max)) + expected_backoff_times.append( + self.calculate_backoff_time( + attempt, retry_policy.delay_min, retry_policy.delay_max + ) + ) # Asserts if the sleep value was called in the expected order - t_mock.assert_has_calls([call(expected_time) for expected_time in expected_backoff_times]) + t_mock.assert_has_calls( + [call(expected_time) for expected_time in expected_backoff_times] + ) @patch("time.sleep") def test_excessive_retry_attempts_error(self, t_mock, retry_policy): From d6905164f5df4591f5c311ca39c16ef3b230624d Mon Sep 17 00:00:00 2001 From: arredond Date: Thu, 5 Dec 2024 18:44:15 +0100 Subject: [PATCH 12/13] Remove upper caps on dependencies (#452) * Remove upper caps on numpy and pyarrow versions --- poetry.lock | 2 +- pyproject.toml | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/poetry.lock b/poetry.lock index 9fe49690..576adbd3 100755 --- a/poetry.lock +++ b/poetry.lock @@ -1202,4 +1202,4 @@ sqlalchemy = ["sqlalchemy"] [metadata] lock-version = "2.0" python-versions = "^3.8.0" -content-hash = "31066a85f646d0009d6fe9ffc833a64fcb4b6923c2e7f2652e7aa8540acba298" +content-hash = "9d8a91369fc79f9ca9f7502e2ed284b66531c954ae59a723e465a76073966998" diff --git a/pyproject.toml b/pyproject.toml index 1747d21b..23ffed58 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,14 +14,14 @@ thrift = ">=0.16.0,<0.21.0" pandas = [ { version = ">=1.2.5,<2.3.0", python = ">=3.8" } ] -pyarrow = ">=14.0.1,<17" +pyarrow = ">=14.0.1" lz4 = "^4.0.2" requests = "^2.18.1" oauthlib = "^3.1.0" numpy = [ - { version = "^1.16.6", python = ">=3.8,<3.11" }, - { version = "^1.23.4", python = ">=3.11" }, + { version = ">=1.16.6", python = ">=3.8,<3.11" }, + { version = ">=1.23.4", python = ">=3.11" }, ] sqlalchemy = { version = ">=2.0.21", optional = true } openpyxl = "^3.0.10" From 680b3b6df55518bb22939a9c1f052759c1b680df Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Fri, 6 Dec 2024 10:15:23 +0530 Subject: [PATCH 13/13] Updated the doc to specify native parameters in PUT operation is not supported from >=3.x connector (#477) Added doc update --- docs/parameters.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/parameters.md b/docs/parameters.md index a538af1a..f9f4c5ff 100644 --- a/docs/parameters.md +++ b/docs/parameters.md @@ -17,6 +17,7 @@ See `examples/parameters.py` in this repository for a runnable demo. - A query executed with native parameters can contain at most 255 parameter markers - The maximum size of all parameterized values cannot exceed 1MB +- For volume operations such as PUT, native parameters are not supported ## SQL Syntax