diff --git a/src/easynetwork/lowlevel/api_async/transports/tls.py b/src/easynetwork/lowlevel/api_async/transports/tls.py index 897d3c68..a1cb5f9f 100644 --- a/src/easynetwork/lowlevel/api_async/transports/tls.py +++ b/src/easynetwork/lowlevel/api_async/transports/tls.py @@ -24,7 +24,8 @@ import functools import logging import warnings -from collections.abc import Callable, Coroutine, Mapping +from collections import deque +from collections.abc import Callable, Coroutine, Iterable, Mapping from typing import TYPE_CHECKING, Any, Final, NoReturn, Self, TypeVar, TypeVarTuple try: @@ -64,6 +65,7 @@ class AsyncTLSStreamTransport(AsyncStreamTransport): _write_bio: MemoryBIO __incoming_reader: _IncomingDataReader = dataclasses.field(init=False) __closing: bool = dataclasses.field(init=False, default=False) + __data_to_send: deque[memoryview] = dataclasses.field(init=False, default_factory=deque) def __post_init__(self) -> None: self.__incoming_reader = _IncomingDataReader(transport=self._transport) @@ -203,12 +205,35 @@ async def recv_into(self, buffer: WriteableBuffer) -> int: @_utils.inherit_doc(AsyncStreamTransport) async def send_all(self, data: bytes | bytearray | memoryview) -> None: + self.__data_to_send.append(memoryview(data)) + del data + return await self.__flush_data_to_send() + + @_utils.inherit_doc(AsyncStreamTransport) + async def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | memoryview]) -> None: + self.__data_to_send.extend(map(memoryview, iterable_of_data)) + del iterable_of_data + return await self.__flush_data_to_send() + + async def __flush_data_to_send(self) -> None: assert _ssl_module is not None, "stdlib ssl module not available" # nosec assert_used try: - await self._retry_ssl_method(self._ssl_object.write, data) + await self._retry_ssl_method(self.__write_all_to_ssl_object, self._ssl_object, self.__data_to_send) except _ssl_module.SSLZeroReturnError as exc: raise _utils.error_from_errno(errno.ECONNRESET) from exc + @staticmethod + def __write_all_to_ssl_object(ssl_object: SSLObject, write_backlog: deque[memoryview]) -> None: + while write_backlog: + data = write_backlog[0] + if data.itemsize != 1: + write_backlog[0] = data = data.cast("B") + sent = ssl_object.write(data) + if sent < len(data): + write_backlog[0] = data[sent:] + else: + del write_backlog[0] + @_utils.inherit_doc(AsyncStreamTransport) async def send_eof(self) -> None: raise UnsupportedOperation("SSL/TLS API does not support sending EOF.") diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_transports/test_tls.py b/tests/unit_test/test_async/test_lowlevel_api/test_transports/test_tls.py index 08852931..51293dbb 100644 --- a/tests/unit_test/test_async/test_lowlevel_api/test_transports/test_tls.py +++ b/tests/unit_test/test_async/test_lowlevel_api/test_transports/test_tls.py @@ -550,10 +550,10 @@ async def test____recv_into____unrelated_ssl_error( with pytest.raises(ssl.SSLError): _ = await tls_transport.recv_into(buffer) + @pytest.mark.usefixtures("mock_tls_transport_retry") async def test____send_all____default( self, tls_transport: AsyncTLSStreamTransport, - mock_tls_transport_retry: AsyncMock, mock_ssl_object: MagicMock, ) -> None: # Arrange @@ -563,13 +563,51 @@ async def test____send_all____default( await tls_transport.send_all(b"decrypted-data") # Assert - mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.write, b"decrypted-data") mock_ssl_object.write.assert_called_once_with(b"decrypted-data") + @pytest.mark.usefixtures("mock_tls_transport_retry") + async def test____send_all____partial_data( + self, + tls_transport: AsyncTLSStreamTransport, + mock_ssl_object: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + data = b"decrypted-data" + mock_ssl_object.write.side_effect = [len(data) - 4, 4] + + # Act + await tls_transport.send_all(data) + + # Assert + assert mock_ssl_object.write.mock_calls == [ + mocker.call(b"decrypted-data"), + mocker.call(b"data"), + ] + + @pytest.mark.usefixtures("mock_tls_transport_retry") + async def test____send_all____properly_handle_views_with_different_size( + self, + tls_transport: AsyncTLSStreamTransport, + mock_ssl_object: MagicMock, + ) -> None: + # Arrange + import array + + data = array.array("I", [42, 56]) + + mock_ssl_object.write.side_effect = lambda data: memoryview(data).nbytes + + # Act + await tls_transport.send_all(memoryview(data)) + + # Assert + mock_ssl_object.write.assert_called_once_with(memoryview(data).cast("B")) + + @pytest.mark.usefixtures("mock_tls_transport_retry") async def test____send_all____null_buffer( self, tls_transport: AsyncTLSStreamTransport, - mock_tls_transport_retry: AsyncMock, mock_ssl_object: MagicMock, ) -> None: # Arrange @@ -579,7 +617,6 @@ async def test____send_all____null_buffer( await tls_transport.send_all(b"") # Assert - mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.write, b"") mock_ssl_object.write.assert_called_once_with(b"") @pytest.mark.parametrize("standard_compatible", [False, True], indirect=True, ids=lambda p: f"standard_compatible=={p}") @@ -624,6 +661,47 @@ async def test____send_all____unrelated_ssl_error( with pytest.raises(ssl.SSLError): await tls_transport.send_all(b"decrypted-data") + @pytest.mark.usefixtures("mock_tls_transport_retry") + async def test____send_all_from_iterable____default( + self, + tls_transport: AsyncTLSStreamTransport, + mock_ssl_object: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + mock_ssl_object.write.side_effect = lambda data: memoryview(data).nbytes + + # Act + await tls_transport.send_all_from_iterable([b"decrypted-data-1", b"decrypted-data-2"]) + + # Assert + assert mock_ssl_object.write.mock_calls == [ + mocker.call(b"decrypted-data-1"), + mocker.call(b"decrypted-data-2"), + ] + + @pytest.mark.usefixtures("mock_tls_transport_retry") + async def test____send_all_from_iterable____partial_data( + self, + tls_transport: AsyncTLSStreamTransport, + mock_ssl_object: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + data_list = [b"decrypted-data-1", b"decrypted-data-2"] + mock_ssl_object.write.side_effect = [len(data_list[0]) - 4, 4, len(data_list[1]) - 6, 6] + + # Act + await tls_transport.send_all_from_iterable(data_list) + + # Assert + assert mock_ssl_object.write.mock_calls == [ + mocker.call(b"decrypted-data-1"), + mocker.call(b"ta-1"), + mocker.call(b"decrypted-data-2"), + mocker.call(b"data-2"), + ] + async def test____send_eof____default( self, tls_transport: AsyncTLSStreamTransport,