Skip to content

Commit

Permalink
Merge pull request #273 from mheilman/s3_retry
Browse files Browse the repository at this point in the history
Add a retry for downloading files via civis.io.file_to_civis
  • Loading branch information
mheilman authored Dec 3, 2018
2 parents 7a08b8c + 4562a2d commit bdb197b
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 6 deletions.
28 changes: 22 additions & 6 deletions civis/io/_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,14 @@ def _single_upload(buf, name, client, **kwargs):
form_key = OrderedDict(key=form.pop('key'))
form_key.update(form)

# Store the current buffer position in case we need to retry below.
buf_orig_position = buf.tell()

@retry(RETRY_EXCEPTIONS)
def _post():
buf.seek(0)
# Reset the buffer in case we had to retry.
buf.seek(buf_orig_position)

form_key['file'] = buf
# requests will not stream multipart/form-data, but _single_upload
# is only used for small file objects or non-seekable file objects
Expand Down Expand Up @@ -315,11 +320,22 @@ def _civis_to_file(file_id, buf, api_key=None, client=None):
raise EmptyResultError('Unable to locate file {}. If it previously '
'existed, it may have '
'expired.'.format(file_id))
response = requests.get(url, stream=True)
response.raise_for_status()
chunked = response.iter_content(CHUNK_SIZE)
for lines in chunked:
buf.write(lines)

# Store the current buffer position in case we need to retry below.
buf_orig_position = buf.tell()

@retry(RETRY_EXCEPTIONS)
def _download_url_to_buf():
# Reset the buffer in case we had to retry.
buf.seek(buf_orig_position)

response = requests.get(url, stream=True)
response.raise_for_status()
chunked = response.iter_content(CHUNK_SIZE)
for lines in chunked:
buf.write(lines)

_download_url_to_buf()


def file_id_from_run_output(name, job_id, run_id, regex=False, client=None):
Expand Down
46 changes: 46 additions & 0 deletions civis/tests/test_io.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections import OrderedDict
import io
import json
import os
from six import StringIO, BytesIO
import zipfile

import pytest
import requests
import vcr

try:
Expand Down Expand Up @@ -431,6 +433,50 @@ def test_civis_to_file_local(mock_requests):
mock_civis.files.get.return_value.file_url, stream=True)


@mock.patch.object(_files, 'requests', autospec=True)
def test_civis_to_file_retries(mock_requests):
mock_civis = create_client_mock()

# Mock the request iter_content so it fails partway the first time.
# Python 2.7 doesn't have the nonlocal keyword, so here's a little class to
# track whether it's the first call.
class UnreliableIterContent:
def __init__(self):
self.first_try = True

def mock_iter_content(self, _):
chunks = [l.encode() for l in 'abcdef']
for i, chunk in enumerate(chunks):

# Fail partway through on the first try.
if self.first_try and i == 3:
self.first_try = False
raise requests.ConnectionError()

yield chunk

mock_requests.get.return_value.iter_content = \
UnreliableIterContent().mock_iter_content

# Add some data to the buffer to test that we seek to the right place
# when retrying.
buf = io.BytesIO(b'0123')
buf.seek(4)

_files.civis_to_file(137, buf, client=mock_civis)

# Check that retries work and that the buffer position is reset.
# If we didn't seek when retrying, we'd get abcabcdef.
# If we seek'd to position 0, then we'd get abcdef.
buf.seek(0)
assert buf.read() == b'0123abcdef'

mock_civis.files.get.assert_called_once_with(137)
assert mock_requests.get.call_count == 2
mock_requests.get.assert_called_with(
mock_civis.files.get.return_value.file_url, stream=True)


@mock.patch.object(_files, 'requests', autospec=True)
def test_file_to_civis(mock_requests):
# Test that file_to_civis posts a Civis File with the API client
Expand Down

0 comments on commit bdb197b

Please sign in to comment.