Skip to content

Commit

Permalink
fix(pilota-build): cyclic gen code (#205)
Browse files Browse the repository at this point in the history
  • Loading branch information
PureWhiteWu authored Oct 26, 2023
1 parent b9fff2c commit aca6e7e
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 50 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pilota-build/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pilota-build"
version = "0.9.0"
version = "0.9.1"
edition = "2021"
description = "Compile thrift and protobuf idl into rust code at compile-time."
documentation = "https://docs.rs/pilota-build"
Expand Down
41 changes: 34 additions & 7 deletions pilota-build/src/codegen/thrift/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,32 @@ impl ThriftBackend {

fn codegen_impl_message(
&self,
def_id: DefId,
name: Symbol,
encode: String,
size: String,
decode: String,
decode_async: String,
) -> String {
let decode_async_fn = if self.cx().db.type_graph().is_cycled(def_id) {
format!(
r#"async fn decode_async<T: ::pilota::thrift::TAsyncInputProtocol>(
protocol: &mut T,
) -> ::std::result::Result<Self,::pilota::thrift::DecodeError> {{
::std::boxed::Box::pin(async move {{
{decode_async}
}}).await
}}"#
)
} else {
format!(
r#"async fn decode_async<T: ::pilota::thrift::TAsyncInputProtocol>(
protocol: &mut T,
) -> ::std::result::Result<Self,::pilota::thrift::DecodeError> {{
{decode_async}
}}"#
)
};
format! {r#"
impl ::pilota::thrift::Message for {name} {{
fn encode<T: ::pilota::thrift::TOutputProtocol>(
Expand All @@ -118,11 +138,7 @@ impl ThriftBackend {
{decode}
}}
async fn decode_async<T: ::pilota::thrift::TAsyncInputProtocol>(
protocol: &mut T,
) -> ::std::result::Result<Self,::pilota::thrift::DecodeError> {{
{decode_async}
}}
{decode_async_fn}
fn size<T: ::pilota::thrift::TLengthProtocol>(&self, protocol: &mut T) -> usize {{
#[allow(unused_imports)]
Expand All @@ -134,14 +150,22 @@ impl ThriftBackend {

fn codegen_impl_message_with_helper<F: Fn(&DecodeHelper) -> String>(
&self,
def_id: DefId,
name: Symbol,
encode: String,
size: String,
decode: F,
) -> String {
let decode_stream = decode(&DecodeHelper::new(false));
let decode_async_stream = decode(&DecodeHelper::new(true));
self.codegen_impl_message(name, encode, size, decode_stream, decode_async_stream)
self.codegen_impl_message(
def_id,
name,
encode,
size,
decode_stream,
decode_async_stream,
)
}

fn codegen_decode(
Expand Down Expand Up @@ -440,6 +464,7 @@ impl CodegenBackend for ThriftBackend {
encode_fields_size.push_str("self._unknown_fields.size() +");
}
stream.push_str(&self.codegen_impl_message_with_helper(
def_id,
name,
format! {
r#"let struct_ident =::pilota::thrift::TStructIdentifier {{
Expand Down Expand Up @@ -485,6 +510,7 @@ impl CodegenBackend for ThriftBackend {
};
match e.repr {
Some(EnumRepr::I32) => stream.push_str(&self.codegen_impl_message_with_helper(
def_id,
name.clone(),
format! {
r#"protocol.write_i32({v})?;
Expand Down Expand Up @@ -571,7 +597,7 @@ impl CodegenBackend for ThriftBackend {
&*v.name.sym == "Ok" && v.fields.len() == 1 && v.fields[0].kind == TyKind::Void
};

stream.push_str(&self.codegen_impl_message_with_helper(
stream.push_str(&self.codegen_impl_message_with_helper(def_id,
name.clone(),
format! {
r#"protocol.write_struct_begin(&::pilota::thrift::TStructIdentifier {{
Expand Down Expand Up @@ -727,6 +753,7 @@ impl CodegenBackend for ThriftBackend {
let encode_size = self.codegen_ty_size(&t.ty, "&**self".into());

stream.push_str(&self.codegen_impl_message_with_helper(
def_id,
name.clone(),
format! {
r#"{encode}
Expand Down
13 changes: 13 additions & 0 deletions pilota-build/src/middle/type_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,17 @@ impl TypeGraph {
let b = self.node_map[&b];
has_path_connecting(&self.graph, a, b, None)
}

pub fn is_cycled(&self, a: DefId) -> bool {
let a = self.node_map[&a];
for n in self
.graph
.neighbors_directed(a, petgraph::Direction::Outgoing)
{
if has_path_connecting(&self.graph, n, a, None) {
return true;
}
}
false
}
}
86 changes: 45 additions & 41 deletions pilota-build/test_data/thrift/recursive_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,54 +83,58 @@ pub mod recursive_type {
async fn decode_async<T: ::pilota::thrift::TAsyncInputProtocol>(
protocol: &mut T,
) -> ::std::result::Result<Self, ::pilota::thrift::DecodeError> {
let mut a = None;
::std::boxed::Box::pin(async move {
let mut a = None;

let mut __pilota_decoding_field_id = None;
let mut __pilota_decoding_field_id = None;

protocol.read_struct_begin().await?;
if let Err(err) = async {
loop {
let field_ident = protocol.read_field_begin().await?;
if field_ident.field_type == ::pilota::thrift::TType::Stop {
break;
} else {
}
__pilota_decoding_field_id = field_ident.id;
match field_ident.id {
Some(1)
if field_ident.field_type == ::pilota::thrift::TType::Struct =>
{
a = Some(::std::boxed::Box::new(
<A as ::pilota::thrift::Message>::decode_async(protocol)
.await?,
));
protocol.read_struct_begin().await?;
if let Err(err) = async {
loop {
let field_ident = protocol.read_field_begin().await?;
if field_ident.field_type == ::pilota::thrift::TType::Stop {
break;
} else {
}
_ => {
protocol.skip(field_ident.field_type).await?;
__pilota_decoding_field_id = field_ident.id;
match field_ident.id {
Some(1)
if field_ident.field_type
== ::pilota::thrift::TType::Struct =>
{
a = Some(::std::boxed::Box::new(
<A as ::pilota::thrift::Message>::decode_async(protocol)
.await?,
));
}
_ => {
protocol.skip(field_ident.field_type).await?;
}
}
}

protocol.read_field_end().await?;
}
Ok::<_, ::pilota::thrift::DecodeError>(())
}
.await
{
if let Some(field_id) = __pilota_decoding_field_id {
return Err(::pilota::thrift::DecodeError::new(
::pilota::thrift::DecodeErrorKind::WithContext(::std::boxed::Box::new(
err,
)),
format!("decode struct `A` field(#{}) failed", field_id),
));
} else {
return Err(err);
protocol.read_field_end().await?;
}
Ok::<_, ::pilota::thrift::DecodeError>(())
}
};
protocol.read_struct_end().await?;
.await
{
if let Some(field_id) = __pilota_decoding_field_id {
return Err(::pilota::thrift::DecodeError::new(
::pilota::thrift::DecodeErrorKind::WithContext(
::std::boxed::Box::new(err),
),
format!("decode struct `A` field(#{}) failed", field_id),
));
} else {
return Err(err);
}
};
protocol.read_struct_end().await?;

let data = Self { a };
Ok(data)
let data = Self { a };
Ok(data)
})
.await
}

fn size<T: ::pilota::thrift::TLengthProtocol>(&self, protocol: &mut T) -> usize {
Expand Down

0 comments on commit aca6e7e

Please sign in to comment.