Skip to content

Commit

Permalink
Low-level API: Faster access to extra attributes (#362)
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia authored Oct 17, 2024
1 parent 64000ab commit c368ca9
Show file tree
Hide file tree
Showing 12 changed files with 75 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import logging
import warnings
from collections.abc import Callable, Coroutine, Mapping
from types import MappingProxyType
from typing import Any, NoReturn, final

from ..... import _utils, socket as socket_tools
Expand All @@ -42,8 +43,8 @@ class DatagramListenerSocketAdapter(transports.AsyncDatagramListener[tuple[Any,
"__backend",
"__transport",
"__protocol",
"__socket",
"__closing",
"__extra_attributes",
)

def __init__(self, backend: AsyncBackend, transport: asyncio.DatagramTransport, protocol: DatagramListenerProtocol) -> None:
Expand All @@ -56,13 +57,14 @@ def __init__(self, backend: AsyncBackend, transport: asyncio.DatagramTransport,
self.__backend: AsyncBackend = backend
self.__transport: asyncio.DatagramTransport = transport
self.__protocol: DatagramListenerProtocol = protocol
self.__socket: asyncio.trsock.TransportSocket = socket

# asyncio.DatagramTransport.is_closing() can suddently become true if there is something wrong with the socket
# even if transport.close() was never called.
# To bypass this side effect, we use our own flag.
self.__closing: bool = False

self.__extra_attributes = MappingProxyType(socket_tools._get_socket_extra(socket, wrap_in_proxy=False))

def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None:
try:
transport = self.__transport
Expand Down Expand Up @@ -104,8 +106,7 @@ def backend(self) -> AsyncBackend:

@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
socket = self.__socket
return socket_tools._get_socket_extra(socket, wrap_in_proxy=False)
return self.__extra_attributes


@dataclasses.dataclass(eq=False, frozen=True, slots=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import asyncio.trsock
import warnings
from collections.abc import Callable, Mapping
from types import MappingProxyType
from typing import Any, final

from ..... import _utils, socket as socket_tools
Expand All @@ -36,8 +37,8 @@ class AsyncioTransportDatagramSocketAdapter(AsyncDatagramTransport):
__slots__ = (
"__backend",
"__endpoint",
"__socket",
"__closing",
"__extra_attributes",
)

def __init__(self, backend: AsyncBackend, endpoint: DatagramEndpoint) -> None:
Expand All @@ -49,13 +50,14 @@ def __init__(self, backend: AsyncBackend, endpoint: DatagramEndpoint) -> None:

self.__backend: AsyncBackend = backend
self.__endpoint: DatagramEndpoint = endpoint
self.__socket: asyncio.trsock.TransportSocket = socket

# asyncio.DatagramTransport.is_closing() can suddently become true if there is something wrong with the socket
# even if transport.close() was never called.
# To bypass this side effect, we use our own flag.
self.__closing: bool = False

self.__extra_attributes = MappingProxyType(socket_tools._get_socket_extra(socket, wrap_in_proxy=False))

def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None:
try:
endpoint = self.__endpoint
Expand Down Expand Up @@ -85,5 +87,4 @@ def backend(self) -> AsyncBackend:

@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
socket = self.__socket
return socket_tools._get_socket_extra(socket, wrap_in_proxy=False)
return self.__extra_attributes
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import warnings
from abc import abstractmethod
from collections.abc import Callable, Coroutine, Mapping
from types import MappingProxyType
from typing import Any, Generic, NoReturn, TypeVar, final

from ..... import _utils, constants, socket as socket_tools
Expand All @@ -49,10 +50,10 @@ class ListenerSocketAdapter(AsyncListener[_T_Stream]):
__slots__ = (
"__backend",
"__socket",
"__trsock",
"__accepted_socket_factory",
"__accept_scope",
"__serve_guard",
"__extra_attributes",
)

def __init__(
Expand All @@ -68,14 +69,16 @@ def __init__(

_utils.check_socket_no_ssl(socket)
socket.setblocking(False)
trsock: asyncio.trsock.TransportSocket = asyncio.trsock.TransportSocket(socket)

self.__socket: _socket.socket | None = socket
self.__trsock: asyncio.trsock.TransportSocket = asyncio.trsock.TransportSocket(socket)
self.__backend: AsyncBackend = backend
self.__accepted_socket_factory = accepted_socket_factory
self.__accept_scope: CancelScope | None = None
self.__serve_guard: _utils.ResourceGuard = _utils.ResourceGuard(f"{self.__class__.__name__}.serve() awaited twice.")

self.__extra_attributes = MappingProxyType(socket_tools._get_socket_extra(trsock, wrap_in_proxy=False))

def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None:
try:
socket: _socket.socket | None = self.__socket
Expand Down Expand Up @@ -186,7 +189,7 @@ def backend(self) -> AsyncBackend:

@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return socket_tools._get_socket_extra(self.__trsock, wrap_in_proxy=False)
return self.__extra_attributes


class AbstractAcceptedSocketFactory(Generic[_T_Stream]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import errno as _errno
import warnings
from collections.abc import Callable, Iterable, Mapping
from types import TracebackType
from types import MappingProxyType, TracebackType
from typing import TYPE_CHECKING, Any, final

from ......exceptions import UnsupportedOperation
Expand All @@ -44,8 +44,8 @@ class AsyncioTransportStreamSocketAdapter(AsyncStreamTransport):
"__backend",
"__transport",
"__protocol",
"__socket",
"__closing",
"__extra_attributes",
)

def __init__(
Expand All @@ -63,7 +63,6 @@ def __init__(
if over_ssl:
raise NotImplementedError(f"{self.__class__.__name__} does not support SSL")

self.__socket: asyncio.trsock.TransportSocket = socket
self.__backend: AsyncBackend = backend
self.__transport: asyncio.Transport = transport
self.__protocol: StreamReaderBufferedProtocol = protocol
Expand All @@ -76,6 +75,8 @@ def __init__(
# Disable in-memory byte buffering.
transport.set_write_buffer_limits(0)

self.__extra_attributes = MappingProxyType(socket_tools._get_socket_extra(socket, wrap_in_proxy=False))

def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None:
try:
closing = self.__closing
Expand Down Expand Up @@ -132,7 +133,7 @@ def backend(self) -> AsyncBackend:

@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return socket_tools._get_socket_extra(self.__socket, wrap_in_proxy=False)
return self.__extra_attributes


class StreamReaderBufferedProtocol(asyncio.BufferedProtocol):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import socket as _socket
import warnings
from collections.abc import Awaitable, Callable, Coroutine, Mapping
from types import MappingProxyType
from typing import Any, NoReturn, final

import trio
Expand All @@ -38,9 +39,9 @@ class TrioDatagramListenerSocketAdapter(AsyncDatagramListener[tuple[Any, ...]]):
__slots__ = (
"__backend",
"__listener",
"__trsock",
"__serve_guard",
"__send_lock",
"__extra_attributes",
"__wait_readable",
"__wait_writable",
)
Expand All @@ -59,9 +60,9 @@ def __init__(self, backend: AsyncBackend, sock: _socket.socket) -> None:

self.__backend: AsyncBackend = backend
self.__listener: _socket.socket = sock
self.__trsock: socket_tools.SocketProxy = socket_tools.SocketProxy(sock)
self.__serve_guard: _utils.ResourceGuard = _utils.ResourceGuard(f"{self.__class__.__name__}.serve() awaited twice.")
self.__send_lock: ILock = FastFIFOLock()
self.__extra_attributes = MappingProxyType(socket_tools._get_socket_extra(sock))

self.__wait_readable: Callable[[_socket.socket], Awaitable[None]] = wait_readable
self.__wait_writable: Callable[[_socket.socket], Awaitable[None]] = wait_writable
Expand Down Expand Up @@ -131,4 +132,4 @@ def backend(self) -> AsyncBackend:

@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return socket_tools._get_socket_extra(self.__trsock, wrap_in_proxy=False)
return self.__extra_attributes
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import socket as _socket
import warnings
from collections.abc import Callable, Mapping
from types import MappingProxyType
from typing import Any, final

import trio
Expand All @@ -36,7 +37,7 @@ class TrioDatagramSocketAdapter(AsyncDatagramTransport):
__slots__ = (
"__backend",
"__socket",
"__trsock",
"__extra_attributes",
)

from .....constants import MAX_DATAGRAM_BUFSIZE
Expand All @@ -49,7 +50,7 @@ def __init__(self, backend: AsyncBackend, sock: trio.socket.SocketType) -> None:

self.__backend: AsyncBackend = backend
self.__socket: trio.socket.SocketType = sock
self.__trsock: socket_tools.SocketProxy = socket_tools.SocketProxy(sock)
self.__extra_attributes = MappingProxyType(socket_tools._get_socket_extra(sock))

def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None:
try:
Expand Down Expand Up @@ -78,4 +79,4 @@ def backend(self) -> AsyncBackend:

@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return socket_tools._get_socket_extra(self.__trsock, wrap_in_proxy=False)
return self.__extra_attributes
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import os
import warnings
from collections.abc import Callable, Coroutine, Mapping
from types import MappingProxyType
from typing import Any, NoReturn, final

import trio
Expand All @@ -41,17 +42,17 @@ class TrioListenerSocketAdapter(AsyncListener[TrioStreamSocketAdapter]):
__slots__ = (
"__backend",
"__listener",
"__trsock",
"__serve_guard",
"__extra_attributes",
)

def __init__(self, backend: AsyncBackend, listener: trio.SocketListener) -> None:
super().__init__()

self.__backend: AsyncBackend = backend
self.__listener: trio.SocketListener = listener
self.__trsock: socket_tools.SocketProxy = socket_tools.SocketProxy(listener.socket)
self.__serve_guard: _utils.ResourceGuard = _utils.ResourceGuard(f"{self.__class__.__name__}.serve() awaited twice.")
self.__extra_attributes = MappingProxyType(socket_tools._get_socket_extra(listener.socket))

def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None:
try:
Expand Down Expand Up @@ -111,4 +112,4 @@ def backend(self) -> AsyncBackend:

@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return socket_tools._get_socket_extra(self.__trsock, wrap_in_proxy=False)
return self.__extra_attributes
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import warnings
from collections import deque
from collections.abc import Callable, Iterable, Mapping
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, final

import trio
Expand All @@ -43,15 +44,15 @@ class TrioStreamSocketAdapter(AsyncStreamTransport):
__slots__ = (
"__backend",
"__stream",
"__trsock",
"__extra_attributes",
)

def __init__(self, backend: AsyncBackend, stream: trio.SocketStream) -> None:
super().__init__()

self.__backend: AsyncBackend = backend
self.__stream: trio.SocketStream = stream
self.__trsock: socket_tools.SocketProxy = socket_tools.SocketProxy(stream.socket)
self.__extra_attributes = MappingProxyType(socket_tools._get_socket_extra(stream.socket))

def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None:
try:
Expand Down Expand Up @@ -109,4 +110,4 @@ def backend(self) -> AsyncBackend:

@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return socket_tools._get_socket_extra(self.__trsock, wrap_in_proxy=False)
return self.__extra_attributes
17 changes: 13 additions & 4 deletions src/easynetwork/lowlevel/api_async/transports/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,15 @@ class AsyncTLSStreamTransport(AsyncStreamTransport):
__transport_recv_lock: ILock = dataclasses.field(init=False)
__closing: bool = dataclasses.field(init=False, default=False)
__closed: IEvent = dataclasses.field(init=False)
__tls_extra_atributes: Mapping[Any, Callable[[], Any]] = dataclasses.field(init=False)

def __post_init__(self) -> None:
backend = self._transport.backend()
self.__incoming_reader = _IncomingDataReader(transport=self._transport)
self.__transport_send_lock = backend.create_fair_lock()
self.__transport_recv_lock = backend.create_fair_lock()
self.__closed = backend.create_event()
self.__tls_extra_atributes = socket_tools._get_tls_extra(self._ssl_object, self._standard_compatible)

@classmethod
async def wrap(
Expand Down Expand Up @@ -310,8 +312,7 @@ async def _retry_ssl_method(
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return {
**self._transport.extra_attributes,
**socket_tools._get_tls_extra(self._ssl_object),
socket_tools.TLSAttribute.standard_compatible: lambda: self._standard_compatible,
**self.__tls_extra_atributes,
}


Expand All @@ -327,6 +328,7 @@ class AsyncTLSListener(AsyncListener[AsyncTLSStreamTransport]):
"__handshake_timeout",
"__shutdown_timeout",
"__handshake_error_handler",
"__tls_extra_attributes",
)

def __init__(
Expand Down Expand Up @@ -362,6 +364,7 @@ def __init__(
self.__shutdown_timeout: float | None = shutdown_timeout
self.__standard_compatible: bool = standard_compatible
self.__handshake_error_handler: Callable[[Exception], None] | None = handshake_error_handler
self.__tls_extra_attributes = self.__make_tls_extra_attributes(self.__ssl_context, self.__standard_compatible)

def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None:
try:
Expand Down Expand Up @@ -427,13 +430,19 @@ def __default_handshake_error_handler(exc: Exception) -> None:
def backend(self) -> AsyncBackend:
return self.__listener.backend()

@staticmethod
def __make_tls_extra_attributes(ssl_context: SSLContext, standard_compatible: bool) -> dict[Any, Callable[[], Any]]:
return {
socket_tools.TLSAttribute.sslcontext: lambda: ssl_context,
socket_tools.TLSAttribute.standard_compatible: lambda: standard_compatible,
}

@property
@_utils.inherit_doc(AsyncListener)
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return {
**self.__listener.extra_attributes,
socket_tools.TLSAttribute.sslcontext: lambda: self.__ssl_context,
socket_tools.TLSAttribute.standard_compatible: lambda: self.__standard_compatible,
**self.__tls_extra_attributes,
}


Expand Down
Loading

0 comments on commit c368ca9

Please sign in to comment.