Skip to content

Commit

Permalink
Implement box_it for oneof fields (#150)
Browse files Browse the repository at this point in the history
This fixes codegen for messages that have recursive fields inside of a oneof.
  • Loading branch information
goffrie authored Oct 18, 2023
1 parent 238912e commit 4d8c43f
Show file tree
Hide file tree
Showing 6 changed files with 507 additions and 36 deletions.
60 changes: 27 additions & 33 deletions pb-jelly-gen/codegen/codegen.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
#!/usr/bin/env python3

import itertools
import os
import re
import sys

import google.protobuf

from collections import defaultdict, namedtuple, OrderedDict
from contextlib import contextmanager
from typing import (
Expand Down Expand Up @@ -35,7 +32,6 @@
OneofDescriptorProto,
SourceCodeInfo,
)
from google.protobuf.message import Message

from proto.rust import extensions_pb2

Expand Down Expand Up @@ -280,6 +276,8 @@ def custom_type(self) -> Text:
return self.field.options.Extensions[extensions_pb2.type]

def is_nullable(self) -> bool:
if self.oneof:
return False
if (
self.field.type in PRIMITIVE_TYPES
and self.is_proto3
Expand All @@ -301,7 +299,7 @@ def is_nullable(self) -> bool:

def is_empty_oneof_field(self) -> bool:
assert self.oneof
return self.field.type_name == ".google.protobuf.Empty"
return self.field.type_name == ".google.protobuf.Empty" and not self.is_boxed()

def can_be_packed(self) -> bool:
# Return true if incoming messages could be packed on the wire
Expand Down Expand Up @@ -367,7 +365,7 @@ def set_method(self) -> Tuple[Text, Text]:
return self.rust_type(), "v"
elif self.field.type == FieldDescriptorProto.TYPE_MESSAGE:
if self.is_boxed():
return "Box<%s>" % self.rust_type(), "v"
return "::std::boxed::Box<%s>" % self.rust_type(), "v"
else:
return self.rust_type(), "v"
raise AssertionError("Unexpected field type")
Expand Down Expand Up @@ -403,7 +401,7 @@ def take_method(self) -> Tuple[Optional[Text], Optional[Text]]:
return self.rust_type(), expr
elif self.field.type == FieldDescriptorProto.TYPE_MESSAGE:
if self.is_boxed():
return "Box<%s>" % self.rust_type(), expr
return "::std::boxed::Box<%s>" % self.rust_type(), expr
else:
return self.rust_type(), expr
raise AssertionError("Unexpected field type")
Expand Down Expand Up @@ -492,19 +490,18 @@ def rust_type(self) -> Text:
"Unsupported type: {!r}".format(FieldDescriptorProto.Type.Name(typ))
)

def __str__(self) -> str:
def storage_type(self) -> str:
rust_type = self.rust_type()

if self.is_boxed():
rust_type = "::std::boxed::Box<%s>" % rust_type

if self.is_repeated():
return "::std::vec::Vec<%s>" % rust_type
elif self.is_nullable() and self.is_boxed():
return "::std::option::Option<::std::boxed::Box<%s>>" % str(rust_type)
elif self.is_boxed():
return "::std::boxed::Box<%s>" % rust_type
rust_type = "::std::vec::Vec<%s>" % rust_type
elif self.is_nullable():
return "::std::option::Option<%s>" % rust_type
else:
return str(rust_type)
rust_type = "::std::option::Option<%s>" % rust_type

return rust_type

def oneof_field_match(self, var: Text) -> Text:
if self.is_empty_oneof_field():
Expand Down Expand Up @@ -584,6 +581,11 @@ def field_iter(
"let %s: &%s = &::std::default::Default::default();"
% (var, typ.rust_type())
)
elif typ.is_boxed():
ctx.write(
"let %(var)s: &%(typ)s = &**%(var)s;"
% dict(var=var, typ=typ.rust_type())
)
yield
elif (
field.type == FieldDescriptorProto.TYPE_MESSAGE
Expand Down Expand Up @@ -612,7 +614,7 @@ def field_iter(
with block(
ctx,
"if self.%s != <%s as ::std::default::Default>::default()"
% (field.name, typ),
% (field.name, typ.storage_type()),
):
if typ.is_boxed():
ctx.write("let %s = &*self.%s;" % (var, field.name))
Expand Down Expand Up @@ -937,7 +939,7 @@ def gen_msg(
if typ.oneof:
oneof_fields[typ.oneof.name].append(field)
else:
self.write("pub %s: %s," % (field.name, typ))
self.write("pub %s: %s," % (field.name, typ.storage_type()))

for oneof in oneof_decls:
if oneof_nullable(oneof):
Expand All @@ -959,7 +961,7 @@ def gen_msg(
with block(self, "pub enum " + oneof_msg_name(name, oneof)):
for oneof_field in oneof_fields[oneof.name]:
typ = self.rust_type(msg_type, oneof_field)
self.write("%s," % typ.oneof_field_match(typ.rust_type()))
self.write("%s," % typ.oneof_field_match(typ.storage_type()))

if not self.is_proto3:
with block(self, "impl " + name):
Expand Down Expand Up @@ -1461,6 +1463,8 @@ def gen_msg(
typ.oneof.name,
),
):
if typ.is_boxed():
self.write("let val = &mut **val;")
self.write(
"return ::pb_jelly::reflection::FieldMut::Value(val);"
)
Expand Down Expand Up @@ -1683,6 +1687,9 @@ def calc_impls(
msg_impls_eq = False
if not self.impls_by_msg[field_fq_msg].Copy:
msg_impls_copy = False

if rust_type.is_boxed():
msg_impls_copy = False
else:
raise RuntimeError(
"Unsupported type: {!r}".format(FieldDescriptorProto.Type.Name(typ))
Expand Down Expand Up @@ -1744,20 +1751,7 @@ def _set_boxed_if_recursive(
visited, looking_for, self.find_msg(field.type_name)
)
if need_box or field.type_name == looking_for:
# We only box normal fields, not oneof variants
#
# TODO: We are restricting this case because the codegen
# can't currently box oneof variants. This means there are
# cases won't work with the Rust codegen. Specifically, if
# you have a oneof variant that directly references the
# containing message or is co-recursive to another message,
# the codegen won't box the variant and the resulting code
# won't compile.
if not (
field.HasField("oneof_index")
and pt.typ.oneof_decl[field.oneof_index]
):
field.options.Extensions[extensions_pb2.box_it] = True
field.options.Extensions[extensions_pb2.box_it] = True
any_field_boxed = True
return any_field_boxed

Expand Down
4 changes: 2 additions & 2 deletions pb-test/gen/pb-jelly/proto_pbtest/src/pbtest2.rs.expected
Original file line number Diff line number Diff line change
Expand Up @@ -3085,10 +3085,10 @@ impl TestMessage {
pub fn has_optional_foreign_message_boxed(&self) -> bool {
self.optional_foreign_message_boxed.is_some()
}
pub fn set_optional_foreign_message_boxed(&mut self, v: Box<ForeignMessage>) {
pub fn set_optional_foreign_message_boxed(&mut self, v: ::std::boxed::Box<ForeignMessage>) {
self.optional_foreign_message_boxed = Some(v);
}
pub fn take_optional_foreign_message_boxed(&mut self) -> Box<ForeignMessage> {
pub fn take_optional_foreign_message_boxed(&mut self) -> ::std::boxed::Box<ForeignMessage> {
self.optional_foreign_message_boxed.take().unwrap_or_default()
}
pub fn get_optional_foreign_message_boxed(&self) -> &ForeignMessage {
Expand Down
Loading

0 comments on commit 4d8c43f

Please sign in to comment.