diff --git a/requests_toolbelt/exceptions.py b/requests_toolbelt/exceptions.py index 32ade215..a80043a8 100644 --- a/requests_toolbelt/exceptions.py +++ b/requests_toolbelt/exceptions.py @@ -35,3 +35,17 @@ class IgnoringGAECertificateValidation(Warning): In :class:`requests_toolbelt.adapters.appengine.InsecureAppEngineAdapter`. """ pass + + +class MultipartEncoderSourceShrunkError(Exception): + """Used to indicate that a data source passed to + :class:`requests_toolbelt.multipart.encoder.MultipartEncoder` shrunk in + size during encoding. + + If the data source shrinks during encoding, the overall encoded data + length would no longer match the initially calculated length (sent to the + server as ``Content-Length``), potentially causing an HTTP 408 request + timeout error. + + """ + pass diff --git a/requests_toolbelt/multipart/encoder.py b/requests_toolbelt/multipart/encoder.py index 2d539617..d4117a26 100644 --- a/requests_toolbelt/multipart/encoder.py +++ b/requests_toolbelt/multipart/encoder.py @@ -15,6 +15,7 @@ import requests from .._compat import fields +from ..exceptions import MultipartEncoderSourceShrunkError class FileNotSupportedError(Exception): @@ -69,6 +70,17 @@ class MultipartEncoder(object): {'X-My-Header': 'my-value'}) ]) + .. warning:: + + If the content of the field's stream (e.g. the underlying file) changes + during encoding, the uploaded data may contain a mix of old and new + data. The size of the stream must not shrink during encoding - if it + does, a + :class:`~requests_toolbelt.exceptions.MultipartEncoderSourceShrunkError` + will be raised. It is acceptable for a stream to grow, but only + the amount of data present at the start of encoding will be + transmitted. + .. warning:: This object will end up directly in :mod:`httplib`. Currently, @@ -485,7 +497,8 @@ def __init__(self, headers, body): self.headers = headers self.body = body self.headers_unread = True - self.len = len(self.headers) + total_len(self.body) + self.body_len = total_len(self.body) + self.len = len(self.headers) + self.body_len @classmethod def from_field(cls, field, encoding): @@ -504,7 +517,7 @@ def bytes_left_to_write(self): if self.headers_unread: to_read += len(self.headers) - return (to_read + total_len(self.body)) > 0 + return (to_read + self.body_len) > 0 def write_to(self, buffer, size): """Write the requested amount of bytes to the buffer provided. @@ -521,11 +534,27 @@ def write_to(self, buffer, size): written += buffer.append(self.headers) self.headers_unread = False - while total_len(self.body) > 0 and (size == -1 or written < size): + while self.body_len > 0 and (size == -1 or written < size): + # Check that the body hasn't shrunk since we started the encoding + if total_len(self.body) < self.body_len: + raise MultipartEncoderSourceShrunkError() + amount_to_read = size if size != -1: amount_to_read = size - written - written += buffer.append(self.body.read(amount_to_read)) + + if amount_to_read == -1: + amount_to_read = self.body_len + + # Cap amount of data read based on the initial body length (even + # if there is now more data available), since this amount is what + # we have committed to in the Content-Length header. + amount_to_read = min(self.body_len, amount_to_read) + + body_data = self.body.read(amount_to_read) + self.body_len -= len(body_data) + + written += buffer.append(body_data) return written diff --git a/tests/test_multipart_encoder.py b/tests/test_multipart_encoder.py index 575f54c4..430e7ef1 100644 --- a/tests/test_multipart_encoder.py +++ b/tests/test_multipart_encoder.py @@ -7,6 +7,7 @@ from requests_toolbelt.multipart.encoder import ( CustomBytesIO, MultipartEncoder, FileFromURLWrapper, FileNotSupportedError) from requests_toolbelt._compat import filepost +from requests_toolbelt.exceptions import MultipartEncoderSourceShrunkError from . import get_betamax @@ -191,6 +192,38 @@ def test_reads_open_file_objects(self): m = MultipartEncoder([('field', 'foo'), ('file', fd)]) assert m.read() is not None + def test_reads_growing_sources(self): + large_file = LargeFileMock() + large_file.bytes_max = 128 + m = MultipartEncoder([('file', large_file)], boundary=self.boundary) + # This is the value that ends up in the Content-Length header + initial_len = m.len + # Start the encoding, including some of the stream contents + data = m.read(128) + # ...meanwhile, the stream grows + large_file.bytes_max *= 2 + # Ensure that the additional data is not transmitted + assert m.len == initial_len + data += m.read() + assert data == ( + '--this-is-a-boundary\r\n' + 'Content-Disposition: form-data; name="file"\r\n\r\n' + + ('a' * 128) + '\r\n' + '--this-is-a-boundary--\r\n' + ).encode() + + def test_reads_shrinking_sources(self): + large_file = LargeFileMock() + m = MultipartEncoder([('file', large_file)], boundary=self.boundary) + initial_len = m.len + # Start the encoding + data = m.read(1) + # ...meanwhile, the source stream shrinks + large_file.bytes_max = large_file.bytes_read + # Further encoding fails + with self.assertRaises(MultipartEncoderSourceShrunkError): + m.read() + def test_reads_file_from_url_wrapper(self): s = requests.Session() recorder = get_betamax(s)