Skip to content

Commit

Permalink
[PR #9839/a9a0d84 backport][3.11] Implement zero copy writes in `Stre…
Browse files Browse the repository at this point in the history
…amWriter` (#9847)
  • Loading branch information
bdraco authored Nov 13, 2024
1 parent c39032b commit 354489d
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 34 deletions.
1 change: 1 addition & 0 deletions CHANGES/9839.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implemented zero copy writes for ``StreamWriter`` -- by :user:`bdraco`.
68 changes: 49 additions & 19 deletions aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@

import asyncio
import zlib
from typing import Any, Awaitable, Callable, NamedTuple, Optional, Union # noqa
from typing import ( # noqa
Any,
Awaitable,
Callable,
Iterable,
List,
NamedTuple,
Optional,
Union,
)

from multidict import CIMultiDict

Expand Down Expand Up @@ -76,6 +85,17 @@ def _write(self, chunk: bytes) -> None:
raise ClientConnectionResetError("Cannot write to closing transport")
transport.write(chunk)

def _writelines(self, chunks: Iterable[bytes]) -> None:
size = 0
for chunk in chunks:
size += len(chunk)
self.buffer_size += size
self.output_size += size
transport = self._protocol.transport
if transport is None or transport.is_closing():
raise ClientConnectionResetError("Cannot write to closing transport")
transport.writelines(chunks)

async def write(
self, chunk: bytes, *, drain: bool = True, LIMIT: int = 0x10000
) -> None:
Expand Down Expand Up @@ -110,10 +130,11 @@ async def write(

if chunk:
if self.chunked:
chunk_len_pre = ("%x\r\n" % len(chunk)).encode("ascii")
chunk = chunk_len_pre + chunk + b"\r\n"

self._write(chunk)
self._writelines(
(f"{len(chunk):x}\r\n".encode("ascii"), chunk, b"\r\n")
)
else:
self._write(chunk)

if self.buffer_size > LIMIT and drain:
self.buffer_size = 0
Expand Down Expand Up @@ -142,22 +163,31 @@ async def write_eof(self, chunk: bytes = b"") -> None:
await self._on_chunk_sent(chunk)

if self._compress:
if chunk:
chunk = await self._compress.compress(chunk)
chunks: List[bytes] = []
chunks_len = 0
if chunk and (compressed_chunk := await self._compress.compress(chunk)):
chunks_len = len(compressed_chunk)
chunks.append(compressed_chunk)

chunk += self._compress.flush()
if chunk and self.chunked:
chunk_len = ("%x\r\n" % len(chunk)).encode("ascii")
chunk = chunk_len + chunk + b"\r\n0\r\n\r\n"
else:
if self.chunked:
if chunk:
chunk_len = ("%x\r\n" % len(chunk)).encode("ascii")
chunk = chunk_len + chunk + b"\r\n0\r\n\r\n"
else:
chunk = b"0\r\n\r\n"
flush_chunk = self._compress.flush()
chunks_len += len(flush_chunk)
chunks.append(flush_chunk)
assert chunks_len

if chunk:
if self.chunked:
chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii")
self._writelines((chunk_len_pre, *chunks, b"\r\n0\r\n\r\n"))
elif len(chunks) > 1:
self._writelines(chunks)
else:
self._write(chunks[0])
elif self.chunked:
if chunk:
chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii")
self._writelines((chunk_len_pre, chunk, b"\r\n0\r\n\r\n"))
else:
self._write(b"0\r\n\r\n")
elif chunk:
self._write(chunk)

await self.drain()
Expand Down
13 changes: 7 additions & 6 deletions tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import urllib.parse
import zlib
from http.cookies import BaseCookie, Morsel, SimpleCookie
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Iterable, Optional
from unittest import mock

import pytest
Expand Down Expand Up @@ -67,17 +67,18 @@ def protocol(loop, transport):


@pytest.fixture
def transport(buf):
transport = mock.Mock()
def transport(buf: bytearray) -> mock.Mock:
transport = mock.create_autospec(asyncio.Transport, spec_set=True, instance=True)

def write(chunk):
buf.extend(chunk)

async def write_eof():
pass
def writelines(chunks: Iterable[bytes]) -> None:
for chunk in chunks:
buf.extend(chunk)

transport.write.side_effect = write
transport.write_eof.side_effect = write_eof
transport.writelines.side_effect = writelines
transport.is_closing.return_value = False

return transport
Expand Down
143 changes: 134 additions & 9 deletions tests/test_http_writer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Tests for aiohttp/http_writer.py
import array
import asyncio
import zlib
from typing import Iterable
from unittest import mock

import pytest
Expand All @@ -23,7 +25,12 @@ def transport(buf):
def write(chunk):
buf.extend(chunk)

def writelines(chunks: Iterable[bytes]) -> None:
for chunk in chunks:
buf.extend(chunk)

transport.write.side_effect = write
transport.writelines.side_effect = writelines
transport.is_closing.return_value = False
return transport

Expand Down Expand Up @@ -85,21 +92,53 @@ async def test_write_payload_length(protocol, transport, loop) -> None:
assert b"da" == content.split(b"\r\n\r\n", 1)[-1]


async def test_write_payload_chunked_filter(protocol, transport, loop) -> None:
write = transport.write = mock.Mock()
async def test_write_large_payload_deflate_compression_data_in_eof(
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
msg = http.StreamWriter(protocol, loop)
msg.enable_compression("deflate")

await msg.write(b"data" * 4096)
assert transport.write.called # type: ignore[attr-defined]
chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined]
transport.write.reset_mock() # type: ignore[attr-defined]
assert not transport.writelines.called # type: ignore[attr-defined]

# This payload compresses to 20447 bytes
payload = b"".join(
[bytes((*range(0, i), *range(i, 0, -1))) for i in range(255) for _ in range(64)]
)
await msg.write_eof(payload)
assert not transport.write.called # type: ignore[attr-defined]
assert transport.writelines.called # type: ignore[attr-defined]
chunks.extend(transport.writelines.mock_calls[0][1][0]) # type: ignore[attr-defined]
content = b"".join(chunks)
assert zlib.decompress(content) == (b"data" * 4096) + payload


async def test_write_payload_chunked_filter(
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
msg = http.StreamWriter(protocol, loop)
msg.enable_chunking()
await msg.write(b"da")
await msg.write(b"ta")
await msg.write_eof()

content = b"".join([c[1][0] for c in list(write.mock_calls)])
content = b"".join([b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)]) # type: ignore[attr-defined]
content += b"".join([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined]
assert content.endswith(b"2\r\nda\r\n2\r\nta\r\n0\r\n\r\n")


async def test_write_payload_chunked_filter_mutiple_chunks(protocol, transport, loop):
write = transport.write = mock.Mock()
async def test_write_payload_chunked_filter_multiple_chunks(
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
msg = http.StreamWriter(protocol, loop)
msg.enable_chunking()
await msg.write(b"da")
Expand All @@ -108,14 +147,14 @@ async def test_write_payload_chunked_filter_mutiple_chunks(protocol, transport,
await msg.write(b"at")
await msg.write(b"a2")
await msg.write_eof()
content = b"".join([c[1][0] for c in list(write.mock_calls)])
content = b"".join([b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)]) # type: ignore[attr-defined]
content += b"".join([c[1][0] for c in list(transport.write.mock_calls)]) # type: ignore[attr-defined]
assert content.endswith(
b"2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n2\r\na2\r\n0\r\n\r\n"
)


async def test_write_payload_deflate_compression(protocol, transport, loop) -> None:

COMPRESSED = b"x\x9cKI,I\x04\x00\x04\x00\x01\x9b"
write = transport.write = mock.Mock()
msg = http.StreamWriter(protocol, loop)
Expand All @@ -129,7 +168,30 @@ async def test_write_payload_deflate_compression(protocol, transport, loop) -> N
assert COMPRESSED == content.split(b"\r\n\r\n", 1)[-1]


async def test_write_payload_deflate_and_chunked(buf, protocol, transport, loop):
async def test_write_payload_deflate_compression_chunked(
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
expected = b"2\r\nx\x9c\r\na\r\nKI,I\x04\x00\x04\x00\x01\x9b\r\n0\r\n\r\n"
msg = http.StreamWriter(protocol, loop)
msg.enable_compression("deflate")
msg.enable_chunking()
await msg.write(b"data")
await msg.write_eof()

chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined]
assert all(chunks)
content = b"".join(chunks)
assert content == expected


async def test_write_payload_deflate_and_chunked(
buf: bytearray,
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
msg = http.StreamWriter(protocol, loop)
msg.enable_compression("deflate")
msg.enable_chunking()
Expand All @@ -142,8 +204,71 @@ async def test_write_payload_deflate_and_chunked(buf, protocol, transport, loop)
assert thing == buf


async def test_write_payload_bytes_memoryview(buf, protocol, transport, loop):
async def test_write_payload_deflate_compression_chunked_data_in_eof(
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
expected = b"2\r\nx\x9c\r\nd\r\nKI,IL\xcdK\x01\x00\x0b@\x02\xd2\r\n0\r\n\r\n"
msg = http.StreamWriter(protocol, loop)
msg.enable_compression("deflate")
msg.enable_chunking()
await msg.write(b"data")
await msg.write_eof(b"end")

chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined]
assert all(chunks)
content = b"".join(chunks)
assert content == expected


async def test_write_large_payload_deflate_compression_chunked_data_in_eof(
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
msg = http.StreamWriter(protocol, loop)
msg.enable_compression("deflate")
msg.enable_chunking()

await msg.write(b"data" * 4096)
# This payload compresses to 1111 bytes
payload = b"".join([bytes((*range(0, i), *range(i, 0, -1))) for i in range(255)])
await msg.write_eof(payload)
assert not transport.write.called # type: ignore[attr-defined]

chunks = []
for write_lines_call in transport.writelines.mock_calls: # type: ignore[attr-defined]
chunked_payload = list(write_lines_call[1][0])[1:]
chunked_payload.pop()
chunks.extend(chunked_payload)

assert all(chunks)
content = b"".join(chunks)
assert zlib.decompress(content) == (b"data" * 4096) + payload


async def test_write_payload_deflate_compression_chunked_connection_lost(
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
msg = http.StreamWriter(protocol, loop)
msg.enable_compression("deflate")
msg.enable_chunking()
await msg.write(b"data")
with pytest.raises(
ClientConnectionResetError, match="Cannot write to closing transport"
), mock.patch.object(transport, "is_closing", return_value=True):
await msg.write_eof(b"end")


async def test_write_payload_bytes_memoryview(
buf: bytearray,
protocol: BaseProtocol,
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
msg = http.StreamWriter(protocol, loop)

mv = memoryview(b"abcd")
Expand Down

0 comments on commit 354489d

Please sign in to comment.