From e90c19319f912491910428b19439885fb4da1064 Mon Sep 17 00:00:00 2001 From: Nugraha Date: Sat, 13 Apr 2024 04:40:26 +0700 Subject: [PATCH] feat: preserve original proto field name roundtrip --- src/betterproto/__init__.py | 251 +++++++++++++----- src/betterproto/plugin/models.py | 6 + src/betterproto/plugin/parser.py | 14 +- .../casing_message_field_uppercase.py | 14 - .../test_casing_message_field_uppercase.py | 46 ++++ 5 files changed, 254 insertions(+), 77 deletions(-) delete mode 100644 tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.py create mode 100644 tests/inputs/casing_message_field_uppercase/test_casing_message_field_uppercase.py diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index 1466e7615..850a53a5f 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -181,6 +181,9 @@ class FieldMetadata: # Is the field optional optional: Optional[bool] = False + # Holding the original field name on proto file + name: Optional[string] = None + @staticmethod def get(field: dataclasses.Field) -> "FieldMetadata": """Returns the field metadata for a dataclass field.""" @@ -195,13 +198,20 @@ def dataclass_field( group: Optional[str] = None, wraps: Optional[str] = None, optional: bool = False, + name: Optional[str] = None, ) -> dataclasses.Field: """Creates a dataclass field with attached protobuf metadata.""" return dataclasses.field( default=None if optional else PLACEHOLDER, metadata={ "betterproto": FieldMetadata( - number, proto_type, map_types, group, wraps, optional + number, + proto_type, + map_types, + group, + wraps, + optional, + name, ) }, ) @@ -212,96 +222,192 @@ def dataclass_field( # out at runtime. The generated dataclass variables are still typed correctly. -def enum_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any: - return dataclass_field(number, TYPE_ENUM, group=group, optional=optional) +def enum_field( + number: int, + group: Optional[str] = None, + optional: bool = False, + name: Optional[str] = None, +) -> Any: + return dataclass_field(number, TYPE_ENUM, group=group, optional=optional, name=name) -def bool_field(number: int, group: Optional[str] = None, optional: bool = False) -> Any: - return dataclass_field(number, TYPE_BOOL, group=group, optional=optional) +def bool_field( + number: int, + group: Optional[str] = None, + optional: bool = False, + name: Optional[str] = None, +) -> Any: + return dataclass_field(number, TYPE_BOOL, group=group, optional=optional, name=name) def int32_field( - number: int, group: Optional[str] = None, optional: bool = False + number: int, + group: Optional[str] = None, + optional: bool = False, + name: Optional[str] = None, ) -> Any: - return dataclass_field(number, TYPE_INT32, group=group, optional=optional) + return dataclass_field( + number, TYPE_INT32, group=group, optional=optional, name=name + ) def int64_field( - number: int, group: Optional[str] = None, optional: bool = False + number: int, + group: Optional[str] = None, + optional: bool = False, + name: Optional[str] = None, ) -> Any: - return dataclass_field(number, TYPE_INT64, group=group, optional=optional) + return dataclass_field( + number, TYPE_INT64, group=group, optional=optional, name=name + ) def uint32_field( - number: int, group: Optional[str] = None, optional: bool = False + number: int, + group: Optional[str] = None, + optional: bool = False, + name: Optional[str] = None, ) -> Any: - return dataclass_field(number, TYPE_UINT32, group=group, optional=optional) + return dataclass_field( + number, TYPE_UINT32, group=group, optional=optional, name=name + ) def uint64_field( - number: int, group: Optional[str] = None, optional: bool = False + number: int, + group: Optional[str] = None, + optional: bool = False, + name: Optional[str] = None, ) -> Any: - return dataclass_field(number, TYPE_UINT64, group=group, optional=optional) + return dataclass_field( + number, TYPE_UINT64, group=group, optional=optional, name=name + ) def sint32_field( - number: int, group: Optional[str] = None, optional: bool = False + number: int, + group: Optional[str] = None, + optional: bool = False, + name: Optional[str] = None, ) -> Any: - return dataclass_field(number, TYPE_SINT32, group=group, optional=optional) + return dataclass_field( + number, TYPE_SINT32, group=group, optional=optional, name=name + ) def sint64_field( - number: int, group: Optional[str] = None, optional: bool = False + number: int, + group: Optional[str] = None, + optional: bool = False, + name: Optional[str] = None, ) -> Any: - return dataclass_field(number, TYPE_SINT64, group=group, optional=optional) + return dataclass_field( + number, TYPE_SINT64, group=group, optional=optional, name=name + ) def float_field( - number: int, group: Optional[str] = None, optional: bool = False + number: int, + group: Optional[str] = None, + optional: bool = False, + name: Optional[str] = None, ) -> Any: - return dataclass_field(number, TYPE_FLOAT, group=group, optional=optional) + return dataclass_field( + number, TYPE_FLOAT, group=group, optional=optional, name=name + ) def double_field( - number: int, group: Optional[str] = None, optional: bool = False + number: int, + group: Optional[str] = None, + optional: bool = False, + name: Optional[str] = None, ) -> Any: - return dataclass_field(number, TYPE_DOUBLE, group=group, optional=optional) + return dataclass_field( + number, TYPE_DOUBLE, group=group, optional=optional, name=name + ) def fixed32_field( - number: int, group: Optional[str] = None, optional: bool = False + number: int, + group: Optional[str] = None, + optional: bool = False, + name: Optional[str] = None, ) -> Any: - return dataclass_field(number, TYPE_FIXED32, group=group, optional=optional) + return dataclass_field( + number, + TYPE_FIXED32, + group=group, + optional=optional, + name=name, + ) def fixed64_field( - number: int, group: Optional[str] = None, optional: bool = False + number: int, + group: Optional[str] = None, + optional: bool = False, + name: Optional[str] = None, ) -> Any: - return dataclass_field(number, TYPE_FIXED64, group=group, optional=optional) + return dataclass_field( + number, + TYPE_FIXED64, + group=group, + optional=optional, + name=name, + ) def sfixed32_field( - number: int, group: Optional[str] = None, optional: bool = False + number: int, + group: Optional[str] = None, + optional: bool = False, + name: Optional[str] = None, ) -> Any: - return dataclass_field(number, TYPE_SFIXED32, group=group, optional=optional) + return dataclass_field( + number, + TYPE_SFIXED32, + group=group, + optional=optional, + name=name, + ) def sfixed64_field( - number: int, group: Optional[str] = None, optional: bool = False + number: int, + group: Optional[str] = None, + optional: bool = False, + name: Optional[str] = None, ) -> Any: - return dataclass_field(number, TYPE_SFIXED64, group=group, optional=optional) + return dataclass_field( + number, + TYPE_SFIXED64, + group=group, + optional=optional, + name=name, + ) def string_field( - number: int, group: Optional[str] = None, optional: bool = False + number: int, + group: Optional[str] = None, + optional: bool = False, + name: Optional[str] = None, ) -> Any: - return dataclass_field(number, TYPE_STRING, group=group, optional=optional) + return dataclass_field( + number, TYPE_STRING, group=group, optional=optional, name=name + ) def bytes_field( - number: int, group: Optional[str] = None, optional: bool = False + number: int, + group: Optional[str] = None, + optional: bool = False, + name: Optional[str] = None, ) -> Any: - return dataclass_field(number, TYPE_BYTES, group=group, optional=optional) + return dataclass_field( + number, TYPE_BYTES, group=group, optional=optional, name=name + ) def message_field( @@ -309,17 +415,31 @@ def message_field( group: Optional[str] = None, wraps: Optional[str] = None, optional: bool = False, + name: Optional[str] = None, ) -> Any: return dataclass_field( - number, TYPE_MESSAGE, group=group, wraps=wraps, optional=optional + number, + TYPE_MESSAGE, + group=group, + wraps=wraps, + optional=optional, + name=name, ) def map_field( - number: int, key_type: str, value_type: str, group: Optional[str] = None + number: int, + key_type: str, + value_type: str, + group: Optional[str] = None, + name: Optional[str] = None, ) -> Any: return dataclass_field( - number, TYPE_MAP, map_types=(key_type, value_type), group=group + number, + TYPE_MAP, + map_types=(key_type, value_type), + group=group, + name=name, ) @@ -1384,14 +1504,19 @@ def FromString(cls: Type[T], data: bytes) -> T: return cls().parse(data) def to_dict( - self, casing: Casing = Casing.CAMEL, include_default_values: bool = False + self, + casing: Optional[Casing] = Casing.CAMEL, + include_default_values: bool = False, ) -> Dict[str, Any]: """ Returns a JSON serializable dict representation of this object. Parameters ----------- - casing: :class:`Casing` + casing: Optional[:class:`Casing`] + If set to None, it will check field metadata to see if original proto field name + is preserved, if it is empty, it will use ``Casing.CAMEL`` + The casing to use for key values. Default is :attr:`Casing.CAMEL` for compatibility purposes. include_default_values: :class:`bool` @@ -1413,7 +1538,11 @@ def to_dict( value = getattr(self, field_name) except AttributeError: value = self._get_field_default(field_name) - cased_name = casing(field_name).rstrip("_") # type: ignore + + memb_key = meta.name or field_name + if casing: + memb_key = casing(field_name).rstrip("_") # type: ignore + if meta.proto_type == TYPE_MESSAGE: if isinstance(value, datetime): if ( @@ -1423,7 +1552,7 @@ def to_dict( field_name=field_name, meta=meta ) ): - output[cased_name] = _Timestamp.timestamp_to_json(value) + output[memb_key] = _Timestamp.timestamp_to_json(value) elif isinstance(value, timedelta): if ( value != timedelta(0) @@ -1432,10 +1561,10 @@ def to_dict( field_name=field_name, meta=meta ) ): - output[cased_name] = _Duration.delta_to_json(value) + output[memb_key] = _Duration.delta_to_json(value) elif meta.wraps: if value is not None or include_default_values: - output[cased_name] = value + output[memb_key] = value elif field_is_repeated: # Convert each item. cls = self._betterproto.cls_by_field[field_name] @@ -1450,10 +1579,10 @@ def to_dict( i.to_dict(casing, include_default_values) for i in value ] if value or include_default_values: - output[cased_name] = value + output[memb_key] = value elif value is None: if include_default_values: - output[cased_name] = value + output[memb_key] = value elif ( value._serialized_on_wire or include_default_values @@ -1461,7 +1590,7 @@ def to_dict( field_name=field_name, meta=meta ) ): - output[cased_name] = value.to_dict(casing, include_default_values) + output[memb_key] = value.to_dict(casing, include_default_values) elif meta.proto_type == TYPE_MAP: output_map = {**value} for k in value: @@ -1469,7 +1598,7 @@ def to_dict( output_map[k] = value[k].to_dict(casing, include_default_values) if value or include_default_values: - output[cased_name] = output_map + output[memb_key] = output_map elif ( value != self._get_field_default(field_name) or include_default_values @@ -1479,47 +1608,45 @@ def to_dict( ): if meta.proto_type in INT_64_TYPES: if field_is_repeated: - output[cased_name] = [str(n) for n in value] + output[memb_key] = [str(n) for n in value] elif value is None: if include_default_values: - output[cased_name] = value + output[memb_key] = value else: - output[cased_name] = str(value) + output[memb_key] = str(value) elif meta.proto_type == TYPE_BYTES: if field_is_repeated: - output[cased_name] = [ - b64encode(b).decode("utf8") for b in value - ] + output[memb_key] = [b64encode(b).decode("utf8") for b in value] elif value is None and include_default_values: - output[cased_name] = value + output[memb_key] = value else: - output[cased_name] = b64encode(value).decode("utf8") + output[memb_key] = b64encode(value).decode("utf8") elif meta.proto_type == TYPE_ENUM: if field_is_repeated: enum_class = field_types[field_name].__args__[0] if isinstance(value, typing.Iterable) and not isinstance( value, str ): - output[cased_name] = [enum_class(el).name for el in value] + output[memb_key] = [enum_class(el).name for el in value] else: # transparently upgrade single value to repeated - output[cased_name] = [enum_class(value).name] + output[memb_key] = [enum_class(value).name] elif value is None: if include_default_values: - output[cased_name] = value + output[memb_key] = value elif meta.optional: enum_class = field_types[field_name].__args__[0] - output[cased_name] = enum_class(value).name + output[memb_key] = enum_class(value).name else: enum_class = field_types[field_name] # noqa - output[cased_name] = enum_class(value).name + output[memb_key] = enum_class(value).name elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE): if field_is_repeated: - output[cased_name] = [_dump_float(n) for n in value] + output[memb_key] = [_dump_float(n) for n in value] else: - output[cased_name] = _dump_float(value) + output[memb_key] = _dump_float(value) else: - output[cased_name] = value + output[memb_key] = value return output @classmethod diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index defacde5a..5fb6a93e9 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -257,6 +257,7 @@ class OutputTemplate: imports_type_checking_only: Set[str] = field(default_factory=set) pydantic_dataclasses: bool = False use_optionals: Optional[Literal["all"]] = None + include_original_field_name: bool = True output: bool = True @property @@ -428,6 +429,11 @@ def betterproto_field_args(self) -> List[str]: args.append(f"wraps={self.field_wraps}") if self.optional: args.append(f"optional=True") + if ( + self.proto_name != self.py_name + and self.output_file.include_original_field_name + ): + args.append(f'name="{self.proto_name}"') return args @property diff --git a/src/betterproto/plugin/parser.py b/src/betterproto/plugin/parser.py index a23082ae2..fddc00e5b 100644 --- a/src/betterproto/plugin/parser.py +++ b/src/betterproto/plugin/parser.py @@ -98,9 +98,21 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: output_package_name ].pydantic_dataclasses = True - if "useOptionals=all" in plugin_options: + if ( + "useOptionals=all" in plugin_options + or "use_optionals=all" in plugin_options + ): request_data.output_packages[output_package_name].use_optionals = "all" + if ( + "include_original_field_name=false" in plugin_options + or "include_original_field_name=0" in plugin_options + or "include_original_field_name=off" in plugin_options + ): + request_data.output_packages[ + output_package_name + ].include_original_field_name = False + # Read Messages and Enums # We need to read Messages before Services in so that we can # get the references to input/output messages for each service diff --git a/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.py b/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.py deleted file mode 100644 index 2b32b5308..000000000 --- a/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.py +++ /dev/null @@ -1,14 +0,0 @@ -from tests.output_betterproto.casing_message_field_uppercase import Test - - -def test_message_casing(): - message = Test() - assert hasattr( - message, "uppercase" - ), "UPPERCASE attribute is converted to 'uppercase' in python" - assert hasattr( - message, "uppercase_v2" - ), "UPPERCASE_V2 attribute is converted to 'uppercase_v2' in python" - assert hasattr( - message, "upper_camel_case" - ), "UPPER_CAMEL_CASE attribute is converted to upper_camel_case in python" diff --git a/tests/inputs/casing_message_field_uppercase/test_casing_message_field_uppercase.py b/tests/inputs/casing_message_field_uppercase/test_casing_message_field_uppercase.py new file mode 100644 index 000000000..4e06b09ab --- /dev/null +++ b/tests/inputs/casing_message_field_uppercase/test_casing_message_field_uppercase.py @@ -0,0 +1,46 @@ +from betterproto import Casing +from tests.output_betterproto.casing_message_field_uppercase import Test + + +def test_message_casing(): + message = Test() + + assert hasattr( + message, "uppercase" + ), "UPPERCASE attribute is converted to 'uppercase' in python" + assert hasattr( + message, "uppercase_v2" + ), "UPPERCASE_V2 attribute is converted to 'uppercase_v2' in python" + assert hasattr( + message, "upper_camel_case" + ), "UPPER_CAMEL_CASE attribute is converted to upper_camel_case in python" + + +def test_message_casing_roundtrip(): + snake_case_dict = { + "uppercase": 1, + "uppercase_v2": 2, + "upper_camel_case": 3, + } + original_case_dict = { + "UPPERCASE": 1, + "UPPERCASE_V2": 2, + "UPPER_CAMEL_CASE": 3, + } + camel_case_dict = { + "uppercase": 1, + "uppercaseV2": 2, + "upperCamelCase": 3, + } + + def compare_expected(message: Test): + message_dict = message.to_dict(casing=None) + assert message_dict == original_case_dict, message_dict + message_dict = message.to_dict(casing=Casing.CAMEL) + assert message_dict == camel_case_dict, message_dict + message_dict = message.to_dict(casing=Casing.SNAKE) + assert message_dict == snake_case_dict, message_dict + + compare_expected(Test.from_dict(snake_case_dict)) + compare_expected(Test.from_dict(original_case_dict)) + compare_expected(Test.from_dict(camel_case_dict))