diff --git a/compiler/src/codegen.rs b/compiler/src/codegen.rs index 1f90d726..c871ceaf 100644 --- a/compiler/src/codegen.rs +++ b/compiler/src/codegen.rs @@ -55,7 +55,6 @@ struct MethodGen<'a> { proto: &'a MethodDescriptorProto, package_name: String, service_name: String, - service_path: String, root_scope: &'a RootScope<'a>, customize: &'a Customize, } @@ -65,7 +64,6 @@ impl<'a> MethodGen<'a> { proto: &'a MethodDescriptorProto, package_name: String, service_name: String, - service_path: String, root_scope: &'a RootScope<'a>, customize: &'a Customize, ) -> MethodGen<'a> { @@ -73,7 +71,6 @@ impl<'a> MethodGen<'a> { proto, package_name, service_name, - service_path, root_scope, customize, } @@ -127,10 +124,6 @@ impl<'a> MethodGen<'a> { to_camel_case(self.proto.get_name()) } - fn fq_name(&self) -> String { - format!("\"{}/{}\"", self.service_path, &self.proto.get_name()) - } - fn const_method_name(&self) -> String { format!( "METHOD_{}_{}", @@ -139,36 +132,13 @@ impl<'a> MethodGen<'a> { ) } - fn write_definition(&self, w: &mut CodeWriter) { - let head = format!( - "const {}: {}<{}, {}> = {} {{", - self.const_method_name(), - fq_grpc("Method"), - self.input(), - self.output(), - fq_grpc("Method") - ); - let pb_mar = format!( - "{} {{ ser: {}, de: {} }}", - fq_grpc("Marshaller"), - fq_grpc("pb_ser"), - fq_grpc("pb_de") - ); - w.block(&head, "};", |w| { - w.field_entry("ty", &self.method_type().1); - w.field_entry("name", &self.fq_name()); - w.field_entry("req_mar", &pb_mar); - w.field_entry("resp_mar", &pb_mar); - }); - } - fn write_handler(&self, w: &mut CodeWriter) { w.block( &format!("struct {}Method {{", self.struct_name()), "}", |w| { w.write_line(&format!( - "service: Arc>,", + "service: Arc>,", self.service_name )); }, @@ -197,16 +167,55 @@ impl<'a> MethodGen<'a> { fn write_handler_impl_async(&self, w: &mut CodeWriter) { w.write_line("#[async_trait]"); - w.block(&format!("impl ::ttrpc::r#async::MethodHandler for {}Method {{", self.struct_name()), "}", - |w| { - w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, req: ::ttrpc::Request) -> ::ttrpc::Result<(u32, Vec)> {", "}", - |w| { - w.write_line(&format!("::ttrpc::async_request_handler!(self, ctx, req, {}, {}, {});", + match self.method_type().0 { + MethodType::Unary => { + w.block(&format!("impl ::ttrpc::r#async::MethodHandler for {}Method {{", self.struct_name()), "}", + |w| { + w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, req: ::ttrpc::Request) -> ::ttrpc::Result<::ttrpc::Response> {", "}", + |w| { + w.write_line(&format!("::ttrpc::async_request_handler!(self, ctx, req, {}, {}, {});", proto_path_to_rust_mod(self.root_scope.find_message(self.proto.get_input_type()).get_scope().get_file_descriptor().get_name()), self.root_scope.find_message(self.proto.get_input_type()).rust_name(), self.name())); + }); }); - }); + } + // only receive + MethodType::ClientStreaming => { + w.block(&format!("impl ::ttrpc::r#async::StreamHandler for {}Method {{", self.struct_name()), "}", + |w| { + w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, inner: ::ttrpc::r#async::StreamInner) -> ::ttrpc::Result> {", "}", + |w| { + w.write_line(&format!("::ttrpc::async_client_streamimg_handler!(self, ctx, inner, {});", + self.name())); + }); + }); + } + // only send + MethodType::ServerStreaming => { + w.block(&format!("impl ::ttrpc::r#async::StreamHandler for {}Method {{", self.struct_name()), "}", + |w| { + w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, mut inner: ::ttrpc::r#async::StreamInner) -> ::ttrpc::Result> {", "}", + |w| { + w.write_line(&format!("::ttrpc::async_server_streamimg_handler!(self, ctx, inner, {}, {}, {});", + proto_path_to_rust_mod(self.root_scope.find_message(self.proto.get_input_type()).get_scope().get_file_descriptor().get_name()), + self.root_scope.find_message(self.proto.get_input_type()).rust_name(), + self.name())); + }); + }); + } + // receive and send + MethodType::Duplex => { + w.block(&format!("impl ::ttrpc::r#async::StreamHandler for {}Method {{", self.struct_name()), "}", + |w| { + w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, inner: ::ttrpc::r#async::StreamInner) -> ::ttrpc::Result> {", "}", + |w| { + w.write_line(&format!("::ttrpc::async_duplex_streamimg_handler!(self, ctx, inner, {});", + self.name())); + }); + }); + } + } } // Method signatures @@ -222,73 +231,33 @@ impl<'a> MethodGen<'a> { fn client_streaming(&self, method_name: &str) -> String { format!( - "{}(&self) -> {}<({}<{}>, {}<{}>)>", + "{}(&self, ctx: ttrpc::context::Context) -> {}<{}<{}, {}>>", method_name, fq_grpc("Result"), - fq_grpc("ClientCStreamSender"), + fq_grpc("r#async::ClientStreamSender"), self.input(), - fq_grpc("ClientCStreamReceiver"), - self.output() - ) - } - - fn client_streaming_opt(&self, method_name: &str) -> String { - format!( - "{}_opt(&self, opt: {}) -> {}<({}<{}>, {}<{}>)>", - method_name, - fq_grpc("CallOption"), - fq_grpc("Result"), - fq_grpc("ClientCStreamSender"), - self.input(), - fq_grpc("ClientCStreamReceiver"), self.output() ) } fn server_streaming(&self, method_name: &str) -> String { format!( - "{}(&self, req: &{}) -> {}<{}<{}>>", + "{}(&self, ctx: ttrpc::context::Context, req: &{}) -> {}<{}<{}>>", method_name, self.input(), fq_grpc("Result"), - fq_grpc("ClientSStreamReceiver"), - self.output() - ) - } - - fn server_streaming_opt(&self, method_name: &str) -> String { - format!( - "{}_opt(&self, req: &{}, opt: {}) -> {}<{}<{}>>", - method_name, - self.input(), - fq_grpc("CallOption"), - fq_grpc("Result"), - fq_grpc("ClientSStreamReceiver"), + fq_grpc("r#async::ClientStreamReceiver"), self.output() ) } fn duplex_streaming(&self, method_name: &str) -> String { format!( - "{}(&self) -> {}<({}<{}>, {}<{}>)>", + "{}(&self, ctx: ttrpc::context::Context) -> {}<{}<{}, {}>>", method_name, fq_grpc("Result"), - fq_grpc("ClientDuplexSender"), + fq_grpc("r#async::ClientStream"), self.input(), - fq_grpc("ClientDuplexReceiver"), - self.output() - ) - } - - fn duplex_streaming_opt(&self, method_name: &str) -> String { - format!( - "{}_opt(&self, opt: {}) -> {}<({}<{}>, {}<{}>)>", - method_name, - fq_grpc("CallOption"), - fq_grpc("Result"), - fq_grpc("ClientDuplexSender"), - self.input(), - fq_grpc("ClientDuplexReceiver"), self.output() ) } @@ -329,28 +298,79 @@ impl<'a> MethodGen<'a> { )); }); } - - _ => {} + // ClientStreaming + MethodType::ClientStreaming => { + pub_async_fn(w, &self.client_streaming(&method_name), |w| { + w.write_line(&format!( + "::ttrpc::async_client_stream_send!(self, ctx, \"{}.{}\", \"{}\");", + self.package_name, + self.service_name, + &self.proto.get_name(), + )); + }); + } + // ServerStreaming + MethodType::ServerStreaming => { + pub_async_fn(w, &self.server_streaming(&method_name), |w| { + w.write_line(&format!( + "::ttrpc::async_client_stream_receive!(self, ctx, req, \"{}.{}\", \"{}\");", + self.package_name, + self.service_name, + &self.proto.get_name(), + )); + }); + } + // Duplex + MethodType::Duplex => { + pub_async_fn(w, &self.duplex_streaming(&method_name), |w| { + w.write_line(&format!( + "::ttrpc::async_client_stream!(self, ctx, \"{}.{}\", \"{}\");", + self.package_name, + self.service_name, + &self.proto.get_name(), + )); + }); + } }; } fn write_service(&self, w: &mut CodeWriter) { - let req_stream_type = format!("{}<{}>", fq_grpc("RequestStream"), self.input()); - let (req, req_type, _resp_type) = match self.method_type().0 { - MethodType::Unary => ("req", self.input(), "UnarySink"), - MethodType::ClientStreaming => ("stream", req_stream_type, "ClientStreamingSink"), - MethodType::ServerStreaming => ("req", self.input(), "ServerStreamingSink"), - MethodType::Duplex => ("stream", req_stream_type, "DuplexSink"), + let (_req, req_type, resp_type) = match self.method_type().0 { + MethodType::Unary => ("req", self.input(), self.output()), + MethodType::ClientStreaming => ( + "stream", + format!("::ttrpc::r#async::ServerStreamReceiver<{}>", self.input()), + self.output(), + ), + MethodType::ServerStreaming => ( + "req", + format!( + "{}, _: {}<{}>", + self.input(), + "::ttrpc::r#async::ServerStreamSender", + self.output() + ), + "()".to_string(), + ), + MethodType::Duplex => ( + "stream", + format!( + "{}<{}, {}>", + "::ttrpc::r#async::ServerStream", // Send type first + self.output(), + self.input(), + ), + "()".to_string(), + ), }; let get_sig = |context_name| { format!( - "{}(&self, _ctx: &{}, _{}: {}) -> ::ttrpc::Result<{}>", + "{}(&self, _ctx: &{}, _: {}) -> ::ttrpc::Result<{}>", self.name(), fq_grpc(context_name), - req, req_type, - self.output() + resp_type, ) }; @@ -370,20 +390,38 @@ impl<'a> MethodGen<'a> { } fn write_bind(&self, w: &mut CodeWriter) { - let mut method_handler_name = "::ttrpc::MethodHandler"; + let method_handler_name = "::ttrpc::MethodHandler"; - if async_on(self.customize, "server") { - method_handler_name = "::ttrpc::r#async::MethodHandler"; - } + let s = format!( + "methods.insert(\"/{}.{}/{}\".to_string(), + Box::new({}Method{{service: service.clone()}}) as Box);", + self.package_name, + self.service_name, + self.proto.get_name(), + self.struct_name(), + method_handler_name, + ); + w.write_line(&s); + } - let s = format!("methods.insert(\"/{}.{}/{}\".to_string(), - std::boxed::Box::new({}Method{{service: service.clone()}}) as std::boxed::Box);", - self.package_name, - self.service_name, - self.proto.get_name(), - self.struct_name(), - method_handler_name, - ); + fn write_async_bind(&self, w: &mut CodeWriter) { + let s = if matches!(self.method_type().0, MethodType::Unary) { + format!( + "methods.insert(\"{}\".to_string(), + Box::new({}Method{{service: service.clone()}}) as {});", + self.proto.get_name(), + self.struct_name(), + "Box" + ) + } else { + format!( + "streams.insert(\"{}\".to_string(), + Arc::new({}Method{{service: service.clone()}}) as {});", + self.proto.get_name(), + self.struct_name(), + "Arc" + ) + }; w.write_line(&s); } } @@ -392,6 +430,7 @@ struct ServiceGen<'a> { proto: &'a ServiceDescriptorProto, methods: Vec>, customize: &'a Customize, + package_name: String, } impl<'a> ServiceGen<'a> { @@ -401,11 +440,6 @@ impl<'a> ServiceGen<'a> { root_scope: &'a RootScope, customize: &'a Customize, ) -> ServiceGen<'a> { - let service_path = if file.get_package().is_empty() { - format!("{}", proto.get_name()) - } else { - format!("{}.{}", file.get_package(), proto.get_name()) - }; let methods = proto .get_method() .iter() @@ -414,7 +448,6 @@ impl<'a> ServiceGen<'a> { m, file.get_package().to_string(), util::to_camel_case(proto.get_name()), - service_path.clone(), root_scope, &customize, ) @@ -425,6 +458,7 @@ impl<'a> ServiceGen<'a> { proto, methods, customize, + package_name: file.get_package().to_string(), } } @@ -432,10 +466,20 @@ impl<'a> ServiceGen<'a> { util::to_camel_case(self.proto.get_name()) } + fn service_path(&self) -> String { + format!("{}.{}", self.package_name, self.service_name()) + } + fn client_name(&self) -> String { format!("{}Client", self.service_name()) } + fn has_stream_method(&self) -> bool { + self.methods + .iter() + .any(|method| !matches!(method.method_type().0, MethodType::Unary)) + } + fn write_client(&self, w: &mut CodeWriter) { if async_on(self.customize, "client") { self.write_async_client(w) @@ -490,11 +534,9 @@ impl<'a> ServiceGen<'a> { fn write_server(&self, w: &mut CodeWriter) { let mut trait_name = self.service_name(); - let mut method_handler_name = "::ttrpc::MethodHandler"; if async_on(self.customize, "server") { w.write_line("#[async_trait]"); trait_name = format!("{}: Sync", &self.service_name()); - method_handler_name = "::ttrpc::r#async::MethodHandler"; } w.pub_trait(&trait_name.to_owned(), |w| { @@ -504,9 +546,17 @@ impl<'a> ServiceGen<'a> { }); w.write_line(""); + if async_on(self.customize, "server") { + self.write_async_server_create(w); + } else { + self.write_sync_server_create(w); + } + } + fn write_sync_server_create(&self, w: &mut CodeWriter) { + let method_handler_name = "::ttrpc::MethodHandler"; let s = format!( - "create_{}(service: Arc>) -> HashMap >", + "create_{}(service: Arc>) -> HashMap>", to_snake_case(&self.service_name()), self.service_name(), method_handler_name, @@ -523,14 +573,35 @@ impl<'a> ServiceGen<'a> { }); } - fn write_method_definitions(&self, w: &mut CodeWriter) { - for (i, method) in self.methods.iter().enumerate() { - if i != 0 { + fn write_async_server_create(&self, w: &mut CodeWriter) { + let s = format!( + "create_{}(service: Arc>) -> HashMap", + to_snake_case(&self.service_name()), + self.service_name(), + "::ttrpc::r#async::Service" + ); + + let has_stream_method = self.has_stream_method(); + w.pub_fn(&s, |w| { + w.write_line("let mut ret = HashMap::new();"); + w.write_line("let mut methods = HashMap::new();"); + if has_stream_method { + w.write_line("let mut streams = HashMap::new();"); + } else { + w.write_line("let streams = HashMap::new();"); + } + for method in &self.methods[0..self.methods.len()] { w.write_line(""); + method.write_async_bind(w); } - - method.write_definition(w); - } + w.write_line(""); + w.write_line(format!( + "ret.insert(\"{}\".to_string(), {});", + self.service_path(), + "::ttrpc::r#async::Service{ methods, streams }" + )); + w.write_line("ret"); + }); } fn write_method_handlers(&self, w: &mut CodeWriter) { diff --git a/src/asynchronous/mod.rs b/src/asynchronous/mod.rs index 2e976b5f..fb6d39cb 100644 --- a/src/asynchronous/mod.rs +++ b/src/asynchronous/mod.rs @@ -15,7 +15,10 @@ mod connection; pub mod shutdown; mod unix_incoming; -pub use self::stream::{Kind, StreamInner}; +pub use self::stream::{ + ClientStream, ClientStreamReceiver, ClientStreamSender, Kind, ServerStream, + ServerStreamReceiver, ServerStreamSender, StreamInner, +}; #[doc(inline)] pub use crate::r#async::client::Client; #[doc(inline)] diff --git a/src/asynchronous/stream.rs b/src/asynchronous/stream.rs index 26fc8ae7..7b78b916 100644 --- a/src/asynchronous/stream.rs +++ b/src/asynchronous/stream.rs @@ -1,9 +1,11 @@ +// Copyright 2022 Alibaba Cloud. All rights reserved. // Copyright (c) 2020 Ant Financial // // SPDX-License-Identifier: Apache-2.0 // use std::collections::HashMap; +use std::marker::PhantomData; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; @@ -21,6 +23,291 @@ pub type MessageReceiver = mpsc::Receiver; pub type ResultSender = mpsc::Sender>; pub type ResultReceiver = mpsc::Receiver>; +#[derive(Debug)] +pub struct ClientStream { + tx: CSSender, + rx: CSReceiver

, +} + +impl ClientStream +where + Q: Codec, + P: Codec, + ::E: std::fmt::Display, +

::E: std::fmt::Display, +{ + pub fn new(inner: StreamInner) -> Self { + let (tx, rx) = inner.split(); + Self { + tx: CSSender { + tx, + _send: PhantomData, + }, + rx: CSReceiver { + rx, + _recv: PhantomData, + }, + } + } + + pub fn split(self) -> (CSSender, CSReceiver

) { + (self.tx, self.rx) + } + + pub async fn send(&self, req: &Q) -> Result<()> { + self.tx.send(req).await + } + + pub async fn close_send(&self) -> Result<()> { + self.tx.close_send().await + } + + pub async fn recv(&mut self) -> Result

{ + self.rx.recv().await + } +} + +#[derive(Clone, Debug)] +pub struct CSSender { + tx: StreamSender, + _send: PhantomData, +} + +impl CSSender +where + Q: Codec, + ::E: std::fmt::Display, +{ + pub async fn send(&self, req: &Q) -> Result<()> { + let msg_buf = req + .encode() + .map_err(err_to_others_err!(e, "Encode message failed."))?; + self.tx.send(msg_buf).await + } + + pub async fn close_send(&self) -> Result<()> { + self.tx.close_send().await + } +} + +#[derive(Debug)] +pub struct CSReceiver

{ + rx: StreamReceiver, + _recv: PhantomData

, +} + +impl

CSReceiver

+where + P: Codec, +

::E: std::fmt::Display, +{ + pub async fn recv(&mut self) -> Result

{ + let msg_buf = self.rx.recv().await?; + P::decode(&msg_buf).map_err(err_to_others_err!(e, "Decode message failed.")) + } +} + +#[derive(Debug)] +pub struct ServerStream { + tx: SSSender

, + rx: SSReceiver, +} + +impl ServerStream +where + P: Codec, + Q: Codec, +

::E: std::fmt::Display, + ::E: std::fmt::Display, +{ + pub fn new(inner: StreamInner) -> Self { + let (tx, rx) = inner.split(); + Self { + tx: SSSender { + tx, + _send: PhantomData, + }, + rx: SSReceiver { + rx, + _recv: PhantomData, + }, + } + } + + pub fn split(self) -> (SSSender

, SSReceiver) { + (self.tx, self.rx) + } + + pub async fn send(&self, resp: &P) -> Result<()> { + self.tx.send(resp).await + } + + pub async fn recv(&mut self) -> Result> { + self.rx.recv().await + } +} + +#[derive(Clone, Debug)] +pub struct SSSender

{ + tx: StreamSender, + _send: PhantomData

, +} + +impl

SSSender

+where + P: Codec, +

::E: std::fmt::Display, +{ + pub async fn send(&self, resp: &P) -> Result<()> { + let msg_buf = resp + .encode() + .map_err(err_to_others_err!(e, "Encode message failed."))?; + self.tx.send(msg_buf).await + } +} + +#[derive(Debug)] +pub struct SSReceiver { + rx: StreamReceiver, + _recv: PhantomData, +} + +impl SSReceiver +where + Q: Codec, + ::E: std::fmt::Display, +{ + pub async fn recv(&mut self) -> Result> { + let res = self.rx.recv().await; + + if matches!(res, Err(Error::EOF)) { + return Ok(None); + } + let msg_buf = res?; + Q::decode(&msg_buf) + .map_err(err_to_others_err!(e, "Decode message failed.")) + .map(Some) + } +} + +pub struct ClientStreamSender { + inner: StreamInner, + _send: PhantomData, + _recv: PhantomData

, +} + +impl ClientStreamSender +where + Q: Codec, + P: Codec, + ::E: std::fmt::Display, +

::E: std::fmt::Display, +{ + pub fn new(inner: StreamInner) -> Self { + Self { + inner, + _send: PhantomData, + _recv: PhantomData, + } + } + + pub async fn send(&self, req: &Q) -> Result<()> { + let msg_buf = req + .encode() + .map_err(err_to_others_err!(e, "Encode message failed."))?; + self.inner.send(msg_buf).await + } + + pub async fn close_and_recv(&mut self) -> Result

{ + self.inner.close_send().await?; + let msg_buf = self.inner.recv().await?; + P::decode(&msg_buf).map_err(err_to_others_err!(e, "Decode message failed.")) + } +} + +pub struct ServerStreamSender

{ + inner: StreamSender, + _send: PhantomData

, +} + +impl

ServerStreamSender

+where + P: Codec, +

::E: std::fmt::Display, +{ + pub fn new(inner: StreamInner) -> Self { + Self { + inner: inner.split().0, + _send: PhantomData, + } + } + + pub async fn send(&self, resp: &P) -> Result<()> { + let msg_buf = resp + .encode() + .map_err(err_to_others_err!(e, "Encode message failed."))?; + self.inner.send(msg_buf).await + } +} + +pub struct ClientStreamReceiver

{ + inner: StreamReceiver, + _recv: PhantomData

, +} + +impl

ClientStreamReceiver

+where + P: Codec, +

::E: std::fmt::Display, +{ + pub fn new(inner: StreamInner) -> Self { + Self { + inner: inner.split().1, + _recv: PhantomData, + } + } + + pub async fn recv(&mut self) -> Result> { + let res = self.inner.recv().await; + if matches!(res, Err(Error::EOF)) { + return Ok(None); + } + let msg_buf = res?; + P::decode(&msg_buf) + .map_err(err_to_others_err!(e, "Decode message failed.")) + .map(Some) + } +} + +pub struct ServerStreamReceiver { + inner: StreamReceiver, + _recv: PhantomData, +} + +impl ServerStreamReceiver +where + Q: Codec, + ::E: std::fmt::Display, +{ + pub fn new(inner: StreamInner) -> Self { + Self { + inner: inner.split().1, + _recv: PhantomData, + } + } + + pub async fn recv(&mut self) -> Result> { + let res = self.inner.recv().await; + if matches!(res, Err(Error::EOF)) { + return Ok(None); + } + let msg_buf = res?; + Q::decode(&msg_buf) + .map_err(err_to_others_err!(e, "Decode message failed.")) + .map(Some) + } +} + async fn _recv(rx: &mut ResultReceiver) -> Result { rx.recv() .await diff --git a/src/asynchronous/utils.rs b/src/asynchronous/utils.rs index 17716ae8..2e2555d1 100644 --- a/src/asynchronous/utils.rs +++ b/src/asynchronous/utils.rs @@ -1,3 +1,4 @@ +// Copyright 2022 Alibaba Cloud. All rights reserved. // Copyright (c) 2020 Ant Financial // // SPDX-License-Identifier: Apache-2.0 @@ -50,6 +51,96 @@ macro_rules! async_request_handler { }; } +/// Handle client streaming in async mode. +#[macro_export] +macro_rules! async_client_streamimg_handler { + ($class: ident, $ctx: ident, $inner: ident, $req_fn: ident) => { + let stream = ::ttrpc::r#async::ServerStreamReceiver::new($inner); + let mut res = ::ttrpc::Response::new(); + match $class.service.$req_fn(&$ctx, stream).await { + Ok(rep) => { + res.set_status(::ttrpc::get_status(::ttrpc::Code::OK, "".to_string())); + res.payload.reserve(rep.compute_size() as usize); + let mut s = protobuf::CodedOutputStream::vec(&mut res.payload); + rep.write_to(&mut s) + .map_err(::ttrpc::err_to_others!(e, ""))?; + s.flush().map_err(::ttrpc::err_to_others!(e, ""))?; + } + Err(x) => match x { + ::ttrpc::Error::RpcStatus(s) => { + res.set_status(s); + } + _ => { + res.set_status(::ttrpc::get_status( + ::ttrpc::Code::UNKNOWN, + format!("{:?}", x), + )); + } + }, + } + return Ok(Some(res)); + }; +} + +/// Handle server streaming in async mode. +#[macro_export] +macro_rules! async_server_streamimg_handler { + ($class: ident, $ctx: ident, $inner: ident, $server: ident, $req_type: ident, $req_fn: ident) => { + let req_buf = $inner.recv().await?; + let req = ::decode(&req_buf) + .map_err(|e| ::ttrpc::Error::Others(e.to_string()))?; + let stream = ::ttrpc::r#async::ServerStreamSender::new($inner); + match $class.service.$req_fn(&$ctx, req, stream).await { + Ok(_) => { + return Ok(None); + } + Err(x) => { + let mut res = ::ttrpc::Response::new(); + match x { + ::ttrpc::Error::RpcStatus(s) => { + res.set_status(s); + } + _ => { + res.set_status(::ttrpc::get_status( + ::ttrpc::Code::UNKNOWN, + format!("{:?}", x), + )); + } + } + return Ok(Some(res)); + } + } + }; +} + +/// Handle duplex streaming in async mode. +#[macro_export] +macro_rules! async_duplex_streamimg_handler { + ($class: ident, $ctx: ident, $inner: ident, $req_fn: ident) => { + let stream = ::ttrpc::r#async::ServerStream::new($inner); + match $class.service.$req_fn(&$ctx, stream).await { + Ok(_) => { + return Ok(None); + } + Err(x) => { + let mut res = ::ttrpc::Response::new(); + match x { + ::ttrpc::Error::RpcStatus(s) => { + res.set_status(s); + } + _ => { + res.set_status(::ttrpc::get_status( + ::ttrpc::Code::UNKNOWN, + format!("{:?}", x), + )); + } + } + return Ok(Some(res)); + } + } + }; +} + /// Send request through async client. #[macro_export] macro_rules! async_client_request { @@ -78,6 +169,67 @@ macro_rules! async_client_request { }; } +/// Duplex streaming through async client. +#[macro_export] +macro_rules! async_client_stream { + ($self: ident, $ctx: ident, $server: expr, $method: expr) => { + let mut creq = ::ttrpc::Request::new(); + creq.set_service($server.to_string()); + creq.set_method($method.to_string()); + creq.set_timeout_nano($ctx.timeout_nano); + let md = ::ttrpc::context::to_pb($ctx.metadata); + creq.set_metadata(md); + + let inner = $self.client.new_stream(creq, true, true).await?; + let stream = ::ttrpc::r#async::ClientStream::new(inner); + + return Ok(stream); + }; +} + +/// Only send streaming through async client. +#[macro_export] +macro_rules! async_client_stream_send { + ($self: ident, $ctx: ident, $server: expr, $method: expr) => { + let mut creq = ::ttrpc::Request::new(); + creq.set_service($server.to_string()); + creq.set_method($method.to_string()); + creq.set_timeout_nano($ctx.timeout_nano); + let md = ::ttrpc::context::to_pb($ctx.metadata); + creq.set_metadata(md); + + let inner = $self.client.new_stream(creq, true, false).await?; + let stream = ::ttrpc::r#async::ClientStreamSender::new(inner); + + return Ok(stream); + }; +} + +/// Only receive streaming through async client. +#[macro_export] +macro_rules! async_client_stream_receive { + ($self: ident, $ctx: ident, $req: ident, $server: expr, $method: expr) => { + let mut creq = ::ttrpc::Request::new(); + creq.set_service($server.to_string()); + creq.set_method($method.to_string()); + creq.set_timeout_nano($ctx.timeout_nano); + let md = ::ttrpc::context::to_pb($ctx.metadata); + creq.set_metadata(md); + creq.payload.reserve($req.compute_size() as usize); + { + let mut s = CodedOutputStream::vec(&mut creq.payload); + $req.write_to(&mut s) + .map_err(::ttrpc::err_to_others!(e, ""))?; + s.flush().map_err(::ttrpc::err_to_others!(e, ""))?; + } + + let inner = $self.client.new_stream(creq, false, true).await?; + let stream = ::ttrpc::r#async::ClientStreamReceiver::new(inner); + + return Ok(stream); + }; +} + /// Trait that implements handler which is a proxy to the desired method (async). #[async_trait] pub trait MethodHandler {