Skip to content

Commit

Permalink
Added missing functional tests for serializers (#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia authored Nov 27, 2023
1 parent 3a6dd4a commit 781739e
Show file tree
Hide file tree
Showing 18 changed files with 195 additions and 88 deletions.
2 changes: 2 additions & 0 deletions src/easynetwork/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def __init__(self, message: str, buffer: ReadableBuffer, consumed: int, separato
remaining_data = memoryview(buffer)[consumed:].tobytes()
if separator and remaining_data.startswith(separator):
remaining_data = remaining_data.removeprefix(separator)
else:
remaining_data = remaining_data[1:]

super().__init__(message, remaining_data, error_info=None)

Expand Down
12 changes: 6 additions & 6 deletions src/easynetwork/serializers/base_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,19 +132,19 @@ def incremental_deserialize(self) -> Generator[None, bytes, tuple[_DTOPacketT, b
"""
reader = GeneratorStreamReader()
data = yield from reader.read_until(self.__separator, limit=self.__limit, keep_end=False)
buffer = reader.read_all()
remainder = reader.read_all()

try:
packet = self.deserialize(data)
except DeserializeError as exc:
raise IncrementalDeserializeError(
f"Error when deserializing data: {exc}",
remaining_data=buffer,
remaining_data=remainder,
error_info=exc.error_info,
) from exc
finally:
del data
return packet, buffer
return packet, remainder

@property
@final
Expand Down Expand Up @@ -231,19 +231,19 @@ def incremental_deserialize(self) -> Generator[None, bytes, tuple[_DTOPacketT, b
"""
reader = GeneratorStreamReader()
data = yield from reader.read_exactly(self.__size)
buffer = reader.read_all()
remainder = reader.read_all()

try:
packet = self.deserialize(data)
except DeserializeError as exc:
raise IncrementalDeserializeError(
f"Error when deserializing data: {exc}",
remaining_data=buffer,
remaining_data=remainder,
error_info=exc.error_info,
) from exc
finally:
del data
return packet, buffer
return packet, remainder

@property
@final
Expand Down
4 changes: 2 additions & 2 deletions src/easynetwork/serializers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,21 +404,21 @@ def raw_parse(*, limit: int) -> Generator[None, bytes, tuple[bytes, bytes]]:
raise LimitOverrunError(
"JSON object's end frame is not found, and chunk exceed the limit",
partial_document,
nprint_idx,
len(partial_document),
)
partial_document += yield

return split_partial_document(partial_document, nprint_idx, limit)

@staticmethod
def _split_partial_document(partial_document: bytes, consumed: int, limit: int) -> tuple[bytes, bytes]:
consumed = _JSONParser._whitespaces_match(partial_document, consumed).end()
if consumed > limit:
raise LimitOverrunError(
"JSON object's end frame is found, but chunk is longer than limit",
partial_document,
consumed,
)
consumed = _JSONParser._whitespaces_match(partial_document, consumed).end()
if consumed == len(partial_document):
# The following bytes are only spaces
# Do not slice the document, the trailing spaces will be ignored by JSONDecoder
Expand Down
3 changes: 3 additions & 0 deletions src/easynetwork/serializers/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def deserialize(self, data):
Raises:
DeserializeError: :class:`UnicodeError` raised when decoding `data`.
DeserializeError: Newline found in `data` (excluding those at the end of the sequence).
Returns:
the string.
Expand All @@ -139,6 +140,8 @@ def deserialize(self, data):
separator: bytes = self.separator
while data.endswith(separator):
data = data.removesuffix(separator)
if separator in data:
raise DeserializeError("Newline found in string")
try:
return data.decode(self.__encoding, self.__unicode_errors)
except UnicodeError as exc:
Expand Down
15 changes: 9 additions & 6 deletions tests/functional_test/test_serializers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,14 @@ def invalid_partial_data() -> bytes:

@pytest.fixture(scope="class")
@staticmethod
def invalid_partial_data_extra_data() -> bytes | None:
def invalid_partial_data_extra_data() -> bytes:
return b"remaining_data"

@pytest.fixture(scope="class")
@staticmethod
def invalid_partial_data_expected_extra_data(invalid_partial_data_extra_data: bytes) -> bytes:
return invalid_partial_data_extra_data

def test____fixture____consistency____incremental_serializer(
self,
serializer_for_serialization: AbstractIncrementalPacketSerializer[Any],
Expand Down Expand Up @@ -223,7 +228,8 @@ def test____incremental_deserialize____invalid_data(
self,
serializer_for_deserialization: AbstractIncrementalPacketSerializer[Any],
invalid_partial_data: bytes,
invalid_partial_data_extra_data: bytes | None,
invalid_partial_data_extra_data: bytes,
invalid_partial_data_expected_extra_data: bytes,
) -> None:
# Arrange
consumer = serializer_for_deserialization.incremental_deserialize()
Expand All @@ -238,10 +244,7 @@ def test____incremental_deserialize____invalid_data(
exception = exc_info.value

# Assert
if invalid_partial_data_extra_data is not None:
assert exception.remaining_data == invalid_partial_data_extra_data
else:
assert len(exception.remaining_data) > 0
assert exception.remaining_data == invalid_partial_data_expected_extra_data


@final
Expand Down
2 changes: 1 addition & 1 deletion tests/functional_test/test_serializers/samples/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
]


BIG_JSON: Any = [
BIG_JSON: list[Any] = [
{
"_id": "63cd615fa31a400f255ec20c",
"index": 0,
Expand Down
39 changes: 27 additions & 12 deletions tests/functional_test/test_serializers/test_base64.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from __future__ import annotations

import base64
import hashlib
import hmac
import random
from typing import Any, Literal, final

Expand All @@ -14,9 +17,6 @@


def generate_key_from_string(s: str) -> bytes:
import base64
import hashlib

return base64.urlsafe_b64encode(hashlib.sha256(s.encode("utf-8")).digest())


Expand All @@ -30,6 +30,8 @@ def generate_key_from_string(s: str) -> bytes:
class BaseTestBase64EncoderSerializer(BaseTestIncrementalSerializer):
#### Serializers

BUFFER_LIMIT = 1024

@pytest.fixture(scope="class", params=["standard", "urlsafe"])
@staticmethod
def alphabet(request: pytest.FixtureRequest) -> Literal["standard", "urlsafe"]:
Expand All @@ -42,7 +44,7 @@ def serializer(
checksum: bool | bytes,
alphabet: Literal["standard", "urlsafe"],
) -> Base64EncoderSerializer[bytes]:
return Base64EncoderSerializer(NoSerialization(), alphabet=alphabet, checksum=checksum)
return Base64EncoderSerializer(NoSerialization(), alphabet=alphabet, checksum=checksum, limit=cls.BUFFER_LIMIT)

@pytest.fixture(scope="class")
@staticmethod
Expand Down Expand Up @@ -71,10 +73,6 @@ def expected_complete_data(
checksum: bool | bytes,
alphabet: Literal["standard", "urlsafe"],
) -> bytes:
import base64
import hashlib
import hmac

if checksum:
if isinstance(checksum, bytes):
key = base64.urlsafe_b64decode(checksum)
Expand Down Expand Up @@ -114,10 +112,27 @@ def complete_data_for_incremental_deserialize(complete_data: bytes) -> bytes:
def invalid_complete_data(complete_data: bytes) -> bytes:
return complete_data[:-1] # Remove one byte at last will break the padding

@pytest.fixture
@staticmethod
def invalid_partial_data() -> bytes:
pytest.skip("Cannot be tested")
@pytest.fixture(scope="class", params=["missing_data", "limit_overrun_without_newline", "limit_overrun_with_newline"])
@classmethod
def invalid_partial_data(cls, request: pytest.FixtureRequest, alphabet: Literal["standard", "urlsafe"]) -> bytes:
match request.param:
case "missing_data":
if alphabet == "standard":
return base64.standard_b64encode(random.randbytes(255))[:-1] + b"\r\n"
return base64.urlsafe_b64encode(random.randbytes(255))[:-1] + b"\r\n"
case "limit_overrun_without_newline":
return b"4" * (cls.BUFFER_LIMIT + 10)
case "limit_overrun_with_newline":
return b"4" * (cls.BUFFER_LIMIT + 10) + b"\r\n"
case _:
pytest.fail("Invalid fixture parameter")

@pytest.fixture(scope="class")
@classmethod
def invalid_partial_data_extra_data(cls, invalid_partial_data: bytes) -> bytes:
if len(invalid_partial_data) > cls.BUFFER_LIMIT:
return b""
return b"remaining_data"

#### Other

Expand Down
2 changes: 1 addition & 1 deletion tests/functional_test/test_serializers/test_cbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def complete_data_for_incremental_deserialize(complete_data: bytes) -> bytes:
def invalid_complete_data(complete_data: bytes) -> bytes:
return complete_data[:-1] # Missing data error

@pytest.fixture
@pytest.fixture(scope="class")
@staticmethod
def invalid_partial_data() -> bytes:
pytest.skip("Cannot be tested")
17 changes: 13 additions & 4 deletions tests/functional_test/test_serializers/test_compressors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
]


def _make_data_invalid(token: bytes) -> bytes:
return token[:-2] + random.randbytes(5) + token[-2:]


class BaseTestCompressorSerializer(BaseTestIncrementalSerializer):
#### Serializers: To be defined in subclass

Expand Down Expand Up @@ -53,12 +57,17 @@ def complete_data_for_incremental_deserialize(complete_data: bytes) -> bytes:
@pytest.fixture(scope="class")
@staticmethod
def invalid_complete_data(complete_data: bytes) -> bytes:
return complete_data[:-1] # Remove one byte at last will break the checksum
return _make_data_invalid(complete_data)

@pytest.fixture(scope="class")
@staticmethod
def invalid_partial_data(invalid_complete_data: bytes) -> bytes:
return invalid_complete_data

@pytest.fixture
@pytest.fixture(scope="class")
@staticmethod
def invalid_partial_data() -> bytes:
pytest.skip("Cannot be tested")
def invalid_partial_data_expected_extra_data() -> bytes:
return b""


@final
Expand Down
29 changes: 24 additions & 5 deletions tests/functional_test/test_serializers/test_encryptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ class TestEncryptorSerializer(BaseTestIncrementalSerializer):

KEY = generate_key_from_string("key")

BUFFER_LIMIT = 1024

@pytest.fixture(scope="class")
@classmethod
def serializer(cls) -> EncryptorSerializer[bytes]:
return EncryptorSerializer(NoSerialization(), key=cls.KEY)
return EncryptorSerializer(NoSerialization(), key=cls.KEY, limit=cls.BUFFER_LIMIT)

@pytest.fixture(scope="class")
@staticmethod
Expand Down Expand Up @@ -96,10 +98,27 @@ def invalid_complete_data(complete_data: bytes) -> bytes:
pytest.skip("empty bytes")
return complete_data[:-1] # Remove one byte at last will break the padding

@pytest.fixture
@staticmethod
def invalid_partial_data() -> bytes:
pytest.skip("Cannot be tested")
@pytest.fixture(scope="class", params=["missing_data", "limit_overrun_without_newline", "limit_overrun_with_newline"])
@classmethod
def invalid_partial_data(cls, request: pytest.FixtureRequest) -> bytes:
match request.param:
case "missing_data":
from cryptography.fernet import Fernet

return Fernet(cls.KEY).encrypt_at_time(b"a", 0)[:-1] + b"\r\n"
case "limit_overrun_without_newline":
return b"4" * (cls.BUFFER_LIMIT + 10)
case "limit_overrun_with_newline":
return b"4" * (cls.BUFFER_LIMIT + 10) + b"\r\n"
case _:
pytest.fail("Invalid fixture parameter")

@pytest.fixture(scope="class")
@classmethod
def invalid_partial_data_extra_data(cls, invalid_partial_data: bytes) -> bytes:
if len(invalid_partial_data) > cls.BUFFER_LIMIT:
return b""
return b"remaining_data"

#### Other

Expand Down
Loading

0 comments on commit 781739e

Please sign in to comment.