Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BufferedIncrementalPacketSerializer.create_deserializer_buffer() can now return a buffer with a different item size #184

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/easynetwork/lowlevel/_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def get_write_buffer(self) -> WriteableBuffer:
raise RuntimeError("protocol.build_packet_from_buffer() crashed") from exc
self.__consumer = consumer

buffer: memoryview = memoryview(self.__buffer)
buffer: memoryview = memoryview(self.__buffer).cast("B")

match self.__buffer_start:
case None | 0:
Expand Down Expand Up @@ -309,7 +309,7 @@ def get_value(self, *, full: bool = False) -> bytes | None:
return None
if full:
return bytes(self.__buffer)
buffer = memoryview(self.__buffer)
buffer = memoryview(self.__buffer).cast("B")
if self.__buffer_start is None:
nbytes = self.__already_written
elif self.__buffer_start < 0:
Expand Down Expand Up @@ -344,8 +344,6 @@ def __validate_created_buffer(buffer: WriteableBuffer) -> None:
with memoryview(buffer) as buffer:
if buffer.readonly:
raise ValueError("protocol.create_buffer() returned a read-only buffer")
if buffer.itemsize != 1:
raise ValueError("protocol.create_buffer() must return a byte buffer")
if not len(buffer):
raise ValueError("protocol.create_buffer() returned a null buffer")

Expand All @@ -354,7 +352,7 @@ def buffer_size(self) -> int:
if self.__buffer is None:
return 0
with memoryview(self.__buffer) as buffer:
return len(buffer)
return buffer.nbytes


def _check_protocol(p: StreamProtocol[Any, Any]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/easynetwork/serializers/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def create_deserializer_buffer(self, sizehint: int, /) -> _BufferT:
Called to allocate a new receive buffer.

Parameters:
sizehint: the recommended size for the returned buffer.
sizehint: the recommended size (in bytes) for the returned buffer.
It is acceptable to return smaller or larger buffers than what `sizehint` suggests.

Returns:
Expand Down
57 changes: 40 additions & 17 deletions tests/unit_test/test_tools/test_stream.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import itertools
import struct
from collections.abc import Generator
from typing import TYPE_CHECKING, Any, Literal, assert_never

Expand Down Expand Up @@ -603,23 +604,6 @@ def test____get_write_buffer____protocol_create_buffer_validation____readonly_bu
mock_buffered_stream_receiver.build_packet_from_buffer.assert_not_called()
assert consumer.get_value() is None

def test____get_write_buffer____protocol_create_buffer_validation____not_a_byte_buffer(
self,
consumer: BufferedStreamDataConsumer[Any],
mock_buffered_stream_receiver: MagicMock,
) -> None:
# Arrange
from array import array

mock_buffered_stream_receiver.create_buffer.side_effect = [memoryview(array("I", itertools.repeat(0, 5)))]

# Act & Assert
with pytest.raises(ValueError, match=r"^protocol\.create_buffer\(\) must return a byte buffer$"):
_ = consumer.get_write_buffer()

mock_buffered_stream_receiver.build_packet_from_buffer.assert_not_called()
assert consumer.get_value() is None

def test____get_write_buffer____protocol_create_buffer_validation____empty_byte_buffer(
self,
consumer: BufferedStreamDataConsumer[Any],
Expand Down Expand Up @@ -1015,6 +999,45 @@ def side_effect(buffer: memoryview) -> Generator[int | None, int, tuple[Any, Rea
assert full_buffer_value[-10:-7] == b"Bye"
assert truncated_buffer_value.endswith(b"Bye")

def test____next____not_a_byte_buffer(
self,
consumer: BufferedStreamDataConsumer[Any],
mock_buffered_stream_receiver: MagicMock,
sizehint: int,
) -> None:
# Arrange
from array import array

itemsize = struct.calcsize("@I")

mock_buffered_stream_receiver.create_buffer.side_effect = lambda sizehint: array(
"I", itertools.repeat(0, sizehint // itemsize)
)

def side_effect(buffer: array[int]) -> Generator[int, int, tuple[Any, bytes]]:
nbytes = yield 0
assert nbytes == itemsize
nbytes = yield 1 * itemsize
assert nbytes == itemsize
return (buffer[0], buffer[1]), b""

mock_build_packet_from_buffer_func: MagicMock = mock_buffered_stream_receiver.build_packet_from_buffer
mock_build_packet_from_buffer_func.side_effect = side_effect

# Act & Assert
buffer = consumer.get_write_buffer()
assert memoryview(buffer).format == "B"
assert memoryview(buffer).itemsize == 1
assert memoryview(buffer).nbytes == sizehint
assert consumer.buffer_size == sizehint
self.write_in_consumer(consumer, struct.pack("@I", 42))
with pytest.raises(StopIteration):
next(consumer)
self.write_in_consumer(consumer, struct.pack("@I", 987))
packet = next(consumer)
assert isinstance(packet, tuple)
assert packet == (42, 987)

@pytest.mark.parametrize("remainder_type", ["buffer_view", "external"])
def test____next____protocol_parse_error(
self,
Expand Down