diff --git a/compiler/src/codegen.rs b/compiler/src/codegen.rs index 1f90d726..f9ec2b06 100644 --- a/compiler/src/codegen.rs +++ b/compiler/src/codegen.rs @@ -55,7 +55,6 @@ struct MethodGen<'a> { proto: &'a MethodDescriptorProto, package_name: String, service_name: String, - service_path: String, root_scope: &'a RootScope<'a>, customize: &'a Customize, } @@ -65,7 +64,6 @@ impl<'a> MethodGen<'a> { proto: &'a MethodDescriptorProto, package_name: String, service_name: String, - service_path: String, root_scope: &'a RootScope<'a>, customize: &'a Customize, ) -> MethodGen<'a> { @@ -73,7 +71,6 @@ impl<'a> MethodGen<'a> { proto, package_name, service_name, - service_path, root_scope, customize, } @@ -127,10 +124,6 @@ impl<'a> MethodGen<'a> { to_camel_case(self.proto.get_name()) } - fn fq_name(&self) -> String { - format!("\"{}/{}\"", self.service_path, &self.proto.get_name()) - } - fn const_method_name(&self) -> String { format!( "METHOD_{}_{}", @@ -139,36 +132,13 @@ impl<'a> MethodGen<'a> { ) } - fn write_definition(&self, w: &mut CodeWriter) { - let head = format!( - "const {}: {}<{}, {}> = {} {{", - self.const_method_name(), - fq_grpc("Method"), - self.input(), - self.output(), - fq_grpc("Method") - ); - let pb_mar = format!( - "{} {{ ser: {}, de: {} }}", - fq_grpc("Marshaller"), - fq_grpc("pb_ser"), - fq_grpc("pb_de") - ); - w.block(&head, "};", |w| { - w.field_entry("ty", &self.method_type().1); - w.field_entry("name", &self.fq_name()); - w.field_entry("req_mar", &pb_mar); - w.field_entry("resp_mar", &pb_mar); - }); - } - fn write_handler(&self, w: &mut CodeWriter) { w.block( &format!("struct {}Method {{", self.struct_name()), "}", |w| { w.write_line(&format!( - "service: Arc>,", + "service: Arc>,", self.service_name )); }, @@ -197,16 +167,55 @@ impl<'a> MethodGen<'a> { fn write_handler_impl_async(&self, w: &mut CodeWriter) { w.write_line("#[async_trait]"); - w.block(&format!("impl ::ttrpc::r#async::MethodHandler for {}Method {{", self.struct_name()), "}", - |w| { - w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, req: ::ttrpc::Request) -> ::ttrpc::Result<(u32, Vec)> {", "}", - |w| { - w.write_line(&format!("::ttrpc::async_request_handler!(self, ctx, req, {}, {}, {});", + match self.method_type().0 { + MethodType::Unary => { + w.block(&format!("impl ::ttrpc::r#async::MethodHandler for {}Method {{", self.struct_name()), "}", + |w| { + w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, req: ::ttrpc::Request) -> ::ttrpc::Result<::ttrpc::Response> {", "}", + |w| { + w.write_line(&format!("::ttrpc::async_request_handler!(self, ctx, req, {}, {}, {});", proto_path_to_rust_mod(self.root_scope.find_message(self.proto.get_input_type()).get_scope().get_file_descriptor().get_name()), self.root_scope.find_message(self.proto.get_input_type()).rust_name(), self.name())); + }); }); - }); + } + // only receive + MethodType::ClientStreaming => { + w.block(&format!("impl ::ttrpc::r#async::StreamHandler for {}Method {{", self.struct_name()), "}", + |w| { + w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, inner: ::ttrpc::r#async::StreamInner) -> ::ttrpc::Result> {", "}", + |w| { + w.write_line(&format!("::ttrpc::async_client_streamimg_handler!(self, ctx, inner, {});", + self.name())); + }); + }); + } + // only send + MethodType::ServerStreaming => { + w.block(&format!("impl ::ttrpc::r#async::StreamHandler for {}Method {{", self.struct_name()), "}", + |w| { + w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, mut inner: ::ttrpc::r#async::StreamInner) -> ::ttrpc::Result> {", "}", + |w| { + w.write_line(&format!("::ttrpc::async_server_streamimg_handler!(self, ctx, inner, {}, {}, {});", + proto_path_to_rust_mod(self.root_scope.find_message(self.proto.get_input_type()).get_scope().get_file_descriptor().get_name()), + self.root_scope.find_message(self.proto.get_input_type()).rust_name(), + self.name())); + }); + }); + } + // receive and send + MethodType::Duplex => { + w.block(&format!("impl ::ttrpc::r#async::StreamHandler for {}Method {{", self.struct_name()), "}", + |w| { + w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, inner: ::ttrpc::r#async::StreamInner) -> ::ttrpc::Result> {", "}", + |w| { + w.write_line(&format!("::ttrpc::async_duplex_streamimg_handler!(self, ctx, inner, {});", + self.name())); + }); + }); + } + } } // Method signatures @@ -222,73 +231,33 @@ impl<'a> MethodGen<'a> { fn client_streaming(&self, method_name: &str) -> String { format!( - "{}(&self) -> {}<({}<{}>, {}<{}>)>", + "{}(&self, ctx: ttrpc::context::Context) -> {}<{}<{}, {}>>", method_name, fq_grpc("Result"), - fq_grpc("ClientCStreamSender"), + fq_grpc("r#async::ClientStreamSender"), self.input(), - fq_grpc("ClientCStreamReceiver"), - self.output() - ) - } - - fn client_streaming_opt(&self, method_name: &str) -> String { - format!( - "{}_opt(&self, opt: {}) -> {}<({}<{}>, {}<{}>)>", - method_name, - fq_grpc("CallOption"), - fq_grpc("Result"), - fq_grpc("ClientCStreamSender"), - self.input(), - fq_grpc("ClientCStreamReceiver"), self.output() ) } fn server_streaming(&self, method_name: &str) -> String { format!( - "{}(&self, req: &{}) -> {}<{}<{}>>", + "{}(&self, ctx: ttrpc::context::Context, req: &{}) -> {}<{}<{}>>", method_name, self.input(), fq_grpc("Result"), - fq_grpc("ClientSStreamReceiver"), - self.output() - ) - } - - fn server_streaming_opt(&self, method_name: &str) -> String { - format!( - "{}_opt(&self, req: &{}, opt: {}) -> {}<{}<{}>>", - method_name, - self.input(), - fq_grpc("CallOption"), - fq_grpc("Result"), - fq_grpc("ClientSStreamReceiver"), + fq_grpc("r#async::ClientStreamReceiver"), self.output() ) } fn duplex_streaming(&self, method_name: &str) -> String { format!( - "{}(&self) -> {}<({}<{}>, {}<{}>)>", + "{}(&self, ctx: ttrpc::context::Context) -> {}<{}<{}, {}>>", method_name, fq_grpc("Result"), - fq_grpc("ClientDuplexSender"), + fq_grpc("r#async::ClientStream"), self.input(), - fq_grpc("ClientDuplexReceiver"), - self.output() - ) - } - - fn duplex_streaming_opt(&self, method_name: &str) -> String { - format!( - "{}_opt(&self, opt: {}) -> {}<({}<{}>, {}<{}>)>", - method_name, - fq_grpc("CallOption"), - fq_grpc("Result"), - fq_grpc("ClientDuplexSender"), - self.input(), - fq_grpc("ClientDuplexReceiver"), self.output() ) } @@ -317,7 +286,7 @@ impl<'a> MethodGen<'a> { fn write_async_client(&self, w: &mut CodeWriter) { let method_name = self.name(); match self.method_type().0 { - // Unary + // Unary RPC MethodType::Unary => { pub_async_fn(w, &self.unary(&method_name), |w| { w.write_line(&format!("let mut cres = {}::new();", self.output())); @@ -329,28 +298,79 @@ impl<'a> MethodGen<'a> { )); }); } - - _ => {} + // Client Streaming RPC + MethodType::ClientStreaming => { + pub_async_fn(w, &self.client_streaming(&method_name), |w| { + w.write_line(&format!( + "::ttrpc::async_client_stream_send!(self, ctx, \"{}.{}\", \"{}\");", + self.package_name, + self.service_name, + &self.proto.get_name(), + )); + }); + } + // Server Streaming RPC + MethodType::ServerStreaming => { + pub_async_fn(w, &self.server_streaming(&method_name), |w| { + w.write_line(&format!( + "::ttrpc::async_client_stream_receive!(self, ctx, req, \"{}.{}\", \"{}\");", + self.package_name, + self.service_name, + &self.proto.get_name(), + )); + }); + } + // Bidirectional streaming RPC + MethodType::Duplex => { + pub_async_fn(w, &self.duplex_streaming(&method_name), |w| { + w.write_line(&format!( + "::ttrpc::async_client_stream!(self, ctx, \"{}.{}\", \"{}\");", + self.package_name, + self.service_name, + &self.proto.get_name(), + )); + }); + } }; } fn write_service(&self, w: &mut CodeWriter) { - let req_stream_type = format!("{}<{}>", fq_grpc("RequestStream"), self.input()); - let (req, req_type, _resp_type) = match self.method_type().0 { - MethodType::Unary => ("req", self.input(), "UnarySink"), - MethodType::ClientStreaming => ("stream", req_stream_type, "ClientStreamingSink"), - MethodType::ServerStreaming => ("req", self.input(), "ServerStreamingSink"), - MethodType::Duplex => ("stream", req_stream_type, "DuplexSink"), + let (_req, req_type, resp_type) = match self.method_type().0 { + MethodType::Unary => ("req", self.input(), self.output()), + MethodType::ClientStreaming => ( + "stream", + format!("::ttrpc::r#async::ServerStreamReceiver<{}>", self.input()), + self.output(), + ), + MethodType::ServerStreaming => ( + "req", + format!( + "{}, _: {}<{}>", + self.input(), + "::ttrpc::r#async::ServerStreamSender", + self.output() + ), + "()".to_string(), + ), + MethodType::Duplex => ( + "stream", + format!( + "{}<{}, {}>", + "::ttrpc::r#async::ServerStream", + self.output(), + self.input(), + ), + "()".to_string(), + ), }; let get_sig = |context_name| { format!( - "{}(&self, _ctx: &{}, _{}: {}) -> ::ttrpc::Result<{}>", + "{}(&self, _ctx: &{}, _: {}) -> ::ttrpc::Result<{}>", self.name(), fq_grpc(context_name), - req, req_type, - self.output() + resp_type, ) }; @@ -370,20 +390,38 @@ impl<'a> MethodGen<'a> { } fn write_bind(&self, w: &mut CodeWriter) { - let mut method_handler_name = "::ttrpc::MethodHandler"; + let method_handler_name = "::ttrpc::MethodHandler"; - if async_on(self.customize, "server") { - method_handler_name = "::ttrpc::r#async::MethodHandler"; - } + let s = format!( + "methods.insert(\"/{}.{}/{}\".to_string(), + Box::new({}Method{{service: service.clone()}}) as Box);", + self.package_name, + self.service_name, + self.proto.get_name(), + self.struct_name(), + method_handler_name, + ); + w.write_line(&s); + } - let s = format!("methods.insert(\"/{}.{}/{}\".to_string(), - std::boxed::Box::new({}Method{{service: service.clone()}}) as std::boxed::Box);", - self.package_name, - self.service_name, - self.proto.get_name(), - self.struct_name(), - method_handler_name, - ); + fn write_async_bind(&self, w: &mut CodeWriter) { + let s = if matches!(self.method_type().0, MethodType::Unary) { + format!( + "methods.insert(\"{}\".to_string(), + Box::new({}Method{{service: service.clone()}}) as {});", + self.proto.get_name(), + self.struct_name(), + "Box" + ) + } else { + format!( + "streams.insert(\"{}\".to_string(), + Arc::new({}Method{{service: service.clone()}}) as {});", + self.proto.get_name(), + self.struct_name(), + "Arc" + ) + }; w.write_line(&s); } } @@ -392,6 +430,7 @@ struct ServiceGen<'a> { proto: &'a ServiceDescriptorProto, methods: Vec>, customize: &'a Customize, + package_name: String, } impl<'a> ServiceGen<'a> { @@ -401,11 +440,6 @@ impl<'a> ServiceGen<'a> { root_scope: &'a RootScope, customize: &'a Customize, ) -> ServiceGen<'a> { - let service_path = if file.get_package().is_empty() { - format!("{}", proto.get_name()) - } else { - format!("{}.{}", file.get_package(), proto.get_name()) - }; let methods = proto .get_method() .iter() @@ -414,7 +448,6 @@ impl<'a> ServiceGen<'a> { m, file.get_package().to_string(), util::to_camel_case(proto.get_name()), - service_path.clone(), root_scope, &customize, ) @@ -425,6 +458,7 @@ impl<'a> ServiceGen<'a> { proto, methods, customize, + package_name: file.get_package().to_string(), } } @@ -432,10 +466,20 @@ impl<'a> ServiceGen<'a> { util::to_camel_case(self.proto.get_name()) } + fn service_path(&self) -> String { + format!("{}.{}", self.package_name, self.service_name()) + } + fn client_name(&self) -> String { format!("{}Client", self.service_name()) } + fn has_stream_method(&self) -> bool { + self.methods + .iter() + .any(|method| !matches!(method.method_type().0, MethodType::Unary)) + } + fn write_client(&self, w: &mut CodeWriter) { if async_on(self.customize, "client") { self.write_async_client(w) @@ -490,11 +534,9 @@ impl<'a> ServiceGen<'a> { fn write_server(&self, w: &mut CodeWriter) { let mut trait_name = self.service_name(); - let mut method_handler_name = "::ttrpc::MethodHandler"; if async_on(self.customize, "server") { w.write_line("#[async_trait]"); trait_name = format!("{}: Sync", &self.service_name()); - method_handler_name = "::ttrpc::r#async::MethodHandler"; } w.pub_trait(&trait_name.to_owned(), |w| { @@ -504,9 +546,17 @@ impl<'a> ServiceGen<'a> { }); w.write_line(""); + if async_on(self.customize, "server") { + self.write_async_server_create(w); + } else { + self.write_sync_server_create(w); + } + } + fn write_sync_server_create(&self, w: &mut CodeWriter) { + let method_handler_name = "::ttrpc::MethodHandler"; let s = format!( - "create_{}(service: Arc>) -> HashMap >", + "create_{}(service: Arc>) -> HashMap>", to_snake_case(&self.service_name()), self.service_name(), method_handler_name, @@ -523,14 +573,35 @@ impl<'a> ServiceGen<'a> { }); } - fn write_method_definitions(&self, w: &mut CodeWriter) { - for (i, method) in self.methods.iter().enumerate() { - if i != 0 { + fn write_async_server_create(&self, w: &mut CodeWriter) { + let s = format!( + "create_{}(service: Arc>) -> HashMap", + to_snake_case(&self.service_name()), + self.service_name(), + "::ttrpc::r#async::Service" + ); + + let has_stream_method = self.has_stream_method(); + w.pub_fn(&s, |w| { + w.write_line("let mut ret = HashMap::new();"); + w.write_line("let mut methods = HashMap::new();"); + if has_stream_method { + w.write_line("let mut streams = HashMap::new();"); + } else { + w.write_line("let streams = HashMap::new();"); + } + for method in &self.methods[0..self.methods.len()] { w.write_line(""); + method.write_async_bind(w); } - - method.write_definition(w); - } + w.write_line(""); + w.write_line(format!( + "ret.insert(\"{}\".to_string(), {});", + self.service_path(), + "::ttrpc::r#async::Service{ methods, streams }" + )); + w.write_line("ret"); + }); } fn write_method_handlers(&self, w: &mut CodeWriter) { diff --git a/deny.toml b/deny.toml index c7f46580..6223c48d 100644 --- a/deny.toml +++ b/deny.toml @@ -72,6 +72,7 @@ unlicensed = "deny" allow = [ "MIT", "Apache-2.0", + "Unicode-DFS-2016", #"Apache-2.0 WITH LLVM-exception", ] # List of explictly disallowed licenses diff --git a/example/Cargo.toml b/example/Cargo.toml index fe9a099f..c873be51 100644 --- a/example/Cargo.toml +++ b/example/Cargo.toml @@ -22,6 +22,7 @@ ttrpc = { path = "../", features = ["async"] } ctrlc = { version = "3.0", features = ["termination"] } tokio = { version = "1.0.1", features = ["signal", "time"] } async-trait = "0.1.42" +rand = "0.8.5" [[example]] @@ -40,5 +41,13 @@ path = "./async-server.rs" name = "async-client" path = "./async-client.rs" +[[example]] +name = "async-stream-server" +path = "./async-stream-server.rs" + +[[example]] +name = "async-stream-client" +path = "./async-stream-client.rs" + [build-dependencies] ttrpc-codegen = { path = "../ttrpc-codegen"} diff --git a/example/Makefile b/example/Makefile index c2d695c5..518a45bf 100644 --- a/example/Makefile +++ b/example/Makefile @@ -8,6 +8,8 @@ build: cargo build --example client cargo build --example async-server cargo build --example async-client + cargo build --example async-stream-server + cargo build --example async-stream-client .PHONY: deps deps: diff --git a/example/async-stream-client.rs b/example/async-stream-client.rs new file mode 100644 index 00000000..ba953596 --- /dev/null +++ b/example/async-stream-client.rs @@ -0,0 +1,174 @@ +// Copyright 2022 Alibaba Cloud. All rights reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +mod protocols; +mod utils; + +use protocols::r#async::{empty, streaming, streaming_ttrpc}; +use ttrpc::context::{self, Context}; +use ttrpc::r#async::Client; + +#[tokio::main(flavor = "current_thread")] +async fn main() { + simple_logging::log_to_stderr(log::LevelFilter::Info); + + let c = Client::connect(utils::SOCK_ADDR).unwrap(); + let sc = streaming_ttrpc::StreamingClient::new(c); + + let _now = std::time::Instant::now(); + + let sc1 = sc.clone(); + let t1 = tokio::spawn(echo_request(sc1)); + + let sc1 = sc.clone(); + let t2 = tokio::spawn(echo_stream(sc1)); + + let sc1 = sc.clone(); + let t3 = tokio::spawn(sum_stream(sc1)); + + let sc1 = sc.clone(); + let t4 = tokio::spawn(divide_stream(sc1)); + + let sc1 = sc.clone(); + let t5 = tokio::spawn(echo_null(sc1)); + + let t6 = tokio::spawn(echo_null_stream(sc)); + + let _ = tokio::join!(t1, t2, t3, t4, t5, t6); +} + +fn default_ctx() -> Context { + let mut ctx = context::with_timeout(0); + ctx.add("key-1".to_string(), "value-1-1".to_string()); + ctx.add("key-1".to_string(), "value-1-2".to_string()); + ctx.set("key-2".to_string(), vec!["value-2".to_string()]); + + ctx +} + +async fn echo_request(cli: streaming_ttrpc::StreamingClient) { + let echo1 = streaming::EchoPayload { + seq: 1, + msg: "Echo Me".to_string(), + ..Default::default() + }; + let resp = cli.echo(default_ctx(), &echo1).await.unwrap(); + assert_eq!(resp.msg, echo1.msg); + assert_eq!(resp.seq, echo1.seq + 1); +} + +async fn echo_stream(cli: streaming_ttrpc::StreamingClient) { + let mut stream = cli.echo_stream(default_ctx()).await.unwrap(); + + let mut i = 0; + while i < 100 { + let echo = streaming::EchoPayload { + seq: i as u32, + msg: format!("{}: Echo in a stream", i), + ..Default::default() + }; + stream.send(&echo).await.unwrap(); + let resp = stream.recv().await.unwrap(); + assert_eq!(resp.msg, echo.msg); + assert_eq!(resp.seq, echo.seq + 1); + + i += 2; + } + stream.close_send().await.unwrap(); + let ret = stream.recv().await; + assert!(matches!(ret, Err(ttrpc::Error::Eof))); +} + +async fn sum_stream(cli: streaming_ttrpc::StreamingClient) { + let mut stream = cli.sum_stream(default_ctx()).await.unwrap(); + + let mut sum = streaming::Sum::new(); + stream.send(&streaming::Part::new()).await.unwrap(); + + sum.num += 1; + let mut i = -99i32; + while i <= 100 { + let addi = streaming::Part { + add: i, + ..Default::default() + }; + stream.send(&addi).await.unwrap(); + sum.sum += i; + sum.num += 1; + + i += 1; + } + stream.send(&streaming::Part::new()).await.unwrap(); + sum.num += 1; + + let ssum = stream.close_and_recv().await.unwrap(); + assert_eq!(ssum.sum, sum.sum); + assert_eq!(ssum.num, sum.num); +} + +async fn divide_stream(cli: streaming_ttrpc::StreamingClient) { + let expected = streaming::Sum { + sum: 392, + num: 4, + ..Default::default() + }; + let mut stream = cli.divide_stream(default_ctx(), &expected).await.unwrap(); + + let mut actual = streaming::Sum::new(); + + // NOTE: `for part in stream.recv().await.unwrap()` can't work. + while let Some(part) = stream.recv().await.unwrap() { + actual.sum += part.add; + actual.num += 1; + } + assert_eq!(actual.sum, expected.sum); + assert_eq!(actual.num, expected.num); +} + +async fn echo_null(cli: streaming_ttrpc::StreamingClient) { + let mut stream = cli.echo_null(default_ctx()).await.unwrap(); + + for i in 0..100 { + let echo = streaming::EchoPayload { + seq: i as u32, + msg: "non-empty empty".to_string(), + ..Default::default() + }; + stream.send(&echo).await.unwrap(); + } + let res = stream.close_and_recv().await.unwrap(); + assert_eq!(res, empty::Empty::new()); +} + +async fn echo_null_stream(cli: streaming_ttrpc::StreamingClient) { + let stream = cli.echo_null_stream(default_ctx()).await.unwrap(); + + let (tx, mut rx) = stream.split(); + + let task = tokio::spawn(async move { + loop { + let ret = rx.recv().await; + if matches!(ret, Err(ttrpc::Error::Eof)) { + break; + } + } + }); + + for i in 0..100 { + let echo = streaming::EchoPayload { + seq: i as u32, + msg: "non-empty empty".to_string(), + ..Default::default() + }; + tx.send(&echo).await.unwrap(); + } + + tx.close_send().await.unwrap(); + + tokio::time::timeout(tokio::time::Duration::from_secs(10), task) + .await + .unwrap() + .unwrap(); +} diff --git a/example/async-stream-server.rs b/example/async-stream-server.rs new file mode 100644 index 00000000..d828e2d3 --- /dev/null +++ b/example/async-stream-server.rs @@ -0,0 +1,170 @@ +// Copyright 2022 Alibaba Cloud. All rights reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +mod protocols; +mod utils; + +use std::sync::Arc; + +use log::{info, LevelFilter}; + +use protocols::r#async::{empty, streaming, streaming_ttrpc}; +use ttrpc::asynchronous::Server; + +use async_trait::async_trait; +use tokio::signal::unix::{signal, SignalKind}; +use tokio::time::sleep; + +struct StreamingService; + +#[async_trait] +impl streaming_ttrpc::Streaming for StreamingService { + async fn echo( + &self, + _ctx: &::ttrpc::r#async::TtrpcContext, + mut e: streaming::EchoPayload, + ) -> ::ttrpc::Result { + e.seq += 1; + Ok(e) + } + + async fn echo_stream( + &self, + _ctx: &::ttrpc::r#async::TtrpcContext, + mut s: ::ttrpc::r#async::ServerStream, + ) -> ::ttrpc::Result<()> { + while let Some(mut e) = s.recv().await? { + e.seq += 1; + s.send(&e).await?; + } + + Ok(()) + } + + async fn sum_stream( + &self, + _ctx: &::ttrpc::r#async::TtrpcContext, + mut s: ::ttrpc::r#async::ServerStreamReceiver, + ) -> ::ttrpc::Result { + let mut sum = streaming::Sum::new(); + while let Some(part) = s.recv().await? { + sum.sum += part.add; + sum.num += 1; + } + + Ok(sum) + } + + async fn divide_stream( + &self, + _ctx: &::ttrpc::r#async::TtrpcContext, + sum: streaming::Sum, + s: ::ttrpc::r#async::ServerStreamSender, + ) -> ::ttrpc::Result<()> { + let mut parts = vec![streaming::Part::new(); sum.num as usize]; + + let mut total = 0i32; + for i in 1..(sum.num - 2) { + let add = (rand::random::() % 1000) as i32 - 500; + parts[i as usize].add = add; + total += add; + } + + parts[sum.num as usize - 2].add = sum.sum - total; + + for part in parts { + s.send(&part).await.unwrap(); + } + + Ok(()) + } + + async fn echo_null( + &self, + _ctx: &::ttrpc::r#async::TtrpcContext, + mut s: ::ttrpc::r#async::ServerStreamReceiver, + ) -> ::ttrpc::Result { + let mut seq = 0; + while let Some(e) = s.recv().await? { + assert_eq!(e.seq, seq); + assert_eq!(e.msg.as_str(), "non-empty empty"); + seq += 1; + } + Ok(empty::Empty::new()) + } + + async fn echo_null_stream( + &self, + _ctx: &::ttrpc::r#async::TtrpcContext, + s: ::ttrpc::r#async::ServerStream, + ) -> ::ttrpc::Result<()> { + let msg = "non-empty empty".to_string(); + + let mut tasks = Vec::new(); + + let (tx, mut rx) = s.split(); + let mut seq = 0u32; + while let Some(e) = rx.recv().await? { + assert_eq!(e.seq, seq); + assert_eq!(e.msg, msg); + seq += 1; + + for _i in 0..10 { + let tx = tx.clone(); + tasks.push(tokio::spawn( + async move { tx.send(&empty::Empty::new()).await }, + )); + } + } + + for t in tasks { + t.await.unwrap().map_err(|e| { + ::ttrpc::Error::RpcStatus(::ttrpc::get_status( + ::ttrpc::Code::UNKNOWN, + e.to_string(), + )) + })?; + } + Ok(()) + } +} + +#[tokio::main(flavor = "current_thread")] +async fn main() { + simple_logging::log_to_stderr(LevelFilter::Info); + + let s = Box::new(StreamingService {}) as Box; + let s = Arc::new(s); + let service = streaming_ttrpc::create_streaming(s); + + utils::remove_if_sock_exist(utils::SOCK_ADDR).unwrap(); + + let mut server = Server::new() + .bind(utils::SOCK_ADDR) + .unwrap() + .register_service(service); + + let mut hangup = signal(SignalKind::hangup()).unwrap(); + let mut interrupt = signal(SignalKind::interrupt()).unwrap(); + server.start().await.unwrap(); + + tokio::select! { + _ = hangup.recv() => { + // test stop_listen -> start + info!("stop listen"); + server.stop_listen().await; + info!("start listen"); + server.start().await.unwrap(); + + // hold some time for the new test connection. + sleep(std::time::Duration::from_secs(100)).await; + } + _ = interrupt.recv() => { + // test graceful shutdown + info!("graceful shutdown"); + server.shutdown().await.unwrap(); + } + }; +} diff --git a/example/build.rs b/example/build.rs index b5550e2b..7ebb9fb5 100644 --- a/example/build.rs +++ b/example/build.rs @@ -9,7 +9,7 @@ use ttrpc_codegen::Codegen; use ttrpc_codegen::Customize; fn main() { - let protos = vec![ + let mut protos = vec![ "protocols/protos/github.com/kata-containers/agent/pkg/types/types.proto", "protocols/protos/agent.proto", "protocols/protos/health.proto", @@ -28,6 +28,9 @@ fn main() { .run() .expect("Gen sync code failed."); + // Only async support stream currently. + protos.push("protocols/protos/streaming.proto"); + Codegen::new() .out_dir("protocols/asynchronous") .inputs(&protos) diff --git a/example/protocols/asynchronous/mod.rs b/example/protocols/asynchronous/mod.rs index fd7082cc..34df78b2 100644 --- a/example/protocols/asynchronous/mod.rs +++ b/example/protocols/asynchronous/mod.rs @@ -10,3 +10,5 @@ pub mod health; pub mod health_ttrpc; mod oci; pub mod types; +pub mod streaming; +pub mod streaming_ttrpc; diff --git a/example/protocols/protos/streaming.proto b/example/protocols/protos/streaming.proto new file mode 100644 index 00000000..fce29dd6 --- /dev/null +++ b/example/protocols/protos/streaming.proto @@ -0,0 +1,49 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +syntax = "proto3"; + +package ttrpc.test.streaming; + +import "google/protobuf/empty.proto"; + +// Shim service is launched for each container and is responsible for owning the IO +// for the container and its additional processes. The shim is also the parent of +// each container and allows reattaching to the IO and receiving the exit status +// for the container processes. + +service Streaming { + rpc Echo(EchoPayload) returns (EchoPayload); + rpc EchoStream(stream EchoPayload) returns (stream EchoPayload); + rpc SumStream(stream Part) returns (Sum); + rpc DivideStream(Sum) returns (stream Part); + rpc EchoNull(stream EchoPayload) returns (google.protobuf.Empty); + rpc EchoNullStream(stream EchoPayload) returns (stream google.protobuf.Empty); +} + +message EchoPayload { + uint32 seq = 1; + string msg = 2; +} + +message Part { + int32 add = 1; +} + +message Sum { + int32 sum = 1; + int32 num = 2; +} diff --git a/src/asynchronous/client.rs b/src/asynchronous/client.rs index c5668694..bf6e44ed 100644 --- a/src/asynchronous/client.rs +++ b/src/asynchronous/client.rs @@ -1,37 +1,38 @@ +// Copyright 2022 Alibaba Cloud. All rights reserved. // Copyright (c) 2020 Ant Financial // // SPDX-License-Identifier: Apache-2.0 // -use nix::unistd::close; -use protobuf::{CodedInputStream, CodedOutputStream, Message}; use std::collections::HashMap; +use std::convert::TryInto; use std::os::unix::io::RawFd; +use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::{Arc, Mutex}; -use crate::common::{client_connect, MESSAGE_TYPE_RESPONSE}; -use crate::error::{Error, Result}; -use crate::proto::{Code, Request, Response}; +use async_trait::async_trait; +use nix::unistd::close; +use tokio::{self, sync::mpsc, task}; -use crate::asynchronous::stream::{receive, to_req_buf}; -use crate::r#async::utils; -use tokio::{ - self, - io::{split, AsyncWriteExt}, - sync::mpsc::{channel, Receiver, Sender}, - sync::Notify, +use crate::common::client_connect; +use crate::error::{Error, Result}; +use crate::proto::{ + Code, Codec, GenMessage, Message, Request, Response, FLAG_REMOTE_CLOSED, FLAG_REMOTE_OPEN, + MESSAGE_TYPE_DATA, MESSAGE_TYPE_RESPONSE, }; - -type RequestSender = Sender<(Vec, Sender>>)>; -type RequestReceiver = Receiver<(Vec, Sender>>)>; - -type ResponseSender = Sender>>; -type ResponseReceiver = Receiver>>; +use crate::r#async::connection::*; +use crate::r#async::shutdown; +use crate::r#async::stream::{ + Kind, MessageReceiver, MessageSender, ResultReceiver, ResultSender, StreamInner, +}; +use crate::r#async::utils; /// A ttrpc Client (async). #[derive(Clone)] pub struct Client { - req_tx: RequestSender, + req_tx: MessageSender, + next_stream_id: Arc, + streams: Arc>>, } impl Client { @@ -44,146 +45,50 @@ impl Client { pub fn new(fd: RawFd) -> Client { let stream = utils::new_unix_stream_from_raw_fd(fd); - let (mut reader, mut writer) = split(stream); - let (req_tx, mut rx): (RequestSender, RequestReceiver) = channel(100); + let (req_tx, rx): (MessageSender, MessageReceiver) = mpsc::channel(100); let req_map = Arc::new(Mutex::new(HashMap::new())); - let req_map2 = req_map.clone(); - - let notify = Arc::new(Notify::new()); - let notify2 = notify.clone(); - - // Request sender - let request_sender = tokio::spawn(async move { - let mut stream_id: u32 = 1; - - while let Some((body, resp_tx)) = rx.recv().await { - let current_stream_id = stream_id; - stream_id += 2; - - { - let mut map = req_map2.lock().unwrap(); - map.insert(current_stream_id, resp_tx.clone()); - } - - let buf = to_req_buf(current_stream_id, body); - if let Err(e) = writer.write_all(&buf).await { - error!("write_message got error: {:?}", e); - - { - let mut map = req_map2.lock().unwrap(); - map.remove(¤t_stream_id); - } + let delegate = ClientBuilder { + rx: Some(rx), + streams: req_map.clone(), + }; - let e = Error::Socket(format!("{:?}", e)); - resp_tx - .send(Err(e)) - .await - .unwrap_or_else(|_e| error!("The request has returned")); + let conn = Connection::new(stream, delegate); + tokio::spawn(async move { conn.run().await }); - break; // The stream is dead, exit the loop. - } - } + Client { + req_tx, + next_stream_id: Arc::new(AtomicU32::new(1)), + streams: req_map, + } + } - // rx.recv will abort when client.req_tx and client is dropped. - // notify the response-receiver to quit at this time. - notify.notify_one(); - }); + /// Requsts a unary request and returns with response. + pub async fn request(&self, req: Request) -> Result { + let timeout_nano = req.timeout_nano; + let stream_id = self.next_stream_id.fetch_add(2, Ordering::Relaxed); - // Response receiver - tokio::spawn(async move { - loop { - tokio::select! { - _ = notify2.notified() => { - break; - } - res = receive(&mut reader) => { - match res { - Ok((header, body)) => { - let req_map = req_map.clone(); - tokio::spawn(async move { - let resp_tx2; - { - let mut map = req_map.lock().unwrap(); - let resp_tx = match map.get(&header.stream_id) { - Some(tx) => tx, - None => { - debug!( - "Receiver got unknown packet {:?} {:?}", - header, body - ); - return; - } - }; - - resp_tx2 = resp_tx.clone(); - map.remove(&header.stream_id); // Forget the result, just remove. - } - - if header.type_ != MESSAGE_TYPE_RESPONSE { - resp_tx2 - .send(Err(Error::Others(format!( - "Recver got malformed packet {:?} {:?}", - header, body - )))) - .await - .unwrap_or_else(|_e| error!("The request has returned")); - return; - } - - resp_tx2.send(Ok(body)).await.unwrap_or_else(|_e| error!("The request has returned")); - }); - } - Err(e) => { - debug!("Connection closed by the ttRPC server: {}", e); - - // Abort the request sender task to prevent incoming RPC requests - // from being processed. - request_sender.abort(); - let _ = request_sender.await; - - // Take all items out of `req_map`. - let mut map = std::mem::take(&mut *req_map.lock().unwrap()); - // Terminate outstanding RPC requests with the error. - for (_stream_id, resp_tx) in map.drain() { - if let Err(_e) = resp_tx.send(Err(e.clone())).await { - warn!("Failed to terminate pending RPC: \ - the request has returned"); - } - } - - break; - } - } - } - }; - } - }); + let msg: GenMessage = Message::new_request(stream_id, req) + .try_into() + .map_err(|e: protobuf::error::ProtobufError| Error::Others(e.to_string()))?; - Client { req_tx } - } + let (tx, mut rx): (ResultSender, ResultReceiver) = mpsc::channel(100); - pub async fn request(&self, req: Request) -> Result { - let mut buf = Vec::with_capacity(req.compute_size() as usize); - { - let mut s = CodedOutputStream::vec(&mut buf); - req.write_to(&mut s).map_err(err_to_others_err!(e, ""))?; - s.flush().map_err(err_to_others_err!(e, ""))?; - } + // TODO: check return. + self.streams.lock().unwrap().insert(stream_id, tx); - let (tx, mut rx): (ResponseSender, ResponseReceiver) = channel(100); self.req_tx - .send((buf, tx)) + .send(msg) .await .map_err(|e| Error::Others(format!("Send packet to sender error {:?}", e)))?; - let result = if req.timeout_nano == 0 { + let result = if timeout_nano == 0 { rx.recv() .await .ok_or_else(|| Error::Others("Receive packet from receiver error".to_string()))? } else { tokio::time::timeout( - std::time::Duration::from_nanos(req.timeout_nano as u64), + std::time::Duration::from_nanos(timeout_nano as u64), rx.recv(), ) .await @@ -191,10 +96,8 @@ impl Client { .ok_or_else(|| Error::Others("Receive packet from receiver error".to_string()))? }; - let buf = result?; - let mut s = CodedInputStream::from_bytes(&buf); - let mut res = Response::new(); - res.merge_from(&mut s) + let msg = result?; + let res = Response::decode(&msg.payload) .map_err(err_to_others_err!(e, "Unpack response error "))?; let status = res.get_status(); @@ -204,6 +107,44 @@ impl Client { Ok(res) } + + /// Creates a StreamInner instance. + pub async fn new_stream( + &self, + req: Request, + streaming_client: bool, + streaming_server: bool, + ) -> Result { + let stream_id = self.next_stream_id.fetch_add(2, Ordering::Relaxed); + + let mut msg: GenMessage = Message::new_request(stream_id, req) + .try_into() + .map_err(|e: protobuf::error::ProtobufError| Error::Others(e.to_string()))?; + + if streaming_client { + msg.header.add_flags(FLAG_REMOTE_OPEN); + } else { + msg.header.add_flags(FLAG_REMOTE_CLOSED); + } + + let (tx, rx): (ResultSender, ResultReceiver) = mpsc::channel(100); + // TODO: check return + self.streams.lock().unwrap().insert(stream_id, tx); + self.req_tx + .send(msg) + .await + .map_err(|e| Error::Others(format!("Send packet to sender error {:?}", e)))?; + + Ok(StreamInner::new( + stream_id, + self.req_tx.clone(), + rx, + streaming_client, + streaming_server, + Kind::Client, + self.streams.clone(), + )) + } } struct ClientClose { @@ -218,3 +159,153 @@ impl Drop for ClientClose { trace!("All client is droped"); } } + +#[derive(Debug)] +struct ClientBuilder { + rx: Option, + streams: Arc>>, +} + +impl Builder for ClientBuilder { + type Reader = ClientReader; + type Writer = ClientWriter; + + fn build(&mut self) -> (Self::Reader, Self::Writer) { + let (notifier, waiter) = shutdown::new(); + ( + ClientReader { + shutdown_waiter: waiter, + streams: self.streams.clone(), + }, + ClientWriter { + rx: self.rx.take().unwrap(), + shutdown_notifier: notifier, + + streams: self.streams.clone(), + }, + ) + } +} + +struct ClientWriter { + rx: MessageReceiver, + shutdown_notifier: shutdown::Notifier, + + streams: Arc>>, +} + +#[async_trait] +impl WriterDelegate for ClientWriter { + async fn recv(&mut self) -> Option { + self.rx.recv().await + } + + async fn disconnect(&self, msg: &GenMessage, e: Error) { + // TODO: + // At this point, a new request may have been received. + let resp_tx = { + let mut map = self.streams.lock().unwrap(); + map.remove(&msg.header.stream_id) + }; + + // TODO: if None + if let Some(resp_tx) = resp_tx { + let e = Error::Socket(format!("{:?}", e)); + resp_tx + .send(Err(e)) + .await + .unwrap_or_else(|_e| error!("The request has returned")); + } + } + + async fn exit(&self) { + self.shutdown_notifier.shutdown(); + } +} + +struct ClientReader { + streams: Arc>>, + shutdown_waiter: shutdown::Waiter, +} + +#[async_trait] +impl ReaderDelegate for ClientReader { + async fn wait_shutdown(&self) { + self.shutdown_waiter.wait_shutdown().await + } + + async fn disconnect(&self, e: Error, sender: &mut task::JoinHandle<()>) { + // Abort the request sender task to prevent incoming RPC requests + // from being processed. + sender.abort(); + let _ = sender.await; + + // Take all items out of `req_map`. + let mut map = std::mem::take(&mut *self.streams.lock().unwrap()); + // Terminate undone RPC requests with the error. + for (_stream_id, resp_tx) in map.drain() { + if let Err(_e) = resp_tx.send(Err(e.clone())).await { + warn!("Failed to terminate pending RPC: the request has returned"); + } + } + } + + async fn exit(&self) {} + + async fn handle_msg(&self, msg: GenMessage) { + let req_map = self.streams.clone(); + tokio::spawn(async move { + let resp_tx = match msg.header.type_ { + MESSAGE_TYPE_RESPONSE => { + match req_map.lock().unwrap().remove(&msg.header.stream_id) { + Some(tx) => tx, + None => { + debug!("Receiver got unknown response packet {:?}", msg); + return; + } + } + } + MESSAGE_TYPE_DATA => { + if (msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED { + match req_map.lock().unwrap().remove(&msg.header.stream_id) { + Some(tx) => tx.clone(), + None => { + debug!("Receiver got unknown data packet {:?}", msg); + return; + } + } + } else { + match req_map.lock().unwrap().get(&msg.header.stream_id) { + Some(tx) => tx.clone(), + None => { + debug!("Receiver got unknown data packet {:?}", msg); + return; + } + } + } + } + _ => { + let resp_tx = match req_map.lock().unwrap().remove(&msg.header.stream_id) { + Some(tx) => tx, + None => { + debug!("Receiver got unknown packet {:?}", msg); + return; + } + }; + resp_tx + .send(Err(Error::Others(format!( + "Recver got malformed packet {:?}", + msg + )))) + .await + .unwrap_or_else(|_e| error!("The request has returned")); + return; + } + }; + resp_tx + .send(Ok(msg)) + .await + .unwrap_or_else(|_e| error!("The request has returned")); + }); + } +} diff --git a/src/asynchronous/connection.rs b/src/asynchronous/connection.rs new file mode 100644 index 00000000..8de87d3b --- /dev/null +++ b/src/asynchronous/connection.rs @@ -0,0 +1,110 @@ +// Copyright 2022 Alibaba Cloud. All rights reserved. +// Copyright (c) 2020 Ant Financial +// +// SPDX-License-Identifier: Apache-2.0 +// + +use std::os::unix::io::AsRawFd; + +use async_trait::async_trait; +use log::{error, trace}; +use tokio::{ + io::{split, AsyncRead, AsyncWrite, ReadHalf}, + select, task, +}; + +use crate::error::Error; +use crate::proto::GenMessage; + +pub trait Builder { + type Reader; + type Writer; + + fn build(&mut self) -> (Self::Reader, Self::Writer); +} + +#[async_trait] +pub trait WriterDelegate { + async fn recv(&mut self) -> Option; + async fn disconnect(&self, msg: &GenMessage, e: Error); + async fn exit(&self); +} + +#[async_trait] +pub trait ReaderDelegate { + async fn wait_shutdown(&self); + async fn disconnect(&self, e: Error, task: &mut task::JoinHandle<()>); + async fn exit(&self); + async fn handle_msg(&self, msg: GenMessage); +} + +pub struct Connection { + reader: ReadHalf, + writer_task: task::JoinHandle<()>, + reader_delegate: B::Reader, +} + +impl Connection +where + S: AsyncRead + AsyncWrite + AsRawFd + Send + 'static, + B: Builder, + B::Reader: ReaderDelegate + Send + Sync + 'static, + B::Writer: WriterDelegate + Send + Sync + 'static, +{ + pub fn new(conn: S, mut builder: B) -> Self { + let (reader, mut writer) = split(conn); + + let (reader_delegate, mut writer_delegate) = builder.build(); + + let writer_task = tokio::spawn(async move { + while let Some(msg) = writer_delegate.recv().await { + trace!("write message: {:?}", msg); + if let Err(e) = msg.write_to(&mut writer).await { + error!("write_message got error: {:?}", e); + writer_delegate.disconnect(&msg, e).await; + } + } + writer_delegate.exit().await; + trace!("Writer task exit."); + }); + + Self { + reader, + writer_task, + reader_delegate, + } + } + + pub async fn run(self) -> std::io::Result<()> { + let Connection { + mut reader, + mut writer_task, + reader_delegate, + } = self; + loop { + select! { + res = GenMessage::read_from(&mut reader) => { + match res { + Ok(msg) => { + trace!("Got Message {:?}", msg); + reader_delegate.handle_msg(msg).await; + } + Err(e) => { + trace!("Read msg err: {:?}", e); + reader_delegate.disconnect(e, &mut writer_task).await; + break; + } + } + } + _v = reader_delegate.wait_shutdown() => { + trace!("Receive shutdown."); + break; + } + } + } + reader_delegate.exit().await; + trace!("Reader task exit."); + + Ok(()) + } +} diff --git a/src/asynchronous/mod.rs b/src/asynchronous/mod.rs index f63aae86..fb6d39cb 100644 --- a/src/asynchronous/mod.rs +++ b/src/asynchronous/mod.rs @@ -11,11 +11,17 @@ mod stream; #[macro_use] #[doc(hidden)] mod utils; +mod connection; +pub mod shutdown; mod unix_incoming; +pub use self::stream::{ + ClientStream, ClientStreamReceiver, ClientStreamSender, Kind, ServerStream, + ServerStreamReceiver, ServerStreamSender, StreamInner, +}; #[doc(inline)] pub use crate::r#async::client::Client; #[doc(inline)] -pub use crate::r#async::server::Server; +pub use crate::r#async::server::{Server, Service}; #[doc(inline)] -pub use utils::{MethodHandler, TtrpcContext}; +pub use utils::{MethodHandler, StreamHandler, TtrpcContext}; diff --git a/src/asynchronous/server.rs b/src/asynchronous/server.rs index 9afd1836..a19dac30 100644 --- a/src/asynchronous/server.rs +++ b/src/asynchronous/server.rs @@ -1,49 +1,76 @@ +// Copyright 2022 Alibaba Cloud. All rights reserved. // Copyright (c) 2020 Ant Financial // // SPDX-License-Identifier: Apache-2.0 // -use crate::r#async::utils; -use nix::unistd; use std::collections::HashMap; +use std::convert::TryFrom; +use std::marker::Unpin; use std::os::unix::io::RawFd; +use std::os::unix::io::{AsRawFd, FromRawFd}; +use std::os::unix::net::UnixListener as SysUnixListener; use std::result::Result as StdResult; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::time::Duration; -use crate::asynchronous::stream::{receive, respond, respond_with_status}; -use crate::asynchronous::unix_incoming::UnixIncoming; -use crate::common::{self, Domain, MESSAGE_TYPE_REQUEST}; -use crate::context; -use crate::error::{get_status, Error, Result}; -use crate::proto::{Code, Status}; -use crate::r#async::{MethodHandler, TtrpcContext}; -use crate::MessageHeader; +use async_trait::async_trait; use futures::stream::Stream; use futures::StreamExt as _; -use std::marker::Unpin; -use std::os::unix::io::{AsRawFd, FromRawFd}; -use std::os::unix::net::UnixListener as SysUnixListener; +use nix::unistd; use tokio::{ self, - io::{split, AsyncRead, AsyncWrite, AsyncWriteExt}, + io::{AsyncRead, AsyncWrite}, net::UnixListener, select, spawn, - sync::mpsc::{channel, Receiver, Sender}, - sync::watch, + sync::mpsc::{channel, Sender}, + task, time::timeout, }; - #[cfg(target_os = "linux")] use tokio_vsock::VsockListener; +use crate::asynchronous::unix_incoming::UnixIncoming; +use crate::common::{self, Domain}; +use crate::context; +use crate::error::{get_status, Error, Result}; +use crate::proto::{ + Code, Codec, GenMessage, Message, MessageHeader, Request, Response, Status, FLAG_NO_DATA, + FLAG_REMOTE_CLOSED, FLAG_REMOTE_OPEN, MESSAGE_TYPE_DATA, MESSAGE_TYPE_REQUEST, +}; +use crate::r#async::connection::*; +use crate::r#async::shutdown; +use crate::r#async::stream::{ + Kind, MessageReceiver, MessageSender, ResultReceiver, ResultSender, StreamInner, +}; +use crate::r#async::utils; +use crate::r#async::{MethodHandler, StreamHandler, TtrpcContext}; + +const DEFAULT_CONN_SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(5000); +const DEFAULT_SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(10000); + +pub struct Service { + pub methods: HashMap>, + pub streams: HashMap>, +} + +impl Service { + pub(crate) fn get_method(&self, name: &str) -> Option<&(dyn MethodHandler + Send + Sync)> { + self.methods.get(name).map(|b| b.as_ref()) + } + + pub(crate) fn get_stream(&self, name: &str) -> Option> { + self.streams.get(name).cloned() + } +} + /// A ttrpc Server (async). pub struct Server { listeners: Vec, - methods: Arc>>, + services: Arc>, domain: Option, - disconnect_tx: Option>, - all_conn_done_rx: Option>, + + shutdown: shutdown::Notifier, stop_listen_tx: Option>>, } @@ -51,10 +78,9 @@ impl Default for Server { fn default() -> Self { Server { listeners: Vec::with_capacity(1), - methods: Arc::new(HashMap::new()), + services: Arc::new(HashMap::new()), domain: None, - disconnect_tx: None, - all_conn_done_rx: None, + shutdown: shutdown::with_timeout(DEFAULT_SERVER_SHUTDOWN_TIMEOUT).0, stop_listen_tx: None, } } @@ -97,12 +123,9 @@ impl Server { Ok(self) } - pub fn register_service( - mut self, - methods: HashMap>, - ) -> Server { - let mut_methods = Arc::get_mut(&mut self.methods).unwrap(); - mut_methods.extend(methods); + pub fn register_service(mut self, new: HashMap) -> Server { + let services = Arc::get_mut(&mut self.services).unwrap(); + services.extend(new); self } @@ -150,13 +173,9 @@ impl Server { I: Stream> + Unpin + Send + 'static + AsRawFd, S: AsyncRead + AsyncWrite + AsRawFd + Send + 'static, { - let methods = self.methods.clone(); - - let (disconnect_tx, close_conn_rx) = watch::channel(0); - self.disconnect_tx = Some(disconnect_tx); + let services = self.services.clone(); - let (conn_done_tx, all_conn_done_rx) = channel::(1); - self.all_conn_done_rx = Some(all_conn_done_rx); + let shutdown_waiter = self.shutdown.subscribe(); let (stop_listen_tx, mut stop_listen_rx) = channel(1); self.stop_listen_tx = Some(stop_listen_tx); @@ -168,15 +187,14 @@ impl Server { if let Some(conn) = conn { // Accept a new connection match conn { - Ok(stream) => { - let fd = stream.as_raw_fd(); + Ok(conn) => { + let fd = conn.as_raw_fd(); // spawn a connection handler, would not block spawn_connection_handler( fd, - stream, - methods.clone(), - close_conn_rx.clone(), - conn_done_tx.clone() + conn, + services.clone(), + shutdown_waiter.clone(), ).await; } Err(e) => { @@ -202,7 +220,6 @@ impl Server { } } } - drop(conn_done_tx); }); Ok(()) } @@ -215,13 +232,16 @@ impl Server { } pub async fn disconnect(&mut self) { - if let Some(tx) = self.disconnect_tx.take() { - tx.send(1).ok(); - } + self.shutdown.shutdown(); - if let Some(mut rx) = self.all_conn_done_rx.take() { - rx.recv().await; - } + self.shutdown + .wait_all_exit() + .await + .map_err(|e| { + trace!("wait connection exit error: {}", e); + }) + .ok(); + trace!("wait connection exit."); } pub async fn stop_listen(&mut self) { @@ -236,155 +256,394 @@ impl Server { } } -async fn spawn_connection_handler( +async fn spawn_connection_handler( fd: RawFd, - stream: S, - methods: Arc>>, - mut close_conn_rx: watch::Receiver, - conn_done_tx: Sender, + conn: C, + services: Arc>, + shutdown_waiter: shutdown::Waiter, ) where - S: AsyncRead + AsyncWrite + AsRawFd + Send + 'static, + C: AsyncRead + AsyncWrite + AsRawFd + Send + 'static, { - let (req_done_tx, mut all_req_done_rx) = channel::(1); - + let delegate = ServerBuilder { + fd, + services, + streams: Arc::new(Mutex::new(HashMap::new())), + shutdown_waiter, + }; + let conn = Connection::new(conn, delegate); spawn(async move { - let (mut reader, mut writer) = split(stream); - let (tx, mut rx): (Sender>, Receiver>) = channel(100); - let (client_disconnected_tx, client_disconnected_rx) = watch::channel(false); + conn.run() + .await + .map_err(|e| { + trace!("connection run error. {}", e); + }) + .ok(); + }); +} + +impl FromRawFd for Server { + unsafe fn from_raw_fd(fd: RawFd) -> Self { + Self::default().add_listener(fd).unwrap() + } +} + +impl AsRawFd for Server { + fn as_raw_fd(&self) -> RawFd { + self.listeners[0] + } +} + +struct ServerBuilder { + fd: RawFd, + services: Arc>, + streams: Arc>>, + shutdown_waiter: shutdown::Waiter, +} + +impl Builder for ServerBuilder { + type Reader = ServerReader; + type Writer = ServerWriter; + + fn build(&mut self) -> (Self::Reader, Self::Writer) { + let (tx, rx): (MessageSender, MessageReceiver) = channel(100); + let (disconnect_notifier, _disconnect_waiter) = + shutdown::with_timeout(DEFAULT_CONN_SHUTDOWN_TIMEOUT); + + ( + ServerReader { + fd: self.fd, + tx, + services: self.services.clone(), + streams: self.streams.clone(), + server_shutdown: self.shutdown_waiter.clone(), + handler_shutdown: disconnect_notifier, + }, + ServerWriter { rx }, + ) + } +} + +struct ServerWriter { + rx: MessageReceiver, +} +#[async_trait] +impl WriterDelegate for ServerWriter { + async fn recv(&mut self) -> Option { + self.rx.recv().await + } + async fn disconnect(&self, _msg: &GenMessage, _: Error) {} + async fn exit(&self) {} +} + +struct ServerReader { + fd: RawFd, + tx: MessageSender, + services: Arc>, + streams: Arc>>, + server_shutdown: shutdown::Waiter, + handler_shutdown: shutdown::Notifier, +} + +#[async_trait] +impl ReaderDelegate for ServerReader { + async fn wait_shutdown(&self) { + self.server_shutdown.wait_shutdown().await + } + + async fn disconnect(&self, _: Error, _: &mut task::JoinHandle<()>) { + self.handler_shutdown.shutdown(); + // TODO: Don't wait for all requests to complete? when the connection is disconnected. + } + + async fn exit(&self) { + // TODO: Don't self.conn_shutdown.shutdown(); + // Wait pedding request/stream to exit. + self.handler_shutdown + .wait_all_exit() + .await + .map_err(|e| { + trace!("wait handler exit error: {}", e); + }) + .ok(); + } + + async fn handle_msg(&self, msg: GenMessage) { + let handler_shutdown_waiter = self.handler_shutdown.subscribe(); + let context = self.context(); spawn(async move { - while let Some(buf) = rx.recv().await { - if let Err(e) = writer.write_all(&buf).await { - error!("write_message got error: {:?}", e); - } + select! { + _ = context.handle_msg(msg) => {} + _ = handler_shutdown_waiter.wait_shutdown() => {} } }); + } +} - loop { - let tx = tx.clone(); - let methods = methods.clone(); - let req_done_tx2 = req_done_tx.clone(); - let mut client_disconnected_rx2 = client_disconnected_rx.clone(); +impl ServerReader { + fn context(&self) -> HandlerContext { + HandlerContext { + fd: self.fd, + tx: self.tx.clone(), + services: self.services.clone(), + streams: self.streams.clone(), + _handler_shutdown_waiter: self.handler_shutdown.subscribe(), + } + } +} - select! { - resp = receive(&mut reader) => { - match resp { - Ok(message) => { - spawn(async move { - select! { - _ = handle_request(tx, fd, methods, message) => {} - _ = client_disconnected_rx2.changed() => {} - } +struct HandlerContext { + fd: RawFd, + tx: MessageSender, + services: Arc>, + streams: Arc>>, + // Used for waiting handler exit. + _handler_shutdown_waiter: shutdown::Waiter, +} - drop(req_done_tx2); - }); - } - Err(e) => { - let _ = client_disconnected_tx.send(true); - trace!("error {:?}", e); - break; - } +impl HandlerContext { + async fn handle_msg(&self, msg: GenMessage) { + let stream_id = msg.header.stream_id; + + if (stream_id % 2) != 1 { + Self::respond_with_status( + self.tx.clone(), + stream_id, + get_status(Code::INVALID_ARGUMENT, "stream id must be odd"), + ) + .await; + return; + } + + match msg.header.type_ { + MESSAGE_TYPE_REQUEST => match self.handle_request(msg).await { + Ok(opt_msg) => match opt_msg { + Some(msg) => { + Self::respond(self.tx.clone(), stream_id, msg) + .await + .map_err(|e| { + error!("respond got error {:?}", e); + }) + .ok(); + } + None => { + let mut header = MessageHeader::new_data(stream_id, 0); + header.set_flags(FLAG_REMOTE_CLOSED | FLAG_NO_DATA); + let msg = GenMessage { + header, + payload: Vec::new(), + }; + + self.tx + .send(msg) + .await + .map_err(err_to_others_err!(e, "Send packet to sender error ")) + .ok(); } + }, + Err(status) => Self::respond_with_status(self.tx.clone(), stream_id, status).await, + }, + MESSAGE_TYPE_DATA => { + // TODO(wllenyj): Compatible with golang behavior. + if (msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED + && !msg.payload.is_empty() + { + Self::respond_with_status( + self.tx.clone(), + stream_id, + get_status( + Code::INVALID_ARGUMENT, + format!( + "Stream id {}: data close message connot include data", + stream_id + ), + ), + ) + .await; + return; } - v = close_conn_rx.changed() => { - // 0 is the init value of this watch, not a valid signal - // is_err means the tx was dropped. - if v.is_err() || *close_conn_rx.borrow() != 0 { - info!("Stop accepting new connections."); - break; + let stream_tx = self.streams.lock().unwrap().get(&stream_id).cloned(); + if let Some(stream_tx) = stream_tx { + if let Err(e) = stream_tx.send(Ok(msg)).await { + Self::respond_with_status( + self.tx.clone(), + stream_id, + get_status( + Code::INVALID_ARGUMENT, + format!("Stream id {}: handling data error: {}", stream_id, e), + ), + ) + .await; } + } else { + Self::respond_with_status( + self.tx.clone(), + stream_id, + get_status(Code::INVALID_ARGUMENT, "Stream is no longer active"), + ) + .await; } } + _ => { + // TODO: else we must ignore this for future compat. log this? + // TODO(wllenyj): Compatible with golang behavior. + error!("Unknown message type. {:?}", msg.header); + } } + } - drop(req_done_tx); - all_req_done_rx.recv().await; - drop(conn_done_tx); - }); -} + async fn handle_request(&self, msg: GenMessage) -> StdResult, Status> { + //TODO: + //if header.stream_id <= self.last_stream_id { + // return Err; + //} + // self.last_stream_id = header.stream_id; -async fn do_handle_request( - fd: RawFd, - methods: Arc>>, - header: MessageHeader, - body: &[u8], -) -> StdResult<(u32, Vec), Status> { - let req = utils::body_to_request(body)?; - let path = utils::get_path(&req.service, &req.method); - let method = methods - .get(&path) - .ok_or_else(|| get_status(Code::INVALID_ARGUMENT, format!("{} does not exist", &path)))?; - - let ctx = TtrpcContext { - fd, - mh: header, - metadata: context::from_pb(&req.metadata), - timeout_nano: req.timeout_nano, - }; + let req_msg = Message::::try_from(msg) + .map_err(|e| get_status(Code::INVALID_ARGUMENT, e.to_string()))?; - let get_unknown_status_and_log_err = |e| { - error!("method handle {} got error {:?}", path, &e); - get_status(Code::UNKNOWN, e) - }; + let req = &req_msg.payload; + trace!("Got Message request {} {}", req.service, req.method); - if req.timeout_nano == 0 { - method - .handler(ctx, req) - .await - .map_err(get_unknown_status_and_log_err) - } else { - timeout( - Duration::from_nanos(req.timeout_nano as u64), - method.handler(ctx, req), - ) - .await - .map_err(|_| { - // Timed out - error!("method handle {} got error timed out", path); - get_status(Code::DEADLINE_EXCEEDED, "timeout") - }) - .and_then(|r| { - // Handler finished - r.map_err(get_unknown_status_and_log_err) - }) - } -} + let srv = self.services.get(&req.service).ok_or_else(|| { + get_status( + Code::INVALID_ARGUMENT, + format!("{} service does not exist", &req.service), + ) + })?; -async fn handle_request( - tx: Sender>, - fd: RawFd, - methods: Arc>>, - message: (MessageHeader, Vec), -) { - let (header, body) = message; - let stream_id = header.stream_id; - - if header.type_ != MESSAGE_TYPE_REQUEST { - return; + if let Some(method) = srv.get_method(&req.method) { + return self.handle_method(method, req_msg).await; + } + if let Some(stream) = srv.get_stream(&req.method) { + return self.handle_stream(stream, req_msg).await; + } + Err(get_status( + Code::UNIMPLEMENTED, + format!("{} method", &req.method), + )) } - match do_handle_request(fd, methods, header, &body).await { - Ok((stream_id, resp_body)) => { - if let Err(x) = respond(tx.clone(), stream_id, resp_body).await { - error!("respond got error {:?}", x); - } + async fn handle_method( + &self, + method: &(dyn MethodHandler + Send + Sync), + req_msg: Message, + ) -> StdResult, Status> { + let req = req_msg.payload; + let path = utils::get_path(&req.service, &req.method); + + let ctx = TtrpcContext { + fd: self.fd, + mh: req_msg.header, + metadata: context::from_pb(&req.metadata), + timeout_nano: req.timeout_nano, + }; + + let get_unknown_status_and_log_err = |e| { + error!("method handle {} got error {:?}", path, &e); + get_status(Code::UNKNOWN, e) + }; + if req.timeout_nano == 0 { + method + .handler(ctx, req) + .await + .map_err(get_unknown_status_and_log_err) + .map(Some) + } else { + timeout( + Duration::from_nanos(req.timeout_nano as u64), + method.handler(ctx, req), + ) + .await + .map_err(|_| { + // Timed out + error!("method handle {} got error timed out", path); + get_status(Code::DEADLINE_EXCEEDED, "timeout") + }) + .and_then(|r| { + // Handler finished + r.map_err(get_unknown_status_and_log_err) + }) + .map(Some) } - Err(status) => { - if let Err(x) = respond_with_status(tx.clone(), stream_id, status).await { - error!("respond got error {:?}", x); - } + } + + async fn handle_stream( + &self, + stream: Arc, + req_msg: Message, + ) -> StdResult, Status> { + let stream_id = req_msg.header.stream_id; + let req = req_msg.payload; + let path = utils::get_path(&req.service, &req.method); + + let (tx, rx): (ResultSender, ResultReceiver) = channel(100); + let stream_tx = tx.clone(); + self.streams.lock().unwrap().insert(stream_id, tx); + + let _remote_close = (req_msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED; + let _remote_open = (req_msg.header.flags & FLAG_REMOTE_OPEN) == FLAG_REMOTE_OPEN; + let si = StreamInner::new( + stream_id, + self.tx.clone(), + rx, + true, // TODO + true, + Kind::Server, + self.streams.clone(), + ); + + let ctx = TtrpcContext { + fd: self.fd, + mh: req_msg.header, + metadata: context::from_pb(&req.metadata), + timeout_nano: req.timeout_nano, + }; + + let task = spawn(async move { stream.handler(ctx, si).await }); + + if !req.payload.is_empty() { + // Fake the first data message. + let msg = GenMessage { + header: MessageHeader::new_data(stream_id, req.payload.len() as u32), + payload: req.payload, + }; + stream_tx.send(Ok(msg)).await.map_err(|e| { + error!("send stream data {} got error {:?}", path, &e); + get_status(Code::UNKNOWN, e) + })?; } + task.await + .unwrap_or_else(|e| { + Err(Error::Others(format!( + "stream {} task got error {:?}", + path, e + ))) + }) + .map_err(|e| get_status(Code::UNKNOWN, e)) } -} -impl FromRawFd for Server { - unsafe fn from_raw_fd(fd: RawFd) -> Self { - Self::default().add_listener(fd).unwrap() + async fn respond(tx: MessageSender, stream_id: u32, resp: Response) -> Result<()> { + let payload = resp + .encode() + .map_err(err_to_others_err!(e, "Encode Response failed."))?; + let msg = GenMessage { + header: MessageHeader::new_response(stream_id, payload.len() as u32), + payload, + }; + tx.send(msg) + .await + .map_err(err_to_others_err!(e, "Send packet to sender error ")) } -} -impl AsRawFd for Server { - fn as_raw_fd(&self) -> RawFd { - self.listeners[0] + async fn respond_with_status(tx: MessageSender, stream_id: u32, status: Status) { + let mut resp = Response::new(); + resp.set_status(status); + Self::respond(tx, stream_id, resp) + .await + .map_err(|e| { + error!("respond with status got error {:?}", e); + }) + .ok(); } } diff --git a/src/asynchronous/shutdown.rs b/src/asynchronous/shutdown.rs new file mode 100644 index 00000000..9d1bb136 --- /dev/null +++ b/src/asynchronous/shutdown.rs @@ -0,0 +1,311 @@ +// Copyright 2022 Alibaba Cloud. All rights reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; + +use tokio::sync::Notify; +use tokio::time::{error::Elapsed, timeout, Duration}; + +#[derive(Debug)] +struct Shared { + shutdown: AtomicBool, + notify_shutdown: Notify, + + waiters: AtomicUsize, + notify_exit: Notify, +} + +impl Shared { + fn is_shutdown(&self) -> bool { + self.shutdown.load(Ordering::Relaxed) + } +} + +/// Wait for the shutdown notification. +#[derive(Debug)] +pub struct Waiter { + shared: Arc, +} + +/// Used to Notify all [`Waiter`s](Waiter) shutdown. +/// +/// No `Clone` is provided. If you want multiple instances, you can use Arc. +/// Notifier will automatically call shutdown when dropping. +#[derive(Debug)] +pub struct Notifier { + shared: Arc, + wait_time: Option, +} + +/// Create a new shutdown pair([`Notifier`], [`Waiter`]) without timeout. +/// +/// The [`Notifier`] +pub fn new() -> (Notifier, Waiter) { + _with_timeout(None) +} + +/// Create a new shutdown pair with the specified [`Duration`]. +/// +/// The [`Duration`] is used to specify the timeout of the [`Notifier::wait_all_exit()`]. +/// +/// [`Duration`]: tokio::time::Duration +pub fn with_timeout(wait_time: Duration) -> (Notifier, Waiter) { + _with_timeout(Some(wait_time)) +} + +fn _with_timeout(wait_time: Option) -> (Notifier, Waiter) { + let shared = Arc::new(Shared { + shutdown: AtomicBool::new(false), + waiters: AtomicUsize::new(1), + notify_shutdown: Notify::new(), + notify_exit: Notify::new(), + }); + + let notifier = Notifier { + shared: shared.clone(), + wait_time, + }; + + let waiter = Waiter { shared }; + + (notifier, waiter) +} + +impl Waiter { + /// Return `true` if the [`Notifier::shutdown()`] has been called. + /// + /// [`Notifier::shutdown()`]: Notifier::shutdown() + pub fn is_shutdown(&self) -> bool { + self.shared.is_shutdown() + } + + /// Waiting for the [`Notifier::shutdown()`] to be called. + pub async fn wait_shutdown(&self) { + while !self.is_shutdown() { + let shutdown = self.shared.notify_shutdown.notified(); + if self.is_shutdown() { + return; + } + shutdown.await; + } + } + + fn from_shared(shared: Arc) -> Self { + shared.waiters.fetch_add(1, Ordering::Relaxed); + Self { shared } + } +} + +impl Clone for Waiter { + fn clone(&self) -> Self { + Self::from_shared(self.shared.clone()) + } +} + +impl Drop for Waiter { + fn drop(&mut self) { + if 1 == self.shared.waiters.fetch_sub(1, Ordering::Relaxed) { + self.shared.notify_exit.notify_waiters(); + } + } +} + +impl Notifier { + /// Return `true` if the [`Notifier::shutdown()`] has been called. + /// + /// [`Notifier::shutdown()`]: Notifier::shutdown() + pub fn is_shutdown(&self) -> bool { + self.shared.is_shutdown() + } + + /// Notify all [`Waiter`s](Waiter) shutdown. + /// + /// It will cause all calls blocking at `Waiter::wait_shutdown().await` to return. + pub fn shutdown(&self) { + let is_shutdown = self.shared.shutdown.swap(true, Ordering::Relaxed); + if !is_shutdown { + self.shared.notify_shutdown.notify_waiters(); + } + } + + /// Return the num of all [`Waiter`]s. + pub fn waiters(&self) -> usize { + self.shared.waiters.load(Ordering::Relaxed) + } + + /// Create a new [`Waiter`]. + pub fn subscribe(&self) -> Waiter { + Waiter::from_shared(self.shared.clone()) + } + + /// Wait for all [`Waiter`]s to drop. + pub async fn wait_all_exit(&self) -> Result<(), Elapsed> { + //debug_assert!(self.shared.is_shutdown()); + if self.waiters() == 0 { + return Ok(()); + } + let wait = self.wait(); + if self.waiters() == 0 { + return Ok(()); + } + wait.await + } + + async fn wait(&self) -> Result<(), Elapsed> { + if let Some(tm) = self.wait_time { + timeout(tm, self.shared.notify_exit.notified()).await + } else { + self.shared.notify_exit.notified().await; + Ok(()) + } + } +} + +impl Drop for Notifier { + fn drop(&mut self) { + self.shutdown() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[tokio::test] + async fn it_work() { + let (notifier, waiter) = new(); + + let task = tokio::spawn(async move { + waiter.wait_shutdown().await; + }); + + assert_eq!(notifier.waiters(), 1); + notifier.shutdown(); + task.await.unwrap(); + assert_eq!(notifier.waiters(), 0); + } + + #[tokio::test] + async fn notifier_drop() { + let (notifier, waiter) = new(); + assert_eq!(notifier.waiters(), 1); + assert!(!waiter.is_shutdown()); + drop(notifier); + assert!(waiter.is_shutdown()); + assert_eq!(waiter.shared.waiters.load(Ordering::Relaxed), 1); + } + + #[tokio::test] + async fn waiter_clone() { + let (notifier, waiter1) = new(); + assert_eq!(notifier.waiters(), 1); + + let waiter2 = waiter1.clone(); + assert_eq!(notifier.waiters(), 2); + + let waiter3 = notifier.subscribe(); + assert_eq!(notifier.waiters(), 3); + + drop(waiter2); + assert_eq!(notifier.waiters(), 2); + + let task = tokio::spawn(async move { + waiter3.wait_shutdown().await; + assert!(waiter3.is_shutdown()); + }); + + assert!(!waiter1.is_shutdown()); + notifier.shutdown(); + assert!(waiter1.is_shutdown()); + + task.await.unwrap(); + + assert_eq!(notifier.waiters(), 1); + } + + #[tokio::test] + async fn concurrency_notifier_shutdown() { + let (notifier, waiter) = new(); + let arc_notifier = Arc::new(notifier); + let notifier1 = arc_notifier.clone(); + let notifier2 = notifier1.clone(); + + let task1 = tokio::spawn(async move { + assert_eq!(notifier1.waiters(), 1); + + let waiter = notifier1.subscribe(); + assert_eq!(notifier1.waiters(), 2); + + notifier1.shutdown(); + waiter.wait_shutdown().await; + }); + + let task2 = tokio::spawn(async move { + assert_eq!(notifier2.waiters(), 1); + notifier2.shutdown(); + }); + waiter.wait_shutdown().await; + assert!(arc_notifier.is_shutdown()); + task1.await.unwrap(); + task2.await.unwrap(); + } + + #[tokio::test] + async fn concurrency_notifier_wait() { + let (notifier, waiter) = new(); + let arc_notifier = Arc::new(notifier); + let notifier1 = arc_notifier.clone(); + let notifier2 = notifier1.clone(); + + let task1 = tokio::spawn(async move { + notifier1.shutdown(); + notifier1.wait_all_exit().await.unwrap(); + }); + + let task2 = tokio::spawn(async move { + notifier2.shutdown(); + notifier2.wait_all_exit().await.unwrap(); + }); + + waiter.wait_shutdown().await; + drop(waiter); + task1.await.unwrap(); + task2.await.unwrap(); + } + + #[tokio::test] + async fn wait_all_exit() { + let (notifier, waiter) = new(); + let mut tasks = Vec::with_capacity(100); + for i in 0..100 { + assert_eq!(notifier.waiters(), 1 + i); + let waiter1 = waiter.clone(); + tasks.push(tokio::spawn(async move { + waiter1.wait_shutdown().await; + })); + } + drop(waiter); + assert_eq!(notifier.waiters(), 100); + notifier.shutdown(); + notifier.wait_all_exit().await.unwrap(); + for t in tasks { + t.await.unwrap(); + } + } + + #[tokio::test] + async fn wait_timeout() { + let (notifier, waiter) = with_timeout(Duration::from_millis(100)); + let task = tokio::spawn(async move { + waiter.wait_shutdown().await; + tokio::time::sleep(Duration::from_millis(200)).await; + }); + notifier.shutdown(); + // Elapsed + assert!(matches!(notifier.wait_all_exit().await, Err(_))); + task.await.unwrap(); + } +} diff --git a/src/asynchronous/stream.rs b/src/asynchronous/stream.rs index c99ba336..fdf90341 100644 --- a/src/asynchronous/stream.rs +++ b/src/asynchronous/stream.rs @@ -1,137 +1,482 @@ +// Copyright 2022 Alibaba Cloud. All rights reserved. // Copyright (c) 2020 Ant Financial // // SPDX-License-Identifier: Apache-2.0 // -use crate::common::{MESSAGE_HEADER_LENGTH, MESSAGE_LENGTH_MAX, MESSAGE_TYPE_RESPONSE}; -use crate::error::{get_rpc_status, sock_error_msg, Error, Result}; -use crate::proto::{Code, Response, Status}; -use crate::r#async::utils; -use crate::MessageHeader; -use protobuf::Message; -use tokio::io::AsyncReadExt; +use std::collections::HashMap; +use std::marker::PhantomData; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; -async fn receive_count(reader: &mut T, count: usize) -> Result> +use tokio::sync::mpsc; + +use crate::error::{Error, Result}; +use crate::proto::{ + Code, Codec, GenMessage, MessageHeader, Response, FLAG_NO_DATA, FLAG_REMOTE_CLOSED, + MESSAGE_TYPE_DATA, MESSAGE_TYPE_RESPONSE, +}; + +pub type MessageSender = mpsc::Sender; +pub type MessageReceiver = mpsc::Receiver; + +pub type ResultSender = mpsc::Sender>; +pub type ResultReceiver = mpsc::Receiver>; + +#[derive(Debug)] +pub struct ClientStream { + tx: CSSender, + rx: CSReceiver

, +} + +impl ClientStream where - T: AsyncReadExt + std::marker::Unpin, + Q: Codec, + P: Codec, + ::E: std::fmt::Display, +

::E: std::fmt::Display, { - let mut content = vec![0u8; count]; - if let Err(e) = reader.read_exact(&mut content).await { - return Err(Error::Socket(e.to_string())); + pub fn new(inner: StreamInner) -> Self { + let (tx, rx) = inner.split(); + Self { + tx: CSSender { + tx, + _send: PhantomData, + }, + rx: CSReceiver { + rx, + _recv: PhantomData, + }, + } + } + + pub fn split(self) -> (CSSender, CSReceiver

) { + (self.tx, self.rx) + } + + pub async fn send(&self, req: &Q) -> Result<()> { + self.tx.send(req).await + } + + pub async fn close_send(&self) -> Result<()> { + self.tx.close_send().await + } + + pub async fn recv(&mut self) -> Result

{ + self.rx.recv().await } +} - Ok(content) +#[derive(Clone, Debug)] +pub struct CSSender { + tx: StreamSender, + _send: PhantomData, } -async fn receive_header(reader: &mut T) -> Result +impl CSSender where - T: AsyncReadExt + std::marker::Unpin, + Q: Codec, + ::E: std::fmt::Display, { - let buf = receive_count(reader, MESSAGE_HEADER_LENGTH).await?; - let size = buf.len(); - if size != MESSAGE_HEADER_LENGTH { - return Err(sock_error_msg( - size, - format!("Message header length {} is too small", size), - )); + pub async fn send(&self, req: &Q) -> Result<()> { + let msg_buf = req + .encode() + .map_err(err_to_others_err!(e, "Encode message failed."))?; + self.tx.send(msg_buf).await } - let mh = MessageHeader::from(&buf); + pub async fn close_send(&self) -> Result<()> { + self.tx.close_send().await + } +} - Ok(mh) +#[derive(Debug)] +pub struct CSReceiver

{ + rx: StreamReceiver, + _recv: PhantomData

, } -pub(crate) async fn receive(reader: &mut T) -> Result<(MessageHeader, Vec)> +impl

CSReceiver

where - T: AsyncReadExt + std::marker::Unpin, + P: Codec, +

::E: std::fmt::Display, { - let mh = receive_header(reader).await?; - trace!("Got Message header {:?}", mh); + pub async fn recv(&mut self) -> Result

{ + let msg_buf = self.rx.recv().await?; + P::decode(&msg_buf).map_err(err_to_others_err!(e, "Decode message failed.")) + } +} - if mh.length > MESSAGE_LENGTH_MAX as u32 { - return Err(get_rpc_status( - Code::INVALID_ARGUMENT, - format!( - "message length {} exceed maximum message size of {}", - mh.length, MESSAGE_LENGTH_MAX - ), - )); +#[derive(Debug)] +pub struct ServerStream { + tx: SSSender

, + rx: SSReceiver, +} + +impl ServerStream +where + P: Codec, + Q: Codec, +

::E: std::fmt::Display, + ::E: std::fmt::Display, +{ + pub fn new(inner: StreamInner) -> Self { + let (tx, rx) = inner.split(); + Self { + tx: SSSender { + tx, + _send: PhantomData, + }, + rx: SSReceiver { + rx, + _recv: PhantomData, + }, + } } - let buf = receive_count(reader, mh.length as usize).await?; - let size = buf.len(); - if size != mh.length as usize { - return Err(sock_error_msg( - size, - format!("Message length {} is not {}", size, mh.length), - )); + pub fn split(self) -> (SSSender

, SSReceiver) { + (self.tx, self.rx) } - trace!("Got Message body {:?}", buf); - Ok((mh, buf)) + pub async fn send(&self, resp: &P) -> Result<()> { + self.tx.send(resp).await + } + + pub async fn recv(&mut self) -> Result> { + self.rx.recv().await + } } -fn header_to_buf(mh: MessageHeader) -> Vec { - mh.into() +#[derive(Clone, Debug)] +pub struct SSSender

{ + tx: StreamSender, + _send: PhantomData

, } -pub(crate) fn to_req_buf(stream_id: u32, mut body: Vec) -> Vec { - let header = utils::get_request_header_from_body(stream_id, &body); - let mut buf = header_to_buf(header); - buf.append(&mut body); +impl

SSSender

+where + P: Codec, +

::E: std::fmt::Display, +{ + pub async fn send(&self, resp: &P) -> Result<()> { + let msg_buf = resp + .encode() + .map_err(err_to_others_err!(e, "Encode message failed."))?; + self.tx.send(msg_buf).await + } +} - buf +#[derive(Debug)] +pub struct SSReceiver { + rx: StreamReceiver, + _recv: PhantomData, } -pub(crate) fn to_res_buf(stream_id: u32, mut body: Vec) -> Vec { - let header = utils::get_response_header_from_body(stream_id, &body); - let mut buf = header_to_buf(header); - buf.append(&mut body); +impl SSReceiver +where + Q: Codec, + ::E: std::fmt::Display, +{ + pub async fn recv(&mut self) -> Result> { + let res = self.rx.recv().await; - buf + if matches!(res, Err(Error::Eof)) { + return Ok(None); + } + let msg_buf = res?; + Q::decode(&msg_buf) + .map_err(err_to_others_err!(e, "Decode message failed.")) + .map(Some) + } } -fn get_response_body(res: &Response) -> Result> { - let mut buf = Vec::with_capacity(res.compute_size() as usize); - let mut s = protobuf::CodedOutputStream::vec(&mut buf); - res.write_to(&mut s).map_err(err_to_others_err!(e, ""))?; - s.flush().map_err(err_to_others_err!(e, ""))?; +pub struct ClientStreamSender { + inner: StreamInner, + _send: PhantomData, + _recv: PhantomData

, +} + +impl ClientStreamSender +where + Q: Codec, + P: Codec, + ::E: std::fmt::Display, +

::E: std::fmt::Display, +{ + pub fn new(inner: StreamInner) -> Self { + Self { + inner, + _send: PhantomData, + _recv: PhantomData, + } + } - Ok(buf) + pub async fn send(&self, req: &Q) -> Result<()> { + let msg_buf = req + .encode() + .map_err(err_to_others_err!(e, "Encode message failed."))?; + self.inner.send(msg_buf).await + } + + pub async fn close_and_recv(&mut self) -> Result

{ + self.inner.close_send().await?; + let msg_buf = self.inner.recv().await?; + P::decode(&msg_buf).map_err(err_to_others_err!(e, "Decode message failed.")) + } } -pub(crate) async fn respond( - tx: tokio::sync::mpsc::Sender>, - stream_id: u32, - body: Vec, -) -> Result<()> { - let buf = to_res_buf(stream_id, body); +pub struct ServerStreamSender

{ + inner: StreamSender, + _send: PhantomData

, +} + +impl

ServerStreamSender

+where + P: Codec, +

::E: std::fmt::Display, +{ + pub fn new(inner: StreamInner) -> Self { + Self { + inner: inner.split().0, + _send: PhantomData, + } + } + + pub async fn send(&self, resp: &P) -> Result<()> { + let msg_buf = resp + .encode() + .map_err(err_to_others_err!(e, "Encode message failed."))?; + self.inner.send(msg_buf).await + } +} + +pub struct ClientStreamReceiver

{ + inner: StreamReceiver, + _recv: PhantomData

, +} + +impl

ClientStreamReceiver

+where + P: Codec, +

::E: std::fmt::Display, +{ + pub fn new(inner: StreamInner) -> Self { + Self { + inner: inner.split().1, + _recv: PhantomData, + } + } + + pub async fn recv(&mut self) -> Result> { + let res = self.inner.recv().await; + if matches!(res, Err(Error::Eof)) { + return Ok(None); + } + let msg_buf = res?; + P::decode(&msg_buf) + .map_err(err_to_others_err!(e, "Decode message failed.")) + .map(Some) + } +} + +pub struct ServerStreamReceiver { + inner: StreamReceiver, + _recv: PhantomData, +} + +impl ServerStreamReceiver +where + Q: Codec, + ::E: std::fmt::Display, +{ + pub fn new(inner: StreamInner) -> Self { + Self { + inner: inner.split().1, + _recv: PhantomData, + } + } + + pub async fn recv(&mut self) -> Result> { + let res = self.inner.recv().await; + if matches!(res, Err(Error::Eof)) { + return Ok(None); + } + let msg_buf = res?; + Q::decode(&msg_buf) + .map_err(err_to_others_err!(e, "Decode message failed.")) + .map(Some) + } +} - tx.send(buf) +async fn _recv(rx: &mut ResultReceiver) -> Result { + rx.recv() .await - .map_err(err_to_others_err!(e, "Send packet to sender error ")) + .unwrap_or_else(|| Err(Error::Others("Receive packet from recver error".to_string()))) } -pub(crate) async fn respond_with_status( - tx: tokio::sync::mpsc::Sender>, - stream_id: u32, - status: Status, -) -> Result<()> { - let mut res = Response::new(); - res.set_status(status); - let mut body = get_response_body(&res)?; - - let mh = MessageHeader { - length: body.len() as u32, - stream_id, - type_: MESSAGE_TYPE_RESPONSE, - flags: 0, - }; - - let mut buf = header_to_buf(mh); - buf.append(&mut body); - - tx.send(buf) +async fn _send(tx: &MessageSender, msg: GenMessage) -> Result<()> { + tx.send(msg) .await - .map_err(err_to_others_err!(e, "Send packet to sender error ")) + .map_err(|e| Error::Others(format!("Send data packet to sender error {:?}", e))) +} + +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum Kind { + Client, + Server, +} + +#[derive(Debug)] +pub struct StreamInner { + sender: StreamSender, + receiver: StreamReceiver, +} + +impl StreamInner { + pub fn new( + stream_id: u32, + tx: MessageSender, + rx: ResultReceiver, + //waiter: shutdown::Waiter, + sendable: bool, + recveivable: bool, + kind: Kind, + streams: Arc>>, + ) -> Self { + Self { + sender: StreamSender { + tx, + stream_id, + sendable, + local_closed: Arc::new(AtomicBool::new(false)), + kind, + }, + receiver: StreamReceiver { + rx, + stream_id, + recveivable, + remote_closed: false, + kind, + streams, + }, + } + } + + fn split(self) -> (StreamSender, StreamReceiver) { + (self.sender, self.receiver) + } + + pub async fn send(&self, buf: Vec) -> Result<()> { + self.sender.send(buf).await + } + + pub async fn close_send(&self) -> Result<()> { + self.sender.close_send().await + } + + pub async fn recv(&mut self) -> Result> { + self.receiver.recv().await + } +} + +#[derive(Clone, Debug)] +pub struct StreamSender { + tx: MessageSender, + stream_id: u32, + sendable: bool, + local_closed: Arc, + kind: Kind, +} + +#[derive(Debug)] +pub struct StreamReceiver { + rx: ResultReceiver, + stream_id: u32, + recveivable: bool, + remote_closed: bool, + kind: Kind, + streams: Arc>>, +} + +impl Drop for StreamReceiver { + fn drop(&mut self) { + self.streams.lock().unwrap().remove(&self.stream_id); + } +} + +impl StreamSender { + pub async fn send(&self, buf: Vec) -> Result<()> { + debug_assert!(self.sendable); + if self.local_closed.load(Ordering::Relaxed) { + debug_assert_eq!(self.kind, Kind::Client); + return Err(Error::LocalClosed); + } + let header = MessageHeader::new_data(self.stream_id, buf.len() as u32); + let msg = GenMessage { + header, + payload: buf, + }; + _send(&self.tx, msg).await?; + + Ok(()) + } + + pub async fn close_send(&self) -> Result<()> { + debug_assert_eq!(self.kind, Kind::Client); + debug_assert!(self.sendable); + if self.local_closed.load(Ordering::Relaxed) { + return Err(Error::LocalClosed); + } + let mut header = MessageHeader::new_data(self.stream_id, 0); + header.set_flags(FLAG_REMOTE_CLOSED | FLAG_NO_DATA); + let msg = GenMessage { + header, + payload: Vec::new(), + }; + _send(&self.tx, msg).await?; + self.local_closed.store(true, Ordering::Relaxed); + Ok(()) + } +} + +impl StreamReceiver { + pub async fn recv(&mut self) -> Result> { + if self.remote_closed { + return Err(Error::RemoteClosed); + } + let msg = _recv(&mut self.rx).await?; + let payload = match msg.header.type_ { + MESSAGE_TYPE_RESPONSE => { + debug_assert_eq!(self.kind, Kind::Client); + self.remote_closed = true; + let resp = Response::decode(&msg.payload) + .map_err(err_to_others_err!(e, "Decode message failed."))?; + if let Some(status) = resp.status.as_ref() { + if status.get_code() != Code::OK { + return Err(Error::RpcStatus((*status).clone())); + } + } + resp.payload + } + MESSAGE_TYPE_DATA => { + if !self.recveivable { + self.remote_closed = true; + return Err(Error::Others( + "received data from non-streaming server.".to_string(), + )); + } + if (msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED { + self.remote_closed = true; + if (msg.header.flags & FLAG_NO_DATA) == FLAG_NO_DATA { + return Err(Error::Eof); + } + } + msg.payload + } + _ => { + return Err(Error::Others("not support".to_string())); + } + }; + Ok(payload) + } } diff --git a/src/asynchronous/utils.rs b/src/asynchronous/utils.rs index cbd2cc57..2e2555d1 100644 --- a/src/asynchronous/utils.rs +++ b/src/asynchronous/utils.rs @@ -1,18 +1,18 @@ +// Copyright 2022 Alibaba Cloud. All rights reserved. // Copyright (c) 2020 Ant Financial // // SPDX-License-Identifier: Apache-2.0 // -use crate::common::{MessageHeader, MESSAGE_TYPE_REQUEST, MESSAGE_TYPE_RESPONSE}; -use crate::error::{get_status, Result}; -use crate::proto::{Code, Request, Status}; -use async_trait::async_trait; -use protobuf::{CodedInputStream, Message}; use std::collections::HashMap; use std::os::unix::io::{FromRawFd, RawFd}; -use std::result::Result as StdResult; + +use async_trait::async_trait; use tokio::net::UnixStream; +use crate::error::Result; +use crate::proto::{MessageHeader, Request, Response}; + /// Handle request in async mode. #[macro_export] macro_rules! async_request_handler { @@ -47,12 +47,97 @@ macro_rules! async_request_handler { }, } - let mut buf = Vec::with_capacity(res.compute_size() as usize); - let mut s = protobuf::CodedOutputStream::vec(&mut buf); - res.write_to(&mut s).map_err(ttrpc::err_to_others!(e, ""))?; - s.flush().map_err(ttrpc::err_to_others!(e, ""))?; + return Ok(res); + }; +} + +/// Handle client streaming in async mode. +#[macro_export] +macro_rules! async_client_streamimg_handler { + ($class: ident, $ctx: ident, $inner: ident, $req_fn: ident) => { + let stream = ::ttrpc::r#async::ServerStreamReceiver::new($inner); + let mut res = ::ttrpc::Response::new(); + match $class.service.$req_fn(&$ctx, stream).await { + Ok(rep) => { + res.set_status(::ttrpc::get_status(::ttrpc::Code::OK, "".to_string())); + res.payload.reserve(rep.compute_size() as usize); + let mut s = protobuf::CodedOutputStream::vec(&mut res.payload); + rep.write_to(&mut s) + .map_err(::ttrpc::err_to_others!(e, ""))?; + s.flush().map_err(::ttrpc::err_to_others!(e, ""))?; + } + Err(x) => match x { + ::ttrpc::Error::RpcStatus(s) => { + res.set_status(s); + } + _ => { + res.set_status(::ttrpc::get_status( + ::ttrpc::Code::UNKNOWN, + format!("{:?}", x), + )); + } + }, + } + return Ok(Some(res)); + }; +} + +/// Handle server streaming in async mode. +#[macro_export] +macro_rules! async_server_streamimg_handler { + ($class: ident, $ctx: ident, $inner: ident, $server: ident, $req_type: ident, $req_fn: ident) => { + let req_buf = $inner.recv().await?; + let req = ::decode(&req_buf) + .map_err(|e| ::ttrpc::Error::Others(e.to_string()))?; + let stream = ::ttrpc::r#async::ServerStreamSender::new($inner); + match $class.service.$req_fn(&$ctx, req, stream).await { + Ok(_) => { + return Ok(None); + } + Err(x) => { + let mut res = ::ttrpc::Response::new(); + match x { + ::ttrpc::Error::RpcStatus(s) => { + res.set_status(s); + } + _ => { + res.set_status(::ttrpc::get_status( + ::ttrpc::Code::UNKNOWN, + format!("{:?}", x), + )); + } + } + return Ok(Some(res)); + } + } + }; +} - return Ok(($ctx.mh.stream_id, buf)); +/// Handle duplex streaming in async mode. +#[macro_export] +macro_rules! async_duplex_streamimg_handler { + ($class: ident, $ctx: ident, $inner: ident, $req_fn: ident) => { + let stream = ::ttrpc::r#async::ServerStream::new($inner); + match $class.service.$req_fn(&$ctx, stream).await { + Ok(_) => { + return Ok(None); + } + Err(x) => { + let mut res = ::ttrpc::Response::new(); + match x { + ::ttrpc::Error::RpcStatus(s) => { + res.set_status(s); + } + _ => { + res.set_status(::ttrpc::get_status( + ::ttrpc::Code::UNKNOWN, + format!("{:?}", x), + )); + } + } + return Ok(Some(res)); + } + } }; } @@ -84,10 +169,81 @@ macro_rules! async_client_request { }; } +/// Duplex streaming through async client. +#[macro_export] +macro_rules! async_client_stream { + ($self: ident, $ctx: ident, $server: expr, $method: expr) => { + let mut creq = ::ttrpc::Request::new(); + creq.set_service($server.to_string()); + creq.set_method($method.to_string()); + creq.set_timeout_nano($ctx.timeout_nano); + let md = ::ttrpc::context::to_pb($ctx.metadata); + creq.set_metadata(md); + + let inner = $self.client.new_stream(creq, true, true).await?; + let stream = ::ttrpc::r#async::ClientStream::new(inner); + + return Ok(stream); + }; +} + +/// Only send streaming through async client. +#[macro_export] +macro_rules! async_client_stream_send { + ($self: ident, $ctx: ident, $server: expr, $method: expr) => { + let mut creq = ::ttrpc::Request::new(); + creq.set_service($server.to_string()); + creq.set_method($method.to_string()); + creq.set_timeout_nano($ctx.timeout_nano); + let md = ::ttrpc::context::to_pb($ctx.metadata); + creq.set_metadata(md); + + let inner = $self.client.new_stream(creq, true, false).await?; + let stream = ::ttrpc::r#async::ClientStreamSender::new(inner); + + return Ok(stream); + }; +} + +/// Only receive streaming through async client. +#[macro_export] +macro_rules! async_client_stream_receive { + ($self: ident, $ctx: ident, $req: ident, $server: expr, $method: expr) => { + let mut creq = ::ttrpc::Request::new(); + creq.set_service($server.to_string()); + creq.set_method($method.to_string()); + creq.set_timeout_nano($ctx.timeout_nano); + let md = ::ttrpc::context::to_pb($ctx.metadata); + creq.set_metadata(md); + creq.payload.reserve($req.compute_size() as usize); + { + let mut s = CodedOutputStream::vec(&mut creq.payload); + $req.write_to(&mut s) + .map_err(::ttrpc::err_to_others!(e, ""))?; + s.flush().map_err(::ttrpc::err_to_others!(e, ""))?; + } + + let inner = $self.client.new_stream(creq, false, true).await?; + let stream = ::ttrpc::r#async::ClientStreamReceiver::new(inner); + + return Ok(stream); + }; +} + /// Trait that implements handler which is a proxy to the desired method (async). #[async_trait] pub trait MethodHandler { - async fn handler(&self, ctx: TtrpcContext, req: Request) -> Result<(u32, Vec)>; + async fn handler(&self, ctx: TtrpcContext, req: Request) -> Result; +} + +/// Trait that implements handler which is a proxy to the stream (async). +#[async_trait] +pub trait StreamHandler { + async fn handler( + &self, + ctx: TtrpcContext, + stream: crate::r#async::StreamInner, + ) -> Result>; } /// The context of ttrpc (async). @@ -99,24 +255,6 @@ pub struct TtrpcContext { pub timeout_nano: i64, } -pub(crate) fn get_response_header_from_body(stream_id: u32, body: &[u8]) -> MessageHeader { - MessageHeader { - length: body.len() as u32, - stream_id, - type_: MESSAGE_TYPE_RESPONSE, - flags: 0, - } -} - -pub(crate) fn get_request_header_from_body(stream_id: u32, body: &[u8]) -> MessageHeader { - MessageHeader { - length: body.len() as u32, - stream_id, - type_: MESSAGE_TYPE_REQUEST, - flags: 0, - } -} - pub(crate) fn new_unix_stream_from_raw_fd(fd: RawFd) -> UnixStream { let std_stream: std::os::unix::net::UnixStream; unsafe { @@ -128,23 +266,6 @@ pub(crate) fn new_unix_stream_from_raw_fd(fd: RawFd) -> UnixStream { UnixStream::from_std(std_stream).unwrap() } -pub(crate) fn body_to_request(body: &[u8]) -> StdResult { - let mut req = Request::new(); - let merge_result; - { - let mut s = CodedInputStream::from_bytes(body); - merge_result = req.merge_from(&mut s); - } - - if merge_result.is_err() { - return Err(get_status(Code::INVALID_ARGUMENT, "".to_string())); - } - - trace!("Got Message request {:?}", req); - - Ok(req) -} - pub(crate) fn get_path(service: &str, method: &str) -> String { format!("/{}/{}", service, method) } diff --git a/src/common.rs b/src/common.rs index b5be1833..3f39b070 100644 --- a/src/common.rs +++ b/src/common.rs @@ -6,7 +6,6 @@ //! Common functions and macros. use crate::error::{Error, Result}; -use byteorder::{BigEndian, ByteOrder}; #[cfg(any(feature = "async", not(target_os = "linux")))] use nix::fcntl::FdFlag; use nix::fcntl::{fcntl, FcntlArg, OFlag}; @@ -20,53 +19,6 @@ pub(crate) enum Domain { Vsock, } -/// Message header of ttrpc. -#[derive(Default, Debug)] -pub struct MessageHeader { - pub length: u32, - pub stream_id: u32, - pub type_: u8, - pub flags: u8, -} - -impl From for MessageHeader -where - T: AsRef<[u8]>, -{ - fn from(buf: T) -> Self { - let buf = buf.as_ref(); - debug_assert!(buf.len() >= MESSAGE_HEADER_LENGTH); - Self { - length: BigEndian::read_u32(&buf[..4]), - stream_id: BigEndian::read_u32(&buf[4..8]), - type_: buf[8], - flags: buf[9], - } - } -} - -impl From for Vec { - fn from(mh: MessageHeader) -> Self { - let mut buf = vec![0u8; MESSAGE_HEADER_LENGTH]; - mh.into_buf(&mut buf); - buf - } -} - -impl MessageHeader { - pub(crate) fn into_buf(self, mut buf: impl AsMut<[u8]>) { - let buf = buf.as_mut(); - debug_assert!(buf.len() >= MESSAGE_HEADER_LENGTH); - - let covbuf: &mut [u8] = &mut buf[..4]; - BigEndian::write_u32(covbuf, self.length); - let covbuf: &mut [u8] = &mut buf[4..8]; - BigEndian::write_u32(covbuf, self.stream_id); - buf[8] = self.type_; - buf[9] = self.flags; - } -} - pub(crate) fn do_listen(listener: RawFd) -> Result<()> { if let Err(e) = fcntl(listener, FcntlArg::F_SETFL(OFlag::O_NONBLOCK)) { return Err(Error::Others(format!( @@ -238,12 +190,6 @@ macro_rules! cfg_async { } } -pub const MESSAGE_HEADER_LENGTH: usize = 10; -pub const MESSAGE_LENGTH_MAX: usize = 4 << 20; - -pub const MESSAGE_TYPE_REQUEST: u8 = 0x1; -pub const MESSAGE_TYPE_RESPONSE: u8 = 0x2; - #[cfg(test)] mod tests { use super::*; @@ -306,23 +252,4 @@ mod tests { } } } - - #[test] - fn message_header() { - let buf = vec![ - 0x10, 0x0, 0x0, 0x0, // length - 0x0, 0x0, 0x0, 0x03, // stream_id - 0x2, // type_ - 0xef, // flags - ]; - let mh = MessageHeader::from(&buf); - assert_eq!(mh.length, 0x1000_0000); - assert_eq!(mh.stream_id, 0x3); - assert_eq!(mh.type_, MESSAGE_TYPE_RESPONSE); - assert_eq!(mh.flags, 0xef); - - let mut buf2 = vec![0; MESSAGE_HEADER_LENGTH]; - mh.into_buf(&mut buf2); - assert_eq!(&buf, &buf2); - } } diff --git a/src/error.rs b/src/error.rs index 1e9afc21..4f7e2e3a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -30,6 +30,15 @@ pub enum Error { #[error("Nix error: {0}")] Nix(#[from] nix::Error), + #[error("ttrpc err: local stream closed")] + LocalClosed, + + #[error("ttrpc err: remote stream closed")] + RemoteClosed, + + #[error("eof")] + Eof, + #[error("ttrpc err: {0}")] Others(String), } diff --git a/src/lib.rs b/src/lib.rs index b5d2dfe6..4f913d44 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,20 +48,15 @@ extern crate log; pub mod error; #[macro_use] mod common; -#[allow(soft_unstable, clippy::type_complexity, clippy::too_many_arguments)] -mod compiled { - include!(concat!(env!("OUT_DIR"), "/mod.rs")); -} -pub use compiled::ttrpc as proto; pub mod context; +pub mod proto; #[doc(inline)] -pub use crate::common::MessageHeader; +pub use self::proto::{Code, MessageHeader, Request, Response, Status}; + #[doc(inline)] pub use crate::error::{get_status, Error, Result}; -#[doc(inline)] -pub use proto::{Code, Request, Response, Status}; cfg_sync! { pub mod sync; diff --git a/src/proto.rs b/src/proto.rs new file mode 100644 index 00000000..65d477b4 --- /dev/null +++ b/src/proto.rs @@ -0,0 +1,498 @@ +// Copyright 2022 Alibaba Cloud. All rights reserved. +// Copyright (c) 2020 Ant Financial +// +// SPDX-License-Identifier: Apache-2.0 +// + +#[allow(soft_unstable, clippy::type_complexity, clippy::too_many_arguments)] +mod compiled { + include!(concat!(env!("OUT_DIR"), "/mod.rs")); +} +pub use compiled::ttrpc::*; + +use byteorder::{BigEndian, ByteOrder}; +use protobuf::{CodedInputStream, CodedOutputStream}; + +#[cfg(feature = "async")] +use crate::error::{get_rpc_status, Error, Result as TtResult}; + +pub const MESSAGE_HEADER_LENGTH: usize = 10; +pub const MESSAGE_LENGTH_MAX: usize = 4 << 20; + +pub const MESSAGE_TYPE_REQUEST: u8 = 0x1; +pub const MESSAGE_TYPE_RESPONSE: u8 = 0x2; +pub const MESSAGE_TYPE_DATA: u8 = 0x3; + +pub const FLAG_REMOTE_CLOSED: u8 = 0x1; +pub const FLAG_REMOTE_OPEN: u8 = 0x2; +pub const FLAG_NO_DATA: u8 = 0x4; + +/// Message header of ttrpc. +#[derive(Default, Debug, Clone, Copy, PartialEq)] +pub struct MessageHeader { + pub length: u32, + pub stream_id: u32, + pub type_: u8, + pub flags: u8, +} + +impl From for MessageHeader +where + T: AsRef<[u8]>, +{ + fn from(buf: T) -> Self { + let buf = buf.as_ref(); + debug_assert!(buf.len() >= MESSAGE_HEADER_LENGTH); + Self { + length: BigEndian::read_u32(&buf[..4]), + stream_id: BigEndian::read_u32(&buf[4..8]), + type_: buf[8], + flags: buf[9], + } + } +} + +impl From for Vec { + fn from(mh: MessageHeader) -> Self { + let mut buf = vec![0u8; MESSAGE_HEADER_LENGTH]; + mh.into_buf(&mut buf); + buf + } +} + +impl MessageHeader { + /// Creates a request MessageHeader from stream_id and len. + /// + /// Use the default message type MESSAGE_TYPE_REQUEST, and default flags 0. + pub fn new_request(stream_id: u32, len: u32) -> Self { + Self { + length: len, + stream_id, + type_: MESSAGE_TYPE_REQUEST, + flags: 0, + } + } + + /// Creates a response MessageHeader from stream_id and len. + /// + /// Use the MESSAGE_TYPE_RESPONSE message type, and default flags 0. + pub fn new_response(stream_id: u32, len: u32) -> Self { + Self { + length: len, + stream_id, + type_: MESSAGE_TYPE_RESPONSE, + flags: 0, + } + } + + /// Creates a data MessageHeader from stream_id and len. + /// + /// Use the MESSAGE_TYPE_DATA message type, and default flags 0. + pub fn new_data(stream_id: u32, len: u32) -> Self { + Self { + length: len, + stream_id, + type_: MESSAGE_TYPE_DATA, + flags: 0, + } + } + + /// Set the stream_id of message using the given value. + pub fn set_stream_id(&mut self, stream_id: u32) { + self.stream_id = stream_id; + } + + /// Set the flags of message using the given flags. + pub fn set_flags(&mut self, flags: u8) { + self.flags = flags; + } + + /// Add a new flags to the message. + pub fn add_flags(&mut self, flags: u8) { + self.flags |= flags; + } + + pub(crate) fn into_buf(self, mut buf: impl AsMut<[u8]>) { + let buf = buf.as_mut(); + debug_assert!(buf.len() >= MESSAGE_HEADER_LENGTH); + + let covbuf: &mut [u8] = &mut buf[..4]; + BigEndian::write_u32(covbuf, self.length); + let covbuf: &mut [u8] = &mut buf[4..8]; + BigEndian::write_u32(covbuf, self.stream_id); + buf[8] = self.type_; + buf[9] = self.flags; + } +} + +#[cfg(feature = "async")] +impl MessageHeader { + /// Encodes a MessageHeader to writer. + pub async fn write_to( + &self, + mut writer: impl tokio::io::AsyncWriteExt + Unpin, + ) -> std::io::Result<()> { + writer.write_u32(self.length).await?; + writer.write_u32(self.stream_id).await?; + writer.write_u8(self.type_).await?; + writer.write_u8(self.flags).await?; + writer.flush().await + } + + /// Decodes a MessageHeader from reader. + pub async fn read_from( + mut reader: impl tokio::io::AsyncReadExt + Unpin, + ) -> std::io::Result { + let mut content = vec![0; MESSAGE_HEADER_LENGTH]; + reader.read_exact(&mut content).await?; + Ok(MessageHeader::from(&content)) + } +} + +/// Generic message of ttrpc. +#[derive(Default, Debug, Clone, PartialEq)] +pub struct GenMessage { + pub header: MessageHeader, + pub payload: Vec, +} + +#[cfg(feature = "async")] +impl GenMessage { + /// Encodes a MessageHeader to writer. + pub async fn write_to( + &self, + mut writer: impl tokio::io::AsyncWriteExt + Unpin, + ) -> TtResult<()> { + self.header + .write_to(&mut writer) + .await + .map_err(|e| Error::Socket(e.to_string()))?; + writer + .write_all(&self.payload) + .await + .map_err(|e| Error::Socket(e.to_string()))?; + Ok(()) + } + + /// Decodes a MessageHeader from reader. + pub async fn read_from(mut reader: impl tokio::io::AsyncReadExt + Unpin) -> TtResult { + let header = MessageHeader::read_from(&mut reader) + .await + .map_err(|e| Error::Socket(e.to_string()))?; + + if header.length > MESSAGE_LENGTH_MAX as u32 { + return Err(get_rpc_status( + Code::INVALID_ARGUMENT, + format!( + "message length {} exceed maximum message size of {}", + header.length, MESSAGE_LENGTH_MAX + ), + )); + } + + let mut content = vec![0; header.length as usize]; + reader + .read_exact(&mut content) + .await + .map_err(|e| Error::Socket(e.to_string()))?; + + Ok(Self { + header, + payload: content, + }) + } +} + +/// TTRPC codec, only protobuf is supported. +pub trait Codec { + type E; + + fn size(&self) -> u32; + fn encode(&self) -> Result, Self::E>; + fn decode(buf: impl AsRef<[u8]>) -> Result + where + Self: Sized; +} + +impl Codec for M { + type E = protobuf::error::ProtobufError; + + fn size(&self) -> u32 { + self.compute_size() + } + + fn encode(&self) -> Result, Self::E> { + let mut buf = vec![0; self.compute_size() as usize]; + let mut s = CodedOutputStream::bytes(&mut buf); + self.write_to(&mut s)?; + s.flush()?; + Ok(buf) + } + + fn decode(buf: impl AsRef<[u8]>) -> Result { + let mut s = CodedInputStream::from_bytes(buf.as_ref()); + M::parse_from(&mut s) + } +} + +/// Message of ttrpc. +#[derive(Default, Debug, Clone, PartialEq)] +pub struct Message { + pub header: MessageHeader, + pub payload: C, +} + +impl std::convert::TryFrom for Message +where + C: Codec, +{ + type Error = C::E; + fn try_from(gen: GenMessage) -> Result { + Ok(Self { + header: gen.header, + payload: C::decode(&gen.payload)?, + }) + } +} + +impl std::convert::TryFrom> for GenMessage +where + C: Codec, +{ + type Error = C::E; + fn try_from(msg: Message) -> Result { + Ok(Self { + header: msg.header, + payload: msg.payload.encode()?, + }) + } +} + +impl Message { + pub fn new_request(stream_id: u32, message: C) -> Self { + Self { + header: MessageHeader::new_request(stream_id, message.size()), + payload: message, + } + } +} + +#[cfg(feature = "async")] +impl Message +where + C: Codec, + C::E: std::fmt::Display, +{ + /// Encodes a MessageHeader to writer. + pub async fn write_to( + &self, + mut writer: impl tokio::io::AsyncWriteExt + Unpin, + ) -> TtResult<()> { + self.header + .write_to(&mut writer) + .await + .map_err(|e| Error::Socket(e.to_string()))?; + let content = self + .payload + .encode() + .map_err(err_to_others_err!(e, "Encode payload failed."))?; + writer + .write_all(&content) + .await + .map_err(|e| Error::Socket(e.to_string()))?; + Ok(()) + } + + /// Decodes a MessageHeader from reader. + pub async fn read_from(mut reader: impl tokio::io::AsyncReadExt + Unpin) -> TtResult { + let header = MessageHeader::read_from(&mut reader) + .await + .map_err(|e| Error::Socket(e.to_string()))?; + + if header.length > MESSAGE_LENGTH_MAX as u32 { + return Err(get_rpc_status( + Code::INVALID_ARGUMENT, + format!( + "message length {} exceed maximum message size of {}", + header.length, MESSAGE_LENGTH_MAX + ), + )); + } + + let mut content = vec![0; header.length as usize]; + reader + .read_exact(&mut content) + .await + .map_err(|e| Error::Socket(e.to_string()))?; + let payload = + C::decode(content).map_err(err_to_others_err!(e, "Decode payload failed."))?; + Ok(Self { header, payload }) + } +} + +#[cfg(test)] +mod tests { + use std::convert::{TryFrom, TryInto}; + + use super::*; + + static MESSAGE_HEADER: [u8; MESSAGE_HEADER_LENGTH] = [ + 0x10, 0x0, 0x0, 0x0, // length + 0x0, 0x0, 0x0, 0x03, // stream_id + 0x2, // type_ + 0xef, // flags + ]; + + #[test] + fn message_header() { + let mh = MessageHeader::from(&MESSAGE_HEADER); + assert_eq!(mh.length, 0x1000_0000); + assert_eq!(mh.stream_id, 0x3); + assert_eq!(mh.type_, MESSAGE_TYPE_RESPONSE); + assert_eq!(mh.flags, 0xef); + + let mut buf2 = vec![0; MESSAGE_HEADER_LENGTH]; + mh.into_buf(&mut buf2); + assert_eq!(&MESSAGE_HEADER, &buf2[..]); + + let mh = MessageHeader::from(&PROTOBUF_MESSAGE_HEADER); + assert_eq!(mh.length as usize, TEST_PAYLOAD_LEN); + } + + #[rustfmt::skip] + static PROTOBUF_MESSAGE_HEADER: [u8; MESSAGE_HEADER_LENGTH] = [ + 0x00, 0x0, 0x0, TEST_PAYLOAD_LEN as u8, // length + 0x0, 0x12, 0x34, 0x56, // stream_id + 0x1, // type_ + 0xef, // flags + ]; + + const TEST_PAYLOAD_LEN: usize = 67; + static PROTOBUF_REQUEST: [u8; TEST_PAYLOAD_LEN] = [ + 10, 17, 103, 114, 112, 99, 46, 84, 101, 115, 116, 83, 101, 114, 118, 105, 99, 101, 115, 18, + 4, 84, 101, 115, 116, 26, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 32, 128, 218, 196, 9, 42, 24, 10, + 9, 116, 101, 115, 116, 95, 107, 101, 121, 49, 18, 11, 116, 101, 115, 116, 95, 118, 97, 108, + 117, 101, 49, + ]; + + fn new_protobuf_request() -> Request { + let mut creq = Request::new(); + creq.set_service("grpc.TestServices".to_string()); + creq.set_method("Test".to_string()); + creq.set_timeout_nano(20 * 1000 * 1000); + let mut meta: protobuf::RepeatedField = protobuf::RepeatedField::default(); + meta.push(KeyValue { + key: "test_key1".to_string(), + value: "test_value1".to_string(), + ..Default::default() + }); + creq.set_metadata(meta); + creq.payload = vec![0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9]; + creq + } + + #[test] + fn protobuf_codec() { + let creq = new_protobuf_request(); + let buf = creq.encode().unwrap(); + assert_eq!(&buf, &PROTOBUF_REQUEST); + let dreq = Request::decode(&buf).unwrap(); + assert_eq!(creq, dreq); + let dreq2 = Request::decode(&PROTOBUF_REQUEST).unwrap(); + assert_eq!(creq, dreq2); + } + + #[test] + fn gen_message_to_message() { + let req = new_protobuf_request(); + let msg = Message::new_request(3, req); + let msg_clone = msg.clone(); + let gen: GenMessage = msg.try_into().unwrap(); + let dmsg = Message::::try_from(gen).unwrap(); + assert_eq!(msg_clone, dmsg); + } + + #[cfg(feature = "async")] + #[tokio::test] + async fn async_message_header() { + use std::io::Cursor; + let mut buf = vec![]; + let mut io = Cursor::new(&mut buf); + let mh = MessageHeader::from(&MESSAGE_HEADER); + mh.write_to(&mut io).await.unwrap(); + assert_eq!(buf, &MESSAGE_HEADER); + + let dmh = MessageHeader::read_from(&buf[..]).await.unwrap(); + assert_eq!(mh, dmh); + } + + #[cfg(feature = "async")] + #[tokio::test] + async fn async_gen_message() { + let mut buf = Vec::from(MESSAGE_HEADER); + buf.extend_from_slice(&PROTOBUF_REQUEST); + let res = GenMessage::read_from(&*buf).await; + // exceed maximum message size + assert!(matches!(res, Err(Error::RpcStatus(_)))); + + let mut buf = Vec::from(PROTOBUF_MESSAGE_HEADER); + buf.extend_from_slice(&PROTOBUF_REQUEST); + buf.extend_from_slice(&[0x0, 0x0]); + let gen = GenMessage::read_from(&*buf).await.unwrap(); + assert_eq!(gen.header.length as usize, TEST_PAYLOAD_LEN); + assert_eq!(gen.header.length, gen.payload.len() as u32); + assert_eq!(gen.header.stream_id, 0x123456); + assert_eq!(gen.header.type_, MESSAGE_TYPE_REQUEST); + assert_eq!(gen.header.flags, 0xef); + assert_eq!(&gen.payload, &PROTOBUF_REQUEST); + assert_eq!( + &buf[MESSAGE_HEADER_LENGTH + TEST_PAYLOAD_LEN..], + &[0x0, 0x0] + ); + + let mut dbuf = vec![]; + let mut io = std::io::Cursor::new(&mut dbuf); + gen.write_to(&mut io).await.unwrap(); + assert_eq!(&*dbuf, &buf[..MESSAGE_HEADER_LENGTH + TEST_PAYLOAD_LEN]); + } + + #[cfg(feature = "async")] + #[tokio::test] + async fn async_message() { + let mut buf = Vec::from(MESSAGE_HEADER); + buf.extend_from_slice(&PROTOBUF_REQUEST); + let res = Message::::read_from(&*buf).await; + // exceed maximum message size + assert!(matches!(res, Err(Error::RpcStatus(_)))); + + let mut buf = Vec::from(PROTOBUF_MESSAGE_HEADER); + buf.extend_from_slice(&PROTOBUF_REQUEST); + buf.extend_from_slice(&[0x0, 0x0]); + let msg = Message::::read_from(&*buf).await.unwrap(); + assert_eq!(msg.header.length, 67); + assert_eq!(msg.header.length, msg.payload.size() as u32); + assert_eq!(msg.header.stream_id, 0x123456); + assert_eq!(msg.header.type_, MESSAGE_TYPE_REQUEST); + assert_eq!(msg.header.flags, 0xef); + assert_eq!(&msg.payload.service, "grpc.TestServices"); + assert_eq!(&msg.payload.method, "Test"); + assert_eq!( + msg.payload.payload, + vec![0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9] + ); + assert_eq!(msg.payload.timeout_nano, 20 * 1000 * 1000); + assert_eq!(msg.payload.metadata.len(), 1); + assert_eq!(&msg.payload.metadata[0].key, "test_key1"); + assert_eq!(&msg.payload.metadata[0].value, "test_value1"); + + let req = new_protobuf_request(); + let mut dmsg = Message::new_request(u32::MAX, req); + dmsg.header.set_stream_id(0x123456); + dmsg.header.set_flags(0xe0); + dmsg.header.add_flags(0x0f); + let mut dbuf = vec![]; + let mut io = std::io::Cursor::new(&mut dbuf); + dmsg.write_to(&mut io).await.unwrap(); + assert_eq!(&dbuf, &buf[..MESSAGE_HEADER_LENGTH + TEST_PAYLOAD_LEN]); + } +} diff --git a/src/sync/channel.rs b/src/sync/channel.rs index 07e42e81..4c49dfc5 100644 --- a/src/sync/channel.rs +++ b/src/sync/channel.rs @@ -15,10 +15,8 @@ use nix::sys::socket::*; use std::os::unix::io::RawFd; -use crate::common::{MESSAGE_HEADER_LENGTH, MESSAGE_LENGTH_MAX}; use crate::error::{get_rpc_status, sock_error_msg, Error, Result}; -use crate::proto::Code; -use crate::MessageHeader; +use crate::proto::{Code, MessageHeader, MESSAGE_HEADER_LENGTH, MESSAGE_LENGTH_MAX}; fn retryable(e: nix::Error) -> bool { use ::nix::Error; diff --git a/src/sync/client.rs b/src/sync/client.rs index 1d2587b3..ece9ee0b 100644 --- a/src/sync/client.rs +++ b/src/sync/client.rs @@ -16,7 +16,6 @@ use nix::sys::socket::*; use nix::unistd::close; -use protobuf::{CodedInputStream, CodedOutputStream, Message}; use std::collections::HashMap; use std::os::unix::io::RawFd; use std::sync::mpsc; @@ -25,11 +24,10 @@ use std::{io, thread}; #[cfg(target_os = "macos")] use crate::common::set_fd_close_exec; -use crate::common::{client_connect, MESSAGE_TYPE_REQUEST, MESSAGE_TYPE_RESPONSE, SOCK_CLOEXEC}; +use crate::common::{client_connect, SOCK_CLOEXEC}; use crate::error::{Error, Result}; -use crate::proto::{Code, Request, Response}; +use crate::proto::{Code, Codec, MessageHeader, Request, Response, MESSAGE_TYPE_RESPONSE}; use crate::sync::channel::{read_message, write_message}; -use crate::MessageHeader; use std::time::Duration; type Sender = mpsc::Sender<(Vec, mpsc::SyncSender>>)>; @@ -80,12 +78,8 @@ impl Client { let mut map = recver_map.lock().unwrap(); map.insert(current_stream_id, recver_tx.clone()); } - let mh = MessageHeader { - length: buf.len() as u32, - stream_id: current_stream_id, - type_: MESSAGE_TYPE_REQUEST, - flags: 0, - }; + let mut mh = MessageHeader::new_request(0, buf.len() as u32); + mh.set_stream_id(current_stream_id); if let Err(e) = write_message(fd, mh, buf) { //Remove current_stream_id and recver_tx to recver_map { @@ -214,10 +208,7 @@ impl Client { } } pub fn request(&self, req: Request) -> Result { - let mut buf = Vec::with_capacity(req.compute_size() as usize); - let mut s = CodedOutputStream::vec(&mut buf); - req.write_to(&mut s).map_err(err_to_others_err!(e, ""))?; - s.flush().map_err(err_to_others_err!(e, ""))?; + let buf = req.encode().map_err(err_to_others_err!(e, ""))?; let (tx, rx) = mpsc::sync_channel(0); @@ -237,10 +228,8 @@ impl Client { }; let buf = result?; - let mut s = CodedInputStream::from_bytes(&buf); - let mut res = Response::new(); - res.merge_from(&mut s) - .map_err(err_to_others_err!(e, "Unpack response error "))?; + let res = + Response::decode(&buf).map_err(err_to_others_err!(e, "Unpack response error "))?; let status = res.get_status(); if status.get_code() != Code::OK { diff --git a/src/sync/server.rs b/src/sync/server.rs index 282c072e..9bee608c 100644 --- a/src/sync/server.rs +++ b/src/sync/server.rs @@ -26,14 +26,13 @@ use std::thread::JoinHandle; use std::{io, thread}; use super::utils::response_to_channel; +use crate::common; #[cfg(not(target_os = "linux"))] use crate::common::set_fd_close_exec; -use crate::common::{self, MESSAGE_TYPE_REQUEST}; use crate::context; use crate::error::{get_status, Error, Result}; -use crate::proto::{Code, Request, Response}; +use crate::proto::{Code, MessageHeader, Request, Response, MESSAGE_TYPE_REQUEST}; use crate::sync::channel::{read_message, write_message}; -use crate::MessageHeader; use crate::{MethodHandler, TtrpcContext}; // poll_queue will create WAIT_THREAD_COUNT_DEFAULT threads in begin. diff --git a/src/sync/utils.rs b/src/sync/utils.rs index 2155f180..b607b98b 100644 --- a/src/sync/utils.rs +++ b/src/sync/utils.rs @@ -3,9 +3,8 @@ // SPDX-License-Identifier: Apache-2.0 // -use crate::common::{MessageHeader, MESSAGE_TYPE_RESPONSE}; use crate::error::{Error, Result}; -use crate::proto::{Request, Response}; +use crate::proto::{MessageHeader, Request, Response, MESSAGE_TYPE_RESPONSE}; use protobuf::Message; use std::collections::HashMap; diff --git a/ttrpc-codegen/Cargo.toml b/ttrpc-codegen/Cargo.toml index fc6eb32b..2805bcf7 100644 --- a/ttrpc-codegen/Cargo.toml +++ b/ttrpc-codegen/Cargo.toml @@ -16,4 +16,4 @@ readme = "README.md" protobuf = { version = "2.14.0" } protobuf-codegen-pure = "2.14.0" protobuf-codegen = "2.14.0" -ttrpc-compiler = "0.5.0" +ttrpc-compiler = { version = "0.5.0", path = "../compiler" }