Skip to content

Commit

Permalink
Endpoints: Forbid closing transport while sending data
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia committed Jun 27, 2024
1 parent b30d27b commit 7c5f2f6
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 16 deletions.
5 changes: 3 additions & 2 deletions src/easynetwork/lowlevel/api_async/endpoints/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,9 @@ async def aclose(self) -> None:
"""
Closes the endpoint.
"""
await self.__transport.aclose()
self.__receiver.clear()
with self.__send_guard:
await self.__transport.aclose()
self.__receiver.clear()

async def send_packet(self, packet: _T_SentPacket) -> None:
"""
Expand Down
5 changes: 3 additions & 2 deletions src/easynetwork/lowlevel/api_async/servers/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ async def aclose(self) -> None:
"""
Closes the endpoint.
"""
await self.__transport.aclose()
await self.__exit_stack.aclose()
with self.__send_guard:
await self.__transport.aclose()
await self.__exit_stack.aclose()

async def send_packet(self, packet: _T_Response) -> None:
"""
Expand Down
15 changes: 5 additions & 10 deletions src/easynetwork/lowlevel/api_async/transports/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ class AsyncTLSStreamTransport(AsyncStreamTransport):
_ssl_object: SSLObject
_read_bio: MemoryBIO
_write_bio: MemoryBIO
_data_deque: deque[memoryview] = dataclasses.field(init=False, default_factory=deque)
__incoming_reader: _IncomingDataReader = dataclasses.field(init=False)
__closing: bool = dataclasses.field(init=False, default=False)
__data_to_send: deque[memoryview] = dataclasses.field(init=False, default_factory=deque)

def __post_init__(self) -> None:
self.__incoming_reader = _IncomingDataReader(transport=self._transport)
Expand Down Expand Up @@ -160,7 +160,7 @@ async def aclose(self) -> None:
already_closing = self.__closing
with contextlib.ExitStack() as stack:
stack.callback(self.__incoming_reader.close)
stack.callback(self.__data_to_send.clear)
stack.callback(self._data_deque.clear)

self.__closing = True
if not already_closing and self._standard_compatible and not self._transport.is_closing():
Expand Down Expand Up @@ -215,22 +215,22 @@ async def recv_into(self, buffer: WriteableBuffer) -> int:
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
if self.__closing:
raise _utils.error_from_errno(errno.ECONNABORTED)
self.__data_to_send.append(memoryview(data))
self._data_deque.append(memoryview(data))
del data
return await self.__flush_data_to_send()

@_utils.inherit_doc(AsyncStreamTransport)
async def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | memoryview]) -> None:
if self.__closing:
raise _utils.error_from_errno(errno.ECONNABORTED)
self.__data_to_send.extend(map(memoryview, iterable_of_data))
self._data_deque.extend(map(memoryview, iterable_of_data))
del iterable_of_data
return await self.__flush_data_to_send()

async def __flush_data_to_send(self) -> None:
assert _ssl_module is not None, "stdlib ssl module not available" # nosec assert_used
try:
await self._retry_ssl_method(self.__write_all_to_ssl_object, self._ssl_object, self.__data_to_send)
await self._retry_ssl_method(self.__write_all_to_ssl_object, self._ssl_object, self._data_deque)
except _ssl_module.SSLZeroReturnError as exc:
raise _utils.error_from_errno(errno.ECONNRESET) from exc

Expand Down Expand Up @@ -292,11 +292,6 @@ def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
socket_tools.TLSAttribute.standard_compatible: lambda: self._standard_compatible,
}

if __debug__:

def _test__data_queue(self) -> list[bytes]:
return list(map(bytes, self.__data_to_send))


class AsyncTLSListener(AsyncListener[AsyncTLSStreamTransport]):
__slots__ = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ async def test____send_all____closed_transport(

mock_ssl_object.write.assert_not_called()
mock_tls_transport_retry.assert_not_awaited()
assert tls_transport._test__data_queue() == []
assert len(tls_transport._data_deque) == 0

@pytest.mark.usefixtures("mock_tls_transport_retry")
async def test____send_all_from_iterable____default(
Expand Down Expand Up @@ -801,7 +801,7 @@ async def test____send_all_from_iterable____closed_transport(

mock_ssl_object.write.assert_not_called()
mock_tls_transport_retry.assert_not_awaited()
assert tls_transport._test__data_queue() == []
assert len(tls_transport._data_deque) == 0

async def test____send_eof____default(
self,
Expand Down

0 comments on commit 7c5f2f6

Please sign in to comment.