diff --git a/docs/source/api/serializers/cbor.rst b/docs/source/api/serializers/cbor.rst index b553e897..3e0bfca5 100644 --- a/docs/source/api/serializers/cbor.rst +++ b/docs/source/api/serializers/cbor.rst @@ -12,6 +12,8 @@ Configuration .. autoclass:: CBOREncoderConfig :members: + :undoc-members: .. autoclass:: CBORDecoderConfig :members: + :undoc-members: diff --git a/docs/source/api/serializers/json.rst b/docs/source/api/serializers/json.rst index 8ac3774c..06e5dab3 100644 --- a/docs/source/api/serializers/json.rst +++ b/docs/source/api/serializers/json.rst @@ -12,6 +12,8 @@ Configuration .. autoclass:: JSONEncoderConfig :members: + :undoc-members: .. autoclass:: JSONDecoderConfig :members: + :undoc-members: diff --git a/docs/source/api/serializers/msgpack.rst b/docs/source/api/serializers/msgpack.rst index 1b667986..4cac8243 100644 --- a/docs/source/api/serializers/msgpack.rst +++ b/docs/source/api/serializers/msgpack.rst @@ -12,6 +12,8 @@ Configuration .. autoclass:: MessagePackerConfig :members: + :undoc-members: .. autoclass:: MessageUnpackerConfig :members: + :undoc-members: diff --git a/docs/source/api/serializers/pickle.rst b/docs/source/api/serializers/pickle.rst index 83704b67..277450db 100644 --- a/docs/source/api/serializers/pickle.rst +++ b/docs/source/api/serializers/pickle.rst @@ -16,6 +16,8 @@ Configuration .. autoclass:: PicklerConfig :members: + :undoc-members: .. autoclass:: UnpicklerConfig :members: + :undoc-members: diff --git a/micro_benchmarks/serializers/bench_cbor.py b/micro_benchmarks/serializers/bench_cbor.py index f01de370..d404b0b1 100644 --- a/micro_benchmarks/serializers/bench_cbor.py +++ b/micro_benchmarks/serializers/bench_cbor.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, assert_type +from typing import TYPE_CHECKING, Any from easynetwork.serializers.cbor import CBORSerializer @@ -57,9 +57,8 @@ def bench_CBORSerializer_incremental_deserialize( def deserialize() -> Any: consumer = serializer.buffered_incremental_deserialize(buffer) - assert_type(next(consumer), None) - with buffer.cast("B") as view: - view[:nbytes] = cbor_data + next(consumer) + buffer[:nbytes] = cbor_data try: consumer.send(nbytes) except StopIteration as exc: diff --git a/micro_benchmarks/serializers/bench_msgpack.py b/micro_benchmarks/serializers/bench_msgpack.py index cf82769b..bba69137 100644 --- a/micro_benchmarks/serializers/bench_msgpack.py +++ b/micro_benchmarks/serializers/bench_msgpack.py @@ -6,6 +6,8 @@ from easynetwork.serializers.msgpack import MessagePackSerializer +import pytest + if TYPE_CHECKING: from pytest_benchmark.fixture import BenchmarkFixture @@ -29,3 +31,53 @@ def bench_MessagePackSerializer_deserialize( result = benchmark(serializer.deserialize, msgpack_data) assert result == json_object + + +def bench_MessagePackSerializer_incremental_serialize( + benchmark: BenchmarkFixture, + json_object: Any, +) -> None: + serializer = MessagePackSerializer() + + benchmark(lambda: b"".join(serializer.incremental_serialize(json_object))) + + +@pytest.mark.parametrize("buffered", [False, True], ids=lambda p: f"buffered=={p}") +def bench_MessagePackSerializer_incremental_deserialize( + buffered: bool, + benchmark: BenchmarkFixture, + msgpack_data: bytes, + json_object: Any, +) -> None: + serializer = MessagePackSerializer() + + if buffered: + nbytes = len(msgpack_data) + buffer: memoryview = serializer.create_deserializer_buffer(nbytes) + + def deserialize() -> Any: + consumer = serializer.buffered_incremental_deserialize(buffer) + next(consumer) + buffer[:nbytes] = msgpack_data + try: + consumer.send(nbytes) + except StopIteration as exc: + return exc.value + else: + raise RuntimeError("consumer yielded") + + else: + + def deserialize() -> Any: + consumer = serializer.incremental_deserialize() + next(consumer) + try: + consumer.send(msgpack_data) + except StopIteration as exc: + return exc.value + else: + raise RuntimeError("consumer yielded") + + result, _ = benchmark(deserialize) + + assert result == json_object diff --git a/pdm.lock b/pdm.lock index bad0eb96..37593186 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "bandit", "benchmark-servers", "benchmark-servers-deps", "build", "cbor", "coverage", "dev", "doc", "flake8", "format", "micro-benchmark", "msgpack", "mypy", "pre-commit", "test", "test-trio", "tox", "trio", "types-msgpack", "uvloop"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:1b80719e2b5aeccbd3fc7688f2139ec1335e93b13057431aa7e47cd0c1a66ab2" +content_hash = "sha256:696499065d91cf2d7c26ecdd12eb68179cb0e0972ba0a3847690fbefc8e1d9d5" [[metadata.targets]] requires_python = ">=3.11" diff --git a/pyproject.toml b/pyproject.toml index 1b61cc1f..2404a68f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ cbor = [ "cbor2>=5.5,<6", ] msgpack = [ - "msgpack>=1.0.7,<2", + "msgpack>=1.0.8,<2", ] trio = [ "trio>=0.26,<1", diff --git a/src/easynetwork/exceptions.py b/src/easynetwork/exceptions.py index 8e88c6f0..c8d2fa48 100644 --- a/src/easynetwork/exceptions.py +++ b/src/easynetwork/exceptions.py @@ -115,7 +115,7 @@ def __init__(self, message: str, buffer: ReadableBuffer, consumed: int, separato while remaining_data.nbytes and remaining_data[:seplen] != separator[: remaining_data.nbytes]: remaining_data = remaining_data[1:] - super().__init__(message, remaining_data, error_info=None) + super().__init__(message, bytes(remaining_data), error_info=None) self.consumed: int = consumed """Total number of to be consumed bytes.""" diff --git a/src/easynetwork/serializers/base_stream.py b/src/easynetwork/serializers/base_stream.py index 66ab6638..13df839b 100644 --- a/src/easynetwork/serializers/base_stream.py +++ b/src/easynetwork/serializers/base_stream.py @@ -361,12 +361,13 @@ class FileBasedPacketSerializer(BufferedIncrementalPacketSerializer[_T_SentDTOPa Base class for APIs requiring a :std:term:`file object` for serialization/deserialization. """ - __slots__ = ("__expected_errors", "__debug") + __slots__ = ("__expected_errors", "__limit", "__debug") def __init__( self, expected_load_error: type[Exception] | tuple[type[Exception], ...], *, + limit: int = DEFAULT_SERIALIZER_LIMIT, debug: bool = False, **kwargs: Any, ) -> None: @@ -374,6 +375,7 @@ def __init__( Parameters: expected_load_error: Errors that can be raised by :meth:`load_from_file` implementation, which must be considered as deserialization errors. + limit: Maximum buffer size. debug: If :data:`True`, add information to :exc:`.DeserializeError` via the ``error_info`` attribute. kwargs: Extra options given to ``super().__init__()``. """ @@ -381,6 +383,11 @@ def __init__( if not isinstance(expected_load_error, tuple): expected_load_error = (expected_load_error,) assert all(issubclass(e, Exception) for e in expected_load_error) # nosec assert_used + + if limit <= 0: + raise ValueError("limit must be a positive integer") + + self.__limit: int = limit self.__expected_errors: tuple[type[Exception], ...] = expected_load_error self.__debug: bool = bool(debug) @@ -412,24 +419,26 @@ def load_from_file(self, file: IO[bytes], /) -> _T_ReceivedDTOPacket: """ raise NotImplementedError - @final def serialize(self, packet: _T_SentDTOPacket, /) -> bytes: """ - Calls :meth:`dump_to_file` and returns the result. + Returns the byte representation of the Python object `packet`. + + By default, this method uses :meth:`dump_to_file`, but it can be overriden for speicific usage. See :meth:`.AbstractPacketSerializer.serialize` documentation for details. Raises: - Exception: Any error raised by :meth:`dump_to_file`. + Exception: Any error raised by :meth:`dump_to_bytes`. """ with BytesIO() as buffer: self.dump_to_file(packet, buffer) return buffer.getvalue() - @final def deserialize(self, data: bytes, /) -> _T_ReceivedDTOPacket: """ - Calls :meth:`load_from_file` and returns the result. + Creates a Python object representing the raw :term:`packet` from `data`. + + By default, this method uses :meth:`load_from_file`, but it can be overriden for speicific usage. See :meth:`.AbstractPacketSerializer.deserialize` documentation for details. @@ -499,6 +508,7 @@ def create_deserializer_buffer(self, sizehint: int, /) -> memoryview: """ See :meth:`.BufferedIncrementalPacketSerializer.create_deserializer_buffer` documentation for details. """ + sizehint = min(sizehint, self.__limit) return memoryview(bytearray(sizehint)) @final @@ -528,10 +538,11 @@ def __generic_incremental_deserialize(self) -> Generator[None, ReadableBuffer, t if not initial: buffer.write((yield)) buffer.seek(0) + self.__check_file_buffer_limit(buffer) try: packet: _T_ReceivedDTOPacket = self.load_from_file(buffer) except EOFError: - continue + pass except self.__expected_errors as exc: msg = f"Deserialize error: {exc}" if self.debug: @@ -546,6 +557,11 @@ def __generic_incremental_deserialize(self) -> Generator[None, ReadableBuffer, t finally: initial = False + def __check_file_buffer_limit(self, file: BytesIO) -> None: + with file.getbuffer() as buffer_view: + if buffer_view.nbytes > self.__limit: + raise LimitOverrunError("chunk exceeded buffer limit", buffer_view, consumed=buffer_view.nbytes) + @property @final def debug(self) -> bool: @@ -554,6 +570,14 @@ def debug(self) -> bool: """ return self.__debug + @property + @final + def buffer_limit(self) -> int: + """ + Maximum buffer size. Read-only attribute. + """ + return self.__limit + def _wrap_generic_incremental_deserialize( func: Callable[[], Generator[None, ReadableBuffer, tuple[_T_ReceivedDTOPacket, ReadableBuffer]]], diff --git a/src/easynetwork/serializers/cbor.py b/src/easynetwork/serializers/cbor.py index 755fdf23..fc9b128a 100644 --- a/src/easynetwork/serializers/cbor.py +++ b/src/easynetwork/serializers/cbor.py @@ -36,6 +36,7 @@ from typing import IO, TYPE_CHECKING, Any, final from ..lowlevel import _utils +from ..lowlevel.constants import DEFAULT_SERIALIZER_LIMIT from .base_stream import FileBasedPacketSerializer if TYPE_CHECKING: @@ -86,12 +87,14 @@ def __init__( encoder_config: CBOREncoderConfig | None = None, decoder_config: CBORDecoderConfig | None = None, *, + limit: int = DEFAULT_SERIALIZER_LIMIT, debug: bool = False, ) -> None: """ Parameters: encoder_config: Parameter object to configure the :class:`~cbor.encoder.CBOREncoder`. decoder_config: Parameter object to configure the :class:`~cbor.decoder.CBORDecoder`. + limit: Maximum buffer size. Used in incremental serialization context. debug: If :data:`True`, add information to :exc:`.DeserializeError` via the ``error_info`` attribute. """ try: @@ -99,7 +102,7 @@ def __init__( except ModuleNotFoundError as exc: raise _utils.missing_extra_deps("cbor") from exc - super().__init__(expected_load_error=(cbor2.CBORDecodeError, UnicodeError), debug=debug) + super().__init__(expected_load_error=(cbor2.CBORDecodeError, UnicodeError), limit=limit, debug=debug) self.__encoder_cls: Callable[[IO[bytes]], cbor2.CBOREncoder] self.__decoder_cls: Callable[[IO[bytes]], cbor2.CBORDecoder] diff --git a/src/easynetwork/serializers/msgpack.py b/src/easynetwork/serializers/msgpack.py index e7e72e40..b6024f8d 100644 --- a/src/easynetwork/serializers/msgpack.py +++ b/src/easynetwork/serializers/msgpack.py @@ -31,21 +31,14 @@ ] from collections.abc import Callable -from dataclasses import asdict as dataclass_asdict, dataclass, field +from dataclasses import asdict as dataclass_asdict, dataclass from functools import partial -from typing import Any, final +from typing import IO, Any, final from ..exceptions import DeserializeError from ..lowlevel import _utils -from .abc import AbstractPacketSerializer - - -def _get_default_ext_hook() -> Callable[[int, bytes], Any]: - try: - from msgpack import ExtType - except ModuleNotFoundError as exc: - raise _utils.missing_extra_deps("msgpack", feature_name="message-pack") from exc - return ExtType +from ..lowlevel.constants import DEFAULT_SERIALIZER_LIMIT +from .base_stream import FileBasedPacketSerializer @dataclass(kw_only=True) @@ -79,29 +72,38 @@ class MessageUnpackerConfig: unicode_errors: str = "strict" object_hook: Callable[[dict[Any, Any]], Any] | None = None object_pairs_hook: Callable[[list[tuple[Any, Any]]], Any] | None = None - ext_hook: Callable[[int, bytes], Any] = field(default_factory=_get_default_ext_hook) + ext_hook: Callable[[int, bytes], Any] | None = None -class MessagePackSerializer(AbstractPacketSerializer[Any, Any]): +class MessagePackSerializer(FileBasedPacketSerializer[Any, Any]): """ - A :term:`one-shot serializer` built on top of the :mod:`msgpack` module. + A :term:`serializer` built on top of the :mod:`msgpack` module. Needs ``msgpack`` extra dependencies. """ - __slots__ = ("__packb", "__unpackb", "__unpack_out_of_data_cls", "__unpack_extra_data_cls", "__debug") + __slots__ = ( + "__packb", + "__unpackb", + "__incremental_packer", + "__incremental_unpacker", + "__unpack_out_of_data_cls", + "__unpack_extra_data_cls", + ) def __init__( self, packer_config: MessagePackerConfig | None = None, unpacker_config: MessageUnpackerConfig | None = None, *, + limit: int = DEFAULT_SERIALIZER_LIMIT, debug: bool = False, ) -> None: """ Parameters: packer_config: Parameter object to configure the :class:`~msgpack.Packer`. unpacker_config: Parameter object to configure the :class:`~msgpack.Unpacker`. + limit: Maximum buffer size. Used in incremental serialization context. debug: If :data:`True`, add information to :exc:`.DeserializeError` via the ``error_info`` attribute. """ try: @@ -109,9 +111,16 @@ def __init__( except ModuleNotFoundError as exc: raise _utils.missing_extra_deps("msgpack", feature_name="message-pack") from exc - super().__init__() + super().__init__( + expected_load_error=Exception, # The documentation says to catch all exceptions :) + limit=limit, + debug=debug, + ) + limit = self.buffer_limit self.__packb: Callable[[Any], bytes] self.__unpackb: Callable[[bytes], Any] + self.__incremental_packer: Callable[[], msgpack.Packer] + self.__incremental_unpacker: Callable[[IO[bytes]], msgpack.Unpacker] if packer_config is None: packer_config = MessagePackerConfig() @@ -125,13 +134,21 @@ def __init__( f"Invalid unpacker config: expected {MessageUnpackerConfig.__name__}, got {type(unpacker_config).__name__}" ) - self.__packb = partial(msgpack.packb, **dataclass_asdict(packer_config), autoreset=True) - self.__unpackb = partial(msgpack.unpackb, **dataclass_asdict(unpacker_config)) + packer_options = dataclass_asdict(packer_config) + unpacker_options = dataclass_asdict(unpacker_config) + + del packer_config, unpacker_config + + if unpacker_options.get("ext_hook") is None: + unpacker_options["ext_hook"] = msgpack.ExtType + + self.__packb = partial(msgpack.packb, **packer_options, autoreset=True) + self.__unpackb = partial(msgpack.unpackb, **unpacker_options) + self.__incremental_packer = partial(msgpack.Packer, **packer_options, autoreset=True) + self.__incremental_unpacker = partial(msgpack.Unpacker, **unpacker_options, max_buffer_size=limit) self.__unpack_out_of_data_cls = msgpack.OutOfData self.__unpack_extra_data_cls = msgpack.ExtraData - self.__debug: bool = bool(debug) - @final def serialize(self, packet: Any) -> bytes: """ @@ -172,11 +189,6 @@ def deserialize(self, data): """ try: return self.__unpackb(data) - except self.__unpack_out_of_data_cls as exc: - msg = "Missing data to create packet" - if self.debug: - raise DeserializeError(msg, error_info={"data": data}) from exc - raise DeserializeError(msg) from exc except self.__unpack_extra_data_cls as exc: msg = "Extra data caught" if self.debug: @@ -184,16 +196,54 @@ def deserialize(self, data): raise DeserializeError(msg) from exc except Exception as exc: # The documentation says to catch all exceptions :) msg = str(exc) or "Invalid token" + if isinstance(exc, ValueError) and "incomplete input" in msg: # <- And here is our "OutOfData" :) + msg = "Missing data to create packet" if self.debug: raise DeserializeError(msg, error_info={"data": data}) from exc raise DeserializeError(msg) from exc finally: del data - @property @final - def debug(self) -> bool: + def dump_to_file(self, packet: Any, file: IO[bytes]) -> None: """ - The debug mode flag. Read-only attribute. + Write the MessagePack representation of `packet` to `file`. + + Roughly equivalent to:: + + def dump_to_file(self, packet, file): + msgpack.pack(packet, file) + + Parameters: + packet: The Python object to serialize. + file: The :std:term:`binary file` to write to. """ - return self.__debug + file.write(self.__incremental_packer().pack(packet)) + + @final + def load_from_file(self, file: IO[bytes]) -> Any: + """ + Read from `file` to deserialize the raw MessagePack :term:`packet`. + + Roughly equivalent to:: + + def load_from_file(self, file): + return msgpack.unpack(file) + + Parameters: + file: The :std:term:`binary file` to read from. + + Returns: + the deserialized Python object. + """ + current_position = file.tell() + unpacker = self.__incremental_unpacker(file) + try: + packet = unpacker.unpack() + except self.__unpack_out_of_data_cls: + raise EOFError from None + else: + file.seek(current_position + unpacker.tell()) + finally: + del unpacker + return packet diff --git a/tests/functional_test/test_serializers/test_cbor.py b/tests/functional_test/test_serializers/test_cbor.py index d6d949d7..e023ecea 100644 --- a/tests/functional_test/test_serializers/test_cbor.py +++ b/tests/functional_test/test_serializers/test_cbor.py @@ -18,12 +18,15 @@ class TestCBORSerializer(BaseTestBufferedIncrementalSerializer, BaseTestSerializerExtraData): #### Serializers - ENCODER_CONFIG = CBOREncoderConfig() + @pytest.fixture(scope="class") + @staticmethod + def encoder_config() -> CBOREncoderConfig: + return CBOREncoderConfig() @pytest.fixture(scope="class") - @classmethod - def serializer_for_serialization(cls) -> CBORSerializer: - return CBORSerializer(encoder_config=cls.ENCODER_CONFIG) + @staticmethod + def serializer_for_serialization(encoder_config: CBOREncoderConfig) -> CBORSerializer: + return CBORSerializer(encoder_config=encoder_config) @pytest.fixture(scope="class") @staticmethod @@ -41,10 +44,10 @@ def packet_to_serialize(request: Any) -> Any: @pytest.fixture(scope="class") @classmethod - def expected_complete_data(cls, packet_to_serialize: Any) -> bytes: + def expected_complete_data(cls, packet_to_serialize: Any, encoder_config: CBOREncoderConfig) -> bytes: import cbor2 - return cbor2.dumps(packet_to_serialize, **dataclasses.asdict(cls.ENCODER_CONFIG)) + return cbor2.dumps(packet_to_serialize, **dataclasses.asdict(encoder_config)) #### Incremental Serialize diff --git a/tests/functional_test/test_serializers/test_json.py b/tests/functional_test/test_serializers/test_json.py index 06f2abd6..733fc605 100644 --- a/tests/functional_test/test_serializers/test_json.py +++ b/tests/functional_test/test_serializers/test_json.py @@ -23,12 +23,15 @@ class TestJSONSerializer(BaseTestIncrementalSerializer, BaseTestSerializerExtraData): #### Serializers - ENCODER_CONFIG = JSONEncoderConfig(ensure_ascii=False) - BUFFER_LIMIT = 14 * 1024 # 14KiB assert len(TOO_BIG_JSON_SERIALIZED) > BUFFER_LIMIT + @pytest.fixture(scope="class") + @staticmethod + def encoder_config() -> JSONEncoderConfig: + return JSONEncoderConfig(ensure_ascii=False) + @pytest.fixture(scope="class", params=[False, True], ids=lambda p: f"use_lines=={p}") @staticmethod def use_lines(request: Any) -> bool: @@ -36,8 +39,8 @@ def use_lines(request: Any) -> bool: @pytest.fixture(scope="class") @classmethod - def serializer_for_serialization(cls, use_lines: bool) -> JSONSerializer: - return JSONSerializer(encoder_config=cls.ENCODER_CONFIG, use_lines=use_lines, limit=cls.BUFFER_LIMIT) + def serializer_for_serialization(cls, encoder_config: JSONEncoderConfig, use_lines: bool) -> JSONSerializer: + return JSONSerializer(encoder_config=encoder_config, use_lines=use_lines, limit=cls.BUFFER_LIMIT) @pytest.fixture(scope="class") @classmethod @@ -55,8 +58,8 @@ def packet_to_serialize(request: Any) -> Any: @pytest.fixture(scope="class") @classmethod - def expected_complete_data(cls, packet_to_serialize: Any) -> bytes: - return json.dumps(packet_to_serialize, **dataclasses.asdict(cls.ENCODER_CONFIG), separators=(",", ":")).encode("utf-8") + def expected_complete_data(cls, packet_to_serialize: Any, encoder_config: JSONEncoderConfig) -> bytes: + return json.dumps(packet_to_serialize, **dataclasses.asdict(encoder_config), separators=(",", ":")).encode("utf-8") #### Incremental Serialize diff --git a/tests/functional_test/test_serializers/test_msgpack.py b/tests/functional_test/test_serializers/test_msgpack.py index 4b1a5cdb..d2914a3a 100644 --- a/tests/functional_test/test_serializers/test_msgpack.py +++ b/tests/functional_test/test_serializers/test_msgpack.py @@ -9,13 +9,13 @@ import pytest -from .base import BaseTestSerializerExtraData +from .base import BaseTestBufferedIncrementalSerializer, BaseTestSerializerExtraData from .samples.json import SAMPLES @final @pytest.mark.feature_msgpack -class TestMessagePackSerializer(BaseTestSerializerExtraData): +class TestMessagePackSerializer(BaseTestBufferedIncrementalSerializer, BaseTestSerializerExtraData): #### Serializers @pytest.fixture(scope="class") @@ -49,6 +49,13 @@ def expected_complete_data(cls, packet_to_serialize: Any, packer_config: Message return msgpack.packb(packet_to_serialize, **dataclasses.asdict(packer_config), autoreset=True) + #### Incremental Serialize + + @pytest.fixture(scope="class") + @staticmethod + def expected_joined_data(expected_complete_data: bytes) -> bytes: + return expected_complete_data + #### One-shot Deserialize @pytest.fixture(scope="class") @@ -58,9 +65,21 @@ def complete_data(packet_to_serialize: Any) -> bytes: return msgpack.packb(packet_to_serialize) + #### Incremental Deserialize + + @pytest.fixture(scope="class") + @staticmethod + def complete_data_for_incremental_deserialize(complete_data: bytes) -> bytes: + return complete_data + #### Invalid data @pytest.fixture(scope="class") @staticmethod def invalid_complete_data(complete_data: bytes) -> bytes: - return complete_data[:-1] # Missing data error + return complete_data[:-1] # Extra data error + + @pytest.fixture(scope="class") + @staticmethod + def invalid_partial_data_extra_data() -> tuple[bytes, bytes]: + return (b"remaining_data", b"") diff --git a/tests/unit_test/test_serializers/test_abc.py b/tests/unit_test/test_serializers/test_abc.py index 8fbaba95..49822119 100644 --- a/tests/unit_test/test_serializers/test_abc.py +++ b/tests/unit_test/test_serializers/test_abc.py @@ -916,10 +916,19 @@ def test____properties____right_values(self, debug_mode: bool) -> None: # Arrange # Act - serializer = _FileBasedPacketSerializerForTest(expected_load_error=(), debug=debug_mode) + serializer = _FileBasedPacketSerializerForTest(expected_load_error=(), debug=debug_mode, limit=123456789) # Assert assert serializer.debug is debug_mode + assert serializer.buffer_limit == 123456789 + + @pytest.mark.parametrize("limit", [0, -42], ids=lambda p: f"limit=={p}") + def test____dunder_init____invalid_limit(self, limit: int) -> None: + # Arrange + + # Act & Assert + with pytest.raises(ValueError, match=r"^limit must be a positive integer$"): + _ = _FileBasedPacketSerializerForTest(expected_load_error=(), limit=limit) def test____serialize____dump_to_file( self, @@ -1081,6 +1090,27 @@ def side_effect(_: Any, file: IO[bytes]) -> None: mock_dump_to_file_func.assert_called_once_with(mocker.sentinel.packet, mocker.ANY) assert data == [] + @pytest.mark.parametrize( + ["sizehint", "limit"], + [ + pytest.param(1024, 65536, id="lower_than_limit"), + pytest.param(65536, 65536, id="equal_to_limit"), + pytest.param(65536, 1024, id="greater_than_limit"), + ], + ) + def test____create_deserializer____sizehint_limit(self, sizehint: int, limit: int) -> None: + # Arrange + serializer = _FileBasedPacketSerializerForTest(expected_load_error=(), limit=limit) + + # Act + buffer = serializer.create_deserializer_buffer(sizehint) + + # Assert + if sizehint <= limit: + assert len(buffer) == sizehint + else: + assert len(buffer) == limit + def test____incremental_deserialize____load_from_file____read_all( self, incremental_deserialize_mode: Literal["data", "buffer"], @@ -1254,3 +1284,55 @@ def side_effect(file: IO[bytes]) -> Any: assert exception.error_info == {"data": b"data"} else: assert exception.error_info is None + + @pytest.mark.parametrize("at_first_chunk", [False, True], ids=lambda p: f"at_first_chunk=={p}") + def test____incremental_deserialize____load_from_file____buffer_limit_overrun( + self, + at_first_chunk: bool, + incremental_deserialize_mode: Literal["data", "buffer"], + mock_load_from_file_func: MagicMock, + ) -> None: + # Arrange + def side_effect(file: IO[bytes]) -> Any: + file.read() + raise EOFError + + serializer = _FileBasedPacketSerializerForTest(expected_load_error=(), limit=3) + mock_load_from_file_func.side_effect = side_effect + + # Act + match incremental_deserialize_mode: + case "data": + data_consumer = serializer.incremental_deserialize() + next(data_consumer) + if at_first_chunk: + with pytest.raises(LimitOverrunError) as exc_info: + data_consumer.send(b"data") + else: + data_consumer.send(b"d") + data_consumer.send(b"a") + data_consumer.send(b"t") + with pytest.raises(LimitOverrunError) as exc_info: + data_consumer.send(b"a") + case "buffer": + buffer = serializer.create_deserializer_buffer(1024) + buffered_consumer = serializer.buffered_incremental_deserialize(buffer) + next(buffered_consumer) + if at_first_chunk: + buffered_consumer.send(write_in_buffer(buffer, b"dat")) # <- The buffer size is set to 3 (because of "limit") + with pytest.raises(LimitOverrunError) as exc_info: + buffered_consumer.send(write_in_buffer(buffer, b"a")) + else: + buffered_consumer.send(write_in_buffer(buffer, b"d")) + buffered_consumer.send(write_in_buffer(buffer, b"a")) + buffered_consumer.send(write_in_buffer(buffer, b"t")) + with pytest.raises(LimitOverrunError) as exc_info: + buffered_consumer.send(write_in_buffer(buffer, b"a")) + case _: + pytest.fail("Invalid fixture argument") + + exception = exc_info.value + + # Assert + assert bytes(exception.remaining_data) == b"" + assert exception.consumed == 4 diff --git a/tests/unit_test/test_serializers/test_cbor.py b/tests/unit_test/test_serializers/test_cbor.py index f012a01e..21af3141 100644 --- a/tests/unit_test/test_serializers/test_cbor.py +++ b/tests/unit_test/test_serializers/test_cbor.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Any, final +from easynetwork.lowlevel.constants import DEFAULT_SERIALIZER_LIMIT from easynetwork.serializers.cbor import CBORDecoderConfig, CBOREncoderConfig, CBORSerializer import pytest @@ -88,7 +89,17 @@ def decoder_config(request: Any, mocker: MockerFixture) -> CBORDecoderConfig | N str_errors=mocker.sentinel.str_errors, ) - @pytest.mark.parametrize("method", ["serialize", "incremental_serialize", "deserialize", "incremental_deserialize"]) + @pytest.mark.parametrize( + "method", + [ + "serialize", + "incremental_serialize", + "deserialize", + "incremental_deserialize", + "create_deserializer_buffer", + "buffered_incremental_deserialize", + ], + ) def test____base_class____implements_default_methods(self, method: str) -> None: # Arrange from easynetwork.serializers.base_stream import FileBasedPacketSerializer @@ -96,14 +107,30 @@ def test____base_class____implements_default_methods(self, method: str) -> None: # Act & Assert assert getattr(CBORSerializer, method) is getattr(FileBasedPacketSerializer, method) - def test____properties____right_values(self, debug_mode: bool) -> None: + @pytest.mark.parametrize("limit", [147258369, None], ids=lambda p: f"limit=={p}") + def test____properties____right_values(self, debug_mode: bool, limit: int | None) -> None: # Arrange # Act - serializer = CBORSerializer(debug=debug_mode) + if limit is None: + serializer = CBORSerializer(debug=debug_mode) + else: + serializer = CBORSerializer(debug=debug_mode, limit=limit) # Assert assert serializer.debug is debug_mode + if limit is None: + assert serializer.buffer_limit == DEFAULT_SERIALIZER_LIMIT + else: + assert serializer.buffer_limit == limit + + @pytest.mark.parametrize("limit", [0, -42], ids=lambda p: f"limit=={p}") + def test____dunder_init____invalid_limit(self, limit: int) -> None: + # Arrange + + # Act & Assert + with pytest.raises(ValueError, match=r"^limit must be a positive integer$"): + CBORSerializer(limit=limit) def test____dump_to_file____with_config( self, diff --git a/tests/unit_test/test_serializers/test_compressor.py b/tests/unit_test/test_serializers/test_compressor.py index c4c6ada8..d3686bec 100644 --- a/tests/unit_test/test_serializers/test_compressor.py +++ b/tests/unit_test/test_serializers/test_compressor.py @@ -457,7 +457,17 @@ def test____incremental_deserialize____translate_deserialize_errors( class BaseTestCompressorSerializerImplementation: - @pytest.mark.parametrize("method", ["serialize", "incremental_serialize", "deserialize", "incremental_deserialize"]) + @pytest.mark.parametrize( + "method", + [ + "serialize", + "incremental_serialize", + "deserialize", + "incremental_deserialize", + "create_deserializer_buffer", + "buffered_incremental_deserialize", + ], + ) def test____base_class____implements_default_methods( self, serializer_cls: type[AbstractCompressorSerializer[Any, Any]], diff --git a/tests/unit_test/test_serializers/test_msgpack.py b/tests/unit_test/test_serializers/test_msgpack.py index 0ffae5a0..71cfcc2f 100644 --- a/tests/unit_test/test_serializers/test_msgpack.py +++ b/tests/unit_test/test_serializers/test_msgpack.py @@ -1,8 +1,10 @@ from __future__ import annotations +from io import BytesIO from typing import TYPE_CHECKING, Any, final from easynetwork.exceptions import DeserializeError +from easynetwork.lowlevel.constants import DEFAULT_SERIALIZER_LIMIT from easynetwork.serializers.msgpack import MessagePackerConfig, MessagePackSerializer, MessageUnpackerConfig import pytest @@ -40,6 +42,30 @@ def mock_packb(mocker: MockerFixture) -> MagicMock: def mock_unpackb(mocker: MockerFixture) -> MagicMock: return mocker.patch("msgpack.unpackb", autospec=True) + @pytest.fixture + @staticmethod + def mock_packer(mocker: MockerFixture) -> MagicMock: + from msgpack import Packer + + return mocker.NonCallableMagicMock(spec=Packer) + + @pytest.fixture(autouse=True) + @staticmethod + def mock_packer_cls(mock_packer: MagicMock, mocker: MockerFixture) -> MagicMock: + return mocker.patch("msgpack.Packer", return_value=mock_packer) + + @pytest.fixture + @staticmethod + def mock_unpacker(mocker: MockerFixture) -> MagicMock: + from msgpack import Unpacker + + return mocker.NonCallableMagicMock(spec=Unpacker, **{"tell.return_value": 0}) + + @pytest.fixture(autouse=True) + @staticmethod + def mock_unpacker_cls(mock_unpacker: MagicMock, mocker: MockerFixture) -> MagicMock: + return mocker.patch("msgpack.Unpacker", return_value=mock_unpacker) + @pytest.fixture(params=[True, False], ids=lambda boolean: f"default_packer_config=={boolean}") @staticmethod def packer_config(request: Any, mocker: MockerFixture) -> MessagePackerConfig | None: @@ -72,19 +98,44 @@ def unpacker_config(request: Any, mocker: MockerFixture) -> MessageUnpackerConfi ext_hook=mocker.sentinel.ext_hook, ) - def test____properties____right_values(self, debug_mode: bool) -> None: + @pytest.mark.parametrize( + "method", + [ + "incremental_serialize", + "incremental_deserialize", + "create_deserializer_buffer", + "buffered_incremental_deserialize", + ], + ) + def test____base_class____implements_default_methods(self, method: str) -> None: + # Arrange + from easynetwork.serializers.base_stream import FileBasedPacketSerializer + + # Act & Assert + assert getattr(MessagePackSerializer, method) is getattr(FileBasedPacketSerializer, method) + + @pytest.mark.parametrize("limit", [147258369, None], ids=lambda p: f"limit=={p}") + def test____properties____right_values(self, debug_mode: bool, limit: int | None) -> None: # Arrange # Act - serializer = MessagePackSerializer(debug=debug_mode) + if limit is None: + serializer = MessagePackSerializer(debug=debug_mode) + else: + serializer = MessagePackSerializer(debug=debug_mode, limit=limit) # Assert assert serializer.debug is debug_mode + if limit is None: + assert serializer.buffer_limit == DEFAULT_SERIALIZER_LIMIT + else: + assert serializer.buffer_limit == limit def test____serialize____with_config( self, packer_config: MessagePackerConfig | None, mock_packb: MagicMock, + mock_packer_cls: MagicMock, mocker: MockerFixture, ) -> None: # Arrange @@ -106,11 +157,13 @@ def test____serialize____with_config( unicode_errors=mocker.sentinel.unicode_errors if packer_config is not None else "strict", autoreset=True, ) + mock_packer_cls.assert_not_called() def test____deserialize____with_config( self, unpacker_config: MessageUnpackerConfig | None, mock_unpackb: MagicMock, + mock_unpacker_cls: MagicMock, mocker: MockerFixture, ) -> None: # Arrange @@ -135,6 +188,7 @@ def test____deserialize____with_config( object_pairs_hook=mocker.sentinel.object_pairs_hook if unpacker_config is not None else None, ext_hook=mocker.sentinel.ext_hook if unpacker_config is not None else msgpack.ExtType, ) + mock_unpacker_cls.assert_not_called() def test____deserialize____missing_data( self, @@ -143,17 +197,15 @@ def test____deserialize____missing_data( mocker: MockerFixture, ) -> None: # Arrange - import msgpack - serializer: MessagePackSerializer = MessagePackSerializer(debug=debug_mode) - mock_unpackb.side_effect = msgpack.OutOfData + mock_unpackb.side_effect = ValueError("Unpack failed: incomplete input") # Act & Assert with pytest.raises(DeserializeError, match=r"^Missing data to create packet$") as exc_info: serializer.deserialize(mocker.sentinel.data) # Assert - assert isinstance(exc_info.value.__cause__, msgpack.OutOfData) + assert isinstance(exc_info.value.__cause__, ValueError) if debug_mode: assert exc_info.value.error_info == {"data": mocker.sentinel.data} else: @@ -203,29 +255,132 @@ def test____deserialize____any_exception_occurs( else: assert exc_info.value.error_info is None + def test____dump_to_file____with_config( + self, + packer_config: MessagePackerConfig | None, + mock_packb: MagicMock, + mock_packer_cls: MagicMock, + mock_packer: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + serializer: MessagePackSerializer = MessagePackSerializer(packer_config=packer_config) + mock_packer.pack.side_effect = [b"msgpack-data"] + mock_packb.side_effect = AssertionError -class TestMessagePackSerializerDependencies: - def test____dunder_init____msgpack_missing( + # Act + file = BytesIO() + serializer.dump_to_file(mocker.sentinel.packet, file) + + # Assert + assert file.getvalue() == b"msgpack-data" + mock_packer_cls.assert_called_once_with( + default=mocker.sentinel.default if packer_config is not None else None, + use_single_float=mocker.sentinel.use_single_float if packer_config is not None else False, + use_bin_type=mocker.sentinel.use_bin_type if packer_config is not None else True, + datetime=mocker.sentinel.datetime if packer_config is not None else False, + strict_types=mocker.sentinel.strict_types if packer_config is not None else False, + unicode_errors=mocker.sentinel.unicode_errors if packer_config is not None else "strict", + autoreset=True, + ) + mock_packer.pack.assert_called_once_with(mocker.sentinel.packet) + mock_packb.assert_not_called() + + @pytest.mark.parametrize("limit", [147258369, DEFAULT_SERIALIZER_LIMIT], ids=lambda p: f"limit=={p}") + def test____load_from_file____with_config( self, + limit: int, + unpacker_config: MessageUnpackerConfig | None, + mock_unpackb: MagicMock, + mock_unpacker_cls: MagicMock, + mock_unpacker: MagicMock, mocker: MockerFixture, ) -> None: # Arrange - mock_import: MagicMock = mock_import_module_not_found({"msgpack"}, mocker) + import msgpack + + serializer: MessagePackSerializer = MessagePackSerializer(unpacker_config=unpacker_config, limit=limit) + mock_unpackb.side_effect = AssertionError + + def unpack_side_effect() -> Any: + file.read() + mock_unpacker.tell.return_value = file.tell() + return mocker.sentinel.packet + + mock_unpacker.unpack.side_effect = unpack_side_effect # Act - with pytest.raises(ModuleNotFoundError) as exc_info: - try: - _ = MessagePackSerializer() - finally: - mocker.stop(mock_import) + file = BytesIO(b"msgpack-data") + packet = serializer.load_from_file(file) # Assert - mock_import.assert_any_call("msgpack", mocker.ANY, mocker.ANY, None, 0) - assert exc_info.value.args[0] == "message-pack dependencies are missing. Consider adding 'msgpack' extra" - assert exc_info.value.__notes__ == ['example: pip install "easynetwork[msgpack]"'] - assert isinstance(exc_info.value.__cause__, ModuleNotFoundError) + assert packet is mocker.sentinel.packet + assert file.read() == b"" + mock_unpacker_cls.assert_called_once_with( + file, + raw=mocker.sentinel.raw if unpacker_config is not None else False, + use_list=mocker.sentinel.use_list if unpacker_config is not None else True, + timestamp=mocker.sentinel.timestamp if unpacker_config is not None else 0, + strict_map_key=mocker.sentinel.strict_map_key if unpacker_config is not None else True, + unicode_errors=mocker.sentinel.unicode_errors if unpacker_config is not None else "strict", + object_hook=mocker.sentinel.object_hook if unpacker_config is not None else None, + object_pairs_hook=mocker.sentinel.object_pairs_hook if unpacker_config is not None else None, + ext_hook=mocker.sentinel.ext_hook if unpacker_config is not None else msgpack.ExtType, + max_buffer_size=limit, + ) + mock_unpacker.unpack.assert_called_once_with() + mock_unpackb.assert_not_called() - def test____MessageUnpackerConfig____msgpack_missing( + def test____load_from_file____missing_data( + self, + debug_mode: bool, + mock_unpacker: MagicMock, + ) -> None: + # Arrange + import msgpack + + serializer: MessagePackSerializer = MessagePackSerializer(debug=debug_mode) + + def unpack_side_effect() -> Any: + file.read() + mock_unpacker.tell.return_value = 0 + raise msgpack.OutOfData + + mock_unpacker.unpack.side_effect = unpack_side_effect + + # Act & Assert + file = BytesIO(b"msgpack-data") + with pytest.raises(EOFError): + serializer.load_from_file(file) + assert file.read() == b"" + + def test____load_from_file____extra_data( + self, + debug_mode: bool, + mock_unpacker: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + serializer: MessagePackSerializer = MessagePackSerializer(debug=debug_mode) + + def unpack_side_effect() -> Any: + file.read() + mock_unpacker.tell.return_value = len(b"msgpack-data") + return mocker.sentinel.packet + + mock_unpacker.unpack.side_effect = unpack_side_effect + + # Act + file = BytesIO(b"".join([b"msgpack-data", b"remaining_data"])) + packet = serializer.load_from_file(file) + + # Assert + assert packet is mocker.sentinel.packet + assert file.read() == b"remaining_data" + + +class TestMessagePackSerializerDependencies: + def test____dunder_init____msgpack_missing( self, mocker: MockerFixture, ) -> None: @@ -235,12 +390,12 @@ def test____MessageUnpackerConfig____msgpack_missing( # Act with pytest.raises(ModuleNotFoundError) as exc_info: try: - _ = MessageUnpackerConfig() + _ = MessagePackSerializer() finally: mocker.stop(mock_import) # Assert - mock_import.assert_any_call("msgpack", mocker.ANY, mocker.ANY, ("ExtType",), 0) + mock_import.assert_any_call("msgpack", mocker.ANY, mocker.ANY, None, 0) assert exc_info.value.args[0] == "message-pack dependencies are missing. Consider adding 'msgpack' extra" assert exc_info.value.__notes__ == ['example: pip install "easynetwork[msgpack]"'] assert isinstance(exc_info.value.__cause__, ModuleNotFoundError) diff --git a/tests/unit_test/test_serializers/test_struct.py b/tests/unit_test/test_serializers/test_struct.py index 069299e7..2d883e2a 100644 --- a/tests/unit_test/test_serializers/test_struct.py +++ b/tests/unit_test/test_serializers/test_struct.py @@ -56,7 +56,15 @@ def mock_serializer_iter_values(mocker: MockerFixture) -> MagicMock: def mock_serializer_from_tuple(mocker: MockerFixture) -> MagicMock: return mocker.patch.object(_StructSerializerForTest, "from_tuple") - @pytest.mark.parametrize("method", ["incremental_serialize", "incremental_deserialize"]) + @pytest.mark.parametrize( + "method", + [ + "incremental_serialize", + "incremental_deserialize", + "create_deserializer_buffer", + "buffered_incremental_deserialize", + ], + ) def test____base_class____implements_default_methods(self, method: str) -> None: # Arrange from easynetwork.serializers.base_stream import FixedSizePacketSerializer @@ -157,7 +165,17 @@ def test____deserialize____translate_struct_errors( class TestNamedTupleStructSerializer(BaseTestStructBasedSerializer): - @pytest.mark.parametrize("method", ["serialize", "incremental_serialize", "deserialize", "incremental_deserialize"]) + @pytest.mark.parametrize( + "method", + [ + "serialize", + "incremental_serialize", + "deserialize", + "incremental_deserialize", + "create_deserializer_buffer", + "buffered_incremental_deserialize", + ], + ) def test____base_class____implements_default_methods(self, method: str) -> None: # Arrange @@ -471,7 +489,17 @@ def test____from_tuple____construct_namedtuple____string_padding( class TestStructSerializer(BaseTestStructBasedSerializer): - @pytest.mark.parametrize("method", ["serialize", "incremental_serialize", "deserialize", "incremental_deserialize"]) + @pytest.mark.parametrize( + "method", + [ + "serialize", + "incremental_serialize", + "deserialize", + "incremental_deserialize", + "create_deserializer_buffer", + "buffered_incremental_deserialize", + ], + ) def test____base_class____implements_default_methods(self, method: str) -> None: # Arrange