Skip to content

Commit

Permalink
Get rid of maybe_boxed, rename __str__ to storage_type, fix a bug
Browse files Browse the repository at this point in the history
  • Loading branch information
goffrie committed Oct 17, 2023
1 parent 00fccbf commit 9f253b3
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 15 deletions.
41 changes: 26 additions & 15 deletions pb-jelly-gen/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ def custom_type(self) -> Text:
return self.field.options.Extensions[extensions_pb2.type]

def is_nullable(self) -> bool:
if self.oneof:
return False
if (
self.field.type in PRIMITIVE_TYPES
and self.is_proto3
Expand Down Expand Up @@ -362,7 +364,10 @@ def set_method(self) -> Tuple[Text, Text]:
elif self.field.type == FieldDescriptorProto.TYPE_ENUM:
return self.rust_type(), "v"
elif self.field.type == FieldDescriptorProto.TYPE_MESSAGE:
return self.rust_type(maybe_boxed=True), "v"
if self.is_boxed():
return "::std::boxed::Box<%s>" % self.rust_type(), "v"
else:
return self.rust_type(), "v"
raise AssertionError("Unexpected field type")

def take_method(self) -> Tuple[Optional[Text], Optional[Text]]:
Expand Down Expand Up @@ -395,7 +400,10 @@ def take_method(self) -> Tuple[Optional[Text], Optional[Text]]:
elif self.field.type == FieldDescriptorProto.TYPE_ENUM:
return self.rust_type(), expr
elif self.field.type == FieldDescriptorProto.TYPE_MESSAGE:
return self.rust_type(maybe_boxed=True), expr
if self.is_boxed():
return "::std::boxed::Box<%s>" % self.rust_type(), expr
else:
return self.rust_type(), expr
raise AssertionError("Unexpected field type")

def get_method(self) -> Tuple[Text, Text]:
Expand Down Expand Up @@ -446,10 +454,7 @@ def get_method(self) -> Tuple[Text, Text]:
)
raise AssertionError("Unexpected field type")

def rust_type(self, maybe_boxed: bool = False) -> Text:
if maybe_boxed and self.is_boxed():
return "::std::boxed::Box<%s>" % self.rust_type(maybe_boxed=False)

def rust_type(self) -> Text:
typ = self.field.type

if self.has_custom_type():
Expand Down Expand Up @@ -485,15 +490,18 @@ def rust_type(self, maybe_boxed: bool = False) -> Text:
"Unsupported type: {!r}".format(FieldDescriptorProto.Type.Name(typ))
)

def __str__(self) -> str:
rust_type = self.rust_type(maybe_boxed=True)
def storage_type(self) -> str:
rust_type = self.rust_type()

if self.is_boxed():
rust_type = "::std::boxed::Box<%s>" % rust_type

if self.is_repeated():
return "::std::vec::Vec<%s>" % rust_type
rust_type = "::std::vec::Vec<%s>" % rust_type
elif self.is_nullable():
return "::std::option::Option<%s>" % rust_type
else:
return rust_type
rust_type = "::std::option::Option<%s>" % rust_type

return rust_type

def oneof_field_match(self, var: Text) -> Text:
if self.is_empty_oneof_field():
Expand Down Expand Up @@ -606,7 +614,7 @@ def field_iter(
with block(
ctx,
"if self.%s != <%s as ::std::default::Default>::default()"
% (field.name, typ),
% (field.name, typ.storage_type()),
):
if typ.is_boxed():
ctx.write("let %s = &*self.%s;" % (var, field.name))
Expand Down Expand Up @@ -931,7 +939,7 @@ def gen_msg(
if typ.oneof:
oneof_fields[typ.oneof.name].append(field)
else:
self.write("pub %s: %s," % (field.name, typ))
self.write("pub %s: %s," % (field.name, typ.storage_type()))

for oneof in oneof_decls:
if oneof_nullable(oneof):
Expand All @@ -954,7 +962,7 @@ def gen_msg(
for oneof_field in oneof_fields[oneof.name]:
typ = self.rust_type(msg_type, oneof_field)
self.write(
"%s," % typ.oneof_field_match(typ.rust_type(maybe_boxed=True))
"%s," % typ.oneof_field_match(typ.storage_type())
)

if not self.is_proto3:
Expand Down Expand Up @@ -1681,6 +1689,9 @@ def calc_impls(
msg_impls_eq = False
if not self.impls_by_msg[field_fq_msg].Copy:
msg_impls_copy = False

if rust_type.is_boxed():
msg_impls_copy = False
else:
raise RuntimeError(
"Unsupported type: {!r}".format(FieldDescriptorProto.Type.Name(typ))
Expand Down
104 changes: 104 additions & 0 deletions pb-test/gen/pb-jelly/proto_pbtest/src/pbtest3.rs.expected
Original file line number Diff line number Diff line change
Expand Up @@ -5817,6 +5817,110 @@ impl ::pb_jelly::Reflection for TestMessage3_NestedMessage_Dir {
}
}

#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct TestBoxedNonnullable {
pub field: ::std::boxed::Box<ForeignMessage3>,
}
impl ::std::default::Default for TestBoxedNonnullable {
fn default() -> Self {
TestBoxedNonnullable {
field: ::std::default::Default::default(),
}
}
}
lazy_static! {
pub static ref TestBoxedNonnullable_default: TestBoxedNonnullable = TestBoxedNonnullable::default();
}
impl ::pb_jelly::Message for TestBoxedNonnullable {
fn descriptor(&self) -> ::std::option::Option<::pb_jelly::MessageDescriptor> {
Some(::pb_jelly::MessageDescriptor {
name: "TestBoxedNonnullable",
full_name: "pbtest.TestBoxedNonnullable",
fields: &[
::pb_jelly::FieldDescriptor {
name: "field",
full_name: "pbtest.TestBoxedNonnullable.field",
index: 0,
number: 1,
typ: ::pb_jelly::wire_format::Type::LengthDelimited,
label: ::pb_jelly::Label::Optional,
oneof_index: None,
},
],
oneofs: &[
],
})
}
fn compute_size(&self) -> usize {
let mut size = 0;
let mut field_size = 0;
{
let val = &*self.field;
let l = ::pb_jelly::Message::compute_size(val);
field_size += ::pb_jelly::wire_format::serialized_length(1);
field_size += ::pb_jelly::varint::serialized_length(l as u64);
field_size += l;
}
size += field_size;
size
}
fn compute_grpc_slices_size(&self) -> usize {
let mut size = 0;
{
let val = &*self.field;
size += ::pb_jelly::Message::compute_grpc_slices_size(val);
}
size
}
fn serialize<W: ::pb_jelly::PbBufferWriter>(&self, w: &mut W) -> ::std::io::Result<()> {
{
let val = &*self.field;
::pb_jelly::wire_format::write(1, ::pb_jelly::wire_format::Type::LengthDelimited, w)?;
let l = ::pb_jelly::Message::compute_size(val);
::pb_jelly::varint::write(l as u64, w)?;
::pb_jelly::Message::serialize(val, w)?;
}
Ok(())
}
fn deserialize<B: ::pb_jelly::PbBufferReader>(&mut self, mut buf: &mut B) -> ::std::io::Result<()> {
while let Some((field_number, typ)) = ::pb_jelly::wire_format::read(&mut buf)? {
match field_number {
1 => {
::pb_jelly::ensure_wire_format(typ, ::pb_jelly::wire_format::Type::LengthDelimited, "TestBoxedNonnullable", 1)?;
let len = ::pb_jelly::varint::ensure_read(&mut buf)?;
let mut next = ::pb_jelly::ensure_split(buf, len as usize)?;
let mut val: ForeignMessage3 = ::std::default::Default::default();
::pb_jelly::Message::deserialize(&mut val, &mut next)?;
self.field = Box::new(val);
}
_ => {
::pb_jelly::skip(typ, &mut buf)?;
}
}
}
Ok(())
}
}
impl ::pb_jelly::Reflection for TestBoxedNonnullable {
fn which_one_of(&self, oneof_name: &str) -> ::std::option::Option<&'static str> {
match oneof_name {
_ => {
panic!("unknown oneof name given");
}
}
}
fn get_field_mut(&mut self, field_name: &str) -> ::pb_jelly::reflection::FieldMut<'_> {
match field_name {
"field" => {
::pb_jelly::reflection::FieldMut::Value(self.field.as_mut())
}
_ => {
panic!("unknown field name given")
}
}
}
}

#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct TestMessage3NonNullableOneof {
pub other_field: u64,
Expand Down
4 changes: 4 additions & 0 deletions pb-test/proto/packages/pbtest/pbtest3.proto
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ message TestMessage3 {
repeated bytes zero_or_fixed_length_repeated = 79 [(rust.type)="Option<[u8; 4]>"];
}

message TestBoxedNonnullable {
ForeignMessage3 field = 1 [(rust.box_it)=true, (rust.nullable_field)=false];
}

message TestMessage3NonNullableOneof {
oneof non_nullable_oneof {
option (rust.nullable) = false;
Expand Down

0 comments on commit 9f253b3

Please sign in to comment.