diff --git a/src/easynetwork/lowlevel/api_async/transports/tls.py b/src/easynetwork/lowlevel/api_async/transports/tls.py index e28fccf3..ce09b9d0 100644 --- a/src/easynetwork/lowlevel/api_async/transports/tls.py +++ b/src/easynetwork/lowlevel/api_async/transports/tls.py @@ -213,15 +213,17 @@ async def recv_into(self, buffer: WriteableBuffer) -> int: @_utils.inherit_doc(AsyncStreamTransport) async def send_all(self, data: bytes | bytearray | memoryview) -> None: - if not self.__closing: - self.__data_to_send.append(memoryview(data)) + if self.__closing: + raise _utils.error_from_errno(errno.ECONNABORTED) + self.__data_to_send.append(memoryview(data)) del data return await self.__flush_data_to_send() @_utils.inherit_doc(AsyncStreamTransport) async def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | memoryview]) -> None: - if not self.__closing: - self.__data_to_send.extend(map(memoryview, iterable_of_data)) + if self.__closing: + raise _utils.error_from_errno(errno.ECONNABORTED) + self.__data_to_send.extend(map(memoryview, iterable_of_data)) del iterable_of_data return await self.__flush_data_to_send() @@ -254,8 +256,6 @@ async def _retry_ssl_method( *args: *_T_PosArgs, ) -> _T_Return: assert _ssl_module is not None, "stdlib ssl module not available" # nosec assert_used - if self.__closing: - raise _utils.error_from_errno(errno.ECONNABORTED) while True: try: result = ssl_object_method(*args) diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_transports/test_tls.py b/tests/unit_test/test_async/test_lowlevel_api/test_transports/test_tls.py index 476bf362..8fb3fa78 100644 --- a/tests/unit_test/test_async/test_lowlevel_api/test_transports/test_tls.py +++ b/tests/unit_test/test_async/test_lowlevel_api/test_transports/test_tls.py @@ -102,8 +102,12 @@ async def tls_transport( @pytest.fixture @staticmethod def mock_tls_transport_retry(mocker: MockerFixture) -> AsyncMock: - side_effect: Callable[..., Any] = lambda ssl_object_method, *args: ssl_object_method(*args) - return mocker.patch.object(AsyncTLSStreamTransport, "_retry_ssl_method", side_effect=side_effect) + return mocker.patch.object( + AsyncTLSStreamTransport, + "_retry_ssl_method", + autospec=True, + wraps=AsyncTLSStreamTransport._retry_ssl_method, + ) async def test____wrap____default( self, @@ -135,7 +139,7 @@ async def test____wrap____default( server_hostname="server_hostname", session=None, ) - mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.do_handshake) + mock_tls_transport_retry.assert_awaited_once_with(tls_transport, mock_ssl_object.do_handshake) assert mock_ssl_object.mock_calls == [mocker.call.do_handshake(), mocker.call.getpeercert()] ## Attributes assert tls_transport._shutdown_timeout == DEFAULT_SSL_SHUTDOWN_TIMEOUT @@ -199,7 +203,7 @@ async def test____wrap____with_parameters( server_hostname=server_hostname, session=session, ) - mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.do_handshake) + mock_tls_transport_retry.assert_awaited_once_with(tls_transport, mock_ssl_object.do_handshake) assert mock_ssl_object.mock_calls == [mocker.call.do_handshake(), mocker.call.getpeercert()] assert mock_wrapped_transport.mock_calls == [mocker.call.backend()] ## Attributes @@ -257,7 +261,7 @@ async def test____wrap____handshake_timeout( mock_ssl_object: MagicMock, ) -> None: # Arrange - async def retry_side_effect(ssl_object_method: Callable[..., Any], *args: Any) -> Any: + async def retry_side_effect(self: AsyncTLSStreamTransport, ssl_object_method: Callable[..., Any], *args: Any) -> Any: await asyncio.sleep(5) return ssl_object_method(*args) @@ -325,7 +329,6 @@ async def test____dunder_del____ResourceWarning( async def test____aclose____close_transport( self, tls_transport: AsyncTLSStreamTransport, - mock_tls_transport_retry: AsyncMock, mock_wrapped_transport: MagicMock, mock_ssl_object: MagicMock, standard_compatible: bool, @@ -341,22 +344,19 @@ async def test____aclose____close_transport( # Assert assert tls_transport.is_closing() if standard_compatible: - mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.unwrap) mock_ssl_object.unwrap.assert_called_once_with() assert read_bio.eof assert write_bio.eof else: - mock_tls_transport_retry.assert_not_called() mock_ssl_object.unwrap.assert_not_called() assert not read_bio.eof assert not write_bio.eof mock_wrapped_transport.aclose.assert_awaited_once_with() @pytest.mark.parametrize("standard_compatible", [False, True], indirect=True, ids=lambda p: f"standard_compatible=={p}") - async def test____aclose____indempotent( + async def test____aclose____idempotent( self, tls_transport: AsyncTLSStreamTransport, - mock_tls_transport_retry: AsyncMock, mock_wrapped_transport: MagicMock, mock_ssl_object: MagicMock, standard_compatible: bool, @@ -373,12 +373,10 @@ async def test____aclose____indempotent( # Assert if standard_compatible: - mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.unwrap) mock_ssl_object.unwrap.assert_called_once_with() assert read_bio.eof assert write_bio.eof else: - mock_tls_transport_retry.assert_not_called() mock_ssl_object.unwrap.assert_not_called() assert not read_bio.eof assert not write_bio.eof @@ -396,7 +394,7 @@ async def test____aclose____shutdown_timeout( write_bio: ssl.MemoryBIO, ) -> None: # Arrange - async def retry_side_effect(ssl_object_method: Callable[..., Any], *args: Any) -> Any: + async def retry_side_effect(self: AsyncTLSStreamTransport, ssl_object_method: Callable[..., Any], *args: Any) -> Any: await asyncio.sleep(5) return ssl_object_method(*args) @@ -406,7 +404,7 @@ async def retry_side_effect(ssl_object_method: Callable[..., Any], *args: Any) - await tls_transport.aclose() # Assert - mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.unwrap) + mock_tls_transport_retry.assert_awaited_once_with(tls_transport, mock_ssl_object.unwrap) assert not read_bio.eof assert not write_bio.eof mock_wrapped_transport.aclose.assert_awaited_once_with() @@ -422,7 +420,7 @@ async def test____aclose____mask_unwrap_error( write_bio: ssl.MemoryBIO, ) -> None: # Arrange - async def retry_side_effect(ssl_object_method: Callable[..., Any], *args: Any) -> Any: + async def retry_side_effect(self: AsyncTLSStreamTransport, ssl_object_method: Callable[..., Any], *args: Any) -> Any: try: return ssl_object_method(*args) except ssl.SSLError: @@ -437,7 +435,7 @@ async def retry_side_effect(ssl_object_method: Callable[..., Any], *args: Any) - await tls_transport.aclose() # Assert - mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.unwrap) + mock_tls_transport_retry.assert_awaited_once_with(tls_transport, mock_ssl_object.unwrap) assert read_bio.eof assert write_bio.eof mock_wrapped_transport.aclose.assert_awaited_once_with() @@ -456,7 +454,7 @@ async def test____recv____default( # Assert assert data == b"decrypted-data" - mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.read, 123456) + mock_tls_transport_retry.assert_awaited_once_with(tls_transport, mock_ssl_object.read, 123456) mock_ssl_object.read.assert_called_once_with(123456) async def test____recv____null_buffer( @@ -473,7 +471,7 @@ async def test____recv____null_buffer( # Assert assert data == b"" - mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.read, 0) + mock_tls_transport_retry.assert_awaited_once_with(tls_transport, mock_ssl_object.read, 0) mock_ssl_object.read.assert_called_once_with(0) @pytest.mark.parametrize("standard_compatible", [False, True], indirect=True, ids=lambda p: f"standard_compatible=={p}") @@ -540,7 +538,7 @@ async def test____recv_into____default( # Assert assert nbytes == 42 - mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.read, 1234, buffer) + mock_tls_transport_retry.assert_awaited_once_with(tls_transport, mock_ssl_object.read, 1234, buffer) mock_ssl_object.read.assert_called_once_with(1234, buffer) async def test____recv_into____null_buffer( @@ -558,7 +556,7 @@ async def test____recv_into____null_buffer( # Assert assert nbytes == 0 - mock_tls_transport_retry.assert_awaited_once_with(mock_ssl_object.read, 1024, buffer) + mock_tls_transport_retry.assert_awaited_once_with(tls_transport, mock_ssl_object.read, 1024, buffer) mock_ssl_object.read.assert_called_once_with(1024, buffer) @pytest.mark.parametrize("standard_compatible", [False, True], indirect=True, ids=lambda p: f"standard_compatible=={p}") @@ -728,17 +726,20 @@ async def test____send_all____unrelated_ssl_error( async def test____send_all____closed_transport( self, tls_transport: AsyncTLSStreamTransport, + mock_tls_transport_retry: AsyncMock, mock_ssl_object: MagicMock, ) -> None: # Arrange mock_ssl_object.unwrap.return_value = None await tls_transport.aclose() + mock_tls_transport_retry.reset_mock() # Act & Assert with pytest.raises(ConnectionAbortedError): await tls_transport.send_all(b"decrypted-data") mock_ssl_object.write.assert_not_called() + mock_tls_transport_retry.assert_not_awaited() assert tls_transport._test__data_queue() == [] @pytest.mark.usefixtures("mock_tls_transport_retry") @@ -786,17 +787,20 @@ async def test____send_all_from_iterable____partial_data( async def test____send_all_from_iterable____closed_transport( self, tls_transport: AsyncTLSStreamTransport, + mock_tls_transport_retry: AsyncMock, mock_ssl_object: MagicMock, ) -> None: # Arrange mock_ssl_object.unwrap.return_value = None await tls_transport.aclose() + mock_tls_transport_retry.reset_mock() # Act & Assert with pytest.raises(ConnectionAbortedError): await tls_transport.send_all_from_iterable([b"decrypted-data-1", b"decrypted-data-2"]) mock_ssl_object.write.assert_not_called() + mock_tls_transport_retry.assert_not_awaited() assert tls_transport._test__data_queue() == [] async def test____send_eof____default( @@ -992,24 +996,6 @@ async def test____retry____unrelated_ssl_error( assert read_bio.eof assert write_bio.eof - @pytest.mark.parametrize("standard_compatible", [False], ids=lambda p: f"standard_compatible=={p}") - async def test____retry____closed_transport( - self, - tls_transport: AsyncTLSStreamTransport, - mock_wrapped_transport: MagicMock, - mocker: MockerFixture, - ) -> None: - # Arrange - mock_wrapped_transport.send_all.return_value = None - ssl_object_method = mocker.stub() - await tls_transport.aclose() - - # Act & Assert - with pytest.raises(ConnectionAbortedError): - await tls_transport._retry_ssl_method(ssl_object_method) - - ssl_object_method.assert_not_called() - async def test____get_backend____returns_inner_transport_backend( self, tls_transport: AsyncTLSStreamTransport,