Skip to content

Commit

Permalink
Minor bug fix in servers (#130)
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia authored Sep 26, 2023
1 parent c58962b commit 6bf1e24
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 84 deletions.
4 changes: 2 additions & 2 deletions src/easynetwork/api_async/server/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
]

from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, Any, Protocol, Self
from typing import TYPE_CHECKING, Any, NoReturn, Protocol, Self

if TYPE_CHECKING:
from types import TracebackType
Expand Down Expand Up @@ -74,7 +74,7 @@ def is_serving(self) -> bool:
raise NotImplementedError

@abstractmethod
async def serve_forever(self, *, is_up_event: SupportsEventSet | None = ...) -> None:
async def serve_forever(self, *, is_up_event: SupportsEventSet | None = ...) -> NoReturn:
"""
Starts the server's main loop.
Expand Down
38 changes: 21 additions & 17 deletions src/easynetwork/api_async/server/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import weakref
from collections import deque
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Coroutine, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, final
from typing import TYPE_CHECKING, Any, Generic, NoReturn, final

from ..._typevars import _RequestT, _ResponseT
from ...exceptions import ClientClosedError, ServerAlreadyRunning, ServerClosedError
Expand Down Expand Up @@ -210,8 +210,8 @@ def _value_or_default(value: float | None, default: float) -> float:
self.__is_shutdown.set()
self.__shutdown_asked: bool = False
self.__max_recv_size: int = max_recv_size
self.__listener_tasks: deque[Task[None]] = deque()
self.__mainloop_task: Task[None] | None = None
self.__listener_tasks: deque[Task[NoReturn]] = deque() # type: ignore[assignment]
self.__mainloop_task: Task[NoReturn] | None = None
self.__logger: logging.Logger = logger or logging.getLogger(__name__)
self.__client_connection_log_level: int
if log_client_connection:
Expand Down Expand Up @@ -280,13 +280,8 @@ async def shutdown(self) -> None:

shutdown.__doc__ = AbstractAsyncNetworkServer.shutdown.__doc__

async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> None:
async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> NoReturn:
async with _contextlib.AsyncExitStack() as server_exit_stack:
is_up_callback = server_exit_stack.enter_context(_contextlib.ExitStack())
if is_up_event is not None:
# Force is_up_event to be set, in order not to stuck the waiting task
is_up_callback.callback(is_up_event.set)

# Wake up server
if not self.__is_shutdown.is_set():
raise ServerAlreadyRunning("Server is already running")
Expand Down Expand Up @@ -338,8 +333,8 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) ->
#################

# Server is up
is_up_callback.close()
del is_up_callback
if is_up_event is not None and not self.__shutdown_asked:
is_up_event.set()
##############

# Main loop
Expand All @@ -351,9 +346,11 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) ->
finally:
self.__mainloop_task = None

raise AssertionError("sleep_forever() does not return")

serve_forever.__doc__ = AbstractAsyncNetworkServer.serve_forever.__doc__

async def __listener_accept(self, listener: AsyncListenerSocketAdapter, task_group: TaskGroup) -> None:
async def __listener_accept(self, listener: AsyncListenerSocketAdapter, task_group: TaskGroup) -> NoReturn:
backend = self.__backend
client_task = self.__client_coroutine
async with listener:
Expand All @@ -367,7 +364,7 @@ async def __listener_accept(self, listener: AsyncListenerSocketAdapter, task_gro
_errno.errorcode[exc.errno],
os.strerror(exc.errno),
ACCEPT_CAPACITY_ERROR_SLEEP_TIME,
exc_info=True,
exc_info=exc,
)
await backend.sleep(ACCEPT_CAPACITY_ERROR_SLEEP_TIME)
else:
Expand Down Expand Up @@ -431,7 +428,14 @@ async def __client_coroutine(self, accepted_socket: AcceptedSocket) -> None:
assert inspect.isawaitable(_on_connection_hook) # nosec assert_used
await _on_connection_hook
del _on_connection_hook
client_exit_stack.push_async_callback(self.__request_handler.on_disconnection, client)

async def disconnect_client() -> None:
try:
await self.__request_handler.on_disconnection(client)
except* ConnectionError:
self.__logger.warning("ConnectionError raised in request_handler.on_disconnection()")

client_exit_stack.push_async_callback(disconnect_client)

del client_exit_stack

Expand Down Expand Up @@ -507,7 +511,7 @@ def __suppress_and_log_remaining_exception(self, client_address: SocketAddress |
self.__logger.warning(
"There have been attempts to do operation on closed client %s",
client_address,
exc_info=True,
exc_info=excgrp,
)
except* ConnectionError:
# This exception come from the request handler ( most likely due to client.send_packet() )
Expand All @@ -518,9 +522,9 @@ def __suppress_and_log_remaining_exception(self, client_address: SocketAddress |
_remove_traceback_frames_in_place(exc, 1) # Removes the 'yield' frame just above
self.__logger.error("-" * 40)
if client_address is None:
self.__logger.exception("Error in client task")
self.__logger.error("Error in client task", exc_info=exc)
else:
self.__logger.exception("Exception occurred during processing of request from %s", client_address)
self.__logger.error("Exception occurred during processing of request from %s", client_address, exc_info=exc)
self.__logger.error("-" * 40)

def get_addresses(self) -> Sequence[SocketAddress]:
Expand Down
23 changes: 10 additions & 13 deletions src/easynetwork/api_async/server/udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import weakref
from collections import Counter, deque
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Coroutine, Iterator, Mapping
from typing import TYPE_CHECKING, Any, Generic, TypeVar, final
from typing import TYPE_CHECKING, Any, Generic, NoReturn, TypeVar, final
from weakref import WeakValueDictionary

from ..._typevars import _RequestT, _ResponseT
Expand Down Expand Up @@ -131,7 +131,7 @@ def __init__(
self.__is_shutdown.set()
self.__shutdown_asked: bool = False
self.__sendto_lock: ILock = backend.create_lock()
self.__mainloop_task: Task[None] | None = None
self.__mainloop_task: Task[NoReturn] | None = None
self.__logger: logging.Logger = logger or logging.getLogger(__name__)
self.__client_manager: _ClientAPIManager[_ResponseT] = _ClientAPIManager(
self.__backend,
Expand Down Expand Up @@ -184,13 +184,8 @@ async def shutdown(self) -> None:

shutdown.__doc__ = AbstractAsyncNetworkServer.shutdown.__doc__

async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> None:
async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> NoReturn:
async with _contextlib.AsyncExitStack() as server_exit_stack:
is_up_callback = server_exit_stack.enter_context(_contextlib.ExitStack())
if is_up_event is not None:
# Force is_up_event to be set, in order not to stuck the waiting task
is_up_callback.callback(is_up_event.set)

# Wake up server
if not self.__is_shutdown.is_set():
raise ServerAlreadyRunning("Server is already running")
Expand Down Expand Up @@ -237,8 +232,8 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) ->
#################

# Server is up
is_up_callback.close()
del is_up_callback
if is_up_event is not None and not self.__shutdown_asked:
is_up_event.set()
##############

# Main loop
Expand All @@ -250,13 +245,15 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) ->
finally:
self.__mainloop_task = None

raise AssertionError("received_datagrams() does not return")

serve_forever.__doc__ = AbstractAsyncNetworkServer.serve_forever.__doc__

async def __receive_datagrams(
self,
socket: AsyncDatagramSocketAdapter,
task_group: TaskGroup,
) -> None:
) -> NoReturn:
backend = self.__backend
socket_family: int = socket.socket().family
datagram_received_task_method = self.__datagram_received_coroutine
Expand Down Expand Up @@ -374,12 +371,12 @@ def __suppress_and_log_remaining_exception(self, client_address: SocketAddress)
self.__logger.warning(
"There have been attempts to do operation on closed client %s",
client_address,
exc_info=True,
exc_info=excgrp,
)
except Exception as exc:
_remove_traceback_frames_in_place(exc, 1) # Removes the 'yield' frame just above
self.__logger.error("-" * 40)
self.__logger.exception("Exception occurred during processing of request from %s", client_address)
self.__logger.error("Exception occurred during processing of request from %s", client_address, exc_info=exc)
self.__logger.error("-" * 40)

@_contextlib.contextmanager
Expand Down
24 changes: 9 additions & 15 deletions src/easynetwork/api_sync/server/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import threading as _threading
import time
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, NoReturn

from ...api_async.backend.abc import ThreadsPortal
from ...api_async.server.abc import SupportsEventSet
Expand Down Expand Up @@ -78,7 +78,7 @@ def server_close(self) -> None:

def shutdown(self, timeout: float | None = None) -> None:
if (portal := self._portal) is not None:
try:
with _contextlib.suppress(RuntimeError, concurrent.futures.CancelledError):
# If shutdown() have been cancelled, that means the scheduler itself is shutting down, and this is what we want
if timeout is None:
portal.run_coroutine(self.__server.shutdown)
Expand All @@ -88,8 +88,6 @@ def shutdown(self, timeout: float | None = None) -> None:
portal.run_coroutine(self.__do_shutdown_with_timeout, timeout)
finally:
timeout -= time.perf_counter() - _start
except (RuntimeError, concurrent.futures.CancelledError):
pass
self.__is_shutdown.wait(timeout)

shutdown.__doc__ = AbstractNetworkServer.shutdown.__doc__
Expand Down Expand Up @@ -119,10 +117,6 @@ def serve_forever(

backend = self.__server.get_backend()
with _contextlib.ExitStack() as server_exit_stack, _contextlib.suppress(backend.get_cancelled_exc_class()):
if is_up_event is not None:
# Force is_up_event to be set, in order not to stuck the waiting thread
server_exit_stack.callback(is_up_event.set)

# locks_stack is used to acquire locks until
# serve_forever() coroutine creates the thread portal
locks_stack = server_exit_stack.enter_context(_contextlib.ExitStack())
Expand All @@ -138,16 +132,16 @@ def serve_forever(
self.__is_shutdown.clear()
server_exit_stack.callback(self.__is_shutdown.set)

async def serve_forever() -> None:
def reset_threads_portal() -> None:
self.__threads_portal = None
def reset_threads_portal() -> None:
self.__threads_portal = None

def acquire_bootstrap_lock() -> None:
locks_stack.enter_context(self.__bootstrap_lock.get())
def acquire_bootstrap_lock() -> None:
locks_stack.enter_context(self.__bootstrap_lock.get())

server_exit_stack.callback(reset_threads_portal)
server_exit_stack.callback(acquire_bootstrap_lock)
server_exit_stack.callback(reset_threads_portal)
server_exit_stack.callback(acquire_bootstrap_lock)

async def serve_forever() -> NoReturn:
async with backend.create_threads_portal() as self.__threads_portal:
# Initialization finished; release the locks
locks_stack.close()
Expand Down
24 changes: 8 additions & 16 deletions src/easynetwork/api_sync/server/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,31 +36,23 @@ def __init__(
daemon: bool | None = None,
) -> None:
super().__init__(group=group, target=None, name=name, daemon=daemon)
self.__server: AbstractNetworkServer | None = server
self.__server: AbstractNetworkServer = server
self.__is_up_event: _threading.Event = _threading.Event()

def start(self) -> None:
super().start()
self.__is_up_event.wait()

def run(self) -> None:
assert self.__server is not None, f"{self.__server=}" # nosec assert_used
try:
return self.__server.serve_forever(is_up_event=self.__is_up_event)
self.__server.serve_forever(is_up_event=self.__is_up_event)
finally:
self.__server = None
self.__is_up_event.set()

def join(self, timeout: float | None = None) -> None:
server = self.__server
if server is not None:
_start = time.perf_counter()
try:
server.shutdown(timeout=timeout)
finally:
_end = time.perf_counter()
if timeout is not None:
timeout -= _end - _start
super().join(timeout=timeout)
else:
super().join(timeout=timeout)
_start = time.perf_counter()
self.__server.shutdown(timeout=timeout)
_end = time.perf_counter()
if timeout is not None:
timeout -= _end - _start
super().join(timeout=timeout)
18 changes: 12 additions & 6 deletions src/easynetwork_asyncio/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import enum
import math
from collections import deque
from collections.abc import Callable, Coroutine, Iterable
from collections.abc import Callable, Coroutine, Iterable, Iterator
from typing import TYPE_CHECKING, Any, NamedTuple, ParamSpec, Self, TypeVar, final
from weakref import WeakKeyDictionary

Expand Down Expand Up @@ -225,13 +225,13 @@ def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException
del self.__delayed_task_cancel_dict[host_task]
delayed_task_cancel.handle.cancel()

current_task_scope = self._current_task_scope(host_task)
if current_task_scope is None:
for cancel_scope in self._inner_to_outer_task_scopes(host_task):
if cancel_scope.__cancel_called:
self._reschedule_delayed_task_cancel(host_task, cancel_scope.__cancellation_id())
break
else:
if task_cancelling > 0:
self._reschedule_delayed_task_cancel(host_task, None)
else:
if current_task_scope.__cancel_called:
self._reschedule_delayed_task_cancel(host_task, current_task_scope.__cancellation_id())

return self.__cancelled_caught

Expand Down Expand Up @@ -290,6 +290,12 @@ def _current_task_scope(cls, task: asyncio.Task[Any]) -> CancelScope | None:
except LookupError:
return None

@classmethod
def _inner_to_outer_task_scopes(cls, task: asyncio.Task[Any]) -> Iterator[CancelScope]:
if cls._current_task_scope(task) is None:
return iter(())
return iter(cls.__current_task_scope_dict[task])

@classmethod
def _reschedule_delayed_task_cancel(cls, task: asyncio.Task[Any], cancel_msg: str | None) -> asyncio.Handle:
if task in cls.__delayed_task_cancel_dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1139,7 +1139,7 @@ async def coroutine(value: int) -> int:
await cancel_shielded_coroutine()
checkpoints.append("cancel_shielded_coroutine")

with backend.timeout(0):
with backend.timeout(0), backend.open_cancel_scope():
await cancel_shielded_coroutine()
checkpoints.append("inner_cancel_shielded_coroutine")
assert current_task.cancelling() == 2
Expand Down Expand Up @@ -1222,8 +1222,9 @@ async def coroutine() -> None:
outer_scope = backend.move_on_after(1.5)
inner_scope = backend.move_on_after(0.5)
with outer_scope:
with inner_scope:
await backend.ignore_cancellation(backend.sleep(1))
with backend.open_cancel_scope(), backend.open_cancel_scope():
with inner_scope:
await backend.ignore_cancellation(backend.sleep(1))
assert not inner_scope.cancelled_caught()
try:
await backend.coro_yield()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async def test____serve_forever____shutdown_during_setup(
await asyncio.sleep(0)
assert not event.is_set()
await server.shutdown()
assert event.is_set()
assert not event.is_set()

async def test____serve_forever____server_close_during_setup(
self,
Expand Down
Loading

0 comments on commit 6bf1e24

Please sign in to comment.