From 6faac1d1ca4b979704c145d968dfed8e38fe159c Mon Sep 17 00:00:00 2001 From: Alexander Khabarov Date: Fri, 21 Jul 2023 13:26:30 +0100 Subject: [PATCH] Raise `AttributeError` on attempts to access unset `oneof` fields (#510) --- .pre-commit-config.yaml | 3 +- poetry.lock | 2 +- pyproject.toml | 4 +- src/betterproto/__init__.py | 46 +++++++++++++++---- src/betterproto/grpc/grpclib_server.py | 1 - src/betterproto/plugin/compiler.py | 1 - src/betterproto/plugin/parser.py | 1 - .../test_google_impl_behavior_equivalence.py | 6 ++- tests/inputs/oneof_enum/test_oneof_enum.py | 13 +++--- tests/oneof_pattern_matching.py | 46 +++++++++++++++++++ tests/test_features.py | 22 +++++++-- 11 files changed, 116 insertions(+), 29 deletions(-) create mode 100644 tests/oneof_pattern_matching.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 26087961f..fa461c82c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,9 +8,10 @@ repos: - id: isort - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 23.1.0 hooks: - id: black + args: ["--target-version", "py310"] - repo: https://github.com/PyCQA/doc8 rev: 0.10.1 diff --git a/poetry.lock b/poetry.lock index 6231cf93c..00d196612 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1858,4 +1858,4 @@ compiler = ["black", "isort", "jinja2"] [metadata] lock-version = "2.0" python-versions = "^3.7" -content-hash = "62d298634665ebd06f69ec8ea543c3d7720184ec9d833c32575de8d965332aec" +content-hash = "8f733a72705d31633a7f198a7a7dd6e3170876a1ccb8ca75b7d94b6379384a8f" diff --git a/pyproject.toml b/pyproject.toml index 6487726e8..4c43d7806 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ packages = [ [tool.poetry.dependencies] python = "^3.7" -black = { version = ">=19.3b0", optional = true } +black = { version = ">=23.1.0", optional = true } grpclib = "^0.4.1" importlib-metadata = { version = ">=1.6.0", python = "<3.8" } jinja2 = { version = ">=3.0.3", optional = true } @@ -62,7 +62,7 @@ cmd = "mypy src --ignore-missing-imports" help = "Check types with mypy" [tool.poe.tasks.format] -cmd = "black . --exclude tests/output_" +cmd = "black . --exclude tests/output_ --target-version py310" help = "Apply black formatting to source code" [tool.poe.tasks.docs] diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index f22a8f7cb..be06c3e4b 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -693,8 +693,28 @@ def __repr__(self) -> str: def __getattribute__(self, name: str) -> Any: """ Lazily initialize default values to avoid infinite recursion for recursive - message types + message types. + Raise :class:`AttributeError` on attempts to access unset ``oneof`` fields. """ + try: + group_current = super().__getattribute__("_group_current") + except AttributeError: + pass + else: + if name not in {"__class__", "_betterproto"}: + group = self._betterproto.oneof_group_by_field.get(name) + if group is not None and group_current[group] != name: + if sys.version_info < (3, 10): + raise AttributeError( + f"{group!r} is set to {group_current[group]!r}, not {name!r}" + ) + else: + raise AttributeError( + f"{group!r} is set to {group_current[group]!r}, not {name!r}", + name=name, + obj=self, + ) + value = super().__getattribute__(name) if value is not PLACEHOLDER: return value @@ -761,7 +781,10 @@ def __bytes__(self) -> bytes: """ output = bytearray() for field_name, meta in self._betterproto.meta_by_field_name.items(): - value = getattr(self, field_name) + try: + value = getattr(self, field_name) + except AttributeError: + continue if value is None: # Optional items should be skipped. This is used for the Google @@ -775,9 +798,7 @@ def __bytes__(self) -> bytes: # Note that proto3 field presence/optional fields are put in a # synthetic single-item oneof by protoc, which helps us ensure we # send the value even if the value is the default zero value. - selected_in_group = ( - meta.group and self._group_current[meta.group] == field_name - ) + selected_in_group = bool(meta.group) # Empty messages can still be sent on the wire if they were # set (or received empty). @@ -1016,7 +1037,12 @@ def parse(self: T, data: bytes) -> T: parsed.wire_type, meta, field_name, parsed.value ) - current = getattr(self, field_name) + try: + current = getattr(self, field_name) + except AttributeError: + current = self._get_field_default(field_name) + setattr(self, field_name, current) + if meta.proto_type == TYPE_MAP: # Value represents a single key/value pair entry in the map. current[value.key] = value.value @@ -1077,7 +1103,10 @@ def to_dict( defaults = self._betterproto.default_gen for field_name, meta in self._betterproto.meta_by_field_name.items(): field_is_repeated = defaults[field_name] is list - value = getattr(self, field_name) + try: + value = getattr(self, field_name) + except AttributeError: + value = self._get_field_default(field_name) cased_name = casing(field_name).rstrip("_") # type: ignore if meta.proto_type == TYPE_MESSAGE: if isinstance(value, datetime): @@ -1209,7 +1238,7 @@ def from_dict(self: T, value: Mapping[str, Any]) -> T: if value[key] is not None: if meta.proto_type == TYPE_MESSAGE: - v = getattr(self, field_name) + v = self._get_field_default(field_name) cls = self._betterproto.cls_by_field[field_name] if isinstance(v, list): if cls == datetime: @@ -1486,7 +1515,6 @@ def _validate_field_groups(cls, values): field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore for group, field_set in group_to_one_ofs.items(): - if len(field_set) == 1: (field,) = field_set field_name = field.name diff --git a/src/betterproto/grpc/grpclib_server.py b/src/betterproto/grpc/grpclib_server.py index 5c1f93452..3e2803113 100644 --- a/src/betterproto/grpc/grpclib_server.py +++ b/src/betterproto/grpc/grpclib_server.py @@ -21,7 +21,6 @@ async def _call_rpc_handler_server_stream( stream: grpclib.server.Stream, request: Any, ) -> None: - response_iter = handler(request) # check if response is actually an AsyncIterator # this might be false if the method just returns without diff --git a/src/betterproto/plugin/compiler.py b/src/betterproto/plugin/compiler.py index 7542432b2..510d64857 100644 --- a/src/betterproto/plugin/compiler.py +++ b/src/betterproto/plugin/compiler.py @@ -21,7 +21,6 @@ def outputfile_compiler(output_file: OutputTemplate) -> str: - templates_folder = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "templates") ) diff --git a/src/betterproto/plugin/parser.py b/src/betterproto/plugin/parser.py index 358cc20e4..f48533338 100644 --- a/src/betterproto/plugin/parser.py +++ b/src/betterproto/plugin/parser.py @@ -159,7 +159,6 @@ def _make_one_of_field_compiler( proto_obj: "FieldDescriptorProto", path: List[int], ) -> FieldCompiler: - pydantic = output_package.pydantic_dataclasses Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler return Cls( diff --git a/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py b/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py index 476d20e3b..dd2a9f53e 100644 --- a/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py +++ b/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py @@ -50,8 +50,10 @@ def test_bytes_are_the_same_for_oneof(): # None of these fields were explicitly set BUT they should not actually be null # themselves - assert isinstance(message.foo, Foo) - assert isinstance(message2.foo, Foo) + assert not hasattr(message, "foo") + assert object.__getattribute__(message, "foo") == betterproto.PLACEHOLDER + assert not hasattr(message2, "foo") + assert object.__getattribute__(message2, "foo") == betterproto.PLACEHOLDER assert isinstance(message_reference.foo, ReferenceFoo) assert isinstance(message_reference2.foo, ReferenceFoo) diff --git a/tests/inputs/oneof_enum/test_oneof_enum.py b/tests/inputs/oneof_enum/test_oneof_enum.py index 7e287d4a4..e54fa3859 100644 --- a/tests/inputs/oneof_enum/test_oneof_enum.py +++ b/tests/inputs/oneof_enum/test_oneof_enum.py @@ -18,9 +18,8 @@ def test_which_one_of_returns_enum_with_default_value(): get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json ) - assert message.move == Move( - x=0, y=0 - ) # Proto3 will default this as there is no null + assert not hasattr(message, "move") + assert object.__getattribute__(message, "move") == betterproto.PLACEHOLDER assert message.signal == Signal.PASS assert betterproto.which_one_of(message, "action") == ("signal", Signal.PASS) @@ -33,9 +32,8 @@ def test_which_one_of_returns_enum_with_non_default_value(): message.from_json( get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json ) - assert message.move == Move( - x=0, y=0 - ) # Proto3 will default this as there is no null + assert not hasattr(message, "move") + assert object.__getattribute__(message, "move") == betterproto.PLACEHOLDER assert message.signal == Signal.RESIGN assert betterproto.which_one_of(message, "action") == ("signal", Signal.RESIGN) @@ -44,5 +42,6 @@ def test_which_one_of_returns_second_field_when_set(): message = Test() message.from_json(get_test_case_json_data("oneof_enum")[0].json) assert message.move == Move(x=2, y=3) - assert message.signal == Signal.PASS + assert not hasattr(message, "signal") + assert object.__getattribute__(message, "signal") == betterproto.PLACEHOLDER assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3)) diff --git a/tests/oneof_pattern_matching.py b/tests/oneof_pattern_matching.py new file mode 100644 index 000000000..d4f18aab2 --- /dev/null +++ b/tests/oneof_pattern_matching.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass + +import pytest + +import betterproto + + +def test_oneof_pattern_matching(): + @dataclass + class Sub(betterproto.Message): + val: int = betterproto.int32_field(1) + + @dataclass + class Foo(betterproto.Message): + bar: int = betterproto.int32_field(1, group="group1") + baz: str = betterproto.string_field(2, group="group1") + sub: Sub = betterproto.message_field(3, group="group2") + abc: str = betterproto.string_field(4, group="group2") + + foo = Foo(baz="test1", abc="test2") + + match foo: + case Foo(bar=_): + pytest.fail("Matched 'bar' instead of 'baz'") + case Foo(baz=v): + assert v == "test1" + case _: + pytest.fail("Matched neither 'bar' nor 'baz'") + + match foo: + case Foo(sub=_): + pytest.fail("Matched 'sub' instead of 'abc'") + case Foo(abc=v): + assert v == "test2" + case _: + pytest.fail("Matched neither 'sub' nor 'abc'") + + foo.sub = Sub(val=1) + + match foo: + case Foo(sub=Sub(val=v)): + assert v == 1 + case Foo(abc=v): + pytest.fail("Matched 'abc' instead of 'sub'") + case _: + pytest.fail("Matched neither 'sub' nor 'abc'") diff --git a/tests/test_features.py b/tests/test_features.py index 940cd51c8..322a310f4 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -1,4 +1,5 @@ import json +import sys from copy import ( copy, deepcopy, @@ -18,6 +19,8 @@ Optional, ) +import pytest + import betterproto @@ -151,17 +154,18 @@ class Foo(betterproto.Message): foo.baz = "test" # Other oneof fields should now be unset - assert foo.bar == 0 + assert not hasattr(foo, "bar") + assert object.__getattribute__(foo, "bar") == betterproto.PLACEHOLDER assert betterproto.which_one_of(foo, "group1")[0] == "baz" - foo.sub.val = 1 + foo.sub = Sub(val=1) assert betterproto.serialized_on_wire(foo.sub) foo.abc = "test" # Group 1 shouldn't be touched, group 2 should have reset - assert foo.sub.val == 0 - assert betterproto.serialized_on_wire(foo.sub) is False + assert not hasattr(foo, "sub") + assert object.__getattribute__(foo, "sub") == betterproto.PLACEHOLDER assert betterproto.which_one_of(foo, "group2")[0] == "abc" # Zero value should always serialize for one-of @@ -176,6 +180,16 @@ class Foo(betterproto.Message): assert betterproto.which_one_of(foo2, "group2")[0] == "" +@pytest.mark.skipif( + sys.version_info < (3, 10), + reason="pattern matching is only supported in python3.10+", +) +def test_oneof_pattern_matching(): + from .oneof_pattern_matching import test_oneof_pattern_matching + + test_oneof_pattern_matching() + + def test_json_casing(): @dataclass class CasingTest(betterproto.Message):