diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index 99a578478..f02872701 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -17,14 +17,16 @@ timedelta, timezone, ) +from io import BytesIO +from itertools import count from typing import ( TYPE_CHECKING, Any, + BinaryIO, Callable, Dict, Generator, Iterable, - List, Mapping, Optional, Set, @@ -46,6 +48,10 @@ from .grpc.grpclib_client import ServiceStub +if TYPE_CHECKING: + from _typeshed import ReadableBuffer + + # Proto 3 data types TYPE_ENUM = "enum" TYPE_BOOL = "bool" @@ -66,7 +72,6 @@ TYPE_MESSAGE = "message" TYPE_MAP = "map" - # Fields that use a fixed amount of space (4 or 8 bytes) FIXED_TYPES = [ TYPE_FLOAT, @@ -129,7 +134,6 @@ def datetime_default_gen() -> datetime: DATETIME_ZERO = datetime_default_gen() - # Special protobuf json doubles INFINITY = "Infinity" NEG_INFINITY = "-Infinity" @@ -343,20 +347,43 @@ def _pack_fmt(proto_type: str) -> str: }[proto_type] -def encode_varint(value: int) -> bytes: - """Encodes a single varint value for serialization.""" - b: List[int] = [] - - if value < 0: +def dump_varint(value: int, stream: BinaryIO) -> None: + """Encodes a single varint and dumps it into the provided stream.""" + if value < -(1 << 63): + raise ValueError( + "Negative value is not representable as a 64-bit integer - unable to encode a varint within 10 bytes." + ) + elif value < 0: value += 1 << 64 bits = value & 0x7F value >>= 7 while value: - b.append(0x80 | bits) + stream.write((0x80 | bits).to_bytes(1, "little")) bits = value & 0x7F value >>= 7 - return bytes(b + [bits]) + stream.write(bits.to_bytes(1, "little")) + + +def encode_varint(value: int) -> bytes: + """Encodes a single varint value for serialization.""" + with BytesIO() as stream: + dump_varint(value, stream) + return stream.getvalue() + + +def size_varint(value: int) -> int: + """Calculates the size in bytes that a value would take as a varint.""" + if value < -(1 << 63): + raise ValueError( + "Negative value is not representable as a 64-bit integer - unable to encode a varint within 10 bytes." + ) + elif value < 0: + return 10 + elif value == 0: + return 1 + else: + return math.ceil(value.bit_length() / 7) def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes: @@ -394,6 +421,41 @@ def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes: return value +def _len_preprocessed_single(proto_type: str, wraps: str, value: Any) -> int: + """Calculate the size of adjusted values for serialization without fully serializing them.""" + if proto_type in ( + TYPE_ENUM, + TYPE_BOOL, + TYPE_INT32, + TYPE_INT64, + TYPE_UINT32, + TYPE_UINT64, + ): + return size_varint(value) + elif proto_type in (TYPE_SINT32, TYPE_SINT64): + # Handle zig-zag encoding. + return size_varint(value << 1 if value >= 0 else (value << 1) ^ (~0)) + elif proto_type in FIXED_TYPES: + return len(struct.pack(_pack_fmt(proto_type), value)) + elif proto_type == TYPE_STRING: + return len(value.encode("utf-8")) + elif proto_type == TYPE_MESSAGE: + if isinstance(value, datetime): + # Convert the `datetime` to a timestamp message. + value = _Timestamp.from_datetime(value) + elif isinstance(value, timedelta): + # Convert the `timedelta` to a duration message. + value = _Duration.from_timedelta(value) + elif wraps: + if value is None: + return 0 + value = _get_wrapper(wraps)(value=value) + + return len(bytes(value)) + + return len(value) + + def _serialize_single( field_number: int, proto_type: str, @@ -425,6 +487,31 @@ def _serialize_single( return bytes(output) +def _len_single( + field_number: int, + proto_type: str, + value: Any, + *, + serialize_empty: bool = False, + wraps: str = "", +) -> int: + """Calculates the size of a serialized single field and value.""" + size = _len_preprocessed_single(proto_type, wraps, value) + if proto_type in WIRE_VARINT_TYPES: + size += size_varint(field_number << 3) + elif proto_type in WIRE_FIXED_32_TYPES: + size += size_varint((field_number << 3) | 5) + elif proto_type in WIRE_FIXED_64_TYPES: + size += size_varint((field_number << 3) | 1) + elif proto_type in WIRE_LEN_DELIM_TYPES: + if size or serialize_empty or wraps: + size += size_varint((field_number << 3) | 2) + size_varint(size) + else: + raise NotImplementedError(proto_type) + + return size + + def _parse_float(value: Any) -> float: """Parse the given value to a float @@ -469,22 +556,34 @@ def _dump_float(value: float) -> Union[float, str]: return value -def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]: +def load_varint(stream: BinaryIO) -> Tuple[int, bytes]: """ - Decode a single varint value from a byte buffer. Returns the value and the - new position in the buffer. + Load a single varint value from a stream. Returns the value and the raw bytes read. """ result = 0 - shift = 0 - while 1: - b = buffer[pos] - result |= (b & 0x7F) << shift - pos += 1 - if not (b & 0x80): - return result, pos - shift += 7 + raw = b"" + for shift in count(0, 7): if shift >= 64: raise ValueError("Too many bytes when decoding varint.") + b = stream.read(1) + if not b: + raise EOFError("Stream ended unexpectedly while attempting to load varint.") + raw += b + b_int = int.from_bytes(b, byteorder="little") + result |= (b_int & 0x7F) << shift + if not (b_int & 0x80): + return result, raw + + +def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]: + """ + Decode a single varint value from a byte buffer. Returns the value and the + new position in the buffer. + """ + with BytesIO(buffer) as stream: + stream.seek(pos) + value, raw = load_varint(stream) + return value, pos + len(raw) @dataclasses.dataclass(frozen=True) @@ -495,6 +594,34 @@ class ParsedField: raw: bytes +def load_fields(stream: BinaryIO) -> Generator[ParsedField, None, None]: + while True: + try: + num_wire, raw = load_varint(stream) + except EOFError: + return + number = num_wire >> 3 + wire_type = num_wire & 0x7 + + decoded: Any = None + if wire_type == WIRE_VARINT: + decoded, r = load_varint(stream) + raw += r + elif wire_type == WIRE_FIXED_64: + decoded = stream.read(8) + raw += decoded + elif wire_type == WIRE_LEN_DELIM: + length, r = load_varint(stream) + decoded = stream.read(length) + raw += r + raw += decoded + elif wire_type == WIRE_FIXED_32: + decoded = stream.read(4) + raw += decoded + + yield ParsedField(number=number, wire_type=wire_type, value=decoded, raw=raw) + + def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: i = 0 while i < len(value): @@ -775,11 +902,16 @@ def _betterproto(self) -> ProtoClassMetadata: self.__class__._betterproto_meta = meta # type: ignore return meta - def __bytes__(self) -> bytes: + def dump(self, stream: BinaryIO) -> None: """ - Get the binary encoded Protobuf representation of this message instance. + Dumps the binary encoded Protobuf message to the stream. + + Parameters + ----------- + stream: :class:`BinaryIO` + The stream to dump the message to. """ - output = bytearray() + for field_name, meta in self._betterproto.meta_by_field_name.items(): try: value = getattr(self, field_name) @@ -825,10 +957,10 @@ def __bytes__(self) -> bytes: buf = bytearray() for item in value: buf += _preprocess_single(meta.proto_type, "", item) - output += _serialize_single(meta.number, TYPE_BYTES, buf) + stream.write(_serialize_single(meta.number, TYPE_BYTES, buf)) else: for item in value: - output += ( + stream.write( _serialize_single( meta.number, meta.proto_type, @@ -846,7 +978,9 @@ def __bytes__(self) -> bytes: assert meta.map_types sk = _serialize_single(1, meta.map_types[0], k) sv = _serialize_single(2, meta.map_types[1], v) - output += _serialize_single(meta.number, meta.proto_type, sk + sv) + stream.write( + _serialize_single(meta.number, meta.proto_type, sk + sv) + ) else: # If we have an empty string and we're including the default value for # a oneof, make sure we serialize it. This ensures that the byte string @@ -859,7 +993,111 @@ def __bytes__(self) -> bytes: ): serialize_empty = True - output += _serialize_single( + stream.write( + _serialize_single( + meta.number, + meta.proto_type, + value, + serialize_empty=serialize_empty or bool(selected_in_group), + wraps=meta.wraps or "", + ) + ) + + stream.write(self._unknown_fields) + + def __bytes__(self) -> bytes: + """ + Get the binary encoded Protobuf representation of this message instance. + """ + with BytesIO() as stream: + self.dump(stream) + return stream.getvalue() + + def __len__(self) -> int: + """ + Get the size of the encoded Protobuf representation of this message instance. + """ + size = 0 + for field_name, meta in self._betterproto.meta_by_field_name.items(): + try: + value = getattr(self, field_name) + except AttributeError: + continue + + if value is None: + # Optional items should be skipped. This is used for the Google + # wrapper types and proto3 field presence/optional fields. + continue + + # Being selected in a group means this field is the one that is + # currently set in a `oneof` group, so it must be serialized even + # if the value is the default zero value. + # + # Note that proto3 field presence/optional fields are put in a + # synthetic single-item oneof by protoc, which helps us ensure we + # send the value even if the value is the default zero value. + selected_in_group = bool(meta.group) + + # Empty messages can still be sent on the wire if they were + # set (or received empty). + serialize_empty = isinstance(value, Message) and value._serialized_on_wire + + include_default_value_for_oneof = self._include_default_value_for_oneof( + field_name=field_name, meta=meta + ) + + if value == self._get_field_default(field_name) and not ( + selected_in_group or serialize_empty or include_default_value_for_oneof + ): + # Default (zero) values are not serialized. Two exceptions are + # if this is the selected oneof item or if we know we have to + # serialize an empty message (i.e. zero value was explicitly + # set by the user). + continue + + if isinstance(value, list): + if meta.proto_type in PACKED_TYPES: + # Packed lists look like a length-delimited field. First, + # preprocess/encode each value into a buffer and then + # treat it like a field of raw bytes. + buf = bytearray() + for item in value: + buf += _preprocess_single(meta.proto_type, "", item) + size += _len_single(meta.number, TYPE_BYTES, buf) + else: + for item in value: + size += ( + _len_single( + meta.number, + meta.proto_type, + item, + wraps=meta.wraps or "", + serialize_empty=True, + ) + # if it's an empty message it still needs to be represented + # as an item in the repeated list + or 2 + ) + + elif isinstance(value, dict): + for k, v in value.items(): + assert meta.map_types + sk = _serialize_single(1, meta.map_types[0], k) + sv = _serialize_single(2, meta.map_types[1], v) + size += _len_single(meta.number, meta.proto_type, sk + sv) + else: + # If we have an empty string and we're including the default value for + # a oneof, make sure we serialize it. This ensures that the byte string + # output isn't simply an empty string. This also ensures that round trip + # serialization will keep `which_one_of` calls consistent. + if ( + isinstance(value, str) + and value == "" + and include_default_value_for_oneof + ): + serialize_empty = True + + size += _len_single( meta.number, meta.proto_type, value, @@ -867,8 +1105,8 @@ def __bytes__(self) -> bytes: wraps=meta.wraps or "", ) - output += self._unknown_fields - return bytes(output) + size += len(self._unknown_fields) + return size # For compatibility with other libraries def SerializeToString(self: T) -> bytes: @@ -987,15 +1225,18 @@ def _include_default_value_for_oneof( meta.group is not None and self._group_current.get(meta.group) == field_name ) - def parse(self: T, data: bytes) -> T: + def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T: """ - Parse the binary encoded Protobuf into this message instance. This + Load the binary encoded Protobuf from a stream into this message instance. This returns the instance itself and is therefore assignable and chainable. Parameters ----------- - data: :class:`bytes` - The data to parse the protobuf from. + stream: :class:`bytes` + The stream to load the message from. + size: :class:`Optional[int]` + The size of the message in the stream. + Reads stream until EOF if ``None`` is given. Returns -------- @@ -1005,7 +1246,8 @@ def parse(self: T, data: bytes) -> T: # Got some data over the wire self._serialized_on_wire = True proto_meta = self._betterproto - for parsed in parse_fields(data): + read = 0 + for parsed in load_fields(stream): field_name = proto_meta.field_name_by_number.get(parsed.number) if not field_name: self._unknown_fields += parsed.raw @@ -1051,8 +1293,46 @@ def parse(self: T, data: bytes) -> T: else: setattr(self, field_name, value) + # If we have now loaded the expected length of the message, stop + if size is not None: + prev = read + read += len(parsed.raw) + if read == size: + break + elif read > size: + raise ValueError( + f"Expected message of size {size}, can only read " + f"either {prev} or {read} bytes - there is no " + "message of the expected size in the stream." + ) + + if size is not None and read < size: + raise ValueError( + f"Expected message of size {size}, but was only able to " + f"read {read} bytes - the stream may have ended too soon," + " or the expected size may have been incorrect." + ) + return self + def parse(self: T, data: "ReadableBuffer") -> T: + """ + Parse the binary encoded Protobuf into this message instance. This + returns the instance itself and is therefore assignable and chainable. + + Parameters + ----------- + data: :class:`bytes` + The data to parse the message from. + + Returns + -------- + :class:`Message` + The initialized message. + """ + with BytesIO(data) as stream: + return self.load(stream) + # For compatibility with other libraries. @classmethod def FromString(cls: Type[T], data: bytes) -> T: diff --git a/tests/streams/dump_varint_negative.expected b/tests/streams/dump_varint_negative.expected new file mode 100644 index 000000000..095482297 --- /dev/null +++ b/tests/streams/dump_varint_negative.expected @@ -0,0 +1 @@ +ӝ \ No newline at end of file diff --git a/tests/streams/dump_varint_positive.expected b/tests/streams/dump_varint_positive.expected new file mode 100644 index 000000000..8614b9d7a --- /dev/null +++ b/tests/streams/dump_varint_positive.expected @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/tests/streams/load_varint_cutoff.in b/tests/streams/load_varint_cutoff.in new file mode 100644 index 000000000..52b9bf1e1 --- /dev/null +++ b/tests/streams/load_varint_cutoff.in @@ -0,0 +1 @@ +ȁ \ No newline at end of file diff --git a/tests/streams/message_dump_file_multiple.expected b/tests/streams/message_dump_file_multiple.expected new file mode 100644 index 000000000..b5fdf9c30 --- /dev/null +++ b/tests/streams/message_dump_file_multiple.expected @@ -0,0 +1,2 @@ +:bTesting:bTesting +  \ No newline at end of file diff --git a/tests/streams/message_dump_file_single.expected b/tests/streams/message_dump_file_single.expected new file mode 100644 index 000000000..9b7bafb6a --- /dev/null +++ b/tests/streams/message_dump_file_single.expected @@ -0,0 +1 @@ +:bTesting \ No newline at end of file diff --git a/tests/test_streams.py b/tests/test_streams.py new file mode 100644 index 000000000..a1c2bbd98 --- /dev/null +++ b/tests/test_streams.py @@ -0,0 +1,268 @@ +from dataclasses import dataclass +from io import BytesIO +from pathlib import Path +from typing import Optional + +import pytest + +import betterproto +from tests.output_betterproto import ( + map, + nested, + oneof, + repeated, + repeatedpacked, +) + + +oneof_example = oneof.Test().from_dict( + {"pitied": 1, "just_a_regular_field": 123456789, "bar_name": "Testing"} +) + +len_oneof = len(oneof_example) + +nested_example = nested.Test().from_dict( + { + "nested": {"count": 1}, + "sibling": {"foo": 2}, + "sibling2": {"foo": 3}, + "msg": nested.TestMsg.THIS, + } +) + +repeated_example = repeated.Test().from_dict({"names": ["blah", "Blah2"]}) + +packed_example = repeatedpacked.Test().from_dict( + {"counts": [1, 2, 3], "signed": [-1, 2, -3], "fixed": [1.2, -2.3, 3.4]} +) + +map_example = map.Test().from_dict({"counts": {"blah": 1, "Blah2": 2}}) + +streams_path = Path("tests/streams/") + + +def test_load_varint_too_long(): + with BytesIO( + b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01" + ) as stream, pytest.raises(ValueError): + betterproto.load_varint(stream) + + with BytesIO(b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01") as stream: + # This should not raise a ValueError, as it is within 64 bits + betterproto.load_varint(stream) + + +def test_load_varint_file(): + with open(streams_path / "message_dump_file_single.expected", "rb") as stream: + assert betterproto.load_varint(stream) == (8, b"\x08") # Single-byte varint + stream.read(2) # Skip until first multi-byte + assert betterproto.load_varint(stream) == ( + 123456789, + b"\x95\x9A\xEF\x3A", + ) # Multi-byte varint + + +def test_load_varint_cutoff(): + with open(streams_path / "load_varint_cutoff.in", "rb") as stream: + with pytest.raises(EOFError): + betterproto.load_varint(stream) + + stream.seek(1) + with pytest.raises(EOFError): + betterproto.load_varint(stream) + + +def test_dump_varint_file(tmp_path): + # Dump test varints to file + with open(tmp_path / "dump_varint_file.out", "wb") as stream: + betterproto.dump_varint(8, stream) # Single-byte varint + betterproto.dump_varint(123456789, stream) # Multi-byte varint + + # Check that file contents are as expected + with open(tmp_path / "dump_varint_file.out", "rb") as test_stream, open( + streams_path / "message_dump_file_single.expected", "rb" + ) as exp_stream: + assert betterproto.load_varint(test_stream) == betterproto.load_varint( + exp_stream + ) + exp_stream.read(2) + assert betterproto.load_varint(test_stream) == betterproto.load_varint( + exp_stream + ) + + +def test_parse_fields(): + with open(streams_path / "message_dump_file_single.expected", "rb") as stream: + parsed_bytes = betterproto.parse_fields(stream.read()) + + with open(streams_path / "message_dump_file_single.expected", "rb") as stream: + parsed_stream = betterproto.load_fields(stream) + for field in parsed_bytes: + assert field == next(parsed_stream) + + +def test_message_dump_file_single(tmp_path): + # Write the message to the stream + with open(tmp_path / "message_dump_file_single.out", "wb") as stream: + oneof_example.dump(stream) + + # Check that the outputted file is exactly as expected + with open(tmp_path / "message_dump_file_single.out", "rb") as test_stream, open( + streams_path / "message_dump_file_single.expected", "rb" + ) as exp_stream: + assert test_stream.read() == exp_stream.read() + + +def test_message_dump_file_multiple(tmp_path): + # Write the same Message twice and another, different message + with open(tmp_path / "message_dump_file_multiple.out", "wb") as stream: + oneof_example.dump(stream) + oneof_example.dump(stream) + nested_example.dump(stream) + + # Check that all three Messages were outputted to the file correctly + with open(tmp_path / "message_dump_file_multiple.out", "rb") as test_stream, open( + streams_path / "message_dump_file_multiple.expected", "rb" + ) as exp_stream: + assert test_stream.read() == exp_stream.read() + + +def test_message_len(): + assert len_oneof == len(bytes(oneof_example)) + assert len(nested_example) == len(bytes(nested_example)) + + +def test_message_load_file_single(): + with open(streams_path / "message_dump_file_single.expected", "rb") as stream: + assert oneof.Test().load(stream) == oneof_example + stream.seek(0) + assert oneof.Test().load(stream, len_oneof) == oneof_example + + +def test_message_load_file_multiple(): + with open(streams_path / "message_dump_file_multiple.expected", "rb") as stream: + oneof_size = len_oneof + assert oneof.Test().load(stream, oneof_size) == oneof_example + assert oneof.Test().load(stream, oneof_size) == oneof_example + assert nested.Test().load(stream) == nested_example + assert stream.read(1) == b"" + + +def test_message_load_too_small(): + with open( + streams_path / "message_dump_file_single.expected", "rb" + ) as stream, pytest.raises(ValueError): + oneof.Test().load(stream, len_oneof - 1) + + +def test_message_too_large(): + with open( + streams_path / "message_dump_file_single.expected", "rb" + ) as stream, pytest.raises(ValueError): + oneof.Test().load(stream, len_oneof + 1) + + +def test_message_len_optional_field(): + @dataclass + class Request(betterproto.Message): + flag: Optional[bool] = betterproto.message_field(1, wraps=betterproto.TYPE_BOOL) + + assert len(Request()) == len(b"") + assert len(Request(flag=True)) == len(b"\n\x02\x08\x01") + assert len(Request(flag=False)) == len(b"\n\x00") + + +def test_message_len_repeated_field(): + assert len(repeated_example) == len(bytes(repeated_example)) + + +def test_message_len_packed_field(): + assert len(packed_example) == len(bytes(packed_example)) + + +def test_message_len_map_field(): + assert len(map_example) == len(bytes(map_example)) + + +def test_message_len_empty_string(): + @dataclass + class Empty(betterproto.Message): + string: str = betterproto.string_field(1, "group") + integer: int = betterproto.int32_field(2, "group") + + empty = Empty().from_dict({"string": ""}) + assert len(empty) == len(bytes(empty)) + + +def test_calculate_varint_size_negative(): + single_byte = -1 + multi_byte = -10000000 + edge = -(1 << 63) + beyond = -(1 << 63) - 1 + before = -(1 << 63) + 1 + + assert ( + betterproto.size_varint(single_byte) + == len(betterproto.encode_varint(single_byte)) + == 10 + ) + assert ( + betterproto.size_varint(multi_byte) + == len(betterproto.encode_varint(multi_byte)) + == 10 + ) + assert betterproto.size_varint(edge) == len(betterproto.encode_varint(edge)) == 10 + assert ( + betterproto.size_varint(before) == len(betterproto.encode_varint(before)) == 10 + ) + + with pytest.raises(ValueError): + betterproto.size_varint(beyond) + + +def test_calculate_varint_size_positive(): + single_byte = 1 + multi_byte = 10000000 + + assert betterproto.size_varint(single_byte) == len( + betterproto.encode_varint(single_byte) + ) + assert betterproto.size_varint(multi_byte) == len( + betterproto.encode_varint(multi_byte) + ) + + +def test_dump_varint_negative(tmp_path): + single_byte = -1 + multi_byte = -10000000 + edge = -(1 << 63) + beyond = -(1 << 63) - 1 + before = -(1 << 63) + 1 + + with open(tmp_path / "dump_varint_negative.out", "wb") as stream: + betterproto.dump_varint(single_byte, stream) + betterproto.dump_varint(multi_byte, stream) + betterproto.dump_varint(edge, stream) + betterproto.dump_varint(before, stream) + + with pytest.raises(ValueError): + betterproto.dump_varint(beyond, stream) + + with open(streams_path / "dump_varint_negative.expected", "rb") as exp_stream, open( + tmp_path / "dump_varint_negative.out", "rb" + ) as test_stream: + assert test_stream.read() == exp_stream.read() + + +def test_dump_varint_positive(tmp_path): + single_byte = 1 + multi_byte = 10000000 + + with open(tmp_path / "dump_varint_positive.out", "wb") as stream: + betterproto.dump_varint(single_byte, stream) + betterproto.dump_varint(multi_byte, stream) + + with open(tmp_path / "dump_varint_positive.out", "rb") as test_stream, open( + streams_path / "dump_varint_positive.expected", "rb" + ) as exp_stream: + assert test_stream.read() == exp_stream.read()