Skip to content

Commit

Permalink
Only generate compute_grpc_slices_size for types that use grpc_slices
Browse files Browse the repository at this point in the history
  • Loading branch information
goffrie committed Dec 14, 2023
1 parent 1111bf7 commit 77495ba
Show file tree
Hide file tree
Showing 10 changed files with 275 additions and 1,264 deletions.
37 changes: 29 additions & 8 deletions pb-jelly-gen/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,18 @@ def get_method(self) -> Tuple[Text, Text]:
)
raise AssertionError("Unexpected field type")

def may_use_grpc_slices(self) -> bool:
if (
self.has_custom_type()
or self.is_blob()
or self.is_grpc_slices()
or self.is_lazy_bytes()
):
return True
if self.field.type == FieldDescriptorProto.TYPE_MESSAGE:
return self.ctx.impls_by_msg[self.field.type_name].may_use_grpc_slices
return False

def rust_type(self) -> Text:
typ = self.field.type

Expand Down Expand Up @@ -1267,17 +1279,17 @@ def gen_msg(
else:
self.write("0")

with block(self, "fn compute_grpc_slices_size(&self) -> usize"):
if len(msg_type.field) > 0:
if impls.may_use_grpc_slices:
with block(self, "fn compute_grpc_slices_size(&self) -> usize"):
self.write("let mut size = 0;")
for field in msg_type.field:
with field_iter(self, "val", name, msg_type, field):
self.write(
"size += ::pb_jelly::Message::compute_grpc_slices_size(val);"
)
rust_type = RustType(self.ctx, self.proto_file, msg_type, field)
if rust_type.may_use_grpc_slices():
with field_iter(self, "val", name, msg_type, field):
self.write(
"size += ::pb_jelly::Message::compute_grpc_slices_size(val);"
)
self.write("size")
else:
self.write("0")

with block(
self,
Expand Down Expand Up @@ -1672,6 +1684,7 @@ def rust_name(self, other_crate: Text, other_mod_parts: List[Text]) -> Text:
class Impls(NamedTuple):
impls_eq: bool
impls_copy: bool
may_use_grpc_slices: bool


def box_recursive_fields(types: Dict[Text, ProtoType[DescriptorProto]]) -> None:
Expand Down Expand Up @@ -1739,6 +1752,7 @@ def calc_impls(
) -> None:
impls_eq = True
impls_copy = True
may_use_grpc_slices = False

for type_name in types:
msg_type = self.find(type_name)
Expand All @@ -1756,6 +1770,7 @@ def calc_impls(
self.extra_crates[crate].update(
CRATE_NAME_REGEX.findall(rust_type.custom_type())
)
may_use_grpc_slices = True

if field.type_name:
field_type = self.find(field.type_name)
Expand All @@ -1776,13 +1791,15 @@ def calc_impls(
):
(impls_eq, impls_copy) = (False, False) # Blob is not eq/copy
self.extra_crates[crate].add("blob_pb")
may_use_grpc_slices = True
# If we use a Bytes type
elif (
typ == FieldDescriptorProto.TYPE_BYTES
and field.options.Extensions[extensions_pb2.zero_copy]
):
(impls_eq, impls_copy) = (False, False)
self.extra_crates[crate].add("bytes")
may_use_grpc_slices = True
elif typ in PRIMITIVE_TYPES:
if not PRIMITIVE_TYPES[typ][1]:
impls_eq = False
Expand All @@ -1808,6 +1825,9 @@ def calc_impls(
field_impls = self.impls_by_msg[field.type_name]
impls_eq = impls_eq and field_impls.impls_eq
impls_copy = impls_copy and field_impls.impls_copy
may_use_grpc_slices = (
may_use_grpc_slices or field_impls.may_use_grpc_slices
)

if rust_type.is_boxed():
impls_copy = False
Expand All @@ -1822,6 +1842,7 @@ def calc_impls(
self.impls_by_msg[type_name] = Impls(
impls_eq=impls_eq,
impls_copy=impls_copy,
may_use_grpc_slices=may_use_grpc_slices,
)

def feed(self, proto_file: FileDescriptorProto, to_generate: List[Text]) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ impl ::pb_jelly::Message for Empty {
fn compute_size(&self) -> usize {
0
}
fn compute_grpc_slices_size(&self) -> usize {
0
}
fn serialize<W: ::pb_jelly::PbBufferWriter>(&self, w: &mut W) -> ::std::io::Result<()> {
Ok(())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,6 @@ impl ::pb_jelly::Message for NoPackage {
size += ::pb_jelly::helpers::compute_size_scalar::<::std::string::String>(&self.field, 1, ::pb_jelly::wire_format::Type::LengthDelimited);
size
}
fn compute_grpc_slices_size(&self) -> usize {
let mut size = 0;
if self.field != <::std::string::String as ::std::default::Default>::default() {
let val = &self.field;
size += ::pb_jelly::Message::compute_grpc_slices_size(val);
}
size
}
fn serialize<W: ::pb_jelly::PbBufferWriter>(&self, w: &mut W) -> ::std::io::Result<()> {
::pb_jelly::helpers::serialize_scalar::<W, ::std::string::String>(w, &self.field, 1, ::pb_jelly::wire_format::Type::LengthDelimited)?;
Ok(())
Expand Down
85 changes: 0 additions & 85 deletions pb-test/gen/pb-jelly/proto_pbtest/src/bench.rs.expected
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,6 @@ impl ::pb_jelly::Message for VecData {
size += data_size;
size
}
fn compute_grpc_slices_size(&self) -> usize {
let mut size = 0;
if let Some(ref val) = self.data {
size += ::pb_jelly::Message::compute_grpc_slices_size(val);
}
size
}
fn serialize<W: ::pb_jelly::PbBufferWriter>(&self, w: &mut W) -> ::std::io::Result<()> {
if let Some(ref val) = self.data {
::pb_jelly::wire_format::write(1, ::pb_jelly::wire_format::Type::LengthDelimited, w)?;
Expand Down Expand Up @@ -280,13 +273,6 @@ impl ::pb_jelly::Message for StringMessage {
size += data_size;
size
}
fn compute_grpc_slices_size(&self) -> usize {
let mut size = 0;
if let Some(ref val) = self.data {
size += ::pb_jelly::Message::compute_grpc_slices_size(val);
}
size
}
fn serialize<W: ::pb_jelly::PbBufferWriter>(&self, w: &mut W) -> ::std::io::Result<()> {
if let Some(ref val) = self.data {
::pb_jelly::wire_format::write(1, ::pb_jelly::wire_format::Type::LengthDelimited, w)?;
Expand Down Expand Up @@ -391,13 +377,6 @@ impl ::pb_jelly::Message for StringMessageSSO {
size += data_size;
size
}
fn compute_grpc_slices_size(&self) -> usize {
let mut size = 0;
if let Some(ref val) = self.data {
size += ::pb_jelly::Message::compute_grpc_slices_size(val);
}
size
}
fn serialize<W: ::pb_jelly::PbBufferWriter>(&self, w: &mut W) -> ::std::io::Result<()> {
if let Some(ref val) = self.data {
::pb_jelly::wire_format::write(1, ::pb_jelly::wire_format::Type::LengthDelimited, w)?;
Expand Down Expand Up @@ -502,13 +481,6 @@ impl ::pb_jelly::Message for Cities {
size += cities_size;
size
}
fn compute_grpc_slices_size(&self) -> usize {
let mut size = 0;
for val in &self.cities {
size += ::pb_jelly::Message::compute_grpc_slices_size(val);
}
size
}
fn serialize<W: ::pb_jelly::PbBufferWriter>(&self, w: &mut W) -> ::std::io::Result<()> {
for val in &self.cities {
::pb_jelly::wire_format::write(1, ::pb_jelly::wire_format::Type::LengthDelimited, w)?;
Expand Down Expand Up @@ -791,31 +763,6 @@ impl ::pb_jelly::Message for City {
size += state_size;
size
}
fn compute_grpc_slices_size(&self) -> usize {
let mut size = 0;
if let Some(ref val) = self.city {
size += ::pb_jelly::Message::compute_grpc_slices_size(val);
}
if let Some(ref val) = self.growth_from_2000_to_2013 {
size += ::pb_jelly::Message::compute_grpc_slices_size(val);
}
if let Some(ref val) = self.latitude {
size += ::pb_jelly::Message::compute_grpc_slices_size(val);
}
if let Some(ref val) = self.longitude {
size += ::pb_jelly::Message::compute_grpc_slices_size(val);
}
if let Some(ref val) = self.population {
size += ::pb_jelly::Message::compute_grpc_slices_size(val);
}
if let Some(ref val) = self.rank {
size += ::pb_jelly::Message::compute_grpc_slices_size(val);
}
if let Some(ref val) = self.state {
size += ::pb_jelly::Message::compute_grpc_slices_size(val);
}
size
}
fn serialize<W: ::pb_jelly::PbBufferWriter>(&self, w: &mut W) -> ::std::io::Result<()> {
if let Some(ref val) = self.city {
::pb_jelly::wire_format::write(1, ::pb_jelly::wire_format::Type::LengthDelimited, w)?;
Expand Down Expand Up @@ -994,13 +941,6 @@ impl ::pb_jelly::Message for CitiesSSO {
size += cities_size;
size
}
fn compute_grpc_slices_size(&self) -> usize {
let mut size = 0;
for val in &self.cities {
size += ::pb_jelly::Message::compute_grpc_slices_size(val);
}
size
}
fn serialize<W: ::pb_jelly::PbBufferWriter>(&self, w: &mut W) -> ::std::io::Result<()> {
for val in &self.cities {
::pb_jelly::wire_format::write(1, ::pb_jelly::wire_format::Type::LengthDelimited, w)?;
Expand Down Expand Up @@ -1283,31 +1223,6 @@ impl ::pb_jelly::Message for CitySSO {
size += state_size;
size
}
fn compute_grpc_slices_size(&self) -> usize {
let mut size = 0;
if let Some(ref val) = self.city {
size += ::pb_jelly::Message::compute_grpc_slices_size(val);
}
if let Some(ref val) = self.growth_from_2000_to_2013 {
size += ::pb_jelly::Message::compute_grpc_slices_size(val);
}
if let Some(ref val) = self.latitude {
size += ::pb_jelly::Message::compute_grpc_slices_size(val);
}
if let Some(ref val) = self.longitude {
size += ::pb_jelly::Message::compute_grpc_slices_size(val);
}
if let Some(ref val) = self.population {
size += ::pb_jelly::Message::compute_grpc_slices_size(val);
}
if let Some(ref val) = self.rank {
size += ::pb_jelly::Message::compute_grpc_slices_size(val);
}
if let Some(ref val) = self.state {
size += ::pb_jelly::Message::compute_grpc_slices_size(val);
}
size
}
fn serialize<W: ::pb_jelly::PbBufferWriter>(&self, w: &mut W) -> ::std::io::Result<()> {
if let Some(ref val) = self.city {
::pb_jelly::wire_format::write(1, ::pb_jelly::wire_format::Type::LengthDelimited, w)?;
Expand Down
3 changes: 0 additions & 3 deletions pb-test/gen/pb-jelly/proto_pbtest/src/mod/struct.rs.expected
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ impl ::pb_jelly::Message for Message {
fn compute_size(&self) -> usize {
0
}
fn compute_grpc_slices_size(&self) -> usize {
0
}
fn serialize<W: ::pb_jelly::PbBufferWriter>(&self, w: &mut W) -> ::std::io::Result<()> {
Ok(())
}
Expand Down
Loading

0 comments on commit 77495ba

Please sign in to comment.