From 8b1b1d7fb7fb4a44044969e036f3fb8d60661dc6 Mon Sep 17 00:00:00 2001 From: wllenyj Date: Mon, 6 Jun 2022 23:42:45 +0800 Subject: [PATCH] async: add streaming support for server Added streaming support for server-side. Signed-off-by: wllenyj --- src/asynchronous/mod.rs | 5 +- src/asynchronous/server.rs | 198 +++++++++++++++++++++++++++++++------ src/asynchronous/utils.rs | 10 ++ 3 files changed, 181 insertions(+), 32 deletions(-) diff --git a/src/asynchronous/mod.rs b/src/asynchronous/mod.rs index dd6e5fae..2e976b5f 100644 --- a/src/asynchronous/mod.rs +++ b/src/asynchronous/mod.rs @@ -15,9 +15,10 @@ mod connection; pub mod shutdown; mod unix_incoming; +pub use self::stream::{Kind, StreamInner}; #[doc(inline)] pub use crate::r#async::client::Client; #[doc(inline)] -pub use crate::r#async::server::Server; +pub use crate::r#async::server::{Server, Service}; #[doc(inline)] -pub use utils::{MethodHandler, TtrpcContext}; +pub use utils::{MethodHandler, StreamHandler, TtrpcContext}; diff --git a/src/asynchronous/server.rs b/src/asynchronous/server.rs index f7c1c1c6..c4978772 100644 --- a/src/asynchronous/server.rs +++ b/src/asynchronous/server.rs @@ -1,3 +1,4 @@ +// Copyright 2022 Alibaba Cloud. All rights reserved. // Copyright (c) 2020 Ant Financial // // SPDX-License-Identifier: Apache-2.0 @@ -10,7 +11,7 @@ use std::os::unix::io::RawFd; use std::os::unix::io::{AsRawFd, FromRawFd}; use std::os::unix::net::UnixListener as SysUnixListener; use std::result::Result as StdResult; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::time::Duration; use async_trait::async_trait; @@ -34,22 +35,39 @@ use crate::common::{self, Domain}; use crate::context; use crate::error::{get_status, Error, Result}; use crate::proto::{ - Code, Codec, GenMessage, Message, MessageHeader, Request, Response, Status, - MESSAGE_TYPE_REQUEST, + Code, Codec, GenMessage, Message, MessageHeader, Request, Response, Status, FLAG_NO_DATA, + FLAG_REMOTE_CLOSED, FLAG_REMOTE_OPEN, MESSAGE_TYPE_DATA, MESSAGE_TYPE_REQUEST, }; use crate::r#async::connection::*; use crate::r#async::shutdown; -use crate::r#async::stream::{MessageReceiver, MessageSender}; +use crate::r#async::stream::{ + Kind, MessageReceiver, MessageSender, ResultReceiver, ResultSender, StreamInner, +}; use crate::r#async::utils; -use crate::r#async::{MethodHandler, TtrpcContext}; +use crate::r#async::{MethodHandler, StreamHandler, TtrpcContext}; const DEFAULT_CONN_SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(5000); const DEFAULT_SERVER_SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(10000); +pub struct Service { + pub methods: HashMap>, + pub streams: HashMap>, +} + +impl Service { + pub(crate) fn get_method(&self, name: &String) -> Option<&(dyn MethodHandler + Send + Sync)> { + self.methods.get(name).map(|b| b.as_ref()) + } + + pub(crate) fn get_stream(&self, name: &String) -> Option> { + self.streams.get(name).cloned() + } +} + /// A ttrpc Server (async). pub struct Server { listeners: Vec, - methods: Arc>>, + services: Arc>, domain: Option, shutdown: shutdown::Notifier, @@ -60,7 +78,7 @@ impl Default for Server { fn default() -> Self { Server { listeners: Vec::with_capacity(1), - methods: Arc::new(HashMap::new()), + services: Arc::new(HashMap::new()), domain: None, shutdown: shutdown::with_timeout(DEFAULT_SERVER_SHUTDOWN_TIMEOUT).0, stop_listen_tx: None, @@ -105,12 +123,9 @@ impl Server { Ok(self) } - pub fn register_service( - mut self, - methods: HashMap>, - ) -> Server { - let mut_methods = Arc::get_mut(&mut self.methods).unwrap(); - mut_methods.extend(methods); + pub fn register_service(mut self, new: HashMap) -> Server { + let services = Arc::get_mut(&mut self.services).unwrap(); + services.extend(new); self } @@ -158,7 +173,7 @@ impl Server { I: Stream> + Unpin + Send + 'static + AsRawFd, S: AsyncRead + AsyncWrite + AsRawFd + Send + 'static, { - let methods = self.methods.clone(); + let services = self.services.clone(); let shutdown_waiter = self.shutdown.subscribe(); @@ -172,13 +187,13 @@ impl Server { if let Some(conn) = conn { // Accept a new connection match conn { - Ok(stream) => { - let fd = stream.as_raw_fd(); + Ok(conn) => { + let fd = conn.as_raw_fd(); // spawn a connection handler, would not block spawn_connection_handler( fd, - stream, - methods.clone(), + conn, + services.clone(), shutdown_waiter.clone(), ).await; } @@ -244,14 +259,15 @@ impl Server { async fn spawn_connection_handler( fd: RawFd, conn: C, - methods: Arc>>, + services: Arc>, shutdown_waiter: shutdown::Waiter, ) where C: AsyncRead + AsyncWrite + AsRawFd + Send + 'static, { let delegate = ServerBuilder { fd, - methods, + services, + streams: Arc::new(Mutex::new(HashMap::new())), shutdown_waiter, }; let conn = Connection::new(conn, delegate); @@ -279,7 +295,8 @@ impl AsRawFd for Server { struct ServerBuilder { fd: RawFd, - methods: Arc>>, + services: Arc>, + streams: Arc>>, shutdown_waiter: shutdown::Waiter, } @@ -296,7 +313,8 @@ impl Builder for ServerBuilder { ServerReader { fd: self.fd, tx, - methods: self.methods.clone(), + services: self.services.clone(), + streams: self.streams.clone(), server_shutdown: self.shutdown_waiter.clone(), handler_shutdown: disconnect_notifier, }, @@ -321,7 +339,8 @@ impl WriterDelegate for ServerWriter { struct ServerReader { fd: RawFd, tx: MessageSender, - methods: Arc>>, + services: Arc>, + streams: Arc>>, server_shutdown: shutdown::Waiter, handler_shutdown: shutdown::Notifier, } @@ -366,7 +385,8 @@ impl ServerReader { HandlerContext { fd: self.fd, tx: self.tx.clone(), - methods: self.methods.clone(), + services: self.services.clone(), + streams: self.streams.clone(), _handler_shutdown_waiter: self.handler_shutdown.subscribe(), } } @@ -375,7 +395,8 @@ impl ServerReader { struct HandlerContext { fd: RawFd, tx: MessageSender, - methods: Arc>>, + services: Arc>, + streams: Arc>>, // Used for waiting handler exit. _handler_shutdown_waiter: shutdown::Waiter, } @@ -406,11 +427,62 @@ impl HandlerContext { .ok(); } None => { - unimplemented!(); + let mut header = MessageHeader::new_data(stream_id, 0); + header.set_flags(FLAG_REMOTE_CLOSED | FLAG_NO_DATA); + let msg = GenMessage { + header, + payload: Vec::new(), + }; + + self.tx + .send(msg) + .await + .map_err(err_to_others_err!(e, "Send packet to sender error ")) + .ok(); } }, Err(status) => Self::respond_with_status(self.tx.clone(), stream_id, status).await, }, + MESSAGE_TYPE_DATA => { + // TODO(wllenyj): Compatible with golang behavior. + if (msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED + && !msg.payload.is_empty() + { + Self::respond_with_status( + self.tx.clone(), + stream_id, + get_status( + Code::INVALID_ARGUMENT, + format!( + "Stream id {}: data close message connot include data", + stream_id + ), + ), + ) + .await; + } + let stream_tx = self.streams.lock().unwrap().get(&stream_id).cloned(); + if let Some(stream_tx) = stream_tx { + if let Err(e) = stream_tx.send(Ok(msg)).await { + Self::respond_with_status( + self.tx.clone(), + stream_id, + get_status( + Code::INVALID_ARGUMENT, + format!("Stream id {}: handling data error: {}", stream_id, e), + ), + ) + .await; + } + } else { + Self::respond_with_status( + self.tx.clone(), + stream_id, + get_status(Code::INVALID_ARGUMENT, "Stream is no longger active"), + ) + .await; + } + } _ => { // TODO: else we must ignore this for future compat. log this? // TODO(wllenyj): Compatible with golang behavior. @@ -432,12 +504,23 @@ impl HandlerContext { let req = &req_msg.payload; trace!("Got Message request {} {}", req.service, req.method); - let path = utils::get_path(&req.service, &req.method); - let method = self.methods.get(&path).ok_or_else(|| { - get_status(Code::INVALID_ARGUMENT, format!("{} does not exist", &path)) + let srv = self.services.get(&req.service).ok_or_else(|| { + get_status( + Code::INVALID_ARGUMENT, + format!("{} service does not exist", &req.service), + ) })?; - return self.handle_method(method.as_ref(), req_msg).await; + if let Some(method) = srv.get_method(&req.method) { + return self.handle_method(method, req_msg).await; + } + if let Some(stream) = srv.get_stream(&req.method) { + return self.handle_stream(stream, req_msg).await; + } + Err(get_status( + Code::UNIMPLEMENTED, + format!("{} method", &req.method), + )) } async fn handle_method( @@ -484,6 +567,61 @@ impl HandlerContext { } } + async fn handle_stream( + &self, + stream: Arc, + req_msg: Message, + ) -> StdResult, Status> { + let stream_id = req_msg.header.stream_id; + let req = req_msg.payload; + let path = utils::get_path(&req.service, &req.method); + + let (tx, rx): (ResultSender, ResultReceiver) = channel(100); + let stream_tx = tx.clone(); + self.streams.lock().unwrap().insert(stream_id, tx); + + let _remote_close = (req_msg.header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED; + let _remote_open = (req_msg.header.flags & FLAG_REMOTE_OPEN) == FLAG_REMOTE_OPEN; + let si = StreamInner::new( + stream_id, + self.tx.clone(), + rx, + true, // TODO + true, + Kind::Server, + self.streams.clone(), + ); + + let ctx = TtrpcContext { + fd: self.fd, + mh: req_msg.header, + metadata: context::from_pb(&req.metadata), + timeout_nano: req.timeout_nano, + }; + + let task = spawn(async move { stream.handler(ctx, si).await }); + + if !req.payload.is_empty() { + // Fake the first data message. + let msg = GenMessage { + header: MessageHeader::new_data(stream_id, req.payload.len() as u32), + payload: req.payload, + }; + stream_tx.send(Ok(msg)).await.map_err(|e| { + error!("send stream data {} got error {:?}", path, &e); + get_status(Code::UNKNOWN, e) + })?; + } + task.await + .unwrap_or_else(|e| { + Err(Error::Others(format!( + "stream {} task got error {:?}", + path, e + ))) + }) + .map_err(|e| get_status(Code::UNKNOWN, e)) + } + async fn respond(tx: MessageSender, stream_id: u32, resp: Response) -> Result<()> { let payload = resp .encode() diff --git a/src/asynchronous/utils.rs b/src/asynchronous/utils.rs index 7cb038b4..17716ae8 100644 --- a/src/asynchronous/utils.rs +++ b/src/asynchronous/utils.rs @@ -84,6 +84,16 @@ pub trait MethodHandler { async fn handler(&self, ctx: TtrpcContext, req: Request) -> Result; } +/// Trait that implements handler which is a proxy to the stream (async). +#[async_trait] +pub trait StreamHandler { + async fn handler( + &self, + ctx: TtrpcContext, + stream: crate::r#async::StreamInner, + ) -> Result>; +} + /// The context of ttrpc (async). #[derive(Debug)] pub struct TtrpcContext {