Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Servers: Fixed slow access to client's extra attributes provided #355

Merged
merged 1 commit into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions src/easynetwork/servers/async_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
import logging
import weakref
from collections.abc import AsyncIterator, Callable, Coroutine, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, NoReturn, final
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Generic, NoReturn, TypeVar, final

from .._typevars import _T_Request, _T_Response
from ..exceptions import ClientClosedError
Expand Down Expand Up @@ -366,6 +367,9 @@ def get_sockets(self) -> Sequence[SocketProxy]:
)


_T_Value = TypeVar("_T_Value")


@final
@runtime_final_class
class _ConnectedClientAPI(AsyncStreamClient[_T_Response]):
Expand All @@ -388,7 +392,17 @@ def __init__(
self.__send_lock = client.backend().create_fair_lock()
self.__proxy: SocketProxy = SocketProxy(client.extra(INETSocketAttribute.socket))
self.__address: SocketAddress = address
self.__extra_attributes_cache: Mapping[Any, Callable[[], Any]] | None = None

local_address = new_socket_address(client.extra(INETSocketAttribute.sockname), client.extra(INETSocketAttribute.family))

self.__extra_attributes_cache: Mapping[Any, Callable[[], Any]] = MappingProxyType(
{
**client.extra_attributes,
INETClientAttribute.socket: _utils.make_callback(self.__simple_attribute_return, self.__proxy),
INETClientAttribute.local_address: _utils.make_callback(self.__simple_attribute_return, local_address),
INETClientAttribute.remote_address: _utils.make_callback(self.__simple_attribute_return, self.__address),
}
)

with contextlib.suppress(OSError):
set_tcp_nodelay(self.__proxy, True)
Expand Down Expand Up @@ -425,18 +439,10 @@ async def send_packet(self, packet: _T_Response, /) -> None:
def backend(self) -> AsyncBackend:
return self.__client.backend()

@staticmethod
def __simple_attribute_return(value: _T_Value) -> _T_Value:
return value

@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
if (extra_attributes_cache := self.__extra_attributes_cache) is not None:
return extra_attributes_cache
client = self.__client
self.__extra_attributes_cache = extra_attributes_cache = {
**client.extra_attributes,
INETClientAttribute.socket: lambda: self.__proxy,
INETClientAttribute.local_address: lambda: new_socket_address(
client.extra(INETSocketAttribute.sockname),
client.extra(INETSocketAttribute.family),
),
INETClientAttribute.remote_address: lambda: self.__address,
}
return extra_attributes_cache
return self.__extra_attributes_cache
34 changes: 17 additions & 17 deletions src/easynetwork/servers/async_udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ class _ClientAPI(AsyncDatagramClient[_T_Response]):
"__context",
"__service_available",
"__h",
"__extra_attributes_cache",
)

def __init__(
Expand All @@ -185,7 +184,6 @@ def __init__(
super().__init__()
self.__context: _datagram_server.DatagramClientContext[_T_Response, tuple[Any, ...]] = context
self.__h: int | None = None
self.__extra_attributes_cache: Mapping[Any, Callable[[], Any]] | None = None
self.__service_available: _utils.Flag = service_available

def __repr__(self) -> str:
Expand Down Expand Up @@ -221,24 +219,26 @@ async def send_packet(self, packet: _T_Response, /) -> None:
def backend(self) -> AsyncBackend:
return self.__context.backend()

def __get_server_socket(self) -> SocketProxy:
server = self.__context.server
return SocketProxy(server.extra(INETSocketAttribute.socket))

def __get_server_address(self) -> SocketAddress:
server = self.__context.server
return new_socket_address(server.extra(INETSocketAttribute.sockname), server.extra(INETSocketAttribute.family))

def __get_remote_address(self) -> SocketAddress:
server = self.__context.server
address = self.__context.address
return new_socket_address(address, server.extra(INETSocketAttribute.family))

@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
if (extra_attributes_cache := self.__extra_attributes_cache) is not None:
return extra_attributes_cache
server = self.__context.server
self.__extra_attributes_cache = extra_attributes_cache = {
**server.extra_attributes,
INETClientAttribute.socket: lambda: SocketProxy(server.extra(INETSocketAttribute.socket)),
INETClientAttribute.local_address: lambda: new_socket_address(
server.extra(INETSocketAttribute.sockname),
server.extra(INETSocketAttribute.family),
),
INETClientAttribute.remote_address: lambda: new_socket_address(
self.__context.address,
server.extra(INETSocketAttribute.family),
),
return {
INETClientAttribute.socket: self.__get_server_socket,
INETClientAttribute.local_address: self.__get_server_address,
INETClientAttribute.remote_address: self.__get_remote_address,
}
return extra_attributes_cache


@final
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
IncrementalDeserializeError,
StreamProtocolParseError,
)
from easynetwork.lowlevel._utils import remove_traceback_frames_in_place
from easynetwork.lowlevel.api_async.backend._asyncio.backend import AsyncIOBackend
from easynetwork.lowlevel.api_async.backend._asyncio.dns_resolver import AsyncIODNSResolver
from easynetwork.lowlevel.api_async.backend._asyncio.stream.listener import ListenerSocketAdapter
from easynetwork.lowlevel.api_async.transports.utils import aclose_forcefully
from easynetwork.lowlevel.socket import SocketAddress, enable_socket_linger
from easynetwork.lowlevel.socket import SocketAddress, SocketProxy, TLSAttribute, enable_socket_linger
from easynetwork.protocol import AnyStreamProtocolType
from easynetwork.servers.async_tcp import AsyncTCPNetworkServer
from easynetwork.servers.handlers import AsyncStreamClient, AsyncStreamRequestHandler, INETClientAttribute
Expand Down Expand Up @@ -56,7 +57,7 @@ class RandomError(Exception):


class MyAsyncTCPRequestHandler(AsyncStreamRequestHandler[str, str]):
connected_clients: WeakValueDictionary[SocketAddress, AsyncStreamClient[str]]
connected_clients: WeakValueDictionary[tuple[Any, ...], AsyncStreamClient[str]]
request_received: collections.defaultdict[tuple[Any, ...], list[str]]
request_count: collections.Counter[tuple[Any, ...]]
bad_request_received: collections.defaultdict[tuple[Any, ...], list[BaseProtocolParseError]]
Expand All @@ -70,19 +71,23 @@ async def service_init(self, exit_stack: contextlib.AsyncExitStack, server: Asyn
await super().service_init(exit_stack, server)
self.server = server
assert isinstance(self.server, AsyncTCPNetworkServer)

self.connected_clients = WeakValueDictionary()
exit_stack.callback(self.connected_clients.clear)

self.request_received = collections.defaultdict(list)
exit_stack.callback(self.request_received.clear)

self.request_count = collections.Counter()
exit_stack.callback(self.request_count.clear)

self.bad_request_received = collections.defaultdict(list)
exit_stack.callback(self.bad_request_received.clear)

exit_stack.push_async_callback(self.service_quit)

async def service_quit(self) -> None:
del (
self.connected_clients,
self.request_received,
self.request_count,
self.bad_request_received,
)
pass

async def on_connection(self, client: AsyncStreamClient[str]) -> None:
assert client_address(client) not in self.connected_clients
Expand Down Expand Up @@ -163,6 +168,7 @@ async def handle_bad_requests(self, client: AsyncStreamClient[str]) -> AsyncIter
try:
yield
except StreamProtocolParseError as exc:
remove_traceback_frames_in_place(exc, 1)
self.bad_request_received[client_address(client)].append(exc)
await client.send_packet("wrong encoding man.")

Expand All @@ -171,9 +177,6 @@ class TimeoutYieldedRequestHandler(AsyncStreamRequestHandler[str, str]):
request_timeout: float = 1.0
timeout_on_second_yield: bool = False

async def service_init(self, exit_stack: contextlib.AsyncExitStack, server: Any) -> None:
pass

async def on_connection(self, client: AsyncStreamClient[str]) -> None:
await client.send_packet("milk")

Expand Down Expand Up @@ -222,9 +225,6 @@ class InitialHandshakeRequestHandler(AsyncStreamRequestHandler[str, str]):
bypass_handshake: bool = False
handshake_2fa: bool = False

async def service_init(self, exit_stack: contextlib.AsyncExitStack, server: Any) -> None:
pass

async def on_connection(self, client: AsyncStreamClient[str]) -> AsyncGenerator[float | None, str]:
await client.send_packet("milk")
if self.bypass_handshake:
Expand Down Expand Up @@ -263,6 +263,7 @@ class RequestRefusedHandler(AsyncStreamRequestHandler[str, str]):

async def service_init(self, exit_stack: contextlib.AsyncExitStack, server: Any) -> None:
self.request_count: collections.Counter[AsyncStreamClient[str]] = collections.Counter()
exit_stack.callback(self.request_count.clear)

async def on_connection(self, client: AsyncStreamClient[str]) -> None:
await client.send_packet("milk")
Expand Down Expand Up @@ -643,20 +644,42 @@ async def test____serve_forever____accept_client____client_sent_RST_packet_right
assert caplog.records[0].levelno == logging.WARNING
assert caplog.records[0].message == "A client connection was interrupted just after listener.accept()"

async def test____serve_forever____disable_nagle_algorithm(
async def test____serve_forever____client_extra_attributes(
self,
client_factory: Callable[[], Awaitable[tuple[asyncio.StreamReader, asyncio.StreamWriter]]],
request_handler: MyAsyncTCPRequestHandler,
use_ssl: bool,
) -> None:
_ = await client_factory()
all_writers: list[asyncio.StreamWriter] = [(await client_factory())[1] for _ in range(3)]
assert len(request_handler.connected_clients) == 3

for writer in all_writers:
client_address: tuple[Any, ...] = writer.get_extra_info("sockname")
connected_client: AsyncStreamClient[str] = request_handler.connected_clients[client_address]

assert isinstance(connected_client.extra(INETClientAttribute.socket), SocketProxy)
assert connected_client.extra(INETClientAttribute.remote_address) == client_address
assert connected_client.extra(INETClientAttribute.local_address) == writer.get_extra_info("peername")

if use_ssl:
assert connected_client.extra(TLSAttribute.sslcontext, None) is not None
assert connected_client.extra(TLSAttribute.peercert, None) is not None

connected_client: AsyncStreamClient[str] = list(request_handler.connected_clients.values())[0]
async def test____serve_forever____disable_nagle_algorithm(
self,
client_factory: Callable[[], Awaitable[tuple[asyncio.StreamReader, asyncio.StreamWriter]]],
request_handler: MyAsyncTCPRequestHandler,
) -> None:
for _ in range(3):
_ = await client_factory()

tcp_nodelay_state: int = connected_client.extra(INETClientAttribute.socket).getsockopt(IPPROTO_TCP, TCP_NODELAY)
assert len(request_handler.connected_clients) == 3
for connected_client in request_handler.connected_clients.values():
tcp_nodelay_state: int = connected_client.extra(INETClientAttribute.socket).getsockopt(IPPROTO_TCP, TCP_NODELAY)

# Do not test with '== 1', on MacOS it will return 4
# (c.f. https://stackoverflow.com/a/31835137)
assert tcp_nodelay_state != 0
# Do not test with '== 1', on MacOS it will return 4
# (c.f. https://stackoverflow.com/a/31835137)
assert tcp_nodelay_state != 0

async def test____serve_forever____shutdown_during_loop____kill_client_tasks(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from typing import Any

from easynetwork.exceptions import BaseProtocolParseError, ClientClosedError, DatagramProtocolParseError, DeserializeError
from easynetwork.lowlevel._utils import remove_traceback_frames_in_place
from easynetwork.lowlevel.api_async.backend._asyncio.backend import AsyncIOBackend
from easynetwork.lowlevel.api_async.backend._asyncio.datagram.endpoint import DatagramEndpoint, create_datagram_endpoint
from easynetwork.lowlevel.api_async.backend._asyncio.datagram.listener import DatagramListenerSocketAdapter
from easynetwork.lowlevel.socket import SocketAddress
from easynetwork.lowlevel.socket import SocketAddress, SocketProxy
from easynetwork.protocol import DatagramProtocol
from easynetwork.servers.async_udp import AsyncUDPNetworkServer
from easynetwork.servers.handlers import AsyncDatagramClient, AsyncDatagramRequestHandler, INETClientAttribute
Expand Down Expand Up @@ -49,16 +50,27 @@ class MyAsyncUDPRequestHandler(AsyncDatagramRequestHandler[str, str]):
request_received: collections.defaultdict[tuple[Any, ...], list[str]]
bad_request_received: collections.defaultdict[tuple[Any, ...], list[BaseProtocolParseError]]
created_clients: set[AsyncDatagramClient[str]]
created_clients_map: dict[tuple[Any, ...], AsyncDatagramClient[str]]
server: AsyncUDPNetworkServer[str, str]

async def service_init(self, exit_stack: contextlib.AsyncExitStack, server: AsyncUDPNetworkServer[str, str]) -> None:
await super().service_init(exit_stack, server)
self.server = server
assert isinstance(self.server, AsyncUDPNetworkServer)

self.request_count = collections.Counter()
exit_stack.callback(self.request_count.clear)

self.request_received = collections.defaultdict(list)
exit_stack.callback(self.request_received.clear)

self.bad_request_received = collections.defaultdict(list)
exit_stack.callback(self.bad_request_received.clear)

self.created_clients = set()
self.created_clients_map = dict()
exit_stack.callback(self.created_clients_map.clear)
exit_stack.callback(self.created_clients.clear)

exit_stack.push_async_callback(self.service_quit)

Expand All @@ -69,21 +81,17 @@ async def service_quit(self) -> None:
with pytest.raises(ClientClosedError):
await client.send_packet("something")

del (
self.request_count,
self.request_received,
self.bad_request_received,
self.created_clients,
)

async def handle(self, client: AsyncDatagramClient[str]) -> AsyncGenerator[None, str]:
self.created_clients.add(client)
self.created_clients_map[client_address(client)] = client
while True:
async with self.handle_bad_requests(client):
request = yield
break
self.request_count[client_address(client)] += 1
match request:
case "__ping__":
await client.send_packet("pong")
case "__error__":
raise RandomError("Sorry man!")
case "__error_excgrp__":
Expand Down Expand Up @@ -120,6 +128,7 @@ async def handle_bad_requests(self, client: AsyncDatagramClient[str]) -> AsyncIt
try:
yield
except DatagramProtocolParseError as exc:
remove_traceback_frames_in_place(exc, 1)
self.bad_request_received[client_address(client)].append(exc)
await client.send_packet("wrong encoding man.")

Expand Down Expand Up @@ -196,6 +205,7 @@ class RequestRefusedHandler(AsyncDatagramRequestHandler[str, str]):

async def service_init(self, exit_stack: contextlib.AsyncExitStack, server: Any) -> None:
self.request_count: collections.Counter[AsyncDatagramClient[str]] = collections.Counter()
exit_stack.callback(self.request_count.clear)

async def handle(self, client: AsyncDatagramClient[str]) -> AsyncGenerator[None, str]:
if self.request_count[client] >= self.refuse_after and not self.bypass_refusal:
Expand Down Expand Up @@ -340,6 +350,13 @@ async def factory() -> DatagramEndpoint:

yield factory

@staticmethod
async def __ping_server(endpoint: DatagramEndpoint) -> None:
await endpoint.sendto(b"__ping__", None)
async with asyncio.timeout(1):
pong, _ = await endpoint.recvfrom()
assert pong == b"pong"

async def test____serve_forever____empty_listener_list(
self,
request_handler: MyAsyncUDPRequestHandler,
Expand Down Expand Up @@ -377,6 +394,26 @@ async def test____serve_forever____handle_request(

assert request_handler.request_received[client_address] == ["hello, world."]

async def test____serve_forever____client_extra_attributes(
self,
client_factory: Callable[[], Awaitable[DatagramEndpoint]],
request_handler: MyAsyncUDPRequestHandler,
) -> None:
all_endpoints: list[DatagramEndpoint] = [await client_factory() for _ in range(3)]

for endpoint in all_endpoints:
await self.__ping_server(endpoint)

assert len(request_handler.created_clients_map) == 3

for endpoint in all_endpoints:
client_address: tuple[Any, ...] = endpoint.get_extra_info("sockname")
connected_client: AsyncDatagramClient[str] = request_handler.created_clients_map[client_address]

assert isinstance(connected_client.extra(INETClientAttribute.socket), SocketProxy)
assert connected_client.extra(INETClientAttribute.remote_address) == client_address
assert connected_client.extra(INETClientAttribute.local_address) == endpoint.get_extra_info("peername")

async def test____serve_forever____client_equality(
self,
client_factory: Callable[[], Awaitable[DatagramEndpoint]],
Expand Down