Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standalone servers: Improved robustness #273

Merged
merged 1 commit into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions docs/source/_extensions/sphinx_easynetwork.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
"""
Changelog:

v0.1.0: Initial
v0.1.1 (current): Fix base is not replaced if the class is generic.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, get_origin

if TYPE_CHECKING:
from sphinx.application import Sphinx
Expand All @@ -12,7 +19,7 @@
def _replace_base_in_place(klass: type, bases: list[type], base_to_replace: type, base_to_set_instead: type) -> None:
if issubclass(klass, base_to_replace):
for index, base in enumerate(bases):
if base is base_to_replace:
if get_origin(base) is base_to_replace:
bases[index] = base_to_set_instead


Expand All @@ -25,7 +32,7 @@ def setup(app: Sphinx) -> dict[str, Any]:
app.connect("autodoc-process-bases", autodoc_process_bases)

return {
"version": "0.1",
"version": "0.1.1",
"parallel_read_safe": True,
"parallel_write_safe": True,
}
110 changes: 58 additions & 52 deletions src/easynetwork/servers/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.
#
#
"""Asynchronous network server module"""
"""Generic network servers module"""

from __future__ import annotations

Expand All @@ -22,7 +22,7 @@
import contextlib
import threading as _threading
from collections.abc import Callable, Mapping, Sequence
from typing import Any
from typing import Any, Generic, TypeVar

from ..exceptions import ServerAlreadyRunning, ServerClosedError
from ..lowlevel import _utils
Expand All @@ -31,24 +31,28 @@
from ..lowlevel.socket import SocketAddress
from .abc import AbstractAsyncNetworkServer, AbstractNetworkServer, SupportsEventSet

_T_Return = TypeVar("_T_Return")
_T_Default = TypeVar("_T_Default")
_T_AsyncServer = TypeVar("_T_AsyncServer", bound=AbstractAsyncNetworkServer)

class BaseStandaloneNetworkServerImpl(AbstractNetworkServer):

class BaseStandaloneNetworkServerImpl(AbstractNetworkServer, Generic[_T_AsyncServer]):
__slots__ = (
"__server_factory",
"__default_runner_options",
"__private_server",
"__server",
"__backend",
"__close_lock",
"__bootstrap_lock",
"__private_threads_portal",
"__threads_portal",
"__is_shutdown",
"__is_closed",
)

def __init__(
self,
backend: AsyncBackend | None,
server_factory: Callable[[AsyncBackend], AbstractAsyncNetworkServer],
server_factory: Callable[[AsyncBackend], _T_AsyncServer],
*,
runner_options: Mapping[str, Any] | None = None,
) -> None:
Expand All @@ -64,22 +68,37 @@ def __init__(
raise TypeError(f"Expected an AsyncBackend instance, got {backend!r}")

self.__backend: AsyncBackend = backend
self.__server_factory: Callable[[AsyncBackend], AbstractAsyncNetworkServer] = server_factory
self.__private_server: AbstractAsyncNetworkServer | None = None
self.__private_threads_portal: ThreadsPortal | None = None
self.__server_factory: Callable[[AsyncBackend], _T_AsyncServer] = server_factory
self.__server: _T_AsyncServer | None = None
self.__threads_portal: ThreadsPortal | None = None
self.__is_shutdown = _threading.Event()
self.__is_shutdown.set()
self.__is_closed = _threading.Event()
self.__close_lock = ForkSafeLock()
self.__bootstrap_lock = ForkSafeLock()
self.__default_runner_options: dict[str, Any] = dict(runner_options) if runner_options else {}

def _run_sync_or_else(
self,
f: Callable[[ThreadsPortal, _T_AsyncServer], _T_Return],
default: Callable[[], _T_Default],
) -> _T_Return | _T_Default:
with self.__bootstrap_lock.get():
if (portal := self.__threads_portal) is not None and (server := self.__server) is not None:
with contextlib.suppress(RuntimeError, concurrent.futures.CancelledError):
return f(portal, server)
return default()

def _run_sync_or(
self,
f: Callable[[ThreadsPortal, _T_AsyncServer], _T_Return],
default: _T_Default,
) -> _T_Return | _T_Default:
return self._run_sync_or_else(f, lambda: default)

@_utils.inherit_doc(AbstractNetworkServer)
def is_serving(self) -> bool:
if (portal := self._portal) is not None and (server := self._server) is not None:
with contextlib.suppress(RuntimeError):
return portal.run_sync(server.is_serving)
return False
return self._run_sync_or(lambda portal, server: portal.run_sync(server.is_serving), False)

@_utils.inherit_doc(AbstractNetworkServer)
def server_close(self) -> None:
Expand All @@ -89,28 +108,29 @@ def server_close(self) -> None:
# Ensure we are not in the interval between the server shutdown and the scheduler shutdown
stack.callback(self.__is_shutdown.wait)

if (server := self._server) is not None and (portal := self._portal) is not None:
with contextlib.suppress(RuntimeError, concurrent.futures.CancelledError):
portal.run_coroutine(server.server_close)
self._run_sync_or(lambda portal, server: portal.run_coroutine(server.server_close), None)

@_utils.inherit_doc(AbstractNetworkServer)
def shutdown(self, timeout: float | None = None) -> None:
if (portal := self._portal) is not None and (server := self._server) is not None:
with contextlib.suppress(RuntimeError, concurrent.futures.CancelledError), _utils.ElapsedTime() as elapsed:
# If shutdown() have been cancelled, that means the scheduler itself is shutting down, and this is what we want
if timeout is None:
portal.run_coroutine(server.shutdown)
else:
portal.run_coroutine(self.__do_shutdown_with_timeout, server, timeout)
if timeout is not None:
timeout = elapsed.recompute_timeout(timeout)
with self.__bootstrap_lock.get():
if (portal := self.__threads_portal) is not None and (server := self.__server) is not None:

async def do_shutdown_with_timeout(server: AbstractAsyncNetworkServer, timeout: float) -> None:
with server.backend().move_on_after(timeout):
await server.shutdown()

with contextlib.suppress(RuntimeError, concurrent.futures.CancelledError), _utils.ElapsedTime() as elapsed:
# If shutdown() have been cancelled, that means the scheduler itself is shutting down,
# and this is what we want
if timeout is None:
portal.run_coroutine(server.shutdown)
else:

portal.run_coroutine(do_shutdown_with_timeout, server, timeout)
if timeout is not None:
timeout = elapsed.recompute_timeout(timeout)
self.__is_shutdown.wait(timeout)

@staticmethod
async def __do_shutdown_with_timeout(server: AbstractAsyncNetworkServer, timeout_delay: float) -> None:
with server.backend().move_on_after(timeout_delay):
await server.shutdown()

def serve_forever(
self,
*,
Expand Down Expand Up @@ -153,41 +173,27 @@ def serve_forever(
server_exit_stack.callback(self.__is_shutdown.set)

def reset_values() -> None:
self.__private_threads_portal = None
self.__private_server = None
self.__threads_portal = None
self.__server = None

def acquire_bootstrap_lock() -> None:
def reacquire_bootstrap_lock_on_shutdown() -> None:
locks_stack.enter_context(self.__bootstrap_lock.get())

server_exit_stack.callback(reset_values)
server_exit_stack.callback(acquire_bootstrap_lock)
server_exit_stack.callback(reacquire_bootstrap_lock_on_shutdown)

async def serve_forever() -> None:
async with (
self.__server_factory(backend) as self.__private_server,
backend.create_threads_portal() as self.__private_threads_portal,
self.__server_factory(backend) as self.__server,
backend.create_threads_portal() as self.__threads_portal,
):
server = self.__private_server
# Initialization finished; release the locks
locks_stack.close()

await server.serve_forever(is_up_event=is_up_event)
await self.__server.serve_forever(is_up_event=is_up_event)

backend.bootstrap(serve_forever, runner_options=runner_options)

@_utils.inherit_doc(AbstractNetworkServer)
def get_addresses(self) -> Sequence[SocketAddress]:
if (portal := self._portal) is not None and (server := self._server) is not None:
with contextlib.suppress(RuntimeError):
return portal.run_sync(server.get_addresses)
return ()

@property
def _server(self) -> AbstractAsyncNetworkServer | None:
with self.__bootstrap_lock.get():
return self.__private_server

@property
def _portal(self) -> ThreadsPortal | None:
with self.__bootstrap_lock.get():
return self.__private_threads_portal
return self._run_sync_or(lambda portal, server: portal.run_sync(server.get_addresses), ())
26 changes: 11 additions & 15 deletions src/easynetwork/servers/standalone_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"StandaloneTCPNetworkServer",
]

import contextlib
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Literal

Expand All @@ -38,7 +37,10 @@
from .handlers import AsyncStreamRequestHandler


class StandaloneTCPNetworkServer(_base.BaseStandaloneNetworkServerImpl, Generic[_T_Request, _T_Response]):
class StandaloneTCPNetworkServer(
_base.BaseStandaloneNetworkServerImpl[AsyncTCPNetworkServer[_T_Request, _T_Response]],
Generic[_T_Request, _T_Response],
):
"""
A network server for TCP connections.

Expand Down Expand Up @@ -107,9 +109,7 @@ def stop_listening(self) -> None:

Further calls to :meth:`is_serving` will return :data:`False`.
"""
if (portal := self._portal) is not None and (server := self._server) is not None:
with contextlib.suppress(RuntimeError):
portal.run_sync(server.stop_listening)
self._run_sync_or(lambda portal, server: portal.run_sync(server.stop_listening), None)

def get_sockets(self) -> Sequence[SocketProxy]:
"""Gets the listeners sockets. Thread-safe.
Expand All @@ -119,13 +119,9 @@ def get_sockets(self) -> Sequence[SocketProxy]:

If the server is not running, an empty sequence is returned.
"""
if (portal := self._portal) is not None and (server := self._server) is not None:
with contextlib.suppress(RuntimeError):
sockets = portal.run_sync(server.get_sockets)
return tuple(SocketProxy(sock, runner=portal.run_sync) for sock in sockets)
return ()

if TYPE_CHECKING:

@property
def _server(self) -> AsyncTCPNetworkServer[_T_Request, _T_Response] | None: ...
return self._run_sync_or(
lambda portal, server: tuple(
SocketProxy(sock, runner=portal.run_sync) for sock in portal.run_sync(server.get_sockets)
),
(),
)
22 changes: 10 additions & 12 deletions src/easynetwork/servers/standalone_udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"StandaloneUDPNetworkServer",
]

import contextlib
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic

Expand All @@ -37,7 +36,10 @@
from .handlers import AsyncDatagramRequestHandler


class StandaloneUDPNetworkServer(_base.BaseStandaloneNetworkServerImpl, Generic[_T_Request, _T_Response]):
class StandaloneUDPNetworkServer(
_base.BaseStandaloneNetworkServerImpl[AsyncUDPNetworkServer[_T_Request, _T_Response]],
Generic[_T_Request, _T_Response],
):
"""
A network server for UDP communication.

Expand Down Expand Up @@ -89,13 +91,9 @@ def get_sockets(self) -> Sequence[SocketProxy]:

If the server is not running, an empty sequence is returned.
"""
if (portal := self._portal) is not None and (server := self._server) is not None:
with contextlib.suppress(RuntimeError):
sockets = portal.run_sync(server.get_sockets)
return tuple(SocketProxy(sock, runner=portal.run_sync) for sock in sockets)
return ()

if TYPE_CHECKING:

@property
def _server(self) -> AsyncUDPNetworkServer[_T_Request, _T_Response] | None: ...
return self._run_sync_or(
lambda portal, server: tuple(
SocketProxy(sock, runner=portal.run_sync) for sock in portal.run_sync(server.get_sockets)
),
(),
)