From 9f253b37a0cdedb8cc7d595e1f6b1eb91cfeeb77 Mon Sep 17 00:00:00 2001 From: Geoffry Song Date: Tue, 17 Oct 2023 13:52:16 -0700 Subject: [PATCH] Get rid of maybe_boxed, rename __str__ to storage_type, fix a bug --- pb-jelly-gen/codegen/codegen.py | 41 ++++--- .../proto_pbtest/src/pbtest3.rs.expected | 104 ++++++++++++++++++ pb-test/proto/packages/pbtest/pbtest3.proto | 4 + 3 files changed, 134 insertions(+), 15 deletions(-) diff --git a/pb-jelly-gen/codegen/codegen.py b/pb-jelly-gen/codegen/codegen.py index 762c1d7..a07965b 100755 --- a/pb-jelly-gen/codegen/codegen.py +++ b/pb-jelly-gen/codegen/codegen.py @@ -276,6 +276,8 @@ def custom_type(self) -> Text: return self.field.options.Extensions[extensions_pb2.type] def is_nullable(self) -> bool: + if self.oneof: + return False if ( self.field.type in PRIMITIVE_TYPES and self.is_proto3 @@ -362,7 +364,10 @@ def set_method(self) -> Tuple[Text, Text]: elif self.field.type == FieldDescriptorProto.TYPE_ENUM: return self.rust_type(), "v" elif self.field.type == FieldDescriptorProto.TYPE_MESSAGE: - return self.rust_type(maybe_boxed=True), "v" + if self.is_boxed(): + return "::std::boxed::Box<%s>" % self.rust_type(), "v" + else: + return self.rust_type(), "v" raise AssertionError("Unexpected field type") def take_method(self) -> Tuple[Optional[Text], Optional[Text]]: @@ -395,7 +400,10 @@ def take_method(self) -> Tuple[Optional[Text], Optional[Text]]: elif self.field.type == FieldDescriptorProto.TYPE_ENUM: return self.rust_type(), expr elif self.field.type == FieldDescriptorProto.TYPE_MESSAGE: - return self.rust_type(maybe_boxed=True), expr + if self.is_boxed(): + return "::std::boxed::Box<%s>" % self.rust_type(), expr + else: + return self.rust_type(), expr raise AssertionError("Unexpected field type") def get_method(self) -> Tuple[Text, Text]: @@ -446,10 +454,7 @@ def get_method(self) -> Tuple[Text, Text]: ) raise AssertionError("Unexpected field type") - def rust_type(self, maybe_boxed: bool = False) -> Text: - if maybe_boxed and self.is_boxed(): - return "::std::boxed::Box<%s>" % self.rust_type(maybe_boxed=False) - + def rust_type(self) -> Text: typ = self.field.type if self.has_custom_type(): @@ -485,15 +490,18 @@ def rust_type(self, maybe_boxed: bool = False) -> Text: "Unsupported type: {!r}".format(FieldDescriptorProto.Type.Name(typ)) ) - def __str__(self) -> str: - rust_type = self.rust_type(maybe_boxed=True) + def storage_type(self) -> str: + rust_type = self.rust_type() + + if self.is_boxed(): + rust_type = "::std::boxed::Box<%s>" % rust_type if self.is_repeated(): - return "::std::vec::Vec<%s>" % rust_type + rust_type = "::std::vec::Vec<%s>" % rust_type elif self.is_nullable(): - return "::std::option::Option<%s>" % rust_type - else: - return rust_type + rust_type = "::std::option::Option<%s>" % rust_type + + return rust_type def oneof_field_match(self, var: Text) -> Text: if self.is_empty_oneof_field(): @@ -606,7 +614,7 @@ def field_iter( with block( ctx, "if self.%s != <%s as ::std::default::Default>::default()" - % (field.name, typ), + % (field.name, typ.storage_type()), ): if typ.is_boxed(): ctx.write("let %s = &*self.%s;" % (var, field.name)) @@ -931,7 +939,7 @@ def gen_msg( if typ.oneof: oneof_fields[typ.oneof.name].append(field) else: - self.write("pub %s: %s," % (field.name, typ)) + self.write("pub %s: %s," % (field.name, typ.storage_type())) for oneof in oneof_decls: if oneof_nullable(oneof): @@ -954,7 +962,7 @@ def gen_msg( for oneof_field in oneof_fields[oneof.name]: typ = self.rust_type(msg_type, oneof_field) self.write( - "%s," % typ.oneof_field_match(typ.rust_type(maybe_boxed=True)) + "%s," % typ.oneof_field_match(typ.storage_type()) ) if not self.is_proto3: @@ -1681,6 +1689,9 @@ def calc_impls( msg_impls_eq = False if not self.impls_by_msg[field_fq_msg].Copy: msg_impls_copy = False + + if rust_type.is_boxed(): + msg_impls_copy = False else: raise RuntimeError( "Unsupported type: {!r}".format(FieldDescriptorProto.Type.Name(typ)) diff --git a/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest3.rs.expected b/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest3.rs.expected index 54ba73f..571b2fe 100644 --- a/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest3.rs.expected +++ b/pb-test/gen/pb-jelly/proto_pbtest/src/pbtest3.rs.expected @@ -5817,6 +5817,110 @@ impl ::pb_jelly::Reflection for TestMessage3_NestedMessage_Dir { } } +#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub struct TestBoxedNonnullable { + pub field: ::std::boxed::Box, +} +impl ::std::default::Default for TestBoxedNonnullable { + fn default() -> Self { + TestBoxedNonnullable { + field: ::std::default::Default::default(), + } + } +} +lazy_static! { + pub static ref TestBoxedNonnullable_default: TestBoxedNonnullable = TestBoxedNonnullable::default(); +} +impl ::pb_jelly::Message for TestBoxedNonnullable { + fn descriptor(&self) -> ::std::option::Option<::pb_jelly::MessageDescriptor> { + Some(::pb_jelly::MessageDescriptor { + name: "TestBoxedNonnullable", + full_name: "pbtest.TestBoxedNonnullable", + fields: &[ + ::pb_jelly::FieldDescriptor { + name: "field", + full_name: "pbtest.TestBoxedNonnullable.field", + index: 0, + number: 1, + typ: ::pb_jelly::wire_format::Type::LengthDelimited, + label: ::pb_jelly::Label::Optional, + oneof_index: None, + }, + ], + oneofs: &[ + ], + }) + } + fn compute_size(&self) -> usize { + let mut size = 0; + let mut field_size = 0; + { + let val = &*self.field; + let l = ::pb_jelly::Message::compute_size(val); + field_size += ::pb_jelly::wire_format::serialized_length(1); + field_size += ::pb_jelly::varint::serialized_length(l as u64); + field_size += l; + } + size += field_size; + size + } + fn compute_grpc_slices_size(&self) -> usize { + let mut size = 0; + { + let val = &*self.field; + size += ::pb_jelly::Message::compute_grpc_slices_size(val); + } + size + } + fn serialize(&self, w: &mut W) -> ::std::io::Result<()> { + { + let val = &*self.field; + ::pb_jelly::wire_format::write(1, ::pb_jelly::wire_format::Type::LengthDelimited, w)?; + let l = ::pb_jelly::Message::compute_size(val); + ::pb_jelly::varint::write(l as u64, w)?; + ::pb_jelly::Message::serialize(val, w)?; + } + Ok(()) + } + fn deserialize(&mut self, mut buf: &mut B) -> ::std::io::Result<()> { + while let Some((field_number, typ)) = ::pb_jelly::wire_format::read(&mut buf)? { + match field_number { + 1 => { + ::pb_jelly::ensure_wire_format(typ, ::pb_jelly::wire_format::Type::LengthDelimited, "TestBoxedNonnullable", 1)?; + let len = ::pb_jelly::varint::ensure_read(&mut buf)?; + let mut next = ::pb_jelly::ensure_split(buf, len as usize)?; + let mut val: ForeignMessage3 = ::std::default::Default::default(); + ::pb_jelly::Message::deserialize(&mut val, &mut next)?; + self.field = Box::new(val); + } + _ => { + ::pb_jelly::skip(typ, &mut buf)?; + } + } + } + Ok(()) + } +} +impl ::pb_jelly::Reflection for TestBoxedNonnullable { + fn which_one_of(&self, oneof_name: &str) -> ::std::option::Option<&'static str> { + match oneof_name { + _ => { + panic!("unknown oneof name given"); + } + } + } + fn get_field_mut(&mut self, field_name: &str) -> ::pb_jelly::reflection::FieldMut<'_> { + match field_name { + "field" => { + ::pb_jelly::reflection::FieldMut::Value(self.field.as_mut()) + } + _ => { + panic!("unknown field name given") + } + } + } +} + #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct TestMessage3NonNullableOneof { pub other_field: u64, diff --git a/pb-test/proto/packages/pbtest/pbtest3.proto b/pb-test/proto/packages/pbtest/pbtest3.proto index b8099d9..a05b619 100644 --- a/pb-test/proto/packages/pbtest/pbtest3.proto +++ b/pb-test/proto/packages/pbtest/pbtest3.proto @@ -208,6 +208,10 @@ message TestMessage3 { repeated bytes zero_or_fixed_length_repeated = 79 [(rust.type)="Option<[u8; 4]>"]; } +message TestBoxedNonnullable { + ForeignMessage3 field = 1 [(rust.box_it)=true, (rust.nullable_field)=false]; +} + message TestMessage3NonNullableOneof { oneof non_nullable_oneof { option (rust.nullable) = false;