From e2e114bfead6dc854d90d3fd2aeddfeb3056bd5d Mon Sep 17 00:00:00 2001 From: "evgeny.bovykin" Date: Mon, 23 Sep 2024 17:21:30 +0200 Subject: [PATCH] Support splitting a single item into multiple files --- pilota-build/src/codegen/mod.rs | 26 +- pilota-build/src/codegen/workspace.rs | 1 + pilota-build/src/lib.rs | 13 + pilota-build/src/middle/context.rs | 4 + pilota-build/src/test/mod.rs | 100 +++++ .../thrift_with_split/wrapper_arc.rs | 20 + .../thrift_with_split/wrapper_arc.thrift | 13 + .../thrift_with_split/wrapper_arc/A.rs | 119 ++++++ .../thrift_with_split/wrapper_arc/TEST.rs | 345 ++++++++++++++++++ .../wrapper_arc/TestService.rs | 2 + .../wrapper_arc/TestServiceTestArgsRecv.rs | 152 ++++++++ .../wrapper_arc/TestServiceTestArgsSend.rs | 154 ++++++++ .../wrapper_arc/TestServiceTestResultRecv.rs | 140 +++++++ .../wrapper_arc/TestServiceTestResultSend.rs | 142 +++++++ 14 files changed, 1229 insertions(+), 2 deletions(-) create mode 100644 pilota-build/test_data/thrift_with_split/wrapper_arc.rs create mode 100644 pilota-build/test_data/thrift_with_split/wrapper_arc.thrift create mode 100644 pilota-build/test_data/thrift_with_split/wrapper_arc/A.rs create mode 100644 pilota-build/test_data/thrift_with_split/wrapper_arc/TEST.rs create mode 100644 pilota-build/test_data/thrift_with_split/wrapper_arc/TestService.rs create mode 100644 pilota-build/test_data/thrift_with_split/wrapper_arc/TestServiceTestArgsRecv.rs create mode 100644 pilota-build/test_data/thrift_with_split/wrapper_arc/TestServiceTestArgsSend.rs create mode 100644 pilota-build/test_data/thrift_with_split/wrapper_arc/TestServiceTestResultRecv.rs create mode 100644 pilota-build/test_data/thrift_with_split/wrapper_arc/TestServiceTestResultSend.rs diff --git a/pilota-build/src/codegen/mod.rs b/pilota-build/src/codegen/mod.rs index 0e67582e..a5e668a1 100644 --- a/pilota-build/src/codegen/mod.rs +++ b/pilota-build/src/codegen/mod.rs @@ -28,6 +28,7 @@ use crate::{ symbol::{DefId, EnumRepr, FileId}, Context, Symbol, }; +use crate::rir::NodeKind; pub(crate) mod pkg_tree; pub mod toml; @@ -447,7 +448,7 @@ where ws.write_crates() } - pub fn write_items(&self, stream: &mut String, items: impl Iterator) + pub fn write_items(&self, stream: &mut String, items: impl Iterator, base_dir: &Path) where B: Send, { @@ -474,7 +475,26 @@ where let _enter = span.enter(); let mut dup = AHashMap::default(); for def_id in def_ids.iter() { - this.write_item(&mut stream, *def_id, &mut dup) + if this.split { + let mut item_stream = String::new(); + let node = this.db.node(def_id.def_id).unwrap(); + let file_name = format!("{}.rs", node.name()); + this.write_item(&mut item_stream, *def_id, &mut dup); + + let full_path = base_dir.join(file_name.clone()); + std::fs::create_dir_all(base_dir).unwrap(); + let mut file = std::io::BufWriter::new(std::fs::File::create(full_path.clone()).unwrap()); + file.write_all(item_stream.as_bytes()).unwrap(); + file.flush().unwrap(); + fmt_file(full_path); + + let base_dir_local_path = base_dir.iter().last().unwrap().to_str().unwrap(); + + stream.push_str(format!("\ninclude!(\"{}/{}\");\n", base_dir_local_path, file_name).as_str()); + } else { + this.write_item(&mut stream, *def_id, &mut dup) + } + } }); @@ -515,10 +535,12 @@ where } pub fn write_file(self, ns_name: Symbol, file_name: impl AsRef) { + let base_dir = file_name.as_ref().parent().unwrap(); let mut stream = String::default(); self.write_items( &mut stream, self.codegen_items.iter().map(|def_id| (*def_id).into()), + base_dir.join(ns_name.to_string()).as_path() ); stream = format! {r#"pub mod {ns_name} {{ diff --git a/pilota-build/src/codegen/workspace.rs b/pilota-build/src/codegen/workspace.rs index 2f76b331..72f929a9 100644 --- a/pilota-build/src/codegen/workspace.rs +++ b/pilota-build/src/codegen/workspace.rs @@ -246,6 +246,7 @@ where def_id, kind: super::CodegenKind::RePub, })), + base_dir.as_ref(), ); if let Some(main_mod_path) = info.main_mod_path { gen_rs_stream.push_str(&format!( diff --git a/pilota-build/src/lib.rs b/pilota-build/src/lib.rs index 2d015fec..963e213d 100644 --- a/pilota-build/src/lib.rs +++ b/pilota-build/src/lib.rs @@ -78,6 +78,7 @@ pub struct Builder { parser: P, plugins: Vec>, ignore_unused: bool, + split: bool, touches: Vec<(std::path::PathBuf, Vec)>, change_case: bool, keep_unknown_fields: Vec, @@ -103,6 +104,7 @@ impl Builder { dedups: Vec::default(), special_namings: Vec::default(), common_crate_name: "common".into(), + split: false, } } } @@ -124,6 +126,7 @@ impl Builder { dedups: Vec::default(), special_namings: Vec::default(), common_crate_name: "common".into(), + split: false, } } } @@ -152,6 +155,7 @@ impl Builder { dedups: self.dedups, special_namings: self.special_namings, common_crate_name: self.common_crate_name, + split: self.split, } } @@ -161,6 +165,11 @@ impl Builder { self } + pub fn with_split(mut self) -> Self { + self.split = true; + self + } + pub fn change_case(mut self, change_case: bool) -> Self { self.change_case = change_case; self @@ -266,6 +275,7 @@ where dedups: Vec, special_namings: Vec, common_crate_name: FastStr, + split: bool, ) -> Context { let mut db = RootDatabase::default(); parser.inputs(services.iter().map(|s| &s.path)); @@ -341,6 +351,7 @@ where dedups, special_namings, common_crate_name, + split, ) } @@ -359,6 +370,7 @@ where self.dedups, self.special_namings, self.common_crate_name, + self.split, ); cx.exec_plugin(BoxedPlugin); @@ -441,6 +453,7 @@ where self.dedups, self.special_namings, self.common_crate_name, + self.split, ); std::thread::scope(|_scope| { diff --git a/pilota-build/src/middle/context.rs b/pilota-build/src/middle/context.rs index 69019884..a7de9c92 100644 --- a/pilota-build/src/middle/context.rs +++ b/pilota-build/src/middle/context.rs @@ -67,6 +67,7 @@ pub struct Context { pub(crate) codegen_items: Arc<[DefId]>, pub(crate) path_resolver: Arc, pub(crate) mode: Arc, + pub(crate) split: bool, pub(crate) keep_unknown_fields: Arc>, pub location_map: Arc>, pub entry_map: Arc>>, @@ -86,6 +87,7 @@ impl Clone for Context { codegen_items: self.codegen_items.clone(), path_resolver: self.path_resolver.clone(), mode: self.mode.clone(), + split: self.split, services: self.services.clone(), keep_unknown_fields: self.keep_unknown_fields.clone(), location_map: self.location_map.clone(), @@ -327,6 +329,7 @@ impl ContextBuilder { dedups: Vec, special_namings: Vec, common_crate_name: FastStr, + split: bool, ) -> Context { SPECIAL_NAMINGS.get_or_init(|| special_namings); let mut cx = Context { @@ -341,6 +344,7 @@ impl ContextBuilder { Mode::SingleFile { .. } => Arc::new(DefaultPathResolver), }, mode: Arc::new(self.mode), + split, keep_unknown_fields: Arc::new(self.keep_unknown_fields), location_map: Arc::new(self.location_map), entry_map: Arc::new(self.entry_map), diff --git a/pilota-build/src/test/mod.rs b/pilota-build/src/test/mod.rs index b354e0c6..c3051219 100644 --- a/pilota-build/src/test/mod.rs +++ b/pilota-build/src/test/mod.rs @@ -1,5 +1,6 @@ #![cfg(test)] +use std::fs; use std::path::Path; use tempfile::tempdir; @@ -19,6 +20,36 @@ fn diff_file(old: impl AsRef, new: impl AsRef) { } } +fn diff_dir(old: impl AsRef, new: impl AsRef) { + let old_files: Vec<_> = fs::read_dir(old.as_ref()) + .unwrap() + .map(|res| res.unwrap().path()) + .collect(); + let new_files: Vec<_> = fs::read_dir(new.as_ref()) + .unwrap() + .map(|res| res.unwrap().path()) + .collect(); + + if old_files.len() != new_files.len() { + panic!( + "Number of files are different between {} and {}: {} vs {}", + old.as_ref().to_str().unwrap(), + new.as_ref().to_str().unwrap(), + old_files.len(), + new_files.len() + ); + } + + for old_file in old_files { + let file_name = old_file.file_name().unwrap(); + let corresponding_new_file = new.as_ref().join(file_name); + if !corresponding_new_file.exists() { + panic!("File {:?} does not exist in the new directory", file_name); + } + diff_file(old_file, corresponding_new_file); + } +} + fn test_protobuf(source: impl AsRef, target: impl AsRef) { test_with_builder(source, target, |source, target| { crate::Builder::protobuf() @@ -55,6 +86,35 @@ fn test_with_builder( } } +fn test_with_split_builder( + source: impl AsRef, + target: impl AsRef, + gen_dir: impl AsRef, + f: F, +) { + if std::env::var("UPDATE_TEST_DATA").as_deref() == Ok("1") { + f(source.as_ref(), target.as_ref()); + } else { + let dir = tempdir().unwrap(); + let path = dir.path().join( + target + .as_ref() + .file_name() + .and_then(|s| s.to_str()) + .unwrap() + ); + let mut base_dir_tmp = path.clone(); + base_dir_tmp.pop(); + base_dir_tmp.push(path.file_stem().unwrap()); + println!("{path:?}"); + + f(source.as_ref(), &path); + diff_file(target, path); + + diff_dir(gen_dir, base_dir_tmp); + } +} + fn test_thrift(source: impl AsRef, target: impl AsRef) { test_with_builder(source, target, |source, target| { crate::Builder::thrift() @@ -66,6 +126,19 @@ fn test_thrift(source: impl AsRef, target: impl AsRef) { }); } +fn test_thrift_with_split(source: impl AsRef, target: impl AsRef, gen_dir: impl AsRef) { + test_with_split_builder(source, target, gen_dir, |source, target| { + crate::Builder::thrift() + .ignore_unused(false) + .with_split() + .compile_with_config( + vec![IdlService::from_path(source.to_owned())], + crate::Output::File(target.into()), + ) + }); +} + + fn test_plugin_thrift(source: impl AsRef, target: impl AsRef) { test_with_builder(source, target, |source, target| { crate::Builder::thrift() @@ -111,6 +184,33 @@ fn test_thrift_gen() { }); } +#[test] +fn test_thrift_gen_with_split() { + let test_data_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("test_data") + .join("thrift_with_split"); + + test_data_dir.read_dir().unwrap().for_each(|f| { + let f = f.unwrap(); + + let path = f.path(); + + if let Some(ext) = path.extension() { + if ext == "thrift" { + let mut rs_path = path.clone(); + rs_path.set_extension("rs"); + + let mut gen_dir = path.clone(); + gen_dir.pop(); + gen_dir.push(rs_path.file_stem().unwrap()); + + test_thrift_with_split(path, rs_path, gen_dir.as_path()); + } + } + }); +} + + #[test] fn test_protobuf_gen() { let test_data_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) diff --git a/pilota-build/test_data/thrift_with_split/wrapper_arc.rs b/pilota-build/test_data/thrift_with_split/wrapper_arc.rs new file mode 100644 index 00000000..34fdaaa3 --- /dev/null +++ b/pilota-build/test_data/thrift_with_split/wrapper_arc.rs @@ -0,0 +1,20 @@ +pub mod wrapper_arc { + #![allow(warnings, clippy::all)] + + pub mod wrapper_arc { + + include!("wrapper_arc/A.rs"); + + include!("wrapper_arc/TestService.rs"); + + include!("wrapper_arc/TestServiceTestResultRecv.rs"); + + include!("wrapper_arc/TestServiceTestArgsRecv.rs"); + + include!("wrapper_arc/TestServiceTestResultSend.rs"); + + include!("wrapper_arc/TEST.rs"); + + include!("wrapper_arc/TestServiceTestArgsSend.rs"); + } +} diff --git a/pilota-build/test_data/thrift_with_split/wrapper_arc.thrift b/pilota-build/test_data/thrift_with_split/wrapper_arc.thrift new file mode 100644 index 00000000..6bb95526 --- /dev/null +++ b/pilota-build/test_data/thrift_with_split/wrapper_arc.thrift @@ -0,0 +1,13 @@ +struct A { + +} + +struct TEST { + 1: required string ID, + 2: required list> Name2(pilota.rust_wrapper_arc="true"), + 3: required map> Name3(pilota.rust_wrapper_arc="true"), +} + +service TestService { + TEST(pilota.rust_wrapper_arc="true") test(1: TEST req(pilota.rust_wrapper_arc="true")); +} diff --git a/pilota-build/test_data/thrift_with_split/wrapper_arc/A.rs b/pilota-build/test_data/thrift_with_split/wrapper_arc/A.rs new file mode 100644 index 00000000..b0c48761 --- /dev/null +++ b/pilota-build/test_data/thrift_with_split/wrapper_arc/A.rs @@ -0,0 +1,119 @@ +#[derive(PartialOrd, Hash, Eq, Ord, Debug, Default, Clone, PartialEq)] +pub struct A {} +impl ::pilota::thrift::Message for A { + fn encode( + &self, + __protocol: &mut T, + ) -> ::std::result::Result<(), ::pilota::thrift::ThriftException> { + #[allow(unused_imports)] + use ::pilota::thrift::TOutputProtocolExt; + let struct_ident = ::pilota::thrift::TStructIdentifier { name: "A" }; + + __protocol.write_struct_begin(&struct_ident)?; + + __protocol.write_field_stop()?; + __protocol.write_struct_end()?; + ::std::result::Result::Ok(()) + } + + fn decode( + __protocol: &mut T, + ) -> ::std::result::Result { + #[allow(unused_imports)] + use ::pilota::{thrift::TLengthProtocolExt, Buf}; + + let mut __pilota_decoding_field_id = None; + + __protocol.read_struct_begin()?; + if let ::std::result::Result::Err(mut err) = (|| { + loop { + let field_ident = __protocol.read_field_begin()?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + __protocol.field_stop_len(); + break; + } else { + __protocol.field_begin_len(field_ident.field_type, field_ident.id); + } + __pilota_decoding_field_id = field_ident.id; + match field_ident.id { + _ => { + __protocol.skip(field_ident.field_type)?; + } + } + + __protocol.read_field_end()?; + __protocol.field_end_len(); + } + ::std::result::Result::Ok::<_, ::pilota::thrift::ThriftException>(()) + })() { + if let Some(field_id) = __pilota_decoding_field_id { + err.prepend_msg(&format!( + "decode struct `A` field(#{}) failed, caused by: ", + field_id + )); + } + return ::std::result::Result::Err(err); + }; + __protocol.read_struct_end()?; + + let data = Self {}; + ::std::result::Result::Ok(data) + } + + fn decode_async<'a, T: ::pilota::thrift::TAsyncInputProtocol>( + __protocol: &'a mut T, + ) -> ::std::pin::Pin< + ::std::boxed::Box< + dyn ::std::future::Future< + Output = ::std::result::Result, + > + Send + + 'a, + >, + > { + ::std::boxed::Box::pin(async move { + let mut __pilota_decoding_field_id = None; + + __protocol.read_struct_begin().await?; + if let ::std::result::Result::Err(mut err) = async { + loop { + let field_ident = __protocol.read_field_begin().await?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + break; + } else { + } + __pilota_decoding_field_id = field_ident.id; + match field_ident.id { + _ => { + __protocol.skip(field_ident.field_type).await?; + } + } + + __protocol.read_field_end().await?; + } + ::std::result::Result::Ok::<_, ::pilota::thrift::ThriftException>(()) + } + .await + { + if let Some(field_id) = __pilota_decoding_field_id { + err.prepend_msg(&format!( + "decode struct `A` field(#{}) failed, caused by: ", + field_id + )); + } + return ::std::result::Result::Err(err); + }; + __protocol.read_struct_end().await?; + + let data = Self {}; + ::std::result::Result::Ok(data) + }) + } + + fn size(&self, __protocol: &mut T) -> usize { + #[allow(unused_imports)] + use ::pilota::thrift::TLengthProtocolExt; + __protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier { name: "A" }) + + __protocol.field_stop_len() + + __protocol.struct_end_len() + } +} diff --git a/pilota-build/test_data/thrift_with_split/wrapper_arc/TEST.rs b/pilota-build/test_data/thrift_with_split/wrapper_arc/TEST.rs new file mode 100644 index 00000000..2b7c05e9 --- /dev/null +++ b/pilota-build/test_data/thrift_with_split/wrapper_arc/TEST.rs @@ -0,0 +1,345 @@ +#[derive(Debug, Default, Clone, PartialEq)] +pub struct Test { + pub id: ::pilota::FastStr, + + pub name2: ::std::vec::Vec<::std::vec::Vec<::std::sync::Arc>>, + + pub name3: ::pilota::AHashMap>>, +} +impl ::pilota::thrift::Message for Test { + fn encode( + &self, + __protocol: &mut T, + ) -> ::std::result::Result<(), ::pilota::thrift::ThriftException> { + #[allow(unused_imports)] + use ::pilota::thrift::TOutputProtocolExt; + let struct_ident = ::pilota::thrift::TStructIdentifier { name: "TEST" }; + + __protocol.write_struct_begin(&struct_ident)?; + __protocol.write_faststr_field(1, (&self.id).clone())?; + __protocol.write_list_field( + 2, + ::pilota::thrift::TType::List, + &&self.name2, + |__protocol, val| { + __protocol.write_list( + ::pilota::thrift::TType::Struct, + &val, + |__protocol, val| { + __protocol.write_struct(val)?; + ::std::result::Result::Ok(()) + }, + )?; + ::std::result::Result::Ok(()) + }, + )?; + __protocol.write_map_field( + 3, + ::pilota::thrift::TType::I32, + ::pilota::thrift::TType::List, + &&self.name3, + |__protocol, key| { + __protocol.write_i32(*key)?; + ::std::result::Result::Ok(()) + }, + |__protocol, val| { + __protocol.write_list( + ::pilota::thrift::TType::Struct, + &val, + |__protocol, val| { + __protocol.write_struct(val)?; + ::std::result::Result::Ok(()) + }, + )?; + ::std::result::Result::Ok(()) + }, + )?; + __protocol.write_field_stop()?; + __protocol.write_struct_end()?; + ::std::result::Result::Ok(()) + } + + fn decode( + __protocol: &mut T, + ) -> ::std::result::Result { + #[allow(unused_imports)] + use ::pilota::{thrift::TLengthProtocolExt, Buf}; + + let mut var_1 = None; + let mut var_2 = None; + let mut var_3 = None; + + let mut __pilota_decoding_field_id = None; + + __protocol.read_struct_begin()?; + if let ::std::result::Result::Err(mut err) = (|| { + loop { + let field_ident = __protocol.read_field_begin()?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + __protocol.field_stop_len(); + break; + } else { + __protocol.field_begin_len(field_ident.field_type, field_ident.id); + } + __pilota_decoding_field_id = field_ident.id; + match field_ident.id { + Some(1) if field_ident.field_type == ::pilota::thrift::TType::Binary => { + var_1 = Some(__protocol.read_faststr()?); + } + Some(2) if field_ident.field_type == ::pilota::thrift::TType::List => { + var_2 = Some(unsafe { + let list_ident = __protocol.read_list_begin()?; + let mut val: Vec<::std::vec::Vec<::std::sync::Arc>> = + Vec::with_capacity(list_ident.size); + for i in 0..list_ident.size { + val.as_mut_ptr().offset(i as isize).write(unsafe { + let list_ident = __protocol.read_list_begin()?; + let mut val: Vec<::std::sync::Arc> = + Vec::with_capacity(list_ident.size); + for i in 0..list_ident.size { + val.as_mut_ptr().offset(i as isize).write( + ::std::sync::Arc::new( + ::pilota::thrift::Message::decode(__protocol)?, + ), + ); + } + val.set_len(list_ident.size); + __protocol.read_list_end()?; + val + }); + } + val.set_len(list_ident.size); + __protocol.read_list_end()?; + val + }); + } + Some(3) if field_ident.field_type == ::pilota::thrift::TType::Map => { + var_3 = Some({ + let map_ident = __protocol.read_map_begin()?; + let mut val = ::pilota::AHashMap::with_capacity(map_ident.size); + for _ in 0..map_ident.size { + val.insert(__protocol.read_i32()?, unsafe { + let list_ident = __protocol.read_list_begin()?; + let mut val: Vec<::std::sync::Arc> = + Vec::with_capacity(list_ident.size); + for i in 0..list_ident.size { + val.as_mut_ptr().offset(i as isize).write( + ::std::sync::Arc::new( + ::pilota::thrift::Message::decode(__protocol)?, + ), + ); + } + val.set_len(list_ident.size); + __protocol.read_list_end()?; + val + }); + } + __protocol.read_map_end()?; + val + }); + } + _ => { + __protocol.skip(field_ident.field_type)?; + } + } + + __protocol.read_field_end()?; + __protocol.field_end_len(); + } + ::std::result::Result::Ok::<_, ::pilota::thrift::ThriftException>(()) + })() { + if let Some(field_id) = __pilota_decoding_field_id { + err.prepend_msg(&format!( + "decode struct `TEST` field(#{}) failed, caused by: ", + field_id + )); + } + return ::std::result::Result::Err(err); + }; + __protocol.read_struct_end()?; + + let Some(var_1) = var_1 else { + return ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "field id is required".to_string(), + )); + }; + let Some(var_2) = var_2 else { + return ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "field name2 is required".to_string(), + )); + }; + let Some(var_3) = var_3 else { + return ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "field name3 is required".to_string(), + )); + }; + + let data = Self { + id: var_1, + name2: var_2, + name3: var_3, + }; + ::std::result::Result::Ok(data) + } + + fn decode_async<'a, T: ::pilota::thrift::TAsyncInputProtocol>( + __protocol: &'a mut T, + ) -> ::std::pin::Pin< + ::std::boxed::Box< + dyn ::std::future::Future< + Output = ::std::result::Result, + > + Send + + 'a, + >, + > { + ::std::boxed::Box::pin(async move { + let mut var_1 = None; + let mut var_2 = None; + let mut var_3 = None; + + let mut __pilota_decoding_field_id = None; + + __protocol.read_struct_begin().await?; + if let ::std::result::Result::Err(mut err) = async { + loop { + let field_ident = __protocol.read_field_begin().await?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + break; + } else { + } + __pilota_decoding_field_id = field_ident.id; + match field_ident.id { + Some(1) if field_ident.field_type == ::pilota::thrift::TType::Binary => { + var_1 = Some(__protocol.read_faststr().await?); + } + Some(2) if field_ident.field_type == ::pilota::thrift::TType::List => { + var_2 = Some({ + let list_ident = __protocol.read_list_begin().await?; + let mut val = Vec::with_capacity(list_ident.size); + for _ in 0..list_ident.size { + val.push({ + let list_ident = __protocol.read_list_begin().await?; + let mut val = Vec::with_capacity(list_ident.size); + for _ in 0..list_ident.size { + val.push(::std::sync::Arc::new( + ::decode_async( + __protocol, + ) + .await?, + )); + } + __protocol.read_list_end().await?; + val + }); + } + __protocol.read_list_end().await?; + val + }); + } + Some(3) if field_ident.field_type == ::pilota::thrift::TType::Map => { + var_3 = Some({ + let map_ident = __protocol.read_map_begin().await?; + let mut val = ::pilota::AHashMap::with_capacity(map_ident.size); + for _ in 0..map_ident.size { + val.insert(__protocol.read_i32().await?, { + let list_ident = __protocol.read_list_begin().await?; + let mut val = Vec::with_capacity(list_ident.size); + for _ in 0..list_ident.size { + val.push(::std::sync::Arc::new( + ::decode_async( + __protocol, + ) + .await?, + )); + } + __protocol.read_list_end().await?; + val + }); + } + __protocol.read_map_end().await?; + val + }); + } + _ => { + __protocol.skip(field_ident.field_type).await?; + } + } + + __protocol.read_field_end().await?; + } + ::std::result::Result::Ok::<_, ::pilota::thrift::ThriftException>(()) + } + .await + { + if let Some(field_id) = __pilota_decoding_field_id { + err.prepend_msg(&format!( + "decode struct `TEST` field(#{}) failed, caused by: ", + field_id + )); + } + return ::std::result::Result::Err(err); + }; + __protocol.read_struct_end().await?; + + let Some(var_1) = var_1 else { + return ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "field id is required".to_string(), + )); + }; + let Some(var_2) = var_2 else { + return ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "field name2 is required".to_string(), + )); + }; + let Some(var_3) = var_3 else { + return ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "field name3 is required".to_string(), + )); + }; + + let data = Self { + id: var_1, + name2: var_2, + name3: var_3, + }; + ::std::result::Result::Ok(data) + }) + } + + fn size(&self, __protocol: &mut T) -> usize { + #[allow(unused_imports)] + use ::pilota::thrift::TLengthProtocolExt; + __protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier { name: "TEST" }) + + __protocol.faststr_field_len(Some(1), &self.id) + + __protocol.list_field_len( + Some(2), + ::pilota::thrift::TType::List, + &self.name2, + |__protocol, el| { + __protocol.list_len(::pilota::thrift::TType::Struct, el, |__protocol, el| { + __protocol.struct_len(el) + }) + }, + ) + + __protocol.map_field_len( + Some(3), + ::pilota::thrift::TType::I32, + ::pilota::thrift::TType::List, + &self.name3, + |__protocol, key| __protocol.i32_len(*key), + |__protocol, val| { + __protocol.list_len(::pilota::thrift::TType::Struct, val, |__protocol, el| { + __protocol.struct_len(el) + }) + }, + ) + + __protocol.field_stop_len() + + __protocol.struct_end_len() + } +} diff --git a/pilota-build/test_data/thrift_with_split/wrapper_arc/TestService.rs b/pilota-build/test_data/thrift_with_split/wrapper_arc/TestService.rs new file mode 100644 index 00000000..bcfc2708 --- /dev/null +++ b/pilota-build/test_data/thrift_with_split/wrapper_arc/TestService.rs @@ -0,0 +1,2 @@ + +pub trait TestService {} diff --git a/pilota-build/test_data/thrift_with_split/wrapper_arc/TestServiceTestArgsRecv.rs b/pilota-build/test_data/thrift_with_split/wrapper_arc/TestServiceTestArgsRecv.rs new file mode 100644 index 00000000..71a1d9b1 --- /dev/null +++ b/pilota-build/test_data/thrift_with_split/wrapper_arc/TestServiceTestArgsRecv.rs @@ -0,0 +1,152 @@ +#[derive(Debug, Default, Clone, PartialEq)] +pub struct TestServiceTestArgsRecv { + pub req: Test, +} +impl ::pilota::thrift::Message for TestServiceTestArgsRecv { + fn encode( + &self, + __protocol: &mut T, + ) -> ::std::result::Result<(), ::pilota::thrift::ThriftException> { + #[allow(unused_imports)] + use ::pilota::thrift::TOutputProtocolExt; + let struct_ident = ::pilota::thrift::TStructIdentifier { + name: "TestServiceTestArgsRecv", + }; + + __protocol.write_struct_begin(&struct_ident)?; + __protocol.write_struct_field(1, &self.req, ::pilota::thrift::TType::Struct)?; + __protocol.write_field_stop()?; + __protocol.write_struct_end()?; + ::std::result::Result::Ok(()) + } + + fn decode( + __protocol: &mut T, + ) -> ::std::result::Result { + #[allow(unused_imports)] + use ::pilota::{thrift::TLengthProtocolExt, Buf}; + + let mut var_1 = None; + + let mut __pilota_decoding_field_id = None; + + __protocol.read_struct_begin()?; + if let ::std::result::Result::Err(mut err) = (|| { + loop { + let field_ident = __protocol.read_field_begin()?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + __protocol.field_stop_len(); + break; + } else { + __protocol.field_begin_len(field_ident.field_type, field_ident.id); + } + __pilota_decoding_field_id = field_ident.id; + match field_ident.id { + Some(1) if field_ident.field_type == ::pilota::thrift::TType::Struct => { + var_1 = Some(::pilota::thrift::Message::decode(__protocol)?); + } + _ => { + __protocol.skip(field_ident.field_type)?; + } + } + + __protocol.read_field_end()?; + __protocol.field_end_len(); + } + ::std::result::Result::Ok::<_, ::pilota::thrift::ThriftException>(()) + })() { + if let Some(field_id) = __pilota_decoding_field_id { + err.prepend_msg(&format!( + "decode struct `TestServiceTestArgsRecv` field(#{}) failed, caused by: ", + field_id + )); + } + return ::std::result::Result::Err(err); + }; + __protocol.read_struct_end()?; + + let Some(var_1) = var_1 else { + return ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "field req is required".to_string(), + )); + }; + + let data = Self { req: var_1 }; + ::std::result::Result::Ok(data) + } + + fn decode_async<'a, T: ::pilota::thrift::TAsyncInputProtocol>( + __protocol: &'a mut T, + ) -> ::std::pin::Pin< + ::std::boxed::Box< + dyn ::std::future::Future< + Output = ::std::result::Result, + > + Send + + 'a, + >, + > { + ::std::boxed::Box::pin(async move { + let mut var_1 = None; + + let mut __pilota_decoding_field_id = None; + + __protocol.read_struct_begin().await?; + if let ::std::result::Result::Err(mut err) = async { + loop { + let field_ident = __protocol.read_field_begin().await?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + break; + } else { + } + __pilota_decoding_field_id = field_ident.id; + match field_ident.id { + Some(1) if field_ident.field_type == ::pilota::thrift::TType::Struct => { + var_1 = Some( + ::decode_async(__protocol) + .await?, + ); + } + _ => { + __protocol.skip(field_ident.field_type).await?; + } + } + + __protocol.read_field_end().await?; + } + ::std::result::Result::Ok::<_, ::pilota::thrift::ThriftException>(()) + } + .await + { + if let Some(field_id) = __pilota_decoding_field_id { + err.prepend_msg(&format!( + "decode struct `TestServiceTestArgsRecv` field(#{}) failed, caused by: ", + field_id + )); + } + return ::std::result::Result::Err(err); + }; + __protocol.read_struct_end().await?; + + let Some(var_1) = var_1 else { + return ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "field req is required".to_string(), + )); + }; + + let data = Self { req: var_1 }; + ::std::result::Result::Ok(data) + }) + } + + fn size(&self, __protocol: &mut T) -> usize { + #[allow(unused_imports)] + use ::pilota::thrift::TLengthProtocolExt; + __protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier { + name: "TestServiceTestArgsRecv", + }) + __protocol.struct_field_len(Some(1), &self.req) + + __protocol.field_stop_len() + + __protocol.struct_end_len() + } +} diff --git a/pilota-build/test_data/thrift_with_split/wrapper_arc/TestServiceTestArgsSend.rs b/pilota-build/test_data/thrift_with_split/wrapper_arc/TestServiceTestArgsSend.rs new file mode 100644 index 00000000..12e37b2a --- /dev/null +++ b/pilota-build/test_data/thrift_with_split/wrapper_arc/TestServiceTestArgsSend.rs @@ -0,0 +1,154 @@ +#[derive(Debug, Default, Clone, PartialEq)] +pub struct TestServiceTestArgsSend { + pub req: ::std::sync::Arc, +} +impl ::pilota::thrift::Message for TestServiceTestArgsSend { + fn encode( + &self, + __protocol: &mut T, + ) -> ::std::result::Result<(), ::pilota::thrift::ThriftException> { + #[allow(unused_imports)] + use ::pilota::thrift::TOutputProtocolExt; + let struct_ident = ::pilota::thrift::TStructIdentifier { + name: "TestServiceTestArgsSend", + }; + + __protocol.write_struct_begin(&struct_ident)?; + __protocol.write_struct_field(1, &self.req, ::pilota::thrift::TType::Struct)?; + __protocol.write_field_stop()?; + __protocol.write_struct_end()?; + ::std::result::Result::Ok(()) + } + + fn decode( + __protocol: &mut T, + ) -> ::std::result::Result { + #[allow(unused_imports)] + use ::pilota::{thrift::TLengthProtocolExt, Buf}; + + let mut var_1 = None; + + let mut __pilota_decoding_field_id = None; + + __protocol.read_struct_begin()?; + if let ::std::result::Result::Err(mut err) = (|| { + loop { + let field_ident = __protocol.read_field_begin()?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + __protocol.field_stop_len(); + break; + } else { + __protocol.field_begin_len(field_ident.field_type, field_ident.id); + } + __pilota_decoding_field_id = field_ident.id; + match field_ident.id { + Some(1) if field_ident.field_type == ::pilota::thrift::TType::Struct => { + var_1 = Some(::std::sync::Arc::new(::pilota::thrift::Message::decode( + __protocol, + )?)); + } + _ => { + __protocol.skip(field_ident.field_type)?; + } + } + + __protocol.read_field_end()?; + __protocol.field_end_len(); + } + ::std::result::Result::Ok::<_, ::pilota::thrift::ThriftException>(()) + })() { + if let Some(field_id) = __pilota_decoding_field_id { + err.prepend_msg(&format!( + "decode struct `TestServiceTestArgsSend` field(#{}) failed, caused by: ", + field_id + )); + } + return ::std::result::Result::Err(err); + }; + __protocol.read_struct_end()?; + + let Some(var_1) = var_1 else { + return ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "field req is required".to_string(), + )); + }; + + let data = Self { req: var_1 }; + ::std::result::Result::Ok(data) + } + + fn decode_async<'a, T: ::pilota::thrift::TAsyncInputProtocol>( + __protocol: &'a mut T, + ) -> ::std::pin::Pin< + ::std::boxed::Box< + dyn ::std::future::Future< + Output = ::std::result::Result, + > + Send + + 'a, + >, + > { + ::std::boxed::Box::pin(async move { + let mut var_1 = None; + + let mut __pilota_decoding_field_id = None; + + __protocol.read_struct_begin().await?; + if let ::std::result::Result::Err(mut err) = async { + loop { + let field_ident = __protocol.read_field_begin().await?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + break; + } else { + } + __pilota_decoding_field_id = field_ident.id; + match field_ident.id { + Some(1) if field_ident.field_type == ::pilota::thrift::TType::Struct => { + var_1 = Some(::std::sync::Arc::new( + ::decode_async(__protocol) + .await?, + )); + } + _ => { + __protocol.skip(field_ident.field_type).await?; + } + } + + __protocol.read_field_end().await?; + } + ::std::result::Result::Ok::<_, ::pilota::thrift::ThriftException>(()) + } + .await + { + if let Some(field_id) = __pilota_decoding_field_id { + err.prepend_msg(&format!( + "decode struct `TestServiceTestArgsSend` field(#{}) failed, caused by: ", + field_id + )); + } + return ::std::result::Result::Err(err); + }; + __protocol.read_struct_end().await?; + + let Some(var_1) = var_1 else { + return ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "field req is required".to_string(), + )); + }; + + let data = Self { req: var_1 }; + ::std::result::Result::Ok(data) + }) + } + + fn size(&self, __protocol: &mut T) -> usize { + #[allow(unused_imports)] + use ::pilota::thrift::TLengthProtocolExt; + __protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier { + name: "TestServiceTestArgsSend", + }) + __protocol.struct_field_len(Some(1), &self.req) + + __protocol.field_stop_len() + + __protocol.struct_end_len() + } +} diff --git a/pilota-build/test_data/thrift_with_split/wrapper_arc/TestServiceTestResultRecv.rs b/pilota-build/test_data/thrift_with_split/wrapper_arc/TestServiceTestResultRecv.rs new file mode 100644 index 00000000..eb02d1a1 --- /dev/null +++ b/pilota-build/test_data/thrift_with_split/wrapper_arc/TestServiceTestResultRecv.rs @@ -0,0 +1,140 @@ +#[derive(Debug, ::pilota::derivative::Derivative)] +#[derivative(Default)] +#[derive(Clone, PartialEq)] +pub enum TestServiceTestResultRecv { + #[derivative(Default)] + Ok(Test), +} + +impl ::pilota::thrift::Message for TestServiceTestResultRecv { + fn encode( + &self, + __protocol: &mut T, + ) -> ::std::result::Result<(), ::pilota::thrift::ThriftException> { + #[allow(unused_imports)] + use ::pilota::thrift::TOutputProtocolExt; + __protocol.write_struct_begin(&::pilota::thrift::TStructIdentifier { + name: "TestServiceTestResultRecv", + })?; + match self { + TestServiceTestResultRecv::Ok(ref value) => { + __protocol.write_struct_field(0, value, ::pilota::thrift::TType::Struct)?; + } + } + __protocol.write_field_stop()?; + __protocol.write_struct_end()?; + ::std::result::Result::Ok(()) + } + + fn decode( + __protocol: &mut T, + ) -> ::std::result::Result { + #[allow(unused_imports)] + use ::pilota::{thrift::TLengthProtocolExt, Buf}; + let mut ret = None; + __protocol.read_struct_begin()?; + loop { + let field_ident = __protocol.read_field_begin()?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + __protocol.field_stop_len(); + break; + } else { + __protocol.field_begin_len(field_ident.field_type, field_ident.id); + } + match field_ident.id { + Some(0) => { + if ret.is_none() { + let field_ident = ::pilota::thrift::Message::decode(__protocol)?; + __protocol.struct_len(&field_ident); + ret = Some(TestServiceTestResultRecv::Ok(field_ident)); + } else { + return ::std::result::Result::Err( + ::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "received multiple fields for union from remote Message", + ), + ); + } + } + _ => { + __protocol.skip(field_ident.field_type)?; + } + } + } + __protocol.read_field_end()?; + __protocol.read_struct_end()?; + if let Some(ret) = ret { + ::std::result::Result::Ok(ret) + } else { + ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "received empty union from remote Message", + )) + } + } + + fn decode_async<'a, T: ::pilota::thrift::TAsyncInputProtocol>( + __protocol: &'a mut T, + ) -> ::std::pin::Pin< + ::std::boxed::Box< + dyn ::std::future::Future< + Output = ::std::result::Result, + > + Send + + 'a, + >, + > { + ::std::boxed::Box::pin(async move { + let mut ret = None; + __protocol.read_struct_begin().await?; + loop { + let field_ident = __protocol.read_field_begin().await?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + break; + } else { + } + match field_ident.id { + Some(0) => { + if ret.is_none() { + let field_ident = + ::decode_async(__protocol) + .await?; + + ret = Some(TestServiceTestResultRecv::Ok(field_ident)); + } else { + return ::std::result::Result::Err( + ::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "received multiple fields for union from remote Message", + ), + ); + } + } + _ => { + __protocol.skip(field_ident.field_type).await?; + } + } + } + __protocol.read_field_end().await?; + __protocol.read_struct_end().await?; + if let Some(ret) = ret { + ::std::result::Result::Ok(ret) + } else { + ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "received empty union from remote Message", + )) + } + }) + } + + fn size(&self, __protocol: &mut T) -> usize { + #[allow(unused_imports)] + use ::pilota::thrift::TLengthProtocolExt; + __protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier { + name: "TestServiceTestResultRecv", + }) + match self { + TestServiceTestResultRecv::Ok(ref value) => __protocol.struct_field_len(Some(0), value), + } + __protocol.field_stop_len() + + __protocol.struct_end_len() + } +} diff --git a/pilota-build/test_data/thrift_with_split/wrapper_arc/TestServiceTestResultSend.rs b/pilota-build/test_data/thrift_with_split/wrapper_arc/TestServiceTestResultSend.rs new file mode 100644 index 00000000..97b0f8b5 --- /dev/null +++ b/pilota-build/test_data/thrift_with_split/wrapper_arc/TestServiceTestResultSend.rs @@ -0,0 +1,142 @@ +#[derive(Debug, ::pilota::derivative::Derivative)] +#[derivative(Default)] +#[derive(Clone, PartialEq)] +pub enum TestServiceTestResultSend { + #[derivative(Default)] + Ok(::std::sync::Arc), +} + +impl ::pilota::thrift::Message for TestServiceTestResultSend { + fn encode( + &self, + __protocol: &mut T, + ) -> ::std::result::Result<(), ::pilota::thrift::ThriftException> { + #[allow(unused_imports)] + use ::pilota::thrift::TOutputProtocolExt; + __protocol.write_struct_begin(&::pilota::thrift::TStructIdentifier { + name: "TestServiceTestResultSend", + })?; + match self { + TestServiceTestResultSend::Ok(ref value) => { + __protocol.write_struct_field(0, value, ::pilota::thrift::TType::Struct)?; + } + } + __protocol.write_field_stop()?; + __protocol.write_struct_end()?; + ::std::result::Result::Ok(()) + } + + fn decode( + __protocol: &mut T, + ) -> ::std::result::Result { + #[allow(unused_imports)] + use ::pilota::{thrift::TLengthProtocolExt, Buf}; + let mut ret = None; + __protocol.read_struct_begin()?; + loop { + let field_ident = __protocol.read_field_begin()?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + __protocol.field_stop_len(); + break; + } else { + __protocol.field_begin_len(field_ident.field_type, field_ident.id); + } + match field_ident.id { + Some(0) => { + if ret.is_none() { + let field_ident = + ::std::sync::Arc::new(::pilota::thrift::Message::decode(__protocol)?); + __protocol.struct_len(&field_ident); + ret = Some(TestServiceTestResultSend::Ok(field_ident)); + } else { + return ::std::result::Result::Err( + ::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "received multiple fields for union from remote Message", + ), + ); + } + } + _ => { + __protocol.skip(field_ident.field_type)?; + } + } + } + __protocol.read_field_end()?; + __protocol.read_struct_end()?; + if let Some(ret) = ret { + ::std::result::Result::Ok(ret) + } else { + ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "received empty union from remote Message", + )) + } + } + + fn decode_async<'a, T: ::pilota::thrift::TAsyncInputProtocol>( + __protocol: &'a mut T, + ) -> ::std::pin::Pin< + ::std::boxed::Box< + dyn ::std::future::Future< + Output = ::std::result::Result, + > + Send + + 'a, + >, + > { + ::std::boxed::Box::pin(async move { + let mut ret = None; + __protocol.read_struct_begin().await?; + loop { + let field_ident = __protocol.read_field_begin().await?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + break; + } else { + } + match field_ident.id { + Some(0) => { + if ret.is_none() { + let field_ident = ::std::sync::Arc::new( + ::decode_async(__protocol) + .await?, + ); + + ret = Some(TestServiceTestResultSend::Ok(field_ident)); + } else { + return ::std::result::Result::Err( + ::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "received multiple fields for union from remote Message", + ), + ); + } + } + _ => { + __protocol.skip(field_ident.field_type).await?; + } + } + } + __protocol.read_field_end().await?; + __protocol.read_struct_end().await?; + if let Some(ret) = ret { + ::std::result::Result::Ok(ret) + } else { + ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "received empty union from remote Message", + )) + } + }) + } + + fn size(&self, __protocol: &mut T) -> usize { + #[allow(unused_imports)] + use ::pilota::thrift::TLengthProtocolExt; + __protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier { + name: "TestServiceTestResultSend", + }) + match self { + TestServiceTestResultSend::Ok(ref value) => __protocol.struct_field_len(Some(0), value), + } + __protocol.field_stop_len() + + __protocol.struct_end_len() + } +}