diff --git a/examples/experimental/tests/test_sea_sync_query.py b/examples/experimental/tests/test_sea_sync_query.py index 16ee80a78..05f0f314e 100644 --- a/examples/experimental/tests/test_sea_sync_query.py +++ b/examples/experimental/tests/test_sea_sync_query.py @@ -4,9 +4,10 @@ import os import sys import logging +import time from databricks.sql.client import Connection -logging.basicConfig(level=logging.INFO) +logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -51,13 +52,10 @@ def test_sea_sync_query_with_cloud_fetch(): ) # Execute a query that generates large rows to force multiple chunks - requested_row_count = 10000 + requested_row_count = 100000000 cursor = connection.cursor() query = f""" - SELECT - id, - concat('value_', repeat('a', 10000)) as test_value - FROM range(1, {requested_row_count} + 1) AS t(id) + SELECT * FROM samples.tpch.lineitem LIMIT {requested_row_count} """ logger.info( @@ -65,6 +63,8 @@ def test_sea_sync_query_with_cloud_fetch(): ) cursor.execute(query) results = [cursor.fetchone()] + logger.info("SLEEPING FOR 1000 SECONDS TO EXPIRE LINKS") + time.sleep(1000) results.extend(cursor.fetchmany(10)) results.extend(cursor.fetchall()) actual_row_count = len(results) diff --git a/src/databricks/sql/backend/sea/queue.py b/src/databricks/sql/backend/sea/queue.py index 12ed38ee7..dd8e7b5fe 100644 --- a/src/databricks/sql/backend/sea/queue.py +++ b/src/databricks/sql/backend/sea/queue.py @@ -189,9 +189,11 @@ def _add_links(self, links: List[ExternalLink]): len(links), ", ".join(str(l.chunk_index) for l in links) if links else "", ) - for link in links: - self.chunk_index_to_link[link.chunk_index] = link - self.download_manager.add_link(LinkFetcher._convert_to_thrift_link(link)) + + self.chunk_index_to_link.update({link.chunk_index: link for link in links}) + self.download_manager.add_links( + [LinkFetcher._convert_to_thrift_link(link) for link in links] + ) def _get_next_chunk_index(self) -> Optional[int]: """Return the next *chunk_index* that should be requested from the backend, or ``None`` if we have them all.""" @@ -281,9 +283,27 @@ def _worker_loop(self): with self._link_data_update: self._link_data_update.notify_all() + def _restart_from_expired_link(self, link: TSparkArrowResultLink): + """Restart the link fetcher from the expired link.""" + self.stop() + + with self._link_data_update: + self.download_manager.cancel_tasks_from_offset(link.startRowOffset) + + chunks_to_restart = [] + for chunk_index, l in self.chunk_index_to_link.items(): + if l.row_offset < link.startRowOffset: + continue + chunks_to_restart.append(chunk_index) + for chunk_index in chunks_to_restart: + self.chunk_index_to_link.pop(chunk_index) + + self.start() + def start(self): """Spawn the worker thread.""" logger.debug("LinkFetcher[%s]: starting worker thread", self._statement_id) + self._shutdown_event.clear() self._worker_thread = threading.Thread( target=self._worker_loop, name=f"LinkFetcher-{self._statement_id}" ) @@ -333,6 +353,7 @@ def __init__( schema_bytes=None, lz4_compressed=lz4_compressed, description=description, + expiry_callback=self._expiry_callback, # TODO: fix these arguments when telemetry is implemented in SEA session_id_hex=None, chunk_id=0, @@ -363,6 +384,14 @@ def __init__( # Initialize table and position self.table = self._create_next_table() + def _expiry_callback(self, link: TSparkArrowResultLink): + logger.info( + f"SeaCloudFetchQueue: Link expired, restarting from offset {link.startRowOffset}" + ) + if not self.link_fetcher: + return + self.link_fetcher._restart_from_expired_link(link) + def _create_next_table(self) -> "pyarrow.Table": """Create next table by retrieving the logical next downloaded file.""" if self.link_fetcher is None: diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index e187771f7..a0eb73a1b 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -2,7 +2,7 @@ from concurrent.futures import ThreadPoolExecutor, Future import threading -from typing import List, Union, Tuple, Optional +from typing import Callable, List, Optional, Union, Generic, TypeVar, Tuple, Optional from databricks.sql.cloudfetch.downloader import ( ResultSetDownloadHandler, @@ -16,6 +16,27 @@ logger = logging.getLogger(__name__) +T = TypeVar("T") + + +class TaskWithMetadata(Generic[T]): + """ + Wrapper around Future that stores additional metadata (the link). + Provides type-safe access to both the Future result and the associated link. + """ + + def __init__(self, future: Future[T], link: TSparkArrowResultLink): + self.future = future + self.link = link + + def result(self, timeout: Optional[float] = None) -> T: + """Get the result of the Future, blocking if necessary.""" + return self.future.result(timeout) + + def cancel(self) -> bool: + """Cancel the Future if possible.""" + return self.future.cancel() + class ResultFileDownloadManager: def __init__( @@ -27,6 +48,7 @@ def __init__( session_id_hex: Optional[str], statement_id: str, chunk_id: int, + expiry_callback: Callable[[TSparkArrowResultLink], None], ): self._pending_links: List[Tuple[int, TSparkArrowResultLink]] = [] self.chunk_id = chunk_id @@ -44,11 +66,12 @@ def __init__( self._max_download_threads: int = max_download_threads self._download_condition = threading.Condition() - self._download_tasks: List[Future[DownloadedFile]] = [] + self._download_tasks: List[TaskWithMetadata[DownloadedFile]] = [] self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads) self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed) self._ssl_options = ssl_options + self._expiry_callback = expiry_callback self.session_id_hex = session_id_hex self.statement_id = statement_id @@ -89,6 +112,41 @@ def get_next_downloaded_file(self, next_row_offset: int) -> DownloadedFile: return file + def cancel_tasks_from_offset(self, start_row_offset: int): + """ + Cancel all download tasks starting from a specific row offset. + This is used when links expire and we need to restart from a certain point. + + Args: + start_row_offset (int): Row offset from which to cancel tasks + """ + + def to_cancel(link: TSparkArrowResultLink) -> bool: + return link.startRowOffset < start_row_offset + + tasks_to_cancel = [ + task for task in self._download_tasks if to_cancel(task.link) + ] + for task in tasks_to_cancel: + task.cancel() + logger.info( + f"ResultFileDownloadManager: cancelled {len(tasks_to_cancel)} tasks from offset {start_row_offset}" + ) + + # Remove cancelled tasks from the download queue + tasks_to_keep = [ + task for task in self._download_tasks if not to_cancel(task.link) + ] + self._download_tasks = tasks_to_keep + + pending_links_to_keep = [ + link for link in self._pending_links if not to_cancel(link[1]) + ] + self._pending_links = pending_links_to_keep + logger.info( + f"ResultFileDownloadManager: removed {len(self._pending_links) - len(pending_links_to_keep)} links from pending links" + ) + def _schedule_downloads(self): """ While download queue has a capacity, peek pending links and submit them to thread pool. @@ -107,34 +165,35 @@ def _schedule_downloads(self): settings=self._downloadable_result_settings, link=link, ssl_options=self._ssl_options, + expiry_callback=self._expiry_callback, chunk_id=chunk_id, session_id_hex=self.session_id_hex, statement_id=self.statement_id, ) - task = self._thread_pool.submit(handler.run) + future = self._thread_pool.submit(handler.run) + task = TaskWithMetadata(future, link) self._download_tasks.append(task) with self._download_condition: self._download_condition.notify_all() - def add_link(self, link: TSparkArrowResultLink): + def add_links(self, links: List[TSparkArrowResultLink]): """ Add more links to the download manager. Args: - link: Link to add + link (TSparkArrowResultLink): The link to add to the download manager. """ - - if link.rowCount <= 0: - return - - logger.debug( - "ResultFileDownloadManager: adding file link, start offset {}, row count: {}".format( - link.startRowOffset, link.rowCount + for link in links: + if link.rowCount <= 0: + continue + logger.debug( + "ResultFileDownloadManager: adding file link, start offset {}, row count: {}".format( + link.startRowOffset, link.rowCount + ) ) - ) - self._pending_links.append((self.chunk_id, link)) - self.chunk_id += 1 + self._pending_links.append((self.chunk_id, link)) + self.chunk_id += 1 self._schedule_downloads() diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index e19a69046..a857f7ae5 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -1,5 +1,6 @@ import logging from dataclasses import dataclass +from typing import Callable from typing import Optional import requests @@ -69,6 +70,7 @@ def __init__( settings: DownloadableResultSettings, link: TSparkArrowResultLink, ssl_options: SSLOptions, + expiry_callback: Callable[[TSparkArrowResultLink], None], chunk_id: int, session_id_hex: Optional[str], statement_id: str, @@ -76,6 +78,7 @@ def __init__( self.settings = settings self.link = link self._ssl_options = ssl_options + self._expiry_callback = expiry_callback self.chunk_id = chunk_id self.session_id_hex = session_id_hex self.statement_id = statement_id @@ -96,9 +99,7 @@ def run(self) -> DownloadedFile: ) # Check if link is already expired or is expiring - ResultSetDownloadHandler._validate_link( - self.link, self.settings.link_expiry_buffer_secs - ) + self._validate_link(self.link, self.settings.link_expiry_buffer_secs) session = requests.Session() session.mount("http://", HTTPAdapter(max_retries=retryPolicy)) @@ -146,8 +147,7 @@ def run(self) -> DownloadedFile: if session: session.close() - @staticmethod - def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int): + def _validate_link(self, link: TSparkArrowResultLink, expiry_buffer_secs: int): """ Check if a link has expired or will expire. @@ -159,7 +159,7 @@ def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int): link.expiryTime <= current_time or link.expiryTime - current_time <= expiry_buffer_secs ): - raise Error("CloudFetch link has expired") + self._expiry_callback(link) @staticmethod def _decompress_data(compressed_data: bytes) -> bytes: diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 098638158..9db3055b0 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union from dateutil import parser import datetime @@ -14,6 +14,8 @@ import lz4.frame +from databricks.sql.exc import Error + try: import pyarrow except ImportError: @@ -227,6 +229,7 @@ def __init__( schema_bytes: Optional[bytes] = None, lz4_compressed: bool = True, description: List[Tuple] = [], + expiry_callback: Callable[[TSparkArrowResultLink], None] = lambda _: None, ): """ Initialize the base CloudFetchQueue. @@ -261,6 +264,7 @@ def __init__( session_id_hex=session_id_hex, statement_id=statement_id, chunk_id=chunk_id, + expiry_callback=expiry_callback, ) def next_n_rows(self, num_rows: int) -> "pyarrow.Table": @@ -381,6 +385,7 @@ def __init__( session_id_hex=session_id_hex, statement_id=statement_id, chunk_id=chunk_id, + expiry_callback=self._expiry_callback, ) self.start_row_index = start_row_offset @@ -404,11 +409,14 @@ def __init__( result_link.startRowOffset, result_link.rowCount ) ) - self.download_manager.add_link(result_link) + self.download_manager.add_links(self.result_links) # Initialize table and position self.table = self._create_next_table() + def _expiry_callback(self, link: TSparkArrowResultLink): + raise Error("Cloudfetch link has expired") + def _create_next_table(self) -> "pyarrow.Table": if self.num_links_downloaded >= len(self.result_links): return self._create_empty_table() diff --git a/tests/unit/test_download_manager.py b/tests/unit/test_download_manager.py index 6eb17a05a..593fe0e5d 100644 --- a/tests/unit/test_download_manager.py +++ b/tests/unit/test_download_manager.py @@ -14,11 +14,15 @@ class DownloadManagerTests(unittest.TestCase): def create_download_manager( self, links, max_download_threads=10, lz4_compressed=True ): + def expiry_callback(link: TSparkArrowResultLink): + return None + return download_manager.ResultFileDownloadManager( links, max_download_threads, lz4_compressed, ssl_options=SSLOptions(), + expiry_callback=expiry_callback, session_id_hex=Mock(), statement_id=Mock(), chunk_id=0, diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index 687b7db7f..2be700820 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -5,6 +5,7 @@ import databricks.sql.cloudfetch.downloader as downloader from databricks.sql.exc import Error +from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink from databricks.sql.types import SSLOptions @@ -20,21 +21,30 @@ class DownloaderTests(unittest.TestCase): Unit tests for checking downloader logic. """ - @patch("time.time", return_value=1000) - def test_run_link_expired(self, mock_time): - settings = Mock() - result_link = Mock() - # Already expired - result_link.expiryTime = 999 - d = downloader.ResultSetDownloadHandler( + def create_download_handler( + self, settings: Mock, result_link: Mock + ) -> downloader.ResultSetDownloadHandler: + def expiry_callback(link: TSparkArrowResultLink): + raise Error("Cloudfetch link has expired") + + return downloader.ResultSetDownloadHandler( settings, result_link, ssl_options=SSLOptions(), + expiry_callback=expiry_callback, chunk_id=0, session_id_hex=Mock(), statement_id=Mock(), ) + @patch("time.time", return_value=1000) + def test_run_link_expired(self, mock_time): + settings = Mock() + result_link = Mock() + # Already expired + result_link.expiryTime = 999 + d = self.create_download_handler(settings, result_link) + with self.assertRaises(Error) as context: d.run() self.assertTrue("link has expired" in context.exception.message) @@ -47,14 +57,7 @@ def test_run_link_past_expiry_buffer(self, mock_time): result_link = Mock() # Within the expiry buffer time result_link.expiryTime = 1004 - d = downloader.ResultSetDownloadHandler( - settings, - result_link, - ssl_options=SSLOptions(), - chunk_id=0, - session_id_hex=Mock(), - statement_id=Mock(), - ) + d = self.create_download_handler(settings, result_link) with self.assertRaises(Error) as context: d.run() @@ -72,14 +75,7 @@ def test_run_get_response_not_ok(self, mock_time, mock_session): settings.use_proxy = False result_link = Mock(expiryTime=1001) - d = downloader.ResultSetDownloadHandler( - settings, - result_link, - ssl_options=SSLOptions(), - chunk_id=0, - session_id_hex=Mock(), - statement_id=Mock(), - ) + d = self.create_download_handler(settings, result_link) with self.assertRaises(requests.exceptions.HTTPError) as context: d.run() self.assertTrue("404" in str(context.exception)) @@ -96,14 +92,7 @@ def test_run_uncompressed_successful(self, mock_time, mock_session): settings.is_lz4_compressed = False result_link = Mock(bytesNum=100, expiryTime=1001) - d = downloader.ResultSetDownloadHandler( - settings, - result_link, - ssl_options=SSLOptions(), - chunk_id=0, - session_id_hex=Mock(), - statement_id=Mock(), - ) + d = self.create_download_handler(settings, result_link) file = d.run() assert file.file_bytes == b"1234567890" * 10 @@ -124,14 +113,7 @@ def test_run_compressed_successful(self, mock_time, mock_session): settings.is_lz4_compressed = True result_link = Mock(bytesNum=100, expiryTime=1001) - d = downloader.ResultSetDownloadHandler( - settings, - result_link, - ssl_options=SSLOptions(), - chunk_id=0, - session_id_hex=Mock(), - statement_id=Mock(), - ) + d = self.create_download_handler(settings, result_link) file = d.run() assert file.file_bytes == b"1234567890" * 10 @@ -145,14 +127,7 @@ def test_download_connection_error(self, mock_time, mock_session): result_link = Mock(bytesNum=100, expiryTime=1001) mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - d = downloader.ResultSetDownloadHandler( - settings, - result_link, - ssl_options=SSLOptions(), - chunk_id=0, - session_id_hex=Mock(), - statement_id=Mock(), - ) + d = self.create_download_handler(settings, result_link) with self.assertRaises(ConnectionError): d.run() @@ -165,13 +140,6 @@ def test_download_timeout(self, mock_time, mock_session): result_link = Mock(bytesNum=100, expiryTime=1001) mock_session.return_value.get.return_value.content = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' - d = downloader.ResultSetDownloadHandler( - settings, - result_link, - ssl_options=SSLOptions(), - chunk_id=0, - session_id_hex=Mock(), - statement_id=Mock(), - ) + d = self.create_download_handler(settings, result_link) with self.assertRaises(TimeoutError): d.run() diff --git a/tests/unit/test_sea_queue.py b/tests/unit/test_sea_queue.py index f0dcf5297..706d60702 100644 --- a/tests/unit/test_sea_queue.py +++ b/tests/unit/test_sea_queue.py @@ -638,7 +638,7 @@ def test_add_links_and_get_next_chunk_index(self, sample_links): fetcher, _backend, download_manager = self._create_fetcher([link0]) # add_link should have been called for the initial link - download_manager.add_link.assert_called_once() + download_manager.add_links.assert_called_once() # Internal mapping should contain the link assert fetcher.chunk_index_to_link[0] == link0 @@ -668,7 +668,7 @@ def test_trigger_next_batch_download_success(self, sample_links): backend.get_chunk_links.assert_called_once_with("statement-123", 1) assert fetcher.chunk_index_to_link[1] == link1 # Two calls to add_link: one for initial link, one for new link - assert download_manager.add_link.call_count == 2 + assert download_manager.add_links.call_count == 2 def test_trigger_next_batch_download_error(self, sample_links): """Ensure that errors from backend are captured and surfaced."""