Skip to content

Commit

Permalink
Servers: Fixed slow access to client's extra attributes provided (#355)
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia authored Sep 29, 2024
1 parent 96e73c0 commit fecf422
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 61 deletions.
36 changes: 21 additions & 15 deletions src/easynetwork/servers/async_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
import logging
import weakref
from collections.abc import AsyncIterator, Callable, Coroutine, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, NoReturn, final
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Generic, NoReturn, TypeVar, final

from .._typevars import _T_Request, _T_Response
from ..exceptions import ClientClosedError
Expand Down Expand Up @@ -366,6 +367,9 @@ def get_sockets(self) -> Sequence[SocketProxy]:
)


_T_Value = TypeVar("_T_Value")


@final
@runtime_final_class
class _ConnectedClientAPI(AsyncStreamClient[_T_Response]):
Expand All @@ -388,7 +392,17 @@ def __init__(
self.__send_lock = client.backend().create_fair_lock()
self.__proxy: SocketProxy = SocketProxy(client.extra(INETSocketAttribute.socket))
self.__address: SocketAddress = address
self.__extra_attributes_cache: Mapping[Any, Callable[[], Any]] | None = None

local_address = new_socket_address(client.extra(INETSocketAttribute.sockname), client.extra(INETSocketAttribute.family))

self.__extra_attributes_cache: Mapping[Any, Callable[[], Any]] = MappingProxyType(
{
**client.extra_attributes,
INETClientAttribute.socket: _utils.make_callback(self.__simple_attribute_return, self.__proxy),
INETClientAttribute.local_address: _utils.make_callback(self.__simple_attribute_return, local_address),
INETClientAttribute.remote_address: _utils.make_callback(self.__simple_attribute_return, self.__address),
}
)

with contextlib.suppress(OSError):
set_tcp_nodelay(self.__proxy, True)
Expand Down Expand Up @@ -425,18 +439,10 @@ async def send_packet(self, packet: _T_Response, /) -> None:
def backend(self) -> AsyncBackend:
return self.__client.backend()

@staticmethod
def __simple_attribute_return(value: _T_Value) -> _T_Value:
return value

@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
if (extra_attributes_cache := self.__extra_attributes_cache) is not None:
return extra_attributes_cache
client = self.__client
self.__extra_attributes_cache = extra_attributes_cache = {
**client.extra_attributes,
INETClientAttribute.socket: lambda: self.__proxy,
INETClientAttribute.local_address: lambda: new_socket_address(
client.extra(INETSocketAttribute.sockname),
client.extra(INETSocketAttribute.family),
),
INETClientAttribute.remote_address: lambda: self.__address,
}
return extra_attributes_cache
return self.__extra_attributes_cache
34 changes: 17 additions & 17 deletions src/easynetwork/servers/async_udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ class _ClientAPI(AsyncDatagramClient[_T_Response]):
"__context",
"__service_available",
"__h",
"__extra_attributes_cache",
)

def __init__(
Expand All @@ -185,7 +184,6 @@ def __init__(
super().__init__()
self.__context: _datagram_server.DatagramClientContext[_T_Response, tuple[Any, ...]] = context
self.__h: int | None = None
self.__extra_attributes_cache: Mapping[Any, Callable[[], Any]] | None = None
self.__service_available: _utils.Flag = service_available

def __repr__(self) -> str:
Expand Down Expand Up @@ -221,24 +219,26 @@ async def send_packet(self, packet: _T_Response, /) -> None:
def backend(self) -> AsyncBackend:
return self.__context.backend()

def __get_server_socket(self) -> SocketProxy:
server = self.__context.server
return SocketProxy(server.extra(INETSocketAttribute.socket))

def __get_server_address(self) -> SocketAddress:
server = self.__context.server
return new_socket_address(server.extra(INETSocketAttribute.sockname), server.extra(INETSocketAttribute.family))

def __get_remote_address(self) -> SocketAddress:
server = self.__context.server
address = self.__context.address
return new_socket_address(address, server.extra(INETSocketAttribute.family))

@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
if (extra_attributes_cache := self.__extra_attributes_cache) is not None:
return extra_attributes_cache
server = self.__context.server
self.__extra_attributes_cache = extra_attributes_cache = {
**server.extra_attributes,
INETClientAttribute.socket: lambda: SocketProxy(server.extra(INETSocketAttribute.socket)),
INETClientAttribute.local_address: lambda: new_socket_address(
server.extra(INETSocketAttribute.sockname),
server.extra(INETSocketAttribute.family),
),
INETClientAttribute.remote_address: lambda: new_socket_address(
self.__context.address,
server.extra(INETSocketAttribute.family),
),
return {
INETClientAttribute.socket: self.__get_server_socket,
INETClientAttribute.local_address: self.__get_server_address,
INETClientAttribute.remote_address: self.__get_remote_address,
}
return extra_attributes_cache


@final
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
IncrementalDeserializeError,
StreamProtocolParseError,
)
from easynetwork.lowlevel._utils import remove_traceback_frames_in_place
from easynetwork.lowlevel.api_async.backend._asyncio.backend import AsyncIOBackend
from easynetwork.lowlevel.api_async.backend._asyncio.dns_resolver import AsyncIODNSResolver
from easynetwork.lowlevel.api_async.backend._asyncio.stream.listener import ListenerSocketAdapter
from easynetwork.lowlevel.api_async.transports.utils import aclose_forcefully
from easynetwork.lowlevel.socket import SocketAddress, enable_socket_linger
from easynetwork.lowlevel.socket import SocketAddress, SocketProxy, TLSAttribute, enable_socket_linger
from easynetwork.protocol import AnyStreamProtocolType
from easynetwork.servers.async_tcp import AsyncTCPNetworkServer
from easynetwork.servers.handlers import AsyncStreamClient, AsyncStreamRequestHandler, INETClientAttribute
Expand Down Expand Up @@ -56,7 +57,7 @@ class RandomError(Exception):


class MyAsyncTCPRequestHandler(AsyncStreamRequestHandler[str, str]):
connected_clients: WeakValueDictionary[SocketAddress, AsyncStreamClient[str]]
connected_clients: WeakValueDictionary[tuple[Any, ...], AsyncStreamClient[str]]
request_received: collections.defaultdict[tuple[Any, ...], list[str]]
request_count: collections.Counter[tuple[Any, ...]]
bad_request_received: collections.defaultdict[tuple[Any, ...], list[BaseProtocolParseError]]
Expand All @@ -70,19 +71,23 @@ async def service_init(self, exit_stack: contextlib.AsyncExitStack, server: Asyn
await super().service_init(exit_stack, server)
self.server = server
assert isinstance(self.server, AsyncTCPNetworkServer)

self.connected_clients = WeakValueDictionary()
exit_stack.callback(self.connected_clients.clear)

self.request_received = collections.defaultdict(list)
exit_stack.callback(self.request_received.clear)

self.request_count = collections.Counter()
exit_stack.callback(self.request_count.clear)

self.bad_request_received = collections.defaultdict(list)
exit_stack.callback(self.bad_request_received.clear)

exit_stack.push_async_callback(self.service_quit)

async def service_quit(self) -> None:
del (
self.connected_clients,
self.request_received,
self.request_count,
self.bad_request_received,
)
pass

async def on_connection(self, client: AsyncStreamClient[str]) -> None:
assert client_address(client) not in self.connected_clients
Expand Down Expand Up @@ -163,6 +168,7 @@ async def handle_bad_requests(self, client: AsyncStreamClient[str]) -> AsyncIter
try:
yield
except StreamProtocolParseError as exc:
remove_traceback_frames_in_place(exc, 1)
self.bad_request_received[client_address(client)].append(exc)
await client.send_packet("wrong encoding man.")

Expand All @@ -171,9 +177,6 @@ class TimeoutYieldedRequestHandler(AsyncStreamRequestHandler[str, str]):
request_timeout: float = 1.0
timeout_on_second_yield: bool = False

async def service_init(self, exit_stack: contextlib.AsyncExitStack, server: Any) -> None:
pass

async def on_connection(self, client: AsyncStreamClient[str]) -> None:
await client.send_packet("milk")

Expand Down Expand Up @@ -222,9 +225,6 @@ class InitialHandshakeRequestHandler(AsyncStreamRequestHandler[str, str]):
bypass_handshake: bool = False
handshake_2fa: bool = False

async def service_init(self, exit_stack: contextlib.AsyncExitStack, server: Any) -> None:
pass

async def on_connection(self, client: AsyncStreamClient[str]) -> AsyncGenerator[float | None, str]:
await client.send_packet("milk")
if self.bypass_handshake:
Expand Down Expand Up @@ -263,6 +263,7 @@ class RequestRefusedHandler(AsyncStreamRequestHandler[str, str]):

async def service_init(self, exit_stack: contextlib.AsyncExitStack, server: Any) -> None:
self.request_count: collections.Counter[AsyncStreamClient[str]] = collections.Counter()
exit_stack.callback(self.request_count.clear)

async def on_connection(self, client: AsyncStreamClient[str]) -> None:
await client.send_packet("milk")
Expand Down Expand Up @@ -643,20 +644,42 @@ async def test____serve_forever____accept_client____client_sent_RST_packet_right
assert caplog.records[0].levelno == logging.WARNING
assert caplog.records[0].message == "A client connection was interrupted just after listener.accept()"

async def test____serve_forever____disable_nagle_algorithm(
async def test____serve_forever____client_extra_attributes(
self,
client_factory: Callable[[], Awaitable[tuple[asyncio.StreamReader, asyncio.StreamWriter]]],
request_handler: MyAsyncTCPRequestHandler,
use_ssl: bool,
) -> None:
_ = await client_factory()
all_writers: list[asyncio.StreamWriter] = [(await client_factory())[1] for _ in range(3)]
assert len(request_handler.connected_clients) == 3

for writer in all_writers:
client_address: tuple[Any, ...] = writer.get_extra_info("sockname")
connected_client: AsyncStreamClient[str] = request_handler.connected_clients[client_address]

assert isinstance(connected_client.extra(INETClientAttribute.socket), SocketProxy)
assert connected_client.extra(INETClientAttribute.remote_address) == client_address
assert connected_client.extra(INETClientAttribute.local_address) == writer.get_extra_info("peername")

if use_ssl:
assert connected_client.extra(TLSAttribute.sslcontext, None) is not None
assert connected_client.extra(TLSAttribute.peercert, None) is not None

connected_client: AsyncStreamClient[str] = list(request_handler.connected_clients.values())[0]
async def test____serve_forever____disable_nagle_algorithm(
self,
client_factory: Callable[[], Awaitable[tuple[asyncio.StreamReader, asyncio.StreamWriter]]],
request_handler: MyAsyncTCPRequestHandler,
) -> None:
for _ in range(3):
_ = await client_factory()

tcp_nodelay_state: int = connected_client.extra(INETClientAttribute.socket).getsockopt(IPPROTO_TCP, TCP_NODELAY)
assert len(request_handler.connected_clients) == 3
for connected_client in request_handler.connected_clients.values():
tcp_nodelay_state: int = connected_client.extra(INETClientAttribute.socket).getsockopt(IPPROTO_TCP, TCP_NODELAY)

# Do not test with '== 1', on MacOS it will return 4
# (c.f. https://stackoverflow.com/a/31835137)
assert tcp_nodelay_state != 0
# Do not test with '== 1', on MacOS it will return 4
# (c.f. https://stackoverflow.com/a/31835137)
assert tcp_nodelay_state != 0

async def test____serve_forever____shutdown_during_loop____kill_client_tasks(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from typing import Any

from easynetwork.exceptions import BaseProtocolParseError, ClientClosedError, DatagramProtocolParseError, DeserializeError
from easynetwork.lowlevel._utils import remove_traceback_frames_in_place
from easynetwork.lowlevel.api_async.backend._asyncio.backend import AsyncIOBackend
from easynetwork.lowlevel.api_async.backend._asyncio.datagram.endpoint import DatagramEndpoint, create_datagram_endpoint
from easynetwork.lowlevel.api_async.backend._asyncio.datagram.listener import DatagramListenerSocketAdapter
from easynetwork.lowlevel.socket import SocketAddress
from easynetwork.lowlevel.socket import SocketAddress, SocketProxy
from easynetwork.protocol import DatagramProtocol
from easynetwork.servers.async_udp import AsyncUDPNetworkServer
from easynetwork.servers.handlers import AsyncDatagramClient, AsyncDatagramRequestHandler, INETClientAttribute
Expand Down Expand Up @@ -49,16 +50,27 @@ class MyAsyncUDPRequestHandler(AsyncDatagramRequestHandler[str, str]):
request_received: collections.defaultdict[tuple[Any, ...], list[str]]
bad_request_received: collections.defaultdict[tuple[Any, ...], list[BaseProtocolParseError]]
created_clients: set[AsyncDatagramClient[str]]
created_clients_map: dict[tuple[Any, ...], AsyncDatagramClient[str]]
server: AsyncUDPNetworkServer[str, str]

async def service_init(self, exit_stack: contextlib.AsyncExitStack, server: AsyncUDPNetworkServer[str, str]) -> None:
await super().service_init(exit_stack, server)
self.server = server
assert isinstance(self.server, AsyncUDPNetworkServer)

self.request_count = collections.Counter()
exit_stack.callback(self.request_count.clear)

self.request_received = collections.defaultdict(list)
exit_stack.callback(self.request_received.clear)

self.bad_request_received = collections.defaultdict(list)
exit_stack.callback(self.bad_request_received.clear)

self.created_clients = set()
self.created_clients_map = dict()
exit_stack.callback(self.created_clients_map.clear)
exit_stack.callback(self.created_clients.clear)

exit_stack.push_async_callback(self.service_quit)

Expand All @@ -69,21 +81,17 @@ async def service_quit(self) -> None:
with pytest.raises(ClientClosedError):
await client.send_packet("something")

del (
self.request_count,
self.request_received,
self.bad_request_received,
self.created_clients,
)

async def handle(self, client: AsyncDatagramClient[str]) -> AsyncGenerator[None, str]:
self.created_clients.add(client)
self.created_clients_map[client_address(client)] = client
while True:
async with self.handle_bad_requests(client):
request = yield
break
self.request_count[client_address(client)] += 1
match request:
case "__ping__":
await client.send_packet("pong")
case "__error__":
raise RandomError("Sorry man!")
case "__error_excgrp__":
Expand Down Expand Up @@ -120,6 +128,7 @@ async def handle_bad_requests(self, client: AsyncDatagramClient[str]) -> AsyncIt
try:
yield
except DatagramProtocolParseError as exc:
remove_traceback_frames_in_place(exc, 1)
self.bad_request_received[client_address(client)].append(exc)
await client.send_packet("wrong encoding man.")

Expand Down Expand Up @@ -196,6 +205,7 @@ class RequestRefusedHandler(AsyncDatagramRequestHandler[str, str]):

async def service_init(self, exit_stack: contextlib.AsyncExitStack, server: Any) -> None:
self.request_count: collections.Counter[AsyncDatagramClient[str]] = collections.Counter()
exit_stack.callback(self.request_count.clear)

async def handle(self, client: AsyncDatagramClient[str]) -> AsyncGenerator[None, str]:
if self.request_count[client] >= self.refuse_after and not self.bypass_refusal:
Expand Down Expand Up @@ -340,6 +350,13 @@ async def factory() -> DatagramEndpoint:

yield factory

@staticmethod
async def __ping_server(endpoint: DatagramEndpoint) -> None:
await endpoint.sendto(b"__ping__", None)
async with asyncio.timeout(1):
pong, _ = await endpoint.recvfrom()
assert pong == b"pong"

async def test____serve_forever____empty_listener_list(
self,
request_handler: MyAsyncUDPRequestHandler,
Expand Down Expand Up @@ -377,6 +394,26 @@ async def test____serve_forever____handle_request(

assert request_handler.request_received[client_address] == ["hello, world."]

async def test____serve_forever____client_extra_attributes(
self,
client_factory: Callable[[], Awaitable[DatagramEndpoint]],
request_handler: MyAsyncUDPRequestHandler,
) -> None:
all_endpoints: list[DatagramEndpoint] = [await client_factory() for _ in range(3)]

for endpoint in all_endpoints:
await self.__ping_server(endpoint)

assert len(request_handler.created_clients_map) == 3

for endpoint in all_endpoints:
client_address: tuple[Any, ...] = endpoint.get_extra_info("sockname")
connected_client: AsyncDatagramClient[str] = request_handler.created_clients_map[client_address]

assert isinstance(connected_client.extra(INETClientAttribute.socket), SocketProxy)
assert connected_client.extra(INETClientAttribute.remote_address) == client_address
assert connected_client.extra(INETClientAttribute.local_address) == endpoint.get_extra_info("peername")

async def test____serve_forever____client_equality(
self,
client_factory: Callable[[], Awaitable[DatagramEndpoint]],
Expand Down

0 comments on commit fecf422

Please sign in to comment.