diff --git a/compiler/src/codegen.rs b/compiler/src/codegen.rs index 6fd95df1..2ee34900 100644 --- a/compiler/src/codegen.rs +++ b/compiler/src/codegen.rs @@ -487,6 +487,12 @@ impl<'a> ServiceGen<'a> { .any(|method| !matches!(method.method_type().0, MethodType::Unary)) } + fn has_unary_method(&self) -> bool { + self.methods + .iter() + .any(|method| matches!(method.method_type().0, MethodType::Unary)) + } + fn write_client(&self, w: &mut CodeWriter) { if async_on(self.customize, "client") { self.write_async_client(w) @@ -589,9 +595,14 @@ impl<'a> ServiceGen<'a> { ); let has_stream_method = self.has_stream_method(); + let has_unary_method = self.has_unary_method(); w.pub_fn(&s, |w| { w.write_line("let mut ret = HashMap::new();"); - w.write_line("let mut methods = HashMap::new();"); + if has_unary_method { + w.write_line("let mut methods = HashMap::new();"); + } else { + w.write_line("let methods = HashMap::new();"); + } if has_stream_method { w.write_line("let mut streams = HashMap::new();"); } else { diff --git a/src/asynchronous/client.rs b/src/asynchronous/client.rs index d0683e61..eeaec054 100644 --- a/src/asynchronous/client.rs +++ b/src/asynchronous/client.rs @@ -27,6 +27,8 @@ use crate::r#async::stream::{ }; use crate::r#async::utils; +use super::stream::SendingMessage; + /// A ttrpc Client (async). #[derive(Clone)] pub struct Client { @@ -78,7 +80,7 @@ impl Client { self.streams.lock().unwrap().insert(stream_id, tx); self.req_tx - .send(msg) + .send(SendingMessage::new(msg)) .await .map_err(|e| Error::Others(format!("Send packet to sender error {:?}", e)))?; @@ -131,7 +133,7 @@ impl Client { // TODO: check return self.streams.lock().unwrap().insert(stream_id, tx); self.req_tx - .send(msg) + .send(SendingMessage::new(msg)) .await .map_err(|e| Error::Others(format!("Send packet to sender error {:?}", e)))?; @@ -196,7 +198,7 @@ struct ClientWriter { #[async_trait] impl WriterDelegate for ClientWriter { - async fn recv(&mut self) -> Option { + async fn recv(&mut self) -> Option { self.rx.recv().await } diff --git a/src/asynchronous/connection.rs b/src/asynchronous/connection.rs index 8de87d3b..7f099bb2 100644 --- a/src/asynchronous/connection.rs +++ b/src/asynchronous/connection.rs @@ -16,6 +16,8 @@ use tokio::{ use crate::error::Error; use crate::proto::GenMessage; +use super::stream::SendingMessage; + pub trait Builder { type Reader; type Writer; @@ -25,7 +27,7 @@ pub trait Builder { #[async_trait] pub trait WriterDelegate { - async fn recv(&mut self) -> Option; + async fn recv(&mut self) -> Option; async fn disconnect(&self, msg: &GenMessage, e: Error); async fn exit(&self); } @@ -57,12 +59,14 @@ where let (reader_delegate, mut writer_delegate) = builder.build(); let writer_task = tokio::spawn(async move { - while let Some(msg) = writer_delegate.recv().await { - trace!("write message: {:?}", msg); - if let Err(e) = msg.write_to(&mut writer).await { + while let Some(mut sending_msg) = writer_delegate.recv().await { + trace!("write message: {:?}", sending_msg.msg); + if let Err(e) = sending_msg.msg.write_to(&mut writer).await { error!("write_message got error: {:?}", e); - writer_delegate.disconnect(&msg, e).await; + sending_msg.send_result(Err(e.clone())); + writer_delegate.disconnect(&sending_msg.msg, e).await; } + sending_msg.send_result(Ok(())); } writer_delegate.exit().await; trace!("Writer task exit."); diff --git a/src/asynchronous/server.rs b/src/asynchronous/server.rs index a19dac30..4e43092f 100644 --- a/src/asynchronous/server.rs +++ b/src/asynchronous/server.rs @@ -30,7 +30,7 @@ use tokio::{ #[cfg(target_os = "linux")] use tokio_vsock::VsockListener; -use crate::asynchronous::unix_incoming::UnixIncoming; +use crate::asynchronous::{stream::SendingMessage, unix_incoming::UnixIncoming}; use crate::common::{self, Domain}; use crate::context; use crate::error::{get_status, Error, Result}; @@ -329,7 +329,7 @@ struct ServerWriter { #[async_trait] impl WriterDelegate for ServerWriter { - async fn recv(&mut self) -> Option { + async fn recv(&mut self) -> Option { self.rx.recv().await } async fn disconnect(&self, _msg: &GenMessage, _: Error) {} @@ -371,12 +371,14 @@ impl ReaderDelegate for ServerReader { async fn handle_msg(&self, msg: GenMessage) { let handler_shutdown_waiter = self.handler_shutdown.subscribe(); let context = self.context(); + let (wait_tx, wait_rx) = tokio::sync::oneshot::channel::<()>(); spawn(async move { select! { - _ = context.handle_msg(msg) => {} + _ = context.handle_msg(msg, wait_tx) => {} _ = handler_shutdown_waiter.wait_shutdown() => {} } }); + wait_rx.await.unwrap_or_default(); } } @@ -402,7 +404,7 @@ struct HandlerContext { } impl HandlerContext { - async fn handle_msg(&self, msg: GenMessage) { + async fn handle_msg(&self, msg: GenMessage, wait_tx: tokio::sync::oneshot::Sender<()>) { let stream_id = msg.header.stream_id; if (stream_id % 2) != 1 { @@ -416,7 +418,7 @@ impl HandlerContext { } match msg.header.type_ { - MESSAGE_TYPE_REQUEST => match self.handle_request(msg).await { + MESSAGE_TYPE_REQUEST => match self.handle_request(msg, wait_tx).await { Ok(opt_msg) => match opt_msg { Some(msg) => { Self::respond(self.tx.clone(), stream_id, msg) @@ -435,7 +437,7 @@ impl HandlerContext { }; self.tx - .send(msg) + .send(SendingMessage::new(msg)) .await .map_err(err_to_others_err!(e, "Send packet to sender error ")) .ok(); @@ -444,6 +446,8 @@ impl HandlerContext { Err(status) => Self::respond_with_status(self.tx.clone(), stream_id, status).await, }, MESSAGE_TYPE_DATA => { + // no need to wait data message handling + drop(wait_tx); // TODO(wllenyj): Compatible with golang behavior. if (msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED && !msg.payload.is_empty() @@ -492,7 +496,11 @@ impl HandlerContext { } } - async fn handle_request(&self, msg: GenMessage) -> StdResult, Status> { + async fn handle_request( + &self, + msg: GenMessage, + wait_tx: tokio::sync::oneshot::Sender<()>, + ) -> StdResult, Status> { //TODO: //if header.stream_id <= self.last_stream_id { // return Err; @@ -513,10 +521,11 @@ impl HandlerContext { })?; if let Some(method) = srv.get_method(&req.method) { + drop(wait_tx); return self.handle_method(method, req_msg).await; } if let Some(stream) = srv.get_stream(&req.method) { - return self.handle_stream(stream, req_msg).await; + return self.handle_stream(stream, req_msg, wait_tx).await; } Err(get_status( Code::UNIMPLEMENTED, @@ -572,6 +581,7 @@ impl HandlerContext { &self, stream: Arc, req_msg: Message, + wait_tx: tokio::sync::oneshot::Sender<()>, ) -> StdResult, Status> { let stream_id = req_msg.header.stream_id; let req = req_msg.payload; @@ -583,6 +593,9 @@ impl HandlerContext { let _remote_close = (req_msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED; let _remote_open = (req_msg.header.flags & FLAG_REMOTE_OPEN) == FLAG_REMOTE_OPEN; + + drop(wait_tx); + let si = StreamInner::new( stream_id, self.tx.clone(), @@ -631,7 +644,7 @@ impl HandlerContext { header: MessageHeader::new_response(stream_id, payload.len() as u32), payload, }; - tx.send(msg) + tx.send(SendingMessage::new(msg)) .await .map_err(err_to_others_err!(e, "Send packet to sender error ")) } diff --git a/src/asynchronous/stream.rs b/src/asynchronous/stream.rs index 27256aa8..e75ce041 100644 --- a/src/asynchronous/stream.rs +++ b/src/asynchronous/stream.rs @@ -17,12 +17,42 @@ use crate::proto::{ MESSAGE_TYPE_DATA, MESSAGE_TYPE_RESPONSE, }; -pub type MessageSender = mpsc::Sender; -pub type MessageReceiver = mpsc::Receiver; +pub type MessageSender = mpsc::Sender; +pub type MessageReceiver = mpsc::Receiver; pub type ResultSender = mpsc::Sender>; pub type ResultReceiver = mpsc::Receiver>; +#[derive(Debug)] +pub struct SendingMessage { + pub msg: GenMessage, + pub result_chan: Option>>, +} + +impl SendingMessage { + pub fn new(msg: GenMessage) -> Self { + Self { + msg, + result_chan: None, + } + } + pub fn new_with_result( + msg: GenMessage, + result_chan: tokio::sync::oneshot::Sender>, + ) -> Self { + Self { + msg, + result_chan: Some(result_chan), + } + } + + pub fn send_result(&mut self, result: Result<()>) { + if let Some(result_ch) = self.result_chan.take() { + result_ch.send(result).unwrap_or_default(); + } + } +} + #[derive(Debug)] pub struct ClientStream { tx: CSSender, @@ -317,9 +347,13 @@ async fn _recv(rx: &mut ResultReceiver) -> Result { } async fn _send(tx: &MessageSender, msg: GenMessage) -> Result<()> { - tx.send(msg) + let (res_tx, res_rx) = tokio::sync::oneshot::channel(); + tx.send(SendingMessage::new_with_result(msg, res_tx)) + .await + .map_err(|e| Error::Others(format!("Send data packet to sender error {:?}", e)))?; + res_rx .await - .map_err(|e| Error::Others(format!("Send data packet to sender error {:?}", e))) + .map_err(|e| Error::Others(format!("Failed to wait send result {:?}", e)))? } #[derive(Clone, Copy, Debug, PartialEq, Eq)] diff --git a/ttrpc-codegen/Cargo.toml b/ttrpc-codegen/Cargo.toml index fa0fc2e8..8d98c9c8 100644 --- a/ttrpc-codegen/Cargo.toml +++ b/ttrpc-codegen/Cargo.toml @@ -16,4 +16,4 @@ readme = "README.md" protobuf-support = "3.1.0" protobuf = { version = "2.27.1" } protobuf-codegen = "3.1.0" -ttrpc-compiler = "0.6.1" +ttrpc-compiler = { path = "../ttrpc-compiler" }