Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix parse_source_type_name #635

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 33 additions & 10 deletions src/betterproto/compile/importing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


if TYPE_CHECKING:
from ..plugin.models import PluginRequestCompiler
from ..plugin.typing_compiler import TypingCompiler

WRAPPER_TYPES: Dict[str, Type] = {
Expand All @@ -32,20 +33,41 @@
}


def parse_source_type_name(field_type_name: str) -> Tuple[str, str]:
def parse_source_type_name(
field_type_name: str, request: "PluginRequestCompiler"
) -> Tuple[str, str]:
"""
Split full source type name into package and type name.
E.g. 'root.package.Message' -> ('root.package', 'Message')
'root.Message.SomeEnum' -> ('root', 'Message.SomeEnum')

The function goes through the symbols that have been defined (names, enums, packages) to find the actual package and
name of the object that is referenced.
"""
package_match = re.match(r"^\.?([^A-Z]+)\.(.+)", field_type_name)
if package_match:
package = package_match.group(1)
name = package_match.group(2)
else:
package = ""
name = field_type_name.lstrip(".")
return package, name
if field_type_name[0] != ".":
raise RuntimeError("relative names are not supported")
field_type_name = field_type_name[1:]
parts = field_type_name.split(".")

answer = None

# a.b.c:
# i=0: "", "a.b.c"
# i=1: "a", "b.c"
# i=2: "a.b", "c"
for i in range(len(parts)):
package_name, object_name = ".".join(parts[:i]), ".".join(parts[i:])

if package := request.output_packages.get(package_name):
if object_name in package.messages or object_name in package.enums:
if answer:
raise ValueError(f"ambiguous definition: {field_type_name}")
answer = package_name, object_name

if answer:
return answer

raise ValueError(f"can't find type name: {field_type_name}")


def get_type_reference(
Expand All @@ -54,6 +76,7 @@ def get_type_reference(
imports: set,
source_type: str,
typing_compiler: TypingCompiler,
request: "PluginRequestCompiler",
unwrap: bool = True,
pydantic: bool = False,
) -> str:
Expand All @@ -72,7 +95,7 @@ def get_type_reference(
elif source_type == ".google.protobuf.Timestamp":
return "datetime"

source_package, source_type = parse_source_type_name(source_type)
source_package, source_type = parse_source_type_name(source_type, request)

current_package: List[str] = package.split(".") if package else []
py_package: List[str] = source_package.split(".") if source_package else []
Expand Down
85 changes: 56 additions & 29 deletions src/betterproto/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,7 @@
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest

from .. import which_one_of
from ..compile.importing import (
get_type_reference,
parse_source_type_name,
)
from ..compile.importing import get_type_reference
from ..compile.naming import (
pythonize_class_name,
pythonize_enum_member_name,
Expand Down Expand Up @@ -205,6 +202,12 @@ def __post_init__(self) -> None:
if field_val is PLACEHOLDER:
raise ValueError(f"`{field_name}` is a required field.")

def ready(self) -> None:
"""
This function is called after all the compilers are created, but before generating the output code.
"""
pass

@property
def output_file(self) -> "OutputTemplate":
current = self
Expand All @@ -214,10 +217,7 @@ def output_file(self) -> "OutputTemplate":

@property
def request(self) -> "PluginRequestCompiler":
current = self
while not isinstance(current, OutputTemplate):
current = current.parent
return current.parent_request
return self.output_file.parent_request

@property
def comment(self) -> str:
Expand All @@ -228,6 +228,10 @@ def comment(self) -> str:
proto_file=self.source_file, path=self.path, indent=self.comment_indent
)

@property
def deprecated(self) -> bool:
return self.proto_obj.options.deprecated


@dataclass
class PluginRequestCompiler:
Expand All @@ -244,7 +248,9 @@ def all_messages(self) -> List["MessageCompiler"]:
List of all of the messages in this request.
"""
return [
msg for output in self.output_packages.values() for msg in output.messages
msg
for output in self.output_packages.values()
for msg in output.messages.values()
]


Expand All @@ -264,9 +270,9 @@ class OutputTemplate:
datetime_imports: Set[str] = field(default_factory=set)
pydantic_imports: Set[str] = field(default_factory=set)
builtins_import: bool = False
messages: List["MessageCompiler"] = field(default_factory=list)
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
services: List["ServiceCompiler"] = field(default_factory=list)
messages: Dict[str, "MessageCompiler"] = field(default_factory=dict)
enums: Dict[str, "EnumDefinitionCompiler"] = field(default_factory=dict)
services: Dict[str, "ServiceCompiler"] = field(default_factory=dict)
imports_type_checking_only: Set[str] = field(default_factory=set)
pydantic_dataclasses: bool = False
output: bool = True
Expand Down Expand Up @@ -299,13 +305,13 @@ def python_module_imports(self) -> Set[str]:
imports = set()

has_deprecated = False
if any(m.deprecated for m in self.messages):
if any(m.deprecated for m in self.messages.values()):
has_deprecated = True
if any(x for x in self.messages if any(x.deprecated_fields)):
if any(x for x in self.messages.values() if any(x.deprecated_fields)):
has_deprecated = True
if any(
any(m.proto_obj.options.deprecated for m in s.methods)
for s in self.services
for s in self.services.values()
):
has_deprecated = True

Expand All @@ -329,17 +335,15 @@ class MessageCompiler(ProtoContentBase):
fields: List[Union["FieldCompiler", "MessageCompiler"]] = field(
default_factory=list
)
deprecated: bool = field(default=False, init=False)
builtins_types: Set[str] = field(default_factory=set)

def __post_init__(self) -> None:
# Add message to output file
if isinstance(self.parent, OutputTemplate):
if isinstance(self, EnumDefinitionCompiler):
self.output_file.enums.append(self)
self.output_file.enums[self.proto_name] = self
else:
self.output_file.messages.append(self)
self.deprecated = self.proto_obj.options.deprecated
self.output_file.messages[self.proto_name] = self
super().__post_init__()

@property
Expand Down Expand Up @@ -417,16 +421,24 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:


@dataclass
class FieldCompiler(MessageCompiler):
class FieldCompiler(ProtoContentBase):
source_file: FileDescriptorProto
typing_compiler: TypingCompiler
path: List[int] = PLACEHOLDER
builtins_types: Set[str] = field(default_factory=set)

parent: MessageCompiler = PLACEHOLDER
proto_obj: FieldDescriptorProto = PLACEHOLDER

def __post_init__(self) -> None:
# Add field to message
self.parent.fields.append(self)
if isinstance(self.parent, MessageCompiler):
self.parent.fields.append(self)
super().__post_init__()

def ready(self) -> None:
# Check for new imports
self.add_imports_to(self.output_file)
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__

def get_field_string(self, indent: int = 4) -> str:
"""Construct string representation of this field as a field."""
Expand Down Expand Up @@ -544,6 +556,7 @@ def py_type(self) -> str:
imports=self.output_file.imports_end,
source_type=self.proto_obj.type_name,
typing_compiler=self.typing_compiler,
request=self.request,
pydantic=self.output_file.pydantic_dataclasses,
)
else:
Expand Down Expand Up @@ -587,12 +600,22 @@ def pydantic_imports(self) -> Set[str]:

@dataclass
class MapEntryCompiler(FieldCompiler):
py_k_type: Type = PLACEHOLDER
py_v_type: Type = PLACEHOLDER
proto_k_type: str = PLACEHOLDER
proto_v_type: str = PLACEHOLDER
py_k_type: Type = None
py_v_type: Type = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change this to pass type checking

proto_k_type: str = ""
proto_v_type: str = ""

def __post_init__(self) -> None:
def __post_init__(self):
map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry"
for nested in self.parent.proto_obj.nested_type:
if (
nested.name.replace("_", "").lower() == map_entry
and nested.options.map_entry
):
pass
return super().__post_init__()

def ready(self) -> None:
"""Explore nested types and set k_type and v_type if unset."""
map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry"
for nested in self.parent.proto_obj.nested_type:
Expand All @@ -617,7 +640,9 @@ def __post_init__(self) -> None:
# Get proto types
self.proto_k_type = FieldDescriptorProtoType(nested.field[0].type).name
self.proto_v_type = FieldDescriptorProtoType(nested.field[1].type).name
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
return

raise ValueError("can't find enum")

@property
def betterproto_field_args(self) -> List[str]:
Expand Down Expand Up @@ -678,7 +703,7 @@ class ServiceCompiler(ProtoContentBase):

def __post_init__(self) -> None:
# Add service to output file
self.output_file.services.append(self)
self.output_file.services[self.proto_name] = self
super().__post_init__() # check for unset fields

@property
Expand Down Expand Up @@ -744,6 +769,7 @@ def py_input_message_type(self) -> str:
imports=self.output_file.imports_end,
source_type=self.proto_obj.input_type,
typing_compiler=self.output_file.typing_compiler,
request=self.request,
unwrap=False,
pydantic=self.output_file.pydantic_dataclasses,
).strip('"')
Expand Down Expand Up @@ -774,6 +800,7 @@ def py_output_message_type(self) -> str:
imports=self.output_file.imports_end,
source_type=self.proto_obj.output_type,
typing_compiler=self.output_file.typing_compiler,
request=self.request,
unwrap=False,
pydantic=self.output_file.pydantic_dataclasses,
).strip('"')
Expand Down
24 changes: 22 additions & 2 deletions src/betterproto/plugin/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from .typing_compiler import (
DirectImportTypingCompiler,
NoTyping310TypingCompiler,
TypingCompiler,
TypingImportTypingCompiler,
)

Expand All @@ -61,7 +60,13 @@ def _traverse(
for i, item in enumerate(items):
# Adjust the name since we flatten the hierarchy.
# Todo: don't change the name, but include full name in returned tuple
item.name = next_prefix = f"{prefix}_{item.name}"
should_rename = (
not isinstance(item, DescriptorProto) or not item.options.map_entry
)

item.name = next_prefix = (
f"{prefix}.{item.name}" if prefix and should_rename else item.name
)
yield item, [*path, i]

if isinstance(item, DescriptorProto):
Expand Down Expand Up @@ -145,6 +150,21 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
for index, service in enumerate(proto_input_file.service):
read_protobuf_service(proto_input_file, service, index, output_package)

# All the hierarchy is ready. We can perform pre-computations before generating the output files
for package in request_data.output_packages.values():
for message in package.messages.values():
for field in message.fields:
field.ready()
message.ready()
for enum in package.enums.values():
for variant in enum.fields:
variant.ready()
enum.ready()
for service in package.services.values():
for method in service.methods:
method.ready()
service.ready()

# Generate output files
output_paths: Set[pathlib.Path] = set()
for output_package_name, output_package in request_data.output_packages.items():
Expand Down
6 changes: 3 additions & 3 deletions src/betterproto/templates/header.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
# This file has been @generated

__all__ = (
{%- for enum in output_file.enums -%}
{% for _, enum in output_file.enums|dictsort(by="key") %}
"{{ enum.py_name }}",
{%- endfor -%}
{%- for message in output_file.messages -%}
{% for _, message in output_file.messages|dictsort(by="key") %}
"{{ message.py_name }}",
{%- endfor -%}
{%- for service in output_file.services -%}
{% for _, service in output_file.services|dictsort(by="key") %}
Comment on lines -7 to +13
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come these are being sorted now?

"{{ service.py_name }}Stub",
"{{ service.py_name }}Base",
{%- endfor -%}
Expand Down
8 changes: 4 additions & 4 deletions src/betterproto/templates/template.py.j2
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{% if output_file.enums %}{% for enum in output_file.enums %}
{% if output_file.enums %}{% for _, enum in output_file.enums|dictsort(by="key") %}
class {{ enum.py_name }}(betterproto.Enum):
{% if enum.comment %}
{{ enum.comment }}
Expand All @@ -22,7 +22,7 @@ class {{ enum.py_name }}(betterproto.Enum):

{% endfor %}
{% endif %}
{% for message in output_file.messages %}
{% for _, message in output_file.messages|dictsort(by="key") %}
{% if output_file.pydantic_dataclasses %}
@dataclass(eq=False, repr=False, config={"extra": "forbid"})
{% else %}
Expand Down Expand Up @@ -63,7 +63,7 @@ class {{ message.py_name }}(betterproto.Message):
{% endif %}

{% endfor %}
{% for service in output_file.services %}
{% for _, service in output_file.services|dictsort(by="key") %}
class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% if service.comment %}
{{ service.comment }}
Expand Down Expand Up @@ -147,7 +147,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{{ i }}
{% endfor %}

{% for service in output_file.services %}
{% for _, service in output_file.services|dictsort(by="key") %}
class {{ service.py_name }}Base(ServiceBase):
{% if service.comment %}
{{ service.comment }}
Expand Down
1 change: 0 additions & 1 deletion tests/inputs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"namespace_keywords", # 70
"googletypes_struct", # 9
"googletypes_value", # 9
"import_capitalized_package",
"example", # This is the example in the readme. Not a test.
}

Expand Down
7 changes: 7 additions & 0 deletions tests/inputs/import_child_scoping_rules/child.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
syntax = "proto3";

package import_child_scoping_rules.aaa.bbb.ccc.ddd;

message ChildMessage {

}
Loading
Loading