Skip to content

Commit

Permalink
Standalone servers: Improved robustness (#273)
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia authored Apr 20, 2024
1 parent dff87e6 commit 4ab44c0
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 82 deletions.
13 changes: 10 additions & 3 deletions docs/source/_extensions/sphinx_easynetwork.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
"""
Changelog:
v0.1.0: Initial
v0.1.1 (current): Fix base is not replaced if the class is generic.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, get_origin

if TYPE_CHECKING:
from sphinx.application import Sphinx
Expand All @@ -12,7 +19,7 @@
def _replace_base_in_place(klass: type, bases: list[type], base_to_replace: type, base_to_set_instead: type) -> None:
if issubclass(klass, base_to_replace):
for index, base in enumerate(bases):
if base is base_to_replace:
if get_origin(base) is base_to_replace:
bases[index] = base_to_set_instead


Expand All @@ -25,7 +32,7 @@ def setup(app: Sphinx) -> dict[str, Any]:
app.connect("autodoc-process-bases", autodoc_process_bases)

return {
"version": "0.1",
"version": "0.1.1",
"parallel_read_safe": True,
"parallel_write_safe": True,
}
110 changes: 58 additions & 52 deletions src/easynetwork/servers/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.
#
#
"""Asynchronous network server module"""
"""Generic network servers module"""

from __future__ import annotations

Expand All @@ -22,7 +22,7 @@
import contextlib
import threading as _threading
from collections.abc import Callable, Mapping, Sequence
from typing import Any
from typing import Any, Generic, TypeVar

from ..exceptions import ServerAlreadyRunning, ServerClosedError
from ..lowlevel import _utils
Expand All @@ -31,24 +31,28 @@
from ..lowlevel.socket import SocketAddress
from .abc import AbstractAsyncNetworkServer, AbstractNetworkServer, SupportsEventSet

_T_Return = TypeVar("_T_Return")
_T_Default = TypeVar("_T_Default")
_T_AsyncServer = TypeVar("_T_AsyncServer", bound=AbstractAsyncNetworkServer)

class BaseStandaloneNetworkServerImpl(AbstractNetworkServer):

class BaseStandaloneNetworkServerImpl(AbstractNetworkServer, Generic[_T_AsyncServer]):
__slots__ = (
"__server_factory",
"__default_runner_options",
"__private_server",
"__server",
"__backend",
"__close_lock",
"__bootstrap_lock",
"__private_threads_portal",
"__threads_portal",
"__is_shutdown",
"__is_closed",
)

def __init__(
self,
backend: AsyncBackend | None,
server_factory: Callable[[AsyncBackend], AbstractAsyncNetworkServer],
server_factory: Callable[[AsyncBackend], _T_AsyncServer],
*,
runner_options: Mapping[str, Any] | None = None,
) -> None:
Expand All @@ -64,22 +68,37 @@ def __init__(
raise TypeError(f"Expected an AsyncBackend instance, got {backend!r}")

self.__backend: AsyncBackend = backend
self.__server_factory: Callable[[AsyncBackend], AbstractAsyncNetworkServer] = server_factory
self.__private_server: AbstractAsyncNetworkServer | None = None
self.__private_threads_portal: ThreadsPortal | None = None
self.__server_factory: Callable[[AsyncBackend], _T_AsyncServer] = server_factory
self.__server: _T_AsyncServer | None = None
self.__threads_portal: ThreadsPortal | None = None
self.__is_shutdown = _threading.Event()
self.__is_shutdown.set()
self.__is_closed = _threading.Event()
self.__close_lock = ForkSafeLock()
self.__bootstrap_lock = ForkSafeLock()
self.__default_runner_options: dict[str, Any] = dict(runner_options) if runner_options else {}

def _run_sync_or_else(
self,
f: Callable[[ThreadsPortal, _T_AsyncServer], _T_Return],
default: Callable[[], _T_Default],
) -> _T_Return | _T_Default:
with self.__bootstrap_lock.get():
if (portal := self.__threads_portal) is not None and (server := self.__server) is not None:
with contextlib.suppress(RuntimeError, concurrent.futures.CancelledError):
return f(portal, server)
return default()

def _run_sync_or(
self,
f: Callable[[ThreadsPortal, _T_AsyncServer], _T_Return],
default: _T_Default,
) -> _T_Return | _T_Default:
return self._run_sync_or_else(f, lambda: default)

@_utils.inherit_doc(AbstractNetworkServer)
def is_serving(self) -> bool:
if (portal := self._portal) is not None and (server := self._server) is not None:
with contextlib.suppress(RuntimeError):
return portal.run_sync(server.is_serving)
return False
return self._run_sync_or(lambda portal, server: portal.run_sync(server.is_serving), False)

@_utils.inherit_doc(AbstractNetworkServer)
def server_close(self) -> None:
Expand All @@ -89,28 +108,29 @@ def server_close(self) -> None:
# Ensure we are not in the interval between the server shutdown and the scheduler shutdown
stack.callback(self.__is_shutdown.wait)

if (server := self._server) is not None and (portal := self._portal) is not None:
with contextlib.suppress(RuntimeError, concurrent.futures.CancelledError):
portal.run_coroutine(server.server_close)
self._run_sync_or(lambda portal, server: portal.run_coroutine(server.server_close), None)

@_utils.inherit_doc(AbstractNetworkServer)
def shutdown(self, timeout: float | None = None) -> None:
if (portal := self._portal) is not None and (server := self._server) is not None:
with contextlib.suppress(RuntimeError, concurrent.futures.CancelledError), _utils.ElapsedTime() as elapsed:
# If shutdown() have been cancelled, that means the scheduler itself is shutting down, and this is what we want
if timeout is None:
portal.run_coroutine(server.shutdown)
else:
portal.run_coroutine(self.__do_shutdown_with_timeout, server, timeout)
if timeout is not None:
timeout = elapsed.recompute_timeout(timeout)
with self.__bootstrap_lock.get():
if (portal := self.__threads_portal) is not None and (server := self.__server) is not None:

async def do_shutdown_with_timeout(server: AbstractAsyncNetworkServer, timeout: float) -> None:
with server.backend().move_on_after(timeout):
await server.shutdown()

with contextlib.suppress(RuntimeError, concurrent.futures.CancelledError), _utils.ElapsedTime() as elapsed:
# If shutdown() have been cancelled, that means the scheduler itself is shutting down,
# and this is what we want
if timeout is None:
portal.run_coroutine(server.shutdown)
else:

portal.run_coroutine(do_shutdown_with_timeout, server, timeout)
if timeout is not None:
timeout = elapsed.recompute_timeout(timeout)
self.__is_shutdown.wait(timeout)

@staticmethod
async def __do_shutdown_with_timeout(server: AbstractAsyncNetworkServer, timeout_delay: float) -> None:
with server.backend().move_on_after(timeout_delay):
await server.shutdown()

def serve_forever(
self,
*,
Expand Down Expand Up @@ -153,41 +173,27 @@ def serve_forever(
server_exit_stack.callback(self.__is_shutdown.set)

def reset_values() -> None:
self.__private_threads_portal = None
self.__private_server = None
self.__threads_portal = None
self.__server = None

def acquire_bootstrap_lock() -> None:
def reacquire_bootstrap_lock_on_shutdown() -> None:
locks_stack.enter_context(self.__bootstrap_lock.get())

server_exit_stack.callback(reset_values)
server_exit_stack.callback(acquire_bootstrap_lock)
server_exit_stack.callback(reacquire_bootstrap_lock_on_shutdown)

async def serve_forever() -> None:
async with (
self.__server_factory(backend) as self.__private_server,
backend.create_threads_portal() as self.__private_threads_portal,
self.__server_factory(backend) as self.__server,
backend.create_threads_portal() as self.__threads_portal,
):
server = self.__private_server
# Initialization finished; release the locks
locks_stack.close()

await server.serve_forever(is_up_event=is_up_event)
await self.__server.serve_forever(is_up_event=is_up_event)

backend.bootstrap(serve_forever, runner_options=runner_options)

@_utils.inherit_doc(AbstractNetworkServer)
def get_addresses(self) -> Sequence[SocketAddress]:
if (portal := self._portal) is not None and (server := self._server) is not None:
with contextlib.suppress(RuntimeError):
return portal.run_sync(server.get_addresses)
return ()

@property
def _server(self) -> AbstractAsyncNetworkServer | None:
with self.__bootstrap_lock.get():
return self.__private_server

@property
def _portal(self) -> ThreadsPortal | None:
with self.__bootstrap_lock.get():
return self.__private_threads_portal
return self._run_sync_or(lambda portal, server: portal.run_sync(server.get_addresses), ())
26 changes: 11 additions & 15 deletions src/easynetwork/servers/standalone_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"StandaloneTCPNetworkServer",
]

import contextlib
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Literal

Expand All @@ -38,7 +37,10 @@
from .handlers import AsyncStreamRequestHandler


class StandaloneTCPNetworkServer(_base.BaseStandaloneNetworkServerImpl, Generic[_T_Request, _T_Response]):
class StandaloneTCPNetworkServer(
_base.BaseStandaloneNetworkServerImpl[AsyncTCPNetworkServer[_T_Request, _T_Response]],
Generic[_T_Request, _T_Response],
):
"""
A network server for TCP connections.
Expand Down Expand Up @@ -107,9 +109,7 @@ def stop_listening(self) -> None:
Further calls to :meth:`is_serving` will return :data:`False`.
"""
if (portal := self._portal) is not None and (server := self._server) is not None:
with contextlib.suppress(RuntimeError):
portal.run_sync(server.stop_listening)
self._run_sync_or(lambda portal, server: portal.run_sync(server.stop_listening), None)

def get_sockets(self) -> Sequence[SocketProxy]:
"""Gets the listeners sockets. Thread-safe.
Expand All @@ -119,13 +119,9 @@ def get_sockets(self) -> Sequence[SocketProxy]:
If the server is not running, an empty sequence is returned.
"""
if (portal := self._portal) is not None and (server := self._server) is not None:
with contextlib.suppress(RuntimeError):
sockets = portal.run_sync(server.get_sockets)
return tuple(SocketProxy(sock, runner=portal.run_sync) for sock in sockets)
return ()

if TYPE_CHECKING:

@property
def _server(self) -> AsyncTCPNetworkServer[_T_Request, _T_Response] | None: ...
return self._run_sync_or(
lambda portal, server: tuple(
SocketProxy(sock, runner=portal.run_sync) for sock in portal.run_sync(server.get_sockets)
),
(),
)
22 changes: 10 additions & 12 deletions src/easynetwork/servers/standalone_udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"StandaloneUDPNetworkServer",
]

import contextlib
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic

Expand All @@ -37,7 +36,10 @@
from .handlers import AsyncDatagramRequestHandler


class StandaloneUDPNetworkServer(_base.BaseStandaloneNetworkServerImpl, Generic[_T_Request, _T_Response]):
class StandaloneUDPNetworkServer(
_base.BaseStandaloneNetworkServerImpl[AsyncUDPNetworkServer[_T_Request, _T_Response]],
Generic[_T_Request, _T_Response],
):
"""
A network server for UDP communication.
Expand Down Expand Up @@ -89,13 +91,9 @@ def get_sockets(self) -> Sequence[SocketProxy]:
If the server is not running, an empty sequence is returned.
"""
if (portal := self._portal) is not None and (server := self._server) is not None:
with contextlib.suppress(RuntimeError):
sockets = portal.run_sync(server.get_sockets)
return tuple(SocketProxy(sock, runner=portal.run_sync) for sock in sockets)
return ()

if TYPE_CHECKING:

@property
def _server(self) -> AsyncUDPNetworkServer[_T_Request, _T_Response] | None: ...
return self._run_sync_or(
lambda portal, server: tuple(
SocketProxy(sock, runner=portal.run_sync) for sock in portal.run_sync(server.get_sockets)
),
(),
)

0 comments on commit 4ab44c0

Please sign in to comment.