Skip to content

Commit

Permalink
[Feature] Files API client: recover on download failures (#844) (#845)
Browse files Browse the repository at this point in the history
## What changes are proposed in this pull request?

1. Extending Files API client to support resuming download on failures.
New implementation tracks current offset in the input stream and issues
a new download request from this point in case of an error.
2. New code path is enabled by
'DATABRICKS_ENABLE_EXPERIMENTAL_FILES_API_CLIENT' config parameter.

## How is this tested?

Added unit tests for the new code path:
`% python3 -m pytest tests/test_files.py`

---------

Signed-off-by: Kirill Safonov <[email protected]>
  • Loading branch information
ksafonov-db authored Jan 8, 2025
1 parent 6d6923e commit d907c0c
Show file tree
Hide file tree
Showing 6 changed files with 559 additions and 11 deletions.
11 changes: 9 additions & 2 deletions databricks/sdk/__init__.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 16 additions & 3 deletions databricks/sdk/_base_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
import logging
import urllib.parse
from abc import ABC, abstractmethod
from datetime import timedelta
from types import TracebackType
from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List,
Expand Down Expand Up @@ -285,8 +286,20 @@ def _record_request_log(self, response: requests.Response, raw: bool = False) ->
logger.debug(RoundTrip(response, self._debug_headers, self._debug_truncate_bytes, raw).generate())


class _RawResponse(ABC):

@abstractmethod
# follows Response signature: https://github.com/psf/requests/blob/main/src/requests/models.py#L799
def iter_content(self, chunk_size: int = 1, decode_unicode: bool = False):
pass

@abstractmethod
def close(self):
pass


class _StreamingResponse(BinaryIO):
_response: requests.Response
_response: _RawResponse
_buffer: bytes
_content: Union[Iterator[bytes], None]
_chunk_size: Union[int, None]
Expand All @@ -298,7 +311,7 @@ def fileno(self) -> int:
def flush(self) -> int:
pass

def __init__(self, response: requests.Response, chunk_size: Union[int, None] = None):
def __init__(self, response: _RawResponse, chunk_size: Union[int, None] = None):
self._response = response
self._buffer = b''
self._content = None
Expand All @@ -308,7 +321,7 @@ def _open(self) -> None:
if self._closed:
raise ValueError("I/O operation on closed file")
if not self._content:
self._content = self._response.iter_content(chunk_size=self._chunk_size)
self._content = self._response.iter_content(chunk_size=self._chunk_size, decode_unicode=False)

def __enter__(self) -> BinaryIO:
self._open()
Expand Down
5 changes: 5 additions & 0 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ class Config:
max_connections_per_pool: int = ConfigAttribute()
databricks_environment: Optional[DatabricksEnvironment] = None

enable_experimental_files_api_client: bool = ConfigAttribute(
env='DATABRICKS_ENABLE_EXPERIMENTAL_FILES_API_CLIENT')
files_api_client_download_max_total_recovers = None
files_api_client_download_max_total_recovers_without_progressing = 1

def __init__(
self,
*,
Expand Down
185 changes: 184 additions & 1 deletion databricks/sdk/mixins/files.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,35 @@
from __future__ import annotations

import base64
import logging
import os
import pathlib
import platform
import shutil
import sys
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Iterator
from io import BytesIO
from types import TracebackType
from typing import (TYPE_CHECKING, AnyStr, BinaryIO, Generator, Iterable,
Iterator, Type, Union)
Optional, Type, Union)
from urllib import parse

from requests import RequestException

from .._base_client import _RawResponse, _StreamingResponse
from .._property import _cached_property
from ..errors import NotFound
from ..service import files
from ..service._internal import _escape_multi_segment_path_parameter
from ..service.files import DownloadResponse

if TYPE_CHECKING:
from _typeshed import Self

_LOG = logging.getLogger(__name__)


class _DbfsIO(BinaryIO):
MAX_CHUNK_SIZE = 1024 * 1024
Expand Down Expand Up @@ -636,3 +645,177 @@ def delete(self, path: str, *, recursive=False):
if p.is_dir and not recursive:
raise IOError('deleting directories requires recursive flag')
p.delete(recursive=recursive)


class FilesExt(files.FilesAPI):
__doc__ = files.FilesAPI.__doc__

def __init__(self, api_client, config: Config):
super().__init__(api_client)
self._config = config.copy()

def download(self, file_path: str) -> DownloadResponse:
"""Download a file.
Downloads a file of any size. The file contents are the response body.
This is a standard HTTP file download, not a JSON RPC.
It is strongly recommended, for fault tolerance reasons,
to iteratively consume from the stream with a maximum read(size)
defined instead of using indefinite-size reads.
:param file_path: str
The remote path of the file, e.g. /Volumes/path/to/your/file
:returns: :class:`DownloadResponse`
"""

initial_response: DownloadResponse = self._download_raw_stream(file_path=file_path,
start_byte_offset=0,
if_unmodified_since_timestamp=None)

wrapped_response = self._wrap_stream(file_path, initial_response)
initial_response.contents._response = wrapped_response
return initial_response

def _download_raw_stream(self,
file_path: str,
start_byte_offset: int,
if_unmodified_since_timestamp: Optional[str] = None) -> DownloadResponse:
headers = {'Accept': 'application/octet-stream', }

if start_byte_offset and not if_unmodified_since_timestamp:
raise Exception("if_unmodified_since_timestamp is required if start_byte_offset is specified")

if start_byte_offset:
headers['Range'] = f'bytes={start_byte_offset}-'

if if_unmodified_since_timestamp:
headers['If-Unmodified-Since'] = if_unmodified_since_timestamp

response_headers = ['content-length', 'content-type', 'last-modified', ]
res = self._api.do('GET',
f'/api/2.0/fs/files{_escape_multi_segment_path_parameter(file_path)}',
headers=headers,
response_headers=response_headers,
raw=True)

result = DownloadResponse.from_dict(res)
if not isinstance(result.contents, _StreamingResponse):
raise Exception("Internal error: response contents is of unexpected type: " +
type(result.contents).__name__)

return result

def _wrap_stream(self, file_path: str, downloadResponse: DownloadResponse):
underlying_response = _ResilientIterator._extract_raw_response(downloadResponse)
return _ResilientResponse(self,
file_path,
downloadResponse.last_modified,
offset=0,
underlying_response=underlying_response)


class _ResilientResponse(_RawResponse):

def __init__(self, api: FilesExt, file_path: str, file_last_modified: str, offset: int,
underlying_response: _RawResponse):
self.api = api
self.file_path = file_path
self.underlying_response = underlying_response
self.offset = offset
self.file_last_modified = file_last_modified

def iter_content(self, chunk_size=1, decode_unicode=False):
if decode_unicode:
raise ValueError('Decode unicode is not supported')

iterator = self.underlying_response.iter_content(chunk_size=chunk_size, decode_unicode=False)
self.iterator = _ResilientIterator(iterator, self.file_path, self.file_last_modified, self.offset,
self.api, chunk_size)
return self.iterator

def close(self):
self.iterator.close()


class _ResilientIterator(Iterator):
# This class tracks current offset (returned to the client code)
# and recovers from failures by requesting download from the current offset.

@staticmethod
def _extract_raw_response(download_response: DownloadResponse) -> _RawResponse:
streaming_response: _StreamingResponse = download_response.contents # this is an instance of _StreamingResponse
return streaming_response._response

def __init__(self, underlying_iterator, file_path: str, file_last_modified: str, offset: int,
api: FilesExt, chunk_size: int):
self._underlying_iterator = underlying_iterator
self._api = api
self._file_path = file_path

# Absolute current offset (0-based), i.e. number of bytes from the beginning of the file
# that were so far returned to the caller code.
self._offset = offset
self._file_last_modified = file_last_modified
self._chunk_size = chunk_size

self._total_recovers_count: int = 0
self._recovers_without_progressing_count: int = 0
self._closed: bool = False

def _should_recover(self) -> bool:
if self._total_recovers_count == self._api._config.files_api_client_download_max_total_recovers:
_LOG.debug("Total recovers limit exceeded")
return False
if self._api._config.files_api_client_download_max_total_recovers_without_progressing is not None and self._recovers_without_progressing_count >= self._api._config.files_api_client_download_max_total_recovers_without_progressing:
_LOG.debug("No progression recovers limit exceeded")
return False
return True

def _recover(self) -> bool:
if not self._should_recover():
return False # recover suppressed, rethrow original exception

self._total_recovers_count += 1
self._recovers_without_progressing_count += 1

try:
self._underlying_iterator.close()

_LOG.debug("Trying to recover from offset " + str(self._offset))

# following call includes all the required network retries
downloadResponse = self._api._download_raw_stream(self._file_path, self._offset,
self._file_last_modified)
underlying_response = _ResilientIterator._extract_raw_response(downloadResponse)
self._underlying_iterator = underlying_response.iter_content(chunk_size=self._chunk_size,
decode_unicode=False)
_LOG.debug("Recover succeeded")
return True
except:
return False # recover failed, rethrow original exception

def __next__(self):
if self._closed:
# following _BaseClient
raise ValueError("I/O operation on closed file")

while True:
try:
returned_bytes = next(self._underlying_iterator)
self._offset += len(returned_bytes)
self._recovers_without_progressing_count = 0
return returned_bytes

except StopIteration:
raise

# https://requests.readthedocs.io/en/latest/user/quickstart/#errors-and-exceptions
except RequestException:
if not self._recover():
raise

def close(self):
self._underlying_iterator.close()
self._closed = True
10 changes: 5 additions & 5 deletions tests/test_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
from unittest.mock import Mock

import pytest
import requests

from databricks.sdk import errors, useragent
from databricks.sdk._base_client import _BaseClient, _StreamingResponse
from databricks.sdk._base_client import (_BaseClient, _RawResponse,
_StreamingResponse)
from databricks.sdk.core import DatabricksError

from .clock import FakeClock
from .fixture_server import http_fixture_server


class DummyResponse(requests.Response):
class DummyResponse(_RawResponse):
_content: Iterator[bytes]
_closed: bool = False

Expand Down Expand Up @@ -293,9 +293,9 @@ def test_streaming_response_chunk_size(chunk_size, expected_chunks, data_size):
test_data = bytes(rng.getrandbits(8) for _ in range(data_size))

content_chunks = []
mock_response = Mock(spec=requests.Response)
mock_response = Mock(spec=_RawResponse)

def mock_iter_content(chunk_size):
def mock_iter_content(chunk_size: int, decode_unicode: bool):
# Simulate how requests would chunk the data.
for i in range(0, len(test_data), chunk_size):
chunk = test_data[i:i + chunk_size]
Expand Down
Loading

0 comments on commit d907c0c

Please sign in to comment.