Skip to content

Commit

Permalink
Fixed bug in DatagramServer where a datagram could never be freed aft…
Browse files Browse the repository at this point in the history
…er usage
  • Loading branch information
francis-clairicia committed Jun 23, 2024
1 parent 22beb8e commit df84bc1
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 92 deletions.
141 changes: 65 additions & 76 deletions src/easynetwork/lowlevel/api_async/servers/datagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import weakref
from collections import deque
from collections.abc import AsyncGenerator, Callable, Hashable, Mapping
from contextlib import AsyncExitStack, ExitStack
from contextlib import AsyncExitStack
from typing import Any, Generic, NoReturn, TypeVar

from ...._typevars import _T_Request, _T_Response
Expand Down Expand Up @@ -189,11 +189,12 @@ async def handler(datagram: bytes, address: _T_Address, /) -> None:
else:
new_client_task = client.data.state is None

await client.data.push_datagram(datagram)

if new_client_task:
del datagram
client.data.mark_pending()
await self.__client_coroutine(datagram_received_cb, datagram, client, task_group, default_context)
else:
await client.data.push_datagram(datagram)
await self.__client_coroutine(datagram_received_cb, client, task_group, default_context)

await listener.serve(handler, task_group)

Expand All @@ -204,65 +205,67 @@ async def __client_coroutine(
datagram_received_cb: Callable[
[DatagramClientContext[_T_Response, _T_Address]], AsyncGenerator[float | None, _T_Request]
],
datagram: bytes,
client: _ClientToken[_T_Response, _T_Address],
task_group: TaskGroup,
default_context: contextvars.Context,
) -> None:
client_data = client.data
async with client_data.task_lock:
with ExitStack() as exit_stack:
#####################################################################################################
# CRITICAL SECTION
# This block must not have any asynchronous function calls
# or add any asynchronous callbacks/contexts to the exit stack.
client_data.mark_running()
exit_stack.callback(
self.__on_task_done,
async with client.data.task_lock:
client.data.mark_running()
try:
await self.__client_coroutine_inner_loop(
request_handler_generator=datagram_received_cb(client.ctx),
client_data=client.data,
)
finally:
self.__on_task_done(
datagram_received_cb=datagram_received_cb,
client=client,
task_group=task_group,
default_context=default_context,
)
#####################################################################################################

request_handler_generator = datagram_received_cb(client.ctx)
timeout: float | None

async def __client_coroutine_inner_loop(
self,
*,
request_handler_generator: AsyncGenerator[float | None, _T_Request],
client_data: _ClientData,
) -> None:
timeout: float | None
datagram: bytes = client_data.pop_datagram_no_wait()
try:
# Ignore sent timeout here, we already have the datagram.
await anext(request_handler_generator)
except StopAsyncIteration:
return
else:
action: AsyncGenAction[_T_Request] | None
action = self.__parse_datagram(datagram, self.__protocol)
try:
timeout = await action.asend(request_handler_generator)
except StopAsyncIteration:
return
finally:
action = None

del datagram
null_timeout_ctx = contextlib.nullcontext()
while True:
try:
# Ignore sent timeout here, we already have the datagram.
await anext(request_handler_generator)
except StopAsyncIteration:
return
else:
action: AsyncGenAction[_T_Request] | None
with null_timeout_ctx if timeout is None else client_data.backend.timeout(timeout):
datagram = await client_data.pop_datagram()
action = self.__parse_datagram(datagram, self.__protocol)
try:
timeout = await action.asend(request_handler_generator)
except StopAsyncIteration:
return
finally:
action = None

del datagram
null_timeout_ctx = contextlib.nullcontext()
while True:
try:
with null_timeout_ctx if timeout is None else client_data.backend.timeout(timeout):
datagram = await client_data.pop_datagram()
action = self.__parse_datagram(datagram, self.__protocol)
except BaseException as exc:
action = ThrowAction(exc)
finally:
datagram = b""
try:
timeout = await action.asend(request_handler_generator)
except StopAsyncIteration:
break
finally:
action = None
except BaseException as exc:
action = ThrowAction(exc)
finally:
datagram = b""
try:
timeout = await action.asend(request_handler_generator)
except StopAsyncIteration:
break
finally:
await request_handler_generator.aclose()
action = None
finally:
await request_handler_generator.aclose()

def __on_task_done(
self,
Expand All @@ -274,17 +277,14 @@ def __on_task_done(
default_context: contextvars.Context,
) -> None:
client.data.mark_done()
try:
pending_datagram = client.data.pop_datagram_no_wait()
except IndexError:
if client.data.queue_is_empty():
return

client.data.mark_pending()
default_context.run(
task_group.start_soon,
self.__client_coroutine,
datagram_received_cb,
pending_datagram,
client,
task_group,
default_context,
Expand Down Expand Up @@ -349,8 +349,8 @@ def __init__(self, backend: AsyncBackend) -> None:
self.__backend: AsyncBackend = backend
self.__task_lock: ILock = backend.create_lock()
self.__state: _ClientState | None = None
self._queue_condition: ICondition | None = None
self._datagram_queue: deque[bytes] | None = None
self._queue_condition: ICondition = backend.create_condition_var()
self._datagram_queue: deque[bytes] = deque()

@property
def backend(self) -> AsyncBackend:
Expand All @@ -364,21 +364,20 @@ def task_lock(self) -> ILock:
def state(self) -> _ClientState | None:
return self.__state

def queue_is_empty(self) -> bool:
return not self._datagram_queue

async def push_datagram(self, datagram: bytes) -> None:
self.__ensure_queue().append(datagram)
if (queue_condition := self._queue_condition) is not None:
async with queue_condition:
queue_condition.notify()
self._datagram_queue.append(datagram)
async with (queue_condition := self._queue_condition):
queue_condition.notify()

def pop_datagram_no_wait(self) -> bytes:
if not (queue := self._datagram_queue):
raise IndexError("pop from an empty deque")
return queue.popleft()
return self._datagram_queue.popleft()

async def pop_datagram(self) -> bytes:
queue_condition = self.__ensure_queue_condition_var()
async with queue_condition:
queue = self.__ensure_queue()
async with (queue_condition := self._queue_condition):
queue = self._datagram_queue
while not queue:
await queue_condition.wait()
return queue.popleft()
Expand All @@ -398,16 +397,6 @@ def mark_running(self) -> None:
self.handle_inconsistent_state_error()
self.__state = _ClientState.TASK_RUNNING

def __ensure_queue(self) -> deque[bytes]:
if (queue := self._datagram_queue) is None:
self._datagram_queue = queue = deque()
return queue

def __ensure_queue_condition_var(self) -> ICondition:
if (cond := self._queue_condition) is None:
self._queue_condition = cond = self.__backend.create_condition_var()
return cond

@staticmethod
def handle_inconsistent_state_error() -> NoReturn:
msg = "The server has created too many tasks and ends up in an inconsistent state."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,6 @@ def client_data(mock_backend: MagicMock) -> _ClientData:
def get_client_state(client_data: _ClientData) -> _ClientState | None:
return client_data.state

@staticmethod
def get_client_queue(client_data: _ClientData) -> deque[bytes] | None:
return client_data._datagram_queue

def test____dunder_init____default(
self,
client_data: _ClientData,
Expand All @@ -219,8 +215,6 @@ def test____dunder_init____default(
# Act & Assert
assert isinstance(client_data.task_lock, asyncio.Lock)
assert client_data.state is None
assert client_data._datagram_queue is None
assert client_data._queue_condition is None

def test____client_state____regular_state_transition(
self,
Expand Down Expand Up @@ -279,16 +273,13 @@ async def test____datagram_queue____push_datagram(
client_data: _ClientData,
) -> None:
# Arrange
assert self.get_client_queue(client_data) is None

# Act
await client_data.push_datagram(b"datagram_1")
await client_data.push_datagram(b"datagram_2")
await client_data.push_datagram(b"datagram_3")

# Assert
assert client_data._datagram_queue is not None
assert client_data._queue_condition is None
assert list(client_data._datagram_queue) == [b"datagram_1", b"datagram_2", b"datagram_3"]

@pytest.mark.asyncio
Expand All @@ -313,19 +304,12 @@ async def test____datagram_queue____pop_datagram(

# Assert
assert len(client_data._datagram_queue) == 0
if no_wait:
assert client_data._queue_condition is None
else:
assert client_data._queue_condition is not None

@pytest.mark.parametrize("queue", [deque(), None], ids=lambda p: f"queue=={p!r}")
def test____datagram_queue____pop_datagram_no_wait____empty_list(
self,
queue: deque[bytes] | None,
client_data: _ClientData,
) -> None:
# Arrange
client_data._datagram_queue = queue

# Act & Assert
with pytest.raises(IndexError):
Expand Down

0 comments on commit df84bc1

Please sign in to comment.