Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replaced AsyncioTransportStreamSocketAdapter by AsyncioTransportBufferedStreamSocketAdapter #252

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/easynetwork/lowlevel/std_asyncio/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from .datagram.listener import DatagramListenerSocketAdapter
from .datagram.socket import AsyncioTransportDatagramSocketAdapter
from .stream.listener import AcceptedSocketFactory, ListenerSocketAdapter
from .stream.socket import AsyncioTransportBufferedStreamSocketAdapter, AsyncioTransportStreamSocketAdapter
from .stream.socket import AsyncioTransportStreamSocketAdapter, StreamReaderBufferedProtocol
from .tasks import CancelScope, TaskGroup, TaskUtils
from .threads import ThreadsPortal

Expand Down Expand Up @@ -127,8 +127,12 @@ async def create_tcp_connection(

async def wrap_stream_socket(self, socket: _socket.socket) -> AsyncioTransportStreamSocketAdapter:
socket.setblocking(False)
reader, writer = await asyncio.open_connection(sock=socket)
return AsyncioTransportStreamSocketAdapter(reader, writer)
loop = asyncio.get_running_loop()
transport, protocol = await loop.create_connection(
_utils.make_callback(StreamReaderBufferedProtocol, loop=loop),
sock=socket,
)
return AsyncioTransportStreamSocketAdapter(transport, protocol)

async def create_tcp_listeners(
self,
Expand All @@ -137,7 +141,7 @@ async def create_tcp_listeners(
backlog: int,
*,
reuse_port: bool = False,
) -> Sequence[ListenerSocketAdapter[AsyncioTransportBufferedStreamSocketAdapter]]:
) -> Sequence[ListenerSocketAdapter[AsyncioTransportStreamSocketAdapter]]:
if not isinstance(backlog, int):
raise TypeError("backlog: Expected an integer")
loop = asyncio.get_running_loop()
Expand Down
8 changes: 4 additions & 4 deletions src/easynetwork/lowlevel/std_asyncio/stream/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from ...api_async.transports import abc as transports
from ..socket import AsyncSocket
from ..tasks import TaskGroup as AsyncIOTaskGroup
from .socket import AsyncioTransportBufferedStreamSocketAdapter, StreamReaderBufferedProtocol
from .socket import AsyncioTransportStreamSocketAdapter, StreamReaderBufferedProtocol

if TYPE_CHECKING:
import asyncio.trsock
Expand Down Expand Up @@ -148,17 +148,17 @@ async def connect(self, socket: _socket.socket, loop: asyncio.AbstractEventLoop)

@final
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class AcceptedSocketFactory(AbstractAcceptedSocketFactory[AsyncioTransportBufferedStreamSocketAdapter]):
class AcceptedSocketFactory(AbstractAcceptedSocketFactory[AsyncioTransportStreamSocketAdapter]):
def log_connection_error(self, logger: logging.Logger, exc: BaseException) -> None:
logger.error("Error in client task", exc_info=exc)

async def connect(
self,
socket: _socket.socket,
loop: asyncio.AbstractEventLoop,
) -> AsyncioTransportBufferedStreamSocketAdapter:
) -> AsyncioTransportStreamSocketAdapter:
transport, protocol = await loop.connect_accepted_socket(
_utils.make_callback(StreamReaderBufferedProtocol, loop=loop),
socket,
)
return AsyncioTransportBufferedStreamSocketAdapter(transport, protocol)
return AsyncioTransportStreamSocketAdapter(transport, protocol)
125 changes: 8 additions & 117 deletions src/easynetwork/lowlevel/std_asyncio/stream/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@

from __future__ import annotations

__all__ = ["AsyncioTransportBufferedStreamSocketAdapter", "AsyncioTransportStreamSocketAdapter"]
__all__ = ["AsyncioTransportStreamSocketAdapter"]

import asyncio
import errno as _errno
from collections import ChainMap
from collections.abc import Callable, Iterable, Mapping
from types import TracebackType
from typing import TYPE_CHECKING, Any, final
Expand All @@ -35,111 +34,17 @@

if TYPE_CHECKING:
import asyncio.trsock
import ssl as _typing_ssl

from _typeshed import WriteableBuffer


@final
class AsyncioTransportStreamSocketAdapter(transports.AsyncStreamTransport):
__slots__ = (
"__reader",
"__writer",
"__socket",
"__closing",
"__over_ssl",
)

def __init__(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
) -> None:
super().__init__()
self.__reader: asyncio.StreamReader = reader
self.__writer: asyncio.StreamWriter = writer
self.__over_ssl: bool = writer.get_extra_info("sslcontext") is not None

socket: asyncio.trsock.TransportSocket | None = writer.get_extra_info("socket")
assert socket is not None, "Writer transport must be a socket transport" # nosec assert_used
self.__socket: asyncio.trsock.TransportSocket = socket

# asyncio.Transport.is_closing() can suddently become true if there is something wrong with the socket
# even if transport.close() was never called.
# To bypass this side effect, we use our own flag.
self.__closing: bool = False

async def aclose(self) -> None:
self.__closing = True
if self.__writer.is_closing():
# Only wait for it.
try:
await self.__writer.wait_closed()
except OSError:
pass
return

try:
if self.__writer.can_write_eof():
self.__writer.write_eof()
except OSError:
pass
finally:
self.__writer.close()
try:
await self.__writer.wait_closed()
except OSError:
pass
except asyncio.CancelledError:
if self.__over_ssl:
self.__writer.transport.abort()
raise

def is_closing(self) -> bool:
return self.__closing

async def recv(self, bufsize: int) -> bytes:
if bufsize < 0:
raise ValueError("'bufsize' must be a positive or null integer")
return await self.__reader.read(bufsize)

async def send_all(self, data: bytes | bytearray | memoryview) -> None:
self.__writer.write(data)
await self.__writer.drain()

async def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | memoryview]) -> None:
self.__writer.writelines(iterable_of_data)
await self.__writer.drain()

async def send_eof(self) -> None:
if not self.__writer.can_write_eof():
raise UnsupportedOperation("transport does not support sending EOF")
self.__writer.write_eof()
await TaskUtils.coro_yield()

@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
socket = self.__socket
socket_extra: dict[Any, Callable[[], Any]] = socket_tools._get_socket_extra(socket, wrap_in_proxy=False)

ssl_obj: _typing_ssl.SSLObject | _typing_ssl.SSLSocket | None = self.__writer.get_extra_info("ssl_object")
if ssl_obj is None:
return socket_extra
return ChainMap(
socket_extra,
socket_tools._get_tls_extra(ssl_obj),
{socket_tools.TLSAttribute.standard_compatible: lambda: True},
)


@final
class AsyncioTransportBufferedStreamSocketAdapter(transports.AsyncStreamTransport, transports.AsyncBufferedStreamReadTransport):
class AsyncioTransportStreamSocketAdapter(transports.AsyncStreamTransport, transports.AsyncBufferedStreamReadTransport):
__slots__ = (
"__transport",
"__protocol",
"__socket",
"__closing",
"__over_ssl",
)

def __init__(
Expand All @@ -150,22 +55,22 @@ def __init__(
super().__init__()
self.__transport: asyncio.Transport = transport
self.__protocol: StreamReaderBufferedProtocol = protocol
self.__over_ssl: bool = transport.get_extra_info("sslcontext") is not None
over_ssl: bool = transport.get_extra_info("sslcontext") is not None

socket: asyncio.trsock.TransportSocket | None = transport.get_extra_info("socket")
assert socket is not None, "Writer transport must be a socket transport" # nosec assert_used
self.__socket: asyncio.trsock.TransportSocket = socket

if over_ssl:
raise NotImplementedError(f"{self.__class__.__name__} does not support SSL")

# asyncio.Transport.is_closing() can suddently become true if there is something wrong with the socket
# even if transport.close() was never called.
# To bypass this side effect, we use our own flag.
self.__closing: bool = False

# Disable in-memory byte buffering.
if self.__over_ssl:
transport.set_write_buffer_limits(1)
else:
transport.set_write_buffer_limits(0)
transport.set_write_buffer_limits(0)

async def aclose(self) -> None:
self.__closing = True
Expand All @@ -188,10 +93,6 @@ async def aclose(self) -> None:
await asyncio.shield(self.__protocol._get_close_waiter())
except OSError:
pass
except asyncio.CancelledError:
if self.__over_ssl:
self.__transport.abort()
raise

def is_closing(self) -> bool:
return self.__closing
Expand All @@ -218,17 +119,7 @@ async def send_eof(self) -> None:

@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
socket = self.__socket
socket_extra: dict[Any, Callable[[], Any]] = socket_tools._get_socket_extra(socket, wrap_in_proxy=False)

ssl_obj: _typing_ssl.SSLObject | _typing_ssl.SSLSocket | None = self.__transport.get_extra_info("ssl_object")
if ssl_obj is None:
return socket_extra
return ChainMap(
socket_extra,
socket_tools._get_tls_extra(ssl_obj),
{socket_tools.TLSAttribute.standard_compatible: lambda: True},
)
return socket_tools._get_socket_extra(self.__socket, wrap_in_proxy=False)


class StreamReaderBufferedProtocol(asyncio.BufferedProtocol):
Expand Down
49 changes: 27 additions & 22 deletions tests/unit_test/test_async/test_asyncio_backend/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from easynetwork.lowlevel.std_asyncio import AsyncIOBackend
from easynetwork.lowlevel.std_asyncio.datagram.listener import DatagramListenerProtocol
from easynetwork.lowlevel.std_asyncio.stream.listener import AbstractAcceptedSocketFactory, AcceptedSocketFactory
from easynetwork.lowlevel.std_asyncio.stream.socket import StreamReaderBufferedProtocol

import pytest

Expand Down Expand Up @@ -206,28 +207,27 @@ async def test____get_current_task____compute_task_info(
assert task_info.coro is current_task.get_coro()

@pytest.mark.parametrize("happy_eyeballs_delay", [None, 42], ids=lambda p: f"happy_eyeballs_delay=={p}")
async def test____create_tcp_connection____use_asyncio_open_connection(
async def test____create_tcp_connection____use_loop_create_connection(
self,
happy_eyeballs_delay: float | None,
event_loop: asyncio.AbstractEventLoop,
local_address: tuple[str, int] | None,
remote_address: tuple[str, int],
backend: AsyncIOBackend,
mock_asyncio_stream_reader_factory: Callable[[], MagicMock],
mock_asyncio_stream_writer_factory: Callable[[], MagicMock],
mock_tcp_socket: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_asyncio_reader = mock_asyncio_stream_reader_factory()
mock_asyncio_writer = mock_asyncio_stream_writer_factory()
mock_StreamSocketAdapter: MagicMock = mocker.patch(
mock_asyncio_transport = mocker.NonCallableMagicMock(spec=asyncio.Transport)
mock_protocol = mocker.NonCallableMagicMock(spec=StreamReaderBufferedProtocol)
mock_AsyncioTransportStreamSocketAdapter: MagicMock = mocker.patch(
"easynetwork.lowlevel.std_asyncio.backend.AsyncioTransportStreamSocketAdapter", return_value=mocker.sentinel.socket
)
mock_open_connection: AsyncMock = mocker.patch(
"asyncio.open_connection",
mock_event_loop_create_connection: AsyncMock = mocker.patch.object(
event_loop,
"create_connection",
new_callable=mocker.AsyncMock,
return_value=(mock_asyncio_reader, mock_asyncio_writer),
return_value=(mock_asyncio_transport, mock_protocol),
)
mock_own_create_connection: AsyncMock = mocker.patch(
"easynetwork.lowlevel.std_asyncio.backend.create_connection",
Expand All @@ -253,37 +253,42 @@ async def test____create_tcp_connection____use_asyncio_open_connection(
happy_eyeballs_delay=expected_happy_eyeballs_delay,
local_address=local_address,
)
mock_open_connection.assert_awaited_once_with(sock=mock_tcp_socket)
mock_StreamSocketAdapter.assert_called_once_with(mock_asyncio_reader, mock_asyncio_writer)
mock_event_loop_create_connection.assert_awaited_once_with(
partial_eq(StreamReaderBufferedProtocol, loop=event_loop),
sock=mock_tcp_socket,
)
mock_AsyncioTransportStreamSocketAdapter.assert_called_once_with(mock_asyncio_transport, mock_protocol)
assert socket is mocker.sentinel.socket

async def test____wrap_stream_socket____use_asyncio_open_connection(
self,
event_loop: asyncio.AbstractEventLoop,
backend: AsyncIOBackend,
mock_tcp_socket: MagicMock,
mock_asyncio_stream_reader_factory: Callable[[], MagicMock],
mock_asyncio_stream_writer_factory: Callable[[], MagicMock],
mocker: MockerFixture,
) -> None:
# Arrange
mock_asyncio_reader = mock_asyncio_stream_reader_factory()
mock_asyncio_writer = mock_asyncio_stream_writer_factory()
mock_asyncio_transport = mocker.NonCallableMagicMock(spec=asyncio.Transport)
mock_protocol = mocker.NonCallableMagicMock(spec=StreamReaderBufferedProtocol)
mock_AsyncioTransportStreamSocketAdapter: MagicMock = mocker.patch(
"easynetwork.lowlevel.std_asyncio.backend.AsyncioTransportStreamSocketAdapter",
return_value=mocker.sentinel.socket,
"easynetwork.lowlevel.std_asyncio.backend.AsyncioTransportStreamSocketAdapter", return_value=mocker.sentinel.socket
)
mock_open_connection: AsyncMock = mocker.patch(
"asyncio.open_connection",
mock_event_loop_create_connection: AsyncMock = mocker.patch.object(
event_loop,
"create_connection",
new_callable=mocker.AsyncMock,
return_value=(mock_asyncio_reader, mock_asyncio_writer),
return_value=(mock_asyncio_transport, mock_protocol),
)

# Act
socket = await backend.wrap_stream_socket(mock_tcp_socket)

# Assert
mock_open_connection.assert_awaited_once_with(sock=mock_tcp_socket)
mock_AsyncioTransportStreamSocketAdapter.assert_called_once_with(mock_asyncio_reader, mock_asyncio_writer)
mock_event_loop_create_connection.assert_awaited_once_with(
partial_eq(StreamReaderBufferedProtocol, loop=event_loop),
sock=mock_tcp_socket,
)
mock_AsyncioTransportStreamSocketAdapter.assert_called_once_with(mock_asyncio_transport, mock_protocol)
assert socket is mocker.sentinel.socket
mock_tcp_socket.setblocking.assert_called_with(False)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def test____create_datagram_endpoint____return_DatagramEndpoint_instance(
mock_loop_create_datagram_endpoint: AsyncMock = cast(
"AsyncMock",
mocker.patch.object(
asyncio.get_running_loop(),
event_loop,
"create_datagram_endpoint",
new_callable=mocker.AsyncMock,
return_value=(
Expand Down
Loading