diff --git a/.gitignore b/.gitignore index 7a7b978f..2d948221 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,5 @@ build/ *.egg .env .cache/ +.vscode +.idea diff --git a/AUTHORS.rst b/AUTHORS.rst index 797d51bd..c80b2f42 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -56,4 +56,6 @@ Patches and Suggestions - Chris van Marle (https://github.com/qistoph) -- Florence Blanc-Renaud (@flo-renaud) \ No newline at end of file +- Florence Blanc-Renaud (@flo-renaud) + +- Francis Charette-Migneault (https://github.com/fmigneault) diff --git a/requests_toolbelt/multipart/decoder.py b/requests_toolbelt/multipart/decoder.py index 2a0d1c46..13105408 100644 --- a/requests_toolbelt/multipart/decoder.py +++ b/requests_toolbelt/multipart/decoder.py @@ -59,6 +59,14 @@ def __init__(self, content, encoding): first, self.content = _split_on_find(content, b'\r\n\r\n') if first != b'': headers = _header_parser(first.lstrip(), encoding) + # process 'no content' part + elif content.endswith(b'\r\n') and not content.endswith(b'\r\n\r\n'): + self.content = None + headers = _header_parser(content.strip(), encoding) + if not headers: + raise ImproperBodyPartContentException( + 'No contents part without any header is invalid.' + ) else: raise ImproperBodyPartContentException( 'content does not contain CR-LF-CR-LF' @@ -68,6 +76,8 @@ def __init__(self, content, encoding): @property def text(self): """Content of the ``BodyPart`` in unicode.""" + if self.content is None: + return None return self.content.decode(self.encoding) diff --git a/requests_toolbelt/multipart/encoder.py b/requests_toolbelt/multipart/encoder.py index 2d539617..0c4721e7 100644 --- a/requests_toolbelt/multipart/encoder.py +++ b/requests_toolbelt/multipart/encoder.py @@ -13,6 +13,7 @@ from uuid import uuid4 import requests +from requests.structures import CaseInsensitiveDict from .._compat import fields @@ -21,7 +22,12 @@ class FileNotSupportedError(Exception): """File not supported error.""" -class MultipartEncoder(object): +class ContentIO(object): + def __init__(self, no_content=False): + self.no_content = no_content + + +class MultipartEncoder(ContentIO): """ @@ -84,13 +90,18 @@ class MultipartEncoder(object): """ - def __init__(self, fields, boundary=None, encoding='utf-8'): + def __init__(self, fields, boundary=None, encoding='utf-8', content_type='multipart/form-data'): + ContentIO.__init__(self) + #: Boundary value either passed in by the user or created self.boundary_value = boundary or uuid4().hex # Computed boundary self.boundary = '--{}'.format(self.boundary_value) + # Multipart content + self._content_type = content_type + #: Encoding of the data being passed in self.encoding = encoding @@ -191,7 +202,9 @@ def _load(self, amount): while amount == -1 or amount > 0: written = 0 if part and not part.bytes_left_to_write(): - written += self._write(b'\r\n') + # distinguish no content from empty string + if not part.body.no_content: + written += self._write(b'\r\n') written += self._write_boundary() part = self._next_part() @@ -227,13 +240,24 @@ def _iter_fields(self): file_name, file_pointer, file_type = v else: file_name, file_pointer, file_type, file_headers = v + elif isinstance(v, Part): + file_pointer = v.body + file_headers = v.headers else: file_pointer = v field = fields.RequestField(name=k, data=file_pointer, filename=file_name, headers=file_headers) - field.make_multipart(content_type=file_type) + file_headers = CaseInsensitiveDict(file_headers or {}) + file_type = file_type or file_headers.get("Content-Type") + file_loc = file_headers.get("Content-Location") + file_dis = (file_headers.get("Content-Disposition") or "").split(";", 1)[0].strip() + field.make_multipart( + content_type=file_type, + content_location=file_loc, + content_disposition=file_dis, + ) yield field def _prepare_parts(self): @@ -272,9 +296,7 @@ def _write_headers(self, headers): @property def content_type(self): - return str( - 'multipart/form-data; boundary={}'.format(self.boundary_value) - ) + return '{}; boundary={}'.format(self._content_type, self.boundary_value) def to_string(self): """Return the entirety of the data in the encoder. @@ -319,7 +341,7 @@ def IDENTITY(monitor): return monitor -class MultipartEncoderMonitor(object): +class MultipartEncoderMonitor(ContentIO): """ An object used to monitor the progress of a :class:`MultipartEncoder`. @@ -371,6 +393,8 @@ def callback(monitor): """ def __init__(self, encoder, callback=None): + ContentIO.__init__(self) + #: Instance of the :class:`MultipartEncoder` being monitored self.encoder = encoder @@ -461,6 +485,9 @@ def reset(buffer): def coerce_data(data, encoding): """Ensure that every object's __len__ behaves uniformly.""" + if data is None: + return CustomBytesIO(no_content=True) + if not isinstance(data, CustomBytesIO): if hasattr(data, 'getvalue'): return CustomBytesIO(data.getvalue(), encoding) @@ -500,6 +527,9 @@ def bytes_left_to_write(self): :returns: bool -- ``True`` if there are bytes left to write, otherwise ``False`` """ + if getattr(self.body, "finished", False): # part is a nested multipart and has finished reading + return False + to_read = 0 if self.headers_unread: to_read += len(self.headers) @@ -526,14 +556,19 @@ def write_to(self, buffer, size): if size != -1: amount_to_read = size - written written += buffer.append(self.body.read(amount_to_read)) + if getattr(self, 'bytes_left_to_write', False): + return written return written -class CustomBytesIO(io.BytesIO): - def __init__(self, buffer=None, encoding='utf-8'): +class CustomBytesIO(ContentIO, io.BytesIO): + def __init__(self, buffer=None, encoding='utf-8', no_content=False): + ContentIO.__init__(self, no_content=no_content) + if self.no_content: + buffer = None buffer = encode_with(buffer, encoding) - super(CustomBytesIO, self).__init__(buffer) + io.BytesIO.__init__(self, buffer) def _get_end(self): current_pos = self.tell() @@ -564,8 +599,9 @@ def smart_truncate(self): self.seek(0, 0) # We want to be at the beginning -class FileWrapper(object): +class FileWrapper(ContentIO): def __init__(self, file_object): + ContentIO.__init__(self) self.fd = file_object @property @@ -576,7 +612,7 @@ def read(self, length=-1): return self.fd.read(length) -class FileFromURLWrapper(object): +class FileFromURLWrapper(ContentIO): """File from URL wrapper. The :class:`FileFromURLWrapper` object gives you the ability to stream file @@ -623,6 +659,7 @@ class FileFromURLWrapper(object): """ def __init__(self, file_url, session=None): + ContentIO.__init__(self) self.session = session or requests.Session() requested_file = self._request_for_file(file_url) self.len = int(requested_file.headers['content-length']) diff --git a/tests/test_multipart_decoder.py b/tests/test_multipart_decoder.py index e9229183..7e53cf2b 100644 --- a/tests/test_multipart_decoder.py +++ b/tests/test_multipart_decoder.py @@ -92,7 +92,7 @@ def test_no_headers(self): assert part_3.content == b'No headers\r\nTwo lines' def test_no_crlf_crlf_in_content(self): - content = b'no CRLF CRLF here!\r\n' + content = b'no CRLF CRLF here!' with pytest.raises(ImproperBodyPartContentException): BodyPart(content, 'utf-8') @@ -191,3 +191,40 @@ def test_from_responsecaplarge(self): assert decoder_2.parts[0].headers[b'Header-1'] == b'Header-Value-1' assert len(decoder_2.parts[1].headers) == 0 assert decoder_2.parts[1].content == b'Body 2, Line 1' + + def test_no_content_empty_string_and_contents(self): + contents = b'\r\n'.join([ + b'--boundary', + b'Header-1: Header-Value-1', + b'Header-2: Header-Value-2', + b'', + b'--boundary', + b'Header-3: Header-Value-3', + b'Header-4: Header-Value-4', + b'', + b'some contents', + b'--boundary', + b'Header-5: Header-Value-5', + b'Header-6: Header-Value-6', + b'', + b'', + b'--boundary--', + ]) + dec = MultipartDecoder(contents, content_type='multipart/mixed; boundary="boundary"') + assert dec.parts[0].headers == {b'Header-1': b'Header-Value-1', b'Header-2': b'Header-Value-2'} + assert dec.parts[0].content is None + assert dec.parts[1].headers == {b'Header-3': b'Header-Value-3', b'Header-4': b'Header-Value-4'} + assert dec.parts[1].content == b'some contents' + assert dec.parts[2].headers == {b'Header-5': b'Header-Value-5', b'Header-6': b'Header-Value-6'} + assert dec.parts[2].content == b'' + + def test_no_content_crlf_separator_required(self): + contents = b'\r\n'.join([ + b'--boundary', + b'Header-1: Header-Value-1', + b'Header-2: Header-Value-2', + # missing CRLF here + b'--boundary', + ]) + with pytest.raises(ImproperBodyPartContentException): + MultipartDecoder(contents, content_type='multipart/mixed; boundary="boundary"') diff --git a/tests/test_multipart_encoder.py b/tests/test_multipart_encoder.py index f864487c..5a61d09e 100644 --- a/tests/test_multipart_encoder.py +++ b/tests/test_multipart_encoder.py @@ -319,5 +319,51 @@ def test_no_parts(self): output = m.read().decode('utf-8') assert output == '----90967316f8404798963cce746a4f4ef9--\r\n' + def test_nested_multipart(self): + sub_fields = { + 'item1': (None, b'data', 'text/plain', {'Extra-Header': 'data'}), + 'item2': (None, b'[1,2,3]', 'application/json', {'Extra-Header': 'json'}), + } + sub_multi = MultipartEncoder(fields=sub_fields, content_type='multipart/mixed') + fields = { + 'file': ('filename', b'{"item1": "data", "item2": [1,2,3]}', 'application/json', {'Extra-Header': 'file'}), + 'multi': (None, sub_multi, sub_multi.content_type, {'Extra-Header': 'multi'}), + } + multi = MultipartEncoder(fields=fields, content_type='multipart/alternate') + sub_boundary = sub_multi.boundary.encode().strip() + top_boundary = multi.boundary.encode().strip() + content = multi.read() + expect = b'\r\n'.join([ + top_boundary, + b'Content-Disposition: form-data; name="file"; filename="filename"', + b'Content-Type: application/json', + b'Extra-Header: file', + b'', + b'{"item1": "data", "item2": [1,2,3]}', + top_boundary, + b'Content-Disposition: form-data; name="multi"', + b'Content-Type: multipart/mixed; boundary=' + sub_boundary.strip(b'-'), + b'Extra-Header: multi', + b'', + sub_boundary, + b'Content-Disposition: form-data; name="item1"', + b'Content-Type: text/plain', + b'Extra-Header: data', + b'', + b'data', + sub_boundary, + b'Content-Disposition: form-data; name="item2"', + b'Content-Type: application/json', + b'Extra-Header: json', + b'', + b'[1,2,3]', + sub_boundary + b'--', + b'', + top_boundary + b'--', + b'', + ]) + assert content == expect + + if __name__ == '__main__': unittest.main()