Skip to content

Commit

Permalink
Removed BufferedStreamReadTransport class
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia committed Jun 26, 2024
1 parent e9a2f17 commit 7135dbd
Show file tree
Hide file tree
Showing 28 changed files with 325 additions and 477 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ instance/
.scrapy

# Sphinx documentation
docs/_build/
docs/source/_build/
docs/build/

# PyBuilder
target/
Expand Down
11 changes: 7 additions & 4 deletions benchmark_server/servers/easynetwork_tcp_echoserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,16 @@ def create_tcp_server(
if context_reuse:
print("with context reuse")

serializer: BufferedIncrementalPacketSerializer[Any, Any, Any]
protocol: StreamProtocol[Any, Any] | BufferedStreamProtocol[Any, Any, Any]
if readline:
protocol = BufferedStreamProtocol(LineSerializer())
serializer = LineSerializer()
else:
protocol = BufferedStreamProtocol(NoSerializer())
if not buffered:
protocol = protocol.into_data_protocol()
serializer = NoSerializer()
if buffered:
protocol = BufferedStreamProtocol(serializer)
else:
protocol = StreamProtocol(serializer)
return StandaloneTCPNetworkServer(
None,
port,
Expand Down
2 changes: 0 additions & 2 deletions docs/source/howto/advanced/buffered_serializers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ its :meth:`~.BufferedIncrementalPacketSerializer.create_deserializer_buffer` and
.. note::

You still need to implement :class:`.AbstractIncrementalPacketSerializer` methods.
Writing input directly to an external object that implements the :ref:`buffer protocol <bufferobjects>` is not supported by
all transport layer implementations.

Let's see how we can use it for ``MyJSONSerializer`` (from :doc:`../serializers`):

Expand Down
17 changes: 3 additions & 14 deletions src/easynetwork/lowlevel/api_async/endpoints/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,10 @@
from typing import Any, Generic, assert_never

from ...._typevars import _T_ReceivedPacket, _T_SentPacket
from ....exceptions import UnsupportedOperation
from ....protocol import AnyStreamProtocolType
from ... import _stream, _utils
from ..backend.abc import AsyncBackend
from ..transports.abc import (
AsyncBaseTransport,
AsyncBufferedStreamReadTransport,
AsyncStreamReadTransport,
AsyncStreamTransport,
AsyncStreamWriteTransport,
)
from ..transports.abc import AsyncBaseTransport, AsyncStreamReadTransport, AsyncStreamTransport, AsyncStreamWriteTransport


class AsyncStreamReceiverEndpoint(AsyncBaseTransport, Generic[_T_ReceivedPacket]):
Expand Down Expand Up @@ -399,7 +392,7 @@ async def receive(self) -> _T_ReceivedPacket:

@dataclasses.dataclass(slots=True)
class _BufferedReceiverImpl(Generic[_T_ReceivedPacket]):
transport: AsyncBufferedStreamReadTransport
transport: AsyncStreamReadTransport
consumer: _stream.BufferedStreamDataConsumer[_T_ReceivedPacket]
_eof_reached: bool = dataclasses.field(init=False, default=False)

Expand Down Expand Up @@ -446,11 +439,7 @@ def _get_receiver(

match protocol:
case BufferedStreamProtocol():
buffered_consumer = _stream.BufferedStreamDataConsumer(protocol, max_recv_size)
if not isinstance(transport, AsyncBufferedStreamReadTransport):
msg = f"The transport implementation {transport!r} does not implement AsyncBufferedStreamReadTransport interface"
raise UnsupportedOperation(msg)
return _BufferedReceiverImpl(transport, buffered_consumer)
return _BufferedReceiverImpl(transport, _stream.BufferedStreamDataConsumer(protocol, max_recv_size))
case StreamProtocol():
return _DataReceiverImpl(transport, _stream.StreamDataConsumer(protocol), max_recv_size)
case _: # pragma: no cover
Expand Down
33 changes: 6 additions & 27 deletions src/easynetwork/lowlevel/api_async/servers/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,12 @@

from ...._typevars import _T_Request, _T_Response
from ....protocol import AnyStreamProtocolType
from ....warnings import ManualBufferAllocationWarning
from ... import _stream, _utils
from ..._asyncgen import AsyncGenAction, SendAction, ThrowAction
from ..backend.abc import AsyncBackend, TaskGroup
from ..transports import utils as transports_utils
from ..transports.abc import (
AsyncBaseTransport,
AsyncBufferedStreamReadTransport,
AsyncListener,
AsyncStreamReadTransport,
AsyncStreamTransport,
Expand Down Expand Up @@ -209,20 +207,11 @@ async def __client_coroutine(
request_receiver: _RequestReceiver[_T_Request] | _BufferedRequestReceiver[_T_Request]
match self.__protocol:
case BufferedStreamProtocol():
if isinstance(transport, AsyncBufferedStreamReadTransport):
consumer = _stream.BufferedStreamDataConsumer(self.__protocol, self.__max_recv_size)
request_receiver = _BufferedRequestReceiver(
transport=transport,
consumer=consumer,
)
else:
self.__manual_buffer_allocation_warning(transport)
consumer = _stream.StreamDataConsumer(self.__protocol.into_data_protocol())
request_receiver = _RequestReceiver(
transport=transport,
consumer=consumer,
max_recv_size=self.__max_recv_size,
)
consumer = _stream.BufferedStreamDataConsumer(self.__protocol, self.__max_recv_size)
request_receiver = _BufferedRequestReceiver(
transport=transport,
consumer=consumer,
)
case StreamProtocol():
consumer = _stream.StreamDataConsumer(self.__protocol)
request_receiver = _RequestReceiver(
Expand Down Expand Up @@ -263,16 +252,6 @@ async def __client_coroutine(
finally:
await request_handler_generator.aclose()

@staticmethod
def __manual_buffer_allocation_warning(transport: AsyncStreamTransport) -> None:
_warn_msg = " ".join(
[
f"The transport implementation {transport!r} does not implement AsyncBufferedStreamReadTransport interface.",
"Consider using StreamProtocol instead of BufferedStreamProtocol.",
]
)
warnings.warn(_warn_msg, category=ManualBufferAllocationWarning, stacklevel=2)

@property
@_utils.inherit_doc(AsyncBaseTransport)
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
Expand Down Expand Up @@ -315,7 +294,7 @@ async def next(self, timeout: float | None) -> AsyncGenAction[_T_Request]:

@dataclasses.dataclass(kw_only=True, eq=False, slots=True)
class _BufferedRequestReceiver(Generic[_T_Request]):
transport: AsyncBufferedStreamReadTransport
transport: AsyncStreamReadTransport
consumer: _stream.BufferedStreamDataConsumer[_T_Request]
__null_timeout_ctx: contextlib.nullcontext[None] = dataclasses.field(init=False, default_factory=contextlib.nullcontext)
__backend: AsyncBackend = dataclasses.field(init=False)
Expand Down
21 changes: 10 additions & 11 deletions src/easynetwork/lowlevel/api_async/transports/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

__all__ = [
"AsyncBaseTransport",
"AsyncBufferedStreamReadTransport",
"AsyncDatagramListener",
"AsyncDatagramReadTransport",
"AsyncDatagramTransport",
Expand Down Expand Up @@ -104,7 +103,6 @@ class AsyncStreamReadTransport(AsyncBaseTransport):

__slots__ = ()

@abstractmethod
async def recv(self, bufsize: int) -> bytes:
"""
Read and return up to `bufsize` bytes.
Expand All @@ -120,15 +118,16 @@ async def recv(self, bufsize: int) -> bytes:
If `bufsize` is greater than zero and an empty byte buffer is returned, this indicates an EOF.
"""
raise NotImplementedError


class AsyncBufferedStreamReadTransport(AsyncStreamReadTransport):
"""
An asynchronous continuous stream data reader transport that supports externally allocated buffers.
"""

__slots__ = ()
if bufsize == 0:
return b""
if bufsize < 0:
raise ValueError("'bufsize' must be a positive or null integer")

with memoryview(bytearray(bufsize)) as buffer:
nbytes = await self.recv_into(buffer)
if nbytes < 0:
raise RuntimeError("transport.recv_into() returned a negative value")
return bytes(buffer[:nbytes])

@abstractmethod
async def recv_into(self, buffer: WriteableBuffer) -> int:
Expand Down
29 changes: 5 additions & 24 deletions src/easynetwork/lowlevel/api_async/transports/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from ....exceptions import UnsupportedOperation
from ... import _utils, constants, socket as socket_tools
from ..backend.abc import AsyncBackend, TaskGroup
from .abc import AsyncBufferedStreamReadTransport, AsyncListener, AsyncStreamReadTransport, AsyncStreamTransport
from .abc import AsyncListener, AsyncStreamReadTransport, AsyncStreamTransport
from .utils import aclose_forcefully

if TYPE_CHECKING:
Expand All @@ -51,7 +51,7 @@


@dataclasses.dataclass(repr=False, eq=False, slots=True, kw_only=True)
class AsyncTLSStreamTransport(AsyncStreamTransport, AsyncBufferedStreamReadTransport):
class AsyncTLSStreamTransport(AsyncStreamTransport):
"""
SSL/TLS wrapper for a continuous stream transport.
"""
Expand All @@ -66,10 +66,7 @@ class AsyncTLSStreamTransport(AsyncStreamTransport, AsyncBufferedStreamReadTrans
__closing: bool = dataclasses.field(init=False, default=False)

def __post_init__(self) -> None:
if isinstance(self._transport, AsyncBufferedStreamReadTransport):
self.__incoming_reader = _BufferedIncomingDataReader(transport=self._transport)
else:
self.__incoming_reader = _IncomingDataReader(transport=self._transport)
self.__incoming_reader = _IncomingDataReader(transport=self._transport)

@classmethod
async def wrap(
Expand Down Expand Up @@ -190,7 +187,7 @@ async def recv(self, bufsize: int) -> bytes:
return b""
raise

@_utils.inherit_doc(AsyncBufferedStreamReadTransport)
@_utils.inherit_doc(AsyncStreamTransport)
async def recv_into(self, buffer: WriteableBuffer) -> int:
assert _ssl_module is not None, "stdlib ssl module not available" # nosec assert_used
nbytes = memoryview(buffer).nbytes or 1024
Expand Down Expand Up @@ -357,20 +354,6 @@ class _IncomingDataReader:
transport: AsyncStreamReadTransport
max_size: Final[int] = 256 * 1024 # 256KiB

async def readinto(self, read_bio: MemoryBIO) -> int:
data = await self.transport.recv(self.max_size)
if data:
return read_bio.write(data)
read_bio.write_eof()
return 0

def close(self) -> None:
pass


@dataclasses.dataclass(kw_only=True, eq=False, slots=True)
class _BufferedIncomingDataReader(_IncomingDataReader):
transport: AsyncBufferedStreamReadTransport
buffer: bytearray | None = dataclasses.field(init=False)
buffer_view: memoryview = dataclasses.field(init=False)

Expand All @@ -379,9 +362,7 @@ def __post_init__(self) -> None:
self.buffer_view = memoryview(self.buffer)

async def readinto(self, read_bio: MemoryBIO) -> int:
buffer = self.buffer_view
nbytes = await self.transport.recv_into(buffer)
if nbytes:
if nbytes := await self.transport.recv_into(buffer := self.buffer_view):
return read_bio.write(buffer[:nbytes])
read_bio.write_eof()
return 0
Expand Down
17 changes: 3 additions & 14 deletions src/easynetwork/lowlevel/api_sync/endpoints/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,9 @@
from typing import Any, Generic, assert_never

from ...._typevars import _T_ReceivedPacket, _T_SentPacket
from ....exceptions import UnsupportedOperation
from ....protocol import AnyStreamProtocolType
from ... import _stream, _utils
from ..transports.abc import (
BaseTransport,
BufferedStreamReadTransport,
StreamReadTransport,
StreamTransport,
StreamWriteTransport,
)
from ..transports.abc import BaseTransport, StreamReadTransport, StreamTransport, StreamWriteTransport


class StreamReceiverEndpoint(BaseTransport, Generic[_T_ReceivedPacket]):
Expand Down Expand Up @@ -418,7 +411,7 @@ def receive(self, timeout: float) -> _T_ReceivedPacket:

@dataclasses.dataclass(slots=True)
class _BufferedReceiverImpl(Generic[_T_ReceivedPacket]):
transport: BufferedStreamReadTransport
transport: StreamReadTransport
consumer: _stream.BufferedStreamDataConsumer[_T_ReceivedPacket]
_eof_reached: bool = dataclasses.field(init=False, default=False)

Expand Down Expand Up @@ -472,11 +465,7 @@ def _get_receiver(

match protocol:
case BufferedStreamProtocol():
buffered_consumer = _stream.BufferedStreamDataConsumer(protocol, max_recv_size)
if not isinstance(transport, BufferedStreamReadTransport):
msg = f"The transport implementation {transport!r} does not implement BufferedStreamReadTransport interface"
raise UnsupportedOperation(msg)
return _BufferedReceiverImpl(transport, buffered_consumer)
return _BufferedReceiverImpl(transport, _stream.BufferedStreamDataConsumer(protocol, max_recv_size))
case StreamProtocol():
return _DataReceiverImpl(transport, _stream.StreamDataConsumer(protocol), max_recv_size)
case _: # pragma: no cover
Expand Down
21 changes: 10 additions & 11 deletions src/easynetwork/lowlevel/api_sync/transports/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

__all__ = [
"BaseTransport",
"BufferedStreamReadTransport",
"DatagramReadTransport",
"DatagramTransport",
"DatagramWriteTransport",
Expand Down Expand Up @@ -79,7 +78,6 @@ class StreamReadTransport(BaseTransport):

__slots__ = ()

@abstractmethod
def recv(self, bufsize: int, timeout: float) -> bytes:
"""
Read and return up to `bufsize` bytes.
Expand All @@ -98,15 +96,16 @@ def recv(self, bufsize: int, timeout: float) -> bytes:
If `bufsize` is greater than zero and an empty byte buffer is returned, this indicates an EOF.
"""
raise NotImplementedError


class BufferedStreamReadTransport(StreamReadTransport):
"""
A continuous stream data reader transport that supports externally allocated buffers.
"""

__slots__ = ()
if bufsize == 0:
return b""
if bufsize < 0:
raise ValueError("'bufsize' must be a positive or null integer")

with memoryview(bytearray(bufsize)) as buffer:
nbytes = self.recv_into(buffer, timeout)
if nbytes < 0:
raise RuntimeError("transport.recv_into() returned a negative value")
return bytes(buffer[:nbytes])

@abstractmethod
def recv_into(self, buffer: WriteableBuffer, timeout: float) -> int:
Expand Down
22 changes: 10 additions & 12 deletions src/easynetwork/lowlevel/api_sync/transports/base_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

__all__ = [
"SelectorBaseTransport",
"SelectorBufferedStreamReadTransport",
"SelectorDatagramReadTransport",
"SelectorDatagramTransport",
"SelectorDatagramWriteTransport",
Expand Down Expand Up @@ -169,7 +168,6 @@ class SelectorStreamReadTransport(SelectorBaseTransport, transports.StreamReadTr

__slots__ = ()

@abstractmethod
def recv_noblock(self, bufsize: int) -> bytes:
"""
Read and return up to `bufsize` bytes.
Expand All @@ -187,7 +185,16 @@ def recv_noblock(self, bufsize: int) -> bytes:
If `bufsize` is greater than zero and an empty byte buffer is returned, this indicates an EOF.
"""
raise NotImplementedError
if bufsize == 0:
return b""
if bufsize < 0:
raise ValueError("'bufsize' must be a positive or null integer")

with memoryview(bytearray(bufsize)) as buffer:
nbytes = self.recv_noblock_into(buffer)
if nbytes < 0:
raise RuntimeError("transport.recv_noblock_into() returned a negative value")
return bytes(buffer[:nbytes])

def recv(self, bufsize: int, timeout: float) -> bytes:
"""
Expand All @@ -197,15 +204,6 @@ def recv(self, bufsize: int, timeout: float) -> bytes:
"""
return self._retry(lambda: self.recv_noblock(bufsize), timeout)[0]


class SelectorBufferedStreamReadTransport(SelectorStreamReadTransport, transports.BufferedStreamReadTransport):
"""
A continuous stream data reader transport using the :mod:`selectors` module for blocking operations polling
that supports externally allocated buffers.
"""

__slots__ = ()

@abstractmethod
def recv_noblock_into(self, buffer: WriteableBuffer) -> int:
"""
Expand Down
Loading

0 comments on commit 7135dbd

Please sign in to comment.