Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rust: make generated views implement std::default::Default #18436

Merged
merged 1 commit into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rust/codegen_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub trait Message: SealedInternal
pub trait MessageView<'msg>: SealedInternal
+ ViewProxy<'msg, Proxied = Self::Message>
// Read traits:
+ Debug + Serialize
+ Debug + Serialize + Default
// Thread safety:
+ Send + Sync
// Copy/Clone:
Expand Down
7 changes: 7 additions & 0 deletions rust/test/shared/serialization_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

use googletest::prelude::*;
use protobuf::prelude::*;
use protobuf::View;

use edition_unittest_rust_proto::TestAllTypes as TestAllTypesEditions;
use paste::paste;
Expand All @@ -31,6 +32,12 @@ macro_rules! generate_parameterized_serialization_test {
assert_that!(serialized.len(), eq(0));
}

#[gtest]
fn [< serialize_default_view $name_ext>]() {
let default = View::<[< $type >]>::default();
assert_that!(default.serialize().unwrap().len(), eq(0));
}

#[gtest]
fn [< serialize_deserialize_message_ $name_ext>]() {
let mut msg = [< $type >]::new();
Expand Down
25 changes: 25 additions & 0 deletions src/google/protobuf/compiler/rust/message.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ void CppMessageExterns(Context& ctx, const Descriptor& msg) {
ABSL_CHECK(ctx.is_cpp());
ctx.Emit(
{{"new_thunk", ThunkName(ctx, msg, "new")},
{"default_instance_thunk", ThunkName(ctx, msg, "default_instance")},
{"repeated_new_thunk", ThunkName(ctx, msg, "repeated_new")},
{"repeated_free_thunk", ThunkName(ctx, msg, "repeated_free")},
{"repeated_len_thunk", ThunkName(ctx, msg, "repeated_len")},
Expand All @@ -207,6 +208,7 @@ void CppMessageExterns(Context& ctx, const Descriptor& msg) {
{"map_size_info_thunk", ThunkName(ctx, msg, "size_info")}},
R"rs(
fn $new_thunk$() -> $pbr$::RawMessage;
fn $default_instance_thunk$() -> $pbr$::RawMessage;
fn $repeated_new_thunk$() -> $pbr$::RawRepeatedField;
fn $repeated_free_thunk$(raw: $pbr$::RawRepeatedField);
fn $repeated_len_thunk$(raw: $pbr$::RawRepeatedField) -> usize;
Expand Down Expand Up @@ -733,6 +735,16 @@ void MessageProxiedInMapValue(Context& ctx, const Descriptor& msg) {
}
}

void GenerateDefaultInstanceImpl(Context& ctx, const Descriptor& msg) {
if (ctx.is_upb()) {
ctx.Emit("$pbr$::ScratchSpace::zeroed_block()");
} else {
ctx.Emit(
{{"default_instance_thunk", ThunkName(ctx, msg, "default_instance")}},
"unsafe { $default_instance_thunk$() }");
}
}

} // namespace

void GenerateRs(Context& ctx, const Descriptor& msg) {
Expand All @@ -749,6 +761,8 @@ void GenerateRs(Context& ctx, const Descriptor& msg) {
{"Msg::drop", [&] { MessageDrop(ctx, msg); }},
{"Msg::debug", [&] { MessageDebug(ctx, msg); }},
{"MsgMut::merge_from", [&] { MessageMutMergeFrom(ctx, msg); }},
{"default_instance_impl",
[&] { GenerateDefaultInstanceImpl(ctx, msg); }},
{"accessor_fns",
[&] {
for (int i = 0; i < msg.field_count(); ++i) {
Expand Down Expand Up @@ -948,6 +962,12 @@ void GenerateRs(Context& ctx, const Descriptor& msg) {
}
}

impl $std$::default::Default for $Msg$View<'_> {
fn default() -> $Msg$View<'static> {
$Msg$View::new($pbi$::Private, $default_instance_impl$)
}
}

#[allow(dead_code)]
impl<'msg> $Msg$View<'msg> {
#[doc(hidden)]
Expand Down Expand Up @@ -1367,6 +1387,7 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) {
{"Msg", RsSafeName(msg.name())},
{"QualifiedMsg", cpp::QualifiedClassName(&msg)},
{"new_thunk", ThunkName(ctx, msg, "new")},
{"default_instance_thunk", ThunkName(ctx, msg, "default_instance")},
{"repeated_new_thunk", ThunkName(ctx, msg, "repeated_new")},
{"repeated_free_thunk", ThunkName(ctx, msg, "repeated_free")},
{"repeated_len_thunk", ThunkName(ctx, msg, "repeated_len")},
Expand Down Expand Up @@ -1403,6 +1424,10 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) {
extern $abi$ {
void* $new_thunk$() { return new $QualifiedMsg$(); }

const google::protobuf::MessageLite* $default_instance_thunk$() {
return &$QualifiedMsg$::default_instance();
}

void* $repeated_new_thunk$() {
return new google::protobuf::RepeatedPtrField<$QualifiedMsg$>();
}
Expand Down
Loading