From cba4d95df525274b5a90917d68faac35c500bec8 Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Thu, 24 Oct 2024 18:30:15 +0200 Subject: [PATCH 01/20] Don't make FieldCompiler inherit from MessageCompiler --- src/betterproto/plugin/models.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index e330e6884..3d8a0c294 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -228,6 +228,10 @@ def comment(self) -> str: proto_file=self.source_file, path=self.path, indent=self.comment_indent ) + @property + def deprecated(self) -> bool: + return self.proto_obj.options.deprecated + @dataclass class PluginRequestCompiler: @@ -329,7 +333,6 @@ class MessageCompiler(ProtoContentBase): fields: List[Union["FieldCompiler", "MessageCompiler"]] = field( default_factory=list ) - deprecated: bool = field(default=False, init=False) builtins_types: Set[str] = field(default_factory=set) def __post_init__(self) -> None: @@ -339,7 +342,6 @@ def __post_init__(self) -> None: self.output_file.enums.append(self) else: self.output_file.messages.append(self) - self.deprecated = self.proto_obj.options.deprecated super().__post_init__() @property @@ -417,16 +419,22 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool: @dataclass -class FieldCompiler(MessageCompiler): +class FieldCompiler(ProtoContentBase): + source_file: FileDescriptorProto + typing_compiler: TypingCompiler + path: List[int] = PLACEHOLDER + builtins_types: Set[str] = field(default_factory=set) + parent: MessageCompiler = PLACEHOLDER proto_obj: FieldDescriptorProto = PLACEHOLDER def __post_init__(self) -> None: # Add field to message - self.parent.fields.append(self) + if isinstance(self.parent, MessageCompiler): # TODO make this useless + self.parent.fields.append(self) # Check for new imports self.add_imports_to(self.output_file) - super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__ + super().__post_init__() def get_field_string(self, indent: int = 4) -> str: """Construct string representation of this field as a field.""" From 4da394eeba57eb6e44e95fcf42df8a7ae8373f52 Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Thu, 24 Oct 2024 18:31:20 +0200 Subject: [PATCH 02/20] Remove wrong comment --- src/betterproto/plugin/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index 3d8a0c294..1e9aacb6b 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -625,7 +625,7 @@ def __post_init__(self) -> None: # Get proto types self.proto_k_type = FieldDescriptorProtoType(nested.field[0].type).name self.proto_v_type = FieldDescriptorProtoType(nested.field[1].type).name - super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__ + super().__post_init__() @property def betterproto_field_args(self) -> List[str]: From 7a6b6de927c4585be6b9f484526bbe796b83a6f6 Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Fri, 25 Oct 2024 11:18:56 +0200 Subject: [PATCH 03/20] Simplify code --- src/betterproto/plugin/models.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index 1e9aacb6b..f54f17882 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -214,10 +214,7 @@ def output_file(self) -> "OutputTemplate": @property def request(self) -> "PluginRequestCompiler": - current = self - while not isinstance(current, OutputTemplate): - current = current.parent - return current.parent_request + return self.output_file.parent_request @property def comment(self) -> str: @@ -430,7 +427,7 @@ class FieldCompiler(ProtoContentBase): def __post_init__(self) -> None: # Add field to message - if isinstance(self.parent, MessageCompiler): # TODO make this useless + if isinstance(self.parent, MessageCompiler): self.parent.fields.append(self) # Check for new imports self.add_imports_to(self.output_file) From 2dd8a5e7148128e22d99da677645a5a5d049888e Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Fri, 25 Oct 2024 12:55:38 +0200 Subject: [PATCH 04/20] Store messages in a dict --- src/betterproto/plugin/models.py | 10 +++++----- src/betterproto/templates/template.py.j2 | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index f54f17882..74949ad77 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -245,7 +245,7 @@ def all_messages(self) -> List["MessageCompiler"]: List of all of the messages in this request. """ return [ - msg for output in self.output_packages.values() for msg in output.messages + msg for output in self.output_packages.values() for msg in output.messages.values() ] @@ -265,7 +265,7 @@ class OutputTemplate: datetime_imports: Set[str] = field(default_factory=set) pydantic_imports: Set[str] = field(default_factory=set) builtins_import: bool = False - messages: List["MessageCompiler"] = field(default_factory=list) + messages: Dict[str, "MessageCompiler"] = field(default_factory=dict) enums: List["EnumDefinitionCompiler"] = field(default_factory=list) services: List["ServiceCompiler"] = field(default_factory=list) imports_type_checking_only: Set[str] = field(default_factory=set) @@ -300,9 +300,9 @@ def python_module_imports(self) -> Set[str]: imports = set() has_deprecated = False - if any(m.deprecated for m in self.messages): + if any(m.deprecated for m in self.messages.values()): has_deprecated = True - if any(x for x in self.messages if any(x.deprecated_fields)): + if any(x for x in self.messages.values() if any(x.deprecated_fields)): has_deprecated = True if any( any(m.proto_obj.options.deprecated for m in s.methods) @@ -338,7 +338,7 @@ def __post_init__(self) -> None: if isinstance(self, EnumDefinitionCompiler): self.output_file.enums.append(self) else: - self.output_file.messages.append(self) + self.output_file.messages[self.proto_name] = self super().__post_init__() @property diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 4a252aec6..489526e1e 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -22,7 +22,7 @@ class {{ enum.py_name }}(betterproto.Enum): {% endfor %} {% endif %} -{% for message in output_file.messages %} +{% for _, message in output_file.messages|dictsort(by="key") %} {% if output_file.pydantic_dataclasses %} @dataclass(eq=False, repr=False, config={"extra": "forbid"}) {% else %} From f4fa3598b76850d3c583fe5628d09bcf29622ef7 Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Fri, 25 Oct 2024 13:00:08 +0200 Subject: [PATCH 05/20] Store services and enums in dict --- src/betterproto/plugin/models.py | 10 +++++----- src/betterproto/templates/template.py.j2 | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index 74949ad77..ae5b41196 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -266,8 +266,8 @@ class OutputTemplate: pydantic_imports: Set[str] = field(default_factory=set) builtins_import: bool = False messages: Dict[str, "MessageCompiler"] = field(default_factory=dict) - enums: List["EnumDefinitionCompiler"] = field(default_factory=list) - services: List["ServiceCompiler"] = field(default_factory=list) + enums: Dict[str, "EnumDefinitionCompiler"] = field(default_factory=dict) + services: Dict[str, "ServiceCompiler"] = field(default_factory=dict) imports_type_checking_only: Set[str] = field(default_factory=set) pydantic_dataclasses: bool = False output: bool = True @@ -306,7 +306,7 @@ def python_module_imports(self) -> Set[str]: has_deprecated = True if any( any(m.proto_obj.options.deprecated for m in s.methods) - for s in self.services + for s in self.services.values() ): has_deprecated = True @@ -336,7 +336,7 @@ def __post_init__(self) -> None: # Add message to output file if isinstance(self.parent, OutputTemplate): if isinstance(self, EnumDefinitionCompiler): - self.output_file.enums.append(self) + self.output_file.enums[self.proto_name] = self else: self.output_file.messages[self.proto_name] = self super().__post_init__() @@ -683,7 +683,7 @@ class ServiceCompiler(ProtoContentBase): def __post_init__(self) -> None: # Add service to output file - self.output_file.services.append(self) + self.output_file.services[self.proto_name] = self super().__post_init__() # check for unset fields @property diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 489526e1e..8230819a9 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -1,4 +1,4 @@ -{% if output_file.enums %}{% for enum in output_file.enums %} +{% if output_file.enums %}{% for _, enum in output_file.enums|dictsort(by="key") %} class {{ enum.py_name }}(betterproto.Enum): {% if enum.comment %} {{ enum.comment }} @@ -63,7 +63,7 @@ class {{ message.py_name }}(betterproto.Message): {% endif %} {% endfor %} -{% for service in output_file.services %} +{% for _, service in output_file.services|dictsort(by="key") %} class {{ service.py_name }}Stub(betterproto.ServiceStub): {% if service.comment %} {{ service.comment }} @@ -147,7 +147,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {{ i }} {% endfor %} -{% for service in output_file.services %} +{% for _, service in output_file.services|dictsort(by="key") %} class {{ service.py_name }}Base(ServiceBase): {% if service.comment %} {{ service.comment }} From 0f8e038a68a01002a83838979309668d0fe895a9 Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Fri, 25 Oct 2024 13:00:47 +0200 Subject: [PATCH 06/20] Switch nested messages delimiter --- src/betterproto/plugin/parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/betterproto/plugin/parser.py b/src/betterproto/plugin/parser.py index 5f7b72c40..4b87f0dc3 100644 --- a/src/betterproto/plugin/parser.py +++ b/src/betterproto/plugin/parser.py @@ -61,7 +61,7 @@ def _traverse( for i, item in enumerate(items): # Adjust the name since we flatten the hierarchy. # Todo: don't change the name, but include full name in returned tuple - item.name = next_prefix = f"{prefix}_{item.name}" + item.name = next_prefix = f"{prefix}.{item.name}" yield item, [*path, i] if isinstance(item, DescriptorProto): From 058fb5d34355cd1b8f7afade04d317c1676b9c4e Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Fri, 25 Oct 2024 13:31:28 +0200 Subject: [PATCH 07/20] Improve parse_source_type_name --- src/betterproto/compile/importing.py | 40 ++++++++++++++++++++++++++-- src/betterproto/plugin/models.py | 28 ++++++++++--------- src/betterproto/plugin/parser.py | 2 +- 3 files changed, 55 insertions(+), 15 deletions(-) diff --git a/src/betterproto/compile/importing.py b/src/betterproto/compile/importing.py index b216dfc59..e07c33a8f 100644 --- a/src/betterproto/compile/importing.py +++ b/src/betterproto/compile/importing.py @@ -18,6 +18,7 @@ if TYPE_CHECKING: from ..plugin.typing_compiler import TypingCompiler + from ..plugin.models import PluginRequestCompiler WRAPPER_TYPES: Dict[str, Type] = { ".google.protobuf.DoubleValue": google_protobuf.DoubleValue, @@ -32,12 +33,46 @@ } -def parse_source_type_name(field_type_name: str) -> Tuple[str, str]: +def parse_source_type_name(field_type_name: str, request: "PluginRequestCompiler") -> Tuple[str, str]: """ Split full source type name into package and type name. E.g. 'root.package.Message' -> ('root.package', 'Message') 'root.Message.SomeEnum' -> ('root', 'Message.SomeEnum') """ + if field_type_name[0] != ".": + raise RuntimeError + field_type_name = field_type_name[1:] + parts = field_type_name.split(".") + + answer = None + + # a.b.c: + # i=0: "", "a.b.c" + # i=1: "a", "b.c" + # i=2: "a.b", "c" + for i in range(len(parts)): + if i == 0: # TODO + continue + + package_name, object_name = ".".join(parts[:i]), ".".join(parts[i:]) + import sys + print("Trying", package_name, "|", object_name, file=sys.stderr) + + + if package := request.output_packages.get(package_name): + print("->", list(package.messages.keys()), file=sys.stderr) + print("->", list(package.enums.keys()), file=sys.stderr) + if object_name in package.messages or object_name in package.enums: + if answer: + raise ValueError(f"ambiguous definition: {field_type_name}") + answer = package_name, object_name + + if answer: + return answer + + raise ValueError(f"can't find type name: {field_type_name}") + + package_match = re.match(r"^\.?([^A-Z]+)\.(.+)", field_type_name) if package_match: package = package_match.group(1) @@ -54,6 +89,7 @@ def get_type_reference( imports: set, source_type: str, typing_compiler: TypingCompiler, + request: "PluginRequestCompiler", unwrap: bool = True, pydantic: bool = False, ) -> str: @@ -72,7 +108,7 @@ def get_type_reference( elif source_type == ".google.protobuf.Timestamp": return "datetime" - source_package, source_type = parse_source_type_name(source_type) + source_package, source_type = parse_source_type_name(source_type, request) current_package: List[str] = package.split(".") if package else [] py_package: List[str] = source_package.split(".") if source_package else [] diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index ae5b41196..132ca1260 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -67,7 +67,6 @@ from .. import which_one_of from ..compile.importing import ( get_type_reference, - parse_source_type_name, ) from ..compile.naming import ( pythonize_class_name, @@ -458,14 +457,15 @@ def betterproto_field_args(self) -> List[str]: @property def datetime_imports(self) -> Set[str]: - imports = set() - annotation = self.annotation - # FIXME: false positives - e.g. `MyDatetimedelta` - if "timedelta" in annotation: - imports.add("timedelta") - if "datetime" in annotation: - imports.add("datetime") - return imports + # imports = set() + # annotation = self.annotation + # # FIXME: false positives - e.g. `MyDatetimedelta` + # if "timedelta" in annotation: + # imports.add("timedelta") + # if "datetime" in annotation: + # imports.add("datetime") + # return imports + return {"timedelta", "datetime"} @property def pydantic_imports(self) -> Set[str]: @@ -473,9 +473,10 @@ def pydantic_imports(self) -> Set[str]: @property def use_builtins(self) -> bool: - return self.py_type in self.parent.builtins_types or ( - self.py_type == self.py_name and self.py_name in dir(builtins) - ) + return False + # return self.py_type in self.parent.builtins_types or ( + # self.py_type == self.py_name and self.py_name in dir(builtins) + # ) def add_imports_to(self, output_file: OutputTemplate) -> None: output_file.datetime_imports.update(self.datetime_imports) @@ -549,6 +550,7 @@ def py_type(self) -> str: imports=self.output_file.imports_end, source_type=self.proto_obj.type_name, typing_compiler=self.typing_compiler, + request=self.request, pydantic=self.output_file.pydantic_dataclasses, ) else: @@ -749,6 +751,7 @@ def py_input_message_type(self) -> str: imports=self.output_file.imports_end, source_type=self.proto_obj.input_type, typing_compiler=self.output_file.typing_compiler, + request=self.request, unwrap=False, pydantic=self.output_file.pydantic_dataclasses, ).strip('"') @@ -779,6 +782,7 @@ def py_output_message_type(self) -> str: imports=self.output_file.imports_end, source_type=self.proto_obj.output_type, typing_compiler=self.output_file.typing_compiler, + request=self.request, unwrap=False, pydantic=self.output_file.pydantic_dataclasses, ).strip('"') diff --git a/src/betterproto/plugin/parser.py b/src/betterproto/plugin/parser.py index 4b87f0dc3..834fb16a0 100644 --- a/src/betterproto/plugin/parser.py +++ b/src/betterproto/plugin/parser.py @@ -61,7 +61,7 @@ def _traverse( for i, item in enumerate(items): # Adjust the name since we flatten the hierarchy. # Todo: don't change the name, but include full name in returned tuple - item.name = next_prefix = f"{prefix}.{item.name}" + item.name = next_prefix = f"{prefix}.{item.name}" if prefix else item.name yield item, [*path, i] if isinstance(item, DescriptorProto): From e106ee8846581b0b12a3ca4bae05a3d052d70462 Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Fri, 25 Oct 2024 14:43:10 +0200 Subject: [PATCH 08/20] Fix bugs --- src/betterproto/plugin/models.py | 38 +++++++++++++++++++++++--------- src/betterproto/plugin/parser.py | 20 +++++++++++++++-- 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index 132ca1260..6c46d784c 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -204,6 +204,9 @@ def __post_init__(self) -> None: if field_val is PLACEHOLDER: raise ValueError(f"`{field_name}` is a required field.") + def ready(self) -> None: + pass + @property def output_file(self) -> "OutputTemplate": current = self @@ -428,9 +431,11 @@ def __post_init__(self) -> None: # Add field to message if isinstance(self.parent, MessageCompiler): self.parent.fields.append(self) + super().__post_init__() + + def ready(self) -> None: # Check for new imports self.add_imports_to(self.output_file) - super().__post_init__() def get_field_string(self, indent: int = 4) -> str: """Construct string representation of this field as a field.""" @@ -473,10 +478,9 @@ def pydantic_imports(self) -> Set[str]: @property def use_builtins(self) -> bool: - return False - # return self.py_type in self.parent.builtins_types or ( - # self.py_type == self.py_name and self.py_name in dir(builtins) - # ) + return self.py_type in self.parent.builtins_types or ( + self.py_type == self.py_name and self.py_name in dir(builtins) + ) def add_imports_to(self, output_file: OutputTemplate) -> None: output_file.datetime_imports.update(self.datetime_imports) @@ -594,12 +598,22 @@ def pydantic_imports(self) -> Set[str]: @dataclass class MapEntryCompiler(FieldCompiler): - py_k_type: Type = PLACEHOLDER - py_v_type: Type = PLACEHOLDER - proto_k_type: str = PLACEHOLDER - proto_v_type: str = PLACEHOLDER + py_k_type: Type = None + py_v_type: Type = None + proto_k_type: str = "" + proto_v_type: str = "" - def __post_init__(self) -> None: + def __post_init__(self): + map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry" + for nested in self.parent.proto_obj.nested_type: + if ( + nested.name.replace("_", "").lower() == map_entry + and nested.options.map_entry + ): + pass + return super().__post_init__() + + def ready(self) -> None: """Explore nested types and set k_type and v_type if unset.""" map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry" for nested in self.parent.proto_obj.nested_type: @@ -624,7 +638,9 @@ def __post_init__(self) -> None: # Get proto types self.proto_k_type = FieldDescriptorProtoType(nested.field[0].type).name self.proto_v_type = FieldDescriptorProtoType(nested.field[1].type).name - super().__post_init__() + return + + raise ValueError("can't find enum") @property def betterproto_field_args(self) -> List[str]: diff --git a/src/betterproto/plugin/parser.py b/src/betterproto/plugin/parser.py index 834fb16a0..6eacf703f 100644 --- a/src/betterproto/plugin/parser.py +++ b/src/betterproto/plugin/parser.py @@ -40,7 +40,6 @@ from .typing_compiler import ( DirectImportTypingCompiler, NoTyping310TypingCompiler, - TypingCompiler, TypingImportTypingCompiler, ) @@ -61,7 +60,9 @@ def _traverse( for i, item in enumerate(items): # Adjust the name since we flatten the hierarchy. # Todo: don't change the name, but include full name in returned tuple - item.name = next_prefix = f"{prefix}.{item.name}" if prefix else item.name + should_rename = not isinstance(item, DescriptorProto) or not item.options.map_entry + + item.name = next_prefix = f"{prefix}.{item.name}" if prefix and should_rename else item.name yield item, [*path, i] if isinstance(item, DescriptorProto): @@ -145,6 +146,21 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: for index, service in enumerate(proto_input_file.service): read_protobuf_service(proto_input_file, service, index, output_package) + # All the hierarchy is ready. We can perform pre-computations before generating the output files + for package in request_data.output_packages.values(): + for message in package.messages.values(): + for field in message.fields: + field.ready() + message.ready() + for enum in package.enums.values(): + for variant in enum.fields: + variant.ready() + enum.ready() + for service in package.services.values(): + for method in service.methods: + method.ready() + service.ready() + # Generate output files output_paths: Set[pathlib.Path] = set() for output_package_name, output_package in request_data.output_packages.items(): From 1457adc940135331817d196113490a392e836012 Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Fri, 25 Oct 2024 14:46:36 +0200 Subject: [PATCH 09/20] Add new test --- tests/inputs/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/inputs/config.py b/tests/inputs/config.py index 6da1f887d..4fb1565f8 100644 --- a/tests/inputs/config.py +++ b/tests/inputs/config.py @@ -4,7 +4,6 @@ "namespace_keywords", # 70 "googletypes_struct", # 9 "googletypes_value", # 9 - "import_capitalized_package", "example", # This is the example in the readme. Not a test. } From 05c2ecf9e2af51a783e6fb12d3e3af1125d5bd1e Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Fri, 25 Oct 2024 14:48:28 +0200 Subject: [PATCH 10/20] Remove useless code --- src/betterproto/compile/importing.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/betterproto/compile/importing.py b/src/betterproto/compile/importing.py index e07c33a8f..b2f07d14f 100644 --- a/src/betterproto/compile/importing.py +++ b/src/betterproto/compile/importing.py @@ -55,13 +55,8 @@ def parse_source_type_name(field_type_name: str, request: "PluginRequestCompiler continue package_name, object_name = ".".join(parts[:i]), ".".join(parts[i:]) - import sys - print("Trying", package_name, "|", object_name, file=sys.stderr) - if package := request.output_packages.get(package_name): - print("->", list(package.messages.keys()), file=sys.stderr) - print("->", list(package.enums.keys()), file=sys.stderr) if object_name in package.messages or object_name in package.enums: if answer: raise ValueError(f"ambiguous definition: {field_type_name}") @@ -73,16 +68,6 @@ def parse_source_type_name(field_type_name: str, request: "PluginRequestCompiler raise ValueError(f"can't find type name: {field_type_name}") - package_match = re.match(r"^\.?([^A-Z]+)\.(.+)", field_type_name) - if package_match: - package = package_match.group(1) - name = package_match.group(2) - else: - package = "" - name = field_type_name.lstrip(".") - return package, name - - def get_type_reference( *, package: str, From e00005b3fbc42e4a66cd2aa4b16c0275a6dcf7d8 Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Fri, 25 Oct 2024 14:50:09 +0200 Subject: [PATCH 11/20] Add back datetime imports --- src/betterproto/plugin/models.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index 6c46d784c..d1bbcc71b 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -462,15 +462,14 @@ def betterproto_field_args(self) -> List[str]: @property def datetime_imports(self) -> Set[str]: - # imports = set() - # annotation = self.annotation - # # FIXME: false positives - e.g. `MyDatetimedelta` - # if "timedelta" in annotation: - # imports.add("timedelta") - # if "datetime" in annotation: - # imports.add("datetime") - # return imports - return {"timedelta", "datetime"} + imports = set() + annotation = self.annotation + # FIXME: false positives - e.g. `MyDatetimedelta` + if "timedelta" in annotation: + imports.add("timedelta") + if "datetime" in annotation: + imports.add("datetime") + return imports @property def pydantic_imports(self) -> Set[str]: From 36aea089d0bc6774f7f4cbd5cd9627fade2f569c Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Fri, 25 Oct 2024 14:58:00 +0200 Subject: [PATCH 12/20] Allow empty packages --- src/betterproto/compile/importing.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/betterproto/compile/importing.py b/src/betterproto/compile/importing.py index b2f07d14f..3c2d10c10 100644 --- a/src/betterproto/compile/importing.py +++ b/src/betterproto/compile/importing.py @@ -51,9 +51,6 @@ def parse_source_type_name(field_type_name: str, request: "PluginRequestCompiler # i=1: "a", "b.c" # i=2: "a.b", "c" for i in range(len(parts)): - if i == 0: # TODO - continue - package_name, object_name = ".".join(parts[:i]), ".".join(parts[i:]) if package := request.output_packages.get(package_name): From bde69dab2de5bf30e2bcf28483d4a512c7fc632e Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Fri, 25 Oct 2024 15:13:07 +0200 Subject: [PATCH 13/20] Format --- src/betterproto/compile/importing.py | 6 ++++-- src/betterproto/plugin/models.py | 8 ++++---- src/betterproto/plugin/parser.py | 8 ++++++-- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/betterproto/compile/importing.py b/src/betterproto/compile/importing.py index 3c2d10c10..828968b7e 100644 --- a/src/betterproto/compile/importing.py +++ b/src/betterproto/compile/importing.py @@ -17,8 +17,8 @@ if TYPE_CHECKING: - from ..plugin.typing_compiler import TypingCompiler from ..plugin.models import PluginRequestCompiler + from ..plugin.typing_compiler import TypingCompiler WRAPPER_TYPES: Dict[str, Type] = { ".google.protobuf.DoubleValue": google_protobuf.DoubleValue, @@ -33,7 +33,9 @@ } -def parse_source_type_name(field_type_name: str, request: "PluginRequestCompiler") -> Tuple[str, str]: +def parse_source_type_name( + field_type_name: str, request: "PluginRequestCompiler" +) -> Tuple[str, str]: """ Split full source type name into package and type name. E.g. 'root.package.Message' -> ('root.package', 'Message') diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index d1bbcc71b..caa85a06f 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -65,9 +65,7 @@ from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest from .. import which_one_of -from ..compile.importing import ( - get_type_reference, -) +from ..compile.importing import get_type_reference from ..compile.naming import ( pythonize_class_name, pythonize_enum_member_name, @@ -247,7 +245,9 @@ def all_messages(self) -> List["MessageCompiler"]: List of all of the messages in this request. """ return [ - msg for output in self.output_packages.values() for msg in output.messages.values() + msg + for output in self.output_packages.values() + for msg in output.messages.values() ] diff --git a/src/betterproto/plugin/parser.py b/src/betterproto/plugin/parser.py index 6eacf703f..cf2a8e3eb 100644 --- a/src/betterproto/plugin/parser.py +++ b/src/betterproto/plugin/parser.py @@ -60,9 +60,13 @@ def _traverse( for i, item in enumerate(items): # Adjust the name since we flatten the hierarchy. # Todo: don't change the name, but include full name in returned tuple - should_rename = not isinstance(item, DescriptorProto) or not item.options.map_entry + should_rename = ( + not isinstance(item, DescriptorProto) or not item.options.map_entry + ) - item.name = next_prefix = f"{prefix}.{item.name}" if prefix and should_rename else item.name + item.name = next_prefix = ( + f"{prefix}.{item.name}" if prefix and should_rename else item.name + ) yield item, [*path, i] if isinstance(item, DescriptorProto): From 978f8a45fcd9de65a33b578af0e517ec758f82d1 Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Fri, 25 Oct 2024 15:30:53 +0200 Subject: [PATCH 14/20] Delete wrong test file --- tests/test_get_ref_type.py | 497 ------------------------------------- 1 file changed, 497 deletions(-) delete mode 100644 tests/test_get_ref_type.py diff --git a/tests/test_get_ref_type.py b/tests/test_get_ref_type.py deleted file mode 100644 index 7b529bd27..000000000 --- a/tests/test_get_ref_type.py +++ /dev/null @@ -1,497 +0,0 @@ -import pytest - -from betterproto.compile.importing import ( - get_type_reference, - parse_source_type_name, -) -from betterproto.plugin.typing_compiler import DirectImportTypingCompiler - - -@pytest.fixture -def typing_compiler() -> DirectImportTypingCompiler: - """ - Generates a simple Direct Import Typing Compiler for testing. - """ - return DirectImportTypingCompiler() - - -@pytest.mark.parametrize( - ["google_type", "expected_name", "expected_import"], - [ - ( - ".google.protobuf.Empty", - '"betterproto_lib_google_protobuf.Empty"', - "import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf", - ), - ( - ".google.protobuf.Struct", - '"betterproto_lib_google_protobuf.Struct"', - "import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf", - ), - ( - ".google.protobuf.ListValue", - '"betterproto_lib_google_protobuf.ListValue"', - "import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf", - ), - ( - ".google.protobuf.Value", - '"betterproto_lib_google_protobuf.Value"', - "import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf", - ), - ], -) -def test_reference_google_wellknown_types_non_wrappers( - google_type: str, - expected_name: str, - expected_import: str, - typing_compiler: DirectImportTypingCompiler, -): - imports = set() - name = get_type_reference( - package="", - imports=imports, - source_type=google_type, - typing_compiler=typing_compiler, - pydantic=False, - ) - - assert name == expected_name - assert imports.__contains__( - expected_import - ), f"{expected_import} not found in {imports}" - - -@pytest.mark.parametrize( - ["google_type", "expected_name", "expected_import"], - [ - ( - ".google.protobuf.Empty", - '"betterproto_lib_pydantic_google_protobuf.Empty"', - "import betterproto.lib.pydantic.google.protobuf as betterproto_lib_pydantic_google_protobuf", - ), - ( - ".google.protobuf.Struct", - '"betterproto_lib_pydantic_google_protobuf.Struct"', - "import betterproto.lib.pydantic.google.protobuf as betterproto_lib_pydantic_google_protobuf", - ), - ( - ".google.protobuf.ListValue", - '"betterproto_lib_pydantic_google_protobuf.ListValue"', - "import betterproto.lib.pydantic.google.protobuf as betterproto_lib_pydantic_google_protobuf", - ), - ( - ".google.protobuf.Value", - '"betterproto_lib_pydantic_google_protobuf.Value"', - "import betterproto.lib.pydantic.google.protobuf as betterproto_lib_pydantic_google_protobuf", - ), - ], -) -def test_reference_google_wellknown_types_non_wrappers_pydantic( - google_type: str, - expected_name: str, - expected_import: str, - typing_compiler: DirectImportTypingCompiler, -): - imports = set() - name = get_type_reference( - package="", - imports=imports, - source_type=google_type, - typing_compiler=typing_compiler, - pydantic=True, - ) - - assert name == expected_name - assert imports.__contains__( - expected_import - ), f"{expected_import} not found in {imports}" - - -@pytest.mark.parametrize( - ["google_type", "expected_name"], - [ - (".google.protobuf.DoubleValue", "Optional[float]"), - (".google.protobuf.FloatValue", "Optional[float]"), - (".google.protobuf.Int32Value", "Optional[int]"), - (".google.protobuf.Int64Value", "Optional[int]"), - (".google.protobuf.UInt32Value", "Optional[int]"), - (".google.protobuf.UInt64Value", "Optional[int]"), - (".google.protobuf.BoolValue", "Optional[bool]"), - (".google.protobuf.StringValue", "Optional[str]"), - (".google.protobuf.BytesValue", "Optional[bytes]"), - ], -) -def test_referenceing_google_wrappers_unwraps_them( - google_type: str, expected_name: str, typing_compiler: DirectImportTypingCompiler -): - imports = set() - name = get_type_reference( - package="", - imports=imports, - source_type=google_type, - typing_compiler=typing_compiler, - ) - - assert name == expected_name - assert imports == set() - - -@pytest.mark.parametrize( - ["google_type", "expected_name"], - [ - ( - ".google.protobuf.DoubleValue", - '"betterproto_lib_google_protobuf.DoubleValue"', - ), - (".google.protobuf.FloatValue", '"betterproto_lib_google_protobuf.FloatValue"'), - (".google.protobuf.Int32Value", '"betterproto_lib_google_protobuf.Int32Value"'), - (".google.protobuf.Int64Value", '"betterproto_lib_google_protobuf.Int64Value"'), - ( - ".google.protobuf.UInt32Value", - '"betterproto_lib_google_protobuf.UInt32Value"', - ), - ( - ".google.protobuf.UInt64Value", - '"betterproto_lib_google_protobuf.UInt64Value"', - ), - (".google.protobuf.BoolValue", '"betterproto_lib_google_protobuf.BoolValue"'), - ( - ".google.protobuf.StringValue", - '"betterproto_lib_google_protobuf.StringValue"', - ), - (".google.protobuf.BytesValue", '"betterproto_lib_google_protobuf.BytesValue"'), - ], -) -def test_referenceing_google_wrappers_without_unwrapping( - google_type: str, expected_name: str, typing_compiler: DirectImportTypingCompiler -): - name = get_type_reference( - package="", - imports=set(), - source_type=google_type, - typing_compiler=typing_compiler, - unwrap=False, - ) - - assert name == expected_name - - -def test_reference_child_package_from_package( - typing_compiler: DirectImportTypingCompiler, -): - imports = set() - name = get_type_reference( - package="package", - imports=imports, - source_type="package.child.Message", - typing_compiler=typing_compiler, - ) - - assert imports == {"from . import child"} - assert name == '"child.Message"' - - -def test_reference_child_package_from_root(typing_compiler: DirectImportTypingCompiler): - imports = set() - name = get_type_reference( - package="", - imports=imports, - source_type="child.Message", - typing_compiler=typing_compiler, - ) - - assert imports == {"from . import child"} - assert name == '"child.Message"' - - -def test_reference_camel_cased(typing_compiler: DirectImportTypingCompiler): - imports = set() - name = get_type_reference( - package="", - imports=imports, - source_type="child_package.example_message", - typing_compiler=typing_compiler, - ) - - assert imports == {"from . import child_package"} - assert name == '"child_package.ExampleMessage"' - - -def test_reference_nested_child_from_root(typing_compiler: DirectImportTypingCompiler): - imports = set() - name = get_type_reference( - package="", - imports=imports, - source_type="nested.child.Message", - typing_compiler=typing_compiler, - ) - - assert imports == {"from .nested import child as nested_child"} - assert name == '"nested_child.Message"' - - -def test_reference_deeply_nested_child_from_root( - typing_compiler: DirectImportTypingCompiler, -): - imports = set() - name = get_type_reference( - package="", - imports=imports, - source_type="deeply.nested.child.Message", - typing_compiler=typing_compiler, - ) - - assert imports == {"from .deeply.nested import child as deeply_nested_child"} - assert name == '"deeply_nested_child.Message"' - - -def test_reference_deeply_nested_child_from_package( - typing_compiler: DirectImportTypingCompiler, -): - imports = set() - name = get_type_reference( - package="package", - imports=imports, - source_type="package.deeply.nested.child.Message", - typing_compiler=typing_compiler, - ) - - assert imports == {"from .deeply.nested import child as deeply_nested_child"} - assert name == '"deeply_nested_child.Message"' - - -def test_reference_root_sibling(typing_compiler: DirectImportTypingCompiler): - imports = set() - name = get_type_reference( - package="", - imports=imports, - source_type="Message", - typing_compiler=typing_compiler, - ) - - assert imports == set() - assert name == '"Message"' - - -def test_reference_nested_siblings(typing_compiler: DirectImportTypingCompiler): - imports = set() - name = get_type_reference( - package="foo", - imports=imports, - source_type="foo.Message", - typing_compiler=typing_compiler, - ) - - assert imports == set() - assert name == '"Message"' - - -def test_reference_deeply_nested_siblings(typing_compiler: DirectImportTypingCompiler): - imports = set() - name = get_type_reference( - package="foo.bar", - imports=imports, - source_type="foo.bar.Message", - typing_compiler=typing_compiler, - ) - - assert imports == set() - assert name == '"Message"' - - -def test_reference_parent_package_from_child( - typing_compiler: DirectImportTypingCompiler, -): - imports = set() - name = get_type_reference( - package="package.child", - imports=imports, - source_type="package.Message", - typing_compiler=typing_compiler, - ) - - assert imports == {"from ... import package as __package__"} - assert name == '"__package__.Message"' - - -def test_reference_parent_package_from_deeply_nested_child( - typing_compiler: DirectImportTypingCompiler, -): - imports = set() - name = get_type_reference( - package="package.deeply.nested.child", - imports=imports, - source_type="package.deeply.nested.Message", - typing_compiler=typing_compiler, - ) - - assert imports == {"from ... import nested as __nested__"} - assert name == '"__nested__.Message"' - - -def test_reference_ancestor_package_from_nested_child( - typing_compiler: DirectImportTypingCompiler, -): - imports = set() - name = get_type_reference( - package="package.ancestor.nested.child", - imports=imports, - source_type="package.ancestor.Message", - typing_compiler=typing_compiler, - ) - - assert imports == {"from .... import ancestor as ___ancestor__"} - assert name == '"___ancestor__.Message"' - - -def test_reference_root_package_from_child(typing_compiler: DirectImportTypingCompiler): - imports = set() - name = get_type_reference( - package="package.child", - imports=imports, - source_type="Message", - typing_compiler=typing_compiler, - ) - - assert imports == {"from ... import Message as __Message__"} - assert name == '"__Message__"' - - -def test_reference_root_package_from_deeply_nested_child( - typing_compiler: DirectImportTypingCompiler, -): - imports = set() - name = get_type_reference( - package="package.deeply.nested.child", - imports=imports, - source_type="Message", - typing_compiler=typing_compiler, - ) - - assert imports == {"from ..... import Message as ____Message__"} - assert name == '"____Message__"' - - -def test_reference_unrelated_package(typing_compiler: DirectImportTypingCompiler): - imports = set() - name = get_type_reference( - package="a", - imports=imports, - source_type="p.Message", - typing_compiler=typing_compiler, - ) - - assert imports == {"from .. import p as _p__"} - assert name == '"_p__.Message"' - - -def test_reference_unrelated_nested_package( - typing_compiler: DirectImportTypingCompiler, -): - imports = set() - name = get_type_reference( - package="a.b", - imports=imports, - source_type="p.q.Message", - typing_compiler=typing_compiler, - ) - - assert imports == {"from ...p import q as __p_q__"} - assert name == '"__p_q__.Message"' - - -def test_reference_unrelated_deeply_nested_package( - typing_compiler: DirectImportTypingCompiler, -): - imports = set() - name = get_type_reference( - package="a.b.c.d", - imports=imports, - source_type="p.q.r.s.Message", - typing_compiler=typing_compiler, - ) - - assert imports == {"from .....p.q.r import s as ____p_q_r_s__"} - assert name == '"____p_q_r_s__.Message"' - - -def test_reference_cousin_package(typing_compiler: DirectImportTypingCompiler): - imports = set() - name = get_type_reference( - package="a.x", - imports=imports, - source_type="a.y.Message", - typing_compiler=typing_compiler, - ) - - assert imports == {"from .. import y as _y__"} - assert name == '"_y__.Message"' - - -def test_reference_cousin_package_different_name( - typing_compiler: DirectImportTypingCompiler, -): - imports = set() - name = get_type_reference( - package="test.package1", - imports=imports, - source_type="cousin.package2.Message", - typing_compiler=typing_compiler, - ) - - assert imports == {"from ...cousin import package2 as __cousin_package2__"} - assert name == '"__cousin_package2__.Message"' - - -def test_reference_cousin_package_same_name( - typing_compiler: DirectImportTypingCompiler, -): - imports = set() - name = get_type_reference( - package="test.package", - imports=imports, - source_type="cousin.package.Message", - typing_compiler=typing_compiler, - ) - - assert imports == {"from ...cousin import package as __cousin_package__"} - assert name == '"__cousin_package__.Message"' - - -def test_reference_far_cousin_package(typing_compiler: DirectImportTypingCompiler): - imports = set() - name = get_type_reference( - package="a.x.y", - imports=imports, - source_type="a.b.c.Message", - typing_compiler=typing_compiler, - ) - - assert imports == {"from ...b import c as __b_c__"} - assert name == '"__b_c__.Message"' - - -def test_reference_far_far_cousin_package(typing_compiler: DirectImportTypingCompiler): - imports = set() - name = get_type_reference( - package="a.x.y.z", - imports=imports, - source_type="a.b.c.d.Message", - typing_compiler=typing_compiler, - ) - - assert imports == {"from ....b.c import d as ___b_c_d__"} - assert name == '"___b_c_d__.Message"' - - -@pytest.mark.parametrize( - ["full_name", "expected_output"], - [ - ("package.SomeMessage.NestedType", ("package", "SomeMessage.NestedType")), - (".package.SomeMessage.NestedType", ("package", "SomeMessage.NestedType")), - (".service.ExampleRequest", ("service", "ExampleRequest")), - (".package.lower_case_message", ("package", "lower_case_message")), - ], -) -def test_parse_field_type_name(full_name, expected_output): - assert parse_source_type_name(full_name) == expected_output From 28f5894126d9a4311b19275c7c13dd0a198269ce Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Fri, 8 Nov 2024 13:14:08 +0100 Subject: [PATCH 15/20] Add documentation --- src/betterproto/compile/importing.py | 5 ++++- src/betterproto/plugin/models.py | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/betterproto/compile/importing.py b/src/betterproto/compile/importing.py index 828968b7e..0b66bfd5e 100644 --- a/src/betterproto/compile/importing.py +++ b/src/betterproto/compile/importing.py @@ -40,9 +40,12 @@ def parse_source_type_name( Split full source type name into package and type name. E.g. 'root.package.Message' -> ('root.package', 'Message') 'root.Message.SomeEnum' -> ('root', 'Message.SomeEnum') + + The function goes through the symbols that have been defined (names, enums, packages) to find the actual package and + name of the object that is referenced. """ if field_type_name[0] != ".": - raise RuntimeError + raise RuntimeError("relative names are not supported") field_type_name = field_type_name[1:] parts = field_type_name.split(".") diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index caa85a06f..d51259809 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -203,6 +203,9 @@ def __post_init__(self) -> None: raise ValueError(f"`{field_name}` is a required field.") def ready(self) -> None: + """ + This function is called after all the compilers are created, but before generating the output code. + """ pass @property From 37cbbdd1092796b498ea8d1bf483fa9a057cf9f9 Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Fri, 8 Nov 2024 14:02:34 +0100 Subject: [PATCH 16/20] New tests --- tests/inputs/import_child_scoping_rules/child.proto | 7 +++++++ .../import_child_scoping_rules.proto | 9 +++++++++ .../inputs/import_child_scoping_rules/package.proto | 13 +++++++++++++ .../child.proto | 7 +++++++ .../import_nested_child_package_from_root.proto | 9 +++++++++ 5 files changed, 45 insertions(+) create mode 100644 tests/inputs/import_child_scoping_rules/child.proto create mode 100644 tests/inputs/import_child_scoping_rules/import_child_scoping_rules.proto create mode 100644 tests/inputs/import_child_scoping_rules/package.proto create mode 100644 tests/inputs/import_nested_child_package_from_root/child.proto create mode 100644 tests/inputs/import_nested_child_package_from_root/import_nested_child_package_from_root.proto diff --git a/tests/inputs/import_child_scoping_rules/child.proto b/tests/inputs/import_child_scoping_rules/child.proto new file mode 100644 index 000000000..f491e0da9 --- /dev/null +++ b/tests/inputs/import_child_scoping_rules/child.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package import_child_scoping_rules.aaa.bbb.ccc.ddd; + +message ChildMessage { + +} diff --git a/tests/inputs/import_child_scoping_rules/import_child_scoping_rules.proto b/tests/inputs/import_child_scoping_rules/import_child_scoping_rules.proto new file mode 100644 index 000000000..272852ccd --- /dev/null +++ b/tests/inputs/import_child_scoping_rules/import_child_scoping_rules.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package import_child_scoping_rules; + +import "package.proto"; + +message Test { + aaa.bbb.Msg msg = 1; +} diff --git a/tests/inputs/import_child_scoping_rules/package.proto b/tests/inputs/import_child_scoping_rules/package.proto new file mode 100644 index 000000000..6b51fe567 --- /dev/null +++ b/tests/inputs/import_child_scoping_rules/package.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package import_child_scoping_rules.aaa.bbb; + +import "child.proto"; + +message Msg { + .import_child_scoping_rules.aaa.bbb.ccc.ddd.ChildMessage a = 1; + import_child_scoping_rules.aaa.bbb.ccc.ddd.ChildMessage b = 2; + aaa.bbb.ccc.ddd.ChildMessage c = 3; + bbb.ccc.ddd.ChildMessage d = 4; + ccc.ddd.ChildMessage e = 5; +} diff --git a/tests/inputs/import_nested_child_package_from_root/child.proto b/tests/inputs/import_nested_child_package_from_root/child.proto new file mode 100644 index 000000000..fcd7e2f6c --- /dev/null +++ b/tests/inputs/import_nested_child_package_from_root/child.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package import_nested_child_package_from_root.package.child.otherchild; + +message ChildMessage { + +} diff --git a/tests/inputs/import_nested_child_package_from_root/import_nested_child_package_from_root.proto b/tests/inputs/import_nested_child_package_from_root/import_nested_child_package_from_root.proto new file mode 100644 index 000000000..96da1ace6 --- /dev/null +++ b/tests/inputs/import_nested_child_package_from_root/import_nested_child_package_from_root.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package import_nested_child_package_from_root; + +import "child.proto"; + +message Test { + package.child.otherchild.ChildMessage child = 1; +} From 016a566f91e9530f39ffc7a4b50ebb68348dad39 Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Fri, 8 Nov 2024 15:23:25 +0100 Subject: [PATCH 17/20] Fix problem with __all__ --- src/betterproto/templates/header.py.j2 | 6 +++--- tests/test_all_definition.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/betterproto/templates/header.py.j2 b/src/betterproto/templates/header.py.j2 index b6d0a6c44..f3be1c6c0 100644 --- a/src/betterproto/templates/header.py.j2 +++ b/src/betterproto/templates/header.py.j2 @@ -4,13 +4,13 @@ # This file has been @generated __all__ = ( - {%- for enum in output_file.enums -%} + {% for _, enum in output_file.enums|dictsort(by="key") %} "{{ enum.py_name }}", {%- endfor -%} - {%- for message in output_file.messages -%} + {% for _, message in output_file.messages|dictsort(by="key") %} "{{ message.py_name }}", {%- endfor -%} - {%- for service in output_file.services -%} + {% for _, service in output_file.services|dictsort(by="key") %} "{{ service.py_name }}Stub", "{{ service.py_name }}Base", {%- endfor -%} diff --git a/tests/test_all_definition.py b/tests/test_all_definition.py index 61abb5f37..ca0b03f5a 100644 --- a/tests/test_all_definition.py +++ b/tests/test_all_definition.py @@ -16,4 +16,4 @@ def test_all_definition(): "TestStub", "TestBase", ) - assert enum.__all__ == ("Choice", "ArithmeticOperator", "Test") + assert enum.__all__ == ("ArithmeticOperator", "Choice", "Test") From e5e99e43a4886eeb186d059e5a56d492387c044f Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Fri, 8 Nov 2024 15:27:04 +0100 Subject: [PATCH 18/20] Format --- src/betterproto/compile/importing.py | 2 +- tests/test_all_definition.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/betterproto/compile/importing.py b/src/betterproto/compile/importing.py index 0b66bfd5e..a3dc1a7d1 100644 --- a/src/betterproto/compile/importing.py +++ b/src/betterproto/compile/importing.py @@ -40,7 +40,7 @@ def parse_source_type_name( Split full source type name into package and type name. E.g. 'root.package.Message' -> ('root.package', 'Message') 'root.Message.SomeEnum' -> ('root', 'Message.SomeEnum') - + The function goes through the symbols that have been defined (names, enums, packages) to find the actual package and name of the object that is referenced. """ diff --git a/tests/test_all_definition.py b/tests/test_all_definition.py index ca0b03f5a..01743af77 100644 --- a/tests/test_all_definition.py +++ b/tests/test_all_definition.py @@ -16,4 +16,4 @@ def test_all_definition(): "TestStub", "TestBase", ) - assert enum.__all__ == ("ArithmeticOperator", "Choice", "Test") + assert enum.__all__ == ("ArithmeticOperator", "Choice", "Test") From bede1ac54ad44bc758e8e1ed0c064c65bc156782 Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Mon, 11 Nov 2024 10:58:04 +0100 Subject: [PATCH 19/20] Update typing --- src/betterproto/plugin/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index d51259809..826fc7813 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -600,8 +600,8 @@ def pydantic_imports(self) -> Set[str]: @dataclass class MapEntryCompiler(FieldCompiler): - py_k_type: Type = None - py_v_type: Type = None + py_k_type: Optional[Type] = None + py_v_type: Optional[Type] = None proto_k_type: str = "" proto_v_type: str = "" From ef271df3378a46c9252635289c3a2055b351bfad Mon Sep 17 00:00:00 2001 From: Adrien Vannson Date: Wed, 27 Nov 2024 18:43:11 +0100 Subject: [PATCH 20/20] Add comment --- src/betterproto/compile/importing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/betterproto/compile/importing.py b/src/betterproto/compile/importing.py index a3dc1a7d1..0998b8c03 100644 --- a/src/betterproto/compile/importing.py +++ b/src/betterproto/compile/importing.py @@ -61,6 +61,7 @@ def parse_source_type_name( if package := request.output_packages.get(package_name): if object_name in package.messages or object_name in package.enums: if answer: + # This should have already been handeled by protoc raise ValueError(f"ambiguous definition: {field_type_name}") answer = package_name, object_name