Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TLS transport: Fixed SSL object consistency errors when sending data and closing the object #309

Merged
merged 7 commits into from
Jun 27, 2024
5 changes: 3 additions & 2 deletions src/easynetwork/lowlevel/api_async/endpoints/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,9 @@ async def aclose(self) -> None:
"""
Closes the endpoint.
"""
await self.__transport.aclose()
self.__receiver.clear()
with self.__send_guard:
await self.__transport.aclose()
self.__receiver.clear()

async def send_packet(self, packet: _T_SentPacket) -> None:
"""
Expand Down
5 changes: 3 additions & 2 deletions src/easynetwork/lowlevel/api_async/servers/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ async def aclose(self) -> None:
"""
Closes the endpoint.
"""
await self.__transport.aclose()
await self.__exit_stack.aclose()
with self.__send_guard:
await self.__transport.aclose()
await self.__exit_stack.aclose()

async def send_packet(self, packet: _T_Response) -> None:
"""
Expand Down
45 changes: 41 additions & 4 deletions src/easynetwork/lowlevel/api_async/transports/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
import functools
import logging
import warnings
from collections.abc import Callable, Coroutine, Mapping
from collections import deque
from collections.abc import Callable, Coroutine, Iterable, Mapping
from typing import TYPE_CHECKING, Any, Final, NoReturn, Self, TypeVar, TypeVarTuple

try:
Expand Down Expand Up @@ -62,6 +63,7 @@ class AsyncTLSStreamTransport(AsyncStreamTransport):
_ssl_object: SSLObject
_read_bio: MemoryBIO
_write_bio: MemoryBIO
_data_deque: deque[memoryview] = dataclasses.field(init=False, default_factory=deque)
__incoming_reader: _IncomingDataReader = dataclasses.field(init=False)
__closing: bool = dataclasses.field(init=False, default=False)

Expand Down Expand Up @@ -132,6 +134,7 @@ async def wrap(

_ = ssl_object.getpeercert()
except BaseException:
self.__closing = True
await aclose_forcefully(transport)
raise
return self
Expand All @@ -152,14 +155,21 @@ def is_closing(self) -> bool:

@_utils.inherit_doc(AsyncStreamTransport)
async def aclose(self) -> None:
assert _ssl_module is not None, "stdlib ssl module not available" # nosec assert_used

already_closing = self.__closing
with contextlib.ExitStack() as stack:
stack.callback(self.__incoming_reader.close)
stack.callback(self._data_deque.clear)

self.__closing = True
if self._standard_compatible:
if not already_closing and self._standard_compatible and not self._transport.is_closing():
with self._transport.backend().move_on_after(self._shutdown_timeout) as shutdown_timeout_scope:
try:
await self._retry_ssl_method(self._ssl_object.unwrap)
try:
await self._retry_ssl_method(self._ssl_object.unwrap)
except OSError:
pass
self._read_bio.write_eof()
self._write_bio.write_eof()
except BaseException:
Expand Down Expand Up @@ -203,12 +213,39 @@ async def recv_into(self, buffer: WriteableBuffer) -> int:

@_utils.inherit_doc(AsyncStreamTransport)
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
if self.__closing:
raise _utils.error_from_errno(errno.ECONNABORTED)
self._data_deque.append(memoryview(data))
del data
return await self.__flush_data_to_send()

@_utils.inherit_doc(AsyncStreamTransport)
async def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | memoryview]) -> None:
if self.__closing:
raise _utils.error_from_errno(errno.ECONNABORTED)
self._data_deque.extend(map(memoryview, iterable_of_data))
del iterable_of_data
return await self.__flush_data_to_send()

async def __flush_data_to_send(self) -> None:
assert _ssl_module is not None, "stdlib ssl module not available" # nosec assert_used
try:
await self._retry_ssl_method(self._ssl_object.write, data)
await self._retry_ssl_method(self.__write_all_to_ssl_object, self._ssl_object, self._data_deque)
except _ssl_module.SSLZeroReturnError as exc:
raise _utils.error_from_errno(errno.ECONNRESET) from exc

@staticmethod
def __write_all_to_ssl_object(ssl_object: SSLObject, write_backlog: deque[memoryview]) -> None:
while write_backlog:
data = write_backlog[0]
if data.itemsize != 1:
write_backlog[0] = data = data.cast("B")
sent = ssl_object.write(data)
if sent < len(data):
write_backlog[0] = data[sent:]
else:
del write_backlog[0]

@_utils.inherit_doc(AsyncStreamTransport)
async def send_eof(self) -> None:
raise UnsupportedOperation("SSL/TLS API does not support sending EOF.")
Expand Down
11 changes: 10 additions & 1 deletion src/easynetwork/servers/async_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
ISocket,
SocketAddress,
SocketProxy,
TLSAttribute,
enable_socket_linger,
new_socket_address,
set_tcp_keepalive,
Expand Down Expand Up @@ -405,7 +406,15 @@ async def __client_initializer(
client_exit_stack.enter_context(self.__suppress_and_log_remaining_exception(client_address=client_address))
# If the socket was not closed gracefully, (i.e. client.aclose() failed )
# tell the OS to immediately abort the connection when calling socket.socket.close()
client_exit_stack.callback(self.__set_socket_linger_if_not_closed, lowlevel_client.extra(INETSocketAttribute.socket))
# NOTE: Do not set this option if SSL/TLS is enabled
if lowlevel_client.extra(TLSAttribute.sslcontext, None) is None:
client_exit_stack.callback(
self.__set_socket_linger_if_not_closed,
lowlevel_client.extra(INETSocketAttribute.socket),
)
elif lowlevel_client.extra(TLSAttribute.standard_compatible, False):
# We expect a TLS close handshake, so we must (try to) properly close the transport before
await client_exit_stack.enter_async_context(contextlib.aclosing(lowlevel_client))

logger: logging.Logger = self.__logger
client = _ConnectedClientAPI(client_address, lowlevel_client)
Expand Down
Loading