Skip to content

Commit

Permalink
[FIX] Serializers and converters exceptions are now converted to Runt…
Browse files Browse the repository at this point in the history
…imeErrors
  • Loading branch information
francis-clairicia committed Dec 7, 2023
1 parent 3600654 commit 8ddf388
Show file tree
Hide file tree
Showing 20 changed files with 353 additions and 33 deletions.
3 changes: 2 additions & 1 deletion src/easynetwork/api_async/server/udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,8 @@ def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return {}
server = server_ref
del server_ref
return dict(server.extra_attributes) | {
return {
**server.extra_attributes,
INETClientAttribute.socket: lambda: self.__socket_proxy,
INETClientAttribute.local_address: lambda: new_socket_address(
server.extra(INETSocketAttribute.sockname),
Expand Down
2 changes: 2 additions & 0 deletions src/easynetwork/lowlevel/_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def __next__(self) -> bytes:
chunk = next(filter(None, map(bytes, generator)))
except StopIteration:
pass
except Exception as exc:
raise RuntimeError("protocol.generate_chunks() crashed") from exc
else:
self.__g = generator
return chunk
Expand Down
9 changes: 8 additions & 1 deletion src/easynetwork/lowlevel/api_async/endpoints/datagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,14 @@ async def send_packet(self, packet: _SentPacketT) -> None:
if not self.__supports_write(transport):
raise UnsupportedOperation("transport does not support sending data")

await transport.send(protocol.make_datagram(packet))
try:
datagram: bytes = protocol.make_datagram(packet)
except Exception as exc:
raise RuntimeError("protocol.make_datagram() crashed") from exc
finally:
del packet

await transport.send(datagram)

async def recv_packet(self) -> _ReceivedPacketT:
"""
Expand Down
2 changes: 2 additions & 0 deletions src/easynetwork/lowlevel/api_async/servers/datagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ async def send_packet_to(self, packet: _ResponseT, address: _T_Address) -> None:

try:
datagram: bytes = protocol.make_datagram(packet)
except Exception as exc:
raise RuntimeError("protocol.make_datagram() crashed") from exc
finally:
del packet
try:
Expand Down
9 changes: 8 additions & 1 deletion src/easynetwork/lowlevel/api_sync/endpoints/datagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,14 @@ def send_packet(self, packet: _SentPacketT, *, timeout: float | None = None) ->
if not self.__supports_write(transport):
raise UnsupportedOperation("transport does not support sending data")

transport.send(protocol.make_datagram(packet), timeout)
try:
datagram: bytes = protocol.make_datagram(packet)
except Exception as exc:
raise RuntimeError("protocol.make_datagram() crashed") from exc
finally:
del packet

transport.send(datagram, timeout)

def recv_packet(self, *, timeout: float | None = None) -> _ReceivedPacketT:
"""
Expand Down
24 changes: 11 additions & 13 deletions src/easynetwork/lowlevel/std_asyncio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,22 +114,22 @@ async def aclose(self) -> None:
await asyncio.sleep(0)

async def accept(self) -> _socket.socket:
listener_socket = self.__check_not_closed()
with self.__conflict_detection("accept"):
listener_socket = self.__check_not_closed()
client_socket, _ = await self.__loop.sock_accept(listener_socket)
return client_socket

async def sendall(self, data: ReadableBuffer, /) -> None:
socket = self.__check_not_closed()
with self.__conflict_detection("send"):
socket = self.__check_not_closed()
await self.__loop.sock_sendall(socket, data)

async def sendmsg(self, buffers: Iterable[ReadableBuffer], /) -> None:
with self.__conflict_detection("send"):
socket = self.__check_not_closed()
if constants.SC_IOV_MAX <= 0 or not _utils.supports_socket_sendmsg(_sock := socket):
raise UnsupportedOperation("sendmsg() is not supported")
socket = self.__check_not_closed()
if constants.SC_IOV_MAX <= 0 or not _utils.supports_socket_sendmsg(_sock := socket):
raise UnsupportedOperation("sendmsg() is not supported")

with self.__conflict_detection("send"):
loop = self.__loop
buffers = cast("deque[memoryview]", deque(memoryview(data).cast("B") for data in buffers))

Expand All @@ -145,18 +145,18 @@ async def sendmsg(self, buffers: Iterable[ReadableBuffer], /) -> None:
_utils.adjust_leftover_buffer(buffers, sent)

async def sendto(self, data: ReadableBuffer, address: _socket._Address, /) -> None:
socket = self.__check_not_closed()
with self.__conflict_detection("send"):
socket = self.__check_not_closed()
await self.__loop.sock_sendto(socket, data, address)

async def recv(self, bufsize: int, /) -> bytes:
socket = self.__check_not_closed()
with self.__conflict_detection("recv"):
socket = self.__check_not_closed()
return await self.__loop.sock_recv(socket, bufsize)

async def recv_into(self, buffer: WriteableBuffer, /) -> int:
socket = self.__check_not_closed()
with self.__conflict_detection("recv"):
socket = self.__check_not_closed()
return await self.__loop.sock_recv_into(socket, buffer)

async def recvfrom(self, bufsize: int, /) -> tuple[bytes, _socket._RetAddress]:
Expand All @@ -165,8 +165,7 @@ async def recvfrom(self, bufsize: int, /) -> tuple[bytes, _socket._RetAddress]:
return await self.__loop.sock_recvfrom(socket, bufsize)

async def shutdown(self, how: int, /) -> None:
# Checks if we are within the bound loop
TaskUtils.check_current_asyncio_task(self.__loop)
TaskUtils.check_current_event_loop(self.__loop)

if how in {_socket.SHUT_RDWR, _socket.SHUT_WR}:
while (waiter := self.__waiters.get("send")) is not None:
Expand All @@ -184,8 +183,7 @@ def __conflict_detection(self, task_id: _SocketTaskId) -> Iterator[None]:
if task_id in self.__waiters:
raise _utils.error_from_errno(_errno.EBUSY)

# Checks if we are within the bound loop
TaskUtils.check_current_asyncio_task(self.__loop)
TaskUtils.check_current_event_loop(self.__loop)

with contextlib.ExitStack() as stack, CancelScope() as scope:
self.__scopes.add(scope)
Expand Down
5 changes: 3 additions & 2 deletions src/easynetwork/lowlevel/std_asyncio/stream/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,10 @@ async def send_eof(self) -> None:

async def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | memoryview]) -> None:
try:
await self.__socket.sendmsg(iterable_of_data)
return await self.__socket.sendmsg(iterable_of_data)
except UnsupportedOperation:
await super().send_all_from_iterable(iterable_of_data)
pass
return await super().send_all_from_iterable(iterable_of_data)

@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
Expand Down
6 changes: 4 additions & 2 deletions src/easynetwork/lowlevel/std_asyncio/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,10 @@ def __cancel_task_unless_done(task: asyncio.Task[Any], cancel_msg: str | None) -
@final
class TaskUtils:
@staticmethod
def check_current_asyncio_task(loop: asyncio.AbstractEventLoop | None = None) -> None:
_ = TaskUtils.current_asyncio_task(loop=loop)
def check_current_event_loop(loop: asyncio.AbstractEventLoop) -> None:
running_loop = asyncio.get_running_loop()
if running_loop is not loop:
raise RuntimeError(f"{running_loop=!r} is not {loop=!r}")

@staticmethod
def current_asyncio_task(loop: asyncio.AbstractEventLoop | None = None) -> asyncio.Task[Any]:
Expand Down
12 changes: 11 additions & 1 deletion tests/functional_test/test_communication/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
import pytest
import trustme

from .serializer import BufferedStringSerializer, NotGoodBufferedStringSerializer, NotGoodStringSerializer, StringSerializer
from .serializer import (
BadSerializeStringSerializer,
BufferedStringSerializer,
NotGoodBufferedStringSerializer,
NotGoodStringSerializer,
StringSerializer,
)

_FAMILY_TO_LOCALHOST: dict[int, str] = {
AF_INET: "127.0.0.1",
Expand Down Expand Up @@ -68,6 +74,8 @@ def one_shot_serializer(request: pytest.FixtureRequest) -> StringSerializer:
return StringSerializer()
case "invalid":
return NotGoodStringSerializer()
case "bad_serialize":
return BadSerializeStringSerializer()
case _:
pytest.fail("Invalid parameter")

Expand All @@ -83,6 +91,8 @@ def incremental_serializer(request: pytest.FixtureRequest) -> StringSerializer:
return NotGoodStringSerializer()
case "invalid_buffered":
return NotGoodBufferedStringSerializer()
case "bad_serialize":
return BadSerializeStringSerializer()
case _:
pytest.fail("Invalid parameter")

Expand Down
11 changes: 11 additions & 0 deletions tests/functional_test/test_communication/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,14 @@ class NotGoodBufferedStringSerializer(NotGoodStringSerializer, BufferedStringSer
def buffered_incremental_deserialize(self, write_buffer: bytearray) -> Generator[int, int, tuple[str, bytearray]]:
yield 0
raise SystemError("CRASH")


class BadSerializeStringSerializer(StringSerializer):
__slots__ = ()

def serialize(self, packet: str) -> bytes:
raise SystemError("CRASH")

def incremental_serialize(self, packet: str) -> Generator[bytes, None, None]:
raise SystemError("CRASH")
yield b"chunk" # type: ignore[unreachable]
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ async def client(
assert client.is_connected()
yield client

@pytest.fixture
@staticmethod
def is_buffered_protocol(stream_protocol: StreamProtocol[str, str]) -> bool:
try:
stream_protocol.buffered_receiver()
except NotImplementedError:
return False
return True

async def test____aclose____idempotent(self, client: AsyncTCPNetworkClient[str, str]) -> None:
assert not client.is_closing()
await client.aclose()
Expand All @@ -80,9 +89,9 @@ async def test____send_packet____connection_error____fresh_connection_closed_by_
server: Socket,
) -> None:
if use_asyncio_transport:
pytest.skip("It is not mandadtory for asyncio.Transport implementations to raise ConnectionAbortedError")
pytest.skip("It is not mandatory for asyncio.Transport implementations to raise ConnectionAbortedError")
if is_uvloop_event_loop(event_loop):
pytest.skip("It is not mandadtory for uvloop to raise ConnectionAbortedError")
pytest.skip("It is not mandatory for uvloop to raise ConnectionAbortedError")
server.close()
with pytest.raises(ConnectionAbortedError):
for _ in range(3): # Windows and macOS catch the issue after several send()
Expand All @@ -97,9 +106,9 @@ async def test____send_packet____connection_error____after_previous_successful_t
server: Socket,
) -> None:
if use_asyncio_transport:
pytest.skip("It is not mandadtory for asyncio.Transport implementations to raise ConnectionAbortedError")
pytest.skip("It is not mandatory for asyncio.Transport implementations to raise ConnectionAbortedError")
if is_uvloop_event_loop(event_loop):
pytest.skip("It is not mandadtory for uvloop to raise ConnectionAbortedError")
pytest.skip("It is not mandatory for uvloop to raise ConnectionAbortedError")

await client.send_packet("ABCDEF")
assert await readline(event_loop, server) == b"ABCDEF\n"
Expand All @@ -117,9 +126,9 @@ async def test____send_packet____connection_error____partial_read_then_close(
server: Socket,
) -> None:
if use_asyncio_transport:
pytest.skip("It is not mandadtory for asyncio.Transport implementations to raise ConnectionAbortedError")
pytest.skip("It is not mandatory for asyncio.Transport implementations to raise ConnectionAbortedError")
if is_uvloop_event_loop(event_loop):
pytest.skip("It is not mandadtory for uvloop to raise ConnectionAbortedError")
pytest.skip("It is not mandatory for uvloop to raise ConnectionAbortedError")
await client.send_packet("ABC")
assert await event_loop.sock_recv(server, 1) == b"A"
server.close()
Expand All @@ -133,6 +142,14 @@ async def test____send_packet____closed_client(self, client: AsyncTCPNetworkClie
with pytest.raises(ClientClosedError):
await client.send_packet("ABCDEF")

@pytest.mark.parametrize("incremental_serializer", [pytest.param("bad_serialize", id="serializer_crash")], indirect=True)
async def test____send_packet____protocol_crashed(
self,
client: AsyncTCPNetworkClient[str, str],
) -> None:
with pytest.raises(RuntimeError, match=r"^protocol\.generate_chunks\(\) crashed$"):
await client.send_packet("ABCDEF")

async def test____send_eof____close_write_stream(
self,
event_loop: asyncio.AbstractEventLoop,
Expand Down Expand Up @@ -240,6 +257,31 @@ async def test____recv_packet____invalid_data(
await client.recv_packet()
assert await client.recv_packet() == "valid"

@pytest.mark.parametrize(
"incremental_serializer",
[
pytest.param("invalid", id="serializer_crash"),
pytest.param("invalid_buffered", id="buffered_serializer_crash"),
],
indirect=True,
)
async def test____recv_packet____protocol_crashed(
self,
event_loop: asyncio.AbstractEventLoop,
use_asyncio_transport: bool,
is_buffered_protocol: bool,
client: AsyncTCPNetworkClient[str, str],
server: Socket,
) -> None:
expected_pattern: str
if is_buffered_protocol and not use_asyncio_transport:
expected_pattern = r"^protocol\.build_packet_from_buffer\(\) crashed$"
else:
expected_pattern = r"^protocol\.build_packet_from_chunks\(\) crashed$"
await event_loop.sock_sendall(server, b"ABCDEF\n")
with pytest.raises(RuntimeError, match=expected_pattern):
await client.recv_packet()

async def test____iter_received_packets____yields_available_packets_until_eof(
self,
event_loop: asyncio.AbstractEventLoop,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ async def test____send_packet____closed_client(self, client: AsyncUDPNetworkClie
with pytest.raises(ClientClosedError):
await client.send_packet("ABCDEF")

@pytest.mark.parametrize("one_shot_serializer", [pytest.param("bad_serialize", id="serializer_crash")], indirect=True)
async def test____send_packet____protocol_crashed(
self,
client: AsyncUDPNetworkClient[str, str],
) -> None:
with pytest.raises(RuntimeError, match=r"^protocol\.make_datagram\(\) crashed$"):
await client.send_packet("ABCDEF")

@use_asyncio_transport_xfail_uvloop
async def test____recv_packet____default(self, client: AsyncUDPNetworkClient[str, str], server: DatagramEndpoint) -> None:
await server.sendto(b"ABCDEF", client.get_local_address())
Expand Down Expand Up @@ -176,6 +184,16 @@ async def test____recv_packet____invalid_data(
async with asyncio.timeout(3):
await client.recv_packet()

@pytest.mark.parametrize("one_shot_serializer", [pytest.param("invalid", id="serializer_crash")], indirect=True)
async def test____recv_packet____protocol_crashed(
self,
client: AsyncUDPNetworkClient[str, str],
server: DatagramEndpoint,
) -> None:
await server.sendto(b"ABCDEF", client.get_local_address())
with pytest.raises(RuntimeError, match=r"^protocol\.build_packet_from_datagram\(\) crashed$"):
await client.recv_packet()

@use_asyncio_transport_xfail_uvloop
async def test____iter_received_packets____yields_available_packets_until_close(
self,
Expand Down
Loading

0 comments on commit 8ddf388

Please sign in to comment.