From 3e995fb4f8e9759b93dba3b468db3f2537577300 Mon Sep 17 00:00:00 2001 From: Francis CLAIRICIA-ROSE-CLAIRE-JOSEPHINE Date: Thu, 28 Sep 2023 13:02:59 +0200 Subject: [PATCH 1/6] [UPDATE] Improved stream tools --- src/easynetwork/api_async/server/tcp.py | 2 +- src/easynetwork/tools/_stream.py | 21 ++++---- tests/unit_test/test_tools/test_stream.py | 64 +++++++++++++++++++---- 3 files changed, 67 insertions(+), 20 deletions(-) diff --git a/src/easynetwork/api_async/server/tcp.py b/src/easynetwork/api_async/server/tcp.py index 6c793394..6e7122bf 100644 --- a/src/easynetwork/api_async/server/tcp.py +++ b/src/easynetwork/api_async/server/tcp.py @@ -656,7 +656,7 @@ async def send_packet(self, packet: _ResponseT, /) -> None: self.__check_closed() self.__logger.debug("A response will be sent to %s", self.address) producer = self.__producer - producer.queue(packet) + producer.enqueue(packet) del packet async with self.__send_lock: socket = self.__check_closed() diff --git a/src/easynetwork/tools/_stream.py b/src/easynetwork/tools/_stream.py index ac569041..4366b792 100644 --- a/src/easynetwork/tools/_stream.py +++ b/src/easynetwork/tools/_stream.py @@ -23,12 +23,15 @@ from collections import deque from collections.abc import Generator, Iterator -from typing import Any, Generic, final +from typing import TYPE_CHECKING, Any, Generic, final from .._typevars import _ReceivedPacketT, _SentPacketT from ..exceptions import StreamProtocolParseError from ..protocol import StreamProtocol +if TYPE_CHECKING: + from _typeshed import ReadableBuffer + @final @Iterator.register @@ -69,11 +72,10 @@ def __next__(self) -> bytes: else: self.__g = None try: - chunk = next(filter(None, generator)) + chunk = next(filter(None, map(_ensure_bytes, generator))) except StopIteration: pass else: - _check_bytes(chunk) self.__g = generator return chunk finally: @@ -83,7 +85,7 @@ def __next__(self) -> bytes: def pending_packets(self) -> bool: return self.__g is not None or bool(self.__q) - def queue(self, *packets: _SentPacketT) -> None: + def enqueue(self, *packets: _SentPacketT) -> None: self.__q.extend(packets) def clear(self) -> None: @@ -140,9 +142,10 @@ def __next__(self) -> _ReceivedPacketT: consumer.send(chunk) except StopIteration as exc: packet, remaining = exc.value + remaining = _ensure_bytes(remaining) except StreamProtocolParseError as exc: remaining, exc.remaining_data = exc.remaining_data, b"" - _check_bytes(remaining) + remaining = _ensure_bytes(remaining) self.__b = remaining raise else: @@ -150,12 +153,11 @@ def __next__(self) -> _ReceivedPacketT: raise StopIteration finally: del consumer, chunk - _check_bytes(remaining) self.__b = remaining return packet def feed(self, chunk: bytes) -> None: - _check_bytes(chunk) + chunk = _ensure_bytes(chunk) if not chunk: return if self.__b: @@ -178,6 +180,7 @@ def _check_protocol(p: StreamProtocol[Any, Any]) -> None: raise TypeError(f"Expected a StreamProtocol object, got {p!r}") -def _check_bytes(b: bytes) -> None: +def _ensure_bytes(b: ReadableBuffer) -> bytes: if type(b) is not bytes: - raise AssertionError(f"Expected bytes, got {b!r}") + b = memoryview(b).tobytes() + return b diff --git a/tests/unit_test/test_tools/test_stream.py b/tests/unit_test/test_tools/test_stream.py index c0fcda2a..847b22d5 100644 --- a/tests/unit_test/test_tools/test_stream.py +++ b/tests/unit_test/test_tools/test_stream.py @@ -65,7 +65,7 @@ def side_effect(_: Any) -> Generator[bytes, None, None]: mock_generate_chunks_func: MagicMock = mock_stream_protocol.generate_chunks mock_generate_chunks_func.side_effect = side_effect - producer.queue(mocker.sentinel.packet) + producer.enqueue(mocker.sentinel.packet) # Act chunk: bytes = next(producer) @@ -87,7 +87,7 @@ def side_effect(_: Any) -> Generator[bytes, None, None]: mock_generate_chunks_func: MagicMock = mock_stream_protocol.generate_chunks mock_generate_chunks_func.side_effect = side_effect - producer.queue(mocker.sentinel.packet) + producer.enqueue(mocker.sentinel.packet) # Act chunk: bytes = next(producer) @@ -111,7 +111,7 @@ def side_effect(_: Any) -> Generator[bytes, None, None]: mock_generate_chunks_func: MagicMock = mock_stream_protocol.generate_chunks mock_generate_chunks_func.side_effect = side_effect - producer.queue(mocker.sentinel.packet) + producer.enqueue(mocker.sentinel.packet) # Act chunk: bytes = next(producer) @@ -132,7 +132,7 @@ def side_effect(_: Any) -> Generator[bytes, None, None]: mock_generate_chunks_func: MagicMock = mock_stream_protocol.generate_chunks mock_generate_chunks_func.side_effect = side_effect - producer.queue(mocker.sentinel.packet_for_test_arrange, mocker.sentinel.second_packet) + producer.enqueue(mocker.sentinel.packet_for_test_arrange, mocker.sentinel.second_packet) next(producer) mock_generate_chunks_func.reset_mock() # Needed to call assert_called_once() later @@ -155,7 +155,7 @@ def side_effect(_: Any) -> Generator[bytes, None, None]: mock_generate_chunks_func: MagicMock = mock_stream_protocol.generate_chunks mock_generate_chunks_func.side_effect = side_effect - producer.queue(mocker.sentinel.packet_for_test_arrange) + producer.enqueue(mocker.sentinel.packet_for_test_arrange) next(producer) mock_generate_chunks_func.reset_mock() # Needed to call assert_not_called() later @@ -166,6 +166,27 @@ def side_effect(_: Any) -> Generator[bytes, None, None]: # Assert mock_generate_chunks_func.assert_not_called() + def test____next____convert_bytearrays( + self, + producer: StreamDataProducer[Any], + mock_stream_protocol: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + def side_effect(_: Any) -> Generator[bytes, None, None]: + yield bytearray(b"chunk") + + mock_generate_chunks_func: MagicMock = mock_stream_protocol.generate_chunks + mock_generate_chunks_func.side_effect = side_effect + producer.enqueue(mocker.sentinel.packet_for_test_arrange) + + # Act + chunk = next(producer) + + # Assert + assert isinstance(chunk, bytes) + assert chunk == b"chunk" + def test____pending_packets____empty_producer(self, producer: StreamDataProducer[Any]) -> None: # Arrange @@ -185,7 +206,7 @@ def side_effect(_: Any) -> Generator[bytes, None, None]: mock_generate_chunks_func: MagicMock = mock_stream_protocol.generate_chunks mock_generate_chunks_func.side_effect = side_effect - producer.queue(mocker.sentinel.packet) + producer.enqueue(mocker.sentinel.packet) # Act & Assert assert producer.pending_packets() @@ -197,11 +218,11 @@ def side_effect(_: Any) -> Generator[bytes, None, None]: next(producer) assert not producer.pending_packets() - def test____queue____no_args(self, producer: StreamDataProducer[Any]) -> None: + def test____enqueue____no_args(self, producer: StreamDataProducer[Any]) -> None: # Arrange # Act - producer.queue() + producer.enqueue() # Assert ## There is no exceptions ? Nice ! @@ -212,7 +233,7 @@ def test____clear____remove_queued_packets( mocker: MockerFixture, ) -> None: # Arrange - producer.queue(mocker.sentinel.packet) + producer.enqueue(mocker.sentinel.packet) assert producer.pending_packets() # Act @@ -237,7 +258,7 @@ def side_effect(_: Any) -> Generator[bytes, None, None]: mock_generate_chunks_func: MagicMock = mock_stream_protocol.generate_chunks mock_generate_chunks_func.side_effect = side_effect - producer.queue(mocker.sentinel.packet) + producer.enqueue(mocker.sentinel.packet) assert next(producer) == b"chunk" assert producer.pending_packets() @@ -273,6 +294,29 @@ def test____dunder_iter____return_self(self, consumer: StreamDataConsumer[Any]) # Assert assert iterator is consumer + def test____feed____convert_bytearrays( + self, + consumer: StreamDataConsumer[Any], + mock_stream_protocol: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + def side_effect() -> Generator[None, bytes, tuple[Any, bytes]]: + data = yield + assert isinstance(data, bytes) + assert data == b"Hello" + return mocker.sentinel.packet, bytearray(b"World") + + mock_build_packet_from_chunks_func: MagicMock = mock_stream_protocol.build_packet_from_chunks + mock_build_packet_from_chunks_func.side_effect = side_effect + consumer.feed(bytearray(b"Hello")) # 0 + 0 == 0 + + # Act + packet = next(consumer) + + # Assert + assert packet is mocker.sentinel.packet + def test____next____no_buffer( self, consumer: StreamDataConsumer[Any], From d6682dfadec05701a4866d9557a1cd4586d4cde2 Mon Sep 17 00:00:00 2001 From: Francis CLAIRICIA-ROSE-CLAIRE-JOSEPHINE Date: Fri, 29 Sep 2023 11:52:59 +0200 Subject: [PATCH 2/6] [UPDATE] AutoSeparatedPacketSerializer: Added buffer limitation --- src/easynetwork/exceptions.py | 21 +++++ src/easynetwork/serializers/base_stream.py | 69 +++++++++++----- src/easynetwork/serializers/line.py | 5 +- src/easynetwork/serializers/wrapper/base64.py | 5 +- .../serializers/wrapper/encryptor.py | 5 +- src/easynetwork/tools/constants.py | 7 ++ tests/unit_test/test_serializers/test_abc.py | 78 +++++++++++++++---- 7 files changed, 156 insertions(+), 34 deletions(-) diff --git a/src/easynetwork/exceptions.py b/src/easynetwork/exceptions.py index 32261f12..8a72332c 100644 --- a/src/easynetwork/exceptions.py +++ b/src/easynetwork/exceptions.py @@ -34,6 +34,8 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from _typeshed import ReadableBuffer + from .tools.socket import SocketAddress @@ -82,6 +84,25 @@ def __init__(self, message: str, remaining_data: bytes, error_info: Any = None) """Unused trailing data.""" +class LimitOverrunError(IncrementalDeserializeError): + """Reached the buffer size limit while looking for a separator.""" + + def __init__(self, message: str, buffer: ReadableBuffer, consumed: int, separator: bytes = b"") -> None: + """ + Parameters: + message: Error message. + buffer: Currently too big buffer. + consumed: Total number of to be consumed bytes. + separator: Searched separator. + """ + + remaining_data = memoryview(buffer)[consumed:].tobytes() + if separator and remaining_data.startswith(separator): + remaining_data = remaining_data.removeprefix(separator) + + super().__init__(message, remaining_data, error_info=None) + + class PacketConversionError(Exception): """The deserialized :term:`packet` is invalid.""" diff --git a/src/easynetwork/serializers/base_stream.py b/src/easynetwork/serializers/base_stream.py index c01af6e8..f5a18f6d 100644 --- a/src/easynetwork/serializers/base_stream.py +++ b/src/easynetwork/serializers/base_stream.py @@ -28,7 +28,8 @@ from typing import IO, Any, final from .._typevars import _DTOPacketT -from ..exceptions import DeserializeError, IncrementalDeserializeError +from ..exceptions import DeserializeError, IncrementalDeserializeError, LimitOverrunError +from ..tools.constants import _DEFAULT_LIMIT from .abc import AbstractIncrementalPacketSerializer, AbstractPacketSerializer @@ -37,26 +38,38 @@ class AutoSeparatedPacketSerializer(AbstractIncrementalPacketSerializer[_DTOPack Base class for stream protocols that separates sent information by a byte sequence. """ - __slots__ = ("__separator", "__incremental_serialize_check_separator") + __slots__ = ("__separator", "__limit", "__incremental_serialize_check_separator") - def __init__(self, separator: bytes, *, incremental_serialize_check_separator: bool = True, **kwargs: Any) -> None: + def __init__( + self, + separator: bytes, + *, + incremental_serialize_check_separator: bool = True, + limit: int = _DEFAULT_LIMIT, + **kwargs: Any, + ) -> None: """ Parameters: separator: Byte sequence that indicates the end of the token. incremental_serialize_check_separator: If `True` (the default), checks that the data returned by :meth:`serialize` does not contain `separator`, and removes superfluous `separator` added at the end. + limit: Maximum buffer size. kwargs: Extra options given to ``super().__init__()``. Raises: TypeError: Invalid arguments. - ValueError: Empty separator sequence. + ValueError: Empty `separator` sequence. + ValueError: `limit` must be a positive integer. """ super().__init__(**kwargs) separator = bytes(separator) if len(separator) < 1: raise ValueError("Empty separator") + if limit <= 0: + raise ValueError("limit must be a positive integer") self.__separator: bytes = separator + self.__limit: int = limit self.__incremental_serialize_check_separator = bool(incremental_serialize_check_separator) @abstractmethod @@ -84,9 +97,15 @@ def incremental_serialize(self, packet: _DTOPacketT, /) -> Generator[bytes, None data = data.removesuffix(separator) if separator in data: raise ValueError(f"{separator!r} separator found in serialized packet {packet!r} which was not at the end") - yield data - del data - yield separator + if not data: + return + if len(data) + len(separator) <= self.__limit // 2: + data += separator + yield data + else: + yield data + del data + yield separator @abstractmethod def deserialize(self, data: bytes, /) -> _DTOPacketT: @@ -108,18 +127,32 @@ def incremental_deserialize(self) -> Generator[None, bytes, tuple[_DTOPacketT, b """ buffer: bytes = yield separator: bytes = self.__separator - separator_length: int = len(separator) + seplen: int = len(separator) + limit: int = self.__limit + offset: int = 0 + sepidx: int = -1 + while True: - data, found_separator, buffer = buffer.partition(separator) - if found_separator: - del found_separator - if not data: # There was successive separators - continue - break - assert not buffer # nosec assert_used - buffer = data + (yield) - while buffer.startswith(separator): # Remove successive separators which can already be eliminated - buffer = buffer[separator_length:] + buflen = len(buffer) + + if buflen - offset >= seplen: + sepidx = buffer.find(separator, offset) + + if sepidx != -1: + break + + offset = buflen + 1 - seplen + if offset > limit: + raise LimitOverrunError("Separator is not found, and chunk exceed the limit", buffer, offset, separator) + + buffer += yield + + if sepidx > limit: + raise LimitOverrunError("Separator is found, but chunk is longer than limit", buffer, sepidx, separator) + + data = buffer[:sepidx] + buffer = buffer[sepidx + seplen :] + try: packet = self.deserialize(data) except DeserializeError as exc: diff --git a/src/easynetwork/serializers/line.py b/src/easynetwork/serializers/line.py index f6e7e711..6c839668 100644 --- a/src/easynetwork/serializers/line.py +++ b/src/easynetwork/serializers/line.py @@ -21,6 +21,7 @@ from typing import Literal, assert_never, final from ..exceptions import DeserializeError +from ..tools.constants import _DEFAULT_LIMIT from .base_stream import AutoSeparatedPacketSerializer @@ -37,6 +38,7 @@ def __init__( *, encoding: str = "ascii", unicode_errors: str = "strict", + limit: int = _DEFAULT_LIMIT, ) -> None: r""" Parameters: @@ -50,6 +52,7 @@ def __init__( - ``"CRLF"``: Carriage return + line feed character sequence (``"\r\n"``). encoding: String encoding. Defaults to ``"ascii"``. unicode_errors: Controls how encoding errors are handled. + limit: Maximum buffer size. Used in incremental serialization context. See Also: :ref:`standard-encodings` and :ref:`error-handlers`. @@ -64,7 +67,7 @@ def __init__( separator = b"\r\n" case _: assert_never(newline) - super().__init__(separator=separator, incremental_serialize_check_separator=False) + super().__init__(separator=separator, incremental_serialize_check_separator=False, limit=limit) self.__encoding: str = encoding self.__unicode_errors: str = unicode_errors diff --git a/src/easynetwork/serializers/wrapper/base64.py b/src/easynetwork/serializers/wrapper/base64.py index 163cd65c..1cbffbe1 100644 --- a/src/easynetwork/serializers/wrapper/base64.py +++ b/src/easynetwork/serializers/wrapper/base64.py @@ -26,6 +26,7 @@ from ..._typevars import _DTOPacketT from ...exceptions import DeserializeError +from ...tools.constants import _DEFAULT_LIMIT from ..abc import AbstractPacketSerializer from ..base_stream import AutoSeparatedPacketSerializer @@ -44,6 +45,7 @@ def __init__( alphabet: Literal["standard", "urlsafe"] = "urlsafe", checksum: bool | str | bytes = False, separator: bytes = b"\r\n", + limit: int = _DEFAULT_LIMIT, ) -> None: """ Parameters: @@ -58,12 +60,13 @@ def __init__( checksum: If `True`, appends a sha256 checksum to the serialized data. `checksum` can also be a URL-safe base64-encoded 32-byte key for a signed checksum. separator: Token for :class:`AutoSeparatedPacketSerializer`. Used in incremental serialization context. + limit: Maximum buffer size. Used in incremental serialization context. """ import base64 import binascii from hmac import compare_digest - super().__init__(separator=separator, incremental_serialize_check_separator=not separator.isspace()) + super().__init__(separator=separator, incremental_serialize_check_separator=not separator.isspace(), limit=limit) if not isinstance(serializer, AbstractPacketSerializer): raise TypeError(f"Expected a serializer instance, got {serializer!r}") self.__serializer: AbstractPacketSerializer[_DTOPacketT] = serializer diff --git a/src/easynetwork/serializers/wrapper/encryptor.py b/src/easynetwork/serializers/wrapper/encryptor.py index d43c44b0..adf9873e 100644 --- a/src/easynetwork/serializers/wrapper/encryptor.py +++ b/src/easynetwork/serializers/wrapper/encryptor.py @@ -24,6 +24,7 @@ from ..._typevars import _DTOPacketT from ...exceptions import DeserializeError +from ...tools.constants import _DEFAULT_LIMIT from ..abc import AbstractPacketSerializer from ..base_stream import AutoSeparatedPacketSerializer @@ -44,6 +45,7 @@ def __init__( *, token_ttl: int | None = None, separator: bytes = b"\r\n", + limit: int = _DEFAULT_LIMIT, ) -> None: """ Parameters: @@ -51,13 +53,14 @@ def __init__( key: A URL-safe base64-encoded 32-byte key. token_ttl: Token time-to-live. See :meth:`cryptography.fernet.Fernet.decrypt` for details. separator: Token for :class:`AutoSeparatedPacketSerializer`. Used in incremental serialization context. + limit: Maximum buffer size. Used in incremental serialization context. """ try: import cryptography.fernet except ModuleNotFoundError as exc: # pragma: no cover raise ModuleNotFoundError("encryption dependencies are missing. Consider adding 'encryption' extra") from exc - super().__init__(separator=separator, incremental_serialize_check_separator=not separator.isspace()) + super().__init__(separator=separator, incremental_serialize_check_separator=not separator.isspace(), limit=limit) if not isinstance(serializer, AbstractPacketSerializer): raise TypeError(f"Expected a serializer instance, got {serializer!r}") self.__serializer: AbstractPacketSerializer[_DTOPacketT] = serializer diff --git a/src/easynetwork/tools/constants.py b/src/easynetwork/tools/constants.py index a6bc7f03..8ea7dae3 100644 --- a/src/easynetwork/tools/constants.py +++ b/src/easynetwork/tools/constants.py @@ -23,12 +23,16 @@ "MAX_STREAM_BUFSIZE", "SSL_HANDSHAKE_TIMEOUT", "SSL_SHUTDOWN_TIMEOUT", + "_DEFAULT_LIMIT", ] import errno as _errno from typing import Final +# Buffer size for a recv(2) operation MAX_STREAM_BUFSIZE: Final[int] = 256 * 1024 # 256KiB + +# Buffer size for a recvfrom(2) operation MAX_DATAGRAM_BUFSIZE: Final[int] = 64 * 1024 # 64KiB # Errors that socket operations can return if the socket is closed @@ -62,3 +66,6 @@ # Number of seconds to wait for SSL shutdown to complete # The default timeout mimics lingering_time SSL_SHUTDOWN_TIMEOUT: Final[float] = 30.0 + +# Buffer size limit when waiting for a byte sequence +_DEFAULT_LIMIT: Final[int] = 64 * 1024 # 64 KiB diff --git a/tests/unit_test/test_serializers/test_abc.py b/tests/unit_test/test_serializers/test_abc.py index 44656ce2..eca63da4 100644 --- a/tests/unit_test/test_serializers/test_abc.py +++ b/tests/unit_test/test_serializers/test_abc.py @@ -5,7 +5,7 @@ from collections.abc import Generator from typing import IO, TYPE_CHECKING, Any, final -from easynetwork.exceptions import DeserializeError, IncrementalDeserializeError +from easynetwork.exceptions import DeserializeError, IncrementalDeserializeError, LimitOverrunError from easynetwork.serializers.abc import AbstractIncrementalPacketSerializer from easynetwork.serializers.base_stream import ( AutoSeparatedPacketSerializer, @@ -186,8 +186,38 @@ def test____dunder_init____empty_separator_bytes(self) -> None: with pytest.raises(ValueError, match=r"^Empty separator$"): _ = _AutoSeparatedPacketSerializerForTest(b"") + @pytest.mark.parametrize("limit", [0, -42], ids=lambda p: f"limit=={p}") + def test____dunder_init____invalid_limit(self, limit: int) -> None: + # Arrange + + # Act & Assert + with pytest.raises(ValueError, match=r"^limit must be a positive integer$"): + _ = _AutoSeparatedPacketSerializerForTest(b"\n", limit=limit) + + def test____incremental_serialize____empty_bytes( + self, + check_separator: bool, + mock_serialize_func: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + serializer = _AutoSeparatedPacketSerializerForTest( + separator=b"\r\n", + incremental_serialize_check_separator=check_separator, + ) + mock_serialize_func.return_value = b"" + + # Act + data = list(serializer.incremental_serialize(mocker.sentinel.packet)) + + # Assert + mock_serialize_func.assert_called_once_with(mocker.sentinel.packet) + assert data == [] + + @pytest.mark.parametrize("limit_reached", [False, True], ids=lambda p: f"limit_reached=={p}") def test____incremental_serialize____append_separator( self, + limit_reached: bool, check_separator: bool, mock_serialize_func: MagicMock, mocker: MockerFixture, @@ -196,6 +226,7 @@ def test____incremental_serialize____append_separator( serializer = _AutoSeparatedPacketSerializerForTest( separator=b"\r\n", incremental_serialize_check_separator=check_separator, + limit=4 if limit_reached else 42, ) mock_serialize_func.return_value = b"data" @@ -204,7 +235,7 @@ def test____incremental_serialize____append_separator( # Assert mock_serialize_func.assert_called_once_with(mocker.sentinel.packet) - assert data == [b"data", b"\r\n"] + assert data == ([b"data", b"\r\n"] if limit_reached else [b"data\r\n"]) def test____incremental_serialize____keep_already_present_separator( self, @@ -224,7 +255,7 @@ def test____incremental_serialize____keep_already_present_separator( # Assert mock_serialize_func.assert_called_once_with(mocker.sentinel.packet) - assert data == [b"data" if check_separator else b"data\r\n", b"\r\n"] + assert data == [b"data\r\n" if check_separator else b"data\r\n\r\n"] def test____incremental_serialize____remove_useless_trailing_separators( self, @@ -244,7 +275,7 @@ def test____incremental_serialize____remove_useless_trailing_separators( # Assert mock_serialize_func.assert_called_once_with(mocker.sentinel.packet) - assert data == [b"data" if check_separator else b"data\r\n\r\n\r\n\r\n", b"\r\n"] + assert data == [b"data\r\n" if check_separator else b"data\r\n\r\n\r\n\r\n\r\n"] def test____incremental_serialize____does_not_remove_partial_separator_at_end( self, @@ -264,7 +295,7 @@ def test____incremental_serialize____does_not_remove_partial_separator_at_end( # Assert mock_serialize_func.assert_called_once_with(mocker.sentinel.packet) - assert data == [b"data\r" if check_separator else b"data\r\r\n", b"\r\n"] + assert data == [b"data\r\r\n" if check_separator else b"data\r\r\n\r\n"] def test____incremental_serialize____error_if_separator_is_within_output( self, @@ -294,13 +325,9 @@ def test____incremental_serialize____error_if_separator_is_within_output( pytest.param(b"remaining\r\nother", id="with remaining data including separator"), ], ) - @pytest.mark.parametrize("several_trailing_separators", [False, True], ids=lambda b: f"several_trailing_separators=={b}") - @pytest.mark.parametrize("several_leading_separators", [False, True], ids=lambda b: f"several_leading_separators=={b}") def test____incremental_deserialize____one_shot_chunk( self, expected_remaining_data: bytes, - several_trailing_separators: bool, - several_leading_separators: bool, mock_deserialize_func: MagicMock, mocker: MockerFixture, ) -> None: @@ -308,10 +335,6 @@ def test____incremental_deserialize____one_shot_chunk( serializer = _AutoSeparatedPacketSerializerForTest(separator=b"\r\n") mock_deserialize_func.return_value = mocker.sentinel.packet data_to_test: bytes = b"data\r\n" - if several_trailing_separators: - data_to_test = data_to_test + b"\r\n\r\n\r\n" - if several_leading_separators: - data_to_test = b"\r\n\r\n\r\n" + data_to_test # Act consumer = serializer.incremental_deserialize() @@ -387,6 +410,35 @@ def test____incremental_deserialize____translate_deserialize_errors( assert exception.remaining_data == expected_remaining_data assert exception.error_info is mocker.sentinel.error_info + @pytest.mark.parametrize("separator_found", [False, True], ids=lambda p: f"separator_found=={p}") + def test____incremental_deserialize____reached_limit( + self, + separator_found: bytes, + mock_deserialize_func: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + serializer = _AutoSeparatedPacketSerializerForTest(separator=b"\r\n", limit=1) + mock_deserialize_func.return_value = mocker.sentinel.packet + data_to_test: bytes = b"data\r" + if separator_found: + data_to_test += b"\n" + + # Act + consumer = serializer.incremental_deserialize() + next(consumer) + with pytest.raises(LimitOverrunError) as exc_info: + consumer.send(data_to_test) + + # Assert + mock_deserialize_func.assert_not_called() + if separator_found: + assert str(exc_info.value) == "Separator is found, but chunk is longer than limit" + assert exc_info.value.remaining_data == b"" + else: + assert str(exc_info.value) == "Separator is not found, and chunk exceed the limit" + assert exc_info.value.remaining_data == b"\r" + class _FixedSizePacketSerializerForTest(FixedSizePacketSerializer[Any]): def serialize(self, packet: Any) -> bytes: From 6e41a0c7563be6dd81e2baa5918dabd5e0fe7b64 Mon Sep 17 00:00:00 2001 From: Francis CLAIRICIA-ROSE-CLAIRE-JOSEPHINE Date: Fri, 29 Sep 2023 13:18:30 +0200 Subject: [PATCH 3/6] [UPDATE] JSONSerializer: Removed indent and separators options from JSONEncoderConfig --- src/easynetwork/serializers/json.py | 175 +++++++++--------- .../test_serializers/test_json.py | 2 +- tests/unit_test/test_serializers/test_json.py | 6 +- 3 files changed, 91 insertions(+), 92 deletions(-) diff --git a/src/easynetwork/serializers/json.py b/src/easynetwork/serializers/json.py index a7b764de..4ac25353 100644 --- a/src/easynetwork/serializers/json.py +++ b/src/easynetwork/serializers/json.py @@ -33,11 +33,6 @@ from ..tools._utils import iter_bytes from .abc import AbstractIncrementalPacketSerializer -_JSON_VALUE_BYTES: frozenset[int] = frozenset(bytes(string.digits + string.ascii_letters + string.punctuation, "ascii")) -_ESCAPE_BYTE: int = b"\\"[0] - -_whitespaces_match: Callable[[bytes, int], re.Match[bytes]] = re.compile(rb"[ \t\n\r]*", re.MULTILINE | re.DOTALL).match # type: ignore[assignment] - @dataclass(kw_only=True) class JSONEncoderConfig: @@ -51,8 +46,6 @@ class JSONEncoderConfig: check_circular: bool = True ensure_ascii: bool = True allow_nan: bool = True - indent: int | None = None - separators: tuple[str, str] | None = (",", ":") # Compact JSON (w/o whitespaces) default: Callable[..., Any] | None = None @@ -72,85 +65,6 @@ class JSONDecoderConfig: strict: bool = True -class _JSONParser: - class _PlainValueError(Exception): - pass - - @staticmethod - def _escaped(partial_document_view: memoryview) -> bool: - escaped = False - for byte in reversed(partial_document_view): - if byte == _ESCAPE_BYTE: - escaped = not escaped - else: - break - return escaped - - @staticmethod - def raw_parse() -> Generator[None, bytes, tuple[bytes, bytes]]: - escaped = _JSONParser._escaped - split_partial_document = _JSONParser._split_partial_document - enclosure_counter: Counter[bytes] = Counter() - partial_document: bytes = yield - first_enclosure: bytes = b"" - start: int = 0 - try: - while True: - with memoryview(partial_document) as partial_document_view: - for nb_chars, char in enumerate(iter_bytes(partial_document_view[start:]), start=start + 1): - match char: - case b'"' if not escaped(partial_document_view[: nb_chars - 1]): - enclosure_counter[b'"'] = 0 if enclosure_counter[b'"'] == 1 else 1 - case _ if enclosure_counter[b'"'] > 0: # We are within a JSON string, move on. - continue - case b"{" | b"[": - enclosure_counter[char] += 1 - case b"}": - enclosure_counter[b"{"] -= 1 - case b"]": - enclosure_counter[b"["] -= 1 - case b" " | b"\t" | b"\n" | b"\r": # Optimization: Skip spaces - continue - case _ if len(enclosure_counter) == 0: # No enclosure, only value - partial_document = partial_document[nb_chars - 1 :] if nb_chars > 1 else partial_document - del char, nb_chars - raise _JSONParser._PlainValueError - case _: # JSON character, quickly go to next character - continue - assert len(enclosure_counter) > 0 # nosec assert_used - if not first_enclosure: - first_enclosure = next(iter(enclosure_counter)) - if enclosure_counter[first_enclosure] <= 0: # 1st found is closed - return split_partial_document(partial_document, nb_chars) - - # partial_document not complete - start = partial_document_view.nbytes - - # yield outside view scope - partial_document += yield - - except _JSONParser._PlainValueError: - pass - - # The document is a plain value (null, true, false, or a number) - - del enclosure_counter, first_enclosure - - while (nprint_idx := next((idx for idx, byte in enumerate(partial_document) if byte not in _JSON_VALUE_BYTES), -1)) < 0: - partial_document += yield - - return split_partial_document(partial_document, nprint_idx) - - @staticmethod - def _split_partial_document(partial_document: bytes, index: int) -> tuple[bytes, bytes]: - index = _whitespaces_match(partial_document, index).end() - if index == len(partial_document): - # The following bytes are only spaces - # Do not slice the document, the trailing spaces will be ignored by JSONDecoder - return partial_document, b"" - return partial_document[:index], partial_document[index:] - - class JSONSerializer(AbstractIncrementalPacketSerializer[Any]): """ A :term:`serializer` built on top of the :mod:`json` module. @@ -193,7 +107,7 @@ def __init__( elif not isinstance(decoder_config, JSONDecoderConfig): raise TypeError(f"Invalid decoder config: expected {JSONDecoderConfig.__name__}, got {type(decoder_config).__name__}") - self.__encoder = JSONEncoder(**dataclass_asdict(encoder_config)) + self.__encoder = JSONEncoder(**dataclass_asdict(encoder_config), indent=None, separators=(",", ":")) self.__decoder = JSONDecoder(**dataclass_asdict(decoder_config)) self.__decoder_error_cls = JSONDecodeError @@ -342,3 +256,90 @@ def incremental_deserialize(self) -> Generator[None, bytes, tuple[Any, bytes]]: }, ) from exc return packet, remaining_data + + +class _JSONParser: + _JSON_VALUE_BYTES: frozenset[int] = frozenset(bytes(string.digits + string.ascii_letters + string.punctuation, "ascii")) + _ESCAPE_BYTE: int = ord(b"\\") + + _whitespaces_match: Callable[[bytes, int], re.Match[bytes]] = re.compile(rb"[ \t\n\r]*", re.MULTILINE | re.DOTALL).match # type: ignore[assignment] + + class _PlainValueError(Exception): + pass + + @staticmethod + def _escaped(partial_document_view: memoryview) -> bool: + escaped = False + _ESCAPE_BYTE = _JSONParser._ESCAPE_BYTE + for byte in reversed(partial_document_view): + if byte == _ESCAPE_BYTE: + escaped = not escaped + else: + break + return escaped + + @staticmethod + def raw_parse() -> Generator[None, bytes, tuple[bytes, bytes]]: + escaped = _JSONParser._escaped + split_partial_document = _JSONParser._split_partial_document + enclosure_counter: Counter[bytes] = Counter() + partial_document: bytes = yield + first_enclosure: bytes = b"" + start: int = 0 + try: + while True: + with memoryview(partial_document) as partial_document_view: + for nb_chars, char in enumerate(iter_bytes(partial_document_view[start:]), start=start + 1): + match char: + case b'"' if not escaped(partial_document_view[: nb_chars - 1]): + enclosure_counter[b'"'] = 0 if enclosure_counter[b'"'] == 1 else 1 + case _ if enclosure_counter[b'"'] > 0: # We are within a JSON string, move on. + continue + case b"{" | b"[": + enclosure_counter[char] += 1 + case b"}": + enclosure_counter[b"{"] -= 1 + case b"]": + enclosure_counter[b"["] -= 1 + case b" " | b"\t" | b"\n" | b"\r": # Optimization: Skip spaces + continue + case _ if len(enclosure_counter) == 0: # No enclosure, only value + partial_document = partial_document[nb_chars - 1 :] if nb_chars > 1 else partial_document + del char, nb_chars + raise _JSONParser._PlainValueError + case _: # JSON character, quickly go to next character + continue + assert len(enclosure_counter) > 0 # nosec assert_used + if not first_enclosure: + first_enclosure = next(iter(enclosure_counter)) + if enclosure_counter[first_enclosure] <= 0: # 1st found is closed + return split_partial_document(partial_document, nb_chars) + + # partial_document not complete + start = partial_document_view.nbytes + + # yield outside view scope + partial_document += yield + + except _JSONParser._PlainValueError: + pass + + # The document is a plain value (null, true, false, or a number) + + del enclosure_counter, first_enclosure + + _JSON_VALUE_BYTES = _JSONParser._JSON_VALUE_BYTES + + while (nprint_idx := next((idx for idx, byte in enumerate(partial_document) if byte not in _JSON_VALUE_BYTES), -1)) < 0: + partial_document += yield + + return split_partial_document(partial_document, nprint_idx) + + @staticmethod + def _split_partial_document(partial_document: bytes, index: int) -> tuple[bytes, bytes]: + index = _JSONParser._whitespaces_match(partial_document, index).end() + if index == len(partial_document): + # The following bytes are only spaces + # Do not slice the document, the trailing spaces will be ignored by JSONDecoder + return partial_document, b"" + return partial_document[:index], partial_document[index:] diff --git a/tests/functional_test/test_serializers/test_json.py b/tests/functional_test/test_serializers/test_json.py index 4d173b6b..3d682eac 100644 --- a/tests/functional_test/test_serializers/test_json.py +++ b/tests/functional_test/test_serializers/test_json.py @@ -43,7 +43,7 @@ def packet_to_serialize(request: Any) -> Any: def expected_complete_data(cls, packet_to_serialize: Any) -> bytes: import json - return json.dumps(packet_to_serialize, **dataclasses.asdict(cls.ENCODER_CONFIG)).encode("utf-8") + return json.dumps(packet_to_serialize, **dataclasses.asdict(cls.ENCODER_CONFIG), separators=(",", ":")).encode("utf-8") #### Incremental Serialize diff --git a/tests/unit_test/test_serializers/test_json.py b/tests/unit_test/test_serializers/test_json.py index 0fc0eaf0..9d3ac7b2 100644 --- a/tests/unit_test/test_serializers/test_json.py +++ b/tests/unit_test/test_serializers/test_json.py @@ -66,8 +66,6 @@ def encoder_config(request: Any, mocker: MockerFixture) -> JSONEncoderConfig | N check_circular=mocker.sentinel.check_circular, ensure_ascii=mocker.sentinel.ensure_ascii, allow_nan=mocker.sentinel.allow_nan, - indent=mocker.sentinel.indent, - separators=mocker.sentinel.separators, default=mocker.sentinel.object_default, ) @@ -108,8 +106,8 @@ def test____dunder_init____with_encoder_config( check_circular=mocker.sentinel.check_circular if encoder_config is not None else True, ensure_ascii=mocker.sentinel.ensure_ascii if encoder_config is not None else True, allow_nan=mocker.sentinel.allow_nan if encoder_config is not None else True, - indent=mocker.sentinel.indent if encoder_config is not None else None, - separators=mocker.sentinel.separators if encoder_config is not None else (",", ":"), + indent=None, + separators=(",", ":"), default=mocker.sentinel.object_default if encoder_config is not None else None, ) From 9f0d006a0ac2f366943cac5409e8edd9dc01e1b9 Mon Sep 17 00:00:00 2001 From: Francis CLAIRICIA-ROSE-CLAIRE-JOSEPHINE Date: Sat, 30 Sep 2023 13:10:05 +0200 Subject: [PATCH 4/6] [UPDATE] JSONSerializer: Added buffer limit --- docs/source/api/serializers/index.rst | 8 + docs/source/api/serializers/tools.rst | 7 + src/easynetwork/exceptions.py | 4 + src/easynetwork/serializers/base_stream.py | 48 +-- src/easynetwork/serializers/json.py | 112 +++++-- src/easynetwork/serializers/tools.py | 200 +++++++++++++ .../test_serializers/test_json.py | 45 ++- tests/tools.py | 6 + tests/unit_test/test_serializers/test_json.py | 106 ++++++- .../unit_test/test_serializers/test_tools.py | 274 ++++++++++++++++++ 10 files changed, 724 insertions(+), 86 deletions(-) create mode 100644 docs/source/api/serializers/tools.rst create mode 100644 src/easynetwork/serializers/tools.py create mode 100644 tests/unit_test/test_serializers/test_tools.py diff --git a/docs/source/api/serializers/index.rst b/docs/source/api/serializers/index.rst index 7d623319..a38059f5 100644 --- a/docs/source/api/serializers/index.rst +++ b/docs/source/api/serializers/index.rst @@ -30,3 +30,11 @@ Serializers wrappers/base64 wrappers/compressor wrappers/encryptor + +----- + +.. toctree:: + :caption: Miscellaneous + :maxdepth: 2 + + tools diff --git a/docs/source/api/serializers/tools.rst b/docs/source/api/serializers/tools.rst new file mode 100644 index 00000000..39bdb70b --- /dev/null +++ b/docs/source/api/serializers/tools.rst @@ -0,0 +1,7 @@ +******************************* +Serializer implementation tools +******************************* + +.. automodule:: easynetwork.serializers.tools + :members: + :no-docstring: diff --git a/src/easynetwork/exceptions.py b/src/easynetwork/exceptions.py index 8a72332c..dfb83463 100644 --- a/src/easynetwork/exceptions.py +++ b/src/easynetwork/exceptions.py @@ -25,6 +25,7 @@ "DatagramProtocolParseError", "DeserializeError", "IncrementalDeserializeError", + "LimitOverrunError", "PacketConversionError", "ServerAlreadyRunning", "ServerClosedError", @@ -102,6 +103,9 @@ def __init__(self, message: str, buffer: ReadableBuffer, consumed: int, separato super().__init__(message, remaining_data, error_info=None) + self.consumed: int = consumed + """Total number of to be consumed bytes.""" + class PacketConversionError(Exception): """The deserialized :term:`packet` is invalid.""" diff --git a/src/easynetwork/serializers/base_stream.py b/src/easynetwork/serializers/base_stream.py index f5a18f6d..587067eb 100644 --- a/src/easynetwork/serializers/base_stream.py +++ b/src/easynetwork/serializers/base_stream.py @@ -28,9 +28,10 @@ from typing import IO, Any, final from .._typevars import _DTOPacketT -from ..exceptions import DeserializeError, IncrementalDeserializeError, LimitOverrunError +from ..exceptions import DeserializeError, IncrementalDeserializeError from ..tools.constants import _DEFAULT_LIMIT from .abc import AbstractIncrementalPacketSerializer, AbstractPacketSerializer +from .tools import GeneratorStreamReader class AutoSeparatedPacketSerializer(AbstractIncrementalPacketSerializer[_DTOPacketT]): @@ -122,36 +123,13 @@ def incremental_deserialize(self) -> Generator[None, bytes, tuple[_DTOPacketT, b See :meth:`.AbstractIncrementalPacketSerializer.incremental_deserialize` documentation for details. Raises: + LimitOverrunError: Reached buffer size limit. IncrementalDeserializeError: :meth:`deserialize` raised :exc:`.DeserializeError`. Exception: Any error raised by :meth:`deserialize`. """ - buffer: bytes = yield - separator: bytes = self.__separator - seplen: int = len(separator) - limit: int = self.__limit - offset: int = 0 - sepidx: int = -1 - - while True: - buflen = len(buffer) - - if buflen - offset >= seplen: - sepidx = buffer.find(separator, offset) - - if sepidx != -1: - break - - offset = buflen + 1 - seplen - if offset > limit: - raise LimitOverrunError("Separator is not found, and chunk exceed the limit", buffer, offset, separator) - - buffer += yield - - if sepidx > limit: - raise LimitOverrunError("Separator is found, but chunk is longer than limit", buffer, sepidx, separator) - - data = buffer[:sepidx] - buffer = buffer[sepidx + seplen :] + reader = GeneratorStreamReader() + data = yield from reader.read_until(self.__separator, limit=self.__limit, keep_end=False) + buffer = reader.read_all() try: packet = self.deserialize(data) @@ -238,18 +216,10 @@ def incremental_deserialize(self) -> Generator[None, bytes, tuple[_DTOPacketT, b IncrementalDeserializeError: :meth:`deserialize` raised :exc:`.DeserializeError`. Exception: Any error raised by :meth:`deserialize`. """ - buffer: bytes = yield - packet_size: int = self.__size - while (buffer_size := len(buffer)) < packet_size: - buffer += yield + reader = GeneratorStreamReader() + data = yield from reader.read_exactly(self.__size) + buffer = reader.read_all() - # Do not copy if the size is *exactly* as expected - if buffer_size == packet_size: - data = buffer - buffer = b"" - else: - data = buffer[:packet_size] - buffer = buffer[packet_size:] try: packet = self.deserialize(data) except DeserializeError as exc: diff --git a/src/easynetwork/serializers/json.py b/src/easynetwork/serializers/json.py index 4ac25353..90aa06b1 100644 --- a/src/easynetwork/serializers/json.py +++ b/src/easynetwork/serializers/json.py @@ -22,6 +22,7 @@ "JSONSerializer", ] + import re import string from collections import Counter @@ -29,9 +30,11 @@ from dataclasses import asdict as dataclass_asdict, dataclass from typing import Any, final -from ..exceptions import DeserializeError, IncrementalDeserializeError +from ..exceptions import DeserializeError, IncrementalDeserializeError, LimitOverrunError from ..tools._utils import iter_bytes +from ..tools.constants import _DEFAULT_LIMIT from .abc import AbstractIncrementalPacketSerializer +from .tools import GeneratorStreamReader @dataclass(kw_only=True) @@ -70,7 +73,15 @@ class JSONSerializer(AbstractIncrementalPacketSerializer[Any]): A :term:`serializer` built on top of the :mod:`json` module. """ - __slots__ = ("__encoder", "__decoder", "__decoder_error_cls", "__encoding", "__unicode_errors") + __slots__ = ( + "__encoder", + "__decoder", + "__decoder_error_cls", + "__encoding", + "__unicode_errors", + "__limit", + "__use_lines", + ) def __init__( self, @@ -79,6 +90,8 @@ def __init__( *, encoding: str = "utf-8", unicode_errors: str = "strict", + limit: int = _DEFAULT_LIMIT, + use_lines: bool = True, ) -> None: """ Parameters: @@ -86,6 +99,8 @@ def __init__( decoder_config: Parameter object to configure the :class:`~json.JSONDecoder`. encoding: String encoding. unicode_errors: Controls how encoding errors are handled. + limit: Maximum buffer size. Used in incremental serialization context. + use_lines: If :data:`True` (the default), each ASCII lines is considered a JSON object. See Also: :ref:`standard-encodings` and :ref:`error-handlers`. @@ -107,6 +122,9 @@ def __init__( elif not isinstance(decoder_config, JSONDecoderConfig): raise TypeError(f"Invalid decoder config: expected {JSONDecoderConfig.__name__}, got {type(decoder_config).__name__}") + if limit <= 0: + raise ValueError("limit must be a positive integer") + self.__encoder = JSONEncoder(**dataclass_asdict(encoder_config), indent=None, separators=(",", ":")) self.__decoder = JSONDecoder(**dataclass_asdict(decoder_config)) self.__decoder_error_cls = JSONDecodeError @@ -114,6 +132,9 @@ def __init__( self.__encoding: str = encoding self.__unicode_errors: str = unicode_errors + self.__limit: int = limit + self.__use_lines: bool = bool(use_lines) + @final def serialize(self, packet: Any) -> bytes: """ @@ -153,8 +174,22 @@ def incremental_serialize(self, packet: Any) -> Generator[bytes, None, None]: Yields: all the parts of the JSON :term:`packet`. """ - yield self.__encoder.encode(packet).encode(self.__encoding, self.__unicode_errors) - yield b"\n" + data = self.__encoder.encode(packet).encode(self.__encoding, self.__unicode_errors) + newline = b"\n" + if not data.startswith((b"{", b"[", b'"')): + data += newline + yield data + return + if not self.__use_lines: + yield data + return + if len(data) + len(newline) <= self.__limit // 2: + data += newline + yield data + else: + yield data + del data + yield newline @final def deserialize(self, data: bytes) -> Any: @@ -206,7 +241,7 @@ def incremental_deserialize(self) -> Generator[None, bytes, tuple[Any, bytes]]: Creates a Python object representing the raw JSON :term:`packet`. Example: - >>> s = JSONSerializer() + >>> s = JSONSerializer(use_lines=False) >>> consumer = s.incremental_deserialize() >>> next(consumer) >>> consumer.send(b'{"key":[1,2,3]') @@ -219,17 +254,21 @@ def incremental_deserialize(self) -> Generator[None, bytes, tuple[Any, bytes]]: :data:`None` until the whole :term:`packet` has been deserialized. Raises: + LimitOverrunError: Reached buffer size limit. IncrementalDeserializeError: A :class:`UnicodeError` or :class:`~json.JSONDecodeError` have been raised. Returns: a tuple with the deserialized Python object and the unused trailing data. """ - complete_document, remaining_data = yield from _JSONParser.raw_parse() - - if not complete_document: - # If this condition is verified, decoder.decode() will most likely raise JSONDecodeError - complete_document = remaining_data - remaining_data = b"" + complete_document: bytes + remaining_data: bytes + if self.__use_lines: + reader = GeneratorStreamReader() + complete_document = yield from reader.read_until(b"\n", limit=self.__limit) + remaining_data = reader.read_all() + del reader + else: + complete_document, remaining_data = yield from _JSONParser.raw_parse(limit=self.__limit) packet: Any try: @@ -279,19 +318,21 @@ def _escaped(partial_document_view: memoryview) -> bool: return escaped @staticmethod - def raw_parse() -> Generator[None, bytes, tuple[bytes, bytes]]: + def raw_parse(*, limit: int) -> Generator[None, bytes, tuple[bytes, bytes]]: + if limit <= 0: + raise ValueError("limit must be a positive integer") escaped = _JSONParser._escaped split_partial_document = _JSONParser._split_partial_document enclosure_counter: Counter[bytes] = Counter() partial_document: bytes = yield first_enclosure: bytes = b"" - start: int = 0 try: + offset: int = 0 while True: with memoryview(partial_document) as partial_document_view: - for nb_chars, char in enumerate(iter_bytes(partial_document_view[start:]), start=start + 1): + for offset, char in enumerate(iter_bytes(partial_document_view[offset:]), start=offset): match char: - case b'"' if not escaped(partial_document_view[: nb_chars - 1]): + case b'"' if not escaped(partial_document_view[:offset]): enclosure_counter[b'"'] = 0 if enclosure_counter[b'"'] == 1 else 1 case _ if enclosure_counter[b'"'] > 0: # We are within a JSON string, move on. continue @@ -304,8 +345,8 @@ def raw_parse() -> Generator[None, bytes, tuple[bytes, bytes]]: case b" " | b"\t" | b"\n" | b"\r": # Optimization: Skip spaces continue case _ if len(enclosure_counter) == 0: # No enclosure, only value - partial_document = partial_document[nb_chars - 1 :] if nb_chars > 1 else partial_document - del char, nb_chars + partial_document = partial_document[offset:] if offset > 0 else partial_document + del char, offset raise _JSONParser._PlainValueError case _: # JSON character, quickly go to next character continue @@ -313,10 +354,16 @@ def raw_parse() -> Generator[None, bytes, tuple[bytes, bytes]]: if not first_enclosure: first_enclosure = next(iter(enclosure_counter)) if enclosure_counter[first_enclosure] <= 0: # 1st found is closed - return split_partial_document(partial_document, nb_chars) + return split_partial_document(partial_document, offset + 1, limit) # partial_document not complete - start = partial_document_view.nbytes + offset = partial_document_view.nbytes + if offset > limit: + raise LimitOverrunError( + "JSON object's end frame is not found, and chunk exceed the limit", + partial_document, + offset, + ) # yield outside view scope partial_document += yield @@ -331,15 +378,32 @@ def raw_parse() -> Generator[None, bytes, tuple[bytes, bytes]]: _JSON_VALUE_BYTES = _JSONParser._JSON_VALUE_BYTES while (nprint_idx := next((idx for idx, byte in enumerate(partial_document) if byte not in _JSON_VALUE_BYTES), -1)) < 0: + if len(partial_document) > limit: + raise LimitOverrunError( + "JSON object's end frame is not found, and chunk exceed the limit", + partial_document, + nprint_idx, + ) partial_document += yield - return split_partial_document(partial_document, nprint_idx) + return split_partial_document(partial_document, nprint_idx, limit) @staticmethod - def _split_partial_document(partial_document: bytes, index: int) -> tuple[bytes, bytes]: - index = _JSONParser._whitespaces_match(partial_document, index).end() - if index == len(partial_document): + def _split_partial_document(partial_document: bytes, consumed: int, limit: int) -> tuple[bytes, bytes]: + if consumed > limit: + raise LimitOverrunError( + "JSON object's end frame is found, but chunk is longer than limit", + partial_document, + consumed, + ) + consumed = _JSONParser._whitespaces_match(partial_document, consumed).end() + if consumed == len(partial_document): # The following bytes are only spaces # Do not slice the document, the trailing spaces will be ignored by JSONDecoder return partial_document, b"" - return partial_document[:index], partial_document[index:] + complete_document, partial_document = partial_document[:consumed], partial_document[consumed:] + if not complete_document: + # If this condition is verified, decoder.decode() will most likely raise JSONDecodeError + complete_document = partial_document + partial_document = b"" + return complete_document, partial_document diff --git a/src/easynetwork/serializers/tools.py b/src/easynetwork/serializers/tools.py new file mode 100644 index 00000000..d4f06844 --- /dev/null +++ b/src/easynetwork/serializers/tools.py @@ -0,0 +1,200 @@ +# Copyright 2021-2023, Francis Clairicia-Rose-Claire-Josephine +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +"""Serializer implementation tools module""" + +from __future__ import annotations + +__all__ = ["GeneratorStreamReader"] + +from collections.abc import Generator + +from ..exceptions import LimitOverrunError + + +class GeneratorStreamReader: + """ + A binary stream-like object using an in-memory bytes buffer. + + The "blocking" operation is done with the generator's :keyword:`yield` statement. It is an helper for + :term:`incremental serializer` implementations. + """ + + __slots__ = ("__buffer",) + + def __init__(self, initial_bytes: bytes = b"") -> None: + """ + Parameters: + initial_bytes: a :class:`bytes` object that contains initial data. + """ + self.__buffer: bytes = initial_bytes + + def read_all(self) -> bytes: + """ + Read and return all the bytes currently in the reader. + + Returns: + a :class:`bytes` object. + """ + + data, self.__buffer = self.__buffer, b"" + return data + + def read(self, size: int) -> Generator[None, bytes, bytes]: + """ + Read and return up to `size` bytes. + + Example:: + + def incremental_deserialize(self) -> Generator[None, bytes, tuple[Packet, bytes]]: + reader = GeneratorStreamReader() + + data: bytes = yield from reader.read(1024) # Get at most 1024 bytes. + + ... + + Yields: + if there is no data in buffer. + + Returns: + a :class:`bytes` object. + """ + + if size < 0: + raise ValueError("size must not be < 0") + if size == 0: + return b"" + + while not self.__buffer: + self.__buffer = yield + + if len(self.__buffer) <= size: + data, self.__buffer = self.__buffer, b"" + else: + data = self.__buffer[:size] + self.__buffer = self.__buffer[size:] + + return data + + def read_exactly(self, n: int) -> Generator[None, bytes, bytes]: + """ + Read exactly `n` bytes. + + Example:: + + def incremental_deserialize(self) -> Generator[None, bytes, tuple[Packet, bytes]]: + reader = GeneratorStreamReader() + + header: bytes = yield from reader.read_exactly(32) + assert len(header) == 32 + + ... + + Yields: + until `n` bytes is in the buffer. + + Returns: + a :class:`bytes` object. + """ + + if n < 0: + raise ValueError("n must not be < 0") + if n == 0: + return b"" + + if not self.__buffer: + self.__buffer = yield + while (buflen := len(self.__buffer)) < n: + self.__buffer += yield + + if buflen == n: + data, self.__buffer = self.__buffer, b"" + else: + data = self.__buffer[:n] + self.__buffer = self.__buffer[n:] + + return data + + def read_until(self, separator: bytes, limit: int, *, keep_end: bool = True) -> Generator[None, bytes, bytes]: + r""" + Read data from the stream until `separator` is found. + + On success, the data and separator will be removed from the internal buffer (consumed). + + If the amount of data read exceeds `limit`, a :exc:`LimitOverrunError` exception is raised, + and the data is left in the internal buffer and can be read again. + + Example:: + + def incremental_deserialize(self) -> Generator[None, bytes, tuple[Packet, bytes]]: + reader = GeneratorStreamReader() + + line: bytes = yield from reader.read_until(b"\n", limit=65535) + assert line.endswith(b"\n") + + ... + + Parameters: + separator: The byte sequence to find. + limit: The maximum buffer size. + keep_end: If :data:`True` (the default), returned data will include the separator at the end. + + Raises: + LimitOverrunError: Reached buffer size limit. + + Yields: + until `separator` is found in the buffer. + + Returns: + a :class:`bytes` object. + """ + + if limit <= 0: + raise ValueError("limit must be a positive integer") + seplen: int = len(separator) + if seplen < 1: + raise ValueError("Empty separator") + + if not self.__buffer: + self.__buffer = yield + + offset: int = 0 + sepidx: int = -1 + while True: + buflen = len(self.__buffer) + + if buflen - offset >= seplen: + sepidx = self.__buffer.find(separator, offset) + + if sepidx != -1: + break + + offset = buflen + 1 - seplen + if offset > limit: + msg = "Separator is not found, and chunk exceed the limit" + raise LimitOverrunError(msg, self.__buffer, offset, separator) + + self.__buffer += yield + + if sepidx > limit: + msg = "Separator is found, but chunk is longer than limit" + raise LimitOverrunError(msg, self.__buffer, sepidx, separator) + + if keep_end: + data = self.__buffer[: sepidx + seplen] + else: + data = self.__buffer[:sepidx] + self.__buffer = self.__buffer[sepidx + seplen :] + + return data diff --git a/tests/functional_test/test_serializers/test_json.py b/tests/functional_test/test_serializers/test_json.py index 3d682eac..0b1b7f14 100644 --- a/tests/functional_test/test_serializers/test_json.py +++ b/tests/functional_test/test_serializers/test_json.py @@ -19,15 +19,20 @@ class TestJSONSerializer(BaseTestIncrementalSerializer): ENCODER_CONFIG = JSONEncoderConfig(ensure_ascii=False) + @pytest.fixture(scope="class", params=[False, True], ids=lambda p: f"use_lines=={p}") + @staticmethod + def use_lines(request: Any) -> bool: + return request.param + @pytest.fixture(scope="class") @classmethod - def serializer_for_serialization(cls) -> JSONSerializer: - return JSONSerializer(encoder_config=cls.ENCODER_CONFIG) + def serializer_for_serialization(cls, use_lines: bool) -> JSONSerializer: + return JSONSerializer(encoder_config=cls.ENCODER_CONFIG, use_lines=use_lines) @pytest.fixture(scope="class") @staticmethod - def serializer_for_deserialization() -> JSONSerializer: - return JSONSerializer() + def serializer_for_deserialization(use_lines: bool) -> JSONSerializer: + return JSONSerializer(use_lines=use_lines) #### Packets to test @@ -49,18 +54,24 @@ def expected_complete_data(cls, packet_to_serialize: Any) -> bytes: @pytest.fixture(scope="class") @staticmethod - def expected_joined_data(expected_complete_data: bytes) -> bytes: - return expected_complete_data + b"\n" + def expected_joined_data(expected_complete_data: bytes, use_lines: bool) -> bytes: + if use_lines or not expected_complete_data.startswith((b"{", b"[", b'"')): + return expected_complete_data + b"\n" + return expected_complete_data #### One-shot Deserialize @pytest.fixture(scope="class") @staticmethod - def complete_data(packet_to_serialize: Any) -> bytes: + def complete_data(packet_to_serialize: Any, use_lines: bool) -> bytes: import json - # Test with indentation to see whitespace handling - return json.dumps(packet_to_serialize, ensure_ascii=False, indent=2).encode("utf-8") + indent: int | None = None + if not use_lines: + # Test with indentation to see whitespace handling + indent = 4 + + return json.dumps(packet_to_serialize, ensure_ascii=False, indent=indent).encode("utf-8") #### Incremental Deserialize @@ -73,17 +84,23 @@ def complete_data_for_incremental_deserialize(complete_data: bytes) -> bytes: @pytest.fixture(scope="class", params=[b"invalid", b"\0"]) @staticmethod - def invalid_complete_data(request: Any) -> bytes: - return getattr(request, "param") + def invalid_complete_data(request: Any, use_lines: bool) -> bytes: + data: bytes = getattr(request, "param") + if use_lines: + data += b"\n" + return data @pytest.fixture(scope="class", params=[b"[ invalid ]", b"\0"]) @staticmethod - def invalid_partial_data(request: Any) -> bytes: - return getattr(request, "param") + def invalid_partial_data(request: Any, use_lines: bool) -> bytes: + data: bytes = getattr(request, "param") + if use_lines: + data += b"\n" + return data @pytest.fixture @staticmethod def invalid_partial_data_extra_data(invalid_partial_data: bytes) -> bytes: - if invalid_partial_data == b"\0": + if invalid_partial_data.startswith(b"\0"): return b"" return b"remaining_data" diff --git a/tests/tools.py b/tests/tools.py index 513501d4..b4234ab6 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -28,6 +28,12 @@ def send_return(gen: Generator[Any, _T_contra, _V_co], value: _T_contra, /) -> _ return exc_info.value.value +def next_return(gen: Generator[Any, Any, _V_co], /) -> _V_co: + with pytest.raises(StopIteration) as exc_info: + gen.send(None) + return exc_info.value.value + + @final class TimeTest: def __init__(self, expected_time: float, approx: float | None = None) -> None: diff --git a/tests/unit_test/test_serializers/test_json.py b/tests/unit_test/test_serializers/test_json.py index 9d3ac7b2..1cedf53e 100644 --- a/tests/unit_test/test_serializers/test_json.py +++ b/tests/unit_test/test_serializers/test_json.py @@ -5,6 +5,8 @@ from easynetwork.exceptions import DeserializeError, IncrementalDeserializeError from easynetwork.serializers.json import JSONDecoderConfig, JSONEncoderConfig, JSONSerializer, _JSONParser +from easynetwork.serializers.tools import GeneratorStreamReader +from easynetwork.tools.constants import _DEFAULT_LIMIT import pytest @@ -84,11 +86,26 @@ def decoder_config(request: Any, mocker: MockerFixture) -> JSONDecoderConfig | N strict=mocker.sentinel.strict, ) + @pytest.fixture(params=[True, False], ids=lambda boolean: f"use_lines=={boolean}") + @staticmethod + def use_lines(request: pytest.FixtureRequest) -> bool: + return getattr(request, "param") + @pytest.fixture @staticmethod def mock_json_parser(mocker: MockerFixture) -> MagicMock: return mocker.patch.object(_JSONParser, "raw_parse", autospec=True) + @pytest.fixture + @staticmethod + def mock_generator_stream_reader(mocker: MockerFixture) -> MagicMock: + return mocker.NonCallableMagicMock(spec=GeneratorStreamReader) + + @pytest.fixture + @staticmethod + def mock_generator_stream_reader_cls(mock_generator_stream_reader: MagicMock, mocker: MockerFixture) -> MagicMock: + return mocker.patch(f"{JSONSerializer.__module__}.GeneratorStreamReader", return_value=mock_generator_stream_reader) + def test____dunder_init____with_encoder_config( self, encoder_config: JSONEncoderConfig | None, @@ -132,6 +149,14 @@ def test____dunder_init____with_decoder_config( strict=mocker.sentinel.strict if decoder_config is not None else True, ) + @pytest.mark.parametrize("limit", [0, -42], ids=lambda p: f"limit=={p}") + def test____dunder_init____invalid_limit(self, limit: int) -> None: + # Arrange + + # Act & Assert + with pytest.raises(ValueError, match=r"^limit must be a positive integer$"): + _ = JSONSerializer(limit=limit) + def test____serialize____encode_packet( self, mock_encoder: MagicMock, @@ -143,7 +168,7 @@ def test____serialize____encode_packet( unicode_errors=mocker.sentinel.str_errors, ) mock_string = mock_encoder.encode.return_value = mocker.NonCallableMagicMock() - mock_string.encode.return_value = mocker.sentinel.data + mock_string.encode.return_value = b'{"data":42}' # Act data = serializer.serialize(mocker.sentinel.packet) @@ -151,10 +176,15 @@ def test____serialize____encode_packet( # Assert mock_encoder.encode.assert_called_once_with(mocker.sentinel.packet) mock_string.encode.assert_called_once_with(mocker.sentinel.encoding, mocker.sentinel.str_errors) - assert data is mocker.sentinel.data + assert data == b'{"data":42}' - def test____incremental_serialize____encode_packet( + @pytest.mark.parametrize("value", [b'{"data":42}', b"[4]", b'"string"']) + @pytest.mark.parametrize("limit_reached", [False, True], ids=lambda p: f"limit_reached=={p}") + def test____incremental_serialize____encode_packet____with_frames( self, + value: bytes, + use_lines: bool, + limit_reached: bool, mock_encoder: MagicMock, mocker: MockerFixture, ) -> None: @@ -162,9 +192,11 @@ def test____incremental_serialize____encode_packet( serializer: JSONSerializer = JSONSerializer( encoding=mocker.sentinel.encoding, unicode_errors=mocker.sentinel.str_errors, + limit=2 if limit_reached else 1024, + use_lines=use_lines, ) mock_string = mock_encoder.encode.return_value = mocker.NonCallableMagicMock() - mock_string.encode.return_value = mocker.sentinel.data + mock_string.encode.return_value = value # Act chunks = list(serializer.incremental_serialize(mocker.sentinel.packet)) @@ -172,7 +204,41 @@ def test____incremental_serialize____encode_packet( # Assert mock_encoder.encode.assert_called_once_with(mocker.sentinel.packet) mock_string.encode.assert_called_once_with(mocker.sentinel.encoding, mocker.sentinel.str_errors) - assert chunks == [mocker.sentinel.data, b"\n"] + if use_lines: + if limit_reached: + assert chunks == [value, b"\n"] + else: + assert chunks == [value + b"\n"] + else: + assert chunks == [value] + + @pytest.mark.parametrize("value", [b"12345", b"true", b"false", b"null"]) + @pytest.mark.parametrize("limit_reached", [False, True], ids=lambda p: f"limit_reached=={p}") + def test____incremental_serialize____encode_packet____plain_value( + self, + value: bytes, + use_lines: bool, + limit_reached: bool, + mock_encoder: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + serializer: JSONSerializer = JSONSerializer( + encoding=mocker.sentinel.encoding, + unicode_errors=mocker.sentinel.str_errors, + limit=2 if limit_reached else 1024, + use_lines=use_lines, + ) + mock_string = mock_encoder.encode.return_value = mocker.NonCallableMagicMock() + mock_string.encode.return_value = value + + # Act + chunks = list(serializer.incremental_serialize(mocker.sentinel.packet)) + + # Assert + mock_encoder.encode.assert_called_once_with(mocker.sentinel.packet) + mock_string.encode.assert_called_once_with(mocker.sentinel.encoding, mocker.sentinel.str_errors) + assert chunks == [value + b"\n"] def test____deserialize____decode_data( self, @@ -247,24 +313,35 @@ def test____deserialize____translate_json_decode_errors( def test____incremental_deserialize____parse_and_decode_data( self, + use_lines: bool, mock_decoder: MagicMock, mock_json_parser: MagicMock, + mock_generator_stream_reader_cls: MagicMock, + mock_generator_stream_reader: MagicMock, mocker: MockerFixture, ) -> None: # Arrange - def raw_parse_side_effect() -> Generator[None, bytes, tuple[bytes, bytes]]: + def raw_parse_side_effect(*args: Any, **kwargs: Any) -> Generator[None, bytes, tuple[bytes, bytes]]: data = yield assert data is mocker.sentinel.data return mock_bytes, b"Hello World !" + def reader_read_until_side_effect(*args: Any, **kwargs: Any) -> Generator[None, bytes, bytes]: + data = yield + assert data is mocker.sentinel.data + return mock_bytes + serializer: JSONSerializer = JSONSerializer( encoding=mocker.sentinel.encoding, unicode_errors=mocker.sentinel.str_errors, + use_lines=use_lines, ) mock_bytes = mocker.NonCallableMagicMock() mock_string_document = mock_bytes.decode.return_value = mocker.NonCallableMagicMock() mock_decoder.decode.return_value = mocker.sentinel.packet mock_json_parser.side_effect = raw_parse_side_effect + mock_generator_stream_reader.read_until.side_effect = reader_read_until_side_effect + mock_generator_stream_reader.read_all.return_value = b"Hello World !" # Act consumer = serializer.incremental_deserialize() @@ -272,7 +349,16 @@ def raw_parse_side_effect() -> Generator[None, bytes, tuple[bytes, bytes]]: packet, remaining_data = send_return(consumer, mocker.sentinel.data) # Assert - mock_json_parser.assert_called_once_with() + if use_lines: + mock_json_parser.assert_not_called() + mock_generator_stream_reader_cls.assert_called_once_with() + mock_generator_stream_reader.read_until.assert_called_once_with(b"\n", limit=_DEFAULT_LIMIT) + mock_generator_stream_reader.read_all.assert_called_once_with() + else: + mock_json_parser.assert_called_once_with(limit=_DEFAULT_LIMIT) + mock_generator_stream_reader_cls.assert_not_called() + mock_generator_stream_reader.read_until.assert_not_called() + mock_generator_stream_reader.read_all.assert_not_called() mock_bytes.decode.assert_called_once_with(mocker.sentinel.encoding, mocker.sentinel.str_errors) mock_decoder.decode.assert_called_once_with(mock_string_document) assert packet is mocker.sentinel.packet @@ -285,7 +371,7 @@ def test____incremental_deserialize____translate_unicode_decode_errors( mocker: MockerFixture, ) -> None: # Arrange - def raw_parse_side_effect() -> Generator[None, bytes, tuple[bytes, bytes]]: + def raw_parse_side_effect(*args: Any, **kwargs: Any) -> Generator[None, bytes, tuple[bytes, bytes]]: data = yield assert data is mocker.sentinel.data return mock_bytes, mocker.sentinel.remaining_data @@ -293,6 +379,7 @@ def raw_parse_side_effect() -> Generator[None, bytes, tuple[bytes, bytes]]: serializer: JSONSerializer = JSONSerializer( encoding=mocker.sentinel.encoding, unicode_errors=mocker.sentinel.str_errors, + use_lines=False, ) mock_bytes = mocker.NonCallableMagicMock() mock_bytes.decode.side_effect = UnicodeDecodeError("some encoding", b"invalid data", 0, 2, "Bad encoding ?") @@ -321,7 +408,7 @@ def test____incremental_deserialize____translate_json_decode_errors( # Arrange from json import JSONDecodeError - def raw_parse_side_effect() -> Generator[None, bytes, tuple[bytes, bytes]]: + def raw_parse_side_effect(*args: Any, **kwargs: Any) -> Generator[None, bytes, tuple[bytes, bytes]]: data = yield assert data is mocker.sentinel.data return mock_bytes, mocker.sentinel.remaining_data @@ -329,6 +416,7 @@ def raw_parse_side_effect() -> Generator[None, bytes, tuple[bytes, bytes]]: serializer: JSONSerializer = JSONSerializer( encoding=mocker.sentinel.encoding, unicode_errors=mocker.sentinel.str_errors, + use_lines=False, ) mock_bytes = mocker.NonCallableMagicMock() mock_decoder.decode.side_effect = JSONDecodeError("Invalid payload", "invalid\ndocument", 8) diff --git a/tests/unit_test/test_serializers/test_tools.py b/tests/unit_test/test_serializers/test_tools.py new file mode 100644 index 00000000..649f39a3 --- /dev/null +++ b/tests/unit_test/test_serializers/test_tools.py @@ -0,0 +1,274 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from easynetwork.exceptions import LimitOverrunError +from easynetwork.serializers.tools import GeneratorStreamReader + +import pytest + +from ...tools import next_return, send_return + +if TYPE_CHECKING: + from pytest_mock import MockerFixture + + +class TestGeneratorStreamReader: + def test____read_all____pop_buffer_without_copying(self, mocker: MockerFixture) -> None: + # Arrange + reader = GeneratorStreamReader(mocker.sentinel.initial_buffer) + + # Act + data = reader.read_all() + + # Assert + assert data is mocker.sentinel.initial_buffer + assert reader.read_all() == b"" + + def test____read____yield_if_buffer_is_empty(self) -> None: + # Arrange + reader = GeneratorStreamReader() + + # Act + consumer = reader.read(1024) + next(consumer) + data = send_return(consumer, b"data") + + # Assert + assert data == b"data" + + def test____read____do_not_yield_if_buffer_is_not_empty(self) -> None: + # Arrange + reader = GeneratorStreamReader(b"data") + + # Act + consumer = reader.read(1024) + data = next_return(consumer) + + # Assert + assert data == b"data" + + def test____read____no_data_copy____in_use_buffer(self) -> None: + # Arrange + value = b"data" + reader = GeneratorStreamReader(value) + + # Act + consumer = reader.read(1024) + data = next_return(consumer) + + # Assert + assert data is value + + def test____read____no_data_copy____sent_value(self) -> None: + # Arrange + value = b"data" + reader = GeneratorStreamReader() + + # Act + consumer = reader.read(1024) + next(consumer) + data = send_return(consumer, value) + + # Assert + assert data is value + + @pytest.mark.parametrize("initial", [False, True], ids=lambda p: f"initial=={p}") + def test____read____shrink_too_big_buffer(self, initial: bool) -> None: + # Arrange + value = b"data" + reader = GeneratorStreamReader(value if initial else b"") + + # Act + consumer = reader.read(2) + if initial: + data = next_return(consumer) + else: + next(consumer) + data = send_return(consumer, value) + + # Assert + assert data == b"da" + assert reader.read_all() == b"ta" + + def test____read____null_nbytes(self) -> None: + # Arrange + reader = GeneratorStreamReader() + + # Act + consumer = reader.read(0) + data = next_return(consumer) + + # Assert + assert data == b"" + + def test____read____negative_nbytes(self) -> None: + # Arrange + reader = GeneratorStreamReader() + + # Act & Assert + consumer = reader.read(-1) + with pytest.raises(ValueError, match=r"^size must not be < 0$"): + next_return(consumer) + + def test____read_exactly____yield_until_size_has_been_reached(self) -> None: + # Arrange + reader = GeneratorStreamReader() + + # Act + consumer = reader.read_exactly(4) + next(consumer) + consumer.send(b"d") + consumer.send(b"a") + consumer.send(b"t") + data = send_return(consumer, b"a") + + # Assert + assert data == b"data" + + @pytest.mark.parametrize("initial", [False, True], ids=lambda p: f"initial=={p}") + def test____read_exactly____size_already_fit(self, initial: bool) -> None: + # Arrange + value = b"data" + reader = GeneratorStreamReader(value if initial else b"") + + # Act + consumer = reader.read_exactly(4) + if initial: + data = next_return(consumer) + else: + next(consumer) + data = send_return(consumer, value) + + # Assert + assert data is value + + @pytest.mark.parametrize("initial", [False, True], ids=lambda p: f"initial=={p}") + def test____read_exactly____shrink_too_big_buffer(self, initial: bool) -> None: + # Arrange + prefix = b"dat" + reader = GeneratorStreamReader(prefix if initial else b"") + + # Act + consumer = reader.read_exactly(4) + next(consumer) + if not initial: + consumer.send(prefix) + data = send_return(consumer, b"abc") + + # Assert + assert data == b"data" + assert reader.read_all() == b"bc" + + def test____read_exactly____null_nbytes(self) -> None: + # Arrange + reader = GeneratorStreamReader() + + # Act + consumer = reader.read_exactly(0) + data = next_return(consumer) + + # Assert + assert data == b"" + + def test____read_exactly____negative_nbytes(self) -> None: + # Arrange + reader = GeneratorStreamReader() + + # Act & Assert + consumer = reader.read_exactly(-1) + with pytest.raises(ValueError, match=r"^n must not be < 0$"): + next_return(consumer) + + @pytest.mark.parametrize( + "expected_remaining_data", + [ + pytest.param(b"", id="without remaining data"), + pytest.param(b"remaining", id="with remaining data"), + pytest.param(b"remaining\r\nother", id="with remaining data including separator"), + ], + ) + @pytest.mark.parametrize("keep_end", [False, True], ids=lambda p: f"keep_end=={p}") + def test____read_until____one_shot_chunk(self, expected_remaining_data: bytes, keep_end: bool) -> None: + # Arrange + reader = GeneratorStreamReader() + data_to_test: bytes = b"data\r\n" + + # Act + consumer = reader.read_until(b"\r\n", limit=1024, keep_end=keep_end) + next(consumer) + data = send_return(consumer, data_to_test + expected_remaining_data) + + # Assert + if keep_end: + assert data == b"data\r\n" + else: + assert data == b"data" + assert reader.read_all() == expected_remaining_data + + @pytest.mark.parametrize( + "expected_remaining_data", + [ + pytest.param(b"", id="without remaining data"), + pytest.param(b"remaining", id="with remaining data"), + pytest.param(b"remaining\r\nother", id="with remaining data including separator"), + ], + ) + @pytest.mark.parametrize("keep_end", [False, True], ids=lambda p: f"keep_end=={p}") + def test____read_until____several_chunks(self, expected_remaining_data: bytes, keep_end: bool) -> None: + # Arrange + reader = GeneratorStreamReader(b"d") + + # Act + consumer = reader.read_until(b"\r\n", limit=1024, keep_end=keep_end) + next(consumer) + consumer.send(b"ata\r") + data = send_return(consumer, b"\n" + expected_remaining_data) + + # Assert + if keep_end: + assert data == b"data\r\n" + else: + assert data == b"data" + assert reader.read_all() == expected_remaining_data + + @pytest.mark.parametrize("separator_found", [False, True], ids=lambda p: f"separator_found=={p}") + def test____read_until____reached_limit(self, separator_found: bytes) -> None: + # Arrange + reader = GeneratorStreamReader() + data_to_test: bytes = b"data\r" + if separator_found: + data_to_test += b"\n" + + # Act + consumer = reader.read_until(b"\r\n", limit=1) + next(consumer) + with pytest.raises(LimitOverrunError) as exc_info: + consumer.send(data_to_test) + + # Assert + if separator_found: + assert str(exc_info.value) == "Separator is found, but chunk is longer than limit" + assert exc_info.value.remaining_data == b"" + else: + assert str(exc_info.value) == "Separator is not found, and chunk exceed the limit" + assert exc_info.value.remaining_data == b"\r" + + def test____read_until____empty_separator(self) -> None: + # Arrange + reader = GeneratorStreamReader() + + # Act & Assert + consumer = reader.read_until(b"", limit=1024) + with pytest.raises(ValueError, match=r"^Empty separator$"): + next_return(consumer) + + @pytest.mark.parametrize("limit", [0, -42], ids=lambda p: f"limit=={p}") + def test____read_until____invalid_limit(self, limit: int) -> None: + # Arrange + reader = GeneratorStreamReader() + + # Act & Assert + consumer = reader.read_until(b"\n", limit) + with pytest.raises(ValueError, match=r"^limit must be a positive integer$"): + next_return(consumer) From 91eda2462af5c44d6e431fad5b48dad775ee5876 Mon Sep 17 00:00:00 2001 From: Francis CLAIRICIA-ROSE-CLAIRE-JOSEPHINE Date: Sat, 30 Sep 2023 14:17:28 +0200 Subject: [PATCH 5/6] [UPDATE] JSONSerializer: Added unit tests --- src/easynetwork/serializers/json.py | 13 +- tests/unit_test/test_serializers/test_json.py | 195 +++++++++++++++++- 2 files changed, 196 insertions(+), 12 deletions(-) diff --git a/src/easynetwork/serializers/json.py b/src/easynetwork/serializers/json.py index 90aa06b1..c4c64a61 100644 --- a/src/easynetwork/serializers/json.py +++ b/src/easynetwork/serializers/json.py @@ -175,21 +175,18 @@ def incremental_serialize(self, packet: Any) -> Generator[bytes, None, None]: all the parts of the JSON :term:`packet`. """ data = self.__encoder.encode(packet).encode(self.__encoding, self.__unicode_errors) - newline = b"\n" if not data.startswith((b"{", b"[", b'"')): - data += newline + data += b"\n" yield data - return - if not self.__use_lines: + elif not self.__use_lines: yield data - return - if len(data) + len(newline) <= self.__limit // 2: - data += newline + elif len(data) + 1 <= self.__limit // 2: + data += b"\n" yield data else: yield data del data - yield newline + yield b"\n" @final def deserialize(self, data: bytes) -> Any: diff --git a/tests/unit_test/test_serializers/test_json.py b/tests/unit_test/test_serializers/test_json.py index 1cedf53e..37ac3edb 100644 --- a/tests/unit_test/test_serializers/test_json.py +++ b/tests/unit_test/test_serializers/test_json.py @@ -3,10 +3,10 @@ from collections.abc import Generator from typing import TYPE_CHECKING, Any -from easynetwork.exceptions import DeserializeError, IncrementalDeserializeError +from easynetwork.exceptions import DeserializeError, IncrementalDeserializeError, LimitOverrunError from easynetwork.serializers.json import JSONDecoderConfig, JSONEncoderConfig, JSONSerializer, _JSONParser from easynetwork.serializers.tools import GeneratorStreamReader -from easynetwork.tools.constants import _DEFAULT_LIMIT +from easynetwork.tools.constants import _DEFAULT_LIMIT as DEFAULT_LIMIT import pytest @@ -352,10 +352,10 @@ def reader_read_until_side_effect(*args: Any, **kwargs: Any) -> Generator[None, if use_lines: mock_json_parser.assert_not_called() mock_generator_stream_reader_cls.assert_called_once_with() - mock_generator_stream_reader.read_until.assert_called_once_with(b"\n", limit=_DEFAULT_LIMIT) + mock_generator_stream_reader.read_until.assert_called_once_with(b"\n", limit=DEFAULT_LIMIT) mock_generator_stream_reader.read_all.assert_called_once_with() else: - mock_json_parser.assert_called_once_with(limit=_DEFAULT_LIMIT) + mock_json_parser.assert_called_once_with(limit=DEFAULT_LIMIT) mock_generator_stream_reader_cls.assert_not_called() mock_generator_stream_reader.read_until.assert_not_called() mock_generator_stream_reader.read_all.assert_not_called() @@ -440,3 +440,190 @@ def raw_parse_side_effect(*args: Any, **kwargs: Any) -> Generator[None, bytes, t "lineno": 2, "colno": 1, } + + +class TestJSONParser: + def test____raw_parse____object_frame(self) -> None: + # Arrange + consumer = _JSONParser.raw_parse(limit=DEFAULT_LIMIT) + next(consumer) + + # Act + consumer.send(b'{"data"') + complete, remainder = send_return(consumer, b":42}remainder") + + # Assert + assert complete == b'{"data":42}' + assert remainder == b"remainder" + + def test____raw_parse____object_frame____skip_bracket_in_strings(self) -> None: + # Arrange + consumer = _JSONParser.raw_parse(limit=DEFAULT_LIMIT) + next(consumer) + + # Act + complete, remainder = send_return(consumer, b'{"data}": "something}"}remainder') + + # Assert + assert complete == b'{"data}": "something}"}' + assert remainder == b"remainder" + + def test____raw_parse____object_frame____whitespaces(self) -> None: + # Arrange + consumer = _JSONParser.raw_parse(limit=DEFAULT_LIMIT) + next(consumer) + + # Act + consumer.send(b'{"data": 42,\n') + consumer.send(b'"list": [true, false, null]\n') + complete, remainder = send_return(consumer, b"}\n") + + # Assert + assert complete == b'{"data": 42,\n"list": [true, false, null]\n}\n' + assert remainder == b"" + + def test____raw_parse____list_frame(self) -> None: + # Arrange + consumer = _JSONParser.raw_parse(limit=DEFAULT_LIMIT) + next(consumer) + + # Act + consumer.send(b'[{"data"') + consumer.send(b":42}") + complete, remainder = send_return(consumer, b"]remainder") + + # Assert + assert complete == b'[{"data":42}]' + assert remainder == b"remainder" + + def test____raw_parse___list_frame____skip_bracket_in_strings(self) -> None: + # Arrange + consumer = _JSONParser.raw_parse(limit=DEFAULT_LIMIT) + next(consumer) + + # Act + complete, remainder = send_return(consumer, b'["string]", "second]"]remainder') + + # Assert + assert complete == b'["string]", "second]"]' + assert remainder == b"remainder" + + def test____raw_parse____list_frame____whitespaces(self) -> None: + # Arrange + consumer = _JSONParser.raw_parse(limit=DEFAULT_LIMIT) + next(consumer) + + # Act + consumer.send(b'[{\n"data"') + consumer.send(b': 42,\n "test": true},\n') + consumer.send(b"null,\n") + consumer.send(b'"string"\n') + complete, remainder = send_return(consumer, b"]\n") + + # Assert + assert complete == b'[{\n"data": 42,\n "test": true},\nnull,\n"string"\n]\n' + assert remainder == b"" + + def test____raw_parse____string_frame(self) -> None: + # Arrange + consumer = _JSONParser.raw_parse(limit=DEFAULT_LIMIT) + next(consumer) + + # Act + consumer.send(b'"data{') + consumer.send(b"}") + complete, remainder = send_return(consumer, b'"remainder') + + # Assert + assert complete == b'"data{}"' + assert remainder == b"remainder" + + def test____raw_parse____string_frame____escaped_quote(self) -> None: + # Arrange + consumer = _JSONParser.raw_parse(limit=DEFAULT_LIMIT) + next(consumer) + + # Act + consumer.send(b'"data') + consumer.send(b'\\"') + complete, remainder = send_return(consumer, b'"remainder') + + # Assert + assert complete == b'"data\\""' + assert remainder == b"remainder" + + def test____raw_parse____string_frame____escape_character(self) -> None: + # Arrange + consumer = _JSONParser.raw_parse(limit=DEFAULT_LIMIT) + next(consumer) + + # Act + consumer.send(b'"data') + consumer.send(b"\\\\") + complete, remainder = send_return(consumer, b'"remainder') + + # Assert + assert complete == b'"data\\\\"' + assert remainder == b"remainder" + + def test____raw_parse____plain_value(self) -> None: + # Arrange + consumer = _JSONParser.raw_parse(limit=DEFAULT_LIMIT) + next(consumer) + + # Act + consumer.send(b"tr") + complete, remainder = send_return(consumer, b"ue\nremainder") + + # Assert + assert complete == b"true\n" + assert remainder == b"remainder" + + def test____raw_parse____plain_value____first_character_is_invalid(self) -> None: + # Arrange + consumer = _JSONParser.raw_parse(limit=DEFAULT_LIMIT) + next(consumer) + + # Act + complete, remainder = send_return(consumer, b"\0") + + # Assert + assert complete == b"\0" + assert remainder == b"" + + @pytest.mark.parametrize("limit", [0, -42], ids=lambda p: f"limit=={p}") + def test____raw_parse____invalid_limit(self, limit: int) -> None: + # Arrange + consumer = _JSONParser.raw_parse(limit=limit) + + # Act & Assert + with pytest.raises(ValueError, match=r"^limit must be a positive integer$"): + next(consumer) + + @pytest.mark.parametrize( + ["start_frame", "end_frame"], + [ + pytest.param(b'{"data":', b'"something"}', id="object frame"), + pytest.param(b'["data",', b'"something"]', id="list frame"), + pytest.param(b'"data', b' something"', id="string frame"), + pytest.param(b"123", b"45\n", id="plain value"), + ], + ) + @pytest.mark.parametrize("end_frame_found", [False, True], ids=lambda p: f"end_frame_found=={p}") + def test____raw_parse____reached_limit(self, start_frame: bytes, end_frame: bytes, end_frame_found: bool) -> None: + # Arrange + consumer = _JSONParser.raw_parse(limit=2) + next(consumer) + data_to_test = start_frame + if end_frame_found: + data_to_test += end_frame + + # Act + with pytest.raises(LimitOverrunError) as exc_info: + consumer.send(data_to_test) + + # Assert + if end_frame_found: + assert str(exc_info.value) == "JSON object's end frame is found, but chunk is longer than limit" + else: + assert str(exc_info.value) == "JSON object's end frame is not found, and chunk exceed the limit" From 190888536573bc1e7adcce74ff96332b8b6c4bfd Mon Sep 17 00:00:00 2001 From: Francis CLAIRICIA-ROSE-CLAIRE-JOSEPHINE Date: Sat, 30 Sep 2023 14:29:36 +0200 Subject: [PATCH 6/6] [FIX] Compressor serializers: Minor fix --- src/easynetwork/serializers/wrapper/compressor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/easynetwork/serializers/wrapper/compressor.py b/src/easynetwork/serializers/wrapper/compressor.py index e8f67c7f..b982f70b 100644 --- a/src/easynetwork/serializers/wrapper/compressor.py +++ b/src/easynetwork/serializers/wrapper/compressor.py @@ -186,7 +186,10 @@ def incremental_deserialize(self) -> Generator[None, bytes, tuple[_DTOPacketT, b results.append(chunk) del chunk - data = b"".join(results) + if len(results) == 1: + data = results[0] + else: + data = b"".join(results) unused_data: bytes = decompressor.unused_data del results, decompressor