diff --git a/adafruit_requests.py b/adafruit_requests.py index 8b0539e..6231185 100644 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -50,7 +50,7 @@ if not sys.implementation.name == "circuitpython": from types import TracebackType - from typing import Any, Dict, Optional, Type + from typing import IO, Any, Dict, Optional, Type from circuitpython_typing.socket import ( SocketpoolModuleType, @@ -387,19 +387,7 @@ def _build_boundary_data(self, files: dict): # pylint: disable=too-many-locals boundary_objects.append("\r\n") if hasattr(file_handle, "read"): - is_binary = False - try: - content = file_handle.read(1) - is_binary = isinstance(content, bytes) - except UnicodeError: - is_binary = False - - if not is_binary: - raise ValueError("Files must be opened in binary mode") - - file_handle.seek(0, SEEK_END) - content_length += file_handle.tell() - file_handle.seek(0) + content_length += self._get_file_length(file_handle) boundary_objects.append(file_handle) boundary_objects.append("\r\n") @@ -428,6 +416,25 @@ def _check_headers(headers: Dict[str, str]): f"Header part ({value}) from {key} must be of type str or bytes, not {type(value)}" ) + @staticmethod + def _get_file_length(file_handle: IO): + is_binary = False + try: + file_handle.seek(0) + # read at least 4 bytes incase we are reading a b64 stream + content = file_handle.read(4) + is_binary = isinstance(content, bytes) + except UnicodeError: + is_binary = False + + if not is_binary: + raise ValueError("Files must be opened in binary mode") + + file_handle.seek(0, SEEK_END) + content_length = file_handle.tell() + file_handle.seek(0) + return content_length + @staticmethod def _send(socket: SocketType, data: bytes): total_sent = 0 @@ -458,13 +465,16 @@ def _send_boundary_objects(self, socket: SocketType, boundary_objects: Any): if isinstance(boundary_object, str): self._send_as_bytes(socket, boundary_object) else: - chunk_size = 32 - b = bytearray(chunk_size) - while True: - size = boundary_object.readinto(b) - if size == 0: - break - self._send(socket, b[:size]) + self._send_file(socket, boundary_object) + + def _send_file(self, socket: SocketType, file_handle: IO): + chunk_size = 36 + b = bytearray(chunk_size) + while True: + size = file_handle.readinto(b) + if size == 0: + break + self._send(socket, b[:size]) def _send_header(self, socket, header, value): if value is None: @@ -517,12 +527,16 @@ def _send_request( # pylint: disable=too-many-arguments # If files are send, build data to send and calculate length content_length = 0 + data_is_file = False boundary_objects = None if files and isinstance(files, dict): boundary_string, content_length, boundary_objects = ( self._build_boundary_data(files) ) content_type_header = f"multipart/form-data; boundary={boundary_string}" + elif data and hasattr(data, "read"): + data_is_file = True + content_length = self._get_file_length(data) else: if data is None: data = b"" @@ -551,7 +565,9 @@ def _send_request( # pylint: disable=too-many-arguments self._send(socket, b"\r\n") # Send data - if data: + if data_is_file: + self._send_file(socket, data) + elif data: self._send(socket, bytes(data)) elif boundary_objects: self._send_boundary_objects(socket, boundary_objects) diff --git a/tests/files_test.py b/tests/files_test.py index 8cac77c..4026b9d 100644 --- a/tests/files_test.py +++ b/tests/files_test.py @@ -50,7 +50,7 @@ def get_actual_request_data(log_stream): boundary = boundary_search[0] if content_length_search: content_length = content_length_search[0] - if "Content-Disposition" in log_arg: + if "Content-Disposition" in log_arg or "\\x" in log_arg: # this will look like: # b\'{content}\' # and escaped characters look like: @@ -63,6 +63,28 @@ def get_actual_request_data(log_stream): return boundary, content_length, actual_request_post +def test_post_file_as_data( # pylint: disable=unused-argument + requests, sock, log_stream, post_url, request_logging +): + with open("tests/files/red_green.png", "rb") as file_1: + python_requests.post(post_url, data=file_1, timeout=30) + __, content_length, actual_request_post = get_actual_request_data(log_stream) + + requests.post("http://" + mocket.MOCK_HOST_1 + "/post", data=file_1) + + sock.connect.assert_called_once_with((mocket.MOCK_POOL_IP, 80)) + sock.send.assert_has_calls( + [ + mock.call(b"Content-Length"), + mock.call(b": "), + mock.call(content_length.encode()), + mock.call(b"\r\n"), + ] + ) + sent = b"".join(sock.sent_data) + assert sent.endswith(actual_request_post) + + def test_post_files_text( # pylint: disable=unused-argument sock, requests, log_stream, post_url, request_logging ):