From 5da506101f1c6d6c2dd106643b8053a2c1092aa2 Mon Sep 17 00:00:00 2001 From: wllenyj Date: Sun, 24 Apr 2022 17:14:15 +0800 Subject: [PATCH 01/15] Add proto.rs to include protocol related. Move the MessageHeader from `common.rs` to the `proto.rs` file. Move the `compiled::ttrpc` from `common.rs` to `proto.rs` file and re-export. Signed-off-by: wllenyj --- src/asynchronous/client.rs | 5 +- src/asynchronous/server.rs | 5 +- src/asynchronous/stream.rs | 5 +- src/asynchronous/utils.rs | 5 +- src/common.rs | 73 ---------------------- src/lib.rs | 11 +--- src/proto.rs | 122 +++++++++++++++++++++++++++++++++++++ src/sync/channel.rs | 4 +- src/sync/client.rs | 7 ++- src/sync/server.rs | 5 +- src/sync/utils.rs | 3 +- 11 files changed, 144 insertions(+), 101 deletions(-) create mode 100644 src/proto.rs diff --git a/src/asynchronous/client.rs b/src/asynchronous/client.rs index c5668694..f232fe81 100644 --- a/src/asynchronous/client.rs +++ b/src/asynchronous/client.rs @@ -9,9 +9,9 @@ use std::collections::HashMap; use std::os::unix::io::RawFd; use std::sync::{Arc, Mutex}; -use crate::common::{client_connect, MESSAGE_TYPE_RESPONSE}; +use crate::common::client_connect; use crate::error::{Error, Result}; -use crate::proto::{Code, Request, Response}; +use crate::proto::{Code, Request, Response, MESSAGE_TYPE_RESPONSE}; use crate::asynchronous::stream::{receive, to_req_buf}; use crate::r#async::utils; @@ -163,6 +163,7 @@ impl Client { Client { req_tx } } + /// Requsts a unary request and returns with response. pub async fn request(&self, req: Request) -> Result { let mut buf = Vec::with_capacity(req.compute_size() as usize); { diff --git a/src/asynchronous/server.rs b/src/asynchronous/server.rs index 9afd1836..17722411 100644 --- a/src/asynchronous/server.rs +++ b/src/asynchronous/server.rs @@ -13,12 +13,11 @@ use std::time::Duration; use crate::asynchronous::stream::{receive, respond, respond_with_status}; use crate::asynchronous::unix_incoming::UnixIncoming; -use crate::common::{self, Domain, MESSAGE_TYPE_REQUEST}; +use crate::common::{self, Domain}; use crate::context; use crate::error::{get_status, Error, Result}; -use crate::proto::{Code, Status}; +use crate::proto::{Code, MessageHeader, Status, MESSAGE_TYPE_REQUEST}; use crate::r#async::{MethodHandler, TtrpcContext}; -use crate::MessageHeader; use futures::stream::Stream; use futures::StreamExt as _; use std::marker::Unpin; diff --git a/src/asynchronous/stream.rs b/src/asynchronous/stream.rs index c99ba336..2cf94cc4 100644 --- a/src/asynchronous/stream.rs +++ b/src/asynchronous/stream.rs @@ -3,9 +3,10 @@ // SPDX-License-Identifier: Apache-2.0 // -use crate::common::{MESSAGE_HEADER_LENGTH, MESSAGE_LENGTH_MAX, MESSAGE_TYPE_RESPONSE}; use crate::error::{get_rpc_status, sock_error_msg, Error, Result}; -use crate::proto::{Code, Response, Status}; +use crate::proto::{ + Code, Response, Status, MESSAGE_HEADER_LENGTH, MESSAGE_LENGTH_MAX, MESSAGE_TYPE_RESPONSE, +}; use crate::r#async::utils; use crate::MessageHeader; use protobuf::Message; diff --git a/src/asynchronous/utils.rs b/src/asynchronous/utils.rs index cbd2cc57..faffdafd 100644 --- a/src/asynchronous/utils.rs +++ b/src/asynchronous/utils.rs @@ -3,9 +3,10 @@ // SPDX-License-Identifier: Apache-2.0 // -use crate::common::{MessageHeader, MESSAGE_TYPE_REQUEST, MESSAGE_TYPE_RESPONSE}; use crate::error::{get_status, Result}; -use crate::proto::{Code, Request, Status}; +use crate::proto::{ + Code, MessageHeader, Request, Status, MESSAGE_TYPE_REQUEST, MESSAGE_TYPE_RESPONSE, +}; use async_trait::async_trait; use protobuf::{CodedInputStream, Message}; use std::collections::HashMap; diff --git a/src/common.rs b/src/common.rs index b5be1833..3f39b070 100644 --- a/src/common.rs +++ b/src/common.rs @@ -6,7 +6,6 @@ //! Common functions and macros. use crate::error::{Error, Result}; -use byteorder::{BigEndian, ByteOrder}; #[cfg(any(feature = "async", not(target_os = "linux")))] use nix::fcntl::FdFlag; use nix::fcntl::{fcntl, FcntlArg, OFlag}; @@ -20,53 +19,6 @@ pub(crate) enum Domain { Vsock, } -/// Message header of ttrpc. -#[derive(Default, Debug)] -pub struct MessageHeader { - pub length: u32, - pub stream_id: u32, - pub type_: u8, - pub flags: u8, -} - -impl From for MessageHeader -where - T: AsRef<[u8]>, -{ - fn from(buf: T) -> Self { - let buf = buf.as_ref(); - debug_assert!(buf.len() >= MESSAGE_HEADER_LENGTH); - Self { - length: BigEndian::read_u32(&buf[..4]), - stream_id: BigEndian::read_u32(&buf[4..8]), - type_: buf[8], - flags: buf[9], - } - } -} - -impl From for Vec { - fn from(mh: MessageHeader) -> Self { - let mut buf = vec![0u8; MESSAGE_HEADER_LENGTH]; - mh.into_buf(&mut buf); - buf - } -} - -impl MessageHeader { - pub(crate) fn into_buf(self, mut buf: impl AsMut<[u8]>) { - let buf = buf.as_mut(); - debug_assert!(buf.len() >= MESSAGE_HEADER_LENGTH); - - let covbuf: &mut [u8] = &mut buf[..4]; - BigEndian::write_u32(covbuf, self.length); - let covbuf: &mut [u8] = &mut buf[4..8]; - BigEndian::write_u32(covbuf, self.stream_id); - buf[8] = self.type_; - buf[9] = self.flags; - } -} - pub(crate) fn do_listen(listener: RawFd) -> Result<()> { if let Err(e) = fcntl(listener, FcntlArg::F_SETFL(OFlag::O_NONBLOCK)) { return Err(Error::Others(format!( @@ -238,12 +190,6 @@ macro_rules! cfg_async { } } -pub const MESSAGE_HEADER_LENGTH: usize = 10; -pub const MESSAGE_LENGTH_MAX: usize = 4 << 20; - -pub const MESSAGE_TYPE_REQUEST: u8 = 0x1; -pub const MESSAGE_TYPE_RESPONSE: u8 = 0x2; - #[cfg(test)] mod tests { use super::*; @@ -306,23 +252,4 @@ mod tests { } } } - - #[test] - fn message_header() { - let buf = vec![ - 0x10, 0x0, 0x0, 0x0, // length - 0x0, 0x0, 0x0, 0x03, // stream_id - 0x2, // type_ - 0xef, // flags - ]; - let mh = MessageHeader::from(&buf); - assert_eq!(mh.length, 0x1000_0000); - assert_eq!(mh.stream_id, 0x3); - assert_eq!(mh.type_, MESSAGE_TYPE_RESPONSE); - assert_eq!(mh.flags, 0xef); - - let mut buf2 = vec![0; MESSAGE_HEADER_LENGTH]; - mh.into_buf(&mut buf2); - assert_eq!(&buf, &buf2); - } } diff --git a/src/lib.rs b/src/lib.rs index b5d2dfe6..4f913d44 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,20 +48,15 @@ extern crate log; pub mod error; #[macro_use] mod common; -#[allow(soft_unstable, clippy::type_complexity, clippy::too_many_arguments)] -mod compiled { - include!(concat!(env!("OUT_DIR"), "/mod.rs")); -} -pub use compiled::ttrpc as proto; pub mod context; +pub mod proto; #[doc(inline)] -pub use crate::common::MessageHeader; +pub use self::proto::{Code, MessageHeader, Request, Response, Status}; + #[doc(inline)] pub use crate::error::{get_status, Error, Result}; -#[doc(inline)] -pub use proto::{Code, Request, Response, Status}; cfg_sync! { pub mod sync; diff --git a/src/proto.rs b/src/proto.rs new file mode 100644 index 00000000..227ddc9a --- /dev/null +++ b/src/proto.rs @@ -0,0 +1,122 @@ +// Copyright 2022 Alibaba Cloud. All rights reserved. +// Copyright (c) 2020 Ant Financial +// +// SPDX-License-Identifier: Apache-2.0 +// + +#[allow(soft_unstable, clippy::type_complexity, clippy::too_many_arguments)] +mod compiled { + include!(concat!(env!("OUT_DIR"), "/mod.rs")); +} +pub use compiled::ttrpc::*; + +use byteorder::{BigEndian, ByteOrder}; + +pub const MESSAGE_HEADER_LENGTH: usize = 10; +pub const MESSAGE_LENGTH_MAX: usize = 4 << 20; + +pub const MESSAGE_TYPE_REQUEST: u8 = 0x1; +pub const MESSAGE_TYPE_RESPONSE: u8 = 0x2; + +/// Message header of ttrpc. +#[derive(Default, Debug)] +pub struct MessageHeader { + pub length: u32, + pub stream_id: u32, + pub type_: u8, + pub flags: u8, +} + +impl From for MessageHeader +where + T: AsRef<[u8]>, +{ + fn from(buf: T) -> Self { + let buf = buf.as_ref(); + debug_assert!(buf.len() >= MESSAGE_HEADER_LENGTH); + Self { + length: BigEndian::read_u32(&buf[..4]), + stream_id: BigEndian::read_u32(&buf[4..8]), + type_: buf[8], + flags: buf[9], + } + } +} + +impl From for Vec { + fn from(mh: MessageHeader) -> Self { + let mut buf = vec![0u8; MESSAGE_HEADER_LENGTH]; + mh.into_buf(&mut buf); + buf + } +} + +impl MessageHeader { + /// Creates a request MessageHeader from stream_id and len. + /// Use the default message type MESSAGE_TYPE_REQUEST, and default flags 0. + pub fn new_request(stream_id: u32, len: u32) -> Self { + Self { + length: len, + stream_id, + type_: MESSAGE_TYPE_REQUEST, + flags: 0, + } + } + + /// Creates a response MessageHeader from stream_id and len. + /// Use the default message type MESSAGE_TYPE_REQUEST, and default flags 0. + pub fn new_response(stream_id: u32, len: u32) -> Self { + Self { + length: len, + stream_id, + type_: MESSAGE_TYPE_RESPONSE, + flags: 0, + } + } + + /// Set the flags of message using the given flags. + pub fn set_flags(&mut self, flags: u8) { + self.flags = flags; + } + + /// Add a new flags to the message. + pub fn add_flags(&mut self, flags: u8) { + self.flags |= flags; + } + + pub(crate) fn into_buf(self, mut buf: impl AsMut<[u8]>) { + let buf = buf.as_mut(); + debug_assert!(buf.len() >= MESSAGE_HEADER_LENGTH); + + let covbuf: &mut [u8] = &mut buf[..4]; + BigEndian::write_u32(covbuf, self.length); + let covbuf: &mut [u8] = &mut buf[4..8]; + BigEndian::write_u32(covbuf, self.stream_id); + buf[8] = self.type_; + buf[9] = self.flags; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn message_header() { + let buf = vec![ + 0x10, 0x0, 0x0, 0x0, // length + 0x0, 0x0, 0x0, 0x03, // stream_id + 0x2, // type_ + 0xef, // flags + ]; + let mh = MessageHeader::from(&buf); + assert_eq!(mh.length, 0x1000_0000); + assert_eq!(mh.stream_id, 0x3); + assert_eq!(mh.type_, MESSAGE_TYPE_RESPONSE); + assert_eq!(mh.flags, 0xef); + + let mut buf2 = vec![0; MESSAGE_HEADER_LENGTH]; + mh.into_buf(&mut buf2); + assert_eq!(&buf, &buf2); + } +} diff --git a/src/sync/channel.rs b/src/sync/channel.rs index 07e42e81..4c49dfc5 100644 --- a/src/sync/channel.rs +++ b/src/sync/channel.rs @@ -15,10 +15,8 @@ use nix::sys::socket::*; use std::os::unix::io::RawFd; -use crate::common::{MESSAGE_HEADER_LENGTH, MESSAGE_LENGTH_MAX}; use crate::error::{get_rpc_status, sock_error_msg, Error, Result}; -use crate::proto::Code; -use crate::MessageHeader; +use crate::proto::{Code, MessageHeader, MESSAGE_HEADER_LENGTH, MESSAGE_LENGTH_MAX}; fn retryable(e: nix::Error) -> bool { use ::nix::Error; diff --git a/src/sync/client.rs b/src/sync/client.rs index 1d2587b3..aea0ca2b 100644 --- a/src/sync/client.rs +++ b/src/sync/client.rs @@ -25,11 +25,12 @@ use std::{io, thread}; #[cfg(target_os = "macos")] use crate::common::set_fd_close_exec; -use crate::common::{client_connect, MESSAGE_TYPE_REQUEST, MESSAGE_TYPE_RESPONSE, SOCK_CLOEXEC}; +use crate::common::{client_connect, SOCK_CLOEXEC}; use crate::error::{Error, Result}; -use crate::proto::{Code, Request, Response}; +use crate::proto::{ + Code, MessageHeader, Request, Response, MESSAGE_TYPE_REQUEST, MESSAGE_TYPE_RESPONSE, +}; use crate::sync::channel::{read_message, write_message}; -use crate::MessageHeader; use std::time::Duration; type Sender = mpsc::Sender<(Vec, mpsc::SyncSender>>)>; diff --git a/src/sync/server.rs b/src/sync/server.rs index 282c072e..9bee608c 100644 --- a/src/sync/server.rs +++ b/src/sync/server.rs @@ -26,14 +26,13 @@ use std::thread::JoinHandle; use std::{io, thread}; use super::utils::response_to_channel; +use crate::common; #[cfg(not(target_os = "linux"))] use crate::common::set_fd_close_exec; -use crate::common::{self, MESSAGE_TYPE_REQUEST}; use crate::context; use crate::error::{get_status, Error, Result}; -use crate::proto::{Code, Request, Response}; +use crate::proto::{Code, MessageHeader, Request, Response, MESSAGE_TYPE_REQUEST}; use crate::sync::channel::{read_message, write_message}; -use crate::MessageHeader; use crate::{MethodHandler, TtrpcContext}; // poll_queue will create WAIT_THREAD_COUNT_DEFAULT threads in begin. diff --git a/src/sync/utils.rs b/src/sync/utils.rs index 2155f180..b607b98b 100644 --- a/src/sync/utils.rs +++ b/src/sync/utils.rs @@ -3,9 +3,8 @@ // SPDX-License-Identifier: Apache-2.0 // -use crate::common::{MessageHeader, MESSAGE_TYPE_RESPONSE}; use crate::error::{Error, Result}; -use crate::proto::{Request, Response}; +use crate::proto::{MessageHeader, Request, Response, MESSAGE_TYPE_RESPONSE}; use protobuf::Message; use std::collections::HashMap; From 912667430e60ceba0bbd1a68cb663ba0dd918e31 Mon Sep 17 00:00:00 2001 From: wllenyj Date: Thu, 28 Apr 2022 12:07:37 +0800 Subject: [PATCH 02/15] Improve Message codec. Introduce Codec trait for protobuf message encoding/decoding. Signed-off-by: wllenyj --- src/asynchronous/client.rs | 60 +++++----- src/asynchronous/stream.rs | 8 -- src/asynchronous/utils.rs | 11 +- src/proto.rs | 236 +++++++++++++++++++++++++++++++++---- src/sync/client.rs | 24 +--- 5 files changed, 248 insertions(+), 91 deletions(-) diff --git a/src/asynchronous/client.rs b/src/asynchronous/client.rs index f232fe81..fe6fb9d1 100644 --- a/src/asynchronous/client.rs +++ b/src/asynchronous/client.rs @@ -4,26 +4,25 @@ // use nix::unistd::close; -use protobuf::{CodedInputStream, CodedOutputStream, Message}; use std::collections::HashMap; +use std::convert::TryInto; use std::os::unix::io::RawFd; use std::sync::{Arc, Mutex}; use crate::common::client_connect; use crate::error::{Error, Result}; -use crate::proto::{Code, Request, Response, MESSAGE_TYPE_RESPONSE}; +use crate::proto::{Code, Codec, GenMessage, Message, Request, Response, MESSAGE_TYPE_RESPONSE}; -use crate::asynchronous::stream::{receive, to_req_buf}; use crate::r#async::utils; use tokio::{ self, - io::{split, AsyncWriteExt}, + io::split, sync::mpsc::{channel, Receiver, Sender}, sync::Notify, }; -type RequestSender = Sender<(Vec, Sender>>)>; -type RequestReceiver = Receiver<(Vec, Sender>>)>; +type RequestSender = Sender<(GenMessage, Sender>>)>; +type RequestReceiver = Receiver<(GenMessage, Sender>>)>; type ResponseSender = Sender>>; type ResponseReceiver = Receiver>>; @@ -57,8 +56,9 @@ impl Client { let request_sender = tokio::spawn(async move { let mut stream_id: u32 = 1; - while let Some((body, resp_tx)) = rx.recv().await { + while let Some((mut msg, resp_tx)) = rx.recv().await { let current_stream_id = stream_id; + msg.header.set_stream_id(current_stream_id); stream_id += 2; { @@ -66,8 +66,7 @@ impl Client { map.insert(current_stream_id, resp_tx.clone()); } - let buf = to_req_buf(current_stream_id, body); - if let Err(e) = writer.write_all(&buf).await { + if let Err(e) = msg.write_to(&mut writer).await { error!("write_message got error: {:?}", e); { @@ -97,41 +96,42 @@ impl Client { _ = notify2.notified() => { break; } - res = receive(&mut reader) => { + res = GenMessage::read_from(&mut reader) => { match res { - Ok((header, body)) => { + Ok(msg) => { + trace!("Got Message body {:?}", msg.payload); let req_map = req_map.clone(); tokio::spawn(async move { let resp_tx2; { let mut map = req_map.lock().unwrap(); - let resp_tx = match map.get(&header.stream_id) { + let resp_tx = match map.get(&msg.header.stream_id) { Some(tx) => tx, None => { debug!( - "Receiver got unknown packet {:?} {:?}", - header, body + "Receiver got unknown packet {:?}", + msg ); return; } }; resp_tx2 = resp_tx.clone(); - map.remove(&header.stream_id); // Forget the result, just remove. + map.remove(&msg.header.stream_id); // Forget the result, just remove. } - if header.type_ != MESSAGE_TYPE_RESPONSE { + if msg.header.type_ != MESSAGE_TYPE_RESPONSE { resp_tx2 .send(Err(Error::Others(format!( - "Recver got malformed packet {:?} {:?}", - header, body + "Recver got malformed packet {:?}", + msg )))) .await .unwrap_or_else(|_e| error!("The request has returned")); return; } - resp_tx2.send(Ok(body)).await.unwrap_or_else(|_e| error!("The request has returned")); + resp_tx2.send(Ok(msg.payload)).await.unwrap_or_else(|_e| error!("The request has returned")); }); } Err(e) => { @@ -165,26 +165,24 @@ impl Client { /// Requsts a unary request and returns with response. pub async fn request(&self, req: Request) -> Result { - let mut buf = Vec::with_capacity(req.compute_size() as usize); - { - let mut s = CodedOutputStream::vec(&mut buf); - req.write_to(&mut s).map_err(err_to_others_err!(e, ""))?; - s.flush().map_err(err_to_others_err!(e, ""))?; - } + let timeout_nano = req.timeout_nano; + let msg: GenMessage = Message::new_request(0, req) + .try_into() + .map_err(|e: protobuf::error::ProtobufError| Error::Others(e.to_string()))?; let (tx, mut rx): (ResponseSender, ResponseReceiver) = channel(100); self.req_tx - .send((buf, tx)) + .send((msg, tx)) .await .map_err(|e| Error::Others(format!("Send packet to sender error {:?}", e)))?; - let result = if req.timeout_nano == 0 { + let result = if timeout_nano == 0 { rx.recv() .await .ok_or_else(|| Error::Others("Receive packet from receiver error".to_string()))? } else { tokio::time::timeout( - std::time::Duration::from_nanos(req.timeout_nano as u64), + std::time::Duration::from_nanos(timeout_nano as u64), rx.recv(), ) .await @@ -193,10 +191,8 @@ impl Client { }; let buf = result?; - let mut s = CodedInputStream::from_bytes(&buf); - let mut res = Response::new(); - res.merge_from(&mut s) - .map_err(err_to_others_err!(e, "Unpack response error "))?; + let res = + Response::decode(&buf).map_err(err_to_others_err!(e, "Unpack response error "))?; let status = res.get_status(); if status.get_code() != Code::OK { diff --git a/src/asynchronous/stream.rs b/src/asynchronous/stream.rs index 2cf94cc4..f3d12abc 100644 --- a/src/asynchronous/stream.rs +++ b/src/asynchronous/stream.rs @@ -76,14 +76,6 @@ fn header_to_buf(mh: MessageHeader) -> Vec { mh.into() } -pub(crate) fn to_req_buf(stream_id: u32, mut body: Vec) -> Vec { - let header = utils::get_request_header_from_body(stream_id, &body); - let mut buf = header_to_buf(header); - buf.append(&mut body); - - buf -} - pub(crate) fn to_res_buf(stream_id: u32, mut body: Vec) -> Vec { let header = utils::get_response_header_from_body(stream_id, &body); let mut buf = header_to_buf(header); diff --git a/src/asynchronous/utils.rs b/src/asynchronous/utils.rs index faffdafd..91d47bcc 100644 --- a/src/asynchronous/utils.rs +++ b/src/asynchronous/utils.rs @@ -5,7 +5,7 @@ use crate::error::{get_status, Result}; use crate::proto::{ - Code, MessageHeader, Request, Status, MESSAGE_TYPE_REQUEST, MESSAGE_TYPE_RESPONSE, + Code, MessageHeader, Request, Status, MESSAGE_TYPE_RESPONSE, }; use async_trait::async_trait; use protobuf::{CodedInputStream, Message}; @@ -109,15 +109,6 @@ pub(crate) fn get_response_header_from_body(stream_id: u32, body: &[u8]) -> Mess } } -pub(crate) fn get_request_header_from_body(stream_id: u32, body: &[u8]) -> MessageHeader { - MessageHeader { - length: body.len() as u32, - stream_id, - type_: MESSAGE_TYPE_REQUEST, - flags: 0, - } -} - pub(crate) fn new_unix_stream_from_raw_fd(fd: RawFd) -> UnixStream { let std_stream: std::os::unix::net::UnixStream; unsafe { diff --git a/src/proto.rs b/src/proto.rs index 227ddc9a..e8e78b30 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -11,6 +11,10 @@ mod compiled { pub use compiled::ttrpc::*; use byteorder::{BigEndian, ByteOrder}; +use protobuf::{CodedInputStream, CodedOutputStream}; + +#[cfg(feature = "async")] +use crate::error::{get_rpc_status, Error, Result as TtResult}; pub const MESSAGE_HEADER_LENGTH: usize = 10; pub const MESSAGE_LENGTH_MAX: usize = 4 << 20; @@ -19,7 +23,7 @@ pub const MESSAGE_TYPE_REQUEST: u8 = 0x1; pub const MESSAGE_TYPE_RESPONSE: u8 = 0x2; /// Message header of ttrpc. -#[derive(Default, Debug)] +#[derive(Default, Debug, Clone, Copy, PartialEq)] pub struct MessageHeader { pub length: u32, pub stream_id: u32, @@ -64,7 +68,7 @@ impl MessageHeader { } /// Creates a response MessageHeader from stream_id and len. - /// Use the default message type MESSAGE_TYPE_REQUEST, and default flags 0. + /// Use the default message type MESSAGE_TYPE_RESPONSE, and default flags 0. pub fn new_response(stream_id: u32, len: u32) -> Self { Self { length: len, @@ -74,6 +78,11 @@ impl MessageHeader { } } + /// Set the stream_id of message using the given value. + pub fn set_stream_id(&mut self, stream_id: u32) { + self.stream_id = stream_id; + } + /// Set the flags of message using the given flags. pub fn set_flags(&mut self, flags: u8) { self.flags = flags; @@ -97,26 +106,207 @@ impl MessageHeader { } } -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn message_header() { - let buf = vec![ - 0x10, 0x0, 0x0, 0x0, // length - 0x0, 0x0, 0x0, 0x03, // stream_id - 0x2, // type_ - 0xef, // flags - ]; - let mh = MessageHeader::from(&buf); - assert_eq!(mh.length, 0x1000_0000); - assert_eq!(mh.stream_id, 0x3); - assert_eq!(mh.type_, MESSAGE_TYPE_RESPONSE); - assert_eq!(mh.flags, 0xef); - - let mut buf2 = vec![0; MESSAGE_HEADER_LENGTH]; - mh.into_buf(&mut buf2); - assert_eq!(&buf, &buf2); +#[cfg(feature = "async")] +impl MessageHeader { + /// Encodes a MessageHeader to writer. + pub async fn write_to( + &self, + mut writer: impl tokio::io::AsyncWriteExt + Unpin, + ) -> std::io::Result<()> { + writer.write_u32(self.length).await?; + writer.write_u32(self.stream_id).await?; + writer.write_u8(self.type_).await?; + writer.write_u8(self.flags).await?; + writer.flush().await + } + + /// Decodes a MessageHeader from reader. + pub async fn read_from( + mut reader: impl tokio::io::AsyncReadExt + Unpin, + ) -> std::io::Result { + let mut content = vec![0; MESSAGE_HEADER_LENGTH]; + reader.read_exact(&mut content).await?; + Ok(MessageHeader::from(&content)) + } +} + +/// Generic message of ttrpc. +#[derive(Default, Debug, Clone, PartialEq)] +pub struct GenMessage { + pub header: MessageHeader, + pub payload: Vec, +} + +#[cfg(feature = "async")] +impl GenMessage { + /// Encodes a MessageHeader to writer. + pub async fn write_to( + &self, + mut writer: impl tokio::io::AsyncWriteExt + Unpin, + ) -> TtResult<()> { + self.header + .write_to(&mut writer) + .await + .map_err(|e| Error::Socket(e.to_string()))?; + writer + .write_all(&self.payload) + .await + .map_err(|e| Error::Socket(e.to_string()))?; + Ok(()) + } + + /// Decodes a MessageHeader from reader. + pub async fn read_from(mut reader: impl tokio::io::AsyncReadExt + Unpin) -> TtResult { + let header = MessageHeader::read_from(&mut reader) + .await + .map_err(|e| Error::Socket(e.to_string()))?; + + if header.length > MESSAGE_LENGTH_MAX as u32 { + return Err(get_rpc_status( + Code::INVALID_ARGUMENT, + format!( + "message length {} exceed maximum message size of {}", + header.length, MESSAGE_LENGTH_MAX + ), + )); + } + + let mut content = vec![0; header.length as usize]; + reader + .read_exact(&mut content) + .await + .map_err(|e| Error::Socket(e.to_string()))?; + + Ok(Self { + header, + payload: content, + }) + } +} + +/// TTRPC codec, only protobuf is supported. +pub trait Codec { + type E; + + fn size(&self) -> u32; + fn encode(&self) -> Result, Self::E>; + fn decode(buf: impl AsRef<[u8]>) -> Result + where + Self: Sized; +} + +impl Codec for M { + type E = protobuf::error::ProtobufError; + + fn size(&self) -> u32 { + self.compute_size() + } + + fn encode(&self) -> Result, Self::E> { + let mut buf = vec![0; self.compute_size() as usize]; + let mut s = CodedOutputStream::bytes(&mut buf); + self.write_to(&mut s)?; + s.flush()?; + Ok(buf) + } + + fn decode(buf: impl AsRef<[u8]>) -> Result { + let mut s = CodedInputStream::from_bytes(buf.as_ref()); + M::parse_from(&mut s) + } +} + +/// Message of ttrpc. +#[derive(Default, Debug, Clone, PartialEq)] +pub struct Message { + pub header: MessageHeader, + pub payload: C, +} + +impl std::convert::TryFrom for Message +where + C: Codec, +{ + type Error = C::E; + fn try_from(gen: GenMessage) -> Result { + Ok(Self { + header: gen.header, + payload: C::decode(&gen.payload)?, + }) + } +} + +impl std::convert::TryFrom> for GenMessage +where + C: Codec, +{ + type Error = C::E; + fn try_from(msg: Message) -> Result { + Ok(Self { + header: msg.header, + payload: msg.payload.encode()?, + }) + } +} + +impl Message { + pub fn new_request(stream_id: u32, message: C) -> Self { + Self { + header: MessageHeader::new_request(stream_id, message.size()), + payload: message, + } + } +} + +#[cfg(feature = "async")] +impl Message +where + C: Codec, + C::E: std::fmt::Display, +{ + /// Encodes a MessageHeader to writer. + pub async fn write_to( + &self, + mut writer: impl tokio::io::AsyncWriteExt + Unpin, + ) -> TtResult<()> { + self.header + .write_to(&mut writer) + .await + .map_err(|e| Error::Socket(e.to_string()))?; + let content = self + .payload + .encode() + .map_err(err_to_others_err!(e, "Encode payload failed."))?; + writer + .write_all(&content) + .await + .map_err(|e| Error::Socket(e.to_string()))?; + Ok(()) + } + + /// Decodes a MessageHeader from reader. + pub async fn read_from(mut reader: impl tokio::io::AsyncReadExt + Unpin) -> TtResult { + let header = MessageHeader::read_from(&mut reader) + .await + .map_err(|e| Error::Socket(e.to_string()))?; + + if header.length > MESSAGE_LENGTH_MAX as u32 { + return Err(get_rpc_status( + Code::INVALID_ARGUMENT, + format!( + "message length {} exceed maximum message size of {}", + header.length, MESSAGE_LENGTH_MAX + ), + )); + } + + let mut content = vec![0; header.length as usize]; + reader + .read_exact(&mut content) + .await + .map_err(|e| Error::Socket(e.to_string()))?; + let payload = + C::decode(content).map_err(err_to_others_err!(e, "Decode payload failed."))?; + Ok(Self { header, payload }) } } diff --git a/src/sync/client.rs b/src/sync/client.rs index aea0ca2b..ece9ee0b 100644 --- a/src/sync/client.rs +++ b/src/sync/client.rs @@ -16,7 +16,6 @@ use nix::sys::socket::*; use nix::unistd::close; -use protobuf::{CodedInputStream, CodedOutputStream, Message}; use std::collections::HashMap; use std::os::unix::io::RawFd; use std::sync::mpsc; @@ -27,9 +26,7 @@ use std::{io, thread}; use crate::common::set_fd_close_exec; use crate::common::{client_connect, SOCK_CLOEXEC}; use crate::error::{Error, Result}; -use crate::proto::{ - Code, MessageHeader, Request, Response, MESSAGE_TYPE_REQUEST, MESSAGE_TYPE_RESPONSE, -}; +use crate::proto::{Code, Codec, MessageHeader, Request, Response, MESSAGE_TYPE_RESPONSE}; use crate::sync::channel::{read_message, write_message}; use std::time::Duration; @@ -81,12 +78,8 @@ impl Client { let mut map = recver_map.lock().unwrap(); map.insert(current_stream_id, recver_tx.clone()); } - let mh = MessageHeader { - length: buf.len() as u32, - stream_id: current_stream_id, - type_: MESSAGE_TYPE_REQUEST, - flags: 0, - }; + let mut mh = MessageHeader::new_request(0, buf.len() as u32); + mh.set_stream_id(current_stream_id); if let Err(e) = write_message(fd, mh, buf) { //Remove current_stream_id and recver_tx to recver_map { @@ -215,10 +208,7 @@ impl Client { } } pub fn request(&self, req: Request) -> Result { - let mut buf = Vec::with_capacity(req.compute_size() as usize); - let mut s = CodedOutputStream::vec(&mut buf); - req.write_to(&mut s).map_err(err_to_others_err!(e, ""))?; - s.flush().map_err(err_to_others_err!(e, ""))?; + let buf = req.encode().map_err(err_to_others_err!(e, ""))?; let (tx, rx) = mpsc::sync_channel(0); @@ -238,10 +228,8 @@ impl Client { }; let buf = result?; - let mut s = CodedInputStream::from_bytes(&buf); - let mut res = Response::new(); - res.merge_from(&mut s) - .map_err(err_to_others_err!(e, "Unpack response error "))?; + let res = + Response::decode(&buf).map_err(err_to_others_err!(e, "Unpack response error "))?; let status = res.get_status(); if status.get_code() != Code::OK { From f4b3d90bb1d47e4b09bd35445d57dde6fff89986 Mon Sep 17 00:00:00 2001 From: wllenyj Date: Mon, 6 Jun 2022 15:18:00 +0800 Subject: [PATCH 03/15] Add more unit test for proto.rs Added more unit test for proto.rs. Signed-off-by: wllenyj --- src/proto.rs | 167 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) diff --git a/src/proto.rs b/src/proto.rs index e8e78b30..be8c4add 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -310,3 +310,170 @@ where Ok(Self { header, payload }) } } + +#[cfg(test)] +mod tests { + use std::convert::{TryFrom, TryInto}; + + use super::*; + + static MESSAGE_HEADER: [u8; MESSAGE_HEADER_LENGTH] = [ + 0x10, 0x0, 0x0, 0x0, // length + 0x0, 0x0, 0x0, 0x03, // stream_id + 0x2, // type_ + 0xef, // flags + ]; + + #[test] + fn message_header() { + let mh = MessageHeader::from(&MESSAGE_HEADER); + assert_eq!(mh.length, 0x1000_0000); + assert_eq!(mh.stream_id, 0x3); + assert_eq!(mh.type_, MESSAGE_TYPE_RESPONSE); + assert_eq!(mh.flags, 0xef); + + let mut buf2 = vec![0; MESSAGE_HEADER_LENGTH]; + mh.into_buf(&mut buf2); + assert_eq!(&MESSAGE_HEADER, &buf2[..]); + + let mh = MessageHeader::from(&PROTOBUF_MESSAGE_HEADER); + assert_eq!(mh.length as usize, TEST_PAYLOAD_LEN); + } + + #[rustfmt::skip] + static PROTOBUF_MESSAGE_HEADER: [u8; MESSAGE_HEADER_LENGTH] = [ + 0x00, 0x0, 0x0, TEST_PAYLOAD_LEN as u8, // length + 0x0, 0x12, 0x34, 0x56, // stream_id + 0x1, // type_ + 0xef, // flags + ]; + + const TEST_PAYLOAD_LEN: usize = 67; + static PROTOBUF_REQUEST: [u8; TEST_PAYLOAD_LEN] = [ + 10, 17, 103, 114, 112, 99, 46, 84, 101, 115, 116, 83, 101, 114, 118, 105, 99, 101, 115, 18, + 4, 84, 101, 115, 116, 26, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 32, 128, 218, 196, 9, 42, 24, 10, + 9, 116, 101, 115, 116, 95, 107, 101, 121, 49, 18, 11, 116, 101, 115, 116, 95, 118, 97, 108, + 117, 101, 49, + ]; + + fn new_protobuf_request() -> Request { + let mut creq = Request::new(); + creq.set_service("grpc.TestServices".to_string()); + creq.set_method("Test".to_string()); + creq.set_timeout_nano(20 * 1000 * 1000); + let mut meta: protobuf::RepeatedField = protobuf::RepeatedField::default(); + meta.push(KeyValue { + key: "test_key1".to_string(), + value: "test_value1".to_string(), + ..Default::default() + }); + creq.set_metadata(meta); + creq.payload = vec![0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9]; + creq + } + + #[test] + fn protobuf_codec() { + let creq = new_protobuf_request(); + let buf = creq.encode().unwrap(); + assert_eq!(&buf, &PROTOBUF_REQUEST); + let dreq = Request::decode(&buf).unwrap(); + assert_eq!(creq, dreq); + let dreq2 = Request::decode(&PROTOBUF_REQUEST).unwrap(); + assert_eq!(creq, dreq2); + } + + #[test] + fn gen_message_to_message() { + let req = new_protobuf_request(); + let msg = Message::new_request(3, req); + let msg_clone = msg.clone(); + let gen: GenMessage = msg.try_into().unwrap(); + let dmsg = Message::::try_from(gen).unwrap(); + assert_eq!(msg_clone, dmsg); + } + + #[cfg(feature = "async")] + #[tokio::test] + async fn async_message_header() { + use std::io::Cursor; + let mut buf = vec![]; + let mut io = Cursor::new(&mut buf); + let mh = MessageHeader::from(&MESSAGE_HEADER); + mh.write_to(&mut io).await.unwrap(); + assert_eq!(buf, &MESSAGE_HEADER); + + let dmh = MessageHeader::read_from(&buf[..]).await.unwrap(); + assert_eq!(mh, dmh); + } + + #[cfg(feature = "async")] + #[tokio::test] + async fn async_gen_message() { + let mut buf = Vec::from(MESSAGE_HEADER); + buf.extend_from_slice(&PROTOBUF_REQUEST); + let res = GenMessage::read_from(&*buf).await; + // exceed maximum message size + assert!(matches!(res, Err(Error::RpcStatus(_)))); + + let mut buf = Vec::from(PROTOBUF_MESSAGE_HEADER); + buf.extend_from_slice(&PROTOBUF_REQUEST); + buf.extend_from_slice(&[0x0, 0x0]); + let gen = GenMessage::read_from(&*buf).await.unwrap(); + assert_eq!(gen.header.length as usize, TEST_PAYLOAD_LEN); + assert_eq!(gen.header.length, gen.payload.len() as u32); + assert_eq!(gen.header.stream_id, 0x123456); + assert_eq!(gen.header.type_, MESSAGE_TYPE_REQUEST); + assert_eq!(gen.header.flags, 0xef); + assert_eq!(&gen.payload, &PROTOBUF_REQUEST); + assert_eq!( + &buf[MESSAGE_HEADER_LENGTH + TEST_PAYLOAD_LEN..], + &[0x0, 0x0] + ); + + let mut dbuf = vec![]; + let mut io = std::io::Cursor::new(&mut dbuf); + gen.write_to(&mut io).await.unwrap(); + assert_eq!(&*dbuf, &buf[..MESSAGE_HEADER_LENGTH + TEST_PAYLOAD_LEN]); + } + + #[cfg(feature = "async")] + #[tokio::test] + async fn async_message() { + let mut buf = Vec::from(MESSAGE_HEADER); + buf.extend_from_slice(&PROTOBUF_REQUEST); + let res = Message::::read_from(&*buf).await; + // exceed maximum message size + assert!(matches!(res, Err(Error::RpcStatus(_)))); + + let mut buf = Vec::from(PROTOBUF_MESSAGE_HEADER); + buf.extend_from_slice(&PROTOBUF_REQUEST); + buf.extend_from_slice(&[0x0, 0x0]); + let msg = Message::::read_from(&*buf).await.unwrap(); + assert_eq!(msg.header.length, 67); + assert_eq!(msg.header.length, msg.payload.size() as u32); + assert_eq!(msg.header.stream_id, 0x123456); + assert_eq!(msg.header.type_, MESSAGE_TYPE_REQUEST); + assert_eq!(msg.header.flags, 0xef); + assert_eq!(&msg.payload.service, "grpc.TestServices"); + assert_eq!(&msg.payload.method, "Test"); + assert_eq!( + msg.payload.payload, + vec![0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9] + ); + assert_eq!(msg.payload.timeout_nano, 20 * 1000 * 1000); + assert_eq!(msg.payload.metadata.len(), 1); + assert_eq!(&msg.payload.metadata[0].key, "test_key1"); + assert_eq!(&msg.payload.metadata[0].value, "test_value1"); + + let req = new_protobuf_request(); + let mut dmsg = Message::new_request(u32::MAX, req); + dmsg.header.set_stream_id(0x123456); + dmsg.header.set_flags(0xe0); + dmsg.header.add_flags(0x0f); + let mut dbuf = vec![]; + let mut io = std::io::Cursor::new(&mut dbuf); + dmsg.write_to(&mut io).await.unwrap(); + assert_eq!(&dbuf, &buf[..MESSAGE_HEADER_LENGTH + TEST_PAYLOAD_LEN]); + } +} From 041bf30c0a43cfb8527d153fd1927faaf1aaf8d5 Mon Sep 17 00:00:00 2001 From: wllenyj Date: Mon, 6 Jun 2022 16:03:56 +0800 Subject: [PATCH 04/15] async: add shutdown module. It is used for server-side graceful shutdown. Signed-off-by: wanglei01 --- src/asynchronous/mod.rs | 1 + src/asynchronous/shutdown.rs | 311 +++++++++++++++++++++++++++++++++++ 2 files changed, 312 insertions(+) create mode 100644 src/asynchronous/shutdown.rs diff --git a/src/asynchronous/mod.rs b/src/asynchronous/mod.rs index f63aae86..371aaa21 100644 --- a/src/asynchronous/mod.rs +++ b/src/asynchronous/mod.rs @@ -12,6 +12,7 @@ mod stream; #[doc(hidden)] mod utils; mod unix_incoming; +pub mod shutdown; #[doc(inline)] pub use crate::r#async::client::Client; diff --git a/src/asynchronous/shutdown.rs b/src/asynchronous/shutdown.rs new file mode 100644 index 00000000..9d1bb136 --- /dev/null +++ b/src/asynchronous/shutdown.rs @@ -0,0 +1,311 @@ +// Copyright 2022 Alibaba Cloud. All rights reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; + +use tokio::sync::Notify; +use tokio::time::{error::Elapsed, timeout, Duration}; + +#[derive(Debug)] +struct Shared { + shutdown: AtomicBool, + notify_shutdown: Notify, + + waiters: AtomicUsize, + notify_exit: Notify, +} + +impl Shared { + fn is_shutdown(&self) -> bool { + self.shutdown.load(Ordering::Relaxed) + } +} + +/// Wait for the shutdown notification. +#[derive(Debug)] +pub struct Waiter { + shared: Arc, +} + +/// Used to Notify all [`Waiter`s](Waiter) shutdown. +/// +/// No `Clone` is provided. If you want multiple instances, you can use Arc. +/// Notifier will automatically call shutdown when dropping. +#[derive(Debug)] +pub struct Notifier { + shared: Arc, + wait_time: Option, +} + +/// Create a new shutdown pair([`Notifier`], [`Waiter`]) without timeout. +/// +/// The [`Notifier`] +pub fn new() -> (Notifier, Waiter) { + _with_timeout(None) +} + +/// Create a new shutdown pair with the specified [`Duration`]. +/// +/// The [`Duration`] is used to specify the timeout of the [`Notifier::wait_all_exit()`]. +/// +/// [`Duration`]: tokio::time::Duration +pub fn with_timeout(wait_time: Duration) -> (Notifier, Waiter) { + _with_timeout(Some(wait_time)) +} + +fn _with_timeout(wait_time: Option) -> (Notifier, Waiter) { + let shared = Arc::new(Shared { + shutdown: AtomicBool::new(false), + waiters: AtomicUsize::new(1), + notify_shutdown: Notify::new(), + notify_exit: Notify::new(), + }); + + let notifier = Notifier { + shared: shared.clone(), + wait_time, + }; + + let waiter = Waiter { shared }; + + (notifier, waiter) +} + +impl Waiter { + /// Return `true` if the [`Notifier::shutdown()`] has been called. + /// + /// [`Notifier::shutdown()`]: Notifier::shutdown() + pub fn is_shutdown(&self) -> bool { + self.shared.is_shutdown() + } + + /// Waiting for the [`Notifier::shutdown()`] to be called. + pub async fn wait_shutdown(&self) { + while !self.is_shutdown() { + let shutdown = self.shared.notify_shutdown.notified(); + if self.is_shutdown() { + return; + } + shutdown.await; + } + } + + fn from_shared(shared: Arc) -> Self { + shared.waiters.fetch_add(1, Ordering::Relaxed); + Self { shared } + } +} + +impl Clone for Waiter { + fn clone(&self) -> Self { + Self::from_shared(self.shared.clone()) + } +} + +impl Drop for Waiter { + fn drop(&mut self) { + if 1 == self.shared.waiters.fetch_sub(1, Ordering::Relaxed) { + self.shared.notify_exit.notify_waiters(); + } + } +} + +impl Notifier { + /// Return `true` if the [`Notifier::shutdown()`] has been called. + /// + /// [`Notifier::shutdown()`]: Notifier::shutdown() + pub fn is_shutdown(&self) -> bool { + self.shared.is_shutdown() + } + + /// Notify all [`Waiter`s](Waiter) shutdown. + /// + /// It will cause all calls blocking at `Waiter::wait_shutdown().await` to return. + pub fn shutdown(&self) { + let is_shutdown = self.shared.shutdown.swap(true, Ordering::Relaxed); + if !is_shutdown { + self.shared.notify_shutdown.notify_waiters(); + } + } + + /// Return the num of all [`Waiter`]s. + pub fn waiters(&self) -> usize { + self.shared.waiters.load(Ordering::Relaxed) + } + + /// Create a new [`Waiter`]. + pub fn subscribe(&self) -> Waiter { + Waiter::from_shared(self.shared.clone()) + } + + /// Wait for all [`Waiter`]s to drop. + pub async fn wait_all_exit(&self) -> Result<(), Elapsed> { + //debug_assert!(self.shared.is_shutdown()); + if self.waiters() == 0 { + return Ok(()); + } + let wait = self.wait(); + if self.waiters() == 0 { + return Ok(()); + } + wait.await + } + + async fn wait(&self) -> Result<(), Elapsed> { + if let Some(tm) = self.wait_time { + timeout(tm, self.shared.notify_exit.notified()).await + } else { + self.shared.notify_exit.notified().await; + Ok(()) + } + } +} + +impl Drop for Notifier { + fn drop(&mut self) { + self.shutdown() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[tokio::test] + async fn it_work() { + let (notifier, waiter) = new(); + + let task = tokio::spawn(async move { + waiter.wait_shutdown().await; + }); + + assert_eq!(notifier.waiters(), 1); + notifier.shutdown(); + task.await.unwrap(); + assert_eq!(notifier.waiters(), 0); + } + + #[tokio::test] + async fn notifier_drop() { + let (notifier, waiter) = new(); + assert_eq!(notifier.waiters(), 1); + assert!(!waiter.is_shutdown()); + drop(notifier); + assert!(waiter.is_shutdown()); + assert_eq!(waiter.shared.waiters.load(Ordering::Relaxed), 1); + } + + #[tokio::test] + async fn waiter_clone() { + let (notifier, waiter1) = new(); + assert_eq!(notifier.waiters(), 1); + + let waiter2 = waiter1.clone(); + assert_eq!(notifier.waiters(), 2); + + let waiter3 = notifier.subscribe(); + assert_eq!(notifier.waiters(), 3); + + drop(waiter2); + assert_eq!(notifier.waiters(), 2); + + let task = tokio::spawn(async move { + waiter3.wait_shutdown().await; + assert!(waiter3.is_shutdown()); + }); + + assert!(!waiter1.is_shutdown()); + notifier.shutdown(); + assert!(waiter1.is_shutdown()); + + task.await.unwrap(); + + assert_eq!(notifier.waiters(), 1); + } + + #[tokio::test] + async fn concurrency_notifier_shutdown() { + let (notifier, waiter) = new(); + let arc_notifier = Arc::new(notifier); + let notifier1 = arc_notifier.clone(); + let notifier2 = notifier1.clone(); + + let task1 = tokio::spawn(async move { + assert_eq!(notifier1.waiters(), 1); + + let waiter = notifier1.subscribe(); + assert_eq!(notifier1.waiters(), 2); + + notifier1.shutdown(); + waiter.wait_shutdown().await; + }); + + let task2 = tokio::spawn(async move { + assert_eq!(notifier2.waiters(), 1); + notifier2.shutdown(); + }); + waiter.wait_shutdown().await; + assert!(arc_notifier.is_shutdown()); + task1.await.unwrap(); + task2.await.unwrap(); + } + + #[tokio::test] + async fn concurrency_notifier_wait() { + let (notifier, waiter) = new(); + let arc_notifier = Arc::new(notifier); + let notifier1 = arc_notifier.clone(); + let notifier2 = notifier1.clone(); + + let task1 = tokio::spawn(async move { + notifier1.shutdown(); + notifier1.wait_all_exit().await.unwrap(); + }); + + let task2 = tokio::spawn(async move { + notifier2.shutdown(); + notifier2.wait_all_exit().await.unwrap(); + }); + + waiter.wait_shutdown().await; + drop(waiter); + task1.await.unwrap(); + task2.await.unwrap(); + } + + #[tokio::test] + async fn wait_all_exit() { + let (notifier, waiter) = new(); + let mut tasks = Vec::with_capacity(100); + for i in 0..100 { + assert_eq!(notifier.waiters(), 1 + i); + let waiter1 = waiter.clone(); + tasks.push(tokio::spawn(async move { + waiter1.wait_shutdown().await; + })); + } + drop(waiter); + assert_eq!(notifier.waiters(), 100); + notifier.shutdown(); + notifier.wait_all_exit().await.unwrap(); + for t in tasks { + t.await.unwrap(); + } + } + + #[tokio::test] + async fn wait_timeout() { + let (notifier, waiter) = with_timeout(Duration::from_millis(100)); + let task = tokio::spawn(async move { + waiter.wait_shutdown().await; + tokio::time::sleep(Duration::from_millis(200)).await; + }); + notifier.shutdown(); + // Elapsed + assert!(matches!(notifier.wait_all_exit().await, Err(_))); + task.await.unwrap(); + } +} From 21b8e1accb5fc410c241dc2133e5a592abeaf282 Mon Sep 17 00:00:00 2001 From: wllenyj Date: Mon, 6 Jun 2022 16:18:15 +0800 Subject: [PATCH 05/15] async: Relayout use crate. The first block uses the standard library. The second block uses the third party crates, and the third block uses this crate. Signed-off-by: wllenyj --- src/asynchronous/client.rs | 12 ++++++------ src/asynchronous/server.rs | 26 +++++++++++++------------- src/asynchronous/stream.rs | 5 +++-- src/asynchronous/utils.rs | 12 ++++++------ 4 files changed, 28 insertions(+), 27 deletions(-) diff --git a/src/asynchronous/client.rs b/src/asynchronous/client.rs index fe6fb9d1..abf61c92 100644 --- a/src/asynchronous/client.rs +++ b/src/asynchronous/client.rs @@ -3,17 +3,12 @@ // SPDX-License-Identifier: Apache-2.0 // -use nix::unistd::close; use std::collections::HashMap; use std::convert::TryInto; use std::os::unix::io::RawFd; use std::sync::{Arc, Mutex}; -use crate::common::client_connect; -use crate::error::{Error, Result}; -use crate::proto::{Code, Codec, GenMessage, Message, Request, Response, MESSAGE_TYPE_RESPONSE}; - -use crate::r#async::utils; +use nix::unistd::close; use tokio::{ self, io::split, @@ -21,6 +16,11 @@ use tokio::{ sync::Notify, }; +use crate::common::client_connect; +use crate::error::{Error, Result}; +use crate::proto::{Code, Codec, GenMessage, Message, Request, Response, MESSAGE_TYPE_RESPONSE}; +use crate::r#async::utils; + type RequestSender = Sender<(GenMessage, Sender>>)>; type RequestReceiver = Receiver<(GenMessage, Sender>>)>; diff --git a/src/asynchronous/server.rs b/src/asynchronous/server.rs index 17722411..e9dfee70 100644 --- a/src/asynchronous/server.rs +++ b/src/asynchronous/server.rs @@ -3,26 +3,18 @@ // SPDX-License-Identifier: Apache-2.0 // -use crate::r#async::utils; -use nix::unistd; use std::collections::HashMap; +use std::marker::Unpin; use std::os::unix::io::RawFd; +use std::os::unix::io::{AsRawFd, FromRawFd}; +use std::os::unix::net::UnixListener as SysUnixListener; use std::result::Result as StdResult; use std::sync::Arc; use std::time::Duration; -use crate::asynchronous::stream::{receive, respond, respond_with_status}; -use crate::asynchronous::unix_incoming::UnixIncoming; -use crate::common::{self, Domain}; -use crate::context; -use crate::error::{get_status, Error, Result}; -use crate::proto::{Code, MessageHeader, Status, MESSAGE_TYPE_REQUEST}; -use crate::r#async::{MethodHandler, TtrpcContext}; use futures::stream::Stream; use futures::StreamExt as _; -use std::marker::Unpin; -use std::os::unix::io::{AsRawFd, FromRawFd}; -use std::os::unix::net::UnixListener as SysUnixListener; +use nix::unistd; use tokio::{ self, io::{split, AsyncRead, AsyncWrite, AsyncWriteExt}, @@ -32,10 +24,18 @@ use tokio::{ sync::watch, time::timeout, }; - #[cfg(target_os = "linux")] use tokio_vsock::VsockListener; +use crate::asynchronous::stream::{receive, respond, respond_with_status}; +use crate::asynchronous::unix_incoming::UnixIncoming; +use crate::common::{self, Domain}; +use crate::context; +use crate::error::{get_status, Error, Result}; +use crate::proto::{Code, MessageHeader, Status, MESSAGE_TYPE_REQUEST}; +use crate::r#async::utils; +use crate::r#async::{MethodHandler, TtrpcContext}; + /// A ttrpc Server (async). pub struct Server { listeners: Vec, diff --git a/src/asynchronous/stream.rs b/src/asynchronous/stream.rs index f3d12abc..f281da25 100644 --- a/src/asynchronous/stream.rs +++ b/src/asynchronous/stream.rs @@ -3,14 +3,15 @@ // SPDX-License-Identifier: Apache-2.0 // +use protobuf::Message; +use tokio::io::AsyncReadExt; + use crate::error::{get_rpc_status, sock_error_msg, Error, Result}; use crate::proto::{ Code, Response, Status, MESSAGE_HEADER_LENGTH, MESSAGE_LENGTH_MAX, MESSAGE_TYPE_RESPONSE, }; use crate::r#async::utils; use crate::MessageHeader; -use protobuf::Message; -use tokio::io::AsyncReadExt; async fn receive_count(reader: &mut T, count: usize) -> Result> where diff --git a/src/asynchronous/utils.rs b/src/asynchronous/utils.rs index 91d47bcc..68ce29cc 100644 --- a/src/asynchronous/utils.rs +++ b/src/asynchronous/utils.rs @@ -3,17 +3,17 @@ // SPDX-License-Identifier: Apache-2.0 // -use crate::error::{get_status, Result}; -use crate::proto::{ - Code, MessageHeader, Request, Status, MESSAGE_TYPE_RESPONSE, -}; -use async_trait::async_trait; -use protobuf::{CodedInputStream, Message}; use std::collections::HashMap; use std::os::unix::io::{FromRawFd, RawFd}; use std::result::Result as StdResult; + +use async_trait::async_trait; +use protobuf::{CodedInputStream, Message}; use tokio::net::UnixStream; +use crate::error::{get_status, Result}; +use crate::proto::{Code, MessageHeader, Request, Status, MESSAGE_TYPE_RESPONSE}; + /// Handle request in async mode. #[macro_export] macro_rules! async_request_handler { From 54ab070db134fb059283fbe590d6c1d755d5ba14 Mon Sep 17 00:00:00 2001 From: wllenyj Date: Mon, 6 Jun 2022 16:43:40 +0800 Subject: [PATCH 06/15] async: use shutdown to implement server graceful shutdown. Using shutdown can be used instead of channel composition. Signed-off-by: wllenyj --- src/asynchronous/server.rs | 79 ++++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 41 deletions(-) diff --git a/src/asynchronous/server.rs b/src/asynchronous/server.rs index e9dfee70..671e6b97 100644 --- a/src/asynchronous/server.rs +++ b/src/asynchronous/server.rs @@ -21,7 +21,6 @@ use tokio::{ net::UnixListener, select, spawn, sync::mpsc::{channel, Receiver, Sender}, - sync::watch, time::timeout, }; #[cfg(target_os = "linux")] @@ -33,16 +32,20 @@ use crate::common::{self, Domain}; use crate::context; use crate::error::{get_status, Error, Result}; use crate::proto::{Code, MessageHeader, Status, MESSAGE_TYPE_REQUEST}; +use crate::r#async::shutdown; use crate::r#async::utils; use crate::r#async::{MethodHandler, TtrpcContext}; +const DEFAULT_CONN_SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(5000); +const DEFAULT_SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(10000); + /// A ttrpc Server (async). pub struct Server { listeners: Vec, methods: Arc>>, domain: Option, - disconnect_tx: Option>, - all_conn_done_rx: Option>, + + shutdown: shutdown::Notifier, stop_listen_tx: Option>>, } @@ -52,8 +55,7 @@ impl Default for Server { listeners: Vec::with_capacity(1), methods: Arc::new(HashMap::new()), domain: None, - disconnect_tx: None, - all_conn_done_rx: None, + shutdown: shutdown::with_timeout(DEFAULT_SERVER_SHUTDOWN_TIMEOUT).0, stop_listen_tx: None, } } @@ -151,11 +153,7 @@ impl Server { { let methods = self.methods.clone(); - let (disconnect_tx, close_conn_rx) = watch::channel(0); - self.disconnect_tx = Some(disconnect_tx); - - let (conn_done_tx, all_conn_done_rx) = channel::(1); - self.all_conn_done_rx = Some(all_conn_done_rx); + let shutdown_waiter = self.shutdown.subscribe(); let (stop_listen_tx, mut stop_listen_rx) = channel(1); self.stop_listen_tx = Some(stop_listen_tx); @@ -174,8 +172,7 @@ impl Server { fd, stream, methods.clone(), - close_conn_rx.clone(), - conn_done_tx.clone() + shutdown_waiter.clone(), ).await; } Err(e) => { @@ -201,7 +198,6 @@ impl Server { } } } - drop(conn_done_tx); }); Ok(()) } @@ -214,13 +210,16 @@ impl Server { } pub async fn disconnect(&mut self) { - if let Some(tx) = self.disconnect_tx.take() { - tx.send(1).ok(); - } + self.shutdown.shutdown(); - if let Some(mut rx) = self.all_conn_done_rx.take() { - rx.recv().await; - } + self.shutdown + .wait_all_exit() + .await + .map_err(|e| { + trace!("wait connection exit error: {}", e); + }) + .ok(); + trace!("wait connection exit."); } pub async fn stop_listen(&mut self) { @@ -239,17 +238,17 @@ async fn spawn_connection_handler( fd: RawFd, stream: S, methods: Arc>>, - mut close_conn_rx: watch::Receiver, - conn_done_tx: Sender, + shutdown_waiter: shutdown::Waiter, ) where S: AsyncRead + AsyncWrite + AsRawFd + Send + 'static, { - let (req_done_tx, mut all_req_done_rx) = channel::(1); - spawn(async move { let (mut reader, mut writer) = split(stream); let (tx, mut rx): (Sender>, Receiver>) = channel(100); - let (client_disconnected_tx, client_disconnected_rx) = watch::channel(false); + + let server_shutdown = shutdown_waiter.clone(); + let (disconnect_notifier, disconnect_waiter) = + shutdown::with_timeout(DEFAULT_CONN_SHUTDOWN_TIMEOUT); spawn(async move { while let Some(buf) = rx.recv().await { @@ -262,8 +261,7 @@ async fn spawn_connection_handler( loop { let tx = tx.clone(); let methods = methods.clone(); - let req_done_tx2 = req_done_tx.clone(); - let mut client_disconnected_rx2 = client_disconnected_rx.clone(); + let handler_shutdown_waiter = disconnect_waiter.clone(); select! { resp = receive(&mut reader) => { @@ -272,33 +270,32 @@ async fn spawn_connection_handler( spawn(async move { select! { _ = handle_request(tx, fd, methods, message) => {} - _ = client_disconnected_rx2.changed() => {} + _ = handler_shutdown_waiter.wait_shutdown() => {} } - - drop(req_done_tx2); }); } Err(e) => { - let _ = client_disconnected_tx.send(true); + disconnect_notifier.shutdown(); trace!("error {:?}", e); break; } } } - v = close_conn_rx.changed() => { - // 0 is the init value of this watch, not a valid signal - // is_err means the tx was dropped. - if v.is_err() || *close_conn_rx.borrow() != 0 { - info!("Stop accepting new connections."); - break; - } + _ = server_shutdown.wait_shutdown() => { + trace!("Receive shutdown."); + break; } } } - - drop(req_done_tx); - all_req_done_rx.recv().await; - drop(conn_done_tx); + // TODO: Don't disconnect_notifier.shutdown(); + // Wait pedding request/stream to exit. + disconnect_notifier + .wait_all_exit() + .await + .map_err(|e| { + trace!("wait handler exit error: {}", e); + }) + .ok(); }); } From 69bd03d4886a630900423ebea8fe7a959ef7f6fb Mon Sep 17 00:00:00 2001 From: wllenyj Date: Mon, 6 Jun 2022 17:55:31 +0800 Subject: [PATCH 07/15] async: Refactor message encoding/deocding. Make message encoding/decoding uniform. Signed-off-by: wllenyj --- src/asynchronous/client.rs | 27 +++----- src/asynchronous/server.rs | 49 ++++++++------ src/asynchronous/stream.rs | 133 ++++++------------------------------- src/asynchronous/utils.rs | 20 +----- 4 files changed, 64 insertions(+), 165 deletions(-) diff --git a/src/asynchronous/client.rs b/src/asynchronous/client.rs index abf61c92..c9def8f2 100644 --- a/src/asynchronous/client.rs +++ b/src/asynchronous/client.rs @@ -9,23 +9,16 @@ use std::os::unix::io::RawFd; use std::sync::{Arc, Mutex}; use nix::unistd::close; -use tokio::{ - self, - io::split, - sync::mpsc::{channel, Receiver, Sender}, - sync::Notify, -}; +use tokio::{self, io::split, sync::mpsc, sync::Notify}; use crate::common::client_connect; use crate::error::{Error, Result}; use crate::proto::{Code, Codec, GenMessage, Message, Request, Response, MESSAGE_TYPE_RESPONSE}; +use crate::r#async::stream::{ResultReceiver, ResultSender}; use crate::r#async::utils; -type RequestSender = Sender<(GenMessage, Sender>>)>; -type RequestReceiver = Receiver<(GenMessage, Sender>>)>; - -type ResponseSender = Sender>>; -type ResponseReceiver = Receiver>>; +type RequestSender = mpsc::Sender<(GenMessage, ResultSender)>; +type RequestReceiver = mpsc::Receiver<(GenMessage, ResultSender)>; /// A ttrpc Client (async). #[derive(Clone)] @@ -44,7 +37,7 @@ impl Client { let stream = utils::new_unix_stream_from_raw_fd(fd); let (mut reader, mut writer) = split(stream); - let (req_tx, mut rx): (RequestSender, RequestReceiver) = channel(100); + let (req_tx, mut rx): (RequestSender, RequestReceiver) = mpsc::channel(100); let req_map = Arc::new(Mutex::new(HashMap::new())); let req_map2 = req_map.clone(); @@ -131,7 +124,7 @@ impl Client { return; } - resp_tx2.send(Ok(msg.payload)).await.unwrap_or_else(|_e| error!("The request has returned")); + resp_tx2.send(Ok(msg)).await.unwrap_or_else(|_e| error!("The request has returned")); }); } Err(e) => { @@ -170,7 +163,7 @@ impl Client { .try_into() .map_err(|e: protobuf::error::ProtobufError| Error::Others(e.to_string()))?; - let (tx, mut rx): (ResponseSender, ResponseReceiver) = channel(100); + let (tx, mut rx): (ResultSender, ResultReceiver) = mpsc::channel(100); self.req_tx .send((msg, tx)) .await @@ -190,9 +183,9 @@ impl Client { .ok_or_else(|| Error::Others("Receive packet from receiver error".to_string()))? }; - let buf = result?; - let res = - Response::decode(&buf).map_err(err_to_others_err!(e, "Unpack response error "))?; + let msg = result?; + let res = Response::decode(&msg.payload) + .map_err(err_to_others_err!(e, "Unpack response error "))?; let status = res.get_status(); if status.get_code() != Code::OK { diff --git a/src/asynchronous/server.rs b/src/asynchronous/server.rs index 671e6b97..62253fbf 100644 --- a/src/asynchronous/server.rs +++ b/src/asynchronous/server.rs @@ -17,22 +17,23 @@ use futures::StreamExt as _; use nix::unistd; use tokio::{ self, - io::{split, AsyncRead, AsyncWrite, AsyncWriteExt}, + io::{split, AsyncRead, AsyncWrite}, net::UnixListener, select, spawn, - sync::mpsc::{channel, Receiver, Sender}, + sync::mpsc::{channel, Sender}, time::timeout, }; #[cfg(target_os = "linux")] use tokio_vsock::VsockListener; -use crate::asynchronous::stream::{receive, respond, respond_with_status}; +use crate::asynchronous::stream::{respond, respond_with_status}; use crate::asynchronous::unix_incoming::UnixIncoming; use crate::common::{self, Domain}; use crate::context; use crate::error::{get_status, Error, Result}; -use crate::proto::{Code, MessageHeader, Status, MESSAGE_TYPE_REQUEST}; +use crate::proto::{Code, GenMessage, MessageHeader, Response, Status, MESSAGE_TYPE_REQUEST}; use crate::r#async::shutdown; +use crate::r#async::stream::{MessageReceiver, MessageSender}; use crate::r#async::utils; use crate::r#async::{MethodHandler, TtrpcContext}; @@ -244,15 +245,15 @@ async fn spawn_connection_handler( { spawn(async move { let (mut reader, mut writer) = split(stream); - let (tx, mut rx): (Sender>, Receiver>) = channel(100); + let (tx, mut rx): (MessageSender, MessageReceiver) = channel(100); let server_shutdown = shutdown_waiter.clone(); let (disconnect_notifier, disconnect_waiter) = shutdown::with_timeout(DEFAULT_CONN_SHUTDOWN_TIMEOUT); spawn(async move { - while let Some(buf) = rx.recv().await { - if let Err(e) = writer.write_all(&buf).await { + while let Some(msg) = rx.recv().await { + if let Err(e) = msg.write_to(&mut writer).await { error!("write_message got error: {:?}", e); } } @@ -264,8 +265,8 @@ async fn spawn_connection_handler( let handler_shutdown_waiter = disconnect_waiter.clone(); select! { - resp = receive(&mut reader) => { - match resp { + res = GenMessage::read_from(&mut reader) => { + match res { Ok(message) => { spawn(async move { select! { @@ -304,7 +305,7 @@ async fn do_handle_request( methods: Arc>>, header: MessageHeader, body: &[u8], -) -> StdResult<(u32, Vec), Status> { +) -> StdResult, Status> { let req = utils::body_to_request(body)?; let path = utils::get_path(&req.service, &req.method); let method = methods @@ -328,6 +329,7 @@ async fn do_handle_request( .handler(ctx, req) .await .map_err(get_unknown_status_and_log_err) + .map(Some) } else { timeout( Duration::from_nanos(req.timeout_nano as u64), @@ -343,16 +345,20 @@ async fn do_handle_request( // Handler finished r.map_err(get_unknown_status_and_log_err) }) + .map(Some) } } async fn handle_request( - tx: Sender>, + tx: MessageSender, fd: RawFd, methods: Arc>>, - message: (MessageHeader, Vec), + message: GenMessage, ) { - let (header, body) = message; + let GenMessage { + header, + payload: body, + } = message; let stream_id = header.stream_id; if header.type_ != MESSAGE_TYPE_REQUEST { @@ -360,15 +366,18 @@ async fn handle_request( } match do_handle_request(fd, methods, header, &body).await { - Ok((stream_id, resp_body)) => { - if let Err(x) = respond(tx.clone(), stream_id, resp_body).await { - error!("respond got error {:?}", x); + Ok(opt_msg) => match opt_msg { + Some(msg) => { + if let Err(x) = respond(tx.clone(), stream_id, msg).await { + error!("respond got error {:?}", x); + } } - } - Err(status) => { - if let Err(x) = respond_with_status(tx.clone(), stream_id, status).await { - error!("respond got error {:?}", x); + None => { + unimplemented!(); } + }, + Err(status) => { + respond_with_status(tx.clone(), stream_id, status).await; } } } diff --git a/src/asynchronous/stream.rs b/src/asynchronous/stream.rs index f281da25..150b2585 100644 --- a/src/asynchronous/stream.rs +++ b/src/asynchronous/stream.rs @@ -3,129 +3,40 @@ // SPDX-License-Identifier: Apache-2.0 // -use protobuf::Message; -use tokio::io::AsyncReadExt; +use tokio::sync::mpsc; -use crate::error::{get_rpc_status, sock_error_msg, Error, Result}; -use crate::proto::{ - Code, Response, Status, MESSAGE_HEADER_LENGTH, MESSAGE_LENGTH_MAX, MESSAGE_TYPE_RESPONSE, -}; -use crate::r#async::utils; +use crate::error::{Error, Result}; +use crate::proto::{Codec, GenMessage, Response, Status}; use crate::MessageHeader; -async fn receive_count(reader: &mut T, count: usize) -> Result> -where - T: AsyncReadExt + std::marker::Unpin, -{ - let mut content = vec![0u8; count]; - if let Err(e) = reader.read_exact(&mut content).await { - return Err(Error::Socket(e.to_string())); - } +pub type MessageSender = mpsc::Sender; +pub type MessageReceiver = mpsc::Receiver; - Ok(content) -} - -async fn receive_header(reader: &mut T) -> Result -where - T: AsyncReadExt + std::marker::Unpin, -{ - let buf = receive_count(reader, MESSAGE_HEADER_LENGTH).await?; - let size = buf.len(); - if size != MESSAGE_HEADER_LENGTH { - return Err(sock_error_msg( - size, - format!("Message header length {} is too small", size), - )); - } - - let mh = MessageHeader::from(&buf); - - Ok(mh) -} - -pub(crate) async fn receive(reader: &mut T) -> Result<(MessageHeader, Vec)> -where - T: AsyncReadExt + std::marker::Unpin, -{ - let mh = receive_header(reader).await?; - trace!("Got Message header {:?}", mh); - - if mh.length > MESSAGE_LENGTH_MAX as u32 { - return Err(get_rpc_status( - Code::INVALID_ARGUMENT, - format!( - "message length {} exceed maximum message size of {}", - mh.length, MESSAGE_LENGTH_MAX - ), - )); - } - - let buf = receive_count(reader, mh.length as usize).await?; - let size = buf.len(); - if size != mh.length as usize { - return Err(sock_error_msg( - size, - format!("Message length {} is not {}", size, mh.length), - )); - } - trace!("Got Message body {:?}", buf); - - Ok((mh, buf)) -} - -fn header_to_buf(mh: MessageHeader) -> Vec { - mh.into() -} +pub type ResultSender = mpsc::Sender>; +pub type ResultReceiver = mpsc::Receiver>; -pub(crate) fn to_res_buf(stream_id: u32, mut body: Vec) -> Vec { - let header = utils::get_response_header_from_body(stream_id, &body); - let mut buf = header_to_buf(header); - buf.append(&mut body); - - buf -} - -fn get_response_body(res: &Response) -> Result> { - let mut buf = Vec::with_capacity(res.compute_size() as usize); - let mut s = protobuf::CodedOutputStream::vec(&mut buf); - res.write_to(&mut s).map_err(err_to_others_err!(e, ""))?; - s.flush().map_err(err_to_others_err!(e, ""))?; - - Ok(buf) -} - -pub(crate) async fn respond( - tx: tokio::sync::mpsc::Sender>, - stream_id: u32, - body: Vec, -) -> Result<()> { - let buf = to_res_buf(stream_id, body); +pub(crate) async fn respond(tx: MessageSender, stream_id: u32, resp: Response) -> Result<()> { + let payload = resp + .encode() + .map_err(err_to_others_err!(e, "Encode Response failed."))?; + let msg = GenMessage { + header: MessageHeader::new_response(stream_id, payload.len() as u32), + payload, + }; - tx.send(buf) + tx.send(msg) .await .map_err(err_to_others_err!(e, "Send packet to sender error ")) } -pub(crate) async fn respond_with_status( - tx: tokio::sync::mpsc::Sender>, - stream_id: u32, - status: Status, -) -> Result<()> { +pub(crate) async fn respond_with_status(tx: MessageSender, stream_id: u32, status: Status) { let mut res = Response::new(); res.set_status(status); - let mut body = get_response_body(&res)?; - - let mh = MessageHeader { - length: body.len() as u32, - stream_id, - type_: MESSAGE_TYPE_RESPONSE, - flags: 0, - }; - let mut buf = header_to_buf(mh); - buf.append(&mut body); - - tx.send(buf) + respond(tx, stream_id, res) .await - .map_err(err_to_others_err!(e, "Send packet to sender error ")) + .map_err(|e| { + error!("respond with status got error {:?}", e); + }) + .ok(); } diff --git a/src/asynchronous/utils.rs b/src/asynchronous/utils.rs index 68ce29cc..d171ea43 100644 --- a/src/asynchronous/utils.rs +++ b/src/asynchronous/utils.rs @@ -12,7 +12,7 @@ use protobuf::{CodedInputStream, Message}; use tokio::net::UnixStream; use crate::error::{get_status, Result}; -use crate::proto::{Code, MessageHeader, Request, Status, MESSAGE_TYPE_RESPONSE}; +use crate::proto::{Code, MessageHeader, Request, Response, Status}; /// Handle request in async mode. #[macro_export] @@ -48,12 +48,7 @@ macro_rules! async_request_handler { }, } - let mut buf = Vec::with_capacity(res.compute_size() as usize); - let mut s = protobuf::CodedOutputStream::vec(&mut buf); - res.write_to(&mut s).map_err(ttrpc::err_to_others!(e, ""))?; - s.flush().map_err(ttrpc::err_to_others!(e, ""))?; - - return Ok(($ctx.mh.stream_id, buf)); + return Ok(res); }; } @@ -88,7 +83,7 @@ macro_rules! async_client_request { /// Trait that implements handler which is a proxy to the desired method (async). #[async_trait] pub trait MethodHandler { - async fn handler(&self, ctx: TtrpcContext, req: Request) -> Result<(u32, Vec)>; + async fn handler(&self, ctx: TtrpcContext, req: Request) -> Result; } /// The context of ttrpc (async). @@ -100,15 +95,6 @@ pub struct TtrpcContext { pub timeout_nano: i64, } -pub(crate) fn get_response_header_from_body(stream_id: u32, body: &[u8]) -> MessageHeader { - MessageHeader { - length: body.len() as u32, - stream_id, - type_: MESSAGE_TYPE_RESPONSE, - flags: 0, - } -} - pub(crate) fn new_unix_stream_from_raw_fd(fd: RawFd) -> UnixStream { let std_stream: std::os::unix::net::UnixStream; unsafe { From 9f97207e473e9a55752ea5fa8a9a26908ab8767b Mon Sep 17 00:00:00 2001 From: wllenyj Date: Mon, 6 Jun 2022 17:02:24 +0800 Subject: [PATCH 08/15] async: add connection module. For abstract connections, our connections are handled as one sending task and one receiving task. We can use the same logic to handle it. Signed-off-by: wllenyj --- src/asynchronous/connection.rs | 110 +++++++++++++++++++++++++++++++++ src/asynchronous/mod.rs | 3 +- 2 files changed, 112 insertions(+), 1 deletion(-) create mode 100644 src/asynchronous/connection.rs diff --git a/src/asynchronous/connection.rs b/src/asynchronous/connection.rs new file mode 100644 index 00000000..8de87d3b --- /dev/null +++ b/src/asynchronous/connection.rs @@ -0,0 +1,110 @@ +// Copyright 2022 Alibaba Cloud. All rights reserved. +// Copyright (c) 2020 Ant Financial +// +// SPDX-License-Identifier: Apache-2.0 +// + +use std::os::unix::io::AsRawFd; + +use async_trait::async_trait; +use log::{error, trace}; +use tokio::{ + io::{split, AsyncRead, AsyncWrite, ReadHalf}, + select, task, +}; + +use crate::error::Error; +use crate::proto::GenMessage; + +pub trait Builder { + type Reader; + type Writer; + + fn build(&mut self) -> (Self::Reader, Self::Writer); +} + +#[async_trait] +pub trait WriterDelegate { + async fn recv(&mut self) -> Option; + async fn disconnect(&self, msg: &GenMessage, e: Error); + async fn exit(&self); +} + +#[async_trait] +pub trait ReaderDelegate { + async fn wait_shutdown(&self); + async fn disconnect(&self, e: Error, task: &mut task::JoinHandle<()>); + async fn exit(&self); + async fn handle_msg(&self, msg: GenMessage); +} + +pub struct Connection { + reader: ReadHalf, + writer_task: task::JoinHandle<()>, + reader_delegate: B::Reader, +} + +impl Connection +where + S: AsyncRead + AsyncWrite + AsRawFd + Send + 'static, + B: Builder, + B::Reader: ReaderDelegate + Send + Sync + 'static, + B::Writer: WriterDelegate + Send + Sync + 'static, +{ + pub fn new(conn: S, mut builder: B) -> Self { + let (reader, mut writer) = split(conn); + + let (reader_delegate, mut writer_delegate) = builder.build(); + + let writer_task = tokio::spawn(async move { + while let Some(msg) = writer_delegate.recv().await { + trace!("write message: {:?}", msg); + if let Err(e) = msg.write_to(&mut writer).await { + error!("write_message got error: {:?}", e); + writer_delegate.disconnect(&msg, e).await; + } + } + writer_delegate.exit().await; + trace!("Writer task exit."); + }); + + Self { + reader, + writer_task, + reader_delegate, + } + } + + pub async fn run(self) -> std::io::Result<()> { + let Connection { + mut reader, + mut writer_task, + reader_delegate, + } = self; + loop { + select! { + res = GenMessage::read_from(&mut reader) => { + match res { + Ok(msg) => { + trace!("Got Message {:?}", msg); + reader_delegate.handle_msg(msg).await; + } + Err(e) => { + trace!("Read msg err: {:?}", e); + reader_delegate.disconnect(e, &mut writer_task).await; + break; + } + } + } + _v = reader_delegate.wait_shutdown() => { + trace!("Receive shutdown."); + break; + } + } + } + reader_delegate.exit().await; + trace!("Reader task exit."); + + Ok(()) + } +} diff --git a/src/asynchronous/mod.rs b/src/asynchronous/mod.rs index 371aaa21..dd6e5fae 100644 --- a/src/asynchronous/mod.rs +++ b/src/asynchronous/mod.rs @@ -11,8 +11,9 @@ mod stream; #[macro_use] #[doc(hidden)] mod utils; -mod unix_incoming; +mod connection; pub mod shutdown; +mod unix_incoming; #[doc(inline)] pub use crate::r#async::client::Client; From 4cafce1b37f7fba8bac2589373000cf0e79fa292 Mon Sep 17 00:00:00 2001 From: wllenyj Date: Mon, 6 Jun 2022 18:45:29 +0800 Subject: [PATCH 09/15] async: Refactor connection handling. The client and server handle connections almost identically. Both use a sender side and a receiver side to handle the connection. Their respective differences are implemented using the delegate. Signed-off-by: wllenyj --- src/asynchronous/client.rs | 260 ++++++++++++++------------ src/asynchronous/server.rs | 367 ++++++++++++++++++++++++------------- src/asynchronous/stream.rs | 31 +--- src/asynchronous/utils.rs | 23 +-- 4 files changed, 390 insertions(+), 291 deletions(-) diff --git a/src/asynchronous/client.rs b/src/asynchronous/client.rs index c9def8f2..63a92d45 100644 --- a/src/asynchronous/client.rs +++ b/src/asynchronous/client.rs @@ -8,12 +8,15 @@ use std::convert::TryInto; use std::os::unix::io::RawFd; use std::sync::{Arc, Mutex}; +use async_trait::async_trait; use nix::unistd::close; -use tokio::{self, io::split, sync::mpsc, sync::Notify}; +use tokio::{self, sync::mpsc, task}; use crate::common::client_connect; use crate::error::{Error, Result}; use crate::proto::{Code, Codec, GenMessage, Message, Request, Response, MESSAGE_TYPE_RESPONSE}; +use crate::r#async::connection::*; +use crate::r#async::shutdown; use crate::r#async::stream::{ResultReceiver, ResultSender}; use crate::r#async::utils; @@ -36,122 +39,12 @@ impl Client { pub fn new(fd: RawFd) -> Client { let stream = utils::new_unix_stream_from_raw_fd(fd); - let (mut reader, mut writer) = split(stream); - let (req_tx, mut rx): (RequestSender, RequestReceiver) = mpsc::channel(100); + let (req_tx, rx): (RequestSender, RequestReceiver) = mpsc::channel(100); - let req_map = Arc::new(Mutex::new(HashMap::new())); - let req_map2 = req_map.clone(); - - let notify = Arc::new(Notify::new()); - let notify2 = notify.clone(); - - // Request sender - let request_sender = tokio::spawn(async move { - let mut stream_id: u32 = 1; - - while let Some((mut msg, resp_tx)) = rx.recv().await { - let current_stream_id = stream_id; - msg.header.set_stream_id(current_stream_id); - stream_id += 2; - - { - let mut map = req_map2.lock().unwrap(); - map.insert(current_stream_id, resp_tx.clone()); - } - - if let Err(e) = msg.write_to(&mut writer).await { - error!("write_message got error: {:?}", e); - - { - let mut map = req_map2.lock().unwrap(); - map.remove(¤t_stream_id); - } + let delegate = ClientBuilder { rx: Some(rx) }; - let e = Error::Socket(format!("{:?}", e)); - resp_tx - .send(Err(e)) - .await - .unwrap_or_else(|_e| error!("The request has returned")); - - break; // The stream is dead, exit the loop. - } - } - - // rx.recv will abort when client.req_tx and client is dropped. - // notify the response-receiver to quit at this time. - notify.notify_one(); - }); - - // Response receiver - tokio::spawn(async move { - loop { - tokio::select! { - _ = notify2.notified() => { - break; - } - res = GenMessage::read_from(&mut reader) => { - match res { - Ok(msg) => { - trace!("Got Message body {:?}", msg.payload); - let req_map = req_map.clone(); - tokio::spawn(async move { - let resp_tx2; - { - let mut map = req_map.lock().unwrap(); - let resp_tx = match map.get(&msg.header.stream_id) { - Some(tx) => tx, - None => { - debug!( - "Receiver got unknown packet {:?}", - msg - ); - return; - } - }; - - resp_tx2 = resp_tx.clone(); - map.remove(&msg.header.stream_id); // Forget the result, just remove. - } - - if msg.header.type_ != MESSAGE_TYPE_RESPONSE { - resp_tx2 - .send(Err(Error::Others(format!( - "Recver got malformed packet {:?}", - msg - )))) - .await - .unwrap_or_else(|_e| error!("The request has returned")); - return; - } - - resp_tx2.send(Ok(msg)).await.unwrap_or_else(|_e| error!("The request has returned")); - }); - } - Err(e) => { - debug!("Connection closed by the ttRPC server: {}", e); - - // Abort the request sender task to prevent incoming RPC requests - // from being processed. - request_sender.abort(); - let _ = request_sender.await; - - // Take all items out of `req_map`. - let mut map = std::mem::take(&mut *req_map.lock().unwrap()); - // Terminate outstanding RPC requests with the error. - for (_stream_id, resp_tx) in map.drain() { - if let Err(_e) = resp_tx.send(Err(e.clone())).await { - warn!("Failed to terminate pending RPC: \ - the request has returned"); - } - } - - break; - } - } - } - }; - } - }); + let conn = Connection::new(stream, delegate); + tokio::spawn(async move { conn.run().await }); Client { req_tx } } @@ -208,3 +101,140 @@ impl Drop for ClientClose { trace!("All client is droped"); } } + +#[derive(Debug)] +struct ClientBuilder { + rx: Option, +} + +impl Builder for ClientBuilder { + type Reader = ClientReader; + type Writer = ClientWriter; + + fn build(&mut self) -> (Self::Reader, Self::Writer) { + let (notifier, waiter) = shutdown::new(); + let req_map = Arc::new(Mutex::new(HashMap::new())); + ( + ClientReader { + shutdown_waiter: waiter, + req_map: req_map.clone(), + }, + ClientWriter { + stream_id: 1, + rx: self.rx.take().unwrap(), + shutdown_notifier: notifier, + req_map, + }, + ) + } +} + +struct ClientWriter { + stream_id: u32, + rx: RequestReceiver, + shutdown_notifier: shutdown::Notifier, + req_map: Arc>>, +} + +#[async_trait] +impl WriterDelegate for ClientWriter { + async fn recv(&mut self) -> Option { + if let Some((mut msg, resp_tx)) = self.rx.recv().await { + let current_stream_id = self.stream_id; + msg.header.set_stream_id(current_stream_id); + self.stream_id += 2; + { + let mut map = self.req_map.lock().unwrap(); + map.insert(current_stream_id, resp_tx); + } + return Some(msg); + } else { + return None; + } + } + + async fn disconnect(&self, msg: &GenMessage, e: Error) { + let resp_tx = { + let mut map = self.req_map.lock().unwrap(); + map.remove(&msg.header.stream_id) + }; + + if let Some(resp_tx) = resp_tx { + let e = Error::Socket(format!("{:?}", e)); + resp_tx + .send(Err(e)) + .await + .unwrap_or_else(|_e| error!("The request has returned")); + } + } + + async fn exit(&self) { + self.shutdown_notifier.shutdown(); + } +} + +struct ClientReader { + shutdown_waiter: shutdown::Waiter, + req_map: Arc>>, +} + +#[async_trait] +impl ReaderDelegate for ClientReader { + async fn wait_shutdown(&self) { + self.shutdown_waiter.wait_shutdown().await + } + + async fn disconnect(&self, e: Error, sender: &mut task::JoinHandle<()>) { + // Abort the request sender task to prevent incoming RPC requests + // from being processed. + sender.abort(); + let _ = sender.await; + + // Take all items out of `req_map`. + let mut map = std::mem::take(&mut *self.req_map.lock().unwrap()); + // Terminate outstanding RPC requests with the error. + for (_stream_id, resp_tx) in map.drain() { + if let Err(_e) = resp_tx.send(Err(e.clone())).await { + warn!("Failed to terminate pending RPC: the request has returned"); + } + } + } + + async fn exit(&self) {} + + async fn handle_msg(&self, msg: GenMessage) { + let req_map = self.req_map.clone(); + tokio::spawn(async move { + let resp_tx2; + { + let mut map = req_map.lock().unwrap(); + let resp_tx = match map.get(&msg.header.stream_id) { + Some(tx) => tx, + None => { + debug!("Receiver got unknown packet {:?}", msg); + return; + } + }; + + resp_tx2 = resp_tx.clone(); + map.remove(&msg.header.stream_id); // Forget the result, just remove. + } + + if msg.header.type_ != MESSAGE_TYPE_RESPONSE { + resp_tx2 + .send(Err(Error::Others(format!( + "Recver got malformed packet {:?}", + msg + )))) + .await + .unwrap_or_else(|_e| error!("The request has returned")); + return; + } + + resp_tx2 + .send(Ok(msg)) + .await + .unwrap_or_else(|_e| error!("The request has returned")); + }); + } +} diff --git a/src/asynchronous/server.rs b/src/asynchronous/server.rs index 62253fbf..1da858f8 100644 --- a/src/asynchronous/server.rs +++ b/src/asynchronous/server.rs @@ -4,6 +4,7 @@ // use std::collections::HashMap; +use std::convert::TryFrom; use std::marker::Unpin; use std::os::unix::io::RawFd; use std::os::unix::io::{AsRawFd, FromRawFd}; @@ -12,26 +13,31 @@ use std::result::Result as StdResult; use std::sync::Arc; use std::time::Duration; +use async_trait::async_trait; use futures::stream::Stream; use futures::StreamExt as _; use nix::unistd; use tokio::{ self, - io::{split, AsyncRead, AsyncWrite}, + io::{AsyncRead, AsyncWrite}, net::UnixListener, select, spawn, sync::mpsc::{channel, Sender}, + task, time::timeout, }; #[cfg(target_os = "linux")] use tokio_vsock::VsockListener; -use crate::asynchronous::stream::{respond, respond_with_status}; use crate::asynchronous::unix_incoming::UnixIncoming; use crate::common::{self, Domain}; use crate::context; use crate::error::{get_status, Error, Result}; -use crate::proto::{Code, GenMessage, MessageHeader, Response, Status, MESSAGE_TYPE_REQUEST}; +use crate::proto::{ + Code, Codec, GenMessage, Message, MessageHeader, Request, Response, Status, + MESSAGE_TYPE_REQUEST, +}; +use crate::r#async::connection::*; use crate::r#async::shutdown; use crate::r#async::stream::{MessageReceiver, MessageSender}; use crate::r#async::utils; @@ -235,161 +241,270 @@ impl Server { } } -async fn spawn_connection_handler( +async fn spawn_connection_handler( fd: RawFd, - stream: S, + conn: C, methods: Arc>>, shutdown_waiter: shutdown::Waiter, ) where - S: AsyncRead + AsyncWrite + AsRawFd + Send + 'static, + C: AsyncRead + AsyncWrite + AsRawFd + Send + 'static, { + let delegate = ServerBuilder { + fd, + methods, + shutdown_waiter, + }; + let conn = Connection::new(conn, delegate); spawn(async move { - let (mut reader, mut writer) = split(stream); - let (tx, mut rx): (MessageSender, MessageReceiver) = channel(100); + conn.run() + .await + .map_err(|e| { + trace!("connection run error. {}", e); + }) + .ok(); + }); +} + +impl FromRawFd for Server { + unsafe fn from_raw_fd(fd: RawFd) -> Self { + Self::default().add_listener(fd).unwrap() + } +} + +impl AsRawFd for Server { + fn as_raw_fd(&self) -> RawFd { + self.listeners[0] + } +} - let server_shutdown = shutdown_waiter.clone(); - let (disconnect_notifier, disconnect_waiter) = +struct ServerBuilder { + fd: RawFd, + methods: Arc>>, + shutdown_waiter: shutdown::Waiter, +} + +impl Builder for ServerBuilder { + type Reader = ServerReader; + type Writer = ServerWriter; + + fn build(&mut self) -> (Self::Reader, Self::Writer) { + let (tx, rx): (MessageSender, MessageReceiver) = channel(100); + let (disconnect_notifier, _disconnect_waiter) = shutdown::with_timeout(DEFAULT_CONN_SHUTDOWN_TIMEOUT); - spawn(async move { - while let Some(msg) = rx.recv().await { - if let Err(e) = msg.write_to(&mut writer).await { - error!("write_message got error: {:?}", e); - } - } - }); + ( + ServerReader { + fd: self.fd, + tx, + methods: self.methods.clone(), + server_shutdown: self.shutdown_waiter.clone(), + handler_shutdown: disconnect_notifier, + }, + ServerWriter { rx }, + ) + } +} - loop { - let tx = tx.clone(); - let methods = methods.clone(); - let handler_shutdown_waiter = disconnect_waiter.clone(); +struct ServerWriter { + rx: MessageReceiver, +} - select! { - res = GenMessage::read_from(&mut reader) => { - match res { - Ok(message) => { - spawn(async move { - select! { - _ = handle_request(tx, fd, methods, message) => {} - _ = handler_shutdown_waiter.wait_shutdown() => {} - } - }); - } - Err(e) => { - disconnect_notifier.shutdown(); - trace!("error {:?}", e); - break; - } - } - } - _ = server_shutdown.wait_shutdown() => { - trace!("Receive shutdown."); - break; - } - } - } - // TODO: Don't disconnect_notifier.shutdown(); +#[async_trait] +impl WriterDelegate for ServerWriter { + async fn recv(&mut self) -> Option { + self.rx.recv().await + } + async fn disconnect(&self, _msg: &GenMessage, _: Error) {} + async fn exit(&self) {} +} + +struct ServerReader { + fd: RawFd, + tx: MessageSender, + methods: Arc>>, + server_shutdown: shutdown::Waiter, + handler_shutdown: shutdown::Notifier, +} + +#[async_trait] +impl ReaderDelegate for ServerReader { + async fn wait_shutdown(&self) { + self.server_shutdown.wait_shutdown().await + } + + async fn disconnect(&self, _: Error, _: &mut task::JoinHandle<()>) { + self.handler_shutdown.shutdown(); + // TODO: Don't wait for all requests to complete? when the connection is disconnected. + } + + async fn exit(&self) { + // TODO: Don't self.conn_shutdown.shutdown(); // Wait pedding request/stream to exit. - disconnect_notifier + self.handler_shutdown .wait_all_exit() .await .map_err(|e| { trace!("wait handler exit error: {}", e); }) .ok(); - }); + } + + async fn handle_msg(&self, msg: GenMessage) { + let handler_shutdown_waiter = self.handler_shutdown.subscribe(); + let context = self.context(); + spawn(async move { + select! { + _ = context.handle_msg(msg) => {} + _ = handler_shutdown_waiter.wait_shutdown() => {} + } + }); + } +} + +impl ServerReader { + fn context(&self) -> HandlerContext { + HandlerContext { + fd: self.fd, + tx: self.tx.clone(), + methods: self.methods.clone(), + _handler_shutdown_waiter: self.handler_shutdown.subscribe(), + } + } } -async fn do_handle_request( +struct HandlerContext { fd: RawFd, + tx: MessageSender, methods: Arc>>, - header: MessageHeader, - body: &[u8], -) -> StdResult, Status> { - let req = utils::body_to_request(body)?; - let path = utils::get_path(&req.service, &req.method); - let method = methods - .get(&path) - .ok_or_else(|| get_status(Code::INVALID_ARGUMENT, format!("{} does not exist", &path)))?; - - let ctx = TtrpcContext { - fd, - mh: header, - metadata: context::from_pb(&req.metadata), - timeout_nano: req.timeout_nano, - }; + // Used for waiting handler exit. + _handler_shutdown_waiter: shutdown::Waiter, +} - let get_unknown_status_and_log_err = |e| { - error!("method handle {} got error {:?}", path, &e); - get_status(Code::UNKNOWN, e) - }; +impl HandlerContext { + async fn handle_msg(&self, msg: GenMessage) { + let stream_id = msg.header.stream_id; + + if (stream_id % 2) != 1 { + Self::respond_with_status( + self.tx.clone(), + stream_id, + get_status(Code::INVALID_ARGUMENT, "stream id must be odd"), + ) + .await; + return; + } - if req.timeout_nano == 0 { - method - .handler(ctx, req) - .await - .map_err(get_unknown_status_and_log_err) - .map(Some) - } else { - timeout( - Duration::from_nanos(req.timeout_nano as u64), - method.handler(ctx, req), - ) - .await - .map_err(|_| { - // Timed out - error!("method handle {} got error timed out", path); - get_status(Code::DEADLINE_EXCEEDED, "timeout") - }) - .and_then(|r| { - // Handler finished - r.map_err(get_unknown_status_and_log_err) - }) - .map(Some) + match msg.header.type_ { + MESSAGE_TYPE_REQUEST => match self.handle_request(msg).await { + Ok(opt_msg) => match opt_msg { + Some(msg) => { + Self::respond(self.tx.clone(), stream_id, msg) + .await + .map_err(|e| { + error!("respond got error {:?}", e); + }) + .ok(); + } + None => { + unimplemented!(); + } + }, + Err(status) => Self::respond_with_status(self.tx.clone(), stream_id, status).await, + }, + _ => { + // TODO: else we must ignore this for future compat. log this? + // TODO(wllenyj): Compatible with golang behavior. + error!("Unknown message type. {:?}", msg.header); + } + } } -} -async fn handle_request( - tx: MessageSender, - fd: RawFd, - methods: Arc>>, - message: GenMessage, -) { - let GenMessage { - header, - payload: body, - } = message; - let stream_id = header.stream_id; - - if header.type_ != MESSAGE_TYPE_REQUEST { - return; + async fn handle_request(&self, msg: GenMessage) -> StdResult, Status> { + //TODO: + //if header.stream_id <= self.last_stream_id { + // return Err; + //} + // self.last_stream_id = header.stream_id; + + let req_msg = Message::::try_from(msg) + .map_err(|e| get_status(Code::INVALID_ARGUMENT, e.to_string()))?; + + let req = &req_msg.payload; + trace!("Got Message request {} {}", req.service, req.method); + + let path = utils::get_path(&req.service, &req.method); + let method = self.methods.get(&path).ok_or_else(|| { + get_status(Code::INVALID_ARGUMENT, format!("{} does not exist", &path)) + })?; + + return self.handle_method(method.as_ref(), req_msg).await; } - match do_handle_request(fd, methods, header, &body).await { - Ok(opt_msg) => match opt_msg { - Some(msg) => { - if let Err(x) = respond(tx.clone(), stream_id, msg).await { - error!("respond got error {:?}", x); - } - } - None => { - unimplemented!(); - } - }, - Err(status) => { - respond_with_status(tx.clone(), stream_id, status).await; + async fn handle_method( + &self, + method: &(dyn MethodHandler + Send + Sync), + req_msg: Message, + ) -> StdResult, Status> { + let req = req_msg.payload; + let path = utils::get_path(&req.service, &req.method); + + let ctx = TtrpcContext { + fd: self.fd, + mh: req_msg.header, + metadata: context::from_pb(&req.metadata), + timeout_nano: req.timeout_nano, + }; + + let get_unknown_status_and_log_err = |e| { + error!("method handle {} got error {:?}", path, &e); + get_status(Code::UNKNOWN, e) + }; + if req.timeout_nano == 0 { + method + .handler(ctx, req) + .await + .map_err(get_unknown_status_and_log_err) + .map(Some) + } else { + timeout( + Duration::from_nanos(req.timeout_nano as u64), + method.handler(ctx, req), + ) + .await + .map_err(|_| { + // Timed out + error!("method handle {} got error timed out", path); + get_status(Code::DEADLINE_EXCEEDED, "timeout") + }) + .and_then(|r| { + // Handler finished + r.map_err(get_unknown_status_and_log_err) + }) + .map(Some) } } -} -impl FromRawFd for Server { - unsafe fn from_raw_fd(fd: RawFd) -> Self { - Self::default().add_listener(fd).unwrap() + async fn respond(tx: MessageSender, stream_id: u32, resp: Response) -> Result<()> { + let payload = resp + .encode() + .map_err(err_to_others_err!(e, "Encode Response failed."))?; + let msg = GenMessage { + header: MessageHeader::new_response(stream_id, payload.len() as u32), + payload, + }; + tx.send(msg) + .await + .map_err(err_to_others_err!(e, "Send packet to sender error ")) } -} -impl AsRawFd for Server { - fn as_raw_fd(&self) -> RawFd { - self.listeners[0] + async fn respond_with_status(tx: MessageSender, stream_id: u32, status: Status) { + let mut resp = Response::new(); + resp.set_status(status); + Self::respond(tx, stream_id, resp) + .await + .map_err(|e| { + error!("respond with status got error {:?}", e); + }) + .ok(); } } diff --git a/src/asynchronous/stream.rs b/src/asynchronous/stream.rs index 150b2585..03d10756 100644 --- a/src/asynchronous/stream.rs +++ b/src/asynchronous/stream.rs @@ -5,38 +5,11 @@ use tokio::sync::mpsc; -use crate::error::{Error, Result}; -use crate::proto::{Codec, GenMessage, Response, Status}; -use crate::MessageHeader; +use crate::error::Result; +use crate::proto::GenMessage; pub type MessageSender = mpsc::Sender; pub type MessageReceiver = mpsc::Receiver; pub type ResultSender = mpsc::Sender>; pub type ResultReceiver = mpsc::Receiver>; - -pub(crate) async fn respond(tx: MessageSender, stream_id: u32, resp: Response) -> Result<()> { - let payload = resp - .encode() - .map_err(err_to_others_err!(e, "Encode Response failed."))?; - let msg = GenMessage { - header: MessageHeader::new_response(stream_id, payload.len() as u32), - payload, - }; - - tx.send(msg) - .await - .map_err(err_to_others_err!(e, "Send packet to sender error ")) -} - -pub(crate) async fn respond_with_status(tx: MessageSender, stream_id: u32, status: Status) { - let mut res = Response::new(); - res.set_status(status); - - respond(tx, stream_id, res) - .await - .map_err(|e| { - error!("respond with status got error {:?}", e); - }) - .ok(); -} diff --git a/src/asynchronous/utils.rs b/src/asynchronous/utils.rs index d171ea43..7cb038b4 100644 --- a/src/asynchronous/utils.rs +++ b/src/asynchronous/utils.rs @@ -5,14 +5,12 @@ use std::collections::HashMap; use std::os::unix::io::{FromRawFd, RawFd}; -use std::result::Result as StdResult; use async_trait::async_trait; -use protobuf::{CodedInputStream, Message}; use tokio::net::UnixStream; -use crate::error::{get_status, Result}; -use crate::proto::{Code, MessageHeader, Request, Response, Status}; +use crate::error::Result; +use crate::proto::{MessageHeader, Request, Response}; /// Handle request in async mode. #[macro_export] @@ -106,23 +104,6 @@ pub(crate) fn new_unix_stream_from_raw_fd(fd: RawFd) -> UnixStream { UnixStream::from_std(std_stream).unwrap() } -pub(crate) fn body_to_request(body: &[u8]) -> StdResult { - let mut req = Request::new(); - let merge_result; - { - let mut s = CodedInputStream::from_bytes(body); - merge_result = req.merge_from(&mut s); - } - - if merge_result.is_err() { - return Err(get_status(Code::INVALID_ARGUMENT, "".to_string())); - } - - trace!("Got Message request {:?}", req); - - Ok(req) -} - pub(crate) fn get_path(service: &str, method: &str) -> String { format!("/{}/{}", service, method) } From 314ec95a5f37dd5acf63a388fd2d69a25c4bd9a8 Mon Sep 17 00:00:00 2001 From: wllenyj Date: Mon, 6 Jun 2022 23:16:51 +0800 Subject: [PATCH 10/15] async: add streaming support for the base component Added protocol support for streaming. Added `StreamInner` struct for streaming operations and Streaming related errors. Signed-off-by: wllenyj --- src/asynchronous/stream.rs | 184 ++++++++++++++++++++++++++++++++++++- src/error.rs | 9 ++ src/proto.rs | 21 ++++- 3 files changed, 211 insertions(+), 3 deletions(-) diff --git a/src/asynchronous/stream.rs b/src/asynchronous/stream.rs index 03d10756..404a73be 100644 --- a/src/asynchronous/stream.rs +++ b/src/asynchronous/stream.rs @@ -3,13 +3,193 @@ // SPDX-License-Identifier: Apache-2.0 // +use std::collections::HashMap; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; + use tokio::sync::mpsc; -use crate::error::Result; -use crate::proto::GenMessage; +use crate::error::{Error, Result}; +use crate::proto::{ + Code, Codec, GenMessage, MessageHeader, Response, FLAG_NO_DATA, FLAG_REMOTE_CLOSED, + MESSAGE_TYPE_DATA, MESSAGE_TYPE_RESPONSE, +}; pub type MessageSender = mpsc::Sender; pub type MessageReceiver = mpsc::Receiver; pub type ResultSender = mpsc::Sender>; pub type ResultReceiver = mpsc::Receiver>; + +async fn _recv(rx: &mut ResultReceiver) -> Result { + rx.recv() + .await + .unwrap_or_else(|| Err(Error::Others("Receive packet from recver error".to_string()))) +} + +async fn _send(tx: &MessageSender, msg: GenMessage) -> Result<()> { + tx.send(msg) + .await + .map_err(|e| Error::Others(format!("Send data packet to sender error {:?}", e))) +} + +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum Kind { + Client, + Server, +} + +#[derive(Debug)] +pub struct StreamInner { + sender: StreamSender, + receiver: StreamReceiver, +} + +impl StreamInner { + pub fn new( + stream_id: u32, + tx: MessageSender, + rx: ResultReceiver, + //waiter: shutdown::Waiter, + sendable: bool, + recveivable: bool, + kind: Kind, + streams: Arc>>, + ) -> Self { + Self { + sender: StreamSender { + tx, + stream_id, + sendable, + local_closed: Arc::new(AtomicBool::new(false)), + kind, + }, + receiver: StreamReceiver { + rx, + stream_id, + recveivable, + remote_closed: false, + kind, + streams, + }, + } + } + + fn split(self) -> (StreamSender, StreamReceiver) { + (self.sender, self.receiver) + } + + pub async fn send(&self, buf: Vec) -> Result<()> { + self.sender.send(buf).await + } + + pub async fn close_send(&self) -> Result<()> { + self.sender.close_send().await + } + + pub async fn recv(&mut self) -> Result> { + self.receiver.recv().await + } +} + +#[derive(Clone, Debug)] +pub struct StreamSender { + tx: MessageSender, + stream_id: u32, + sendable: bool, + local_closed: Arc, + kind: Kind, +} + +#[derive(Debug)] +pub struct StreamReceiver { + rx: ResultReceiver, + stream_id: u32, + recveivable: bool, + remote_closed: bool, + kind: Kind, + streams: Arc>>, +} + +impl Drop for StreamReceiver { + fn drop(&mut self) { + self.streams.lock().unwrap().remove(&self.stream_id); + } +} + +impl StreamSender { + pub async fn send(&self, buf: Vec) -> Result<()> { + debug_assert!(self.sendable); + if self.local_closed.load(Ordering::Relaxed) { + debug_assert_eq!(self.kind, Kind::Client); + return Err(Error::LocalClosed); + } + let header = MessageHeader::new_data(self.stream_id, buf.len() as u32); + let msg = GenMessage { + header, + payload: buf, + }; + _send(&self.tx, msg).await?; + + Ok(()) + } + + pub async fn close_send(&self) -> Result<()> { + debug_assert_eq!(self.kind, Kind::Client); + debug_assert!(self.sendable); + if self.local_closed.load(Ordering::Relaxed) { + return Err(Error::LocalClosed); + } + let mut header = MessageHeader::new_data(self.stream_id, 0); + header.set_flags(FLAG_REMOTE_CLOSED | FLAG_NO_DATA); + let msg = GenMessage { + header, + payload: Vec::new(), + }; + _send(&self.tx, msg).await?; + self.local_closed.store(true, Ordering::Relaxed); + Ok(()) + } +} + +impl StreamReceiver { + pub async fn recv(&mut self) -> Result> { + if self.remote_closed { + return Err(Error::RemoteClosed); + } + let msg = _recv(&mut self.rx).await?; + let payload = match msg.header.type_ { + MESSAGE_TYPE_RESPONSE => { + debug_assert_eq!(self.kind, Kind::Client); + self.remote_closed = true; + let resp = Response::decode(&msg.payload) + .map_err(err_to_others_err!(e, "Decode message failed."))?; + if let Some(status) = resp.status.as_ref() { + if status.get_code() != Code::OK { + return Err(Error::RpcStatus((*status).clone())); + } + } + resp.payload + } + MESSAGE_TYPE_DATA => { + if !self.recveivable { + self.remote_closed = true; + return Err(Error::Others( + "received data from non-streaming server.".to_string(), + )); + } + if (msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED { + self.remote_closed = true; + if (msg.header.flags & FLAG_NO_DATA) == FLAG_NO_DATA { + return Err(Error::Eof); + } + } + msg.payload + } + _ => { + return Err(Error::Others("not support".to_string())); + } + }; + Ok(payload) + } +} diff --git a/src/error.rs b/src/error.rs index 1e9afc21..4f7e2e3a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -30,6 +30,15 @@ pub enum Error { #[error("Nix error: {0}")] Nix(#[from] nix::Error), + #[error("ttrpc err: local stream closed")] + LocalClosed, + + #[error("ttrpc err: remote stream closed")] + RemoteClosed, + + #[error("eof")] + Eof, + #[error("ttrpc err: {0}")] Others(String), } diff --git a/src/proto.rs b/src/proto.rs index be8c4add..65d477b4 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -21,6 +21,11 @@ pub const MESSAGE_LENGTH_MAX: usize = 4 << 20; pub const MESSAGE_TYPE_REQUEST: u8 = 0x1; pub const MESSAGE_TYPE_RESPONSE: u8 = 0x2; +pub const MESSAGE_TYPE_DATA: u8 = 0x3; + +pub const FLAG_REMOTE_CLOSED: u8 = 0x1; +pub const FLAG_REMOTE_OPEN: u8 = 0x2; +pub const FLAG_NO_DATA: u8 = 0x4; /// Message header of ttrpc. #[derive(Default, Debug, Clone, Copy, PartialEq)] @@ -57,6 +62,7 @@ impl From for Vec { impl MessageHeader { /// Creates a request MessageHeader from stream_id and len. + /// /// Use the default message type MESSAGE_TYPE_REQUEST, and default flags 0. pub fn new_request(stream_id: u32, len: u32) -> Self { Self { @@ -68,7 +74,8 @@ impl MessageHeader { } /// Creates a response MessageHeader from stream_id and len. - /// Use the default message type MESSAGE_TYPE_RESPONSE, and default flags 0. + /// + /// Use the MESSAGE_TYPE_RESPONSE message type, and default flags 0. pub fn new_response(stream_id: u32, len: u32) -> Self { Self { length: len, @@ -78,6 +85,18 @@ impl MessageHeader { } } + /// Creates a data MessageHeader from stream_id and len. + /// + /// Use the MESSAGE_TYPE_DATA message type, and default flags 0. + pub fn new_data(stream_id: u32, len: u32) -> Self { + Self { + length: len, + stream_id, + type_: MESSAGE_TYPE_DATA, + flags: 0, + } + } + /// Set the stream_id of message using the given value. pub fn set_stream_id(&mut self, stream_id: u32) { self.stream_id = stream_id; From 1126e4356bb706244bd8882faa9b35748e8641c9 Mon Sep 17 00:00:00 2001 From: wllenyj Date: Mon, 6 Jun 2022 23:27:03 +0800 Subject: [PATCH 11/15] async: add streaming support for client. Added streaming support for client-side. Signed-off-by: wanglei01 --- src/asynchronous/client.rs | 195 +++++++++++++++++++++++++------------ 1 file changed, 133 insertions(+), 62 deletions(-) diff --git a/src/asynchronous/client.rs b/src/asynchronous/client.rs index 63a92d45..bf6e44ed 100644 --- a/src/asynchronous/client.rs +++ b/src/asynchronous/client.rs @@ -1,3 +1,4 @@ +// Copyright 2022 Alibaba Cloud. All rights reserved. // Copyright (c) 2020 Ant Financial // // SPDX-License-Identifier: Apache-2.0 @@ -6,6 +7,7 @@ use std::collections::HashMap; use std::convert::TryInto; use std::os::unix::io::RawFd; +use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::{Arc, Mutex}; use async_trait::async_trait; @@ -14,19 +16,23 @@ use tokio::{self, sync::mpsc, task}; use crate::common::client_connect; use crate::error::{Error, Result}; -use crate::proto::{Code, Codec, GenMessage, Message, Request, Response, MESSAGE_TYPE_RESPONSE}; +use crate::proto::{ + Code, Codec, GenMessage, Message, Request, Response, FLAG_REMOTE_CLOSED, FLAG_REMOTE_OPEN, + MESSAGE_TYPE_DATA, MESSAGE_TYPE_RESPONSE, +}; use crate::r#async::connection::*; use crate::r#async::shutdown; -use crate::r#async::stream::{ResultReceiver, ResultSender}; +use crate::r#async::stream::{ + Kind, MessageReceiver, MessageSender, ResultReceiver, ResultSender, StreamInner, +}; use crate::r#async::utils; -type RequestSender = mpsc::Sender<(GenMessage, ResultSender)>; -type RequestReceiver = mpsc::Receiver<(GenMessage, ResultSender)>; - /// A ttrpc Client (async). #[derive(Clone)] pub struct Client { - req_tx: RequestSender, + req_tx: MessageSender, + next_stream_id: Arc, + streams: Arc>>, } impl Client { @@ -39,26 +45,40 @@ impl Client { pub fn new(fd: RawFd) -> Client { let stream = utils::new_unix_stream_from_raw_fd(fd); - let (req_tx, rx): (RequestSender, RequestReceiver) = mpsc::channel(100); + let (req_tx, rx): (MessageSender, MessageReceiver) = mpsc::channel(100); - let delegate = ClientBuilder { rx: Some(rx) }; + let req_map = Arc::new(Mutex::new(HashMap::new())); + let delegate = ClientBuilder { + rx: Some(rx), + streams: req_map.clone(), + }; let conn = Connection::new(stream, delegate); tokio::spawn(async move { conn.run().await }); - Client { req_tx } + Client { + req_tx, + next_stream_id: Arc::new(AtomicU32::new(1)), + streams: req_map, + } } /// Requsts a unary request and returns with response. pub async fn request(&self, req: Request) -> Result { let timeout_nano = req.timeout_nano; - let msg: GenMessage = Message::new_request(0, req) + let stream_id = self.next_stream_id.fetch_add(2, Ordering::Relaxed); + + let msg: GenMessage = Message::new_request(stream_id, req) .try_into() .map_err(|e: protobuf::error::ProtobufError| Error::Others(e.to_string()))?; let (tx, mut rx): (ResultSender, ResultReceiver) = mpsc::channel(100); + + // TODO: check return. + self.streams.lock().unwrap().insert(stream_id, tx); + self.req_tx - .send((msg, tx)) + .send(msg) .await .map_err(|e| Error::Others(format!("Send packet to sender error {:?}", e)))?; @@ -87,6 +107,44 @@ impl Client { Ok(res) } + + /// Creates a StreamInner instance. + pub async fn new_stream( + &self, + req: Request, + streaming_client: bool, + streaming_server: bool, + ) -> Result { + let stream_id = self.next_stream_id.fetch_add(2, Ordering::Relaxed); + + let mut msg: GenMessage = Message::new_request(stream_id, req) + .try_into() + .map_err(|e: protobuf::error::ProtobufError| Error::Others(e.to_string()))?; + + if streaming_client { + msg.header.add_flags(FLAG_REMOTE_OPEN); + } else { + msg.header.add_flags(FLAG_REMOTE_CLOSED); + } + + let (tx, rx): (ResultSender, ResultReceiver) = mpsc::channel(100); + // TODO: check return + self.streams.lock().unwrap().insert(stream_id, tx); + self.req_tx + .send(msg) + .await + .map_err(|e| Error::Others(format!("Send packet to sender error {:?}", e)))?; + + Ok(StreamInner::new( + stream_id, + self.req_tx.clone(), + rx, + streaming_client, + streaming_server, + Kind::Client, + self.streams.clone(), + )) + } } struct ClientClose { @@ -104,7 +162,8 @@ impl Drop for ClientClose { #[derive(Debug)] struct ClientBuilder { - rx: Option, + rx: Option, + streams: Arc>>, } impl Builder for ClientBuilder { @@ -113,52 +172,43 @@ impl Builder for ClientBuilder { fn build(&mut self) -> (Self::Reader, Self::Writer) { let (notifier, waiter) = shutdown::new(); - let req_map = Arc::new(Mutex::new(HashMap::new())); ( ClientReader { shutdown_waiter: waiter, - req_map: req_map.clone(), + streams: self.streams.clone(), }, ClientWriter { - stream_id: 1, rx: self.rx.take().unwrap(), shutdown_notifier: notifier, - req_map, + + streams: self.streams.clone(), }, ) } } struct ClientWriter { - stream_id: u32, - rx: RequestReceiver, + rx: MessageReceiver, shutdown_notifier: shutdown::Notifier, - req_map: Arc>>, + + streams: Arc>>, } #[async_trait] impl WriterDelegate for ClientWriter { async fn recv(&mut self) -> Option { - if let Some((mut msg, resp_tx)) = self.rx.recv().await { - let current_stream_id = self.stream_id; - msg.header.set_stream_id(current_stream_id); - self.stream_id += 2; - { - let mut map = self.req_map.lock().unwrap(); - map.insert(current_stream_id, resp_tx); - } - return Some(msg); - } else { - return None; - } + self.rx.recv().await } async fn disconnect(&self, msg: &GenMessage, e: Error) { + // TODO: + // At this point, a new request may have been received. let resp_tx = { - let mut map = self.req_map.lock().unwrap(); + let mut map = self.streams.lock().unwrap(); map.remove(&msg.header.stream_id) }; + // TODO: if None if let Some(resp_tx) = resp_tx { let e = Error::Socket(format!("{:?}", e)); resp_tx @@ -174,8 +224,8 @@ impl WriterDelegate for ClientWriter { } struct ClientReader { + streams: Arc>>, shutdown_waiter: shutdown::Waiter, - req_map: Arc>>, } #[async_trait] @@ -191,8 +241,8 @@ impl ReaderDelegate for ClientReader { let _ = sender.await; // Take all items out of `req_map`. - let mut map = std::mem::take(&mut *self.req_map.lock().unwrap()); - // Terminate outstanding RPC requests with the error. + let mut map = std::mem::take(&mut *self.streams.lock().unwrap()); + // Terminate undone RPC requests with the error. for (_stream_id, resp_tx) in map.drain() { if let Err(_e) = resp_tx.send(Err(e.clone())).await { warn!("Failed to terminate pending RPC: the request has returned"); @@ -203,35 +253,56 @@ impl ReaderDelegate for ClientReader { async fn exit(&self) {} async fn handle_msg(&self, msg: GenMessage) { - let req_map = self.req_map.clone(); + let req_map = self.streams.clone(); tokio::spawn(async move { - let resp_tx2; - { - let mut map = req_map.lock().unwrap(); - let resp_tx = match map.get(&msg.header.stream_id) { - Some(tx) => tx, - None => { - debug!("Receiver got unknown packet {:?}", msg); - return; + let resp_tx = match msg.header.type_ { + MESSAGE_TYPE_RESPONSE => { + match req_map.lock().unwrap().remove(&msg.header.stream_id) { + Some(tx) => tx, + None => { + debug!("Receiver got unknown response packet {:?}", msg); + return; + } } - }; - - resp_tx2 = resp_tx.clone(); - map.remove(&msg.header.stream_id); // Forget the result, just remove. - } - - if msg.header.type_ != MESSAGE_TYPE_RESPONSE { - resp_tx2 - .send(Err(Error::Others(format!( - "Recver got malformed packet {:?}", - msg - )))) - .await - .unwrap_or_else(|_e| error!("The request has returned")); - return; - } - - resp_tx2 + } + MESSAGE_TYPE_DATA => { + if (msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED { + match req_map.lock().unwrap().remove(&msg.header.stream_id) { + Some(tx) => tx.clone(), + None => { + debug!("Receiver got unknown data packet {:?}", msg); + return; + } + } + } else { + match req_map.lock().unwrap().get(&msg.header.stream_id) { + Some(tx) => tx.clone(), + None => { + debug!("Receiver got unknown data packet {:?}", msg); + return; + } + } + } + } + _ => { + let resp_tx = match req_map.lock().unwrap().remove(&msg.header.stream_id) { + Some(tx) => tx, + None => { + debug!("Receiver got unknown packet {:?}", msg); + return; + } + }; + resp_tx + .send(Err(Error::Others(format!( + "Recver got malformed packet {:?}", + msg + )))) + .await + .unwrap_or_else(|_e| error!("The request has returned")); + return; + } + }; + resp_tx .send(Ok(msg)) .await .unwrap_or_else(|_e| error!("The request has returned")); From 78ffcb8ad7bdee2bf0229ef5ba7d3e03948bbbaa Mon Sep 17 00:00:00 2001 From: wllenyj Date: Mon, 6 Jun 2022 23:42:45 +0800 Subject: [PATCH 12/15] async: add streaming support for server Added streaming support for server-side. Signed-off-by: wanglei01 --- src/asynchronous/mod.rs | 5 +- src/asynchronous/server.rs | 199 +++++++++++++++++++++++++++++++------ src/asynchronous/utils.rs | 10 ++ 3 files changed, 182 insertions(+), 32 deletions(-) diff --git a/src/asynchronous/mod.rs b/src/asynchronous/mod.rs index dd6e5fae..2e976b5f 100644 --- a/src/asynchronous/mod.rs +++ b/src/asynchronous/mod.rs @@ -15,9 +15,10 @@ mod connection; pub mod shutdown; mod unix_incoming; +pub use self::stream::{Kind, StreamInner}; #[doc(inline)] pub use crate::r#async::client::Client; #[doc(inline)] -pub use crate::r#async::server::Server; +pub use crate::r#async::server::{Server, Service}; #[doc(inline)] -pub use utils::{MethodHandler, TtrpcContext}; +pub use utils::{MethodHandler, StreamHandler, TtrpcContext}; diff --git a/src/asynchronous/server.rs b/src/asynchronous/server.rs index 1da858f8..a19dac30 100644 --- a/src/asynchronous/server.rs +++ b/src/asynchronous/server.rs @@ -1,3 +1,4 @@ +// Copyright 2022 Alibaba Cloud. All rights reserved. // Copyright (c) 2020 Ant Financial // // SPDX-License-Identifier: Apache-2.0 @@ -10,7 +11,7 @@ use std::os::unix::io::RawFd; use std::os::unix::io::{AsRawFd, FromRawFd}; use std::os::unix::net::UnixListener as SysUnixListener; use std::result::Result as StdResult; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::time::Duration; use async_trait::async_trait; @@ -34,22 +35,39 @@ use crate::common::{self, Domain}; use crate::context; use crate::error::{get_status, Error, Result}; use crate::proto::{ - Code, Codec, GenMessage, Message, MessageHeader, Request, Response, Status, - MESSAGE_TYPE_REQUEST, + Code, Codec, GenMessage, Message, MessageHeader, Request, Response, Status, FLAG_NO_DATA, + FLAG_REMOTE_CLOSED, FLAG_REMOTE_OPEN, MESSAGE_TYPE_DATA, MESSAGE_TYPE_REQUEST, }; use crate::r#async::connection::*; use crate::r#async::shutdown; -use crate::r#async::stream::{MessageReceiver, MessageSender}; +use crate::r#async::stream::{ + Kind, MessageReceiver, MessageSender, ResultReceiver, ResultSender, StreamInner, +}; use crate::r#async::utils; -use crate::r#async::{MethodHandler, TtrpcContext}; +use crate::r#async::{MethodHandler, StreamHandler, TtrpcContext}; const DEFAULT_CONN_SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(5000); const DEFAULT_SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(10000); +pub struct Service { + pub methods: HashMap>, + pub streams: HashMap>, +} + +impl Service { + pub(crate) fn get_method(&self, name: &str) -> Option<&(dyn MethodHandler + Send + Sync)> { + self.methods.get(name).map(|b| b.as_ref()) + } + + pub(crate) fn get_stream(&self, name: &str) -> Option> { + self.streams.get(name).cloned() + } +} + /// A ttrpc Server (async). pub struct Server { listeners: Vec, - methods: Arc>>, + services: Arc>, domain: Option, shutdown: shutdown::Notifier, @@ -60,7 +78,7 @@ impl Default for Server { fn default() -> Self { Server { listeners: Vec::with_capacity(1), - methods: Arc::new(HashMap::new()), + services: Arc::new(HashMap::new()), domain: None, shutdown: shutdown::with_timeout(DEFAULT_SERVER_SHUTDOWN_TIMEOUT).0, stop_listen_tx: None, @@ -105,12 +123,9 @@ impl Server { Ok(self) } - pub fn register_service( - mut self, - methods: HashMap>, - ) -> Server { - let mut_methods = Arc::get_mut(&mut self.methods).unwrap(); - mut_methods.extend(methods); + pub fn register_service(mut self, new: HashMap) -> Server { + let services = Arc::get_mut(&mut self.services).unwrap(); + services.extend(new); self } @@ -158,7 +173,7 @@ impl Server { I: Stream> + Unpin + Send + 'static + AsRawFd, S: AsyncRead + AsyncWrite + AsRawFd + Send + 'static, { - let methods = self.methods.clone(); + let services = self.services.clone(); let shutdown_waiter = self.shutdown.subscribe(); @@ -172,13 +187,13 @@ impl Server { if let Some(conn) = conn { // Accept a new connection match conn { - Ok(stream) => { - let fd = stream.as_raw_fd(); + Ok(conn) => { + let fd = conn.as_raw_fd(); // spawn a connection handler, would not block spawn_connection_handler( fd, - stream, - methods.clone(), + conn, + services.clone(), shutdown_waiter.clone(), ).await; } @@ -244,14 +259,15 @@ impl Server { async fn spawn_connection_handler( fd: RawFd, conn: C, - methods: Arc>>, + services: Arc>, shutdown_waiter: shutdown::Waiter, ) where C: AsyncRead + AsyncWrite + AsRawFd + Send + 'static, { let delegate = ServerBuilder { fd, - methods, + services, + streams: Arc::new(Mutex::new(HashMap::new())), shutdown_waiter, }; let conn = Connection::new(conn, delegate); @@ -279,7 +295,8 @@ impl AsRawFd for Server { struct ServerBuilder { fd: RawFd, - methods: Arc>>, + services: Arc>, + streams: Arc>>, shutdown_waiter: shutdown::Waiter, } @@ -296,7 +313,8 @@ impl Builder for ServerBuilder { ServerReader { fd: self.fd, tx, - methods: self.methods.clone(), + services: self.services.clone(), + streams: self.streams.clone(), server_shutdown: self.shutdown_waiter.clone(), handler_shutdown: disconnect_notifier, }, @@ -321,7 +339,8 @@ impl WriterDelegate for ServerWriter { struct ServerReader { fd: RawFd, tx: MessageSender, - methods: Arc>>, + services: Arc>, + streams: Arc>>, server_shutdown: shutdown::Waiter, handler_shutdown: shutdown::Notifier, } @@ -366,7 +385,8 @@ impl ServerReader { HandlerContext { fd: self.fd, tx: self.tx.clone(), - methods: self.methods.clone(), + services: self.services.clone(), + streams: self.streams.clone(), _handler_shutdown_waiter: self.handler_shutdown.subscribe(), } } @@ -375,7 +395,8 @@ impl ServerReader { struct HandlerContext { fd: RawFd, tx: MessageSender, - methods: Arc>>, + services: Arc>, + streams: Arc>>, // Used for waiting handler exit. _handler_shutdown_waiter: shutdown::Waiter, } @@ -406,11 +427,63 @@ impl HandlerContext { .ok(); } None => { - unimplemented!(); + let mut header = MessageHeader::new_data(stream_id, 0); + header.set_flags(FLAG_REMOTE_CLOSED | FLAG_NO_DATA); + let msg = GenMessage { + header, + payload: Vec::new(), + }; + + self.tx + .send(msg) + .await + .map_err(err_to_others_err!(e, "Send packet to sender error ")) + .ok(); } }, Err(status) => Self::respond_with_status(self.tx.clone(), stream_id, status).await, }, + MESSAGE_TYPE_DATA => { + // TODO(wllenyj): Compatible with golang behavior. + if (msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED + && !msg.payload.is_empty() + { + Self::respond_with_status( + self.tx.clone(), + stream_id, + get_status( + Code::INVALID_ARGUMENT, + format!( + "Stream id {}: data close message connot include data", + stream_id + ), + ), + ) + .await; + return; + } + let stream_tx = self.streams.lock().unwrap().get(&stream_id).cloned(); + if let Some(stream_tx) = stream_tx { + if let Err(e) = stream_tx.send(Ok(msg)).await { + Self::respond_with_status( + self.tx.clone(), + stream_id, + get_status( + Code::INVALID_ARGUMENT, + format!("Stream id {}: handling data error: {}", stream_id, e), + ), + ) + .await; + } + } else { + Self::respond_with_status( + self.tx.clone(), + stream_id, + get_status(Code::INVALID_ARGUMENT, "Stream is no longer active"), + ) + .await; + } + } _ => { // TODO: else we must ignore this for future compat. log this? // TODO(wllenyj): Compatible with golang behavior. @@ -432,12 +505,23 @@ impl HandlerContext { let req = &req_msg.payload; trace!("Got Message request {} {}", req.service, req.method); - let path = utils::get_path(&req.service, &req.method); - let method = self.methods.get(&path).ok_or_else(|| { - get_status(Code::INVALID_ARGUMENT, format!("{} does not exist", &path)) + let srv = self.services.get(&req.service).ok_or_else(|| { + get_status( + Code::INVALID_ARGUMENT, + format!("{} service does not exist", &req.service), + ) })?; - return self.handle_method(method.as_ref(), req_msg).await; + if let Some(method) = srv.get_method(&req.method) { + return self.handle_method(method, req_msg).await; + } + if let Some(stream) = srv.get_stream(&req.method) { + return self.handle_stream(stream, req_msg).await; + } + Err(get_status( + Code::UNIMPLEMENTED, + format!("{} method", &req.method), + )) } async fn handle_method( @@ -484,6 +568,61 @@ impl HandlerContext { } } + async fn handle_stream( + &self, + stream: Arc, + req_msg: Message, + ) -> StdResult, Status> { + let stream_id = req_msg.header.stream_id; + let req = req_msg.payload; + let path = utils::get_path(&req.service, &req.method); + + let (tx, rx): (ResultSender, ResultReceiver) = channel(100); + let stream_tx = tx.clone(); + self.streams.lock().unwrap().insert(stream_id, tx); + + let _remote_close = (req_msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED; + let _remote_open = (req_msg.header.flags & FLAG_REMOTE_OPEN) == FLAG_REMOTE_OPEN; + let si = StreamInner::new( + stream_id, + self.tx.clone(), + rx, + true, // TODO + true, + Kind::Server, + self.streams.clone(), + ); + + let ctx = TtrpcContext { + fd: self.fd, + mh: req_msg.header, + metadata: context::from_pb(&req.metadata), + timeout_nano: req.timeout_nano, + }; + + let task = spawn(async move { stream.handler(ctx, si).await }); + + if !req.payload.is_empty() { + // Fake the first data message. + let msg = GenMessage { + header: MessageHeader::new_data(stream_id, req.payload.len() as u32), + payload: req.payload, + }; + stream_tx.send(Ok(msg)).await.map_err(|e| { + error!("send stream data {} got error {:?}", path, &e); + get_status(Code::UNKNOWN, e) + })?; + } + task.await + .unwrap_or_else(|e| { + Err(Error::Others(format!( + "stream {} task got error {:?}", + path, e + ))) + }) + .map_err(|e| get_status(Code::UNKNOWN, e)) + } + async fn respond(tx: MessageSender, stream_id: u32, resp: Response) -> Result<()> { let payload = resp .encode() diff --git a/src/asynchronous/utils.rs b/src/asynchronous/utils.rs index 7cb038b4..17716ae8 100644 --- a/src/asynchronous/utils.rs +++ b/src/asynchronous/utils.rs @@ -84,6 +84,16 @@ pub trait MethodHandler { async fn handler(&self, ctx: TtrpcContext, req: Request) -> Result; } +/// Trait that implements handler which is a proxy to the stream (async). +#[async_trait] +pub trait StreamHandler { + async fn handler( + &self, + ctx: TtrpcContext, + stream: crate::r#async::StreamInner, + ) -> Result>; +} + /// The context of ttrpc (async). #[derive(Debug)] pub struct TtrpcContext { From 1cc2e5988bbce602ea59630717d81bd3608866c3 Mon Sep 17 00:00:00 2001 From: wllenyj Date: Mon, 6 Jun 2022 23:57:46 +0800 Subject: [PATCH 13/15] async: add generator supprot for streaming Added streaming support for generator. Signed-off-by: wanglei01 --- compiler/src/codegen.rs | 315 +++++++++++++++++++++++-------------- src/asynchronous/mod.rs | 5 +- src/asynchronous/stream.rs | 287 +++++++++++++++++++++++++++++++++ src/asynchronous/utils.rs | 152 ++++++++++++++++++ 4 files changed, 636 insertions(+), 123 deletions(-) diff --git a/compiler/src/codegen.rs b/compiler/src/codegen.rs index 1f90d726..f9ec2b06 100644 --- a/compiler/src/codegen.rs +++ b/compiler/src/codegen.rs @@ -55,7 +55,6 @@ struct MethodGen<'a> { proto: &'a MethodDescriptorProto, package_name: String, service_name: String, - service_path: String, root_scope: &'a RootScope<'a>, customize: &'a Customize, } @@ -65,7 +64,6 @@ impl<'a> MethodGen<'a> { proto: &'a MethodDescriptorProto, package_name: String, service_name: String, - service_path: String, root_scope: &'a RootScope<'a>, customize: &'a Customize, ) -> MethodGen<'a> { @@ -73,7 +71,6 @@ impl<'a> MethodGen<'a> { proto, package_name, service_name, - service_path, root_scope, customize, } @@ -127,10 +124,6 @@ impl<'a> MethodGen<'a> { to_camel_case(self.proto.get_name()) } - fn fq_name(&self) -> String { - format!("\"{}/{}\"", self.service_path, &self.proto.get_name()) - } - fn const_method_name(&self) -> String { format!( "METHOD_{}_{}", @@ -139,36 +132,13 @@ impl<'a> MethodGen<'a> { ) } - fn write_definition(&self, w: &mut CodeWriter) { - let head = format!( - "const {}: {}<{}, {}> = {} {{", - self.const_method_name(), - fq_grpc("Method"), - self.input(), - self.output(), - fq_grpc("Method") - ); - let pb_mar = format!( - "{} {{ ser: {}, de: {} }}", - fq_grpc("Marshaller"), - fq_grpc("pb_ser"), - fq_grpc("pb_de") - ); - w.block(&head, "};", |w| { - w.field_entry("ty", &self.method_type().1); - w.field_entry("name", &self.fq_name()); - w.field_entry("req_mar", &pb_mar); - w.field_entry("resp_mar", &pb_mar); - }); - } - fn write_handler(&self, w: &mut CodeWriter) { w.block( &format!("struct {}Method {{", self.struct_name()), "}", |w| { w.write_line(&format!( - "service: Arc>,", + "service: Arc>,", self.service_name )); }, @@ -197,16 +167,55 @@ impl<'a> MethodGen<'a> { fn write_handler_impl_async(&self, w: &mut CodeWriter) { w.write_line("#[async_trait]"); - w.block(&format!("impl ::ttrpc::r#async::MethodHandler for {}Method {{", self.struct_name()), "}", - |w| { - w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, req: ::ttrpc::Request) -> ::ttrpc::Result<(u32, Vec)> {", "}", - |w| { - w.write_line(&format!("::ttrpc::async_request_handler!(self, ctx, req, {}, {}, {});", + match self.method_type().0 { + MethodType::Unary => { + w.block(&format!("impl ::ttrpc::r#async::MethodHandler for {}Method {{", self.struct_name()), "}", + |w| { + w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, req: ::ttrpc::Request) -> ::ttrpc::Result<::ttrpc::Response> {", "}", + |w| { + w.write_line(&format!("::ttrpc::async_request_handler!(self, ctx, req, {}, {}, {});", proto_path_to_rust_mod(self.root_scope.find_message(self.proto.get_input_type()).get_scope().get_file_descriptor().get_name()), self.root_scope.find_message(self.proto.get_input_type()).rust_name(), self.name())); + }); }); - }); + } + // only receive + MethodType::ClientStreaming => { + w.block(&format!("impl ::ttrpc::r#async::StreamHandler for {}Method {{", self.struct_name()), "}", + |w| { + w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, inner: ::ttrpc::r#async::StreamInner) -> ::ttrpc::Result> {", "}", + |w| { + w.write_line(&format!("::ttrpc::async_client_streamimg_handler!(self, ctx, inner, {});", + self.name())); + }); + }); + } + // only send + MethodType::ServerStreaming => { + w.block(&format!("impl ::ttrpc::r#async::StreamHandler for {}Method {{", self.struct_name()), "}", + |w| { + w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, mut inner: ::ttrpc::r#async::StreamInner) -> ::ttrpc::Result> {", "}", + |w| { + w.write_line(&format!("::ttrpc::async_server_streamimg_handler!(self, ctx, inner, {}, {}, {});", + proto_path_to_rust_mod(self.root_scope.find_message(self.proto.get_input_type()).get_scope().get_file_descriptor().get_name()), + self.root_scope.find_message(self.proto.get_input_type()).rust_name(), + self.name())); + }); + }); + } + // receive and send + MethodType::Duplex => { + w.block(&format!("impl ::ttrpc::r#async::StreamHandler for {}Method {{", self.struct_name()), "}", + |w| { + w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, inner: ::ttrpc::r#async::StreamInner) -> ::ttrpc::Result> {", "}", + |w| { + w.write_line(&format!("::ttrpc::async_duplex_streamimg_handler!(self, ctx, inner, {});", + self.name())); + }); + }); + } + } } // Method signatures @@ -222,73 +231,33 @@ impl<'a> MethodGen<'a> { fn client_streaming(&self, method_name: &str) -> String { format!( - "{}(&self) -> {}<({}<{}>, {}<{}>)>", + "{}(&self, ctx: ttrpc::context::Context) -> {}<{}<{}, {}>>", method_name, fq_grpc("Result"), - fq_grpc("ClientCStreamSender"), + fq_grpc("r#async::ClientStreamSender"), self.input(), - fq_grpc("ClientCStreamReceiver"), - self.output() - ) - } - - fn client_streaming_opt(&self, method_name: &str) -> String { - format!( - "{}_opt(&self, opt: {}) -> {}<({}<{}>, {}<{}>)>", - method_name, - fq_grpc("CallOption"), - fq_grpc("Result"), - fq_grpc("ClientCStreamSender"), - self.input(), - fq_grpc("ClientCStreamReceiver"), self.output() ) } fn server_streaming(&self, method_name: &str) -> String { format!( - "{}(&self, req: &{}) -> {}<{}<{}>>", + "{}(&self, ctx: ttrpc::context::Context, req: &{}) -> {}<{}<{}>>", method_name, self.input(), fq_grpc("Result"), - fq_grpc("ClientSStreamReceiver"), - self.output() - ) - } - - fn server_streaming_opt(&self, method_name: &str) -> String { - format!( - "{}_opt(&self, req: &{}, opt: {}) -> {}<{}<{}>>", - method_name, - self.input(), - fq_grpc("CallOption"), - fq_grpc("Result"), - fq_grpc("ClientSStreamReceiver"), + fq_grpc("r#async::ClientStreamReceiver"), self.output() ) } fn duplex_streaming(&self, method_name: &str) -> String { format!( - "{}(&self) -> {}<({}<{}>, {}<{}>)>", + "{}(&self, ctx: ttrpc::context::Context) -> {}<{}<{}, {}>>", method_name, fq_grpc("Result"), - fq_grpc("ClientDuplexSender"), + fq_grpc("r#async::ClientStream"), self.input(), - fq_grpc("ClientDuplexReceiver"), - self.output() - ) - } - - fn duplex_streaming_opt(&self, method_name: &str) -> String { - format!( - "{}_opt(&self, opt: {}) -> {}<({}<{}>, {}<{}>)>", - method_name, - fq_grpc("CallOption"), - fq_grpc("Result"), - fq_grpc("ClientDuplexSender"), - self.input(), - fq_grpc("ClientDuplexReceiver"), self.output() ) } @@ -317,7 +286,7 @@ impl<'a> MethodGen<'a> { fn write_async_client(&self, w: &mut CodeWriter) { let method_name = self.name(); match self.method_type().0 { - // Unary + // Unary RPC MethodType::Unary => { pub_async_fn(w, &self.unary(&method_name), |w| { w.write_line(&format!("let mut cres = {}::new();", self.output())); @@ -329,28 +298,79 @@ impl<'a> MethodGen<'a> { )); }); } - - _ => {} + // Client Streaming RPC + MethodType::ClientStreaming => { + pub_async_fn(w, &self.client_streaming(&method_name), |w| { + w.write_line(&format!( + "::ttrpc::async_client_stream_send!(self, ctx, \"{}.{}\", \"{}\");", + self.package_name, + self.service_name, + &self.proto.get_name(), + )); + }); + } + // Server Streaming RPC + MethodType::ServerStreaming => { + pub_async_fn(w, &self.server_streaming(&method_name), |w| { + w.write_line(&format!( + "::ttrpc::async_client_stream_receive!(self, ctx, req, \"{}.{}\", \"{}\");", + self.package_name, + self.service_name, + &self.proto.get_name(), + )); + }); + } + // Bidirectional streaming RPC + MethodType::Duplex => { + pub_async_fn(w, &self.duplex_streaming(&method_name), |w| { + w.write_line(&format!( + "::ttrpc::async_client_stream!(self, ctx, \"{}.{}\", \"{}\");", + self.package_name, + self.service_name, + &self.proto.get_name(), + )); + }); + } }; } fn write_service(&self, w: &mut CodeWriter) { - let req_stream_type = format!("{}<{}>", fq_grpc("RequestStream"), self.input()); - let (req, req_type, _resp_type) = match self.method_type().0 { - MethodType::Unary => ("req", self.input(), "UnarySink"), - MethodType::ClientStreaming => ("stream", req_stream_type, "ClientStreamingSink"), - MethodType::ServerStreaming => ("req", self.input(), "ServerStreamingSink"), - MethodType::Duplex => ("stream", req_stream_type, "DuplexSink"), + let (_req, req_type, resp_type) = match self.method_type().0 { + MethodType::Unary => ("req", self.input(), self.output()), + MethodType::ClientStreaming => ( + "stream", + format!("::ttrpc::r#async::ServerStreamReceiver<{}>", self.input()), + self.output(), + ), + MethodType::ServerStreaming => ( + "req", + format!( + "{}, _: {}<{}>", + self.input(), + "::ttrpc::r#async::ServerStreamSender", + self.output() + ), + "()".to_string(), + ), + MethodType::Duplex => ( + "stream", + format!( + "{}<{}, {}>", + "::ttrpc::r#async::ServerStream", + self.output(), + self.input(), + ), + "()".to_string(), + ), }; let get_sig = |context_name| { format!( - "{}(&self, _ctx: &{}, _{}: {}) -> ::ttrpc::Result<{}>", + "{}(&self, _ctx: &{}, _: {}) -> ::ttrpc::Result<{}>", self.name(), fq_grpc(context_name), - req, req_type, - self.output() + resp_type, ) }; @@ -370,20 +390,38 @@ impl<'a> MethodGen<'a> { } fn write_bind(&self, w: &mut CodeWriter) { - let mut method_handler_name = "::ttrpc::MethodHandler"; + let method_handler_name = "::ttrpc::MethodHandler"; - if async_on(self.customize, "server") { - method_handler_name = "::ttrpc::r#async::MethodHandler"; - } + let s = format!( + "methods.insert(\"/{}.{}/{}\".to_string(), + Box::new({}Method{{service: service.clone()}}) as Box);", + self.package_name, + self.service_name, + self.proto.get_name(), + self.struct_name(), + method_handler_name, + ); + w.write_line(&s); + } - let s = format!("methods.insert(\"/{}.{}/{}\".to_string(), - std::boxed::Box::new({}Method{{service: service.clone()}}) as std::boxed::Box);", - self.package_name, - self.service_name, - self.proto.get_name(), - self.struct_name(), - method_handler_name, - ); + fn write_async_bind(&self, w: &mut CodeWriter) { + let s = if matches!(self.method_type().0, MethodType::Unary) { + format!( + "methods.insert(\"{}\".to_string(), + Box::new({}Method{{service: service.clone()}}) as {});", + self.proto.get_name(), + self.struct_name(), + "Box" + ) + } else { + format!( + "streams.insert(\"{}\".to_string(), + Arc::new({}Method{{service: service.clone()}}) as {});", + self.proto.get_name(), + self.struct_name(), + "Arc" + ) + }; w.write_line(&s); } } @@ -392,6 +430,7 @@ struct ServiceGen<'a> { proto: &'a ServiceDescriptorProto, methods: Vec>, customize: &'a Customize, + package_name: String, } impl<'a> ServiceGen<'a> { @@ -401,11 +440,6 @@ impl<'a> ServiceGen<'a> { root_scope: &'a RootScope, customize: &'a Customize, ) -> ServiceGen<'a> { - let service_path = if file.get_package().is_empty() { - format!("{}", proto.get_name()) - } else { - format!("{}.{}", file.get_package(), proto.get_name()) - }; let methods = proto .get_method() .iter() @@ -414,7 +448,6 @@ impl<'a> ServiceGen<'a> { m, file.get_package().to_string(), util::to_camel_case(proto.get_name()), - service_path.clone(), root_scope, &customize, ) @@ -425,6 +458,7 @@ impl<'a> ServiceGen<'a> { proto, methods, customize, + package_name: file.get_package().to_string(), } } @@ -432,10 +466,20 @@ impl<'a> ServiceGen<'a> { util::to_camel_case(self.proto.get_name()) } + fn service_path(&self) -> String { + format!("{}.{}", self.package_name, self.service_name()) + } + fn client_name(&self) -> String { format!("{}Client", self.service_name()) } + fn has_stream_method(&self) -> bool { + self.methods + .iter() + .any(|method| !matches!(method.method_type().0, MethodType::Unary)) + } + fn write_client(&self, w: &mut CodeWriter) { if async_on(self.customize, "client") { self.write_async_client(w) @@ -490,11 +534,9 @@ impl<'a> ServiceGen<'a> { fn write_server(&self, w: &mut CodeWriter) { let mut trait_name = self.service_name(); - let mut method_handler_name = "::ttrpc::MethodHandler"; if async_on(self.customize, "server") { w.write_line("#[async_trait]"); trait_name = format!("{}: Sync", &self.service_name()); - method_handler_name = "::ttrpc::r#async::MethodHandler"; } w.pub_trait(&trait_name.to_owned(), |w| { @@ -504,9 +546,17 @@ impl<'a> ServiceGen<'a> { }); w.write_line(""); + if async_on(self.customize, "server") { + self.write_async_server_create(w); + } else { + self.write_sync_server_create(w); + } + } + fn write_sync_server_create(&self, w: &mut CodeWriter) { + let method_handler_name = "::ttrpc::MethodHandler"; let s = format!( - "create_{}(service: Arc>) -> HashMap >", + "create_{}(service: Arc>) -> HashMap>", to_snake_case(&self.service_name()), self.service_name(), method_handler_name, @@ -523,14 +573,35 @@ impl<'a> ServiceGen<'a> { }); } - fn write_method_definitions(&self, w: &mut CodeWriter) { - for (i, method) in self.methods.iter().enumerate() { - if i != 0 { + fn write_async_server_create(&self, w: &mut CodeWriter) { + let s = format!( + "create_{}(service: Arc>) -> HashMap", + to_snake_case(&self.service_name()), + self.service_name(), + "::ttrpc::r#async::Service" + ); + + let has_stream_method = self.has_stream_method(); + w.pub_fn(&s, |w| { + w.write_line("let mut ret = HashMap::new();"); + w.write_line("let mut methods = HashMap::new();"); + if has_stream_method { + w.write_line("let mut streams = HashMap::new();"); + } else { + w.write_line("let streams = HashMap::new();"); + } + for method in &self.methods[0..self.methods.len()] { w.write_line(""); + method.write_async_bind(w); } - - method.write_definition(w); - } + w.write_line(""); + w.write_line(format!( + "ret.insert(\"{}\".to_string(), {});", + self.service_path(), + "::ttrpc::r#async::Service{ methods, streams }" + )); + w.write_line("ret"); + }); } fn write_method_handlers(&self, w: &mut CodeWriter) { diff --git a/src/asynchronous/mod.rs b/src/asynchronous/mod.rs index 2e976b5f..fb6d39cb 100644 --- a/src/asynchronous/mod.rs +++ b/src/asynchronous/mod.rs @@ -15,7 +15,10 @@ mod connection; pub mod shutdown; mod unix_incoming; -pub use self::stream::{Kind, StreamInner}; +pub use self::stream::{ + ClientStream, ClientStreamReceiver, ClientStreamSender, Kind, ServerStream, + ServerStreamReceiver, ServerStreamSender, StreamInner, +}; #[doc(inline)] pub use crate::r#async::client::Client; #[doc(inline)] diff --git a/src/asynchronous/stream.rs b/src/asynchronous/stream.rs index 404a73be..fdf90341 100644 --- a/src/asynchronous/stream.rs +++ b/src/asynchronous/stream.rs @@ -1,9 +1,11 @@ +// Copyright 2022 Alibaba Cloud. All rights reserved. // Copyright (c) 2020 Ant Financial // // SPDX-License-Identifier: Apache-2.0 // use std::collections::HashMap; +use std::marker::PhantomData; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; @@ -21,6 +23,291 @@ pub type MessageReceiver = mpsc::Receiver; pub type ResultSender = mpsc::Sender>; pub type ResultReceiver = mpsc::Receiver>; +#[derive(Debug)] +pub struct ClientStream { + tx: CSSender, + rx: CSReceiver

, +} + +impl ClientStream +where + Q: Codec, + P: Codec, + ::E: std::fmt::Display, +

::E: std::fmt::Display, +{ + pub fn new(inner: StreamInner) -> Self { + let (tx, rx) = inner.split(); + Self { + tx: CSSender { + tx, + _send: PhantomData, + }, + rx: CSReceiver { + rx, + _recv: PhantomData, + }, + } + } + + pub fn split(self) -> (CSSender, CSReceiver

) { + (self.tx, self.rx) + } + + pub async fn send(&self, req: &Q) -> Result<()> { + self.tx.send(req).await + } + + pub async fn close_send(&self) -> Result<()> { + self.tx.close_send().await + } + + pub async fn recv(&mut self) -> Result

{ + self.rx.recv().await + } +} + +#[derive(Clone, Debug)] +pub struct CSSender { + tx: StreamSender, + _send: PhantomData, +} + +impl CSSender +where + Q: Codec, + ::E: std::fmt::Display, +{ + pub async fn send(&self, req: &Q) -> Result<()> { + let msg_buf = req + .encode() + .map_err(err_to_others_err!(e, "Encode message failed."))?; + self.tx.send(msg_buf).await + } + + pub async fn close_send(&self) -> Result<()> { + self.tx.close_send().await + } +} + +#[derive(Debug)] +pub struct CSReceiver

{ + rx: StreamReceiver, + _recv: PhantomData

, +} + +impl

CSReceiver

+where + P: Codec, +

::E: std::fmt::Display, +{ + pub async fn recv(&mut self) -> Result

{ + let msg_buf = self.rx.recv().await?; + P::decode(&msg_buf).map_err(err_to_others_err!(e, "Decode message failed.")) + } +} + +#[derive(Debug)] +pub struct ServerStream { + tx: SSSender

, + rx: SSReceiver, +} + +impl ServerStream +where + P: Codec, + Q: Codec, +

::E: std::fmt::Display, + ::E: std::fmt::Display, +{ + pub fn new(inner: StreamInner) -> Self { + let (tx, rx) = inner.split(); + Self { + tx: SSSender { + tx, + _send: PhantomData, + }, + rx: SSReceiver { + rx, + _recv: PhantomData, + }, + } + } + + pub fn split(self) -> (SSSender

, SSReceiver) { + (self.tx, self.rx) + } + + pub async fn send(&self, resp: &P) -> Result<()> { + self.tx.send(resp).await + } + + pub async fn recv(&mut self) -> Result> { + self.rx.recv().await + } +} + +#[derive(Clone, Debug)] +pub struct SSSender

{ + tx: StreamSender, + _send: PhantomData

, +} + +impl

SSSender

+where + P: Codec, +

::E: std::fmt::Display, +{ + pub async fn send(&self, resp: &P) -> Result<()> { + let msg_buf = resp + .encode() + .map_err(err_to_others_err!(e, "Encode message failed."))?; + self.tx.send(msg_buf).await + } +} + +#[derive(Debug)] +pub struct SSReceiver { + rx: StreamReceiver, + _recv: PhantomData, +} + +impl SSReceiver +where + Q: Codec, + ::E: std::fmt::Display, +{ + pub async fn recv(&mut self) -> Result> { + let res = self.rx.recv().await; + + if matches!(res, Err(Error::Eof)) { + return Ok(None); + } + let msg_buf = res?; + Q::decode(&msg_buf) + .map_err(err_to_others_err!(e, "Decode message failed.")) + .map(Some) + } +} + +pub struct ClientStreamSender { + inner: StreamInner, + _send: PhantomData, + _recv: PhantomData

, +} + +impl ClientStreamSender +where + Q: Codec, + P: Codec, + ::E: std::fmt::Display, +

::E: std::fmt::Display, +{ + pub fn new(inner: StreamInner) -> Self { + Self { + inner, + _send: PhantomData, + _recv: PhantomData, + } + } + + pub async fn send(&self, req: &Q) -> Result<()> { + let msg_buf = req + .encode() + .map_err(err_to_others_err!(e, "Encode message failed."))?; + self.inner.send(msg_buf).await + } + + pub async fn close_and_recv(&mut self) -> Result

{ + self.inner.close_send().await?; + let msg_buf = self.inner.recv().await?; + P::decode(&msg_buf).map_err(err_to_others_err!(e, "Decode message failed.")) + } +} + +pub struct ServerStreamSender

{ + inner: StreamSender, + _send: PhantomData

, +} + +impl

ServerStreamSender

+where + P: Codec, +

::E: std::fmt::Display, +{ + pub fn new(inner: StreamInner) -> Self { + Self { + inner: inner.split().0, + _send: PhantomData, + } + } + + pub async fn send(&self, resp: &P) -> Result<()> { + let msg_buf = resp + .encode() + .map_err(err_to_others_err!(e, "Encode message failed."))?; + self.inner.send(msg_buf).await + } +} + +pub struct ClientStreamReceiver

{ + inner: StreamReceiver, + _recv: PhantomData

, +} + +impl

ClientStreamReceiver

+where + P: Codec, +

::E: std::fmt::Display, +{ + pub fn new(inner: StreamInner) -> Self { + Self { + inner: inner.split().1, + _recv: PhantomData, + } + } + + pub async fn recv(&mut self) -> Result> { + let res = self.inner.recv().await; + if matches!(res, Err(Error::Eof)) { + return Ok(None); + } + let msg_buf = res?; + P::decode(&msg_buf) + .map_err(err_to_others_err!(e, "Decode message failed.")) + .map(Some) + } +} + +pub struct ServerStreamReceiver { + inner: StreamReceiver, + _recv: PhantomData, +} + +impl ServerStreamReceiver +where + Q: Codec, + ::E: std::fmt::Display, +{ + pub fn new(inner: StreamInner) -> Self { + Self { + inner: inner.split().1, + _recv: PhantomData, + } + } + + pub async fn recv(&mut self) -> Result> { + let res = self.inner.recv().await; + if matches!(res, Err(Error::Eof)) { + return Ok(None); + } + let msg_buf = res?; + Q::decode(&msg_buf) + .map_err(err_to_others_err!(e, "Decode message failed.")) + .map(Some) + } +} + async fn _recv(rx: &mut ResultReceiver) -> Result { rx.recv() .await diff --git a/src/asynchronous/utils.rs b/src/asynchronous/utils.rs index 17716ae8..2e2555d1 100644 --- a/src/asynchronous/utils.rs +++ b/src/asynchronous/utils.rs @@ -1,3 +1,4 @@ +// Copyright 2022 Alibaba Cloud. All rights reserved. // Copyright (c) 2020 Ant Financial // // SPDX-License-Identifier: Apache-2.0 @@ -50,6 +51,96 @@ macro_rules! async_request_handler { }; } +/// Handle client streaming in async mode. +#[macro_export] +macro_rules! async_client_streamimg_handler { + ($class: ident, $ctx: ident, $inner: ident, $req_fn: ident) => { + let stream = ::ttrpc::r#async::ServerStreamReceiver::new($inner); + let mut res = ::ttrpc::Response::new(); + match $class.service.$req_fn(&$ctx, stream).await { + Ok(rep) => { + res.set_status(::ttrpc::get_status(::ttrpc::Code::OK, "".to_string())); + res.payload.reserve(rep.compute_size() as usize); + let mut s = protobuf::CodedOutputStream::vec(&mut res.payload); + rep.write_to(&mut s) + .map_err(::ttrpc::err_to_others!(e, ""))?; + s.flush().map_err(::ttrpc::err_to_others!(e, ""))?; + } + Err(x) => match x { + ::ttrpc::Error::RpcStatus(s) => { + res.set_status(s); + } + _ => { + res.set_status(::ttrpc::get_status( + ::ttrpc::Code::UNKNOWN, + format!("{:?}", x), + )); + } + }, + } + return Ok(Some(res)); + }; +} + +/// Handle server streaming in async mode. +#[macro_export] +macro_rules! async_server_streamimg_handler { + ($class: ident, $ctx: ident, $inner: ident, $server: ident, $req_type: ident, $req_fn: ident) => { + let req_buf = $inner.recv().await?; + let req = ::decode(&req_buf) + .map_err(|e| ::ttrpc::Error::Others(e.to_string()))?; + let stream = ::ttrpc::r#async::ServerStreamSender::new($inner); + match $class.service.$req_fn(&$ctx, req, stream).await { + Ok(_) => { + return Ok(None); + } + Err(x) => { + let mut res = ::ttrpc::Response::new(); + match x { + ::ttrpc::Error::RpcStatus(s) => { + res.set_status(s); + } + _ => { + res.set_status(::ttrpc::get_status( + ::ttrpc::Code::UNKNOWN, + format!("{:?}", x), + )); + } + } + return Ok(Some(res)); + } + } + }; +} + +/// Handle duplex streaming in async mode. +#[macro_export] +macro_rules! async_duplex_streamimg_handler { + ($class: ident, $ctx: ident, $inner: ident, $req_fn: ident) => { + let stream = ::ttrpc::r#async::ServerStream::new($inner); + match $class.service.$req_fn(&$ctx, stream).await { + Ok(_) => { + return Ok(None); + } + Err(x) => { + let mut res = ::ttrpc::Response::new(); + match x { + ::ttrpc::Error::RpcStatus(s) => { + res.set_status(s); + } + _ => { + res.set_status(::ttrpc::get_status( + ::ttrpc::Code::UNKNOWN, + format!("{:?}", x), + )); + } + } + return Ok(Some(res)); + } + } + }; +} + /// Send request through async client. #[macro_export] macro_rules! async_client_request { @@ -78,6 +169,67 @@ macro_rules! async_client_request { }; } +/// Duplex streaming through async client. +#[macro_export] +macro_rules! async_client_stream { + ($self: ident, $ctx: ident, $server: expr, $method: expr) => { + let mut creq = ::ttrpc::Request::new(); + creq.set_service($server.to_string()); + creq.set_method($method.to_string()); + creq.set_timeout_nano($ctx.timeout_nano); + let md = ::ttrpc::context::to_pb($ctx.metadata); + creq.set_metadata(md); + + let inner = $self.client.new_stream(creq, true, true).await?; + let stream = ::ttrpc::r#async::ClientStream::new(inner); + + return Ok(stream); + }; +} + +/// Only send streaming through async client. +#[macro_export] +macro_rules! async_client_stream_send { + ($self: ident, $ctx: ident, $server: expr, $method: expr) => { + let mut creq = ::ttrpc::Request::new(); + creq.set_service($server.to_string()); + creq.set_method($method.to_string()); + creq.set_timeout_nano($ctx.timeout_nano); + let md = ::ttrpc::context::to_pb($ctx.metadata); + creq.set_metadata(md); + + let inner = $self.client.new_stream(creq, true, false).await?; + let stream = ::ttrpc::r#async::ClientStreamSender::new(inner); + + return Ok(stream); + }; +} + +/// Only receive streaming through async client. +#[macro_export] +macro_rules! async_client_stream_receive { + ($self: ident, $ctx: ident, $req: ident, $server: expr, $method: expr) => { + let mut creq = ::ttrpc::Request::new(); + creq.set_service($server.to_string()); + creq.set_method($method.to_string()); + creq.set_timeout_nano($ctx.timeout_nano); + let md = ::ttrpc::context::to_pb($ctx.metadata); + creq.set_metadata(md); + creq.payload.reserve($req.compute_size() as usize); + { + let mut s = CodedOutputStream::vec(&mut creq.payload); + $req.write_to(&mut s) + .map_err(::ttrpc::err_to_others!(e, ""))?; + s.flush().map_err(::ttrpc::err_to_others!(e, ""))?; + } + + let inner = $self.client.new_stream(creq, false, true).await?; + let stream = ::ttrpc::r#async::ClientStreamReceiver::new(inner); + + return Ok(stream); + }; +} + /// Trait that implements handler which is a proxy to the desired method (async). #[async_trait] pub trait MethodHandler { From 6877207da53e16622bb88cedf050f674d1727ad3 Mon Sep 17 00:00:00 2001 From: wllenyj Date: Tue, 7 Jun 2022 00:04:24 +0800 Subject: [PATCH 14/15] example: add streaming example This example is the same as the golang version. See `https://github.com/containerd/ttrpc/tree/main/integration/streaming` for details. Signed-off-by: wllenyj --- example/Cargo.toml | 9 ++ example/Makefile | 2 + example/async-stream-client.rs | 174 +++++++++++++++++++++++ example/async-stream-server.rs | 170 ++++++++++++++++++++++ example/build.rs | 5 +- example/protocols/asynchronous/mod.rs | 2 + example/protocols/protos/streaming.proto | 49 +++++++ ttrpc-codegen/Cargo.toml | 2 +- 8 files changed, 411 insertions(+), 2 deletions(-) create mode 100644 example/async-stream-client.rs create mode 100644 example/async-stream-server.rs create mode 100644 example/protocols/protos/streaming.proto diff --git a/example/Cargo.toml b/example/Cargo.toml index fe9a099f..c873be51 100644 --- a/example/Cargo.toml +++ b/example/Cargo.toml @@ -22,6 +22,7 @@ ttrpc = { path = "../", features = ["async"] } ctrlc = { version = "3.0", features = ["termination"] } tokio = { version = "1.0.1", features = ["signal", "time"] } async-trait = "0.1.42" +rand = "0.8.5" [[example]] @@ -40,5 +41,13 @@ path = "./async-server.rs" name = "async-client" path = "./async-client.rs" +[[example]] +name = "async-stream-server" +path = "./async-stream-server.rs" + +[[example]] +name = "async-stream-client" +path = "./async-stream-client.rs" + [build-dependencies] ttrpc-codegen = { path = "../ttrpc-codegen"} diff --git a/example/Makefile b/example/Makefile index c2d695c5..518a45bf 100644 --- a/example/Makefile +++ b/example/Makefile @@ -8,6 +8,8 @@ build: cargo build --example client cargo build --example async-server cargo build --example async-client + cargo build --example async-stream-server + cargo build --example async-stream-client .PHONY: deps deps: diff --git a/example/async-stream-client.rs b/example/async-stream-client.rs new file mode 100644 index 00000000..ba953596 --- /dev/null +++ b/example/async-stream-client.rs @@ -0,0 +1,174 @@ +// Copyright 2022 Alibaba Cloud. All rights reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +mod protocols; +mod utils; + +use protocols::r#async::{empty, streaming, streaming_ttrpc}; +use ttrpc::context::{self, Context}; +use ttrpc::r#async::Client; + +#[tokio::main(flavor = "current_thread")] +async fn main() { + simple_logging::log_to_stderr(log::LevelFilter::Info); + + let c = Client::connect(utils::SOCK_ADDR).unwrap(); + let sc = streaming_ttrpc::StreamingClient::new(c); + + let _now = std::time::Instant::now(); + + let sc1 = sc.clone(); + let t1 = tokio::spawn(echo_request(sc1)); + + let sc1 = sc.clone(); + let t2 = tokio::spawn(echo_stream(sc1)); + + let sc1 = sc.clone(); + let t3 = tokio::spawn(sum_stream(sc1)); + + let sc1 = sc.clone(); + let t4 = tokio::spawn(divide_stream(sc1)); + + let sc1 = sc.clone(); + let t5 = tokio::spawn(echo_null(sc1)); + + let t6 = tokio::spawn(echo_null_stream(sc)); + + let _ = tokio::join!(t1, t2, t3, t4, t5, t6); +} + +fn default_ctx() -> Context { + let mut ctx = context::with_timeout(0); + ctx.add("key-1".to_string(), "value-1-1".to_string()); + ctx.add("key-1".to_string(), "value-1-2".to_string()); + ctx.set("key-2".to_string(), vec!["value-2".to_string()]); + + ctx +} + +async fn echo_request(cli: streaming_ttrpc::StreamingClient) { + let echo1 = streaming::EchoPayload { + seq: 1, + msg: "Echo Me".to_string(), + ..Default::default() + }; + let resp = cli.echo(default_ctx(), &echo1).await.unwrap(); + assert_eq!(resp.msg, echo1.msg); + assert_eq!(resp.seq, echo1.seq + 1); +} + +async fn echo_stream(cli: streaming_ttrpc::StreamingClient) { + let mut stream = cli.echo_stream(default_ctx()).await.unwrap(); + + let mut i = 0; + while i < 100 { + let echo = streaming::EchoPayload { + seq: i as u32, + msg: format!("{}: Echo in a stream", i), + ..Default::default() + }; + stream.send(&echo).await.unwrap(); + let resp = stream.recv().await.unwrap(); + assert_eq!(resp.msg, echo.msg); + assert_eq!(resp.seq, echo.seq + 1); + + i += 2; + } + stream.close_send().await.unwrap(); + let ret = stream.recv().await; + assert!(matches!(ret, Err(ttrpc::Error::Eof))); +} + +async fn sum_stream(cli: streaming_ttrpc::StreamingClient) { + let mut stream = cli.sum_stream(default_ctx()).await.unwrap(); + + let mut sum = streaming::Sum::new(); + stream.send(&streaming::Part::new()).await.unwrap(); + + sum.num += 1; + let mut i = -99i32; + while i <= 100 { + let addi = streaming::Part { + add: i, + ..Default::default() + }; + stream.send(&addi).await.unwrap(); + sum.sum += i; + sum.num += 1; + + i += 1; + } + stream.send(&streaming::Part::new()).await.unwrap(); + sum.num += 1; + + let ssum = stream.close_and_recv().await.unwrap(); + assert_eq!(ssum.sum, sum.sum); + assert_eq!(ssum.num, sum.num); +} + +async fn divide_stream(cli: streaming_ttrpc::StreamingClient) { + let expected = streaming::Sum { + sum: 392, + num: 4, + ..Default::default() + }; + let mut stream = cli.divide_stream(default_ctx(), &expected).await.unwrap(); + + let mut actual = streaming::Sum::new(); + + // NOTE: `for part in stream.recv().await.unwrap()` can't work. + while let Some(part) = stream.recv().await.unwrap() { + actual.sum += part.add; + actual.num += 1; + } + assert_eq!(actual.sum, expected.sum); + assert_eq!(actual.num, expected.num); +} + +async fn echo_null(cli: streaming_ttrpc::StreamingClient) { + let mut stream = cli.echo_null(default_ctx()).await.unwrap(); + + for i in 0..100 { + let echo = streaming::EchoPayload { + seq: i as u32, + msg: "non-empty empty".to_string(), + ..Default::default() + }; + stream.send(&echo).await.unwrap(); + } + let res = stream.close_and_recv().await.unwrap(); + assert_eq!(res, empty::Empty::new()); +} + +async fn echo_null_stream(cli: streaming_ttrpc::StreamingClient) { + let stream = cli.echo_null_stream(default_ctx()).await.unwrap(); + + let (tx, mut rx) = stream.split(); + + let task = tokio::spawn(async move { + loop { + let ret = rx.recv().await; + if matches!(ret, Err(ttrpc::Error::Eof)) { + break; + } + } + }); + + for i in 0..100 { + let echo = streaming::EchoPayload { + seq: i as u32, + msg: "non-empty empty".to_string(), + ..Default::default() + }; + tx.send(&echo).await.unwrap(); + } + + tx.close_send().await.unwrap(); + + tokio::time::timeout(tokio::time::Duration::from_secs(10), task) + .await + .unwrap() + .unwrap(); +} diff --git a/example/async-stream-server.rs b/example/async-stream-server.rs new file mode 100644 index 00000000..d828e2d3 --- /dev/null +++ b/example/async-stream-server.rs @@ -0,0 +1,170 @@ +// Copyright 2022 Alibaba Cloud. All rights reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +mod protocols; +mod utils; + +use std::sync::Arc; + +use log::{info, LevelFilter}; + +use protocols::r#async::{empty, streaming, streaming_ttrpc}; +use ttrpc::asynchronous::Server; + +use async_trait::async_trait; +use tokio::signal::unix::{signal, SignalKind}; +use tokio::time::sleep; + +struct StreamingService; + +#[async_trait] +impl streaming_ttrpc::Streaming for StreamingService { + async fn echo( + &self, + _ctx: &::ttrpc::r#async::TtrpcContext, + mut e: streaming::EchoPayload, + ) -> ::ttrpc::Result { + e.seq += 1; + Ok(e) + } + + async fn echo_stream( + &self, + _ctx: &::ttrpc::r#async::TtrpcContext, + mut s: ::ttrpc::r#async::ServerStream, + ) -> ::ttrpc::Result<()> { + while let Some(mut e) = s.recv().await? { + e.seq += 1; + s.send(&e).await?; + } + + Ok(()) + } + + async fn sum_stream( + &self, + _ctx: &::ttrpc::r#async::TtrpcContext, + mut s: ::ttrpc::r#async::ServerStreamReceiver, + ) -> ::ttrpc::Result { + let mut sum = streaming::Sum::new(); + while let Some(part) = s.recv().await? { + sum.sum += part.add; + sum.num += 1; + } + + Ok(sum) + } + + async fn divide_stream( + &self, + _ctx: &::ttrpc::r#async::TtrpcContext, + sum: streaming::Sum, + s: ::ttrpc::r#async::ServerStreamSender, + ) -> ::ttrpc::Result<()> { + let mut parts = vec![streaming::Part::new(); sum.num as usize]; + + let mut total = 0i32; + for i in 1..(sum.num - 2) { + let add = (rand::random::() % 1000) as i32 - 500; + parts[i as usize].add = add; + total += add; + } + + parts[sum.num as usize - 2].add = sum.sum - total; + + for part in parts { + s.send(&part).await.unwrap(); + } + + Ok(()) + } + + async fn echo_null( + &self, + _ctx: &::ttrpc::r#async::TtrpcContext, + mut s: ::ttrpc::r#async::ServerStreamReceiver, + ) -> ::ttrpc::Result { + let mut seq = 0; + while let Some(e) = s.recv().await? { + assert_eq!(e.seq, seq); + assert_eq!(e.msg.as_str(), "non-empty empty"); + seq += 1; + } + Ok(empty::Empty::new()) + } + + async fn echo_null_stream( + &self, + _ctx: &::ttrpc::r#async::TtrpcContext, + s: ::ttrpc::r#async::ServerStream, + ) -> ::ttrpc::Result<()> { + let msg = "non-empty empty".to_string(); + + let mut tasks = Vec::new(); + + let (tx, mut rx) = s.split(); + let mut seq = 0u32; + while let Some(e) = rx.recv().await? { + assert_eq!(e.seq, seq); + assert_eq!(e.msg, msg); + seq += 1; + + for _i in 0..10 { + let tx = tx.clone(); + tasks.push(tokio::spawn( + async move { tx.send(&empty::Empty::new()).await }, + )); + } + } + + for t in tasks { + t.await.unwrap().map_err(|e| { + ::ttrpc::Error::RpcStatus(::ttrpc::get_status( + ::ttrpc::Code::UNKNOWN, + e.to_string(), + )) + })?; + } + Ok(()) + } +} + +#[tokio::main(flavor = "current_thread")] +async fn main() { + simple_logging::log_to_stderr(LevelFilter::Info); + + let s = Box::new(StreamingService {}) as Box; + let s = Arc::new(s); + let service = streaming_ttrpc::create_streaming(s); + + utils::remove_if_sock_exist(utils::SOCK_ADDR).unwrap(); + + let mut server = Server::new() + .bind(utils::SOCK_ADDR) + .unwrap() + .register_service(service); + + let mut hangup = signal(SignalKind::hangup()).unwrap(); + let mut interrupt = signal(SignalKind::interrupt()).unwrap(); + server.start().await.unwrap(); + + tokio::select! { + _ = hangup.recv() => { + // test stop_listen -> start + info!("stop listen"); + server.stop_listen().await; + info!("start listen"); + server.start().await.unwrap(); + + // hold some time for the new test connection. + sleep(std::time::Duration::from_secs(100)).await; + } + _ = interrupt.recv() => { + // test graceful shutdown + info!("graceful shutdown"); + server.shutdown().await.unwrap(); + } + }; +} diff --git a/example/build.rs b/example/build.rs index b5550e2b..7ebb9fb5 100644 --- a/example/build.rs +++ b/example/build.rs @@ -9,7 +9,7 @@ use ttrpc_codegen::Codegen; use ttrpc_codegen::Customize; fn main() { - let protos = vec![ + let mut protos = vec![ "protocols/protos/github.com/kata-containers/agent/pkg/types/types.proto", "protocols/protos/agent.proto", "protocols/protos/health.proto", @@ -28,6 +28,9 @@ fn main() { .run() .expect("Gen sync code failed."); + // Only async support stream currently. + protos.push("protocols/protos/streaming.proto"); + Codegen::new() .out_dir("protocols/asynchronous") .inputs(&protos) diff --git a/example/protocols/asynchronous/mod.rs b/example/protocols/asynchronous/mod.rs index fd7082cc..34df78b2 100644 --- a/example/protocols/asynchronous/mod.rs +++ b/example/protocols/asynchronous/mod.rs @@ -10,3 +10,5 @@ pub mod health; pub mod health_ttrpc; mod oci; pub mod types; +pub mod streaming; +pub mod streaming_ttrpc; diff --git a/example/protocols/protos/streaming.proto b/example/protocols/protos/streaming.proto new file mode 100644 index 00000000..fce29dd6 --- /dev/null +++ b/example/protocols/protos/streaming.proto @@ -0,0 +1,49 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +syntax = "proto3"; + +package ttrpc.test.streaming; + +import "google/protobuf/empty.proto"; + +// Shim service is launched for each container and is responsible for owning the IO +// for the container and its additional processes. The shim is also the parent of +// each container and allows reattaching to the IO and receiving the exit status +// for the container processes. + +service Streaming { + rpc Echo(EchoPayload) returns (EchoPayload); + rpc EchoStream(stream EchoPayload) returns (stream EchoPayload); + rpc SumStream(stream Part) returns (Sum); + rpc DivideStream(Sum) returns (stream Part); + rpc EchoNull(stream EchoPayload) returns (google.protobuf.Empty); + rpc EchoNullStream(stream EchoPayload) returns (stream google.protobuf.Empty); +} + +message EchoPayload { + uint32 seq = 1; + string msg = 2; +} + +message Part { + int32 add = 1; +} + +message Sum { + int32 sum = 1; + int32 num = 2; +} diff --git a/ttrpc-codegen/Cargo.toml b/ttrpc-codegen/Cargo.toml index fc6eb32b..2805bcf7 100644 --- a/ttrpc-codegen/Cargo.toml +++ b/ttrpc-codegen/Cargo.toml @@ -16,4 +16,4 @@ readme = "README.md" protobuf = { version = "2.14.0" } protobuf-codegen-pure = "2.14.0" protobuf-codegen = "2.14.0" -ttrpc-compiler = "0.5.0" +ttrpc-compiler = { version = "0.5.0", path = "../compiler" } From 98d128d4b3efc92ece13193a5f3bdc6b9aa1bd2f Mon Sep 17 00:00:00 2001 From: wanglei01 Date: Wed, 10 Aug 2022 16:30:27 +0800 Subject: [PATCH 15/15] deny: add Unicode-DFS-2016 license MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is required by unicode-ident crate. For details: ``` error[L001]: failed to satisfy license requirements ┌─ unicode-ident 1.0.3 (registry+https://github.com/rust-lang/crates.io-index):4:13 │ 4 │ license = "(MIT OR Apache-2.0) AND Unicode-DFS-2016" │ -^^^----^^^^^^^^^^------^^^^^^^^^^^^^^^^ │ ││ │ │ │ ││ │ rejected: not explicitly allowed │ ││ accepted: license is explicitly allowed │ │accepted: license is explicitly allowed │ license expression retrieved via Cargo.toml `license` │ = unicode-ident v1.0.3 ├── proc-macro2 v1.0.43 │ ├── async-trait v0.1.57 │ │ └── ttrpc v0.6.1 │ ├── futures-macro v0.3.21 │ │ └── futures-util v0.3.21 │ │ ├── futures v0.3.21 │ │ │ ├── tokio-vsock v0.3.2 │ │ │ │ └── ttrpc v0.6.1 (*) │ │ │ └── ttrpc v0.6.1 (*) │ │ └── futures-executor v0.3.21 │ │ └── futures v0.3.21 (*) │ ├── quote v1.0.21 │ │ ├── async-trait v0.1.57 (*) │ │ ├── futures-macro v0.3.21 (*) │ │ ├── syn v1.0.99 │ │ │ ├── async-trait v0.1.57 (*) │ │ │ ├── futures-macro v0.3.21 (*) │ │ │ ├── thiserror-impl v1.0.32 │ │ │ │ └── thiserror v1.0.32 │ │ │ │ └── ttrpc v0.6.1 (*) │ │ │ └── tokio-macros v1.8.0 │ │ │ └── tokio v1.20.1 │ │ │ ├── tokio-vsock v0.3.2 (*) │ │ │ └── ttrpc v0.6.1 (*) │ │ ├── thiserror-impl v1.0.32 (*) │ │ └── tokio-macros v1.8.0 (*) │ ├── syn v1.0.99 (*) │ ├── thiserror-impl v1.0.32 (*) │ └── tokio-macros v1.8.0 (*) └── syn v1.0.99 (*) ``` Signed-off-by: wanglei01 --- deny.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/deny.toml b/deny.toml index c7f46580..6223c48d 100644 --- a/deny.toml +++ b/deny.toml @@ -72,6 +72,7 @@ unlicensed = "deny" allow = [ "MIT", "Apache-2.0", + "Unicode-DFS-2016", #"Apache-2.0 WITH LLVM-exception", ] # List of explictly disallowed licenses