diff --git a/src/easynetwork/lowlevel/api_async/transports/tls.py b/src/easynetwork/lowlevel/api_async/transports/tls.py index da23c757..a2c35dd4 100644 --- a/src/easynetwork/lowlevel/api_async/transports/tls.py +++ b/src/easynetwork/lowlevel/api_async/transports/tls.py @@ -41,7 +41,7 @@ from ....exceptions import UnsupportedOperation from ... import _utils, constants, socket as socket_tools -from ..backend.abc import AsyncBackend, TaskGroup +from ..backend.abc import AsyncBackend, IEvent, TaskGroup from .abc import AsyncListener, AsyncStreamReadTransport, AsyncStreamTransport from .utils import aclose_forcefully @@ -69,9 +69,11 @@ class AsyncTLSStreamTransport(AsyncStreamTransport): _data_deque: deque[memoryview] = dataclasses.field(init=False, default_factory=deque) __incoming_reader: _IncomingDataReader = dataclasses.field(init=False) __closing: bool = dataclasses.field(init=False, default=False) + __closed: IEvent = dataclasses.field(init=False) def __post_init__(self) -> None: self.__incoming_reader = _IncomingDataReader(transport=self._transport) + self.__closed = self._transport.backend().create_event() @classmethod async def wrap( @@ -160,13 +162,16 @@ def is_closing(self) -> bool: async def aclose(self) -> None: assert _ssl_module is not None, "stdlib ssl module not available" # nosec assert_used - already_closing = self.__closing + if self.__closing: + await self.__closed.wait() + return with contextlib.ExitStack() as stack: + stack.callback(self.__closed.set) stack.callback(self.__incoming_reader.close) stack.callback(self._data_deque.clear) self.__closing = True - if not already_closing and self._standard_compatible and not self._transport.is_closing(): + if self._standard_compatible and not self._transport.is_closing(): with self._transport.backend().move_on_after(self._shutdown_timeout) as shutdown_timeout_scope: try: try: diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_transports/test_tls.py b/tests/unit_test/test_async/test_lowlevel_api/test_transports/test_tls.py index 327a0d41..1c079b6a 100644 --- a/tests/unit_test/test_async/test_lowlevel_api/test_transports/test_tls.py +++ b/tests/unit_test/test_async/test_lowlevel_api/test_transports/test_tls.py @@ -96,6 +96,7 @@ async def tls_transport( _read_bio=read_bio, _write_bio=write_bio, ) + mock_wrapped_transport.reset_mock() async with transport: yield transport @@ -205,7 +206,7 @@ async def test____wrap____with_parameters( ) mock_tls_transport_retry.assert_awaited_once_with(tls_transport, mock_ssl_object.do_handshake) assert mock_ssl_object.mock_calls == [mocker.call.do_handshake(), mocker.call.getpeercert()] - assert mock_wrapped_transport.mock_calls == [mocker.call.backend()] + assert mock_wrapped_transport.mock_calls == [mocker.call.backend(), mocker.call.backend()] ## Attributes assert tls_transport._shutdown_timeout == shutdown_timeout assert tls_transport.extra(mocker.sentinel.attr_1) is mocker.sentinel.value_1 @@ -380,7 +381,7 @@ async def test____aclose____idempotent( mock_ssl_object.unwrap.assert_not_called() assert not read_bio.eof assert not write_bio.eof - assert mock_wrapped_transport.aclose.await_count == 2 + mock_wrapped_transport.aclose.assert_awaited_once() @pytest.mark.parametrize("standard_compatible", [True], indirect=True, ids=lambda p: f"standard_compatible=={p}") @pytest.mark.parametrize("shutdown_timeout", [1], indirect=True, ids=lambda p: f"shutdown_timeout=={p}")