Skip to content

Commit

Permalink
Add unit test for non-seekable stream
Browse files Browse the repository at this point in the history
  • Loading branch information
renaudhartert-db committed Nov 13, 2024
1 parent 787241f commit 4a196f7
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions tests/test_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,38 @@ def mock_iter_content(chunk_size):
assert all(len(c) <= chunk_size for c in content_chunks) # chunks don't exceed size


def test_no_retry_on_non_seekable_stream():
requests = []

# Always respond with a response that triggers a retry.
def inner(h: BaseHTTPRequestHandler):
content_length = int(h.headers.get('Content-Length', 0))
if content_length > 0:
requests.append(h.rfile.read(content_length))

h.send_response(429)
h.send_header('Retry-After', '1')
h.end_headers()

stream = io.BytesIO(b"test data")
stream.seekable = lambda: False # makes the stream appear non-seekable

with http_fixture_server(inner) as host:
client = _BaseClient()

# Should raise error immediately without retry.
with pytest.raises(DatabricksError):
client.do('POST', f'{host}/foo', data=stream)

# Verify that only one request was made (no retries).
assert len(requests) == 1
assert requests[0] == b"test data"


def test_perform_resets_seekable_stream_on_error():
received_data = []

# Response that triggers a retry.
# Always respond with a response that triggers a retry.
def inner(h: BaseHTTPRequestHandler):
content_length = int(h.headers.get('Content-Length', 0))
if content_length > 0:
Expand All @@ -340,7 +368,7 @@ def inner(h: BaseHTTPRequestHandler):
stream.read(4)
assert stream.tell() == 4

# Call perform which should fail but reset the stream.
# Should fail but reset the stream.
with pytest.raises(DatabricksError):
client._perform('POST', f'{host}/foo', data=stream)

Expand All @@ -353,7 +381,7 @@ def inner(h: BaseHTTPRequestHandler):
def test_perform_does_not_reset_nonseekable_stream_on_error():
received_data = []

# Response that triggers a retry.
# Always respond with a response that triggers a retry.
def inner(h: BaseHTTPRequestHandler):
content_length = int(h.headers.get('Content-Length', 0))
if content_length > 0:
Expand All @@ -374,7 +402,7 @@ def inner(h: BaseHTTPRequestHandler):
stream.read(4)
assert stream.tell() == 4

# Call perform which should fail but reset the stream.
# Should fail without resetting the stream.
with pytest.raises(DatabricksError):
client._perform('POST', f'{host}/foo', data=stream)

Expand Down

0 comments on commit 4a196f7

Please sign in to comment.