Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Packet send exceptions are now converted to RuntimeErrors #187

Merged
merged 3 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/easynetwork/api_async/server/udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,8 @@ def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return {}
server = server_ref
del server_ref
return dict(server.extra_attributes) | {
return {
**server.extra_attributes,
INETClientAttribute.socket: lambda: self.__socket_proxy,
INETClientAttribute.local_address: lambda: new_socket_address(
server.extra(INETSocketAttribute.sockname),
Expand Down
2 changes: 2 additions & 0 deletions src/easynetwork/lowlevel/_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def __next__(self) -> bytes:
chunk = next(filter(None, map(bytes, generator)))
except StopIteration:
pass
except Exception as exc:
raise RuntimeError("protocol.generate_chunks() crashed") from exc
else:
self.__g = generator
return chunk
Expand Down
9 changes: 8 additions & 1 deletion src/easynetwork/lowlevel/api_async/endpoints/datagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,14 @@ async def send_packet(self, packet: _SentPacketT) -> None:
if not self.__supports_write(transport):
raise UnsupportedOperation("transport does not support sending data")

await transport.send(protocol.make_datagram(packet))
try:
datagram: bytes = protocol.make_datagram(packet)
except Exception as exc:
raise RuntimeError("protocol.make_datagram() crashed") from exc
finally:
del packet

await transport.send(datagram)

async def recv_packet(self) -> _ReceivedPacketT:
"""
Expand Down
2 changes: 2 additions & 0 deletions src/easynetwork/lowlevel/api_async/servers/datagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ async def send_packet_to(self, packet: _ResponseT, address: _T_Address) -> None:

try:
datagram: bytes = protocol.make_datagram(packet)
except Exception as exc:
raise RuntimeError("protocol.make_datagram() crashed") from exc
finally:
del packet
try:
Expand Down
9 changes: 8 additions & 1 deletion src/easynetwork/lowlevel/api_sync/endpoints/datagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,14 @@ def send_packet(self, packet: _SentPacketT, *, timeout: float | None = None) ->
if not self.__supports_write(transport):
raise UnsupportedOperation("transport does not support sending data")

transport.send(protocol.make_datagram(packet), timeout)
try:
datagram: bytes = protocol.make_datagram(packet)
except Exception as exc:
raise RuntimeError("protocol.make_datagram() crashed") from exc
finally:
del packet

transport.send(datagram, timeout)

def recv_packet(self, *, timeout: float | None = None) -> _ReceivedPacketT:
"""
Expand Down
24 changes: 11 additions & 13 deletions src/easynetwork/lowlevel/std_asyncio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,22 +114,22 @@ async def aclose(self) -> None:
await asyncio.sleep(0)

async def accept(self) -> _socket.socket:
listener_socket = self.__check_not_closed()
with self.__conflict_detection("accept"):
listener_socket = self.__check_not_closed()
client_socket, _ = await self.__loop.sock_accept(listener_socket)
return client_socket

async def sendall(self, data: ReadableBuffer, /) -> None:
socket = self.__check_not_closed()
with self.__conflict_detection("send"):
socket = self.__check_not_closed()
await self.__loop.sock_sendall(socket, data)

async def sendmsg(self, buffers: Iterable[ReadableBuffer], /) -> None:
with self.__conflict_detection("send"):
socket = self.__check_not_closed()
if constants.SC_IOV_MAX <= 0 or not _utils.supports_socket_sendmsg(_sock := socket):
raise UnsupportedOperation("sendmsg() is not supported")
socket = self.__check_not_closed()
if constants.SC_IOV_MAX <= 0 or not _utils.supports_socket_sendmsg(_sock := socket):
raise UnsupportedOperation("sendmsg() is not supported")

with self.__conflict_detection("send"):
loop = self.__loop
buffers = cast("deque[memoryview]", deque(memoryview(data).cast("B") for data in buffers))

Expand All @@ -145,18 +145,18 @@ async def sendmsg(self, buffers: Iterable[ReadableBuffer], /) -> None:
_utils.adjust_leftover_buffer(buffers, sent)

async def sendto(self, data: ReadableBuffer, address: _socket._Address, /) -> None:
socket = self.__check_not_closed()
with self.__conflict_detection("send"):
socket = self.__check_not_closed()
await self.__loop.sock_sendto(socket, data, address)

async def recv(self, bufsize: int, /) -> bytes:
socket = self.__check_not_closed()
with self.__conflict_detection("recv"):
socket = self.__check_not_closed()
return await self.__loop.sock_recv(socket, bufsize)

async def recv_into(self, buffer: WriteableBuffer, /) -> int:
socket = self.__check_not_closed()
with self.__conflict_detection("recv"):
socket = self.__check_not_closed()
return await self.__loop.sock_recv_into(socket, buffer)

async def recvfrom(self, bufsize: int, /) -> tuple[bytes, _socket._RetAddress]:
Expand All @@ -165,8 +165,7 @@ async def recvfrom(self, bufsize: int, /) -> tuple[bytes, _socket._RetAddress]:
return await self.__loop.sock_recvfrom(socket, bufsize)

async def shutdown(self, how: int, /) -> None:
# Checks if we are within the bound loop
TaskUtils.check_current_asyncio_task(self.__loop)
TaskUtils.check_current_event_loop(self.__loop)

if how in {_socket.SHUT_RDWR, _socket.SHUT_WR}:
while (waiter := self.__waiters.get("send")) is not None:
Expand All @@ -184,8 +183,7 @@ def __conflict_detection(self, task_id: _SocketTaskId) -> Iterator[None]:
if task_id in self.__waiters:
raise _utils.error_from_errno(_errno.EBUSY)

# Checks if we are within the bound loop
TaskUtils.check_current_asyncio_task(self.__loop)
TaskUtils.check_current_event_loop(self.__loop)

with contextlib.ExitStack() as stack, CancelScope() as scope:
self.__scopes.add(scope)
Expand Down
5 changes: 3 additions & 2 deletions src/easynetwork/lowlevel/std_asyncio/stream/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,10 @@ async def send_eof(self) -> None:

async def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | memoryview]) -> None:
try:
await self.__socket.sendmsg(iterable_of_data)
return await self.__socket.sendmsg(iterable_of_data)
except UnsupportedOperation:
await super().send_all_from_iterable(iterable_of_data)
pass
return await super().send_all_from_iterable(iterable_of_data)

@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
Expand Down
6 changes: 4 additions & 2 deletions src/easynetwork/lowlevel/std_asyncio/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,10 @@ def __cancel_task_unless_done(task: asyncio.Task[Any], cancel_msg: str | None) -
@final
class TaskUtils:
@staticmethod
def check_current_asyncio_task(loop: asyncio.AbstractEventLoop | None = None) -> None:
_ = TaskUtils.current_asyncio_task(loop=loop)
def check_current_event_loop(loop: asyncio.AbstractEventLoop) -> None:
running_loop = asyncio.get_running_loop()
if running_loop is not loop:
raise RuntimeError(f"{running_loop=!r} is not {loop=!r}")

@staticmethod
def current_asyncio_task(loop: asyncio.AbstractEventLoop | None = None) -> asyncio.Task[Any]:
Expand Down
12 changes: 11 additions & 1 deletion tests/functional_test/test_communication/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
import pytest
import trustme

from .serializer import BufferedStringSerializer, NotGoodBufferedStringSerializer, NotGoodStringSerializer, StringSerializer
from .serializer import (
BadSerializeStringSerializer,
BufferedStringSerializer,
NotGoodBufferedStringSerializer,
NotGoodStringSerializer,
StringSerializer,
)

_FAMILY_TO_LOCALHOST: dict[int, str] = {
AF_INET: "127.0.0.1",
Expand Down Expand Up @@ -68,6 +74,8 @@ def one_shot_serializer(request: pytest.FixtureRequest) -> StringSerializer:
return StringSerializer()
case "invalid":
return NotGoodStringSerializer()
case "bad_serialize":
return BadSerializeStringSerializer()
case _:
pytest.fail("Invalid parameter")

Expand All @@ -83,6 +91,8 @@ def incremental_serializer(request: pytest.FixtureRequest) -> StringSerializer:
return NotGoodStringSerializer()
case "invalid_buffered":
return NotGoodBufferedStringSerializer()
case "bad_serialize":
return BadSerializeStringSerializer()
case _:
pytest.fail("Invalid parameter")

Expand Down
11 changes: 11 additions & 0 deletions tests/functional_test/test_communication/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,14 @@ class NotGoodBufferedStringSerializer(NotGoodStringSerializer, BufferedStringSer
def buffered_incremental_deserialize(self, write_buffer: bytearray) -> Generator[int, int, tuple[str, bytearray]]:
yield 0
raise SystemError("CRASH")


class BadSerializeStringSerializer(StringSerializer):
__slots__ = ()

def serialize(self, packet: str) -> bytes:
raise SystemError("CRASH")

def incremental_serialize(self, packet: str) -> Generator[bytes, None, None]:
raise SystemError("CRASH")
yield b"chunk" # type: ignore[unreachable]
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ async def client(
assert client.is_connected()
yield client

@pytest.fixture
@staticmethod
def is_buffered_protocol(stream_protocol: StreamProtocol[str, str]) -> bool:
try:
stream_protocol.buffered_receiver()
except NotImplementedError:
return False
return True

async def test____aclose____idempotent(self, client: AsyncTCPNetworkClient[str, str]) -> None:
assert not client.is_closing()
await client.aclose()
Expand All @@ -80,9 +89,9 @@ async def test____send_packet____connection_error____fresh_connection_closed_by_
server: Socket,
) -> None:
if use_asyncio_transport:
pytest.skip("It is not mandadtory for asyncio.Transport implementations to raise ConnectionAbortedError")
pytest.skip("It is not mandatory for asyncio.Transport implementations to raise ConnectionAbortedError")
if is_uvloop_event_loop(event_loop):
pytest.skip("It is not mandadtory for uvloop to raise ConnectionAbortedError")
pytest.skip("It is not mandatory for uvloop to raise ConnectionAbortedError")
server.close()
with pytest.raises(ConnectionAbortedError):
for _ in range(3): # Windows and macOS catch the issue after several send()
Expand All @@ -97,9 +106,9 @@ async def test____send_packet____connection_error____after_previous_successful_t
server: Socket,
) -> None:
if use_asyncio_transport:
pytest.skip("It is not mandadtory for asyncio.Transport implementations to raise ConnectionAbortedError")
pytest.skip("It is not mandatory for asyncio.Transport implementations to raise ConnectionAbortedError")
if is_uvloop_event_loop(event_loop):
pytest.skip("It is not mandadtory for uvloop to raise ConnectionAbortedError")
pytest.skip("It is not mandatory for uvloop to raise ConnectionAbortedError")

await client.send_packet("ABCDEF")
assert await readline(event_loop, server) == b"ABCDEF\n"
Expand All @@ -117,9 +126,9 @@ async def test____send_packet____connection_error____partial_read_then_close(
server: Socket,
) -> None:
if use_asyncio_transport:
pytest.skip("It is not mandadtory for asyncio.Transport implementations to raise ConnectionAbortedError")
pytest.skip("It is not mandatory for asyncio.Transport implementations to raise ConnectionAbortedError")
if is_uvloop_event_loop(event_loop):
pytest.skip("It is not mandadtory for uvloop to raise ConnectionAbortedError")
pytest.skip("It is not mandatory for uvloop to raise ConnectionAbortedError")
await client.send_packet("ABC")
assert await event_loop.sock_recv(server, 1) == b"A"
server.close()
Expand All @@ -133,6 +142,14 @@ async def test____send_packet____closed_client(self, client: AsyncTCPNetworkClie
with pytest.raises(ClientClosedError):
await client.send_packet("ABCDEF")

@pytest.mark.parametrize("incremental_serializer", [pytest.param("bad_serialize", id="serializer_crash")], indirect=True)
async def test____send_packet____protocol_crashed(
self,
client: AsyncTCPNetworkClient[str, str],
) -> None:
with pytest.raises(RuntimeError, match=r"^protocol\.generate_chunks\(\) crashed$"):
await client.send_packet("ABCDEF")

async def test____send_eof____close_write_stream(
self,
event_loop: asyncio.AbstractEventLoop,
Expand Down Expand Up @@ -240,6 +257,31 @@ async def test____recv_packet____invalid_data(
await client.recv_packet()
assert await client.recv_packet() == "valid"

@pytest.mark.parametrize(
"incremental_serializer",
[
pytest.param("invalid", id="serializer_crash"),
pytest.param("invalid_buffered", id="buffered_serializer_crash"),
],
indirect=True,
)
async def test____recv_packet____protocol_crashed(
self,
event_loop: asyncio.AbstractEventLoop,
use_asyncio_transport: bool,
is_buffered_protocol: bool,
client: AsyncTCPNetworkClient[str, str],
server: Socket,
) -> None:
expected_pattern: str
if is_buffered_protocol and not use_asyncio_transport:
expected_pattern = r"^protocol\.build_packet_from_buffer\(\) crashed$"
else:
expected_pattern = r"^protocol\.build_packet_from_chunks\(\) crashed$"
await event_loop.sock_sendall(server, b"ABCDEF\n")
with pytest.raises(RuntimeError, match=expected_pattern):
await client.recv_packet()

async def test____iter_received_packets____yields_available_packets_until_eof(
self,
event_loop: asyncio.AbstractEventLoop,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ async def test____send_packet____closed_client(self, client: AsyncUDPNetworkClie
with pytest.raises(ClientClosedError):
await client.send_packet("ABCDEF")

@pytest.mark.parametrize("one_shot_serializer", [pytest.param("bad_serialize", id="serializer_crash")], indirect=True)
async def test____send_packet____protocol_crashed(
self,
client: AsyncUDPNetworkClient[str, str],
) -> None:
with pytest.raises(RuntimeError, match=r"^protocol\.make_datagram\(\) crashed$"):
await client.send_packet("ABCDEF")

@use_asyncio_transport_xfail_uvloop
async def test____recv_packet____default(self, client: AsyncUDPNetworkClient[str, str], server: DatagramEndpoint) -> None:
await server.sendto(b"ABCDEF", client.get_local_address())
Expand Down Expand Up @@ -176,6 +184,22 @@ async def test____recv_packet____invalid_data(
async with asyncio.timeout(3):
await client.recv_packet()

@pytest.mark.parametrize("one_shot_serializer", [pytest.param("invalid", id="serializer_crash")], indirect=True)
@use_asyncio_transport_xfail_uvloop
async def test____recv_packet____protocol_crashed(
self,
client: AsyncUDPNetworkClient[str, str],
server: DatagramEndpoint,
) -> None:
await server.sendto(b"ABCDEF", client.get_local_address())
try:
await client.recv_packet()
except NotImplementedError:
raise
except Exception:
with pytest.raises(RuntimeError, match=r"^protocol\.build_packet_from_datagram\(\) crashed$"):
raise

@use_asyncio_transport_xfail_uvloop
async def test____iter_received_packets____yields_available_packets_until_close(
self,
Expand Down
Loading