Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement zero copy writes for the WebSocket writer #9634

Draft
wants to merge 17 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGES/9634.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implemented zero copy writes for WebSockets when using Python 3.12+ -- by :user:`bdraco`.
2 changes: 2 additions & 0 deletions aiohttp/_websocket/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from ..helpers import NO_EXTENSIONS

DEFAULT_LIMIT = 2**18

if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
from .reader_py import (
WebSocketDataQueue as WebSocketDataQueuePython,
Expand Down
10 changes: 7 additions & 3 deletions aiohttp/_websocket/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,16 @@ async def send_frame(
mask = PACK_RANDBITS(self.get_random_bits())
message = bytearray(message)
websocket_mask(mask, message)
self.transport.write(header + mask + message)
self.transport.writelines((header, mask, message))
self._output_size += MASK_LEN
elif msg_length > MSG_SIZE:
self.transport.write(header)
self.transport.write(message)
# For large messages, we use writelines to avoid copying the
# entire message into a new buffer. This is a performance
# optimization to avoid unnecessary memory allocations.
self.transport.writelines((header, message))
else:
# If the message is small, its faster to copy it into a new
# buffer and send it all at once.
self.transport.write(header + message)

self._output_size += header_len + msg_length
Expand Down
4 changes: 2 additions & 2 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from yarl import URL

from . import hdrs, http, payload
from ._websocket.reader import WebSocketDataQueue
from ._websocket.reader import DEFAULT_LIMIT, WebSocketDataQueue
from .abc import AbstractCookieJar
from .client_exceptions import (
ClientConnectionError,
Expand Down Expand Up @@ -1035,7 +1035,7 @@ async def _ws_connect(

transport = conn.transport
assert transport is not None
reader = WebSocketDataQueue(conn_proto, 2**16, loop=self._loop)
reader = WebSocketDataQueue(conn_proto, DEFAULT_LIMIT, loop=self._loop)
conn_proto.set_parser(WebSocketReader(reader, max_msg_size), reader)
writer = WebSocketWriter(
conn_proto,
Expand Down
10 changes: 6 additions & 4 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from multidict import CIMultiDict

from . import hdrs
from ._websocket.reader import WebSocketDataQueue
from ._websocket.writer import DEFAULT_LIMIT
from ._websocket.reader import DEFAULT_LIMIT as DEFAULT_READER_LIMIT, WebSocketDataQueue
from ._websocket.writer import DEFAULT_LIMIT as DEFAULT_WRITER_LIMIT
from .abc import AbstractStreamWriter
from .client_exceptions import WSMessageTypeError
from .helpers import calculate_timeout_when, set_exception, set_result
Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(
protocols: Iterable[str] = (),
compress: bool = True,
max_msg_size: int = 4 * 1024 * 1024,
writer_limit: int = DEFAULT_LIMIT,
writer_limit: int = DEFAULT_WRITER_LIMIT,
) -> None:
super().__init__(status=101)
self._length_check = False
Expand Down Expand Up @@ -356,7 +356,9 @@ def _post_start(

loop = self._loop
assert loop is not None
self._reader = WebSocketDataQueue(request._protocol, 2**16, loop=loop)
self._reader = WebSocketDataQueue(
request._protocol, DEFAULT_READER_LIMIT, loop=loop
)
request.protocol.set_parser(
WebSocketReader(
self._reader, self._max_msg_size, compress=bool(self._compress)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_benchmarks_http_websocket.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""codspeed benchmarks for http websocket."""

import asyncio
from typing import Union
from typing import Iterable, Union

from pytest_codspeed import BenchmarkFixture

Expand Down Expand Up @@ -62,6 +62,9 @@ def is_closing(self) -> bool:
def write(self, data: Union[bytes, bytearray, memoryview]) -> None:
"""Swallow writes."""

def writelines(self, data: Iterable[Union[bytes, bytearray, memoryview]]) -> None:
"""Swallow writes."""


class MockProtocol(BaseProtocol):

Expand Down
11 changes: 6 additions & 5 deletions tests/test_websocket_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,12 @@ async def test_send_binary_long(writer: WebSocketWriter) -> None:

async def test_send_binary_very_long(writer: WebSocketWriter) -> None:
await writer.send_frame(b"b" * 65537, WSMsgType.BINARY)
assert (
writer.transport.write.call_args_list[0][0][0] # type: ignore[attr-defined]
== b"\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x01"
assert writer.transport.writelines.call_args_list[0][0][ # type: ignore[attr-defined]
0
] == (
b"\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x01",
b"b" * 65537,
)
assert writer.transport.write.call_args_list[1][0][0] == b"b" * 65537 # type: ignore[attr-defined]


async def test_close(writer: WebSocketWriter) -> None:
Expand All @@ -84,7 +85,7 @@ async def test_send_text_masked(
protocol, transport, use_mask=True, random=random.Random(123)
)
await writer.send_frame(b"text", WSMsgType.TEXT)
writer.transport.write.assert_called_with(b"\x81\x84\rg\xb3fy\x02\xcb\x12") # type: ignore[attr-defined]
writer.transport.writelines.assert_called_with((b"\x81\x84", b"\rg\xb3f", bytearray(b"y\x02\xcb\x12"))) # type: ignore[attr-defined]


async def test_send_compress_text(
Expand Down
Loading