Skip to content

Commit

Permalink
Async TLS: Ensure wrapped transport is closed once (#333)
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia authored Jul 28, 2024
1 parent c560002 commit 9d4bf79
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
11 changes: 8 additions & 3 deletions src/easynetwork/lowlevel/api_async/transports/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

from ....exceptions import UnsupportedOperation
from ... import _utils, constants, socket as socket_tools
from ..backend.abc import AsyncBackend, TaskGroup
from ..backend.abc import AsyncBackend, IEvent, TaskGroup
from .abc import AsyncListener, AsyncStreamReadTransport, AsyncStreamTransport
from .utils import aclose_forcefully

Expand Down Expand Up @@ -69,9 +69,11 @@ class AsyncTLSStreamTransport(AsyncStreamTransport):
_data_deque: deque[memoryview] = dataclasses.field(init=False, default_factory=deque)
__incoming_reader: _IncomingDataReader = dataclasses.field(init=False)
__closing: bool = dataclasses.field(init=False, default=False)
__closed: IEvent = dataclasses.field(init=False)

def __post_init__(self) -> None:
self.__incoming_reader = _IncomingDataReader(transport=self._transport)
self.__closed = self._transport.backend().create_event()

@classmethod
async def wrap(
Expand Down Expand Up @@ -160,13 +162,16 @@ def is_closing(self) -> bool:
async def aclose(self) -> None:
assert _ssl_module is not None, "stdlib ssl module not available" # nosec assert_used

already_closing = self.__closing
if self.__closing:
await self.__closed.wait()
return
with contextlib.ExitStack() as stack:
stack.callback(self.__closed.set)
stack.callback(self.__incoming_reader.close)
stack.callback(self._data_deque.clear)

self.__closing = True
if not already_closing and self._standard_compatible and not self._transport.is_closing():
if self._standard_compatible and not self._transport.is_closing():
with self._transport.backend().move_on_after(self._shutdown_timeout) as shutdown_timeout_scope:
try:
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ async def tls_transport(
_read_bio=read_bio,
_write_bio=write_bio,
)
mock_wrapped_transport.reset_mock()
async with transport:
yield transport

Expand Down Expand Up @@ -205,7 +206,7 @@ async def test____wrap____with_parameters(
)
mock_tls_transport_retry.assert_awaited_once_with(tls_transport, mock_ssl_object.do_handshake)
assert mock_ssl_object.mock_calls == [mocker.call.do_handshake(), mocker.call.getpeercert()]
assert mock_wrapped_transport.mock_calls == [mocker.call.backend()]
assert mock_wrapped_transport.mock_calls == [mocker.call.backend(), mocker.call.backend()]
## Attributes
assert tls_transport._shutdown_timeout == shutdown_timeout
assert tls_transport.extra(mocker.sentinel.attr_1) is mocker.sentinel.value_1
Expand Down Expand Up @@ -380,7 +381,7 @@ async def test____aclose____idempotent(
mock_ssl_object.unwrap.assert_not_called()
assert not read_bio.eof
assert not write_bio.eof
assert mock_wrapped_transport.aclose.await_count == 2
mock_wrapped_transport.aclose.assert_awaited_once()

@pytest.mark.parametrize("standard_compatible", [True], indirect=True, ids=lambda p: f"standard_compatible=={p}")
@pytest.mark.parametrize("shutdown_timeout", [1], indirect=True, ids=lambda p: f"shutdown_timeout=={p}")
Expand Down

0 comments on commit 9d4bf79

Please sign in to comment.